1515 is_param_node ,
1616)
1717from executorch .backends .arm .constants import NCHW_ORDER , NNCHW_ORDER , NNNCHW_ORDER
18+ from executorch .backends .arm .tosa .dialect .shape import is_shape_op_node
1819from executorch .exir import ExportedProgram
1920from executorch .exir .dialects ._ops import ops as exir_ops
2021from executorch .exir .pass_base import ExportPass , PassResult
@@ -404,22 +405,57 @@ def remove_dim_order_kwargs(
404405
405406 node .kwargs = kwargs
406407
408+ def _propagate_dim_order_to_shape_args (self , node : torch .fx .Node ) -> None :
409+ for arg in node .all_input_nodes :
410+ if is_shape_op_node (arg ):
411+ # Shape nodes may get its dim_order from multiple users. Keep track of old dim_order to make sure all
412+ # users agree on the same dim_order, otherwise we may end up with non-deterministic dim_orders for
413+ # shape nodes depending on the order of user traversal.
414+ old_dim_order = arg .meta .get ("tosa_dim_order" , None ) is not None
415+ dim_order = node .meta ["tosa_dim_order" ]
416+ if len (dim_order ) != len (arg .meta ["val" ]):
417+ dim_order = tuple (range (len (arg .meta ["val" ])))
418+ if old_dim_order and arg .meta ["tosa_dim_order" ] != dim_order :
419+ raise RuntimeError (
420+ f"Conflicting dim orders { arg .meta ['tosa_dim_order' ]} and { dim_order } for shape node { arg .name } "
421+ )
422+ arg .meta ["tosa_dim_order" ] = dim_order
423+ self ._propagate_dim_order_to_shape_args (arg )
424+
425+ def _annotate_shape_nodes (self , graph_module : torch .fx .GraphModule ) -> None :
426+ for node in graph_module .graph .nodes :
427+ if not self ._is_ok_for_annotation (node ):
428+ continue
429+ self ._propagate_dim_order_to_shape_args (node )
430+
431+ def _is_ok_for_annotation (self , node : torch .fx .Node ) -> bool :
432+ if "val" not in node .meta :
433+ return False
434+ # Shape-only nodes which produce SymInt[] rather than real tensors are annotated separately by propagating dim order from their users.
435+ # We must therefore annotate all valid nodes before propagating dim order upwards in graph.
436+ if is_shape_op_node (node ):
437+ return False
438+ # For some models, the symbolic value is passed to the graph, skip it
439+ if isinstance (node .meta ["val" ], torch .SymInt ):
440+ return False
441+ return True
442+
407443 def call (self , graph_module : torch .fx .GraphModule ):
408444 """
409445 Entry point for the pass: annotate spatial ranks, compute dim orders,
410446 insert bridging transposes, and forward to child passes.
411447 """
412448 nodes = list (graph_module .graph .nodes )
413449 for node in nodes :
414- if "val" not in node . meta :
450+ if not self . _is_ok_for_annotation ( node ) :
415451 continue
416452 node .meta ["tosa_spatial_rank" ] = self ._initial_spatial_rank (node )
417453 self .remove_dim_order_kwargs (graph_module , node )
418454
419455 self ._propagate_spatial_ranks (nodes )
420456
421457 for node in nodes :
422- if "val" not in node . meta :
458+ if not self . _is_ok_for_annotation ( node ) :
423459 continue
424460 node_data = get_first_fake_tensor (node ).data
425461 spatial_rank = node .meta ["tosa_spatial_rank" ]
@@ -437,6 +473,9 @@ def call(self, graph_module: torch.fx.GraphModule):
437473 # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
438474 # See insert_tosa_transposes for insertion conditions.
439475 self .insert_tosa_transposes (graph_module )
476+ # Special handling is needed for shape nodes as they don't have real tensors or real dim orders, but the order
477+ # still needs to be propagated to them so that they can be serialized with the correct order and shapes.
478+ self ._annotate_shape_nodes (graph_module )
440479 graph_module .recompile ()
441480 graph_module = super ().call (graph_module ).graph_module
442481
@@ -450,7 +489,7 @@ def _propagate_spatial_ranks(self, nodes):
450489 while changed :
451490 changed = False
452491 for node in reversed (nodes ):
453- if "val" not in node . meta :
492+ if not self . _is_ok_for_annotation ( node ) :
454493 continue
455494 tensor = get_first_fake_tensor (node )
456495 limit = max (tensor .dim () - 2 , 0 )
0 commit comments