Skip to content

Commit 86dfbaf

Browse files
committed
Fix NaN correctness for ordered float reflexive rules and merge_idx guard
- Disable FOrdEq(x,x), FOrdLe(x,x), FOrdGe(x,x) reflexive rules in logical.egg: IEEE 754 says NaN is not ordered with itself, so these return false for NaN inputs. Keep FOrdLt/FOrdGt reflexive rules (strict inequality is always false, correct for NaN). - Add merge_idx < header_idx safety guard in both selection and switch construct RVSDG guards: non-standard block ordering could cause empty/invalid ranges. - Update test_boolconst_float_comparison_reflexive to verify FOrdEq(x,x) is NOT folded (NaN correctness) and FOrdLt(x,x) IS folded.
1 parent 96e17dd commit 86dfbaf

3 files changed

Lines changed: 50 additions & 19 deletions

File tree

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -632,11 +632,15 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
632632
let label_map = &func_block_labels[sel.func_idx];
633633
let merge_idx = label_map.get(&sel.merge_label).copied();
634634
let in_loop = if let Some(merge_idx) = merge_idx {
635-
// Check all blocks from header to merge (inclusive)
636-
(sel.header_block_idx..=merge_idx).any(|idx| {
637-
loop_block_set.contains(&(sel.func_idx, idx))
638-
|| continue_block_set.contains(&(sel.func_idx, idx))
639-
})
635+
// Check all blocks from header to merge (inclusive).
636+
if merge_idx < sel.header_block_idx {
637+
true // Non-standard block ordering — skip to be safe
638+
} else {
639+
(sel.header_block_idx..=merge_idx).any(|idx| {
640+
loop_block_set.contains(&(sel.func_idx, idx))
641+
|| continue_block_set.contains(&(sel.func_idx, idx))
642+
})
643+
}
640644
} else {
641645
// Can't resolve merge — skip to be safe
642646
true
@@ -754,10 +758,14 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
754758
let label_map = &func_block_labels[sw.func_idx];
755759
let merge_idx = label_map.get(&sw.merge_label).copied();
756760
let in_loop = if let Some(merge_idx) = merge_idx {
757-
(sw.header_block_idx..=merge_idx).any(|idx| {
758-
loop_block_set.contains(&(sw.func_idx, idx))
759-
|| continue_block_set.contains(&(sw.func_idx, idx))
760-
})
761+
if merge_idx < sw.header_block_idx {
762+
true // Non-standard block ordering — skip to be safe
763+
} else {
764+
(sw.header_block_idx..=merge_idx).any(|idx| {
765+
loop_block_set.contains(&(sw.func_idx, idx))
766+
|| continue_block_set.contains(&(sw.func_idx, idx))
767+
})
768+
}
761769
} else {
762770
true
763771
};

rust/spirv-tools-opt/src/egglog_opt/tests.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6557,7 +6557,9 @@ fn test_boolconst_logical_tautology() {
65576557

65586558
#[test]
65596559
fn test_boolconst_float_comparison_reflexive() {
6560-
// FOrdEq(x, x) should produce BoolConst(1)
6560+
// FOrdEq(x, x) must NOT fold to BoolConst(1) because x could be NaN.
6561+
// IEEE 754: NaN is not ordered with itself, so FOrdEq(NaN, NaN) = false.
6562+
// Only FOrdLt(x,x) and FOrdGt(x,x) are always false (correct for NaN too).
65616563
let mut egraph = create_spirv_egraph().unwrap();
65626564

65636565
egraph
@@ -6566,16 +6568,29 @@ fn test_boolconst_float_comparison_reflexive() {
65666568
r#"
65676569
(let x (FSym "f"))
65686570
(let root (FOrdEq x x))
6569-
(let expected (BoolConst 1))
6571+
(let wrong (BoolConst 1))
6572+
(let lt_root (FOrdLt x x))
6573+
(let false_val (BoolConst 0))
65706574
"#,
65716575
)
65726576
.unwrap();
65736577
egraph
65746578
.parse_and_run_program(None, "(run-schedule (repeat 10 (run)))")
65756579
.unwrap();
65766580

6577-
let check = egraph.parse_and_run_program(None, "(check (= root expected))");
6578-
assert!(check.is_ok(), "FOrdEq(x, x) should fold to BoolConst(1)");
6581+
// FOrdEq(x,x) must NOT be folded (NaN correctness)
6582+
let check = egraph.parse_and_run_program(None, "(check (= root wrong))");
6583+
assert!(
6584+
check.is_err(),
6585+
"FOrdEq(x, x) must NOT fold to BoolConst(1) — wrong for NaN"
6586+
);
6587+
6588+
// FOrdLt(x,x) IS correctly folded to BoolConst(0)
6589+
let check2 = egraph.parse_and_run_program(None, "(check (= lt_root false_val))");
6590+
assert!(
6591+
check2.is_ok(),
6592+
"FOrdLt(x, x) should fold to BoolConst(0)"
6593+
);
65796594
}
65806595

65816596
#[test]

rust/spirv-tools-opt/src/rules/logical.egg

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,16 +459,24 @@
459459
; =============================================================================
460460

461461
; Floating-point reflexive comparisons
462+
; DISABLED: Ordered float reflexive comparisons — wrong for NaN.
463+
; IEEE 754: NaN is NOT ordered with itself, so FOrdLe(NaN,NaN) = false,
464+
; FOrdEq(NaN,NaN) = false, etc. Only unordered comparisons return true for NaN.
465+
; The correct rules would be:
466+
; FOrdEq(x,x) = LogNot(IsNan(x))
467+
; FOrdLt(x,x) = BoolConst(0) (correct: nothing is strictly less than itself)
468+
; FOrdGt(x,x) = BoolConst(0) (correct: same)
469+
; FOrdLt/FOrdGt reflexive rules ARE correct (strict inequality is always false).
462470
(rule ((= e (FOrdLt a a)))
463471
((union e (BoolConst 0))))
464472
(rule ((= e (FOrdGt a a)))
465473
((union e (BoolConst 0))))
466-
(rule ((= e (FOrdLe a a)))
467-
((union e (BoolConst 1))))
468-
(rule ((= e (FOrdGe a a)))
469-
((union e (BoolConst 1))))
470-
(rule ((= e (FOrdEq a a)))
471-
((union e (BoolConst 1))))
474+
; (rule ((= e (FOrdLe a a)))
475+
; ((union e (BoolConst 1))))
476+
; (rule ((= e (FOrdGe a a)))
477+
; ((union e (BoolConst 1))))
478+
; (rule ((= e (FOrdEq a a)))
479+
; ((union e (BoolConst 1))))
472480
(rule ((= e (FOrdNe a a)))
473481
((union e (BoolConst 0))))
474482

0 commit comments

Comments
 (0)