@@ -223,7 +223,7 @@ def on_validation_epoch_end(self):
223223 # wandb.log_artifact(output_file)
224224 return
225225
226- if self .data_mode == 'val' :
226+ if self .data_mode in [ 'val' , 'valid' ] :
227227 print (f"\n Model: { self .model .__class__ .__name__ } , Checkpoint from: { self .checkpoint } " )
228228 print (f"More details parameters and training status are in the checkpoint file." )
229229
@@ -266,7 +266,7 @@ def eval_only_step_(self, batch, res_dict):
266266 else :
267267 final_flow [~ batch ['gm0' ]] = res_dict ['flow' ] + pose_flow [~ batch ['gm0' ]]
268268
269- if self .data_mode == 'val' : # since only val we have ground truth flow to eval
269+ if self .data_mode in [ 'val' , 'valid' ] : # since only val we have ground truth flow to eval
270270 gt_flow = batch ["flow" ]
271271 v1_dict = evaluate_leaderboard (final_flow [eval_mask ], pose_flow [eval_mask ], pc0 [eval_mask ], \
272272 gt_flow [eval_mask ], batch ['flow_is_valid' ][eval_mask ], \
@@ -306,7 +306,7 @@ def run_model_wo_ground_data(self, batch):
306306 return batch , res_dict
307307
308308 def validation_step (self , batch , batch_idx ):
309- if self .data_mode in ['val' , 'test' ]:
309+ if self .data_mode in ['val' , 'test' , 'valid' ]:
310310 batch , res_dict = self .run_model_wo_ground_data (batch )
311311 self .model .timer [13 ].start ("Eval" )
312312 self .eval_only_step_ (batch , res_dict )
0 commit comments