From 2a7897012f085a25c9b889b7db52b6d042238746 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 15 Jun 2026 22:10:35 +0100 Subject: [PATCH 1/6] Remove per-replica stream files from repex restarts. --- src/somd2/runner/_base.py | 117 ++++++++++++++-------- src/somd2/runner/_repex.py | 87 ++++++++++++++--- src/somd2/runner/_runner.py | 184 +++++++++++++++++++++++++++++------ tests/runner/test_restart.py | 37 ++++--- 4 files changed, 320 insertions(+), 105 deletions(-) diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 13a670e..c1ccdb7 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -1273,6 +1273,7 @@ def increment_filename(base_filename, suffix): lam = f"{lambda_value:.5f}" filenames = {} filenames["checkpoint"] = str(output_directory / f"checkpoint_{lam}.s3") + filenames["checkpoint_state"] = str(output_directory / f"checkpoint_{lam}.npz") filenames["energy_traj"] = str(output_directory / f"energy_traj_{lam}.parquet") filenames["trajectory"] = str(output_directory / f"traj_{lam}.dcd") filenames["trajectory_chunk"] = str(output_directory / f"traj_{lam}_") @@ -1628,18 +1629,22 @@ def get_last_config(output_directory): f"No config files found in {self._config.output_directory}, " "attempting to retrieve config from lambda = 0 checkpoint file." ) - try: - system_temp = _sr.stream.load( - str(self._config.output_directory / "checkpoint_0.00000.s3") - ) - except: - expdir = self._config.output_directory / "checkpoint_0.00000.s3" - _logger.error(f"Unable to load checkpoint file from {expdir}.") - raise + s3_path = self._config.output_directory / "checkpoint_0.00000.s3" + if s3_path.exists(): + try: + system_temp = _sr.stream.load(str(s3_path)) + except: + _logger.error(f"Unable to load checkpoint file from {s3_path}.") + raise + else: + self._last_config = dict(system_temp.property("config")) + config = self._config.as_dict(sire_compatible=True) + del system_temp else: - self._last_config = dict(system_temp.property("config")) - config = self._config.as_dict(sire_compatible=True) - del system_temp + raise OSError( + f"No config file found in {self._config.output_directory}. " + "Cannot validate restart config without a config.yaml file." + ) self._compare_configs(self._last_config, config) @@ -1854,6 +1859,8 @@ def _checkpoint( lambda_energy=None, lambda_grad=None, is_final_block=False, + context=None, + gcmc_sampler=None, ): """ Save a checkpoint file. @@ -1997,24 +2004,18 @@ def _checkpoint( for chunk in traj_chunks: _Path(chunk).unlink() - # Add config and lambda value to the system properties. - system.set_property( - "config", self._config.as_dict(sire_compatible=True) + # Write the checkpoint system to file. + self._write_checkpoint_system( + system, index, context=context, gcmc_sampler=gcmc_sampler ) - system.set_property("lambda", lam) - - # Delete all frames from the system. - system.delete_all_frames() - - # Stream the final system to file. - _sr.stream.save(system, self._filenames[index]["checkpoint"]) - # Create the final parquet file. - _dataframe_to_parquet( - df, - metadata=metadata, - filename=self._filenames[index]["energy_traj"], - ) + # Append the final block's energy data. If no parquet exists + # yet (e.g. checkpoint_frequency=0), create one from scratch. + _energy_traj = self._filenames[index]["energy_traj"] + if _Path(_energy_traj).exists(): + _parquet_append(_energy_traj, df.iloc[-self._energy_per_block :]) + else: + _dataframe_to_parquet(df, metadata=metadata, filename=_energy_traj) else: # Update the starting block if necessary. @@ -2034,27 +2035,23 @@ def _checkpoint( format=["DCD"], ) - # Encode the configuration and lambda value as system properties. - system.set_property( - "config", self._config.as_dict(sire_compatible=True) + # Write the checkpoint system to file. + self._write_checkpoint_system( + system, index, context=context, gcmc_sampler=gcmc_sampler ) - system.set_property("lambda", lam) - - # Delete all frames from the system. - system.delete_all_frames() - - # Stream the checkpoint to file. - _sr.stream.save(system, self._filenames[index]["checkpoint"]) # Skip parquet creation for post-equilibration checkpoints. if not is_post_equilibration: # Create the parquet file name. filename = self._filenames[index]["energy_traj"] - # Create the parquet file. - if block == self._start_block: + # At the start block of a restart, append to the existing + # parquet so that historical data is preserved. For fresh + # runs, overwrite (or create) the parquet file. + if block == self._start_block and not ( + self._is_restart and _Path(filename).exists() + ): _dataframe_to_parquet(df, metadata=metadata, filename=filename) - # Append to the parquet file. else: _parquet_append( filename, @@ -2066,6 +2063,35 @@ def _checkpoint( return index, None + def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None): + """ + Write the system state to the checkpoint file. + + Subclasses may override this to store state differently, e.g. repex + records the simulation time in the dynamics cache pickle instead of + streaming a per-replica file. + + Parameters + ---------- + + system: :class: `System ` + The committed system to checkpoint. + + index: int + The index of the lambda window. + + context: openmm.Context, optional + The OpenMM context. Unused in the base implementation. + + gcmc_sampler: GCMCSampler, optional + The GCMC sampler. Unused in the base implementation. + """ + lam = self._lambda_values[index] + system.set_property("config", self._config.as_dict(sire_compatible=True)) + system.set_property("lambda", lam) + system.delete_all_frames() + _sr.stream.save(system, self._filenames[index]["checkpoint"]) + def _backup_checkpoint(self, index): """ Create a backup of the previous checkpoint files. @@ -2088,6 +2114,17 @@ def _backup_checkpoint(self, index): self._filenames[index]["checkpoint"], str(self._filenames[index]["checkpoint"]) + ".bak", ) + except Exception as e: + return index, e + + try: + # Backup the existing compact numpy checkpoint file, if it exists. + path = _Path(self._filenames[index]["checkpoint_state"]) + if path.exists() and path.stat().st_size > 0: + _copyfile( + self._filenames[index]["checkpoint_state"], + str(self._filenames[index]["checkpoint_state"]) + ".bak", + ) traj_filename = self._filenames[index]["trajectory"] except Exception as e: return index, e diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index bf73b30..8edcf5b 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -102,6 +102,7 @@ def __init__( self._lambdas = lambdas self._rest2_scale_factors = rest2_scale_factors self._states = _np.array(range(len(lambdas))) + self._time = None self._openmm_states = [None] * len(lambdas) self._gcmc_samplers = [None] * len(lambdas) self._gcmc_states = [None] * len(lambdas) @@ -136,8 +137,12 @@ def __setstate__(self, state): n = len(self._lambdas) if not hasattr(self, "_gcmc_stats"): self._gcmc_stats = [None] * n + if not hasattr(self, "_gcmc_states"): + self._gcmc_states = [None] * n if not hasattr(self, "_terminal_flip_stats"): self._terminal_flip_stats = [[0, 0]] * n + if not hasattr(self, "_time"): + self._time = None def __getstate__(self): """ @@ -149,6 +154,7 @@ def __getstate__(self): "_lambdas": self._lambdas, "_rest2_scale_factors": self._rest2_scale_factors, "_states": self._states, + "_time": self._time, "_openmm_states": self._openmm_states, # Don't pickle the GCMC samplers since they need to be recreated. "_gcmc_samplers": len(self._gcmc_samplers) * [None], @@ -820,7 +826,27 @@ def __init__(self, system, config): else: _logger.debug("Restarting from file") - time = self._system[0].time() + # Load the dynamics cache first so we can read the simulation time + # from it (new format). Old-format restarts with .s3 files fall + # back to reading the time from the loaded Sire system. + try: + with open(self._repex_state, "rb") as f: + self._dynamics_cache = _pickle.load(f) + except Exception as e: + _logger.error( + f"Could not load dynamics cache from {self._repex_state}: {e}" + ) + raise e + + # Derive the simulation time: prefer the value stored in the + # pickle (_time is set by the new-format _write_checkpoint_system); + # fall back to the Sire system for old-format checkpoints. + if self._dynamics_cache._time is not None and not isinstance( + self._system, list + ): + time = self._dynamics_cache._time + else: + time = self._system[0].time() # Check to see if the simulation is already complete. if self._done_file.exists(): @@ -853,15 +879,6 @@ def __init__(self, system, config): f"Restarting at time {time}, time remaining = {self._config.runtime - time}" ) - try: - with open(self._repex_state, "rb") as f: - self._dynamics_cache = _pickle.load(f) - except Exception as e: - _logger.error( - f"Could not load dynamics cache from {self._repex_state}: {e}" - ) - raise e - # Make sure the number of replicas is the same. if len(self._dynamics_cache._lambdas) != self._config.num_lambda: _logger.error( @@ -911,7 +928,14 @@ def __init__(self, system, config): # If restarting, subtract the time already run from the total runtime if self._config.restart: - time = self._system[0].time() + time = ( + self._dynamics_cache._time + if ( + self._dynamics_cache._time is not None + and not isinstance(self._system, list) + ) + else self._system[0].time() + ) self._config.runtime = str(self._config.runtime - time) # Work out the current block number. @@ -1845,6 +1869,43 @@ def _assemble_results(self, results): return matrix + def _check_restart(self): + """ + Check the output directory for a valid restart state. + + If per-replica checkpoint stream files (.s3) exist the base class is + used to load them (old format, backwards compatible). Otherwise the + repex state pickle is used and the original input system is returned + directly, since positions and velocities come from the OpenMM states + stored in the pickle. + """ + from pathlib import Path as _Path_local + + checkpoint_path = _Path_local(self._filenames[0]["checkpoint"]) + if checkpoint_path.exists(): + # Old format: load per-replica .s3 files via base class. + return super()._check_restart() + + repex_state = self._config.output_directory / "repex_state.pkl" + if not repex_state.exists(): + return False, self._system + + _logger.info( + "No checkpoint stream files found; restarting from repex state pickle." + ) + return True, self._system + + def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None): + """ + Record the current simulation time in the dynamics cache. + + For repex, per-replica stream files are not written. The simulation + time is stored in the dynamics cache pickle instead, and positions and + velocities are already stored as compact numpy arrays in the OpenMM + state dict. + """ + self._dynamics_cache._time = system.time() + def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): """ Checkpoint the simulation. @@ -1886,10 +1947,6 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): # Commit the current system. system = dynamics.commit() - # If performing GCMC, then we need to flag the ghost waters. - if gcmc_sampler is not None: - system = gcmc_sampler._flag_ghost_waters(system) - # Get the simulation speed. speed = dynamics.time_speed() diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 574eb9b..00ea69d 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -266,6 +266,69 @@ def run(self): # Cleanup backup files. self._cleanup() + def _check_restart(self): + """ + Check the output directory for a valid restart state. + + Detects new-format (.npz) checkpoints and falls back to the legacy + .s3 stream file format when only old checkpoints are present. + """ + from pathlib import Path as _Path + + npz_path = _Path(self._filenames[0]["checkpoint_state"]) + s3_path = _Path(self._filenames[0]["checkpoint"]) + + if npz_path.exists(): + _logger.info("Restarting from compact numpy checkpoint state.") + return True, self._system + elif s3_path.exists(): + return super()._check_restart() + else: + return False, self._system + + def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None): + """ + Write the system state to a compact numpy checkpoint file. + + Saves positions, velocities, box vectors, simulation time, and (for + GCMC) ghost water indices to a .npz file. The legacy .s3 stream file + is not written. + + If no context is provided (should not happen in normal operation), + falls back to the base .s3 implementation. + """ + if context is None: + super()._write_checkpoint_system(system, index) + return + + import openmm.unit as _omm_unit + + state = context.getState(getPositions=True, getVelocities=True) + pos = state.getPositions(asNumpy=True).value_in_unit(_omm_unit.nanometer) + vel = state.getVelocities(asNumpy=True).value_in_unit( + _omm_unit.nanometer / _omm_unit.picosecond + ) + time_ps = system.time().to("ps") + + save_kwargs = { + "positions": pos, + "velocities": vel, + "time_ps": _np.array([time_ps]), + } + + box = state.getPeriodicBoxVectors(asNumpy=True) + if box is not None: + save_kwargs["box"] = box.value_in_unit(_omm_unit.nanometer) + + if gcmc_sampler is not None: + # water_state() returns 1 for active, 0 for ghost. + water_state = gcmc_sampler.water_state() + save_kwargs["ghost_water_indices"] = _np.where(water_state == 0)[0].astype( + _np.int32 + ) + + _np.savez(self._filenames[index]["checkpoint_state"], **save_kwargs) + def run_window(self, index): """ Run a single lamdba window. @@ -295,9 +358,16 @@ def run_window(self, index): if self._is_restart: _logger.debug(f"Restarting {_lam_sym} = {lambda_value} from file") - system = self._system[index].clone() - - time = system.time() + if isinstance(self._system, list): + # Old format: system with saved positions loaded from .s3 stream file. + system = self._system[index].clone() + time = system.time() + else: + # New format: original input system; time stored in .npz checkpoint. + system = self._system.clone() + time = _sr.u( + f"{float(_np.load(self._filenames[index]['checkpoint_state'])['time_ps'].item()):.6f} ps" + ) if time > self._config.runtime - self._config.timestep: _logger.success( f"{_lam_sym} = {lambda_value} already complete. Skipping." @@ -398,7 +468,14 @@ def _run( # Check for completion if this is a restart. if is_restart: - time = system.time() + if isinstance(self._system, list): + time = system.time() + else: + # New format: time stored in .npz, not in the Sire system. + time = _sr.u( + f"{float(_np.load(self._filenames[index]['checkpoint_state'])['time_ps'].item()):.6f} ps" + ) + system.set_time(time) if time > self._config.runtime - self._config.timestep: _logger.success( f"{_lam_sym} = {lambda_value} already complete. Skipping." @@ -644,6 +721,28 @@ def generate_lam_vals(lambda_base, increment=0.001): _logger.info(f"Writing OpenMM XML for {_lam_sym} = {lambda_value:.5f}") dynamics.to_xml(self._filenames[index]["xml"]) + # For new-format restarts, apply saved positions/velocities/box to context. + _new_format_restart = is_restart and not isinstance(self._system, list) + if _new_format_restart: + import openmm.unit as _omm_unit + + _npz_state = _np.load(self._filenames[index]["checkpoint_state"]) + dynamics.context().setPositions( + _npz_state["positions"] * _omm_unit.nanometer + ) + dynamics.context().setVelocities( + _npz_state["velocities"] * _omm_unit.nanometer / _omm_unit.picosecond + ) + if "box" in _npz_state: + from openmm import Vec3 as _Vec3 + + _box = _npz_state["box"] + dynamics.context().setPeriodicBoxVectors( + _Vec3(*_box[0]) * _omm_unit.nanometer, + _Vec3(*_box[1]) * _omm_unit.nanometer, + _Vec3(*_box[2]) * _omm_unit.nanometer, + ) + # Reset the GCMC sampler. This resets the sampling statistics and clears # the associated OpenMM forces. if gcmc_sampler is not None: @@ -655,31 +754,52 @@ def generate_lam_vals(lambda_base, increment=0.001): # If this is a restart, then we need to reset the GCMC water state # to match that of the restart system. if self._is_restart: - from openmm.unit import angstrom + if isinstance(self._system, list): + # Old format: restore ghost waters from cached indices/positions. + from openmm.unit import angstrom - gcmc_sampler.push() - try: - # First set all waters to non-ghosts. - gcmc_sampler._set_water_state( - dynamics.context(), - states=_np.ones(len(gcmc_sampler._water_indices)), - force=True, - ) + gcmc_sampler.push() + try: + # First set all waters to non-ghosts. + gcmc_sampler._set_water_state( + dynamics.context(), + states=_np.ones(len(gcmc_sampler._water_indices)), + force=True, + ) - # Now set the ghost waters. - gcmc_sampler._set_water_state( - dynamics.context(), - self._restart_ghost_waters[index], - states=_np.zeros(len(gcmc_sampler._water_indices)), - force=True, - ) - finally: - gcmc_sampler.pop() + # Now set the ghost waters. + gcmc_sampler._set_water_state( + dynamics.context(), + self._restart_ghost_waters[index], + states=_np.zeros(len(gcmc_sampler._water_indices)), + force=True, + ) + finally: + gcmc_sampler.pop() - # Finally, reset the context positions to match the restart system. - dynamics.context().setPositions( - self._restart_positions[index] * angstrom - ) + # Finally, reset the context positions to match the restart system. + dynamics.context().setPositions( + self._restart_positions[index] * angstrom + ) + else: + # New format: positions already applied; restore ghost water state. + ghost_idxs = _npz_state["ghost_water_indices"].tolist() + gcmc_sampler.push() + try: + gcmc_sampler._set_water_state( + dynamics.context(), + states=_np.ones(len(gcmc_sampler._water_indices)), + force=True, + ) + if ghost_idxs: + gcmc_sampler._set_water_state( + dynamics.context(), + ghost_idxs, + states=_np.zeros(len(gcmc_sampler._water_indices)), + force=True, + ) + finally: + gcmc_sampler.pop() # Otherwise, if we've performed equilibration, then we need to reset # the water state in the new context to match the equilibrated system. @@ -731,6 +851,8 @@ def generate_lam_vals(lambda_base, increment=0.001): speed=0.0, lambda_energy=lambda_energy, lambda_grad=lambda_grad, + context=dynamics.context(), + gcmc_sampler=gcmc_sampler, ) if error is not None: msg = ( @@ -940,10 +1062,6 @@ def generate_lam_vals(lambda_base, increment=0.001): # Commit the current system. system = dynamics.commit() - # If performing GCMC, then we need to flag the ghost waters. - if gcmc_sampler is not None: - system = gcmc_sampler._flag_ghost_waters(system) - # Record the end time. block_end = _timer() @@ -979,6 +1097,8 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_energy=lambda_energy, lambda_grad=lambda_grad, is_final_block=is_final_block, + context=dynamics.context(), + gcmc_sampler=gcmc_sampler, ) if error is not None: @@ -1097,6 +1217,8 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_energy=lambda_energy, lambda_grad=lambda_grad, is_final_block=True, + context=dynamics.context(), + gcmc_sampler=gcmc_sampler, ) # Delete all trajectory frames from the Sire system within the @@ -1300,6 +1422,8 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_energy=lambda_energy, lambda_grad=lambda_grad, is_final_block=True, + context=dynamics.context(), + gcmc_sampler=gcmc_sampler, ) if error is not None: diff --git a/tests/runner/test_restart.py b/tests/runner/test_restart.py index 6c1e2f6..fda3fcf 100644 --- a/tests/runner/test_restart.py +++ b/tests/runner/test_restart.py @@ -48,12 +48,13 @@ def test_restart(mols, request): [str(Path(tmpdir) / "system0.prm7"), str(Path(tmpdir) / "traj_0.00000.dcd")] ) - # Check that both config and lambda have been written - # as properties to the streamed checkpoint file. - checkpoint = sr.stream.load(str(Path(tmpdir) / "checkpoint_0.00000.s3")) - props = checkpoint.property_keys() - assert "config" in props - assert "lambda" in props + # Check that the compact numpy checkpoint file was written. + import numpy as np + + checkpoint_state = np.load(str(Path(tmpdir) / "checkpoint_0.00000.npz")) + assert "positions" in checkpoint_state + assert "velocities" in checkpoint_state + assert "time_ps" in checkpoint_state del runner @@ -199,26 +200,22 @@ def test_restart(mols, request): with pytest.raises(ValueError): runner_swapendstates = Runner(mols, Config(**config_diffswapendstates)) - # Need to test restart from sire checkpoint file - # this needs to be done last as it requires unlinking the config files + # Removing the config yaml should raise an OSError since the new-format + # checkpoint stores no config (the yaml is the sole validation source). for file in Path(tmpdir).glob("*.yaml"): file.unlink() - # This should work as the config is read from the lambda=0 checkpoint file - runner_noconfig = Runner(mols, Config(**config_new)) + with pytest.raises(OSError): + runner_noconfig = Runner(mols, Config(**config_new)) - # remove config again - for file in Path(tmpdir).glob("*.yaml"): - file.unlink() + # Write a config yaml with a wrong pressure value and verify restart fails. + import yaml - # Load the checkpoint file using sire and change the pressure option - sire_checkpoint = sr.stream.load(str(Path(tmpdir) / "checkpoint_0.00000.s3")) - cfg = sire_checkpoint.property("config") - cfg["pressure"] = "0.5 atm" - sire_checkpoint.set_property("config", cfg) - sr.stream.save(sire_checkpoint, str(Path(tmpdir) / "checkpoint_0.00000.s3")) + bad_config = config_new.copy() + bad_config["pressure"] = "0.5 atm" + with open(Path(tmpdir) / "config.yaml", "w") as f: + yaml.dump(bad_config, f) - # Load the new checkpoint file and make sure the restart fails with pytest.raises(ValueError): runner_badconfig = Runner(mols, Config(**config_new)) From 5187440cffab627bc0873d0c3ed80d4101997465 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 16 Jun 2026 12:03:14 +0100 Subject: [PATCH 2/6] Report GCMC centre correctly on restart & don't pre-equilibrate on restart. --- src/somd2/runner/_repex.py | 43 +++++++++++++++++++++++++++---------- src/somd2/runner/_runner.py | 24 +++++++++++++-------- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 8edcf5b..7d7eec5 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -103,6 +103,7 @@ def __init__( self._rest2_scale_factors = rest2_scale_factors self._states = _np.array(range(len(lambdas))) self._time = None + self._time_offset = None self._openmm_states = [None] * len(lambdas) self._gcmc_samplers = [None] * len(lambdas) self._gcmc_states = [None] * len(lambdas) @@ -143,6 +144,8 @@ def __setstate__(self, state): self._terminal_flip_stats = [[0, 0]] * n if not hasattr(self, "_time"): self._time = None + if not hasattr(self, "_time_offset"): + self._time_offset = None def __getstate__(self): """ @@ -304,15 +307,6 @@ def _create_dynamics( f"Created GCMC sampler for lambda {lam:.5f} on device {device}" ) - # Log the initial position of the GCMC sphere. - if self._gcmc_samplers[i]._reference is not None: - positions = _sr.io.get_coords_array(mols) - target = self._gcmc_samplers[i]._get_target_position(positions) - _logger.info( - f"Initial GCMC sphere centre for lambda {lam:.5f} on device {device}: " - f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A" - ) - # Create the dynamics object. try: dynamics = mols.dynamics(**dynamics_kwargs) @@ -823,6 +817,7 @@ def __init__(self, system, config): output_directory=self._config.output_directory, xml_filenames=xml_filenames, ) + else: _logger.debug("Restarting from file") @@ -848,6 +843,11 @@ def __init__(self, system, config): else: time = self._system[0].time() + # Store the absolute start time so _write_checkpoint_system can + # compute the correct absolute time across multiple restarts. + if not isinstance(self._system, list): + self._dynamics_cache._time_offset = time + # Check to see if the simulation is already complete. if self._done_file.exists(): # The runtime may have been extended beyond the previous run. @@ -922,6 +922,23 @@ def __init__(self, system, config): if self._dynamics_cache._gcmc_stats[i] is not None: gcmc_sampler.restore_stats(self._dynamics_cache._gcmc_stats[i]) + # Log the GCMC sphere centre for each replica using the actual context + # positions (accurate for both fresh runs and restarts). + import openmm.unit as _omm_unit + + for i, lam in enumerate(self._lambda_values): + dynamics, gcmc_sampler = self._dynamics_cache.get(i) + if gcmc_sampler is not None and gcmc_sampler._reference is not None: + state = dynamics.context().getState(getPositions=True) + positions = state.getPositions(asNumpy=True).value_in_unit( + _omm_unit.angstrom + ) + target = gcmc_sampler._get_target_position(positions) + _logger.info( + f"Initial GCMC sphere centre for lambda {lam:.5f}: " + f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A" + ) + # Conversion factor for reduced potential. kT = (_sr.units.k_boltz * self._config.temperature).to(_sr.units.kcal_per_mol) self._beta = 1.0 / kT @@ -1579,7 +1596,7 @@ def _minimise(self, index): # Get the dynamics object (and GCMC sampler). dynamics, gcmc_sampler = self._dynamics_cache.get(index) - if gcmc_sampler is not None: + if gcmc_sampler is not None and not self._is_restart: gcmc_sampler.push() try: _logger.info( @@ -1904,7 +1921,11 @@ def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=Non velocities are already stored as compact numpy arrays in the OpenMM state dict. """ - self._dynamics_cache._time = system.time() + offset = self._dynamics_cache._time_offset + elapsed = system.time() + self._dynamics_cache._time = ( + (offset + elapsed) if offset is not None else elapsed + ) def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): """ diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 00ea69d..dcc24c0 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -530,15 +530,6 @@ def generate_lam_vals(lambda_base, increment=0.001): # Get the GCMC system. system = gcmc_sampler.system() - # Log the initial position of the GCMC sphere. - if gcmc_sampler._reference is not None: - positions = _sr.io.get_coords_array(system) - target = gcmc_sampler._get_target_position(positions) - _logger.info( - f"Initial GCMC sphere centre at {_lam_sym} = {lambda_value:.5f}: " - f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A" - ) - else: gcmc_sampler = None @@ -825,6 +816,21 @@ def generate_lam_vals(lambda_base, increment=0.001): attempted, accepted = stats["terminal_flip"] terminal_flip_sampler.reset(attempted, accepted) + # Log the GCMC sphere centre using the actual context positions + # (accurate for both fresh runs and restarts). + if gcmc_sampler is not None and gcmc_sampler._reference is not None: + import openmm.unit as _omm_unit + + state = dynamics.context().getState(getPositions=True) + positions = state.getPositions(asNumpy=True).value_in_unit( + _omm_unit.angstrom + ) + target = gcmc_sampler._get_target_position(positions) + _logger.info( + f"Initial GCMC sphere centre at {_lam_sym} = {lambda_value:.5f}: " + f"[{target[0]:.3f}, {target[1]:.3f}, {target[2]:.3f}] A" + ) + # Set the number of neighbours used for the energy calculation. # If not None, then we add one to account for the extra windows # used for finite-difference gradient analysis. From cc4cb4fa6fecf8d124c293ef1e17b077203e4acb Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 16 Jun 2026 12:10:39 +0100 Subject: [PATCH 3/6] Time is already a Sire GeneralUnit. --- src/somd2/runner/_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index dcc24c0..08867cc 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -883,7 +883,7 @@ def generate_lam_vals(lambda_base, increment=0.001): # Handle the case where the runtime is less than the checkpoint frequency. if frac < 1.0: frac = 1.0 - checkpoint_frequency = _sr.u(f"{time} ps") + checkpoint_frequency = time checkpoint_interval = checkpoint_frequency.to("ns") num_blocks = int(frac) From 8dfae403c9eee2eb36ef1bfd932f3d2e16e111b0 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 16 Jun 2026 12:18:22 +0100 Subject: [PATCH 4/6] Set the system time on restart. --- src/somd2/runner/_repex.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 7d7eec5..5dfc8a8 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -886,6 +886,11 @@ def __init__(self, system, config): f"does not match the number of replicas in the configuration ({self._config.num_lambda})." ) + # For new-format restarts, set the system time so that dynamics + # objects are initialised with the correct integrator step count. + if not isinstance(self._system, list): + self._system.set_time(time) + # Create the dynamics objects. self._dynamics_cache._create_dynamics( self._system, From b21b5652290099947bab40b627df147706b7848e Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 16 Jun 2026 12:24:29 +0100 Subject: [PATCH 5/6] Remove redundant time offset. --- src/somd2/runner/_repex.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 5dfc8a8..3c9ff7b 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -103,7 +103,6 @@ def __init__( self._rest2_scale_factors = rest2_scale_factors self._states = _np.array(range(len(lambdas))) self._time = None - self._time_offset = None self._openmm_states = [None] * len(lambdas) self._gcmc_samplers = [None] * len(lambdas) self._gcmc_states = [None] * len(lambdas) @@ -144,8 +143,6 @@ def __setstate__(self, state): self._terminal_flip_stats = [[0, 0]] * n if not hasattr(self, "_time"): self._time = None - if not hasattr(self, "_time_offset"): - self._time_offset = None def __getstate__(self): """ @@ -843,11 +840,6 @@ def __init__(self, system, config): else: time = self._system[0].time() - # Store the absolute start time so _write_checkpoint_system can - # compute the correct absolute time across multiple restarts. - if not isinstance(self._system, list): - self._dynamics_cache._time_offset = time - # Check to see if the simulation is already complete. if self._done_file.exists(): # The runtime may have been extended beyond the previous run. @@ -1926,11 +1918,7 @@ def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=Non velocities are already stored as compact numpy arrays in the OpenMM state dict. """ - offset = self._dynamics_cache._time_offset - elapsed = system.time() - self._dynamics_cache._time = ( - (offset + elapsed) if offset is not None else elapsed - ) + self._dynamics_cache._time = system.time() def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): """ From 519ce9225c015f46713e7fee84421e951544fa97 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 16 Jun 2026 12:34:38 +0100 Subject: [PATCH 6/6] Only log for legacy restart path. --- src/somd2/runner/_repex.py | 5 +---- src/somd2/runner/_runner.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 3c9ff7b..7da3ba3 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -1897,16 +1897,13 @@ def _check_restart(self): checkpoint_path = _Path_local(self._filenames[0]["checkpoint"]) if checkpoint_path.exists(): - # Old format: load per-replica .s3 files via base class. + _logger.info("Restarting from legacy stream file checkpoint.") return super()._check_restart() repex_state = self._config.output_directory / "repex_state.pkl" if not repex_state.exists(): return False, self._system - _logger.info( - "No checkpoint stream files found; restarting from repex state pickle." - ) return True, self._system def _write_checkpoint_system(self, system, index, context=None, gcmc_sampler=None): diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 08867cc..fb3980f 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -279,9 +279,9 @@ def _check_restart(self): s3_path = _Path(self._filenames[0]["checkpoint"]) if npz_path.exists(): - _logger.info("Restarting from compact numpy checkpoint state.") return True, self._system elif s3_path.exists(): + _logger.info("Restarting from legacy stream file checkpoint.") return super()._check_restart() else: return False, self._system