Skip to content

Commit ef23b3b

Browse files
authored
Merge pull request #539 from stan-dev/model-formatting
Model formatting
2 parents d839f02 + e5d7531 commit ef23b3b

5 files changed

Lines changed: 201 additions & 3 deletions

File tree

cmdstanpy/model.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import sys
1010
from collections import OrderedDict
1111
from concurrent.futures import ThreadPoolExecutor
12+
from datetime import datetime
1213
from io import StringIO
1314
from multiprocessing import cpu_count
1415
from pathlib import Path
15-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
16+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
1617

1718
import ujson as json
1819
from tqdm.auto import tqdm
@@ -39,6 +40,7 @@
3940
MaybeDictToFilePath,
4041
SanitizedOrTmpFilePath,
4142
cmdstan_path,
43+
cmdstan_version,
4244
cmdstan_version_before,
4345
do_command,
4446
get_logger,
@@ -297,6 +299,98 @@ def src_info(self) -> Dict[str, Any]:
297299
get_logger().debug(e)
298300
return result
299301

302+
def format(
303+
self,
304+
overwrite_file: bool = False,
305+
canonicalize: Union[bool, str, Iterable[str]] = False,
306+
max_line_length: int = 78,
307+
*,
308+
backup: bool = True,
309+
) -> None:
310+
"""
311+
Run stanc's auto-formatter on the model code. Either saves directly
312+
back to the file or prints for inspection
313+
314+
315+
:param overwrite_file: If True, save the updated code to disk, rather
316+
than printing it. By default False
317+
:param canonicalize: Whether or not the compiler should 'canonicalize'
318+
the Stan model, removing things like deprecated syntax. Default is
319+
False. If True, all canonicalizations are run. If it is a list of
320+
strings, those options are passed to stanc (new in Stan 2.29)
321+
:param max_line_length: Set the wrapping point for the formatter. The
322+
default value is 78, which wraps most lines by the 80th character.
323+
:param backup: If True, create a stanfile.bak backup before
324+
writing to the file. Only disable this if you're sure you have other
325+
copies of the file or are using a version control system like Git.
326+
"""
327+
if self.stan_file is None or not os.path.isfile(self.stan_file):
328+
raise ValueError("No Stan file found for this module")
329+
try:
330+
cmd = (
331+
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
332+
# handle include-paths, allow-undefined etc
333+
+ self._compiler_options.compose_stanc()
334+
+ [self.stan_file]
335+
)
336+
337+
if canonicalize:
338+
if cmdstan_version_before(2, 29):
339+
if isinstance(canonicalize, bool):
340+
cmd.append('--print-canonical')
341+
else:
342+
raise ValueError(
343+
"Invalid arguments passed for current CmdStan"
344+
+ " version({})\n".format(
345+
cmdstan_version() or "Unknown"
346+
)
347+
+ "--canonicalize requires 2.29 or higher"
348+
)
349+
else:
350+
if isinstance(canonicalize, str):
351+
cmd.append('--canonicalize=' + canonicalize)
352+
elif isinstance(canonicalize, Iterable):
353+
cmd.append('--canonicalize=' + ','.join(canonicalize))
354+
else:
355+
cmd.append('--print-canonical')
356+
357+
# before 2.29, having both --print-canonical
358+
# and --auto-format printed twice
359+
if not (cmdstan_version_before(2, 29) and canonicalize):
360+
cmd.append('--auto-format')
361+
362+
if not cmdstan_version_before(2, 29):
363+
cmd.append(f'--max-line-length={max_line_length}')
364+
elif max_line_length != 78:
365+
raise ValueError(
366+
"Invalid arguments passed for current CmdStan version"
367+
+ " ({})\n".format(cmdstan_version() or "Unknown")
368+
+ "--max-line-length requires 2.29 or higher"
369+
)
370+
371+
out = subprocess.run(
372+
cmd, capture_output=True, text=True, check=True
373+
)
374+
if out.stderr:
375+
get_logger().warning(out.stderr)
376+
result = out.stdout
377+
if overwrite_file:
378+
if result:
379+
if backup:
380+
shutil.copyfile(
381+
self.stan_file,
382+
self.stan_file
383+
+ '.bak-'
384+
+ datetime.now().strftime("%Y%m%d%H%M%S"),
385+
)
386+
with (open(self.stan_file, 'w')) as file_handle:
387+
file_handle.write(result)
388+
else:
389+
print(result)
390+
391+
except (ValueError, RuntimeError) as e:
392+
raise RuntimeError("Stanc formatting failed") from e
393+
300394
@property
301395
def stanc_options(self) -> Dict[str, Union[bool, int, str]]:
302396
"""Options to stanc compilers."""

test/data/.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
# and we re-ignore hpp and exe files
77
*.hpp
88
*.exe
9-
!return_one.hpp
9+
*.testbak
10+
*.bak-*
11+
!return_one.hpp

