Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiStateReporter variable pos/vel save frequency #712

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
101 changes: 71 additions & 30 deletions openmmtools/multistate/multistatereporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class MultiStateReporter(object):
analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple)
If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the
reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms.
position_interval : int, default 1
the frequency at which to write positions relative to analysis
information, 0 would prevent information being written
velocity_interval : int, default 1
the frequency at which to write positions relative to analysis
information, 0 would prevent information being written

Attributes
----------
Expand All @@ -113,7 +119,10 @@ class MultiStateReporter(object):
"""
def __init__(self, storage, open_mode=None,
checkpoint_interval=50, checkpoint_storage=None,
analysis_particle_indices=()):
analysis_particle_indices=(),
position_interval=1,
velocity_interval=1,
):

# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
Expand All @@ -136,6 +145,9 @@ def __init__(self, storage, open_mode=None,
self._checkpoint_interval = checkpoint_interval
# Cast to tuple no mater what 1-D-like input was given
self._analysis_particle_indices = tuple(analysis_particle_indices)
self._position_interval = position_interval
self._velocity_interval = velocity_interval

if open_mode is not None:
self.open(open_mode)
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
Expand Down Expand Up @@ -202,6 +214,16 @@ def checkpoint_interval(self):
"""Returns the checkpoint interval"""
return self._checkpoint_interval

@property
def position_interval(self):
"""Interval relative to energies that positions are written at"""
return self._position_interval

@property
def velocity_interval(self):
"""Interval relative to energies that velocities are written at"""
return self._velocity_interval

def storage_exists(self, skip_size=False):
"""
Check if the storage files exist on disk.
Expand Down Expand Up @@ -415,6 +437,8 @@ def _initialize_storage_file(self, ncfile, nc_name, convention):
ncfile.ConventionVersion = '0.2'
ncfile.DataUsedFor = nc_name
ncfile.CheckpointInterval = self._checkpoint_interval
ncfile.PositionInterval = self._position_interval
ncfile.VelocityInterval = self._velocity_interval

# Create and initialize the global variables
nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar')
Expand Down Expand Up @@ -1647,35 +1671,47 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i
write_iteration = self._calculate_checkpoint_iteration(iteration)
else:
write_iteration = iteration

# write out pos/vel - if checkpointing,
# or if interval matches desired frequency
write_pos = (storage_file == 'checkpoint' or
(self._position_interval != 0
and not (write_iteration % self._position_interval)))
write_vel = (storage_file == 'checkpoint' or
(self._velocity_interval != 0
and not (write_iteration % self._velocity_interval)))

# Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval
if write_iteration is not None:
# Store sampler states.
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic:
if write_pos:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

if write_vel:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic and write_pos:
# Store box vectors and volume.
# Allocate whole write to memory first
box_vectors = np.zeros([n_replicas, 3, 3])
Expand Down Expand Up @@ -1727,21 +1763,26 @@ def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoi
sampler_states = list()
for replica_index in range(n_replicas):
# Restore positions.
x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64)
positions = unit.Quantity(x, unit.nanometers)
try:
x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64)
positions = unit.Quantity(x, unit.nanometers)
except (IndexError, KeyError):
positions = np.zeros((storage.dimensions['atom'].size, # TODO: analysis_particles or atom here?
storage.dimensions['spatial'].size), dtype=np.float64)

# Restore velocities
# try-catch exception, enabling reading legacy/older serialized objects from openmmtools<0.21.3
try:
x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64)
velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds)
except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2)
except (IndexError, KeyError): # Velocities key/variable not found in serialization (openmmtools<=0.21.2)
# pass zeros as velocities when key is not found (<0.21.3 behavior)
velocities = np.zeros_like(positions)

if 'box_vectors' in storage.variables:
# Restore box vectors.
x = storage.variables['box_vectors'][read_iteration, replica_index, :, :].astype(np.float64)
# TODO: Are box vectors also variably saved?
box_vectors = unit.Quantity(x, unit.nanometers)
else:
box_vectors = None
Expand Down
141 changes: 131 additions & 10 deletions openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,23 @@ class TestReporter:

@staticmethod
@contextlib.contextmanager
def temporary_reporter(
checkpoint_interval=1, checkpoint_storage=None, analysis_particle_indices=()
):

def temporary_reporter(checkpoint_interval=1, checkpoint_storage=None,
position_interval=1, velocity_interval=1,
analysis_particle_indices=()):
"""Create and initialize a reporter in a temporary directory."""
with temporary_directory() as tmp_dir_path:
storage_file = os.path.join(tmp_dir_path, "temp_dir/test_storage.nc")
assert not os.path.isfile(storage_file)
reporter = MultiStateReporter(
storage=storage_file,
open_mode="w",
checkpoint_interval=checkpoint_interval,
checkpoint_storage=checkpoint_storage,
analysis_particle_indices=analysis_particle_indices,
)

