Skip to content

Commit ae910b5

Browse files
committed
Add tests
1 parent b513445 commit ae910b5

3 files changed

Lines changed: 60 additions & 4 deletions

File tree

cmdstanpy/stanfit/gq.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,8 @@ def draws_pd(
385385
else:
386386
return previous_draws_pd
387387
elif inc_sample and vars is None:
388-
cols_1 = self.previous_fit.column_names
389-
cols_2 = self.column_names
388+
cols_1 = list(previous_draws_pd.columns)
389+
cols_2 = list(draws_pd.columns)
390390
dups = [
391391
item
392392
for item, count in Counter(cols_1 + cols_2).items()
@@ -676,7 +676,6 @@ def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:
676676
elif isinstance(p_fit, CmdStanMLE):
677677
num_draws = 1
678678
if p_fit._save_iterations:
679-
680679
opt_iters = len(p_fit.optimized_iterations_np) # type: ignore
681680
if inc_warmup:
682681
num_draws = opt_iters
@@ -725,7 +724,6 @@ def _previous_draws_pd(
725724
return p_fit.draws_pd(vars or None, inc_warmup=inc_warmup)
726725

727726
elif isinstance(p_fit, CmdStanMLE):
728-
729727
if inc_warmup and p_fit._save_iterations:
730728
return p_fit.optimized_iterations_pd[sel] # type: ignore
731729
else:

test/test_generate_quantities.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,45 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
8989
)
9090

9191

92+
def test_pd_xr_agreement():
93+
# fitted_params sample - list of filenames
94+
goodfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', 'bern')
95+
csv_files = []
96+
for i in range(4):
97+
csv_files.append('{}-{}.csv'.format(goodfiles_path, i + 1))
98+
99+
# gq_model
100+
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
101+
model = CmdStanModel(stan_file=stan)
102+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
103+
104+
bern_gqs = model.generate_quantities(data=jdata, previous_fit=csv_files)
105+
106+
draws_pd = bern_gqs.draws_pd(inc_sample=True)
107+
draws_xr = bern_gqs.draws_xr(inc_sample=True)
108+
109+
# check that the indexing is the same between the two
110+
np.testing.assert_equal(
111+
draws_pd[draws_pd['chain__'] == 2]['y_rep[1]'],
112+
draws_xr.y_rep.sel(chain=2).isel(y_rep_dim_0=0).values,
113+
)
114+
# "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
115+
np.testing.assert_equal(
116+
draws_pd[draws_pd['iter__'] == 100]['y_rep[1]'],
117+
draws_xr.y_rep.sel(draw=99).isel(y_rep_dim_0=0).values,
118+
)
119+
120+
# check for included sample as well
121+
np.testing.assert_equal(
122+
draws_pd[draws_pd['chain__'] == 2]['theta'],
123+
draws_xr.theta.sel(chain=2).values,
124+
)
125+
np.testing.assert_equal(
126+
draws_pd[draws_pd['iter__'] == 100]['theta'],
127+
draws_xr.theta.sel(draw=99).values,
128+
)
129+
130+
92131
def test_from_csv_files_bad() -> None:
93132
# gq model
94133
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')

test/test_sample.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,25 @@ def test_instantiate_from_csvfiles() -> None:
873873
)
874874

875875

876+
def test_pd_xr_agreement():
877+
csvfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', '*.csv')
878+
bern_fit = from_csv(path=csvfiles_path)
879+
880+
draws_pd = bern_fit.draws_pd()
881+
draws_xr = bern_fit.draws_xr()
882+
883+
# check that the indexing is the same between the two
884+
np.testing.assert_equal(
885+
draws_pd[draws_pd['chain__'] == 2]['theta'],
886+
draws_xr.theta.sel(chain=2).values,
887+
)
888+
# "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
889+
np.testing.assert_equal(
890+
draws_pd[draws_pd['iter__'] == 100]['theta'],
891+
draws_xr.theta.sel(draw=99).values,
892+
)
893+
894+
876895
def test_instantiate_from_csvfiles_fail(
877896
caplog: pytest.LogCaptureFixture,
878897
) -> None:

0 commit comments

Comments
 (0)