Skip to content

Commit 6677875

Browse files
committed
more unit tests, cleanup
1 parent 7338633 commit 6677875

3 files changed

Lines changed: 20 additions & 6 deletions

File tree

cmdstanpy/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]:
557557

558558
def check_sampler_csv(
559559
path: str,
560-
is_fixed_param: Optional[bool] = None,
560+
is_fixed_param: bool = False,
561561
iter_sampling: Optional[int] = None,
562562
iter_warmup: Optional[int] = None,
563563
save_warmup: bool = False,
@@ -610,9 +610,7 @@ def check_sampler_csv(
610610
return meta
611611

612612

613-
def scan_sampler_csv(
614-
path: str, is_fixed_param: Optional[bool]
615-
) -> Dict[str, Any]:
613+
def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
616614
"""Process sampler stan_csv output file line by line."""
617615
dict: Dict[str, Any] = {}
618616
lineno = 0

test/test_model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,24 @@ def test_model_good(self):
5757
self.assertEqual(BERN_STAN, model.stan_file)
5858
self.assertPathsEqual(model.exe_file, BERN_EXE)
5959

60+
def test_ctor_compile_arg(self):
6061
# instantiate, don't compile
61-
os.remove(BERN_EXE)
62+
if os.path.exists(BERN_EXE):
63+
os.remove(BERN_EXE)
6264
model = CmdStanModel(stan_file=BERN_STAN, compile=False)
6365
self.assertEqual(BERN_STAN, model.stan_file)
6466
self.assertEqual(None, model.exe_file)
6567

68+
model = CmdStanModel(stan_file=BERN_STAN, compile=True)
69+
self.assertPathsEqual(model.exe_file, BERN_EXE)
70+
exe_time = os.path.getmtime(model.exe_file)
71+
72+
model = CmdStanModel(stan_file=BERN_STAN)
73+
self.assertTrue(exe_time == os.path.getmtime(model.exe_file))
74+
75+
model = CmdStanModel(stan_file=BERN_STAN, compile='force')
76+
self.assertTrue(exe_time < os.path.getmtime(model.exe_file))
77+
6678
def test_exe_only(self):
6779
model = CmdStanModel(stan_file=BERN_STAN)
6880
self.assertEqual(BERN_EXE, model.exe_file)
@@ -109,6 +121,10 @@ def test_model_bad(self):
109121
CmdStanModel(
110122
stan_file=os.path.join(DATAFILES_PATH, "external.stan")
111123
)
124+
CmdStanModel(stan_file=BERN_STAN)
125+
os.remove(BERN_EXE)
126+
with self.assertRaises(ValueError):
127+
CmdStanModel(stan_file=BERN_STAN, exe_file=BERN_EXE)
112128

113129
def test_stanc_options(self):
114130
opts = {

test/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def test_fixed_param_unspecified(self):
605605
os.chmod(exe_only, 0o755)
606606
datagen2_model = CmdStanModel(exe_file=exe_only)
607607
datagen2_fit = datagen2_model.sample(
608-
iter_sampling=200, show_progress=False
608+
iter_sampling=200, show_console=True
609609
)
610610
self.assertEqual(datagen2_fit.chains, 4)
611611
self.assertEqual(datagen2_fit.step_size, None)

0 commit comments

Comments
 (0)