Skip to content

Commit 6659c89

Browse files
committed
feat(loss): update seflow loss calculation.
some notes for pointing out the equation to the paper.
1 parent 3b11c10 commit 6659c89

10 files changed

Lines changed: 226 additions & 92 deletions

File tree

1_train.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828

2929
@hydra.main(version_base=None, config_path="conf", config_name="config")
3030
def main(cfg):
31+
if cfg.loss_fn == 'seflowLoss' and cfg.add_seloss is None:
32+
raise ValueError("Please specify the self-supervised loss items for seflowLoss.")
3133
pl.seed_everything(cfg.seed, workers=True)
32-
output_dir = HydraConfig.get().runtime.output_dir
3334

34-
train_dataset = HDF5Dataset(cfg.train_data)
35+
train_dataset = HDF5Dataset(cfg.train_data, dufo=(cfg.loss_fn == 'seflowLoss'))
3536
train_loader = DataLoader(train_dataset,
3637
batch_size=cfg.batch_size,
3738
shuffle=True,
@@ -48,7 +49,14 @@ def main(cfg):
4849

4950
# count gpus, overwrite gpus
5051
cfg.gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
51-
model_name = cfg.model.name
52+
53+
# only for logging on folder name.
54+
if cfg.loss_fn == 'seflowLoss':
55+
method_name = "seflow"
56+
cfg.output = cfg.output.replace("deflow", "seflow")
57+
else:
58+
method_name = cfg.model.name
59+
output_dir = HydraConfig.get().runtime.output_dir + f"/{cfg.output}"
5260
Path(os.path.join(output_dir, "checkpoints")).mkdir(parents=True, exist_ok=True)
5361

5462
cfg = DictConfig(OmegaConf.to_container(cfg, resolve=True))
@@ -57,7 +65,7 @@ def main(cfg):
5765
callbacks = [
5866
ModelCheckpoint(
5967
dirpath=os.path.join(output_dir, "checkpoints"),
60-
filename="{epoch:02d}_"+model_name,
68+
filename="{epoch:02d}_"+method_name,
6169
auto_insert_metric_name=False,
6270
monitor=cfg.model.val_monitor,
6371
mode="min",
@@ -90,6 +98,9 @@ def main(cfg):
9098
print("Initiating wandb and trainer successfully. ^V^ ")
9199
print(f"We will use {cfg.gpus} GPUs to train the model. Check the checkpoints in {output_dir} checkpoints folder.")
92100
print("Total Train Dataset Size: ", len(train_dataset))
101+
if cfg.add_seloss is not None and cfg.loss_fn == 'seflowLoss':
102+
print(f"Note: We are in **self-supervised** training now. No ground truth label is used.")
103+
print(f"We will use these loss items in {cfg.loss_fn}: {cfg.add_seloss}")
93104
print("-"*40+"\n")
94105

95106
# NOTE(Qingwen): search & check: def training_step(self, batch, batch_idx)

README.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ SeFlow: A Self-Supervised Scene Flow Method in Autonomous Driving
66
[poster comming soon]
77
[video coming soon]
88

9-
2024/07/05 11:35: I'm working on updating code here now. **Not fully ready yet** until Jul'15.
9+
2024/07/07 13:45: I'm working on updating code here now. **Not fully ready yet** until Jul'15.
1010

1111
Pre-trained weights for models are available in [Zenodo](https://zenodo.org/records/12632962) link. Check usage in [2. Evaluation](#2-evaluation) or [3. Visualization](#3-visualization).
1212

@@ -60,7 +60,19 @@ docker run -it --gpus all -v /dev/shm:/dev/shm -v /home/kin/data:/home/kin/data
6060

6161
Note: Prepare raw data and process train data only needed run once for the task. No need to run till you delete all data.
6262

63-
### Prepare raw data
63+
### Data Preparation
64+
65+
Check [dataprocess/README.md](dataprocess/README.md#argoverse-20) for downloading tips for the raw Argoverse 2 dataset
66+
67+
Maybe you only want to have the mini processed dataset to try the code quickly, We directly provide one scene inside `train` and `val`. It already converted to `.h5` format and processed with the label data.
68+
<!-- You can download it from [Zenodo](https://zenodo.org/record/12632962) and extract it to the data folder. -->
69+
```bash
70+
# TODO: update the link later when the data is ready
71+
# wget https://zenodo.org/record/12632962/files/demo_data.zip
72+
unzip demo_data.zip -p /home/kin/data/av2
73+
```
74+
75+
#### Prepare raw data
6476

6577
Extract all data to unified h5 format. [Runtime: Normally need 10 mins finished run following commands totally in my desktop, 45 mins for the cluster I used]
6678
```bash
@@ -69,7 +81,7 @@ python dataprocess/extract_av2.py --av2_type sensor --data_mode val --mask_dir /
6981
python dataprocess/extract_av2.py --av2_type sensor --data_mode test --mask_dir /home/kin/data/av2/3d_scene_flow
7082
```
7183

72-
### Process train data
84+
#### Process train data
7385

7486
Process train data for self-supervised learning. Only training data needs this step. [Runtime: Normally need 15 hours for my desktop, 3 hours for the cluster with five available nodes parallel running.]
7587

@@ -85,6 +97,13 @@ Train SeFlow needed to specify the loss function, we set the config of our best
8597
python 1_train.py model=deflow lr=2e-4 epochs=20 batch_size=16 loss_fn=seflowLoss "add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" "model.target.num_iters=2" "model.val_monitor=val/Dynamic/Mean"
8698
```
8799

100+
### Other Benchmark Models
101+
102+
```bash
103+
python 1_train.py model=fastflow3d lr=2e-4 epochs=20 batch_size=16 loss_fn=deflowLoss
104+
python 1_train.py model=deflow lr=2e-4 epochs=20 batch_size=16 loss_fn=ff3dLoss
105+
```
106+
88107
## 2. Evaluation
89108

90109
You can view Wandb dashboard for the training and evaluation results or upload result to online leaderboard.
@@ -95,8 +114,9 @@ Since in training, we save all hyper-parameters and model checkpoints, the only
95114
# downloaded pre-trained weight, or train by yourself
96115
wget https://zenodo.org/records/12632962/files/seflow_official.ckpt
97116

117+
# it will directly prints all metric
118+
python 2_eval.py checkpoint=/home/kin/seflow_official.ckpt av2_mode=val
98119

99-
python 2_eval.py checkpoint=/home/kin/seflow_official.ckpt av2_mode=val # it will directly prints all metric
100120
# it will output the av2_submit.zip or av2_submit_v2.zip for you to submit to leaderboard
101121
python 2_eval.py checkpoint=/home/kin/seflow_official.ckpt av2_mode=test leaderboard_version=1
102122
python 2_eval.py checkpoint=/home/kin/seflow_official.ckpt av2_mode=test leaderboard_version=2

assets/cuda/mmcv/setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@
1111
name='mmcv._ext',
1212
sources=[
1313
"/".join(__file__.split("/")[:-1] + ["scatter_points_cuda.cu"]),
14+
"/".join(__file__.split("/")[:-1] + ["scatter_points.cpp"]),
1415
"/".join(__file__.split("/")[:-1] + ["voxelization_cuda.cu"]),
1516
"/".join(__file__.split("/")[:-1] + ["voxelization.cpp"]),
16-
"/".join(__file__.split("/")[:-1] + ["scatter_points.cpp"]),
1717
"/".join(__file__.split("/")[:-1] + ["cudabind.cpp"]),
1818
"/".join(__file__.split("/")[:-1] + ["pybind.cpp"]),
1919

20-
]),
21-
# extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
20+
],
21+
# extra_compile_args={
22+
# 'cxx': ['-std=c++17'],
23+
# 'nvcc': ['-std=c++17',
24+
# '-D__CUDA_NO_HALF_OPERATORS__',
25+
# '-D__CUDA_NO_HALF_CONVERSIONS__',
26+
# '-D__CUDA_NO_HALF2_OPERATORS__',
27+
# ],}
28+
),
2229
],
2330
cmdclass={'build_ext': BuildExtension},
2431

assets/slurm/1_train.sh

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#!/bin/bash
2-
#SBATCH -J deflow
3-
#SBATCH --gpus 8 -C "thin"
2+
#SBATCH -J seflow
3+
#SBATCH --gpus 4 -C "fat"
44
#SBATCH -t 3-00:00:00
55
#SBATCH --mail-type=END,FAIL
66
#SBATCH --mail-user=qingwen@kth.se
7-
#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/deflow/logs/slurm/%J_deflow.out
8-
#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/deflow/logs/slurm/%J_deflow.err
7+
#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.out
8+
#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.err
99

10-
cd /proj/berzelius-2023-154/users/x_qinzh/deflow
10+
cd /proj/berzelius-2023-154/users/x_qinzh/seflow
1111

12-
SOURCE="/proj/berzelius-2023-154/users/x_qinzh/av2/deflow_preprocess"
12+
SOURCE="/proj/berzelius-2023-154/users/x_qinzh/data/av2/seflow_preprocess"
1313
DEST="/scratch/local/av2"
1414
SUBDIRS=("sensor/train" "sensor/val")
1515

@@ -24,55 +24,14 @@ elapsed=$((end_time - start_time))
2424
echo "Copy ${SOURCE} to ${DEST} Total time: ${elapsed} seconds"
2525
echo "Start training..."
2626

27-
# ====> leaderboard model = [fastflow3d, deflow]
28-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
29-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
30-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=10 loss_fn=deflowLoss
31-
32-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
33-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
34-
# num_workers=16 model=fastflow3d lr=2e-6 epochs=50 batch_size=16 loss_fn=ff3dLoss
35-
36-
37-
38-
39-
# ===> ablation A: iteration num [2, 4 (R), 8, 16]
40-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
41-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
42-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=10 loss_fn=deflowLoss "model.target.num_iters=2"
43-
44-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
45-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
46-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=8 loss_fn=deflowLoss "model.target.num_iters=8"
47-
48-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
49-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
50-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=10 loss_fn=deflowLoss "model.target.num_iters=16"
51-
52-
53-
# ===> ablation B: loss_fn --- loss_fn = [ff3dLoss (R), zeroflowLoss, deflowLoss]
54-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
55-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
56-
# num_workers=16 model=fastflow3d lr=2e-6 epochs=50 batch_size=16 loss_fn=zeroflowLoss
57-
58-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
59-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
60-
# num_workers=16 model=fastflow3d lr=2e-6 epochs=50 batch_size=16 loss_fn=deflowLoss
61-
62-
63-
# ===> ablation C: decoder --- model.target.decoder_option = [linear, gru] and fastflow3d resolution [0.1, 0.2 (R), 0.4]
64-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
65-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
66-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=10 loss_fn=ff3dLoss "model.target.decoder_option=linear"
67-
68-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
69-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
70-
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=10 loss_fn=ff3dLoss "model.target.decoder_option=gru"
71-
72-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
73-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
74-
# num_workers=16 model=fastflow3d lr=2e-6 epochs=50 batch_size=10 loss_fn=ff3dLoss "voxel_size=[0.1, 0.1, 6]"
75-
76-
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/deflow/bin/python 1_train.py \
77-
# slurm_id=$SLURM_JOB_ID wandb_mode=online dataset_path=/scratch/local/av2/sensor \
78-
# num_workers=16 model=fastflow3d lr=2e-6 epochs=50 batch_size=16 loss_fn=ff3dLoss "voxel_size=[0.4, 0.4, 6]"
27+
# ====> paper model = seflow_official
28+
# /proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/seflow/bin/python 1_train.py \
29+
# slurm_id=$SLURM_JOB_ID wandb_mode=online train_data=/scratch/local/av2/sensor/train val_data=/scratch/local/av2/sensor/val \
30+
# num_workers=16 model=deflow lr=2e-6 epochs=50 batch_size=20 "model.target.num_iters=2" "model.val_monitor=val/Dynamic/Mean" \
31+
# loss_fn=seflowLoss "add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}"
32+
33+
# ====> leaderboard model = seflow_best
34+
/proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/seflow/bin/python 1_train.py \
35+
slurm_id=$SLURM_JOB_ID wandb_mode=online train_data=/scratch/local/av2/sensor/train val_data=/scratch/local/av2/sensor/val \
36+
num_workers=16 model=deflow lr=2e-4 epochs=20 batch_size=16 "model.target.num_iters=2" "model.val_monitor=val/Dynamic/Mean" \
37+
loss_fn=seflowLoss "add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}"

conf/config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ gradient_clip_val: 5.0
2727

2828
# optimizer ==> Adam
2929
lr: 2e-6
30-
loss_fn: deflowLoss # choices: [ff3dLoss, zeroflowLoss, deflowLoss, seflowLoss]
31-
add_seloss: {chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}
32-
label_name: label # choices: [label, nnd_label_plus]
30+
loss_fn: seflowLoss # choices: [ff3dLoss, zeroflowLoss, deflowLoss, seflowLoss]
31+
add_seloss:
3332

3433
# log settings
3534
seed: 42069

conf/hydra/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
run:
2-
dir: logs/wandb/${output}
2+
dir: logs/wandb

dataprocess/extract_av2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def create_eval_mask(data_mode: str, output_dir_: Path, mask_dir: str):
6767
timestamps = sorted([int(file.replace('.feather', ''))
6868
for file in os.listdir(Path(mask_dir) / f"{data_mode}-masks" / scene_id)
6969
if file.endswith('.feather')])
70+
if not os.path.exists(output_dir_ / f'{scene_id}.h5'):
71+
continue
7072
with h5py.File(output_dir_ / f'{scene_id}.h5', 'r+') as f:
7173
for ts in timestamps:
7274
key = str(ts)

scripts/network/dataloader.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@ def collate_fn_pad(batch):
4040

4141
if 'ego_motion' in batch[0]:
4242
res_dict['ego_motion'] = [batch[i]['ego_motion'] for i in range(len(batch))]
43+
44+
if 'pc0_dynamic' in batch[0]:
45+
pc0_dynamic_after_mask_ground, pc1_dynamic_after_mask_ground= [], []
46+
for i in range(len(batch)):
47+
pc0_dynamic_after_mask_ground.append(batch[i]['pc0_dynamic'][~batch[i]['gm0']])
48+
pc1_dynamic_after_mask_ground.append(batch[i]['pc1_dynamic'][~batch[i]['gm1']])
49+
pc0_dynamic_after_mask_ground = torch.nn.utils.rnn.pad_sequence(pc0_dynamic_after_mask_ground, batch_first=True, padding_value=0)
50+
pc1_dynamic_after_mask_ground = torch.nn.utils.rnn.pad_sequence(pc1_dynamic_after_mask_ground, batch_first=True, padding_value=0)
51+
res_dict['pc0_dynamic'] = pc0_dynamic_after_mask_ground
52+
res_dict['pc1_dynamic'] = pc1_dynamic_after_mask_ground
4353

4454
return res_dict
4555
class HDF5Dataset(Dataset):
46-
def __init__(self, directory, eval = False, leaderboard_version=1):
56+
def __init__(self, directory, dufo=False, eval = False, leaderboard_version=1):
4757
'''
4858
directory: the directory of the dataset
4959
eval: if True, use the eval index
@@ -55,6 +65,7 @@ def __init__(self, directory, eval = False, leaderboard_version=1):
5565
self.data_index = pickle.load(f)
5666

5767
self.eval_index = False
68+
self.dufo = dufo
5869
if eval:
5970
index_file_name = 'index_eval.pkl'
6071
if leaderboard_version == 2:
@@ -106,12 +117,12 @@ def __getitem__(self, index_):
106117

107118
key = str(timestamp)
108119
with h5py.File(os.path.join(self.directory, f'{scene_id}.h5'), 'r') as f:
109-
pc0 = torch.tensor(f[key]['lidar'][:])
120+
pc0 = torch.tensor(f[key]['lidar'][:][:,:3])
110121
gm0 = torch.tensor(f[key]['ground_mask'][:])
111122
pose0 = torch.tensor(f[key]['pose'][:])
112123

113124
next_timestamp = str(self.data_index[index_+1][1])
114-
pc1 = torch.tensor(f[next_timestamp]['lidar'][:])
125+
pc1 = torch.tensor(f[next_timestamp]['lidar'][:][:,:3])
115126
gm1 = torch.tensor(f[next_timestamp]['ground_mask'][:])
116127
pose1 = torch.tensor(f[next_timestamp]['pose'][:])
117128
# if pc0[~gm0].shape[0] == 0:
@@ -143,10 +154,15 @@ def __getitem__(self, index_):
143154
ego_motion = torch.tensor(f[key]['ego_motion'][:])
144155
res_dict['ego_motion'] = ego_motion
145156

157+
if self.dufo:
158+
res_dict['pc0_dynamic'] = torch.tensor(f[key]['label'][:].astype('int16'))
159+
res_dict['pc1_dynamic'] = torch.tensor(f[next_timestamp]['label'][:].astype('int16'))
160+
146161
if self.eval_index:
147162
# looks like v2 not follow the same rule as v1 with eval_mask provided
148163
eval_mask = torch.tensor(f[key]['eval_mask'][:]) if 'eval_mask' in f[key] else torch.ones_like(pc0[:, 0], dtype=torch.bool)
149164
res_dict['eval_mask'] = eval_mask
165+
150166
return res_dict
151167

152168
if __name__ == "__main__":

0 commit comments

Comments
 (0)