Skip to content

Commit 0b3b9f7

Browse files
committed
Add OpBranchConditional condition type validation
Validate that the condition operand of OpBranchConditional is OpTypeBool, as required by the SPIR-V spec. Non-bool conditions (e.g., integer values) are rejected with BranchConditionalConditionNotBool.
1 parent 56fe883 commit 0b3b9f7

3 files changed

Lines changed: 169 additions & 0 deletions

File tree

rust/spirv-tools-core/src/validation/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,14 @@ pub enum ValidationError {
28652865
/// The invalid result type opcode.
28662866
type_opcode: rspirv::spirv::Op,
28672867
},
2868+
/// OpBranchConditional condition operand must be OpTypeBool.
2869+
#[error("OpBranchConditional condition {condition_id} must be OpTypeBool, found {found_opcode:?}")]
2870+
BranchConditionalConditionNotBool {
2871+
/// The condition operand ID.
2872+
condition_id: Id,
2873+
/// The opcode of the condition's type.
2874+
found_opcode: rspirv::spirv::Op,
2875+
},
28682876
/// In SPIR-V 1.6 or later, BranchConditional True Label and False Label must be different.
28692877
#[error("In SPIR-V 1.6 or later, True Label and False Label must be different labels")]
28702878
BranchConditionalSameLabels {

rust/spirv-tools-core/src/validation/rules/cfg.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,29 @@ impl ValidationRule for BranchConditionalRule {
13631363
continue;
13641364
}
13651365

1366+
// Validate condition operand is OpTypeBool
1367+
if let Some(Operand::IdRef(cond_id)) = inst.operands.first() {
1368+
if let Ok(cond_rid) = ResultId::try_from(*cond_id) {
1369+
if let Some(cond_def) = ctx.definitions.get(&cond_rid) {
1370+
if let Some(type_id) = cond_def.result_type {
1371+
if let Ok(type_rid) = ResultId::try_from(type_id) {
1372+
if let Some(type_inst) = ctx.definitions.get(&type_rid) {
1373+
if type_inst.class.opcode != Op::TypeBool {
1374+
return Err(
1375+
ValidationError::BranchConditionalConditionNotBool {
1376+
condition_id: to_id(*cond_id),
1377+
found_opcode: type_inst.class.opcode,
1378+
}
1379+
.into(),
1380+
);
1381+
}
1382+
}
1383+
}
1384+
}
1385+
}
1386+
}
1387+
}
1388+
13661389
// Get true and false label operands (operands 1 and 2)
13671390
let true_label = match inst.operands.get(1) {
13681391
Some(Operand::IdRef(id)) => *id,

rust/spirv-tools-core/src/validation/tests/cfg.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,3 +1791,141 @@ fn switch_with_float_selector_fails() {
17911791
"expected SwitchSelectorNotInteger, got {err:?}"
17921792
);
17931793
}
1794+
1795+
// ============================================================================
1796+
// BranchConditional Condition Type Validation
1797+
// ============================================================================
1798+
1799+
#[test]
1800+
fn branch_conditional_with_bool_condition_passes() {
1801+
let text = r#"
1802+
OpCapability Shader
1803+
OpMemoryModel Logical GLSL450
1804+
OpEntryPoint GLCompute %main "main"
1805+
OpExecutionMode %main LocalSize 1 1 1
1806+
%void = OpTypeVoid
1807+
%fn = OpTypeFunction %void
1808+
%bool = OpTypeBool
1809+
%true = OpConstantTrue %bool
1810+
%main = OpFunction %void None %fn
1811+
%entry = OpLabel
1812+
OpSelectionMerge %merge None
1813+
OpBranchConditional %true %then %else
1814+
%then = OpLabel
1815+
OpBranch %merge
1816+
%else = OpLabel
1817+
OpBranch %merge
1818+
%merge = OpLabel
1819+
OpReturn
1820+
OpFunctionEnd
1821+
"#;
1822+
assemble_and_validate(text).expect("BranchConditional with bool condition should pass");
1823+
}
1824+
1825+
#[test]
1826+
fn branch_conditional_with_int_condition_fails() {
1827+
use rspirv::binary::Assemble;
1828+
use rspirv::dr::{Instruction, Module, Operand};
1829+
1830+
let mut module = Module::new();
1831+
module.header = Some(rspirv::dr::ModuleHeader {
1832+
magic_number: rspirv::spirv::MAGIC_NUMBER,
1833+
version: (1 << 16) | (5 << 8),
1834+
generator: 0,
1835+
bound: 20,
1836+
reserved_word: 0,
1837+
});
1838+
1839+
module.capabilities.push(Instruction::new(
1840+
Op::Capability, None, None,
1841+
vec![Operand::Capability(rspirv::spirv::Capability::Shader)],
1842+
));
1843+
module.memory_model = Some(Instruction::new(
1844+
Op::MemoryModel, None, None,
1845+
vec![
1846+
Operand::AddressingModel(rspirv::spirv::AddressingModel::Logical),
1847+
Operand::MemoryModel(rspirv::spirv::MemoryModel::GLSL450),
1848+
],
1849+
));
1850+
module.entry_points.push(Instruction::new(
1851+
Op::EntryPoint, None, None,
1852+
vec![
1853+
Operand::ExecutionModel(rspirv::spirv::ExecutionModel::GLCompute),
1854+
Operand::IdRef(10),
1855+
Operand::LiteralString("main".to_string()),
1856+
],
1857+
));
1858+
module.execution_modes.push(Instruction::new(
1859+
Op::ExecutionMode, None, None,
1860+
vec![
1861+
Operand::IdRef(10),
1862+
Operand::ExecutionMode(rspirv::spirv::ExecutionMode::LocalSize),
1863+
Operand::LiteralBit32(1),
1864+
Operand::LiteralBit32(1),
1865+
Operand::LiteralBit32(1),
1866+
],
1867+
));
1868+
1869+
// %2 = OpTypeVoid
1870+
module.types_global_values.push(Instruction::new(Op::TypeVoid, None, Some(2), vec![]));
1871+
// %3 = OpTypeFunction %void
1872+
module.types_global_values.push(Instruction::new(
1873+
Op::TypeFunction, None, Some(3), vec![Operand::IdRef(2)],
1874+
));
1875+
// %4 = OpTypeInt 32 0
1876+
module.types_global_values.push(Instruction::new(
1877+
Op::TypeInt, None, Some(4), vec![Operand::LiteralBit32(32), Operand::LiteralBit32(0)],
1878+
));
1879+
// %5 = OpConstant %4 1
1880+
module.types_global_values.push(Instruction::new(
1881+
Op::Constant, Some(4), Some(5), vec![Operand::LiteralBit32(1)],
1882+
));
1883+
1884+
let mut func = rspirv::dr::Function::new();
1885+
func.def = Some(Instruction::new(
1886+
Op::Function, Some(2), Some(10),
1887+
vec![
1888+
Operand::FunctionControl(rspirv::spirv::FunctionControl::NONE),
1889+
Operand::IdRef(3),
1890+
],
1891+
));
1892+
1893+
let mut entry = rspirv::dr::Block::new();
1894+
entry.label = Some(Instruction::new(Op::Label, None, Some(11), vec![]));
1895+
entry.instructions.push(Instruction::new(
1896+
Op::SelectionMerge, None, None,
1897+
vec![Operand::IdRef(14), Operand::SelectionControl(rspirv::spirv::SelectionControl::NONE)],
1898+
));
1899+
// OpBranchConditional %5(int!) %12 %13 -- condition is integer, not bool
1900+
entry.instructions.push(Instruction::new(
1901+
Op::BranchConditional, None, None,
1902+
vec![Operand::IdRef(5), Operand::IdRef(12), Operand::IdRef(13)],
1903+
));
1904+
func.blocks.push(entry);
1905+
1906+
let mut then_block = rspirv::dr::Block::new();
1907+
then_block.label = Some(Instruction::new(Op::Label, None, Some(12), vec![]));
1908+
then_block.instructions.push(Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(14)]));
1909+
func.blocks.push(then_block);
1910+
1911+
let mut else_block = rspirv::dr::Block::new();
1912+
else_block.label = Some(Instruction::new(Op::Label, None, Some(13), vec![]));
1913+
else_block.instructions.push(Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(14)]));
1914+
func.blocks.push(else_block);
1915+
1916+
let mut merge = rspirv::dr::Block::new();
1917+
merge.label = Some(Instruction::new(Op::Label, None, Some(14), vec![]));
1918+
merge.instructions.push(Instruction::new(Op::Return, None, None, vec![]));
1919+
func.blocks.push(merge);
1920+
1921+
func.end = Some(Instruction::new(Op::FunctionEnd, None, None, vec![]));
1922+
module.functions.push(func);
1923+
1924+
let binary = module.assemble();
1925+
let err = validate_module(&binary, TargetEnv::Universal1_6)
1926+
.expect_err("BranchConditional with int condition should fail");
1927+
assert!(
1928+
matches!(err, ValidationError::BranchConditionalConditionNotBool { .. }),
1929+
"expected BranchConditionalConditionNotBool, got {err:?}"
1930+
);
1931+
}

0 commit comments

Comments
 (0)