@@ -409,6 +409,21 @@ def cxx_toolchain_path(
409409 return compiler_path , tool_path
410410
411411
412+ def rewrite_inf_nan (
413+ data : Union [float , int , List [Any ]]
414+ ) -> Union [str , int , float , List [Any ]]:
415+ if isinstance (data , float ):
416+ if math .isnan (data ):
417+ return 'NaN'
418+ if math .isinf (data ):
419+ return ('+' if data > 0 else '-' ) + 'inf'
420+ return data
421+ elif isinstance (data , list ):
422+ return [rewrite_inf_nan (item ) for item in data ]
423+ else :
424+ return data
425+
426+
412427def write_stan_json (path : str , data : Mapping [str , Any ]) -> None :
413428 """
414429 Dump a mapping of strings to data to a JSON file.
@@ -427,9 +442,13 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
427442 :param data: A mapping from strings to values. This can be a dictionary
428443 or something more exotic like an :class:`xarray.Dataset`. This will be
429444 copied before type conversion, not modified
445+
446+ :param handle_nan_inf: If enabled, perform the (Slow!) checks necessary to
447+ output NaN and inf as required for Stan
430448 """
431449 data_out = {}
432450 for key , val in data .items ():
451+ handle_nan_inf = False
433452 if val is not None :
434453 if isinstance (val , (str , bytes )) or (
435454 type (val ).__module__ != 'numpy'
@@ -440,18 +459,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
440459 + f"write_stan_json for key '{ key } '"
441460 )
442461 try :
443- if not np .all (np .isfinite (val )):
444- raise ValueError (
445- "Input to write_stan_json has nan or infinite "
446- + f"values for key '{ key } '"
447- )
462+ handle_nan_inf = not np .all (np .isfinite (val ))
448463 except TypeError :
449464 # handles cases like val == ['hello']
450465 # pylint: disable=raise-missing-from
451466 raise ValueError (
452467 "Invalid type provided to "
453- + f"write_stan_json for key '{ key } ' "
454- + f"as part of collection { type (val )} "
468+ f"write_stan_json for key '{ key } ' "
469+ f"as part of collection { type (val )} "
455470 )
456471
457472 if type (val ).__module__ == 'numpy' :
@@ -463,6 +478,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
463478 else :
464479 data_out [key ] = val
465480
481+ if handle_nan_inf :
482+ data_out [key ] = rewrite_inf_nan (data_out [key ])
483+
466484 with open (path , 'w' ) as fd :
467485 json .dump (data_out , fd )
468486
0 commit comments