@@ -225,8 +225,11 @@ fn collect_member_offsets(module: &Module, struct_id: ResultId) -> HashMap<Membe
225225
226226fn type_layout_size (
227227 ty : TypeId ,
228+ module : & Module ,
228229 definitions : & HashMap < ResultId , rspirv:: dr:: Instruction > ,
229230 visiting : & mut HashSet < TypeId > ,
231+ is_row_major : bool ,
232+ matrix_stride : u32 ,
230233) -> Option < u32 > {
231234 if !visiting. insert ( ty) {
232235 return None ;
@@ -240,36 +243,75 @@ fn type_layout_size(
240243 Op :: TypeVector => {
241244 let ( elem, count) = vector_info ( inst) ;
242245 let ( elem, count) = ( elem?, count?) ;
243- let elem_size = type_layout_size ( elem, definitions, visiting) ?;
246+ let elem_size = type_layout_size ( elem, module , definitions, visiting, false , 0 ) ?;
244247 Some ( elem_size. saturating_mul ( count) )
245248 }
246249 Op :: TypeMatrix => {
247- let ( column, count) = matrix_info ( inst) ;
248- let ( column, count) = ( column?, count?) ;
249- let col_size = type_layout_size ( column, definitions, visiting) ?;
250- Some ( col_size. saturating_mul ( count) )
250+ let ( column, num_columns) = matrix_info ( inst) ;
251+ let ( column, num_columns) = ( column?, num_columns?) ;
252+ if matrix_stride > 0 {
253+ if is_row_major {
254+ // Row major: (num_rows - 1) * stride + num_columns * scalar_size
255+ // (C++ getSize, validate_decorations.cpp lines 364-374)
256+ let col_inst =
257+ definitions. get ( & ResultId :: try_from ( u32:: from ( column) ) . ok ( ) ?) ?;
258+ let ( scalar_type, num_rows) = vector_info ( col_inst) ;
259+ let ( scalar_type, num_rows) = ( scalar_type?, num_rows?) ;
260+ let scalar_size =
261+ type_layout_size ( scalar_type, module, definitions, visiting, false , 0 ) ?;
262+ Some (
263+ num_rows
264+ . saturating_sub ( 1 )
265+ . saturating_mul ( matrix_stride)
266+ . saturating_add ( num_columns. saturating_mul ( scalar_size) ) ,
267+ )
268+ } else {
269+ // Column major: num_columns * stride
270+ // (C++ getSize, validate_decorations.cpp lines 362-363)
271+ Some ( num_columns. saturating_mul ( matrix_stride) )
272+ }
273+ } else {
274+ // No stride info, fall back to raw computation.
275+ let col_size =
276+ type_layout_size ( column, module, definitions, visiting, false , 0 ) ?;
277+ Some ( col_size. saturating_mul ( num_columns) )
278+ }
251279 }
252280 Op :: TypeArray => {
253281 let elem = inst. operands . first ( ) . and_then ( |op| match op {
254282 rspirv:: dr:: Operand :: IdRef ( id) => TypeId :: try_from ( * id) . ok ( ) ,
255283 _ => None ,
256284 } ) ?;
257- let elem_size = type_layout_size ( elem, definitions, visiting) ?;
285+ let elem_size =
286+ type_layout_size ( elem, module, definitions, visiting, is_row_major, matrix_stride) ?;
258287 let len = array_length ( inst, definitions) ?;
259288 Some ( elem_size. saturating_mul ( len) )
260289 }
261290 Op :: TypeRuntimeArray => None , // unsized
262291 Op :: TypeStruct => {
263- let mut offset: u32 = 0 ;
264- for op in & inst. operands {
265- let ty = match op {
292+ let struct_id = ResultId :: try_from ( u32:: from ( ty) ) . ok ( ) ?;
293+ let mut total: u32 = 0 ;
294+ for ( idx, op) in inst. operands . iter ( ) . enumerate ( ) {
295+ let member_ty = match op {
266296 rspirv:: dr:: Operand :: IdRef ( id) => TypeId :: try_from ( * id) . ok ( ) ?,
267297 _ => return None ,
268298 } ;
269- let size = type_layout_size ( ty, definitions, visiting) ?;
270- offset = offset. saturating_add ( size) ;
299+ let member_rm =
300+ member_is_row_major ( module, struct_id, MemberIndex ( idx as u32 ) ) ;
301+ let member_ms =
302+ member_matrix_stride ( module, struct_id, MemberIndex ( idx as u32 ) )
303+ . unwrap_or ( 0 ) ;
304+ let size = type_layout_size (
305+ member_ty,
306+ module,
307+ definitions,
308+ visiting,
309+ member_rm,
310+ member_ms,
311+ ) ?;
312+ total = total. saturating_add ( size) ;
271313 }
272- Some ( offset )
314+ Some ( total )
273315 }
274316 _ => None ,
275317 } ;
@@ -295,12 +337,17 @@ fn round_up(value: u32, align: u32) -> u32 {
295337///
296338/// When `scalar_layout` is true, vectors use scalar element alignment instead
297339/// of the standard 2N/4N rules.
340+ ///
341+ /// When `is_row_major` is true, matrices use a virtual row-vector alignment
342+ /// (based on the number of columns) instead of column-vector alignment.
298343fn type_alignment (
299344 ty : TypeId ,
345+ module : & Module ,
300346 definitions : & HashMap < ResultId , rspirv:: dr:: Instruction > ,
301347 visiting : & mut HashSet < TypeId > ,
302348 scalar_layout : bool ,
303349 extended_alignment : bool ,
350+ is_row_major : bool ,
304351) -> Option < u32 > {
305352 if !visiting. insert ( ty) {
306353 return None ;
@@ -314,8 +361,15 @@ fn type_alignment(
314361 Op :: TypeVector => {
315362 let ( elem, count) = vector_info ( inst) ;
316363 let ( elem, count) = ( elem?, count?) ;
317- let elem_align =
318- type_alignment ( elem, definitions, visiting, scalar_layout, extended_alignment) ?;
364+ let elem_align = type_alignment (
365+ elem,
366+ module,
367+ definitions,
368+ visiting,
369+ scalar_layout,
370+ extended_alignment,
371+ false ,
372+ ) ?;
319373 if scalar_layout {
320374 Some ( elem_align)
321375 } else {
@@ -325,11 +379,38 @@ fn type_alignment(
325379 }
326380 }
327381 Op :: TypeMatrix => {
328- // Matrix alignment follows its column vector alignment.
329- let ( column, _) = matrix_info ( inst) ;
330- let column = column?;
331- let base_align =
332- type_alignment ( column, definitions, visiting, scalar_layout, extended_alignment) ?;
382+ let ( column, num_columns) = matrix_info ( inst) ;
383+ let ( column, num_columns) = ( column?, num_columns?) ;
384+ let base_align = if is_row_major && !scalar_layout {
385+ // Row-major: alignment of a virtual vector of num_columns scalar
386+ // components (C++ getBaseAlignment, validate_decorations.cpp:210-219).
387+ let col_inst =
388+ definitions. get ( & ResultId :: try_from ( u32:: from ( column) ) . ok ( ) ?) ?;
389+ let ( scalar_type, _) = vector_info ( col_inst) ;
390+ let scalar_type = scalar_type?;
391+ let scalar_align = type_alignment (
392+ scalar_type,
393+ module,
394+ definitions,
395+ visiting,
396+ scalar_layout,
397+ extended_alignment,
398+ false ,
399+ ) ?;
400+ let multiplier = if num_columns == 2 { 2 } else { 4 } ;
401+ scalar_align. checked_mul ( multiplier) ?
402+ } else {
403+ // Column-major (or scalar layout): alignment of column vector.
404+ type_alignment (
405+ column,
406+ module,
407+ definitions,
408+ visiting,
409+ scalar_layout,
410+ extended_alignment,
411+ false ,
412+ ) ?
413+ } ;
333414 if extended_alignment && !scalar_layout {
334415 Some ( round_up ( base_align, 16 ) )
335416 } else {
@@ -341,23 +422,43 @@ fn type_alignment(
341422 rspirv:: dr:: Operand :: IdRef ( id) => TypeId :: try_from ( * id) . ok ( ) ,
342423 _ => None ,
343424 } ) ?;
344- let base_align =
345- type_alignment ( elem, definitions, visiting, scalar_layout, extended_alignment) ?;
425+ // Propagate is_row_major through arrays (array of matrices inherits
426+ // majorness from the struct member decoration).
427+ let base_align = type_alignment (
428+ elem,
429+ module,
430+ definitions,
431+ visiting,
432+ scalar_layout,
433+ extended_alignment,
434+ is_row_major,
435+ ) ?;
346436 if extended_alignment && !scalar_layout {
347437 Some ( round_up ( base_align, 16 ) )
348438 } else {
349439 Some ( base_align)
350440 }
351441 }
352442 Op :: TypeStruct => {
443+ let struct_id = ResultId :: try_from ( u32:: from ( ty) ) . ok ( ) ?;
353444 let mut max_align = 1 ;
354- for op in & inst. operands {
355- let ty = match op {
445+ for ( idx , op ) in inst. operands . iter ( ) . enumerate ( ) {
446+ let member_ty = match op {
356447 rspirv:: dr:: Operand :: IdRef ( id) => TypeId :: try_from ( * id) . ok ( ) ?,
357448 _ => return None ,
358449 } ;
359- let align =
360- type_alignment ( ty, definitions, visiting, scalar_layout, extended_alignment) ?;
450+ // Each struct member has its own RowMajor/ColMajor decoration.
451+ let member_rm =
452+ member_is_row_major ( module, struct_id, MemberIndex ( idx as u32 ) ) ;
453+ let align = type_alignment (
454+ member_ty,
455+ module,
456+ definitions,
457+ visiting,
458+ scalar_layout,
459+ extended_alignment,
460+ member_rm,
461+ ) ?;
361462 max_align = max_align. max ( align) ;
362463 }
363464 if extended_alignment && !scalar_layout {
@@ -374,13 +475,14 @@ fn type_alignment(
374475
375476fn vector_scalar_alignment (
376477 vector_inst : & rspirv:: dr:: Instruction ,
478+ module : & Module ,
377479 definitions : & HashMap < ResultId , rspirv:: dr:: Instruction > ,
378480) -> Option < u32 > {
379481 let elem = vector_inst. operands . first ( ) . and_then ( |op| match op {
380482 rspirv:: dr:: Operand :: IdRef ( id) => TypeId :: try_from ( * id) . ok ( ) ,
381483 _ => None ,
382484 } ) ?;
383- type_alignment ( elem, definitions, & mut HashSet :: new ( ) , true , false )
485+ type_alignment ( elem, module , definitions, & mut HashSet :: new ( ) , true , false , false )
384486}
385487
386488fn vector_info ( inst : & rspirv:: dr:: Instruction ) -> ( Option < TypeId > , Option < u32 > ) {
@@ -580,12 +682,20 @@ fn check_struct_layout(
580682 continue ;
581683 } ;
582684
685+ // Look up matrix majorness and stride for this member (inherited
686+ // through arrays to contained matrices, matching C++ LayoutConstraints).
687+ let is_row_major = member_is_row_major ( module, struct_id, MemberIndex ( member_idx) ) ;
688+ let mat_stride =
689+ member_matrix_stride ( module, struct_id, MemberIndex ( member_idx) ) . unwrap_or ( 0 ) ;
690+
583691 let Some ( alignment) = type_alignment (
584692 member_type_id,
693+ module,
585694 definitions,
586695 & mut HashSet :: new ( ) ,
587696 scalar_layout,
588697 extended_alignment,
698+ is_row_major,
589699 ) else {
590700 continue ;
591701 } ;
@@ -594,7 +704,14 @@ fn check_struct_layout(
594704 let size = if member_inst. class . opcode == Op :: TypeRuntimeArray {
595705 0
596706 } else {
597- match type_layout_size ( member_type_id, definitions, & mut HashSet :: new ( ) ) {
707+ match type_layout_size (
708+ member_type_id,
709+ module,
710+ definitions,
711+ & mut HashSet :: new ( ) ,
712+ is_row_major,
713+ mat_stride,
714+ ) {
598715 Some ( s) => s,
599716 None => continue ,
600717 }
@@ -614,7 +731,7 @@ fn check_struct_layout(
614731 // Offset alignment checks.
615732 if relax_block_layout && !scalar_layout && member_inst. class . opcode == Op :: TypeVector {
616733 // Relaxed layout: vector offset aligned to scalar element alignment.
617- let Some ( scalar_align) = vector_scalar_alignment ( member_inst, definitions) else {
734+ let Some ( scalar_align) = vector_scalar_alignment ( member_inst, module , definitions) else {
618735 continue ;
619736 } ;
620737 if offset % scalar_align != 0 {
@@ -702,7 +819,7 @@ fn check_struct_layout(
702819 let ( column_type, _) = matrix_info ( member_inst) ;
703820 if let Some ( col_ty) = column_type {
704821 if let Some ( col_size) =
705- type_layout_size ( col_ty, definitions, & mut HashSet :: new ( ) )
822+ type_layout_size ( col_ty, module , definitions, & mut HashSet :: new ( ) , false , 0 )
706823 {
707824 if col_size > stride {
708825 return Err ( ValidationError :: InvalidBlockLayout {
@@ -714,20 +831,6 @@ fn check_struct_layout(
714831 }
715832 . into ( ) ) ;
716833 }
717- if relax_block_layout
718- && !scalar_layout
719- && member_is_row_major ( module, struct_id, MemberIndex ( member_idx) )
720- && col_size > 16
721- && ( offset % 16 ) . saturating_add ( col_size) > 16
722- {
723- return Err ( ValidationError :: InvalidBlockLayout {
724- struct_type : struct_id,
725- reason :
726- "row-major matrix straddles 16-byte boundary under relaxed layout"
727- . to_string ( ) ,
728- }
729- . into ( ) ) ;
730- }
731834 }
732835 }
733836 }
@@ -823,9 +926,14 @@ fn check_struct_layout(
823926
824927 // Check element_size <= stride (C++ lines 700-706).
825928 if stride_val > 0 {
826- if let Some ( element_size) =
827- type_layout_size ( elem_type, definitions, & mut HashSet :: new ( ) )
828- {
929+ if let Some ( element_size) = type_layout_size (
930+ elem_type,
931+ module,
932+ definitions,
933+ & mut HashSet :: new ( ) ,
934+ is_row_major,
935+ mat_stride,
936+ ) {
829937 if element_size > stride_val {
830938 return Err ( ValidationError :: InvalidBlockLayout {
831939 struct_type : struct_id,
@@ -844,10 +952,12 @@ fn check_struct_layout(
844952 array_result_id = elem_result_id;
845953 array_alignment = type_alignment (
846954 elem_type,
955+ module,
847956 definitions,
848957 & mut HashSet :: new ( ) ,
849958 scalar_layout,
850959 extended_alignment,
960+ is_row_major,
851961 )
852962 . unwrap_or ( 1 ) ;
853963 }
0 commit comments