@@ -441,8 +441,8 @@ def test_multi_proc_threads(self):
441441 cpp_options = {'STAN_THREADS' : 'TRUE' },
442442 )
443443 info_dict = logistic_model .exe_info ()
444- self .assertTrue (info_dict is not None )
445- self .assertTrue ('STAN_THREADS' in info_dict )
444+ self .assertIsNotNone (info_dict )
445+ self .assertIn ('STAN_THREADS' , info_dict )
446446 self .assertEqual (info_dict ['STAN_THREADS' ], 'true' )
447447
448448 logistic_data = os .path .join (DATAFILES_PATH , 'logistic.data.R' )
@@ -642,8 +642,8 @@ def test_show_console(self, stanfile='bernoulli.stan'):
642642 show_console = True ,
643643 )
644644 console = sys_stdout .getvalue ()
645- self .assertTrue ('Chain [1] method = sample' in console )
646- self .assertTrue ('Chain [2] method = sample' in console )
645+ self .assertIn ('Chain [1] method = sample' , console )
646+ self .assertIn ('Chain [2] method = sample' , console )
647647
648648 def test_show_progress (self , stanfile = 'bernoulli.stan' ):
649649 stan = os .path .join (DATAFILES_PATH , stanfile )
@@ -660,9 +660,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
660660 show_progress = True ,
661661 )
662662 console = sys_stderr .getvalue ()
663- self .assertTrue ('chain 1' in console )
664- self .assertTrue ('chain 2' in console )
665- self .assertTrue ('Sampling completed' in console )
663+ self .assertIn ('chain 1' , console )
664+ self .assertIn ('chain 2' , console )
665+ self .assertIn ('Sampling completed' , console )
666666
667667 sys_stderr = io .StringIO () # tqdm prints to stderr
668668 with contextlib .redirect_stderr (sys_stderr ):
@@ -674,9 +674,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
674674 show_progress = True ,
675675 )
676676 console = sys_stderr .getvalue ()
677- self .assertTrue ('chain 6' in console )
678- self .assertTrue ('chain 7' in console )
679- self .assertTrue ('Sampling completed' in console )
677+ self .assertIn ('chain 6' , console )
678+ self .assertIn ('chain 7' , console )
679+ self .assertIn ('Sampling completed' , console )
680680 sys_stderr = io .StringIO () # tqdm prints to stderr
681681
682682 with contextlib .redirect_stderr (sys_stderr ):
@@ -690,9 +690,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
690690 show_progress = True ,
691691 )
692692 console = sys_stderr .getvalue ()
693- self .assertTrue ('chain 6' in console )
694- self .assertTrue ('chain 7' in console )
695- self .assertTrue ('Sampling completed' in console )
693+ self .assertIn ('chain 6' , console )
694+ self .assertIn ('chain 7' , console )
695+ self .assertIn ('Sampling completed' , console )
696696
697697
698698class CmdStanMCMCTest (CustomTestCase ):
@@ -1388,7 +1388,7 @@ def test_variable_bern(self):
13881388 data = jdata , chains = 2 , seed = 12345 , iter_warmup = 100 , iter_sampling = 100
13891389 )
13901390 self .assertEqual (1 , len (bern_fit .metadata .stan_vars_dims ))
1391- self .assertTrue ('theta' in bern_fit .metadata .stan_vars_dims )
1391+ self .assertIn ('theta' , bern_fit .metadata .stan_vars_dims )
13921392 self .assertEqual (bern_fit .metadata .stan_vars_dims ['theta' ], ())
13931393 self .assertEqual (bern_fit .stan_variable (var = 'theta' ).shape , (200 ,))
13941394 with self .assertRaises (ValueError ):
@@ -1401,13 +1401,13 @@ def test_variables_2d(self):
14011401 fit = from_csv (path = csvfiles_path )
14021402 self .assertEqual (20 , fit .num_draws_sampling )
14031403 self .assertEqual (8 , len (fit .metadata .stan_vars_dims ))
1404- self .assertTrue ('z' in fit .metadata .stan_vars_dims )
1404+ self .assertIn ('z' , fit .metadata .stan_vars_dims )
14051405 self .assertEqual (fit .metadata .stan_vars_dims ['z' ], (20 , 2 ))
14061406 vars = fit .stan_variables ()
14071407 self .assertEqual (len (vars ), len (fit .metadata .stan_vars_dims ))
1408- self .assertTrue ('z' in vars )
1408+ self .assertIn ('z' , vars )
14091409 self .assertEqual (vars ['z' ].shape , (20 , 20 , 2 ))
1410- self .assertTrue ('theta' in vars )
1410+ self .assertIn ('theta' , vars )
14111411 self .assertEqual (vars ['theta' ].shape , (20 , 4 ))
14121412
14131413 def test_variables_3d (self ):
@@ -1416,7 +1416,7 @@ def test_variables_3d(self):
14161416 fit = from_csv (path = csvfiles_path )
14171417 self .assertEqual (20 , fit .num_draws_sampling )
14181418 self .assertEqual (3 , len (fit .metadata .stan_vars_dims ))
1419- self .assertTrue ('y_rep' in fit .metadata .stan_vars_dims )
1419+ self .assertIn ('y_rep' , fit .metadata .stan_vars_dims )
14201420 self .assertEqual (fit .metadata .stan_vars_dims ['y_rep' ], (5 , 4 , 3 ))
14211421 var_y_rep = fit .stan_variable (var = 'y_rep' )
14221422 self .assertEqual (var_y_rep .shape , (20 , 5 , 4 , 3 ))
@@ -1426,11 +1426,11 @@ def test_variables_3d(self):
14261426 self .assertEqual (var_frac_60 .shape , (20 ,))
14271427 vars = fit .stan_variables ()
14281428 self .assertEqual (len (vars ), len (fit .metadata .stan_vars_dims ))
1429- self .assertTrue ('y_rep' in vars )
1429+ self .assertIn ('y_rep' , vars )
14301430 self .assertEqual (vars ['y_rep' ].shape , (20 , 5 , 4 , 3 ))
1431- self .assertTrue ('beta' in vars )
1431+ self .assertIn ('beta' , vars )
14321432 self .assertEqual (vars ['beta' ].shape , (20 , 2 ))
1433- self .assertTrue ('frac_60' in vars )
1433+ self .assertIn ('frac_60' , vars )
14341434 self .assertEqual (vars ['frac_60' ].shape , (20 ,))
14351435
14361436 def test_variables_issue_361 (self ):
@@ -1592,18 +1592,16 @@ def test_metadata(self):
15921592 fit = CmdStanMCMC (runset )
15931593 meta = fit .metadata
15941594 self .assertEqual (meta .cmdstan_config ['model' ], 'logistic_model' )
1595- col_names = tuple (
1596- [
1597- 'lp__' ,
1598- 'accept_stat__' ,
1599- 'stepsize__' ,
1600- 'treedepth__' ,
1601- 'n_leapfrog__' ,
1602- 'divergent__' ,
1603- 'energy__' ,
1604- 'beta[1]' ,
1605- 'beta[2]' ,
1606- ]
1595+ col_names = (
1596+ 'lp__' ,
1597+ 'accept_stat__' ,
1598+ 'stepsize__' ,
1599+ 'treedepth__' ,
1600+ 'n_leapfrog__' ,
1601+ 'divergent__' ,
1602+ 'energy__' ,
1603+ 'beta[1]' ,
1604+ 'beta[2]' ,
16071605 )
16081606
16091607 self .assertEqual (fit .chains , 4 )
@@ -1619,14 +1617,14 @@ def test_metadata(self):
16191617 self .assertEqual (fit .metadata .cmdstan_config ['metric' ], 'diag_e' )
16201618 self .assertAlmostEqual (fit .metadata .cmdstan_config ['delta' ], 0.80 )
16211619
1622- self .assertTrue ('n_leapfrog__' in fit .metadata .method_vars_cols )
1623- self .assertTrue ('energy__' in fit .metadata .method_vars_cols )
1624- self .assertTrue ('beta' not in fit .metadata .method_vars_cols )
1625- self .assertTrue ('energy__' not in fit .metadata .stan_vars_dims )
1626- self .assertTrue ('beta' in fit .metadata .stan_vars_dims )
1627- self .assertTrue ('beta' in fit .metadata .stan_vars_cols )
1628- self .assertEqual (fit .metadata .stan_vars_dims ['beta' ], tuple ([ 2 ] ))
1629- self .assertEqual (fit .metadata .stan_vars_cols ['beta' ], tuple ([ 7 , 8 ] ))
1620+ self .assertIn ('n_leapfrog__' , fit .metadata .method_vars_cols )
1621+ self .assertIn ('energy__' , fit .metadata .method_vars_cols )
1622+ self .assertNotIn ('beta' , fit .metadata .method_vars_cols )
1623+ self .assertNotIn ('energy__' , fit .metadata .stan_vars_dims )
1624+ self .assertIn ('beta' , fit .metadata .stan_vars_dims )
1625+ self .assertIn ('beta' , fit .metadata .stan_vars_cols )
1626+ self .assertEqual (fit .metadata .stan_vars_dims ['beta' ], ( 2 , ))
1627+ self .assertEqual (fit .metadata .stan_vars_cols ['beta' ], ( 7 , 8 ))
16301628
16311629 def test_save_latent_dynamics (self ):
16321630 stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
0 commit comments