99import sys
1010from collections import OrderedDict
1111from concurrent .futures import ThreadPoolExecutor
12+ from io import StringIO
1213from multiprocessing import cpu_count
1314from pathlib import Path
1415from typing import Any , Callable , Dict , List , Mapping , Optional , Union
4041 cmdstan_version_before ,
4142 do_command ,
4243 get_logger ,
43- model_info ,
4444 returncode_msg ,
4545)
4646
@@ -430,6 +430,29 @@ def compile(
430430 else :
431431 get_logger ().error ('model compilation failed' )
432432
433+ def exe_info (self ) -> Optional [Dict [str , str ]]:
434+ """
435+ Run model with option 'info'. Parse output statements, which all
436+ have form 'key = value' into a Dict.
437+ If exe file compiled with CmdStan < 2.27, calling model with
438+ option 'info' fail and method returns None.
439+ """
440+ if self .exe_file is None :
441+ return None
442+ try :
443+ info = StringIO ()
444+ do_command (cmd = [self .exe_file , 'info' ], fd_out = info )
445+ result : Dict [str , Any ] = {}
446+ lines = info .getvalue ().split ('\n ' )
447+ for line in lines :
448+ kv_pair = [x .strip () for x in line .split ('=' )]
449+ if len (kv_pair ) != 2 :
450+ continue
451+ result [kv_pair [0 ]] = kv_pair [1 ]
452+ return result
453+ except RuntimeError :
454+ return None
455+
433456 def optimize (
434457 self ,
435458 data : Union [Mapping [str , Any ], str , None ] = None ,
@@ -660,13 +683,20 @@ def sample(
660683 :param chains: Number of sampler chains, must be a positive integer.
661684
662685 :param parallel_chains: Number of processes to run in parallel. Must be
663- a positive integer. Defaults to :func:`multiprocessing.cpu_count`.
686+ a positive integer. Defaults to :func:`multiprocessing.cpu_count`,
687+ i.e., it will only run as many chains in parallel as there are
688+ cores on the machine. Note that CmdStan 2.28 and higher can run
689+ all chains in parallel providing that the model was compiled with
690+ threading support.
664691
665692 :param threads_per_chain: The number of threads to use in parallelized
666693 sections within an MCMC chain (e.g., when using the Stan functions
667694 ``reduce_sum()`` or ``map_rect()``). This will only have an effect
668- if the model was compiled with threading support. The total number
669- of threads used will be ``parallel_chains * threads_per_chain``.
695+ if the model was compiled with threading support. For such models,
696+ CmdStan version 2.28 and higher will run all chains in parallel
697+ from within a single process. The total number of threads used
698+ will be ``parallel_chains * threads_per_chain``, where the default
699+ value for parallel_chains is the number of cpus, not chains.
670700
671701 :param seed: The seed for random number generator. Must be an integer
672702 between 0 and 2^32 - 1. If unspecified,
@@ -916,7 +946,7 @@ def sample(
916946 and not cmdstan_version_before (2 , 28 )
917947 ):
918948 assert isinstance (self .exe_file , str ) # make typechecker happy
919- info_dict = model_info ( self .exe_file )
949+ info_dict = self .exe_info ( )
920950 if (
921951 info_dict is not None
922952 and info_dict ['STAN_THREADS' ] == 'true'
@@ -936,26 +966,33 @@ def sample(
936966 )
937967 os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
938968
939- # progress reporting
940- iter_total = 0
941- if iter_warmup is None :
942- iter_total += _CMDSTAN_WARMUP
943- else :
944- iter_total += iter_warmup
945- if iter_sampling is None :
946- iter_total += _CMDSTAN_SAMPLING
947- else :
948- iter_total += iter_sampling
949- if refresh is None :
950- refresh = _CMDSTAN_REFRESH
951- iter_total = iter_total // refresh + 2
952-
953969 if show_console :
954970 show_progress = False
955971 else :
956972 show_progress = show_progress and progbar .allow_show_progress ()
957973 get_logger ().info ('CmdStan start procesing' )
958974
975+ progress_hook : Optional [Callable [[str , int ], None ]] = None
976+ if show_progress :
977+ iter_total = 0
978+ if iter_warmup is None :
979+ iter_total += _CMDSTAN_WARMUP
980+ else :
981+ iter_total += iter_warmup
982+ if iter_sampling is None :
983+ iter_total += _CMDSTAN_SAMPLING
984+ else :
985+ iter_total += iter_sampling
986+ if refresh is None :
987+ refresh = _CMDSTAN_REFRESH
988+ iter_total = iter_total // refresh + 2
989+
990+ progress_hook = self ._wrap_sampler_progress_hook (
991+ one_process_per_chain = one_process_per_chain ,
992+ chains = chains ,
993+ offset = chain_ids [0 ],
994+ total = iter_total ,
995+ )
959996 runset = RunSet (
960997 args = args ,
961998 chains = chains ,
@@ -971,7 +1008,7 @@ def sample(
9711008 idx = i ,
9721009 show_progress = show_progress ,
9731010 show_console = show_console ,
974- iter_total = iter_total ,
1011+ progress_hook = progress_hook ,
9751012 )
9761013 if show_progress :
9771014 # advance terminal window cursor past progress bars
@@ -1331,13 +1368,14 @@ def variational(
13311368 vb = CmdStanVB (runset )
13321369 return vb
13331370
1371+ # pylint: disable=no-self-use
13341372 def _run_cmdstan (
13351373 self ,
13361374 runset : RunSet ,
13371375 idx : int ,
13381376 show_progress : bool = False ,
13391377 show_console : bool = False ,
1340- iter_total : int = 0 ,
1378+ progress_hook : Optional [ Callable [[ str , int ], None ]] = None ,
13411379 ) -> None :
13421380 """
13431381 Helper function which encapsulates call to CmdStan.
@@ -1357,15 +1395,6 @@ def _run_cmdstan(
13571395 logger_prefix = 'Chain [{}]' .format (idx + 1 )
13581396 console_prefix = 'Chain [{}] ' .format (idx + 1 )
13591397
1360- progress_hook : Optional [Callable [[str ], None ]] = None
1361- if show_progress :
1362- progress_hook = self ._wrap_sampler_progress_hook (
1363- one_process_per_chain = runset .one_process_per_chain ,
1364- num_chains = runset .chains ,
1365- chain_id = idx + 1 ,
1366- id_offset = runset .chain_ids [0 ],
1367- total = iter_total ,
1368- )
13691398 cmd = runset .cmd (idx )
13701399 get_logger ().debug ('CmdStan args: %s' , cmd )
13711400
@@ -1390,10 +1419,10 @@ def _run_cmdstan(
13901419 if show_console :
13911420 print (f'{ console_prefix } { line } ' )
13921421 elif progress_hook is not None :
1393- progress_hook (line )
1422+ progress_hook (line , idx )
13941423
13951424 if progress_hook is not None and proc .returncode == 0 :
1396- progress_hook ("Done" )
1425+ progress_hook ("Done" , idx )
13971426
13981427 stdout , _ = proc .communicate ()
13991428 if stdout :
@@ -1437,52 +1466,52 @@ def _run_cmdstan(
14371466 @progbar .wrap_callback
14381467 def _wrap_sampler_progress_hook (
14391468 one_process_per_chain : bool ,
1440- num_chains : int ,
1441- chain_id : int ,
1442- id_offset : int ,
1469+ chains : int ,
1470+ offset : int ,
14431471 total : int ,
1444- ) -> Optional [Callable [[str ], None ]]:
1472+ ) -> Optional [Callable [[str , int ], None ]]:
14451473 """
14461474 Sets up tqdm callback for CmdStan sampler console msgs.
14471475 CmdStan progress messages start with "Iteration", for single chain
14481476 process, "Chain [id] Iteration" for multi-chain processing.
14491477 For the latter, manage array of pbars, update accordingly.
14501478 """
1451- multi_pbars = num_chains > 1 and not one_process_per_chain
1452- num_pbars = num_chains if multi_pbars else 1
1479+ do_match = chains > 1 and not one_process_per_chain
14531480 pat = re .compile (r'Chain \[(\d*)\] (Iteration.*)' )
14541481
1455- pbar : List [tqdm ] = [
1482+ pbars : List [tqdm ] = [
14561483 tqdm (
14571484 total = total ,
14581485 bar_format = "{desc} |{bar}| {elapsed} {postfix[0][value]}" ,
14591486 postfix = [dict (value = "Status" )],
1460- desc = f'chain { chain_id + i } ' ,
1487+ desc = f'chain { offset + i } ' ,
14611488 colour = 'yellow' ,
14621489 )
1463- for i in range (num_pbars )
1490+ for i in range (chains )
14641491 ]
14651492
1466- def progress_hook (line : str ) -> None :
1493+ def progress_hook (line : str , idx : int ) -> None :
14671494 if line == "Done" :
1468- for i in range (num_pbars ):
1469- pbar [i ].postfix [0 ]["value" ] = 'Sampling completed'
1470- pbar [i ].update (total - pbar [i ].n )
1471- pbar [i ].close ()
1472- elif multi_pbars :
1473- match = pat .match (line )
1474- if match :
1475- idx = int (match .group (1 )) - id_offset
1476- pline = match .group (2 ).strip ()
1477- if 'Sampling' in pline :
1478- pbar [idx ].colour = 'blue'
1479- pbar [idx ].update (1 )
1480- pbar [idx ].postfix [0 ]['value' ] = pline
1495+ for i in range (chains ):
1496+ pbars [i ].postfix [0 ]["value" ] = 'Sampling completed'
1497+ pbars [i ].update (total - pbars [i ].n )
1498+ pbars [i ].close ()
14811499 else :
1482- if line .startswith ("Iteration" ):
1483- if 'Sampling' in line :
1484- pbar [0 ].colour = 'blue'
1485- pbar [0 ].update (1 )
1486- pbar [0 ].postfix [0 ]["value" ] = line
1500+ if do_match :
1501+ match = pat .match (line )
1502+ if match :
1503+ idx = int (match .group (1 )) - offset
1504+ mline = match .group (2 ).strip ()
1505+ else :
1506+ return
1507+ else :
1508+ if line .startswith ("Iteration" ):
1509+ mline = line
1510+ else :
1511+ return
1512+ if 'Sampling' in mline :
1513+ pbars [idx ].colour = 'blue'
1514+ pbars [idx ].update (1 )
1515+ pbars [idx ].postfix [0 ]["value" ] = mline
14871516
14881517 return progress_hook
0 commit comments