@@ -409,6 +409,22 @@ def cxx_toolchain_path(
409409 return compiler_path , tool_path
410410
411411
412+ def rewrite_inf_nan (
413+ data : Union [float , int , List [Any ]]
414+ ) -> Union [str , int , float , List [Any ]]:
415+ """Replaces NaN and Infinity with string representations"""
416+ if isinstance (data , float ):
417+ if math .isnan (data ):
418+ return 'NaN'
419+ if math .isinf (data ):
420+ return ('+' if data > 0 else '-' ) + 'inf'
421+ return data
422+ elif isinstance (data , list ):
423+ return [rewrite_inf_nan (item ) for item in data ]
424+ else :
425+ return data
426+
427+
412428def write_stan_json (path : str , data : Mapping [str , Any ]) -> None :
413429 """
414430 Dump a mapping of strings to data to a JSON file.
@@ -430,6 +446,7 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
430446 """
431447 data_out = {}
432448 for key , val in data .items ():
449+ handle_nan_inf = False
433450 if val is not None :
434451 if isinstance (val , (str , bytes )) or (
435452 type (val ).__module__ != 'numpy'
@@ -440,18 +457,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
440457 + f"write_stan_json for key '{ key } '"
441458 )
442459 try :
443- if not np .all (np .isfinite (val )):
444- raise ValueError (
445- "Input to write_stan_json has nan or infinite "
446- + f"values for key '{ key } '"
447- )
460+ handle_nan_inf = not np .all (np .isfinite (val ))
448461 except TypeError :
449462 # handles cases like val == ['hello']
450463 # pylint: disable=raise-missing-from
451464 raise ValueError (
452465 "Invalid type provided to "
453- + f"write_stan_json for key '{ key } ' "
454- + f"as part of collection { type (val )} "
466+ f"write_stan_json for key '{ key } ' "
467+ f"as part of collection { type (val )} "
455468 )
456469
457470 if type (val ).__module__ == 'numpy' :
@@ -463,6 +476,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
463476 else :
464477 data_out [key ] = val
465478
479+ if handle_nan_inf :
480+ data_out [key ] = rewrite_inf_nan (data_out [key ])
481+
466482 with open (path , 'w' ) as fd :
467483 json .dump (data_out , fd )
468484
0 commit comments