Skip to content

Commit ee3692e

Browse files
authored
Merge pull request #540 from stan-dev/test-cleanup
Tidy up tests
2 parents ef23b3b + 490e841 commit ee3692e

8 files changed

Lines changed: 125 additions & 129 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def __init__(
705705
self,
706706
model_name: str,
707707
model_exe: Optional[str],
708-
chain_ids: Union[List[int], None],
708+
chain_ids: Optional[List[int]],
709709
method_args: Union[
710710
SamplerArgs, OptimizeArgs, GenerateQuantitiesArgs, VariationalArgs
711711
],

cmdstanpy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def parse_method_vars(names: Tuple[str, ...]) -> Dict[str, Tuple[int, ...]]:
838838
if names is None:
839839
raise ValueError('missing argument "names"')
840840
# note: method vars are currently all scalar so not checking for structure
841-
return {v: tuple([k]) for (k, v) in enumerate(names) if v.endswith('__')}
841+
return {v: (k,) for (k, v) in enumerate(names) if v.endswith('__')}
842842

843843

844844
def parse_stan_vars(

test/test_cxx_installation.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,38 @@
33
import platform
44
import unittest
55

6+
import pytest
7+
68
from cmdstanpy import install_cxx_toolchain
79

810

911
class InstallCxxScriptTest(unittest.TestCase):
12+
@pytest.mark.skipif(
13+
platform.system() != 'Windows', reason='Windows only tests'
14+
)
1015
def test_config(self):
1116
"""Test config output."""
12-
if platform.system() != 'Windows':
13-
return
14-
else:
15-
config = install_cxx_toolchain.get_config('C:\\RTools', True)
16-
17-
config_reference = [
18-
'/SP-',
19-
'/VERYSILENT',
20-
'/SUPPRESSMSGBOXES',
21-
'/CURRENTUSER',
22-
'LANG="English"',
23-
'/DIR="RTools"',
24-
'/NOICONS',
25-
'/NORESTART',
26-
]
27-
28-
self.assertEqual(config, config_reference)
2917

18+
config = install_cxx_toolchain.get_config('C:\\RTools', True)
19+
20+
config_reference = [
21+
'/SP-',
22+
'/VERYSILENT',
23+
'/SUPPRESSMSGBOXES',
24+
'/CURRENTUSER',
25+
'LANG="English"',
26+
'/DIR="RTools"',
27+
'/NOICONS',
28+
'/NORESTART',
29+
]
30+
31+
self.assertEqual(config, config_reference)
32+
33+
@pytest.mark.skipif(
34+
platform.system() == 'Windows', reason='Windows only tests'
35+
)
3036
def test_install_not_windows(self):
3137
"""Try to install on unsupported platform."""
32-
if platform.system() == 'Windows':
33-
return
3438

3539
with self.assertRaisesRegex(
3640
NotImplementedError,
@@ -39,30 +43,30 @@ def test_install_not_windows(self):
3943
):
4044
install_cxx_toolchain.main({})
4145

46+
@pytest.mark.skipif(
47+
platform.system() != 'Windows', reason='Windows only tests'
48+
)
4249
def test_normalize_version(self):
4350
"""Test supported versions."""
44-
if platform.system() != 'Windows':
45-
return
46-
else:
47-
for ver in ['4.0', '4', '40']:
48-
self.assertEqual(
49-
install_cxx_toolchain.normalize_version(ver), '4.0'
50-
)
51-
52-
for ver in ['3.5', '35']:
53-
self.assertEqual(
54-
install_cxx_toolchain.normalize_version(ver), '3.5'
55-
)
5651

57-
def test_toolchain_name(self):
58-
"""Check toolchain name."""
59-
if platform.system() != 'Windows':
60-
return
61-
else:
52+
for ver in ['4.0', '4', '40']:
6253
self.assertEqual(
63-
install_cxx_toolchain.get_toolchain_name(), 'RTools'
54+
install_cxx_toolchain.normalize_version(ver), '4.0'
6455
)
6556

57+
for ver in ['3.5', '35']:
58+
self.assertEqual(
59+
install_cxx_toolchain.normalize_version(ver), '3.5'
60+
)
61+
62+
@pytest.mark.skipif(
63+
platform.system() != 'Windows', reason='Windows only tests'
64+
)
65+
def test_toolchain_name(self):
66+
"""Check toolchain name."""
67+
68+
self.assertEqual(install_cxx_toolchain.get_toolchain_name(), 'RTools')
69+
6670

6771
if __name__ == '__main__':
6872
unittest.main()

test/test_generate_quantities.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,10 @@ def test_show_console(self):
418418
show_console=True,
419419
)
420420
console = sys_stdout.getvalue()
421-
self.assertTrue('Chain [1] method = generate' in console)
422-
self.assertTrue('Chain [2] method = generate' in console)
423-
self.assertTrue('Chain [3] method = generate' in console)
424-
self.assertTrue('Chain [4] method = generate' in console)
421+
self.assertIn('Chain [1] method = generate', console)
422+
self.assertIn('Chain [2] method = generate', console)
423+
self.assertIn('Chain [3] method = generate', console)
424+
self.assertIn('Chain [4] method = generate', console)
425425

426426
def test_complex_output(self):
427427
stan_bern = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

test/test_optimize.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_variable_bern(self):
192192
history_size=5,
193193
)
194194
self.assertEqual(1, len(bern_mle.metadata.stan_vars_dims))
195-
self.assertTrue('theta' in bern_mle.metadata.stan_vars_dims)
195+
self.assertIn('theta', bern_mle.metadata.stan_vars_dims)
196196
self.assertEqual(bern_mle.metadata.stan_vars_dims['theta'], ())
197197
theta = bern_mle.stan_variable(var='theta')
198198
self.assertTrue(isinstance(theta, float))
@@ -219,7 +219,7 @@ def test_variables_3d(self):
219219
history_size=5,
220220
)
221221
self.assertEqual(3, len(multidim_mle.metadata.stan_vars_dims))
222-
self.assertTrue('y_rep' in multidim_mle.metadata.stan_vars_dims)
222+
self.assertIn('y_rep', multidim_mle.metadata.stan_vars_dims)
223223
self.assertEqual(
224224
multidim_mle.metadata.stan_vars_dims['y_rep'], (5, 4, 3)
225225
)
@@ -231,11 +231,11 @@ def test_variables_3d(self):
231231
self.assertTrue(isinstance(var_frac_60, float))
232232
vars = multidim_mle.stan_variables()
233233
self.assertEqual(len(vars), len(multidim_mle.metadata.stan_vars_dims))
234-
self.assertTrue('y_rep' in vars)
234+
self.assertIn('y_rep', vars)
235235
self.assertEqual(vars['y_rep'].shape, (5, 4, 3))
236-
self.assertTrue('beta' in vars)
236+
self.assertIn('beta', vars)
237237
self.assertEqual(vars['beta'].shape, (2,))
238-
self.assertTrue('frac_60' in vars)
238+
self.assertIn('frac_60', vars)
239239
self.assertTrue(isinstance(vars['frac_60'], float))
240240

