@@ -1074,7 +1074,6 @@ def draws_xr(
10741074 0 ,
10751075 self .draws (inc_warmup = inc_warmup ),
10761076 )
1077-
10781077 return xr .Dataset (data , coords = coordinates , attrs = attrs ).transpose (
10791078 'chain' , 'draw' , ...
10801079 )
@@ -1373,7 +1372,7 @@ def stan_variable(
13731372 inc_iterations : bool = False ,
13741373 warn : bool = True ,
13751374 name : Optional [str ] = None ,
1376- ) -> np .ndarray :
1375+ ) -> Union [ np .ndarray , float ] :
13771376 """
13781377 Return a numpy.ndarray which contains the estimates for the
13791378 for the named Stan program variable where the dimensions of the
@@ -1416,33 +1415,29 @@ def stan_variable(
14161415 'Invalid estimate, optimization failed to converge.'
14171416 )
14181417
1419- col_idxs = self ._metadata .stan_vars_cols [var ]
1418+ col_idxs = list ( self ._metadata .stan_vars_cols [var ])
14201419 if inc_iterations and self ._save_iterations :
14211420 num_rows = self ._all_iters .shape [0 ]
14221421 else :
14231422 num_rows = 1
1424-
1425- if len (col_idxs ) > 0 : # container var
1423+ if len (col_idxs ) > 1 : # container var
14261424 dims = (num_rows ,) + self ._metadata .stan_vars_dims [var ]
14271425 # pylint: disable=redundant-keyword-arg
14281426 if num_rows > 1 :
14291427 result = self ._all_iters [:, col_idxs ].reshape ( # type: ignore
14301428 dims , order = 'F'
14311429 )
14321430 else :
1433- mle = np .expand_dims (self ._mle , axis = 0 ) # hack for col indexing
1434- result = (
1435- mle [0 , col_idxs ]
1436- .reshape (dims , order = 'F' ) # type: ignore
1437- .squeeze (axis = 0 )
1438- )
1431+ result = self ._mle [col_idxs ].reshape (dims [1 :], order = "F" )
14391432 else : # scalar var
1433+ col_idx = col_idxs [0 ]
14401434 if num_rows > 1 :
1441- result = self ._all_iters [:, col_idxs ]
1435+ result = self ._all_iters [:, col_idx ]
14421436 else :
1443- result = np .atleast_1d (mle [0 , col_idxs ])
1444-
1445- assert isinstance (result , np .ndarray ) # make the typechecker happy
1437+ result = float (self ._mle [col_idx ])
1438+ assert isinstance (
1439+ result , (np .ndarray , float )
1440+ ) # make the typechecker happy
14461441 return result
14471442
14481443 def stan_variables (
@@ -2143,7 +2138,7 @@ def metadata(self) -> InferenceMetadata:
21432138
21442139 def stan_variable (
21452140 self , var : Optional [str ] = None , * , name : Optional [str ] = None
2146- ) -> np .ndarray :
2141+ ) -> Union [ np .ndarray , float ] :
21472142 """
21482143 Return a numpy.ndarray which contains the estimates for the
21492144 for the named Stan program variable where the dimensions of the
@@ -2172,12 +2167,16 @@ def stan_variable(
21722167 if var not in self ._metadata .stan_vars_dims :
21732168 raise ValueError ('Unknown variable name: {}' .format (var ))
21742169 col_idxs = list (self ._metadata .stan_vars_cols [var ])
2175- vals = list (self ._variational_mean )
2176- xs = [vals [x ] for x in col_idxs ]
21772170 shape : Tuple [int , ...] = ()
2178- if len (col_idxs ) > 0 :
2171+ if len (col_idxs ) > 1 :
21792172 shape = self ._metadata .stan_vars_dims [var ]
2180- return np .array (xs ).reshape (shape )
2173+ result = np .asarray (self ._variational_mean )[col_idxs ].reshape (
2174+ shape , order = "F"
2175+ )
2176+ else :
2177+ result = float (self ._variational_mean [col_idxs [0 ]])
2178+ assert isinstance (result , (np .ndarray , float ))
2179+ return result
21812180
21822181 def stan_variables (self ) -> Dict [str , np .ndarray ]:
21832182 """
@@ -2424,7 +2423,12 @@ def build_xarray_data(
24242423 var_dims : Tuple [str , ...] = ('draw' , 'chain' )
24252424 if dims :
24262425 var_dims += tuple (f"{ var_name } _dim_{ i } " for i in range (len (dims )))
2427- data [var_name ] = (var_dims , drawset [start_row :, :, col_idxs ])
2426+ data [var_name ] = (
2427+ var_dims ,
2428+ drawset [start_row :, :, col_idxs ].reshape (
2429+ * drawset .shape [:2 ], * dims , order = "F"
2430+ ),
2431+ )
24282432 else :
24292433 data [var_name ] = (
24302434 var_dims ,
0 commit comments