Skip to content

Commit 3224542

Browse files
authored
Merge pull request #450 from stan-dev/project/release-1.0-281-optimize
Project/release 1.0 281 optimize
2 parents 8ad7a9f + 1d8a65f commit 3224542

9 files changed

Lines changed: 665 additions & 46 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def __init__(
315315
algorithm: Optional[str] = None,
316316
init_alpha: Optional[float] = None,
317317
iter: Optional[int] = None,
318+
save_iterations: bool = False,
318319
tol_obj: Optional[float] = None,
319320
tol_rel_obj: Optional[float] = None,
320321
tol_grad: Optional[float] = None,
@@ -326,6 +327,7 @@ def __init__(
326327
self.algorithm = algorithm
327328
self.init_alpha = init_alpha
328329
self.iter = iter
330+
self.save_iterations = save_iterations
329331
self.tol_obj = tol_obj
330332
self.tol_rel_obj = tol_rel_obj
331333
self.tol_grad = tol_grad
@@ -457,6 +459,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
457459
cmd.append('history_size={}'.format(self.history_size))
458460
if self.iter is not None:
459461
cmd.append('iter={}'.format(self.iter))
462+
if self.save_iterations:
463+
cmd.append('save_iterations=1')
460464

461465
return cmd
462466

cmdstanpy/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ def optimize(
392392
tol_param: Optional[float] = None,
393393
history_size: Optional[int] = None,
394394
iter: Optional[int] = None,
395+
save_iterations: bool = False,
396+
require_converged: bool = True,
395397
refresh: Optional[int] = None,
396398
time_fmt: str = "%Y%m%d%H%M%S",
397399
) -> CmdStanMLE:
@@ -477,6 +479,12 @@ def optimize(
477479
478480
:param iter: Total number of iterations
479481
482+
:param save_iterations: When ``True``, save intermediate approximations
483+
to the output CSV file. Default is ``False``.
484+
485+
:param require_converged: Whether or not to raise an error if Stan
486+
reports that "The algorithm may not have converged".
487+
480488
:param refresh: Specify the number of iterations cmdstan will take
481489
between progress messages. Default value is 100.
482490
@@ -496,6 +504,7 @@ def optimize(
496504
tol_param=tol_param,
497505
history_size=history_size,
498506
iter=iter,
507+
save_iterations=save_iterations,
499508
)
500509

501510
with MaybeDictToFilePath(data, inits) as (_data, _inits):
@@ -518,11 +527,11 @@ def optimize(
518527
self._run_cmdstan(runset, dummy_chain_id)
519528

520529
if not runset._check_retcodes():
521-
msg = 'Error during optimization:\n{}'.format(runset.get_err_msgs())
522-
msg = '{}Command and output files:\n{}'.format(
523-
msg, runset.__repr__()
524-
)
525-
raise RuntimeError(msg)
530+
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
531+
if 'Line search failed' in msg and not require_converged:
532+
get_logger().warning(msg)
533+
else:
534+
raise RuntimeError(msg)
526535
mle = CmdStanMLE(runset)
527536
return mle
528537

@@ -1132,7 +1141,7 @@ def variational(
11321141
:param output_samples: Number of approximate posterior output draws
11331142
to save.
11341143
1135-
:param require_converged: Whether or not to raise an error if stan
1144+
:param require_converged: Whether or not to raise an error if Stan
11361145
reports that "The algorithm may not have converged".
11371146
11381147
:param refresh: Specify the number of iterations cmdstan will take

cmdstanpy/stanfit.py

Lines changed: 174 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -304,23 +304,29 @@ def get_err_msgs(self) -> str:
304304
self._chain_ids[i], fd.read()
305305
)
306306
)
307-
# pre 2.27, all msgs sent to stdout, including errors
308-
if (
309-
not cmdstan_version_at(2, 27)
310-
and os.path.exists(self._stdout_files[i])
311-
and os.stat(self._stdout_files[i]).st_size > 0
312-
):
313-
with open(self._stdout_files[i], 'r') as fd:
314-
contents = fd.read()
315-
# pattern matches initial "Exception" or "Error" msg
316-
pat = re.compile(r'^E[rx].*$', re.M)
317-
errors = re.findall(pat, contents)
318-
if len(errors) > 0:
319-
msgs.append(
320-
'chain_id {}:\n\t{}\n'.format(
321-
self._chain_ids[i], '\n\t'.join(errors)
322-
)
323-
)
307+
# pre 2.27, all sampler msgs go to stdout, including errors
308+
if self._args.method == Method.SAMPLE:
309+
if (
310+
not cmdstan_version_at(2, 27)
311+
and os.path.exists(self._stdout_files[i])
312+
and os.stat(self._stdout_files[i]).st_size > 0
313+
):
314+
with open(self._stdout_files[i], 'r') as fd:
315+
contents = fd.read()
316+
# pattern matches initial "Exception" or "Error" msg
317+
pat = re.compile(r'^E[rx].*$', re.M)
318+
errors = re.findall(pat, contents)
319+
if len(errors) > 0:
320+
msgs.append(
321+
'chain_id {}:\n\t{}\n'.format(
322+
self._chain_ids[i], '\n\t'.join(errors)
323+
)
324+
)
325+
elif self._args.method == Method.OPTIMIZE:
326+
msgs.append('console log output:\n')
327+
with open(self._stdout_files[0], 'r') as fd:
328+
msgs.append(fd.read())
329+
324330
return '\n'.join(msgs)
325331

326332
def save_csvfiles(self, dir: Optional[str] = None) -> None:
@@ -1230,6 +1236,13 @@ def __init__(self, runset: RunSet) -> None:
12301236
'found method {}'.format(runset.method)
12311237
)
12321238
self.runset = runset
1239+
# info from runset to be exposed
1240+
self.converged = runset._check_retcodes()
1241+
optimize_args = self.runset._args.method_args
1242+
assert isinstance(
1243+
optimize_args, OptimizeArgs
1244+
) # make the typechecker happy
1245+
self._save_iterations = optimize_args.save_iterations
12331246
self._set_mle_attrs(runset.csv_files[0])
12341247

12351248
def __repr__(self) -> str:
@@ -1241,14 +1254,22 @@ def __repr__(self) -> str:
12411254
'\n\t'.join(self.runset.csv_files),
12421255
'\n\t'.join(self.runset.stdout_files),
12431256
)
1244-
# TODO - profiling files
1257+
if not self.converged:
1258+
repr = '{}\n Warning: invalid estimate, '.format(repr)
1259+
repr = '{} optimization failed to converge.'.format(repr)
12451260
return repr
12461261

12471262
def _set_mle_attrs(self, sample_csv_0: str) -> None:
1248-
meta = scan_optimize_csv(sample_csv_0)
1263+
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
12491264
self._metadata = InferenceMetadata(meta)
12501265
self._column_names: Tuple[str, ...] = meta['column_names']
1251-
self._mle: List[float] = meta['mle']
1266+
assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy
1267+
self._mle = meta['mle']
1268+
if self._save_iterations:
1269+
assert isinstance(
1270+
meta['all_iters'], np.ndarray
1271+
) # make the typechecker happy
1272+
self._all_iters = meta['all_iters']
12521273

12531274
@property
12541275
def column_names(self) -> Tuple[str, ...]:
@@ -1269,21 +1290,89 @@ def metadata(self) -> InferenceMetadata:
12691290

12701291
@property
12711292
def optimized_params_np(self) -> np.ndarray:
1272-
"""Returns optimized params as numpy array."""
1273-
return np.asarray(self._mle)
1293+
"""
1294+
Returns all final estimates from the optimizer as a numpy.ndarray
1295+
which contains all optimizer outputs, i.e., the value for `lp__`
1296+
as well as all Stan program variables.
1297+
"""
1298+
if not self.converged:
1299+
get_logger().warning(
1300+
'Invalid estimate, optimization failed to converge.'
1301+
)
1302+
return self._mle
1303+
1304+
@property
1305+
def optimized_iterations_np(self) -> Optional[np.ndarray]:
1306+
"""
1307+
Returns all saved iterations from the optimizer and final estimate
1308+
as a numpy.ndarray which contains all optimizer outputs, i.e.,
1309+
the value for `lp__` as well as all Stan program variables.
1310+
1311+
"""
1312+
if not self._save_iterations:
1313+
get_logger().warning(
1314+
'Intermediate iterations not saved to CSV output file. '
1315+
'Rerun the optimize method with "save_iterations=True".'
1316+
)
1317+
return None
1318+
if not self.converged:
1319+
get_logger().warning(
1320+
'Invalid estimate, optimization failed to converge.'
1321+
)
1322+
return self._all_iters
12741323

12751324
@property
12761325
def optimized_params_pd(self) -> pd.DataFrame:
1277-
"""Returns optimized params as pandas DataFrame."""
1326+
"""
1327+
Returns all final estimates from the optimizer as a pandas.DataFrame
1328+
which contains all optimizer outputs, i.e., the value for `lp__`
1329+
as well as all Stan program variables.
1330+
"""
1331+
if not self.runset._check_retcodes():
1332+
get_logger().warning(
1333+
'Invalid estimate, optimization failed to converge.'
1334+
)
12781335
return pd.DataFrame([self._mle], columns=self.column_names)
12791336

1337+
@property
1338+
def optimized_iterations_pd(self) -> Optional[pd.DataFrame]:
1339+
"""
1340+
Returns all saved iterations from the optimizer and final estimate
1341+
as a pandas.DataFrame which contains all optimizer outputs, i.e.,
1342+
the value for `lp__` as well as all Stan program variables.
1343+
1344+
"""
1345+
if not self._save_iterations:
1346+
get_logger().warning(
1347+
'Intermediate iterations not saved to CSV output file. '
1348+
'Rerun the optimize method with "save_iterations=True".'
1349+
)
1350+
return None
1351+
if not self.converged:
1352+
get_logger().warning(
1353+
'Invalid estimate, optimization failed to converge.'
1354+
)
1355+
return pd.DataFrame(self._all_iters, columns=self.column_names)
1356+
12801357
@property
12811358
def optimized_params_dict(self) -> Dict[str, float]:
1282-
"""Returns optimized params as Dict."""
1359+
"""
1360+
Returns all estimates from the optimizer, including `lp__` as a
1361+
Python Dict. Only returns estimate from final iteration.
1362+
"""
1363+
if not self.runset._check_retcodes():
1364+
get_logger().warning(
1365+
'Invalid estimate, optimization failed to converge.'
1366+
)
12831367
return OrderedDict(zip(self.column_names, self._mle))
12841368

12851369
def stan_variable(
1286-
self, var: Optional[str] = None, *, name: Optional[str] = None
1370+
self,
1371+
var: Optional[str] = None,
1372+
*,
1373+
inc_iterations: bool = False,
1374+
warn: bool = True,
1375+
name: Optional[str] = None,
12871376
) -> np.ndarray:
12881377
"""
12891378
Return a numpy.ndarray which contains the estimates for the
@@ -1292,6 +1381,11 @@ def stan_variable(
12921381
12931382
:param var: variable name
12941383
1384+
:param inc_iterations: When ``True`` and the intermediate estimates
1385+
are included in the output, i.e., the optimizer was run with
1386+
``save_iterations=True``, then intermediate estimates are included.
1387+
Default value is ``False``.
1388+
12951389
See Also
12961390
--------
12971391
CmdStanMLE.stan_variables
@@ -1312,29 +1406,74 @@ def stan_variable(
13121406
raise ValueError('no variable name specified.')
13131407
if var not in self._metadata.stan_vars_dims:
13141408
raise ValueError('unknown variable name: {}'.format(var))
1315-
col_idxs = list(self._metadata.stan_vars_cols[var])
1316-
vals = list(self._mle)
1317-
xs = [vals[x] for x in col_idxs]
1318-
shape: Tuple[int, ...] = ()
1319-
if len(col_idxs) > 0:
1320-
shape = self._metadata.stan_vars_dims[var]
1321-
return np.array(xs).reshape(shape)
1409+
if warn and inc_iterations and not self._save_iterations:
1410+
get_logger().warning(
1411+
'Intermediate iterations not saved to CSV output file. '
1412+
'Rerun the optimize method with "save_iterations=True".'
1413+
)
1414+
if warn and not self.runset._check_retcodes():
1415+
get_logger().warning(
1416+
'Invalid estimate, optimization failed to converge.'
1417+
)
13221418

1323-
def stan_variables(self) -> Dict[str, np.ndarray]:
1419+
col_idxs = self._metadata.stan_vars_cols[var]
1420+
if inc_iterations and self._save_iterations:
1421+
num_rows = self._all_iters.shape[0]
1422+
else:
1423+
num_rows = 1
1424+
1425+
if len(col_idxs) > 0: # container var
1426+
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
1427+
# pylint: disable=redundant-keyword-arg
1428+
if num_rows > 1:
1429+
result = self._all_iters[:, col_idxs].reshape( # type: ignore
1430+
dims, order='F'
1431+
)
1432+
else:
1433+
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
1434+
result = (
1435+
mle[0, col_idxs]
1436+
.reshape(dims, order='F') # type: ignore
1437+
.squeeze(axis=0)
1438+
)
1439+
else: # scalar var
1440+
if num_rows > 1:
1441+
result = self._all_iters[:, col_idxs]
1442+
else:
1443+
result = np.atleast_1d(mle[0, col_idxs])
1444+
1445+
assert isinstance(result, np.ndarray) # make the typechecker happy
1446+
return result
1447+
1448+
def stan_variables(
1449+
self, inc_iterations: bool = False
1450+
) -> Dict[str, np.ndarray]:
13241451
"""
13251452
Return a dictionary mapping Stan program variables names
13261453
to the corresponding numpy.ndarray containing the inferred values.
13271454
1455+
:param inc_iterations: When ``True`` and the intermediate estimates
1456+
are included in the output, i.e., the optimizer was run with
1457+
``save_iterations=True``, then intermediate estimates are included.
1458+
Default value is ``False``.
1459+
1460+
13281461
See Also
13291462
--------
13301463
CmdStanMLE.stan_variable
13311464
CmdStanMCMC.stan_variables
13321465
CmdStanVB.stan_variables
13331466
CmdStanGQ.stan_variables
13341467
"""
1468+
if not self.runset._check_retcodes():
1469+
get_logger().warning(
1470+
'Invalid estimate, optimization failed to converge.'
1471+
)
13351472
result = {}
13361473
for name in self._metadata.stan_vars_dims.keys():
1337-
result[name] = self.stan_variable(name)
1474+
result[name] = self.stan_variable(
1475+
name, inc_iterations=inc_iterations, warn=False
1476+
)
13381477
return result
13391478

13401479
def save_csvfiles(self, dir: Optional[str] = None) -> None:
@@ -2218,6 +2357,7 @@ def from_csv(
22182357
)
22192358
optimize_args = OptimizeArgs(
22202359
algorithm=config_dict['algorithm'],
2360+
save_iterations=config_dict['save_iterations'],
22212361
)
22222362
cmdstan_args = CmdStanArgs(
22232363
model_name=config_dict['model'],

0 commit comments

Comments
 (0)