2727 * *****************************************************************************/
2828
2929#include <arm_sve.h>
30+
3031#include "common.h"
3132
32- int CNAME (BLASLONG m , BLASLONG n , BLASLONG k , FLOAT alpha , IFLOAT * A , IFLOAT * B ,
33- FLOAT * C , BLASLONG ldc ) {
33+ int CNAME (BLASLONG m , BLASLONG n , BLASLONG k , FLOAT alpha , IFLOAT * A , IFLOAT * B , FLOAT * C ,
34+ BLASLONG ldc ) {
3435 // printf("m: %d, n: %d, k: %d\n", m, n, k);
3536 BLASLONG padk = (k + 3 ) & ~3 ;
3637 BLASLONG padm = (m + 1 ) & ~1 ;
3738 BLASLONG padn = (n + 1 ) & ~1 ;
38- FLOAT * RC = (FLOAT * ) calloc (padm * padn , sizeof (float ));
39+ FLOAT * RC = (FLOAT * )calloc (padm * padn , sizeof (float ));
3940 BLASLONG nldc = padm ;
4041
4142 IFLOAT * ptr_a = A ;
@@ -52,10 +53,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
5253 svbool_t pg32 = svptrue_b32 ();
5354 svfloat32_t svalpha = svdup_f32 (alpha );
5455
55- uint32_t off_c [] = {0 , (uint32_t ) nldc , 1 , (uint32_t ) nldc + 1 }; // 00 01 10 11
56+ uint32_t off_c [] = {0 , (uint32_t )nldc , 1 , (uint32_t )nldc + 1 }; // 00 01 10 11
5657 svuint32_t off_vc = svld1_u32 (pg32 , off_c );
5758
58- for (BLASLONG j = 0 ; j < padn / 4 ; j ++ ) {
59+ for (BLASLONG j = 0 ; j < padn / 4 ; j ++ ) {
5960 ptr_c00 = ptr_c ;
6061 ptr_c10 = ptr_c00 + 2 ;
6162 ptr_c20 = ptr_c10 + 2 ;
@@ -68,7 +69,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
6869
6970 ptr_a = A ;
7071
71- for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
72+ for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
7273 ptr_a0 = ptr_a ;
7374 ptr_a1 = ptr_a0 + 2 * padk ;
7475 ptr_a2 = ptr_a1 + 2 * padk ;
@@ -78,18 +79,22 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
7879 ptr_b0 = ptr_b ;
7980 ptr_b1 = ptr_b0 + 2 * padk ;
8081
81- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
82- mc10 = svdup_f32 (0 ); mc11 = svdup_f32 (0 );
83- mc20 = svdup_f32 (0 ); mc21 = svdup_f32 (0 );
84- mc30 = svdup_f32 (0 ); mc31 = svdup_f32 (0 );
85-
86- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
87- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
88- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
89- ma2 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a2 );
90- ma3 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a3 );
91- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
92- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
82+ mc00 = svdup_f32 (0 );
83+ mc01 = svdup_f32 (0 );
84+ mc10 = svdup_f32 (0 );
85+ mc11 = svdup_f32 (0 );
86+ mc20 = svdup_f32 (0 );
87+ mc21 = svdup_f32 (0 );
88+ mc30 = svdup_f32 (0 );
89+ mc31 = svdup_f32 (0 );
90+
91+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
92+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
93+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
94+ ma2 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a2 );
95+ ma3 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a3 );
96+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
97+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
9398
9499 mc00 = svbfmmla (mc00 , ma0 , mb0 );
95100 mc10 = svbfmmla (mc10 , ma1 , mb0 );
@@ -135,13 +140,15 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
135140 ptr_b0 = ptr_b ;
136141 ptr_b1 = ptr_b0 + 2 * padk ;
137142
138- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
139- mc10 = svdup_f32 (0 ); mc11 = svdup_f32 (0 );
140- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
141- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
142- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
143- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
144- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
143+ mc00 = svdup_f32 (0 );
144+ mc01 = svdup_f32 (0 );
145+ mc10 = svdup_f32 (0 );
146+ mc11 = svdup_f32 (0 );
147+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
148+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
149+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
150+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
151+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
145152
146153 mc00 = svbfmmla (mc00 , ma0 , mb0 );
147154 mc10 = svbfmmla (mc10 , ma1 , mb0 );
@@ -171,11 +178,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
171178 ptr_b0 = ptr_b ;
172179 ptr_b1 = ptr_b0 + 2 * padk ;
173180
174- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
175- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
176- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
177- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
178- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
181+ mc00 = svdup_f32 (0 );
182+ mc01 = svdup_f32 (0 );
183+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
184+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
185+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
186+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
179187 mc00 = svbfmmla (mc00 , ma0 , mb0 );
180188 mc01 = svbfmmla (mc01 , ma0 , mb1 );
181189 ptr_a0 += 8 ;
@@ -189,7 +197,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
189197 }
190198
191199 ptr_b += 4 * padk ;
192-
193200 }
194201
195202 if (padn & 2 ) {
@@ -202,7 +209,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
202209
203210 ptr_a = A ;
204211
205- for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
212+ for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
206213 ptr_a0 = ptr_a ;
207214 ptr_a1 = ptr_a0 + 2 * padk ;
208215 ptr_a2 = ptr_a1 + 2 * padk ;
@@ -216,12 +223,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
216223 mc20 = svdup_f32 (0 );
217224 mc30 = svdup_f32 (0 );
218225
219- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
220- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
221- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
222- ma2 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a2 );
223- ma3 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a3 );
224- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
226+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
227+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
228+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
229+ ma2 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a2 );
230+ ma3 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a3 );
231+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
225232 mc00 = svbfmmla (mc00 , ma0 , mb0 );
226233 mc10 = svbfmmla (mc10 , ma1 , mb0 );
227234 mc20 = svbfmmla (mc20 , ma2 , mb0 );
@@ -251,10 +258,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
251258
252259 mc00 = svdup_f32 (0 );
253260 mc10 = svdup_f32 (0 );
254- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
255- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
256- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
257- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
261+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
262+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
263+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
264+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
258265 mc00 = svbfmmla (mc00 , ma0 , mb0 );
259266 mc10 = svbfmmla (mc10 , ma1 , mb0 );
260267 ptr_a0 += 8 ;
@@ -272,9 +279,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
272279 ptr_a += 2 * padk ;
273280 ptr_b0 = ptr_b ;
274281 mc00 = svdup_f32 (0 );
275- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
276- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
277- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
282+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
283+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
284+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
278285 mc00 = svbfmmla (mc00 , ma0 , mb0 );
279286 ptr_a0 += 8 ;
280287 ptr_b0 += 8 ;
@@ -296,10 +303,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
296303 org_c += ldc ;
297304 raw_c += nldc ;
298305 BLASLONG i ;
299- for (i = 0 ; i < m / 4 ; i ++ ) {
306+ for (i = 0 ; i < m / 4 ; i ++ ) {
300307 org_vc0 = svld1_f32 (pg32 , org_c0 );
301308 raw_vc0 = svld1_f32 (pg32 , raw_c0 );
302- org_vc0 = svmad_z (pg32 , svalpha , raw_vc0 , org_vc0 ); // alpha * raw + org, raw -> a * b
309+ org_vc0 = svmad_z (pg32 , svalpha , raw_vc0 ,
310+ org_vc0 ); // alpha * raw + org, raw -> a * b
303311 svst1_f32 (pg32 , org_c0 , org_vc0 );
304312 org_c0 += 4 ;
305313 raw_c0 += 4 ;
@@ -310,5 +318,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
310318 raw_c0 ++ ;
311319 }
312320 }
321+
313322 return 0 ;
314323}
0 commit comments