Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 77 additions & 40 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,7 @@ def increment_filename(base_filename, suffix):
lam = f"{lambda_value:.5f}"
filenames = {}
filenames["checkpoint"] = str(output_directory / f"checkpoint_{lam}.s3")
filenames["checkpoint_state"] = str(output_directory / f"checkpoint_{lam}.npz")
filenames["energy_traj"] = str(output_directory / f"energy_traj_{lam}.parquet")
filenames["trajectory"] = str(output_directory / f"traj_{lam}.dcd")
filenames["trajectory_chunk"] = str(output_directory / f"traj_{lam}_")
Expand Down Expand Up @@ -1628,18 +1629,22 @@ def get_last_config(output_directory):
f"No config files found in {self._config.output_directory}, "
"attempting to retrieve config from lambda = 0 checkpoint file."
)
try:
system_temp = _sr.stream.load(
str(self._config.output_directory / "checkpoint_0.00000.s3")
)
except:
expdir = self._config.output_directory / "checkpoint_0.00000.s3"
_logger.error(f"Unable to load checkpoint file from {expdir}.")
raise
s3_path = self._config.output_directory / "checkpoint_0.00000.s3"
if s3_path.exists():
try:
system_temp = _sr.stream.load(str(s3_path))
except:
_logger.error(f"Unable to load checkpoint file from {s3_path}.")
raise
else:
self._last_config = dict(system_temp.property("config"))
config = self._config.as_dict(sire_compatible=True)
del system_temp
else:
self._last_config = dict(system_temp.property("config"))
config = self._config.as_dict(sire_compatible=True)
del system_temp
raise OSError(
f"No config file found in {self._config.output_directory}. "
"Cannot validate restart config without a config.yaml file."
)

self._compare_configs(self._last_config, config)

Expand Down Expand Up @@ -1854,6 +1859,8 @@ def _checkpoint(
lambda_energy=None,
lambda_grad=None,
is_final_block=False,
context=None,
gcmc_sampler=None,
):
"""
Save a checkpoint file.
Expand Down Expand Up @@ -1997,24 +2004,18 @@ def _checkpoint(
for chunk in traj_chunks:
_Path(chunk).unlink()

# Add config and lambda value to the system properties.
system.set_property(
"config", self._config.as_dict(sire_compatible=True)
# Write the checkpoint system to file.
self._write_checkpoint_system(
system, index, context=context, gcmc_sampler=gcmc_sampler
)
system.set_property("lambda", lam)

# Delete all frames from the system.
system.delete_all_frames()

# Stream the final system to file.
_sr.stream.save(system, self._filenames[index]["checkpoint"])

# Create the final parquet file.
_dataframe_to_parquet(
df,
metadata=metadata,
filename=self._filenames[index]["energy_traj"],
)
# Append the final block's energy data. If no parquet exists
# yet (e.g. checkpoint_frequency=0), create one from scratch.
_energy_traj = self._filenames[index]["energy_traj"]
if _Path(_energy_traj).exists():
_parquet_append(_energy_traj, df.iloc[-self._energy_per_block :])
else:
_dataframe_to_parquet(df, metadata=metadata, filename=_energy_traj)

else:
# Update the starting block if necessary.
Expand All @@ -2034,27 +2035,23 @@ def _checkpoint(
format=["DCD"],
)

# Encode the configuration and lambda value as system properties.
system.set_property(
"config", self._config.as_dict(sire_compatible=True)
# Write the checkpoint system to file.
self._write_checkpoint_system(
system, index, context=context, gcmc_sampler=gcmc_sampler
)
system.set_property("lambda", lam)

# Delete all frames from the system.
system.delete_all_frames()

# Stream the checkpoint to file.
_sr.stream.save(system, self._filenames[index]["checkpoint"])

# Skip parquet creation for post-equilibration checkpoints.
if not is_post_equilibration:
# Create the parquet file name.
filename = self._filenames[index]["energy_traj"]

# Create the parquet file.
if block == self._start_block:
# At the start block of a restart, append to the existing
# parquet so that historical data is preserved. For fresh
# runs, overwrite (or create) the parquet file.
if block == self._start_block and not (
self._is_restart and _Path(filename).exists()
):
_dataframe_to_parquet(df, metadata=metadata, filename=filename)
# Append to the parquet file.
else:
_parquet_append(
filename,
Expand All @@ -2066,6 +2063,35 @@ def _checkpoint(

return index, None

def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None):
"""
Write the system state to the checkpoint file.

