Skip to content

Commit ccec477

Browse files
authored
Merge pull request #449 from stan-dev/project/1.0-cleanup
Project/1.0 cleanup
2 parents 38932eb + 6072d26 commit ccec477

5 files changed

Lines changed: 120 additions & 24 deletions

File tree

cmdstanpy/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def _cleanup_tmpdir() -> None:
4646
from_csv,
4747
)
4848
from .utils import set_cmdstan_path # noqa
49-
from .utils import cmdstan_path, install_cmdstan, set_make_env, write_stan_json
49+
from .utils import (
50+
cmdstan_path,
51+
install_cmdstan,
52+
set_make_env,
53+
show_versions,
54+
write_stan_json,
55+
)
5056

5157
__all__ = [
5258
'set_cmdstan_path',
@@ -61,4 +67,5 @@ def _cleanup_tmpdir() -> None:
6167
'InferenceMetadata',
6268
'from_csv',
6369
'write_stan_json',
70+
'show_versions',
6471
]

cmdstanpy/model.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,12 @@ def compile(
344344
)
345345
if 'PCH file' in str(e):
346346
get_logger().warning(
347-
"%s, %s",
348-
"CmdStan's precompiled header (PCH) files ",
349-
"may need to be rebuilt.",
350-
)
351-
get_logger().warning(
352-
"%s %s",
353-
"If your model failed to compile please run ",
354-
"install_cmdstan(overwrite=True).",
355-
)
356-
get_logger().warning(
357-
"If the issue persists please open a bug report"
347+
"%s",
348+
"CmdStan's precompiled header (PCH) files "
349+
"may need to be rebuilt."
350+
"If your model failed to compile please run "
351+
"install_cmdstan(overwrite=True).\nIf the "
352+
"issue persists please open a bug report",
358353
)
359354

