Skip to content

Commit c0c1012

Browse files
committed
Fix egraph explosion from expanding rewrite rules
Disable or reverse direction of rules that create more nodes than they consume, causing exponential e-graph growth in matrix-heavy shaders: - bitwise.egg: Reverse De Morgan direction to pull BitNot OUTSIDE (simplifying 3→1 nodes) instead of pushing INSIDE (expanding 1→3) - rvsdg.egg: Disable Gamma distribution INTO rules (Op(Gamma,Gamma) → Gamma(Op,Op)) which create new inner nodes and cycle with the hoisting rules that pull common factors OUT - glsl.egg: Disable Fma recognition (FAdd(FMul(a,b),c) → Fma(a,b,c)) which fires on every float multiply-add, creating massive node counts in matrix-heavy code - arithmetic.egg: Reverse Neg distribution to pull Neg OUTSIDE Add (factoring 2 Negs → 1) instead of pushing INSIDE (expanding 1 → 2) - vector.egg: Disable Dot distribution over VecAdd and VecTimesScalar distribution over VecAdd, both of which create 2+ new nodes per firing
1 parent 86dfbaf commit c0c1012

5 files changed

Lines changed: 49 additions & 47 deletions

File tree

rust/spirv-tools-opt/src/rules/arithmetic.egg

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
((union e (Neg x))))
3333
(rule ((= e (Mul (Const -1) x)))
3434
((union e (Neg x))))
35-
; Neg distribution - ONE-DIRECTIONAL: push Neg INSIDE Add/Sub, not OUTSIDE
36-
; This avoids Add(Neg x, Neg y) → Neg(Add(x, y)) → Add(Neg x, Neg y) cycles
37-
(rule ((= e (Neg (Add x y))))
38-
((union e (Add (Neg x) (Neg y)))))
35+
; Neg factoring - ONE-DIRECTIONAL: pull Neg OUTSIDE Add (simplifying 2 Negs → 1)
36+
; The reverse (pushing Neg inside) creates 2 new Neg nodes causing explosion.
37+
(rule ((= e (Add (Neg x) (Neg y))))
38+
((union e (Neg (Add x y)))))
3939
(rule ((= e (Neg (Sub x y))))
4040
((union e (Sub y x))))
4141
(rule ((= e (Add x (Neg x))))

rust/spirv-tools-opt/src/rules/bitwise.egg

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
((union e x)))
3434

3535
; De Morgan's laws - ONE-DIRECTIONAL to prevent explosion
36-
; We push BitNot INSIDE, not pull OUTSIDE (which would expand forever)
37-
(rule ((= e (BitNot (BitAnd x y))))
38-
((union e (BitOr (BitNot x) (BitNot y)))))
39-
(rule ((= e (BitNot (BitOr x y))))
40-
((union e (BitAnd (BitNot x) (BitNot y)))))
36+
; We pull BitNot OUTSIDE (simplifying 3 nodes → 1), not push INSIDE (which expands 1 → 3)
37+
(rule ((= e (BitOr (BitNot x) (BitNot y))))
38+
((union e (BitNot (BitAnd x y)))))
39+
(rule ((= e (BitAnd (BitNot x) (BitNot y))))
40+
((union e (BitNot (BitOr x y)))))
4141

