@@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16)
8181 return f32 .v ;
8282}
8383
84+ float
85+ float32to16 (float32_bits f32 )
86+ {
87+ bfloat16_bits f16 ;
88+ f16 .bits .s = f32 .bits .s ;
89+ f16 .bits .e = f32 .bits .e ;
90+ f16 .bits .m = (uint32_t ) f32 .bits .m >> 16 ;
91+ return f32 .v ;
92+ }
93+
8494int
8595main (int argc , char * argv [])
8696{
@@ -108,16 +118,16 @@ main (int argc, char *argv[])
108118 A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
109119 B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
110120 C [j * k + i ] = 0 ;
111- AA [j * k + i ].v = * ( uint32_t * ) & A [j * k + i ] >> 16 ;
112- BB [j * k + i ].v = * ( uint32_t * ) & B [j * k + i ] >> 16 ;
121+ AA [j * k + i ].v = float32to16 ( A [j * k + i ] ) ;
122+ BB [j * k + i ].v = float32to16 ( B [j * k + i ] ) ;
113123 CC [j * k + i ] = 0 ;
114124 DD [j * k + i ] = 0 ;
115125 }
116126 }
117127 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
118128 & m , B , & k , & beta , C , & m );
119- SBGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
120- & m , BB , & k , & beta , CC , & m );
129+ SBGEMM (& transA , & transB , & m , & n , & k , & alpha , ( bfloat16 * ) AA ,
130+ & m , ( bfloat16 * ) BB , & k , & beta , CC , & m );
121131 for (i = 0 ; i < n ; i ++ )
122132 for (j = 0 ; j < m ; j ++ )
123133 if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
0 commit comments