@@ -943,39 +943,35 @@ def sample(
943943 one_process_per_chain = True
944944 assert isinstance (self .exe_file , str ) # make typechecker happy
945945 info_dict = self .exe_info ()
946- stan_threads = info_dict .get ('STAN_THREADS' )
947- if stan_threads is not None :
948- stan_threads = stan_threads .lower ()
946+ stan_threads = info_dict .get ('STAN_THREADS' , 'false' ).lower ()
949947 if (
950948 force_one_process_per_chain is None
951949 and not cmdstan_version_before (2 , 28 , info_dict )
952- and stan_threads is not None
953950 and stan_threads == 'true'
954951 ):
955952 one_process_per_chain = False
956953 num_threads = parallel_chains * num_threads
957954 parallel_procs = 1
958955 if force_one_process_per_chain is False :
959- if cmdstan_version_before (2 , 28 , info_dict ):
956+ if not cmdstan_version_before (2 , 28 , info_dict ):
957+ one_process_per_chain = False
958+ num_threads = parallel_chains * num_threads
959+ parallel_procs = 1
960+ if stan_threads == 'false' :
961+ get_logger ().warning (
962+ 'Stan program not compiled for threading, '
963+ 'process will run chains sequentially. '
964+ 'For multi-chain parallelization, recompile '
965+ 'the model with argument '
966+ '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.'
967+ )
968+ else :
960969 get_logger ().warning (
961970 'Installed version of CmdStan cannot multi-process '
962971 'chains, will run %d processes. '
963972 'Run "install_cmdstan" to upgrade to latest version.' ,
964973 chains ,
965974 )
966- elif stan_threads is None or stan_threads == 'false' :
967- get_logger ().warning (
968- 'Stan program not compiled for threading, '
969- 'will run %d processes. '
970- 'Recompile model and specify argument '
971- '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.' ,
972- chains ,
973- )
974- else :
975- one_process_per_chain = False
976- num_threads = parallel_chains * num_threads
977- parallel_procs = 1
978-
979975 os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
980976
981977 if show_console :
@@ -1405,8 +1401,8 @@ def _run_cmdstan(
14051401 logger_prefix = 'CmdStan'
14061402 console_prefix = ''
14071403 if runset .one_process_per_chain :
1408- logger_prefix = 'Chain [{}]' .format (idx + runset .chain_ids [0 ])
1409- console_prefix = 'Chain [{}] ' .format (idx + runset .chain_ids [0 ])
1404+ logger_prefix = 'Chain [{}]' .format (runset .chain_ids [idx ])
1405+ console_prefix = 'Chain [{}] ' .format (runset .chain_ids [idx ])
14101406
14111407 cmd = runset .cmd (idx )
14121408 get_logger ().debug ('CmdStan args: %s' , cmd )
0 commit comments