Skip to content

Commit a888ee2

Browse files
committed
loss(deltaflow): add deltaflow loss.
1 parent 477c580 commit a888ee2

3 files changed

Lines changed: 47 additions & 5 deletions

File tree

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
DeltaFlow
1+
DeltaFlow: An Efficient Multi-frame Scene Flow Estimation Method
22
---
33

44
[![arXiv](https://img.shields.io/badge/arXiv-2508.17054-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2508.17054)
@@ -14,9 +14,8 @@ Note (2025/09/18): We got accepted by NeurIPS 2025 and it's **spotlighted**!
1414
- [x] 2025/08/24: Updating train data augmentation as illustrated in the DeltaFlow paper.
1515
- [x] 2025/08/25: Updating paper preprint link.
1616
- [x] 2025/09/05: Merged the latest commit from OpenSceneFlow codebase to DeltaFlow for afterward unified merged.
17-
- [x] 2025/09/25: DeltaFlow Model python file and config file.
17+
- [x] 2025/09/25: DeltaFlow Model file, config file and loss function. Update quick training example.
1818
- [ ] pre-trained weights upload.
19-
- [ ] DeltaFlow Loss fn.
2019
- [ ] Merged into [OpenSceneFlow](https://github.com/KTH-RPL/OpenSceneFlow)
2120

2221
## Quick Run
@@ -34,7 +33,7 @@ unzip demo-data-v2.zip -d /home/kin/data/av2/h5py # to your data path
3433

3534
3. Run the training with the following command (modify the data path accordingly):
3635
```bash
37-
python train.py model=deltaflow batch_size=4 num_frames=5 voxel_size="[0.15,0.15,0.15]" point_cloud_range="[-38.4,-38.4,-3,38.4,38.4,3]" optimizer.lr=2e-4 train_data=${demo_train_data_path} val_data=${demo_val_data_path}
36+
python train.py model=deltaflow loss_fn=deltaflowLoss batch_size=4 num_frames=5 voxel_size="[0.15,0.15,0.15]" point_cloud_range="[-38.4,-38.4,-3,38.4,38.4,3]" optimizer.lr=2e-4 train_data=${demo_train_data_path} val_data=${demo_val_data_path}
3837
```
3938
### Evaluation
4039

src/lossfuncs/supervise.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,49 @@
1616
import os, sys
1717
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '../..' ))
1818
sys.path.append(BASE_DIR)
19+
from src.utils.av2_eval import CATEGORY_TO_INDEX, BUCKETED_METACATAGORIES
1920

21+
# check: https://arxiv.org/abs/2508.17054
22+
def deltaflowLoss(res_dict):
23+
pred = res_dict['est_flow']
24+
gt = res_dict['gt_flow']
25+
classes = res_dict['gt_classes']
26+
instances = res_dict['gt_instance']
27+
28+
reassign_meta = torch.zeros_like(classes, dtype=torch.int, device=classes.device)
29+
for i, cats in enumerate(BUCKETED_METACATAGORIES):
30+
selected_classes_ids = [CATEGORY_TO_INDEX[cat] for cat in BUCKETED_METACATAGORIES[cats]]
31+
reassign_meta[torch.isin(classes, torch.tensor(selected_classes_ids, device=classes.device))] = i
32+
33+
pts_loss = torch.linalg.vector_norm(pred - gt, dim=-1)
34+
speed = torch.linalg.vector_norm(gt, dim=-1) / 0.1
35+
36+
weight_loss = deflowLoss(res_dict)['loss']
37+
38+
classes_loss = 0.0
39+
weight = [0.1, 1.0, 2.0, 2.5, 1.5] # BACKGROUND, CAR, PEDESTRIAN, WHEELED, OTHER
40+
for class_id in range(len(BUCKETED_METACATAGORIES)):
41+
mask = reassign_meta == class_id
42+
for loss_ in [0.1 * pts_loss[(speed < 0.4) & mask].mean(),
43+
0.4 * pts_loss[(speed >= 0.4) & (speed <= 1.0) & mask].mean(),
44+
0.5 * pts_loss[(speed > 1.0) & mask].mean()]:
45+
classes_loss += torch.nan_to_num(loss_, nan=0.0) * weight[class_id]
46+
47+
instance_loss, cnt = 0.0, 0
48+
if instances is not None:
49+
for instance_id in torch.unique(instances):
50+
mask = instances == instance_id
51+
reassign_meta_instance = reassign_meta[mask]
52+
class_id = torch.mode(reassign_meta_instance, 0).values.item()
53+
loss_ = pts_loss[mask].mean()
54+
if speed[mask].mean() <= 0.4:
55+
continue
56+
instance_loss += (loss_ * torch.exp(loss_) * weight[class_id])
57+
cnt += 1
58+
instance_loss /= (cnt if cnt > 0 else 1)
59+
return {'loss': weight_loss + classes_loss + instance_loss}
60+
61+
# check: https://arxiv.org/abs/2401.16122
2062
def deflowLoss(res_dict):
2163
pred = res_dict['est_flow']
2264
gt = res_dict['gt_flow']

src/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def training_step(self, batch, batch_idx):
132132

133133
dict2loss = {'est_flow': est_flow[batch_id],
134134
'gt_flow': None if 'flow' not in batch else batch['flow'][batch_id][pc0_valid_from_pc2res] - pose_flow_,
135-
'gt_classes': None if 'flow_category_indices' not in batch else batch['flow_category_indices'][batch_id][pc0_valid_from_pc2res]}
135+
'gt_classes': None if 'flow_category_indices' not in batch else batch['flow_category_indices'][batch_id][pc0_valid_from_pc2res],
136+
'gt_instance': None if 'flow_instance_id' not in batch else batch['flow_instance_id'][batch_id][pc0_valid_from_pc2res],}
136137

137138
if 'pc0_dynamic' in batch:
138139
dict2loss['pc0_labels'] = batch['pc0_dynamic'][batch_id][pc0_valid_from_pc2res]

0 commit comments

Comments
 (0)