@@ -434,8 +434,8 @@ def exe_info(self) -> Dict[str, str]:
434434 """
435435 Run model with option 'info'. Parse output statements, which all
436436 have form 'key = value' into a Dict.
437- If exe file compiled with CmdStan < 2.27, calling model with
438- option 'info' fail and method returns None .
437+ If exe file compiled with CmdStan < 2.27, option 'info' isn't
438+ available and the method returns an empty dictionary .
439439 """
440440 result : Dict [str , Any ] = {}
441441 if self .exe_file is None :
@@ -941,19 +941,22 @@ def sample(
941941 parallel_procs = parallel_chains
942942 num_threads = threads_per_chain
943943 one_process_per_chain = True
944+ assert isinstance (self .exe_file , str ) # make typechecker happy
945+ info_dict = self .exe_info ()
946+ stan_threads = info_dict .get ('STAN_THREADS' )
947+ if stan_threads is not None :
948+ stan_threads = stan_threads .lower ()
944949 if (
945950 force_one_process_per_chain is None
946- and not cmdstan_version_before (2 , 28 )
951+ and not cmdstan_version_before (2 , 28 , info_dict )
952+ and stan_threads == 'true'
947953 ):
948- assert isinstance (self .exe_file , str ) # make typechecker happy
949- info_dict = self .exe_info ()
950- if info_dict .get ('STAN_THREADS' ) == 'true' :
951- one_process_per_chain = False
952- num_threads = parallel_chains * num_threads
953- parallel_procs = 1
954+ one_process_per_chain = False
955+ num_threads = parallel_chains * num_threads
956+ parallel_procs = 1
954957 elif (
955958 force_one_process_per_chain is False
956- and cmdstan_version_before (2 , 28 )
959+ and cmdstan_version_before (2 , 28 , info_dict )
957960 ):
958961 get_logger ().warning (
959962 'Installed version of CmdStan cannot multi-process chains, '
@@ -985,9 +988,7 @@ def sample(
985988 iter_total = iter_total // refresh + 2
986989
987990 progress_hook = self ._wrap_sampler_progress_hook (
988- one_process_per_chain = one_process_per_chain ,
989- chains = chains ,
990- offset = chain_ids [0 ],
991+ chain_ids = chain_ids ,
991992 total = iter_total ,
992993 )
993994 runset = RunSet (
@@ -1007,7 +1008,9 @@ def sample(
10071008 show_console = show_console ,
10081009 progress_hook = progress_hook ,
10091010 )
1010- if show_progress :
1011+ if show_progress and progress_hook is not None :
1012+ progress_hook ("Done" , - 1 ) # -1 == all chains finished
1013+
10111014 # advance terminal window cursor past progress bars
10121015 term_size : os .terminal_size = shutil .get_terminal_size (
10131016 fallback = (80 , 24 )
@@ -1381,6 +1384,7 @@ def _run_cmdstan(
13811384 Args 'show_progress' and 'show_console' allow use of progress bar,
13821385 streaming output to console, respectively.
13831386 """
1387+ get_logger ().debug ('idx %d' , idx )
13841388 get_logger ().debug (
13851389 'running CmdStan, num_threads: %s' ,
13861390 str (os .environ .get ('STAN_NUM_THREADS' )),
@@ -1389,8 +1393,8 @@ def _run_cmdstan(
13891393 logger_prefix = 'CmdStan'
13901394 console_prefix = ''
13911395 if runset .one_process_per_chain :
1392- logger_prefix = 'Chain [{}]' .format (idx + 1 )
1393- console_prefix = 'Chain [{}] ' .format (idx + 1 )
1396+ logger_prefix = 'Chain [{}]' .format (idx + runset . chain_ids [ 0 ] )
1397+ console_prefix = 'Chain [{}] ' .format (idx + runset . chain_ids [ 0 ] )
13941398
13951399 cmd = runset .cmd (idx )
13961400 get_logger ().debug ('CmdStan args: %s' , cmd )
@@ -1418,10 +1422,10 @@ def _run_cmdstan(
14181422 elif progress_hook is not None :
14191423 progress_hook (line , idx )
14201424
1421- if progress_hook is not None and proc .returncode == 0 :
1422- progress_hook ("Done" , idx )
1423-
14241425 stdout , _ = proc .communicate ()
1426+ retcode = proc .returncode
1427+ runset ._set_retcode (idx , retcode )
1428+
14251429 if stdout :
14261430 fd_out .write (stdout )
14271431 if show_console :
@@ -1434,15 +1438,15 @@ def _run_cmdstan(
14341438 raise RuntimeError (msg ) from e
14351439 finally :
14361440 fd_out .close ()
1441+
14371442 if not show_progress :
14381443 get_logger ().info ('%s done processing' , logger_prefix )
14391444
1440- runset ._set_retcode (idx , proc .returncode )
1441- if proc .returncode != 0 :
1442- retcode_summary = returncode_msg (proc .returncode )
1445+ if retcode != 0 :
1446+ retcode_summary = returncode_msg (retcode )
14431447 serror = ''
14441448 try :
1445- serror = os .strerror (proc . returncode )
1449+ serror = os .strerror (retcode )
14461450 except (ArithmeticError , ValueError ):
14471451 pass
14481452 get_logger ().error (
@@ -1462,9 +1466,7 @@ def _run_cmdstan(
14621466 @staticmethod
14631467 @progbar .wrap_callback
14641468 def _wrap_sampler_progress_hook (
1465- one_process_per_chain : bool ,
1466- chains : int ,
1467- offset : int ,
1469+ chain_ids : List [int ],
14681470 total : int ,
14691471 ) -> Optional [Callable [[str , int ], None ]]:
14701472 """
@@ -1473,39 +1475,34 @@ def _wrap_sampler_progress_hook(
14731475 process, "Chain [id] Iteration" for multi-chain processing.
14741476 For the latter, manage array of pbars, update accordingly.
14751477 """
1476- do_match = chains > 1 and not one_process_per_chain
14771478 pat = re .compile (r'Chain \[(\d*)\] (Iteration.*)' )
1478-
1479- pbars : List [tqdm ] = [
1480- tqdm (
1479+ pbars : Dict [int , tqdm ] = {
1480+ chain_id : tqdm (
14811481 total = total ,
14821482 bar_format = "{desc} |{bar}| {elapsed} {postfix[0][value]}" ,
14831483 postfix = [dict (value = "Status" )],
1484- desc = f'chain { offset + i } ' ,
1484+ desc = f'chain { chain_id } ' ,
14851485 colour = 'yellow' ,
14861486 )
1487- for i in range ( chains )
1488- ]
1487+ for chain_id in chain_ids
1488+ }
14891489
14901490 def progress_hook (line : str , idx : int ) -> None :
14911491 if line == "Done" :
1492- for i in range ( chains ):
1493- pbars [ i ] .postfix [0 ]["value" ] = 'Sampling completed'
1494- pbars [ i ] .update (total - pbars [ i ] .n )
1495- pbars [ i ] .close ()
1492+ for pbar in pbars . values ( ):
1493+ pbar .postfix [0 ]["value" ] = 'Sampling completed'
1494+ pbar .update (total - pbar .n )
1495+ pbar .close ()
14961496 else :
1497- if do_match :
1498- match = pat . match ( line )
1499- if match :
1500- idx = int ( match .group (1 )) - offset
1501- mline = match . group ( 2 ). strip ()
1502- else :
1503- return
1497+ match = pat . match ( line )
1498+ if match :
1499+ idx = int ( match . group ( 1 ))
1500+ mline = match .group (2 ). strip ()
1501+ elif line . startswith ( "Iteration" ):
1502+ mline = line
1503+ idx = chain_ids [ idx ]
15041504 else :
1505- if line .startswith ("Iteration" ):
1506- mline = line
1507- else :
1508- return
1505+ return
15091506 if 'Sampling' in mline :
15101507 pbars [idx ].colour = 'blue'
15111508 pbars [idx ].update (1 )
0 commit comments