360355
compilation_failed = True
@@ -398,6 +393,7 @@ def optimize(
398393
history_size: Optional[int] = None,
399394
iter: Optional[int] = None,
400395
refresh: Optional[int] = None,
396+
time_fmt: str = "%Y%m%d%H%M%S",
401397
) -> CmdStanMLE:
402398
"""
403399
Run the specified CmdStan optimize algorithm to produce a
@@ -484,6 +480,10 @@ def optimize(
484480
:param refresh: Specify the number of iterations cmdstan will take
485481
between progress messages. Default value is 100.
486482
483+
:param time_fmt: A format string passed to
484+
:meth:`~datetime.datetime.strftime` to decide the file names for
485+
output CSVs. Defaults to "%Y%m%d%H%M%S"
486+
487487
:return: CmdStanMLE object
488488
"""
489489
optimize_args = OptimizeArgs(
@@ -514,7 +514,7 @@ def optimize(
514514
)
515515

516516
dummy_chain_id = 0
517-
runset = RunSet(args=args, chains=1)
517+
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
518518
self._run_cmdstan(runset, dummy_chain_id)
519519

520520
if not runset._check_retcodes():
@@ -555,6 +555,7 @@ def sample(
555555
save_profile: bool = False,
556556
show_progress: Union[bool, str] = False,
557557
refresh: Optional[int] = None,
558+
time_fmt: str = "%Y%m%d%H%M%S",
558559
) -> CmdStanMCMC:
559560
"""
560561
Run or more chains of the NUTS-HMC sampler to produce a set of draws
@@ -718,6 +719,10 @@ def sample(
718719
:param refresh: Specify the number of iterations cmdstan will take
719720
between progress messages. Default value is 100.
720721
722+
:param time_fmt: A format string passed to
723+
:meth:`~datetime.datetime.strftime` to decide the file names for
724+
output CSVs. Defaults to "%Y%m%d%H%M%S"
725+
721726
:return: CmdStanMCMC object
722727
"""
723728
if chains is None:
@@ -829,7 +834,9 @@ def sample(
829834
method_args=sampler_args,
830835
refresh=refresh,
831836
)
832-
runset = RunSet(args=args, chains=chains, chain_ids=chain_ids)
837+
runset = RunSet(
838+
args=args, chains=chains, chain_ids=chain_ids, time_fmt=time_fmt
839+
)
833840
pbar = None
834841
all_pbars = []
835842

@@ -899,6 +906,7 @@ def generate_quantities(
899906
gq_output_dir: Optional[str] = None,
900907
sig_figs: Optional[int] = None,
901908
refresh: Optional[int] = None,
909+
time_fmt: str = "%Y%m%d%H%M%S",
902910
) -> CmdStanGQ:
903911
"""
904912
Run CmdStan's generate_quantities method which runs the generated
@@ -950,6 +958,10 @@ def generate_quantities(
950958
:param refresh: Specify the number of iterations cmdstan will take
951959
between progress messages. Default value is 100.
952960
961+
:param time_fmt: A format string passed to
962+
:meth:`~datetime.datetime.strftime` to decide the file names for
963+
output CSVs. Defaults to "%Y%m%d%H%M%S"
964+
953965
:return: CmdStanGQ object
954966
"""
955967
if isinstance(mcmc_sample, CmdStanMCMC):
@@ -999,7 +1011,9 @@ def generate_quantities(
9991011
method_args=generate_quantities_args,
10001012
refresh=refresh,
10011013
)
1002-
runset = RunSet(args=args, chains=chains, chain_ids=chain_ids)
1014+
runset = RunSet(
1015+
args=args, chains=chains, chain_ids=chain_ids, time_fmt=time_fmt
1016+
)
10031017

10041018
parallel_chains_avail = cpu_count()
10051019
parallel_chains = max(min(parallel_chains_avail - 2, chains), 1)
@@ -1039,6 +1053,7 @@ def variational(
10391053
output_samples: Optional[int] = None,
10401054
require_converged: bool = True,
10411055
refresh: Optional[int] = None,
1056+
time_fmt: str = "%Y%m%d%H%M%S",
10421057
) -> CmdStanVB:
10431058
"""
10441059
Run CmdStan's variational inference algorithm to approximate
@@ -1123,6 +1138,10 @@ def variational(
11231138
:param refresh: Specify the number of iterations cmdstan will take
11241139
between progress messages. Default value is 100.
11251140
1141+
:param time_fmt: A format string passed to
1142+
:meth:`~datetime.datetime.strftime` to decide the file names for
1143+
output CSVs. Defaults to "%Y%m%d%H%M%S"
1144+
11261145
:return: CmdStanVB object
11271146
"""
11281147
variational_args = VariationalArgs(
@@ -1155,7 +1174,7 @@ def variational(
11551174
)
11561175

11571176
dummy_chain_id = 0
1158-
runset = RunSet(args=args, chains=1)
1177+
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
11591178
self._run_cmdstan(runset, dummy_chain_id)
11601179

11611180
# treat failure to converge as failure

cmdstanpy/stanfit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
chains: int = 4,
7171
chain_ids: Optional[List[int]] = None,
7272
logger: Optional[logging.Logger] = None,
73+
time_fmt: str = "%Y%m%d%H%M%S",
7374
) -> None:
7475
"""Initialize object."""
7576
self._args = args
@@ -100,7 +101,7 @@ def __init__(
100101
# prefix: ``<model_name>-<YYYYMMDDHHMM>-<chain_id>``
101102
# suffixes: ``-stdout.txt``, ``-stderr.txt``
102103
now = datetime.now()
103-
now_str = now.strftime('%Y%m%d%H%M')
104+
now_str = now.strftime(time_fmt)
104105
file_basename = '-'.join([args.model_name, now_str])
105106
if args.output_dir is not None:
106107
output_dir = args.output_dir
@@ -794,7 +795,7 @@ def _assemble_draws(self) -> None:
794795
line = fd.readline().strip() # metric type
795796
line = fd.readline().lstrip(' #\t')
796797
num_unconstrained_params = len(line.split(','))
797-
if chain == 0: # can't allocate w/o num params
798+
if chain == 0: # can't allocate w/o num params
798799
if self.metric_type == 'diag_e':
799800
self._metric = np.empty(
800801
(self.chains, num_unconstrained_params),

cmdstanpy/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,68 @@ def create_named_text_file(
10271027
return path
10281028

10291029

1030+
def show_versions(output: bool = True) -> str:
1031+
"""Prints out system and dependency information for debugging"""
1032+
1033+
import importlib
1034+
import locale
1035+
import struct
1036+
1037+
deps_info = []
1038+
try:
1039+
(sysname, _, release, _, machine, processor) = platform.uname()
1040+
deps_info.extend(
1041+
[
1042+
("python", sys.version),
1043+
("python-bits", struct.calcsize("P") * 8),
1044+
("OS", f"{sysname}"),
1045+
("OS-release", f"{release}"),
1046+
("machine", f"{machine}"),
1047+
("processor", f"{processor}"),
1048+
("byteorder", f"{sys.byteorder}"),
1049+
("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'),
1050+
("LANG", f'{os.environ.get("LANG", "None")}'),
1051+
("LOCALE", f"{locale.getlocale()}"),
1052+
]
1053+
)
1054+
# pylint: disable=broad-except
1055+
except Exception:
1056+
pass
1057+
1058+
try:
1059+
deps_info.append(('cmdstan_folder', cmdstan_path()))
1060+
# pylint: disable=broad-except
1061+
except Exception:
1062+
deps_info.append(('cmdstan', 'NOT FOUND'))
1063+
1064+
deps = ['cmdstanpy', 'pandas', 'xarray', 'tdqm', 'numpy', 'ujson']
1065+
for module in deps:
1066+
try:
1067+
if module in sys.modules:
1068+
mod = sys.modules[module]
1069+
else:
1070+
mod = importlib.import_module(module)
1071+
# pylint: disable=broad-except
1072+
except Exception:
1073+
deps_info.append((module, None))
1074+
else:
1075+
try:
1076+
ver = mod.__version__ # type: ignore
1077+
deps_info.append((module, ver))
1078+
# pylint: disable=broad-except
1079+
except Exception:
1080+
deps_info.append((module, "installed"))
1081+
1082+
out = 'INSTALLED VERSIONS\n---------------------\n'
1083+
for k, info in deps_info:
1084+
out += f'{k}: {info}\n'
1085+
if output:
1086+
print(out)
1087+
return " "
1088+
else:
1089+
return out
1090+
1091+
10301092
def install_cmdstan(
10311093
version: Optional[str] = None,
10321094
dir: Optional[str] = None,

test/test_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,19 @@ def test_set_path(self):
127127
self.assertEqual(install_version, os.environ['CMDSTAN'])
128128

129129
def test_validate_path(self):
130-
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
131-
if not os.path.exists(cmdstan_dir):
132-
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTANPY))
133-
install_version = os.path.join(
134-
cmdstan_dir, get_latest_cmdstan(cmdstan_dir)
135-
)
130+
if 'CMDSTAN' in os.environ:
131+
install_version = os.environ.get('CMDSTAN')
132+
else:
133+
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
134+
if not os.path.exists(cmdstan_dir):
135+
cmdstan_dir = os.path.expanduser(
136+
os.path.join('~', _DOT_CMDSTANPY)
137+
)
138+
139+
install_version = os.path.join(
140+
cmdstan_dir, get_latest_cmdstan(cmdstan_dir)
141+
)
142+
136143
set_cmdstan_path(install_version)
137144
validate_cmdstan_path(install_version)
138145
path_foo = os.path.abspath(os.path.join('releases', 'foo'))

0 commit comments

Comments
 (0)