Skip to content

Commit fe2014d

Browse files
authored
Merge pull request #535 from stan-dev/serialize-complex
Serialize complex numbers to json
2 parents a002f80 + 5baaf65 commit fe2014d

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

cmdstanpy/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import contextlib
55
import functools
6+
import json
67
import logging
78
import math
89
import os
@@ -31,7 +32,7 @@
3132

3233
import numpy as np
3334
import pandas as pd
34-
import ujson as json
35+
import ujson
3536
from tqdm.auto import tqdm
3637

3738
from cmdstanpy import (
@@ -450,6 +451,13 @@ def rewrite_inf_nan(
450451
return data
451452

452453

454+
def serialize_complex(c: Any) -> List[float]:
455+
if isinstance(c, complex):
456+
return [c.real, c.imag]
457+
else:
458+
raise TypeError(f"Unserializable type: {type(c)}")
459+
460+
453461
def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
454462
"""
455463
Dump a mapping of strings to data to a JSON file.
@@ -505,7 +513,11 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
505513
data_out[key] = rewrite_inf_nan(data_out[key])
506514

507515
with open(path, 'w') as fd:
508-
json.dump(data_out, fd)
516+
try:
517+
ujson.dump(data_out, fd)
518+
except TypeError as e:
519+
get_logger().debug(e)
520+
json.dump(data_out, fd, default=serialize_complex)
509521

510522

511523
def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]:
@@ -977,7 +989,7 @@ def read_metric(path: str) -> List[int]:
977989
"""
978990
if path.endswith('.json'):
979991
with open(path, 'r') as fd:
980-
metric_dict = json.load(fd)
992+
metric_dict = ujson.load(fd)
981993
if 'inv_metric' in metric_dict:
982994
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
983995
return list(dims_np.shape)

test/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def cmp(d1, d2):
348348
with open(file_fin) as fd:
349349
cmp(json.load(fd), dict_inf_nan_exp)
350350

351+
dict_complex = {'a': np.array([np.complex64(3), 3 + 4j])}
352+
dict_complex_exp = {'a': [[3, 0], [3, 4]]}
353+
file_complex = os.path.join(_TMPDIR, 'complex.json')
354+
write_stan_json(file_complex, dict_complex)
355+
with open(file_complex) as fd:
356+
cmp(json.load(fd), dict_complex_exp)
357+
351358
def test_write_stan_json_bad(self):
352359
file_bad = os.path.join(_TMPDIR, 'bad.json')
353360

0 commit comments

Comments
 (0)