|
4 | 4 | import os |
5 | 5 | import re |
6 | 6 | from test import check_present |
7 | | -from typing import Optional |
| 7 | +from typing import List, Optional |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | import pytest |
|
21 | 21 | BERN_BASENAME = 'bernoulli' |
22 | 22 |
|
23 | 23 |
|
24 | | -@pytest.mark.parametrize("sig_figs", [15, 3, None]) |
25 | | -def test_lp_good(sig_figs: Optional[int]) -> None: |
| 24 | +@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [ |
| 25 | + (11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]), |
| 26 | + (3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]), |
| 27 | + (None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]) |
| 28 | +]) |
| 29 | +def test_lp_good(sig_figs: Optional[int], expected: List[str], |
| 30 | + expected_unadjusted: List[str]) -> None: |
26 | 31 | model = CmdStanModel(stan_file=BERN_STAN) |
27 | 32 | params = {"theta": 0.34903938392023830482} |
28 | 33 | out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs) |
29 | 34 | assert "lp_" in out.columns[0] |
30 | 35 |
|
31 | 36 | # Check the number of digits. |
32 | | - expected_values = { |
33 | | - None: ["-7.02147", "-1.18847"], |
34 | | - 3: ["-7.02", "-1.19"], |
35 | | - 15: ["-7.02146677130525","-1.18847260704286"], |
36 | | - }[sig_figs] |
37 | | - for actual, expected in zip(out.values[0], expected_values): |
38 | | - assert str(actual) == expected |
| 37 | + for actual, value in zip(out.values[0], expected): |
| 38 | + assert str(actual) == value |
39 | 39 |
|
40 | 40 | out_unadjusted = model.log_prob( |
41 | 41 | params, data=BERN_DATA, jacobian=False, sig_figs=sig_figs |
42 | 42 | ) |
43 | 43 | assert "lp_" in out_unadjusted.columns[0] |
44 | 44 | assert not np.allclose(out.to_numpy(), out_unadjusted.to_numpy()) |
45 | 45 |
|
46 | | - expected_values = { |
47 | | - None: ["-5.53959", "-1.49039"], |
48 | | - 3: ["-5.54", "-1.49"], |
49 | | - 15: ["-5.53959011989073", "-1.49039383920238"], |
50 | | - }[sig_figs] |
51 | | - for actual, expected in zip(out_unadjusted.values[0], expected_values): |
52 | | - assert str(actual) == expected |
| 46 | + for actual, value in zip(out_unadjusted.values[0], expected_unadjusted): |
| 47 | + assert str(actual) == value |
53 | 48 |
|
54 | 49 |
|
55 | 50 | def test_lp_bad( |
|
0 commit comments