Skip to content

Commit 09c1877

Browse files
authored
Add test for SHGEMM
1 parent 4ac29b9 commit 09c1877

2 files changed

Lines changed: 157 additions & 3 deletions

File tree

test/Makefile

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1)
234234
BF3= test_bgemm
235235
B3 = test_sbgemm
236236
endif
237+
ifeq ($(BUILD_HFLOAT16),1)
238+
H3 = test_shgemm
239+
endif
237240
ifeq ($(BUILD_SINGLE),1)
238241
S3=sblat3
239242
endif
@@ -257,9 +260,9 @@ endif
257260

258261

259262
ifeq ($(SUPPORT_GEMM3M),1)
260-
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m
263+
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m
261264
else
262-
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3)
265+
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3)
263266
endif
264267

265268
ifneq ($(CROSS), 1)
@@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
454457
endif
455458

456459
ifeq ($(BUILD_HFLOAT16),1)
460+
test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME)
461+
$(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
462+
457463
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
458464
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
459465
endif
@@ -475,7 +481,7 @@ clean:
475481
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
476482
sblat1 dblat1 cblat1 zblat1 \
477483
sblat2 dblat2 cblat2 zblat2 \
478-
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
484+
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \
479485
sblat1p dblat1p cblat1p zblat1p \
480486
sblat2p dblat2p cblat2p zblat2p \
481487
sblat3p dblat3p cblat3p zblat3p \

test/compare_sgemm_shgemm.c

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/***************************************************************************
2+
Copyright (c) 2020,2025 The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
#include <stdio.h>
28+
#include <stdint.h>
29+
#define __USE_POSIX199309
30+
#include "../common.h"
31+
32+
#include "test_helpers.h"
33+
34+
#define SGEMM BLASFUNC(sgemm)
35+
#define SHGEMM BLASFUNC(shgemm)
36+
#define SGEMV BLASFUNC(sgemv)
37+
#define SHGEMV BLASFUNC(shgemv)
38+
#define SHGEMM_LARGEST 256
39+
40+
int
41+
main (int argc, char *argv[])
42+
{
43+
blasint m, n, k;
44+
int i, j, l;
45+
blasint x, y;
46+
int ret = 0;
47+
int loop = SHGEMM_LARGEST;
48+
char transA = 'N', transB = 'N';
49+
float alpha = 1.0, beta = 0.0;
50+
51+
for (x = 0; x <= loop; x++)
52+
{
53+
if ((x > 100) && (x != SHGEMM_LARGEST)) continue;
54+
m = k = n = x;
55+
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
56+
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
57+
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
58+
hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(bfloat16));
59+
hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(bfloat16));
60+
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
61+
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
62+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
63+
(DD == NULL) || (CC == NULL))
64+
return 1;
65+
66+
for (j = 0; j < m; j++)
67+
{
68+
for (i = 0; i < k; i++)
69+
{
70+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
71+
AA[j * k + i] = (hfloat16) A[j * k + i];
72+
}
73+
}
74+
for (j = 0; j < n; j++)
75+
{
76+
for (i = 0; i < k; i++)
77+
{
78+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
79+
BB[j * k + i] = (hfloat16) A[j * k + i];
80+
}
81+
}
82+
for (y = 0; y < 4; y++)
83+
{
84+
if ((y == 0) || (y == 2)) {
85+
transA = 'N';
86+
} else {
87+
transA = 'T';
88+
}
89+
if ((y == 0) || (y == 1)) {
90+
transB = 'N';
91+
} else {
92+
transB = 'T';
93+
}
94+
95+
memset(CC, 0, m * n * sizeof(FLOAT));
96+
memset(DD, 0, m * n * sizeof(FLOAT));
97+
memset(C, 0, m * n * sizeof(FLOAT));
98+
99+
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
100+
&m, B, &k, &beta, C, &m);
101+
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
102+
&m, BB, &k, &beta, CC, &m);
103+
104+
for (i = 0; i < n; i++)
105+
for (j = 0; j < m; j++)
106+
{
107+
for (l = 0; l < k; l++)
108+
if (transA == 'N' && transB == 'N')
109+
{
110+
DD[i * m + j] +=
111+
(float) AA[l * m + j] * (float)BB[l + k * i];
112+
} else if (transA == 'T' && transB == 'N')
113+
{
114+
DD[i * m + j] +=
115+
(float)AA[k * j + l] * (float)BB[l + k * i];
116+
} else if (transA == 'N' && transB == 'T')
117+
{
118+
DD[i * m + j] +=
119+
(float)AA[l * m + j] * (float)BB[i + l * n];
120+
} else if (transA == 'T' && transB == 'T')
121+
{
122+
DD[i * m + j] +=
123+
(float)AA[k * j + l] * (float)BB[i + l * n];
124+
}
125+
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
126+
ret++;
127+
}
128+
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
129+
ret++;
130+
}
131+
}
132+
}
133+
free(A);
134+
free(B);
135+
free(C);
136+
free(AA);
137+
free(BB);
138+
free(DD);
139+
free(CC);
140+
}
141+
142+
if (ret != 0) {
143+
fprintf(stderr, "SHGEMM FAILURES: %d\n", ret);
144+
return 1;
145+
}
146+
147+
return ret;
148+
}

0 commit comments

Comments
 (0)