Skip to content

Commit 7be7dc0

Browse files
committed
changes per code review, more unit tests
1 parent 5a1b9c6 commit 7be7dc0

4 files changed

Lines changed: 145 additions & 86 deletions

File tree

cmdstanpy/model.py

Lines changed: 89 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
from collections import OrderedDict
1111
from concurrent.futures import ThreadPoolExecutor
12+
from io import StringIO
1213
from multiprocessing import cpu_count
1314
from pathlib import Path
1415
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
@@ -40,7 +41,6 @@
4041
cmdstan_version_before,
4142
do_command,
4243
get_logger,
43-
model_info,
4444
returncode_msg,
4545
)
4646

@@ -430,6 +430,29 @@ def compile(
430430
else:
431431
get_logger().error('model compilation failed')
432432

433+
def exe_info(self) -> Optional[Dict[str, str]]:
434+
"""
435+
Run model with option 'info'. Parse output statements, which all
436+
have form 'key = value' into a Dict.
437+
If exe file compiled with CmdStan < 2.27, calling model with
438+
option 'info' fail and method returns None.
439+
"""
440+
if self.exe_file is None:
441+
return None
442+
try:
443+
info = StringIO()
444+
do_command(cmd=[self.exe_file, 'info'], fd_out=info)
445+
result: Dict[str, Any] = {}
446+
lines = info.getvalue().split('\n')
447+
for line in lines:
448+
kv_pair = [x.strip() for x in line.split('=')]
449+
if len(kv_pair) != 2:
450+
continue
451+
result[kv_pair[0]] = kv_pair[1]
452+
return result
453+
except RuntimeError:
454+
return None
455+
433456
def optimize(
434457
self,
435458
data: Union[Mapping[str, Any], str, None] = None,
@@ -660,13 +683,20 @@ def sample(
660683
:param chains: Number of sampler chains, must be a positive integer.
661684
662685
:param parallel_chains: Number of processes to run in parallel. Must be
663-
a positive integer. Defaults to :func:`multiprocessing.cpu_count`.
686+
a positive integer. Defaults to :func:`multiprocessing.cpu_count`,
687+
i.e., it will only run as many chains in parallel as there are
688+
cores on the machine. Note that CmdStan 2.28 and higher can run
689+
all chains in parallel providing that the model was compiled with
690+
threading support.
664691
665692
:param threads_per_chain: The number of threads to use in parallelized
666693
sections within an MCMC chain (e.g., when using the Stan functions
667694
``reduce_sum()`` or ``map_rect()``). This will only have an effect
668-
if the model was compiled with threading support. The total number
669-
of threads used will be ``parallel_chains * threads_per_chain``.
695+
if the model was compiled with threading support. For such models,
696+
CmdStan version 2.28 and higher will run all chains in parallel
697+
from within a single process. The total number of threads used
698+
will be ``parallel_chains * threads_per_chain``, where the default
699+
value for parallel_chains is the number of cpus, not chains.
670700
671701
:param seed: The seed for random number generator. Must be an integer
672702
between 0 and 2^32 - 1. If unspecified,
@@ -916,7 +946,7 @@ def sample(
916946
and not cmdstan_version_before(2, 28)
917947
):
918948
assert isinstance(self.exe_file, str) # make typechecker happy
919-
info_dict = model_info(self.exe_file)
949+
info_dict = self.exe_info()
920950
if (
921951
info_dict is not None
922952
and info_dict['STAN_THREADS'] == 'true'
@@ -936,26 +966,33 @@ def sample(
936966
)
937967
os.environ['STAN_NUM_THREADS'] = str(num_threads)
938968

939-
# progress reporting
940-
iter_total = 0
941-
if iter_warmup is None:
942-
iter_total += _CMDSTAN_WARMUP
943-
else:
944-
iter_total += iter_warmup
945-
if iter_sampling is None:
946-
iter_total += _CMDSTAN_SAMPLING
947-
else:
948-
iter_total += iter_sampling
949-
if refresh is None:
950-
refresh = _CMDSTAN_REFRESH
951-
iter_total = iter_total // refresh + 2
952-
953969
if show_console:
954970
show_progress = False
955971
else:
956972
show_progress = show_progress and progbar.allow_show_progress()
957973
get_logger().info('CmdStan start procesing')
958974

975+
progress_hook: Optional[Callable[[str, int], None]] = None
976+
if show_progress:
977+
iter_total = 0
978+
if iter_warmup is None:
979+
iter_total += _CMDSTAN_WARMUP
980+
else:
981+
iter_total += iter_warmup
982+
if iter_sampling is None:
983+
iter_total += _CMDSTAN_SAMPLING
984+
else:
985+
iter_total += iter_sampling
986+
if refresh is None:
987+
refresh = _CMDSTAN_REFRESH
988+
iter_total = iter_total // refresh + 2
989+
990+
progress_hook = self._wrap_sampler_progress_hook(
991+
one_process_per_chain=one_process_per_chain,
992+
chains=chains,
993+
offset=chain_ids[0],
994+
total=iter_total,
995+
)
959996
runset = RunSet(
960997
args=args,
961998
chains=chains,
@@ -971,7 +1008,7 @@ def sample(
9711008
idx=i,
9721009
show_progress=show_progress,
9731010
show_console=show_console,
974-
iter_total=iter_total,
1011+
progress_hook=progress_hook,
9751012
)
9761013
if show_progress:
9771014
# advance terminal window cursor past progress bars
@@ -1331,13 +1368,14 @@ def variational(
13311368
vb = CmdStanVB(runset)
13321369
return vb
13331370

1371+
# pylint: disable=no-self-use
13341372
def _run_cmdstan(
13351373
self,
13361374
runset: RunSet,
13371375
idx: int,
13381376
show_progress: bool = False,
13391377
show_console: bool = False,
1340-
iter_total: int = 0,
1378+
progress_hook: Optional[Callable[[str, int], None]] = None,
13411379
) -> None:
13421380
"""
13431381
Helper function which encapsulates call to CmdStan.
@@ -1357,15 +1395,6 @@ def _run_cmdstan(
13571395
logger_prefix = 'Chain [{}]'.format(idx + 1)
13581396
console_prefix = 'Chain [{}] '.format(idx + 1)
13591397

1360-
progress_hook: Optional[Callable[[str], None]] = None
1361-
if show_progress:
1362-
progress_hook = self._wrap_sampler_progress_hook(
1363-
one_process_per_chain=runset.one_process_per_chain,
1364-
num_chains=runset.chains,
1365-
chain_id=idx + 1,
1366-
id_offset=runset.chain_ids[0],
1367-
total=iter_total,
1368-
)
13691398
cmd = runset.cmd(idx)
13701399
get_logger().debug('CmdStan args: %s', cmd)
13711400

@@ -1390,10 +1419,10 @@ def _run_cmdstan(
13901419
if show_console:
13911420
print(f'{console_prefix}{line}')
13921421
elif progress_hook is not None:
1393-
progress_hook(line)
1422+
progress_hook(line, idx)
13941423

13951424
if progress_hook is not None and proc.returncode == 0:
1396-
progress_hook("Done")
1425+
progress_hook("Done", idx)
13971426

13981427
stdout, _ = proc.communicate()
13991428
if stdout:
@@ -1437,52 +1466,52 @@ def _run_cmdstan(
14371466
@progbar.wrap_callback
14381467
def _wrap_sampler_progress_hook(
14391468
one_process_per_chain: bool,
1440-
num_chains: int,
1441-
chain_id: int,
1442-
id_offset: int,
1469+
chains: int,
1470+
offset: int,
14431471
total: int,
1444-
) -> Optional[Callable[[str], None]]:
1472+
) -> Optional[Callable[[str, int], None]]:
14451473
"""
14461474
Sets up tqdm callback for CmdStan sampler console msgs.
14471475
CmdStan progress messages start with "Iteration", for single chain
14481476
process, "Chain [id] Iteration" for multi-chain processing.
14491477
For the latter, manage array of pbars, update accordingly.
14501478
"""
1451-
multi_pbars = num_chains > 1 and not one_process_per_chain
1452-
num_pbars = num_chains if multi_pbars else 1
1479+
do_match = chains > 1 and not one_process_per_chain
14531480
pat = re.compile(r'Chain \[(\d*)\] (Iteration.*)')
14541481

1455-
pbar: List[tqdm] = [
1482+
pbars: List[tqdm] = [
14561483
tqdm(
14571484
total=total,
14581485
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
14591486
postfix=[dict(value="Status")],
1460-
desc=f'chain {chain_id + i}',
1487+
desc=f'chain {offset + i}',
14611488
colour='yellow',
14621489
)
1463-
for i in range(num_pbars)
1490+
for i in range(chains)
14641491
]
14651492

1466-
def progress_hook(line: str) -> None:
1493+
def progress_hook(line: str, idx: int) -> None:
14671494
if line == "Done":
1468-
for i in range(num_pbars):
1469-
pbar[i].postfix[0]["value"] = 'Sampling completed'
1470-
pbar[i].update(total - pbar[i].n)
1471-
pbar[i].close()
1472-
elif multi_pbars:
1473-
match = pat.match(line)
1474-
if match:
1475-
idx = int(match.group(1)) - id_offset
1476-
pline = match.group(2).strip()
1477-
if 'Sampling' in pline:
1478-
pbar[idx].colour = 'blue'
1479-
pbar[idx].update(1)
1480-
pbar[idx].postfix[0]['value'] = pline
1495+
for i in range(chains):
1496+
pbars[i].postfix[0]["value"] = 'Sampling completed'
1497+
pbars[i].update(total - pbars[i].n)
1498+
pbars[i].close()
14811499
else:
1482-
if line.startswith("Iteration"):
1483-
if 'Sampling' in line:
1484-
pbar[0].colour = 'blue'
1485-
pbar[0].update(1)
1486-
pbar[0].postfix[0]["value"] = line
1500+
if do_match:
1501+
match = pat.match(line)
1502+
if match:
1503+
idx = int(match.group(1)) - offset
1504+
mline = match.group(2).strip()
1505+
else:
1506+
return
1507+
else:
1508+
if line.startswith("Iteration"):
1509+
mline = line
1510+
else:
1511+
return
1512+
if 'Sampling' in mline:
1513+
pbars[idx].colour = 'blue'
1514+
pbars[idx].update(1)
1515+
pbars[idx].postfix[0]["value"] = mline
14871516

14881517
return progress_hook

cmdstanpy/utils.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import tempfile
1515
from collections import OrderedDict
1616
from collections.abc import Collection
17-
from io import StringIO
1817
from typing import (
1918
Any,
2019
Callable,
@@ -226,28 +225,6 @@ def cmdstan_version() -> Optional[Tuple[int, ...]]:
226225
return tuple(int(x) for x in splits[0:2])
227226

228227

229-
def model_info(model_exe: str) -> Optional[Dict[str, str]]:
230-
"""
231-
Run model with option 'info'. Parse output statements, which all
232-
have form 'key = value' into a Dict.
233-
If exe file compiled with CmdStan < 2.27, calling model with
234-
option 'info' fail and method returns None.
235-
"""
236-
try:
237-
info = StringIO()
238-
do_command(cmd=[model_exe, 'info'], fd_out=info)
239-
result: Dict[str, Any] = {}
240-
lines = info.getvalue().split('\n')
241-
for line in lines:
242-
kv_pair = [x.strip() for x in line.split('=')]
243-
if len(kv_pair) != 2:
244-
continue
245-
result[kv_pair[0]] = kv_pair[1]
246-
return result
247-
except RuntimeError:
248-
return None
249-
250-
251228
def cmdstan_version_before(major: int, minor: int) -> bool:
252229
"""
253230
Check that CmdStan version is less than Major.minor version.

test/test_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ def test_cpp_options(self):
152152
self.assertEqual(cpp_opts['STAN_MPI'], 'TRUE')
153153
self.assertEqual(cpp_opts['STAN_THREADS'], 'TRUE')
154154

155+
def test_model_info(self):
156+
# copy so that parallel compile tests don't mess up exe
157+
b2_filename = os.path.join(DATAFILES_PATH, 'b2.stan')
158+
b2_file = shutil.copyfile(BERN_STAN, b2_filename)
159+
cpp_opts = {'STAN_THREADS': 'TRUE'}
160+
model = CmdStanModel(stan_file=b2_file, cpp_options=cpp_opts)
161+
if model.exe_file is not None and os.path.exists(model.exe_file):
162+
os.remove(model.exe_file)
163+
null_dict = model.exe_info()
164+
self.assertTrue(null_dict is None)
165+
166+
model.compile(force=True)
167+
info_dict = model.exe_info()
168+
self.assertTrue(info_dict is not None)
169+
self.assertEqual(info_dict['STAN_THREADS'], 'true')
170+
os.remove(model.stan_file)
171+
os.remove(model.exe_file)
172+
155173
def test_model_paths(self):
156174
# pylint: disable=unused-variable
157175
model = CmdStanModel(stan_file=BERN_STAN) # instantiates exe

test/test_sample.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cmdstanpy.cmdstan_args import CmdStanArgs, Method, SamplerArgs
2727
from cmdstanpy.model import CmdStanModel
2828
from cmdstanpy.stanfit import CmdStanMCMC, RunSet, from_csv
29-
from cmdstanpy.utils import EXTENSION, cmdstan_version_before, model_info
29+
from cmdstanpy.utils import EXTENSION, cmdstan_version_before
3030

3131
HERE = os.path.dirname(os.path.abspath(__file__))
3232
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -447,9 +447,10 @@ def test_multi_proc_threads(self):
447447
force=True,
448448
cpp_options={'STAN_THREADS': 'TRUE'},
449449
)
450-
info_dict = model_info(logistic_model.exe_file)
450+
info_dict = logistic_model.exe_info()
451451
self.assertTrue(info_dict is not None)
452452
self.assertTrue('STAN_THREADS' in info_dict)
453+
self.assertEqual(info_dict['STAN_THREADS'], 'true')
453454

454455
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')
455456
with LogCapture() as log:
@@ -649,13 +650,47 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
649650
sys_stderr = io.StringIO() # tqdm prints to stderr
650651
with contextlib.redirect_stderr(sys_stderr):
651652
bern_model.sample(
652-
data=jdata, chains=2, parallel_chains=2, show_progress=True
653+
data=jdata,
654+
chains=2,
655+
iter_warmup=100,
656+
iter_sampling=100,
657+
show_progress=True,
653658
)
654659
console = sys_stderr.getvalue()
655660
self.assertTrue('chain 1' in console)
656661
self.assertTrue('chain 2' in console)
657662
self.assertTrue('Sampling completed' in console)
658663

664+
sys_stderr = io.StringIO() # tqdm prints to stderr
665+
with contextlib.redirect_stderr(sys_stderr):
666+
bern_model.sample(
667+
data=jdata,
668+
chains=7,
669+
iter_warmup=100,
670+
iter_sampling=100,
671+
show_progress=True,
672+
)
673+
console = sys_stderr.getvalue()
674+
self.assertTrue('chain 6' in console)
675+
self.assertTrue('chain 7' in console)
676+
self.assertTrue('Sampling completed' in console)
677+
sys_stderr = io.StringIO() # tqdm prints to stderr
678+
679+
with contextlib.redirect_stderr(sys_stderr):
680+
bern_model.sample(
681+
data=jdata,
682+
chains=2,
683+
chain_ids=[6,7],
684+
iter_warmup=100,
685+
iter_sampling=100,
686+
force_one_process_per_chain=True,
687+
show_progress=True,
688+
)
689+
console = sys_stderr.getvalue()
690+
self.assertTrue('chain 6' in console)
691+
self.assertTrue('chain 7' in console)
692+
self.assertTrue('Sampling completed' in console)
693+
659694

660695
class CmdStanMCMCTest(unittest.TestCase):
661696
# pylint: disable=too-many-public-methods

0 commit comments

Comments
 (0)