Skip to content

Commit bda136f

Browse files
committed
changes per code review
1 parent 54ed233 commit bda136f

3 files changed

Lines changed: 59 additions & 47 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,33 +95,14 @@ def __init__(
9595
# only valid when not is_fixed_param
9696
self._metric: np.ndarray = np.array(())
9797
self._step_size: np.ndarray = np.array(())
98-
self._divergences: np.ndarray = np.zeros(self.runset.chains, int)
99-
self._max_treedepths: np.ndarray = np.zeros(self.runset.chains, int)
98+
self._divergences: np.ndarray = np.array(())
99+
self._max_treedepths: np.ndarray = np.array(())
100100

101101
# info from CSV initial comments and header
102102
config = self._validate_csv_files()
103103
self._metadata: InferenceMetadata = InferenceMetadata(config)
104-
# prelim diagnostics
105-
if np.any(self._divergences) or np.any(self._max_treedepths):
106-
diagnostics = ['Some chains may have failed to converge.']
107-
ct_iters = config['num_samples'] # pylint: disable=unused-variable
108-
for i in range(self.runset._chains):
109-
if self._divergences[i] > 0:
110-
diagnostics.append(
111-
f'Chain {i + 1} had {self._divergences[i]} '
112-
'divergent transitions '
113-
f'({((self._divergences[i]/ct_iters)*100):.1f}%)'
114-
)
115-
if self._max_treedepths[i] > 0:
116-
diagnostics.append(
117-
f'Chain {i + 1} had {self._max_treedepths[i]} '
118-
'iterations at max treedepth '
119-
f'({((self._max_treedepths[i]/ct_iters)*100):.1f}%)'
120-
)
121-
diagnostics.append(
122-
'Use function "diagnose()" to see further information.'
123-
)
124-
get_logger().warning('\n\t'.join(diagnostics))
104+
if not self._is_fixed_param:
105+
self._check_sampler_diagnostics()
125106

126107
def __repr__(self) -> str:
127108
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
@@ -304,6 +285,10 @@ def _validate_csv_files(self) -> Dict[str, Any]:
304285
Tabulates sampling iters which are divergent or at max treedepth
305286
Raises exception when inconsistencies detected.
306287
"""
288+
if not self._is_fixed_param:
289+
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
290+
self._max_treedepths: np.ndarray = np.zeros(self.runset.chains, dtype=int)
291+
307292
dzero = {}
308293
for i in range(self.chains):
309294
if i == 0:
@@ -359,6 +344,32 @@ def _validate_csv_files(self) -> Dict[str, Any]:
359344
self._max_treedepths[i] = drest['ct_max_treedepth']
360345
return dzero
361346

347+
# pylint: disable=unused-variable
348+
def _check_sampler_diagnostics(self) -> None:
349+
"""
350+
Warn if any iterations ended in divergences or hit maxtreedepth.
351+
"""
352+
if np.any(self._divergences) or np.any(self._max_treedepths):
353+
diagnostics = ['Some chains may have failed to converge.']
354+
ct_iters = self.metadata.cmdstan_config['num_samples']
355+
for i in range(self.runset._chains):
356+
if self._divergences[i] > 0:
357+
diagnostics.append(
358+
f'Chain {i + 1} had {self._divergences[i]} '
359+
'divergent transitions '
360+
f'({((self._divergences[i]/ct_iters)*100):.1f}%)'
361+
)
362+
if self._max_treedepths[i] > 0:
363+
diagnostics.append(
364+
f'Chain {i + 1} had {self._max_treedepths[i]} '
365+
'iterations at max treedepth '
366+
f'({((self._max_treedepths[i]/ct_iters)*100):.1f}%)'
367+
)
368+
diagnostics.append(
369+
'Use function "diagnose()" to see further information.'
370+
)
371+
get_logger().warning('\n\t'.join(diagnostics))
372+
362373
def _assemble_draws(self) -> None:
363374
"""
364375
Allocates and populates the step size, metric, and sample arrays

cmdstanpy/utils.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
654654
if not is_fixed_param:
655655
lineno = scan_warmup_iters(fd, dict, lineno)
656656
lineno = scan_hmc_params(fd, dict, lineno)
657-
lineno = scan_sampling_iters(fd, dict, lineno)
657+
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
658658
except ValueError as e:
659659
raise ValueError("Error in reading csv file: " + path) from e
660660
return dict
@@ -957,25 +957,21 @@ def scan_hmc_params(
957957

958958

959959
def scan_sampling_iters(
960-
fd: TextIO, config_dict: Dict[str, Any], lineno: int
960+
fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool
961961
) -> int:
962962
"""
963963
Parse sampling iteration, save number of iterations to config_dict.
964964
Also save number of divergences, max_treedepth hits
965965
"""
966966
draws_found = 0
967967
num_cols = len(config_dict['column_names'])
968-
idx_divergent = None
969-
idx_treedepth = None
970-
max_treedepth = None
971-
ct_divergences = 0
972-
ct_max_treedepth = 0
973-
try:
968+
if not is_fixed_param:
974969
idx_divergent = config_dict['column_names'].index('divergent__')
975970
idx_treedepth = config_dict['column_names'].index('treedepth__')
976971
max_treedepth = config_dict['max_depth']
977-
except ValueError:
978-
pass
972+
ct_divergences = 0
973+
ct_max_treedepth = 0
974+
979975
cur_pos = fd.tell()
980976
line = fd.readline().strip()
981977
while len(line) > 0 and not line.startswith('#'):
@@ -991,17 +987,18 @@ def scan_sampling_iters(
991987
'Try clearing up TEMP or setting output_dir to a path'
992988
' on another drive.',
993989
)
994-
if max_treedepth:
990+
cur_pos = fd.tell()
991+
line = fd.readline().strip()
992+
if not is_fixed_param:
995993
ct_divergences += int(data[idx_divergent]) # type: ignore
996994
if int(data[idx_treedepth]) == max_treedepth: # type: ignore
997995
ct_max_treedepth += 1
998-
cur_pos = fd.tell()
999-
line = fd.readline().strip()
996+
997+
fd.seek(cur_pos)
1000998
config_dict['draws_sampling'] = draws_found
1001-
if max_treedepth:
999+
if not is_fixed_param:
10021000
config_dict['ct_divergences'] = ct_divergences
10031001
config_dict['ct_max_treedepth'] = ct_max_treedepth
1004-
fd.seek(cur_pos)
10051002
return lineno
10061003

10071004

test/test_sample.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from time import time
1515

1616
import numpy as np
17-
from testfixtures import LogCapture
17+
from testfixtures import LogCapture, StringComparison
1818

1919
try:
2020
import ujson as json
@@ -488,6 +488,8 @@ def test_fixed_param_good(self):
488488
self.assertEqual(datagen_fit.metric_type, None)
489489
self.assertEqual(datagen_fit.metric, None)
490490
self.assertEqual(datagen_fit.step_size, None)
491+
self.assertEqual(datagen_fit.divergences, None)
492+
self.assertEqual(datagen_fit.max_treedepths, None)
491493

492494
for i in range(datagen_fit.runset.chains):
493495
csv_file = datagen_fit.runset.csv_files[i]
@@ -1805,13 +1807,12 @@ def test_diagnostics(self):
18051807
fit = model.sample(
18061808
data=rdata,
18071809
seed=55157,
1808-
show_progress=False,
1809-
show_console=False,
1810-
)
1811-
msg = log.actual()[-1][-1]
1812-
self.assertTrue(
1813-
msg.startswith('Some chains may have failed to converge.')
18141810
)
1811+
log.check_present((
1812+
'cmdstanpy',
1813+
'WARNING',
1814+
StringComparison(r'(?s).*Some chains may have failed to converge.*')
1815+
))
18151816
self.assertFalse(np.all(fit.divergences == 0))
18161817

18171818
with LogCapture(level=logging.WARNING) as log:
@@ -1821,8 +1822,11 @@ def test_diagnostics(self):
18211822
seed=40508,
18221823
max_treedepth=3,
18231824
)
1824-
msg = log.actual()[-1][-1]
1825-
self.assertTrue('max treedepth' in msg)
1825+
log.check_present((
1826+
'cmdstanpy',
1827+
'WARNING',
1828+
StringComparison(r'(?s).*max treedepth*')
1829+
))
18261830
self.assertFalse(np.all(fit.max_treedepths == 0))
18271831

18281832
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

0 commit comments

Comments
 (0)