@@ -1181,6 +1181,7 @@ def variational(
11811181 :param iter: Maximum number of ADVI iterations.
11821182
11831183 :param grad_samples: Number of MC draws for computing the gradient.
1184+ Default is 10. If problems arise, try doubling current value.
11841185
11851186 :param elbo_samples: Number of MC draws for estimate of ELBO.
11861187
@@ -1247,14 +1248,10 @@ def variational(
12471248
12481249 # treat failure to converge as failure
12491250 transcript_file = runset .stdout_files [dummy_chain_id ]
1250- valid = True
12511251 pat = re .compile (r'The algorithm may not have converged.' , re .M )
12521252 with open (transcript_file , 'r' ) as transcript :
12531253 contents = transcript .read ()
1254- errors = re .findall (pat , contents )
1255- if len (errors ) > 0 :
1256- valid = False
1257- if not valid :
1254+ if len (re .findall (pat , contents )) > 0 :
12581255 if require_converged :
12591256 raise RuntimeError (
12601257 'The algorithm may not have converged.\n '
@@ -1268,12 +1265,23 @@ def variational(
12681265 'Proceeding because require_converged is set to False' ,
12691266 )
12701267 if not runset ._check_retcodes ():
1271- msg = 'Error during variational inference:\n {}' .format (
1272- runset .get_err_msgs ()
1273- )
1274- msg = '{}Command and output files:\n {}' .format (
1275- msg , runset .__repr__ ()
1276- )
1268+ transcript_file = runset .stdout_files [dummy_chain_id ]
1269+ with open (transcript_file , 'r' ) as transcript :
1270+ contents = transcript .read ()
1271+ pat = re .compile (r'stan::variational::normal_meanfield::calc_grad:' , re .M )
1272+ if len (re .findall (pat , contents )) > 0 :
1273+ if grad_samples is None :
1274+ grad_samples = 10
1275+ msg = (
1276+ 'Variational algorithm gradient calculation failed. '
1277+ 'Double the value of argument "grad_samples", '
1278+ 'current value is {}.' .format (grad_samples )
1279+ )
1280+ else :
1281+ msg = (
1282+ 'Variational algorithm failed.\n '
1283+ 'Console output:\n {}' .format (contents )
1284+ )
12771285 raise RuntimeError (msg )
12781286 # pylint: disable=invalid-name
12791287 vb = CmdStanVB (runset )
0 commit comments