Skip to content

Commit fc2d638

Browse files
committed
Renaming and more tests
1 parent 6c9c762 commit fc2d638

4 files changed

Lines changed: 66 additions & 19 deletions

File tree

cmdstanpy/model.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,26 +297,29 @@ def src_info(self) -> Dict[str, Any]:
297297
get_logger().debug(e)
298298
return result
299299

300-
def format_model(
300+
def format(
301301
self,
302-
save: bool = False,
302+
overwrite_file: bool = False,
303303
canonicalize: Union[bool, str, Iterable[str]] = False,
304+
max_line_length: int = 78,
304305
*,
305-
unsafe: bool = False,
306+
backup: bool = True,
306307
) -> None:
307308
"""
308309
Run stanc's auto-formatter on the model code. Either saves directly
309310
back to the file or prints for inspection
310311
311312
312-
:param save: If True, save the updated code to disk, rather than
313-
printing it. By default False
313+
:param overwrite_file: If True, save the updated code to disk, rather
314+
than printing it. By default False
314315
:param canonicalize: Whether or not the compiler should 'canonicalize'
315316
the Stan model, removing things like deprecated syntax. Default is
316317
False. If True, all canonicalizations are run. If it is a list of
317318
strings, those options are passed to stanc (new in Stan 2.29)
318-
:param unsafe: If True, do not create stanfile.bak backups before
319-
writing to the file. Only do this if you're sure you have other
319+
:param max_line_length: Set the wrapping point for the formatter. The
320+
default value is 78, which wraps most lines by the 80th character.
321+
:param backup: If True, create a stanfile.bak backup before
322+
writing to the file. Only disable this if you're sure you have other
320323
copies of the file or are using a version control system like Git.
321324
"""
322325
if self.stan_file is None or not os.path.isfile(self.stan_file):
@@ -330,25 +333,32 @@ def format_model(
330333
)
331334

332335
if canonicalize:
333-
if isinstance(canonicalize, str):
336+
if cmdstan_version_before(2, 29) or isinstance(
337+
canonicalize, bool
338+
):
339+
cmd.append('--print-canonical')
340+
elif isinstance(canonicalize, str):
334341
cmd.append('--canonicalize=' + canonicalize)
335342
elif isinstance(canonicalize, Iterable):
336343
cmd.append('--canonicalize=' + ','.join(canonicalize))
337-
else:
338-
cmd.append('--print-canonical')
339344

345+
# before 2.29, having both --print-canonical
346+
# and --auto-format printed twice
340347
if not (cmdstan_version_before(2, 29) and canonicalize):
341348
cmd.append('--auto-format')
342349

350+
if not cmdstan_version_before(2, 29):
351+
cmd.append(f'--max-line-length={max_line_length}')
352+
343353
out = subprocess.run(
344354
cmd, capture_output=True, text=True, check=True
345355
)
346356
if out.stderr:
347357
get_logger().warning(out.stderr)
348358
result = out.stdout
349-
if save:
359+
if overwrite_file:
350360
if result:
351-
if not unsafe:
361+
if backup:
352362
shutil.copyfile(self.stan_file, self.stan_file + '.bak')
353363
with (open(self.stan_file, 'w')) as file_handle:
354364
file_handle.write(result)

test/data/format_me.stan

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
# pound-sign comment
21
generated quantities {
3-
int x;
4-
x <- (((((3)))));
2+
array[10,10,10,10,10] matrix[100,100] a_very_long_name;
3+
int x = (((10)));
4+
int y;
5+
if (1)
6+
y = 1;
7+
else
8+
y=2;
59
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# pound-sign comment
2+
generated quantities {
3+
int x;
4+
x <- (((((3)))));
5+
}

test/test_model.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,14 @@ def test_model_includes_implicit(self):
382382
not cmdstan_version_before(2, 32),
383383
reason="Deprecated syntax removed in Stan 2.32",
384384
)
385-
def test_model_format(self):
386-
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
385+
def test_model_format_deprecations(self):
386+
stan = os.path.join(DATAFILES_PATH, 'format_me_deprecations.stan')
387387

388388
model = CmdStanModel(stan_file=stan, compile=False)
389389

390390
sys_stdout = io.StringIO()
391391
with contextlib.redirect_stdout(sys_stdout):
392-
model.format_model()
392+
model.format()
393393

394394
formatted = sys_stdout.getvalue()
395395
self.assertIn("//", formatted)
@@ -398,13 +398,41 @@ def test_model_format(self):
398398

399399
sys_stdout = io.StringIO()
400400
with contextlib.redirect_stdout(sys_stdout):
401-
model.format_model(canonicalize=True)
401+
model.format(canonicalize=True)
402402

403403
formatted = sys_stdout.getvalue()
404404
print(formatted)
405405
self.assertNotIn("<-", formatted)
406406
self.assertEqual(formatted.count('('), 0)
407407

408+
@pytest.mark.skipif(
409+
cmdstan_version_before(2, 29), reason='Options only available later'
410+
)
411+
def test_model_format_options(self):
412+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
413+
414+
model = CmdStanModel(stan_file=stan, compile=False)
415+
416+
sys_stdout = io.StringIO()
417+
with contextlib.redirect_stdout(sys_stdout):
418+
model.format(max_line_length=10)
419+
formatted = sys_stdout.getvalue()
420+
self.assertGreater(len(formatted.splitlines()), 11)
421+
422+
sys_stdout = io.StringIO()
423+
with contextlib.redirect_stdout(sys_stdout):
424+
model.format(canonicalize='braces')
425+
formatted = sys_stdout.getvalue()
426+
self.assertEqual(formatted.count('{'), 3)
427+
self.assertEqual(formatted.count('('), 4)
428+
429+
sys_stdout = io.StringIO()
430+
with contextlib.redirect_stdout(sys_stdout):
431+
model.format(canonicalize=['parentheses'])
432+
formatted = sys_stdout.getvalue()
433+
self.assertEqual(formatted.count('{'), 1)
434+
self.assertEqual(formatted.count('('), 1)
435+
408436

409437
if __name__ == '__main__':
410438
unittest.main()

0 commit comments

Comments
 (0)