Skip to content

Commit e533811

Browse files
committed
more unit tests; fixed gq bug; different test model
1 parent 101960c commit e533811

6 files changed

Lines changed: 92 additions & 36 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,16 +1983,26 @@ def stan_variable(
19831983
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
19841984
else: # is gq variable
19851985
self._assemble_generated_quantities()
1986-
col_idxs = self._metadata.stan_vars_cols[var]
1986+
draw1 = 0
19871987
if (
19881988
not inc_warmup
19891989
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
19901990
):
1991-
draw1 = self.mcmc_sample.num_draws_warmup * self.chains
1992-
return flatten_chains(self._draws)[ # type: ignore
1993-
draw1:, col_idxs
1994-
]
1995-
return flatten_chains(self._draws)[:, col_idxs] # type: ignore
1991+
draw1 = self.mcmc_sample.num_draws_warmup
1992+
num_draws = self.mcmc_sample.num_draws_sampling
1993+
if (
1994+
inc_warmup
1995+
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
1996+
):
1997+
num_draws += self.mcmc_sample.num_draws_warmup
1998+
dims = [num_draws * self.chains]
1999+
col_idxs = self._metadata.stan_vars_cols[var]
2000+
if len(col_idxs) > 0:
2001+
dims.extend(self._metadata.stan_vars_dims[var])
2002+
# pylint: disable=redundant-keyword-arg
2003+
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
2004+
dims, order='F'
2005+
)
19962006

19972007
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
19982008
"""

test/data/matrix_var.stan

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
transformed data {
2+
int y[10] = {0,1,0,0,0,0,0,0,0,1};
3+
}
4+
parameters {
5+
real<lower=0,upper=1> theta;
6+
}
7+
model {
8+
theta ~ beta(1,1); // uniform prior on interval 0,1
9+
y ~ bernoulli(theta);
10+
}
11+
generated quantities {
12+
# x is a 4 x 3 matrix where i,j entry == rownum
13+
matrix[4, 3] z;
14+
for (row_num in 1:4) {
15+
for (col_num in 1:3) {
16+
z[row_num, col_num] = row_num;
17+
}
18+
}
19+
}

test/test_generate_quantities.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,28 @@ def test_no_xarray(self):
426426
with self.assertRaises(RuntimeError):
427427
bern_gqs.draws_xr()
428428

429+
def test_bug_455(self):
430+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
431+
bern_model = CmdStanModel(stan_file=stan)
432+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
433+
bern_fit = bern_model.sample(
434+
data=jdata,
435+
chains=1,
436+
seed=12345,
437+
iter_sampling=1,
438+
)
439+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
440+
model = CmdStanModel(stan_file=stan)
441+
gqs = model.generate_quantities(mcmc_sample=bern_fit)
442+
z_as_ndarray = gqs.stan_variable(var="z")
443+
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
444+
z_as_xr = gqs.draws_xr(vars="z")
445+
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
446+
for i in range(4):
447+
for j in range(3):
448+
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
449+
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)
450+
429451

430452
if __name__ == '__main__':
431453
unittest.main()

test/test_optimize.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -287,36 +287,6 @@ def test_variables_3d(self):
287287
self.assertTrue('frac_60' in vars_iters)
288288
self.assertEqual(vars_iters['frac_60'].shape, (8,))
289289

290-
def test_variables_shape(self):
291-
stan = os.path.join(DATAFILES_PATH, 'shape.stan')
292-
model = CmdStanModel(stan_file=stan)
293-
no_data = {}
294-
mle = model.optimize(
295-
seed=1239812093,
296-
algorithm='LBFGS',
297-
init_alpha=0.001,
298-
iter=100,
299-
tol_obj=1e-12,
300-
tol_rel_obj=1e4,
301-
tol_grad=1e-8,
302-
tol_rel_grad=1e7,
303-
tol_param=1e-8,
304-
history_size=5,
305-
)
306-
for var, shape in {
307-
"x": float,
308-
"a": float,
309-
"b": (2,),
310-
"c": (2,3),
311-
"d": (2,),
312-
"e":(2,3),
313-
"f": (4,2,3),
314-
"g":(4,5,2,3),
315-
}.items():
316-
if isinstance(shape, tuple):
317-
assert mle.stan_variable(var).shape == shape
318-
else:
319-
assert isinstance(mle.stan_variable(var), shape)
320290

321291
class OptimizeTest(unittest.TestCase):
322292
def test_optimize_good(self):
@@ -611,6 +581,17 @@ def test_optimize_bad(self):
611581
data=no_data, seed=1239812093, inits=None, algorithm='BFGS'
612582
)
613583

584+
def test_bug_455(self):
585+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
586+
model = CmdStanModel(stan_file=stan)
587+
mle = model.optimize()
588+
self.assertTrue(isinstance(mle.stan_variable('theta'), float))
589+
z_as_ndarray = mle.stan_variable(var="z")
590+
self.assertEqual(z_as_ndarray.shape, (4, 3))
591+
for i in range(4):
592+
for j in range(3):
593+
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)
594+
614595

615596
if __name__ == '__main__':
616597
unittest.main()

test/test_sample.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,19 @@ def test_no_xarray(self):
16681668
with self.assertRaises(RuntimeError):
16691669
bern_fit.draws_xr()
16701670

1671+
def test_bug_455(self):
1672+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
1673+
bug_455_model = CmdStanModel(stan_file=stan)
1674+
bug_455_fit = bug_455_model.sample(iter_sampling=1, chains=1)
1675+
z_as_ndarray = bug_455_fit.stan_variable(var="z")
1676+
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
1677+
z_as_xr = bug_455_fit.draws_xr(vars="z")
1678+
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
1679+
for i in range(4):
1680+
for j in range(3):
1681+
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
1682+
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)
1683+
16711684

16721685
if __name__ == '__main__':
16731686
unittest.main()

test/test_variational.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,17 @@ def test_variational_eta_fail(self):
253253
)
254254
)
255255

256+
def test_bug_455(self):
257+
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
258+
model = CmdStanModel(stan_file=stan)
259+
vb_fit = model.variational()
260+
self.assertTrue(isinstance(vb_fit.stan_variable('theta'), float))
261+
z_as_ndarray = vb_fit.stan_variable(var="z")
262+
self.assertEqual(z_as_ndarray.shape, (4, 3))
263+
for i in range(4):
264+
for j in range(3):
265+
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)
266+
256267

257268
if __name__ == '__main__':
258269
unittest.main()

0 commit comments

Comments
 (0)