Skip to content

Commit 26f38ca

Browse files
committed
more error checking on ADVI
1 parent 91d6183 commit 26f38ca

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

cmdstanpy/model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ def variational(
11811181
:param iter: Maximum number of ADVI iterations.
11821182
11831183
:param grad_samples: Number of MC draws for computing the gradient.
1184+
Default is 10. If problems arise, try doubling current value.
11841185
11851186
:param elbo_samples: Number of MC draws for estimate of ELBO.
11861187
@@ -1247,14 +1248,10 @@ def variational(
12471248

12481249
# treat failure to converge as failure
12491250
transcript_file = runset.stdout_files[dummy_chain_id]
1250-
valid = True
12511251
pat = re.compile(r'The algorithm may not have converged.', re.M)
12521252
with open(transcript_file, 'r') as transcript:
12531253
contents = transcript.read()
1254-
errors = re.findall(pat, contents)
1255-
if len(errors) > 0:
1256-
valid = False
1257-
if not valid:
1254+
if len(re.findall(pat, contents)) > 0:
12581255
if require_converged:
12591256
raise RuntimeError(
12601257
'The algorithm may not have converged.\n'
@@ -1268,12 +1265,23 @@ def variational(
12681265
'Proceeding because require_converged is set to False',
12691266
)
12701267
if not runset._check_retcodes():
1271-
msg = 'Error during variational inference:\n{}'.format(
1272-
runset.get_err_msgs()
1273-
)
1274-
msg = '{}Command and output files:\n{}'.format(
1275-
msg, runset.__repr__()
1276-
)
1268+
transcript_file = runset.stdout_files[dummy_chain_id]
1269+
with open(transcript_file, 'r') as transcript:
1270+
contents = transcript.read()
1271+
pat = re.compile(r'stan::variational::normal_meanfield::calc_grad:', re.M)
1272+
if len(re.findall(pat, contents)) > 0:
1273+
if grad_samples is None:
1274+
grad_samples = 10
1275+
msg = (
1276+
'Variational algorithm gradient calculation failed. '
1277+
'Double the value of argument "grad_samples", '
1278+
'current value is {}.'.format(grad_samples)
1279+
)
1280+
else:
1281+
msg = (
1282+
'Variational algorithm failed.\n '
1283+
'Console output:\n{}'.format(contents)
1284+
)
12771285
raise RuntimeError(msg)
12781286
# pylint: disable=invalid-name
12791287
vb = CmdStanVB(runset)

0 commit comments

Comments
 (0)