@@ -95,9 +95,9 @@ def __init__(
9595 self ._save_warmup = sampler_args .save_warmup
9696 self ._sig_figs = runset ._args .sig_figs
9797 # info from CSV values, instantiated lazily
98- self ._metric = np .array (())
99- self ._step_size = np .array (())
100- self ._draws = np .array (())
98+ self ._metric : np . ndarray = np .array (())
99+ self ._step_size : np . ndarray = np .array (())
100+ self ._draws : np . ndarray = np .array (())
101101 # info from CSV initial comments and header
102102 config = self ._validate_csv_files ()
103103 self ._metadata : InferenceMetadata = InferenceMetadata (config )
@@ -246,7 +246,7 @@ def draws(
246246
247247 if concat_chains :
248248 return flatten_chains (self ._draws [start_idx :, :, :])
249- return self ._draws [start_idx :, :, :] # type: ignore
249+ return self ._draws [start_idx :, :, :]
250250
251251 def _validate_csv_files (self ) -> Dict [str , Any ]:
252252 """
@@ -675,9 +675,7 @@ def stan_variable(
675675 if len (col_idxs ) > 0 :
676676 dims .extend (self ._metadata .stan_vars_dims [var ])
677677 # pylint: disable=redundant-keyword-arg
678- return self ._draws [draw1 :, :, col_idxs ].reshape ( # type: ignore
679- dims , order = 'F'
680- )
678+ return self ._draws [draw1 :, :, col_idxs ].reshape (dims , order = 'F' )
681679
682680 def stan_variables (self ) -> Dict [str , np .ndarray ]:
683681 """
@@ -748,7 +746,7 @@ def __init__(
748746 )
749747 self .runset = runset
750748 self .mcmc_sample = mcmc_sample
751- self ._draws = np .array (())
749+ self ._draws : np . ndarray = np .array (())
752750 config = self ._validate_csv_files ()
753751 self ._metadata = InferenceMetadata (config )
754752
@@ -765,7 +763,7 @@ def __repr__(self) -> str:
765763 )
766764 return repr
767765
768- def _validate_csv_files (self ) -> dict :
766+ def _validate_csv_files (self ) -> Dict [ str , Any ] :
769767 """
770768 Checks that Stan CSV output files for all chains are consistent
771769 and returns dict containing config and column names.
@@ -910,13 +908,13 @@ def draws(
910908 if concat_chains :
911909 return flatten_chains (self ._draws [start_idx :, :, :])
912910 if inc_sample :
913- return np .dstack ( # type: ignore
911+ return np .dstack (
914912 (
915913 np .delete (self .mcmc_sample .draws (), drop_cols , axis = 1 ),
916914 self ._draws ,
917915 )
918916 )[start_idx :, :, :]
919- return self ._draws [start_idx :, :, :] # type: ignore
917+ return self ._draws [start_idx :, :, :]
920918
921919 def draws_pd (
922920 self ,
@@ -1195,9 +1193,7 @@ def stan_variable(
11951193 if len (col_idxs ) > 0 :
11961194 dims .extend (self ._metadata .stan_vars_dims [var ])
11971195 # pylint: disable=redundant-keyword-arg
1198- return self ._draws [draw1 :, :, col_idxs ].reshape ( # type: ignore
1199- dims , order = 'F'
1200- )
1196+ return self ._draws [draw1 :, :, col_idxs ].reshape (dims , order = 'F' )
12011197
12021198 def stan_variables (self , inc_warmup : bool = False ) -> Dict [str , np .ndarray ]:
12031199 """
@@ -1229,7 +1225,7 @@ def _assemble_generated_quantities(self) -> None:
12291225 # use numpy loadtxt
12301226 warmup = self .mcmc_sample .metadata .cmdstan_config ['save_warmup' ]
12311227 num_draws = self .mcmc_sample .draws (inc_warmup = warmup ).shape [0 ]
1232- gq_sample = np .empty (
1228+ gq_sample : np . ndarray = np .empty (
12331229 (num_draws , self .chains , len (self .column_names )),
12341230 dtype = float ,
12351231 order = 'F' ,
0 commit comments