|
9 | 9 | import sys |
10 | 10 | from collections import OrderedDict |
11 | 11 | from concurrent.futures import ThreadPoolExecutor |
| 12 | +from datetime import datetime |
12 | 13 | from io import StringIO |
13 | 14 | from multiprocessing import cpu_count |
14 | 15 | from pathlib import Path |
15 | | -from typing import Any, Callable, Dict, List, Mapping, Optional, Union |
| 16 | +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union |
16 | 17 |
|
17 | 18 | import ujson as json |
18 | 19 | from tqdm.auto import tqdm |
|
39 | 40 | MaybeDictToFilePath, |
40 | 41 | SanitizedOrTmpFilePath, |
41 | 42 | cmdstan_path, |
| 43 | + cmdstan_version, |
42 | 44 | cmdstan_version_before, |
43 | 45 | do_command, |
44 | 46 | get_logger, |
@@ -297,6 +299,98 @@ def src_info(self) -> Dict[str, Any]: |
297 | 299 | get_logger().debug(e) |
298 | 300 | return result |
299 | 301 |
|
| 302 | + def format( |
| 303 | + self, |
| 304 | + overwrite_file: bool = False, |
| 305 | + canonicalize: Union[bool, str, Iterable[str]] = False, |
| 306 | + max_line_length: int = 78, |
| 307 | + *, |
| 308 | + backup: bool = True, |
| 309 | + ) -> None: |
| 310 | + """ |
| 311 | + Run stanc's auto-formatter on the model code. Either saves directly |
| 312 | + back to the file or prints for inspection |
| 313 | +
|
| 314 | +
|
| 315 | + :param overwrite_file: If True, save the updated code to disk, rather |
| 316 | + than printing it. By default False |
| 317 | + :param canonicalize: Whether or not the compiler should 'canonicalize' |
| 318 | + the Stan model, removing things like deprecated syntax. Default is |
| 319 | + False. If True, all canonicalizations are run. If it is a list of |
| 320 | + strings, those options are passed to stanc (new in Stan 2.29) |
| 321 | + :param max_line_length: Set the wrapping point for the formatter. The |
| 322 | + default value is 78, which wraps most lines by the 80th character. |
| 323 | + :param backup: If True, create a stanfile.bak backup before |
| 324 | + writing to the file. Only disable this if you're sure you have other |
| 325 | + copies of the file or are using a version control system like Git. |
| 326 | + """ |
| 327 | + if self.stan_file is None or not os.path.isfile(self.stan_file): |
| 328 | + raise ValueError("No Stan file found for this module") |
| 329 | + try: |
| 330 | + cmd = ( |
| 331 | + [os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)] |
| 332 | + # handle include-paths, allow-undefined etc |
| 333 | + + self._compiler_options.compose_stanc() |
| 334 | + + [self.stan_file] |
| 335 | + ) |
| 336 | + |
| 337 | + if canonicalize: |
| 338 | + if cmdstan_version_before(2, 29): |
| 339 | + if isinstance(canonicalize, bool): |
| 340 | + cmd.append('--print-canonical') |
| 341 | + else: |
| 342 | + raise ValueError( |
| 343 | + "Invalid arguments passed for current CmdStan" |
| 344 | + + " version({})\n".format( |
| 345 | + cmdstan_version() or "Unknown" |
| 346 | + ) |
| 347 | + + "--canonicalize requires 2.29 or higher" |
| 348 | + ) |
| 349 | + else: |
| 350 | + if isinstance(canonicalize, str): |
| 351 | + cmd.append('--canonicalize=' + canonicalize) |
| 352 | + elif isinstance(canonicalize, Iterable): |
| 353 | + cmd.append('--canonicalize=' + ','.join(canonicalize)) |
| 354 | + else: |
| 355 | + cmd.append('--print-canonical') |
| 356 | + |
| 357 | + # before 2.29, having both --print-canonical |
| 358 | + # and --auto-format printed twice |
| 359 | + if not (cmdstan_version_before(2, 29) and canonicalize): |
| 360 | + cmd.append('--auto-format') |
| 361 | + |
| 362 | + if not cmdstan_version_before(2, 29): |
| 363 | + cmd.append(f'--max-line-length={max_line_length}') |
| 364 | + elif max_line_length != 78: |
| 365 | + raise ValueError( |
| 366 | + "Invalid arguments passed for current CmdStan version" |
| 367 | + + " ({})\n".format(cmdstan_version() or "Unknown") |
| 368 | + + "--max-line-length requires 2.29 or higher" |
| 369 | + ) |
| 370 | + |
| 371 | + out = subprocess.run( |
| 372 | + cmd, capture_output=True, text=True, check=True |
| 373 | + ) |
| 374 | + if out.stderr: |
| 375 | + get_logger().warning(out.stderr) |
| 376 | + result = out.stdout |
| 377 | + if overwrite_file: |
| 378 | + if result: |
| 379 | + if backup: |
| 380 | + shutil.copyfile( |
| 381 | + self.stan_file, |
| 382 | + self.stan_file |
| 383 | + + '.bak-' |
| 384 | + + datetime.now().strftime("%Y%m%d%H%M%S"), |
| 385 | + ) |
| 386 | + with (open(self.stan_file, 'w')) as file_handle: |
| 387 | + file_handle.write(result) |
| 388 | + else: |
| 389 | + print(result) |
| 390 | + |
| 391 | + except (ValueError, RuntimeError) as e: |
| 392 | + raise RuntimeError("Stanc formatting failed") from e |
| 393 | + |
300 | 394 | @property |
301 | 395 | def stanc_options(self) -> Dict[str, Union[bool, int, str]]: |
302 | 396 | """Options to stanc compilers.""" |
|
0 commit comments