Skip to content

Commit 3bf1121

Browse files
committed
docs(README): update readme based on our refactor.
chore: move voxel to encoder script. fix(dataset): index eval in the script. style(trainer): log print, eval cfg check.
1 parent 4d4eecd commit 3bf1121

10 files changed

Lines changed: 491 additions & 699 deletions

File tree

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ SeFlow: A Self-Supervised Scene Flow Method in Autonomous Driving
33

44
[![arXiv](https://img.shields.io/badge/arXiv-2407.01702-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2407.01702)
55
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/seflow-a-self-supervised-scene-flow-method-in/self-supervised-scene-flow-estimation-on-1)](https://paperswithcode.com/sota/self-supervised-scene-flow-estimation-on-1?p=seflow-a-self-supervised-scene-flow-method-in)
6-
[poster comming soon]
7-
[video coming soon]
6+
[![poster](https://img.shields.io/badge/ICRA24|Poster-6495ed?style=flat&logo=Shotcut&logoColor=wihte)](https://hkustconnect-my.sharepoint.com/:b:/g/personal/qzhangcb_connect_ust_hk/EWyWD-tAX4xIma5U7ZQVk9cBVjsFv0Y_jAC2G7xAB-w4cg?e=c3FbMg)
7+
[![video](https://img.shields.io/badge/video-YouTube-FF0000?logo=youtube&logoColor=white)](https://youtu.be/fQqx2IES-VI)
88

99
![](assets/docs/seflow_arch.png)
1010

@@ -24,9 +24,9 @@ We directly follow our previous work [code structure](https://github.com/KTH-RPL
2424

2525
- `train.py`: Train the model and get model checkpoints. Pls remember to check the config.
2626

27-
- `eval.py` : Evaluate the model on the validation/test set. And also upload to online leaderboard.
27+
- `eval.py` : Evaluate the model on the validation/test set. And also output the zip file to upload to online leaderboard.
2828

29-
- `save.py` : For visualization of the results with a video.
29+
- `save.py` : Will save result into h5py file, using [tool/visualization.py] to show results with interactive window.
3030

3131
<details> <summary>🎁 <b>One repository, All methods!</b> </summary>
3232
<!-- <br> -->
@@ -118,7 +118,7 @@ Or you can directly download the pre-trained weight from [Zenodo](https://zenodo
118118

119119
You can also train the supervised baseline model in our paper with the following command. [Runtime: Around 10 hours in 4x A100 GPUs.]
120120
```bash
121-
python train.py model=fastflow3d lr=2e-4 epochs=20 batch_size=16 loss_fn=ff3dLoss
121+
python train.py model=fastflow3d lr=4e-5 epochs=20 batch_size=16 loss_fn=ff3dLoss
122122
python train.py model=deflow lr=2e-4 epochs=20 batch_size=16 loss_fn=deflowLoss
123123
```
124124

assets/README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,17 @@ git clone https://github.com/KTH-RPL/SeFlow.git
3737
mamba env create -f assets/environment.yml
3838
```
3939

40-
CUDA package (need install nvcc compiler), the compile time is around 1-5 minutes:
40+
CUDA package (nvcc compiler already installed through conda), the compile time is around 1-5 minutes:
4141
```bash
4242
mamba activate seflow
43-
# change it if you use different cuda version (I tested 11.3, 11.4, 11.7 all works)
44-
export PATH=/usr/local/cuda-11.7/bin:$PATH
45-
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64:$LD_LIBRARY_PATH
46-
4743
cd assets/cuda/mmcv && python ./setup.py install && cd ../../..
4844
cd assets/cuda/chamfer3D && python ./setup.py install && cd ../../..
4945
```
5046

5147

5248
Checking important packages in our environment now:
5349
```bash
54-
mamba activate deflow
50+
mamba activate seflow
5551
python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.version.cuda)"
5652
python -c "import lightning.pytorch as pl; print(pl.__version__)"
5753
python -c "from assets.cuda.mmcv import Voxelization, DynamicScatter;print('successfully import on our lite mmcv package')"
@@ -78,8 +74,8 @@ python -c "from assets.cuda.chamfer3D import nnChamferDis;print('successfully im
7874

7975
If you want to contribute to new model, here are tips you can follow:
8076
1. Dataloader: we believe all data could be process to `.h5`, we named as different scene and inside a scene, the key of each data is timestamp. Check [dataprocess/README.md](../dataprocess/README.md#process) for more details.
81-
2. Model: All model files can be found [here: scripts/network/models](../scripts/network/models). You can view deflow and fastflow3d to know how to implement a new model.
82-
3. Loss: All loss files can be found [here: scripts/network/loss_func.py](../scripts/network/loss_func.py). There are three loss functions already inside the file, you can add a new one following the same pattern.
77+
2. Model: All model files can be found [here: src/models](../src/models). You can view deflow and fastflow3d to know how to implement a new model. Don't forget to add to the `__init__.py` [file to import class](../src/models/__init__.py).
78+
3. Loss: All loss files can be found [here: src/lossfuncs.py](../src/lossfuncs.py). There are three loss functions already inside the file, you can add a new one following the same pattern.
8379
4. Training: Once you have implemented the model, you can add the model to the config file [here: conf/model](../conf/model) and train the model using the command `python train.py model=your_model_name`. One more note here may: if your res_dict from model output is different, you may need add one pattern in `def training_step` and `def validation_step`.
8480

8581
All others like eval and vis will be changed according to the model you implemented as you follow the above steps.

eval.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,21 @@
1414
from torch.utils.data import DataLoader
1515
import lightning.pytorch as pl
1616
from lightning.pytorch.loggers import WandbLogger
17-
from omegaconf import DictConfig, OmegaConf
17+
from omegaconf import DictConfig
1818
import hydra, wandb, os, sys
1919
from hydra.core.hydra_config import HydraConfig
2020
from src.dataset import HDF5Dataset
2121
from src.trainer import ModelWrapper
2222

23+
def precheck_cfg_valid(cfg):
24+
if os.path.exists(cfg.dataset_path + f"/{cfg.av2_mode}") is False:
25+
raise ValueError(f"Dataset {cfg.dataset_path}/{cfg.av2_mode} does not exist. Please check the path.")
26+
if cfg.supervised_flag not in [True, False]:
27+
raise ValueError(f"Supervised flag {cfg.supervised_flag} is not valid. Please set it to True or False.")
28+
if cfg.leaderboard_version not in [1, 2]:
29+
raise ValueError(f"Leaderboard version {cfg.leaderboard_version} is not valid. Please set it to 1 or 2.")
30+
return cfg
31+
2332
@hydra.main(version_base=None, config_path="conf", config_name="eval")
2433
def main(cfg):
2534
pl.seed_everything(cfg.seed, workers=True)
@@ -35,7 +44,7 @@ def main(cfg):
3544
cfg.model.update(checkpoint_params.cfg.model)
3645

3746
mymodel = ModelWrapper.load_from_checkpoint(cfg.checkpoint, cfg=cfg, eval=True)
38-
print(f"\n---LOG[eval]: Loaded model from {cfg.checkpoint}. The model is {checkpoint_params.cfg.model.name}.\n")
47+
print(f"\n---LOG[eval]: Loaded model from {cfg.checkpoint}. The backbone network is {checkpoint_params.cfg.model.name}.\n")
3948

4049
wandb_logger = WandbLogger(save_dir=output_dir,
4150
entity="kth-rpl",

logs/.gitkeep

Whitespace-only changes.

src/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.data import Dataset, DataLoader
1717
import h5py, os, pickle, argparse, sys
1818
from tqdm import tqdm
19-
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '../..' ))
19+
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' ))
2020
sys.path.append(BASE_DIR)
2121

2222
def collate_fn_pad(batch):

0 commit comments

Comments
 (0)