Skip to content

Commit 7cf6483

Browse files
authored
Merge pull request #677 from stan-dev/feature/676-pandas-columns
Add "chain__", "iter__", and "draw__" columns to draws_pd
2 parents b1aff80 + feed16c commit 7cf6483

4 files changed

Lines changed: 155 additions & 52 deletions

File tree

cmdstanpy/stanfit/gq.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -344,25 +344,49 @@ def draws_pd(
344344

345345
previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
346346

347+
draws = self.draws(inc_warmup=inc_warmup)
348+
# add long-form columns for chain, iteration, draw
349+
n_draws, n_chains, _ = draws.shape
350+
chains_col = (
351+
np.repeat(np.arange(1, n_chains + 1), n_draws)
352+
.reshape(1, n_chains, n_draws)
353+
.T
354+
)
355+
iter_col = (
356+
np.tile(np.arange(1, n_draws + 1), n_chains)
357+
.reshape(1, n_chains, n_draws)
358+
.T
359+
)
360+
draw_col = (
361+
np.arange(1, (n_draws * n_chains) + 1)
362+
.reshape(1, n_chains, n_draws)
363+
.T
364+
)
365+
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
366+
367+
vars_list = ['chain__', 'iter__', 'draw__'] + vars_list
368+
if gq_cols:
369+
gq_cols = ['chain__', 'iter__', 'draw__'] + gq_cols
370+
371+
draws_pd = pd.DataFrame(
372+
data=flatten_chains(draws),
373+
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
374+
)
375+
347376
if inc_sample and mcmc_vars:
348377
if gq_cols:
349378
return pd.concat(
350379
[
351380
previous_draws_pd,
352-
pd.DataFrame(
353-
data=flatten_chains(
354-
self.draws(inc_warmup=inc_warmup)
355-
),
356-
columns=self.column_names,
357-
)[gq_cols],
381+
draws_pd[gq_cols],
358382
],
359383
axis='columns',
360384
)[vars_list]
361385
else:
362386
return previous_draws_pd
363387
elif inc_sample and vars is None:
364-
cols_1 = self.previous_fit.column_names
365-
cols_2 = self.column_names
388+
cols_1 = list(previous_draws_pd.columns)
389+
cols_2 = list(draws_pd.columns)
366390
dups = [
367391
item
368392
for item, count in Counter(cols_1 + cols_2).items()
@@ -371,23 +395,14 @@ def draws_pd(
371395
return pd.concat(
372396
[
373397
previous_draws_pd.drop(columns=dups).reset_index(drop=True),
374-
pd.DataFrame(
375-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
376-
columns=self.column_names,
377-
),
398+
draws_pd,
378399
],
379400
axis=1,
380401
)
381402
elif gq_cols:
382-
return pd.DataFrame(
383-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
384-
columns=self.column_names,
385-
)[gq_cols]
386-
387-
return pd.DataFrame(
388-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
389-
columns=self.column_names,
390-
)
403+
return draws_pd[gq_cols]
404+
405+
return draws_pd
391406

392407
@overload
393408
def draws_xr(
@@ -657,7 +672,6 @@ def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:
657672
elif isinstance(p_fit, CmdStanMLE):
658673
num_draws = 1
659674
if p_fit._save_iterations:
660-
661675
opt_iters = len(p_fit.optimized_iterations_np) # type: ignore
662676
if inc_warmup:
663677
num_draws = opt_iters
@@ -706,7 +720,6 @@ def _previous_draws_pd(
706720
return p_fit.draws_pd(vars or None, inc_warmup=inc_warmup)
707721

708722
elif isinstance(p_fit, CmdStanMLE):
709-
710723
if inc_warmup and p_fit._save_iterations:
711724
return p_fit.optimized_iterations_pd[sel] # type: ignore
712725
else:

cmdstanpy/stanfit/mcmc.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,31 @@ def draws_pd(
615615
else:
616616
cols = list(self.column_names)
617617

618+
draws = self.draws(inc_warmup=inc_warmup)
619+
# add long-form columns for chain, iteration, draw
620+
n_draws, n_chains, _ = draws.shape
621+
chains_col = (
622+
np.repeat(np.arange(1, n_chains + 1), n_draws)
623+
.reshape(1, n_chains, n_draws)
624+
.T
625+
)
626+
iter_col = (
627+
np.tile(np.arange(1, n_draws + 1), n_chains)
628+
.reshape(1, n_chains, n_draws)
629+
.T
630+
)
631+
draw_col = (
632+
np.arange(1, (n_draws * n_chains) + 1)
633+
.reshape(1, n_chains, n_draws)
634+
.T
635+
)
636+
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
637+
638+
cols = ['chain__', 'iter__', 'draw__'] + cols
639+
618640
return pd.DataFrame(
619-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
620-
columns=self.column_names,
641+
data=flatten_chains(draws),
642+
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
621643
)[cols]
622644

623645
def draws_xr(

test/test_generate_quantities.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,56 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
7777
)
7878

7979
# draws_pd()
80-
assert bern_gqs.draws_pd().shape == (400, 10)
80+
assert bern_gqs.draws_pd().shape == (400, 13)
8181
assert (
8282
bern_gqs.draws_pd(inc_sample=True).shape[1]
8383
== bern_gqs.previous_fit.draws_pd().shape[1]
8484
+ bern_gqs.draws_pd().shape[1]
85+
- 3 # chain, iter, draw duplicates
8586
)
8687

87-
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == column_names
88+
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (
89+
["chain__", "iter__", "draw__"] + column_names
90+
)
91+
92+
93+
def test_pd_xr_agreement():
94+
# fitted_params sample - list of filenames
95+
goodfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', 'bern')
96+
csv_files = []
97+
for i in range(4):
98+
csv_files.append('{}-{}.csv'.format(goodfiles_path, i + 1))
99+
100+
# gq_model
101+
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
102+
model = CmdStanModel(stan_file=stan)
103+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
104+
105+
bern_gqs = model.generate_quantities(data=jdata, previous_fit=csv_files)
106+
107+
draws_pd = bern_gqs.draws_pd(inc_sample=True)
108+
draws_xr = bern_gqs.draws_xr(inc_sample=True)
109+
110+
# check that the indexing is the same between the two
111+
np.testing.assert_equal(
112+
draws_pd[draws_pd['chain__'] == 2]['y_rep[1]'],
113+
draws_xr.y_rep.sel(chain=2).isel(y_rep_dim_0=0).values,
114+
)
115+
# "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
116+
np.testing.assert_equal(
117+
draws_pd[draws_pd['iter__'] == 100]['y_rep[1]'],
118+
draws_xr.y_rep.sel(draw=99).isel(y_rep_dim_0=0).values,
119+
)
120+
121+
# check for included sample as well
122+
np.testing.assert_equal(
123+
draws_pd[draws_pd['chain__'] == 2]['theta'],
124+
draws_xr.theta.sel(chain=2).values,
125+
)
126+
np.testing.assert_equal(
127+
draws_pd[draws_pd['iter__'] == 100]['theta'],
128+
draws_xr.theta.sel(draw=99).values,
129+
)
88130

89131

90132
def test_from_csv_files_bad() -> None:
@@ -153,16 +195,17 @@ def test_from_previous_fit_draws() -> None:
153195

154196
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
155197

156-
assert bern_gqs.draws_pd().shape == (400, 10)
198+
assert bern_gqs.draws_pd().shape == (400, 13)
157199
assert (
158200
bern_gqs.draws_pd(inc_sample=True).shape[1]
159201
== bern_gqs.previous_fit.draws_pd().shape[1]
160202
+ bern_gqs.draws_pd().shape[1]
203+
- 3 # duplicates of chain, iter, and draw
161204
)
162205
row1_sample_pd = bern_fit.draws_pd().iloc[0]
163206
row1_gqs_pd = bern_gqs.draws_pd().iloc[0]
164207
np.testing.assert_array_equal(
165-
pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values,
208+
pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values[3:],
166209
bern_gqs.draws_pd(inc_sample=True).iloc[0].values,
167210
)
168211
# draws_xr
@@ -267,14 +310,14 @@ def test_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
267310
10,
268311
)
269312

270-
assert bern_gqs.draws_pd().shape == (400, 10)
271-
assert bern_gqs.draws_pd(inc_warmup=False).shape == (400, 10)
272-
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 10)
313+
assert bern_gqs.draws_pd().shape == (400, 13)
314+
assert bern_gqs.draws_pd(inc_warmup=False).shape == (400, 13)
315+
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 13)
273316
assert bern_gqs.draws_pd(vars=['y_rep'], inc_warmup=False).shape == (
274317
400,
275-
10,
318+
13,
276319
)
277-
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 10)
320+
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 13)
278321

279322
theta = bern_gqs.stan_variable(var='theta')
280323
assert theta.shape == (400,)
@@ -523,7 +566,7 @@ def test_from_optimization() -> None:
523566
assert bern_gqs.draws(inc_sample=True).shape == (1, 1, 12)
524567

525568
# draws_pd()
526-
assert bern_gqs.draws_pd().shape == (1, 10)
569+
assert bern_gqs.draws_pd().shape == (1, 13)
527570
assert (
528571
bern_gqs.draws_pd(inc_sample=True).shape[1]
529572
== bern_gqs.previous_fit.optimized_params_pd.shape[1]
@@ -665,7 +708,7 @@ def test_from_vb():
665708
assert bern_gqs.draws(inc_sample=True).shape == (1000, 1, 14)
666709

667710
# draws_pd()
668-
assert bern_gqs.draws_pd().shape == (1000, 10)
711+
assert bern_gqs.draws_pd().shape == (1000, 13)
669712
assert (
670713
bern_gqs.draws_pd(inc_sample=True).shape[1]
671714
== bern_gqs.previous_fit.variational_sample_pd.shape[1]

test/test_sample.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -745,18 +745,24 @@ def test_validate_good_run() -> None:
745745
draws_pd = fit.draws_pd()
746746
assert draws_pd.shape == (
747747
fit.runset.chains * fit.num_draws_sampling,
748-
len(fit.column_names),
748+
len(fit.column_names) + 3,
749749
)
750-
assert fit.draws_pd(vars=['theta']).shape == (400, 1)
751-
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 2)
752-
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 2)
753-
assert fit.draws_pd(vars='theta').shape == (400, 1)
750+
assert fit.draws_pd(vars=['theta']).shape == (400, 4)
751+
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 5)
752+
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 5)
753+
assert fit.draws_pd(vars='theta').shape == (400, 4)
754754

