@@ -303,23 +303,29 @@ def get_err_msgs(self) -> str:
303303 self ._chain_ids [i ], fd .read ()
304304 )
305305 )
306- # pre 2.27, all msgs sent to stdout, including errors
307- if (
308- not cmdstan_version_at (2 , 27 )
309- and os .path .exists (self ._stdout_files [i ])
310- and os .stat (self ._stdout_files [i ]).st_size > 0
311- ):
312- with open (self ._stdout_files [i ], 'r' ) as fd :
313- contents = fd .read ()
314- # pattern matches initial "Exception" or "Error" msg
315- pat = re .compile (r'^E[rx].*$' , re .M )
316- errors = re .findall (pat , contents )
317- if len (errors ) > 0 :
318- msgs .append (
319- 'chain_id {}:\n \t {}\n ' .format (
320- self ._chain_ids [i ], '\n \t ' .join (errors )
321- )
322- )
306+ # pre 2.27, all sampler msgs go to stdout, including errors
307+ if self ._args .method == Method .SAMPLE :
308+ if (
309+ not cmdstan_version_at (2 , 27 )
310+ and os .path .exists (self ._stdout_files [i ])
311+ and os .stat (self ._stdout_files [i ]).st_size > 0
312+ ):
313+ with open (self ._stdout_files [i ], 'r' ) as fd :
314+ contents = fd .read ()
315+ # pattern matches initial "Exception" or "Error" msg
316+ pat = re .compile (r'^E[rx].*$' , re .M )
317+ errors = re .findall (pat , contents )
318+ if len (errors ) > 0 :
319+ msgs .append (
320+ 'chain_id {}:\n \t {}\n ' .format (
321+ self ._chain_ids [i ], '\n \t ' .join (errors )
322+ )
323+ )
324+ elif self ._args .method == Method .OPTIMIZE :
325+ msgs .append ('console log output:\n ' )
326+ with open (self ._stdout_files [0 ], 'r' ) as fd :
327+ msgs .append (fd .read ())
328+
323329 return '\n ' .join (msgs )
324330
325331 def save_csvfiles (self , dir : Optional [str ] = None ) -> None :
@@ -1240,7 +1246,9 @@ def __repr__(self) -> str:
12401246 '\n \t ' .join (self .runset .csv_files ),
12411247 '\n \t ' .join (self .runset .stdout_files ),
12421248 )
1243- # TODO - profiling files
1249+ if not self .runset ._check_retcodes ():
1250+ repr = '{}\n Warning: invalid estimate, ' .format (repr )
1251+ repr = '{} optimization failed to converge.' .format (repr )
12441252 return repr
12451253
12461254 def _set_mle_attrs (self , sample_csv_0 : str ) -> None :
@@ -1269,20 +1277,33 @@ def metadata(self) -> InferenceMetadata:
12691277 @property
12701278 def optimized_params_np (self ) -> np .ndarray :
12711279 """Returns optimized params as numpy array."""
1280+ if not self .runset ._check_retcodes ():
1281+ get_logger ().warning (
1282+ 'invalid estimate, optimization failed to converge'
1283+ )
12721284 return np .asarray (self ._mle )
12731285
12741286 @property
12751287 def optimized_params_pd (self ) -> pd .DataFrame :
12761288 """Returns optimized params as pandas DataFrame."""
1289+ if not self .runset ._check_retcodes ():
1290+ get_logger ().warning (
1291+ 'invalid estimate, optimization failed to converge'
1292+ )
12771293 return pd .DataFrame ([self ._mle ], columns = self .column_names )
12781294
12791295 @property
12801296 def optimized_params_dict (self ) -> Dict [str , float ]:
12811297 """Returns optimized params as Dict."""
1298+ if not self .runset ._check_retcodes ():
1299+ get_logger ().warning (
1300+ 'invalid estimate, optimization failed to converge'
1301+ )
12821302 return OrderedDict (zip (self .column_names , self ._mle ))
12831303
12841304 def stan_variable (
1285- self , var : Optional [str ] = None , * , name : Optional [str ] = None
1305+ self , var : Optional [str ] = None ,
1306+ check_convergence :bool = True , * , name : Optional [str ] = None
12861307 ) -> np .ndarray :
12871308 """
12881309 Return a numpy.ndarray which contains the estimates for the
@@ -1291,6 +1312,11 @@ def stan_variable(
12911312
12921313 :param var: variable name
12931314
1315+ :param check_convergence: Checks for failure to converge and
1316+ prints warning.failed to converge. ``False`` will supress
1317+ check and warning, default is ``True``.
1318+
1319+
12941320 See Also
12951321 --------
12961322 CmdStanMLE.stan_variables
@@ -1311,6 +1337,11 @@ def stan_variable(
13111337 raise ValueError ('no variable name specified.' )
13121338 if var not in self ._metadata .stan_vars_dims :
13131339 raise ValueError ('unknown variable name: {}' .format (var ))
1340+ if check_convergence and not self .runset ._check_retcodes ():
1341+ get_logger ().warning (
1342+ 'invalid estimate, optimization failed to converge'
1343+ )
1344+
13141345 col_idxs = list (self ._metadata .stan_vars_cols [var ])
13151346 vals = list (self ._mle )
13161347 xs = [vals [x ] for x in col_idxs ]
@@ -1319,21 +1350,29 @@ def stan_variable(
13191350 shape = self ._metadata .stan_vars_dims [var ]
13201351 return np .array (xs ).reshape (shape )
13211352
1322- def stan_variables (self ) -> Dict [str , np .ndarray ]:
1353+ def stan_variables (self , check_convergence : bool = True ) -> Dict [str , np .ndarray ]:
13231354 """
13241355 Return a dictionary mapping Stan program variables names
13251356 to the corresponding numpy.ndarray containing the inferred values.
13261357
1358+ :param check_convergence: Checks for failure to converge and
1359+ prints warning.failed to converge. ``False`` will supress
1360+ check and warning, default is ``True``.
1361+
13271362 See Also
13281363 --------
13291364 CmdStanMLE.stan_variable
13301365 CmdStanMCMC.stan_variables
13311366 CmdStanVB.stan_variables
13321367 CmdStanGQ.stan_variables
13331368 """
1369+ if check_convergence and not self .runset ._check_retcodes ():
1370+ get_logger ().warning (
1371+ 'invalid estimate, optimization failed to converge'
1372+ )
13341373 result = {}
13351374 for name in self ._metadata .stan_vars_dims .keys ():
1336- result [name ] = self .stan_variable (name )
1375+ result [name ] = self .stan_variable (name , False ) # don't warn twice
13371376 return result
13381377
13391378 def save_csvfiles (self , dir : Optional [str ] = None ) -> None :
0 commit comments