Skip to content

Commit 6365fda

Browse files
committed
Move expected values to parameterize marker.
1 parent 025b9e7 commit 6365fda

1 file changed

Lines changed: 12 additions & 17 deletions

File tree

test/test_log_prob.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
from test import check_present
7-
from typing import Optional
7+
from typing import List, Optional
88

99
import numpy as np
1010
import pytest
@@ -21,35 +21,30 @@
2121
BERN_BASENAME = 'bernoulli'
2222

2323

24-
@pytest.mark.parametrize("sig_figs", [15, 3, None])
25-
def test_lp_good(sig_figs: Optional[int]) -> None:
24+
@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [
25+
(11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]),
26+
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
27+
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"])
28+
])
29+
def test_lp_good(sig_figs: Optional[int], expected: List[str],
30+
expected_unadjusted: List[str]) -> None:
2631
model = CmdStanModel(stan_file=BERN_STAN)
2732
params = {"theta": 0.34903938392023830482}
2833
out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs)
2934
assert "lp_" in out.columns[0]
3035

3136
# Check the number of digits.
32-
expected_values = {
33-
None: ["-7.02147", "-1.18847"],
34-
3: ["-7.02", "-1.19"],
35-
15: ["-7.02146677130525","-1.18847260704286"],
36-
}[sig_figs]
37-
for actual, expected in zip(out.values[0], expected_values):
38-
assert str(actual) == expected
37+
for actual, value in zip(out.values[0], expected):
38+
assert str(actual) == value
3939

4040
out_unadjusted = model.log_prob(
4141
params, data=BERN_DATA, jacobian=False, sig_figs=sig_figs
4242
)
4343
assert "lp_" in out_unadjusted.columns[0]
4444
assert not np.allclose(out.to_numpy(), out_unadjusted.to_numpy())
4545

46-
expected_values = {
47-
None: ["-5.53959", "-1.49039"],
48-
3: ["-5.54", "-1.49"],
49-
15: ["-5.53959011989073", "-1.49039383920238"],
50-
}[sig_figs]
51-
for actual, expected in zip(out_unadjusted.values[0], expected_values):
52-
assert str(actual) == expected
46+
for actual, value in zip(out_unadjusted.values[0], expected_unadjusted):
47+
assert str(actual) == value
5348

5449

5550
def test_lp_bad(

0 commit comments

Comments
 (0)