@@ -29,20 +29,117 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929#include <immintrin.h>
3030#include "common.h"
3131
32+ #define STORE_VEC (Bx , By , vec ) \
33+ if (By == 0) asm("vmovdqu16 %0, (%1)": : "v"(vec), "r"(boffset##Bx)); \
34+ else asm("vmovdqu16 %0, (%1, %2, %c3)": : "v"(vec), "r"(boffset##Bx), "r"(blk_size), "n"(By * 2));
35+
3236int CNAME (BLASLONG m , BLASLONG n , IFLOAT * a , BLASLONG lda , IFLOAT * b ){
3337 BLASLONG i , j ;
3438
3539 IFLOAT * boffset0 , * boffset1 ;
3640
3741 boffset0 = b ;
3842
43+ BLASLONG n24 = n - (n % 24 );
3944 BLASLONG n8 = n & ~7 ;
45+ BLASLONG m8 = m & ~7 ;
4046 BLASLONG m4 = m & ~3 ;
4147 BLASLONG m2 = m & ~1 ;
4248
43- for (j = 0 ; j < n8 ; j += 8 ) {
49+ int permute_table [] = {
50+ 0x0 , 0x1 , 0x2 , 0x3 , 0x10 , 0x11 , 0x12 , 0x13 , 0x8 , 0x9 , 0xa , 0xb , 0x18 , 0x19 , 0x1a , 0x1b ,
51+ 0x4 , 0x5 , 0x6 , 0x7 , 0x14 , 0x15 , 0x16 , 0x17 , 0xc , 0xd , 0xe , 0xf , 0x1c , 0x1d , 0x1e , 0x1f ,
52+ 0x0 , 0x1 , 0x2 , 0x3 , 0x4 , 0x5 , 0x6 , 0x7 , 0x10 , 0x11 , 0x12 , 0x13 , 0x14 , 0x15 , 0x16 , 0x17 ,
53+ 0x8 , 0x9 , 0xa , 0xb , 0xc , 0xd , 0xe , 0xf , 0x18 , 0x19 , 0x1a , 0x1b , 0x1c , 0x1d , 0x1e , 0x1f ,
54+ };
55+
56+ j = 0 ;
57+ if (n > 23 ) {
58+ /* n = 24 is the max width in current blocking setting */
59+ __m512i idx_lo_128 = _mm512_loadu_si512 (permute_table );
60+ __m512i idx_hi_128 = _mm512_loadu_si512 (permute_table + 16 );
61+ __m512i idx_lo_256 = _mm512_loadu_si512 (permute_table + 32 );
62+ __m512i idx_hi_256 = _mm512_loadu_si512 (permute_table + 48 );
63+ __mmask32 mask24 = (1UL << 24 ) - 1 ;
64+ BLASLONG blk_size = m * 4 ;
65+ BLASLONG stride = blk_size * 3 ;
66+
67+ for (; j < n24 ; j += 24 ) {
68+ boffset1 = boffset0 + stride ;
69+ for (i = 0 ; i < m8 ; i += 8 ) {
70+ __m512i r0 , r1 , r2 , r3 , r4 , r5 , r6 , r7 ;
71+ __m512i t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 ;
72+ r0 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 0 )* lda + j ]);
73+ r1 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 1 )* lda + j ]);
74+ r2 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 2 )* lda + j ]);
75+ r3 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 3 )* lda + j ]);
76+ r4 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 4 )* lda + j ]);
77+ r5 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 5 )* lda + j ]);
78+ r6 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 6 )* lda + j ]);
79+ r7 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 7 )* lda + j ]);
80+
81+ t0 = _mm512_unpacklo_epi16 (r0 , r1 );
82+ t1 = _mm512_unpackhi_epi16 (r0 , r1 );
83+ t2 = _mm512_unpacklo_epi16 (r2 , r3 );
84+ t3 = _mm512_unpackhi_epi16 (r2 , r3 );
85+ t4 = _mm512_unpacklo_epi16 (r4 , r5 );
86+ t5 = _mm512_unpackhi_epi16 (r4 , r5 );
87+ t6 = _mm512_unpacklo_epi16 (r6 , r7 );
88+ t7 = _mm512_unpackhi_epi16 (r6 , r7 );
89+
90+ r0 = _mm512_permutex2var_epi32 (t0 , idx_lo_128 , t2 );
91+ r1 = _mm512_permutex2var_epi32 (t1 , idx_lo_128 , t3 );
92+ r2 = _mm512_permutex2var_epi32 (t4 , idx_lo_128 , t6 );
93+ r3 = _mm512_permutex2var_epi32 (t5 , idx_lo_128 , t7 );
94+ r4 = _mm512_permutex2var_epi32 (t0 , idx_hi_128 , t2 );
95+ r5 = _mm512_permutex2var_epi32 (t1 , idx_hi_128 , t3 );
96+ r6 = _mm512_permutex2var_epi32 (t4 , idx_hi_128 , t6 );
97+ r7 = _mm512_permutex2var_epi32 (t5 , idx_hi_128 , t7 );
98+
99+ t0 = _mm512_permutex2var_epi32 (r0 , idx_lo_256 , r2 );
100+ t1 = _mm512_permutex2var_epi32 (r1 , idx_lo_256 , r3 );
101+ t2 = _mm512_permutex2var_epi32 (r4 , idx_lo_256 , r6 );
102+ t3 = _mm512_permutex2var_epi32 (r5 , idx_lo_256 , r7 );
103+ t4 = _mm512_permutex2var_epi32 (r0 , idx_hi_256 , r2 );
104+ t5 = _mm512_permutex2var_epi32 (r1 , idx_hi_256 , r3 );
105+
106+ STORE_VEC (0 , 0 , t0 ); STORE_VEC (0 , 1 , t1 ); STORE_VEC (0 , 2 , t2 );
107+ STORE_VEC (1 , 0 , t3 ); STORE_VEC (1 , 1 , t4 ); STORE_VEC (1 , 2 , t5 );
108+ boffset0 += 32 ;
109+ boffset1 += 32 ;
110+ }
111+ for (; i < m2 ; i += 2 ) {
112+ __m512i r0 , r1 , t0 , t1 ;
113+ r0 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 0 )* lda + j ]);
114+ r1 = _mm512_maskz_loadu_epi16 (mask24 , & a [(i + 1 )* lda + j ]);
115+ t0 = _mm512_unpacklo_epi16 (r0 , r1 );
116+ t1 = _mm512_unpackhi_epi16 (r0 , r1 );
117+ STORE_VEC (0 , 0 , _mm512_extracti32x4_epi32 (t0 , 0 ));
118+ STORE_VEC (0 , 1 , _mm512_extracti32x4_epi32 (t1 , 0 ));
119+ STORE_VEC (0 , 2 , _mm512_extracti32x4_epi32 (t0 , 1 ));
120+ STORE_VEC (1 , 0 , _mm512_extracti32x4_epi32 (t1 , 1 ));
121+ STORE_VEC (1 , 1 , _mm512_extracti32x4_epi32 (t0 , 2 ));
122+ STORE_VEC (1 , 2 , _mm512_extracti32x4_epi32 (t1 , 2 ));
123+ boffset0 += 8 ;
124+ boffset1 += 8 ;
125+ }
126+ for (; i < m ; i ++ ) {
127+ * (uint64_t * )(boffset0 + blk_size * 0 ) = * (uint64_t * )& a [i * lda + j + 0 ];
128+ * (uint64_t * )(boffset0 + blk_size * 1 ) = * (uint64_t * )& a [i * lda + j + 4 ];
129+ * (uint64_t * )(boffset0 + blk_size * 2 ) = * (uint64_t * )& a [i * lda + j + 8 ];
130+ * (uint64_t * )(boffset1 + blk_size * 0 ) = * (uint64_t * )& a [i * lda + j + 12 ];
131+ * (uint64_t * )(boffset1 + blk_size * 1 ) = * (uint64_t * )& a [i * lda + j + 16 ];
132+ * (uint64_t * )(boffset1 + blk_size * 2 ) = * (uint64_t * )& a [i * lda + j + 20 ];
133+ boffset0 += 4 ;
134+ boffset1 += 4 ;
135+ }
136+ boffset0 += stride * 2 ;
137+ }
138+ }
139+
140+ for (; j < n8 ; j += 8 ) {
44141 boffset1 = boffset0 + m * 4 ;
45- for (i = 0 ; i < m4 ; i += 4 ) {
142+ for (i = 0 ; i < m4 ; i += 4 ) {
46143 __m128i a0 = _mm_loadu_si128 ((void * )& a [(i + 0 )* lda + j ]);
47144 __m128i a1 = _mm_loadu_si128 ((void * )& a [(i + 1 )* lda + j ]);
48145 __m128i a2 = _mm_loadu_si128 ((void * )& a [(i + 2 )* lda + j ]);
0 commit comments