Skip to content

Commit eacbbf9

Browse files
pcaspersjenkins
authored andcommitted
QPR-12493 use one context only
1 parent 1a42545 commit eacbbf9

2 files changed

Lines changed: 91 additions & 48 deletions

File tree

QuantExt/qle/math/openclenvironment.cpp

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
#include <iostream>
2929
#include <thread>
3030

31-
#define MAX_N_DEV_INFO 256U
32-
#define MAX_N_PLATFORMS 4U
33-
#define MAX_BUILD_LOG 65536U
34-
#define MAX_BUILD_LOG_LOGFILE 1024U
31+
#define ORE_OPENCL_MAX_N_DEV_INFO 256U
32+
#define ORE_OPENCL_MAX_BUILD_LOG 65536U
33+
#define ORE_OPENCL_MAX_BUILD_LOG_LOGFILE 1024U
3534

3635
namespace QuantExt {
3736

@@ -261,61 +260,88 @@ class OpenClContext : public ComputeContext {
261260
std::string currentSsa_;
262261
};
263262

264-
OpenClFramework::OpenClFramework() {
263+
bool OpenClFramework::initialized_ = false;
264+
boost::shared_mutex OpenClFramework::mutex_;
265+
cl_uint OpenClFramework::nPlatforms_ = 0;
266+
std::string OpenClFramework::platformName_[ORE_OPENCL_MAX_N_PLATFORMS];
267+
std::string OpenClFramework::deviceName_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
268+
cl_uint OpenClFramework::nDevices_[ORE_OPENCL_MAX_N_PLATFORMS];
269+
cl_device_id OpenClFramework::devices_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
270+
cl_context OpenClFramework::context_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
271+
std::vector<std::pair<std::string, std::string>> OpenClFramework::deviceInfo_[ORE_OPENCL_MAX_N_PLATFORMS]
272+
[ORE_OPENCL_MAX_N_DEVICES];
273+
bool OpenClFramework::supportsDoublePrecision_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
274+
275+
void OpenClFramework::init() {
276+
boost::unique_lock<boost::shared_mutex> lock(mutex_);
277+
278+
if (initialized_)
279+
return;
280+
281+
initialized_ = true;
282+
283+
cl_platform_id platforms[ORE_OPENCL_MAX_N_PLATFORMS];
284+
clGetPlatformIDs(ORE_OPENCL_MAX_N_PLATFORMS, platforms, &nPlatforms_);
265285

266-
cl_platform_id platforms[MAX_N_PLATFORMS];
267-
cl_uint nPlatforms;
268-
clGetPlatformIDs(MAX_N_PLATFORMS, platforms, &nPlatforms);
286+
for (std::size_t p = 0; p < nPlatforms_; ++p) {
287+
char platformName[ORE_OPENCL_MAX_N_DEV_INFO];
288+
clGetPlatformInfo(platforms[p], CL_PLATFORM_NAME, ORE_OPENCL_MAX_N_DEV_INFO, platformName, NULL);
289+
clGetDeviceIDs(platforms[p], CL_DEVICE_TYPE_ALL, ORE_OPENCL_MAX_N_DEVICES, devices_[p], &nDevices_[p]);
269290

270-
for (std::size_t p = 0; p < nPlatforms; ++p) {
271-
char platformName[MAX_N_DEV_INFO];
272-
clGetPlatformInfo(platforms[p], CL_PLATFORM_NAME, MAX_N_DEV_INFO, platformName, NULL);
273-
clGetDeviceIDs(platforms[p], CL_DEVICE_TYPE_ALL, 3, devices_, &nDevices_);
291+
platformName_[p] = std::string(platformName);
274292

275-
for (std::size_t d = 0; d < nDevices_; ++d) {
276-
char deviceName[MAX_N_DEV_INFO], driverVersion[MAX_N_DEV_INFO], deviceVersion[MAX_N_DEV_INFO],
277-
deviceExtensions[MAX_N_DEV_INFO];
293+
for (std::size_t d = 0; d < nDevices_[p]; ++d) {
294+
char deviceName[ORE_OPENCL_MAX_N_DEV_INFO], driverVersion[ORE_OPENCL_MAX_N_DEV_INFO],
295+
deviceVersion[ORE_OPENCL_MAX_N_DEV_INFO], deviceExtensions[ORE_OPENCL_MAX_N_DEV_INFO];
278296
cl_device_fp_config doubleFpConfig;
279-
std::vector<std::pair<std::string, std::string>> deviceInfo;
280297

281-
clGetDeviceInfo(devices_[d], CL_DEVICE_NAME, MAX_N_DEV_INFO, &deviceName, NULL);
282-
clGetDeviceInfo(devices_[d], CL_DRIVER_VERSION, MAX_N_DEV_INFO, &driverVersion, NULL);
283-
clGetDeviceInfo(devices_[d], CL_DEVICE_VERSION, MAX_N_DEV_INFO, &deviceVersion, NULL);
284-
clGetDeviceInfo(devices_[d], CL_DEVICE_EXTENSIONS, MAX_N_DEV_INFO, &deviceExtensions, NULL);
298+
clGetDeviceInfo(devices_[p][d], CL_DEVICE_NAME, ORE_OPENCL_MAX_N_DEV_INFO, &deviceName, NULL);
299+
clGetDeviceInfo(devices_[p][d], CL_DRIVER_VERSION, ORE_OPENCL_MAX_N_DEV_INFO, &driverVersion, NULL);
300+
clGetDeviceInfo(devices_[p][d], CL_DEVICE_VERSION, ORE_OPENCL_MAX_N_DEV_INFO, &deviceVersion, NULL);
301+
clGetDeviceInfo(devices_[p][d], CL_DEVICE_EXTENSIONS, ORE_OPENCL_MAX_N_DEV_INFO, &deviceExtensions, NULL);
302+
303+
deviceInfo_[p][d].push_back(std::make_pair("device_name", std::string(deviceName)));
304+
deviceInfo_[p][d].push_back(std::make_pair("driver_version", std::string(driverVersion)));
305+
deviceInfo_[p][d].push_back(std::make_pair("device_version", std::string(deviceVersion)));
306+
deviceInfo_[p][d].push_back(std::make_pair("device_extensions", std::string(deviceExtensions)));
285307

286-
deviceInfo.push_back(std::make_pair("device_name", std::string(deviceName)));
287-
deviceInfo.push_back(std::make_pair("driver_version", std::string(driverVersion)));
288-
deviceInfo.push_back(std::make_pair("device_version", std::string(deviceVersion)));
289-
deviceInfo.push_back(std::make_pair("device_extensions", std::string(deviceExtensions)));
308+
deviceName_[p][d] = std::string(deviceName);
290309

291-
bool supportsDoublePrecision = false;
310+
supportsDoublePrecision_[p][d] = false;
292311
#if CL_VERSION_1_2
293-
clGetDeviceInfo(devices_[d], CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(cl_device_fp_config), &doubleFpConfig,
312+
clGetDeviceInfo(devices_[p][d], CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(cl_device_fp_config), &doubleFpConfig,
294313
NULL);
295-
deviceInfo.push_back(std::make_pair(
314+
deviceInfo_[p][d].push_back(std::make_pair(
296315
"device_double_fp_config",
297316
((doubleFpConfig & CL_FP_DENORM) ? std::string("Denorm,") : std::string()) +
298317
((doubleFpConfig & CL_FP_INF_NAN) ? std::string("InfNan,") : std::string()) +
299318
((doubleFpConfig & CL_FP_ROUND_TO_NEAREST) ? std::string("RoundNearest,") : std::string()) +
300319
((doubleFpConfig & CL_FP_ROUND_TO_ZERO) ? std::string("RoundZero,") : std::string()) +
301320
((doubleFpConfig & CL_FP_FMA) ? std::string("FMA,") : std::string()) +
302321
((doubleFpConfig & CL_FP_SOFT_FLOAT) ? std::string("SoftFloat,") : std::string())));
303-
supportsDoublePrecision = supportsDoublePrecision || (doubleFpConfig != 0);
322+
supportsDoublePrecision_[p][d] = supportsDoublePrecision_[p][d] || (doubleFpConfig != 0);
304323
#else
305-
deviceInfo.push_back(std::make_pair("device_double_fp_config", "not provided before opencl 1.2"));
306-
supportsDoublePrecision = supportsDoublePrecision || std::string(deviceExtensions).find("cl_khr_fp64");
324+
deviceInfo_[p][d].push_back(std::make_pair("device_double_fp_config", "not provided before opencl 1.2"));
325+
supportsDoublePrecision_[p][d] =
326+
supportsDoublePrecision || std::string(deviceExtensions).find("cl_khr_fp64");
307327
#endif
308328

309329
// create context
310330

311331
cl_int err;
312-
313-
context_[d] = clCreateContext(NULL, 1, &devices_[d], &errorCallback, NULL, &err);
332+
context_[p][d] = clCreateContext(NULL, 1, &devices_[p][d], &errorCallback, NULL, &err);
314333
QL_REQUIRE(err == CL_SUCCESS,
315334
"OpenClFramework::OpenClContext(): error during clCreateContext(): " << errorText(err));
335+
}
336+
}
337+
}
316338

317-
contexts_["OpenCL/" + std::string(platformName) + "/" + std::string(deviceName)] =
318-
new OpenClContext(&devices_[d], &context_[d], deviceInfo, supportsDoublePrecision);
339+
OpenClFramework::OpenClFramework() {
340+
init();
341+
for (std::size_t p = 0; p < nPlatforms_; ++p) {
342+
for (std::size_t d = 0; d < nDevices_[p]; ++d) {
343+
contexts_["OpenCL/" + platformName_[p] + "/" + deviceName_[p][d]] =
344+
new OpenClContext(&devices_[p][d], &context_[p][d], deviceInfo_[p][d], supportsDoublePrecision_[p][d]);
319345
}
320346
}
321347
}
@@ -325,9 +351,11 @@ OpenClFramework::~OpenClFramework() {
325351
delete c;
326352
}
327353
cl_int err;
328-
for (cl_uint d = 0; d < nDevices_; ++d) {
329-
if (err = clReleaseContext(context_[d]); err != CL_SUCCESS) {
330-
std::cerr << "OpenClFramework: error during clReleaseContext: " + errorText(err) << std::endl;
354+
for (cl_uint p = 0; p < nPlatforms_; ++p) {
355+
for (cl_uint d = 0; d < nDevices_[p]; ++d) {
356+
if (err = clReleaseContext(context_[p][d]); err != CL_SUCCESS) {
357+
std::cerr << "OpenClFramework: error during clReleaseContext: " + errorText(err) << std::endl;
358+
}
331359
}
332360
}
333361
}
@@ -749,11 +777,11 @@ void OpenClContext::updateVariatesPool() {
749777
"OpenClContext::updateVariatesPool(): error creating program: " << errorText(err));
750778
err = clBuildProgram(variatesProgram_, 1, device_, NULL, NULL, NULL);
751779
if (err != CL_SUCCESS) {
752-
char buffer[MAX_BUILD_LOG];
753-
clGetProgramBuildInfo(variatesProgram_, *device_, CL_PROGRAM_BUILD_LOG, MAX_BUILD_LOG * sizeof(char),
754-
buffer, NULL);
780+
char buffer[ORE_OPENCL_MAX_BUILD_LOG];
781+
clGetProgramBuildInfo(variatesProgram_, *device_, CL_PROGRAM_BUILD_LOG,
782+
ORE_OPENCL_MAX_BUILD_LOG * sizeof(char), buffer, NULL);
755783
QL_FAIL("OpenClContext::updateVariatesPool(): error during program build: "
756-
<< errorText(err) << ": " << std::string(buffer).substr(MAX_BUILD_LOG_LOGFILE));
784+
<< errorText(err) << ": " << std::string(buffer).substr(ORE_OPENCL_MAX_BUILD_LOG_LOGFILE));
757785
}
758786

759787
variatesKernelSeedInit_ = clCreateKernel(variatesProgram_, "ore_seedInitialization", &err);
@@ -1182,12 +1210,12 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
11821210
<< errorText(err));
11831211
err = clBuildProgram(program_[currentId_ - 1], 1, device_, NULL, NULL, NULL);
11841212
if (err != CL_SUCCESS) {
1185-
char buffer[MAX_BUILD_LOG];
1213+
char buffer[ORE_OPENCL_MAX_BUILD_LOG];
11861214
clGetProgramBuildInfo(program_[currentId_ - 1], *device_, CL_PROGRAM_BUILD_LOG,
1187-
MAX_BUILD_LOG * sizeof(char), buffer, NULL);
1215+
ORE_OPENCL_MAX_BUILD_LOG * sizeof(char), buffer, NULL);
11881216
QL_FAIL("OpenClContext::finalizeCalculation(): error during program build for kernel '"
11891217
<< kernelName << "': " << errorText(err) << ": "
1190-
<< std::string(buffer).substr(MAX_BUILD_LOG_LOGFILE));
1218+
<< std::string(buffer).substr(ORE_OPENCL_MAX_BUILD_LOG_LOGFILE));
11911219
}
11921220
kernel_[currentId_ - 1] = clCreateKernel(program_[currentId_ - 1], kernelName.c_str(), &err);
11931221
QL_REQUIRE(err == CL_SUCCESS,

QuantExt/qle/math/openclenvironment.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
#include <qle/math/computeenvironment.hpp>
2626

27+
#include <boost/thread/lock_types.hpp>
28+
#include <boost/thread/shared_mutex.hpp>
29+
2730
#include <map>
2831

2932
#ifdef ORE_ENABLE_OPENCL
@@ -34,7 +37,8 @@
3437
#endif
3538
#endif
3639

37-
#define MAX_N_DEVICES 8U
40+
#define ORE_OPENCL_MAX_N_PLATFORMS 4U
41+
#define ORE_OPENCL_MAX_N_DEVICES 8U
3842

3943
namespace QuantExt {
4044

@@ -46,10 +50,21 @@ class OpenClFramework : public ComputeFramework {
4650
ComputeContext* getContext(const std::string& deviceName) override final;
4751

4852
private:
53+
static void init();
54+
4955
std::map<std::string, ComputeContext*> contexts_;
50-
cl_uint nDevices_;
51-
cl_device_id devices_[MAX_N_DEVICES];
52-
cl_context context_[MAX_N_DEVICES];
56+
57+
static boost::shared_mutex mutex_;
58+
static bool initialized_;
59+
static cl_uint nPlatforms_;
60+
static std::string platformName_[ORE_OPENCL_MAX_N_PLATFORMS];
61+
static std::string deviceName_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
62+
static cl_uint nDevices_[ORE_OPENCL_MAX_N_PLATFORMS];
63+
static cl_device_id devices_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
64+
static cl_context context_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
65+
static std::vector<std::pair<std::string, std::string>> deviceInfo_[ORE_OPENCL_MAX_N_PLATFORMS]
66+
[ORE_OPENCL_MAX_N_DEVICES];
67+
static bool supportsDoublePrecision_[ORE_OPENCL_MAX_N_PLATFORMS][ORE_OPENCL_MAX_N_DEVICES];
5368
};
5469

5570
} // namespace QuantExt

0 commit comments

Comments
 (0)