@@ -35,6 +35,22 @@ def get_meandim_decomposition(op) -> tuple:
3535 raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
3636
3737
38+ def get_dynamic_meandim_decomposition (op ) -> tuple :
39+ if op in (exir_ops .edge .aten .mean .dim , exir_ops .edge .aten .mean .default ):
40+ return (
41+ exir_ops .edge .aten .sum .dim_IntList ,
42+ exir_ops .edge .aten .mul .Tensor ,
43+ exir_ops .edge .aten .full .default ,
44+ exir_ops .edge .aten .reciprocal .default ,
45+ exir_ops .edge .aten .expand_copy .default ,
46+ )
47+ if op in (torch .ops .aten .mean .dim , torch .ops .aten .mean .default ):
48+ raise NotImplementedError (
49+ "Dynamic mean.dim decomposition is not supported for torch.aten.mean."
50+ )
51+ raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
52+
53+
3854def get_avgpool (op ):
3955 if op in (exir_ops .edge .aten .mean .dim , exir_ops .edge .aten .mean .default ):
4056 return exir_ops .edge .aten .avg_pool2d .default
@@ -103,26 +119,39 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
103119 self ._tosa_spec , WhyNoPartitionReporter ()
104120 )
105121
106- def call_operator (self , op , args , kwargs , meta ):
122+ def call_operator (self , op , args , kwargs , meta , updated = False ):
107123 if op not in (
108124 exir_ops .edge .aten .mean .dim ,
109125 torch .ops .aten .mean .dim ,
110126 exir_ops .edge .aten .mean .default ,
111127 torch .ops .aten .mean .default ,
112128 ) or not self .allowed_to_transform (meta ):
113- return super ().call_operator (op , args , kwargs , meta )
129+ return super ().call_operator (op , args , kwargs , meta , updated )
114130
115131 x = get_node_arg (args , 0 )
116132 input_shape = list (x .data .shape )
117133 output_shape = list (meta ["val" ].shape )
134+
118135 dims_to_reduce = get_node_arg (args , 1 , range (len (input_shape )))
119136 if dims_to_reduce is None :
120137 dims_to_reduce = range (len (input_shape ))
138+
121139 dims_to_reduce = [dim % len (input_shape ) for dim in dims_to_reduce ]
122- dims_to_reduce = [dim for dim in dims_to_reduce if input_shape [dim ] != 1 ]
140+
141+ has_symbolic_reduce_dim = any (
142+ isinstance (input_shape [dim ], torch .SymInt ) for dim in dims_to_reduce
143+ )
144+ if has_symbolic_reduce_dim and get_quantization (x .node .target ) is not None :
145+ raise NotImplementedError (
146+ "Quantized mean.dim with symbolic reduced dimensions is not supported"
147+ )
123148
124149 view_op = get_view (op )
125150
151+ if not has_symbolic_reduce_dim :
152+ # for static shapes we should ensure that we only keep non 1 dimensions.
153+ dims_to_reduce = [dim for dim in dims_to_reduce if input_shape [dim ] != 1 ]
154+
126155 # Reshape to 4D
127156 if len (input_shape ) != 4 :
128157 new_shape = copy (input_shape )
@@ -140,26 +169,66 @@ def call_operator(self, op, args, kwargs, meta):
140169 x = self ._maybe_insert_q_dq_after (x , meta )
141170
142171 # Reduce (h,w) dims by avg pool if possible
143- x , dims_to_reduce = self ._reduce_by_average_pool (op , x , dims_to_reduce , meta )
172+ if not has_symbolic_reduce_dim :
173+ x , dims_to_reduce = self ._reduce_by_average_pool (
174+ op , x , dims_to_reduce , meta
175+ )
144176
145177 # Reshape back to 5D if necessary
146178 if len (input_shape ) > 4 :
147- original_dims = input_shape [0 :- 3 ]
179+ original_dims = input_shape [:- 3 ]
148180 temp_shape = list (x .data .shape )[1 :]
149181 temp_shape = original_dims + temp_shape
150182 dims_to_reduce = [dim + len (original_dims ) - 1 for dim in dims_to_reduce ]
151183
152184 x = super ().call_operator (view_op , (x , temp_shape ), {}, meta , True )
153185 x = self ._maybe_insert_q_dq_after (x , meta )
154- # Reduce remaining dims by sum
155- x = self ._reduce_by_sum (op , x , dims_to_reduce , meta )
186+
187+ if has_symbolic_reduce_dim :
188+ x = self ._reduce_by_sum_symbolic (op , x , dims_to_reduce , meta )
189+ else :
190+ x = self ._reduce_by_sum (op , x , dims_to_reduce , meta )
156191
157192 # Reshape to correct output shape if necessary
158193 if list (x .data .shape ) != output_shape :
159194 x = super ().call_operator (view_op , (x , output_shape ), {}, meta , True )
160195
161196 return x
162197
198+ def _reduce_by_sum_symbolic (self , op , input_node , dims , meta ):
199+ input_shape = input_node .data .size ()
200+ reduced_shape = [input_shape [dim ] for dim in dims ]
201+
202+ sum_op , mul_op , full_op , recip_op , expand_op = (
203+ get_dynamic_meandim_decomposition (op )
204+ )
205+
206+ sum = super ().call_operator (sum_op , (input_node , dims , True ), {}, meta , True )
207+
208+ ones = super ().call_operator (
209+ full_op ,
210+ ([1 ], 1.0 ),
211+ {"dtype" : meta .data ["val" ].dtype , "device" : input_node .data .device },
212+ meta ,
213+ True ,
214+ )
215+ expanded_ones = super ().call_operator (
216+ expand_op ,
217+ (ones , reduced_shape ),
218+ {},
219+ meta ,
220+ True ,
221+ )
222+ counts = super ().call_operator (
223+ sum_op ,
224+ (expanded_ones , list (range (len (reduced_shape ))), True ),
225+ {},
226+ meta ,
227+ True ,
228+ )
229+ recip = super ().call_operator (recip_op , (counts ,), {}, meta , True )
230+ return super ().call_operator (mul_op , (sum , recip ), {}, meta , True )
231+
163232 def _reduce_by_sum (self , op , input_node , dims , meta ):
164233 if len (dims ) == 0 :
165234 return input_node
@@ -224,13 +293,9 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
224293 if is_supported :
225294 out = super ().call_operator (avgpool_op , args , {}, meta , True )
226295 out = self ._maybe_insert_q_dq_after (out , meta )
227- return (
228- out ,
229- dims_to_reduce_by_sum ,
230- )
296+ return out , dims_to_reduce_by_sum
231297
232- else :
233- return input_node , dims
298+ return input_node , dims
234299
235300 def _maybe_insert_q_dq_after (self , op , meta ):
236301 """If the input node of op is a dequant node, insert a q-dq pair after
@@ -242,20 +307,18 @@ def _maybe_insert_q_dq_after(self, op, meta):
242307 f"Expected one input to { op .node } , got inputs { op .node .all_input_nodes } "
243308 )
244309 input_node = op .node .all_input_nodes [0 ]
245- if (quant_ops := get_quantization (input_node .target )) is not None :
246- q_op , dq_op = quant_ops
247- quant_args = list (input_node .args [1 :])
248- q_args = (op , * quant_args )
249- out = super ().call_operator (
250- q_op ,
251- q_args ,
252- kwargs = {},
253- meta = meta ,
254- updated = True ,
255- )
256- dq_args = (out , * quant_args )
257- return super ().call_operator (
258- dq_op , dq_args , kwargs = {}, meta = meta , updated = True
259- )
260- else :
310+ if (quant_ops := get_quantization (input_node .target )) is None :
261311 return op
312+
313+ q_op , dq_op = quant_ops
314+ quant_args = list (input_node .args [1 :])
315+ q_args = (op , * quant_args )
316+ out = super ().call_operator (
317+ q_op ,
318+ q_args ,
319+ kwargs = {},
320+ meta = meta ,
321+ updated = True ,
322+ )
323+ dq_args = (out , * quant_args )
324+ return super ().call_operator (dq_op , dq_args , kwargs = {}, meta = meta , updated = True )
0 commit comments