Skip to content

Commit a0b1521

Browse files
committed
fixing unit tests
1 parent 65419b2 commit a0b1521

3 files changed

Lines changed: 26 additions & 19 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -626,19 +626,21 @@ def _validate_csv_files(self) -> Dict[str, Any]:
626626
thin=self._thin,
627627
)
628628
for key in dzero:
629-
# TODO: only check args that matter for CSV parsing
629+
# check args that matter for parsing, plus name, version
630630
if (
631631
key
632-
not in [
633-
'id',
634-
'algorithm',
635-
'diagnostic_file',
636-
'metric_file',
637-
'profile_file',
638-
'stepsize',
639-
'init',
640-
'seed',
641-
'start_datetime',
632+
in [
633+
'stan_version_major',
634+
'stan_version_minor',
635+
'stan_version_patch',
636+
'stanc_version',
637+
'model',
638+
'num_samples',
639+
'num_warmup',
640+
'save_warmup',
641+
'thin',
642+
'refresh',
643+
'metric',
642644
]
643645
and dzero[key] != drest[key]
644646
):

test/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
# after we have run all tests, use git to delete the built files in data/
12+
13+
1214
@pytest.fixture(scope='session', autouse=True)
1315
def cleanup_test_files():
1416
yield

test/test_sample.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cmdstanpy.cmdstan_args import CmdStanArgs, Method, SamplerArgs
2727
from cmdstanpy.model import CmdStanModel
2828
from cmdstanpy.stanfit import CmdStanMCMC, RunSet, from_csv
29-
from cmdstanpy.utils import EXTENSION, cmdstan_version_before
29+
from cmdstanpy.utils import EXTENSION, cmdstan_version_before, model_info
3030

3131
HERE = os.path.dirname(os.path.abspath(__file__))
3232
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -440,13 +440,18 @@ def test_multi_proc_threads(self):
440440
# 2.28 compile with cpp_options={'STAN_THREADS':'true'}
441441
if not cmdstan_version_before(2, 28):
442442
logistic_stan = os.path.join(DATAFILES_PATH, 'logistic.stan')
443-
logistic_model = CmdStanModel(
444-
stan_file=logistic_stan,
445-
compile=True,
446-
cpp_options={'STAN_THREADS': 'true'},
443+
logistic_model = CmdStanModel(stan_file=logistic_stan)
444+
445+
os.remove(logistic_model.exe_file)
446+
logistic_model.compile(
447+
force=True,
448+
cpp_options={'STAN_THREADS': 'TRUE'},
447449
)
448-
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')
450+
info_dict = model_info(logistic_model.exe_file)
451+
self.assertTrue(info_dict is not None)
452+
self.assertTrue('STAN_THREADS' in info_dict)
449453

454+
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')
450455
with LogCapture() as log:
451456
logging.getLogger()
452457
logistic_model.sample(
@@ -598,7 +603,6 @@ def test_fixed_param_unspecified(self):
598603
datagen_fit = datagen_model.sample(
599604
iter_sampling=100, show_progress=False
600605
)
601-
print(datagen_fit)
602606
self.assertEqual(datagen_fit.step_size, None)
603607

604608
def test_bernoulli_file_with_space(self):
@@ -786,7 +790,6 @@ def test_instantiate_from_csvfiles(self):
786790
if file.endswith(".csv"):
787791
csvfiles.append(os.path.join(csvfiles_path, file))
788792
bern_fit = from_csv(path=csvfiles)
789-
print(bern_fit.metadata.method_vars_cols.keys())
790793

791794
draws_pd = bern_fit.draws_pd()
792795
self.assertEqual(

0 commit comments

Comments
 (0)