Skip to content

Commit 6db7f4c

Browse files
Arm backend: Propagate dim_order to TOSA-shape ops
Make sure that to_tosa_memory_format_pass propagates tosa_dim_order to TOSA-shape ops. These are special as the rank is derived from len(output.shape[0]) rather than len(output.shape). Co-authored-by: Per Åstrand <per.astrand@arm.com> Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: Id5861e4dc018c56ca95cdbe358507dfc7f706b78
1 parent 95dcfb9 commit 6db7f4c

3 files changed

Lines changed: 51 additions & 7 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_param_node,
1616
)
1717
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
18+
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
1819
from executorch.exir import ExportedProgram
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.pass_base import ExportPass, PassResult
@@ -404,22 +405,57 @@ def remove_dim_order_kwargs(
404405

405406
node.kwargs = kwargs
406407

408+
def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None:
409+
for arg in node.all_input_nodes:
410+
if is_shape_op_node(arg):
411+
# Shape nodes may get its dim_order from multiple users. Keep track of old dim_order to make sure all
412+
# users agree on the same dim_order, otherwise we may end up with non-deterministic dim_orders for
413+
# shape nodes depending on the order of user traversal.
414+
old_dim_order = arg.meta.get("tosa_dim_order", None) is not None
415+
dim_order = node.meta["tosa_dim_order"]
416+
if len(dim_order) != len(arg.meta["val"]):
417+
dim_order = tuple(range(len(arg.meta["val"])))
418+
if old_dim_order and arg.meta["tosa_dim_order"] != dim_order:
419+
raise RuntimeError(
420+
f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}"
421+
)
422+
arg.meta["tosa_dim_order"] = dim_order
423+
self._propagate_dim_order_to_shape_args(arg)
424+
425+
def _annotate_shape_nodes(self, graph_module: torch.fx.GraphModule) -> None:
426+
for node in graph_module.graph.nodes:
427+
if not self._is_ok_for_annotation(node):
428+
continue
429+
self._propagate_dim_order_to_shape_args(node)
430+
431+
def _is_ok_for_annotation(self, node: torch.fx.Node) -> bool:
432+
if "val" not in node.meta:
433+
return False
434+
# Shape-only nodes which produce SymInt[] rather than real tensors are annotated separately by propagating dim order from their users.
435+
# We must therefore annotate all valid nodes before propagating dim order upwards in graph.
436+
if is_shape_op_node(node):
437+
return False
438+
# For some models, the symbolic value is passed to the graph, skip it
439+
if isinstance(node.meta["val"], torch.SymInt):
440+
return False
441+
return True
442+
407443
def call(self, graph_module: torch.fx.GraphModule):
408444
"""
409445
Entry point for the pass: annotate spatial ranks, compute dim orders,
410446
insert bridging transposes, and forward to child passes.
411447
"""
412448
nodes = list(graph_module.graph.nodes)
413449
for node in nodes:
414-
if "val" not in node.meta:
450+
if not self._is_ok_for_annotation(node):
415451
continue
416452
node.meta["tosa_spatial_rank"] = self._initial_spatial_rank(node)
417453
self.remove_dim_order_kwargs(graph_module, node)
418454

419455
self._propagate_spatial_ranks(nodes)
420456

421457
for node in nodes:
422-
if "val" not in node.meta:
458+
if not self._is_ok_for_annotation(node):
423459
continue
424460
node_data = get_first_fake_tensor(node).data
425461
spatial_rank = node.meta["tosa_spatial_rank"]
@@ -437,6 +473,9 @@ def call(self, graph_module: torch.fx.GraphModule):
437473
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
438474
# See insert_tosa_transposes for insertion conditions.
439475
self.insert_tosa_transposes(graph_module)
476+
# Special handling is needed for shape nodes as they don't have real tensors or real dim orders, but the order
477+
# still needs to be propagated to them so that they can be serialized with the correct order and shapes.
478+
self._annotate_shape_nodes(graph_module)
440479
graph_module.recompile()
441480
graph_module = super().call(graph_module).graph_module
442481

@@ -450,7 +489,7 @@ def _propagate_spatial_ranks(self, nodes):
450489
while changed:
451490
changed = False
452491
for node in reversed(nodes):
453-
if "val" not in node.meta:
492+
if not self._is_ok_for_annotation(node):
454493
continue
455494
tensor = get_first_fake_tensor(node)
456495
limit = max(tensor.dim() - 2, 0)

backends/arm/operators/op_tosa_shapes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
import torch
1111

1212
import tosa_serializer as ts # type: ignore
13-
1413
from executorch.backends.arm.operators.node_visitor import (
1514
NodeVisitor,
1615
register_node_visitor,
1716
)
1817
from executorch.backends.arm.tosa.mapping import TosaArg
18+
from executorch.backends.arm.tosa.utils import tosa_shape
1919

2020

2121
@register_node_visitor
@@ -33,10 +33,15 @@ def define_node(
3333
output: TosaArg,
3434
) -> None:
3535
shape_input = inputs[0].special
36+
rank = len(shape_input)
37+
tosa_dim_order = output.dim_order
38+
vals = tosa_shape(node.meta["val"], tosa_dim_order)
3639
tosa_graph = cast(ts.TosaSerializer, tosa_graph)
3740
tosa_graph.addConst(
38-
shape_input,
41+
[
42+
rank,
43+
],
3944
dtype=ts.DType.SHAPE,
40-
vals=node.meta["val"],
45+
vals=vals,
4146
name=output.name,
4247
)

backends/arm/tosa/mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def extract_tensor_meta(meta):
131131
special_dtype = meta.get(TosaSpecialDtype.meta_key())
132132
if special_dtype == TosaSpecialDtype.SHAPE:
133133
shape_len = len(meta["val"])
134-
return (ts.DType.SHAPE, (shape_len,), (0,))
134+
return (ts.DType.SHAPE, (shape_len,), meta["tosa_dim_order"])
135135

136136
if meta.get("val") is None:
137137
raise ValueError("Expected node.meta['val'] to be set to a FakeTensor")

0 commit comments

Comments
 (0)