3737 get_logger ,
3838 scan_generated_quantities_csv ,
3939)
40- from cmdstanpy .utils .data_munging import extract_reshape
4140
4241from .mcmc import CmdStanMCMC
4342from .metadata import InferenceMetadata
@@ -242,7 +241,9 @@ def draws(
242241 ]
243242 drop_cols : List [int ] = []
244243 for dup in dups :
245- drop_cols .extend (self .previous_fit .metadata .stan_vars_cols [dup ])
244+ drop_cols .extend (
245+ self .previous_fit ._metadata .stan_vars [dup ].columns ()
246+ )
246247
247248 start_idx , _ = self ._draws_start (inc_warmup )
248249 previous_draws = self ._previous_draws (True )
@@ -324,18 +325,24 @@ def draws_pd(
324325
325326 self ._assemble_generated_quantities ()
326327
327- gq_cols = []
328- mcmc_vars = []
328+ gq_cols : List [ str ] = []
329+ mcmc_vars : List [ str ] = []
329330 if vars is not None :
330331 for var in vars_list :
331- if var in self .metadata .stan_vars_cols :
332- for idx in self .metadata .stan_vars_cols [var ]:
333- gq_cols .append (self .column_names [idx ])
332+ if var in self ._metadata .stan_vars :
333+ info = self ._metadata .stan_vars [var ]
334+ gq_cols .extend (
335+ self .column_names [info .start_idx : info .end_idx ]
336+ )
334337 elif (
335- inc_sample
336- and var in self .previous_fit .metadata .stan_vars_cols
338+ inc_sample and var in self .previous_fit ._metadata .stan_vars
337339 ):
338- mcmc_vars .append (var )
340+ info = self .previous_fit ._metadata .stan_vars [var ]
341+ mcmc_vars .extend (
342+ self .previous_fit .column_names [
343+ info .start_idx : info .end_idx
344+ ]
345+ )
339346 else :
340347 raise ValueError ('Unknown variable: {}' .format (var ))
341348 else :
@@ -463,18 +470,18 @@ def draws_xr(
463470 else :
464471 vars_list = vars
465472 for var in vars_list :
466- if var not in self .metadata . stan_vars_cols :
473+ if var not in self ._metadata . stan_vars :
467474 if inc_sample and (
468- var in self .previous_fit .metadata . stan_vars_cols
475+ var in self .previous_fit ._metadata . stan_vars
469476 ):
470477 mcmc_vars_list .append (var )
471478 dup_vars .append (var )
472479 else :
473480 raise ValueError ('Unknown variable: {}' .format (var ))
474481 else :
475- vars_list = list (self .metadata . stan_vars_cols .keys ())
482+ vars_list = list (self ._metadata . stan_vars .keys ())
476483 if inc_sample :
477- for var in self .previous_fit .metadata . stan_vars_cols .keys ():
484+ for var in self .previous_fit ._metadata . stan_vars .keys ():
478485 if var not in vars_list and var not in mcmc_vars_list :
479486 mcmc_vars_list .append (var )
480487 for var in dup_vars :
@@ -483,7 +490,7 @@ def draws_xr(
483490 self ._assemble_generated_quantities ()
484491
485492 num_draws = self .previous_fit .num_draws_sampling
486- sample_config = self .previous_fit .metadata .cmdstan_config
493+ sample_config = self .previous_fit ._metadata .cmdstan_config
487494 attrs : MutableMapping [Hashable , Any ] = {
488495 "stan_version" : f"{ sample_config ['stan_version_major' ]} ."
489496 f"{ sample_config ['stan_version_minor' ]} ."
@@ -504,23 +511,15 @@ def draws_xr(
504511 for var in vars_list :
505512 build_xarray_data (
506513 data ,
507- var ,
508- self ._metadata .stan_vars_dims [var ],
509- self ._metadata .stan_vars_cols [var ],
510- 0 ,
514+ self ._metadata .stan_vars [var ],
511515 self .draws (inc_warmup = inc_warmup ),
512- self ._metadata .stan_vars_types [var ],
513516 )
514517 if inc_sample :
515518 for var in mcmc_vars_list :
516519 build_xarray_data (
517520 data ,
518- var ,
519- self .previous_fit .metadata .stan_vars_dims [var ],
520- self .previous_fit .metadata .stan_vars_cols [var ],
521- 0 ,
521+ self .previous_fit ._metadata .stan_vars [var ],
522522 self .previous_fit .draws (inc_warmup = inc_warmup ),
523- self .previous_fit ._metadata .stan_vars_types [var ],
524523 )
525524
526525 return xr .Dataset (data , coords = coordinates , attrs = attrs ).transpose (
@@ -545,13 +544,13 @@ def stan_variable(
545544 the next M are from chain 2, and the last M elements are from chain N.
546545
547546 * If the variable is a scalar variable, the return array has shape
548- ( draws X chains, 1).
547+ ( draws * chains, 1).
549548 * If the variable is a vector, the return array has shape
550- ( draws X chains, len(vector))
549+ ( draws * chains, len(vector))
551550 * If the variable is a matrix, the return array has shape
552- ( draws X chains, size(dim 1) X size(dim 2) )
551+ ( draws * chains, size(dim 1), size(dim 2) )
553552 * If the variable is an array with N dimensions, the return array
554- has shape ( draws X chains, size(dim 1) X ... X size(dim N))
553+ has shape ( draws * chains, size(dim 1), ..., size(dim N))
555554
556555 For example, if the Stan program variable ``theta`` is a 3x3 matrix,
557556 and the sample consists of 4 chains with 1000 post-warmup draws,
@@ -573,8 +572,8 @@ def stan_variable(
573572 CmdStanMLE.stan_variable
574573 CmdStanVB.stan_variable
575574 """
576- model_var_names = self .previous_fit .metadata . stan_vars_cols .keys ()
577- gq_var_names = self .metadata . stan_vars_cols .keys ()
575+ model_var_names = self .previous_fit ._metadata . stan_vars .keys ()
576+ gq_var_names = self ._metadata . stan_vars .keys ()
578577 if not (var in model_var_names or var in gq_var_names ):
579578 raise ValueError (
580579 f'Unknown variable name: { var } \n '
@@ -588,30 +587,21 @@ def stan_variable(
588587 )
589588 elif isinstance (self .previous_fit , CmdStanMLE ):
590589 return np .atleast_1d ( # type: ignore
591- np .asarray (
592- self .previous_fit .stan_variable (
593- var , inc_iterations = inc_warmup
594- )
590+ self .previous_fit .stan_variable (
591+ var , inc_iterations = inc_warmup
595592 )
596593 )
597594 else :
598595 return np .atleast_1d ( # type: ignore
599- np . asarray ( self .previous_fit .stan_variable (var ) )
596+ self .previous_fit .stan_variable (var )
600597 )
601-
602598 # is gq variable
603599 self ._assemble_generated_quantities ()
604- draw1 , num_draws = self ._draws_start (inc_warmup )
605- dims = (num_draws * self .chains ,)
606- col_idxs = self ._metadata .stan_vars_cols [var ]
607-
608- return extract_reshape (
609- dims = dims + self ._metadata .stan_vars_dims [var ],
610- col_idxs = col_idxs ,
611- var_type = self ._metadata .stan_vars_types [var ],
612- start_row = draw1 ,
613- draws_in = self ._draws ,
614- )
600+
601+ draw1 , _ = self ._draws_start (inc_warmup )
602+ draws = flatten_chains (self ._draws [draw1 :])
603+ out : np .ndarray = self ._metadata .stan_vars [var ].extract_reshape (draws )
604+ return out
615605
616606 def stan_variables (self , inc_warmup : bool = False ) -> Dict [str , np .ndarray ]:
617607 """
@@ -630,8 +620,8 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
630620 CmdStanVB.stan_variables
631621 """
632622 result = {}
633- sample_var_names = self .previous_fit .metadata . stan_vars_cols .keys ()
634- gq_var_names = self .metadata . stan_vars_cols .keys ()
623+ sample_var_names = self .previous_fit ._metadata . stan_vars .keys ()
624+ gq_var_names = self ._metadata . stan_vars .keys ()
635625 for name in gq_var_names :
636626 result [name ] = self .stan_variable (name , inc_warmup )
637627 for name in sample_var_names :
@@ -697,9 +687,9 @@ def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
697687 if inc_warmup and p_fit ._save_iterations :
698688 return p_fit .optimized_iterations_np [:, None ] # type: ignore
699689
700- return np .atleast_2d (p_fit . optimized_params_np ,)[ # type: ignore
701- :, None
702- ]
690+ return np .atleast_2d ( # type: ignore
691+ p_fit . optimized_params_np ,
692+ )[:, None ]
703693 else : # CmdStanVB:
704694 if inc_warmup :
705695 return np .vstack (
0 commit comments