Skip to content

Commit b666000

Browse files
committed
checkpointing
1 parent 3348211 commit b666000

4 files changed

Lines changed: 83 additions & 28 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def __init__(
315315
algorithm: Optional[str] = None,
316316
init_alpha: Optional[float] = None,
317317
iter: Optional[int] = None,
318+
save_iterations: bool = False,
318319
tol_obj: Optional[float] = None,
319320
tol_rel_obj: Optional[float] = None,
320321
tol_grad: Optional[float] = None,
@@ -326,6 +327,7 @@ def __init__(
326327
self.algorithm = algorithm
327328
self.init_alpha = init_alpha
328329
self.iter = iter
330+
self.save_iterations = save_iterations
329331
self.tol_obj = tol_obj
330332
self.tol_rel_obj = tol_rel_obj
331333
self.tol_grad = tol_grad
@@ -457,6 +459,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
457459
cmd.append('history_size={}'.format(self.history_size))
458460
if self.iter is not None:
459461
cmd.append('iter={}'.format(self.iter))
462+
if self.save_iterations:
463+
cmd.append('save_iterations=1')
460464

461465
return cmd
462466

cmdstanpy/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def optimize(
397397
tol_param: Optional[float] = None,
398398
history_size: Optional[int] = None,
399399
iter: Optional[int] = None,
400+
save_iterations: bool = False,
400401
refresh: Optional[int] = None,
401402
) -> CmdStanMLE:
402403
"""
@@ -481,6 +482,9 @@ def optimize(
481482
482483
:param iter: Total number of iterations
483484
485+
:param save_iterations: When ``True``, save intermediate approximations
486+
to the output CSV file. Default is ``False``.
487+
484488
:param refresh: Specify the number of iterations cmdstan will take
485489
between progress messages. Default value is 100.
486490
@@ -496,6 +500,7 @@ def optimize(
496500
tol_param=tol_param,
497501
history_size=history_size,
498502
iter=iter,
503+
save_iterations=save_iterations,
499504
)
500505

501506
with MaybeDictToFilePath(data, inits) as (_data, _inits):
@@ -518,11 +523,8 @@ def optimize(
518523
self._run_cmdstan(runset, dummy_chain_id)
519524

520525
if not runset._check_retcodes():
521-
msg = 'Error during optimization:\n{}'.format(runset.get_err_msgs())
522-
msg = '{}Command and output files:\n{}'.format(
523-
msg, runset.__repr__()
524-
)
525-
raise RuntimeError(msg)
526+
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
527+
get_logger().warn(msg) # https://github.com/stan-dev/cmdstanr/issues/314
526528
mle = CmdStanMLE(runset)
527529
return mle
528530

cmdstanpy/stanfit.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -303,23 +303,29 @@ def get_err_msgs(self) -> str:
303303
self._chain_ids[i], fd.read()
304304
)
305305
)
306-
# pre 2.27, all msgs sent to stdout, including errors
307-
if (
308-
not cmdstan_version_at(2, 27)
309-
and os.path.exists(self._stdout_files[i])
310-
and os.stat(self._stdout_files[i]).st_size > 0
311-
):
312-
with open(self._stdout_files[i], 'r') as fd:
313-
contents = fd.read()
314-
# pattern matches initial "Exception" or "Error" msg
315-
pat = re.compile(r'^E[rx].*$', re.M)
316-
errors = re.findall(pat, contents)
317-
if len(errors) > 0:
318-
msgs.append(
319-
'chain_id {}:\n\t{}\n'.format(
320-
self._chain_ids[i], '\n\t'.join(errors)
321-
)
322-
)
306+
# pre 2.27, all sampler msgs go to stdout, including errors
307+
if self._args.method == Method.SAMPLE:
308+
if (
309+
not cmdstan_version_at(2, 27)
310+
and os.path.exists(self._stdout_files[i])
311+
and os.stat(self._stdout_files[i]).st_size > 0
312+
):
313+
with open(self._stdout_files[i], 'r') as fd:
314+
contents = fd.read()
315+
# pattern matches initial "Exception" or "Error" msg
316+
pat = re.compile(r'^E[rx].*$', re.M)
317+
errors = re.findall(pat, contents)
318+
if len(errors) > 0:
319+
msgs.append(
320+
'chain_id {}:\n\t{}\n'.format(
321+
self._chain_ids[i], '\n\t'.join(errors)
322+
)
323+
)
324+
elif self._args.method == Method.OPTIMIZE:
325+
msgs.append('console log output:\n')
326+
with open(self._stdout_files[0], 'r') as fd:
327+
msgs.append(fd.read())
328+
323329
return '\n'.join(msgs)
324330

325331
def save_csvfiles(self, dir: Optional[str] = None) -> None:
@@ -1240,7 +1246,9 @@ def __repr__(self) -> str:
12401246
'\n\t'.join(self.runset.csv_files),
12411247
'\n\t'.join(self.runset.stdout_files),
12421248
)
1243-
# TODO - profiling files
1249+
if not self.runset._check_retcodes():
1250+
repr = '{}\n Warning: invalid estimate, '.format(repr)
1251+
repr = '{} optimization failed to converge.'.format(repr)
12441252
return repr
12451253

12461254
def _set_mle_attrs(self, sample_csv_0: str) -> None:
@@ -1269,20 +1277,33 @@ def metadata(self) -> InferenceMetadata:
12691277
@property
12701278
def optimized_params_np(self) -> np.ndarray:
12711279
"""Returns optimized params as numpy array."""
1280+
if not self.runset._check_retcodes():
1281+
get_logger().warning(
1282+
'invalid estimate, optimization failed to converge'
1283+
)
12721284
return np.asarray(self._mle)
12731285

12741286
@property
12751287
def optimized_params_pd(self) -> pd.DataFrame:
12761288
"""Returns optimized params as pandas DataFrame."""
1289+
if not self.runset._check_retcodes():
1290+
get_logger().warning(
1291+
'invalid estimate, optimization failed to converge'
1292+
)
12771293
return pd.DataFrame([self._mle], columns=self.column_names)
12781294

12791295
@property
12801296
def optimized_params_dict(self) -> Dict[str, float]:
12811297
"""Returns optimized params as Dict."""
1298+
if not self.runset._check_retcodes():
1299+
get_logger().warning(
1300+
'invalid estimate, optimization failed to converge'
1301+
)
12821302
return OrderedDict(zip(self.column_names, self._mle))
12831303

12841304
def stan_variable(
1285-
self, var: Optional[str] = None, *, name: Optional[str] = None
1305+
self, var: Optional[str] = None,
1306+
check_convergence:bool = True, *, name: Optional[str] = None
12861307
) -> np.ndarray:
12871308
"""
12881309
Return a numpy.ndarray which contains the estimates for the
@@ -1291,6 +1312,11 @@ def stan_variable(
12911312
12921313
:param var: variable name
12931314
1315+
:param check_convergence: Checks for failure to converge and
1316+
prints warning.failed to converge. ``False`` will supress
1317+
check and warning, default is ``True``.
1318+
1319+
12941320
See Also
12951321
--------
12961322
CmdStanMLE.stan_variables
@@ -1311,6 +1337,11 @@ def stan_variable(
13111337
raise ValueError('no variable name specified.')
13121338
if var not in self._metadata.stan_vars_dims:
13131339
raise ValueError('unknown variable name: {}'.format(var))
1340+
if check_convergence and not self.runset._check_retcodes():
1341+
get_logger().warning(
1342+
'invalid estimate, optimization failed to converge'
1343+
)
1344+
13141345
col_idxs = list(self._metadata.stan_vars_cols[var])
13151346
vals = list(self._mle)
13161347
xs = [vals[x] for x in col_idxs]
@@ -1319,21 +1350,29 @@ def stan_variable(
13191350
shape = self._metadata.stan_vars_dims[var]
13201351
return np.array(xs).reshape(shape)
13211352

1322-
def stan_variables(self) -> Dict[str, np.ndarray]:
1353+
def stan_variables(self, check_convergence:bool = True) -> Dict[str, np.ndarray]:
13231354
"""
13241355
Return a dictionary mapping Stan program variables names
13251356
to the corresponding numpy.ndarray containing the inferred values.
13261357
1358+
:param check_convergence: Checks for failure to converge and
1359+
prints warning.failed to converge. ``False`` will supress
1360+
check and warning, default is ``True``.
1361+
13271362
See Also
13281363
--------
13291364
CmdStanMLE.stan_variable
13301365
CmdStanMCMC.stan_variables
13311366
CmdStanVB.stan_variables
13321367
CmdStanGQ.stan_variables
13331368
"""
1369+
if check_convergence and not self.runset._check_retcodes():
1370+
get_logger().warning(
1371+
'invalid estimate, optimization failed to converge'
1372+
)
13341373
result = {}
13351374
for name in self._metadata.stan_vars_dims.keys():
1336-
result[name] = self.stan_variable(name)
1375+
result[name] = self.stan_variable(name, False) # don't warn twice
13371376
return result
13381377

13391378
def save_csvfiles(self, dir: Optional[str] = None) -> None:

cmdstanpy/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,24 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
578578
return dict
579579

580580

581-
def scan_optimize_csv(path: str) -> Dict[str, Any]:
581+
def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
582582
"""Process optimizer stan_csv output file line by line."""
583583
dict: Dict[str, Any] = {}
584584
lineno = 0
585+
# scan to find config, header, num saved iters
585586
with open(path, 'r') as fd:
586587
lineno = scan_config(fd, dict, lineno)
587588
lineno = scan_column_names(fd, dict, lineno)
588-
line = fd.readline().lstrip(' #\t').rstrip()
589+
iters = 0
590+
for line in fd:
591+
iters += 1
592+
# rescan to capture estimates
593+
with open(path, 'r') as fd:
594+
for i in range(lineno):
595+
fd.readline()
596+
line = fd.readline()
597+
# allocate numpy ndarray for
598+
print(line)
589599
xs = line.split(',')
590600
dict['mle'] = [float(x) for x in xs]
591601
return dict

0 commit comments

Comments
 (0)