Skip to content

Commit 0d648cc

Browse files
pcaspersjenkins
authored andcommitted
QPR-12386 add double precision support
1 parent 2b269d3 commit 0d648cc

1 file changed

Lines changed: 103 additions & 64 deletions

File tree

QuantExt/qle/math/openclenvironment.cpp

Lines changed: 103 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ class OpenClContext : public ComputeContext {
240240
// 2a indexed by var id
241241
std::vector<std::size_t> inputVarOffset_;
242242
std::vector<bool> inputVarIsScalar_;
243-
std::vector<float> inputVarValues_;
243+
std::vector<float> inputVarValues32_;
244+
std::vector<double> inputVarValues64_;
244245

245246
// 2b collection of variable ids
246247
std::vector<std::size_t> freedVariables_;
@@ -376,8 +377,12 @@ void OpenClContext::init() {
376377
context_ = clCreateContext(NULL, 1, &device_, NULL, NULL, &err);
377378
QL_REQUIRE(err == CL_SUCCESS, "OpenClContext::OpenClContext(): error during clCreateContext(): " << errorText(err));
378379

379-
// deprecated in open-cl version 2.0, clCreateCommandQueueWithProperties
380+
#if CL_VERSION_2_0
381+
queue_ = clCreateCommandQueueWithProperties(context_, device_, NULL, &err);
382+
#else
383+
// deprecated in cl version 2_0
380384
queue_ = clCreateCommandQueue(context_, device_, 0, &err);
385+
#endif
381386
QL_REQUIRE(err == CL_SUCCESS,
382387
"OpenClContext::OpenClContext(): error during clCreateCommandQueue(): " << errorText(err));
383388

@@ -459,7 +464,8 @@ std::pair<std::size_t, bool> OpenClContext::initiateCalculation(const std::size_
459464

460465
inputVarOffset_.clear();
461466
inputVarIsScalar_.clear();
462-
inputVarValues_.clear();
467+
inputVarValues32_.clear();
468+
inputVarValues64_.clear();
463469

464470
if (newCalc) {
465471
freedVariables_.clear();
@@ -491,8 +497,13 @@ std::size_t OpenClContext::createInputVariable(double v) {
491497
}
492498
inputVarOffset_.push_back(nextOffset);
493499
inputVarIsScalar_.push_back(true);
494-
inputVarValues_.push_back((float)std::max(std::min(v, (double)std::numeric_limits<float>::max()),
495-
-(double)std::numeric_limits<float>::max()));
500+
if (settings_.useDoublePrecision) {
501+
inputVarValues64_.push_back(v);
502+
} else {
503+
// ensure that v falls into the single precision range
504+
inputVarValues32_.push_back((float)std::max(std::min(v, (double)std::numeric_limits<float>::max()),
505+
-(double)std::numeric_limits<float>::max()));
506+
}
496507
return nVars_++;
497508
}
498509

@@ -506,9 +517,14 @@ std::size_t OpenClContext::createInputVariable(double* v) {
506517
}
507518
inputVarOffset_.push_back(nextOffset);
508519
inputVarIsScalar_.push_back(false);
509-
for (std::size_t i = 0; i < size_[currentId_ - 1]; ++i)
510-
inputVarValues_.push_back((float)std::max(std::min(v[i], (double)std::numeric_limits<float>::max()),
511-
-(double)std::numeric_limits<float>::max()));
520+
for (std::size_t i = 0; i < size_[currentId_ - 1]; ++i) {
521+
if (settings_.useDoublePrecision) {
522+
inputVarValues64_.push_back(v[i]);
523+
} else {
524+
inputVarValues32_.push_back((float)std::max(std::min(v[i], (double)std::numeric_limits<float>::max()),
525+
-(double)std::numeric_limits<float>::max()));
526+
}
527+
}
512528
return nVars_++;
513529
}
514530

@@ -574,8 +590,10 @@ std::size_t OpenClContext::applyOperation(const std::size_t randomVariableOpCode
574590

575591
// generate ssa entry
576592

593+
std::string fpTypeStr = settings_.useDoublePrecision ? "double" : "float";
594+
577595
std::string ssaLine =
578-
(resultIdNeedsDeclaration ? "float " : "") + std::string("v") + std::to_string(resultId) + " = ";
596+
(resultIdNeedsDeclaration ? fpTypeStr + " " : "") + std::string("v") + std::to_string(resultId) + " = ";
579597

580598
switch (randomVariableOpCode) {
581599
case RandomVariableOpCode::None: {
@@ -705,12 +723,17 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
705723
QL_REQUIRE(output.size() == nOutputVars_[currentId_ - 1],
706724
"OpenClContext::finalizeCalculation(): output size ("
707725
<< output.size() << ") inconsistent to kernel output size (" << nOutputVars_[currentId_ - 1] << ")");
726+
QL_REQUIRE(!settings_.useDoublePrecision || supportsDoublePrecision(),
727+
"OpenClContext::finalizeCalculation(): double precision is configured for this calculation, but not "
728+
"supported by the device. Switch to single precision or use an appropriate device.");
708729

709730
boost::timer::cpu_timer timer;
710731
boost::timer::nanosecond_type timerBase;
711732

712733
// create input and output buffers
713734

735+
std::size_t fpSize = settings_.useDoublePrecision ? sizeof(double) : sizeof(float);
736+
714737
if (settings_.debug) {
715738
timerBase = timer.elapsed().wall;
716739
}
@@ -721,7 +744,7 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
721744
cl_int err;
722745
cl_mem inputBuffer;
723746
if (inputBufferSize > 0) {
724-
inputBuffer = clCreateBuffer(context_, CL_MEM_READ_WRITE, sizeof(float) * inputBufferSize, NULL, &err);
747+
inputBuffer = clCreateBuffer(context_, CL_MEM_READ_WRITE, fpSize * inputBufferSize, NULL, &err);
725748
guard.mem.push_back(inputBuffer);
726749
QL_REQUIRE(err == CL_SUCCESS,
727750
"OpenClContext::finalizeCalculation(): creating input buffer fails: " << errorText(err));
@@ -730,7 +753,7 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
730753
std::size_t outputBufferSize = nOutputVars_[currentId_ - 1] * size_[currentId_ - 1];
731754
cl_mem outputBuffer;
732755
if (outputBufferSize > 0) {
733-
outputBuffer = clCreateBuffer(context_, CL_MEM_READ_WRITE, sizeof(float) * outputBufferSize, NULL, &err);
756+
outputBuffer = clCreateBuffer(context_, CL_MEM_READ_WRITE, fpSize * outputBufferSize, NULL, &err);
734757
guard.mem.push_back(outputBuffer);
735758
QL_REQUIRE(err == CL_SUCCESS,
736759
"OpenClContext::finalizeCalculation(): creating output buffer fails: " << errorText(err));
@@ -744,84 +767,92 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
744767

745768
if (!hasKernel_[currentId_ - 1]) {
746769

770+
std::string fpTypeStr = settings_.useDoublePrecision ? "double" : "float";
771+
std::string fpEpsStr = settings_.useDoublePrecision ? "0x1.0p-52" : "0x1.0p-23f";
772+
std::string fpSuffix = settings_.useDoublePrecision ? std::string() : "f";
773+
774+
// clang-format off
747775
const std::string includeSource =
748-
"bool ore_closeEnough(const float x, const float y) {\n"
749-
" const float tol = 42.0f * 1.1920929e-07f;\n"
750-
" float diff = fabs(x - y);\n"
751-
" if (x == 0.0f || y == 0.0f)\n"
776+
"bool ore_closeEnough(const " + fpTypeStr + " x, const " + fpTypeStr + " y) {\n"
777+
" const " + fpTypeStr + " tol = 42.0" + fpSuffix + " * " + fpEpsStr + ";\n"
778+
" " + fpTypeStr + " diff = fabs(x - y);\n"
779+
" if (x == 0.0" + fpSuffix + " || y == 0.0" + fpSuffix + ")\n"
752780
" return diff < tol * tol;\n"
753781
" return diff <= tol * fabs(x) || diff <= tol * fabs(y);\n"
754782
"}\n"
755-
"\n"
756-
"float ore_indicatorEq(const float x, const float y) { return ore_closeEnough(x, y) ? 1.0f : 0.0f; }\n\n"
757-
"float ore_indicatorGt(const float x, const float y) { return x > y && !ore_closeEnough(x, y); }\n\n"
758-
"float ore_indicatorGeq(const float x, const float y) { return x > y || ore_closeEnough(x, y); }\n\n"
759-
"float ore_invCumN(const uint x0) {\n"
760-
" const float a1_ = -3.969683028665376e+01f;\n"
761-
" const float a2_ = 2.209460984245205e+02f;\n"
762-
" const float a3_ = -2.759285104469687e+02f;\n"
763-
" const float a4_ = 1.383577518672690e+02f;\n"
764-
" const float a5_ = -3.066479806614716e+01f;\n"
765-
" const float a6_ = 2.506628277459239e+00f;\n"
766-
" const float b1_ = -5.447609879822406e+01f;\n"
767-
" const float b2_ = 1.615858368580409e+02f;\n"
768-
" const float b3_ = -1.556989798598866e+02f;\n"
769-
" const float b4_ = 6.680131188771972e+01f;\n"
770-
" const float b5_ = -1.328068155288572e+01f;\n"
771-
" const float c1_ = -7.784894002430293e-03f;\n"
772-
" const float c2_ = -3.223964580411365e-01f;\n"
773-
" const float c3_ = -2.400758277161838e+00f;\n"
774-
" const float c4_ = -2.549732539343734e+00f;\n"
775-
" const float c5_ = 4.374664141464968e+00f;\n"
776-
" const float c6_ = 2.938163982698783e+00f;\n"
777-
" const float d1_ = 7.784695709041462e-03f;\n"
778-
" const float d2_ = 3.224671290700398e-01f;\n"
779-
" const float d3_ = 2.445134137142996e+00f;\n"
780-
" const float d4_ = 3.754408661907416e+00f;\n"
781-
" const float x_low_ = 0.02425f;\n"
782-
" const float x_high_ = 1.0f - x_low_;\n"
783-
" const float x = x0 / (float)UINT_MAX;\n"
783+
"\n" +
784+
fpTypeStr + " ore_indicatorEq(const " + fpTypeStr + " x, const " + fpTypeStr + " y) "
785+
"{ return ore_closeEnough(x, y) ? 1.0" + fpSuffix + " : 0.0" + fpSuffix +"; }\n\n" +
786+
fpTypeStr + " ore_indicatorGt(const " + fpTypeStr + " x, const " + fpTypeStr + " y) " +
787+
"{ return x > y && !ore_closeEnough(x, y); }\n\n" +
788+
fpTypeStr + " ore_indicatorGeq(const " + fpTypeStr + " x, const " + fpTypeStr + " y) { return x > y || ore_closeEnough(x, y); }\n\n" +
789+
fpTypeStr + " ore_invCumN(const uint x0) {\n"
790+
" const " + fpTypeStr + " a1_ = -3.969683028665376e+01" + fpSuffix + ";\n"
791+
" const " + fpTypeStr + " a2_ = 2.209460984245205e+02" + fpSuffix + ";\n"
792+
" const " + fpTypeStr + " a3_ = -2.759285104469687e+02" + fpSuffix + ";\n"
793+
" const " + fpTypeStr + " a4_ = 1.383577518672690e+02" + fpSuffix + ";\n"
794+
" const " + fpTypeStr + " a5_ = -3.066479806614716e+01" + fpSuffix + ";\n"
795+
" const " + fpTypeStr + " a6_ = 2.506628277459239e+00" + fpSuffix + ";\n"
796+
" const " + fpTypeStr + " b1_ = -5.447609879822406e+01" + fpSuffix + ";\n"
797+
" const " + fpTypeStr + " b2_ = 1.615858368580409e+02" + fpSuffix + ";\n"
798+
" const " + fpTypeStr + " b3_ = -1.556989798598866e+02" + fpSuffix + ";\n"
799+
" const " + fpTypeStr + " b4_ = 6.680131188771972e+01" + fpSuffix + ";\n"
800+
" const " + fpTypeStr + " b5_ = -1.328068155288572e+01" + fpSuffix + ";\n"
801+
" const " + fpTypeStr + " c1_ = -7.784894002430293e-03" + fpSuffix + ";\n"
802+
" const " + fpTypeStr + " c2_ = -3.223964580411365e-01" + fpSuffix + ";\n"
803+
" const " + fpTypeStr + " c3_ = -2.400758277161838e+00" + fpSuffix + ";\n"
804+
" const " + fpTypeStr + " c4_ = -2.549732539343734e+00" + fpSuffix + ";\n"
805+
" const " + fpTypeStr + " c5_ = 4.374664141464968e+00" + fpSuffix + ";\n"
806+
" const " + fpTypeStr + " c6_ = 2.938163982698783e+00" + fpSuffix + ";\n"
807+
" const " + fpTypeStr + " d1_ = 7.784695709041462e-03" + fpSuffix + ";\n"
808+
" const " + fpTypeStr + " d2_ = 3.224671290700398e-01" + fpSuffix + ";\n"
809+
" const " + fpTypeStr + " d3_ = 2.445134137142996e+00" + fpSuffix + ";\n"
810+
" const " + fpTypeStr + " d4_ = 3.754408661907416e+00" + fpSuffix + ";\n"
811+
" const " + fpTypeStr + " x_low_ = 0.02425" + fpSuffix + ";\n"
812+
" const " + fpTypeStr + " x_high_ = 1.0" + fpSuffix + " - x_low_;\n"
813+
" const " + fpTypeStr + " x = x0 / (" + fpTypeStr + ")UINT_MAX;\n"
784814
" if (x < x_low_ || x_high_ < x) {\n"
785815
" if (x0 == UINT_MAX) {\n"
786-
" return 0x1.fffffep127f;\n"
816+
" return 0x1.fffffep127" + fpSuffix + ";\n"
787817
" } else if(x0 == 0) {\n"
788-
" return -0x1.fffffep127f;\n"
818+
" return -0x1.fffffep127" + fpSuffix + ";\n"
789819
" }\n"
790-
" float z;\n"
820+
" " + fpTypeStr + " z;\n"
791821
" if (x < x_low_) {\n"
792-
" z = sqrt(-2.0f * log(x));\n"
822+
" z = sqrt(-2.0" + fpSuffix + " * log(x));\n"
793823
" z = (((((c1_ * z + c2_) * z + c3_) * z + c4_) * z + c5_) * z + c6_) /\n"
794-
" ((((d1_ * z + d2_) * z + d3_) * z + d4_) * z + 1.0f);\n"
824+
" ((((d1_ * z + d2_) * z + d3_) * z + d4_) * z + 1.0" + fpSuffix + ");\n"
795825
" } else {\n"
796826
" z = sqrt(-2.0f * log(1.0f - x));\n"
797827
" z = -(((((c1_ * z + c2_) * z + c3_) * z + c4_) * z + c5_) * z + c6_) /\n"
798-
" ((((d1_ * z + d2_) * z + d3_) * z + d4_) * z + 1.0f);\n"
828+
" ((((d1_ * z + d2_) * z + d3_) * z + d4_) * z + 1.0" + fpSuffix + ");\n"
799829
" }\n"
800830
" return z;\n"
801831
" } else {\n"
802-
" float z = x - 0.5f;\n"
803-
" float r = z * z;\n"
832+
" " + fpTypeStr + " z = x - 0.5" + fpSuffix + ";\n"
833+
" " + fpTypeStr + " r = z * z;\n"
804834
" z = (((((a1_ * r + a2_) * r + a3_) * r + a4_) * r + a5_) * r + a6_) * z /\n"
805-
" (((((b1_ * r + b2_) * r + b3_) * r + b4_) * r + b5_) * r + 1.0f);\n"
835+
" (((((b1_ * r + b2_) * r + b3_) * r + b4_) * r + b5_) * r + 1.0" + fpSuffix +");\n"
806836
" return z;\n"
807837
" }\n"
808838
"}\n\n";
839+
// clang-format on
809840

810841
std::string kernelName =
811842
"ore_kernel_" + std::to_string(currentId_) + "_" + std::to_string(version_[currentId_ - 1]);
812843

813844
std::string kernelSource = includeSource + "__kernel void " + kernelName +
814845
"(\n"
815846
" __global uint* lcrng_mult" +
816-
(inputBufferSize > 0 ? ",\n __global float* input" : "") +
817-
(outputBufferSize > 0 ? ",\n __global float* output" : "") +
847+
(inputBufferSize > 0 ? ",\n __global " + fpTypeStr + "* input" : "") +
848+
(outputBufferSize > 0 ? ",\n __global " + fpTypeStr + "* output" : "") +
818849
") {\n"
819850
"unsigned int i = get_global_id(0);\n"
820851
"if(i < " +
821852
std::to_string(size_[currentId_ - 1]) + "U) {\n";
822853

823854
for (std::size_t i = 0; i < variateSeed_.size(); ++i) {
824-
kernelSource += " float v" + std::to_string(i + inputVarOffset_.size()) + " = ore_invCumN(" +
855+
kernelSource += " " + fpTypeStr + " v" + std::to_string(i + inputVarOffset_.size()) + " = ore_invCumN(" +
825856
std::to_string(variateSeed_[i]) + "U * lcrng_mult[i]);\n";
826857
if (settings_.debug)
827858
debugInfo_.numberOfOperations += 23 * size_[currentId_ - 1];
@@ -889,8 +920,10 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
889920

890921
cl_event inputBufferEvent;
891922
if (inputBufferSize > 0) {
892-
err = clEnqueueWriteBuffer(queue_, inputBuffer, CL_FALSE, 0, sizeof(float) * inputBufferSize,
893-
&inputVarValues_[0], 0, NULL, &inputBufferEvent);
923+
err = clEnqueueWriteBuffer(queue_, inputBuffer, CL_FALSE, 0, fpSize * inputBufferSize,
924+
settings_.useDoublePrecision ? (void*)&inputVarValues64_[0]
925+
: (void*)&inputVarValues32_[0],
926+
0, NULL, &inputBufferEvent);
894927
QL_REQUIRE(err == CL_SUCCESS,
895928
"OpenClContext::finalizeCalculation(): writing to input buffer fails: " << errorText(err));
896929
}
@@ -941,22 +974,28 @@ void OpenClContext::finalizeCalculation(std::vector<double*>& output) {
941974

942975
std::vector<cl_event> outputBufferEvents;
943976
if (outputBufferSize > 0) {
944-
std::vector<std::vector<float>> outputFloat(output.size(), std::vector<float>(size_[currentId_ - 1]));
977+
std::vector<std::vector<float>> outputFloat;
978+
if (!settings_.useDoublePrecision) {
979+
outputFloat.resize(output.size(), std::vector<float>(size_[currentId_ - 1]));
980+
}
945981
for (std::size_t i = 0; i < output.size(); ++i) {
946982
outputBufferEvents.push_back(cl_event());
947983
err = clEnqueueReadBuffer(queue_, outputBuffer, CL_FALSE, i * size_[currentId_ - 1],
948-
sizeof(float) * size_[currentId_ - 1], &outputFloat[i][0], 1, &runEvent,
949-
&outputBufferEvents.back());
984+
fpSize * size_[currentId_ - 1],
985+
settings_.useDoublePrecision ? (void*)&output[i][0] : (void*)&outputFloat[i][0],
986+
1, &runEvent, &outputBufferEvents.back());
950987
QL_REQUIRE(err == CL_SUCCESS,
951988
"OpenClContext::finalizeCalculation(): writing to output buffer fails: " << errorText(err));
952989
}
953990
err = clWaitForEvents(outputBufferEvents.size(), outputBufferEvents.empty() ? nullptr : &outputBufferEvents[0]);
954991
QL_REQUIRE(
955992
err == CL_SUCCESS,
956993
"OpenClContext::finalizeCalculation(): wait for output buffer events to finish fails: " << errorText(err));
957-
// copy from float to double
958-
for (std::size_t i = 0; i < output.size(); ++i) {
959-
std::copy(outputFloat[i].begin(), outputFloat[i].end(), output[i]);
994+
if (!settings_.useDoublePrecision) {
995+
// copy from float to double
996+
for (std::size_t i = 0; i < output.size(); ++i) {
997+
std::copy(outputFloat[i].begin(), outputFloat[i].end(), output[i]);
998+
}
960999
}
9611000
}
9621001

0 commit comments

Comments
 (0)