Skip to content

Commit 0aed861

Browse files
committed
checkpointing; needs TODOs plus unit tests
1 parent b666000 commit 0aed861

2 files changed

Lines changed: 20 additions & 8 deletions

File tree

cmdstanpy/stanfit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None:
12551255
meta = scan_optimize_csv(sample_csv_0)
12561256
self._metadata = InferenceMetadata(meta)
12571257
self._column_names: Tuple[str, ...] = meta['column_names']
1258-
self._mle: List[float] = meta['mle']
1258+
self._mle = meta['mle']
12591259

12601260
@property
12611261
def column_names(self) -> Tuple[str, ...]:
@@ -1281,7 +1281,8 @@ def optimized_params_np(self) -> np.ndarray:
12811281
get_logger().warning(
12821282
'invalid estimate, optimization failed to converge'
12831283
)
1284-
return np.asarray(self._mle)
1284+
# TODO: squeeze?
1285+
return self._mle
12851286

12861287
@property
12871288
def optimized_params_pd(self) -> pd.DataFrame:
@@ -1290,7 +1291,7 @@ def optimized_params_pd(self) -> pd.DataFrame:
12901291
get_logger().warning(
12911292
'invalid estimate, optimization failed to converge'
12921293
)
1293-
return pd.DataFrame([self._mle], columns=self.column_names)
1294+
return pd.DataFrame(self._mle, columns=self.column_names)
12941295

12951296
@property
12961297
def optimized_params_dict(self) -> Dict[str, float]:
@@ -1299,6 +1300,7 @@ def optimized_params_dict(self) -> Dict[str, float]:
12991300
get_logger().warning(
13001301
'invalid estimate, optimization failed to converge'
13011302
)
1303+
# TODO: return final estimate only
13021304
return OrderedDict(zip(self.column_names, self._mle))
13031305

13041306
def stan_variable(
@@ -1343,6 +1345,7 @@ def stan_variable(
13431345
)
13441346

13451347
col_idxs = list(self._metadata.stan_vars_cols[var])
1348+
# TODO: return final estimate only
13461349
vals = list(self._mle)
13471350
xs = [vals[x] for x in col_idxs]
13481351
shape: Tuple[int, ...] = ()

cmdstanpy/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,14 +590,23 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
590590
for line in fd:
591591
iters += 1
592592
# rescan to capture estimates
593+
mle = np.empty(
594+
(iters, len(dict['column_names'])), dtype=float, order='F'
595+
)
593596
with open(path, 'r') as fd:
594597
for i in range(lineno):
595598
fd.readline()
596-
line = fd.readline()
597-
# allocate numpy ndarray for
598-
print(line)
599-
xs = line.split(',')
600-
dict['mle'] = [float(x) for x in xs]
599+
for i in range(iters):
600+
line = fd.readline().strip()
601+
if len(line) < 1:
602+
raise ValueError(
603+
'cannot parse CSV file {}, error at line {}'.format(
604+
path, lineno + i
605+
)
606+
)
607+
xs = line.split(',')
608+
mle[i, :] = [float(x) for x in xs]
609+
dict['mle'] = mle
601610
return dict
602611

603612

0 commit comments

Comments
 (0)