reporter = MultiStateReporter(storage=storage_file, open_mode='w',
checkpoint_interval=checkpoint_interval,
checkpoint_storage=checkpoint_storage,
analysis_particle_indices=analysis_particle_indices,
position_interval=position_interval,
velocity_interval=velocity_interval,
)

assert reporter.storage_exists(skip_size=True)
yield reporter

Expand Down Expand Up @@ -565,6 +568,124 @@ def test_write_sampler_states(self):
checkpoint_state.box_vectors / unit.nanometer,
)

def test_writer_sampler_states_pos_interval(self):
""" write positions and velocities every other frame"""
analysis_particles = (1, 2)
with self.temporary_reporter(analysis_particle_indices=analysis_particles,
position_interval=2, velocity_interval=2,
checkpoint_interval=2) as reporter:
# Create sampler states.
alanine_test = testsystems.AlanineDipeptideVacuum()
positions = alanine_test.positions
sampler_states = [mmtools.states.SamplerState(positions=positions)
for _ in range(2)]

# Check that after writing and reading, states are identical.
for iteration in range(3):
reporter.write_sampler_states(sampler_states, iteration=iteration)
reporter.write_last_iteration(iteration)

# Check first frame
restored_sampler_states = reporter.read_sampler_states(iteration=0)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert np.allclose(state.positions, restored_state.positions)
# By default stored velocities are zeros if not present in origin sampler_state
assert np.allclose(np.zeros(state.positions.shape), restored_state.velocities)
assert np.allclose(state.box_vectors / unit.nanometer, restored_state.box_vectors / unit.nanometer)
# Second frame should not have positions or velocities
restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
# missing values are returned as numpy masked array
# so we check that these arrays are all masked
assert restored_state.positions._value.mask.all()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is something that I'm not 100% sure of currently. netCDF will return a numpy masked array when you access data that isn't present (here we're saving vels every other frame, so accessing velocities here has no data). I hadn't ever encountered these, so maybe it's not the best thing to return? (or maybe it is if it's the netCDF normal return).

assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is something I need to double check, I'm a little confused why the checkpoint has a box and future frames don't, so this is still WIP till I figure that out


restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert np.allclose(state.positions[analysis_particles, :], restored_state.positions)
# By default stored velocities are zeros if not present in origin sampler_state
assert np.allclose(np.zeros((2, 3)), restored_state.velocities)
assert np.allclose(state.box_vectors / unit.nanometer, restored_state.box_vectors / unit.nanometer)

def test_write_sampler_states_no_vel(self):
"""do not write velocities to trajectory file"""
analysis_particles = (1, 2)
with self.temporary_reporter(analysis_particle_indices=analysis_particles,
position_interval=1, velocity_interval=0,
checkpoint_interval=2) as reporter:
# Create sampler states.
alanine_test = testsystems.AlanineDipeptideVacuum()
positions = alanine_test.positions
sampler_states = [mmtools.states.SamplerState(positions=positions)
for _ in range(2)]

# Check that after writing and reading, states are identical.
for iteration in range(3):
reporter.write_sampler_states(sampler_states, iteration=iteration)
reporter.write_last_iteration(iteration)

# Check first frame
restored_sampler_states = reporter.read_sampler_states(iteration=0, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
# missing values are returned as numpy masked array
# so we check that these arrays are all masked
assert np.allclose(state.positions[analysis_particles, :], restored_state.positions)
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

# Second frame should not have positions or velocities
restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert np.allclose(state.positions[analysis_particles, :], restored_state.positions)
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert np.allclose(state.positions[analysis_particles, :], restored_state.positions)
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

def test_write_sampler_states_no_pos(self):
"""do not write positions or velocities to trajectory file"""
analysis_particles = (1, 2)
with self.temporary_reporter(analysis_particle_indices=analysis_particles,
position_interval=0, velocity_interval=0,
checkpoint_interval=2) as reporter:
# Create sampler states.
alanine_test = testsystems.AlanineDipeptideVacuum()
positions = alanine_test.positions
sampler_states = [mmtools.states.SamplerState(positions=positions)
for _ in range(2)]

# Check that after writing and reading, states are identical.
for iteration in range(3):
reporter.write_sampler_states(sampler_states, iteration=iteration)
reporter.write_last_iteration(iteration)

# Check first frame
restored_sampler_states = reporter.read_sampler_states(iteration=0, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
# missing values are returned as numpy masked array
# so we check that these arrays are all masked
assert restored_state.positions._value.mask.all()
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

# Second frame should not have positions or velocities
restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert restored_state.positions._value.mask.all()
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True)
for state, restored_state in zip(sampler_states, restored_sampler_states):
assert restored_state.positions._value.mask.all()
assert restored_state.velocities._value.mask.all()
assert restored_state.box_vectors is None # not periodic

def test_analysis_particle_mismatch(self):
"""Test that previously stored analysis particles is higher priority."""
blank_analysis_particles = ()
Expand Down
Loading