Skip to content

Commit af1dab0

Browse files
committed
hotfix(dataset): fix eval_mask in dataset.
* we only evaluated points on non-ground point. * no need print ssf_metrics since our train range is out of evaluation.
1 parent 494b916 commit af1dab0

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,14 @@ def __getitem__(self, index_):
366366

367367
if self.eval_index:
368368
# looks like v2 not follow the same rule as v1 with eval_mask provided
369-
data_dict['eval_mask'] = np.ones_like(data_dict['pc0'][:, 0], dtype=np.bool_) if 'eval_mask' not in f[key] else f[key]['eval_mask'][:]
369+
# data_dict['eval_mask'] = np.ones_like(data_dict['pc0'][:, 0], dtype=np.bool_) if 'eval_mask' not in f[key] else f[key]['eval_mask'][:]
370+
if 'eval_mask' in f[key]:
371+
data_dict['eval_mask'] = f[key]['eval_mask'][:]
372+
elif 'ground_mask' in f[key]:
373+
data_dict['eval_mask'] = ~f[key]['ground_mask'][:]
374+
else:
375+
data_dict['eval_mask'] = np.ones_like(data_dict['pc0'][:, 0], dtype=np.bool_)
376+
370377
if self.transform:
371378
data_dict = self.transform(data_dict)
372379
return data_dict

src/utils/eval_metric.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def normalize(self):
342342

343343
self.epe_ssf['Mean'][motion] = np.nanmean(avg_epes)
344344

345-
def print(self):
345+
def print(self, ssf_metrics: bool = False):
346346
if not self.norm_flag:
347347
self.normalize()
348348
printed_data = []
@@ -357,8 +357,9 @@ def print(self):
357357
print("Version 2 Metric on Normalized Category-based:")
358358
print(tabulate(printed_data, headers=["Class", "Static", "Dynamic"], tablefmt='orgtbl'), "\n")
359359

360-
printed_data = []
361-
for key in self.epe_ssf:
362-
printed_data.append([key, np.around(self.epe_ssf[key]['Static'],4), np.around(self.epe_ssf[key]['Dynamic'],4), self.epe_ssf[key]["#Static"], self.epe_ssf[key]["#Dynamic"]])
363-
print("Version 3 Metric on EPE Distance-based:")
364-
print(tabulate(printed_data, headers=["Distance", "Static", "Dynamic", "#Static", "#Dynamic"], tablefmt='orgtbl'), "\n")
360+
if ssf_metrics:
361+
printed_data = []
362+
for key in self.epe_ssf:
363+
printed_data.append([key, np.around(self.epe_ssf[key]['Static'],4), np.around(self.epe_ssf[key]['Dynamic'],4), self.epe_ssf[key]["#Static"], self.epe_ssf[key]["#Dynamic"]])
364+
print("Version 3 Metric on EPE Distance-based:")
365+
print(tabulate(printed_data, headers=["Distance", "Static", "Dynamic", "#Static", "#Dynamic"], tablefmt='orgtbl'), "\n")

0 commit comments

Comments
 (0)