Skip to content

Commit 977ba76

Browse files
committed
Add sig_figs option for log_prob.
1 parent b5d7484 commit 977ba76

2 files changed

Lines changed: 31 additions & 3 deletions

File tree

cmdstanpy/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,7 @@ def log_prob(
16401640
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
16411641
*,
16421642
jacobian: bool = True,
1643+
sig_figs: Optional[int] = None,
16431644
) -> pd.DataFrame:
16441645
"""
16451646
Calculate the log probability and gradient at the given parameter
@@ -1663,6 +1664,11 @@ def log_prob(
16631664
:param jacobian: Whether or not to enable the Jacobian adjustment
16641665
for constrained parameters. Defaults to ``True``.
16651666
1667+
:param sig_figs: Numerical precision used for output CSV and text files.
1668+
Must be an integer between 1 and 18. If unspecified, the default
1669+
precision for the system file I/O is used; the usual value is 6.
1670+
Introduced in CmdStan-2.25.
1671+
16661672
:return: A pandas.DataFrame containing columns "lp__" and additional
16671673
columns for the gradient values. These gradients will be for the
16681674
unconstrained parameters of the model.
@@ -1689,6 +1695,8 @@ def log_prob(
16891695

16901696
output = os.path.join(output_dir, "output.csv")
16911697
cmd += ["output", f"file={output}"]
1698+
if sig_figs is not None:
1699+
cmd.append(f"sig_figs={sig_figs}")
16921700

16931701
get_logger().debug("Cmd: %s", str(cmd))
16941702

test/test_log_prob.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import re
66
from test import check_present
7+
from typing import Optional
78

89
import numpy as np
910
import pytest
@@ -20,17 +21,36 @@
2021
BERN_BASENAME = 'bernoulli'
2122

2223

23-
def test_lp_good() -> None:
24+
@pytest.mark.parametrize("sig_figs", [15, 3, None])
25+
def test_lp_good(sig_figs: Optional[int]) -> None:
2426
model = CmdStanModel(stan_file=BERN_STAN)
25-
out = model.log_prob({"theta": 0.1}, data=BERN_DATA)
27+
params = {"theta": 0.34903938392023830482}
28+
out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs)
2629
assert "lp_" in out.columns[0]
2730

31+
# Check the number of digits.
32+
expected_values = {
33+
None: ["-7.02147", "-1.18847"],
34+
3: ["-7.02", "-1.19"],
35+
15: ["-7.02146677130525","-1.18847260704286"],
36+
}[sig_figs]
37+
for actual, expected in zip(out.values[0], expected_values):
38+
assert str(actual) == expected
39+
2840
out_unadjusted = model.log_prob(
29-
{"theta": 0.1}, data=BERN_DATA, jacobian=False
41+
params, data=BERN_DATA, jacobian=False, sig_figs=sig_figs
3042
)
3143
assert "lp_" in out_unadjusted.columns[0]
3244
assert not np.allclose(out.to_numpy(), out_unadjusted.to_numpy())
3345

46+
expected_values = {
47+
None: ["-5.53959", "-1.49039"],
48+
3: ["-5.54", "-1.49"],
49+
15: ["-5.53959011989073", "-1.49039383920238"],
50+
}[sig_figs]
51+
for actual, expected in zip(out_unadjusted.values[0], expected_values):
52+
assert str(actual) == expected
53+
3454

3555
def test_lp_bad(
3656
caplog: pytest.LogCaptureFixture,

0 commit comments

Comments
 (0)