Skip to content

Commit 0323178

Browse files
committed
added arg 'inc_iterations' to stan_variable(s)
1 parent 1a06f9c commit 0323178

4 files changed

Lines changed: 160 additions & 28 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,7 @@ def __init__(self, runset: RunSet) -> None:
12411241
assert isinstance(
12421242
optimize_args, OptimizeArgs
12431243
) # make the typechecker happy
1244-
self.save_iterations = optimize_args.save_iterations
1244+
self._save_iterations = optimize_args.save_iterations
12451245
self._set_mle_attrs(runset.csv_files[0])
12461246

12471247
def __repr__(self) -> str:
@@ -1259,11 +1259,11 @@ def __repr__(self) -> str:
12591259
return repr
12601260

12611261
def _set_mle_attrs(self, sample_csv_0: str) -> None:
1262-
meta = scan_optimize_csv(sample_csv_0, self.save_iterations)
1262+
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
12631263
self._metadata = InferenceMetadata(meta)
12641264
self._column_names: Tuple[str, ...] = meta['column_names']
12651265
self._mle = meta['mle']
1266-
if self.save_iterations:
1266+
if self._save_iterations:
12671267
self._all_iters = meta['all_iters']
12681268

12691269
@property
@@ -1304,11 +1304,10 @@ def optimized_iterations_np(self) -> np.ndarray:
13041304
the value for `lp__` as well as all Stan program variables.
13051305
13061306
"""
1307-
if not self.save_iterations:
1307+
if not self._save_iterations:
13081308
get_logger().warning(
1309-
'Intermediate iterations not saved because optimizer argument '
1310-
'"save_iterations=True" not specified. You must rerun '
1311-
'the optimize method accordingly.'
1309+
'Intermediate iterations not saved to CSV output file. '
1310+
'Rerun the optimize method with "save_iterations=True".'
13121311
)
13131312
return None
13141313
if not self.converged:
@@ -1338,11 +1337,10 @@ def optimized_iterations_pd(self) -> pd.DataFrame:
13381337
the value for `lp__` as well as all Stan program variables.
13391338
13401339
"""
1341-
if not self.save_iterations:
1340+
if not self._save_iterations:
13421341
get_logger().warning(
1343-
'Intermediate iterations not saved because optimizer argument '
1344-
'"save_iterations=True" not specified. You must rerun '
1345-
'the optimize method accordingly.'
1342+
'Intermediate iterations not saved to CSV output file. '
1343+
'Rerun the optimize method with "save_iterations=True".'
13461344
)
13471345
return None
13481346
if not self.converged:
@@ -1367,6 +1365,7 @@ def stan_variable(
13671365
self,
13681366
var: Optional[str] = None,
13691367
*,
1368+
inc_iterations: bool = False,
13701369
warn: bool = True,
13711370
name: Optional[str] = None,
13721371
) -> np.ndarray:
@@ -1377,6 +1376,11 @@ def stan_variable(
13771376
13781377
:param var: variable name
13791378
1379+
:param inc_iterations: When ``True`` and the intermediate estimates
1380+
are included in the output, i.e., the optimizer was run with
1381+
``save_iterations=True``, then intermediate estimates are included.
1382+
Default value is ``False``.
1383+
13801384
See Also
13811385
--------
13821386
CmdStanMLE.stan_variables
@@ -1397,24 +1401,56 @@ def stan_variable(
13971401
raise ValueError('no variable name specified.')
13981402
if var not in self._metadata.stan_vars_dims:
13991403
raise ValueError('unknown variable name: {}'.format(var))
1404+
if warn and inc_iterations and not self._save_iterations:
1405+
get_logger().warning(
1406+
'Intermediate iterations not saved to CSV output file. '
1407+
'Rerun the optimize method with "save_iterations=True".'
1408+
)
14001409
if warn and not self.runset._check_retcodes():
14011410
get_logger().warning(
14021411
'Invalid estimate, optimization failed to converge.'
14031412
)
14041413

1405-
col_idxs = list(self._metadata.stan_vars_cols[var])
1406-
vals = list(self._mle)
1407-
xs = [vals[x] for x in col_idxs]
1408-
shape: Tuple[int, ...] = ()
1414+
col_idxs = self._metadata.stan_vars_cols[var]
1415+
if inc_iterations and self._save_iterations:
1416+
num_rows = self._all_iters.shape[0]
1417+
else:
1418+
num_rows = 1
1419+
1420+
# extract and reshape, container var
14091421
if len(col_idxs) > 0:
1410-
shape = self._metadata.stan_vars_dims[var]
1411-
return np.array(xs).reshape(shape)
1422+
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
1423+
# pylint: disable=redundant-keyword-arg
1424+
if num_rows > 1:
1425+
return self._all_iters[:, col_idxs].reshape( # type: ignore
1426+
dims, order='F'
1427+
)
1428+
else:
1429+
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
1430+
return (
1431+
mle[0, col_idxs]
1432+
.reshape(dims, order='F') # type: ignore
1433+
.squeeze(axis=0)
1434+
)
14121435

1413-
def stan_variables(self) -> Dict[str, np.ndarray]:
1436+
# extract scalar var
1437+
if num_rows > 1:
1438+
return self._all_iters[:, col_idxs]
1439+
return mle[0, col_idxs]
1440+
1441+
def stan_variables(
1442+
self, inc_iterations: bool = False
1443+
) -> Dict[str, np.ndarray]:
14141444
"""
14151445
Return a dictionary mapping Stan program variables names
14161446
to the corresponding numpy.ndarray containing the inferred values.
14171447
1448+
:param inc_iterations: When ``True`` and the intermediate estimates
1449+
are included in the output, i.e., the optimizer was run with
1450+
``save_iterations=True``, then intermediate estimates are included.
1451+
Default value is ``False``.
1452+
1453+
14181454
See Also
14191455
--------
14201456
CmdStanMLE.stan_variable
@@ -1428,7 +1464,9 @@ def stan_variables(self) -> Dict[str, np.ndarray]:
14281464
)
14291465
result = {}
14301466
for name in self._metadata.stan_vars_dims.keys():
1431-
result[name] = self.stan_variable(name, warn=False)
1467+
result[name] = self.stan_variable(
1468+
name, inc_iterations=inc_iterations, warn=False
1469+
)
14321470
return result
14331471

14341472
def save_csvfiles(self, dir: Optional[str] = None) -> None:

cmdstanpy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
609609
if save_iters:
610610
all_iters[i, :] = [float(x) for x in xs]
611611
if i == iters - 1:
612-
mle = np.array([float(x) for x in xs], dtype=float)
612+
mle = np.array([float(x) for x in xs])
613613
dict['mle'] = mle
614614
if save_iters:
615615
dict['all_iters'] = all_iters
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# stan_version_major = 2
2+
# stan_version_minor = 27
3+
# stan_version_patch = 0
4+
# model = rosenbrock_model
5+
# start_datetime = 2021-09-02 23:50:07 UTC
6+
# method = optimize
7+
# optimize
8+
# algorithm = lbfgs (Default)
9+
# lbfgs
10+
# init_alpha = 0.001 (Default)
11+
# tol_obj = 9.9999999999999998e-13 (Default)
12+
# tol_rel_obj = 10000 (Default)
13+
# tol_grad = 1e-08 (Default)
14+
# tol_rel_grad = 10000000 (Default)
15+
# tol_param = 1e-08 (Default)
16+
# history_size = 5 (Default)
17+
# iter = 2000 (Default)
18+
# save_iterations = 1
19+
# id = 0 (Default)
20+
# data
21+
# file = (Default)
22+
# init = 2 (Default)
23+
# random
24+
# seed = 12345
25+
# output
26+
# file = output.csv (Default)
27+
# diagnostic_file = (Default)
28+
# refresh = 100 (Default)
29+
# sig_figs = -1 (Default)
30+
# profile_file = profile.csv (Default)
31+
# stanc_version = stanc3 v2.27.0
32+
# stancflags =
33+
lp__,x,y
34+
-1732.6,1.98441,-0.2234
35+
-134.258,-1.32063,0.608857
36+
-17.9857,-1.03615,0.701595
37+
-5.53465,-0.93882,0.74813
38+
-3.59021,-0.888547,0.774151
39+
-3.53145,-0.878129,0.777499
40+
-3.52641,-0.875729,0.775875
41+
-3.5005,-0.86295,0.761978
42+
-3.43709,-0.832101,0.720763
43+
-2.19215,-0.451751,0.174997
44+
-2.1677,-0.456726,0.187233
45+
-1.7755,-0.228015,0.000272543
46+
-1.50767,-0.224409,0.041146
47+
-1.20549,-0.0725003,-0.0182459
48+
-1.0779,0.0346959,-0.0370183
49+
-0.86807,0.0898603,-0.0118539
50+
-0.698515,0.231494,0.0207393
51+
-0.481488,0.306423,0.0959909
52+
-0.385645,0.410491,0.148978
53+
-0.354102,0.436205,0.171239
54+
-0.214411,0.536994,0.288962
55+
-0.176819,0.612981,0.359304
56+
-0.0913081,0.710088,0.495705
57+
-0.0592561,0.757026,0.571606
58+
-0.0498401,0.824561,0.666095
59+
-0.0443989,0.826606,0.671305
60+
-0.0145855,0.887119,0.782687
61+
-0.00731916,0.920044,0.843437
62+
-0.00198537,0.955481,0.912759
63+
-0.00124069,0.97227,0.943138
64+
-0.000920772,0.975825,0.9504
65+
-6.69892e-05,0.992364,0.984492
66+
-5.87688e-06,0.998722,0.997239
67+
-8.49546e-07,0.99969,0.999466
68+
-3.31295e-08,0.999975,0.999969
69+
-5.1809e-11,1,1

test/test_optimize.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ def test_rosenbrock(self):
113113
(
114114
'cmdstanpy',
115115
'WARNING',
116-
'Intermediate iterations not saved because optimizer argument '
117-
'"save_iterations=True" not specified. You must rerun '
118-
'the optimize method accordingly.',
116+
'Intermediate iterations not saved to CSV output file. '
117+
'Rerun the optimize method with "save_iterations=True".',
119118
)
120119
)
121120
with LogCapture() as log:
@@ -124,9 +123,8 @@ def test_rosenbrock(self):
124123
(
125124
'cmdstanpy',
126125
'WARNING',
127-
'Intermediate iterations not saved because optimizer argument '
128-
'"save_iterations=True" not specified. You must rerun '
129-
'the optimize method accordingly.',
126+
'Intermediate iterations not saved to CSV output file. '
127+
'Rerun the optimize method with "save_iterations=True".',
130128
)
131129
)
132130

@@ -142,6 +140,9 @@ def test_rosenbrock(self):
142140
self.assertAlmostEqual(mle.optimized_params_pd['x'][0], 1, places=3)
143141
self.assertAlmostEqual(mle.optimized_params_dict['x'], 1, places=3)
144142

143+
self.assertEqual(
144+
mle.stan_variable('x', inc_iterations=True).shape, (36,)
145+
)
145146
self.assertEqual(mle.optimized_iterations_np.shape, (36, 3))
146147
self.assertNotEqual(
147148
mle.optimized_iterations_np[0, 1],
@@ -166,7 +167,7 @@ def test_eight_schools(self):
166167
self.assertIn('method=optimize', mle.__repr__())
167168
self.assertFalse(mle.converged)
168169
with LogCapture() as log:
169-
self.assertEqual(mle.optimized_params_pd.shape, (1,11))
170+
self.assertEqual(mle.optimized_params_pd.shape, (1, 11))
170171
log.check_present(
171172
(
172173
'cmdstanpy',
@@ -221,7 +222,6 @@ def test_variable_bern(self):
221222
)
222223

223224
def test_variables_3d(self):
224-
# construct fit using existing sampler output
225225
stan = os.path.join(DATAFILES_PATH, 'multidim_vars.stan')
226226
jdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
227227
multidim_model = CmdStanModel(stan_file=stan)
@@ -258,6 +258,31 @@ def test_variables_3d(self):
258258
self.assertTrue('frac_60' in vars)
259259
self.assertEqual(vars['frac_60'].shape, ())
260260

261+
multidim_mle_iters = multidim_model.optimize(
262+
data=jdata,
263+
seed=1239812093,
264+
algorithm='LBFGS',
265+
init_alpha=0.001,
266+
iter=100,
267+
tol_obj=1e-12,
268+
tol_rel_obj=1e4,
269+
tol_grad=1e-8,
270+
tol_rel_grad=1e7,
271+
tol_param=1e-8,
272+
history_size=5,
273+
save_iterations=True,
274+
)
275+
vars_iters = multidim_mle_iters.stan_variables(inc_iterations=True)
276+
self.assertEqual(
277+
len(vars_iters), len(multidim_mle_iters.metadata.stan_vars_dims)
278+
)
279+
self.assertTrue('y_rep' in vars_iters)
280+
self.assertEqual(vars_iters['y_rep'].shape, (8, 5, 4, 3))
281+
self.assertTrue('beta' in vars_iters)
282+
self.assertEqual(vars_iters['beta'].shape, (8, 2))
283+
self.assertTrue('frac_60' in vars_iters)
284+
self.assertEqual(vars_iters['frac_60'].shape, (8,))
285+
261286

262287
class OptimizeTest(unittest.TestCase):
263288
def test_optimize_good(self):

0 commit comments

Comments
 (0)