Skip to content

Commit 101960c

Browse files
committed
add shape test
1 parent fcb73a8 commit 101960c

2 files changed

Lines changed: 60 additions & 0 deletions

File tree

test/data/shape.stan

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
parameters {
2+
real z;
3+
}
4+
model {
5+
z ~ normal(0,1);
6+
}
7+
generated quantities {
8+
int x = 1;
9+
real a = 1;
10+
real b[2];
11+
real c[2,3];
12+
vector[2] d;
13+
vector[3] e[2];
14+
matrix[2,3] f[4];
15+
matrix[2,3] g[4,5];
16+
for (n in 1:2) {
17+
b[n] = 1;
18+
d[n] = 1;
19+
for (m in 1:3) {
20+
c[n,m] = n;
21+
e[n,m] = n;
22+
for (k in 1:4) {
23+
f[k, n, m] = n;
24+
for (j in 1:4) {
25+
g[k, j, n, m] = n;
26+
}
27+
}
28+
}
29+
}
30+
}

test/test_optimize.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,36 @@ def test_variables_3d(self):
287287
self.assertTrue('frac_60' in vars_iters)
288288
self.assertEqual(vars_iters['frac_60'].shape, (8,))
289289

290+
def test_variables_shape(self):
291+
stan = os.path.join(DATAFILES_PATH, 'shape.stan')
292+
model = CmdStanModel(stan_file=stan)
293+
no_data = {}
294+
mle = model.optimize(
295+
seed=1239812093,
296+
algorithm='LBFGS',
297+
init_alpha=0.001,
298+
iter=100,
299+
tol_obj=1e-12,
300+
tol_rel_obj=1e4,
301+
tol_grad=1e-8,
302+
tol_rel_grad=1e7,
303+
tol_param=1e-8,
304+
history_size=5,
305+
)
306+
for var, shape in {
307+
"x": float,
308+
"a": float,
309+
"b": (2,),
310+
"c": (2,3),
311+
"d": (2,),
312+
"e":(2,3),
313+
"f": (4,2,3),
314+
"g":(4,5,2,3),
315+
}.items():
316+
if isinstance(shape, tuple):
317+
assert mle.stan_variable(var).shape == shape
318+
else:
319+
assert isinstance(mle.stan_variable(var), shape)
290320

291321
class OptimizeTest(unittest.TestCase):
292322
def test_optimize_good(self):

0 commit comments

Comments
 (0)