Skip to content

Commit 6d4c3f0

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents 29a9831 + 84f9745 commit 6d4c3f0

7 files changed

Lines changed: 130 additions & 36 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,6 @@ def draws_xr(
10741074
0,
10751075
self.draws(inc_warmup=inc_warmup),
10761076
)
1077-
10781077
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
10791078
'chain', 'draw', ...
10801079
)
@@ -1373,7 +1372,7 @@ def stan_variable(
13731372
inc_iterations: bool = False,
13741373
warn: bool = True,
13751374
name: Optional[str] = None,
1376-
) -> np.ndarray:
1375+
) -> Union[np.ndarray, float]:
13771376
"""
13781377
Return a numpy.ndarray which contains the estimates for the
13791378
for the named Stan program variable where the dimensions of the
@@ -1416,38 +1415,34 @@ def stan_variable(
14161415
'Invalid estimate, optimization failed to converge.'
14171416
)
14181417

1419-
col_idxs = self._metadata.stan_vars_cols[var]
1418+
col_idxs = list(self._metadata.stan_vars_cols[var])
14201419
if inc_iterations and self._save_iterations:
14211420
num_rows = self._all_iters.shape[0]
14221421
else:
14231422
num_rows = 1
1424-
1425-
if len(col_idxs) > 0: # container var
1423+
if len(col_idxs) > 1: # container var
14261424
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
14271425
# pylint: disable=redundant-keyword-arg
14281426
if num_rows > 1:
14291427
result = self._all_iters[:, col_idxs].reshape( # type: ignore
14301428
dims, order='F'
14311429
)
14321430
else:
1433-
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
1434-
result = (
1435-
mle[0, col_idxs]
1436-
.reshape(dims, order='F') # type: ignore
1437-
.squeeze(axis=0)
1438-
)
1431+
result = self._mle[col_idxs].reshape(dims[1:], order="F")
14391432
else: # scalar var
1433+
col_idx = col_idxs[0]
14401434
if num_rows > 1:
1441-
result = self._all_iters[:, col_idxs]
1435+
result = self._all_iters[:, col_idx]
14421436
else:
1443-
result = np.atleast_1d(mle[0, col_idxs])
1444-
1445-
assert isinstance(result, np.ndarray) # make the typechecker happy
1437+
result = float(self._mle[col_idx])
1438+
assert isinstance(
1439+
result, (np.ndarray, float)
1440+
) # make the typechecker happy
14461441
return result
14471442

14481443
def stan_variables(
14491444
self, inc_iterations: bool = False
1450-
) -> Dict[str, np.ndarray]:
1445+
) -> Dict[str, Union[np.ndarray, float]]:
14511446
"""
14521447
Return a dictionary mapping Stan program variables names
14531448
to the corresponding numpy.ndarray containing the inferred values.
@@ -1988,16 +1983,26 @@ def stan_variable(
19881983
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
19891984
else: # is gq variable
19901985
self._assemble_generated_quantities()
1991-
col_idxs = self._metadata.stan_vars_cols[var]
1986+
draw1 = 0
19921987
if (
19931988
not inc_warmup
19941989
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
19951990
):
1996-
draw1 = self.mcmc_sample.num_draws_warmup * self.chains
1997-
return flatten_chains(self._draws)[ # type: ignore
1998-
draw1:, col_idxs
1999-
]
2000-
return flatten_chains(self._draws)[:, col_idxs] # type: ignore
1991+
draw1 = self.mcmc_sample.num_draws_warmup
1992+
num_draws = self.mcmc_sample.num_draws_sampling
1993+
if (
1994+
inc_warmup
1995+
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
1996+
):
1997+
num_draws += self.mcmc_sample.num_draws_warmup
1998+
dims = [num_draws * self.chains]
1999+
col_idxs = self._metadata.stan_vars_cols[var]
2000+
if len(col_idxs) > 0:
2001+
dims.extend(self._metadata.stan_vars_dims[var])
2002+
# pylint: disable=redundant-keyword-arg
2003+
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
2004+
dims, order='F'
2005+
)
20012006

