Skip to content

Commit f6d4fe7

Browse files
committed
Fix incorrect cast from BF16 to FP32 in SBGEMM
This change fixes a regression in SBGEMM where C is assumed to be BF16, and so unconditionally casts the output to FP32 resulting in incorrect outputs when beta=1.
1 parent 1f1fcd4 commit f6d4fe7

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

kernel/generic/gemmkernel_2x2.c

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030

3131
#include "conversion_macros.h"
3232

33+
#ifdef BGEMM
34+
#define C_TO_F32 TO_F32
35+
#else
36+
#define C_TO_F32
37+
#endif
38+
3339
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
3440
#ifdef TRMMKERNEL
3541
,BLASLONG offset
@@ -108,13 +114,13 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
108114
ptrbb = ptrbb+2;
109115
}
110116
res0 = res0*ALPHA;
111-
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
117+
C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0);
112118
res1 = res1*ALPHA;
113-
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
119+
C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1);
114120
res2 = res2*ALPHA;
115-
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2);
121+
C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res2);
116122
res3 = res3*ALPHA;
117-
C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3);
123+
C1[1] = TO_OUTPUT(C_TO_F32(C1[1])+res3);
118124
C0 = C0+2;
119125
C1 = C1+2;
120126
}
@@ -134,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
134140
ptrbb = ptrbb+2;
135141
}
136142
res0 = res0*ALPHA;
137-
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
143+
C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0);
138144
res1 = res1*ALPHA;
139-
C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1);
145+
C1[0] = TO_OUTPUT(C_TO_F32(C1[0])+res1);
140146
C0 = C0+1;
141147
C1 = C1+1;
142148
}
@@ -165,9 +171,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
165171
ptrbb = ptrbb+1;
166172
}
167173
res0 = res0*ALPHA;
168-
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
174+
C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0);
169175
res1 = res1*ALPHA;
170-
C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1);
176+
C0[1] = TO_OUTPUT(C_TO_F32(C0[1])+res1);
171177
C0 = C0+2;
172178
}
173179
for (i=0; i<(bm&1); i+=1)
@@ -183,7 +189,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,
183189
ptrbb = ptrbb+1;
184190
}
185191
res0 = res0*ALPHA;
186-
C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0);
192+
C0[0] = TO_OUTPUT(C_TO_F32(C0[0])+res0);
187193
C0 = C0+1;
188194
}
189195
k = (bk<<0);

0 commit comments

Comments
 (0)