241241
multidim_mle_iters = multidim_model.optimize(
@@ -256,11 +256,11 @@ def test_variables_3d(self):
256256
self.assertEqual(
257257
len(vars_iters), len(multidim_mle_iters.metadata.stan_vars_dims)
258258
)
259-
self.assertTrue('y_rep' in vars_iters)
259+
self.assertIn('y_rep', vars_iters)
260260
self.assertEqual(vars_iters['y_rep'].shape, (8, 5, 4, 3))
261-
self.assertTrue('beta' in vars_iters)
261+
self.assertIn('beta', vars_iters)
262262
self.assertEqual(vars_iters['beta'].shape, (8, 2))
263-
self.assertTrue('frac_60' in vars_iters)
263+
self.assertIn('frac_60', vars_iters)
264264
self.assertEqual(vars_iters['frac_60'].shape, (8,))
265265

266266

@@ -580,7 +580,7 @@ def test_show_console(self):
580580
show_console=True,
581581
)
582582
console = sys_stdout.getvalue()
583-
self.assertTrue('Chain [1] method = optimize' in console)
583+
self.assertIn('Chain [1] method = optimize', console)
584584

585585
def test_exe_only(self):
586586
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

test/test_sample.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ def test_multi_proc_threads(self):
441441
cpp_options={'STAN_THREADS': 'TRUE'},
442442
)
443443
info_dict = logistic_model.exe_info()
444-
self.assertTrue(info_dict is not None)
445-
self.assertTrue('STAN_THREADS' in info_dict)
444+
self.assertIsNotNone(info_dict)
445+
self.assertIn('STAN_THREADS', info_dict)
446446
self.assertEqual(info_dict['STAN_THREADS'], 'true')
447447

