|
26 | 26 | from cmdstanpy.cmdstan_args import CmdStanArgs, Method, SamplerArgs |
27 | 27 | from cmdstanpy.model import CmdStanModel |
28 | 28 | from cmdstanpy.stanfit import CmdStanMCMC, RunSet, from_csv |
29 | | -from cmdstanpy.utils import EXTENSION, cmdstan_version_before |
| 29 | +from cmdstanpy.utils import EXTENSION, cmdstan_version_before, model_info |
30 | 30 |
|
31 | 31 | HERE = os.path.dirname(os.path.abspath(__file__)) |
32 | 32 | DATAFILES_PATH = os.path.join(HERE, 'data') |
@@ -440,13 +440,18 @@ def test_multi_proc_threads(self): |
440 | 440 | # 2.28 compile with cpp_options={'STAN_THREADS':'true'} |
441 | 441 | if not cmdstan_version_before(2, 28): |
442 | 442 | logistic_stan = os.path.join(DATAFILES_PATH, 'logistic.stan') |
443 | | - logistic_model = CmdStanModel( |
444 | | - stan_file=logistic_stan, |
445 | | - compile=True, |
446 | | - cpp_options={'STAN_THREADS': 'true'}, |
| 443 | + logistic_model = CmdStanModel(stan_file=logistic_stan) |
| 444 | + |
| 445 | + os.remove(logistic_model.exe_file) |
| 446 | + logistic_model.compile( |
| 447 | + force=True, |
| 448 | + cpp_options={'STAN_THREADS': 'TRUE'}, |
447 | 449 | ) |
448 | | - logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') |
| 450 | + info_dict = model_info(logistic_model.exe_file) |
| 451 | + self.assertTrue(info_dict is not None) |
| 452 | + self.assertTrue('STAN_THREADS' in info_dict) |
449 | 453 |
|
| 454 | + logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R') |
450 | 455 | with LogCapture() as log: |
451 | 456 | logging.getLogger() |
452 | 457 | logistic_model.sample( |
@@ -598,7 +603,6 @@ def test_fixed_param_unspecified(self): |
598 | 603 | datagen_fit = datagen_model.sample( |
599 | 604 | iter_sampling=100, show_progress=False |
600 | 605 | ) |
601 | | - print(datagen_fit) |
602 | 606 | self.assertEqual(datagen_fit.step_size, None) |
603 | 607 |
|
604 | 608 | def test_bernoulli_file_with_space(self): |
@@ -786,7 +790,6 @@ def test_instantiate_from_csvfiles(self): |
786 | 790 | if file.endswith(".csv"): |
787 | 791 | csvfiles.append(os.path.join(csvfiles_path, file)) |
788 | 792 | bern_fit = from_csv(path=csvfiles) |
789 | | - print(bern_fit.metadata.method_vars_cols.keys()) |
790 | 793 |
|
791 | 794 | draws_pd = bern_fit.draws_pd() |
792 | 795 | self.assertEqual( |
|
0 commit comments