Skip to content

Commit b55b276

Browse files
committed
Merge branch 'develop' into feature/tuple-io
2 parents 96e4e21 + 107a347 commit b55b276

7 files changed

Lines changed: 265 additions & 150 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
adapt_metric_window: Optional[int] = None,
5656
adapt_step_size: Optional[int] = None,
5757
fixed_param: bool = False,
58+
num_chains: int = 1,
5859
) -> None:
5960
"""Initialize object."""
6061
self.iter_warmup = iter_warmup
@@ -73,6 +74,7 @@ def __init__(
7374
self.adapt_step_size = adapt_step_size
7475
self.fixed_param = fixed_param
7576
self.diagnostic_file = None
77+
self.num_chains = num_chains
7678

7779
def validate(self, chains: Optional[int]) -> None:
7880
"""
@@ -316,6 +318,10 @@ def validate(self, chains: Optional[int]) -> None:
316318
'Argument "adapt_step_size" must be a non-negative integer,'
317319
'found {}'.format(self.adapt_step_size)
318320
)
321+
if self.num_chains < 1 or not isinstance(
322+
self.num_chains, (int, np.integer)
323+
):
324+
raise ValueError("num_chains must be positive")
319325

320326
if self.fixed_param and (
321327
self.max_treedepth is not None
@@ -378,6 +384,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
378384
cmd.append('window={}'.format(self.adapt_metric_window))
379385
if self.adapt_step_size is not None:
380386
cmd.append('term_buffer={}'.format(self.adapt_step_size))
387+
if self.num_chains > 1:
388+
cmd.append('num_chains={}'.format(self.num_chains))
381389

382390
return cmd
383391

@@ -921,8 +929,12 @@ def validate(self) -> None:
921929
)
922930
)
923931
elif isinstance(self.inits, str):
924-
if not os.path.exists(self.inits):
925-
raise ValueError('no such file {}'.format(self.inits))
932+
if not (
933+
isinstance(self.method_args, SamplerArgs)
934+
and self.method_args.num_chains > 1
935+
):
936+
if not os.path.exists(self.inits):
937+
raise ValueError('no such file {}'.format(self.inits))
926938
elif isinstance(self.inits, list):
927939
if self.chain_ids is None:
928940
raise ValueError(
@@ -948,7 +960,6 @@ def compose_command(
948960
*,
949961
diagnostic_file: Optional[str] = None,
950962
profile_file: Optional[str] = None,
951-
num_chains: Optional[int] = None,
952963
) -> List[str]:
953964
"""
954965
Compose CmdStan command for non-default arguments.
@@ -992,6 +1003,4 @@ def compose_command(
9921003
if self.sig_figs is not None:
9931004
cmd.append('sig_figs={}'.format(self.sig_figs))
9941005
cmd = self.method_args.compose(idx, cmd)
995-
if num_chains:
996-
cmd.append('num_chains={}'.format(num_chains))
9971006
return cmd

cmdstanpy/model.py

Lines changed: 105 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
)
6060
from cmdstanpy.utils import (
6161
EXTENSION,
62-
MaybeDictToFilePath,
6362
SanitizedOrTmpFilePath,
6463
cmdstan_path,
6564
cmdstan_version,
@@ -68,6 +67,7 @@
6867
get_logger,
6968
returncode_msg,
7069
)
70+
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json
7171

7272
from . import progress as progbar
7373

@@ -573,7 +573,7 @@ def optimize(
573573
self,
574574
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
575575
seed: Optional[int] = None,
576-
inits: Union[Dict[str, float], float, str, os.PathLike, None] = None,
576+
inits: Union[Mapping[str, Any], float, str, os.PathLike, None] = None,
577577
output_dir: OptionalPath = None,
578578
sig_figs: Optional[int] = None,
579579
save_profile: bool = False,
@@ -722,7 +722,9 @@ def optimize(
722722
"in CmdStan 2.32 and above."
723723
)
724724

725-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
725+
with temp_single_json(data) as _data, temp_inits(
726+
inits, allow_multiple=False
727+
) as _inits:
726728
args = CmdStanArgs(
727729
self._name,
728730
self._exe_file,
@@ -766,7 +768,14 @@ def sample(
766768
threads_per_chain: Optional[int] = None,
767769
seed: Union[int, List[int], None] = None,
768770
chain_ids: Union[int, List[int], None] = None,
769-
inits: Union[Dict[str, float], float, str, List[str], None] = None,
771+
inits: Union[
772+
Mapping[str, Any],
773+
float,
774+
str,
775+
List[str],
776+
List[Mapping[str, Any]],
777+
None,
778+
] = None,
770779
iter_warmup: Optional[int] = None,
771780
iter_sampling: Optional[int] = None,
772781
save_warmup: bool = False,
@@ -1003,6 +1012,69 @@ def sample(
10031012
chains
10041013
)
10051014
)
1015+
1016+
if parallel_chains is None:
1017+
parallel_chains = max(min(cpu_count(), chains), 1)
1018+
elif parallel_chains > chains:
1019+
get_logger().info(
1020+
'Requested %u parallel_chains but only %u required, '
1021+
'will run all chains in parallel.',
1022+
parallel_chains,
1023+
chains,
1024+
)
1025+
parallel_chains = chains
1026+
elif parallel_chains < 1:
1027+
raise ValueError(
1028+
'Argument parallel_chains must be a positive integer, '
1029+
'found {}.'.format(parallel_chains)
1030+
)
1031+
if threads_per_chain is None:
1032+
threads_per_chain = 1
1033+
if threads_per_chain < 1:
1034+
raise ValueError(
1035+
'Argument threads_per_chain must be a positive integer, '
1036+
'found {}.'.format(threads_per_chain)
1037+
)
1038+
1039+
parallel_procs = parallel_chains
1040+
num_threads = threads_per_chain
1041+
one_process_per_chain = True
1042+
info_dict = self.exe_info()
1043+
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
1044+
# run multi-chain sampler unless algo is fixed_param or 1 chain
1045+
if chains == 1:
1046+
force_one_process_per_chain = True
1047+
1048+
if (
1049+
force_one_process_per_chain is None
1050+
and not cmdstan_version_before(2, 28, info_dict)
1051+
and stan_threads == 'true'
1052+
):
1053+
one_process_per_chain = False
1054+
num_threads = parallel_chains * num_threads
1055+
parallel_procs = 1
1056+
if force_one_process_per_chain is False:
1057+
if not cmdstan_version_before(2, 28, info_dict):
1058+
one_process_per_chain = False
1059+
num_threads = parallel_chains * num_threads
1060+
parallel_procs = 1
1061+
if stan_threads == 'false':
1062+
get_logger().warning(
1063+
'Stan program not compiled for threading, '
1064+
'process will run chains sequentially. '
1065+
'For multi-chain parallelization, recompile '
1066+
'the model with argument '
1067+
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
1068+
)
1069+
else:
1070+
get_logger().warning(
1071+
'Installed version of CmdStan cannot multi-process '
1072+
'chains, will run %d processes. '
1073+
'Run "install_cmdstan" to upgrade to latest version.',
1074+
chains,
1075+
)
1076+
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1077+
10061078
if chain_ids is None:
10071079
chain_ids = [i + 1 for i in range(chains)]
10081080
else:
@@ -1014,6 +1086,13 @@ def sample(
10141086
)
10151087
chain_ids = [i + chain_ids for i in range(chains)]
10161088
else:
1089+
if not one_process_per_chain:
1090+
for i, j in zip(chain_ids, chain_ids[1:]):
1091+
if i != j - 1:
1092+
raise ValueError(
1093+
'chain_ids must be sequential list of integers,'
1094+
' found {}.'.format(chain_ids)
1095+
)
10171096
if not len(chain_ids) == chains:
10181097
raise ValueError(
10191098
'Chain_ids must correspond to number of chains'
@@ -1029,6 +1108,7 @@ def sample(
10291108
)
10301109

10311110
sampler_args = SamplerArgs(
1111+
num_chains=1 if one_process_per_chain else chains,
10321112
iter_warmup=iter_warmup,
10331113
iter_sampling=iter_sampling,
10341114
save_warmup=save_warmup,
@@ -1043,14 +1123,25 @@ def sample(
10431123
adapt_step_size=adapt_step_size,
10441124
fixed_param=fixed_param,
10451125
)
1046-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
1126+
1127+
with temp_single_json(data) as _data, temp_inits(
1128+
inits, id=chain_ids[0]
1129+
) as _inits:
1130+
cmdstan_inits: Union[str, List[str], int, float, None]
1131+
if one_process_per_chain and isinstance(inits, list): # legacy
1132+
cmdstan_inits = [
1133+
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
1134+
]
1135+
else:
1136+
cmdstan_inits = _inits
1137+
10471138
args = CmdStanArgs(
10481139
self._name,
10491140
self._exe_file,
10501141
chain_ids=chain_ids,
10511142
data=_data,
10521143
seed=seed,
1053-
inits=_inits,
1144+
inits=cmdstan_inits,
10541145
output_dir=output_dir,
10551146
sig_figs=sig_figs,
10561147
save_latent_dynamics=save_latent_dynamics,
@@ -1059,67 +1150,6 @@ def sample(
10591150
refresh=refresh,
10601151
)
10611152

1062-
if parallel_chains is None:
1063-
parallel_chains = max(min(cpu_count(), chains), 1)
1064-
elif parallel_chains > chains:
1065-
get_logger().info(
1066-
'Requested %u parallel_chains but only %u required, '
1067-
'will run all chains in parallel.',
1068-
parallel_chains,
1069-
chains,
1070-
)
1071-
parallel_chains = chains
1072-
elif parallel_chains < 1:
1073-
raise ValueError(
1074-
'Argument parallel_chains must be a positive integer, '
1075-
'found {}.'.format(parallel_chains)
1076-
)
1077-
if threads_per_chain is None:
1078-
threads_per_chain = 1
1079-
if threads_per_chain < 1:
1080-
raise ValueError(
1081-
'Argument threads_per_chain must be a positive integer, '
1082-
'found {}.'.format(threads_per_chain)
1083-
)
1084-
1085-
parallel_procs = parallel_chains
1086-
num_threads = threads_per_chain
1087-
one_process_per_chain = True
1088-
info_dict = self.exe_info()
1089-
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
1090-
if chains == 1:
1091-
force_one_process_per_chain = True
1092-
1093-
if (
1094-
force_one_process_per_chain is None
1095-
and not cmdstan_version_before(2, 28, info_dict)
1096-
and stan_threads == 'true'
1097-
):
1098-
one_process_per_chain = False
1099-
num_threads = parallel_chains * num_threads
1100-
parallel_procs = 1
1101-
if force_one_process_per_chain is False:
1102-
if not cmdstan_version_before(2, 28, info_dict):
1103-
one_process_per_chain = False
1104-
num_threads = parallel_chains * num_threads
1105-
parallel_procs = 1
1106-
if stan_threads == 'false':
1107-
get_logger().warning(
1108-
'Stan program not compiled for threading, '
1109-
'process will run chains sequentially. '
1110-
'For multi-chain parallelization, recompile '
1111-
'the model with argument '
1112-
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
1113-
)
1114-
else:
1115-
get_logger().warning(
1116-
'Installed version of CmdStan cannot multi-process '
1117-
'chains, will run %d processes. '
1118-
'Run "install_cmdstan" to upgrade to latest version.',
1119-
chains,
1120-
)
1121-
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1122-
11231153
if show_console:
11241154
show_progress = False
11251155
else:
@@ -1359,7 +1389,7 @@ def generate_quantities(
13591389
csv_files=fit_csv_files
13601390
)
13611391
generate_quantities_args.validate(chains)
1362-
with MaybeDictToFilePath(data, None) as (_data, _inits):
1392+
with temp_single_json(data) as _data:
13631393
args = CmdStanArgs(
13641394
self._name,
13651395
self._exe_file,
@@ -1534,7 +1564,9 @@ def variational(
15341564
output_samples=output_samples,
15351565
)
15361566

1537-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
1567+
with temp_single_json(data) as _data, temp_inits(
1568+
inits, allow_multiple=False
1569+
) as _inits:
15381570
args = CmdStanArgs(
15391571
self._name,
15401572
self._exe_file,
@@ -1641,7 +1673,9 @@ def log_prob(
16411673
"Method 'log_prob' not available for CmdStan versions "
16421674
"before 2.31"
16431675
)
1644-
with MaybeDictToFilePath(data, params) as (_data, _params):
1676+
with temp_single_json(data) as _data, temp_single_json(
1677+
params
1678+
) as _params:
16451679
cmd = [
16461680
str(self.exe_file),
16471681
"log_prob",
@@ -1749,7 +1783,7 @@ def laplace_sample(
17491783
cmdstan_mode.runset.csv_files[0], draws, jacobian
17501784
)
17511785

1752-
with MaybeDictToFilePath(data) as (_data,):
1786+
with temp_single_json(data) as _data:
17531787
args = CmdStanArgs(
17541788
self._name,
17551789
self._exe_file,

cmdstanpy/stanfit/runset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def cmd(self, idx: int) -> List[str]:
179179
profile_file=self.file_path(".csv", extra="-profile")
180180
if self._args.save_profile
181181
else None,
182-
num_chains=self._chains,
183182
)
184183

185184
@property

cmdstanpy/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from .command import do_command, returncode_msg
2323
from .data_munging import build_xarray_data, flatten_chains
2424
from .filesystem import (
25-
MaybeDictToFilePath,
2625
SanitizedOrTmpFilePath,
2726
create_named_text_file,
2827
pushd,
@@ -112,7 +111,6 @@ def show_versions(output: bool = True) -> str:
112111

113112
__all__ = [
114113
'EXTENSION',
115-
'MaybeDictToFilePath',
116114
'SanitizedOrTmpFilePath',
117115
'build_xarray_data',
118116
'check_sampler_csv',

0 commit comments

Comments
 (0)