Skip to content

Commit b6874cf

Browse files
authored
Merge pull request #489 from stan-dev/disk-warning
Give more informative error for bad draws (highlight disk space as possible culprit)
2 parents 3d207c8 + 0058812 commit b6874cf

5 files changed

Lines changed: 71 additions & 44 deletions

File tree

cmdstanpy/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,12 +607,15 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
607607
dict: Dict[str, Any] = {}
608608
lineno = 0
609609
with open(path, 'r') as fd:
610-
lineno = scan_config(fd, dict, lineno)
611-
lineno = scan_column_names(fd, dict, lineno)
612-
if not is_fixed_param:
613-
lineno = scan_warmup_iters(fd, dict, lineno)
614-
lineno = scan_hmc_params(fd, dict, lineno)
615-
lineno = scan_sampling_iters(fd, dict, lineno)
610+
try:
611+
lineno = scan_config(fd, dict, lineno)
612+
lineno = scan_column_names(fd, dict, lineno)
613+
if not is_fixed_param:
614+
lineno = scan_warmup_iters(fd, dict, lineno)
615+
lineno = scan_hmc_params(fd, dict, lineno)
616+
lineno = scan_sampling_iters(fd, dict, lineno)
617+
except ValueError as e:
618+
raise ValueError("Error in reading csv file: " + path) from e
616619
return dict
617620

618621

