Skip to content

Commit 2a382aa

Browse files
committed
Basic tests
1 parent 1075b31 commit 2a382aa

3 files changed

Lines changed: 67 additions & 7 deletions

File tree

cmdstanpy/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,6 @@ def laplace_sample(
17091709
timeout: Optional[float] = None,
17101710
opt_args: Optional[Dict[str, Any]] = None,
17111711
) -> CmdStanLaplace:
1712-
17131712
if cmdstan_version_before(2, 32, self.exe_info()):
17141713
raise ValueError(
17151714
"Method 'laplace_sample' not available for CmdStan versions "
@@ -1741,14 +1740,23 @@ def laplace_sample(
17411740
) from e
17421741
elif not isinstance(mode, CmdStanMLE):
17431742
cmdstan_mode = from_csv(mode) # type: ignore # we check below
1744-
if cmdstan_mode.runset.method != Method.OPTIMIZE:
1745-
raise ValueError(
1746-
"Mode must be a CmdStanMLE or a path to an optimize CSV"
1747-
)
17481743
else:
17491744
cmdstan_mode = mode
17501745

1751-
# TODO: jacobian warnings on mismatch
1746+
if cmdstan_mode.runset.method != Method.OPTIMIZE:
1747+
raise ValueError(
1748+
"Mode must be a CmdStanMLE or a path to an optimize CSV"
1749+
)
1750+
1751+
mode_jacobian = (
1752+
cmdstan_mode.runset._args.method_args.jacobian # type: ignore
1753+
)
1754+
if mode_jacobian != jacobian:
1755+
raise ValueError(
1756+
"Jacobian argument to optimize and laplace must match!\n"
1757+
f"Laplace was run with jacobian={jacobian},\n"
1758+
f"but optimize was run with jacobian={mode_jacobian}"
1759+
)
17521760

17531761
laplace_args = LaplaceArgs(
17541762
cmdstan_mode.runset.csv_files[0], draws, jacobian

cmdstanpy/stanfit/laplace.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ def __repr__(self) -> str:
251251
)
252252
return rep
253253

254+
def __getattr__(self, attr: str) -> np.ndarray:
255+
"""Synonymous with ``fit.stan_variable(attr)"""
256+
if attr.startswith("_"):
257+
raise AttributeError(f"Unknown variable name {attr}")
258+
try:
259+
return self.stan_variable(attr)
260+
except ValueError as e:
261+
# pylint: disable=raise-missing-from
262+
raise AttributeError(*e.args)
263+
264+
def __getstate__(self) -> dict:
265+
# This function returns the mapping of objects to serialize with pickle.
266+
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
267+
# for details. We call _assemble_draws to ensure posterior samples have
268+
# been loaded prior to serialization.
269+
self._assemble_draws()
270+
return self.__dict__
271+
254272
@property
255273
def column_names(self) -> Tuple[str, ...]:
256274
"""

test/test_laplace.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import os
44

5+
import numpy as np
6+
import pytest
7+
58
import cmdstanpy
69

710
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -14,6 +17,7 @@ def test_laplace_from_csv():
1417
fit = model.laplace_sample(
1518
data={},
1619
mode=os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock_mle.csv'),
20+
jacobian=False,
1721
)
1822
assert 'x' in fit.stan_variables()
1923
assert 'y' in fit.stan_variables()
@@ -23,8 +27,38 @@ def test_laplace_from_csv():
2327
def test_laplace_runs_opt():
2428
model_file = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
2529
model = cmdstanpy.CmdStanModel(stan_file=model_file)
26-
fit1 = model.laplace_sample(data={}, seed=1234)
30+
fit1 = model.laplace_sample(data={}, seed=1234, opt_args={'iter': 1003})
2731
assert isinstance(fit1.mode, cmdstanpy.CmdStanMLE)
2832

2933
assert fit1.mode.metadata.cmdstan_config['seed'] == 1234
3034
assert fit1._metadata.cmdstan_config['seed'] == 1234
35+
assert fit1.mode.metadata.cmdstan_config['iter'] == 1003
36+
37+
38+
def test_laplace_bad_jacobian_mismatch():
39+
model_file = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
40+
model = cmdstanpy.CmdStanModel(stan_file=model_file)
41+
with pytest.raises(ValueError):
42+
model.laplace_sample(
43+
data={},
44+
mode=os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock_mle.csv'),
45+
jacobian=True,
46+
)
47+
48+
49+
def test_laplace_outputs():
50+
model_file = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
51+
model = cmdstanpy.CmdStanModel(stan_file=model_file)
52+
fit = model.laplace_sample(data={}, seed=1234, draws=123)
53+
54+
variables = fit.stan_variables()
55+
assert 'x' in variables
56+
assert 'y' in variables
57+
assert variables['x'].shape == (123,)
58+
59+
np.testing.assert_array_equal(variables['x'], fit.x)
60+
61+
fit_pd = fit.draws_pd()
62+
assert 'x' in fit_pd.columns
63+
assert 'y' in fit_pd.columns
64+
assert fit_pd['x'].shape == (123,)

0 commit comments

Comments
 (0)