Skip to content

Commit cb30495

Browse files
Arm backend: Fix dynamic conv-padding condition (#18941)
Make sure that we don't implicitly cast SymBools to bools by expressions like if SymInt > int. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 9d72936 commit cb30495

3 files changed

Lines changed: 214 additions & 11 deletions

File tree

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
12-
1312
from executorch.backends.arm._passes.arm_pass_utils import (
1413
create_node,
1514
expand_around_channel,
@@ -24,9 +23,11 @@
2423
)
2524
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2625
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
26+
from executorch.backends.arm.tosa.specification import get_context_shape_env
2727
from executorch.backends.transforms.utils import create_constant_placeholder
2828
from executorch.exir.dialects._ops import ops as exir_ops
2929
from executorch.exir.pass_base import ExportPass, PassResult
30+
3031
from torch.export.graph_signature import InputKind
3132

3233

@@ -46,8 +47,13 @@ def __init__(self, exported_program: torch.export.ExportedProgram, *args, **kwar
4647
# to be an integer, but tosa currently strictly require this property.
4748
# This function adjusts the pad value to meet the requirement.
4849
def _adjust_pad_if_needed(
49-
self, input_len: int, input_weight: int, stride: int, pad: int, dilation: int
50-
) -> int:
50+
self,
51+
input_len: int | torch.SymInt,
52+
input_weight: int,
53+
stride: int,
54+
pad: int | torch.SymInt,
55+
dilation: int,
56+
) -> int | torch.SymInt:
5157
"""Adjust padding to satisfy TOSA's integer output-size requirement.
5258
5359
Torch ``Conv2d`` does not require the result of
@@ -75,11 +81,16 @@ def _adjust_pad_if_needed(
7581
input_len + 2 * pad - dilation * (input_weight - 1) - 1
7682
) % stride
7783

78-
# No need to adjust
79-
if mod_remainder == 0:
80-
return pad
84+
if isinstance(mod_remainder, torch.SymInt):
85+
shape_env = get_context_shape_env()
86+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
87+
mod_remainder_upper = int(value_ranges.upper)
88+
if mod_remainder_upper == 0:
89+
mod_remainder = 0
90+
else:
91+
mod_remainder_upper = mod_remainder
8192

82-
if mod_remainder > pad:
93+
if mod_remainder_upper > pad:
8394
raise RuntimeError(
8495
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
8596
)
@@ -319,7 +330,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
319330
stride,
320331
)
321332
else:
322-
pad_attr: list[int] = []
333+
pad_attr: list[int | torch.SymInt] = []
323334
for value in pad_list:
324335
pad_attr.extend(
325336
[value, value]

backends/arm/test/passes/test_insert_dynamic_padding_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from executorch.exir import to_edge
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch._export.utils import _get_shape_env_from_gm
1718
from torch.export import Dim, export
1819

1920

@@ -37,7 +38,10 @@ def test_insert_dynamic_padding():
3738
},
3839
)
3940
edge_model = to_edge(ep)
40-
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
41+
shape_env = _get_shape_env_from_gm(edge_model.exported_program().graph_module)
42+
with TosaLoweringContext(
43+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
44+
):
4145
edge_model = edge_model.transform(
4246
[RewriteConvPass(edge_model.exported_program())]
4347
)

backends/arm/test/passes/test_rewrite_conv_pass.py

Lines changed: 190 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import pytest
67
import torch
78
import torch.nn as nn
89
import torch.nn.functional as F
@@ -21,10 +22,15 @@
2122
DWConvsModule,
2223
)
2324
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
24-
from executorch.backends.arm.tosa.specification import TosaLoweringContext
25+
from executorch.backends.arm.tosa.specification import (
26+
TosaLoweringContext,
27+
TosaSpecification,
28+
)
2529
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
2630
from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower
2731
from executorch.exir.dialects._ops import ops as exir_ops
32+
from torch.export import Dim, export
33+
from torch.export.exported_program import _get_shape_env
2834

2935

3036
class TinyConvReluCat(nn.Module):
@@ -95,12 +101,45 @@ def _get_call_function_node(gm: torch.fx.GraphModule, target):
95101
raise AssertionError(f"Node with target {target} not found")
96102

97103

104+
class ConvModule(torch.nn.Module):
105+
def __init__(self):
106+
super().__init__()
107+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0)
108+
109+
def forward(self, x: torch.Tensor) -> torch.Tensor:
110+
return self.conv(x)
111+
112+
113+
def _make_rewrite_pass(
114+
example_inputs: tuple[torch.Tensor, ...],
115+
dynamic_shapes: dict[int, object] | None = None,
116+
) -> tuple[RewriteConvPass, object, int | torch.SymInt]:
117+
if dynamic_shapes is None:
118+
ep = export(ConvModule(), example_inputs)
119+
else:
120+
ep = export(ConvModule(), example_inputs, dynamic_shapes={"x": dynamic_shapes})
121+
edge_model = to_edge(ep)
122+
gm = edge_model.exported_program().graph_module
123+
conv_node = next(
124+
n for n in gm.graph.nodes if n.target == exir_ops.edge.aten.convolution.default
125+
)
126+
input_len = conv_node.args[0].meta["val"].shape[2]
127+
return RewriteConvPass(edge_model.exported_program()), _get_shape_env(gm), input_len
128+
129+
130+
def _multiples_of_three_dynamic_shapes() -> dict[int, object]:
131+
return {
132+
2: Dim("height", min=2, max=6) * 3,
133+
3: Dim("width", min=2, max=6) * 3,
134+
}
135+
136+
98137
def test_rewrite_conv_tosa_FP():
99138
module = DWConvsModule()
100139
pipeline = PassPipeline(
101140
module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass]
102141
)
103-
# We can't run TOSA backend dialect operators in eager mode
142+
# We cannot run TOSA backend dialect operators in eager mode.
104143
pipeline.pop_stage("run_method_and_compare_outputs")
105144
pipeline.run()
106145

@@ -149,3 +188,152 @@ def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> Non
149188

150189
assert len(bias_nodes) == 1
151190
assert bias_nodes[0].meta["val"].dtype == torch.int32
191+
192+
193+
def test_rewrite_conv_dynamic_keeps_static_padding_when_symbolic_remainder_is_zero():
194+
model = ConvModule()
195+
example_inputs = (torch.randn(1, 3, 9, 12),)
196+
ep = export(
197+
model,
198+
example_inputs,
199+
dynamic_shapes={"x": _multiples_of_three_dynamic_shapes()},
200+
)
201+
edge_model = to_edge(ep)
202+
shape_env = _get_shape_env(edge_model.exported_program().graph_module)
203+
with TosaLoweringContext(
204+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
205+
):
206+
edge_model = edge_model.transform(
207+
[RewriteConvPass(edge_model.exported_program())]
208+
)
209+
210+
conv_node = next(
211+
n
212+
for n in edge_model.exported_program().graph.nodes
213+
if n.target == exir_ops.backend.tosa.CONV2D.default
214+
)
215+
padding = conv_node.args[4]
216+
assert padding == [0, 0, 0, 0]
217+
assert all(not isinstance(p, torch.SymInt) for p in padding)
218+
219+
220+
def test_rewrite_conv_adjust_pad_if_needed_static_raises_before_negative_padding():
221+
rewrite_pass, _, _ = _make_rewrite_pass((torch.randn(1, 3, 9, 12),))
222+
223+
with pytest.raises(RuntimeError, match="SizeAdjustInputPass"):
224+
rewrite_pass._adjust_pad_if_needed(6, 2, 3, 0, 1)
225+
226+
227+
def test_rewrite_conv_adjust_pad_if_needed_static_positive_padding_stays_non_negative():
228+
rewrite_pass, _, _ = _make_rewrite_pass((torch.randn(1, 3, 9, 12),))
229+
230+
adjusted_pad = rewrite_pass._adjust_pad_if_needed(8, 2, 3, 2, 1)
231+
232+
assert adjusted_pad == 1
233+
234+
235+
def test_rewrite_conv_adjust_pad_if_needed_static_exact_remainder_matches_pad():
236+
rewrite_pass, _, _ = _make_rewrite_pass((torch.randn(1, 3, 9, 12),))
237+
238+
adjusted_pad = rewrite_pass._adjust_pad_if_needed(6, 1, 3, 1, 1)
239+
240+
assert adjusted_pad == 0
241+
242+
243+
def test_rewrite_conv_adjust_pad_if_needed_symbolic_exact_zero_keeps_zero_pad():
244+
rewrite_pass, shape_env, input_len = _make_rewrite_pass(
245+
(torch.randn(1, 3, 9, 12),),
246+
dynamic_shapes=_multiples_of_three_dynamic_shapes(),
247+
)
248+
249+
with TosaLoweringContext(
250+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
251+
):
252+
adjusted_pad = rewrite_pass._adjust_pad_if_needed(input_len, 3, 3, 0, 1)
253+
254+
assert adjusted_pad == 0
255+
256+
257+
def test_rewrite_conv_adjust_pad_if_needed_symbolic_exact_zero_keeps_positive_pad():
258+
rewrite_pass, shape_env, input_len = _make_rewrite_pass(
259+
(torch.randn(1, 3, 9, 12),),
260+
dynamic_shapes=_multiples_of_three_dynamic_shapes(),
261+
)
262+
263+
with TosaLoweringContext(
264+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
265+
):
266+
adjusted_pad = rewrite_pass._adjust_pad_if_needed(input_len, 2, 3, 1, 1)
267+
268+
assert adjusted_pad == 1
269+
270+
271+
def test_rewrite_conv_adjust_pad_if_needed_symbolic_positive_padding_range_raises_before_negative_padding():
272+
rewrite_pass, shape_env, input_len = _make_rewrite_pass(
273+
(torch.randn(1, 3, 8, 8),),
274+
dynamic_shapes={
275+
2: Dim("height", min=6, max=10),
276+
3: Dim("width", min=6, max=10),
277+
},
278+
)
279+
280+
with TosaLoweringContext(
281+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
282+
):
283+
with pytest.raises(RuntimeError, match="SizeAdjustInputPass"):
284+
rewrite_pass._adjust_pad_if_needed(input_len, 2, 3, 1, 1)
285+
286+
287+
def test_rewrite_conv_symbolic_comparison_with_int_specializes_to_hint():
288+
rewrite_pass, shape_env, input_len = _make_rewrite_pass(
289+
(torch.randn(1, 3, 8, 8),),
290+
dynamic_shapes={
291+
2: Dim("height", min=6, max=10),
292+
3: Dim("width", min=6, max=10),
293+
},
294+
)
295+
296+
def unsafe_adjust(input_len, input_weight, stride, pad, dilation):
297+
mod_remainder = (
298+
input_len + 2 * pad - dilation * (input_weight - 1) - 1
299+
) % stride
300+
if mod_remainder == 0:
301+
return pad
302+
if mod_remainder > pad:
303+
raise RuntimeError("SizeAdjustInputPass")
304+
return pad - mod_remainder
305+
306+
mod_remainder = (input_len - 2) % 3
307+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
308+
309+
assert value_ranges.lower == 0
310+
assert value_ranges.upper == 2
311+
assert len(shape_env.guards) == 0
312+
assert unsafe_adjust(input_len, 2, 3, 0, 1) == 0
313+
assert len(shape_env.guards) == 1
314+
assert shape_env.guards[-1].expr in {
315+
(mod_remainder == 0).node.expr,
316+
(mod_remainder <= 0).node.expr,
317+
}
318+
319+
with TosaLoweringContext(
320+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
321+
):
322+
with pytest.raises(RuntimeError, match="SizeAdjustInputPass"):
323+
rewrite_pass._adjust_pad_if_needed(input_len, 2, 3, 0, 1)
324+
325+
326+
def test_rewrite_conv_adjust_pad_if_needed_symbolic_zero_padding_range_raises_before_negative_padding():
327+
rewrite_pass, shape_env, input_len = _make_rewrite_pass(
328+
(torch.randn(1, 3, 8, 8),),
329+
dynamic_shapes={
330+
2: Dim("height", min=6, max=10),
331+
3: Dim("width", min=6, max=10),
332+
},
333+
)
334+
335+
with TosaLoweringContext(
336+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env
337+
):
338+
with pytest.raises(RuntimeError, match="SizeAdjustInputPass"):
339+
rewrite_pass._adjust_pad_if_needed(input_len, 2, 3, 0, 1)

0 commit comments

Comments
 (0)