Skip to content

Commit b513445

Browse files
committed
Rename, add gq, fix tests
1 parent 5dff5a7 commit b513445

4 files changed

Lines changed: 70 additions & 47 deletions

File tree

cmdstanpy/stanfit/gq.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,41 @@ def draws_pd(
344344

345345
previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
346346

347+
draws = self.draws(inc_warmup=inc_warmup)
348+
# add long-form columns for chain, iteration, draw
349+
n_draws, n_chains, _ = draws.shape
350+
chains_col = (
351+
np.repeat(np.arange(1, n_chains + 1), n_draws)
352+
.reshape(1, n_chains, n_draws)
353+
.T
354+
)
355+
iter_col = (
356+
np.tile(np.arange(1, n_draws + 1), n_chains)
357+
.reshape(1, n_chains, n_draws)
358+
.T
359+
)
360+
draw_col = (
361+
np.arange(1, (n_draws * n_chains) + 1)
362+
.reshape(1, n_chains, n_draws)
363+
.T
364+
)
365+
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
366+
367+
vars_list = ['chain__', 'iter__', 'draw__'] + vars_list
368+
if gq_cols:
369+
gq_cols = ['chain__', 'iter__', 'draw__'] + gq_cols
370+
371+
draws_pd = pd.DataFrame(
372+
data=flatten_chains(draws),
373+
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
374+
)
375+
347376
if inc_sample and mcmc_vars:
348377
if gq_cols:
349378
return pd.concat(
350379
[
351380
previous_draws_pd,
352-
pd.DataFrame(
353-
data=flatten_chains(
354-
self.draws(inc_warmup=inc_warmup)
355-
),
356-
columns=self.column_names,
357-
)[gq_cols],
381+
draws_pd[gq_cols],
358382
],
359383
axis='columns',
360384
)[vars_list]
@@ -371,23 +395,14 @@ def draws_pd(
371395
return pd.concat(
372396
[
373397
previous_draws_pd.drop(columns=dups).reset_index(drop=True),
374-
pd.DataFrame(
375-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
376-
columns=self.column_names,
377-
),
398+
draws_pd,
378399
],
379400
axis=1,
380401
)
381402
elif gq_cols:
382-
return pd.DataFrame(
383-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
384-
columns=self.column_names,
385-
)[gq_cols]
386-
387-
return pd.DataFrame(
388-
data=flatten_chains(self.draws(inc_warmup=inc_warmup)),
389-
columns=self.column_names,
390-
)
403+
return draws_pd[gq_cols]
404+
405+
return draws_pd
391406

392407
@overload
393408
def draws_xr(

cmdstanpy/stanfit/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,11 +635,11 @@ def draws_pd(
635635
)
636636
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
637637

638-
cols = ['chain', 'iter', 'draw'] + cols
638+
cols = ['chain__', 'iter__', 'draw__'] + cols
639639

640640
return pd.DataFrame(
641641
data=flatten_chains(draws),
642-
columns=['chain', 'iter', 'draw'] + list(self.column_names),
642+
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
643643
)[cols]
644644

645645
def draws_xr(

test/test_generate_quantities.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,16 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
7777
)
7878

7979
# draws_pd()
80-
assert bern_gqs.draws_pd().shape == (400, 10)
80+
assert bern_gqs.draws_pd().shape == (400, 13)
8181
assert (
8282
bern_gqs.draws_pd(inc_sample=True).shape[1]
8383
== bern_gqs.previous_fit.draws_pd().shape[1]
8484
+ bern_gqs.draws_pd().shape[1]
8585
)
8686

87-
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == column_names
87+
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (
88+
["chain__", "iter__", "draw__"] + column_names
89+
)
8890

8991

9092
def test_from_csv_files_bad() -> None:
@@ -153,7 +155,7 @@ def test_from_previous_fit_draws() -> None:
153155

154156
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
155157

156-
assert bern_gqs.draws_pd().shape == (400, 10)
158+
assert bern_gqs.draws_pd().shape == (400, 13)
157159
assert (
158160
bern_gqs.draws_pd(inc_sample=True).shape[1]
159161
== bern_gqs.previous_fit.draws_pd().shape[1]
@@ -267,14 +269,14 @@ def test_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
267269
10,
268270
)
269271

270-
assert bern_gqs.draws_pd().shape == (400, 10)
271-
assert bern_gqs.draws_pd(inc_warmup=False).shape == (400, 10)
272-
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 10)
272+
assert bern_gqs.draws_pd().shape == (400, 13)
273+
assert bern_gqs.draws_pd(inc_warmup=False).shape == (400, 13)
274+
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 13)
273275
assert bern_gqs.draws_pd(vars=['y_rep'], inc_warmup=False).shape == (
274276
400,
275-
10,
277+
13,
276278
)
277-
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 10)
279+
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 13)
278280

279281
theta = bern_gqs.stan_variable(var='theta')
280282
assert theta.shape == (400,)
@@ -523,7 +525,7 @@ def test_from_optimization() -> None:
523525
assert bern_gqs.draws(inc_sample=True).shape == (1, 1, 12)
524526

