Skip to content

Commit 651f2f2

Browse files
Arm backend: Enable dynamic shapes in DecomposeMeanDimPass (#18774)
With the upcoming SymInt, enable dynamic shapes in DecomposeMeanDimPass. Avg_pool2d operator is currently not supported. Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com> Co-authored-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 1533e55 commit 651f2f2

1 file changed

Lines changed: 92 additions & 29 deletions

File tree

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def get_meandim_decomposition(op) -> tuple:
3535
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
3636

3737

38+
def get_dynamic_meandim_decomposition(op) -> tuple:
39+
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
40+
return (
41+
exir_ops.edge.aten.sum.dim_IntList,
42+
exir_ops.edge.aten.mul.Tensor,
43+
exir_ops.edge.aten.full.default,
44+
exir_ops.edge.aten.reciprocal.default,
45+
exir_ops.edge.aten.expand_copy.default,
46+
)
47+
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
48+
raise NotImplementedError(
49+
"Dynamic mean.dim decomposition is not supported for torch.aten.mean."
50+
)
51+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
52+
53+
3854
def get_avgpool(op):
3955
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
4056
return exir_ops.edge.aten.avg_pool2d.default
@@ -103,26 +119,39 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
103119
self._tosa_spec, WhyNoPartitionReporter()
104120
)
105121

106-
def call_operator(self, op, args, kwargs, meta):
122+
def call_operator(self, op, args, kwargs, meta, updated=False):
107123
if op not in (
108124
exir_ops.edge.aten.mean.dim,
109125
torch.ops.aten.mean.dim,
110126
exir_ops.edge.aten.mean.default,
111127
torch.ops.aten.mean.default,
112128
) or not self.allowed_to_transform(meta):
113-
return super().call_operator(op, args, kwargs, meta)
129+
return super().call_operator(op, args, kwargs, meta, updated)
114130

115131
x = get_node_arg(args, 0)
116132
input_shape = list(x.data.shape)
117133
output_shape = list(meta["val"].shape)
134+
118135
dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
119136
if dims_to_reduce is None:
120137
dims_to_reduce = range(len(input_shape))
138+
121139
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
122-
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
140+
141+
has_symbolic_reduce_dim = any(
142+
isinstance(input_shape[dim], torch.SymInt) for dim in dims_to_reduce
143+
)
144+
if has_symbolic_reduce_dim and get_quantization(x.node.target) is not None:
145+
raise NotImplementedError(
146+
"Quantized mean.dim with symbolic reduced dimensions is not supported"
147+
)
123148

124149
view_op = get_view(op)
125150

151+
if not has_symbolic_reduce_dim:
152+
# for static shapes we should ensure that we only keep non 1 dimensions.
153+
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
154+
126155
# Reshape to 4D
127156
if len(input_shape) != 4:
128157
new_shape = copy(input_shape)
@@ -140,26 +169,66 @@ def call_operator(self, op, args, kwargs, meta):
140169
x = self._maybe_insert_q_dq_after(x, meta)
141170

142171
# Reduce (h,w) dims by avg pool if possible
143-
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
172+
if not has_symbolic_reduce_dim:
173+
x, dims_to_reduce = self._reduce_by_average_pool(
174+
op, x, dims_to_reduce, meta
175+
)
144176

145177
# Reshape back to 5D if necessary
146178
if len(input_shape) > 4:
147-
original_dims = input_shape[0:-3]
179+
original_dims = input_shape[:-3]
148180
temp_shape = list(x.data.shape)[1:]
149181
temp_shape = original_dims + temp_shape
150182
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
151183

152184
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
153185
x = self._maybe_insert_q_dq_after(x, meta)
154-
# Reduce remaining dims by sum
155-
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)
186+
187+
if has_symbolic_reduce_dim:
188+
x = self._reduce_by_sum_symbolic(op, x, dims_to_reduce, meta)
189+
else:
190+
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)
156191

157192
# Reshape to correct output shape if necessary
158193
if list(x.data.shape) != output_shape:
159194
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
160195

161196
return x
162197

198+
def _reduce_by_sum_symbolic(self, op, input_node, dims, meta):
199+
input_shape = input_node.data.size()
200+
reduced_shape = [input_shape[dim] for dim in dims]
201+
202+
sum_op, mul_op, full_op, recip_op, expand_op = (
203+
get_dynamic_meandim_decomposition(op)
204+
)
205+
206+
sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
207+
208+
ones = super().call_operator(
209+
full_op,
210+
([1], 1.0),
211+
{"dtype": meta.data["val"].dtype, "device": input_node.data.device},
212+
meta,
213+
True,
214+
)
215+
expanded_ones = super().call_operator(
216+
expand_op,
217+
(ones, reduced_shape),
218+
{},
219+
meta,
220+
True,
221+
)
222+
counts = super().call_operator(
223+
sum_op,
224+
(expanded_ones, list(range(len(reduced_shape))), True),
225+
{},
226+
meta,
227+
True,
228+
)
229+
recip = super().call_operator(recip_op, (counts,), {}, meta, True)
230+
return super().call_operator(mul_op, (sum, recip), {}, meta, True)
231+
163232
def _reduce_by_sum(self, op, input_node, dims, meta):
164233
if len(dims) == 0:
165234
return input_node
@@ -224,13 +293,9 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
224293
if is_supported:
225294
out = super().call_operator(avgpool_op, args, {}, meta, True)
226295
out = self._maybe_insert_q_dq_after(out, meta)
227-
return (
228-
out,
229-
dims_to_reduce_by_sum,
230-
)
296+
return out, dims_to_reduce_by_sum
231297

232-
else:
233-
return input_node, dims
298+
return input_node, dims
234299

235300
def _maybe_insert_q_dq_after(self, op, meta):
236301
"""If the input node of op is a dequant node, insert a q-dq pair after
@@ -242,20 +307,18 @@ def _maybe_insert_q_dq_after(self, op, meta):
242307
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
243308
)
244309
input_node = op.node.all_input_nodes[0]
245-
if (quant_ops := get_quantization(input_node.target)) is not None:
246-
q_op, dq_op = quant_ops
247-
quant_args = list(input_node.args[1:])
248-
q_args = (op, *quant_args)
249-
out = super().call_operator(
250-
q_op,
251-
q_args,
252-
kwargs={},
253-
meta=meta,
254-
updated=True,
255-
)
256-
dq_args = (out, *quant_args)
257-
return super().call_operator(
258-
dq_op, dq_args, kwargs={}, meta=meta, updated=True
259-
)
260-
else:
310+
if (quant_ops := get_quantization(input_node.target)) is None:
261311
return op
312+
313+
q_op, dq_op = quant_ops
314+
quant_args = list(input_node.args[1:])
315+
q_args = (op, *quant_args)
316+
out = super().call_operator(
317+
q_op,
318+
q_args,
319+
kwargs={},
320+
meta=meta,
321+
updated=True,
322+
)
323+
dq_args = (out, *quant_args)
324+
return super().call_operator(dq_op, dq_args, kwargs={}, meta=meta, updated=True)

0 commit comments

Comments
 (0)