Subclasses may override this to store state differently, e.g. repex
records the simulation time in the dynamics cache pickle instead of
streaming a per-replica file.

Parameters
----------

system: :class: `System <sire.system.System>`
The committed system to checkpoint.

index: int
The index of the lambda window.

context: openmm.Context, optional
The OpenMM context. Unused in the base implementation.

gcmc_sampler: GCMCSampler, optional
The GCMC sampler. Unused in the base implementation.
"""
lam = self._lambda_values[index]
system.set_property("config", self._config.as_dict(sire_compatible=True))
system.set_property("lambda", lam)
system.delete_all_frames()
_sr.stream.save(system, self._filenames[index]["checkpoint"])

def _backup_checkpoint(self, index):
"""
Create a backup of the previous checkpoint files.
Expand All @@ -2088,6 +2114,17 @@ def _backup_checkpoint(self, index):
self._filenames[index]["checkpoint"],
str(self._filenames[index]["checkpoint"]) + ".bak",
)
except Exception as e:
return index, e

try:
# Backup the existing compact numpy checkpoint file, if it exists.
path = _Path(self._filenames[index]["checkpoint_state"])
if path.exists() and path.stat().st_size > 0:
_copyfile(
self._filenames[index]["checkpoint_state"],
str(self._filenames[index]["checkpoint_state"]) + ".bak",
)
traj_filename = self._filenames[index]["trajectory"]
except Exception as e:
return index, e
Expand Down
118 changes: 93 additions & 25 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self._lambdas = lambdas
self._rest2_scale_factors = rest2_scale_factors
self._states = _np.array(range(len(lambdas)))
self._time = None
self._openmm_states = [None] * len(lambdas)
self._gcmc_samplers = [None] * len(lambdas)
self._gcmc_states = [None] * len(lambdas)
Expand Down Expand Up @@ -136,8 +137,12 @@ def __setstate__(self, state):
n = len(self._lambdas)
if not hasattr(self, "_gcmc_stats"):
self._gcmc_stats = [None] * n
if not hasattr(self, "_gcmc_states"):
self._gcmc_states = [None] * n
if not hasattr(self, "_terminal_flip_stats"):
self._terminal_flip_stats = [[0, 0]] * n
if not hasattr(self, "_time"):
self._time = None

