@@ -46,6 +46,27 @@ typedef union
4646 } bits ;
4747} bfloat16_bits ;
4848
49+ typedef union
50+ {
51+ float v ;
52+ struct
53+ {
54+ uint32_t m :23 ;
55+ uint32_t e :8 ;
56+ uint32_t s :1 ;
57+ } bits ;
58+ } float32_bits ;
59+
60+ float
61+ float16to32 (bfloat16_bits f16 )
62+ {
63+ float32_bits f32 ;
64+ f32 .bits .s = f16 .bits .s ;
65+ f32 .bits .e = f16 .bits .e ;
66+ f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
67+ return f32 .v ;
68+ }
69+
4970int
5071main (int argc , char * argv [])
5172{
@@ -56,8 +77,6 @@ main (int argc, char *argv[])
5677 int loop = 100 ;
5778 char transA = 'N' , transB = 'N' ;
5879 float alpha = 1.0 , beta = 0.0 ;
59- char transa = 'N' ;
60- char transb = 'N' ;
6180
6281 for (x = 0 ; x <= loop ; x ++ )
6382 {
@@ -66,30 +85,45 @@ main (int argc, char *argv[])
6685 float B [k * n ];
6786 float C [m * n ];
6887 bfloat16_bits AA [m * k ], BB [k * n ];
69- float CC [m * n ];
88+ float DD [ m * n ], CC [m * n ];
7089
7190 for (j = 0 ; j < m ; j ++ )
7291 {
7392 for (i = 0 ; i < m ; i ++ )
7493 {
75- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
76- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
94+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
95+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
7796 C [j * k + i ] = 0 ;
7897 AA [j * k + i ].v = * (uint32_t * ) & A [j * k + i ] >> 16 ;
7998 BB [j * k + i ].v = * (uint32_t * ) & B [j * k + i ] >> 16 ;
8099 CC [j * k + i ] = 0 ;
100+ DD [j * k + i ] = 0 ;
81101 }
82102 }
83103 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
84- & m , B , & k , & beta , C , & m );
104+ & m , B , & k , & beta , C , & m );
85105 SHGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
86- & m , BB , & k , & beta , CC , & m );
87-
106+ & m , BB , & k , & beta , CC , & m );
88107 for (i = 0 ; i < n ; i ++ )
89- for (j = 0 ; j < m ; j ++ )
90- for (l = 0 ; l < k ; l ++ )
91- if (fabs (CC [i * m + j ]- C [i * m + j ]) > 1.0 )
92- ret ++ ;
108+ for (j = 0 ; j < m ; j ++ )
109+ for (l = 0 ; l < k ; l ++ )
110+ if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
111+ ret ++ ;
112+ if (transA == 'N' && transB == 'N' )
113+ {
114+ for (i = 0 ; i < n ; i ++ )
115+ for (j = 0 ; j < m ; j ++ )
116+ for (l = 0 ; l < k ; l ++ )
117+ {
118+ DD [i * m + j ] +=
119+ float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
120+ }
121+ for (i = 0 ; i < n ; i ++ )
122+ for (j = 0 ; j < m ; j ++ )
123+ for (l = 0 ; l < k ; l ++ )
124+ if (CC [i * m + j ] != DD [i * m + j ])
125+ ret ++ ;
126+ }
93127 }
94128 if (ret != 0 )
95129 fprintf (stderr , "FATAL ERROR SHGEMM - Return code: %d\n" , ret );
0 commit comments