Skip to content

Commit 9c16449

Browse files
committed
Add K-unrolling to M = 8. Other small changes.
1 parent 6d6af1d commit 9c16449

1 file changed

Lines changed: 166 additions & 23 deletions

File tree

kernel/riscv64/sgemm_kernel_16x8_zvl256b.c

Lines changed: 166 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
151151
B += N;
152152

153153
#ifdef GEMM_NEW_PACKING
154-
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0xE)], B0, N);
154+
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0x6)], B0, N);
155155
A0 += M;
156156
#else
157157
resultE = __riscv_vfmacc_vf_f32m1(resultE, A3[0], B0, N);
@@ -283,11 +283,11 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
283283

284284
#ifdef GEMM_NEW_PACKING
285285
if (M & 2) {
286-
resultC = __riscv_vfmacc_vf_f32m1(resultC, A0[0 + (M & 0xC)], B0, N);
287-
resultD = __riscv_vfmacc_vf_f32m1(resultD, A0[1 + (M & 0xC)], B0, N);
286+
resultC = __riscv_vfmacc_vf_f32m1(resultC, A0[0 + (M & 0x4)], B0, N);
287+
resultD = __riscv_vfmacc_vf_f32m1(resultD, A0[1 + (M & 0x4)], B0, N);
288288
}
289289
if (M & 1) {
290-
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0xE)], B0, N);
290+
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0x6)], B0, N);
291291
}
292292
A0 += M;
293293
#else
@@ -302,13 +302,61 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
302302
}
303303
#endif
304304
}
305-
} else if (M <= 7) {
305+
} else if (M <= 8) {
306+
vfloat32m1_t A4;
307+
306308
if (K >= 2) {
307-
vfloat32m2_t B00 = __riscv_vle32_v_f32m2(B, N * 2);
308-
B0 = __riscv_vget_v_f32m2_f32m1(B00, 0);
309-
B1 = __riscv_vget_v_f32m2_f32m1(B00, 1);
310-
B += (N * 2);
309+
vfloat32m2_t B00, A00;
310+
vfloat32m1_t A5;
311+
vfloat32m1_t resultF;
311312

313+
if (!S2) {
314+
B00 = __riscv_vle32_v_f32m2(B, N * 2);
315+
B0 = __riscv_vget_v_f32m2_f32m1(B00, 0);
316+
B1 = __riscv_vget_v_f32m2_f32m1(B00, 1);
317+
}
318+
319+
if (M == 8) {
320+
if (S2) {
321+
A00 = __riscv_vle32_v_f32m2(A0, N * 2);
322+
A4 = __riscv_vget_v_f32m2_f32m1(A00, 0);
323+
A5 = __riscv_vget_v_f32m2_f32m1(A00, 1);
324+
325+
result0 = __riscv_vfmul_vf_f32m1(A4, B[0], N);
326+
result1 = __riscv_vfmul_vf_f32m1(A4, B[1], N);
327+
result2 = __riscv_vfmul_vf_f32m1(A4, B[2], N);
328+
result3 = __riscv_vfmul_vf_f32m1(A4, B[3], N);
329+
result4 = __riscv_vfmul_vf_f32m1(A4, B[4], N);
330+
result5 = __riscv_vfmul_vf_f32m1(A4, B[5], N);
331+
result6 = __riscv_vfmul_vf_f32m1(A4, B[6], N);
332+
result7 = __riscv_vfmul_vf_f32m1(A4, B[7], N);
333+
result8 = __riscv_vfmul_vf_f32m1(A5, B[8], N);
334+
result9 = __riscv_vfmul_vf_f32m1(A5, B[9], N);
335+
resultA = __riscv_vfmul_vf_f32m1(A5, B[10], N);
336+
resultB = __riscv_vfmul_vf_f32m1(A5, B[11], N);
337+
resultC = __riscv_vfmul_vf_f32m1(A5, B[12], N);
338+
resultD = __riscv_vfmul_vf_f32m1(A5, B[13], N);
339+
resultE = __riscv_vfmul_vf_f32m1(A5, B[14], N);
340+
resultF = __riscv_vfmul_vf_f32m1(A5, B[15], N);
341+
} else {
342+
result0 = __riscv_vfmul_vf_f32m1(B0, A0[0], N);
343+
result1 = __riscv_vfmul_vf_f32m1(B0, A0[1], N);
344+
result2 = __riscv_vfmul_vf_f32m1(B0, A0[2], N);
345+
result3 = __riscv_vfmul_vf_f32m1(B0, A0[3], N);
346+
result4 = __riscv_vfmul_vf_f32m1(B0, A0[4], N);
347+
result5 = __riscv_vfmul_vf_f32m1(B0, A0[5], N);
348+
result6 = __riscv_vfmul_vf_f32m1(B0, A0[6], N);
349+
result7 = __riscv_vfmul_vf_f32m1(B0, A0[7], N);
350+
result8 = __riscv_vfmul_vf_f32m1(B1, A0[8], N);
351+
result9 = __riscv_vfmul_vf_f32m1(B1, A0[9], N);
352+
resultA = __riscv_vfmul_vf_f32m1(B1, A0[10], N);
353+
resultB = __riscv_vfmul_vf_f32m1(B1, A0[11], N);
354+
resultC = __riscv_vfmul_vf_f32m1(B1, A0[12], N);
355+
resultD = __riscv_vfmul_vf_f32m1(B1, A0[13], N);
356+
resultE = __riscv_vfmul_vf_f32m1(B1, A0[14], N);
357+
resultF = __riscv_vfmul_vf_f32m1(B1, A0[15], N);
358+
}
359+
}
312360
#ifdef GEMM_NEW_PACKING
313361
if (M & 4) {
314362
result8 = __riscv_vfmul_vf_f32m1(B0, A0[0 + (M * 0)], N);
@@ -355,14 +403,60 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
355403
result6 = __riscv_vfmul_vf_f32m1(B1, A3[0 + (1 * 1)], N);
356404
A3 += (1 * 2);
357405
}
406+
if (M == 8) {
407+
A0 += (N * 2);
408+
}
358409
#endif
410+
B += (N * 2);
359411

360412
for (BLASLONG k = (K / 2); --k; ) {
361-
B00 = __riscv_vle32_v_f32m2(B, N * 2);
362-
B0 = __riscv_vget_v_f32m2_f32m1(B00, 0);
363-
B1 = __riscv_vget_v_f32m2_f32m1(B00, 1);
364-
B += (N * 2);
413+
if (!S2) {
414+
B00 = __riscv_vle32_v_f32m2(B, N * 2);
415+
B0 = __riscv_vget_v_f32m2_f32m1(B00, 0);
416+
B1 = __riscv_vget_v_f32m2_f32m1(B00, 1);
417+
}
365418

419+
if (M == 8) {
420+
if (S2) {
421+
A00 = __riscv_vle32_v_f32m2(A0, N * 2);
422+
A4 = __riscv_vget_v_f32m2_f32m1(A00, 0);
423+
A5 = __riscv_vget_v_f32m2_f32m1(A00, 1);
424+
425+
result0 = __riscv_vfmacc_vf_f32m1(result0, B[0], A4, N);
426+
result1 = __riscv_vfmacc_vf_f32m1(result1, B[1], A4, N);
427+
result2 = __riscv_vfmacc_vf_f32m1(result2, B[2], A4, N);
428+
result3 = __riscv_vfmacc_vf_f32m1(result3, B[3], A4, N);
429+
result4 = __riscv_vfmacc_vf_f32m1(result4, B[4], A4, N);
430+
result5 = __riscv_vfmacc_vf_f32m1(result5, B[5], A4, N);
431+
result6 = __riscv_vfmacc_vf_f32m1(result6, B[6], A4, N);
432+
result7 = __riscv_vfmacc_vf_f32m1(result7, B[7], A4, N);
433+
result8 = __riscv_vfmacc_vf_f32m1(result8, B[8], A5, N);
434+
result9 = __riscv_vfmacc_vf_f32m1(result9, B[9], A5, N);
435+
resultA = __riscv_vfmacc_vf_f32m1(resultA, B[10], A5, N);
436+
resultB = __riscv_vfmacc_vf_f32m1(resultB, B[11], A5, N);
437+
resultC = __riscv_vfmacc_vf_f32m1(resultC, B[12], A5, N);
438+
resultD = __riscv_vfmacc_vf_f32m1(resultD, B[13], A5, N);
439+
resultE = __riscv_vfmacc_vf_f32m1(resultE, B[14], A5, N);
440+
resultF = __riscv_vfmacc_vf_f32m1(resultF, B[15], A5, N);
441+
} else {
442+
result0 = __riscv_vfmacc_vf_f32m1(result0, A0[0], B0, N);
443+
result1 = __riscv_vfmacc_vf_f32m1(result1, A0[1], B0, N);
444+
result2 = __riscv_vfmacc_vf_f32m1(result2, A0[2], B0, N);
445+
result3 = __riscv_vfmacc_vf_f32m1(result3, A0[3], B0, N);
446+
result4 = __riscv_vfmacc_vf_f32m1(result4, A0[4], B0, N);
447+
result5 = __riscv_vfmacc_vf_f32m1(result5, A0[5], B0, N);
448+
result6 = __riscv_vfmacc_vf_f32m1(result6, A0[6], B0, N);
449+
result7 = __riscv_vfmacc_vf_f32m1(result7, A0[7], B0, N);
450+
result8 = __riscv_vfmacc_vf_f32m1(result8, A0[8], B1, N);
451+
result9 = __riscv_vfmacc_vf_f32m1(result9, A0[9], B1, N);
452+
resultA = __riscv_vfmacc_vf_f32m1(resultA, A0[10], B1, N);
453+
resultB = __riscv_vfmacc_vf_f32m1(resultB, A0[11], B1, N);
454+
resultC = __riscv_vfmacc_vf_f32m1(resultC, A0[12], B1, N);
455+
resultD = __riscv_vfmacc_vf_f32m1(resultD, A0[13], B1, N);
456+
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[14], B1, N);
457+
resultF = __riscv_vfmacc_vf_f32m1(resultF, A0[15], B1, N);
458+
}
459+
}
366460
#ifdef GEMM_NEW_PACKING
367461
if (M & 4) {
368462
result8 = __riscv_vfmacc_vf_f32m1(result8, A0[0 + (M * 0)], B0, N);
@@ -409,9 +503,23 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
409503
result6 = __riscv_vfmacc_vf_f32m1(result6, A3[0 + (1 * 1)], B1, N);
410504
A3 += (1 * 2);
411505
}
506+
if (M == 8) {
507+
A0 += (N * 2);
508+
}
412509
#endif
510+
B += (N * 2);
413511
}
414512

513+
if (M == 8) {
514+
result0 = __riscv_vfadd_vv_f32m1(result0, result8, N);
515+
result1 = __riscv_vfadd_vv_f32m1(result1, result9, N);
516+
result2 = __riscv_vfadd_vv_f32m1(result2, resultA, N);
517+
result3 = __riscv_vfadd_vv_f32m1(result3, resultB, N);
518+
result4 = __riscv_vfadd_vv_f32m1(result4, resultC, N);
519+
result5 = __riscv_vfadd_vv_f32m1(result5, resultD, N);
520+
result6 = __riscv_vfadd_vv_f32m1(result6, resultE, N);
521+
result7 = __riscv_vfadd_vv_f32m1(result7, resultF, N);
522+
}
415523
if (M & 4) {
416524
result8 = __riscv_vfadd_vv_f32m1(result8, result0, N);
417525
result9 = __riscv_vfadd_vv_f32m1(result9, result1, N);
@@ -426,6 +534,16 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
426534
resultE = __riscv_vfadd_vv_f32m1(resultE, result6, N);
427535
}
428536
} else {
537+
if (M == 8) {
538+
result0 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
539+
result1 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
540+
result2 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
541+
result3 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
542+
result4 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
543+
result5 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
544+
result6 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
545+
result7 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
546+
}
429547
if (M & 4) {
430548
result8 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
431549
result9 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
@@ -442,21 +560,45 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
442560
}
443561

444562
if (K & 1) {
445-
B0 = __riscv_vle32_v_f32m1(B, N);
563+
if (!S2) {
564+
B0 = __riscv_vle32_v_f32m1(B, N);
565+
}
446566

567+
if (M == 8) {
568+
if (S2) {
569+
A4 = __riscv_vle32_v_f32m1(A0, N);
570+
result0 = __riscv_vfmacc_vf_f32m1(result0, B[0], A4, N);
571+
result1 = __riscv_vfmacc_vf_f32m1(result1, B[1], A4, N);
572+
result2 = __riscv_vfmacc_vf_f32m1(result2, B[2], A4, N);
573+
result3 = __riscv_vfmacc_vf_f32m1(result3, B[3], A4, N);
574+
result4 = __riscv_vfmacc_vf_f32m1(result4, B[4], A4, N);
575+
result5 = __riscv_vfmacc_vf_f32m1(result5, B[5], A4, N);
576+
result6 = __riscv_vfmacc_vf_f32m1(result6, B[6], A4, N);
577+
result7 = __riscv_vfmacc_vf_f32m1(result7, B[7], A4, N);
578+
} else {
579+
result0 = __riscv_vfmacc_vf_f32m1(result0, A0[0], B0, N);
580+
result1 = __riscv_vfmacc_vf_f32m1(result1, A0[1], B0, N);
581+
result2 = __riscv_vfmacc_vf_f32m1(result2, A0[2], B0, N);
582+
result3 = __riscv_vfmacc_vf_f32m1(result3, A0[3], B0, N);
583+
result4 = __riscv_vfmacc_vf_f32m1(result4, A0[4], B0, N);
584+
result5 = __riscv_vfmacc_vf_f32m1(result5, A0[5], B0, N);
585+
result6 = __riscv_vfmacc_vf_f32m1(result6, A0[6], B0, N);
586+
result7 = __riscv_vfmacc_vf_f32m1(result7, A0[7], B0, N);
587+
}
588+
}
447589
#ifdef GEMM_NEW_PACKING
448590
if (M & 4) {
449-
result8 = __riscv_vfmacc_vf_f32m1(result8, A0[0 + (M & 0x8)], B0, N);
450-
result9 = __riscv_vfmacc_vf_f32m1(result9, A0[1 + (M & 0x8)], B0, N);
451-
resultA = __riscv_vfmacc_vf_f32m1(resultA, A0[2 + (M & 0x8)], B0, N);
452-
resultB = __riscv_vfmacc_vf_f32m1(resultB, A0[3 + (M & 0x8)], B0, N);
591+
result8 = __riscv_vfmacc_vf_f32m1(result8, A0[0], B0, N);
592+
result9 = __riscv_vfmacc_vf_f32m1(result9, A0[1], B0, N);
593+
resultA = __riscv_vfmacc_vf_f32m1(resultA, A0[2], B0, N);
594+
resultB = __riscv_vfmacc_vf_f32m1(resultB, A0[3], B0, N);
453595
}
454596
if (M & 2) {
455-
resultC = __riscv_vfmacc_vf_f32m1(resultC, A0[0 + (M & 0xC)], B0, N);
456-
resultD = __riscv_vfmacc_vf_f32m1(resultD, A0[1 + (M & 0xC)], B0, N);
597+
resultC = __riscv_vfmacc_vf_f32m1(resultC, A0[0 + (M & 0x4)], B0, N);
598+
resultD = __riscv_vfmacc_vf_f32m1(resultD, A0[1 + (M & 0x4)], B0, N);
457599
}
458600
if (M & 1) {
459-
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0xE)], B0, N);
601+
resultE = __riscv_vfmacc_vf_f32m1(resultE, A0[0 + (M & 0x6)], B0, N);
460602
}
461603
#else
462604
if (M & 4) {
@@ -1488,8 +1630,9 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
14881630
}
14891631
#endif
14901632

1491-
vfloat32m1_t A0 = __riscv_vle32_v_f32m1(A + 0, 8);
1492-
vfloat32m1_t A1 = __riscv_vle32_v_f32m1(A + 8, 8);
1633+
vfloat32m2_t A00 = __riscv_vle32_v_f32m2(A, 8 * 2);
1634+
vfloat32m1_t A0 = __riscv_vget_v_f32m2_f32m1(A00, 0);
1635+
vfloat32m1_t A1 = __riscv_vget_v_f32m2_f32m1(A00, 1);
14931636
A += 16;
14941637

14951638
vfloat32m1_t result0, result1, result2, result3, result4, result5, result6, result7;

0 commit comments

Comments
 (0)