Skip to content

Commit 0ab1410

Browse files
committed
Start testing
1 parent c98a625 commit 0ab1410

3 files changed

Lines changed: 63 additions & 17 deletions

File tree

cmdstanpy/model.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def src_info(self) -> Dict[str, Any]:
298298
return result
299299

300300
def format_model(
301-
self, save: bool = False, canonicalize: Union[bool, List[str]] = False
301+
self,
302+
save: bool = False,
303+
canonicalize: Union[bool, str, List[str]] = False,
304+
*,
305+
unsafe: bool = False,
302306
) -> None:
303307
"""
304308
Run stanc's auto-formatter on the model code. Either saves directly
@@ -311,37 +315,43 @@ def format_model(
311315
the Stan model, removing things like deprecated syntax. Default is
312316
False. If True, all canonicalizations are run. If it is a list of
313317
strings, those options are passed to stanc (new in Stan 2.29)
318+
:param unsafe: If True, do not create stanfile.bak backups before
319+
writing to the file. Only do this if you're sure you have other
320+
copies of the file or are using a version control system like Git.
314321
"""
315322
if self.stan_file is None or not os.path.isfile(self.stan_file):
316323
raise ValueError("No Stan file found for this module")
317324
try:
318-
# TODO need include paths if they exist.
319-
cmd = [
320-
os.path.join('.', 'bin', 'stanc' + EXTENSION),
321-
'--auto-format',
322-
]
325+
cmd = (
326+
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
327+
# handle include-paths, allow-undefined etc
328+
+ self._compiler_options.compose_stanc()
329+
+ [self.stan_file]
330+
)
331+
323332
if canonicalize:
324333
if isinstance(canonicalize, list):
325334
cmd.append('--canonicalize=' + ','.join(canonicalize))
335+
elif isinstance(canonicalize, str):
336+
cmd.append('--canonicalize=' + canonicalize)
326337
else:
327338
cmd.append('--print-canonical')
328339

329-
cmd.append(self.stan_file)
340+
if not (cmdstan_version_before(2, 29) and canonicalize):
341+
cmd.append('--auto-format')
330342

331343
out = subprocess.run(
332-
cmd,
333-
cwd=cmdstan_path(),
334-
capture_output=True,
335-
text=True,
336-
check=True,
344+
cmd, capture_output=True, text=True, check=True
337345
)
338346
if out.stderr:
339-
print(out.stderr)
347+
get_logger().warning(out.stderr)
340348
result = out.stdout
341349
if save:
342-
shutil.copyfile(self.stan_file, self.stan_file + '.bak')
343-
with (open(self.stan_file, 'w')) as file:
344-
file.write(result)
350+
if result:
351+
if not unsafe:
352+
shutil.copyfile(self.stan_file, self.stan_file + '.bak')
353+
with (open(self.stan_file, 'w')) as file:
354+
file.write(result)
345355
else:
346356
print(result)
347357

test/data/format_me.stan

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# pound-sign comment
2+
generated quantities {
3+
int x;
4+
x <- (((((3)))));
5+
}

test/test_model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""CmdStanModel tests"""
22

3+
import contextlib
4+
import io
35
import logging
46
import os
57
import shutil
@@ -10,7 +12,7 @@
1012
from testfixtures import LogCapture, StringComparison
1113

1214
from cmdstanpy.model import CmdStanModel
13-
from cmdstanpy.utils import EXTENSION
15+
from cmdstanpy.utils import EXTENSION, cmdstan_version_before
1416

1517
HERE = os.path.dirname(os.path.abspath(__file__))
1618
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -34,6 +36,7 @@
3436
BERN_BASENAME = 'bernoulli'
3537

3638

39+
# pylint: disable=too-many-public-methods
3740
class CmdStanModelTest(CustomTestCase):
3841
def test_model_good(self):
3942
# compile on instantiation, override model name
@@ -374,6 +377,34 @@ def test_model_includes_implicit(self):
374377
model2 = CmdStanModel(stan_file=stan)
375378
self.assertPathsEqual(model2.exe_file, exe)
376379

380+
def test_model_format(self):
381+
# deprecations expire in this version
382+
if cmdstan_version_before(2, 32):
383+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
384+
385+
model = CmdStanModel(stan_file=stan, compile=False)
386+
387+
sys_stdout = io.StringIO()
388+
with contextlib.redirect_stdout(sys_stdout):
389+
model.format_model()
390+
391+
formatted = sys_stdout.getvalue()
392+
self.assertIn("//", formatted)
393+
self.assertNotIn("#", formatted)
394+
self.assertEqual(formatted.count('('), 5)
395+
396+
sys_stdout = io.StringIO()
397+
with contextlib.redirect_stdout(sys_stdout):
398+
model.format_model(canonicalize=True)
399+
400+
formatted = sys_stdout.getvalue()
401+
print(formatted)
402+
self.assertNotIn("<-", formatted)
403+
self.assertEqual(formatted.count('('), 0)
404+
405+
else:
406+
assert False, "Test needs to be updated for Stan 2.32"
407+
377408

378409
if __name__ == '__main__':
379410
unittest.main()

0 commit comments

Comments
 (0)