Skip to content

Commit feed16c

Browse files
committed
Fix tests
1 parent ae910b5 commit feed16c

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

test/test_generate_quantities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
8282
bern_gqs.draws_pd(inc_sample=True).shape[1]
8383
== bern_gqs.previous_fit.draws_pd().shape[1]
8484
+ bern_gqs.draws_pd().shape[1]
85+
- 3 # chain, iter, draw duplicates
8586
)
8687

8788
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (
@@ -199,11 +200,12 @@ def test_from_previous_fit_draws() -> None:
199200
bern_gqs.draws_pd(inc_sample=True).shape[1]
200201
== bern_gqs.previous_fit.draws_pd().shape[1]
201202
+ bern_gqs.draws_pd().shape[1]
203+
- 3 # duplicates of chain, iter, and draw
202204
)
203205
row1_sample_pd = bern_fit.draws_pd().iloc[0]
204206
row1_gqs_pd = bern_gqs.draws_pd().iloc[0]
205207
np.testing.assert_array_equal(
206-
pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values,
208+
pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values[3:],
207209
bern_gqs.draws_pd(inc_sample=True).iloc[0].values,
208210
)
209211
# draws_xr

0 commit comments

Comments
 (0)