@@ -319,8 +319,8 @@ def get_err_msgs(self) -> str:
319319 msgs .append (
320320 'chain_id {}:\n \t {}\n ' .format (
321321 self ._chain_ids [i ], '\n \t ' .join (errors )
322- )
323322 )
323+ )
324324 elif self ._args .method == Method .OPTIMIZE :
325325 msgs .append ('console log output:\n ' )
326326 with open (self ._stdout_files [0 ], 'r' ) as fd :
@@ -800,7 +800,7 @@ def _assemble_draws(self) -> None:
800800 line = fd .readline ().strip () # metric type
801801 line = fd .readline ().lstrip (' #\t ' )
802802 num_unconstrained_params = len (line .split (',' ))
803- if chain == 0 : # can't allocate w/o num params
803+ if chain == 0 : # can't allocate w/o num params
804804 if self .metric_type == 'diag_e' :
805805 self ._metric = np .empty (
806806 (self .chains , num_unconstrained_params ),
@@ -1235,6 +1235,13 @@ def __init__(self, runset: RunSet) -> None:
12351235 'found method {}' .format (runset .method )
12361236 )
12371237 self .runset = runset
1238+ # info from runset to be exposed
1239+ self .converged = runset ._check_retcodes ()
1240+ optimize_args = self .runset ._args .method_args
1241+ assert isinstance (
1242+ optimize_args , OptimizeArgs
1243+ ) # make the typechecker happy
1244+ self .save_iterations = optimize_args .save_iterations
12381245 self ._set_mle_attrs (runset .csv_files [0 ])
12391246
12401247 def __repr__ (self ) -> str :
@@ -1246,16 +1253,18 @@ def __repr__(self) -> str:
12461253 '\n \t ' .join (self .runset .csv_files ),
12471254 '\n \t ' .join (self .runset .stdout_files ),
12481255 )
1249- if not self .runset . _check_retcodes () :
1256+ if not self .converged :
12501257 repr = '{}\n Warning: invalid estimate, ' .format (repr )
12511258 repr = '{} optimization failed to converge.' .format (repr )
12521259 return repr
12531260
12541261 def _set_mle_attrs (self , sample_csv_0 : str ) -> None :
1255- meta = scan_optimize_csv (sample_csv_0 )
1262+ meta = scan_optimize_csv (sample_csv_0 , self . save_iterations )
12561263 self ._metadata = InferenceMetadata (meta )
12571264 self ._column_names : Tuple [str , ...] = meta ['column_names' ]
1258- self ._mle = meta ['mle' ]
1265+ self ._mle = meta ['mle' ]
1266+ if self .save_iterations :
1267+ self ._all_iters = meta ['all_iters' ]
12591268
12601269 @property
12611270 def column_names (self ) -> Tuple [str , ...]:
@@ -1276,36 +1285,90 @@ def metadata(self) -> InferenceMetadata:
12761285
12771286 @property
12781287 def optimized_params_np (self ) -> np .ndarray :
1279- """Returns optimized params as numpy array."""
1280- if not self .runset ._check_retcodes ():
1288+ """
1289+ Returns all final estimates from the optimizer as a numpy.ndarray
1290+ which contains all optimizer outputs, i.e., the value for `lp__`
1291+ as well as all Stan program variables.
1292+ """
1293+ if not self .converged :
12811294 get_logger ().warning (
1282- 'invalid estimate, optimization failed to converge'
1295+ 'Invalid estimate, optimization failed to converge. '
12831296 )
1284- # TODO: squeeze?
12851297 return self ._mle
12861298
1299+ @property
1300+ def optimized_iterations_np (self ) -> np .ndarray :
1301+ """
1302+ Returns all saved iterations from the optimizer and final estimate
1303+ as a numpy.ndarray which contains all optimizer outputs, i.e.,
1304+ the value for `lp__` as well as all Stan program variables.
1305+
1306+ """
1307+ if not self .save_iterations :
1308+ get_logger ().warning (
1309+ 'Intermediate iterations not saved because optimizer argument '
1310+ '"save_iterations=True" not specified. You must rerun '
1311+ 'the optimize method accordingly.'
1312+ )
1313+ return None
1314+ if not self .converged :
1315+ get_logger ().warning (
1316+ 'Invalid estimate, optimization failed to converge.'
1317+ )
1318+ return self ._all_iters
1319+
12871320 @property
12881321 def optimized_params_pd (self ) -> pd .DataFrame :
1289- """Returns optimized params as pandas DataFrame."""
1322+ """
1323+ Returns all final estimates from the optimizer as a pandas.DataFrame
1324+ which contains all optimizer outputs, i.e., the value for `lp__`
1325+ as well as all Stan program variables.
1326+ """
12901327 if not self .runset ._check_retcodes ():
12911328 get_logger ().warning (
1292- 'invalid estimate, optimization failed to converge'
1329+ 'Invalid estimate, optimization failed to converge. '
12931330 )
1294- return pd .DataFrame (self ._mle , columns = self .column_names )
1331+ return pd .DataFrame ([self ._mle ], columns = self .column_names )
1332+
1333+ @property
1334+ def optimized_iterations_pd (self ) -> pd .DataFrame :
1335+ """
1336+ Returns all saved iterations from the optimizer and final estimate
1337+ as a pandas.DataFrame which contains all optimizer outputs, i.e.,
1338+ the value for `lp__` as well as all Stan program variables.
1339+
1340+ """
1341+ if not self .save_iterations :
1342+ get_logger ().warning (
1343+ 'Intermediate iterations not saved because optimizer argument '
1344+ '"save_iterations=True" not specified. You must rerun '
1345+ 'the optimize method accordingly.'
1346+ )
1347+ return None
1348+ if not self .converged :
1349+ get_logger ().warning (
1350+ 'Invalid estimate, optimization failed to converge.'
1351+ )
1352+ return pd .DataFrame (self ._all_iters , columns = self .column_names )
12951353
12961354 @property
12971355 def optimized_params_dict (self ) -> Dict [str , float ]:
1298- """Returns optimized params as Dict."""
1356+ """
1357+ Returns all estimates from the optimizer, including `lp__` as a
1358+ Python Dict. Only returns estimate from final iteration.
1359+ """
12991360 if not self .runset ._check_retcodes ():
13001361 get_logger ().warning (
1301- 'invalid estimate, optimization failed to converge'
1362+ 'Invalid estimate, optimization failed to converge. '
13021363 )
1303- # TODO: return final estimate only
13041364 return OrderedDict (zip (self .column_names , self ._mle ))
13051365
13061366 def stan_variable (
1307- self , var : Optional [str ] = None ,
1308- check_convergence :bool = True , * , name : Optional [str ] = None
1367+ self ,
1368+ var : Optional [str ] = None ,
1369+ * ,
1370+ warn : bool = True ,
1371+ name : Optional [str ] = None ,
13091372 ) -> np .ndarray :
13101373 """
13111374 Return a numpy.ndarray which contains the estimates for the
@@ -1314,11 +1377,6 @@ def stan_variable(
13141377
13151378 :param var: variable name
13161379
1317- :param check_convergence: Checks for failure to converge and
1318- prints warning.failed to converge. ``False`` will supress
1319- check and warning, default is ``True``.
1320-
1321-
13221380 See Also
13231381 --------
13241382 CmdStanMLE.stan_variables
@@ -1339,43 +1397,38 @@ def stan_variable(
13391397 raise ValueError ('no variable name specified.' )
13401398 if var not in self ._metadata .stan_vars_dims :
13411399 raise ValueError ('unknown variable name: {}' .format (var ))
1342- if check_convergence and not self .runset ._check_retcodes ():
1400+ if warn and not self .runset ._check_retcodes ():
13431401 get_logger ().warning (
1344- 'invalid estimate, optimization failed to converge'
1402+ 'Invalid estimate, optimization failed to converge. '
13451403 )
13461404
13471405 col_idxs = list (self ._metadata .stan_vars_cols [var ])
1348- # TODO: return final estimate only
13491406 vals = list (self ._mle )
13501407 xs = [vals [x ] for x in col_idxs ]
13511408 shape : Tuple [int , ...] = ()
13521409 if len (col_idxs ) > 0 :
13531410 shape = self ._metadata .stan_vars_dims [var ]
13541411 return np .array (xs ).reshape (shape )
13551412
1356- def stan_variables (self , check_convergence : bool = True ) -> Dict [str , np .ndarray ]:
1413+ def stan_variables (self ) -> Dict [str , np .ndarray ]:
13571414 """
13581415 Return a dictionary mapping Stan program variables names
13591416 to the corresponding numpy.ndarray containing the inferred values.
13601417
1361- :param check_convergence: Checks for failure to converge and
1362- prints warning.failed to converge. ``False`` will supress
1363- check and warning, default is ``True``.
1364-
13651418 See Also
13661419 --------
13671420 CmdStanMLE.stan_variable
13681421 CmdStanMCMC.stan_variables
13691422 CmdStanVB.stan_variables
13701423 CmdStanGQ.stan_variables
13711424 """
1372- if check_convergence and not self .runset ._check_retcodes ():
1425+ if not self .runset ._check_retcodes ():
13731426 get_logger ().warning (
1374- 'invalid estimate, optimization failed to converge'
1427+ 'Invalid estimate, optimization failed to converge. '
13751428 )
13761429 result = {}
13771430 for name in self ._metadata .stan_vars_dims .keys ():
1378- result [name ] = self .stan_variable (name , False ) # don't warn twice
1431+ result [name ] = self .stan_variable (name , warn = False )
13791432 return result
13801433
13811434 def save_csvfiles (self , dir : Optional [str ] = None ) -> None :
@@ -2259,6 +2312,7 @@ def from_csv(
22592312 )
22602313 optimize_args = OptimizeArgs (
22612314 algorithm = config_dict ['algorithm' ],
2315+ save_iterations = config_dict ['save_iterations' ],
22622316 )
22632317 cmdstan_args = CmdStanArgs (
22642318 model_name = config_dict ['model' ],
0 commit comments