Skip to content

Commit b205926

Browse files
committed
Handle inf/nan in json
1 parent 725ad89 commit b205926

2 files changed

Lines changed: 36 additions & 15 deletions

File tree

cmdstanpy/utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,21 @@ def cxx_toolchain_path(
409409
return compiler_path, tool_path
410410

411411

412+
def rewrite_inf_nan(
413+
data: Union[float, int, List[Any]]
414+
) -> Union[str, int, float, List[Any]]:
415+
if isinstance(data, float):
416+
if math.isnan(data):
417+
return 'NaN'
418+
if math.isinf(data):
419+
return ('+' if data > 0 else '-') + 'inf'
420+
return data
421+
elif isinstance(data, list):
422+
return [rewrite_inf_nan(item) for item in data]
423+
else:
424+
return data
425+
426+
412427
def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
413428
"""
414429
Dump a mapping of strings to data to a JSON file.
@@ -427,9 +442,13 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
427442
:param data: A mapping from strings to values. This can be a dictionary
428443
or something more exotic like an :class:`xarray.Dataset`. This will be
429444
copied before type conversion, not modified
445+
446+
:param handle_nan_inf: If enabled, perform the (Slow!) checks necessary to
447+
output NaN and inf as required for Stan
430448
"""
431449
data_out = {}
432450
for key, val in data.items():
451+
handle_nan_inf = False
433452
if val is not None:
434453
if isinstance(val, (str, bytes)) or (
435454
type(val).__module__ != 'numpy'
@@ -440,18 +459,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
440459
+ f"write_stan_json for key '{key}'"
441460
)
442461
try:
443-
if not np.all(np.isfinite(val)):
444-
raise ValueError(
445-
"Input to write_stan_json has nan or infinite "
446-
+ f"values for key '{key}'"
447-
)
462+
handle_nan_inf = not np.all(np.isfinite(val))
448463
except TypeError:
449464
# handles cases like val == ['hello']
450465
# pylint: disable=raise-missing-from
451466
raise ValueError(
452467
"Invalid type provided to "
453-
+ f"write_stan_json for key '{key}' "
454-
+ f"as part of collection {type(val)}"
468+
f"write_stan_json for key '{key}' "
469+
f"as part of collection {type(val)}"
455470
)
456471

457472
if type(val).__module__ == 'numpy':
@@ -463,6 +478,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
463478
else:
464479
data_out[key] = val
465480

481+
if handle_nan_inf:
482+
data_out[key] = rewrite_inf_nan(data_out[key])
483+
466484
with open(path, 'w') as fd:
467485
json.dump(data_out, fd)
468486

test/test_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,17 @@ def cmp(d1, d2):
338338
with open(file_scalr) as fd:
339339
cmp(json.load(fd), dict_scalr)
340340

341+
# custom Stan serialization
342+
343+
dict_inf_nan = {
344+
'a': np.array([[-np.inf, np.inf, np.NaN, float('NaN')]])
345+
}
346+
dict_inf_nan_exp = {'a': np.array([["-inf", "+inf", "NaN", "NaN"]])}
347+
file_fin = os.path.join(_TMPDIR, 'inf.json')
348+
write_stan_json(file_fin, dict_inf_nan)
349+
with open(file_fin) as fd:
350+
cmp(json.load(fd), dict_inf_nan_exp)
351+
341352
def test_write_stan_json_bad(self):
342353
file_bad = os.path.join(_TMPDIR, 'bad.json')
343354

@@ -349,14 +360,6 @@ def test_write_stan_json_bad(self):
349360
with self.assertRaises(ValueError):
350361
write_stan_json(file_bad, dict_badtype_nested)
351362

352-
dict_inf = {'a': [np.inf]}
353-
with self.assertRaises(ValueError):
354-
write_stan_json(file_bad, dict_inf)
355-
356-
dict_nan = {'a': np.nan}
357-
with self.assertRaises(ValueError):
358-
write_stan_json(file_bad, dict_nan)
359-
360363

361364
class ReadStanCsvTest(unittest.TestCase):
362365
def test_check_sampler_csv_1(self):

0 commit comments

Comments
 (0)