test/data/format_me.stan

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
generated quantities {
2+
array[10,10,10,10,10] matrix[100,100] a_very_long_name;
3+
int x = (((10)));
4+
int y;
5+
if (1)
6+
y = 1;
7+
else
8+
y=2;
9+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# pound-sign comment
2+
generated quantities {
3+
int x;
4+
x <- (((((3)))));
5+
}

test/test_model.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
"""CmdStanModel tests"""
22

3+
import contextlib
4+
import io
35
import logging
46
import os
57
import shutil
68
import tempfile
79
import unittest
10+
from glob import glob
811
from test import CustomTestCase
12+
from unittest.mock import MagicMock, patch
913

14+
import pytest
1015
from testfixtures import LogCapture, StringComparison
1116

1217
from cmdstanpy.model import CmdStanModel
13-
from cmdstanpy.utils import EXTENSION
18+
from cmdstanpy.utils import EXTENSION, cmdstan_version_before
1419

1520
HERE = os.path.dirname(os.path.abspath(__file__))
1621
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -34,6 +39,7 @@
3439
BERN_BASENAME = 'bernoulli'
3540

3641

42+
# pylint: disable=too-many-public-methods
3743
class CmdStanModelTest(CustomTestCase):
3844
def test_model_good(self):
3945
# compile on instantiation, override model name
@@ -374,6 +380,88 @@ def test_model_includes_implicit(self):
374380
model2 = CmdStanModel(stan_file=stan)
375381
self.assertPathsEqual(model2.exe_file, exe)
376382

383+
@pytest.mark.skipif(
384+
not cmdstan_version_before(2, 32),
385+
reason="Deprecated syntax removed in Stan 2.32",
386+
)
387+
def test_model_format_deprecations(self):
388+
stan = os.path.join(DATAFILES_PATH, 'format_me_deprecations.stan')
389+
390+
model = CmdStanModel(stan_file=stan, compile=False)
391+
392+
sys_stdout = io.StringIO()
393+
with contextlib.redirect_stdout(sys_stdout):
394+
model.format()
395+
396+
formatted = sys_stdout.getvalue()
397+
self.assertIn("//", formatted)
398+
self.assertNotIn("#", formatted)
399+
self.assertEqual(formatted.count('('), 5)
400+
401+
sys_stdout = io.StringIO()
402+
with contextlib.redirect_stdout(sys_stdout):
403+
model.format(canonicalize=True)
404+
405+
formatted = sys_stdout.getvalue()
406+
print(formatted)
407+
self.assertNotIn("<-", formatted)
408+
self.assertEqual(formatted.count('('), 0)
409+
410+
shutil.copy(stan, stan + '.testbak')
411+
try:
412+
model.format(overwrite_file=True, canonicalize=True)
413+
self.assertEqual(len(glob(stan + '.bak-*')), 1)
414+
finally:
415+
shutil.copy(stan + '.testbak', stan)
416+
417+
@pytest.mark.skipif(
418+
cmdstan_version_before(2, 29), reason='Options only available later'
419+
)
420+
def test_model_format_options(self):
421+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
422+
423+
model = CmdStanModel(stan_file=stan, compile=False)
424+
425+
sys_stdout = io.StringIO()
426+
with contextlib.redirect_stdout(sys_stdout):
427+
model.format(max_line_length=10)
428+
formatted = sys_stdout.getvalue()
429+
self.assertGreater(len(formatted.splitlines()), 11)
430+
431+
sys_stdout = io.StringIO()
432+
with contextlib.redirect_stdout(sys_stdout):
433+
model.format(canonicalize='braces')
434+
formatted = sys_stdout.getvalue()
435+
self.assertEqual(formatted.count('{'), 3)
436+
self.assertEqual(formatted.count('('), 4)
437+
438+
sys_stdout = io.StringIO()
439+
with contextlib.redirect_stdout(sys_stdout):
440+
model.format(canonicalize=['parentheses'])
441+
formatted = sys_stdout.getvalue()
442+
self.assertEqual(formatted.count('{'), 1)
443+
self.assertEqual(formatted.count('('), 1)
444+
445+
sys_stdout = io.StringIO()
446+
with contextlib.redirect_stdout(sys_stdout):
447+
model.format(canonicalize=True)
448+
formatted = sys_stdout.getvalue()
449+
self.assertEqual(formatted.count('{'), 3)
450+
self.assertEqual(formatted.count('('), 1)
451+
452+
@patch('cmdstanpy.utils.cmdstan_version', MagicMock(return_value=(2, 27)))
453+
def test_format_old_version(self):
454+
self.assertTrue(cmdstan_version_before(2, 28))
455+
456+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
457+
model = CmdStanModel(stan_file=stan, compile=False)
458+
with self.assertRaisesRegexNested(RuntimeError, r"--canonicalize"):
459+
model.format(canonicalize='braces')
460+
with self.assertRaisesRegexNested(RuntimeError, r"--max-line"):
461+
model.format(max_line_length=88)
462+
463+
model.format(canonicalize=True)
464+
377465

378466
if __name__ == '__main__':
379467
unittest.main()

0 commit comments

Comments
 (0)