Skip to content

Commit 1c93e6a

Browse files
authored
Merge pull request #4521 from martin-frbg/fixczsum
Fix BLAS extension kernels for SCSUM and DZSUM on x86_64 targets
2 parents 5403900 + 8f8ef34 commit 1c93e6a

8 files changed

Lines changed: 1417 additions & 0 deletions

File tree

kernel/x86_64/KERNEL

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,5 +489,7 @@ XGEMM3MKERNEL = xgemm3m_kernel_2x2.S
489489

490490
SSUMKERNEL = ../arm/sum.c
491491
DSUMKERNEL = ../arm/sum.c
492+
CSUMKERNEL = zsum_sse.S
493+
ZSUMKERNEL = zsum_sse2.S
492494

493495
SOMATCOPY_RT = omatcopy_rt.c

kernel/x86_64/KERNEL.SKYLAKEX

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ ZGEMMKERNEL = zgemm_kernel_4x2_skylakex.c
4646

4747
CASUMKERNEL = casum.c
4848
ZASUMKERNEL = zasum.c
49+
CSUMKERNEL = csum.c
50+
ZSUMKERNEL = zsum.c

kernel/x86_64/csum.c

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#include "common.h"
2+
3+
#if defined(SKYLAKEX) || defined(COOPERLAKE) || defined(SAPPHIRERAPIDS)
4+
#include "csum_microk_skylakex-2.c"
5+
#endif
6+
7+
#ifndef HAVE_CSUM_KERNEL
8+
static FLOAT csum_kernel(BLASLONG n, FLOAT *x)
9+
{
10+
11+
BLASLONG i=0;
12+
BLASLONG n_8 = n & -8;
13+
FLOAT *x1 = x;
14+
FLOAT temp0, temp1, temp2, temp3;
15+
FLOAT temp4, temp5, temp6, temp7;
16+
FLOAT sum0 = 0.0;
17+
FLOAT sum1 = 0.0;
18+
FLOAT sum2 = 0.0;
19+
FLOAT sum3 = 0.0;
20+
FLOAT sum4 = 0.0;
21+
22+
while (i < n_8) {
23+
sum0 += x1[0];
24+
sum1 += x1[1];
25+
sum2 += x1[2];
26+
sum3 += x1[3];
27+
28+
sum0 += x1[4];
29+
sum1 += x1[5];
30+
sum2 += x1[6];
31+
sum3 += x1[7];
32+
33+
x1+=8;
34+
i+=4;
35+
}
36+
37+
while (i < n) {
38+
sum4 += (x1[0] + x1[1]);
39+
x1 += 2;
40+
i++;
41+
}
42+
43+
return sum0+sum1+sum2+sum3+sum4;
44+
}
45+
46+
#endif
47+
48+
static FLOAT sum_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x)
49+
{
50+
BLASLONG i = 0;
51+
BLASLONG ip = 0;
52+
BLASLONG inc_x2;
53+
FLOAT sumf = 0.0;
54+
55+
if (n <= 0 || inc_x <= 0) return(sumf);
56+
if (inc_x == 1) {
57+
sumf = csum_kernel(n, x);
58+
}
59+
else {
60+
inc_x2 = 2 * inc_x;
61+
62+
while (i < n) {
63+
sumf += x[ip] + x[ip + 1];
64+
ip += inc_x2;
65+
i++;
66+
}
67+
}
68+
69+
return(sumf);
70+
}
71+
72+
#if defined(SMP)
73+
static int sum_thread_function(BLASLONG n,
74+
BLASLONG dummy0, BLASLONG dummy1, FLOAT dummy2,
75+
FLOAT *x, BLASLONG inc_x,
76+
FLOAT * dummy3, BLASLONG dummy4,
77+
FLOAT * result, BLASLONG dummy5)
78+
{
79+
*(FLOAT *) result = sum_compute(n, x, inc_x);
80+
return 0;
81+
}
82+
83+
extern int blas_level1_thread_with_return_value(int mode,
84+
BLASLONG m, BLASLONG n, BLASLONG k, void * alpha,
85+
void *a, BLASLONG lda,
86+
void *b, BLASLONG ldb,
87+
void *c, BLASLONG ldc,
88+
int (*function)(),
89+
int nthread);
90+
#endif
91+
92+
FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
93+
{
94+
#if defined(SMP)
95+
int nthreads;
96+
FLOAT dummy_alpha[2];
97+
#endif
98+
FLOAT sumf = 0.0;
99+
100+
#if defined(SMP)
101+
int num_cpu = num_cpu_avail(1);
102+
if (n <= 10000 || inc_x <= 0)
103+
nthreads = 1;
104+
else
105+
nthreads = num_cpu < n/10000 ? num_cpu : n/10000;
106+
107+
if (nthreads == 1) {
108+
sumf = sum_compute(n, x, inc_x);
109+
}
110+
else {
111+
int mode, i;
112+
char result[MAX_CPU_NUMBER * sizeof(double) *2];
113+
FLOAT *ptr;
114+
#if !defined(DOUBLE)
115+
mode = BLAS_SINGLE | BLAS_COMPLEX;
116+
#else
117+
mode = BLAS_DOUBLE | BLAS_COMPLEX;
118+
#endif
119+
blas_level1_thread_with_return_value(mode, n, 0, 0, dummy_alpha, x, inc_x,
120+
NULL, 0, result, 0, (int (*)(void))sum_thread_function, nthreads);
121+
ptr = (FLOAT *)result;
122+
for (i = 0; i < nthreads; i++) {
123+
sumf += (*ptr);
124+
ptr = (FLOAT *)(((char *)ptr) + sizeof(double) *2);
125+
}
126+
}
127+
#else
128+
sumf = sum_compute(n, x, inc_x);
129+
#endif
130+
return(sumf);
131+
}

0 commit comments

Comments
 (0)