448448
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')
@@ -642,8 +642,8 @@ def test_show_console(self, stanfile='bernoulli.stan'):
642642
show_console=True,
643643
)
644644
console = sys_stdout.getvalue()
645-
self.assertTrue('Chain [1] method = sample' in console)
646-
self.assertTrue('Chain [2] method = sample' in console)
645+
self.assertIn('Chain [1] method = sample', console)
646+
self.assertIn('Chain [2] method = sample', console)
647647

648648
def test_show_progress(self, stanfile='bernoulli.stan'):
649649
stan = os.path.join(DATAFILES_PATH, stanfile)
@@ -660,9 +660,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
660660
show_progress=True,
661661
)
662662
console = sys_stderr.getvalue()
663-
self.assertTrue('chain 1' in console)
664-
self.assertTrue('chain 2' in console)
665-
self.assertTrue('Sampling completed' in console)
663+
self.assertIn('chain 1', console)
664+
self.assertIn('chain 2', console)
665+
self.assertIn('Sampling completed', console)
666666

667667
sys_stderr = io.StringIO() # tqdm prints to stderr
668668
with contextlib.redirect_stderr(sys_stderr):
@@ -674,9 +674,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
674674
show_progress=True,
675675
)
676676
console = sys_stderr.getvalue()
677-
self.assertTrue('chain 6' in console)
678-
self.assertTrue('chain 7' in console)
679-
self.assertTrue('Sampling completed' in console)
677+
self.assertIn('chain 6', console)
678+
self.assertIn('chain 7', console)
679+
self.assertIn('Sampling completed', console)
680680
sys_stderr = io.StringIO() # tqdm prints to stderr
681681

682682
with contextlib.redirect_stderr(sys_stderr):
@@ -690,9 +690,9 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
690690
show_progress=True,
691691
)
692692
console = sys_stderr.getvalue()
693-
self.assertTrue('chain 6' in console)
694-
self.assertTrue('chain 7' in console)
695-
self.assertTrue('Sampling completed' in console)
693+
self.assertIn('chain 6', console)
694+
self.assertIn('chain 7', console)
695+
self.assertIn('Sampling completed', console)
696696

697697

698698
class CmdStanMCMCTest(CustomTestCase):
@@ -1388,7 +1388,7 @@ def test_variable_bern(self):
13881388
data=jdata, chains=2, seed=12345, iter_warmup=100, iter_sampling=100
13891389
)
13901390
self.assertEqual(1, len(bern_fit.metadata.stan_vars_dims))
1391-
self.assertTrue('theta' in bern_fit.metadata.stan_vars_dims)
1391+
self.assertIn('theta', bern_fit.metadata.stan_vars_dims)
13921392
self.assertEqual(bern_fit.metadata.stan_vars_dims['theta'], ())
13931393
self.assertEqual(bern_fit.stan_variable(var='theta').shape, (200,))
13941394
with self.assertRaises(ValueError):
@@ -1401,13 +1401,13 @@ def test_variables_2d(self):
14011401
fit = from_csv(path=csvfiles_path)
14021402
self.assertEqual(20, fit.num_draws_sampling)
14031403
self.assertEqual(8, len(fit.metadata.stan_vars_dims))
1404-
self.assertTrue('z' in fit.metadata.stan_vars_dims)
1404+
self.assertIn('z', fit.metadata.stan_vars_dims)
14051405
self.assertEqual(fit.metadata.stan_vars_dims['z'], (20, 2))
14061406
vars = fit.stan_variables()
14071407
self.assertEqual(len(vars), len(fit.metadata.stan_vars_dims))
1408-
self.assertTrue('z' in vars)
1408+
self.assertIn('z', vars)
14091409
self.assertEqual(vars['z'].shape, (20, 20, 2))
1410-
self.assertTrue('theta' in vars)
1410+
self.assertIn('theta', vars)
14111411
self.assertEqual(vars['theta'].shape, (20, 4))
14121412

