Skip to content

Commit 623be66

Browse files
authored
Merge pull request #3284 from martin-frbg/potrf_potri
Add lower thresholds for multithreading in POTRF/POTRI and improve the related benchmark
2 parents 7ddc9d3 + 1b5620b commit 623be66

6 files changed

Lines changed: 57 additions & 15 deletions

File tree

benchmark/potrf.c

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,15 @@ int main(int argc, char *argv[]){
9999
char *p;
100100
char btest = 'F';
101101

102-
blasint m, i, j, info, uplos=0;
103-
double flops;
102+
blasint m, i, j, l, info, uplos=0;
103+
double flops = 0.;
104104

105105
int from = 1;
106106
int to = 200;
107107
int step = 1;
108+
int loops = 1;
108109

109-
double time1;
110+
double time1, timeg;
110111

111112
argc--;argv++;
112113

@@ -119,6 +120,8 @@ int main(int argc, char *argv[]){
119120

120121
if ((p = getenv("OPENBLAS_TEST"))) btest=*p;
121122

123+
if ((p = getenv("OPENBLAS_LOOPS"))) loops=*p;
124+
122125
fprintf(stderr, "From : %3d To : %3d Step = %3d Uplo = %c\n", from, to, step,*uplo[uplos]);
123126

124127
if (( a = (FLOAT *)malloc(sizeof(FLOAT) * to * to * COMPSIZE)) == NULL){
@@ -129,19 +132,21 @@ int main(int argc, char *argv[]){
129132
fprintf(stderr,"Out of Memory!!\n");exit(1);
130133
}
131134

132-
for(m = from; m <= to; m += step){
133135

136+
for(m = from; m <= to; m += step){
137+
timeg=0.;
138+
for (l = 0; l < loops; l++) {
134139
#ifndef COMPLEX
135140
if (uplos & 1) {
136141
for (j = 0; j < m; j++) {
137142
for(i = 0; i < j; i++) a[(long)i + (long)j * (long)m] = 0.;
138-
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
143+
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
139144
for(i = j + 1; i < m; i++) a[(long)i + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) - 0.5;
140145
}
141146
} else {
142147
for (j = 0; j < m; j++) {
143148
for(i = 0; i < j; i++) a[(long)i + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) - 0.5;
144-
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
149+
a[(long)j + (long)j * (long)m] = ((double) rand() / (double) RAND_MAX) + 8.;
145150
for(i = j + 1; i < m; i++) a[(long)i + (long)j * (long)m] = 0.;
146151
}
147152
}
@@ -192,8 +197,8 @@ int main(int argc, char *argv[]){
192197
exit(1);
193198
}
194199

195-
time1 = getsec();
196-
flops = COMPSIZE * COMPSIZE * (1.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 1.0/6.0* (double)m) / time1 * 1.e-6;
200+
if ( btest == 'F')
201+
timeg += getsec();
197202

198203
if ( btest == 'S' )
199204
{
@@ -214,9 +219,7 @@ int main(int argc, char *argv[]){
214219
fprintf(stderr, "Potrs info = %d\n", info);
215220
exit(1);
216221
}
217-
time1 = getsec();
218-
flops = COMPSIZE * COMPSIZE * (2.0 * (double)m * (double)m *(double)m ) / time1 * 1.e-6;
219-
222+
timeg += getsec();
220223
}
221224

222225
if ( btest == 'I' )
@@ -232,11 +235,17 @@ int main(int argc, char *argv[]){
232235
fprintf(stderr, "Potri info = %d\n", info);
233236
exit(1);
234237
}
235-
236-
time1 = getsec();
237-
flops = COMPSIZE * COMPSIZE * (2.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 5.0/6.0* (double)m) / time1 * 1.e-6;
238+
timeg += getsec();
238239
}
239-
240+
} // loops
241+
242+
time1 = timeg/(double)loops;
243+
if ( btest == 'F')
244+
flops = COMPSIZE * COMPSIZE * (1.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 1.0/6.0* (double)m) / time1 * 1.e-6;
245+
if ( btest == 'S')
246+
flops = COMPSIZE * COMPSIZE * (2.0 * (double)m * (double)m *(double)m ) / time1 * 1.e-6;
247+
if ( btest == 'I')
248+
flops = COMPSIZE * COMPSIZE * (2.0/3.0 * (double)m * (double)m *(double)m +1.0/2.0* (double)m *(double)m + 5.0/6.0* (double)m) / time1 * 1.e-6;
240249
fprintf(stderr, "%8d : %10.2f MFlops : %10.3f Sec : Test=%c\n",m,flops ,time1,btest);
241250

242251

common_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,13 @@ int BLASFUNC(cpotrf)(char *, blasint *, float *, blasint *, blasint *);
709709
int BLASFUNC(zpotrf)(char *, blasint *, double *, blasint *, blasint *);
710710
int BLASFUNC(xpotrf)(char *, blasint *, xdouble *, blasint *, blasint *);
711711

712+
int BLASFUNC(spotri)(char *, blasint *, float *, blasint *, blasint *);
713+
int BLASFUNC(dpotri)(char *, blasint *, double *, blasint *, blasint *);
714+
int BLASFUNC(qpotri)(char *, blasint *, xdouble *, blasint *, blasint *);
715+
int BLASFUNC(cpotri)(char *, blasint *, float *, blasint *, blasint *);
716+
int BLASFUNC(zpotri)(char *, blasint *, double *, blasint *, blasint *);
717+
int BLASFUNC(xpotri)(char *, blasint *, xdouble *, blasint *, blasint *);
718+
712719
int BLASFUNC(spotrs)(char *, blasint *, blasint *, float *, blasint *, float *, blasint *, blasint *);
713720
int BLASFUNC(dpotrs)(char *, blasint *, blasint *, double *, blasint *, double *, blasint *, blasint *);
714721
int BLASFUNC(qpotrs)(char *, blasint *, blasint *, xdouble *, blasint *, xdouble *, blasint *, blasint *);

interface/lapack/potrf.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
112112

113113
#ifdef SMP
114114
args.common = NULL;
115+
#ifndef DOUBLE
116+
if (args.n <128)
117+
#else
118+
if (args.n <64)
119+
#endif
120+
args.nthreads = 1;
121+
else
115122
args.nthreads = num_cpu_avail(4);
116123

117124
if (args.nthreads == 1) {

interface/lapack/potri.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
121121

122122
#ifdef SMP
123123
args.common = NULL;
124+
if (args.n < 180)
125+
args.nthreads = 1;
126+
else
124127
args.nthreads = num_cpu_avail(4);
125128

126129
if (args.nthreads == 1) {

interface/lapack/zpotrf.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
112112

113113
#ifdef SMP
114114
args.common = NULL;
115+
#ifndef DOUBLE
116+
if (args.n < 64)
117+
#else
118+
if (args.n < 64)
119+
#endif
120+
args.nthreads = 1;
121+
else
115122
args.nthreads = num_cpu_avail(4);
116123

117124
if (args.nthreads == 1) {

interface/lapack/zpotri.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){
121121

122122
#ifdef SMP
123123
args.nthreads = num_cpu_avail(4);
124+
#ifndef DOUBLE
125+
if (args.n < 200)
126+
#else
127+
if (args.n < 150)
128+
#endif
129+
args.nthreads=1;
130+
else
131+
#endif
132+
args.nthreads = num_cpu_avail(4);
124133

125134
if (args.nthreads == 1) {
126135
#endif

0 commit comments

Comments
 (0)