Skip to content

Commit 361cffb

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents fa664ba + 3e483cd commit 361cffb

6 files changed

Lines changed: 119 additions & 63 deletions

File tree

.github/workflows/main.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: CmdStanPy
22

33
on:
4-
push:
4+
push:
55
branches:
66
- 'develop'
77
- 'master'
@@ -28,7 +28,7 @@ jobs:
2828
strategy:
2929
matrix:
3030
os: [ubuntu-latest, macos-latest, windows-latest]
31-
python-version: [3.6, 3.7, 3.8, 3.9]
31+
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
3232
fail-fast: false
3333
env:
3434
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
@@ -65,7 +65,7 @@ jobs:
6565
if: matrix.os == 'windows-latest'
6666
run: |
6767
$whl = Get-ChildItem -Path dist -Filter *.whl | Select-Object -First 1
68-
pip install "$whl"
68+
pip install "$whl"
6969
7070
- name: Show libraries
7171
run: python -m pip freeze
@@ -101,4 +101,4 @@ jobs:
101101
- name: Submit codecov
102102
run: |
103103
cd run_tests
104-
codecov
104+
codecov

cmdstanpy/utils.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,22 @@ def cxx_toolchain_path(
409409
return compiler_path, tool_path
410410

411411

412+
def rewrite_inf_nan(
413+
data: Union[float, int, List[Any]]
414+
) -> Union[str, int, float, List[Any]]:
415+
"""Replaces NaN and Infinity with string representations"""
416+
if isinstance(data, float):
417+
if math.isnan(data):
418+
return 'NaN'
419+
if math.isinf(data):
420+
return ('+' if data > 0 else '-') + 'inf'
421+
return data
422+
elif isinstance(data, list):
423+
return [rewrite_inf_nan(item) for item in data]
424+
else:
425+
return data
426+
427+
412428
def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
413429
"""
414430
Dump a mapping of strings to data to a JSON file.
@@ -430,6 +446,7 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
430446
"""
431447
data_out = {}
432448
for key, val in data.items():
449+
handle_nan_inf = False
433450
if val is not None:
434451
if isinstance(val, (str, bytes)) or (
435452
type(val).__module__ != 'numpy'
@@ -440,18 +457,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
440457
+ f"write_stan_json for key '{key}'"
441458
)
442459
try:
443-
if not np.all(np.isfinite(val)):
444-
raise ValueError(
445-
"Input to write_stan_json has nan or infinite "
446-
+ f"values for key '{key}'"
447-
)
460+
handle_nan_inf = not np.all(np.isfinite(val))
448461
except TypeError:
449462
# handles cases like val == ['hello']
450463
# pylint: disable=raise-missing-from
451464
raise ValueError(
452465
"Invalid type provided to "
453-
+ f"write_stan_json for key '{key}' "
454-
+ f"as part of collection {type(val)}"
466+
f"write_stan_json for key '{key}' "
467+
f"as part of collection {type(val)}"
455468
)
456469

457470
if type(val).__module__ == 'numpy':
@@ -463,6 +476,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
463476
else:
464477
data_out[key] = val
465478

479+
if handle_nan_inf:
480+
data_out[key] = rewrite_inf_nan(data_out[key])
481+
466482
with open(path, 'w') as fd:
467483
json.dump(data_out, fd)
468484

@@ -591,12 +607,15 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
591607
dict: Dict[str, Any] = {}
592608
lineno = 0
593609
with open(path, 'r') as fd:
594-
lineno = scan_config(fd, dict, lineno)
595-
lineno = scan_column_names(fd, dict, lineno)
596-
if not is_fixed_param:
597-
lineno = scan_warmup_iters(fd, dict, lineno)
598-
lineno = scan_hmc_params(fd, dict, lineno)
599-
lineno = scan_sampling_iters(fd, dict, lineno)
610+
try:
611+
lineno = scan_config(fd, dict, lineno)
612+
lineno = scan_column_names(fd, dict, lineno)
613+
if not is_fixed_param:
614+
lineno = scan_warmup_iters(fd, dict, lineno)
615+
lineno = scan_hmc_params(fd, dict, lineno)
616+
lineno = scan_sampling_iters(fd, dict, lineno)
617+
except ValueError as e:
618+
raise ValueError("Error in reading csv file: " + path) from e
600619
return dict
601620

602621

@@ -894,9 +913,12 @@ def scan_sampling_iters(
894913
data = line.split(',')
895914
if len(data) != num_cols:
896915
raise ValueError(
897-
'line {}: bad draw, expecting {} items, found {}'.format(
916+
'line {}: bad draw, expecting {} items, found {}\n'.format(
898917
lineno, num_cols, len(line.split(','))
899918
)
919+
+ 'This error could be caused by running out of disk space.\n'
920+
'Try clearing up TEMP or setting output_dir to a path'
921+
' on another drive.',
900922
)
901923
cur_pos = fd.tell()
902924
line = fd.readline().strip()

test/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Testing utilities for CmdStanPy."""
2+
3+
import contextlib
4+
import unittest
5+
from importlib import reload
6+
7+
8+
class CustomTestCase(unittest.TestCase):
9+
# pylint: disable=invalid-name
10+
@contextlib.contextmanager
11+
def assertRaisesRegexNested(self, exc, msg):
12+
"""A version of assertRaisesRegex that checks the full traceback.
13+
14+
Useful for when an exception is raised from another and you wish to
15+
inspect the inner exception.
16+
"""
17+
with self.assertRaises(exc) as ctx:
18+
yield
19+
exception = ctx.exception
20+
exn_string = str(ctx.exception)
21+
while exception.__cause__ is not None:
22+
exception = exception.__cause__
23+
exn_string += "\n" + str(exception)
24+
self.assertRegex(exn_string, msg)
25+
26+
# pylint: disable=no-self-use
27+
@contextlib.contextmanager
28+
def without_import(self, library, module):
29+
with unittest.mock.patch.dict('sys.modules', {library: None}):
30+
reload(module)
31+
yield
32+
reload(module)

test/test_generate_quantities.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
import unittest
9-
from importlib import reload
9+
from test import CustomTestCase
1010

1111
import numpy as np
1212
import pandas as pd
@@ -21,15 +21,7 @@
2121
DATAFILES_PATH = os.path.join(HERE, 'data')
2222

2323

24-
@contextlib.contextmanager
25-
def without_import(library, module):
26-
with unittest.mock.patch.dict('sys.modules', {library: None}):
27-
reload(module)
28-
yield
29-
reload(module)
30-
31-
32-
class GenerateQuantitiesTest(unittest.TestCase):
24+
class GenerateQuantitiesTest(CustomTestCase):
3325
def test_from_csv_files(self):
3426
# fitted_params sample - list of filenames
3527
goodfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', 'bern')
@@ -357,7 +349,7 @@ def test_sample_plus_quantities_dedup(self):
357349
self.assertEqual(y_rep[0, i], bern_data['y'][i])
358350

359351
def test_no_xarray(self):
360-
with without_import('xarray', cmdstanpy.stanfit):
352+
with self.without_import('xarray', cmdstanpy.stanfit):
361353
with self.assertRaises(ImportError):
362354
# if this fails the testing framework is the problem
363355
import xarray as _ # noqa

test/test_sample.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import stat
1010
import tempfile
1111
import unittest
12-
from importlib import reload
1312
from multiprocessing import cpu_count
13+
from test import CustomTestCase
1414
from time import time
1515

1616
import numpy as np
@@ -47,14 +47,6 @@
4747
BERNOULLI_COLS = SAMPLER_STATE + ['theta']
4848

4949

50-
@contextlib.contextmanager
51-
def without_import(library, module):
52-
with unittest.mock.patch.dict('sys.modules', {library: None}):
53-
reload(module)
54-
yield
55-
reload(module)
56-
57-
5850
class SampleTest(unittest.TestCase):
5951
def test_bernoulli_good(self, stanfile='bernoulli.stan'):
6052
stan = os.path.join(DATAFILES_PATH, stanfile)
@@ -584,7 +576,7 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
584576
self.assertTrue('Sampling completed' in console)
585577

586578

587-
class CmdStanMCMCTest(unittest.TestCase):
579+
class CmdStanMCMCTest(CustomTestCase):
588580
# pylint: disable=too-many-public-methods
589581
def test_validate_good_run(self):
590582
# construct fit using existing sampler output
@@ -1092,7 +1084,9 @@ def test_validate_bad_run(self):
10921084
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-3.csv'),
10931085
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-4.csv'),
10941086
]
1095-
with self.assertRaisesRegex(ValueError, 'CmdStan config mismatch'):
1087+
with self.assertRaisesRegexNested(
1088+
ValueError, 'CmdStan config mismatch'
1089+
):
10961090
CmdStanMCMC(runset)
10971091

10981092
# bad draws
@@ -1102,7 +1096,7 @@ def test_validate_bad_run(self):
11021096
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-3.csv'),
11031097
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'),
11041098
]
1105-
with self.assertRaisesRegex(ValueError, 'draws'):
1099+
with self.assertRaisesRegexNested(ValueError, 'draws'):
11061100
CmdStanMCMC(runset)
11071101

11081102
# mismatch - column headers, draws
@@ -1112,7 +1106,7 @@ def test_validate_bad_run(self):
11121106
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'),
11131107
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'),
11141108
]
1115-
with self.assertRaisesRegex(
1109+
with self.assertRaisesRegexNested(
11161110
ValueError, 'bad draw, expecting 9 items, found 8'
11171111
):
11181112
CmdStanMCMC(runset)
@@ -1604,7 +1598,7 @@ def test_xarray_draws(self):
16041598
self.assertEqual(xr_var.theta.values.shape, (1, 100, 1))
16051599

16061600
def test_no_xarray(self):
1607-
with without_import('xarray', cmdstanpy.stanfit):
1601+
with self.without_import('xarray', cmdstanpy.stanfit):
16081602
with self.assertRaises(ImportError):
16091603
# if this fails the testing framework is the problem
16101604
import xarray as _ # noqa

0 commit comments

Comments
 (0)