|
9 | 9 | import unittest |
10 | 10 | from test import CustomTestCase |
11 | 11 |
|
| 12 | +import pytest |
12 | 13 | from testfixtures import LogCapture, StringComparison |
13 | 14 |
|
14 | 15 | from cmdstanpy.model import CmdStanModel |
@@ -377,33 +378,32 @@ def test_model_includes_implicit(self): |
377 | 378 | model2 = CmdStanModel(stan_file=stan) |
378 | 379 | self.assertPathsEqual(model2.exe_file, exe) |
379 | 380 |
|
| 381 | + @pytest.mark.skipif( |
| 382 | + not cmdstan_version_before(2, 32), |
| 383 | + reason="Deprecated syntax removed in Stan 2.32", |
| 384 | + ) |
380 | 385 | def test_model_format(self): |
381 | | - # deprecations expire in this version |
382 | | - if cmdstan_version_before(2, 32): |
383 | | - stan = os.path.join(DATAFILES_PATH, 'format_me.stan') |
| 386 | + stan = os.path.join(DATAFILES_PATH, 'format_me.stan') |
384 | 387 |
|
385 | | - model = CmdStanModel(stan_file=stan, compile=False) |
| 388 | + model = CmdStanModel(stan_file=stan, compile=False) |
386 | 389 |
|
387 | | - sys_stdout = io.StringIO() |
388 | | - with contextlib.redirect_stdout(sys_stdout): |
389 | | - model.format_model() |
| 390 | + sys_stdout = io.StringIO() |
| 391 | + with contextlib.redirect_stdout(sys_stdout): |
| 392 | + model.format_model() |
390 | 393 |
|
391 | | - formatted = sys_stdout.getvalue() |
392 | | - self.assertIn("//", formatted) |
393 | | - self.assertNotIn("#", formatted) |
394 | | - self.assertEqual(formatted.count('('), 5) |
| 394 | + formatted = sys_stdout.getvalue() |
| 395 | + self.assertIn("//", formatted) |
| 396 | + self.assertNotIn("#", formatted) |
| 397 | + self.assertEqual(formatted.count('('), 5) |
395 | 398 |
|
396 | | - sys_stdout = io.StringIO() |
397 | | - with contextlib.redirect_stdout(sys_stdout): |
398 | | - model.format_model(canonicalize=True) |
| 399 | + sys_stdout = io.StringIO() |
| 400 | + with contextlib.redirect_stdout(sys_stdout): |
| 401 | + model.format_model(canonicalize=True) |
399 | 402 |
|
400 | | - formatted = sys_stdout.getvalue() |
401 | | - print(formatted) |
402 | | - self.assertNotIn("<-", formatted) |
403 | | - self.assertEqual(formatted.count('('), 0) |
404 | | - |
405 | | - else: |
406 | | - assert False, "Test needs to be updated for Stan 2.32" |
| 403 | + formatted = sys_stdout.getvalue() |
| 404 | + print(formatted) |
| 405 | + self.assertNotIn("<-", formatted) |
| 406 | + self.assertEqual(formatted.count('('), 0) |
407 | 407 |
|
408 | 408 |
|
409 | 409 | if __name__ == '__main__': |
|
0 commit comments