Skip to content

Commit 1a06f9c

Browse files
committed
optimize - save_iterations, require_converged; unit tests
1 parent 0aed861 commit 1a06f9c

7 files changed

Lines changed: 451 additions & 39 deletions

File tree

cmdstanpy/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def optimize(
398398
history_size: Optional[int] = None,
399399
iter: Optional[int] = None,
400400
save_iterations: bool = False,
401+
require_converged: bool = True,
401402
refresh: Optional[int] = None,
402403
) -> CmdStanMLE:
403404
"""
@@ -485,6 +486,9 @@ def optimize(
485486
:param save_iterations: When ``True``, save intermediate approximations
486487
to the output CSV file. Default is ``False``.
487488
489+
:param require_converged: Whether or not to raise an error if Stan
490+
reports that "The algorithm may not have converged".
491+
488492
:param refresh: Specify the number of iterations cmdstan will take
489493
between progress messages. Default value is 100.
490494
@@ -524,7 +528,10 @@ def optimize(
524528

525529
if not runset._check_retcodes():
526530
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
527-
get_logger().warn(msg) # https://github.com/stan-dev/cmdstanr/issues/314
531+
if 'Line search failed' in msg and not require_converged:
532+
get_logger().warning(msg)
533+
else:
534+
raise RuntimeError(msg)
528535
mle = CmdStanMLE(runset)
529536
return mle
530537

@@ -1119,7 +1126,7 @@ def variational(
11191126
:param output_samples: Number of approximate posterior output draws
11201127
to save.
11211128
1122-
:param require_converged: Whether or not to raise an error if stan
1129+
:param require_converged: Whether or not to raise an error if Stan
11231130
reports that "The algorithm may not have converged".
11241131
11251132
:param refresh: Specify the number of iterations cmdstan will take

cmdstanpy/stanfit.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def get_err_msgs(self) -> str:
319319
msgs.append(
320320
'chain_id {}:\n\t{}\n'.format(
321321
self._chain_ids[i], '\n\t'.join(errors)
322-
)
323322
)
323+
)
324324
elif self._args.method == Method.OPTIMIZE:
325325
msgs.append('console log output:\n')
326326
with open(self._stdout_files[0], 'r') as fd:
@@ -800,7 +800,7 @@ def _assemble_draws(self) -> None:
800800
line = fd.readline().strip() # metric type
801801
line = fd.readline().lstrip(' #\t')
802802
num_unconstrained_params = len(line.split(','))
803-
if chain == 0: # can't allocate w/o num params
803+
if chain == 0: # can't allocate w/o num params
804804
if self.metric_type == 'diag_e':
805805
self._metric = np.empty(
806806
(self.chains, num_unconstrained_params),
@@ -1235,6 +1235,13 @@ def __init__(self, runset: RunSet) -> None:
12351235
'found method {}'.format(runset.method)
12361236
)
12371237
self.runset = runset
1238+
# info from runset to be exposed
1239+
self.converged = runset._check_retcodes()
1240+
optimize_args = self.runset._args.method_args
1241+
assert isinstance(
1242+
optimize_args, OptimizeArgs
1243+
) # make the typechecker happy
1244+
self.save_iterations = optimize_args.save_iterations
12381245
self._set_mle_attrs(runset.csv_files[0])
12391246

12401247
def __repr__(self) -> str:
@@ -1246,16 +1253,18 @@ def __repr__(self) -> str:
12461253
'\n\t'.join(self.runset.csv_files),
12471254
'\n\t'.join(self.runset.stdout_files),
12481255
)
1249-
if not self.runset._check_retcodes():
1256+
if not self.converged:
12501257
repr = '{}\n Warning: invalid estimate, '.format(repr)
12511258
repr = '{} optimization failed to converge.'.format(repr)
12521259
return repr
12531260

12541261
def _set_mle_attrs(self, sample_csv_0: str) -> None:
1255-
meta = scan_optimize_csv(sample_csv_0)
1262+
meta = scan_optimize_csv(sample_csv_0, self.save_iterations)
12561263
self._metadata = InferenceMetadata(meta)
12571264
self._column_names: Tuple[str, ...] = meta['column_names']
1258-
self._mle = meta['mle']
1265+
self._mle = meta['mle']
1266+
if self.save_iterations:
1267+
self._all_iters = meta['all_iters']
12591268

12601269
@property
12611270
def column_names(self) -> Tuple[str, ...]:
@@ -1276,36 +1285,90 @@ def metadata(self) -> InferenceMetadata:
12761285

12771286
@property
12781287
def optimized_params_np(self) -> np.ndarray:
1279-
"""Returns optimized params as numpy array."""
1280-
if not self.runset._check_retcodes():
1288+
"""
1289+
Returns all final estimates from the optimizer as a numpy.ndarray
1290+
which contains all optimizer outputs, i.e., the value for `lp__`
1291+
as well as all Stan program variables.
1292+
"""
1293+
if not self.converged:
12811294
get_logger().warning(
1282-
'invalid estimate, optimization failed to converge'
1295+
'Invalid estimate, optimization failed to converge.'
12831296
)
1284-
# TODO: squeeze?
12851297
return self._mle
12861298

1299+
@property
1300+
def optimized_iterations_np(self) -> np.ndarray:
1301+
"""
1302+
Returns all saved iterations from the optimizer and final estimate
1303+
as a numpy.ndarray which contains all optimizer outputs, i.e.,
1304+
the value for `lp__` as well as all Stan program variables.
1305+
1306+
"""
1307+
if not self.save_iterations:
1308+
get_logger().warning(
1309+
'Intermediate iterations not saved because optimizer argument '
1310+
'"save_iterations=True" not specified. You must rerun '
1311+
'the optimize method accordingly.'
1312+
)
1313+
return None
1314+
if not self.converged:
1315+
get_logger().warning(
1316+
'Invalid estimate, optimization failed to converge.'
1317+
)
1318+
return self._all_iters
1319+
12871320
@property
12881321
def optimized_params_pd(self) -> pd.DataFrame:
1289-
"""Returns optimized params as pandas DataFrame."""
1322+
"""
1323+
Returns all final estimates from the optimizer as a pandas.DataFrame
1324+
which contains all optimizer outputs, i.e., the value for `lp__`
1325+
as well as all Stan program variables.
1326+
"""
12901327
if not self.runset._check_retcodes():
12911328
get_logger().warning(
1292-
'invalid estimate, optimization failed to converge'
1329+
'Invalid estimate, optimization failed to converge.'
12931330
)
1294-
return pd.DataFrame(self._mle, columns=self.column_names)
1331+
return pd.DataFrame([self._mle], columns=self.column_names)
1332+
1333+
@property
1334+
def optimized_iterations_pd(self) -> pd.DataFrame:
1335+
"""
1336+
Returns all saved iterations from the optimizer and final estimate
1337+
as a pandas.DataFrame which contains all optimizer outputs, i.e.,
1338+
the value for `lp__` as well as all Stan program variables.
1339+
1340+
"""
1341+
if not self.save_iterations:
1342+
get_logger().warning(
1343+
'Intermediate iterations not saved because optimizer argument '
1344+
'"save_iterations=True" not specified. You must rerun '
1345+
'the optimize method accordingly.'
1346+
)
1347+
return None
1348+
if not self.converged:
1349+
get_logger().warning(
1350+
'Invalid estimate, optimization failed to converge.'
1351+
)
1352+
return pd.DataFrame(self._all_iters, columns=self.column_names)
12951353

12961354
@property
12971355
def optimized_params_dict(self) -> Dict[str, float]:
1298-
"""Returns optimized params as Dict."""
1356+
"""
1357+
Returns all estimates from the optimizer, including `lp__` as a
1358+
Python Dict. Only returns estimate from final iteration.
1359+
"""
12991360
if not self.runset._check_retcodes():
13001361
get_logger().warning(
1301-
'invalid estimate, optimization failed to converge'
1362+
'Invalid estimate, optimization failed to converge.'
13021363
)
1303-
# TODO: return final estimate only
13041364
return OrderedDict(zip(self.column_names, self._mle))
13051365

13061366
def stan_variable(
1307-
self, var: Optional[str] = None,
1308-
check_convergence:bool = True, *, name: Optional[str] = None
1367+
self,
1368+
var: Optional[str] = None,
1369+
*,
1370+
warn: bool = True,
1371+
name: Optional[str] = None,
13091372
) -> np.ndarray:
13101373
"""
13111374
Return a numpy.ndarray which contains the estimates for the
@@ -1314,11 +1377,6 @@ def stan_variable(
13141377
13151378
:param var: variable name
13161379
1317-
:param check_convergence: Checks for failure to converge and
1318-
prints warning.failed to converge. ``False`` will supress
1319-
check and warning, default is ``True``.
1320-
1321-
13221380
See Also
13231381
--------
13241382
CmdStanMLE.stan_variables
@@ -1339,43 +1397,38 @@ def stan_variable(
13391397
raise ValueError('no variable name specified.')
13401398
if var not in self._metadata.stan_vars_dims:
13411399
raise ValueError('unknown variable name: {}'.format(var))
1342-
if check_convergence and not self.runset._check_retcodes():
1400+
if warn and not self.runset._check_retcodes():
13431401
get_logger().warning(
1344-
'invalid estimate, optimization failed to converge'
1402+
'Invalid estimate, optimization failed to converge.'
13451403
)
13461404

13471405
col_idxs = list(self._metadata.stan_vars_cols[var])
1348-
# TODO: return final estimate only
13491406
vals = list(self._mle)
13501407
xs = [vals[x] for x in col_idxs]
13511408
shape: Tuple[int, ...] = ()
13521409
if len(col_idxs) > 0:
13531410
shape = self._metadata.stan_vars_dims[var]
13541411
return np.array(xs).reshape(shape)
13551412

1356-
def stan_variables(self, check_convergence:bool = True) -> Dict[str, np.ndarray]:
1413+
def stan_variables(self) -> Dict[str, np.ndarray]:
13571414
"""
13581415
Return a dictionary mapping Stan program variables names
13591416
to the corresponding numpy.ndarray containing the inferred values.
13601417
1361-
:param check_convergence: Checks for failure to converge and
1362-
prints warning.failed to converge. ``False`` will supress
1363-
check and warning, default is ``True``.
1364-
13651418
See Also
13661419
--------
13671420
CmdStanMLE.stan_variable
13681421
CmdStanMCMC.stan_variables
13691422
CmdStanVB.stan_variables
13701423
CmdStanGQ.stan_variables
13711424
"""
1372-
if check_convergence and not self.runset._check_retcodes():
1425+
if not self.runset._check_retcodes():
13731426
get_logger().warning(
1374-
'invalid estimate, optimization failed to converge'
1427+
'Invalid estimate, optimization failed to converge.'
13751428
)
13761429
result = {}
13771430
for name in self._metadata.stan_vars_dims.keys():
1378-
result[name] = self.stan_variable(name, False) # don't warn twice
1431+
result[name] = self.stan_variable(name, warn=False)
13791432
return result
13801433

13811434
def save_csvfiles(self, dir: Optional[str] = None) -> None:
@@ -2259,6 +2312,7 @@ def from_csv(
22592312
)
22602313
optimize_args = OptimizeArgs(
22612314
algorithm=config_dict['algorithm'],
2315+
save_iterations=config_dict['save_iterations'],
22622316
)
22632317
cmdstan_args = CmdStanArgs(
22642318
model_name=config_dict['model'],

cmdstanpy/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,11 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
589589
iters = 0
590590
for line in fd:
591591
iters += 1
592+
if save_iters:
593+
all_iters = np.empty(
594+
(iters, len(dict['column_names'])), dtype=float, order='F'
595+
)
592596
# rescan to capture estimates
593-
mle = np.empty(
594-
(iters, len(dict['column_names'])), dtype=float, order='F'
595-
)
596597
with open(path, 'r') as fd:
597598
for i in range(lineno):
598599
fd.readline()
@@ -605,8 +606,13 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
605606
)
606607
)
607608
xs = line.split(',')
608-
mle[i, :] = [float(x) for x in xs]
609+
if save_iters:
610+
all_iters[i, :] = [float(x) for x in xs]
611+
if i == iters - 1:
612+
mle = np.array([float(x) for x in xs], dtype=float)
609613
dict['mle'] = mle
614+
if save_iters:
615+
dict['all_iters'] = all_iters
610616
return dict
611617

612618

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
J <- 8
2+
y <- c(28, 8, -3, 7, -1, 1, 18, 12)
3+
sigma <- c(15, 10, 16, 11, 9, 11, 10, 18)
4+
tau <- 25
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
data {
2+
int<lower=0> J; // number of schools
3+
real y[J]; // estimated treatment effect (school j)
4+
real<lower=0> sigma[J]; // std err of effect estimate (school j)
5+
}
6+
parameters {
7+
real mu;
8+
real theta[J];
9+
real<lower=0> tau;
10+
}
11+
model {
12+
theta ~ normal(mu, tau);
13+
y ~ normal(theta,sigma);
14+
}

0 commit comments

Comments
 (0)