Skip to content

Commit 56ac6b2

Browse files
committed
changes per code review
1 parent 0ebf838 commit 56ac6b2

2 files changed

Lines changed: 8 additions & 18 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,11 @@ def column_names(self) -> Tuple[str, ...]:
173173
@property
174174
def metric_type(self) -> Optional[str]:
175175
"""
176-
Metric type used for adaptation, either 'diag_e' or 'dense_e'.
176+
Metric type used for adaptation, either 'diag_e' or 'dense_e', according
177+
to CmdStan arg 'metric'.
177178
When sampler algorithm 'fixed_param' is specified, metric_type is None.
178179
"""
179-
if self._is_fixed_param:
180-
return None
181-
# cmdstan arg name
182-
return self._metadata.cmdstan_config['metric'] # type: ignore
180+
return self._metadata.cmdstan_config['metric'] if not self._is_fixed_param else None
183181

184182
@property
185183
def metric(self) -> Optional[np.ndarray]:
@@ -203,10 +201,8 @@ def step_size(self) -> Optional[np.ndarray]:
203201
Step size used by sampler for each chain.
204202
When sampler algorithm 'fixed_param' is specified, step size is None.
205203
"""
206-
if self._is_fixed_param:
207-
return None
208204
self._assemble_draws()
209-
return self._step_size
205+
return self._step_size if not self._is_fixed_param else None
210206

211207
@property
212208
def thin(self) -> int:
@@ -221,9 +217,7 @@ def divergences(self) -> Optional[np.ndarray]:
221217
Per-chain total number of post-warmup divergent iterations.
222218
When sampler algorithm 'fixed_param' is specified, returns None.
223219
"""
224-
if self._is_fixed_param:
225-
return None
226-
return self._divergences
220+
return self._divergences if not self._is_fixed_param else None
227221

228222
@property
229223
def max_treedepths(self) -> Optional[np.ndarray]:
@@ -232,9 +226,7 @@ def max_treedepths(self) -> Optional[np.ndarray]:
232226
reached the maximum allowed treedepth.
233227
When sampler algorithm 'fixed_param' is specified, returns None.
234228
"""
235-
if self._is_fixed_param:
236-
return None
237-
return self._max_treedepths
229+
return self._max_treedepths if not self._is_fixed_param else None
238230

239231
def draws(
240232
self, *, inc_warmup: bool = False, concat_chains: bool = False
@@ -298,7 +290,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
298290
save_warmup=self._save_warmup,
299291
thin=self._thin,
300292
)
301-
if 'ct_divergences' in dzero:
293+
if not self._is_fixed_param:
302294
self._divergences[i] = dzero['ct_divergences']
303295
self._max_treedepths[i] = dzero['ct_max_treedepth']
304296
else:
@@ -337,12 +329,11 @@ def _validate_csv_files(self) -> Dict[str, Any]:
337329
drest[key],
338330
)
339331
)
340-
if 'ct_divergences' in drest:
332+
if not self._is_fixed_param:
341333
self._divergences[i] = drest['ct_divergences']
342334
self._max_treedepths[i] = drest['ct_max_treedepth']
343335
return dzero
344336

345-
# pylint: disable=unused-variable
346337
def _check_sampler_diagnostics(self) -> None:
347338
"""
348339
Warn if any iterations ended in divergences or hit maxtreedepth.

test/test_sample.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1796,7 +1796,6 @@ def test_attrs(self):
17961796
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
17971797
dummy = fit.c
17981798

1799-
# pylint: disable=use-a-generator
18001799
def test_diagnostics(self):
18011800
# centered 8 schools hits funnel
18021801
stan = os.path.join(DATAFILES_PATH, 'eight_schools.stan')

0 commit comments

Comments
 (0)