Skip to content

Commit 2d54607

Browse files
authored
Merge pull request #692 from tillahoffmann/log-prob-sig-figs
Add `sig_figs` option for `log_prob`.
2 parents b5d7484 + 6365fda commit 2d54607

2 files changed

Lines changed: 25 additions & 3 deletions

File tree

cmdstanpy/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,7 @@ def log_prob(
16401640
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
16411641
*,
16421642
jacobian: bool = True,
1643+
sig_figs: Optional[int] = None,
16431644
) -> pd.DataFrame:
16441645
"""
16451646
Calculate the log probability and gradient at the given parameter
@@ -1663,6 +1664,10 @@ def log_prob(
16631664
:param jacobian: Whether or not to enable the Jacobian adjustment
16641665
for constrained parameters. Defaults to ``True``.
16651666
1667+
:param sig_figs: Numerical precision used for output CSV and text files.
1668+
Must be an integer between 1 and 18. If unspecified, the default
1669+
precision for the system file I/O is used; the usual value is 6.
1670+
16661671
:return: A pandas.DataFrame containing columns "lp__" and additional
16671672
columns for the gradient values. These gradients will be for the
16681673
unconstrained parameters of the model.
@@ -1689,6 +1694,8 @@ def log_prob(
16891694

16901695
output = os.path.join(output_dir, "output.csv")
16911696
cmd += ["output", f"file={output}"]
1697+
if sig_figs is not None:
1698+
cmd.append(f"sig_figs={sig_figs}")
16921699

16931700
get_logger().debug("Cmd: %s", str(cmd))
16941701

test/test_log_prob.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import re
66
from test import check_present
7+
from typing import List, Optional
78

89
import numpy as np
910
import pytest
@@ -20,17 +21,31 @@
2021
BERN_BASENAME = 'bernoulli'
2122

2223

23-
def test_lp_good() -> None:
24+
@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [
25+
(11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]),
26+
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
27+
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"])
28+
])
29+
def test_lp_good(sig_figs: Optional[int], expected: List[str],
30+
expected_unadjusted: List[str]) -> None:
2431
model = CmdStanModel(stan_file=BERN_STAN)
25-
out = model.log_prob({"theta": 0.1}, data=BERN_DATA)
32+
params = {"theta": 0.34903938392023830482}
33+
out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs)
2634
assert "lp_" in out.columns[0]
2735

36+
# Check the number of digits.
37+
for actual, value in zip(out.values[0], expected):
38+
assert str(actual) == value
39+
2840
out_unadjusted = model.log_prob(
29-
{"theta": 0.1}, data=BERN_DATA, jacobian=False
41+
params, data=BERN_DATA, jacobian=False, sig_figs=sig_figs
3042
)
3143
assert "lp_" in out_unadjusted.columns[0]
3244
assert not np.allclose(out.to_numpy(), out_unadjusted.to_numpy())
3345

46+
for actual, value in zip(out_unadjusted.values[0], expected_unadjusted):
47+
assert str(actual) == value
48+
3449

3550
def test_lp_bad(
3651
caplog: pytest.LogCaptureFixture,

0 commit comments

Comments
 (0)