11"""CmdStanModel"""
22
33import io
4- import logging
54import os
65import platform
76import re
@@ -91,7 +90,6 @@ def __init__(
9190 stanc_options : Optional [Dict [str , Any ]] = None ,
9291 cpp_options : Optional [Dict [str , Any ]] = None ,
9392 user_header : Optional [str ] = None ,
94- logger : Optional [logging .Logger ] = None ,
9593 ) -> None :
9694 """
9795 Initialize object given constructor args.
@@ -113,11 +111,6 @@ def __init__(
113111 cpp_options = cpp_options ,
114112 user_header = user_header ,
115113 )
116- if logger is not None :
117- get_logger ().warning (
118- "Parameter 'logger' is deprecated."
119- " Control logging behavior via logging.getLogger('cmdstanpy')"
120- )
121114
122115 if model_name is not None :
123116 if not model_name .strip ():
@@ -1181,6 +1174,7 @@ def variational(
11811174 :param iter: Maximum number of ADVI iterations.
11821175
11831176 :param grad_samples: Number of MC draws for computing the gradient.
1177+ Default is 10. If problems arise, try doubling current value.
11841178
11851179 :param elbo_samples: Number of MC draws for estimate of ELBO.
11861180
@@ -1247,14 +1241,10 @@ def variational(
12471241
12481242 # treat failure to converge as failure
12491243 transcript_file = runset .stdout_files [dummy_chain_id ]
1250- valid = True
12511244 pat = re .compile (r'The algorithm may not have converged.' , re .M )
12521245 with open (transcript_file , 'r' ) as transcript :
12531246 contents = transcript .read ()
1254- errors = re .findall (pat , contents )
1255- if len (errors ) > 0 :
1256- valid = False
1257- if not valid :
1247+ if len (re .findall (pat , contents )) > 0 :
12581248 if require_converged :
12591249 raise RuntimeError (
12601250 'The algorithm may not have converged.\n '
@@ -1268,12 +1258,25 @@ def variational(
12681258 'Proceeding because require_converged is set to False' ,
12691259 )
12701260 if not runset ._check_retcodes ():
1271- msg = 'Error during variational inference: \n {}' . format (
1272- runset . get_err_msgs ()
1273- )
1274- msg = '{}Command and output files: \n {}' . format (
1275- msg , runset . __repr__ ()
1261+ transcript_file = runset . stdout_files [ dummy_chain_id ]
1262+ with open ( transcript_file , 'r' ) as transcript :
1263+ contents = transcript . read ( )
1264+ pat = re . compile (
1265+ r'stan::variational::normal_meanfield::calc_grad:' , re . M
12761266 )
1267+ if len (re .findall (pat , contents )) > 0 :
1268+ if grad_samples is None :
1269+ grad_samples = 10
1270+ msg = (
1271+ 'Variational algorithm gradient calculation failed. '
1272+ 'Double the value of argument "grad_samples", '
1273+ 'current value is {}.' .format (grad_samples )
1274+ )
1275+ else :
1276+ msg = (
1277+ 'Variational algorithm failed.\n '
1278+ 'Console output:\n {}' .format (contents )
1279+ )
12771280 raise RuntimeError (msg )
12781281 # pylint: disable=invalid-name
12791282 vb = CmdStanVB (runset )
0 commit comments