|
3 | 3 | """ |
4 | 4 | import contextlib |
5 | 5 | import functools |
| 6 | +import json |
6 | 7 | import logging |
7 | 8 | import math |
8 | 9 | import os |
|
31 | 32 |
|
32 | 33 | import numpy as np |
33 | 34 | import pandas as pd |
34 | | -import ujson as json |
| 35 | +import ujson |
35 | 36 | from tqdm.auto import tqdm |
36 | 37 |
|
37 | 38 | from cmdstanpy import ( |
@@ -450,6 +451,13 @@ def rewrite_inf_nan( |
450 | 451 | return data |
451 | 452 |
|
452 | 453 |
|
| 454 | +def serialize_complex(c: Any) -> List[float]: |
| 455 | + if isinstance(c, complex): |
| 456 | + return [c.real, c.imag] |
| 457 | + else: |
| 458 | + raise TypeError(f"Unserializable type: {type(c)}") |
| 459 | + |
| 460 | + |
453 | 461 | def write_stan_json(path: str, data: Mapping[str, Any]) -> None: |
454 | 462 | """ |
455 | 463 | Dump a mapping of strings to data to a JSON file. |
@@ -505,7 +513,11 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: |
505 | 513 | data_out[key] = rewrite_inf_nan(data_out[key]) |
506 | 514 |
|
507 | 515 | with open(path, 'w') as fd: |
508 | | - json.dump(data_out, fd) |
| 516 | + try: |
| 517 | + ujson.dump(data_out, fd) |
| 518 | + except TypeError as e: |
| 519 | + get_logger().debug(e) |
| 520 | + json.dump(data_out, fd, default=serialize_complex) |
509 | 521 |
|
510 | 522 |
|
511 | 523 | def rload(fname: str) -> Optional[Dict[str, Union[int, float, np.ndarray]]]: |
@@ -977,7 +989,7 @@ def read_metric(path: str) -> List[int]: |
977 | 989 | """ |
978 | 990 | if path.endswith('.json'): |
979 | 991 | with open(path, 'r') as fd: |
980 | | - metric_dict = json.load(fd) |
| 992 | + metric_dict = ujson.load(fd) |
981 | 993 | if 'inv_metric' in metric_dict: |
982 | 994 | dims_np: np.ndarray = np.asarray(metric_dict['inv_metric']) |
983 | 995 | return list(dims_np.shape) |
|
0 commit comments