@@ -173,13 +173,11 @@ def column_names(self) -> Tuple[str, ...]:
173173 @property
174174 def metric_type (self ) -> Optional [str ]:
175175 """
176- Metric type used for adaptation, either 'diag_e' or 'dense_e'.
176+ Metric type used for adaptation, either 'diag_e' or 'dense_e', according
177+ to CmdStan arg 'metric'.
177178 When sampler algorithm 'fixed_param' is specified, metric_type is None.
178179 """
179- if self ._is_fixed_param :
180- return None
181- # cmdstan arg name
182- return self ._metadata .cmdstan_config ['metric' ] # type: ignore
180+ return self ._metadata .cmdstan_config ['metric' ] if not self ._is_fixed_param else None
183181
184182 @property
185183 def metric (self ) -> Optional [np .ndarray ]:
@@ -203,10 +201,8 @@ def step_size(self) -> Optional[np.ndarray]:
203201 Step size used by sampler for each chain.
204202 When sampler algorithm 'fixed_param' is specified, step size is None.
205203 """
206- if self ._is_fixed_param :
207- return None
208204 self ._assemble_draws ()
209- return self ._step_size
205+ return self ._step_size if not self . _is_fixed_param else None
210206
211207 @property
212208 def thin (self ) -> int :
@@ -221,9 +217,7 @@ def divergences(self) -> Optional[np.ndarray]:
221217 Per-chain total number of post-warmup divergent iterations.
222218 When sampler algorithm 'fixed_param' is specified, returns None.
223219 """
224- if self ._is_fixed_param :
225- return None
226- return self ._divergences
220+ return self ._divergences if not self ._is_fixed_param else None
227221
228222 @property
229223 def max_treedepths (self ) -> Optional [np .ndarray ]:
@@ -232,9 +226,7 @@ def max_treedepths(self) -> Optional[np.ndarray]:
232226 reached the maximum allowed treedepth.
233227 When sampler algorithm 'fixed_param' is specified, returns None.
234228 """
235- if self ._is_fixed_param :
236- return None
237- return self ._max_treedepths
229+ return self ._max_treedepths if not self ._is_fixed_param else None
238230
239231 def draws (
240232 self , * , inc_warmup : bool = False , concat_chains : bool = False
@@ -298,7 +290,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
298290 save_warmup = self ._save_warmup ,
299291 thin = self ._thin ,
300292 )
301- if 'ct_divergences' in dzero :
293+ if not self . _is_fixed_param :
302294 self ._divergences [i ] = dzero ['ct_divergences' ]
303295 self ._max_treedepths [i ] = dzero ['ct_max_treedepth' ]
304296 else :
@@ -337,12 +329,11 @@ def _validate_csv_files(self) -> Dict[str, Any]:
337329 drest [key ],
338330 )
339331 )
340- if 'ct_divergences' in drest :
332+ if not self . _is_fixed_param :
341333 self ._divergences [i ] = drest ['ct_divergences' ]
342334 self ._max_treedepths [i ] = drest ['ct_max_treedepth' ]
343335 return dzero
344336
345- # pylint: disable=unused-variable
346337 def _check_sampler_diagnostics (self ) -> None :
347338 """
348339 Warn if any iterations ended in divergences or hit maxtreedepth.
0 commit comments