@@ -46,83 +46,50 @@ typedef union
4646 } bits ;
4747} bfloat16_bits ;
4848
49- typedef union
50- {
51- float v ;
52- struct
53- {
54- uint32_t m :23 ;
55- uint32_t e :8 ;
56- uint32_t s :1 ;
57- } bits ;
58- } float32_bits ;
59-
60- float
61- float16to32 (bfloat16_bits f16 )
62- {
63- float32_bits f32 ;
64- f32 .bits .s = f16 .bits .s ;
65- f32 .bits .e = f16 .bits .e ;
66- f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
67- return f32 .v ;
68- }
69-
7049int
7150main (int argc , char * argv [])
7251{
7352 int m , n , k ;
7453 int i , j , l ;
54+ int x ;
7555 int ret = 0 ;
7656 int loop = 100 ;
7757 char transA = 'N' , transB = 'N' ;
7858 float alpha = 1.0 , beta = 0.0 ;
59+ char transa = 'N' ;
60+ char transb = 'N' ;
7961
80- for (int x = 0 ; x <= loop ; x ++ )
62+ for (x = 0 ; x <= loop ; x ++ )
8163 {
8264 m = k = n = x ;
8365 float A [m * k ];
8466 float B [k * n ];
8567 float C [m * n ];
8668 bfloat16_bits AA [m * k ], BB [k * n ];
87- float DD [ m * n ], CC [m * n ];
69+ float CC [m * n ];
8870
89- for (int j = 0 ; j < m ; j ++ )
71+ for (j = 0 ; j < m ; j ++ )
9072 {
91- for (int i = 0 ; i < m ; i ++ )
73+ for (i = 0 ; i < m ; i ++ )
9274 {
93- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
94- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
75+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
76+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
9577 C [j * k + i ] = 0 ;
9678 AA [j * k + i ].v = * (uint32_t * ) & A [j * k + i ] >> 16 ;
9779 BB [j * k + i ].v = * (uint32_t * ) & B [j * k + i ] >> 16 ;
9880 CC [j * k + i ] = 0 ;
99- DD [j * k + i ] = 0 ;
10081 }
10182 }
10283 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
103- & m , B , & k , & beta , C , & m );
84+ & m , B , & k , & beta , C , & m );
10485 SHGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
105- & m , BB , & k , & beta , CC , & m );
86+ & m , BB , & k , & beta , CC , & m );
87+
10688 for (i = 0 ; i < n ; i ++ )
107- for (j = 0 ; j < m ; j ++ )
108- for (l = 0 ; l < k ; l ++ )
109- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
110- ret ++ ;
111- if (transA == 'N' && transB == 'N' )
112- {
113- for (i = 0 ; i < n ; i ++ )
114- for (j = 0 ; j < m ; j ++ )
115- for (l = 0 ; l < k ; l ++ )
116- {
117- DD [i * m + j ] +=
118- float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
119- }
120- for (i = 0 ; i < n ; i ++ )
121- for (j = 0 ; j < m ; j ++ )
122- for (l = 0 ; l < k ; l ++ )
123- if (CC [i * m + j ] != DD [i * m + j ])
124- ret ++ ;
125- }
89+ for (j = 0 ; j < m ; j ++ )
90+ for (l = 0 ; l < k ; l ++ )
91+ if (fabs (CC [i * m + j ]- C [i * m + j ]) > 1.0 )
92+ ret ++ ;
12693 }
12794 if (ret != 0 )
12895 fprintf (stderr , "FATAL ERROR SHGEMM - Return code: %d\n" , ret );
0 commit comments