@@ -1074,7 +1074,6 @@ def draws_xr(
10741074 0 ,
10751075 self .draws (inc_warmup = inc_warmup ),
10761076 )
1077-
10781077 return xr .Dataset (data , coords = coordinates , attrs = attrs ).transpose (
10791078 'chain' , 'draw' , ...
10801079 )
@@ -1373,7 +1372,7 @@ def stan_variable(
13731372 inc_iterations : bool = False ,
13741373 warn : bool = True ,
13751374 name : Optional [str ] = None ,
1376- ) -> np .ndarray :
1375+ ) -> Union [ np .ndarray , float ] :
13771376 """
13781377 Return a numpy.ndarray which contains the estimates for the
13791378 for the named Stan program variable where the dimensions of the
@@ -1416,38 +1415,34 @@ def stan_variable(
14161415 'Invalid estimate, optimization failed to converge.'
14171416 )
14181417
1419- col_idxs = self ._metadata .stan_vars_cols [var ]
1418+ col_idxs = list ( self ._metadata .stan_vars_cols [var ])
14201419 if inc_iterations and self ._save_iterations :
14211420 num_rows = self ._all_iters .shape [0 ]
14221421 else :
14231422 num_rows = 1
1424-
1425- if len (col_idxs ) > 0 : # container var
1423+ if len (col_idxs ) > 1 : # container var
14261424 dims = (num_rows ,) + self ._metadata .stan_vars_dims [var ]
14271425 # pylint: disable=redundant-keyword-arg
14281426 if num_rows > 1 :
14291427 result = self ._all_iters [:, col_idxs ].reshape ( # type: ignore
14301428 dims , order = 'F'
14311429 )
14321430 else :
1433- mle = np .expand_dims (self ._mle , axis = 0 ) # hack for col indexing
1434- result = (
1435- mle [0 , col_idxs ]
1436- .reshape (dims , order = 'F' ) # type: ignore
1437- .squeeze (axis = 0 )
1438- )
1431+ result = self ._mle [col_idxs ].reshape (dims [1 :], order = "F" )
14391432 else : # scalar var
1433+ col_idx = col_idxs [0 ]
14401434 if num_rows > 1 :
1441- result = self ._all_iters [:, col_idxs ]
1435+ result = self ._all_iters [:, col_idx ]
14421436 else :
1443- result = np .atleast_1d (mle [0 , col_idxs ])
1444-
1445- assert isinstance (result , np .ndarray ) # make the typechecker happy
1437+ result = float (self ._mle [col_idx ])
1438+ assert isinstance (
1439+ result , (np .ndarray , float )
1440+ ) # make the typechecker happy
14461441 return result
14471442
14481443 def stan_variables (
14491444 self , inc_iterations : bool = False
1450- ) -> Dict [str , np .ndarray ]:
1445+ ) -> Dict [str , Union [ np .ndarray , float ] ]:
14511446 """
14521447 Return a dictionary mapping Stan program variables names
14531448 to the corresponding numpy.ndarray containing the inferred values.
@@ -1988,16 +1983,26 @@ def stan_variable(
19881983 return self .mcmc_sample .stan_variable (var , inc_warmup = inc_warmup )
19891984 else : # is gq variable
19901985 self ._assemble_generated_quantities ()
1991- col_idxs = self . _metadata . stan_vars_cols [ var ]
1986+ draw1 = 0
19921987 if (
19931988 not inc_warmup
19941989 and self .mcmc_sample .metadata .cmdstan_config ['save_warmup' ]
19951990 ):
1996- draw1 = self .mcmc_sample .num_draws_warmup * self .chains
1997- return flatten_chains (self ._draws )[ # type: ignore
1998- draw1 :, col_idxs
1999- ]
2000- return flatten_chains (self ._draws )[:, col_idxs ] # type: ignore
1991+ draw1 = self .mcmc_sample .num_draws_warmup
1992+ num_draws = self .mcmc_sample .num_draws_sampling
1993+ if (
1994+ inc_warmup
1995+ and self .mcmc_sample .metadata .cmdstan_config ['save_warmup' ]
1996+ ):
1997+ num_draws += self .mcmc_sample .num_draws_warmup
1998+ dims = [num_draws * self .chains ]
1999+ col_idxs = self ._metadata .stan_vars_cols [var ]
2000+ if len (col_idxs ) > 0 :
2001+ dims .extend (self ._metadata .stan_vars_dims [var ])
2002+ # pylint: disable=redundant-keyword-arg
2003+ return self ._draws [draw1 :, :, col_idxs ].reshape ( # type: ignore
2004+ dims , order = 'F'
2005+ )
20012006
20022007 def stan_variables (self , inc_warmup : bool = False ) -> Dict [str , np .ndarray ]:
20032008 """
@@ -2143,7 +2148,7 @@ def metadata(self) -> InferenceMetadata:
21432148
21442149 def stan_variable (
21452150 self , var : Optional [str ] = None , * , name : Optional [str ] = None
2146- ) -> np .ndarray :
2151+ ) -> Union [ np .ndarray , float ] :
21472152 """
21482153 Return a numpy.ndarray which contains the estimates for the
21492154 for the named Stan program variable where the dimensions of the
@@ -2172,14 +2177,18 @@ def stan_variable(
21722177 if var not in self ._metadata .stan_vars_dims :
21732178 raise ValueError ('Unknown variable name: {}' .format (var ))
21742179 col_idxs = list (self ._metadata .stan_vars_cols [var ])
2175- vals = list (self ._variational_mean )
2176- xs = [vals [x ] for x in col_idxs ]
21772180 shape : Tuple [int , ...] = ()
2178- if len (col_idxs ) > 0 :
2181+ if len (col_idxs ) > 1 :
21792182 shape = self ._metadata .stan_vars_dims [var ]
2180- return np .array (xs ).reshape (shape )
2183+ result = np .asarray (self ._variational_mean )[col_idxs ].reshape (
2184+ shape , order = "F"
2185+ )
2186+ else :
2187+ result = float (self ._variational_mean [col_idxs [0 ]])
2188+ assert isinstance (result , (np .ndarray , float ))
2189+ return result
21812190
2182- def stan_variables (self ) -> Dict [str , np .ndarray ]:
2191+ def stan_variables (self ) -> Dict [str , Union [ np .ndarray , float ] ]:
21832192 """
21842193 Return a dictionary mapping Stan program variables names
21852194 to the corresponding numpy.ndarray containing the inferred values.
@@ -2424,7 +2433,12 @@ def build_xarray_data(
24242433 var_dims : Tuple [str , ...] = ('draw' , 'chain' )
24252434 if dims :
24262435 var_dims += tuple (f"{ var_name } _dim_{ i } " for i in range (len (dims )))
2427- data [var_name ] = (var_dims , drawset [start_row :, :, col_idxs ])
2436+ data [var_name ] = (
2437+ var_dims ,
2438+ drawset [start_row :, :, col_idxs ].reshape (
2439+ * drawset .shape [:2 ], * dims , order = "F"
2440+ ),
2441+ )
24282442 else :
24292443 data [var_name ] = (
24302444 var_dims ,
0 commit comments