Skip to content

Commit 7a2e357

Browse files
committed
Fix incorrect PRE hoisting when Gamma is nested inside common operation
The is_gamma_or_select check only inspected the outermost constructor of extracted Gamma terms. When egraph rules redistributed the Gamma inside a common operation (e.g. GammaI(c, x*2, x*3) → Mul(x, GammaI(c, 2, 3))), the outermost Mul was not recognized as condition-dependent, causing the hoisting to incorrectly collapse both branches to the then-branch value. Fix by recursively checking for Gamma/Select nodes anywhere in the term tree before hoisting.
1 parent 7f1bce9 commit 7a2e357

1 file changed

Lines changed: 61 additions & 16 deletions

File tree

  • rust/spirv-tools-opt/src/direct

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,23 +1179,16 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
11791179
None => continue,
11801180
};
11811181

1182-
// If the gamma has simplified to just the expression (not a Gamma/Select variant),
1183-
// it means both branches computed the same thing and it can be hoisted
1184-
let is_gamma_or_select = match parse_sexpr(&gamma_term) {
1185-
Some(Term::App { ref op, .. }) => matches!(
1186-
op.as_str(),
1187-
"Gamma"
1188-
| "GammaI"
1189-
| "GammaF"
1190-
| "GammaB"
1191-
| "Select"
1192-
| "SelectI"
1193-
| "SelectF"
1194-
| "SelectB"
1195-
),
1196-
_ => false,
1182+
// If the gamma has simplified to just the expression (no Gamma/Select anywhere),
1183+
// it means both branches computed the same thing and it can be hoisted.
1184+
// We must check recursively because egraph rules can push Gamma inside
1185+
// common operations (e.g. GammaI(c, x*2, x*3) → Mul(x, GammaI(c, 2, 3))),
1186+
// which moves the selection below the outermost constructor.
1187+
let contains_gamma_or_select = match parse_sexpr(&gamma_term) {
1188+
Some(ref term) => term_contains_gamma_or_select(term),
1189+
None => false,
11971190
};
1198-
if !is_gamma_or_select {
1191+
if !contains_gamma_or_select {
11991192
// The expression can be hoisted!
12001193
// Mark both branch IDs to become CopyObjects of the hoisted value
12011194
let result_type = ctx.id_to_type.get(&pair.then_id).copied().unwrap_or(0);
@@ -2526,6 +2519,31 @@ enum ParsedEffect {
25262519
Unreachable,
25272520
}
25282521

2522+
/// Check whether a parsed term contains a Gamma or Select node anywhere in the tree.
2523+
/// Used to determine if a branch-value pair still depends on the branch condition
2524+
/// (Gamma may have been pushed inside a common operation by egraph rules).
2525+
fn term_contains_gamma_or_select(term: &Term) -> bool {
2526+
match term {
2527+
Term::Atom(_) => false,
2528+
Term::App { op, args } => {
2529+
if matches!(
2530+
op.as_str(),
2531+
"Gamma"
2532+
| "GammaI"
2533+
| "GammaF"
2534+
| "GammaB"
2535+
| "Select"
2536+
| "SelectI"
2537+
| "SelectF"
2538+
| "SelectB"
2539+
) {
2540+
return true;
2541+
}
2542+
args.iter().any(term_contains_gamma_or_select)
2543+
}
2544+
}
2545+
}
2546+
25292547
/// Parse an extracted Effect term from egglog using the Term tree.
25302548
fn parse_effect_result(s: &str) -> Option<ParsedEffect> {
25312549
let term = parse_sexpr(s.trim())?;
@@ -3019,4 +3037,31 @@ mod tests {
30193037
fn parse_effect_bare_atom_returns_none() {
30203038
assert!(parse_effect_result("id5").is_none());
30213039
}
3040+
3041+
#[test]
3042+
fn term_contains_gamma_detects_nested_gamma() {
3043+
// Outermost Gamma → true
3044+
let t = parse_sexpr("(GammaI cond a b)").unwrap();
3045+
assert!(term_contains_gamma_or_select(&t));
3046+
3047+
// Gamma nested inside Mul → true (was the bug: outermost-only check missed this)
3048+
let t = parse_sexpr("(Mul val (GammaI cond 2 3))").unwrap();
3049+
assert!(term_contains_gamma_or_select(&t));
3050+
3051+
// Gamma nested inside Add → true
3052+
let t = parse_sexpr("(Add val (GammaI cond 100 200))").unwrap();
3053+
assert!(term_contains_gamma_or_select(&t));
3054+
3055+
// No Gamma anywhere → false (safe to hoist)
3056+
let t = parse_sexpr("(Mul val 3)").unwrap();
3057+
assert!(!term_contains_gamma_or_select(&t));
3058+
3059+
// Bare atom → false
3060+
let t = parse_sexpr("id5").unwrap();
3061+
assert!(!term_contains_gamma_or_select(&t));
3062+
3063+
// SelectI variant → true
3064+
let t = parse_sexpr("(Add x (SelectI cond a b))").unwrap();
3065+
assert!(term_contains_gamma_or_select(&t));
3066+
}
30223067
}

0 commit comments

Comments
 (0)