14131413
def test_variables_3d(self):
@@ -1416,7 +1416,7 @@ def test_variables_3d(self):
14161416
fit = from_csv(path=csvfiles_path)
14171417
self.assertEqual(20, fit.num_draws_sampling)
14181418
self.assertEqual(3, len(fit.metadata.stan_vars_dims))
1419-
self.assertTrue('y_rep' in fit.metadata.stan_vars_dims)
1419+
self.assertIn('y_rep', fit.metadata.stan_vars_dims)
14201420
self.assertEqual(fit.metadata.stan_vars_dims['y_rep'], (5, 4, 3))
14211421
var_y_rep = fit.stan_variable(var='y_rep')
14221422
self.assertEqual(var_y_rep.shape, (20, 5, 4, 3))
@@ -1426,11 +1426,11 @@ def test_variables_3d(self):
14261426
self.assertEqual(var_frac_60.shape, (20,))
14271427
vars = fit.stan_variables()
14281428
self.assertEqual(len(vars), len(fit.metadata.stan_vars_dims))
1429-
self.assertTrue('y_rep' in vars)
1429+
self.assertIn('y_rep', vars)
14301430
self.assertEqual(vars['y_rep'].shape, (20, 5, 4, 3))
1431-
self.assertTrue('beta' in vars)
1431+
self.assertIn('beta', vars)
14321432
self.assertEqual(vars['beta'].shape, (20, 2))
1433-
self.assertTrue('frac_60' in vars)
1433+
self.assertIn('frac_60', vars)
14341434
self.assertEqual(vars['frac_60'].shape, (20,))
14351435

14361436
def test_variables_issue_361(self):
@@ -1592,18 +1592,16 @@ def test_metadata(self):
15921592
fit = CmdStanMCMC(runset)
15931593
meta = fit.metadata
15941594
self.assertEqual(meta.cmdstan_config['model'], 'logistic_model')
1595-
col_names = tuple(
1596-
[
1597-
'lp__',
1598-
'accept_stat__',
1599-
'stepsize__',
1600-
'treedepth__',
1601-
'n_leapfrog__',
1602-
'divergent__',
1603-
'energy__',
1604-
'beta[1]',
1605-
'beta[2]',
1606-
]
1595+
col_names = (
1596+
'lp__',
1597+
'accept_stat__',
1598+
'stepsize__',
1599+
'treedepth__',
1600+
'n_leapfrog__',
1601+
'divergent__',
1602+
'energy__',
1603+
'beta[1]',
1604+
'beta[2]',
16071605
)
16081606

16091607
self.assertEqual(fit.chains, 4)
@@ -1619,14 +1617,14 @@ def test_metadata(self):
16191617
self.assertEqual(fit.metadata.cmdstan_config['metric'], 'diag_e')
16201618
self.assertAlmostEqual(fit.metadata.cmdstan_config['delta'], 0.80)
16211619

1622-
self.assertTrue('n_leapfrog__' in fit.metadata.method_vars_cols)
1623-
self.assertTrue('energy__' in fit.metadata.method_vars_cols)
1624-
self.assertTrue('beta' not in fit.metadata.method_vars_cols)
1625-
self.assertTrue('energy__' not in fit.metadata.stan_vars_dims)
1626-
self.assertTrue('beta' in fit.metadata.stan_vars_dims)
1627-
self.assertTrue('beta' in fit.metadata.stan_vars_cols)
1628-
self.assertEqual(fit.metadata.stan_vars_dims['beta'], tuple([2]))
1629-
self.assertEqual(fit.metadata.stan_vars_cols['beta'], tuple([7, 8]))
1620+
self.assertIn('n_leapfrog__', fit.metadata.method_vars_cols)
1621+
self.assertIn('energy__', fit.metadata.method_vars_cols)
1622+
self.assertNotIn('beta', fit.metadata.method_vars_cols)
1623+
self.assertNotIn('energy__', fit.metadata.stan_vars_dims)
1624+
self.assertIn('beta', fit.metadata.stan_vars_dims)
1625+
self.assertIn('beta', fit.metadata.stan_vars_cols)
1626+
self.assertEqual(fit.metadata.stan_vars_dims['beta'], (2,))
1627+
self.assertEqual(fit.metadata.stan_vars_cols['beta'], (7, 8))
16301628

16311629
def test_save_latent_dynamics(self):
16321630
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

0 commit comments

Comments
 (0)