Skip to content

Commit 4121a22

Browse files
committed
Convert BF16 values once (and vectorized).
1 parent 3356043 commit 4121a22

2 files changed

Lines changed: 195 additions & 71 deletions

File tree

kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c

Lines changed: 108 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,41 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
1111
__bf16 *BB = (__bf16 *)(B);
1212
__bf16 *AA = (__bf16 *)(A);
1313

14+
#ifdef BF16_WIDEN_ONE
15+
FLOAT *B_CONV = NULL;
16+
if ((M >= 4) && (N >= 4) && (K > 0)) {
17+
B_CONV = (FLOAT *)(malloc(K * 8 * sizeof(FLOAT)));
18+
if (!B_CONV) return 1;
19+
}
20+
#endif
21+
1422
// -- MAIN PASS
1523
for (BLASLONG j=0; j<N/8; j+=1) {
1624
m_top = 0;
1725
BLASLONG gvl = __riscv_vsetvl_e16m1(16);
26+
#ifdef BF16_WIDEN_ONE
27+
BLASLONG bi2;
28+
if (B_CONV) {
29+
BLASLONG bi3 = 0;
30+
BLASLONG gvl2;
31+
bi2 = K * 8;
32+
do {
33+
gvl2 = __riscv_vsetvl_e16m4(bi2);
34+
vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4(&BB[bi3 + (n_top*K)], gvl2);
35+
vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8(A00, gvl2);
36+
__riscv_vse32_v_f32m8(&B_CONV[bi3], A0, gvl2);
37+
bi3 += gvl2;
38+
} while (bi2 -= gvl2);
39+
}
40+
#endif
1841

1942
for (BLASLONG i=0; i<M/16; i+=1) {
2043
BLASLONG ai=m_top*K;
44+
#ifdef BF16_WIDEN_ONE
45+
bi2 = 0;
46+
#else
2147
BLASLONG bi=n_top*K;
48+
#endif
2249

2350
vfloat32m2_t result0 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);
2451
vfloat32m2_t result1 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);
@@ -31,15 +58,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
3158

