Skip to content

Commit aeb1399

Browse files
committed
Allow use of stan_variables through .
1 parent f7ffa6f commit aeb1399

8 files changed

Lines changed: 148 additions & 0 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def __repr__(self) -> str:
117117
# TODO - hamiltonian, profiling files
118118
return repr
119119

120+
def __getattr__(self, attr: str) -> np.ndarray:
121+
"""Synonymous with ``fit.stan_variable(attr)"""
122+
try:
123+
return self.stan_variable(attr)
124+
except ValueError as e:
125+
# pylint: disable=raise-missing-from
126+
raise AttributeError(*e.args)
127+
120128
@property
121129
def chains(self) -> int:
122130
"""Number of chains."""
@@ -647,6 +655,9 @@ def stan_variable(
647655
and the sample consists of 4 chains with 1000 post-warmup draws,
648656
this function will return a numpy.ndarray with shape (4000,3,3).
649657
658+
This functionaltiy is also available via a shortcut using ``.`` -
659+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
660+
650661
:param var: variable name
651662
652663
:param inc_warmup: When ``True`` and the warmup draws are present in
@@ -769,6 +780,14 @@ def __repr__(self) -> str:
769780
)
770781
return repr
771782

783+
def __getattr__(self, attr: str) -> np.ndarray:
784+
"""Synonymous with ``fit.stan_variable(attr)"""
785+
try:
786+
return self.stan_variable(attr)
787+
except ValueError as e:
788+
# pylint: disable=raise-missing-from
789+
raise AttributeError(*e.args)
790+
772791
def _validate_csv_files(self) -> Dict[str, Any]:
773792
"""
774793
Checks that Stan CSV output files for all chains are consistent
@@ -1160,6 +1179,9 @@ def stan_variable(
11601179
and the sample consists of 4 chains with 1000 post-warmup draws,
11611180
this function will return a numpy.ndarray with shape (4000,3,3).
11621181
1182+
This functionaltiy is also available via a shortcut using ``.`` -
1183+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
1184+
11631185
:param var: variable name
11641186
11651187
:param inc_warmup: When ``True`` and the warmup draws are present in

cmdstanpy/stanfit/mle.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __repr__(self) -> str:
5050
repr = '{} optimization failed to converge.'.format(repr)
5151
return repr
5252

53+
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
54+
"""Synonymous with ``fit.stan_variable(attr)"""
55+
try:
56+
return self.stan_variable(attr)
57+
except ValueError as e:
58+
# pylint: disable=raise-missing-from
59+
raise AttributeError(*e.args)
60+
5361
def _set_mle_attrs(self, sample_csv_0: str) -> None:
5462
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
5563
self._metadata = InferenceMetadata(meta)
@@ -165,6 +173,9 @@ def stan_variable(
165173
for the named Stan program variable where the dimensions of the
166174
numpy.ndarray match the shape of the Stan program variable.
167175
176+
This functionaltiy is also available via a shortcut using ``.`` -
177+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
178+
168179
:param var: variable name
169180
170181
:param inc_iterations: When ``True`` and the intermediate estimates

cmdstanpy/stanfit/vb.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def __repr__(self) -> str:
4141
# TODO - diagnostic, profiling files
4242
return repr
4343

44+
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
45+
"""Synonymous with ``fit.stan_variable(attr)"""
46+
try:
47+
return self.stan_variable(attr)
48+
except ValueError as e:
49+
# pylint: disable=raise-missing-from
50+
raise AttributeError(*e.args)
51+
4452
def _set_variational_attrs(self, sample_csv_0: str) -> None:
4553
meta = scan_variational_csv(sample_csv_0)
4654
self._metadata = InferenceMetadata(meta)
@@ -109,6 +117,9 @@ def stan_variable(self, var: str) -> Union[np.ndarray, float]:
109117
for the named Stan program variable where the dimensions of the
110118
numpy.ndarray match the shape of the Stan program variable.
111119
120+
This functionaltiy is also available via a shortcut using ``.`` -
121+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
122+
112123
:param var: variable name
113124
114125
See Also

test/data/named_output.stan

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
data {
2+
int<lower=0> N;
3+
int<lower=0,upper=1> y[N];
4+
}
5+
parameters {
6+
real<lower=0,upper=1> theta;
7+
}
8+
model {
9+
theta ~ beta(1,1); // uniform prior on interval 0,1
10+
y ~ bernoulli(theta);
11+
}
12+
13+
generated quantities {
14+
// these should be accessible via .
15+
real a = 4.5;
16+
array[3] real b = {1, 2.5, 4.5};
17+
18+
// these should not override built in properties/funs
19+
real thin = 3.5;
20+
int draws = 0;
21+
int optimized_params_np = 0;
22+
int variational_params_np = 0;
23+
}

test/test_generate_quantities.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,26 @@ def test_complex_output(self):
444444
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
445445
)
446446

447+
def test_attrs(self):
448+
stan_bern = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
449+
model_bern = CmdStanModel(stan_file=stan_bern)
450+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
451+
fit_sampling = model_bern.sample(chains=1, iter_sampling=10, data=jdata)
452+
453+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
454+
model = CmdStanModel(stan_file=stan)
455+
fit = model.generate_quantities(data=jdata, mcmc_sample=fit_sampling)
456+
457+
self.assertEqual(fit.a[0], 4.5)
458+
self.assertEqual(fit.b.shape, (10, 3))
459+
self.assertEqual(fit.theta.shape, (10,))
460+
461+
fit.draws()
462+
self.assertEqual(fit.stan_variable('draws')[0], 0)
463+
464+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
465+
dummy = fit.c
466+
447467

448468
if __name__ == '__main__':
449469
unittest.main()

test/test_optimize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,24 @@ def test_complex_output(self):
609609
# make sure the name 'imag' isn't magic
610610
self.assertEqual(fit.stan_variable('imag').shape, (2,))
611611

612+
def test_attrs(self):
613+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
614+
model = CmdStanModel(stan_file=stan)
615+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
616+
fit = model.optimize(data=jdata)
617+
618+
self.assertEqual(fit.a, 4.5)
619+
self.assertEqual(fit.b.shape, (3,))
620+
self.assertIsInstance(fit.theta, float)
621+
622+
self.assertEqual(fit.stan_variable('thin'), 3.5)
623+
624+
self.assertIsInstance(fit.optimized_params_np, np.ndarray)
625+
self.assertEqual(fit.stan_variable('optimized_params_np'), 0)
626+
627+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
628+
dummy = fit.c
629+
612630

613631
if __name__ == '__main__':
614632
unittest.main()

test/test_sample.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,25 @@ def test_complex_output(self):
17761776
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
17771777
)
17781778

1779+
def test_attrs(self):
1780+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
1781+
model = CmdStanModel(stan_file=stan)
1782+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1783+
fit = model.sample(chains=1, iter_sampling=10, data=jdata)
1784+
1785+
self.assertEqual(fit.a[0], 4.5)
1786+
self.assertEqual(fit.b.shape, (10, 3))
1787+
self.assertEqual(fit.theta.shape, (10,))
1788+
1789+
self.assertEqual(fit.thin, 1)
1790+
self.assertEqual(fit.stan_variable('thin')[0], 3.5)
1791+
1792+
fit.draws()
1793+
self.assertEqual(fit.stan_variable('draws')[0], 0)
1794+
1795+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
1796+
dummy = fit.c
1797+
17791798

17801799
if __name__ == '__main__':
17811800
unittest.main()

test/test_variational.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88
from math import fabs
99

10+
import numpy as np
1011
from testfixtures import LogCapture
1112

1213
from cmdstanpy.cmdstan_args import CmdStanArgs, VariationalArgs
@@ -264,6 +265,29 @@ def test_complex_output(self):
264265
# make sure the name 'imag' isn't magic
265266
self.assertEqual(fit.stan_variable('imag').shape, (2,))
266267

268+
def test_attrs(self):
269+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
270+
model = CmdStanModel(stan_file=stan)
271+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
272+
fit = model.variational(
273+
data=jdata,
274+
require_converged=False,
275+
seed=12345,
276+
algorithm='meanfield',
277+
)
278+
279+
self.assertEqual(fit.a, 4.5)
280+
self.assertEqual(fit.b.shape, (3,))
281+
self.assertIsInstance(fit.theta, float)
282+
283+
self.assertEqual(fit.stan_variable('thin'), 3.5)
284+
285+
self.assertIsInstance(fit.variational_params_np, np.ndarray)
286+
self.assertEqual(fit.stan_variable('variational_params_np'), 0)
287+
288+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
289+
dummy = fit.c
290+
267291

268292
if __name__ == '__main__':
269293
unittest.main()

0 commit comments

Comments
 (0)