Skip to content

Commit 4d94bd0

Browse files
committed
fix(num_frames): compatible to the old model also.
1 parent d283ec8 commit 4d94bd0

3 files changed

Lines changed: 9 additions & 3 deletions

File tree

2_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def main(cfg):
4646
trainer = pl.Trainer(logger=wandb_logger, devices=1)
4747
# NOTE(Qingwen): search & check: def eval_only_step_(self, batch, res_dict)
4848
trainer.validate(model = mymodel, \
49-
dataloaders = DataLoader(HDF5Dataset(cfg.dataset_path + f"/{cfg.av2_mode}", n_frames=checkpoint_params.cfg.num_frames, eval=True, leaderboard_version=cfg.leaderboard_version), batch_size=1, shuffle=False))
49+
dataloaders = DataLoader( \
50+
HDF5Dataset(cfg.dataset_path + f"/{cfg.av2_mode}", \
51+
n_frames=checkpoint_params.cfg.num_frames if 'num_frames' in checkpoint_params.cfg else 2, \
52+
eval=True, leaderboard_version=cfg.leaderboard_version), \
53+
batch_size=1, shuffle=False))
5054
wandb.finish()
5155

5256
if __name__ == "__main__":

3_vis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def main(cfg):
4949
trainer = pl.Trainer(logger=wandb_logger, devices=1)
5050
# NOTE(Qingwen): search & check in pl_model.py : def test_step(self, batch, res_dict)
5151
trainer.test(model = mymodel, \
52-
dataloaders = DataLoader(HDF5Dataset(cfg.dataset_path, n_frames=checkpoint_params.cfg.num_frames), batch_size=1, shuffle=False))
52+
dataloaders = DataLoader(\
53+
HDF5Dataset(cfg.dataset_path, n_frames=checkpoint_params.cfg.num_frames if 'num_frames' in checkpoint_params.cfg else 2), \
54+
batch_size=1, shuffle=False))
5355
wandb.finish()
5456

5557
if __name__ == "__main__":

conf/eval.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ checkpoint: /home/kin/model_zoo/deflow.ckpt
44
av2_mode: val # [val, test]
55
save_res: False # [True, False]
66

7-
leaderboard_version: 2 # [1, 2]
7+
leaderboard_version: 1 # [1, 2]
88
supervised_flag: True # [True, False], whether you use any label from the dataset
99

1010
# no need to change

0 commit comments

Comments
 (0)