Skip to content

Commit 8ea5ca0

Browse files
committed
lint fix
1 parent bda136f commit 8ea5ca0

2 files changed

Lines changed: 26 additions & 16 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def __init__(
7777
assert isinstance(
7878
sampler_args, SamplerArgs
7979
) # make the typechecker happy
80-
self._iter_sampling:int = _CMDSTAN_SAMPLING
80+
self._iter_sampling: int = _CMDSTAN_SAMPLING
8181
if sampler_args.iter_sampling is not None:
8282
self._iter_sampling = sampler_args.iter_sampling
83-
self._iter_warmup:int = _CMDSTAN_WARMUP
83+
self._iter_warmup: int = _CMDSTAN_WARMUP
8484
if sampler_args.iter_warmup is not None:
8585
self._iter_warmup = sampler_args.iter_warmup
86-
self._thin:int = _CMDSTAN_THIN
86+
self._thin: int = _CMDSTAN_THIN
8787
if sampler_args.thin is not None:
8888
self._thin = sampler_args.thin
8989
self._is_fixed_param = sampler_args.fixed_param
@@ -286,8 +286,12 @@ def _validate_csv_files(self) -> Dict[str, Any]:
286286
Raises exception when inconsistencies detected.
287287
"""
288288
if not self._is_fixed_param:
289-
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
290-
self._max_treedepths: np.ndarray = np.zeros(self.runset.chains, dtype=int)
289+
self._divergences: np.ndarray = np.zeros(
290+
self.runset.chains, dtype=int
291+
)
292+
self._max_treedepths: np.ndarray = np.zeros(
293+
self.runset.chains, dtype=int
294+
)
291295

292296
dzero = {}
293297
for i in range(self.chains):
@@ -369,7 +373,7 @@ def _check_sampler_diagnostics(self) -> None:
369373
'Use function "diagnose()" to see further information.'
370374
)
371375
get_logger().warning('\n\t'.join(diagnostics))
372-
376+
373377
def _assemble_draws(self) -> None:
374378
"""
375379
Allocates and populates the step size, metric, and sample arrays

test/test_sample.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,11 +1808,15 @@ def test_diagnostics(self):
18081808
data=rdata,
18091809
seed=55157,
18101810
)
1811-
log.check_present((
1812-
'cmdstanpy',
1813-
'WARNING',
1814-
StringComparison(r'(?s).*Some chains may have failed to converge.*')
1815-
))
1811+
log.check_present(
1812+
(
1813+
'cmdstanpy',
1814+
'WARNING',
1815+
StringComparison(
1816+
r'(?s).*Some chains may have failed to converge.*'
1817+
),
1818+
)
1819+
)
18161820
self.assertFalse(np.all(fit.divergences == 0))
18171821

18181822
with LogCapture(level=logging.WARNING) as log:
@@ -1822,11 +1826,13 @@ def test_diagnostics(self):
18221826
seed=40508,
18231827
max_treedepth=3,
18241828
)
1825-
log.check_present((
1826-
'cmdstanpy',
1827-
'WARNING',
1828-
StringComparison(r'(?s).*max treedepth*')
1829-
))
1829+
log.check_present(
1830+
(
1831+
'cmdstanpy',
1832+
'WARNING',
1833+
StringComparison(r'(?s).*max treedepth*'),
1834+
)
1835+
)
18301836
self.assertFalse(np.all(fit.max_treedepths == 0))
18311837

18321838
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

0 commit comments

Comments
 (0)