@@ -1801,48 +1801,38 @@ def test_diagnostics(self):
18011801 stan = os .path .join (DATAFILES_PATH , 'eight_schools.stan' )
18021802 model = CmdStanModel (stan_file = stan )
18031803 rdata = os .path .join (DATAFILES_PATH , 'eight_schools.data.R' )
1804- with LogCapture () as log :
1804+ with LogCapture (level = logging . WARNING ) as log :
18051805 logging .getLogger ()
18061806 fit = model .sample (
18071807 data = rdata ,
18081808 seed = 55157 ,
18091809 show_progress = False ,
18101810 show_console = False ,
18111811 )
1812-
1813- self .assertTrue (
1814- any ([a != b for a , b in zip (fit .max_treedepths , [0 , 0 , 0 , 0 ])])
1815- )
1812+ msg = log .actual ()[- 1 ][- 1 ]
18161813 self .assertTrue (
1817- any ([ a != b for a , b in zip ( fit . divergences , [ 0 , 0 , 0 , 0 ])] )
1814+ msg . startswith ( 'Some chains may have failed to converge.' )
18181815 )
1816+ self .assertFalse (np .all (fit .divergences == 0 ))
18191817
1820- log .check_present (
1821- (
1822- 'cmdstanpy' ,
1823- 'WARNING' ,
1824- 'Some chains may have failed to converge.\n '
1825- '\t Chain 1 had 10 divergent transitions (1.0%)\n '
1826- '\t Chain 2 had 143 divergent transitions (14.3%)\n '
1827- '\t Chain 3 had 5 divergent transitions (0.5%)\n '
1828- '\t Chain 4 had 4 divergent transitions (0.4%)\n '
1829- '\t Chain 4 had 6 iterations at max treedepth (0.6%)\n '
1830- '\t Use function "diagnose()" to see further information.' ,
1831- ),
1818+ with LogCapture (level = logging .WARNING ) as log :
1819+ logging .getLogger ()
1820+ fit = model .sample (
1821+ data = rdata ,
1822+ seed = 40508 ,
1823+ max_treedepth = 3 ,
18321824 )
1825+ msg = log .actual ()[- 1 ][- 1 ]
1826+ self .assertTrue ('max treedepth' in msg )
1827+ self .assertFalse (np .all (fit .max_treedepths == 0 ))
18331828
18341829 stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
18351830 model = CmdStanModel (stan_file = stan )
18361831 jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
18371832 fit = model .sample (
18381833 data = jdata ,
1839- chains = 2 ,
1840- parallel_chains = 2 ,
1841- seed = 12345 ,
18421834 iter_warmup = 200 ,
18431835 iter_sampling = 100 ,
1844- show_progress = False ,
1845- show_console = False ,
18461836 )
18471837 self .assertTrue (np .all (fit .divergences == 0 ))
18481838 self .assertTrue (np .all (fit .max_treedepths == 0 ))
0 commit comments