|
| 1 | +#include "common.h" |
| 2 | +#include <stdint.h> |
| 3 | +#include "strsm_kernel_8x4_haswell_L_common.h" |
| 4 | + |
| 5 | +#define SOLVE_LN_m1n4 \ |
| 6 | + "subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4)\ |
| 7 | + SOLVE_m1n4(-4,4) SAVE_b_m1n4(-16,4)\ |
| 8 | + "movq %2,%3;" save_c_m1n4(4) |
| 9 | + |
| 10 | +#define SOLVE_LN_m1n8 \ |
| 11 | + "subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5)\ |
| 12 | + SOLVE_m1n8(-4,4,5) SAVE_b_m1n8(-16,4,5)\ |
| 13 | + "movq %2,%3;" save_c_m1n4(4) save_c_m1n4(5) |
| 14 | + |
| 15 | +#define SOLVE_LN_m1n12 \ |
| 16 | + "subq $4,%2; movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5) GEMM_SUM_REORDER_1x4(6)\ |
| 17 | + SOLVE_m1n12(-4,4,5,6) SAVE_b_m1n12(-16,4,5,6)\ |
| 18 | + "movq %2,%3;" save_c_m1n4(4) save_c_m1n4(5) save_c_m1n4(6) |
| 19 | + |
| 20 | +#define SOLVE_LN_m2n4 \ |
| 21 | + "subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4)\ |
| 22 | + SOLVE_loup_m2n4(-8,4)\ |
| 23 | + SOLVE_up_m2n4(-16,4) SAVE_b_m2n4(-32,4)\ |
| 24 | + "movq %2,%3;" save_c_m2n4(4) |
| 25 | + |
| 26 | +#define SOLVE_LN_m2n8 \ |
| 27 | + "subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5)\ |
| 28 | + SOLVE_loup_m2n8(-8,4,5)\ |
| 29 | + SOLVE_up_m2n8(-16,4,5) SAVE_b_m2n8(-32,4,5)\ |
| 30 | + "movq %2,%3;" save_c_m2n4(4) save_c_m2n4(5) |
| 31 | + |
| 32 | +#define SOLVE_LN_m2n12 \ |
| 33 | + "subq $8,%2; movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5) GEMM_SUM_REORDER_2x4(8,9,6)\ |
| 34 | + SOLVE_loup_m2n12(-8,4,5,6)\ |
| 35 | + SOLVE_up_m2n12(-16,4,5,6) SAVE_b_m2n12(-32,4,5,6)\ |
| 36 | + "movq %2,%3;" save_c_m2n4(4) save_c_m2n4(5) save_c_m2n4(6) |
| 37 | + |
| 38 | +#define SOLVE_LN_m4n4 \ |
| 39 | + "subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5)\ |
| 40 | +\ |
| 41 | + SOLVE_loup_m2n4(-8,5) SUBTRACT_m2n4(-16,4)\ |
| 42 | + SOLVE_up_m2n4(-24,5) SUBTRACT_m2n4(-32,4) SAVE_b_m2n4(-32,5)\ |
| 43 | +\ |
| 44 | + SOLVE_loup_m2n4(-48,4)\ |
| 45 | + SOLVE_up_m2n4(-64,4) SAVE_b_m2n4(-64,4)\ |
| 46 | +\ |
| 47 | + "movq %2,%3;" save_c_m4n4(4,5) |
| 48 | + |
| 49 | +#define SOLVE_LN_m4n8 \ |
| 50 | + "subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7)\ |
| 51 | +\ |
| 52 | + SOLVE_loup_m2n8(-8,5,7) SUBTRACT_m2n8(-16,4,6)\ |
| 53 | + SOLVE_up_m2n8(-24,5,7) SUBTRACT_m2n8(-32,4,6) SAVE_b_m2n8(-32,5,7)\ |
| 54 | +\ |
| 55 | + SOLVE_loup_m2n8(-48,4,6)\ |
| 56 | + SOLVE_up_m2n8(-64,4,6) SAVE_b_m2n8(-64,4,6)\ |
| 57 | +\ |
| 58 | + "movq %2,%3;" save_c_m4n4(4,5) save_c_m4n4(6,7) |
| 59 | + |
| 60 | +#define SOLVE_LN_m4n12 \ |
| 61 | + "subq $16,%2; movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7) GEMM_SUM_REORDER_4x4(12,13,14,15,8,9)\ |
| 62 | +\ |
| 63 | + SOLVE_loup_m2n12(-8,5,7,9) SUBTRACT_m2n12(-16,4,6,8)\ |
| 64 | + SOLVE_up_m2n12(-24,5,7,9) SUBTRACT_m2n12(-32,4,6,8) SAVE_b_m2n12(-32,5,7,9)\ |
| 65 | +\ |
| 66 | + SOLVE_loup_m2n12(-48,4,6,8)\ |
| 67 | + SOLVE_up_m2n12(-64,4,6,8) SAVE_b_m2n12(-64,4,6,8)\ |
| 68 | +\ |
| 69 | + "movq %2,%3;" save_c_m4n4(4,5) save_c_m4n4(6,7) save_c_m4n4(8,9) |
| 70 | + |
| 71 | +#define SOLVE_LN_m8n4 \ |
| 72 | + "subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32)\ |
| 73 | +\ |
| 74 | + SOLVE_loup_m2n4(-8,7) SUBTRACT_m2n4(-16,6) SUBTRACT_m2n4(-24,5) SUBTRACT_m2n4(-32,4)\ |
| 75 | + SOLVE_up_m2n4(-40,7) SUBTRACT_m2n4(-48,6) SUBTRACT_m2n4(-56,5) SUBTRACT_m2n4(-64,4) SAVE_b_m2n4(-32,7)\ |
| 76 | +\ |
| 77 | + SOLVE_loup_m2n4(-80,6) SUBTRACT_m2n4(-88,5) SUBTRACT_m2n4(-96,4)\ |
| 78 | + SOLVE_up_m2n4(-112,6) SUBTRACT_m2n4(-120,5) SUBTRACT_m2n4(-128,4) SAVE_b_m2n4(-64,6)\ |
| 79 | +\ |
| 80 | + SOLVE_loup_m2n4(-152,5) SUBTRACT_m2n4(-160,4)\ |
| 81 | + SOLVE_up_m2n4(-184,5) SUBTRACT_m2n4(-192,4) SAVE_b_m2n4(-96,5)\ |
| 82 | +\ |
| 83 | + SOLVE_loup_m2n4(-224,4)\ |
| 84 | + SOLVE_up_m2n4(-256,4) SAVE_b_m2n4(-128,4)\ |
| 85 | +\ |
| 86 | + "movq %2,%3;" save_c_m8n4(4,5,6,7) |
| 87 | + |
| 88 | +#define SOLVE_LN_m8n8 \ |
| 89 | + "subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32) GEMM_SUM_REORDER_8x4(8,9,10,11,-32)\ |
| 90 | +\ |
| 91 | + SOLVE_loup_m2n8(-8,7,11) SUBTRACT_m2n8(-16,6,10) SUBTRACT_m2n8(-24,5,9) SUBTRACT_m2n8(-32,4,8)\ |
| 92 | + SOLVE_up_m2n8(-40,7,11) SUBTRACT_m2n8(-48,6,10) SUBTRACT_m2n8(-56,5,9) SUBTRACT_m2n8(-64,4,8) SAVE_b_m2n8(-32,7,11)\ |
| 93 | +\ |
| 94 | + SOLVE_loup_m2n8(-80,6,10) SUBTRACT_m2n8(-88,5,9) SUBTRACT_m2n8(-96,4,8)\ |
| 95 | + SOLVE_up_m2n8(-112,6,10) SUBTRACT_m2n8(-120,5,9) SUBTRACT_m2n8(-128,4,8) SAVE_b_m2n8(-64,6,10)\ |
| 96 | +\ |
| 97 | + SOLVE_loup_m2n8(-152,5,9) SUBTRACT_m2n8(-160,4,8)\ |
| 98 | + SOLVE_up_m2n8(-184,5,9) SUBTRACT_m2n8(-192,4,8) SAVE_b_m2n8(-96,5,9)\ |
| 99 | +\ |
| 100 | + SOLVE_loup_m2n8(-224,4,8)\ |
| 101 | + SOLVE_up_m2n8(-256,4,8) SAVE_b_m2n8(-128,4,8)\ |
| 102 | +\ |
| 103 | + "movq %2,%3;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11) |
| 104 | + |
| 105 | +#define SOLVE_LN_m8n12 \ |
| 106 | + "subq $32,%2; movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,-32) GEMM_SUM_REORDER_8x4(8,9,10,11,-32) GEMM_SUM_REORDER_8x4(12,13,14,15,-32)\ |
| 107 | +\ |
| 108 | + SOLVE_loup_m2n12(-8,7,11,15) SUBTRACT_m2n12(-16,6,10,14) SUBTRACT_m2n12(-24,5,9,13) SUBTRACT_m2n12(-32,4,8,12)\ |
| 109 | + SOLVE_up_m2n12(-40,7,11,15) SUBTRACT_m2n12(-48,6,10,14) SUBTRACT_m2n12(-56,5,9,13) SUBTRACT_m2n12(-64,4,8,12) SAVE_b_m2n12(-32,7,11,15)\ |
| 110 | +\ |
| 111 | + SOLVE_loup_m2n12(-80,6,10,14) SUBTRACT_m2n12(-88,5,9,13) SUBTRACT_m2n12(-96,4,8,12)\ |
| 112 | + SOLVE_up_m2n12(-112,6,10,14) SUBTRACT_m2n12(-120,5,9,13) SUBTRACT_m2n12(-128,4,8,12) SAVE_b_m2n12(-64,6,10,14)\ |
| 113 | +\ |
| 114 | + SOLVE_loup_m2n12(-152,5,9,13) SUBTRACT_m2n12(-160,4,8,12)\ |
| 115 | + SOLVE_up_m2n12(-184,5,9,13) SUBTRACT_m2n12(-192,4,8,12) SAVE_b_m2n12(-96,5,9,13)\ |
| 116 | +\ |
| 117 | + SOLVE_loup_m2n12(-224,4,8,12)\ |
| 118 | + SOLVE_up_m2n12(-256,4,8,12) SAVE_b_m2n12(-128,4,8,12)\ |
| 119 | +\ |
| 120 | + "movq %2,%3;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11) save_c_m8n4(12,13,14,15) |
| 121 | + |
| 122 | +/* r13 = k-kk, r14 = b_tail, r15 = a_tail */ |
| 123 | + |
| 124 | +#define GEMM_LN_SIMPLE(mdim,ndim) \ |
| 125 | + "movq %%r15,%0; negq %%r12; leaq (%%r15,%%r12,"#mdim"),%%r15; negq %%r12;"\ |
| 126 | + "movq %%r13,%5; addq $"#mdim",%%r13; movq %%r14,%1;" INIT_m##mdim##n##ndim\ |
| 127 | + "testq %5,%5; jz 2"#mdim""#ndim"2f;"\ |
| 128 | + "2"#mdim""#ndim"1:\n\t"\ |
| 129 | + "subq $16,%1; subq $"#mdim"*4,%0;" GEMM_KERNEL_k1m##mdim##n##ndim "decq %5; jnz 2"#mdim""#ndim"1b;"\ |
| 130 | + "2"#mdim""#ndim"2:\n\t" |
| 131 | +#define GEMM_LN_m8n4 GEMM_LN_SIMPLE(8,4) |
| 132 | +#define GEMM_LN_m8n8 GEMM_LN_SIMPLE(8,8) |
| 133 | +#define GEMM_LN_m8n12 \ |
| 134 | + "movq %%r15,%0; negq %%r12; leaq (%%r15,%%r12,8),%%r15; negq %%r12; movq %%r13,%5; addq $8,%%r13; movq %%r14,%1;" INIT_m8n12\ |
| 135 | + "cmpq $8,%5; jb 28122f;"\ |
| 136 | + "28121:\n\t"\ |
| 137 | + "prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 138 | + "subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 139 | + "prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 140 | + "subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 141 | + "prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 142 | + "subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 143 | + "prefetcht0 -384(%0); subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 144 | + "subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12\ |
| 145 | + "subq $8,%5; cmpq $8,%5; jnb 28121b;"\ |
| 146 | + "28122:\n\t"\ |
| 147 | + "testq %5,%5; jz 28124f;"\ |
| 148 | + "28123:\n\t"\ |
| 149 | + "subq $32,%0; subq $16,%1;" GEMM_KERNEL_k1m8n12 "decq %5; jnz 28123b;"\ |
| 150 | + "28124:\n\t" |
| 151 | +#define GEMM_LN_m4n4 GEMM_LN_SIMPLE(4,4) |
| 152 | +#define GEMM_LN_m4n8 GEMM_LN_SIMPLE(4,8) |
| 153 | +#define GEMM_LN_m4n12 GEMM_LN_SIMPLE(4,12) |
| 154 | +#define GEMM_LN_m2n4 GEMM_LN_SIMPLE(2,4) |
| 155 | +#define GEMM_LN_m2n8 GEMM_LN_SIMPLE(2,8) |
| 156 | +#define GEMM_LN_m2n12 GEMM_LN_SIMPLE(2,12) |
| 157 | +#define GEMM_LN_m1n4 GEMM_LN_SIMPLE(1,4) |
| 158 | +#define GEMM_LN_m1n8 GEMM_LN_SIMPLE(1,8) |
| 159 | +#define GEMM_LN_m1n12 GEMM_LN_SIMPLE(1,12) |
| 160 | + |
| 161 | +#define COMPUTE(ndim) {\ |
| 162 | + c_ptr += M;\ |
| 163 | + __asm__ __volatile__(\ |
| 164 | + "movq %0,%%r15; movq %7,%%r13; movq %6,%%r12; salq $2,%%r12; leaq (%1,%%r12,4),%%r14; movq %10,%%r11;"\ |
| 165 | + "testq $1,%%r11; jz "#ndim"772f;"\ |
| 166 | + #ndim"771:\n\t"\ |
| 167 | + GEMM_LN_m1n##ndim SOLVE_LN_m1n##ndim "subq $1,%%r11;"\ |
| 168 | + #ndim"772:\n\t"\ |
| 169 | + "testq $2,%%r11; jz "#ndim"773f;"\ |
| 170 | + GEMM_LN_m2n##ndim SOLVE_LN_m2n##ndim "subq $2,%%r11;"\ |
| 171 | + #ndim"773:\n\t"\ |
| 172 | + "testq $4,%%r11; jz "#ndim"774f;"\ |
| 173 | + GEMM_LN_m4n##ndim SOLVE_LN_m4n##ndim "subq $4,%%r11;"\ |
| 174 | + #ndim"774:\n\t"\ |
| 175 | + "testq %%r11,%%r11; jz "#ndim"776f;"\ |
| 176 | + #ndim"775:\n\t"\ |
| 177 | + GEMM_LN_m8n##ndim SOLVE_LN_m8n##ndim "subq $8,%%r11; jnz "#ndim"775b;"\ |
| 178 | + #ndim"776:\n\t"\ |
| 179 | + "movq %%r15,%0; movq %%r14,%1; vzeroupper;"\ |
| 180 | + :"+r"(a_ptr),"+r"(b_ptr),"+r"(c_ptr),"+r"(c_tmp),"+r"(ldc_bytes),"+r"(k_cnt):"m"(K),"m"(kmkkinp),"m"(one[0]),"m"(zero[0]),"m"(M)\ |
| 181 | + :"r11","r12","r13","r14","r15","cc","memory",\ |
| 182 | + "xmm0","xmm1","xmm2","xmm3","xmm4","xmm5","xmm6","xmm7","xmm8","xmm9","xmm10","xmm11","xmm12","xmm13","xmm14","xmm15");\ |
| 183 | + a_ptr += M * K; b_ptr += (ndim-4) * K; c_ptr += ldc * ndim;\ |
| 184 | +} |
| 185 | +static void solve_LN(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { |
| 186 | + FLOAT a0, b0; |
| 187 | + int i, j, k; |
| 188 | + for (i=m-1;i>=0;i--) { |
| 189 | + a0 = a[i*m+i]; //reciprocal of the original value |
| 190 | + for (j=0;j<n;j++) { |
| 191 | + b0 = c[j*ldc+i]*a0; |
| 192 | + c[j*ldc+i] = b[i*n+j] = b0; |
| 193 | + for (k=0;k<i;k++) c[j*ldc+k] -= b0*a[i*m+k]; |
| 194 | + } |
| 195 | + } |
| 196 | +} |
| 197 | +static void COMPUTE_EDGE_1_nchunk(BLASLONG m, BLASLONG n, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG k, BLASLONG offset) { |
| 198 | + BLASLONG m_count = m, kk = m+offset; FLOAT *a_ptr = sa+m*k, *c_ptr = C+m; |
| 199 | + if(m_count&1){ |
| 200 | + a_ptr-=k; c_ptr--; |
| 201 | + if(k-kk>0) GEMM_KERNEL_N(1,n,k-kk,-1.0,a_ptr+kk*1,sb+kk*n,c_ptr,ldc); |
| 202 | + solve_LN(1,n,a_ptr+(kk-1)*1,sb+(kk-1)*n,c_ptr,ldc); |
| 203 | + kk -= 1; |
| 204 | + m_count--; |
| 205 | + } |
| 206 | + if(m_count&2){ |
| 207 | + a_ptr-=k*2; c_ptr-=2; |
| 208 | + if(k-kk>0) GEMM_KERNEL_N(2,n,k-kk,-1.0,a_ptr+kk*2,sb+kk*n,c_ptr,ldc); |
| 209 | + solve_LN(2,n,a_ptr+(kk-2)*2,sb+(kk-2)*n,c_ptr,ldc); |
| 210 | + kk -= 2; |
| 211 | + m_count-=2; |
| 212 | + } |
| 213 | + if(m_count&4){ |
| 214 | + a_ptr-=k*4; c_ptr-=4; |
| 215 | + if(k-kk>0) GEMM_KERNEL_N(4,n,k-kk,-1.0,a_ptr+kk*4,sb+kk*n,c_ptr,ldc); |
| 216 | + solve_LN(4,n,a_ptr+(kk-4)*4,sb+(kk-4)*n,c_ptr,ldc); |
| 217 | + kk -= 4; |
| 218 | + m_count-=4; |
| 219 | + } |
| 220 | + for(;m_count>7;m_count-=8){ |
| 221 | + a_ptr-=k*8; c_ptr-=8; |
| 222 | + if(k-kk>0) GEMM_KERNEL_N(8,n,k-kk,-1.0,a_ptr+kk*8,sb+kk*n,c_ptr,ldc); |
| 223 | + solve_LN(8,n,a_ptr+(kk-8)*8,sb+(kk-8)*n,c_ptr,ldc); |
| 224 | + kk -= 8; |
| 225 | + } |
| 226 | +} |
| 227 | +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG offset){ |
| 228 | + float *a_ptr = sa+m*k, *b_ptr = sb, *c_ptr = C, *c_tmp = C; |
| 229 | + float one[8] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0}; |
| 230 | + float zero[8] = {0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; |
| 231 | + uint64_t ldc_bytes = (uint64_t)ldc * sizeof(float), K = (uint64_t)k, M = (uint64_t)m, kmkkinp = (uint64_t)(k-m-offset), k_cnt = 0; |
| 232 | + BLASLONG n_count = n; |
| 233 | + for(;n_count>11;n_count-=12) COMPUTE(12) |
| 234 | + for(;n_count>7;n_count-=8) COMPUTE(8) |
| 235 | + for(;n_count>3;n_count-=4) COMPUTE(4) |
| 236 | + for(;n_count>1;n_count-=2) { COMPUTE_EDGE_1_nchunk(m,2,sa,b_ptr,c_ptr,ldc,k,offset); b_ptr += 2*k; c_ptr += ldc*2;} |
| 237 | + if(n_count>0) COMPUTE_EDGE_1_nchunk(m,1,sa,b_ptr,c_ptr,ldc,k,offset); |
| 238 | + return 0; |
| 239 | +} |
| 240 | + |
0 commit comments