@@ -267,15 +267,17 @@ def _run_process(cfg, mode):
267267 final_metrics .epe_3way [key ].extend (val_list )
268268
269269 for class_idx , class_name in enumerate (metrics_obj .bucketedMatrix .class_names ):
270- for range_idx , range_bucket in enumerate (metrics_obj .bucketedMatrix .range_buckets ):
271- count = metrics_obj .bucketedMatrix .count_storage_matrix [class_idx , range_idx ]
270+ # NOTE(Qingwen): for bucketedMatrix range_buckets = speed_buckets
271+ for speed_idx , speed_bucket in enumerate (metrics_obj .bucketedMatrix .range_buckets ):
272+ count = metrics_obj .bucketedMatrix .count_storage_matrix [class_idx , speed_idx ]
272273 if count > 0 :
273- avg_epe = metrics_obj .bucketedMatrix .epe_storage_matrix [class_idx , range_idx ]
274- avg_range = metrics_obj .bucketedMatrix .range_storage_matrix [class_idx , range_idx ]
274+ avg_epe = metrics_obj .bucketedMatrix .epe_storage_matrix [class_idx , speed_idx ]
275+ avg_speed = metrics_obj .bucketedMatrix .range_storage_matrix [class_idx , speed_idx ]
275276 final_metrics .bucketedMatrix .accumulate_value (
276- class_name , range_bucket , avg_epe , avg_range , count
277+ class_name , speed_bucket , avg_epe , avg_speed , count
277278 )
278279 for class_idx , class_name in enumerate (metrics_obj .distanceMatrix .class_names ):
280+ # NOTE(Qingwen): for distanceMatrix range_buckets = distance_buckets
279281 for range_idx , range_bucket in enumerate (metrics_obj .distanceMatrix .range_buckets ):
280282 count = metrics_obj .distanceMatrix .count_storage_matrix [class_idx , range_idx ]
281283 if count > 0 :
@@ -304,7 +306,7 @@ def _spawn_wrapper(rank, world_size, cfg, mode):
304306 os .environ ['RANK' ] = str (rank )
305307 os .environ ['WORLD_SIZE' ] = str (world_size )
306308 os .environ ['MASTER_ADDR' ] = 'localhost'
307- os .environ ['MASTER_PORT' ] = cfg .get ('master_port' , ' 12355' )
309+ os .environ ['MASTER_PORT' ] = str ( cfg .get ('master_port' , 12355 ) )
308310 _run_process (cfg , mode )
309311
310312def launch_runner (cfg , mode ):
0 commit comments