Skip to content

Commit 5dff5a7

Browse files
committed
Add chain, iter, draw columns to draws_pd
Still needed in GQ
1 parent 1361f35 commit 5dff5a7

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,31 @@ def draws_pd(
615615
else:
616616
cols = list(self.column_names)
617617

618+
draws = self.draws(inc_warmup=inc_warmup)
619+
# add long-form columns for chain, iteration, draw
620+
n_draws, n_chains, _ = draws.shape
621+
chains_col = (
622+
np.repeat(np.arange(1, n_chains + 1), n_draws)
623+
.reshape(1, n_chains, n_draws)
624+
.T
625+
)
626+
iter_col = (
627+
np.tile(np.arange(1, n_draws + 1), n_chains)
628+
.reshape(1, n_chains, n_draws)
629+
.T
630+
)
631+
draw_col = (
632+
np.arange(1, (n_draws * n_chains) + 1)
633+
.reshape(1, n_chains, n_draws)
634+
.T
635+
)
636+
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
637+
638+
cols = ['chain', 'iter', 'draw'] + cols
639+
618640
return pd.DataFrame(
619-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
620-
columns=self.column_names,
641+
data=flatten_chains(draws),
642+
columns=['chain', 'iter', 'draw'] + list(self.column_names),
621643
)[cols]
622644

623645
def draws_xr(

0 commit comments

Comments
 (0)