Skip to content

Commit c1c10cb

Browse files
authored
Merge pull request #2384 from wjc404/develop
Optimize AVX512 DGEMM (& DTRMM)
2 parents 8d2a796 + 3447d04 commit c1c10cb

8 files changed

Lines changed: 565 additions & 23 deletions

File tree

driver/level3/level3.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,16 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
332332
#else
333333
for(jjs = js; jjs < js + min_j; jjs += min_jj){
334334
min_jj = min_j + js - jjs;
335-
335+
#ifdef SKYLAKEX
336+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve best performance */
337+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
338+
#else
336339
if (min_jj >= 3*GEMM_UNROLL_N) min_jj = 3*GEMM_UNROLL_N;
337340
else
338341
if (min_jj >= 2*GEMM_UNROLL_N) min_jj = 2*GEMM_UNROLL_N;
339342
else
340343
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
341-
344+
#endif
342345

343346

344347
START_RPCC();

driver/level3/level3_thread.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,16 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
365365
/* Split local region of B into parts */
366366
for(jjs = js; jjs < MIN(n_to, js + div_n); jjs += min_jj){
367367
min_jj = MIN(n_to, js + div_n) - jjs;
368+
#ifdef SKYLAKEX
369+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
370+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
371+
#else
368372
if (min_jj >= 3*GEMM_UNROLL_N) min_jj = 3*GEMM_UNROLL_N;
369373
else
370374
if (min_jj >= 2*GEMM_UNROLL_N) min_jj = 2*GEMM_UNROLL_N;
371375
else
372376
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
373-
377+
#endif
374378
/* Copy part of local region of B into workspace */
375379
START_RPCC();
376380
OCOPY_OPERATION(min_l, min_jj, b, ldb, ls, jjs,

driver/level3/trmm_L.c

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
135135

136136
for(jjs = js; jjs < js + min_j; jjs += min_jj){
137137
min_jj = min_j + js - jjs;
138+
#ifdef SKYLAKEX
139+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
140+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
141+
#else
138142
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
139143
else
140144
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
141-
145+
#endif
142146
START_RPCC();
143147

144148
GEMM_ONCOPY(min_l, min_jj, b + (jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
@@ -201,10 +205,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
201205

202206
for(jjs = js; jjs < js + min_j; jjs += min_jj){
203207
min_jj = min_j + js - jjs;
208+
#ifdef SKYLAKEX
209+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
210+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
211+
#else
204212
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
205213
else
206214
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
207-
215+
#endif
208216
START_RPCC();
209217

210218
GEMM_ONCOPY(min_l, min_jj, b + (ls + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
@@ -292,10 +300,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
292300

293301
for(jjs = js; jjs < js + min_j; jjs += min_jj){
294302
min_jj = min_j + js - jjs;
303+
#ifdef SKYLAKEX
304+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
305+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
306+
#else
295307
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
296308
else
297309
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
298-
310+
#endif
299311
START_RPCC();
300312

301313
GEMM_ONCOPY(min_l, min_jj, b + (m - min_l + jjs * ldb) * COMPSIZE, ldb,
@@ -358,10 +370,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
358370

359371
for(jjs = js; jjs < js + min_j; jjs += min_jj){
360372
min_jj = min_j + js - jjs;
373+
#ifdef SKYLAKEX
374+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
375+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
376+
#else
361377
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
362378
else
363379
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
364-
380+
#endif
365381
START_RPCC();
366382

367383
GEMM_ONCOPY(min_l, min_jj, b + (ls - min_l + jjs * ldb) * COMPSIZE, ldb,

driver/level3/trmm_R.c

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
122122

123123
for(jjs = 0; jjs < ls - js; jjs += min_jj){
124124
min_jj = ls - js - jjs;
125+
#ifdef SKYLAKEX
126+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
127+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
128+
#else
125129
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
126130
else
127131
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
128-
132+
#endif
129133
#ifndef TRANSA
130134
GEMM_ONCOPY(min_l, min_jj, a + (ls + (js + jjs) * lda) * COMPSIZE, lda, sb + min_l * jjs * COMPSIZE);
131135
#else
@@ -142,10 +146,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
142146

143147
for(jjs = 0; jjs < min_l; jjs += min_jj){
144148
min_jj = min_l - jjs;
149+
#ifdef SKYLAKEX
150+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
151+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
152+
#else
145153
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
146154
else
147155
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
148-
156+
#endif
149157
#ifndef TRANSA
150158
TRMM_OLNCOPY(min_l, min_jj, a, lda, ls, ls + jjs, sb + min_l * (ls - js + jjs) * COMPSIZE);
151159
#else
@@ -195,10 +203,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
195203

196204
for(jjs = js; jjs < js + min_j; jjs += min_jj){
197205
min_jj = min_j + js - jjs;
206+
#ifdef SKYLAKEX
207+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
208+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
209+
#else
198210
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
199211
else
200212
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
201-
213+
#endif
202214
#ifndef TRANSA
203215
GEMM_ONCOPY(min_l, min_jj, a + (ls + jjs * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE);
204216
#else
@@ -246,10 +258,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
246258

247259
for(jjs = 0; jjs < min_l; jjs += min_jj){
248260
min_jj = min_l - jjs;
261+
#ifdef SKYLAKEX
262+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
263+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
264+
#else
249265
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
250266
else
251267
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
252-
268+
#endif
253269
#ifndef TRANSA
254270
TRMM_OUNCOPY(min_l, min_jj, a, lda, ls, ls + jjs, sb + min_l * jjs * COMPSIZE);
255271
#else
@@ -267,10 +283,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
267283

268284
for(jjs = 0; jjs < js - ls - min_l; jjs += min_jj){
269285
min_jj = js - ls - min_l - jjs;
286+
#ifdef SKYLAKEX
287+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
288+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
289+
#else
270290
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
271291
else
272292
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
273-
293+
#endif
274294
#ifndef TRANSA
275295
GEMM_ONCOPY(min_l, min_jj, a + (ls + (ls + min_l + jjs) * lda) * COMPSIZE, lda,
276296
sb + min_l * (min_l + jjs) * COMPSIZE);
@@ -324,10 +344,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
324344

325345
for(jjs = js; jjs < js + min_j; jjs += min_jj){
326346
min_jj = min_j + js - jjs;
347+
#ifdef SKYLAKEX
348+
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
349+
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
350+
#else
327351
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
328352
else
329353
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
330-
354+
#endif
331355
#ifndef TRANSA
332356
GEMM_ONCOPY(min_l, min_jj, a + (ls + (jjs - min_j) * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE);
333357
#else

kernel/x86_64/KERNEL.SKYLAKEX

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ SGEMMITCOPY = sgemm_tcopy_16_skylakex.c
77
SGEMMONCOPY = sgemm_ncopy_4_skylakex.c
88
SGEMMOTCOPY = ../generic/gemm_tcopy_4.c
99

10-
DGEMMKERNEL = dgemm_kernel_4x8_skylakex_2.c
11-
12-
DGEMMONCOPY = dgemm_ncopy_8_skylakex.c
13-
DGEMMOTCOPY = dgemm_tcopy_8_skylakex.c
10+
DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c
11+
DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c
12+
DGEMMINCOPY = ../generic/gemm_ncopy_16.c
13+
DGEMMITCOPY = ../generic/gemm_tcopy_16.c
14+
DGEMMONCOPY = ../generic/gemm_ncopy_2.c
15+
DGEMMOTCOPY = ../generic/gemm_tcopy_2.c
16+
DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c
1417

1518
SGEMM_BETA = sgemm_beta_skylakex.c
1619
DGEMM_BETA = dgemm_beta_skylakex.c

0 commit comments

Comments
 (0)