5959)
6060from cmdstanpy .utils import (
6161 EXTENSION ,
62- MaybeDictToFilePath ,
6362 SanitizedOrTmpFilePath ,
6463 cmdstan_path ,
6564 cmdstan_version ,
6867 get_logger ,
6968 returncode_msg ,
7069)
70+ from cmdstanpy .utils .filesystem import temp_inits , temp_single_json
7171
7272from . import progress as progbar
7373
@@ -573,7 +573,7 @@ def optimize(
573573 self ,
574574 data : Union [Mapping [str , Any ], str , os .PathLike , None ] = None ,
575575 seed : Optional [int ] = None ,
576- inits : Union [Dict [str , float ], float , str , os .PathLike , None ] = None ,
576+ inits : Union [Mapping [str , Any ], float , str , os .PathLike , None ] = None ,
577577 output_dir : OptionalPath = None ,
578578 sig_figs : Optional [int ] = None ,
579579 save_profile : bool = False ,
@@ -722,7 +722,9 @@ def optimize(
722722 "in CmdStan 2.32 and above."
723723 )
724724
725- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
725+ with temp_single_json (data ) as _data , temp_inits (
726+ inits , allow_multiple = False
727+ ) as _inits :
726728 args = CmdStanArgs (
727729 self ._name ,
728730 self ._exe_file ,
@@ -766,7 +768,14 @@ def sample(
766768 threads_per_chain : Optional [int ] = None ,
767769 seed : Union [int , List [int ], None ] = None ,
768770 chain_ids : Union [int , List [int ], None ] = None ,
769- inits : Union [Dict [str , float ], float , str , List [str ], None ] = None ,
771+ inits : Union [
772+ Mapping [str , Any ],
773+ float ,
774+ str ,
775+ List [str ],
776+ List [Mapping [str , Any ]],
777+ None ,
778+ ] = None ,
770779 iter_warmup : Optional [int ] = None ,
771780 iter_sampling : Optional [int ] = None ,
772781 save_warmup : bool = False ,
@@ -1003,6 +1012,69 @@ def sample(
10031012 chains
10041013 )
10051014 )
1015+
1016+ if parallel_chains is None :
1017+ parallel_chains = max (min (cpu_count (), chains ), 1 )
1018+ elif parallel_chains > chains :
1019+ get_logger ().info (
1020+ 'Requested %u parallel_chains but only %u required, '
1021+ 'will run all chains in parallel.' ,
1022+ parallel_chains ,
1023+ chains ,
1024+ )
1025+ parallel_chains = chains
1026+ elif parallel_chains < 1 :
1027+ raise ValueError (
1028+ 'Argument parallel_chains must be a positive integer, '
1029+ 'found {}.' .format (parallel_chains )
1030+ )
1031+ if threads_per_chain is None :
1032+ threads_per_chain = 1
1033+ if threads_per_chain < 1 :
1034+ raise ValueError (
1035+ 'Argument threads_per_chain must be a positive integer, '
1036+ 'found {}.' .format (threads_per_chain )
1037+ )
1038+
1039+ parallel_procs = parallel_chains
1040+ num_threads = threads_per_chain
1041+ one_process_per_chain = True
1042+ info_dict = self .exe_info ()
1043+ stan_threads = info_dict .get ('STAN_THREADS' , 'false' ).lower ()
1044+ # run multi-chain sampler unless algo is fixed_param or 1 chain
1045+ if chains == 1 :
1046+ force_one_process_per_chain = True
1047+
1048+ if (
1049+ force_one_process_per_chain is None
1050+ and not cmdstan_version_before (2 , 28 , info_dict )
1051+ and stan_threads == 'true'
1052+ ):
1053+ one_process_per_chain = False
1054+ num_threads = parallel_chains * num_threads
1055+ parallel_procs = 1
1056+ if force_one_process_per_chain is False :
1057+ if not cmdstan_version_before (2 , 28 , info_dict ):
1058+ one_process_per_chain = False
1059+ num_threads = parallel_chains * num_threads
1060+ parallel_procs = 1
1061+ if stan_threads == 'false' :
1062+ get_logger ().warning (
1063+ 'Stan program not compiled for threading, '
1064+ 'process will run chains sequentially. '
1065+ 'For multi-chain parallelization, recompile '
1066+ 'the model with argument '
1067+ '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.'
1068+ )
1069+ else :
1070+ get_logger ().warning (
1071+ 'Installed version of CmdStan cannot multi-process '
1072+ 'chains, will run %d processes. '
1073+ 'Run "install_cmdstan" to upgrade to latest version.' ,
1074+ chains ,
1075+ )
1076+ os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
1077+
10061078 if chain_ids is None :
10071079 chain_ids = [i + 1 for i in range (chains )]
10081080 else :
@@ -1014,6 +1086,13 @@ def sample(
10141086 )
10151087 chain_ids = [i + chain_ids for i in range (chains )]
10161088 else :
1089+ if not one_process_per_chain :
1090+ for i , j in zip (chain_ids , chain_ids [1 :]):
1091+ if i != j - 1 :
1092+ raise ValueError (
1093+ 'chain_ids must be sequential list of integers,'
1094+ ' found {}.' .format (chain_ids )
1095+ )
10171096 if not len (chain_ids ) == chains :
10181097 raise ValueError (
10191098 'Chain_ids must correspond to number of chains'
@@ -1029,6 +1108,7 @@ def sample(
10291108 )
10301109
10311110 sampler_args = SamplerArgs (
1111+ num_chains = 1 if one_process_per_chain else chains ,
10321112 iter_warmup = iter_warmup ,
10331113 iter_sampling = iter_sampling ,
10341114 save_warmup = save_warmup ,
@@ -1043,14 +1123,25 @@ def sample(
10431123 adapt_step_size = adapt_step_size ,
10441124 fixed_param = fixed_param ,
10451125 )
1046- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
1126+
1127+ with temp_single_json (data ) as _data , temp_inits (
1128+ inits , id = chain_ids [0 ]
1129+ ) as _inits :
1130+ cmdstan_inits : Union [str , List [str ], int , float , None ]
1131+ if one_process_per_chain and isinstance (inits , list ): # legacy
1132+ cmdstan_inits = [
1133+ f"{ _inits [:- 5 ]} _{ i } .json" for i in chain_ids # type: ignore
1134+ ]
1135+ else :
1136+ cmdstan_inits = _inits
1137+
10471138 args = CmdStanArgs (
10481139 self ._name ,
10491140 self ._exe_file ,
10501141 chain_ids = chain_ids ,
10511142 data = _data ,
10521143 seed = seed ,
1053- inits = _inits ,
1144+ inits = cmdstan_inits ,
10541145 output_dir = output_dir ,
10551146 sig_figs = sig_figs ,
10561147 save_latent_dynamics = save_latent_dynamics ,
@@ -1059,67 +1150,6 @@ def sample(
10591150 refresh = refresh ,
10601151 )
10611152
1062- if parallel_chains is None :
1063- parallel_chains = max (min (cpu_count (), chains ), 1 )
1064- elif parallel_chains > chains :
1065- get_logger ().info (
1066- 'Requested %u parallel_chains but only %u required, '
1067- 'will run all chains in parallel.' ,
1068- parallel_chains ,
1069- chains ,
1070- )
1071- parallel_chains = chains
1072- elif parallel_chains < 1 :
1073- raise ValueError (
1074- 'Argument parallel_chains must be a positive integer, '
1075- 'found {}.' .format (parallel_chains )
1076- )
1077- if threads_per_chain is None :
1078- threads_per_chain = 1
1079- if threads_per_chain < 1 :
1080- raise ValueError (
1081- 'Argument threads_per_chain must be a positive integer, '
1082- 'found {}.' .format (threads_per_chain )
1083- )
1084-
1085- parallel_procs = parallel_chains
1086- num_threads = threads_per_chain
1087- one_process_per_chain = True
1088- info_dict = self .exe_info ()
1089- stan_threads = info_dict .get ('STAN_THREADS' , 'false' ).lower ()
1090- if chains == 1 :
1091- force_one_process_per_chain = True
1092-
1093- if (
1094- force_one_process_per_chain is None
1095- and not cmdstan_version_before (2 , 28 , info_dict )
1096- and stan_threads == 'true'
1097- ):
1098- one_process_per_chain = False
1099- num_threads = parallel_chains * num_threads
1100- parallel_procs = 1
1101- if force_one_process_per_chain is False :
1102- if not cmdstan_version_before (2 , 28 , info_dict ):
1103- one_process_per_chain = False
1104- num_threads = parallel_chains * num_threads
1105- parallel_procs = 1
1106- if stan_threads == 'false' :
1107- get_logger ().warning (
1108- 'Stan program not compiled for threading, '
1109- 'process will run chains sequentially. '
1110- 'For multi-chain parallelization, recompile '
1111- 'the model with argument '
1112- '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.'
1113- )
1114- else :
1115- get_logger ().warning (
1116- 'Installed version of CmdStan cannot multi-process '
1117- 'chains, will run %d processes. '
1118- 'Run "install_cmdstan" to upgrade to latest version.' ,
1119- chains ,
1120- )
1121- os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
1122-
11231153 if show_console :
11241154 show_progress = False
11251155 else :
@@ -1359,7 +1389,7 @@ def generate_quantities(
13591389 csv_files = fit_csv_files
13601390 )
13611391 generate_quantities_args .validate (chains )
1362- with MaybeDictToFilePath (data , None ) as ( _data , _inits ) :
1392+ with temp_single_json (data ) as _data :
13631393 args = CmdStanArgs (
13641394 self ._name ,
13651395 self ._exe_file ,
@@ -1534,7 +1564,9 @@ def variational(
15341564 output_samples = output_samples ,
15351565 )
15361566
1537- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
1567+ with temp_single_json (data ) as _data , temp_inits (
1568+ inits , allow_multiple = False
1569+ ) as _inits :
15381570 args = CmdStanArgs (
15391571 self ._name ,
15401572 self ._exe_file ,
@@ -1641,7 +1673,9 @@ def log_prob(
16411673 "Method 'log_prob' not available for CmdStan versions "
16421674 "before 2.31"
16431675 )
1644- with MaybeDictToFilePath (data , params ) as (_data , _params ):
1676+ with temp_single_json (data ) as _data , temp_single_json (
1677+ params
1678+ ) as _params :
16451679 cmd = [
16461680 str (self .exe_file ),
16471681 "log_prob" ,
@@ -1749,7 +1783,7 @@ def laplace_sample(
17491783 cmdstan_mode .runset .csv_files [0 ], draws , jacobian
17501784 )
17511785
1752- with MaybeDictToFilePath (data ) as ( _data ,) :
1786+ with temp_single_json (data ) as _data :
17531787 args = CmdStanArgs (
17541788 self ._name ,
17551789 self ._exe_file ,
0 commit comments