Skip to content

Commit b741373

Browse files
committed
parametrize benchmarks with s-/d-/c-/z- variants
1 parent f173562 commit b741373

2 files changed

Lines changed: 73 additions & 114 deletions

File tree

benchmarks/benchmarks.py

Lines changed: 67 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,42 @@
11
# Write the benchmarking functions here.
22
# See "Writing benchmarks" in the asv docs for more information.
33

4-
'''
5-
class TimeSuite:
6-
"""
7-
An example benchmark that times the performance of various kinds
8-
of iterating over dictionaries in Python.
9-
"""
10-
def setup(self):
11-
self.d = {}
12-
for x in range(500):
13-
self.d[x] = None
14-
15-
def time_keys(self):
16-
for key in self.d.keys():
17-
pass
18-
19-
def time_values(self):
20-
for value in self.d.values():
21-
pass
22-
23-
def time_range(self):
24-
d = self.d
25-
for key in range(500):
26-
d[key]
27-
28-
29-
class MemSuite:
30-
def mem_list(self):
31-
return [0] * 256
32-
'''
33-
34-
354
import numpy as np
36-
from openblas_wrap import (
37-
# level 1
38-
dnrm2, ddot, daxpy,
39-
# level 3
40-
dgemm, dsyrk,
41-
# lapack
42-
dgesv, # linalg.solve
43-
dgesdd, dgesdd_lwork, # linalg.svd
44-
dsyev, dsyev_lwork, # linalg.eigh
45-
)
5+
import openblas_wrap as ow
6+
467

478
# ### BLAS level 1 ###
489

4910
# dnrm2
5011

5112
dnrm2_sizes = [100, 1000]
5213

53-
def run_dnrm2(n, x, incx):
54-
res = dnrm2(x, n, incx=incx)
14+
def run_dnrm2(n, x, incx, func):
15+
res = func(x, n, incx=incx)
5516
return res
5617

5718

5819

5920
class Nrm2:
6021

61-
params = [100, 1000]
62-
param_names = ["size"]
22+
params = [dnrm2_sizes, ['d', 'dz']]
23+
param_names = ["size", "variant"]
6324

64-
def setup(self, n):
25+
def setup(self, n, variant):
6526
rndm = np.random.RandomState(1234)
6627
self.x = rndm.uniform(size=(n,)).astype(float)
28+
self.nrm2 = ow.get_func('nrm2', variant)
6729

68-
def time_dnrm2(self, n):
69-
run_dnrm2(n, self.x, 1)
30+
def time_dnrm2(self, n, variant):
31+
run_dnrm2(n, self.x, 1, self.nrm2)
7032

7133

7234
# ddot
7335

7436
ddot_sizes = [100, 1000]
7537

76-
def run_ddot(x, y,):
77-
res = ddot(x, y)
38+
def run_ddot(x, y, func):
39+
res = func(x, y)
7840
return res
7941

8042

@@ -86,32 +48,34 @@ def setup(self, n):
8648
rndm = np.random.RandomState(1234)
8749
self.x = np.array(rndm.uniform(size=(n,)), dtype=float)
8850
self.y = np.array(rndm.uniform(size=(n,)), dtype=float)
51+
self.func = ow.get_func('dot', 'd')
8952

9053
def time_ddot(self, n):
91-
run_ddot(self.x, self.y)
54+
run_ddot(self.x, self.y, self.func)
9255

9356

9457

9558
# daxpy
9659

9760
daxpy_sizes = [100, 1000]
9861

99-
def run_daxpy(x, y,):
100-
res = daxpy(x, y, a=2.0)
62+
def run_daxpy(x, y, func):
63+
res = func(x, y, a=2.0)
10164
return res
10265

10366

10467
class Daxpy:
105-
params = daxpy_sizes
106-
param_names = ["size"]
68+
params = [daxpy_sizes, ['s', 'd', 'c', 'z']]
69+
param_names = ["size", "variant"]
10770

108-
def setup(self, n):
71+
def setup(self, n, variant):
10972
rndm = np.random.RandomState(1234)
11073
self.x = np.array(rndm.uniform(size=(n,)), dtype=float)
11174
self.y = np.array(rndm.uniform(size=(n,)), dtype=float)
75+
self.axpy = ow.get_func('axpy', variant)
11276

