|
30 | 30 |
|
31 | 31 | #include "common.h" |
32 | 32 |
|
33 | | -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, |
34 | | - BLASLONG ldc) { |
35 | | - // printf("m: %d, n: %d, k: %d\n", m, n, k); |
36 | | - BLASLONG padk = (k + 3) & ~3; |
37 | | - BLASLONG padm = (m + 1) & ~1; |
38 | | - BLASLONG padn = (n + 1) & ~1; |
39 | | - FLOAT *RC = (FLOAT *)calloc(padm * padn, sizeof(float)); |
40 | | - BLASLONG nldc = padm; |
41 | | - |
42 | | - IFLOAT *ptr_a = A; |
43 | | - IFLOAT *ptr_b = B; |
44 | | - FLOAT *ptr_c = RC; |
45 | | - |
46 | | - IFLOAT *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; |
47 | | - IFLOAT *ptr_b0, *ptr_b1; |
48 | | - FLOAT *ptr_c00, *ptr_c10, *ptr_c20, *ptr_c30, *ptr_c01, *ptr_c11, *ptr_c21, *ptr_c31; |
49 | | - |
50 | | - svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; |
51 | | - svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31; |
52 | | - svbool_t pg16 = svptrue_b16(); |
53 | | - svbool_t pg32 = svptrue_b32(); |
54 | | - svfloat32_t svalpha = svdup_f32(alpha); |
55 | | - |
56 | | - uint32_t off_c[] = {0, (uint32_t)nldc, 1, (uint32_t)nldc + 1}; // 00 01 10 11 |
57 | | - svuint32_t off_vc = svld1_u32(pg32, off_c); |
58 | | - |
59 | | - for (BLASLONG j = 0; j < padn / 4; j++) { |
60 | | - ptr_c00 = ptr_c; |
61 | | - ptr_c10 = ptr_c00 + 2; |
62 | | - ptr_c20 = ptr_c10 + 2; |
63 | | - ptr_c30 = ptr_c20 + 2; |
64 | | - ptr_c01 = ptr_c + 2 * nldc; |
65 | | - ptr_c11 = ptr_c01 + 2; |
66 | | - ptr_c21 = ptr_c11 + 2; |
67 | | - ptr_c31 = ptr_c21 + 2; |
68 | | - ptr_c += 4 * nldc; |
69 | | - |
70 | | - ptr_a = A; |
71 | | - |
72 | | - for (BLASLONG i = 0; i < padm / 8; i++) { |
73 | | - ptr_a0 = ptr_a; |
74 | | - ptr_a1 = ptr_a0 + 2 * padk; |
75 | | - ptr_a2 = ptr_a1 + 2 * padk; |
76 | | - ptr_a3 = ptr_a2 + 2 * padk; |
77 | | - ptr_a += 8 * padk; |
78 | | - |
79 | | - ptr_b0 = ptr_b; |
80 | | - ptr_b1 = ptr_b0 + 2 * padk; |
81 | | - |
82 | | - mc00 = svdup_f32(0); |
83 | | - mc01 = svdup_f32(0); |
84 | | - mc10 = svdup_f32(0); |
85 | | - mc11 = svdup_f32(0); |
86 | | - mc20 = svdup_f32(0); |
87 | | - mc21 = svdup_f32(0); |
88 | | - mc30 = svdup_f32(0); |
89 | | - mc31 = svdup_f32(0); |
90 | | - |
91 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
92 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
93 | | - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
94 | | - ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2); |
95 | | - ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3); |
96 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
97 | | - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
98 | | - |
99 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
100 | | - mc10 = svbfmmla(mc10, ma1, mb0); |
101 | | - mc20 = svbfmmla(mc20, ma2, mb0); |
102 | | - mc30 = svbfmmla(mc30, ma3, mb0); |
103 | | - mc01 = svbfmmla(mc01, ma0, mb1); |
104 | | - mc11 = svbfmmla(mc11, ma1, mb1); |
105 | | - mc21 = svbfmmla(mc21, ma2, mb1); |
106 | | - mc31 = svbfmmla(mc31, ma3, mb1); |
107 | | - |
108 | | - ptr_a0 += 8; |
109 | | - ptr_a1 += 8; |
110 | | - ptr_a2 += 8; |
111 | | - ptr_a3 += 8; |
112 | | - ptr_b0 += 8; |
113 | | - ptr_b1 += 8; |
114 | | - } |
115 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
116 | | - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
117 | | - svst1_scatter_index(pg32, ptr_c20, off_vc, mc20); |
118 | | - svst1_scatter_index(pg32, ptr_c30, off_vc, mc30); |
119 | | - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
120 | | - svst1_scatter_index(pg32, ptr_c11, off_vc, mc11); |
121 | | - svst1_scatter_index(pg32, ptr_c21, off_vc, mc21); |
122 | | - svst1_scatter_index(pg32, ptr_c31, off_vc, mc31); |
123 | | - |
124 | | - ptr_c00 += 8; |
125 | | - ptr_c10 += 8; |
126 | | - ptr_c20 += 8; |
127 | | - ptr_c30 += 8; |
128 | | - ptr_c01 += 8; |
129 | | - ptr_c11 += 8; |
130 | | - ptr_c21 += 8; |
131 | | - ptr_c31 += 8; |
132 | | - } |
133 | | - |
134 | | - if (padm & 4) { |
135 | | - // rest 4 or 6 |
136 | | - ptr_a0 = ptr_a; |
137 | | - ptr_a1 = ptr_a0 + 2 * padk; |
138 | | - ptr_a += 4 * padk; |
139 | | - |
140 | | - ptr_b0 = ptr_b; |
141 | | - ptr_b1 = ptr_b0 + 2 * padk; |
142 | | - |
143 | | - mc00 = svdup_f32(0); |
144 | | - mc01 = svdup_f32(0); |
145 | | - mc10 = svdup_f32(0); |
146 | | - mc11 = svdup_f32(0); |
147 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
148 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
149 | | - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
150 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
151 | | - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
152 | | - |
153 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
154 | | - mc10 = svbfmmla(mc10, ma1, mb0); |
155 | | - mc01 = svbfmmla(mc01, ma0, mb1); |
156 | | - mc11 = svbfmmla(mc11, ma1, mb1); |
157 | | - |
158 | | - ptr_a0 += 8; |
159 | | - ptr_a1 += 8; |
160 | | - ptr_b0 += 8; |
161 | | - ptr_b1 += 8; |
162 | | - } |
163 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
164 | | - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
165 | | - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
166 | | - svst1_scatter_index(pg32, ptr_c11, off_vc, mc11); |
167 | | - |
168 | | - ptr_c00 += 4; |
169 | | - ptr_c10 += 4; |
170 | | - ptr_c01 += 4; |
171 | | - ptr_c11 += 4; |
172 | | - } |
173 | | - |
174 | | - if (padm & 2) { |
175 | | - // rest 2 |
176 | | - ptr_a0 = ptr_a; |
177 | | - |
178 | | - ptr_b0 = ptr_b; |
179 | | - ptr_b1 = ptr_b0 + 2 * padk; |
180 | | - |
181 | | - mc00 = svdup_f32(0); |
182 | | - mc01 = svdup_f32(0); |
183 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
184 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
185 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
186 | | - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
187 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
188 | | - mc01 = svbfmmla(mc01, ma0, mb1); |
189 | | - ptr_a0 += 8; |
190 | | - ptr_b0 += 8; |
191 | | - ptr_b1 += 8; |
192 | | - } |
193 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
194 | | - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
195 | | - ptr_c00 += 2; |
196 | | - ptr_c01 += 2; |
197 | | - } |
198 | | - |
199 | | - ptr_b += 4 * padk; |
200 | | - } |
201 | | - |
202 | | - if (padn & 2) { |
203 | | - // rest 2 |
204 | | - ptr_c00 = ptr_c; |
205 | | - ptr_c10 = ptr_c00 + 2; |
206 | | - ptr_c20 = ptr_c10 + 2; |
207 | | - ptr_c30 = ptr_c20 + 2; |
208 | | - ptr_c += 2 * nldc; |
209 | | - |
210 | | - ptr_a = A; |
211 | | - |
212 | | - for (BLASLONG i = 0; i < padm / 8; i++) { |
213 | | - ptr_a0 = ptr_a; |
214 | | - ptr_a1 = ptr_a0 + 2 * padk; |
215 | | - ptr_a2 = ptr_a1 + 2 * padk; |
216 | | - ptr_a3 = ptr_a2 + 2 * padk; |
217 | | - ptr_a += 8 * padk; |
218 | | - |
219 | | - ptr_b0 = ptr_b; |
220 | | - |
221 | | - mc00 = svdup_f32(0); |
222 | | - mc10 = svdup_f32(0); |
223 | | - mc20 = svdup_f32(0); |
224 | | - mc30 = svdup_f32(0); |
225 | | - |
226 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
227 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
228 | | - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
229 | | - ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2); |
230 | | - ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3); |
231 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
232 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
233 | | - mc10 = svbfmmla(mc10, ma1, mb0); |
234 | | - mc20 = svbfmmla(mc20, ma2, mb0); |
235 | | - mc30 = svbfmmla(mc30, ma3, mb0); |
236 | | - ptr_a0 += 8; |
237 | | - ptr_a1 += 8; |
238 | | - ptr_a2 += 8; |
239 | | - ptr_a3 += 8; |
240 | | - ptr_b0 += 8; |
241 | | - } |
242 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
243 | | - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
244 | | - svst1_scatter_index(pg32, ptr_c20, off_vc, mc20); |
245 | | - svst1_scatter_index(pg32, ptr_c30, off_vc, mc30); |
246 | | - ptr_c00 += 8; |
247 | | - ptr_c10 += 8; |
248 | | - ptr_c20 += 8; |
249 | | - ptr_c30 += 8; |
250 | | - } |
251 | | - |
252 | | - if (padm & 4) { |
253 | | - ptr_a0 = ptr_a; |
254 | | - ptr_a1 = ptr_a0 + 2 * padk; |
255 | | - ptr_a += 4 * padk; |
256 | | - |
257 | | - ptr_b0 = ptr_b; |
258 | | - |
259 | | - mc00 = svdup_f32(0); |
260 | | - mc10 = svdup_f32(0); |
261 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
262 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
263 | | - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
264 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
265 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
266 | | - mc10 = svbfmmla(mc10, ma1, mb0); |
267 | | - ptr_a0 += 8; |
268 | | - ptr_a1 += 8; |
269 | | - ptr_b0 += 8; |
270 | | - } |
271 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
272 | | - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
273 | | - ptr_c00 += 4; |
274 | | - ptr_c10 += 4; |
275 | | - } |
276 | | - |
277 | | - if (padm & 2) { |
278 | | - ptr_a0 = ptr_a; |
279 | | - ptr_a += 2 * padk; |
280 | | - ptr_b0 = ptr_b; |
281 | | - mc00 = svdup_f32(0); |
282 | | - for (BLASLONG p = 0; p < padk / 4; p++) { |
283 | | - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
284 | | - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
285 | | - mc00 = svbfmmla(mc00, ma0, mb0); |
286 | | - ptr_a0 += 8; |
287 | | - ptr_b0 += 8; |
288 | | - } |
289 | | - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
290 | | - ptr_c00 += 2; |
291 | | - } |
292 | | - |
293 | | - ptr_b += 2 * padk; |
294 | | - } |
295 | | - |
296 | | - FLOAT *org_c = C; |
297 | | - FLOAT *raw_c = RC; |
298 | | - FLOAT *org_c0, *raw_c0; |
299 | | - svfloat32_t org_vc0, raw_vc0; |
300 | | - for (BLASLONG j = 0; j < n; j++) { |
301 | | - org_c0 = org_c; |
302 | | - raw_c0 = raw_c; |
303 | | - org_c += ldc; |
304 | | - raw_c += nldc; |
305 | | - BLASLONG i; |
306 | | - for (i = 0; i < m / 4; i++) { |
307 | | - org_vc0 = svld1_f32(pg32, org_c0); |
308 | | - raw_vc0 = svld1_f32(pg32, raw_c0); |
309 | | - org_vc0 = svmad_z(pg32, svalpha, raw_vc0, |
310 | | - org_vc0); // alpha * raw + org, raw -> a * b |
311 | | - svst1_f32(pg32, org_c0, org_vc0); |
312 | | - org_c0 += 4; |
313 | | - raw_c0 += 4; |
314 | | - } |
315 | | - for (i = 0; i < (m & 3); i++) { |
316 | | - *org_c0 += alpha * (*raw_c0); |
317 | | - org_c0++; |
318 | | - raw_c0++; |
319 | | - } |
320 | | - } |
321 | | - |
| 33 | +#define ALPHA_ONE |
| 34 | +#include "sbgemm_kernel_8x4_neoversen2_impl.c" |
| 35 | +#undef ALPHA_ONE |
| 36 | +#include "sbgemm_kernel_8x4_neoversen2_impl.c" |
| 37 | + |
| 38 | +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, |
| 39 | + FLOAT *C, BLASLONG ldc) { |
| 40 | + if (alpha == 1.0f) |
| 41 | + return sbgemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc); |
| 42 | + else |
| 43 | + return sbgemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc); |
322 | 44 | return 0; |
323 | 45 | } |
0 commit comments