20022007
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
20032008
"""
@@ -2143,7 +2148,7 @@ def metadata(self) -> InferenceMetadata:
21432148

21442149
def stan_variable(
21452150
self, var: Optional[str] = None, *, name: Optional[str] = None
2146-
) -> np.ndarray:
2151+
) -> Union[np.ndarray, float]:
21472152
"""
21482153
Return a numpy.ndarray which contains the estimates for the
21492154
for the named Stan program variable where the dimensions of the
@@ -2172,14 +2177,18 @@ def stan_variable(
21722177
if var not in self._metadata.stan_vars_dims:
21732178
raise ValueError('Unknown variable name: {}'.format(var))
21742179
col_idxs = list(self._metadata.stan_vars_cols[var])
2175-
vals = list(self._variational_mean)
2176-
xs = [vals[x] for x in col_idxs]
21772180
shape: Tuple[int, ...] = ()
2178-
if len(col_idxs) > 0:
2181+
if len(col_idxs) > 1:
21792182
shape = self._metadata.stan_vars_dims[var]
2180-
return np.array(xs).reshape(shape)
2183+
result = np.asarray(self._variational_mean)[col_idxs].reshape(
2184+
shape, order="F"
2185+
)
2186+
else:
2187+
result = float(self._variational_mean[col_idxs[0]])
2188+
assert isinstance(result, (np.ndarray, float))
2189+
return result
21812190

2182-
def stan_variables(self) -> Dict[str, np.ndarray]:
2191+
def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:
21832192
"""
21842193
Return a dictionary mapping Stan program variables names
21852194
to the corresponding numpy.ndarray containing the inferred values.
@@ -2424,7 +2433,12 @@ def build_xarray_data(
24242433
var_dims: Tuple[str, ...] = ('draw', 'chain')
24252434
if dims:
24262435
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))
2427-
data[var_name] = (var_dims, drawset[start_row:, :, col_idxs])
2436+
data[var_name] = (
2437+
var_dims,
2438+
drawset[start_row:, :, col_idxs].reshape(
2439+
*drawset.shape[:2], *dims, order="F"
2440+
),
2441+
)
24282442
else:
24292443
data[var_name] = (
24302444
var_dims,

docsrc/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,5 +441,7 @@ def emit(self, record):
441441
# }
442442

443443
# Makes the copying behavior on code examples cleaner by removing things like In [10]: from the text to be copied
444-
copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
444+
copybutton_prompt_text = (
445+
r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
446+
)
445447
copybutton_prompt_is_regexp = True

test/data/matrix_var.stan

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
transformed data {
2+
int y[10] = {0,1,0,0,0,0,0,0,0,1};
3+
}
4+
parameters {
5+
real<lower=0,upper=1> theta;
6+
}
7+
model {
8+
theta ~ beta(1,1); // uniform prior on interval 0,1
9+
y ~ bernoulli(theta);
10+
}
11+
generated quantities {
12+
# x is a 4 x 3 matrix where i,j entry == rownum
13+
matrix[4, 3] z;
14+
for (row_num in 1:4) {
15+
for (col_num in 1:3) {
16+
z[row_num, col_num] = row_num;
17+
}
18+
}
19+
}

test/test_generate_quantities.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,28 @@ def test_no_xarray(self):
426426
with self.assertRaises(RuntimeError):
427427
bern_gqs.draws_xr()
428428

429+
def test_single_row_csv(self):
430+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
431+
bern_model = CmdStanModel(stan_file=stan)
432+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
433+
bern_fit = bern_model.sample(
434+
data=jdata,
435+
chains=1,
436+
seed=12345,
437+
iter_sampling=1,
438+
)
439+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
440+
model = CmdStanModel(stan_file=stan)
441+
gqs = model.generate_quantities(mcmc_sample=bern_fit)
442+
z_as_ndarray = gqs.stan_variable(var="z")
443+
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
444+
z_as_xr = gqs.draws_xr(vars="z")
445+
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
446+
for i in range(4):
447+
for j in range(3):
448+
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
449+
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)
450+
429451

430452
if __name__ == '__main__':
431453
unittest.main()

test/test_optimize.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ def test_variable_bern(self):
208208
self.assertTrue('theta' in bern_mle.metadata.stan_vars_dims)
209209
self.assertEqual(bern_mle.metadata.stan_vars_dims['theta'], ())
210210
theta = bern_mle.stan_variable(var='theta')
211-
self.assertEqual(theta.shape, ())
211+
self.assertTrue(isinstance(theta, float))
212212
with self.assertRaises(ValueError):
213213
bern_mle.stan_variable(var='eta')
214214
with self.assertRaises(ValueError):
215215
bern_mle.stan_variable(var='lp__')
216216
with LogCapture() as log:
217-
self.assertEqual(bern_mle.stan_variable(name='theta').shape, ())
217+
self.assertTrue(
218+
isinstance(bern_mle.stan_variable(name='theta'), float)
219+
)
218220
log.check_present(
219221
(
220222
'cmdstanpy',
@@ -250,15 +252,15 @@ def test_variables_3d(self):
250252
var_beta = multidim_mle.stan_variable(var='beta')
251253
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
252254
var_frac_60 = multidim_mle.stan_variable(var='frac_60')
253-
self.assertEqual(var_frac_60.shape, ())
255+
self.assertTrue(isinstance(var_frac_60, float))
254256
vars = multidim_mle.stan_variables()
255257
self.assertEqual(len(vars), len(multidim_mle.metadata.stan_vars_dims))
256258
self.assertTrue('y_rep' in vars)
257259
self.assertEqual(vars['y_rep'].shape, (5, 4, 3))
258260
self.assertTrue('beta' in vars)
259261
self.assertEqual(vars['beta'].shape, (2,))
260262
self.assertTrue('frac_60' in vars)
261-
self.assertEqual(vars['frac_60'].shape, ())
263+
self.assertTrue(isinstance(vars['frac_60'], float))
262264

263265
multidim_mle_iters = multidim_model.optimize(
264266
data=jdata,
@@ -579,6 +581,17 @@ def test_optimize_bad(self):
579581
data=no_data, seed=1239812093, inits=None, algorithm='BFGS'
580582
)
581583

584+
def test_single_row_csv(self):
585+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
586+
model = CmdStanModel(stan_file=stan)
587+
mle = model.optimize()
588+
self.assertTrue(isinstance(mle.stan_variable('theta'), float))
589+
z_as_ndarray = mle.stan_variable(var="z")
590+
self.assertEqual(z_as_ndarray.shape, (4, 3))
591+
for i in range(4):
592+
for j in range(3):
593+
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)
594+
582595

583596
if __name__ == '__main__':
584597
unittest.main()

test/test_sample.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,19 @@ def test_no_xarray(self):
17251725
with self.assertRaises(RuntimeError):
17261726
bern_fit.draws_xr()
17271727

1728+
def test_single_row_csv(self):
1729+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
1730+
model = CmdStanModel(stan_file=stan)
1731+
fit = model.sample(iter_sampling=1, chains=1)
1732+
z_as_ndarray = fit.stan_variable(var="z")
1733+
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
1734+
z_as_xr = fit.draws_xr(vars="z")
1735+
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
1736+
for i in range(4):
1737+
for j in range(3):
1738+
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
1739+
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)
1740+
17281741

17291742
if __name__ == '__main__':
17301743
unittest.main()

test/test_variational.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_variables_3d(self):
126126
var_beta = multidim_variational.stan_variable(var='beta')
127127
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
128128
var_frac_60 = multidim_variational.stan_variable(var='frac_60')
129-
self.assertEqual(var_frac_60.shape, ())
129+
self.assertTrue(isinstance(var_frac_60, float))
130130
vars = multidim_variational.stan_variables()
131131
self.assertEqual(
132132
len(vars), len(multidim_variational.metadata.stan_vars_dims)
@@ -136,7 +136,7 @@ def test_variables_3d(self):
136136
self.assertTrue('beta' in vars)
137137
self.assertEqual(vars['beta'].shape, (2,))
138138
self.assertTrue('frac_60' in vars)
139-
self.assertEqual(vars['frac_60'].shape, ())
139+
self.assertTrue(isinstance(vars['frac_60'], float))
140140
with self.assertRaises(ValueError):
141141
multidim_variational.stan_variable(var='beta', name='yrep')
142142
with LogCapture() as log:
@@ -253,6 +253,17 @@ def test_variational_eta_fail(self):
253253
)
254254
)
255255

256+
def test_single_row_csv(self):
257+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
258+
model = CmdStanModel(stan_file=stan)
259+
vb_fit = model.variational()
260+
self.assertTrue(isinstance(vb_fit.stan_variable('theta'), float))
261+
z_as_ndarray = vb_fit.stan_variable(var="z")
262+
self.assertEqual(z_as_ndarray.shape, (4, 3))
263+
for i in range(4):
264+
for j in range(3):
265+
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)
266+
256267

257268
if __name__ == '__main__':
258269
unittest.main()

0 commit comments

Comments
 (0)