@@ -1241,7 +1241,7 @@ def __init__(self, runset: RunSet) -> None:
12411241 assert isinstance (
12421242 optimize_args , OptimizeArgs
12431243 ) # make the typechecker happy
1244- self .save_iterations = optimize_args .save_iterations
1244+ self ._save_iterations = optimize_args .save_iterations
12451245 self ._set_mle_attrs (runset .csv_files [0 ])
12461246
12471247 def __repr__ (self ) -> str :
@@ -1259,11 +1259,11 @@ def __repr__(self) -> str:
12591259 return repr
12601260
12611261 def _set_mle_attrs (self , sample_csv_0 : str ) -> None :
1262- meta = scan_optimize_csv (sample_csv_0 , self .save_iterations )
1262+ meta = scan_optimize_csv (sample_csv_0 , self ._save_iterations )
12631263 self ._metadata = InferenceMetadata (meta )
12641264 self ._column_names : Tuple [str , ...] = meta ['column_names' ]
12651265 self ._mle = meta ['mle' ]
1266- if self .save_iterations :
1266+ if self ._save_iterations :
12671267 self ._all_iters = meta ['all_iters' ]
12681268
12691269 @property
@@ -1304,11 +1304,10 @@ def optimized_iterations_np(self) -> np.ndarray:
13041304 the value for `lp__` as well as all Stan program variables.
13051305
13061306 """
1307- if not self .save_iterations :
1307+ if not self ._save_iterations :
13081308 get_logger ().warning (
1309- 'Intermediate iterations not saved because optimizer argument '
1310- '"save_iterations=True" not specified. You must rerun '
1311- 'the optimize method accordingly.'
1309+ 'Intermediate iterations not saved to CSV output file. '
1310+ 'Rerun the optimize method with "save_iterations=True".'
13121311 )
13131312 return None
13141313 if not self .converged :
@@ -1338,11 +1337,10 @@ def optimized_iterations_pd(self) -> pd.DataFrame:
13381337 the value for `lp__` as well as all Stan program variables.
13391338
13401339 """
1341- if not self .save_iterations :
1340+ if not self ._save_iterations :
13421341 get_logger ().warning (
1343- 'Intermediate iterations not saved because optimizer argument '
1344- '"save_iterations=True" not specified. You must rerun '
1345- 'the optimize method accordingly.'
1342+ 'Intermediate iterations not saved to CSV output file. '
1343+ 'Rerun the optimize method with "save_iterations=True".'
13461344 )
13471345 return None
13481346 if not self .converged :
@@ -1367,6 +1365,7 @@ def stan_variable(
13671365 self ,
13681366 var : Optional [str ] = None ,
13691367 * ,
1368+ inc_iterations : bool = False ,
13701369 warn : bool = True ,
13711370 name : Optional [str ] = None ,
13721371 ) -> np .ndarray :
@@ -1377,6 +1376,11 @@ def stan_variable(
13771376
13781377 :param var: variable name
13791378
1379+ :param inc_iterations: When ``True`` and the intermediate estimates
1380+ are included in the output, i.e., the optimizer was run with
1381+ ``save_iterations=True``, then intermediate estimates are included.
1382+ Default value is ``False``.
1383+
13801384 See Also
13811385 --------
13821386 CmdStanMLE.stan_variables
@@ -1397,24 +1401,56 @@ def stan_variable(
13971401 raise ValueError ('no variable name specified.' )
13981402 if var not in self ._metadata .stan_vars_dims :
13991403 raise ValueError ('unknown variable name: {}' .format (var ))
1404+ if warn and inc_iterations and not self ._save_iterations :
1405+ get_logger ().warning (
1406+ 'Intermediate iterations not saved to CSV output file. '
1407+ 'Rerun the optimize method with "save_iterations=True".'
1408+ )
14001409 if warn and not self .runset ._check_retcodes ():
14011410 get_logger ().warning (
14021411 'Invalid estimate, optimization failed to converge.'
14031412 )
14041413
1405- col_idxs = list (self ._metadata .stan_vars_cols [var ])
1406- vals = list (self ._mle )
1407- xs = [vals [x ] for x in col_idxs ]
1408- shape : Tuple [int , ...] = ()
1414+ col_idxs = self ._metadata .stan_vars_cols [var ]
1415+ if inc_iterations and self ._save_iterations :
1416+ num_rows = self ._all_iters .shape [0 ]
1417+ else :
1418+ num_rows = 1
1419+
1420+ # extract and reshape, container var
14091421 if len (col_idxs ) > 0 :
1410- shape = self ._metadata .stan_vars_dims [var ]
1411- return np .array (xs ).reshape (shape )
1422+ dims = (num_rows ,) + self ._metadata .stan_vars_dims [var ]
1423+ # pylint: disable=redundant-keyword-arg
1424+ if num_rows > 1 :
1425+ return self ._all_iters [:, col_idxs ].reshape ( # type: ignore
1426+ dims , order = 'F'
1427+ )
1428+ else :
1429+ mle = np .expand_dims (self ._mle , axis = 0 ) # hack for col indexing
1430+ return (
1431+ mle [0 , col_idxs ]
1432+ .reshape (dims , order = 'F' ) # type: ignore
1433+ .squeeze (axis = 0 )
1434+ )
14121435
1413- def stan_variables (self ) -> Dict [str , np .ndarray ]:
1436+ # extract scalar var
1437+ if num_rows > 1 :
1438+ return self ._all_iters [:, col_idxs ]
1439+ return mle [0 , col_idxs ]
1440+
1441+ def stan_variables (
1442+ self , inc_iterations : bool = False
1443+ ) -> Dict [str , np .ndarray ]:
14141444 """
14151445 Return a dictionary mapping Stan program variables names
14161446 to the corresponding numpy.ndarray containing the inferred values.
14171447
1448+ :param inc_iterations: When ``True`` and the intermediate estimates
1449+ are included in the output, i.e., the optimizer was run with
1450+ ``save_iterations=True``, then intermediate estimates are included.
1451+ Default value is ``False``.
1452+
1453+
14181454 See Also
14191455 --------
14201456 CmdStanMLE.stan_variable
@@ -1428,7 +1464,9 @@ def stan_variables(self) -> Dict[str, np.ndarray]:
14281464 )
14291465 result = {}
14301466 for name in self ._metadata .stan_vars_dims .keys ():
1431- result [name ] = self .stan_variable (name , warn = False )
1467+ result [name ] = self .stan_variable (
1468+ name , inc_iterations = inc_iterations , warn = False
1469+ )
14321470 return result
14331471
14341472 def save_csvfiles (self , dir : Optional [str ] = None ) -> None :
0 commit comments