|
5 | 5 |
|
6 | 6 | from typing import Set, Type |
7 | 7 |
|
| 8 | +import executorch.backends.arm.tosa.dialect # noqa: F401 |
| 9 | +import pytest |
8 | 10 | import torch |
| 11 | +import tosa_serializer as ts |
9 | 12 | from executorch.backends.arm._passes.arm_pass import ArmPass |
| 13 | +from executorch.backends.arm._passes.to_tosa_memory_format_pass import ( |
| 14 | + ToTosaMemoryFormatPass, |
| 15 | +) |
| 16 | +from executorch.backends.arm.operators.node_visitor import get_node_visitors |
| 17 | +from executorch.backends.arm.process_node import process_call_function |
10 | 18 | from executorch.backends.arm.tosa.mapping import TosaSpecialDtype |
| 19 | +from executorch.backends.arm.tosa.specification import ( |
| 20 | + TosaLoweringContext, |
| 21 | + TosaSpecification, |
| 22 | +) |
| 23 | +from executorch.backends.test.graph_builder import GraphBuilder |
11 | 24 | from executorch.exir import to_edge |
12 | 25 | from executorch.exir.dialects._ops import ops as exir_ops |
13 | 26 | from executorch.exir.pass_base import ExportPass |
@@ -54,3 +67,76 @@ def forward(self, x): |
54 | 67 | assert const_shape_nodes |
55 | 68 | for n in const_shape_nodes: |
56 | 69 | assert n.meta[TosaSpecialDtype.meta_key()] == TosaSpecialDtype.SHAPE |
| 70 | + |
| 71 | + |
| 72 | +def _graph_module_with_unused_const_shape(): |
| 73 | + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): |
| 74 | + builder = GraphBuilder() |
| 75 | + builder.call_operator(exir_ops.backend.tosa.CONST_SHAPE.default, ([1],)) |
| 76 | + live_const = builder.call_operator( |
| 77 | + exir_ops.backend.tosa.CONST_SHAPE.default, ([3],) |
| 78 | + ) |
| 79 | + builder.output([live_const]) |
| 80 | + graph_module = ExportPass().call(builder.get_graph_module()).graph_module |
| 81 | + for node in graph_module.graph.nodes: |
| 82 | + if node.op == "call_function": |
| 83 | + node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE |
| 84 | + return graph_module |
| 85 | + |
| 86 | + |
| 87 | +def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None: |
| 88 | + output_node = next(node for node in graph_module.graph.nodes if node.op == "output") |
| 89 | + output_node.meta["tosa_dim_order"] = (0,) |
| 90 | + dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),)) |
| 91 | + tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported) |
| 92 | + tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node) |
| 93 | + |
| 94 | + |
| 95 | +def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule): |
| 96 | + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape") |
| 97 | + node_visitors = get_node_visitors(None, tosa_spec) |
| 98 | + tosa_graph = ts.TosaSerializer( |
| 99 | + "", |
| 100 | + targetMajor=tosa_spec.version.major, |
| 101 | + targetMinor=tosa_spec.version.minor, |
| 102 | + targetPatch=tosa_spec.version.micro, |
| 103 | + targetDraft=True, |
| 104 | + ) |
| 105 | + |
| 106 | + for node in graph_module.graph.nodes: |
| 107 | + if node.op == "call_function": |
| 108 | + process_call_function(node, tosa_graph, node_visitors, tosa_spec) |
| 109 | + |
| 110 | + return tosa_graph |
| 111 | + |
| 112 | + |
| 113 | +def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization(): |
| 114 | + graph_module = _graph_module_with_unused_const_shape() |
| 115 | + _propagate_shape_dim_orders_from_users(graph_module) |
| 116 | + |
| 117 | + const_shape_nodes = [ |
| 118 | + node |
| 119 | + for node in graph_module.graph.nodes |
| 120 | + if node.op == "call_function" |
| 121 | + and node.target == exir_ops.backend.tosa.CONST_SHAPE.default |
| 122 | + ] |
| 123 | + dead_const_shape, live_const_shape = const_shape_nodes |
| 124 | + |
| 125 | + assert dead_const_shape.users == {} |
| 126 | + assert "tosa_dim_order" not in dead_const_shape.meta |
| 127 | + assert live_const_shape.meta["tosa_dim_order"] == (0,) |
| 128 | + |
| 129 | + with pytest.raises(KeyError, match="tosa_dim_order"): |
| 130 | + _serialize_graph_module_to_tosa(graph_module) |
| 131 | + |
| 132 | + graph_module.graph.eliminate_dead_code() |
| 133 | + graph_module.recompile() |
| 134 | + |
| 135 | + remaining_const_shape = next( |
| 136 | + node |
| 137 | + for node in graph_module.graph.nodes |
| 138 | + if node.op == "call_function" |
| 139 | + and node.target == exir_ops.backend.tosa.CONST_SHAPE.default |
| 140 | + ) |
| 141 | + assert remaining_const_shape.meta["tosa_dim_order"] == (0,) |
| 142 | + assert _serialize_graph_module_to_tosa(graph_module) |
0 commit comments