|
4 | 4 | import os |
5 | 5 | import re |
6 | 6 | from test import check_present |
| 7 | +from typing import List, Optional |
7 | 8 |
|
8 | 9 | import numpy as np |
9 | 10 | import pytest |
|
20 | 21 | BERN_BASENAME = 'bernoulli' |
21 | 22 |
|
22 | 23 |
|
23 | | -def test_lp_good() -> None: |
| 24 | +@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [ |
| 25 | + (11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]), |
| 26 | + (3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]), |
| 27 | + (None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]) |
| 28 | +]) |
| 29 | +def test_lp_good(sig_figs: Optional[int], expected: List[str], |
| 30 | + expected_unadjusted: List[str]) -> None: |
24 | 31 | model = CmdStanModel(stan_file=BERN_STAN) |
25 | | - out = model.log_prob({"theta": 0.1}, data=BERN_DATA) |
| 32 | + params = {"theta": 0.34903938392023830482} |
| 33 | + out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs) |
26 | 34 | assert "lp_" in out.columns[0] |
27 | 35 |
|
| 36 | + # Check the number of digits. |
| 37 | + for actual, value in zip(out.values[0], expected): |
| 38 | + assert str(actual) == value |
| 39 | + |
28 | 40 | out_unadjusted = model.log_prob( |
29 | | - {"theta": 0.1}, data=BERN_DATA, jacobian=False |
| 41 | + params, data=BERN_DATA, jacobian=False, sig_figs=sig_figs |
30 | 42 | ) |
31 | 43 | assert "lp_" in out_unadjusted.columns[0] |
32 | 44 | assert not np.allclose(out.to_numpy(), out_unadjusted.to_numpy()) |
33 | 45 |
|
| 46 | + for actual, value in zip(out_unadjusted.values[0], expected_unadjusted): |
| 47 | + assert str(actual) == value |
| 48 | + |
34 | 49 |
|
35 | 50 | def test_lp_bad( |
36 | 51 | caplog: pytest.LogCaptureFixture, |
|
0 commit comments