@@ -11,14 +11,41 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
1111 __bf16 * BB = (__bf16 * )(B );
1212 __bf16 * AA = (__bf16 * )(A );
1313
14+ #ifdef BF16_WIDEN_ONE
15+ FLOAT * B_CONV = NULL ;
16+ if ((M >= 4 ) && (N >= 4 ) && (K > 0 )) {
17+ B_CONV = (FLOAT * )(malloc (K * 8 * sizeof (FLOAT )));
18+ if (!B_CONV ) return 1 ;
19+ }
20+ #endif
21+
1422 // -- MAIN PASS
1523 for (BLASLONG j = 0 ; j < N /8 ; j += 1 ) {
1624 m_top = 0 ;
1725 BLASLONG gvl = __riscv_vsetvl_e16m1 (16 );
26+ #ifdef BF16_WIDEN_ONE
27+ BLASLONG bi2 ;
28+ if (B_CONV ) {
29+ BLASLONG bi3 = 0 ;
30+ BLASLONG gvl2 ;
31+ bi2 = K * 8 ;
32+ do {
33+ gvl2 = __riscv_vsetvl_e16m4 (bi2 );
34+ vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4 (& BB [bi3 + (n_top * K )], gvl2 );
35+ vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (A00 , gvl2 );
36+ __riscv_vse32_v_f32m8 (& B_CONV [bi3 ], A0 , gvl2 );
37+ bi3 += gvl2 ;
38+ } while (bi2 -= gvl2 );
39+ }
40+ #endif
1841
1942 for (BLASLONG i = 0 ; i < M /16 ; i += 1 ) {
2043 BLASLONG ai = m_top * K ;
44+ #ifdef BF16_WIDEN_ONE
45+ bi2 = 0 ;
46+ #else
2147 BLASLONG bi = n_top * K ;
48+ #endif
2249
2350 vfloat32m2_t result0 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
2451 vfloat32m2_t result1 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
@@ -31,15 +58,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
3158
3259 for (BLASLONG k = 0 ; k < K ; k ++ ) {
3360#ifdef BF16_WIDEN_ONE
34- float B0 = ( float )( BB [ bi + 0 ]) ;
35- float B1 = ( float )( BB [ bi + 1 ]) ;
36- float B2 = ( float )( BB [ bi + 2 ]) ;
37- float B3 = ( float )( BB [ bi + 3 ]) ;
38- float B4 = ( float )( BB [ bi + 4 ]) ;
39- float B5 = ( float )( BB [ bi + 5 ]) ;
40- float B6 = ( float )( BB [ bi + 6 ]) ;
41- float B7 = ( float )( BB [ bi + 7 ]) ;
42- bi += 8 ;
61+ float B0 = B_CONV [ bi2 + 0 ];
62+ float B1 = B_CONV [ bi2 + 1 ];
63+ float B2 = B_CONV [ bi2 + 2 ];
64+ float B3 = B_CONV [ bi2 + 3 ];
65+ float B4 = B_CONV [ bi2 + 4 ];
66+ float B5 = B_CONV [ bi2 + 5 ];
67+ float B6 = B_CONV [ bi2 + 6 ];
68+ float B7 = B_CONV [ bi2 + 7 ];
69+ bi2 += 8 ;
4370
4471 vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
4572 vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
@@ -117,7 +144,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
117144 gvl = __riscv_vsetvl_e16mf2 (8 );
118145
119146 BLASLONG ai = m_top * K ;
147+ #ifdef BF16_WIDEN_ONE
148+ bi2 = 0 ;
149+ #else
120150 BLASLONG bi = n_top * K ;
151+ #endif
121152
122153 vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
123154 vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
@@ -130,15 +161,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
130161
131162 for (BLASLONG k = 0 ; k < K ; k ++ ) {
132163#ifdef BF16_WIDEN_ONE
133- float B0 = ( float )( BB [ bi + 0 ]) ;
134- float B1 = ( float )( BB [ bi + 1 ]) ;
135- float B2 = ( float )( BB [ bi + 2 ]) ;
136- float B3 = ( float )( BB [ bi + 3 ]) ;
137- float B4 = ( float )( BB [ bi + 4 ]) ;
138- float B5 = ( float )( BB [ bi + 5 ]) ;
139- float B6 = ( float )( BB [ bi + 6 ]) ;
140- float B7 = ( float )( BB [ bi + 7 ]) ;
141- bi += 8 ;
164+ float B0 = B_CONV [ bi2 + 0 ];
165+ float B1 = B_CONV [ bi2 + 1 ];
166+ float B2 = B_CONV [ bi2 + 2 ];
167+ float B3 = B_CONV [ bi2 + 3 ];
168+ float B4 = B_CONV [ bi2 + 4 ];
169+ float B5 = B_CONV [ bi2 + 5 ];
170+ float B6 = B_CONV [ bi2 + 6 ];
171+ float B7 = B_CONV [ bi2 + 7 ];
172+ bi2 += 8 ;
142173
143174 vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
144175 vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
@@ -214,7 +245,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
214245 gvl = __riscv_vsetvl_e16mf2 (4 );
215246
216247 BLASLONG ai = m_top * K ;
248+ #ifdef BF16_WIDEN_ONE
249+ bi2 = 0 ;
250+ #else
217251 BLASLONG bi = n_top * K ;
252+ #endif
218253
219254 vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
220255 vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
@@ -227,15 +262,15 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
227262
228263 for (BLASLONG k = 0 ; k < K ; ++ k ) {
229264#ifdef BF16_WIDEN_ONE
230- float B0 = ( float )( BB [ bi + 0 ]) ;
231- float B1 = ( float )( BB [ bi + 1 ]) ;
232- float B2 = ( float )( BB [ bi + 2 ]) ;
233- float B3 = ( float )( BB [ bi + 3 ]) ;
234- float B4 = ( float )( BB [ bi + 4 ]) ;
235- float B5 = ( float )( BB [ bi + 5 ]) ;
236- float B6 = ( float )( BB [ bi + 6 ]) ;
237- float B7 = ( float )( BB [ bi + 7 ]) ;
238- bi += 8 ;
265+ float B0 = B_CONV [ bi2 + 0 ];
266+ float B1 = B_CONV [ bi2 + 1 ];
267+ float B2 = B_CONV [ bi2 + 2 ];
268+ float B3 = B_CONV [ bi2 + 3 ];
269+ float B4 = B_CONV [ bi2 + 4 ];
270+ float B5 = B_CONV [ bi2 + 5 ];
271+ float B6 = B_CONV [ bi2 + 6 ];
272+ float B7 = B_CONV [ bi2 + 7 ];
273+ bi2 += 8 ;
239274
240275 vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
241276 vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
@@ -423,9 +458,29 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
423458 gvl = __riscv_vsetvl_e16m1 (16 );
424459 m_top = 0 ;
425460
461+ #ifdef BF16_WIDEN_ONE
462+ BLASLONG bi2 ;
463+ if (B_CONV ) {
464+ BLASLONG bi3 = 0 ;
465+ BLASLONG gvl2 ;
466+ bi2 = K * 4 ;
467+ do {
468+ gvl2 = __riscv_vsetvl_e16m4 (bi2 );
469+ vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4 (& BB [bi3 + (n_top * K )], gvl2 );
470+ vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (A00 , gvl2 );
471+ __riscv_vse32_v_f32m8 (& B_CONV [bi3 ], A0 , gvl2 );
472+ bi3 += gvl2 ;
473+ } while (bi2 -= gvl2 );
474+ }
475+ #endif
476+
426477 for (BLASLONG i = 0 ; i < M /16 ; i += 1 ) {
427478 BLASLONG ai = m_top * K ;
479+ #ifdef BF16_WIDEN_ONE
480+ bi2 = 0 ;
481+ #else
428482 BLASLONG bi = n_top * K ;
483+ #endif
429484
430485 vfloat32m2_t result0 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
431486 vfloat32m2_t result1 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
@@ -434,11 +489,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
434489
435490 for (BLASLONG k = 0 ; k < K ; k ++ ) {
436491#ifdef BF16_WIDEN_ONE
437- float B0 = ( float )( BB [ bi + 0 ]) ;
438- float B1 = ( float )( BB [ bi + 1 ]) ;
439- float B2 = ( float )( BB [ bi + 2 ]) ;
440- float B3 = ( float )( BB [ bi + 3 ]) ;
441- bi += 4 ;
492+ float B0 = B_CONV [ bi2 + 0 ];
493+ float B1 = B_CONV [ bi2 + 1 ];
494+ float B2 = B_CONV [ bi2 + 2 ];
495+ float B3 = B_CONV [ bi2 + 3 ];
496+ bi2 += 4 ;
442497
443498 vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
444499 vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
@@ -489,7 +544,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
489544 if ( M & 8 ) {
490545 gvl = __riscv_vsetvl_e16mf2 (8 );
491546 BLASLONG ai = m_top * K ;
547+ #ifdef BF16_WIDEN_ONE
548+ bi2 = 0 ;
549+ #else
492550 BLASLONG bi = n_top * K ;
551+ #endif
493552
494553 vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
495554 vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
@@ -498,11 +557,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
498557
499558 for (BLASLONG k = 0 ; k < K ; k ++ ) {
500559#ifdef BF16_WIDEN_ONE
501- float B0 = ( float )( BB [ bi + 0 ]) ;
502- float B1 = ( float )( BB [ bi + 1 ]) ;
503- float B2 = ( float )( BB [ bi + 2 ]) ;
504- float B3 = ( float )( BB [ bi + 3 ]) ;
505- bi += 4 ;
560+ float B0 = B_CONV [ bi2 + 0 ];
561+ float B1 = B_CONV [ bi2 + 1 ];
562+ float B2 = B_CONV [ bi2 + 2 ];
563+ float B3 = B_CONV [ bi2 + 3 ];
564+ bi2 += 4 ;
506565
507566 vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
508567 vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
@@ -554,7 +613,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
554613 gvl = __riscv_vsetvl_e16mf2 (4 );
555614
556615 BLASLONG ai = m_top * K ;
616+ #ifdef BF16_WIDEN_ONE
617+ bi2 = 0 ;
618+ #else
557619 BLASLONG bi = n_top * K ;
620+ #endif
558621
559622 vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
560623 vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
@@ -563,11 +626,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
563626
564627 for (BLASLONG k = 0 ; k < K ; ++ k ) {
565628#ifdef BF16_WIDEN_ONE
566- float B0 = ( float )( BB [ bi + 0 ]) ;
567- float B1 = ( float )( BB [ bi + 1 ]) ;
568- float B2 = ( float )( BB [ bi + 2 ]) ;
569- float B3 = ( float )( BB [ bi + 3 ]) ;
570- bi += 4 ;
629+ float B0 = B_CONV [ bi2 + 0 ];
630+ float B1 = B_CONV [ bi2 + 1 ];
631+ float B2 = B_CONV [ bi2 + 2 ];
632+ float B3 = B_CONV [ bi2 + 3 ];
633+ bi2 += 4 ;
571634
572635 vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
573636 vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
@@ -977,5 +1040,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
9771040
9781041 n_top += 1 ;
9791042 }
1043+ #ifdef BF16_WIDEN_ONE
1044+ if (B_CONV ) free (B_CONV );
1045+ #endif
9801046 return 0 ;
9811047}
0 commit comments