|
3 | 3 | """ |
4 | 4 | import contextlib |
5 | 5 | import functools |
| 6 | +import json |
6 | 7 | import logging |
7 | 8 | import math |
8 | 9 | import os |
|
30 | 31 |
|
31 | 32 | import numpy as np |
32 | 33 | import pandas as pd |
33 | | -import ujson as json |
| 34 | +import ujson |
34 | 35 | from tqdm.auto import tqdm |
35 | 36 |
|
36 | 37 | from cmdstanpy import ( |
@@ -439,6 +440,13 @@ def rewrite_inf_nan( |
439 | 440 | return data |
440 | 441 |
|
441 | 442 |
|
| 443 | +def serialize_complex(c: Any) -> List[float]: |
| 444 | + if isinstance(c, complex): |
| 445 | + return [c.real, c.imag] |
| 446 | + else: |
| 447 | + raise TypeError(f"Unserializable type: {type(c)}") |
| 448 | + |
| 449 | + |
442 | 450 | def write_stan_json(path: str, data: Mapping[str, Any]) -> None: |
443 | 451 | """ |
444 | 452 | Dump a mapping of strings to data to a JSON file. |
@@ -494,7 +502,11 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: |
494 | 502 | data_out[key] = rewrite_inf_nan(data_out[key]) |
495 | 503 |
|
496 | 504 | with open(path, 'w') as fd: |
497 | | - json.dump(data_out, fd) |
| 505 | + try: |
| 506 | + ujson.dump(data_out, fd) |
| 507 | + except TypeError as e: |
| 508 | + get_logger().debug(e) |
| 509 | + json.dump(data_out, fd, default=serialize_complex) |
498 | 510 |
|
499 | 511 |
|
500 | 512 | def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]: |
@@ -948,7 +960,7 @@ def read_metric(path: str) -> List[int]: |
948 | 960 | """ |
949 | 961 | if path.endswith('.json'): |
950 | 962 | with open(path, 'r') as fd: |
951 | | - metric_dict = json.load(fd) |
| 963 | + metric_dict = ujson.load(fd) |
952 | 964 | if 'inv_metric' in metric_dict: |
953 | 965 | dims_np: np.ndarray = np.asarray(metric_dict['inv_metric']) |
954 | 966 | return list(dims_np.shape) |
|
0 commit comments