Skip to content

Commit 96e17dd

Browse files
committed
Fix BitAnd+Add carry bug and ShrU arithmetic shift bug
- Disable mask & (x + c) rules: bits-disjoint ignores carry propagation (e.g. mask=4, c=2, x=2: (2+2)&4=4 but 2&4=0) - Fix ShrU constant folding: use u32-shr primitive for logical shift instead of egglog's >> which is i64 arithmetic shift - Fix Shl constant folding: use u32-shl for proper 32-bit semantics
1 parent 8404c2c commit 96e17dd

5 files changed

Lines changed: 38 additions & 18 deletions

File tree

rust/spirv-tools-opt/src/egglog_opt/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,14 @@ pub fn create_spirv_egraph() -> Result<EGraph, EgglogOptError> {
998998
);
999999
add_primitive!(&mut egraph, "u32-div" = |a: i64, b: i64| -?> i64 { u32_div(a, b) });
10001000
add_primitive!(&mut egraph, "u32-mod" = |a: i64, b: i64| -?> i64 { u32_mod(a, b) });
1001+
add_primitive!(
1002+
&mut egraph,
1003+
"u32-shr" = |a: i64, b: i64| -> i64 { u32_shr(a, b) }
1004+
);
1005+
add_primitive!(
1006+
&mut egraph,
1007+
"u32-shl" = |a: i64, b: i64| -> i64 { u32_shl(a, b) }
1008+
);
10011009

10021010
// NaN-aware float comparison primitives (FOrd* returns 0 if NaN, FUnord* returns 1 if NaN)
10031011
add_primitive!(

rust/spirv-tools-opt/src/egglog_opt/primitives/bitwise.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,22 @@ pub fn u32_mod(a: i64, b: i64) -> Option<i64> {
253253
}
254254
}
255255

256+
/// Logical (unsigned) right shift: cast to u32, shift, sign-extend back to i64.
257+
/// egglog's native `>>` is arithmetic (sign-extending), which is wrong for ShrU.
258+
pub fn u32_shr(a: i64, b: i64) -> i64 {
259+
let a = a as u32;
260+
let b = (b as u32) & 31; // mask shift amount to 0-31
261+
(a >> b) as i32 as i64
262+
}
263+
264+
/// Logical left shift: cast to u32, shift, sign-extend back to i64.
265+
/// egglog's native `<<` on i64 can produce results wider than 32 bits.
266+
pub fn u32_shl(a: i64, b: i64) -> i64 {
267+
let a = a as u32;
268+
let b = (b as u32) & 31;
269+
(a << b) as i32 as i64
270+
}
271+
256272
// =============================================================================
257273
// Type conversion primitives (cross-type: F <-> i64)
258274
// =============================================================================

rust/spirv-tools-opt/src/egglog_opt/primitives/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub(super) use bitwise::{
1010
funord_ge, funord_gt, funord_le, funord_lt, funord_ne, has_exact_recip, int_to_float_signed,
1111
int_to_float_unsigned, is_float_one64, is_float_zero64, is_pow2, log2_pow2, mask_superset,
1212
popcount, sdiv32, shl_clears_mask, shr_clears_mask, smod, srem32, u32_div, u32_ge, u32_gt,
13-
u32_le, u32_lt, u32_max, u32_min, u32_mod,
13+
u32_le, u32_lt, u32_max, u32_min, u32_mod, u32_shl, u32_shr,
1414
};
1515

1616
#[allow(unused_imports)]

rust/spirv-tools-opt/src/rules/constant_folding.egg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
(rule ((= e (BitNot (Const a))))
2121
((union e (Const (not-i64 a)))))
2222
(rule ((= e (Shl (Const a) (Const b))))
23-
((union e (Const (<< a b)))))
23+
((union e (Const (u32-shl a b)))))
2424
(rule ((= e (ShrU (Const a) (Const b))))
25-
((union e (Const (>> a b)))))
25+
((union e (Const (u32-shr a b)))))
2626

2727
; Signed right shift of constants
2828
; Note: egglog's >> is arithmetic shift for i64

rust/spirv-tools-opt/src/rules/primitives.egg

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,17 @@
2020
; Mask-Aware Optimizations
2121
; =============================================================================
2222

23-
; mask & (x + const) = mask & x when const has no bits in mask
24-
(rule ((= e (BitAnd (Const mask) (Add x (Const c))))
25-
(bits-disjoint mask c))
26-
((union e (BitAnd (Const mask) x))))
27-
(rule ((= e (BitAnd (Add x (Const c)) (Const mask)))
28-
(bits-disjoint mask c))
29-
((union e (BitAnd x (Const mask)))))
30-
31-
; mask & (x - const) = mask & x when const has no bits in mask
32-
(rule ((= e (BitAnd (Const mask) (Sub x (Const c))))
33-
(bits-disjoint mask c))
34-
((union e (BitAnd (Const mask) x))))
35-
(rule ((= e (BitAnd (Sub x (Const c)) (Const mask)))
36-
(bits-disjoint mask c))
37-
((union e (BitAnd x (Const mask)))))
23+
; DISABLED: mask & (x + const) = mask & x — incorrect due to carry propagation.
24+
; bits-disjoint(mask, c) only checks mask & c == 0, but adding c causes carries
25+
; that can affect bits in mask. E.g. mask=4, c=2, x=2: (2+2)&4=4 but 2&4=0.
26+
; (rule ((= e (BitAnd (Const mask) (Add x (Const c)))) (bits-disjoint mask c))
27+
; ((union e (BitAnd (Const mask) x))))
28+
; (rule ((= e (BitAnd (Add x (Const c)) (Const mask))) (bits-disjoint mask c))
29+
; ((union e (BitAnd x (Const mask)))))
30+
; (rule ((= e (BitAnd (Const mask) (Sub x (Const c)))) (bits-disjoint mask c))
31+
; ((union e (BitAnd (Const mask) x))))
32+
; (rule ((= e (BitAnd (Sub x (Const c)) (Const mask))) (bits-disjoint mask c))
33+
; ((union e (BitAnd x (Const mask)))))
3834

3935
; mask & (x | const) = mask & x when const has no bits in mask
4036
(rule ((= e (BitAnd (Const mask) (BitOr x (Const c))))

0 commit comments

Comments
 (0)