@@ -910,9 +913,12 @@ def scan_sampling_iters(
910913
data = line.split(',')
911914
if len(data) != num_cols:
912915
raise ValueError(
913-
'line {}: bad draw, expecting {} items, found {}'.format(
916+
'line {}: bad draw, expecting {} items, found {}\n'.format(
914917
lineno, num_cols, len(line.split(','))
915918
)
919+
+ 'This error could be caused by running out of disk space.\n'
920+
'Try clearing up TEMP or setting output_dir to a path'
921+
' on another drive.',
916922
)
917923
cur_pos = fd.tell()
918924
line = fd.readline().strip()

test/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Testing utilities for CmdStanPy."""
2+
3+
import contextlib
4+
import unittest
5+
from importlib import reload
6+
7+
8+
class CustomTestCase(unittest.TestCase):
9+
# pylint: disable=invalid-name
10+
@contextlib.contextmanager
11+
def assertRaisesRegexNested(self, exc, msg):
12+
"""A version of assertRaisesRegex that checks the full traceback.
13+
14+
Useful for when an exception is raised from another and you wish to
15+
inspect the inner exception.
16+
"""
17+
with self.assertRaises(exc) as ctx:
18+
yield
19+
exception = ctx.exception
20+
exn_string = str(ctx.exception)
21+
while exception.__cause__ is not None:
22+
exception = exception.__cause__
23+
exn_string += "\n" + str(exception)
24+
self.assertRegex(exn_string, msg)
25+
26+
# pylint: disable=no-self-use
27+
@contextlib.contextmanager
28+
def without_import(self, library, module):
29+
with unittest.mock.patch.dict('sys.modules', {library: None}):
30+
reload(module)
31+
yield
32+
reload(module)

test/test_generate_quantities.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
import unittest
9-
from importlib import reload
9+
from test import CustomTestCase
1010

1111
import numpy as np
1212
import pandas as pd
@@ -21,15 +21,7 @@
2121
DATAFILES_PATH = os.path.join(HERE, 'data')
2222

2323

24-
@contextlib.contextmanager
25-
def without_import(library, module):
26-
with unittest.mock.patch.dict('sys.modules', {library: None}):
27-
reload(module)
28-
yield
29-
reload(module)
30-
31-
32-
class GenerateQuantitiesTest(unittest.TestCase):
24+
class GenerateQuantitiesTest(CustomTestCase):
3325
def test_from_csv_files(self):
3426
# fitted_params sample - list of filenames
3527
goodfiles_path = os.path.join(DATAFILES_PATH, 'runset-good', 'bern')
@@ -357,7 +349,7 @@ def test_sample_plus_quantities_dedup(self):
357349
self.assertEqual(y_rep[0, i], bern_data['y'][i])
358350

359351
def test_no_xarray(self):
360-
with without_import('xarray', cmdstanpy.stanfit):
352+
with self.without_import('xarray', cmdstanpy.stanfit):
361353
with self.assertRaises(ImportError):
362354
# if this fails the testing framework is the problem
363355
import xarray as _ # noqa

test/test_sample.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import stat
1010
import tempfile
1111
import unittest
12-
from importlib import reload
1312
from multiprocessing import cpu_count
13+
from test import CustomTestCase
1414
from time import time
1515

1616
import numpy as np
@@ -47,14 +47,6 @@
4747
BERNOULLI_COLS = SAMPLER_STATE + ['theta']
4848

4949

50-
@contextlib.contextmanager
51-
def without_import(library, module):
52-
with unittest.mock.patch.dict('sys.modules', {library: None}):
53-
reload(module)
54-
yield
55-
reload(module)
56-
57-
5850
class SampleTest(unittest.TestCase):
5951
def test_bernoulli_good(self, stanfile='bernoulli.stan'):
6052
stan = os.path.join(DATAFILES_PATH, stanfile)
@@ -584,7 +576,7 @@ def test_show_progress(self, stanfile='bernoulli.stan'):
584576
self.assertTrue('Sampling completed' in console)
585577

586578

587-
class CmdStanMCMCTest(unittest.TestCase):
579+
class CmdStanMCMCTest(CustomTestCase):
588580
# pylint: disable=too-many-public-methods
589581
def test_validate_good_run(self):
590582
# construct fit using existing sampler output
@@ -1092,7 +1084,9 @@ def test_validate_bad_run(self):
10921084
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-3.csv'),
10931085
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-4.csv'),
10941086
]
1095-
with self.assertRaisesRegex(ValueError, 'CmdStan config mismatch'):
1087+
with self.assertRaisesRegexNested(
1088+
ValueError, 'CmdStan config mismatch'
1089+
):
10961090
CmdStanMCMC(runset)
10971091

10981092
# bad draws
@@ -1102,7 +1096,7 @@ def test_validate_bad_run(self):
11021096
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-3.csv'),
11031097
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'),
11041098
]
1105-
with self.assertRaisesRegex(ValueError, 'draws'):
1099+
with self.assertRaisesRegexNested(ValueError, 'draws'):
11061100
CmdStanMCMC(runset)
11071101

11081102
# mismatch - column headers, draws
@@ -1112,7 +1106,7 @@ def test_validate_bad_run(self):
11121106
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'),
11131107
os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'),
11141108
]
1115-
with self.assertRaisesRegex(
1109+
with self.assertRaisesRegexNested(
11161110
ValueError, 'bad draw, expecting 9 items, found 8'
11171111
):
11181112
CmdStanMCMC(runset)
@@ -1604,7 +1598,7 @@ def test_xarray_draws(self):
16041598
self.assertEqual(xr_var.theta.values.shape, (1, 100, 1))
16051599

16061600
def test_no_xarray(self):
1607-
with without_import('xarray', cmdstanpy.stanfit):
1601+
with self.without_import('xarray', cmdstanpy.stanfit):
16081602
with self.assertRaises(ImportError):
16091603
# if this fails the testing framework is the problem
16101604
import xarray as _ # noqa

test/test_utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import tempfile
1515
import unittest
1616
from pathlib import Path
17+
from test import CustomTestCase
1718

1819
import numpy as np
1920
import pandas as pd
@@ -371,7 +372,7 @@ def test_write_stan_json_bad(self):
371372
write_stan_json(file_bad, dict_badtype_nested)
372373

373374

374-
class ReadStanCsvTest(unittest.TestCase):
375+
class ReadStanCsvTest(CustomTestCase):
375376
def test_check_sampler_csv_1(self):
376377
csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv')
377378
dict = check_sampler_csv(
@@ -386,13 +387,13 @@ def test_check_sampler_csv_1(self):
386387
self.assertEqual(10, dict['draws_sampling'])
387388
self.assertEqual(8, len(dict['column_names']))
388389

389-
with self.assertRaisesRegex(
390+
with self.assertRaisesRegexNested(
390391
ValueError, 'config error, expected thin = 2'
391392
):
392393
check_sampler_csv(
393394
path=csv_good, iter_warmup=100, iter_sampling=20, thin=2
394395
)
395-
with self.assertRaisesRegex(
396+
with self.assertRaisesRegexNested(
396397
ValueError, 'config error, expected save_warmup'
397398
):
398399
check_sampler_csv(
@@ -401,7 +402,7 @@ def test_check_sampler_csv_1(self):
401402
iter_sampling=10,
402403
save_warmup=True,
403404
)
404-
with self.assertRaisesRegex(ValueError, 'expected 1000 draws'):
405+
with self.assertRaisesRegexNested(ValueError, 'expected 1000 draws'):
405406
check_sampler_csv(path=csv_good, iter_warmup=100)
406407

407408
def test_check_sampler_csv_2(self):
@@ -411,34 +412,34 @@ def test_check_sampler_csv_2(self):
411412

412413
def test_check_sampler_csv_3(self):
413414
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_cols.csv')
414-
with self.assertRaisesRegex(Exception, '8 items'):
415+
with self.assertRaisesRegexNested(Exception, '8 items'):
415416
check_sampler_csv(csv_bad)
416417

417418
def test_check_sampler_csv_4(self):
418419
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_rows.csv')
419-
with self.assertRaisesRegex(Exception, 'found 9'):
420+
with self.assertRaisesRegexNested(Exception, 'found 9'):
420421
check_sampler_csv(csv_bad)
421422

422423
def test_check_sampler_csv_metric_1(self):
423424
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_1.csv')
424-
with self.assertRaisesRegex(Exception, 'expecting metric'):
425+
with self.assertRaisesRegexNested(Exception, 'expecting metric'):
425426
check_sampler_csv(csv_bad)
426427

427428
def test_check_sampler_csv_metric_2(self):
428429
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_2.csv')
429-
with self.assertRaisesRegex(Exception, 'invalid step size'):
430+
with self.assertRaisesRegexNested(Exception, 'invalid step size'):
430431
check_sampler_csv(csv_bad)
431432

432433
def test_check_sampler_csv_metric_3(self):
433434
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_3.csv')
434-
with self.assertRaisesRegex(
435+
with self.assertRaisesRegexNested(
435436
Exception, 'invalid or missing mass matrix specification'
436437
):
437438
check_sampler_csv(csv_bad)
438439

439440
def test_check_sampler_csv_metric_4(self):
440441
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_4.csv')
441-
with self.assertRaisesRegex(
442+
with self.assertRaisesRegexNested(
442443
Exception, 'invalid or missing mass matrix specification'
443444
):
444445
check_sampler_csv(csv_bad)
@@ -474,15 +475,17 @@ def test_check_sampler_csv_thin(self):
474475
self.assertEqual(dict['max_depth'], 11)
475476
self.assertEqual(dict['delta'], 0.98)
476477

477-
with self.assertRaisesRegex(ValueError, 'config error'):
478+
with self.assertRaisesRegexNested(ValueError, 'config error'):
478479
check_sampler_csv(
479480
path=csv_file,
480481
is_fixed_param=False,
481482
iter_sampling=490,
482483
iter_warmup=490,
483484
thin=9,
484485
)
485-
with self.assertRaisesRegex(ValueError, 'expected 490 draws, found 70'):
486+
with self.assertRaisesRegexNested(
487+
ValueError, 'expected 490 draws, found 70'
488+
):
486489
check_sampler_csv(
487490
path=csv_file,
488491
is_fixed_param=False,

0 commit comments

Comments
 (0)