File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -95,8 +95,10 @@ def __init__(
9595 # only valid when not is_fixed_param
9696 self ._metric : np .ndarray = np .array (())
9797 self ._step_size : np .ndarray = np .array (())
98- self ._divergences : np .ndarray = np .array (())
99- self ._max_treedepths : np .ndarray = np .array (())
98+ self ._divergences : np .ndarray = np .zeros (self .runset .chains , dtype = int )
99+ self ._max_treedepths : np .ndarray = np .zeros (
100+ self .runset .chains , dtype = int
101+ )
100102
101103 # info from CSV initial comments and header
102104 config = self ._validate_csv_files ()
@@ -285,14 +287,6 @@ def _validate_csv_files(self) -> Dict[str, Any]:
285287 Tabulates sampling iters which are divergent or at max treedepth
286288 Raises exception when inconsistencies detected.
287289 """
288- if not self ._is_fixed_param :
289- self ._divergences : np .ndarray = np .zeros (
290- self .runset .chains , dtype = int
291- )
292- self ._max_treedepths : np .ndarray = np .zeros (
293- self .runset .chains , dtype = int
294- )
295-
296290 dzero = {}
297291 for i in range (self .chains ):
298292 if i == 0 :
You can’t perform that action at this time.
0 commit comments