Skip to content

Commit 2d278bb

Browse files
committed
fixed errors in code and unit tests; pinned unit test order
1 parent d3e58e3 commit 2d278bb

6 files changed

Lines changed: 12 additions & 8 deletions

File tree

cmdstanpy/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,9 +1320,8 @@ def _run_cmdstan(
13201320
fd_out.write(stdout)
13211321
if show_console:
13221322
lines = stdout.split('\n')
1323-
print(f' ****** lines to process: {len(lines)}')
13241323
for line in lines:
1325-
print('chain {}: {}'.format(idx + 1, stdout))
1324+
print('chain {}: {}'.format(idx + 1, line))
13261325

13271326
fd_out.close()
13281327
except OSError as e:

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ flake8
22
pylint
33
pytest
44
pytest-cov
5+
pytest-order
56
mypy
67
testfixtures
78
tqdm

test/test_generate_quantities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def test_show_console(self):
470470
mcmc_sample=bern_fit,
471471
show_console=True,
472472
)
473-
console = sys_stdout.getvalue()
473+
console = sys_stdout.getvalue()
474474
self.assertTrue('chain 1: method = generate' in console)
475475
self.assertTrue('chain 2: method = generate' in console)
476476
self.assertTrue('chain 3: method = generate' in console)

test/test_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def test_show_console(self):
603603
data=jdata,
604604
show_console=True,
605605
)
606-
console = sys_stdout.getvalue()
606+
console = sys_stdout.getvalue()
607607
self.assertTrue('chain 1: method = optimize' in console)
608608

609609

test/test_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def test_show_console(self, stanfile='bernoulli.stan'):
578578
iter_sampling=100,
579579
show_console=True,
580580
)
581-
console = sys_stdout.getvalue()
581+
console = sys_stdout.getvalue()
582582
self.assertTrue('chain 1: method = sample' in console)
583583
self.assertTrue('chain 2: method = sample' in console)
584584

@@ -592,7 +592,7 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
592592
bern_model.sample(
593593
data=jdata, chains=2, parallel_chains=2, show_progress=True
594594
)
595-
console = sys_stderr.getvalue()
595+
console = sys_stderr.getvalue()
596596
self.assertTrue('chain 1' in console)
597597
self.assertTrue('chain 2' in console)
598598
self.assertTrue('Sampling completed' in console)

test/test_variational.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ def test_variational_eta_fail(self):
246246
def test_single_row_csv(self):
247247
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
248248
model = CmdStanModel(stan_file=stan)
249-
vb_fit = model.variational()
249+
# testing data parsing, allow non-convergence
250+
vb_fit = model.variational(require_converged=False, seed=12345)
250251
self.assertTrue(isinstance(vb_fit.stan_variable('theta'), float))
251252
z_as_ndarray = vb_fit.stan_variable(var="z")
252253
self.assertEqual(z_as_ndarray.shape, (4, 3))
@@ -261,11 +262,14 @@ def test_show_console(self):
261262

262263
sys_stdout = io.StringIO()
263264
with contextlib.redirect_stdout(sys_stdout):
265+
# testing data parsing, allow non-convergence
264266
bern_model.variational(
265267
data=jdata,
266268
show_console=True,
269+
require_converged=False,
270+
seed=12345,
267271
)
268-
console = sys_stdout.getvalue()
272+
console = sys_stdout.getvalue()
269273
self.assertTrue('chain 1: method = variational' in console)
270274

271275

0 commit comments

Comments
 (0)