Skip to content

Commit 0d0ba1d

Browse files
committed
Fix row-major matrix alignment and size calculations in block layout
Row-major matrices need special handling per the Vulkan spec: - Alignment: use the alignment of a virtual row vector (num_columns scalar components) instead of the column vector alignment - Size: (num_rows - 1) * matrix_stride + num_columns * scalar_size instead of num_columns * matrix_stride - Remove overly strict straddle check for row-major matrices (the Vulkan spec only defines straddle checks for vectors, not matrices) Thread is_row_major and matrix_stride through type_alignment and type_layout_size, looking up per-member majorness from decorations for struct members (matching the C++ LayoutConstraints approach).
1 parent 62f519e commit 0d0ba1d

2 files changed

Lines changed: 243 additions & 46 deletions

File tree

rust/spirv-tools-core/src/validation/rules/block_layout.rs

Lines changed: 156 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,11 @@ fn collect_member_offsets(module: &Module, struct_id: ResultId) -> HashMap<Membe
225225

226226
fn type_layout_size(
227227
ty: TypeId,
228+
module: &Module,
228229
definitions: &HashMap<ResultId, rspirv::dr::Instruction>,
229230
visiting: &mut HashSet<TypeId>,
231+
is_row_major: bool,
232+
matrix_stride: u32,
230233
) -> Option<u32> {
231234
if !visiting.insert(ty) {
232235
return None;
@@ -240,36 +243,75 @@ fn type_layout_size(
240243
Op::TypeVector => {
241244
let (elem, count) = vector_info(inst);
242245
let (elem, count) = (elem?, count?);
243-
let elem_size = type_layout_size(elem, definitions, visiting)?;
246+
let elem_size = type_layout_size(elem, module, definitions, visiting, false, 0)?;
244247
Some(elem_size.saturating_mul(count))
245248
}
246249
Op::TypeMatrix => {
247-
let (column, count) = matrix_info(inst);
248-
let (column, count) = (column?, count?);
249-
let col_size = type_layout_size(column, definitions, visiting)?;
250-
Some(col_size.saturating_mul(count))
250+
let (column, num_columns) = matrix_info(inst);
251+
let (column, num_columns) = (column?, num_columns?);
252+
if matrix_stride > 0 {
253+
if is_row_major {
254+
// Row major: (num_rows - 1) * stride + num_columns * scalar_size
255+
// (C++ getSize, validate_decorations.cpp lines 364-374)
256+
let col_inst =
257+
definitions.get(&ResultId::try_from(u32::from(column)).ok()?)?;
258+
let (scalar_type, num_rows) = vector_info(col_inst);
259+
let (scalar_type, num_rows) = (scalar_type?, num_rows?);
260+
let scalar_size =
261+
type_layout_size(scalar_type, module, definitions, visiting, false, 0)?;
262+
Some(
263+
num_rows
264+
.saturating_sub(1)
265+
.saturating_mul(matrix_stride)
266+
.saturating_add(num_columns.saturating_mul(scalar_size)),
267+
)
268+
} else {
269+
// Column major: num_columns * stride
270+
// (C++ getSize, validate_decorations.cpp lines 362-363)
271+
Some(num_columns.saturating_mul(matrix_stride))
272+
}
273+
} else {
274+
// No stride info, fall back to raw computation.
275+
let col_size =
276+
type_layout_size(column, module, definitions, visiting, false, 0)?;
277+
Some(col_size.saturating_mul(num_columns))
278+
}
251279
}
252280
Op::TypeArray => {
253281
let elem = inst.operands.first().and_then(|op| match op {
254282
rspirv::dr::Operand::IdRef(id) => TypeId::try_from(*id).ok(),
255283
_ => None,
256284
})?;
257-
let elem_size = type_layout_size(elem, definitions, visiting)?;
285+
let elem_size =
286+
type_layout_size(elem, module, definitions, visiting, is_row_major, matrix_stride)?;
258287
let len = array_length(inst, definitions)?;
259288
Some(elem_size.saturating_mul(len))
260289
}
261290
Op::TypeRuntimeArray => None, // unsized
262291
Op::TypeStruct => {
263-
let mut offset: u32 = 0;
264-
for op in &inst.operands {
265-
let ty = match op {
292+
let struct_id = ResultId::try_from(u32::from(ty)).ok()?;
293+
let mut total: u32 = 0;
294+
for (idx, op) in inst.operands.iter().enumerate() {
295+
let member_ty = match op {
266296
rspirv::dr::Operand::IdRef(id) => TypeId::try_from(*id).ok()?,
267297
_ => return None,
268298
};
269-
let size = type_layout_size(ty, definitions, visiting)?;
270-
offset = offset.saturating_add(size);
299+
let member_rm =
300+
member_is_row_major(module, struct_id, MemberIndex(idx as u32));
301+
let member_ms =
302+
member_matrix_stride(module, struct_id, MemberIndex(idx as u32))
303+
.unwrap_or(0);
304+
let size = type_layout_size(
305+
member_ty,
306+
module,
307+
definitions,
308+
visiting,
309+
member_rm,
310+
member_ms,
311+
)?;
312+
total = total.saturating_add(size);
271313
}
272-
Some(offset)
314+
Some(total)
273315
}
274316
_ => None,
275317
};
@@ -295,12 +337,17 @@ fn round_up(value: u32, align: u32) -> u32 {
295337
///
296338
/// When `scalar_layout` is true, vectors use scalar element alignment instead
297339
/// of the standard 2N/4N rules.
340+
///
341+
/// When `is_row_major` is true, matrices use a virtual row-vector alignment
342+
/// (based on the number of columns) instead of column-vector alignment.
298343
fn type_alignment(
299344
ty: TypeId,
345+
module: &Module,
300346
definitions: &HashMap<ResultId, rspirv::dr::Instruction>,
301347
visiting: &mut HashSet<TypeId>,
302348
scalar_layout: bool,
303349
extended_alignment: bool,
350+
is_row_major: bool,
304351
) -> Option<u32> {
305352
if !visiting.insert(ty) {
306353
return None;
@@ -314,8 +361,15 @@ fn type_alignment(
314361
Op::TypeVector => {
315362
let (elem, count) = vector_info(inst);
316363
let (elem, count) = (elem?, count?);
317-
let elem_align =
318-
type_alignment(elem, definitions, visiting, scalar_layout, extended_alignment)?;
364+
let elem_align = type_alignment(
365+
elem,
366+
module,
367+
definitions,
368+
visiting,
369+
scalar_layout,
370+
extended_alignment,
371+
false,
372+
)?;
319373
if scalar_layout {
320374
Some(elem_align)
321375
} else {
@@ -325,11 +379,38 @@ fn type_alignment(
325379
}
326380
}
327381
Op::TypeMatrix => {
328-
// Matrix alignment follows its column vector alignment.
329-
let (column, _) = matrix_info(inst);
330-
let column = column?;
331-
let base_align =
332-
type_alignment(column, definitions, visiting, scalar_layout, extended_alignment)?;
382+
let (column, num_columns) = matrix_info(inst);
383+
let (column, num_columns) = (column?, num_columns?);
384+
let base_align = if is_row_major && !scalar_layout {
385+
// Row-major: alignment of a virtual vector of num_columns scalar
386+
// components (C++ getBaseAlignment, validate_decorations.cpp:210-219).
387+
let col_inst =
388+
definitions.get(&ResultId::try_from(u32::from(column)).ok()?)?;
389+
let (scalar_type, _) = vector_info(col_inst);
390+
let scalar_type = scalar_type?;
391+
let scalar_align = type_alignment(
392+
scalar_type,
393+
module,
394+
definitions,
395+
visiting,
396+
scalar_layout,
397+
extended_alignment,
398+
false,
399+
)?;
400+
let multiplier = if num_columns == 2 { 2 } else { 4 };
401+
scalar_align.checked_mul(multiplier)?
402+
} else {
403+
// Column-major (or scalar layout): alignment of column vector.
404+
type_alignment(
405+
column,
406+
module,
407+
definitions,
408+
visiting,
409+
scalar_layout,
410+
extended_alignment,
411+
false,
412+
)?
413+
};
333414
if extended_alignment && !scalar_layout {
334415
Some(round_up(base_align, 16))
335416
} else {
@@ -341,23 +422,43 @@ fn type_alignment(
341422
rspirv::dr::Operand::IdRef(id) => TypeId::try_from(*id).ok(),
342423
_ => None,
343424
})?;
344-
let base_align =
345-
type_alignment(elem, definitions, visiting, scalar_layout, extended_alignment)?;
425+
// Propagate is_row_major through arrays (array of matrices inherits
426+
// majorness from the struct member decoration).
427+
let base_align = type_alignment(
428+
elem,
429+
module,
430+
definitions,
431+
visiting,
432+
scalar_layout,
433+
extended_alignment,
434+
is_row_major,
435+
)?;
346436
if extended_alignment && !scalar_layout {
347437
Some(round_up(base_align, 16))
348438
} else {
349439
Some(base_align)
350440
}
351441
}
352442
Op::TypeStruct => {
443+
let struct_id = ResultId::try_from(u32::from(ty)).ok()?;
353444
let mut max_align = 1;
354-
for op in &inst.operands {
355-
let ty = match op {
445+
for (idx, op) in inst.operands.iter().enumerate() {
446+
let member_ty = match op {
356447
rspirv::dr::Operand::IdRef(id) => TypeId::try_from(*id).ok()?,
357448
_ => return None,
358449
};
359-
let align =
360-
type_alignment(ty, definitions, visiting, scalar_layout, extended_alignment)?;
450+
// Each struct member has its own RowMajor/ColMajor decoration.
451+
let member_rm =
452+
member_is_row_major(module, struct_id, MemberIndex(idx as u32));
453+
let align = type_alignment(
454+
member_ty,
455+
module,
456+
definitions,
457+
visiting,
458+
scalar_layout,
459+
extended_alignment,
460+
member_rm,
461+
)?;
361462
max_align = max_align.max(align);
362463
}
363464
if extended_alignment && !scalar_layout {
@@ -374,13 +475,14 @@ fn type_alignment(
374475

375476
fn vector_scalar_alignment(
376477
vector_inst: &rspirv::dr::Instruction,
478+
module: &Module,
377479
definitions: &HashMap<ResultId, rspirv::dr::Instruction>,
378480
) -> Option<u32> {
379481
let elem = vector_inst.operands.first().and_then(|op| match op {
380482
rspirv::dr::Operand::IdRef(id) => TypeId::try_from(*id).ok(),
381483
_ => None,
382484
})?;
383-
type_alignment(elem, definitions, &mut HashSet::new(), true, false)
485+
type_alignment(elem, module, definitions, &mut HashSet::new(), true, false, false)
384486
}
385487

386488
fn vector_info(inst: &rspirv::dr::Instruction) -> (Option<TypeId>, Option<u32>) {
@@ -580,12 +682,20 @@ fn check_struct_layout(
580682
continue;
581683
};
582684

685+
// Look up matrix majorness and stride for this member (inherited
686+
// through arrays to contained matrices, matching C++ LayoutConstraints).
687+
let is_row_major = member_is_row_major(module, struct_id, MemberIndex(member_idx));
688+
let mat_stride =
689+
member_matrix_stride(module, struct_id, MemberIndex(member_idx)).unwrap_or(0);
690+
583691
let Some(alignment) = type_alignment(
584692
member_type_id,
693+
module,
585694
definitions,
586695
&mut HashSet::new(),
587696
scalar_layout,
588697
extended_alignment,
698+
is_row_major,
589699
) else {
590700
continue;
591701
};
@@ -594,7 +704,14 @@ fn check_struct_layout(
594704
let size = if member_inst.class.opcode == Op::TypeRuntimeArray {
595705
0
596706
} else {
597-
match type_layout_size(member_type_id, definitions, &mut HashSet::new()) {
707+
match type_layout_size(
708+
member_type_id,
709+
module,
710+
definitions,
711+
&mut HashSet::new(),
712+
is_row_major,
713+
mat_stride,
714+
) {
598715
Some(s) => s,
599716
None => continue,
600717
}
@@ -614,7 +731,7 @@ fn check_struct_layout(
614731
// Offset alignment checks.
615732
if relax_block_layout && !scalar_layout && member_inst.class.opcode == Op::TypeVector {
616733
// Relaxed layout: vector offset aligned to scalar element alignment.
617-
let Some(scalar_align) = vector_scalar_alignment(member_inst, definitions) else {
734+
let Some(scalar_align) = vector_scalar_alignment(member_inst, module, definitions) else {
618735
continue;
619736
};
620737
if offset % scalar_align != 0 {
@@ -702,7 +819,7 @@ fn check_struct_layout(
702819
let (column_type, _) = matrix_info(member_inst);
703820
if let Some(col_ty) = column_type {
704821
if let Some(col_size) =
705-
type_layout_size(col_ty, definitions, &mut HashSet::new())
822+
type_layout_size(col_ty, module, definitions, &mut HashSet::new(), false, 0)
706823
{
707824
if col_size > stride {
708825
return Err(ValidationError::InvalidBlockLayout {
@@ -714,20 +831,6 @@ fn check_struct_layout(
714831
}
715832
.into());
716833
}
717-
if relax_block_layout
718-
&& !scalar_layout
719-
&& member_is_row_major(module, struct_id, MemberIndex(member_idx))
720-
&& col_size > 16
721-
&& (offset % 16).saturating_add(col_size) > 16
722-
{
723-
return Err(ValidationError::InvalidBlockLayout {
724-
struct_type: struct_id,
725-
reason:
726-
"row-major matrix straddles 16-byte boundary under relaxed layout"
727-
.to_string(),
728-
}
729-
.into());
730-
}
731834
}
732835
}
733836
}
@@ -823,9 +926,14 @@ fn check_struct_layout(
823926

824927
// Check element_size <= stride (C++ lines 700-706).
825928
if stride_val > 0 {
826-
if let Some(element_size) =
827-
type_layout_size(elem_type, definitions, &mut HashSet::new())
828-
{
929+
if let Some(element_size) = type_layout_size(
930+
elem_type,
931+
module,
932+
definitions,
933+
&mut HashSet::new(),
934+
is_row_major,
935+
mat_stride,
936+
) {
829937
if element_size > stride_val {
830938
return Err(ValidationError::InvalidBlockLayout {
831939
struct_type: struct_id,
@@ -844,10 +952,12 @@ fn check_struct_layout(
844952
array_result_id = elem_result_id;
845953
array_alignment = type_alignment(
846954
elem_type,
955+
module,
847956
definitions,
848957
&mut HashSet::new(),
849958
scalar_layout,
850959
extended_alignment,
960+
is_row_major,
851961
)
852962
.unwrap_or(1);
853963
}

0 commit comments

Comments
 (0)