Skip to content

Commit 6e4a9ef

Browse files
committed
more unit tests
1 parent 6677875 commit 6e4a9ef

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

test/test_optimize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io
55
import json
66
import os
7+
import shutil
78
import unittest
89

910
import numpy as np
@@ -582,5 +583,24 @@ def test_show_console(self):
582583
self.assertTrue('Chain [1] method = optimize' in console)
583584

584585

586+
def test_exe_only(self):
587+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
588+
bern_model = CmdStanModel(stan_file=stan)
589+
exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
590+
shutil.copyfile(bern_model.exe_file, exe_only)
591+
os.chmod(exe_only, 0o755)
592+
593+
bern2_model = CmdStanModel(exe_file=exe_only)
594+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
595+
mle = bern_model.optimize(data=jdata)
596+
self.assertEqual(
597+
mle.optimized_params_np[0], mle.optimized_params_dict['lp__']
598+
)
599+
self.assertEqual(
600+
mle.optimized_params_np[1], mle.optimized_params_dict['theta']
601+
)
602+
603+
604+
585605
if __name__ == '__main__':
586606
unittest.main()

test/test_variational.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import io
55
import os
6+
import shutil
67
import unittest
78
from math import fabs
89

@@ -233,5 +234,19 @@ def test_show_console(self):
233234
self.assertTrue('Chain [1] method = variational' in console)
234235

235236

237+
def test_exe_only(self):
238+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
239+
bern_model = CmdStanModel(stan_file=stan)
240+
exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
241+
shutil.copyfile(bern_model.exe_file, exe_only)
242+
os.chmod(exe_only, 0o755)
243+
244+
bern2_model = CmdStanModel(exe_file=exe_only)
245+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
246+
variational = bern2_model.variational(data=jdata, algorithm='meanfield')
247+
self.assertEqual(variational.variational_sample.shape, (1000, 4))
248+
249+
250+
236251
if __name__ == '__main__':
237252
unittest.main()

0 commit comments

Comments
 (0)