4242
; Complement rules
4343
(rule ((= e (BitAnd x (BitNot x))))

rust/spirv-tools-opt/src/rules/glsl.egg

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,13 @@
302302
; Fma(a, b, c) = a*b + c (definition, for expansion if needed)
303303
; But prefer Fma when available (single rounding, faster)
304304

305-
; Recognize Fma patterns - ONE-DIRECTIONAL (prefer Fma over expanded form)
306-
; We don't expand Fma to FAdd(FMul...) since that's worse for performance
307-
(rule ((= e (FAdd (FMul a b) c)))
308-
((union e (Fma a b c))))
309-
(rule ((= e (FAdd c (FMul a b))))
310-
((union e (Fma a b c))))
305+
; DISABLED: Fma recognition — fires on every FAdd(FMul(...), ...) pattern,
306+
; creating new Fma nodes that cause exponential growth in matrix-heavy shaders.
307+
; Fma simplification rules below still work if Fma is already in the graph.
308+
; (rule ((= e (FAdd (FMul a b) c)))
309+
; ((union e (Fma a b c))))
310+
; (rule ((= e (FAdd c (FMul a b))))
311+
; ((union e (Fma a b c))))
311312

312313
; Fma with zero - ONE-DIRECTIONAL (simplify, don't expand)
313314
(rule ((= e (Fma a b (FConst c))) (= c 0.0))

rust/spirv-tools-opt/src/rules/rvsdg.egg

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,21 @@
142142
; select(c, true, c) = c | true = true when c is boolean
143143
; But this is the same as select(c, 1, c) = c ? 1 : c which is just 1 when c is true, c when false
144144

145-
; Gamma distributes over operations (hoisting) - ONE-DIRECTIONAL
146-
; Push operations INTO Gamma, not out (pulling out creates Op(Gamma, Gamma) which explodes)
147-
; These rules combine two Gammas with same condition into a single Gamma
148-
; Integer operations use GammaI
149-
(rule ((= e (Add (GammaI c a b) (GammaI c x y))))
150-
((union e (GammaI c (Add a x) (Add b y)))))
151-
(rule ((= e (Sub (GammaI c a b) (GammaI c x y))))
152-
((union e (GammaI c (Sub a x) (Sub b y)))))
153-
(rule ((= e (Mul (GammaI c a b) (GammaI c x y))))
154-
((union e (GammaI c (Mul a x) (Mul b y)))))
155-
(rule ((= e (BitAnd (GammaI c a b) (GammaI c x y))))
156-
((union e (GammaI c (BitAnd a x) (BitAnd b y)))))
157-
(rule ((= e (BitOr (GammaI c a b) (GammaI c x y))))
158-
((union e (GammaI c (BitOr a x) (BitOr b y)))))
159-
(rule ((= e (BitXor (GammaI c a b) (GammaI c x y))))
160-
((union e (GammaI c (BitXor a x) (BitXor b y)))))
145+
; DISABLED: Gamma distribution INTO — creates new Op nodes inside both arms,
146+
; forming cycles with the hoisting rules below (which pull common factors OUT).
147+
; The hoisting rules (lines 166+) are the simplifying direction and are kept.
148+
; (rule ((= e (Add (GammaI c a b) (GammaI c x y))))
149+
; ((union e (GammaI c (Add a x) (Add b y)))))
150+
; (rule ((= e (Sub (GammaI c a b) (GammaI c x y))))
151+
; ((union e (GammaI c (Sub a x) (Sub b y)))))
152+
; (rule ((= e (Mul (GammaI c a b) (GammaI c x y))))
153+
; ((union e (GammaI c (Mul a x) (Mul b y)))))
154+
; (rule ((= e (BitAnd (GammaI c a b) (GammaI c x y))))
155+
; ((union e (GammaI c (BitAnd a x) (BitAnd b y)))))
156+
; (rule ((= e (BitOr (GammaI c a b) (GammaI c x y))))
157+
; ((union e (GammaI c (BitOr a x) (BitOr b y)))))
158+
; (rule ((= e (BitXor (GammaI c a b) (GammaI c x y))))
159+
; ((union e (GammaI c (BitXor a x) (BitXor b y)))))
161160

162161
; Hoist common operations out of Gamma - ONE-DIRECTIONAL
163162
; (c ? (x + a) : (x + b)) => x + (c ? a : b)
@@ -409,13 +408,14 @@
409408
; =============================================================================
410409
; Floating-Point Gamma Distribution - ONE-DIRECTIONAL
411410
; =============================================================================
412-
; Distribution INTO GammaF (combining two conditional float values)
413-
(rule ((= e (FAdd (GammaF c a b) (GammaF c x y))))
414-
((union e (GammaF c (FAdd a x) (FAdd b y)))))
415-
(rule ((= e (FSub (GammaF c a b) (GammaF c x y))))
416-
((union e (GammaF c (FSub a x) (FSub b y)))))
417-
(rule ((= e (FMul (GammaF c a b) (GammaF c x y))))
418-
((union e (GammaF c (FMul a x) (FMul b y)))))
411+
; DISABLED: Distribution INTO GammaF — creates new FOp nodes inside both arms,
412+
; forming cycles with the hoisting rules below.
413+
; (rule ((= e (FAdd (GammaF c a b) (GammaF c x y))))
414+
; ((union e (GammaF c (FAdd a x) (FAdd b y)))))
415+
; (rule ((= e (FSub (GammaF c a b) (GammaF c x y))))
416+
; ((union e (GammaF c (FSub a x) (FSub b y)))))
417+
; (rule ((= e (FMul (GammaF c a b) (GammaF c x y))))
418+
; ((union e (GammaF c (FMul a x) (FMul b y)))))
419419

420420
; Hoisting common factors OUT of GammaF - ONE-DIRECTIONAL
421421
(rule ((= e (GammaF c (FAdd x a) (FAdd x b))))

rust/spirv-tools-opt/src/rules/vector.egg

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,10 @@
394394
; Vector Operations Distributivity
395395
; =============================================================================
396396

397-
; VecTimesScalar distributes over VecAdd - ONE-DIRECTIONAL (distribute, don't factor)
398-
(rule ((= e (VecTimesScalar (VecAdd a b) s)))
399-
((union e (VecAdd (VecTimesScalar a s) (VecTimesScalar b s)))))
397+
; DISABLED: VecTimesScalar distribution over VecAdd — creates 2 new VecTimesScalar
398+
; nodes per rule firing, causing exponential growth in matrix-heavy shaders.
399+
; (rule ((= e (VecTimesScalar (VecAdd a b) s)))
400+
; ((union e (VecAdd (VecTimesScalar a s) (VecTimesScalar b s)))))
400401

401402
; VecTimesScalar with scalar multiplication - ONE-DIRECTIONAL (combine scalars)
402403
; Note: s1 and s2 are Expr (from VecTimesScalar args), but Mul takes IntExpr.
@@ -430,12 +431,12 @@
430431
; Dot Product Additional Properties
431432
; =============================================================================
432433

433-
; Dot product distributes over VecAdd - ONE-DIRECTIONAL (expand Dot)
434-
; Note: Dot returns FloatExpr, so use FAdd (FloatExpr + FloatExpr -> FloatExpr)
435-
(rule ((= e (Dot (VecAdd a b) c)))
436-
((union e (FAdd (Dot a c) (Dot b c)))))
437-
(rule ((= e (Dot a (VecAdd b c))))
438-
((union e (FAdd (Dot a b) (Dot a c)))))
434+
; DISABLED: Dot distribution over VecAdd — creates 2 new Dot nodes per rule,
435+
; causing exponential growth. Matrix multiply = many Dot(VecAdd(...), ...) patterns.
436+
; (rule ((= e (Dot (VecAdd a b) c)))
437+
; ((union e (FAdd (Dot a c) (Dot b c)))))
438+
; (rule ((= e (Dot a (VecAdd b c))))
439+
; ((union e (FAdd (Dot a b) (Dot a c)))))
439440

440441
; Dot product distributes over scalar multiplication - ONE-DIRECTIONAL
441442
; Note: Dot returns FloatExpr, s is Expr (from VecTimesScalar), use FMul with ExprToFloat

0 commit comments

Comments
 (0)