525527
# draws_pd()
526-
assert bern_gqs.draws_pd().shape == (1, 10)
528+
assert bern_gqs.draws_pd().shape == (1, 13)
527529
assert (
528530
bern_gqs.draws_pd(inc_sample=True).shape[1]
529531
== bern_gqs.previous_fit.optimized_params_pd.shape[1]
@@ -665,7 +667,7 @@ def test_from_vb():
665667
assert bern_gqs.draws(inc_sample=True).shape == (1000, 1, 14)
666668

667669
# draws_pd()
668-
assert bern_gqs.draws_pd().shape == (1000, 10)
670+
assert bern_gqs.draws_pd().shape == (1000, 13)
669671
assert (
670672
bern_gqs.draws_pd(inc_sample=True).shape[1]
671673
== bern_gqs.previous_fit.variational_sample_pd.shape[1]

test/test_sample.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -745,18 +745,24 @@ def test_validate_good_run() -> None:
745745
draws_pd = fit.draws_pd()
746746
assert draws_pd.shape == (
747747
fit.runset.chains * fit.num_draws_sampling,
748-
len(fit.column_names),
748+
len(fit.column_names) + 3,
749749
)
750-
assert fit.draws_pd(vars=['theta']).shape == (400, 1)
751-
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 2)
752-
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 2)
753-
assert fit.draws_pd(vars='theta').shape == (400, 1)
750+
assert fit.draws_pd(vars=['theta']).shape == (400, 4)
751+
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 5)
752+
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 5)
753+
assert fit.draws_pd(vars='theta').shape == (400, 4)
754754

755755
assert list(fit.draws_pd(vars=['theta', 'lp__']).columns) == [
756+
'chain__',
757+
'iter__',
758+
'draw__',
756759
'theta',
757760
'lp__',
758761
]
759762
assert list(fit.draws_pd(vars=['lp__', 'theta']).columns) == [
763+
'chain__',
764+
'iter__',
765+
'draw__',
760766
'lp__',
761767
'theta',
762768
]
@@ -817,7 +823,7 @@ def test_validate_big_run() -> None:
817823
assert fit.step_size.shape == (2,)
818824
assert fit.metric.shape == (2, 2095)
819825
assert fit.draws().shape == (1000, 2, 2102)
820-
assert fit.draws_pd(vars=['phi']).shape == (2000, 2095)
826+
assert fit.draws_pd(vars=['phi']).shape == (2000, 2098)
821827
with raises_nested(ValueError, r'Unknown variable: gamma'):
822828
fit.draws_pd(vars=['gamma'])
823829

@@ -828,14 +834,14 @@ def test_instantiate_from_csvfiles() -> None:
828834
draws_pd = bern_fit.draws_pd()
829835
assert draws_pd.shape == (
830836
bern_fit.runset.chains * bern_fit.num_draws_sampling,
831-
len(bern_fit.column_names),
837+
len(bern_fit.column_names) + 3,
832838
)
833839
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-big')
834840
big_fit = from_csv(path=csvfiles_path)
835841
draws_pd = big_fit.draws_pd()
836842
assert draws_pd.shape == (
837843
big_fit.runset.chains * big_fit.num_draws_sampling,
838-
len(big_fit.column_names),
844+
len(big_fit.column_names) + 3,
839845
)
840846
# list
841847
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-good')
@@ -848,22 +854,22 @@ def test_instantiate_from_csvfiles() -> None:
848854
draws_pd = bern_fit.draws_pd()
849855
assert draws_pd.shape == (
850856
bern_fit.runset.chains * bern_fit.num_draws_sampling,
851-
len(bern_fit.column_names),
857+
len(bern_fit.column_names) + 3,
852858
)
853859
# single csvfile
854860
bern_fit = from_csv(path=csvfiles[0])
855861
draws_pd = bern_fit.draws_pd()
856862
assert draws_pd.shape == (
857863
bern_fit.num_draws_sampling,
858-
len(bern_fit.column_names),
864+
len(bern_fit.column_names) + 3,
859865
)
860866
# glob
861867
csvfiles_path = os.path.join(csvfiles_path, '*.csv')
862868
big_fit = from_csv(path=csvfiles_path)
863869
draws_pd = big_fit.draws_pd()
864870
assert draws_pd.shape == (
865871
big_fit.runset.chains * big_fit.num_draws_sampling,
866-
len(big_fit.column_names),
872+
len(big_fit.column_names) + 3,
867873
)
868874

869875

@@ -930,7 +936,7 @@ def test_instantiate_from_csvfiles_fail(
930936
def test_from_csv_fixed_param() -> None:
931937
csv_path = os.path.join(DATAFILES_PATH, 'fixed_param_sample.csv')
932938
fixed_param_sample = from_csv(path=csv_path)
933-
assert fixed_param_sample.draws_pd().shape == (100, 85)
939+
assert fixed_param_sample.draws_pd().shape == (100, 88)
934940

935941

936942
def test_custom_metric() -> None:
@@ -1292,14 +1298,14 @@ def test_save_warmup() -> None:
12921298
len(BERNOULLI_COLS),
12931299
)
12941300

1295-
assert bern_fit.draws_pd().shape == (200, len(BERNOULLI_COLS))
1301+
assert bern_fit.draws_pd().shape == (200, len(BERNOULLI_COLS) + 3)
12961302
assert bern_fit.draws_pd(inc_warmup=False).shape == (
12971303
200,
1298-
len(BERNOULLI_COLS),
1304+
len(BERNOULLI_COLS) + 3,
12991305
)
13001306
assert bern_fit.draws_pd(inc_warmup=True).shape == (
13011307
600,
1302-
len(BERNOULLI_COLS),
1308+
len(BERNOULLI_COLS) + 3,
13031309
)
13041310

13051311

@@ -1371,7 +1377,7 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
13711377
with caplog.at_level(logging.WARNING):
13721378
assert bern_fit.draws_pd(inc_warmup=True).shape == (
13731379
200,
1374-
len(BERNOULLI_COLS),
1380+
len(BERNOULLI_COLS) + 3,
13751381
)
13761382
check_present(
13771383
caplog,

0 commit comments

Comments
 (0)