Skip to content

Commit 3d207c8

Browse files
authored
Merge pull request #488 from stan-dev/nan-inf
Serialize inf/nan to json
2 parents 725ad89 + 86e7a81 commit 3d207c8

2 files changed

Lines changed: 44 additions & 15 deletions

File tree

cmdstanpy/utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,22 @@ def cxx_toolchain_path(
409409
return compiler_path, tool_path
410410

411411

412+
def rewrite_inf_nan(
413+
data: Union[float, int, List[Any]]
414+
) -> Union[str, int, float, List[Any]]:
415+
"""Replaces NaN and Infinity with string representations"""
416+
if isinstance(data, float):
417+
if math.isnan(data):
418+
return 'NaN'
419+
if math.isinf(data):
420+
return ('+' if data > 0 else '-') + 'inf'
421+
return data
422+
elif isinstance(data, list):
423+
return [rewrite_inf_nan(item) for item in data]
424+
else:
425+
return data
426+
427+
412428
def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
413429
"""
414430
Dump a mapping of strings to data to a JSON file.
@@ -430,6 +446,7 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
430446
"""
431447
data_out = {}
432448
for key, val in data.items():
449+
handle_nan_inf = False
433450
if val is not None:
434451
if isinstance(val, (str, bytes)) or (
435452
type(val).__module__ != 'numpy'
@@ -440,18 +457,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
440457
+ f"write_stan_json for key '{key}'"
441458
)
442459
try:
443-
if not np.all(np.isfinite(val)):
444-
raise ValueError(
445-
"Input to write_stan_json has nan or infinite "
446-
+ f"values for key '{key}'"
447-
)
460+
handle_nan_inf = not np.all(np.isfinite(val))
448461
except TypeError:
449462
# handles cases like val == ['hello']
450463
# pylint: disable=raise-missing-from
451464
raise ValueError(
452465
"Invalid type provided to "
453-
+ f"write_stan_json for key '{key}' "
454-
+ f"as part of collection {type(val)}"
466+
f"write_stan_json for key '{key}' "
467+
f"as part of collection {type(val)}"
455468
)
456469

457470
if type(val).__module__ == 'numpy':
@@ -463,6 +476,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
463476
else:
464477
data_out[key] = val
465478

479+
if handle_nan_inf:
480+
data_out[key] = rewrite_inf_nan(data_out[key])
481+
466482
with open(path, 'w') as fd:
467483
json.dump(data_out, fd)
468484

test/test_utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,27 @@ def cmp(d1, d2):
338338
with open(file_scalr) as fd:
339339
cmp(json.load(fd), dict_scalr)
340340

341+
# custom Stan serialization
342+
dict_inf_nan = {
343+
'a': np.array(
344+
[
345+
[-np.inf, np.inf, np.NaN],
346+
[-float('inf'), float('inf'), float('NaN')],
347+
[
348+
np.float32(-np.inf),
349+
np.float32(np.inf),
350+
np.float32(np.NaN),
351+
],
352+
[1e200 * -1e200, 1e220 * 1e200, -np.nan],
353+
]
354+
)
355+
}
356+
dict_inf_nan_exp = {'a': [["-inf", "+inf", "NaN"]] * 4}
357+
file_fin = os.path.join(_TMPDIR, 'inf.json')
358+
write_stan_json(file_fin, dict_inf_nan)
359+
with open(file_fin) as fd:
360+
cmp(json.load(fd), dict_inf_nan_exp)
361+
341362
def test_write_stan_json_bad(self):
342363
file_bad = os.path.join(_TMPDIR, 'bad.json')
343364

@@ -349,14 +370,6 @@ def test_write_stan_json_bad(self):
349370
with self.assertRaises(ValueError):
350371
write_stan_json(file_bad, dict_badtype_nested)
351372

352-
dict_inf = {'a': [np.inf]}
353-
with self.assertRaises(ValueError):
354-
write_stan_json(file_bad, dict_inf)
355-
356-
dict_nan = {'a': np.nan}
357-
with self.assertRaises(ValueError):
358-
write_stan_json(file_bad, dict_nan)
359-
360373

361374
class ReadStanCsvTest(unittest.TestCase):
362375
def test_check_sampler_csv_1(self):

0 commit comments

Comments
 (0)