@@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
197197 dest .write ("ai += {M}*2;" )
198198 dest .write ()
199199
200-
201- accumulation_regs = a_regs * N * settings [ 'LMUL_ACC' ]. value
200+ # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
201+ accumulation_regs = a_regs * N
202202 dest .write ("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k" ,
203203 a_regs = a_regs * 2 , accumulation_regs = accumulation_regs * 2
204204 )
205205 pass_regs = (accumulation_regs + a_regs )* 2
206- tmp_regs = 32 - pass_regs
206+ tmp_regs = ( 32 // settings [ 'LMUL_ACC' ]. value ) - pass_regs
207207 if tmp_regs < 2 :
208208 raise RuntimeError ("Complex kernel would use too many registers!" )
209209
@@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
337337
338338 M = settings ['M' ].value
339339 N = settings ['N' ].value
340- vlenmax = int ( settings ['reg_width_bits' ].value / settings ['ELEN_PARAM' ].value )
340+ vlenmax = int (settings ['reg_width_bits' ].value * settings ['LMUL_ACC' ].value /
341+ settings ['ELEN_PARAM' ].value )
341342 a_regs = max (int (M / vlenmax ), 1 )
342343
343- accumulation_regs = a_regs * N * settings ['LMUL_ACC' ].value
344+ # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
345+ accumulation_regs = a_regs * N
344346 required_regs = accumulation_regs + a_regs
345347 if is_complex :
346348 required_regs = required_regs * 2 + 2
@@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
380382''' .format (tail_policy = settings ['tail_policy' ].value ))
381383
382384
383- if required_regs > 32 :
384- raise Exception ("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available" .format (
385- required_regs , N , M , (" with wide accumulator" if settings ['LMUL_ACC' ].value > 1 else '' )
385+ if required_regs > ( 32 // settings [ 'LMUL_ACC' ]. value ) :
386+ raise Exception ("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available" .format (
387+ required_regs , N , M , (" with wide accumulator" if settings ['LMUL_ACC' ].value > 1 else '' ), 32 // settings [ 'LMUL_ACC' ]. value
386388 ))
387389
388390 TRMM = (settings ['op' ].value == 'trmm' )
@@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
448450def generate_M_tails ( dest , settings , M , N ):
449451 M_tail = int (M / 2 )
450452 M_tail_min = settings ['M_tail_scalar_from' ].value
451- vlenmax = int ( settings ['reg_width_bits' ].value / settings ['ELEN_PARAM' ].value )
453+ vlenmax = int (settings ['reg_width_bits' ].value * settings ['LMUL_ACC' ].value
454+ / settings ['ELEN_PARAM' ].value )
452455 TRMM = (settings ['op' ].value == 'trmm' )
453456 is_complex = settings ['complex' ].value
454457 generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
@@ -667,4 +670,4 @@ def OUTPUT(*args, **kwargs):
667670 ERROR ("unsupported kernel type {}" .format (settings ['op' ]))
668671
669672if __name__ == "__main__" :
670- main ()
673+ main ()
0 commit comments