3259
for (BLASLONG k=0; k<K; k++) {
3360
#ifdef BF16_WIDEN_ONE
34-
float B0 = (float)(BB[bi+0]);
35-
float B1 = (float)(BB[bi+1]);
36-
float B2 = (float)(BB[bi+2]);
37-
float B3 = (float)(BB[bi+3]);
38-
float B4 = (float)(BB[bi+4]);
39-
float B5 = (float)(BB[bi+5]);
40-
float B6 = (float)(BB[bi+6]);
41-
float B7 = (float)(BB[bi+7]);
42-
bi += 8;
61+
float B0 = B_CONV[bi2+0];
62+
float B1 = B_CONV[bi2+1];
63+
float B2 = B_CONV[bi2+2];
64+
float B3 = B_CONV[bi2+3];
65+
float B4 = B_CONV[bi2+4];
66+
float B5 = B_CONV[bi2+5];
67+
float B6 = B_CONV[bi2+6];
68+
float B7 = B_CONV[bi2+7];
69+
bi2 += 8;
4370

4471
vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
4572
vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2(A00, gvl);
@@ -117,7 +144,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
117144
gvl = __riscv_vsetvl_e16mf2(8);
118145

119146
BLASLONG ai=m_top*K;
147+
#ifdef BF16_WIDEN_ONE
148+
bi2 = 0;
149+
#else
120150
BLASLONG bi=n_top*K;
151+
#endif
121152

122153
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
123154
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
@@ -130,15 +161,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
130161

131162
for (BLASLONG k=0; k<K; k++) {
132163
#ifdef BF16_WIDEN_ONE
133-
float B0 = (float)(BB[bi+0]);
134-
float B1 = (float)(BB[bi+1]);
135-
float B2 = (float)(BB[bi+2]);
136-
float B3 = (float)(BB[bi+3]);
137-
float B4 = (float)(BB[bi+4]);
138-
float B5 = (float)(BB[bi+5]);
139-
float B6 = (float)(BB[bi+6]);
140-
float B7 = (float)(BB[bi+7]);
141-
bi += 8;
164+
float B0 = B_CONV[bi2+0];
165+
float B1 = B_CONV[bi2+1];
166+
float B2 = B_CONV[bi2+2];
167+
float B3 = B_CONV[bi2+3];
168+
float B4 = B_CONV[bi2+4];
169+
float B5 = B_CONV[bi2+5];
170+
float B6 = B_CONV[bi2+6];
171+
float B7 = B_CONV[bi2+7];
172+
bi2 += 8;
142173

143174
vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
144175
vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1(A00, gvl);
@@ -214,7 +245,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
214245
gvl = __riscv_vsetvl_e16mf2(4);
215246

216247
BLASLONG ai=m_top*K;
248+
#ifdef BF16_WIDEN_ONE
249+
bi2 = 0;
250+
#else
217251
BLASLONG bi=n_top*K;
252+
#endif
218253

219254
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
220255
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
@@ -227,15 +262,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
227262

228263
for (BLASLONG k=0; k < K; ++k) {
229264
#ifdef BF16_WIDEN_ONE
230-
float B0 = (float)(BB[bi+0]);
231-
float B1 = (float)(BB[bi+1]);
232-
float B2 = (float)(BB[bi+2]);
233-
float B3 = (float)(BB[bi+3]);
234-
float B4 = (float)(BB[bi+4]);
235-
float B5 = (float)(BB[bi+5]);
236-
float B6 = (float)(BB[bi+6]);
237-
float B7 = (float)(BB[bi+7]);
238-
bi += 8;
265+
float B0 = B_CONV[bi2+0];
266+
float B1 = B_CONV[bi2+1];
267+
float B2 = B_CONV[bi2+2];
268+
float B3 = B_CONV[bi2+3];
269+
float B4 = B_CONV[bi2+4];
270+
float B5 = B_CONV[bi2+5];
271+
float B6 = B_CONV[bi2+6];
272+
float B7 = B_CONV[bi2+7];
273+
bi2 += 8;
239274

240275
vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4( &AA[ai+0*gvl], gvl );
241276
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vfwcvtbf16_f_f_v_f32mf2(A00, gvl));
@@ -423,9 +458,29 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
423458
gvl = __riscv_vsetvl_e16m1(16);
424459
m_top = 0;
425460

461+
#ifdef BF16_WIDEN_ONE
462+
BLASLONG bi2;
463+
if (B_CONV) {
464+
BLASLONG bi3 = 0;
465+
BLASLONG gvl2;
466+
bi2 = K * 4;
467+
do {
468+
gvl2 = __riscv_vsetvl_e16m4(bi2);
469+
vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4(&BB[bi3 + (n_top*K)], gvl2);
470+
vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8(A00, gvl2);
471+
__riscv_vse32_v_f32m8(&B_CONV[bi3], A0, gvl2);
472+
bi3 += gvl2;
473+
} while (bi2 -= gvl2);
474+
}
475+
#endif
476+
426477
for (BLASLONG i=0; i<M/16; i+=1) {
427478
BLASLONG ai=m_top*K;
479+
#ifdef BF16_WIDEN_ONE
480+
bi2 = 0;
481+
#else
428482
BLASLONG bi=n_top*K;
483+
#endif
429484

430485
vfloat32m2_t result0 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);
431486
vfloat32m2_t result1 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);
@@ -434,11 +489,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
434489

435490
for (BLASLONG k=0; k<K; k++) {
436491
#ifdef BF16_WIDEN_ONE
437-
float B0 = (float)(BB[bi+0]);
438-
float B1 = (float)(BB[bi+1]);
439-
float B2 = (float)(BB[bi+2]);
440-
float B3 = (float)(BB[bi+3]);
441-
bi += 4;
492+
float B0 = B_CONV[bi2+0];
493+
float B1 = B_CONV[bi2+1];
494+
float B2 = B_CONV[bi2+2];
495+
float B3 = B_CONV[bi2+3];
496+
bi2 += 4;
442497

443498
vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
444499
vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2(A00, gvl);
@@ -489,7 +544,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
489544
if ( M & 8 ) {
490545
gvl = __riscv_vsetvl_e16mf2(8);
491546
BLASLONG ai=m_top*K;
547+
#ifdef BF16_WIDEN_ONE
548+
bi2 = 0;
549+
#else
492550
BLASLONG bi=n_top*K;
551+
#endif
493552

494553
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
495554
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
@@ -498,11 +557,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
498557

499558
for (BLASLONG k=0; k<K; k++) {
500559
#ifdef BF16_WIDEN_ONE
501-
float B0 = (float)(BB[bi+0]);
502-
float B1 = (float)(BB[bi+1]);
503-
float B2 = (float)(BB[bi+2]);
504-
float B3 = (float)(BB[bi+3]);
505-
bi += 4;
560+
float B0 = B_CONV[bi2+0];
561+
float B1 = B_CONV[bi2+1];
562+
float B2 = B_CONV[bi2+2];
563+
float B3 = B_CONV[bi2+3];
564+
bi2 += 4;
506565

507566
vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
508567
vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1(A00, gvl);
@@ -554,7 +613,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
554613
gvl = __riscv_vsetvl_e16mf2(4);
555614

556615
BLASLONG ai=m_top*K;
616+
#ifdef BF16_WIDEN_ONE
617+
bi2 = 0;
618+
#else
557619
BLASLONG bi=n_top*K;
620+
#endif
558621

559622
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
560623
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);
@@ -563,11 +626,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
563626

564627
for (BLASLONG k=0; k < K; ++k) {
565628
#ifdef BF16_WIDEN_ONE
566-
float B0 = (float)(BB[bi+0]);
567-
float B1 = (float)(BB[bi+1]);
568-
float B2 = (float)(BB[bi+2]);
569-
float B3 = (float)(BB[bi+3]);
570-
bi += 4;
629+
float B0 = B_CONV[bi2+0];
630+
float B1 = B_CONV[bi2+1];
631+
float B2 = B_CONV[bi2+2];
632+
float B3 = B_CONV[bi2+3];
633+
bi2 += 4;
571634

572635
vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4( &AA[ai+0*gvl], gvl );
573636
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vfwcvtbf16_f_f_v_f32mf2(A00, gvl));
@@ -977,5 +1040,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
9771040

9781041
n_top += 1;
9791042
}
1043+
#ifdef BF16_WIDEN_ONE
1044+
if (B_CONV) free(B_CONV);
1045+
#endif
9801046
return 0;
9811047
}

0 commit comments

Comments
 (0)