Skip to content

Commit d3e6387

Browse files
committed
made typechecker happy
1 parent 0323178 commit d3e6387

1 file changed

Lines changed: 16 additions & 10 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,12 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None:
12621262
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
12631263
self._metadata = InferenceMetadata(meta)
12641264
self._column_names: Tuple[str, ...] = meta['column_names']
1265+
assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy
12651266
self._mle = meta['mle']
12661267
if self._save_iterations:
1268+
assert isinstance(
1269+
meta['all_iters'], np.ndarray
1270+
) # make the typechecker happy
12671271
self._all_iters = meta['all_iters']
12681272

12691273
@property
@@ -1297,7 +1301,7 @@ def optimized_params_np(self) -> np.ndarray:
12971301
return self._mle
12981302

12991303
@property
1300-
def optimized_iterations_np(self) -> np.ndarray:
1304+
def optimized_iterations_np(self) -> Optional[np.ndarray]:
13011305
"""
13021306
Returns all saved iterations from the optimizer and final estimate
13031307
as a numpy.ndarray which contains all optimizer outputs, i.e.,
@@ -1330,7 +1334,7 @@ def optimized_params_pd(self) -> pd.DataFrame:
13301334
return pd.DataFrame([self._mle], columns=self.column_names)
13311335

13321336
@property
1333-
def optimized_iterations_pd(self) -> pd.DataFrame:
1337+
def optimized_iterations_pd(self) -> Optional[pd.DataFrame]:
13341338
"""
13351339
Returns all saved iterations from the optimizer and final estimate
13361340
as a pandas.DataFrame which contains all optimizer outputs, i.e.,
@@ -1417,26 +1421,28 @@ def stan_variable(
14171421
else:
14181422
num_rows = 1
14191423

1420-
# extract and reshape, container var
1421-
if len(col_idxs) > 0:
1424+
if len(col_idxs) > 0: # container var
14221425
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
14231426
# pylint: disable=redundant-keyword-arg
14241427
if num_rows > 1:
1425-
return self._all_iters[:, col_idxs].reshape( # type: ignore
1428+
result = self._all_iters[:, col_idxs].reshape( # type: ignore
14261429
dims, order='F'
14271430
)
14281431
else:
14291432
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
1430-
return (
1433+
result = (
14311434
mle[0, col_idxs]
14321435
.reshape(dims, order='F') # type: ignore
14331436
.squeeze(axis=0)
14341437
)
1438+
else: # scalar var
1439+
if num_rows > 1:
1440+
result = self._all_iters[:, col_idxs]
1441+
else:
1442+
result = np.atleast_1d(mle[0, col_idxs])
14351443

1436-
# extract scalar var
1437-
if num_rows > 1:
1438-
return self._all_iters[:, col_idxs]
1439-
return mle[0, col_idxs]
1444+
assert isinstance(result, np.ndarray) # make the typechecker happy
1445+
return result
14401446

14411447
def stan_variables(
14421448
self, inc_iterations: bool = False

0 commit comments

Comments
 (0)