@@ -52,7 +52,7 @@ def seflowLoss(res_dict, timer=None):
5252 # NOTE(Qingwen): add in the later part on label==0
5353 static_cluster_loss = torch .tensor (0.0 , device = est_flow .device )
5454
55- # fourth item loss: same label points' flow should be same
55+ # fourth item loss: same label points' flow should be the same
5656 # timer[5][3].start("SameClusterLoss")
5757 moved_cluster_loss = torch .tensor (0.0 , device = est_flow .device )
5858 moved_cluster_norms = torch .tensor ([], device = est_flow .device )
@@ -67,13 +67,13 @@ def seflowLoss(res_dict, timer=None):
6767 if cluster_nnd .shape [0 ] <= 0 :
6868 continue
6969
70- # Eq. 8 in the paper and with truncated
71- k = min ( 10 , cluster_nnd . shape [ 0 ] )
72- top_dis , top_idx = torch . topk ( cluster_nnd , k = k , largest = True )
73- for ii in range ( k ):
74- if pc1_label [ raw_idx0 [ mask ][ top_idx [ ii ]]] > 0 and top_dis [ ii ] <= TRUNCATED_DIST :
75- break
76- max_idx = top_idx [ ii ]
70+ # Eq. 8 in the paper
71+ sorted_idxs = torch . argsort ( cluster_nnd , descending = True )
72+ nearby_label = pc1_label [ raw_idx0 [ mask ][ sorted_idxs ]] # nonzero means dynamic in label
73+ non_zero_valid_indices = torch . nonzero ( nearby_label > 0 )
74+ if non_zero_valid_indices . shape [ 0 ] <= 0 :
75+ continue
76+ max_idx = sorted_idxs [ non_zero_valid_indices . squeeze ( 1 )[ 0 ] ]
7777
7878 # Eq. 9 in the paper
7979 max_flow = pc1 [raw_idx0 [mask ][max_idx ]] - pc0 [mask ][max_idx ]
@@ -150,4 +150,4 @@ def ff3dLoss(res_dict):
150150 is_foreground_class = (classes > 0 ) # 0 is background, ref: FOREGROUND_BACKGROUND_BREAKDOWN
151151 background_scalar = is_foreground_class .float () * 0.9 + 0.1
152152 error = error * background_scalar
153- return {'loss' : error .mean ()}
153+ return {'loss' : error .mean ()}
0 commit comments