Skip to content

Commit 03a8f8b

Browse files
committed
logging test less brittle
1 parent c803469 commit 03a8f8b

1 file changed

Lines changed: 13 additions & 23 deletions

File tree

test/test_sample.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,48 +1801,38 @@ def test_diagnostics(self):
18011801
stan = os.path.join(DATAFILES_PATH, 'eight_schools.stan')
18021802
model = CmdStanModel(stan_file=stan)
18031803
rdata = os.path.join(DATAFILES_PATH, 'eight_schools.data.R')
1804-
with LogCapture() as log:
1804+
with LogCapture(level=logging.WARNING) as log:
18051805
logging.getLogger()
18061806
fit = model.sample(
18071807
data=rdata,
18081808
seed=55157,
18091809
show_progress=False,
18101810
show_console=False,
18111811
)
1812-
1813-
self.assertTrue(
1814-
any([a != b for a, b in zip(fit.max_treedepths, [0, 0, 0, 0])])
1815-
)
1812+
msg = log.actual()[-1][-1]
18161813
self.assertTrue(
1817-
any([a != b for a, b in zip(fit.divergences, [0, 0, 0, 0])])
1814+
msg.startswith('Some chains may have failed to converge.')
18181815
)
1816+
self.assertFalse(np.all(fit.divergences == 0))
18191817

1820-
log.check_present(
1821-
(
1822-
'cmdstanpy',
1823-
'WARNING',
1824-
'Some chains may have failed to converge.\n'
1825-
'\tChain 1 had 10 divergent transitions (1.0%)\n'
1826-
'\tChain 2 had 143 divergent transitions (14.3%)\n'
1827-
'\tChain 3 had 5 divergent transitions (0.5%)\n'
1828-
'\tChain 4 had 4 divergent transitions (0.4%)\n'
1829-
'\tChain 4 had 6 iterations at max treedepth (0.6%)\n'
1830-
'\tUse function "diagnose()" to see further information.',
1831-
),
1818+
with LogCapture(level=logging.WARNING) as log:
1819+
logging.getLogger()
1820+
fit = model.sample(
1821+
data=rdata,
1822+
seed=40508,
1823+
max_treedepth=3,
18321824
)
1825+
msg = log.actual()[-1][-1]
1826+
self.assertTrue('max treedepth' in msg)
1827+
self.assertFalse(np.all(fit.max_treedepths == 0))
18331828

18341829
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
18351830
model = CmdStanModel(stan_file=stan)
18361831
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
18371832
fit = model.sample(
18381833
data=jdata,
1839-
chains=2,
1840-
parallel_chains=2,
1841-
seed=12345,
18421834
iter_warmup=200,
18431835
iter_sampling=100,
1844-
show_progress=False,
1845-
show_console=False,
18461836
)
18471837
self.assertTrue(np.all(fit.divergences == 0))
18481838
self.assertTrue(np.all(fit.max_treedepths == 0))

0 commit comments

Comments
 (0)