755755
assert list(fit.draws_pd(vars=['theta', 'lp__']).columns) == [
756+
'chain__',
757+
'iter__',
758+
'draw__',
756759
'theta',
757760
'lp__',
758761
]
759762
assert list(fit.draws_pd(vars=['lp__', 'theta']).columns) == [
763+
'chain__',
764+
'iter__',
765+
'draw__',
760766
'lp__',
761767
'theta',
762768
]
@@ -817,7 +823,7 @@ def test_validate_big_run() -> None:
817823
assert fit.step_size.shape == (2,)
818824
assert fit.metric.shape == (2, 2095)
819825
assert fit.draws().shape == (1000, 2, 2102)
820-
assert fit.draws_pd(vars=['phi']).shape == (2000, 2095)
826+
assert fit.draws_pd(vars=['phi']).shape == (2000, 2098)
821827
with raises_nested(ValueError, r'Unknown variable: gamma'):
822828
fit.draws_pd(vars=['gamma'])
823829

@@ -828,14 +834,14 @@ def test_instantiate_from_csvfiles() -> None:
828834
draws_pd = bern_fit.draws_pd()
829835
assert draws_pd.shape == (
830836
bern_fit.runset.chains * bern_fit.num_draws_sampling,
831-
len(bern_fit.column_names),
837+
len(bern_fit.column_names) + 3,
832838
)
833839
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-big')
834840
big_fit = from_csv(path=csvfiles_path)
835841
draws_pd = big_fit.draws_pd()
836842
assert draws_pd.shape == (
837843
big_fit.runset.chains * big_fit.num_draws_sampling,
838-
len(big_fit.column_names),
844+
len(big_fit.column_names) + 3,
839845
)
840846
# list
841847
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-good')
@@ -848,22 +854,41 @@ def test_instantiate_from_csvfiles() -> None:
848854
draws_pd = bern_fit.draws_pd()
849855
assert draws_pd.shape == (
850856
bern_fit.runset.chains * bern_fit.num_draws_sampling,
851-
len(bern_fit.column_names),
857+
len(bern_fit.column_names) + 3,
852858
)
853859
# single csvfile
854860
bern_fit = from_csv(path=csvfiles[0])
855861
draws_pd = bern_fit.draws_pd()
856862
assert draws_pd.shape == (
857863
bern_fit.num_draws_sampling,
858-
len(bern_fit.column_names),
864+
len(bern_fit.column_names) + 3,
859865
)
860866
# glob
861867
csvfiles_path = os.path.join(csvfiles_path, '*.csv')
862868
big_fit = from_csv(path=csvfiles_path)
863869
draws_pd = big_fit.draws_pd()
864870
assert draws_pd.shape == (
865871
big_fit.runset.chains * big_fit.num_draws_sampling,
866-
len(big_fit.column_names),
872+
len(big_fit.column_names) + 3,
873+
)
874+
875+
876+
def test_pd_xr_agreement():
877+
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', '*.csv')
878+
bern_fit = from_csv(path=csvfiles_path)
879+
880+
draws_pd = bern_fit.draws_pd()
881+
draws_xr = bern_fit.draws_xr()
882+
883+
# check that the indexing is the same between the two
884+
np.testing.assert_equal(
885+
draws_pd[draws_pd['chain__'] == 2]['theta'],
886+
draws_xr.theta.sel(chain=2).values,
887+
)
888+
# "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
889+
np.testing.assert_equal(
890+
draws_pd[draws_pd['iter__'] == 100]['theta'],
891+
draws_xr.theta.sel(draw=99).values,
867892
)
868893

