Skip to content

Commit d6827d4

Browse files
committed
stanio 0.2
1 parent 972c1e7 commit d6827d4

6 files changed

Lines changed: 15 additions & 13 deletions

File tree

.github/workflows/main.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ on:
1515
required: false
1616
default: ''
1717

18+
# only run one copy per PR
19+
concurrency:
20+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
21+
cancel-in-progress: true
22+
1823
jobs:
1924
get-cmdstan-version:
2025
# get the latest cmdstan version to use as part of the cache key

cmdstanpy/stanfit/gq.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,17 +587,14 @@ def stan_variable(
587587
)
588588
elif isinstance(self.previous_fit, CmdStanMLE):
589589
return np.atleast_1d( # type: ignore
590-
np.asarray(
591-
self.previous_fit.stan_variable(
592-
var, inc_iterations=inc_warmup
593-
)
590+
self.previous_fit.stan_variable(
591+
var, inc_iterations=inc_warmup
594592
)
595593
)
596594
else:
597595
return np.atleast_1d( # type: ignore
598-
np.asarray(self.previous_fit.stan_variable(var))
596+
self.previous_fit.stan_variable(var)
599597
)
600-
601598
# is gq variable
602599
self._assemble_generated_quantities()
603600

cmdstanpy/stanfit/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ def cmdstan_config(self) -> Dict[str, Any]:
3939
return copy.deepcopy(self._cmdstan_config)
4040

4141
@property
42-
def method_vars(self) -> Dict[str, stanio.Parameter]:
42+
def method_vars(self) -> Dict[str, stanio.Variable]:
4343
"""
4444
Method variable names always end in `__`, e.g. `lp__`.
4545
"""
4646
return self._method_vars
4747

4848
@property
49-
def stan_vars(self) -> Dict[str, stanio.Parameter]:
49+
def stan_vars(self) -> Dict[str, stanio.Variable]:
5050
return self._stan_vars

cmdstanpy/stanfit/mle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ def stan_variable(
209209
if inc_iterations and self._save_iterations:
210210
data = self._all_iters
211211
else:
212-
data = np.atleast_2d(self._mle)
212+
data = self._mle
213213

214214
try:
215-
out: np.ndarray = (
216-
self._metadata.stan_vars[var].extract_reshape(data).squeeze()
215+
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
216+
data
217217
)
218218
return out
219219
except KeyError:

cmdstanpy/utils/data_munging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def flatten_chains(draws_array: np.ndarray) -> np.ndarray:
2828

2929
def build_xarray_data(
3030
data: MutableMapping[Hashable, Tuple[Tuple[str, ...], np.ndarray]],
31-
var: stanio.Parameter,
31+
var: stanio.Variable,
3232
drawset: np.ndarray,
3333
) -> None:
3434
"""

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pandas
22
numpy>=1.21
33
tqdm
4-
stanio~=0.1.0
4+
stanio~=0.2.0

0 commit comments

Comments
 (0)