113-
def time_daxpy(self, n):
114-
run_daxpy(self.x, self.y)
77+
def time_daxpy(self, n, variant):
78+
run_daxpy(self.x, self.y, self.axpy)
11579

11680

11781

@@ -121,47 +85,49 @@ def time_daxpy(self, n):
12185

12286
gemm_sizes = [100, 1000]
12387

124-
def run_dgemm(a, b, c):
88+
def run_dgemm(a, b, c, func):
12589
alpha = 1.0
126-
res = dgemm(alpha, a, b, c=c, overwrite_c=True)
90+
res = func(alpha, a, b, c=c, overwrite_c=True)
12791
return res
12892

12993

13094
class Dgemm:
131-
params = gemm_sizes
132-
param_names = ["size"]
95+
params = [gemm_sizes, ['s', 'd', 'c', 'z']]
96+
param_names = ["size", 'variant']
13397

134-
def setup(self, n):
98+
def setup(self, n, variant):
13599
rndm = np.random.RandomState(1234)
136100
self.a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
137101
self.b = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
138102
self.c = np.empty((n, n), dtype=float, order='F')
103+
self.func = ow.get_func('gemm', variant)
139104

140-
def time_dgemm(self, n):
141-
run_dgemm(self.a, self.b, self.c)
105+
def time_dgemm(self, n, variant):
106+
run_dgemm(self.a, self.b, self.c, self.func)
142107

143108

144109
# dsyrk
145110

146111
syrk_sizes = [100, 1000]
147112

148113

149-
def run_dsyrk(a, c):
150-
res = dsyrk(1.0, a, c=c, overwrite_c=True)
114+
def run_dsyrk(a, c, func):
115+
res = func(1.0, a, c=c, overwrite_c=True)
151116
return res
152117

153118

154119
class DSyrk:
155-
params = syrk_sizes
156-
param_names = ["size"]
120+
params = [syrk_sizes, ['s', 'd', 'c', 'z']]
121+
param_names = ["size", "variant"]
157122

158-
def setup(self, n):
123+
def setup(self, n, variant):
159124
rndm = np.random.RandomState(1234)
160125
self.a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
161126
self.c = np.empty((n, n), dtype=float, order='F')
127+
self.func = ow.get_func('syrk', variant)
162128

163-
def time_dsyrk(self, n):
164-
run_dsyrk(self.a, self.c)
129+
def time_dsyrk(self, n, variant):
130+
run_dsyrk(self.a, self.c, self.func)
165131

166132

167133
# ### LAPACK ###
@@ -171,23 +137,24 @@ def time_dsyrk(self, n):
171137
dgesv_sizes = [100, 1000]
172138

173139

174-
def run_dgesv(a, b):
175-
res = dgesv(a, b, overwrite_a=True, overwrite_b=True)
140+
def run_dgesv(a, b, func):
141+
res = func(a, b, overwrite_a=True, overwrite_b=True)
176142
return res
177143

178144

179145
class Dgesv:
180-
params = dgesv_sizes
181-
param_names = ["size"]
146+
params = [dgesv_sizes, ['s', 'd', 'c', 'z']]
147+
param_names = ["size", "variant"]
182148

183-
def setup(self, n):
149+
def setup(self, n, variant):
184150
rndm = np.random.RandomState(1234)
185151
self.a = (np.array(rndm.uniform(size=(n, n)), dtype=float, order='F') +
186152
np.eye(n, order='F'))
187153
self.b = np.array(rndm.uniform(size=(n, 1)), order='F')
154+
self.func = ow.get_func('gesv', variant)
188155

189-
def time_dgesv(self, n):
190-
run_dgesv(self.a, self.b)
156+
def time_dgesv(self, n, variant):
157+
run_dgesv(self.a, self.b, self.func)
191158

192159
# XXX: how to run asserts?
193160
# lu, piv, x, info = benchmark(run_gesv, a, b)
@@ -201,58 +168,63 @@ def time_dgesv(self, n):
201168
dgesdd_sizes = ["100, 5", "1000, 222"]
202169

203170

