@@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv(
438438 torch .ops .aten .permute .default ,
439439 (other_inputs [0 ], [0 , 2 , 1 ]), # NCL -> NLC
440440 )
441- assert "val" in other_inputs [0 ].meta , "Missing val metadata on input node"
442- original_val = other_inputs [0 ].meta ["val" ]
443- assert original_val .fake_mode is not None , "fake_mode is None on input node"
444- with original_val .fake_mode :
445- transposed_inputs .meta ["val" ] = torch .ops .aten .permute .default (
446- original_val , [0 , 2 , 1 ]
447- )
441+ # Propagate val metadata for transposed_inputs
442+ if "val" in other_inputs [0 ].meta :
443+ original_val = other_inputs [0 ].meta ["val" ]
444+ fake_mode = original_val .fake_mode
445+ if fake_mode is not None :
446+ with fake_mode :
447+ transposed_val = torch .ops .aten .permute .default (original_val , [0 , 2 , 1 ])
448+ transposed_inputs .meta ["val" ] = transposed_val
449+ else :
450+ transposed_inputs .meta ["val" ] = torch .ops .aten .permute .default (
451+ original_val , [0 , 2 , 1 ]
452+ )
448453 copy_node_metadata (transposed_inputs , other_inputs [0 ])
449454
450455 transposed_weights = graph_module .graph .call_function (
451456 torch .ops .aten .permute .default ,
452457 (weights_inputs [0 ], [2 , 0 , 1 ]), # NCL -> LNC
453458 )
454- assert "val" in weights_inputs [0 ].meta , "Missing val metadata on weight node"
455- original_val = weights_inputs [0 ].meta ["val" ]
456- assert original_val .fake_mode is not None , "fake_mode is None on weight node"
457- with original_val .fake_mode :
458- transposed_weights .meta ["val" ] = torch .ops .aten .permute .default (
459- original_val , [2 , 0 , 1 ]
460- )
459+ # Propagate val metadata for transposed_weights
460+ if "val" in weights_inputs [0 ].meta :
461+ original_val = weights_inputs [0 ].meta ["val" ]
462+ fake_mode = original_val .fake_mode
463+ if fake_mode is not None :
464+ with fake_mode :
465+ transposed_val = torch .ops .aten .permute .default (original_val , [2 , 0 , 1 ])
466+ transposed_weights .meta ["val" ] = transposed_val
467+ else :
468+ transposed_weights .meta ["val" ] = torch .ops .aten .permute .default (
469+ original_val , [2 , 0 , 1 ]
470+ )
461471 copy_node_metadata (transposed_weights , weights_inputs [0 ])
462472
463473 args = (
0 commit comments