Skip to content

Commit e1bd06a

Browse files
committed
More safeguards and tests
1 parent fc2d638 commit e1bd06a

2 files changed

Lines changed: 46 additions & 8 deletions

File tree

cmdstanpy/model.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
MaybeDictToFilePath,
4040
SanitizedOrTmpFilePath,
4141
cmdstan_path,
42+
cmdstan_version,
4243
cmdstan_version_before,
4344
do_command,
4445
get_logger,
@@ -333,14 +334,24 @@ def format(
333334
)
334335

335336
if canonicalize:
336-
if cmdstan_version_before(2, 29) or isinstance(
337-
canonicalize, bool
338-
):
339-
cmd.append('--print-canonical')
340-
elif isinstance(canonicalize, str):
341-
cmd.append('--canonicalize=' + canonicalize)
342-
elif isinstance(canonicalize, Iterable):
343-
cmd.append('--canonicalize=' + ','.join(canonicalize))
337+
if cmdstan_version_before(2, 29):
338+
if isinstance(canonicalize, bool):
339+
cmd.append('--print-canonical')
340+
else:
341+
raise ValueError(
342+
"Invalid arguments passed for current CmdStan"
343+
+ " version({})\n".format(
344+
cmdstan_version() or "Unknown"
345+
)
346+
+ "--canonicalize requires 2.29 or higher"
347+
)
348+
else:
349+
if isinstance(canonicalize, str):
350+
cmd.append('--canonicalize=' + canonicalize)
351+
elif isinstance(canonicalize, Iterable):
352+
cmd.append('--canonicalize=' + ','.join(canonicalize))
353+
else:
354+
cmd.append('--print-canonical')
344355

345356
# before 2.29, having both --print-canonical
346357
# and --auto-format printed twice
@@ -349,6 +360,12 @@ def format(
349360

350361
if not cmdstan_version_before(2, 29):
351362
cmd.append(f'--max-line-length={max_line_length}')
363+
elif max_line_length != 78:
364+
raise ValueError(
365+
"Invalid arguments passed for current CmdStan version"
366+
+ " ({})\n".format(cmdstan_version() or "Unknown")
367+
+ "--max-line-length requires 2.29 or higher"
368+
)
352369

353370
out = subprocess.run(
354371
cmd, capture_output=True, text=True, check=True

test/test_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import unittest
1010
from test import CustomTestCase
11+
from unittest.mock import MagicMock, patch
1112

1213
import pytest
1314
from testfixtures import LogCapture, StringComparison
@@ -433,6 +434,26 @@ def test_model_format_options(self):
433434
self.assertEqual(formatted.count('{'), 1)
434435
self.assertEqual(formatted.count('('), 1)
435436

437+
sys_stdout = io.StringIO()
438+
with contextlib.redirect_stdout(sys_stdout):
439+
model.format(canonicalize=True)
440+
formatted = sys_stdout.getvalue()
441+
self.assertEqual(formatted.count('{'), 3)
442+
self.assertEqual(formatted.count('('), 1)
443+
444+
@patch('cmdstanpy.utils.cmdstan_version', MagicMock(return_value=(2, 27)))
445+
def test_format_old_version(self):
446+
self.assertTrue(cmdstan_version_before(2, 28))
447+
448+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
449+
model = CmdStanModel(stan_file=stan, compile=False)
450+
with self.assertRaisesRegexNested(RuntimeError, r"--canonicalize"):
451+
model.format(canonicalize='braces')
452+
with self.assertRaisesRegexNested(RuntimeError, r"--max-line"):
453+
model.format(max_line_length=88)
454+
455+
model.format(canonicalize=True)
456+
436457

437458
if __name__ == '__main__':
438459
unittest.main()

0 commit comments

Comments
 (0)