@@ -95,9 +95,9 @@ def __init__(
9595 self ._save_warmup = sampler_args .save_warmup
9696 self ._sig_figs = runset ._args .sig_figs
9797 # info from CSV values, instantiated lazily
98- self ._metric = np .array (())
99- self ._step_size = np .array (())
100- self ._draws = np .array (())
98+ self ._metric : np . ndarray = np .array (())
99+ self ._step_size : np . ndarray = np .array (())
100+ self ._draws : np . ndarray = np .array (())
101101 # info from CSV initial comments and header
102102 config = self ._validate_csv_files ()
103103 self ._metadata : InferenceMetadata = InferenceMetadata (config )
@@ -231,7 +231,7 @@ def draws(
231231 CmdStanMCMC.draws_xr
232232 CmdStanGQ.draws
233233 """
234- if self ._draws .size == 0 :
234+ if self ._draws .shape == ( 0 ,) :
235235 self ._assemble_draws ()
236236
237237 if inc_warmup and not self ._save_warmup :
@@ -246,7 +246,7 @@ def draws(
246246
247247 if concat_chains :
248248 return flatten_chains (self ._draws [start_idx :, :, :])
249- return self ._draws [start_idx :, :, :] # type: ignore
249+ return self ._draws [start_idx :, :, :]
250250
251251 def _validate_csv_files (self ) -> Dict [str , Any ]:
252252 """
@@ -309,9 +309,6 @@ def _assemble_draws(self) -> None:
309309 Allocates and populates the step size, metric, and sample arrays
310310 by parsing the validated stan_csv files.
311311 """
312- if self ._draws .shape != (0 ,):
313- return
314-
315312 num_draws = self .num_draws_sampling
316313 sampling_iter_start = 0
317314 if self ._save_warmup :
@@ -527,7 +524,8 @@ def draws_pd(
527524 ' must run sampler with "save_warmup=True".'
528525 )
529526
530- self ._assemble_draws ()
527+ if self ._draws .shape == (0 ,):
528+ self ._assemble_draws ()
531529 cols = []
532530 if vars is not None :
533531 for var in set (vars_list ):
@@ -583,7 +581,8 @@ def draws_xr(
583581 else :
584582 vars_list = vars
585583
586- self ._assemble_draws ()
584+ if self ._draws .shape == (0 ,):
585+ self ._assemble_draws ()
587586
588587 num_draws = self .num_draws_sampling
589588 meta = self ._metadata .cmdstan_config
@@ -663,7 +662,8 @@ def stan_variable(
663662 raise ValueError ('No variable name specified.' )
664663 if var not in self ._metadata .stan_vars_dims :
665664 raise ValueError ('Unknown variable name: {}' .format (var ))
666- self ._assemble_draws ()
665+ if self ._draws .shape == (0 ,):
666+ self ._assemble_draws ()
667667 draw1 = 0
668668 if not inc_warmup and self ._save_warmup :
669669 draw1 = self .num_draws_warmup
@@ -675,9 +675,7 @@ def stan_variable(
675675 if len (col_idxs ) > 0 :
676676 dims .extend (self ._metadata .stan_vars_dims [var ])
677677 # pylint: disable=redundant-keyword-arg
678- return self ._draws [draw1 :, :, col_idxs ].reshape ( # type: ignore
679- dims , order = 'F'
680- )
678+ return self ._draws [draw1 :, :, col_idxs ].reshape (dims , order = 'F' )
681679
682680 def stan_variables (self ) -> Dict [str , np .ndarray ]:
683681 """
@@ -705,7 +703,8 @@ def method_variables(self) -> Dict[str, np.ndarray]:
705703 containing per-draw diagnostic values.
706704 """
707705 result = {}
708- self ._assemble_draws ()
706+ if self ._draws .shape == (0 ,):
707+ self ._assemble_draws ()
709708 for idxs in self .metadata .method_vars_cols .values ():
710709 for idx in idxs :
711710 result [self .column_names [idx ]] = self ._draws [:, :, idx ]
@@ -747,7 +746,7 @@ def __init__(
747746 )
748747 self .runset = runset
749748 self .mcmc_sample = mcmc_sample
750- self ._draws = np .array (())
749+ self ._draws : np . ndarray = np .array (())
751750 config = self ._validate_csv_files ()
752751 self ._metadata = InferenceMetadata (config )
753752
@@ -764,7 +763,7 @@ def __repr__(self) -> str:
764763 )
765764 return repr
766765
767- def _validate_csv_files (self ) -> dict :
766+ def _validate_csv_files (self ) -> Dict [ str , Any ] :
768767 """
769768 Checks that Stan CSV output files for all chains are consistent
770769 and returns dict containing config and column names.
@@ -868,7 +867,7 @@ def draws(
868867 CmdStanGQ.draws_xr
869868 CmdStanMCMC.draws
870869 """
871- if self ._draws .size == 0 :
870+ if self ._draws .shape == ( 0 ,) :
872871 self ._assemble_generated_quantities ()
873872 if (
874873 inc_warmup
@@ -909,13 +908,13 @@ def draws(
909908 if concat_chains :
910909 return flatten_chains (self ._draws [start_idx :, :, :])
911910 if inc_sample :
912- return np .dstack ( # type: ignore
911+ return np .dstack (
913912 (
914913 np .delete (self .mcmc_sample .draws (), drop_cols , axis = 1 ),
915914 self ._draws ,
916915 )
917916 )[start_idx :, :, :]
918- return self ._draws [start_idx :, :, :] # type: ignore
917+ return self ._draws [start_idx :, :, :]
919918
920919 def draws_pd (
921920 self ,
@@ -955,7 +954,8 @@ def draws_pd(
955954 'Draws from warmup iterations not available,'
956955 ' must run sampler with "save_warmup=True".'
957956 )
958- self ._assemble_generated_quantities ()
957+ if self ._draws .shape == (0 ,):
958+ self ._assemble_generated_quantities ()
959959
960960 gq_cols = []
961961 mcmc_vars = []
@@ -1076,7 +1076,8 @@ def draws_xr(
10761076 for var in dup_vars :
10771077 vars_list .remove (var )
10781078
1079- self ._assemble_generated_quantities ()
1079+ if self ._draws .shape == (0 ,):
1080+ self ._assemble_generated_quantities ()
10801081
10811082 num_draws = self .mcmc_sample .num_draws_sampling
10821083 sample_config = self .mcmc_sample .metadata .cmdstan_config
@@ -1173,7 +1174,8 @@ def stan_variable(
11731174 if var not in gq_var_names :
11741175 return self .mcmc_sample .stan_variable (var , inc_warmup = inc_warmup )
11751176 else : # is gq variable
1176- self ._assemble_generated_quantities ()
1177+ if self ._draws .shape == (0 ,):
1178+ self ._assemble_generated_quantities ()
11771179 draw1 = 0
11781180 if (
11791181 not inc_warmup
@@ -1191,9 +1193,7 @@ def stan_variable(
11911193 if len (col_idxs ) > 0 :
11921194 dims .extend (self ._metadata .stan_vars_dims [var ])
11931195 # pylint: disable=redundant-keyword-arg
1194- return self ._draws [draw1 :, :, col_idxs ].reshape ( # type: ignore
1195- dims , order = 'F'
1196- )
1196+ return self ._draws [draw1 :, :, col_idxs ].reshape (dims , order = 'F' )
11971197
11981198 def stan_variables (self , inc_warmup : bool = False ) -> Dict [str , np .ndarray ]:
11991199 """
@@ -1222,10 +1222,10 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
12221222 return result
12231223
12241224 def _assemble_generated_quantities (self ) -> None :
1225- # use numpy genfromtext
1225+ # use numpy loadtxt
12261226 warmup = self .mcmc_sample .metadata .cmdstan_config ['save_warmup' ]
12271227 num_draws = self .mcmc_sample .draws (inc_warmup = warmup ).shape [0 ]
1228- gq_sample = np .empty (
1228+ gq_sample : np . ndarray = np .empty (
12291229 (num_draws , self .chains , len (self .column_names )),
12301230 dtype = float ,
12311231 order = 'F' ,
0 commit comments