Skip to content

Commit 341039f

Browse files
GregoryComerjpiat
authored andcommitted
Update emformer tests to avoid 0/1 specialization issue (pytorch#18850)
Summary: The tests break in some dynamo configurations. Set min batch to 2 to be exportable in all. Reviewed By: abhinaykukkadapu Differential Revision: D100391669
1 parent d8cd9cc commit 341039f

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

backends/test/harness/tester.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def generate_random_inputs(self):
131131
assert isinstance(self.example_inputs[arg_idx], torch.Tensor)
132132
ex_shape = list(self.example_inputs[arg_idx].shape)
133133
dynamic_dim_spec = self.dynamic_shapes[arg_idx]
134+
if dynamic_dim_spec is None or dynamic_dim_spec == {}:
135+
input_shapes.append(torch.Size(ex_shape))
136+
continue
134137
for dim_idx, dim_spec in dynamic_dim_spec.items():
135138
assert dim_idx < len(ex_shape)
136139
if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim):

backends/xnnpack/test/models/emformer_rnnt.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,20 @@ def test_fp32_emformer_joiner(self):
5050

5151
def test_fp32_emformer_joiner_dynamic(self):
5252
joiner = self.Joiner()
53+
example_inputs = (
54+
torch.rand([2, 128, 1024]),
55+
torch.tensor([128]),
56+
torch.rand([2, 128, 1024]),
57+
torch.tensor([128]),
58+
)
5359
dynamic_shapes = (
5460
{0: torch.export.Dim("batch", min=1, max=4)},
5561
None,
5662
{0: torch.export.Dim("batch", min=1, max=4)},
5763
None,
5864
)
5965
(
60-
Tester(joiner, joiner.get_example_inputs(), dynamic_shapes=dynamic_shapes)
66+
Tester(joiner, example_inputs, dynamic_shapes=dynamic_shapes)
6167
.export()
6268
.to_edge_transform_and_lower()
6369
.check(["torch.ops.higher_order.executorch_call_delegate"])
@@ -117,14 +123,18 @@ def test_fp32_emformer_transcriber(self):
117123

118124
def test_fp32_emformer_transcriber_dynamic(self):
119125
transcriber = self.Transcriber()
126+
example_inputs = (
127+
torch.randn(2, 128, 80),
128+
torch.tensor([128]),
129+
)
120130
dynamic_shapes = (
121131
{0: torch.export.Dim("batch", min=1, max=4)},
122132
None,
123133
)
124134
(
125135
Tester(
126136
transcriber,
127-
transcriber.get_example_inputs(),
137+
example_inputs,
128138
dynamic_shapes=dynamic_shapes,
129139
)
130140
.export()

0 commit comments

Comments
 (0)