File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 44import io
55import json
66import os
7+ import shutil
78import unittest
89
910import numpy as np
@@ -582,5 +583,24 @@ def test_show_console(self):
582583 self .assertTrue ('Chain [1] method = optimize' in console )
583584
584585
586+ def test_exe_only (self ):
587+ stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
588+ bern_model = CmdStanModel (stan_file = stan )
589+ exe_only = os .path .join (DATAFILES_PATH , 'exe_only' )
590+ shutil .copyfile (bern_model .exe_file , exe_only )
591+ os .chmod (exe_only , 0o755 )
592+
593+ bern2_model = CmdStanModel (exe_file = exe_only )
594+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
595+ mle = bern_model .optimize (data = jdata )
596+ self .assertEqual (
597+ mle .optimized_params_np [0 ], mle .optimized_params_dict ['lp__' ]
598+ )
599+ self .assertEqual (
600+ mle .optimized_params_np [1 ], mle .optimized_params_dict ['theta' ]
601+ )
602+
603+
604+
585605if __name__ == '__main__' :
586606 unittest .main ()
Original file line number Diff line number Diff line change 33import contextlib
44import io
55import os
6+ import shutil
67import unittest
78from math import fabs
89
@@ -233,5 +234,19 @@ def test_show_console(self):
233234 self .assertTrue ('Chain [1] method = variational' in console )
234235
235236
237+ def test_exe_only (self ):
238+ stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
239+ bern_model = CmdStanModel (stan_file = stan )
240+ exe_only = os .path .join (DATAFILES_PATH , 'exe_only' )
241+ shutil .copyfile (bern_model .exe_file , exe_only )
242+ os .chmod (exe_only , 0o755 )
243+
244+ bern2_model = CmdStanModel (exe_file = exe_only )
245+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
246+ variational = bern2_model .variational (data = jdata , algorithm = 'meanfield' )
247+ self .assertEqual (variational .variational_sample .shape , (1000 , 4 ))
248+
249+
250+
236251if __name__ == '__main__' :
237252 unittest .main ()
You can’t perform that action at this time.
0 commit comments