|
16 | 16 | import os, sys |
17 | 17 | BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '../..' )) |
18 | 18 | sys.path.append(BASE_DIR) |
| 19 | +from src.utils.av2_eval import CATEGORY_TO_INDEX, BUCKETED_METACATAGORIES |
19 | 20 |
|
| 21 | +# check: https://arxiv.org/abs/2508.17054 |
| 22 | +def deltaflowLoss(res_dict): |
| 23 | + pred = res_dict['est_flow'] |
| 24 | + gt = res_dict['gt_flow'] |
| 25 | + classes = res_dict['gt_classes'] |
| 26 | + instances = res_dict['gt_instance'] |
| 27 | + |
| 28 | + reassign_meta = torch.zeros_like(classes, dtype=torch.int, device=classes.device) |
| 29 | + for i, cats in enumerate(BUCKETED_METACATAGORIES): |
| 30 | + selected_classes_ids = [CATEGORY_TO_INDEX[cat] for cat in BUCKETED_METACATAGORIES[cats]] |
| 31 | + reassign_meta[torch.isin(classes, torch.tensor(selected_classes_ids, device=classes.device))] = i |
| 32 | + |
| 33 | + pts_loss = torch.linalg.vector_norm(pred - gt, dim=-1) |
| 34 | + speed = torch.linalg.vector_norm(gt, dim=-1) / 0.1 |
| 35 | + |
| 36 | + weight_loss = deflowLoss(res_dict)['loss'] |
| 37 | + |
| 38 | + classes_loss = 0.0 |
| 39 | + weight = [0.1, 1.0, 2.0, 2.5, 1.5] # BACKGROUND, CAR, PEDESTRIAN, WHEELED, OTHER |
| 40 | + for class_id in range(len(BUCKETED_METACATAGORIES)): |
| 41 | + mask = reassign_meta == class_id |
| 42 | + for loss_ in [0.1 * pts_loss[(speed < 0.4) & mask].mean(), |
| 43 | + 0.4 * pts_loss[(speed >= 0.4) & (speed <= 1.0) & mask].mean(), |
| 44 | + 0.5 * pts_loss[(speed > 1.0) & mask].mean()]: |
| 45 | + classes_loss += torch.nan_to_num(loss_, nan=0.0) * weight[class_id] |
| 46 | + |
| 47 | + instance_loss, cnt = 0.0, 0 |
| 48 | + if instances is not None: |
| 49 | + for instance_id in torch.unique(instances): |
| 50 | + mask = instances == instance_id |
| 51 | + reassign_meta_instance = reassign_meta[mask] |
| 52 | + class_id = torch.mode(reassign_meta_instance, 0).values.item() |
| 53 | + loss_ = pts_loss[mask].mean() |
| 54 | + if speed[mask].mean() <= 0.4: |
| 55 | + continue |
| 56 | + instance_loss += (loss_ * torch.exp(loss_) * weight[class_id]) |
| 57 | + cnt += 1 |
| 58 | + instance_loss /= (cnt if cnt > 0 else 1) |
| 59 | + return {'loss': weight_loss + classes_loss + instance_loss} |
| 60 | + |
| 61 | +# check: https://arxiv.org/abs/2401.16122 |
20 | 62 | def deflowLoss(res_dict): |
21 | 63 | pred = res_dict['est_flow'] |
22 | 64 | gt = res_dict['gt_flow'] |
|
0 commit comments