204-
def run_dgesdd(a, lwork):
205-
res = dgesdd(a, lwork=lwork, full_matrices=False, overwrite_a=False)
171+
def run_dgesdd(a, lwork, func):
172+
res = func(a, lwork=lwork, full_matrices=False, overwrite_a=False)
206173
return res
207174

208175

209176
class Dgesdd:
210-
params = dgesdd_sizes
211-
param_names = ["(m, n)"]
177+
params = [dgesdd_sizes, ['s', 'd']]
178+
param_names = ["(m, n)", "variant"]
212179

213-
def setup(self, mn):
180+
def setup(self, mn, variant):
214181
m, n = (int(x) for x in mn.split(","))
215182

216183
rndm = np.random.RandomState(1234)
217184
a = np.array(rndm.uniform(size=(m, n)), dtype=float, order='F')
218185

219-
lwork, info = dgesdd_lwork(m, n)
186+
gesdd_lwork = ow.get_func('gesdd_lwork', variant)
187+
188+
lwork, info = gesdd_lwork(m, n)
220189
lwork = int(lwork)
221190
assert info == 0
222191

223192
self.a, self.lwork = a, lwork
193+
self.func = ow.get_func('gesdd', variant)
224194

225-
def time_dgesdd(self, mn):
226-
run_dgesdd(self.a, self.lwork)
195+
def time_dgesdd(self, mn, variant):
196+
run_dgesdd(self.a, self.lwork, self.func)
227197

228198

229199
# linalg.eigh
230200

231201
dsyev_sizes = [50, 200]
232202

233203

234-
def run_dsyev(a, lwork):
235-
res = dsyev(a, lwork=lwork, overwrite_a=True)
204+
def run_dsyev(a, lwork, func):
205+
res = func(a, lwork=lwork, overwrite_a=True)
236206
return res
237207

238208

239209
class Dsyev:
240-
params = dsyev_sizes
241-
param_names = ["size"]
210+
params = [dsyev_sizes, ['s', 'd']]
211+
param_names = ["size", "variant"]
242212

243-
def setup(self, n):
213+
def setup(self, n, variant):
244214
rndm = np.random.RandomState(1234)
245215
a = rndm.uniform(size=(n, n))
246216
a = np.asarray(a + a.T, dtype=float, order='F')
247217
a_ = a.copy()
248218

249-
lwork, info = dsyev_lwork(n)
219+
syev_lwork = ow.get_func('syev_lwork', variant)
220+
lwork, info = syev_lwork(n)
250221
lwork = int(lwork)
251222
assert info == 0
252223

253224
self.a = a_
254225
self.lwork = lwork
226+
self.func = ow.get_func('syev', variant)
255227

256-
def time_dsyev(self, n):
257-
run_dsyev(self.a, self.lwork)
228+
def time_dsyev(self, n, variant):
229+
run_dsyev(self.a, self.lwork, self.func)
258230

openblas_wrap/__init__.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,11 @@
66
__version__ = "0.1"
77

88
import scipy_openblas32 # preload symbols. typically done in _distributor_init.py
9+
from . import _flapack
910

10-
#from scipy.linalg.blas import (
11-
from ._flapack import (
12-
# level 1
13-
scipy_dnrm2 as dnrm2,
14-
scipy_ddot as ddot,
15-
scipy_daxpy as daxpy,
16-
# level 3
17-
scipy_dgemm as dgemm,
18-
scipy_dsyrk as dsyrk,
19-
)
11+
PREFIX = 'scipy_'
2012

21-
#from scipy.linalg.lapack import (
22-
from openblas_wrap._flapack import (
23-
# linalg.solve
24-
scipy_dgesv as dgesv,
25-
# linalg.svd
26-
scipy_dgesdd as dgesdd, scipy_dgesdd_lwork as dgesdd_lwork,
27-
# linalg.eigh
28-
scipy_dsyev as dsyev, scipy_dsyev_lwork as dsyev_lwork
29-
)
13+
14+
def get_func(name, variant):
15+
"""get_func('gesv', 'c') -> cgesv etc."""
16+
return getattr(_flapack, PREFIX + variant + name)

0 commit comments

Comments
 (0)