@@ -26,8 +26,328 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626*****************************************************************************/
2727
2828#include <stdio.h>
29+ #include <immintrin.h>
2930#include "common.h"
3031
32+ #define _MM512_SHUFFLE_i32 (result , in1 , in2 , imm8 ) \
33+ asm("vshufps %3, %2, %1, %0": "=v"(result): "v"(in1), "v"(in2), "N"(imm8))
34+
35+ #define REORDER_8x32 (t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 ) { \
36+ __m512i v; \
37+ t0 = _mm512_unpacklo_epi32(r0, r1); \
38+ t1 = _mm512_unpackhi_epi32(r0, r1); \
39+ t2 = _mm512_unpacklo_epi32(r2, r3); \
40+ t3 = _mm512_unpackhi_epi32(r2, r3); \
41+ t4 = _mm512_unpacklo_epi32(r4, r5); \
42+ t5 = _mm512_unpackhi_epi32(r4, r5); \
43+ t6 = _mm512_unpacklo_epi32(r6, r7); \
44+ t7 = _mm512_unpackhi_epi32(r6, r7); \
45+ _MM512_SHUFFLE_i32(v, t0, t2, 0x4E); \
46+ r0 = _mm512_mask_blend_epi32(kc, t0, v); \
47+ r1 = _mm512_mask_blend_epi32(k3, t2, v); \
48+ _MM512_SHUFFLE_i32(v, t1, t3, 0x4E); \
49+ r2 = _mm512_mask_blend_epi32(kc, t1, v); \
50+ r3 = _mm512_mask_blend_epi32(k3, t3, v); \
51+ _MM512_SHUFFLE_i32(v, t4, t6, 0x4E); \
52+ r4 = _mm512_mask_blend_epi32(kc, t4, v); \
53+ r5 = _mm512_mask_blend_epi32(k3, t6, v); \
54+ _MM512_SHUFFLE_i32(v, t5, t7, 0x4E); \
55+ r6 = _mm512_mask_blend_epi32(kc, t5, v); \
56+ r7 = _mm512_mask_blend_epi32(k3, t7, v); \
57+ t0 = _mm512_permutex2var_epi32(r0, idx_lo, r4); \
58+ t1 = _mm512_permutex2var_epi32(r1, idx_lo, r5); \
59+ t2 = _mm512_permutex2var_epi32(r2, idx_lo, r6); \
60+ t3 = _mm512_permutex2var_epi32(r3, idx_lo, r7); \
61+ t4 = _mm512_permutex2var_epi32(r0, idx_hi, r4); \
62+ t5 = _mm512_permutex2var_epi32(r1, idx_hi, r5); \
63+ t6 = _mm512_permutex2var_epi32(r2, idx_hi, r6); \
64+ t7 = _mm512_permutex2var_epi32(r3, idx_hi, r7); \
65+ }
66+
67+ #define STORE_512_LO (x ) \
68+ v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
69+ _mm512_storeu_si512(boffset0 + x*32, v);
70+
71+ #define STORE_512_HI (x ) \
72+ v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
73+ _mm512_storeu_si512(boffset0 + (x + 8)*32, v);
74+
75+ #define MASK_STORE_512_LO (x ) \
76+ v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
77+ _mm512_mask_storeu_epi32(boffset0 + 2*x*remain_n, nmask, v);
78+
79+ #define MASK_STORE_512_HI (x ) \
80+ v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
81+ _mm512_mask_storeu_epi32(boffset0 + 2*(x + 8)*remain_n, nmask, v);
82+
83+ #define STORE_512 (x , y ) {\
84+ __m512i v; \
85+ if (x == 0) { STORE_512_LO(y); } \
86+ else { STORE_512_HI(y); } \
87+ }
88+
89+ #define MASK_STORE_512 (x , y ) {\
90+ __m512i v; \
91+ if (x == 0) { MASK_STORE_512_LO(y); } \
92+ else { MASK_STORE_512_HI(y); } \
93+ }
94+
95+ #define SET_TAIL (y , x ) {\
96+ if (y == 0) tail = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
97+ else tail = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
98+ }
99+
100+ #define GET_TAIL () \
101+ switch (n_store + 1) { \
102+ case 16: SET_TAIL(1, 7); break; \
103+ case 15: SET_TAIL(1, 6); break; \
104+ case 14: SET_TAIL(1, 5); break; \
105+ case 13: SET_TAIL(1, 4); break; \
106+ case 12: SET_TAIL(1, 3); break; \
107+ case 11: SET_TAIL(1, 2); break; \
108+ case 10: SET_TAIL(1, 1); break; \
109+ case 9: SET_TAIL(1, 0); break; \
110+ case 8: SET_TAIL(0, 7); break; \
111+ case 7: SET_TAIL(0, 6); break; \
112+ case 6: SET_TAIL(0, 5); break; \
113+ case 5: SET_TAIL(0, 4); break; \
114+ case 4: SET_TAIL(0, 3); break; \
115+ case 3: SET_TAIL(0, 2); break; \
116+ case 2: SET_TAIL(0, 1); break; \
117+ case 1: SET_TAIL(0, 0); break; \
118+ }
119+
120+
31121int CNAME (BLASLONG m , BLASLONG n , IFLOAT * a , BLASLONG lda , IFLOAT * b ){
122+ BLASLONG i , j ;
123+
124+ IFLOAT * boffset0 ;
125+ IFLOAT * aoffset ;
126+ IFLOAT * aoffset00 , * aoffset01 , * aoffset02 , * aoffset03 , * aoffset04 , * aoffset05 , * aoffset06 , * aoffset07 ;
127+ IFLOAT * aoffset10 , * aoffset11 , * aoffset12 , * aoffset13 , * aoffset14 , * aoffset15 , * aoffset16 , * aoffset17 ;
128+ aoffset = a ;
129+ boffset0 = b ;
130+
131+ BLASLONG n16 = n & ~15 ;
132+ BLASLONG m32 = m & ~31 ;
133+
134+ int permute_table [] = {
135+ 0x0 , 0x1 , 0x2 , 0x3 , 0x10 , 0x11 , 0x12 , 0x13 , 0x8 , 0x9 , 0xa , 0xb , 0x18 , 0x19 , 0x1a , 0x1b ,
136+ 0x4 , 0x5 , 0x6 , 0x7 , 0x14 , 0x15 , 0x16 , 0x17 , 0xc , 0xd , 0xe , 0xf , 0x1c , 0x1d , 0x1e , 0x1f ,
137+ };
138+ u_int64_t permute_table2 [] = {
139+ 0x00 , 0x01 , 0x02 , 0x03 , 8 |0x0 , 8 |0x1 , 8 |0x2 , 8 |0x3 ,
140+ 0x04 , 0x05 , 0x06 , 0x07 , 8 |0x4 , 8 |0x5 , 8 |0x6 , 8 |0x7 ,
141+ };
142+ __m512i idx_lo = _mm512_loadu_si512 (permute_table );
143+ __m512i idx_hi = _mm512_loadu_si512 (permute_table + 16 );
144+ __m512i idx_lo2 = _mm512_loadu_si512 (permute_table2 );
145+ __m512i idx_hi2 = _mm512_loadu_si512 (permute_table2 + 8 );
146+ __mmask16 kc = 0xcccc ;
147+ __mmask16 k3 = 0x3333 ;
148+ __m512i r0 , r1 , r2 , r3 , r4 , r5 , r6 , r7 ;
149+ __m512i t00 , t01 , t02 , t03 , t04 , t05 , t06 , t07 ;
150+ __m512i t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 ;
151+
152+ for (j = 0 ; j < n16 ; j += 16 ) {
153+ aoffset00 = aoffset ;
154+ aoffset01 = aoffset00 + lda ;
155+ aoffset02 = aoffset01 + lda ;
156+ aoffset03 = aoffset02 + lda ;
157+ aoffset04 = aoffset03 + lda ;
158+ aoffset05 = aoffset04 + lda ;
159+ aoffset06 = aoffset05 + lda ;
160+ aoffset07 = aoffset06 + lda ;
161+ aoffset10 = aoffset07 + lda ;
162+ aoffset11 = aoffset10 + lda ;
163+ aoffset12 = aoffset11 + lda ;
164+ aoffset13 = aoffset12 + lda ;
165+ aoffset14 = aoffset13 + lda ;
166+ aoffset15 = aoffset14 + lda ;
167+ aoffset16 = aoffset15 + lda ;
168+ aoffset17 = aoffset16 + lda ;
169+ aoffset += 16 * lda ;
170+ for (i = 0 ; i < m32 ; i += 32 ) {
171+ r0 = _mm512_loadu_si512 (aoffset00 + i );
172+ r1 = _mm512_loadu_si512 (aoffset01 + i );
173+ r2 = _mm512_loadu_si512 (aoffset02 + i );
174+ r3 = _mm512_loadu_si512 (aoffset03 + i );
175+ r4 = _mm512_loadu_si512 (aoffset04 + i );
176+ r5 = _mm512_loadu_si512 (aoffset05 + i );
177+ r6 = _mm512_loadu_si512 (aoffset06 + i );
178+ r7 = _mm512_loadu_si512 (aoffset07 + i );
179+ REORDER_8x32 (t00 , t01 , t02 , t03 , t04 , t05 , t06 , t07 );
180+ r0 = _mm512_loadu_si512 (aoffset10 + i );
181+ r1 = _mm512_loadu_si512 (aoffset11 + i );
182+ r2 = _mm512_loadu_si512 (aoffset12 + i );
183+ r3 = _mm512_loadu_si512 (aoffset13 + i );
184+ r4 = _mm512_loadu_si512 (aoffset14 + i );
185+ r5 = _mm512_loadu_si512 (aoffset15 + i );
186+ r6 = _mm512_loadu_si512 (aoffset16 + i );
187+ r7 = _mm512_loadu_si512 (aoffset17 + i );
188+ REORDER_8x32 (t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 );
189+ STORE_512 (0 , 0 ); STORE_512 (0 , 1 ); STORE_512 (0 , 2 ); STORE_512 (0 , 3 );
190+ STORE_512 (0 , 4 ); STORE_512 (0 , 5 ); STORE_512 (0 , 6 ); STORE_512 (0 , 7 );
191+ STORE_512 (1 , 0 ); STORE_512 (1 , 1 ); STORE_512 (1 , 2 ); STORE_512 (1 , 3 );
192+ STORE_512 (1 , 4 ); STORE_512 (1 , 5 ); STORE_512 (1 , 6 ); STORE_512 (1 , 7 );
193+ boffset0 += 16 * 32 ;
194+ }
195+ if (i < m ) {
196+ int remain_m = m - i ;
197+ __mmask32 mmask = (1UL << remain_m ) - 1 ;
198+ r0 = _mm512_maskz_loadu_epi16 (mmask , aoffset00 + i );
199+ r1 = _mm512_maskz_loadu_epi16 (mmask , aoffset01 + i );
200+ r2 = _mm512_maskz_loadu_epi16 (mmask , aoffset02 + i );
201+ r3 = _mm512_maskz_loadu_epi16 (mmask , aoffset03 + i );
202+ r4 = _mm512_maskz_loadu_epi16 (mmask , aoffset04 + i );
203+ r5 = _mm512_maskz_loadu_epi16 (mmask , aoffset05 + i );
204+ r6 = _mm512_maskz_loadu_epi16 (mmask , aoffset06 + i );
205+ r7 = _mm512_maskz_loadu_epi16 (mmask , aoffset07 + i );
206+ REORDER_8x32 (t00 , t01 , t02 , t03 , t04 , t05 , t06 , t07 );
207+ r0 = _mm512_maskz_loadu_epi16 (mmask , aoffset10 + i );
208+ r1 = _mm512_maskz_loadu_epi16 (mmask , aoffset11 + i );
209+ r2 = _mm512_maskz_loadu_epi16 (mmask , aoffset12 + i );
210+ r3 = _mm512_maskz_loadu_epi16 (mmask , aoffset13 + i );
211+ r4 = _mm512_maskz_loadu_epi16 (mmask , aoffset14 + i );
212+ r5 = _mm512_maskz_loadu_epi16 (mmask , aoffset15 + i );
213+ r6 = _mm512_maskz_loadu_epi16 (mmask , aoffset16 + i );
214+ r7 = _mm512_maskz_loadu_epi16 (mmask , aoffset17 + i );
215+ REORDER_8x32 (t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 );
216+ int n_store = remain_m /2 ;
217+ switch (n_store ) {
218+ case 15 : STORE_512 (1 , 6 );
219+ case 14 : STORE_512 (1 , 5 );
220+ case 13 : STORE_512 (1 , 4 );
221+ case 12 : STORE_512 (1 , 3 );
222+ case 11 : STORE_512 (1 , 2 );
223+ case 10 : STORE_512 (1 , 1 );
224+ case 9 : STORE_512 (1 , 0 );
225+ case 8 : STORE_512 (0 , 7 );
226+ case 7 : STORE_512 (0 , 6 );
227+ case 6 : STORE_512 (0 , 5 );
228+ case 5 : STORE_512 (0 , 4 );
229+ case 4 : STORE_512 (0 , 3 );
230+ case 3 : STORE_512 (0 , 2 );
231+ case 2 : STORE_512 (0 , 1 );
232+ case 1 : STORE_512 (0 , 0 );
233+ }
234+ boffset0 += n_store * 32 ;
235+ if (m & 0x1 ) {
236+ __m512i tail ;
237+ GET_TAIL ();
238+ _mm256_storeu_si256 ((void * )boffset0 , _mm512_cvtepi32_epi16 (tail ));
239+ boffset0 += 16 ;
240+ }
241+ }
32242
243+ }
244+ if (j < n ) {
245+ int remain_n = n - j ;
246+ __mmask16 nmask = (1UL << remain_n ) - 1 ;
247+ int load0 , load1 ;
248+ if (remain_n > 8 ) {
249+ load0 = 8 ;
250+ load1 = remain_n - 8 ;
251+ } else {
252+ load0 = remain_n ;
253+ load1 = 0 ;
254+ }
255+ aoffset00 = aoffset ;
256+ aoffset01 = aoffset00 + lda ;
257+ aoffset02 = aoffset01 + lda ;
258+ aoffset03 = aoffset02 + lda ;
259+ aoffset04 = aoffset03 + lda ;
260+ aoffset05 = aoffset04 + lda ;
261+ aoffset06 = aoffset05 + lda ;
262+ aoffset07 = aoffset06 + lda ;
263+ aoffset10 = aoffset07 + lda ;
264+ aoffset11 = aoffset10 + lda ;
265+ aoffset12 = aoffset11 + lda ;
266+ aoffset13 = aoffset12 + lda ;
267+ aoffset14 = aoffset13 + lda ;
268+ aoffset15 = aoffset14 + lda ;
269+ aoffset16 = aoffset15 + lda ;
270+ aoffset17 = aoffset16 + lda ;
271+ aoffset += 16 * lda ;
272+ for (i = 0 ; i < m32 ; i += 32 ) {
273+ switch (load0 ) {
274+ case 8 : r7 = _mm512_loadu_si512 (aoffset07 + i );
275+ case 7 : r6 = _mm512_loadu_si512 (aoffset06 + i );
276+ case 6 : r5 = _mm512_loadu_si512 (aoffset05 + i );
277+ case 5 : r4 = _mm512_loadu_si512 (aoffset04 + i );
278+ case 4 : r3 = _mm512_loadu_si512 (aoffset03 + i );
279+ case 3 : r2 = _mm512_loadu_si512 (aoffset02 + i );
280+ case 2 : r1 = _mm512_loadu_si512 (aoffset01 + i );
281+ case 1 : r0 = _mm512_loadu_si512 (aoffset00 + i );
282+ }
283+ REORDER_8x32 (t00 , t01 , t02 , t03 , t04 , t05 , t06 , t07 );
284+ switch (load1 ) {
285+ case 8 : r7 = _mm512_loadu_si512 (aoffset17 + i );
286+ case 7 : r6 = _mm512_loadu_si512 (aoffset16 + i );
287+ case 6 : r5 = _mm512_loadu_si512 (aoffset15 + i );
288+ case 5 : r4 = _mm512_loadu_si512 (aoffset14 + i );
289+ case 4 : r3 = _mm512_loadu_si512 (aoffset13 + i );
290+ case 3 : r2 = _mm512_loadu_si512 (aoffset12 + i );
291+ case 2 : r1 = _mm512_loadu_si512 (aoffset11 + i );
292+ case 1 : r0 = _mm512_loadu_si512 (aoffset10 + i );
293+ }
294+ REORDER_8x32 (t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 );
295+ MASK_STORE_512 (0 , 0 ); MASK_STORE_512 (0 , 1 ); MASK_STORE_512 (0 , 2 ); MASK_STORE_512 (0 , 3 );
296+ MASK_STORE_512 (0 , 4 ); MASK_STORE_512 (0 , 5 ); MASK_STORE_512 (0 , 6 ); MASK_STORE_512 (0 , 7 );
297+ MASK_STORE_512 (1 , 0 ); MASK_STORE_512 (1 , 1 ); MASK_STORE_512 (1 , 2 ); MASK_STORE_512 (1 , 3 );
298+ MASK_STORE_512 (1 , 4 ); MASK_STORE_512 (1 , 5 ); MASK_STORE_512 (1 , 6 ); MASK_STORE_512 (1 , 7 );
299+ boffset0 += remain_n * 32 ;
300+ }
301+ if (i < m ) {
302+ int remain_m = m - i ;
303+ __mmask32 mmask = (1UL << remain_m ) - 1 ;
304+ switch (load0 ) {
305+ case 8 : r7 = _mm512_maskz_loadu_epi16 (mmask , aoffset07 + i );
306+ case 7 : r6 = _mm512_maskz_loadu_epi16 (mmask , aoffset06 + i );
307+ case 6 : r5 = _mm512_maskz_loadu_epi16 (mmask , aoffset05 + i );
308+ case 5 : r4 = _mm512_maskz_loadu_epi16 (mmask , aoffset04 + i );
309+ case 4 : r3 = _mm512_maskz_loadu_epi16 (mmask , aoffset03 + i );
310+ case 3 : r2 = _mm512_maskz_loadu_epi16 (mmask , aoffset02 + i );
311+ case 2 : r1 = _mm512_maskz_loadu_epi16 (mmask , aoffset01 + i );
312+ case 1 : r0 = _mm512_maskz_loadu_epi16 (mmask , aoffset00 + i );
313+ }
314+ REORDER_8x32 (t00 , t01 , t02 , t03 , t04 , t05 , t06 , t07 );
315+ switch (load1 ) {
316+ case 8 : r7 = _mm512_maskz_loadu_epi16 (mmask , aoffset17 + i );
317+ case 7 : r6 = _mm512_maskz_loadu_epi16 (mmask , aoffset16 + i );
318+ case 6 : r5 = _mm512_maskz_loadu_epi16 (mmask , aoffset15 + i );
319+ case 5 : r4 = _mm512_maskz_loadu_epi16 (mmask , aoffset14 + i );
320+ case 4 : r3 = _mm512_maskz_loadu_epi16 (mmask , aoffset13 + i );
321+ case 3 : r2 = _mm512_maskz_loadu_epi16 (mmask , aoffset12 + i );
322+ case 2 : r1 = _mm512_maskz_loadu_epi16 (mmask , aoffset11 + i );
323+ case 1 : r0 = _mm512_maskz_loadu_epi16 (mmask , aoffset10 + i );
324+ }
325+ REORDER_8x32 (t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 );
326+ int n_store = remain_m /2 ;
327+ switch (n_store ) {
328+ case 15 : MASK_STORE_512 (1 , 6 );
329+ case 14 : MASK_STORE_512 (1 , 5 );
330+ case 13 : MASK_STORE_512 (1 , 4 );
331+ case 12 : MASK_STORE_512 (1 , 3 );
332+ case 11 : MASK_STORE_512 (1 , 2 );
333+ case 10 : MASK_STORE_512 (1 , 1 );
334+ case 9 : MASK_STORE_512 (1 , 0 );
335+ case 8 : MASK_STORE_512 (0 , 7 );
336+ case 7 : MASK_STORE_512 (0 , 6 );
337+ case 6 : MASK_STORE_512 (0 , 5 );
338+ case 5 : MASK_STORE_512 (0 , 4 );
339+ case 4 : MASK_STORE_512 (0 , 3 );
340+ case 3 : MASK_STORE_512 (0 , 2 );
341+ case 2 : MASK_STORE_512 (0 , 1 );
342+ case 1 : MASK_STORE_512 (0 , 0 );
343+ }
344+ boffset0 += n_store * remain_n * 2 ;
345+ if (m & 0x1 ) {
346+ __m512i tail ;
347+ GET_TAIL ();
348+ _mm256_mask_storeu_epi16 ((void * )boffset0 , nmask , _mm512_cvtepi32_epi16 (tail ));
349+ }
350+ }
351+ }
352+ return 0 ;
33353}
0 commit comments