Skip to content

Commit b5cf3c3

Browse files
Arm backend: Eliminate dead shape ops (#18902)
Shape ops get their dim_order from their users. If a shape-op does not have any users, it will therefore not get a dim_order and crash during serialization. To avoid this, we can simply delete any unused shape op from the graph before serialization. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 87e65ac commit b5cf3c3

2 files changed

Lines changed: 87 additions & 0 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def call(self, graph_module: torch.fx.GraphModule):
465465
Entry point for the pass: annotate spatial ranks, compute dim orders,
466466
insert bridging transposes, and forward to child passes.
467467
"""
468+
graph_module.graph.eliminate_dead_code()
468469
nodes = list(graph_module.graph.nodes)
469470
for node in nodes:
470471
if not self._is_ok_for_annotation(node):

backends/arm/test/misc/test_const_shape.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,22 @@
55

66
from typing import Set, Type
77

8+
import executorch.backends.arm.tosa.dialect # noqa: F401
9+
import pytest
810
import torch
11+
import tosa_serializer as ts
912
from executorch.backends.arm._passes.arm_pass import ArmPass
13+
from executorch.backends.arm._passes.to_tosa_memory_format_pass import (
14+
ToTosaMemoryFormatPass,
15+
)
16+
from executorch.backends.arm.operators.node_visitor import get_node_visitors
17+
from executorch.backends.arm.process_node import process_call_function
1018
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
19+
from executorch.backends.arm.tosa.specification import (
20+
TosaLoweringContext,
21+
TosaSpecification,
22+
)
23+
from executorch.backends.test.graph_builder import GraphBuilder
1124
from executorch.exir import to_edge
1225
from executorch.exir.dialects._ops import ops as exir_ops
1326
from executorch.exir.pass_base import ExportPass
@@ -54,3 +67,76 @@ def forward(self, x):
5467
assert const_shape_nodes
5568
for n in const_shape_nodes:
5669
assert n.meta[TosaSpecialDtype.meta_key()] == TosaSpecialDtype.SHAPE
70+
71+
72+
def _graph_module_with_unused_const_shape():
73+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
74+
builder = GraphBuilder()
75+
builder.call_operator(exir_ops.backend.tosa.CONST_SHAPE.default, ([1],))
76+
live_const = builder.call_operator(
77+
exir_ops.backend.tosa.CONST_SHAPE.default, ([3],)
78+
)
79+
builder.output([live_const])
80+
graph_module = ExportPass().call(builder.get_graph_module()).graph_module
81+
for node in graph_module.graph.nodes:
82+
if node.op == "call_function":
83+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
84+
return graph_module
85+
86+
87+
def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None:
88+
output_node = next(node for node in graph_module.graph.nodes if node.op == "output")
89+
output_node.meta["tosa_dim_order"] = (0,)
90+
dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),))
91+
tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported)
92+
tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node)
93+
94+
95+
def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule):
96+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
97+
node_visitors = get_node_visitors(None, tosa_spec)
98+
tosa_graph = ts.TosaSerializer(
99+
"",
100+
targetMajor=tosa_spec.version.major,
101+
targetMinor=tosa_spec.version.minor,
102+
targetPatch=tosa_spec.version.micro,
103+
targetDraft=True,
104+
)
105+
106+
for node in graph_module.graph.nodes:
107+
if node.op == "call_function":
108+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
109+
110+
return tosa_graph
111+
112+
113+
def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization():
114+
graph_module = _graph_module_with_unused_const_shape()
115+
_propagate_shape_dim_orders_from_users(graph_module)
116+
117+
const_shape_nodes = [
118+
node
119+
for node in graph_module.graph.nodes
120+
if node.op == "call_function"
121+
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
122+
]
123+
dead_const_shape, live_const_shape = const_shape_nodes
124+
125+
assert dead_const_shape.users == {}
126+
assert "tosa_dim_order" not in dead_const_shape.meta
127+
assert live_const_shape.meta["tosa_dim_order"] == (0,)
128+
129+
with pytest.raises(KeyError, match="tosa_dim_order"):
130+
_serialize_graph_module_to_tosa(graph_module)
131+
132+
graph_module.graph.eliminate_dead_code()
133+
graph_module.recompile()
134+
135+
remaining_const_shape = next(
136+
node
137+
for node in graph_module.graph.nodes
138+
if node.op == "call_function"
139+
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
140+
)
141+
assert remaining_const_shape.meta["tosa_dim_order"] == (0,)
142+
assert _serialize_graph_module_to_tosa(graph_module)

0 commit comments

Comments
 (0)