869894

@@ -930,7 +955,7 @@ def test_instantiate_from_csvfiles_fail(
930955
def test_from_csv_fixed_param() -> None:
931956
csv_path = os.path.join(DATAFILES_PATH, 'fixed_param_sample.csv')
932957
fixed_param_sample = from_csv(path=csv_path)
933-
assert fixed_param_sample.draws_pd().shape == (100, 85)
958+
assert fixed_param_sample.draws_pd().shape == (100, 88)
934959

935960

936961
def test_custom_metric() -> None:
@@ -1292,14 +1317,14 @@ def test_save_warmup() -> None:
12921317
len(BERNOULLI_COLS),
12931318
)
12941319

1295-
assert bern_fit.draws_pd().shape == (200, len(BERNOULLI_COLS))
1320+
assert bern_fit.draws_pd().shape == (200, len(BERNOULLI_COLS) + 3)
12961321
assert bern_fit.draws_pd(inc_warmup=False).shape == (
12971322
200,
1298-
len(BERNOULLI_COLS),
1323+
len(BERNOULLI_COLS) + 3,
12991324
)
13001325
assert bern_fit.draws_pd(inc_warmup=True).shape == (
13011326
600,
1302-
len(BERNOULLI_COLS),
1327+
len(BERNOULLI_COLS) + 3,
13031328
)
13041329

13051330

@@ -1371,7 +1396,7 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
13711396
with caplog.at_level(logging.WARNING):
13721397
assert bern_fit.draws_pd(inc_warmup=True).shape == (
13731398
200,
1374-
len(BERNOULLI_COLS),
1399+
len(BERNOULLI_COLS) + 3,
13751400
)
13761401
check_present(
13771402
caplog,

0 commit comments

Comments
 (0)