Skip to content

Commit 5baaf65

Browse files
committed
Serialize complex numbers to json
1 parent 5ed096e commit 5baaf65

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

cmdstanpy/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import contextlib
55
import functools
6+
import json
67
import logging
78
import math
89
import os
@@ -30,7 +31,7 @@
3031

3132
import numpy as np
3233
import pandas as pd
33-
import ujson as json
34+
import ujson
3435
from tqdm.auto import tqdm
3536

3637
from cmdstanpy import (
@@ -439,6 +440,13 @@ def rewrite_inf_nan(
439440
return data
440441

441442

443+
def serialize_complex(c: Any) -> List[float]:
444+
if isinstance(c, complex):
445+
return [c.real, c.imag]
446+
else:
447+
raise TypeError(f"Unserializable type: {type(c)}")
448+
449+
442450
def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
443451
"""
444452
Dump a mapping of strings to data to a JSON file.
@@ -494,7 +502,11 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
494502
data_out[key] = rewrite_inf_nan(data_out[key])
495503

496504
with open(path, 'w') as fd:
497-
json.dump(data_out, fd)
505+
try:
506+
ujson.dump(data_out, fd)
507+
except TypeError as e:
508+
get_logger().debug(e)
509+
json.dump(data_out, fd, default=serialize_complex)
498510

499511

500512
def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]:
@@ -948,7 +960,7 @@ def read_metric(path: str) -> List[int]:
948960
"""
949961
if path.endswith('.json'):
950962
with open(path, 'r') as fd:
951-
metric_dict = json.load(fd)
963+
metric_dict = ujson.load(fd)
952964
if 'inv_metric' in metric_dict:
953965
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
954966
return list(dims_np.shape)

test/test_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
parse_method_vars,
3939
parse_rdump_value,
4040
parse_stan_vars,
41+
pushd,
4142
read_metric,
4243
rload,
4344
set_cmdstan_path,
4445
validate_cmdstan_path,
4546
validate_dir,
4647
windows_short_path,
4748
write_stan_json,
48-
pushd,
4949
)
5050

5151
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -347,6 +347,13 @@ def cmp(d1, d2):
347347
with open(file_fin) as fd:
348348
cmp(json.load(fd), dict_inf_nan_exp)
349349

350+
dict_complex = {'a': np.array([np.complex64(3), 3 + 4j])}
351+
dict_complex_exp = {'a': [[3, 0], [3, 4]]}
352+
file_complex = os.path.join(_TMPDIR, 'complex.json')
353+
write_stan_json(file_complex, dict_complex)
354+
with open(file_complex) as fd:
355+
cmp(json.load(fd), dict_complex_exp)
356+
350357
def test_write_stan_json_bad(self):
351358
file_bad = os.path.join(_TMPDIR, 'bad.json')
352359

@@ -809,7 +816,6 @@ def test_exit(self):
809816

810817

811818
class PushdTest(unittest.TestCase):
812-
813819
def test_restore_cwd(self):
814820
"Ensure do_command in a different cwd restores cwd after error."
815821
cwd = os.getcwd()

0 commit comments

Comments
 (0)