def __getstate__(self):
"""
Expand All @@ -149,6 +154,7 @@ def __getstate__(self):
"_lambdas": self._lambdas,
"_rest2_scale_factors": self._rest2_scale_factors,
"_states": self._states,
"_time": self._time,
"_openmm_states": self._openmm_states,
# Don't pickle the GCMC samplers since they need to be recreated.
"_gcmc_samplers": len(self._gcmc_samplers) * [None],
Expand Down Expand Up @@ -298,15 +304,6 @@ def _create_dynamics(
f"Created GCMC sampler for lambda {lam:.5f} on device {device}"
)

# Log the initial position of the GCMC sphere.
if self._gcmc_samplers[i]._reference is not None:
positions = _sr.io.get_coords_array(mols)
target = self._gcmc_samplers[i]._get_target_position(positions)
_logger.info(
f"Initial GCMC sphere centre for lambda {lam:.5f} on device {device}: "
f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A"
)

# Create the dynamics object.
try:
dynamics = mols.dynamics(**dynamics_kwargs)
Expand Down Expand Up @@ -817,10 +814,31 @@ def __init__(self, system, config):
output_directory=self._config.output_directory,
xml_filenames=xml_filenames,
)

else:
_logger.debug("Restarting from file")

time = self._system[0].time()
# Load the dynamics cache first so we can read the simulation time
# from it (new format). Old-format restarts with .s3 files fall
# back to reading the time from the loaded Sire system.
try:
with open(self._repex_state, "rb") as f:
self._dynamics_cache = _pickle.load(f)
except Exception as e:
_logger.error(
f"Could not load dynamics cache from {self._repex_state}: {e}"
)
raise e

# Derive the simulation time: prefer the value stored in the
# pickle (_time is set by the new-format _write_checkpoint_system);
# fall back to the Sire system for old-format checkpoints.
if self._dynamics_cache._time is not None and not isinstance(
self._system, list
):
time = self._dynamics_cache._time
else:
time = self._system[0].time()

# Check to see if the simulation is already complete.
if self._done_file.exists():
Expand Down Expand Up @@ -853,22 +871,18 @@ def __init__(self, system, config):
f"Restarting at time {time}, time remaining = {self._config.runtime - time}"
)

try:
with open(self._repex_state, "rb") as f:
self._dynamics_cache = _pickle.load(f)
except Exception as e:
_logger.error(
f"Could not load dynamics cache from {self._repex_state}: {e}"
)
raise e

# Make sure the number of replicas is the same.
if len(self._dynamics_cache._lambdas) != self._config.num_lambda:
_logger.error(
f"The number of replicas in the dynamics cache ({len(self._dynamics_cache._lambdas)}) "
f"does not match the number of replicas in the configuration ({self._config.num_lambda})."
)

# For new-format restarts, set the system time so that dynamics
# objects are initialised with the correct integrator step count.
if not isinstance(self._system, list):
self._system.set_time(time)

# Create the dynamics objects.
self._dynamics_cache._create_dynamics(
self._system,
Expand Down Expand Up @@ -905,13 +919,37 @@ def __init__(self, system, config):
if self._dynamics_cache._gcmc_stats[i] is not None:
gcmc_sampler.restore_stats(self._dynamics_cache._gcmc_stats[i])

# Log the GCMC sphere centre for each replica using the actual context
# positions (accurate for both fresh runs and restarts).
import openmm.unit as _omm_unit

for i, lam in enumerate(self._lambda_values):
dynamics, gcmc_sampler = self._dynamics_cache.get(i)
if gcmc_sampler is not None and gcmc_sampler._reference is not None:
state = dynamics.context().getState(getPositions=True)
positions = state.getPositions(asNumpy=True).value_in_unit(
_omm_unit.angstrom
)
target = gcmc_sampler._get_target_position(positions)
_logger.info(
f"Initial GCMC sphere centre for lambda {lam:.5f}: "
f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A"
)

# Conversion factor for reduced potential.
kT = (_sr.units.k_boltz * self._config.temperature).to(_sr.units.kcal_per_mol)
self._beta = 1.0 / kT

# If restarting, subtract the time already run from the total runtime
if self._config.restart:
time = self._system[0].time()
time = (
self._dynamics_cache._time
if (
self._dynamics_cache._time is not None
and not isinstance(self._system, list)
)
else self._system[0].time()
)
self._config.runtime = str(self._config.runtime - time)

# Work out the current block number.
Expand Down Expand Up @@ -1555,7 +1593,7 @@ def _minimise(self, index):
# Get the dynamics object (and GCMC sampler).
dynamics, gcmc_sampler = self._dynamics_cache.get(index)

if gcmc_sampler is not None:
if gcmc_sampler is not None and not self._is_restart:
gcmc_sampler.push()
try:
_logger.info(
Expand Down Expand Up @@ -1845,6 +1883,40 @@ def _assemble_results(self, results):

return matrix

def _check_restart(self):
"""
Check the output directory for a valid restart state.

If per-replica checkpoint stream files (.s3) exist the base class is
used to load them (old format, backwards compatible). Otherwise the
repex state pickle is used and the original input system is returned
directly, since positions and velocities come from the OpenMM states
stored in the pickle.
"""
from pathlib import Path as _Path_local

checkpoint_path = _Path_local(self._filenames[0]["checkpoint"])
if checkpoint_path.exists():
_logger.info("Restarting from legacy stream file checkpoint.")
return super()._check_restart()

repex_state = self._config.output_directory / "repex_state.pkl"
if not repex_state.exists():
return False, self._system

return True, self._system

def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None):
"""
Record the current simulation time in the dynamics cache.

For repex, per-replica stream files are not written. The simulation
time is stored in the dynamics cache pickle instead, and positions and
velocities are already stored as compact numpy arrays in the OpenMM
state dict.
"""
self._dynamics_cache._time = system.time()

def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
"""
Checkpoint the simulation.
Expand Down Expand Up @@ -1886,10 +1958,6 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
# Commit the current system.
system = dynamics.commit()

# If performing GCMC, then we need to flag the ghost waters.
if gcmc_sampler is not None:
system = gcmc_sampler._flag_ghost_waters(system)

# Get the simulation speed.
speed = dynamics.time_speed()

Expand Down
Loading
Loading