Skip to content

Commit 79c1df6

Browse files
committed
changes per code review
1 parent 42b51dd commit 79c1df6

2 files changed

Lines changed: 21 additions & 29 deletions

File tree

cmdstanpy/model.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -943,39 +943,35 @@ def sample(
943943
one_process_per_chain = True
944944
assert isinstance(self.exe_file, str) # make typechecker happy
945945
info_dict = self.exe_info()
946-
stan_threads = info_dict.get('STAN_THREADS')
947-
if stan_threads is not None:
948-
stan_threads = stan_threads.lower()
946+
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
949947
if (
950948
force_one_process_per_chain is None
951949
and not cmdstan_version_before(2, 28, info_dict)
952-
and stan_threads is not None
953950
and stan_threads == 'true'
954951
):
955952
one_process_per_chain = False
956953
num_threads = parallel_chains * num_threads
957954
parallel_procs = 1
958955
if force_one_process_per_chain is False:
959-
if cmdstan_version_before(2, 28, info_dict):
956+
if not cmdstan_version_before(2, 28, info_dict):
957+
one_process_per_chain = False
958+
num_threads = parallel_chains * num_threads
959+
parallel_procs = 1
960+
if stan_threads == 'false':
961+
get_logger().warning(
962+
'Stan program not compiled for threading, '
963+
'process will run chains sequentially. '
964+
'For multi-chain parallelization, recompile '
965+
'the model with argument '
966+
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
967+
)
968+
else:
960969
get_logger().warning(
961970
'Installed version of CmdStan cannot multi-process '
962971
'chains, will run %d processes. '
963972
'Run "install_cmdstan" to upgrade to latest version.',
964973
chains,
965974
)
966-
elif stan_threads is None or stan_threads == 'false':
967-
get_logger().warning(
968-
'Stan program not compiled for threading, '
969-
'will run %d processes. '
970-
'Recompile model and specify argument '
971-
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.',
972-
chains,
973-
)
974-
else:
975-
one_process_per_chain = False
976-
num_threads = parallel_chains * num_threads
977-
parallel_procs = 1
978-
979975
os.environ['STAN_NUM_THREADS'] = str(num_threads)
980976

981977
if show_console:
@@ -1405,8 +1401,8 @@ def _run_cmdstan(
14051401
logger_prefix = 'CmdStan'
14061402
console_prefix = ''
14071403
if runset.one_process_per_chain:
1408-
logger_prefix = 'Chain [{}]'.format(idx + runset.chain_ids[0])
1409-
console_prefix = 'Chain [{}] '.format(idx + runset.chain_ids[0])
1404+
logger_prefix = 'Chain [{}]'.format(runset.chain_ids[idx])
1405+
console_prefix = 'Chain [{}] '.format(runset.chain_ids[idx])
14101406

14111407
cmd = runset.cmd(idx)
14121408
get_logger().debug('CmdStan args: %s', cmd)

cmdstanpy/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,13 @@ def cmdstan_version_before(
237237
:return: True if version at or above major.minor, else False.
238238
"""
239239
cur_version = None
240-
if info is None:
240+
if info is None or 'stan_version_major' not in info:
241241
cur_version = cmdstan_version()
242242
else:
243-
if (
244-
info['stan_version_major'] is not None
245-
and info['stan_version_minor'] is not None
246-
):
247-
cur_version = (
248-
int(info['stan_version_major']),
249-
int(info['stan_version_minor']),
250-
)
243+
cur_version = (
244+
int(info['stan_version_major']),
245+
int(info['stan_version_minor']),
246+
)
251247
if cur_version is None:
252248
get_logger().info(
253249
'Cannot determine whether version is before %d.%d.', major, minor

0 commit comments

Comments
 (0)