Skip to content

Commit 2b1ff63

Browse files
committed
changes per code review
1 parent b63d3d4 commit 2b1ff63

5 files changed

Lines changed: 77 additions & 63 deletions

File tree

cmdstanpy/model.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def exe_info(self) -> Dict[str, str]:
434434
"""
435435
Run model with option 'info'. Parse output statements, which all
436436
have form 'key = value' into a Dict.
437-
If exe file compiled with CmdStan < 2.27, calling model with
438-
option 'info' fail and method returns None.
437+
If exe file compiled with CmdStan < 2.27, option 'info' isn't
438+
available and the method returns an empty dictionary.
439439
"""
440440
result: Dict[str, Any] = {}
441441
if self.exe_file is None:
@@ -941,19 +941,22 @@ def sample(
941941
parallel_procs = parallel_chains
942942
num_threads = threads_per_chain
943943
one_process_per_chain = True
944+
assert isinstance(self.exe_file, str) # make typechecker happy
945+
info_dict = self.exe_info()
946+
stan_threads = info_dict.get('STAN_THREADS')
947+
if stan_threads is not None:
948+
stan_threads = stan_threads.lower()
944949
if (
945950
force_one_process_per_chain is None
946-
and not cmdstan_version_before(2, 28)
951+
and not cmdstan_version_before(2, 28, info_dict)
952+
and stan_threads == 'true'
947953
):
948-
assert isinstance(self.exe_file, str) # make typechecker happy
949-
info_dict = self.exe_info()
950-
if info_dict.get('STAN_THREADS') == 'true':
951-
one_process_per_chain = False
952-
num_threads = parallel_chains * num_threads
953-
parallel_procs = 1
954+
one_process_per_chain = False
955+
num_threads = parallel_chains * num_threads
956+
parallel_procs = 1
954957
elif (
955958
force_one_process_per_chain is False
956-
and cmdstan_version_before(2, 28)
959+
and cmdstan_version_before(2, 28, info_dict)
957960
):
958961
get_logger().warning(
959962
'Installed version of CmdStan cannot multi-process chains, '
@@ -985,9 +988,7 @@ def sample(
985988
iter_total = iter_total // refresh + 2
986989

987990
progress_hook = self._wrap_sampler_progress_hook(
988-
one_process_per_chain=one_process_per_chain,
989-
chains=chains,
990-
offset=chain_ids[0],
991+
chain_ids=chain_ids,
991992
total=iter_total,
992993
)
993994
runset = RunSet(
@@ -1007,7 +1008,9 @@ def sample(
10071008
show_console=show_console,
10081009
progress_hook=progress_hook,
10091010
)
1010-
if show_progress:
1011+
if show_progress and progress_hook is not None:
1012+
progress_hook("Done", -1) # -1 == all chains finished
1013+
10111014
# advance terminal window cursor past progress bars
10121015
term_size: os.terminal_size = shutil.get_terminal_size(
10131016
fallback=(80, 24)
@@ -1381,6 +1384,7 @@ def _run_cmdstan(
13811384
Args 'show_progress' and 'show_console' allow use of progress bar,
13821385
streaming output to console, respectively.
13831386
"""
1387+
get_logger().debug('idx %d', idx)
13841388
get_logger().debug(
13851389
'running CmdStan, num_threads: %s',
13861390
str(os.environ.get('STAN_NUM_THREADS')),
@@ -1389,8 +1393,8 @@ def _run_cmdstan(
13891393
logger_prefix = 'CmdStan'
13901394
console_prefix = ''
13911395
if runset.one_process_per_chain:
1392-
logger_prefix = 'Chain [{}]'.format(idx + 1)
1393-
console_prefix = 'Chain [{}] '.format(idx + 1)
1396+
logger_prefix = 'Chain [{}]'.format(idx + runset.chain_ids[0])
1397+
console_prefix = 'Chain [{}] '.format(idx + runset.chain_ids[0])
13941398

13951399
cmd = runset.cmd(idx)
13961400
get_logger().debug('CmdStan args: %s', cmd)
@@ -1418,10 +1422,10 @@ def _run_cmdstan(
14181422
elif progress_hook is not None:
14191423
progress_hook(line, idx)
14201424

1421-
if progress_hook is not None and proc.returncode == 0:
1422-
progress_hook("Done", idx)
1423-
14241425
stdout, _ = proc.communicate()
1426+
retcode = proc.returncode
1427+
runset._set_retcode(idx, retcode)
1428+
14251429
if stdout:
14261430
fd_out.write(stdout)
14271431
if show_console:
@@ -1434,15 +1438,15 @@ def _run_cmdstan(
14341438
raise RuntimeError(msg) from e
14351439
finally:
14361440
fd_out.close()
1441+
14371442
if not show_progress:
14381443
get_logger().info('%s done processing', logger_prefix)
14391444

1440-
runset._set_retcode(idx, proc.returncode)
1441-
if proc.returncode != 0:
1442-
retcode_summary = returncode_msg(proc.returncode)
1445+
if retcode != 0:
1446+
retcode_summary = returncode_msg(retcode)
14431447
serror = ''
14441448
try:
1445-
serror = os.strerror(proc.returncode)
1449+
serror = os.strerror(retcode)
14461450
except (ArithmeticError, ValueError):
14471451
pass
14481452
get_logger().error(
@@ -1462,9 +1466,7 @@ def _run_cmdstan(
14621466
@staticmethod
14631467
@progbar.wrap_callback
14641468
def _wrap_sampler_progress_hook(
1465-
one_process_per_chain: bool,
1466-
chains: int,
1467-
offset: int,
1469+
chain_ids: List[int],
14681470
total: int,
14691471
) -> Optional[Callable[[str, int], None]]:
14701472
"""
@@ -1473,39 +1475,34 @@ def _wrap_sampler_progress_hook(
14731475
process, "Chain [id] Iteration" for multi-chain processing.
14741476
For the latter, manage array of pbars, update accordingly.
14751477
"""
1476-
do_match = chains > 1 and not one_process_per_chain
14771478
pat = re.compile(r'Chain \[(\d*)\] (Iteration.*)')
1478-
1479-
pbars: List[tqdm] = [
1480-
tqdm(
1479+
pbars: Dict[int, tqdm] = {
1480+
chain_id: tqdm(
14811481
total=total,
14821482
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
14831483
postfix=[dict(value="Status")],
1484-
desc=f'chain {offset + i}',
1484+
desc=f'chain {chain_id}',
14851485
colour='yellow',
14861486
)
1487-
for i in range(chains)
1488-
]
1487+
for chain_id in chain_ids
1488+
}
14891489

14901490
def progress_hook(line: str, idx: int) -> None:
14911491
if line == "Done":
1492-
for i in range(chains):
1493-
pbars[i].postfix[0]["value"] = 'Sampling completed'
1494-
pbars[i].update(total - pbars[i].n)
1495-
pbars[i].close()
1492+
for pbar in pbars.values():
1493+
pbar.postfix[0]["value"] = 'Sampling completed'
1494+
pbar.update(total - pbar.n)
1495+
pbar.close()
14961496
else:
1497-
if do_match:
1498-
match = pat.match(line)
1499-
if match:
1500-
idx = int(match.group(1)) - offset
1501-
mline = match.group(2).strip()
1502-
else:
1503-
return
1497+
match = pat.match(line)
1498+
if match:
1499+
idx = int(match.group(1))
1500+
mline = match.group(2).strip()
1501+
elif line.startswith("Iteration"):
1502+
mline = line
1503+
idx = chain_ids[idx]
15041504
else:
1505-
if line.startswith("Iteration"):
1506-
mline = line
1507-
else:
1508-
return
1505+
return
15091506
if 'Sampling' in mline:
15101507
pbars[idx].colour = 'blue'
15111508
pbars[idx].update(1)

cmdstanpy/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def cmdstan_version() -> Optional[Tuple[int, ...]]:
225225
return tuple(int(x) for x in splits[0:2])
226226

227227

228-
def cmdstan_version_before(major: int, minor: int) -> bool:
228+
def cmdstan_version_before(
229+
major: int, minor: int, info: Optional[Dict[str, str]] = None
230+
) -> bool:
229231
"""
230232
Check that CmdStan version is less than Major.minor version.
231233
@@ -234,7 +236,18 @@ def cmdstan_version_before(major: int, minor: int) -> bool:
234236
235237
:return: True if version at or above major.minor, else False.
236238
"""
237-
cur_version = cmdstan_version()
239+
cur_version = None
240+
if info is None:
241+
cur_version = cmdstan_version()
242+
else:
243+
if (
244+
info['stan_version_major'] is not None
245+
and info['stan_version_minor'] is not None
246+
):
247+
cur_version = (
248+
int(info['stan_version_major']),
249+
int(info['stan_version_minor']),
250+
)
238251
if cur_version is None:
239252
get_logger().info(
240253
'Cannot determine whether version is before %d.%d.', major, minor

docsrc/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ can be used to override these defaults:
144144

145145
.. code-block:: bash
146146
147-
install_cmdstan -d my_local_cmdstan -v 2.20.0
147+
install_cmdstan -d my_local_cmdstan -v 2.27.0
148148
ls -F my_local_cmdstan
149149
150150
DIY Installation

docsrc/overview.rst

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@ It wraps the
99
command line interface in a small set of
1010
Python classes which provide methods to do analysis and manage the resulting
1111
set of model, data, and posterior estimates.
12-
13-
CmdStanPy is a lightweight interface in that it is designed to use minimal
14-
memory beyond what is used by CmdStan itself to do inference given
15-
and model and data.It runs and records an analysis, but the user chooses
16-
whether or not to instantiate the results in memory,
17-
thus CmdStanPy has the potential to fit more complex models
12+
It is lightweight in that it uses minimal
13+
memory beyond the memory used by CmdStan.
14+
CmdStanPy runs CmdStan, but only instantiates the resulting inference
15+
objects in memory upon request.
16+
Thus CmdStanPy has the potential to fit more complex models
1817
to larger datasets than might be possible in PyStan or RStan.
19-
It manages the set of CmdStan input and output files and provides
20-
methods and options which allow the user to save these files
21-
to a specific filepath.
22-
By default, CmdStan output files are written to a temporary directory
23-
in order to avoid filling up the user's filesystem.
18+
19+
CmdStan is a file-based interface.
20+
CmdStanPy manages the Stan program files and the CmdStan output files.
21+
The user can specify the output directory for the CmdStan outputs,
22+
otherwise the files will be written to a
23+
temporary filesystem which persists throughout the session.
24+
This allows the user to test and develop models prospectively,
25+
following the Bayesian workflow.
26+
27+
2428

test/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
680680
bern_model.sample(
681681
data=jdata,
682682
chains=2,
683-
chain_ids=[6,7],
683+
chain_ids=[6, 7],
684684
iter_warmup=100,
685685
iter_sampling=100,
686686
force_one_process_per_chain=True,

0 commit comments

Comments
 (0)