Skip to content

Commit a14e9bd

Browse files
authored
Merge pull request #510 from qres/develop
assemble draws and generated quantities only once
2 parents 75b4d47 + 95ecd6b commit a14e9bd

1 file changed

Lines changed: 17 additions & 13 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def draws(
231231
CmdStanMCMC.draws_xr
232232
CmdStanGQ.draws
233233
"""
234-
if self._draws.size == 0:
234+
if self._draws.shape == (0,):
235235
self._assemble_draws()
236236

237237
if inc_warmup and not self._save_warmup:
@@ -309,9 +309,6 @@ def _assemble_draws(self) -> None:
309309
Allocates and populates the step size, metric, and sample arrays
310310
by parsing the validated stan_csv files.
311311
"""
312-
if self._draws.shape != (0,):
313-
return
314-
315312
num_draws = self.num_draws_sampling
316313
sampling_iter_start = 0
317314
if self._save_warmup:
@@ -527,7 +524,8 @@ def draws_pd(
527524
' must run sampler with "save_warmup=True".'
528525
)
529526

530-
self._assemble_draws()
527+
if self._draws.shape == (0,):
528+
self._assemble_draws()
531529
cols = []
532530
if vars is not None:
533531
for var in set(vars_list):
@@ -583,7 +581,8 @@ def draws_xr(
583581
else:
584582
vars_list = vars
585583

586-
self._assemble_draws()
584+
if self._draws.shape == (0,):
585+
self._assemble_draws()
587586

588587
num_draws = self.num_draws_sampling
589588
meta = self._metadata.cmdstan_config
@@ -663,7 +662,8 @@ def stan_variable(
663662
raise ValueError('No variable name specified.')
664663
if var not in self._metadata.stan_vars_dims:
665664
raise ValueError('Unknown variable name: {}'.format(var))
666-
self._assemble_draws()
665+
if self._draws.shape == (0,):
666+
self._assemble_draws()
667667
draw1 = 0
668668
if not inc_warmup and self._save_warmup:
669669
draw1 = self.num_draws_warmup
@@ -705,7 +705,8 @@ def method_variables(self) -> Dict[str, np.ndarray]:
705705
containing per-draw diagnostic values.
706706
"""
707707
result = {}
708-
self._assemble_draws()
708+
if self._draws.shape == (0,):
709+
self._assemble_draws()
709710
for idxs in self.metadata.method_vars_cols.values():
710711
for idx in idxs:
711712
result[self.column_names[idx]] = self._draws[:, :, idx]
@@ -868,7 +869,7 @@ def draws(
868869
CmdStanGQ.draws_xr
869870
CmdStanMCMC.draws
870871
"""
871-
if self._draws.size == 0:
872+
if self._draws.shape == (0,):
872873
self._assemble_generated_quantities()
873874
if (
874875
inc_warmup
@@ -955,7 +956,8 @@ def draws_pd(
955956
'Draws from warmup iterations not available,'
956957
' must run sampler with "save_warmup=True".'
957958
)
958-
self._assemble_generated_quantities()
959+
if self._draws.shape == (0,):
960+
self._assemble_generated_quantities()
959961

960962
gq_cols = []
961963
mcmc_vars = []
@@ -1076,7 +1078,8 @@ def draws_xr(
10761078
for var in dup_vars:
10771079
vars_list.remove(var)
10781080

1079-
self._assemble_generated_quantities()
1081+
if self._draws.shape == (0,):
1082+
self._assemble_generated_quantities()
10801083

10811084
num_draws = self.mcmc_sample.num_draws_sampling
10821085
sample_config = self.mcmc_sample.metadata.cmdstan_config
@@ -1173,7 +1176,8 @@ def stan_variable(
11731176
if var not in gq_var_names:
11741177
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
11751178
else: # is gq variable
1176-
self._assemble_generated_quantities()
1179+
if self._draws.shape == (0,):
1180+
self._assemble_generated_quantities()
11771181
draw1 = 0
11781182
if (
11791183
not inc_warmup
@@ -1222,7 +1226,7 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
12221226
return result
12231227

12241228
def _assemble_generated_quantities(self) -> None:
1225-
# use numpy genfromtext
1229+
# use numpy loadtxt
12261230
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
12271231
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]
12281232
gq_sample = np.empty(

0 commit comments

Comments
 (0)