Skip to content

Commit 5297c64

Browse files
committed
fix shaping issues
1 parent 3224542 commit 5297c64

4 files changed

Lines changed: 36 additions & 28 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,6 @@ def draws_xr(
10741074
0,
10751075
self.draws(inc_warmup=inc_warmup),
10761076
)
1077-
10781077
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
10791078
'chain', 'draw', ...
10801079
)
@@ -1373,7 +1372,7 @@ def stan_variable(
13731372
inc_iterations: bool = False,
13741373
warn: bool = True,
13751374
name: Optional[str] = None,
1376-
) -> np.ndarray:
1375+
) -> Union[np.ndarray, float]:
13771376
"""
13781377
Return a numpy.ndarray which contains the estimates for the
13791378
for the named Stan program variable where the dimensions of the
@@ -1416,33 +1415,29 @@ def stan_variable(
14161415
'Invalid estimate, optimization failed to converge.'
14171416
)
14181417

1419-
col_idxs = self._metadata.stan_vars_cols[var]
1418+
col_idxs = list(self._metadata.stan_vars_cols[var])
14201419
if inc_iterations and self._save_iterations:
14211420
num_rows = self._all_iters.shape[0]
14221421
else:
14231422
num_rows = 1
1424-
1425-
if len(col_idxs) > 0: # container var
1423+
if len(col_idxs) > 1: # container var
14261424
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
14271425
# pylint: disable=redundant-keyword-arg
14281426
if num_rows > 1:
14291427
result = self._all_iters[:, col_idxs].reshape( # type: ignore
14301428
dims, order='F'
14311429
)
14321430
else:
1433-
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
1434-
result = (
1435-
mle[0, col_idxs]
1436-
.reshape(dims, order='F') # type: ignore
1437-
.squeeze(axis=0)
1438-
)
1431+
result = self._mle[col_idxs].reshape(dims[1:], order="F")
14391432
else: # scalar var
1433+
col_idx = col_idxs[0]
14401434
if num_rows > 1:
1441-
result = self._all_iters[:, col_idxs]
1435+
result = self._all_iters[:, col_idx]
14421436
else:
1443-
result = np.atleast_1d(mle[0, col_idxs])
1444-
1445-
assert isinstance(result, np.ndarray) # make the typechecker happy
1437+
result = float(self._mle[col_idx])
1438+
assert isinstance(
1439+
result, (np.ndarray, float)
1440+
) # make the typechecker happy
14461441
return result
14471442

14481443
def stan_variables(
@@ -2143,7 +2138,7 @@ def metadata(self) -> InferenceMetadata:
21432138

21442139
def stan_variable(
21452140
self, var: Optional[str] = None, *, name: Optional[str] = None
2146-
) -> np.ndarray:
2141+
) -> Union[np.ndarray, float]:
21472142
"""
21482143
Return a numpy.ndarray which contains the estimates for the
21492144
for the named Stan program variable where the dimensions of the
@@ -2172,12 +2167,16 @@ def stan_variable(
21722167
if var not in self._metadata.stan_vars_dims:
21732168
raise ValueError('Unknown variable name: {}'.format(var))
21742169
col_idxs = list(self._metadata.stan_vars_cols[var])
2175-
vals = list(self._variational_mean)
2176-
xs = [vals[x] for x in col_idxs]
21772170
shape: Tuple[int, ...] = ()
2178-
if len(col_idxs) > 0:
2171+
if len(col_idxs) > 1:
21792172
shape = self._metadata.stan_vars_dims[var]
2180-
return np.array(xs).reshape(shape)
2173+
result = np.asarray(self._variational_mean)[col_idxs].reshape(
2174+
shape, order="F"
2175+
)
2176+
else:
2177+
result = float(self._variational_mean[col_idxs[0]])
2178+
assert isinstance(result, (np.ndarray, float))
2179+
return result
21812180

21822181
def stan_variables(self) -> Dict[str, np.ndarray]:
21832182
"""
@@ -2424,7 +2423,12 @@ def build_xarray_data(
24242423
var_dims: Tuple[str, ...] = ('draw', 'chain')
24252424
if dims:
24262425
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))
2427-
data[var_name] = (var_dims, drawset[start_row:, :, col_idxs])
2426+
data[var_name] = (
2427+
var_dims,
2428+
drawset[start_row:, :, col_idxs].reshape(
2429+
*drawset.shape[:2], *dims, order="F"
2430+
),
2431+
)
24282432
else:
24292433
data[var_name] = (
24302434
var_dims,

docsrc/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,5 +441,7 @@ def emit(self, record):
441441
# }
442442

443443
# Makes the copying behavior on code examples cleaner by removing things like In [10]: from the text to be copied
444-
copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
444+
copybutton_prompt_text = (
445+
r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
446+
)
445447
copybutton_prompt_is_regexp = True

test/test_optimize.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ def test_variable_bern(self):
208208
self.assertTrue('theta' in bern_mle.metadata.stan_vars_dims)
209209
self.assertEqual(bern_mle.metadata.stan_vars_dims['theta'], ())
210210
theta = bern_mle.stan_variable(var='theta')
211-
self.assertEqual(theta.shape, ())
211+
self.assertTrue(isinstance(theta, float))
212212
with self.assertRaises(ValueError):
213213
bern_mle.stan_variable(var='eta')
214214
with self.assertRaises(ValueError):
215215
bern_mle.stan_variable(var='lp__')
216216
with LogCapture() as log:
217-
self.assertEqual(bern_mle.stan_variable(name='theta').shape, ())
217+
self.assertTrue(
218+
isinstance(bern_mle.stan_variable(name='theta'), float)
219+
)
218220
log.check_present(
219221
(
220222
'cmdstanpy',
@@ -250,15 +252,15 @@ def test_variables_3d(self):
250252
var_beta = multidim_mle.stan_variable(var='beta')
251253
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
252254
var_frac_60 = multidim_mle.stan_variable(var='frac_60')
253-
self.assertEqual(var_frac_60.shape, ())
255+
self.assertTrue(isinstance(var_frac_60, float))
254256
vars = multidim_mle.stan_variables()
255257
self.assertEqual(len(vars), len(multidim_mle.metadata.stan_vars_dims))
256258
self.assertTrue('y_rep' in vars)
257259
self.assertEqual(vars['y_rep'].shape, (5, 4, 3))
258260
self.assertTrue('beta' in vars)
259261
self.assertEqual(vars['beta'].shape, (2,))
260262
self.assertTrue('frac_60' in vars)
261-
self.assertEqual(vars['frac_60'].shape, ())
263+
self.assertTrue(isinstance(vars['frac_60'], float))
262264

263265
multidim_mle_iters = multidim_model.optimize(
264266
data=jdata,

test/test_variational.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_variables_3d(self):
126126
var_beta = multidim_variational.stan_variable(var='beta')
127127
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
128128
var_frac_60 = multidim_variational.stan_variable(var='frac_60')
129-
self.assertEqual(var_frac_60.shape, ())
129+
self.assertTrue(isinstance(var_frac_60, float))
130130
vars = multidim_variational.stan_variables()
131131
self.assertEqual(
132132
len(vars), len(multidim_variational.metadata.stan_vars_dims)
@@ -136,7 +136,7 @@ def test_variables_3d(self):
136136
self.assertTrue('beta' in vars)
137137
self.assertEqual(vars['beta'].shape, (2,))
138138
self.assertTrue('frac_60' in vars)
139-
self.assertEqual(vars['frac_60'].shape, ())
139+
self.assertTrue(isinstance(vars['frac_60'], float))
140140
with self.assertRaises(ValueError):
141141
multidim_variational.stan_variable(var='beta', name='yrep')
142142
with LogCapture() as log:

0 commit comments

Comments
 (0)