22
33import os
44
5+ import numpy as np
6+ import pytest
7+
58import cmdstanpy
69
710HERE = os .path .dirname (os .path .abspath (__file__ ))
@@ -14,6 +17,7 @@ def test_laplace_from_csv():
1417 fit = model .laplace_sample (
1518 data = {},
1619 mode = os .path .join (DATAFILES_PATH , 'optimize' , 'rosenbrock_mle.csv' ),
20+ jacobian = False ,
1721 )
1822 assert 'x' in fit .stan_variables ()
1923 assert 'y' in fit .stan_variables ()
@@ -23,8 +27,38 @@ def test_laplace_from_csv():
2327def test_laplace_runs_opt ():
2428 model_file = os .path .join (DATAFILES_PATH , 'optimize' , 'rosenbrock.stan' )
2529 model = cmdstanpy .CmdStanModel (stan_file = model_file )
26- fit1 = model .laplace_sample (data = {}, seed = 1234 )
30+ fit1 = model .laplace_sample (data = {}, seed = 1234 , opt_args = { 'iter' : 1003 } )
2731 assert isinstance (fit1 .mode , cmdstanpy .CmdStanMLE )
2832
2933 assert fit1 .mode .metadata .cmdstan_config ['seed' ] == 1234
3034 assert fit1 ._metadata .cmdstan_config ['seed' ] == 1234
35+ assert fit1 .mode .metadata .cmdstan_config ['iter' ] == 1003
36+
37+
38+ def test_laplace_bad_jacobian_mismatch ():
39+ model_file = os .path .join (DATAFILES_PATH , 'optimize' , 'rosenbrock.stan' )
40+ model = cmdstanpy .CmdStanModel (stan_file = model_file )
41+ with pytest .raises (ValueError ):
42+ model .laplace_sample (
43+ data = {},
44+ mode = os .path .join (DATAFILES_PATH , 'optimize' , 'rosenbrock_mle.csv' ),
45+ jacobian = True ,
46+ )
47+
48+
49+ def test_laplace_outputs ():
50+ model_file = os .path .join (DATAFILES_PATH , 'optimize' , 'rosenbrock.stan' )
51+ model = cmdstanpy .CmdStanModel (stan_file = model_file )
52+ fit = model .laplace_sample (data = {}, seed = 1234 , draws = 123 )
53+
54+ variables = fit .stan_variables ()
55+ assert 'x' in variables
56+ assert 'y' in variables
57+ assert variables ['x' ].shape == (123 ,)
58+
59+ np .testing .assert_array_equal (variables ['x' ], fit .x )
60+
61+ fit_pd = fit .draws_pd ()
62+ assert 'x' in fit_pd .columns
63+ assert 'y' in fit_pd .columns
64+ assert fit_pd ['x' ].shape == (123 ,)
0 commit comments