Skip to content

Commit fa049d4

Browse files
authored
AVX2 STRSM kernel
1 parent dd22eb7 commit fa049d4

7 files changed

Lines changed: 1445 additions & 4 deletions

kernel/x86_64/KERNEL.HASWELL

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ ZGEMMITCOPYOBJ = zgemm_itcopy$(TSUFFIX).$(SUFFIX)
7777
ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX)
7878
ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX)
7979

80-
STRSMKERNEL_LN = ../generic/trsm_kernel_LN.c
81-
STRSMKERNEL_LT = ../generic/trsm_kernel_LT.c
82-
STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c
83-
STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c
80+
STRSMKERNEL_LN = strsm_kernel_8x4_haswell_LN.c
81+
STRSMKERNEL_LT = strsm_kernel_8x4_haswell_LT.c
82+
STRSMKERNEL_RN = strsm_kernel_8x4_haswell_RN.c
83+
STRSMKERNEL_RT = strsm_kernel_8x4_haswell_RT.c
8484

8585
DTRSMKERNEL_LN = ../generic/trsm_kernel_LN.c
8686
DTRSMKERNEL_LT = ../generic/trsm_kernel_LT.c
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#include "common.h"
2+
#include <stdint.h>
3+
#include "strsm_kernel_8x4_haswell_L_common.h"
4+
5+
#define SOLVE_LN_m1n4 \
6+
"subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4)\
7+
SOLVE_m1n4(-4,4) SAVE_b_m1n4(-16,4)\
8+
"movq %2,%3;" save_c_m1n4(4)
9+
10+
#define SOLVE_LN_m1n8 \
11+
"subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5)\
12+
SOLVE_m1n8(-4,4,5) SAVE_b_m1n8(-16,4,5)\
13+
"movq %2,%3;" save_c_m1n4(4) save_c_m1n4(5)
14+
15+
#define SOLVE_LN_m1n12 \
16+
"subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5) GEMM_SUM_REORDER_1x4(6)\
17+
SOLVE_m1n12(-4,4,5,6) SAVE_b_m1n12(-16,4,5,6)\
18+
"movq %2,%3;" save_c_m1n4(4) save_c_m1n4(5) save_c_m1n4(6)
19+
20+
#define SOLVE_LN_m2n4 \
21+
"subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4)\
22+
SOLVE_loup_m2n4(-8,4)\
23+
SOLVE_up_m2n4(-16,4) SAVE_b_m2n4(-32,4)\
24+
"movq %2,%3;" save_c_m2n4(4)
25+
26+
#define SOLVE_LN_m2n8 \
27+
"subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5)\
28+
SOLVE_loup_m2n8(-8,4,5)\
29+
SOLVE_up_m2n8(-16,4,5) SAVE_b_m2n8(-32,4,5)\
30+
"movq %2,%3;" save_c_m2n4(4) save_c_m2n4(5)
31+
32+
#define SOLVE_LN_m2n12 \
33+
"subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5) GEMM_SUM_REORDER_2x4(8,9,6)\
34+
SOLVE_loup_m2n12(-8,4,5,6)\
35+
SOLVE_up_m2n12(-16,4,5,6) SAVE_b_m2n12(-32,4,5,6)\
36+
"movq %2,%3;" save_c_m2n4(4) save_c_m2n4(5) save_c_m2n4(6)
37+
38+
#define SOLVE_LN_m4n4 \
39+
"subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5)\
40+
\
41+
SOLVE_loup_m2n4(-8,5) SUBTRACT_m2n4(-16,4)\
42+
SOLVE_up_m2n4(-24,5) SUBTRACT_m2n4(-32,4) SAVE_b_m2n4(-32,5)\
43+
\
44+
SOLVE_loup_m2n4(-48,4)\
45+
SOLVE_up_m2n4(-64,4) SAVE_b_m2n4(-64,4)\
46+
\
47+
"movq %2,%3;" save_c_m4n4(4,5)
48+
49+
#define SOLVE_LN_m4n8 \
50+
"subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7)\
51+
\
52+
SOLVE_loup_m2n8(-8,5,7) SUBTRACT_m2n8(-16,4,6)\
53+
SOLVE_up_m2n8(-24,5,7) SUBTRACT_m2n8(-32,4,6) SAVE_b_m2n8(-32,5,7)\
54+
\
55+
SOLVE_loup_m2n8(-48,4,6)\
56+
SOLVE_up_m2n8(-64,4,6) SAVE_b_m2n8(-64,4,6)\
57+
\
58+
"movq %2,%3;" save_c_m4n4(4,5) save_c_m4n4(6,7)
59+
60+
#define SOLVE_LN_m4n12 \
61+
"subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7) GEMM_SUM_REORDER_4x4(12,13,14,15,8,9)\
62+
\
63+
SOLVE_loup_m2n12(-8,5,7,9) SUBTRACT_m2n12(-16,4,6,8)\
64+
SOLVE_up_m2n12(-24,5,7,9) SUBTRACT_m2n12(-32,4,6,8) SAVE_b_m2n12(-32,5,7,9)\
65+
\
66+
SOLVE_loup_m2n12(-48,4,6,8)\
67+
SOLVE_up_m2n12(-64,4,6,8) SAVE_b_m2n12(-64,4,6,8)\
68+
\
69+
"movq %2,%3;" save_c_m4n4(4,5) save_c_m4n4(6,7) save_c_m4n4(8,9)
70+
71+
#define SOLVE_LN_m8n4 \
72+
"subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32)\
73+
\
74+
SOLVE_loup_m2n4(-8,7) SUBTRACT_m2n4(-16,6) SUBTRACT_m2n4(-24,5) SUBTRACT_m2n4(-32,4)\
75+
SOLVE_up_m2n4(-40,7) SUBTRACT_m2n4(-48,6) SUBTRACT_m2n4(-56,5) SUBTRACT_m2n4(-64,4) SAVE_b_m2n4(-32,7)\
76+
\
77+
SOLVE_loup_m2n4(-80,6) SUBTRACT_m2n4(-88,5) SUBTRACT_m2n4(-96,4)\
78+
SOLVE_up_m2n4(-112,6) SUBTRACT_m2n4(-120,5) SUBTRACT_m2n4(-128,4) SAVE_b_m2n4(-64,6)\
79+
\
80+
SOLVE_loup_m2n4(-152,5) SUBTRACT_m2n4(-160,4)\
81+
SOLVE_up_m2n4(-184,5) SUBTRACT_m2n4(-192,4) SAVE_b_m2n4(-96,5)\
82+
\
83+
SOLVE_loup_m2n4(-224,4)\
84+
SOLVE_up_m2n4(-256,4) SAVE_b_m2n4(-128,4)\
85+
\
86+
"movq %2,%3;" save_c_m8n4(4,5,6,7)
87+
88+
#define SOLVE_LN_m8n8 \
89+
"subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32) GEMM_SUM_REORDER_8x4(8,9,10,11,-32)\
90+
\
91+
SOLVE_loup_m2n8(-8,7,11) SUBTRACT_m2n8(-16,6,10) SUBTRACT_m2n8(-24,5,9) SUBTRACT_m2n8(-32,4,8)\
92+
SOLVE_up_m2n8(-40,7,11) SUBTRACT_m2n8(-48,6,10) SUBTRACT_m2n8(-56,5,9) SUBTRACT_m2n8(-64,4,8) SAVE_b_m2n8(-32,7,11)\
93+
\
94+
SOLVE_loup_m2n8(-80,6,10) SUBTRACT_m2n8(-88,5,9) SUBTRACT_m2n8(-96,4,8)\
95+
SOLVE_up_m2n8(-112,6,10) SUBTRACT_m2n8(-120,5,9) SUBTRACT_m2n8(-128,4,8) SAVE_b_m2n8(-64,6,10)\
96+
\
97+
SOLVE_loup_m2n8(-152,5,9) SUBTRACT_m2n8(-160,4,8)\
98+
SOLVE_up_m2n8(-184,5,9) SUBTRACT_m2n8(-192,4,8) SAVE_b_m2n8(-96,5,9)\
99+
\
100+
SOLVE_loup_m2n8(-224,4,8)\
101+
SOLVE_up_m2n8(-256,4,8) SAVE_b_m2n8(-128,4,8)\
102+
\
103+
"movq %2,%3;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11)
104+
105+
#define SOLVE_LN_m8n12 \
106+
"subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32) GEMM_SUM_REORDER_8x4(8,9,10,11,-32) GEMM_SUM_REORDER_8x4(12,13,14,15,-32)\
107+
\
108+
SOLVE_loup_m2n12(-8,7,11,15) SUBTRACT_m2n12(-16,6,10,14) SUBTRACT_m2n12(-24,5,9,13) SUBTRACT_m2n12(-32,4,8,12)\
109+
SOLVE_up_m2n12(-40,7,11,15) SUBTRACT_m2n12(-48,6,10,14) SUBTRACT_m2n12(-56,5,9,13) SUBTRACT_m2n12(-64,4,8,12) SAVE_b_m2n12(-32,7,11,15)\
110+
\
111+
SOLVE_loup_m2n12(-80,6,10,14) SUBTRACT_m2n12(-88,5,9,13) SUBTRACT_m2n12(-96,4,8,12)\
112+
SOLVE_up_m2n12(-112,6,10,14) SUBTRACT_m2n12(-120,5,9,13) SUBTRACT_m2n12(-128,4,8,12) SAVE_b_m2n12(-64,6,10,14)\
113+
\
114+
SOLVE_loup_m2n12(-152,5,9,13) SUBTRACT_m2n12(-160,4,8,12)\
115+
SOLVE_up_m2n12(-184,5,9,13) SUBTRACT_m2n12(-192,4,8,12) SAVE_b_m2n12(-96,5,9,13)\
116+
\
117+
SOLVE_loup_m2n12(-224,4,8,12)\
118+
SOLVE_up_m2n12(-256,4,8,12) SAVE_b_m2n12(-128,4,8,12)\
119+
\
120+
"movq %2,%3;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11) save_c_m8n4(12,13,14,15)
121+
122+
/* r13 = k-kk, r14 = b_tail, r15 = a_tail */
123+
124+
#define GEMM_LN_SIMPLE(mdim,ndim) \
125+
"movq %%r15,%0; negq %%r12; leaq (%%r15,%%r12,"#mdim"),%%r15; negq %%r12;"\
126+
"movq %%r13,%5; addq $"#mdim",%%r13; movq %%r14,%1;" INIT_m##mdim##n##ndim\
127+
"testq %5,%5; jz 2"#mdim""#ndim"2f;"\
128+
"2"#mdim""#ndim"1:\n\t"\
129+
"subq $16,%1; subq $"#mdim"*4,%0;" GEMM_KERNEL_k1m##mdim##n##ndim "decq %5; jnz 2"#mdim""#ndim"1b;"\
130+
"2"#mdim""#ndim"2:\n\t"
131+
#define GEMM_LN_m8n4 GEMM_LN_SIMPLE(8,4)
132+
#define GEMM_LN_m8n8 GEMM_LN_SIMPLE(8,8)
133+
#define GEMM_LN_m8n12 \
134+
"movq %%r15,%0; negq %%r12; leaq (%%r15,%%r12,8),%%r15; negq %%r12; movq %%r13,%5; addq $8,%%r13; movq %%r14,%1;" INIT_m8n12\
135+
"cmpq $8,%5; jb 28122f;"\
136+
"28121:\n\t"\
137+
"prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
138+
"subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
139+
"prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
140+
"subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
141+
"prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
142+
"subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
143+
"prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
144+
"subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\
145+
"subq $8,%5; cmpq $8,%5; jnb 28121b;"\
146+
"28122:\n\t"\
147+
"testq %5,%5; jz 28124f;"\
148+
"28123:\n\t"\
149+
"subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12 "decq %5; jnz 28123b;"\
150+
"28124:\n\t"
151+
#define GEMM_LN_m4n4 GEMM_LN_SIMPLE(4,4)
152+
#define GEMM_LN_m4n8 GEMM_LN_SIMPLE(4,8)
153+
#define GEMM_LN_m4n12 GEMM_LN_SIMPLE(4,12)
154+
#define GEMM_LN_m2n4 GEMM_LN_SIMPLE(2,4)
155+
#define GEMM_LN_m2n8 GEMM_LN_SIMPLE(2,8)
156+
#define GEMM_LN_m2n12 GEMM_LN_SIMPLE(2,12)
157+
#define GEMM_LN_m1n4 GEMM_LN_SIMPLE(1,4)
158+
#define GEMM_LN_m1n8 GEMM_LN_SIMPLE(1,8)
159+
#define GEMM_LN_m1n12 GEMM_LN_SIMPLE(1,12)
160+
161+
#define COMPUTE(ndim) {\
162+
c_ptr += M;\
163+
__asm__ __volatile__(\
164+
"movq %0,%%r15; movq %7,%%r13; movq %6,%%r12; salq $2,%%r12; leaq (%1,%%r12,4),%%r14; movq %10,%%r11;"\
165+
"testq $1,%%r11; jz "#ndim"772f;"\
166+
#ndim"771:\n\t"\
167+
GEMM_LN_m1n##ndim SOLVE_LN_m1n##ndim "subq $1,%%r11;"\
168+
#ndim"772:\n\t"\
169+
"testq $2,%%r11; jz "#ndim"773f;"\
170+
GEMM_LN_m2n##ndim SOLVE_LN_m2n##ndim "subq $2,%%r11;"\
171+
#ndim"773:\n\t"\
172+
"testq $4,%%r11; jz "#ndim"774f;"\
173+
GEMM_LN_m4n##ndim SOLVE_LN_m4n##ndim "subq $4,%%r11;"\
174+
#ndim"774:\n\t"\
175+
"testq %%r11,%%r11; jz "#ndim"776f;"\
176+
#ndim"775:\n\t"\
177+
GEMM_LN_m8n##ndim SOLVE_LN_m8n##ndim "subq $8,%%r11; jnz "#ndim"775b;"\
178+
#ndim"776:\n\t"\
179+
"movq %%r15,%0; movq %%r14,%1; vzeroupper;"\
180+
:"+r"(a_ptr),"+r"(b_ptr),"+r"(c_ptr),"+r"(c_tmp),"+r"(ldc_bytes),"+r"(k_cnt):"m"(K),"m"(kmkkinp),"m"(one[0]),"m"(zero[0]),"m"(M)\
181+
:"r11","r12","r13","r14","r15","cc","memory",\
182+
"xmm0","xmm1","xmm2","xmm3","xmm4","xmm5","xmm6","xmm7","xmm8","xmm9","xmm10","xmm11","xmm12","xmm13","xmm14","xmm15");\
183+
a_ptr += M * K; b_ptr += (ndim-4) * K; c_ptr += ldc * ndim;\
184+
}
185+
static void solve_LN(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) {
186+
FLOAT a0, b0;
187+
int i, j, k;
188+
for (i=m-1;i>=0;i--) {
189+
a0 = a[i*m+i]; //reciprocal of the original value
190+
for (j=0;j<n;j++) {
191+
b0 = c[j*ldc+i]*a0;
192+
c[j*ldc+i] = b[i*n+j] = b0;
193+
for (k=0;k<i;k++) c[j*ldc+k] -= b0*a[i*m+k];
194+
}
195+
}
196+
}
197+
static void COMPUTE_EDGE_1_nchunk(BLASLONG m, BLASLONG n, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG k, BLASLONG offset) {
198+
BLASLONG m_count = m, kk = m+offset; FLOAT *a_ptr = sa+m*k, *c_ptr = C+m;
199+
if(m_count&1){
200+
a_ptr-=k; c_ptr--;
201+
if(k-kk>0) GEMM_KERNEL_N(1,n,k-kk,-1.0,a_ptr+kk*1,sb+kk*n,c_ptr,ldc);
202+
solve_LN(1,n,a_ptr+(kk-1)*1,sb+(kk-1)*n,c_ptr,ldc);
203+
kk -= 1;
204+
m_count--;
205+
}
206+
if(m_count&2){
207+
a_ptr-=k*2; c_ptr-=2;
208+
if(k-kk>0) GEMM_KERNEL_N(2,n,k-kk,-1.0,a_ptr+kk*2,sb+kk*n,c_ptr,ldc);
209+
solve_LN(2,n,a_ptr+(kk-2)*2,sb+(kk-2)*n,c_ptr,ldc);
210+
kk -= 2;
211+
m_count-=2;
212+
}
213+
if(m_count&4){
214+
a_ptr-=k*4; c_ptr-=4;
215+
if(k-kk>0) GEMM_KERNEL_N(4,n,k-kk,-1.0,a_ptr+kk*4,sb+kk*n,c_ptr,ldc);
216+
solve_LN(4,n,a_ptr+(kk-4)*4,sb+(kk-4)*n,c_ptr,ldc);
217+
kk -= 4;
218+
m_count-=4;
219+
}
220+
for(;m_count>7;m_count-=8){
221+
a_ptr-=k*8; c_ptr-=8;
222+
if(k-kk>0) GEMM_KERNEL_N(8,n,k-kk,-1.0,a_ptr+kk*8,sb+kk*n,c_ptr,ldc);
223+
solve_LN(8,n,a_ptr+(kk-8)*8,sb+(kk-8)*n,c_ptr,ldc);
224+
kk -= 8;
225+
}
226+
}
227+
int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG offset){
228+
float *a_ptr = sa+m*k, *b_ptr = sb, *c_ptr = C, *c_tmp = C;
229+
float one[8] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0};
230+
float zero[8] = {0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0};
231+
uint64_t ldc_bytes = (uint64_t)ldc * sizeof(float), K = (uint64_t)k, M = (uint64_t)m, kmkkinp = (uint64_t)(k-m-offset), k_cnt = 0;
232+
BLASLONG n_count = n;
233+
for(;n_count>11;n_count-=12) COMPUTE(12)
234+
for(;n_count>7;n_count-=8) COMPUTE(8)
235+
for(;n_count>3;n_count-=4) COMPUTE(4)
236+
for(;n_count>1;n_count-=2) { COMPUTE_EDGE_1_nchunk(m,2,sa,b_ptr,c_ptr,ldc,k,offset); b_ptr += 2*k; c_ptr += ldc*2;}
237+
if(n_count>0) COMPUTE_EDGE_1_nchunk(m,1,sa,b_ptr,c_ptr,ldc,k,offset);
238+
return 0;
239+
}
240+

0 commit comments

Comments
 (0)