Skip to content

Commit 1c5a1ca

Browse files
author
sebastien.bouvard
committed
QPR-13698 Use QL::Rounding and fix gradient
1 parent 75cc528 commit 1c5a1ca

3 files changed

Lines changed: 10 additions & 12 deletions

File tree

QuantExt/qle/ad/computationgraph.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#include <ql/errors.hpp>
2424
#include <ql/math/comparison.hpp>
25-
#include <cmath>
2625

2726
#include <boost/math/distributions/normal.hpp>
2827

@@ -330,8 +329,8 @@ std::size_t cg_frac(ComputationGraph& g, const std::size_t a, const std::string&
330329

331330
std::size_t cg_round(ComputationGraph& g, const std::size_t a, const std::size_t b, const std::string& label) {
332331
if (g.isConstant(a) && g.isConstant(b)){
333-
double factor = std::pow(10.0, g.constantValue(b));
334-
return cg_const(g, std::round(g.constantValue(a)*factor)/factor);
332+
QuantLib::Rounding rnd(g.constantValue(b), QuantLib::Rounding::Closest, 5);
333+
return cg_const(g, rnd(g.constantValue(a)));
335334
}
336335
return g.insert({a, b}, RandomVariableOpCode::Round, label);
337336
}

QuantExt/qle/math/randomvariable.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
#include <boost/functional/hash.hpp>
3939

4040
#include <map>
41-
#include <cmath>
4241

4342
// if defined, RandomVariableStats are updated (this might impact perfomance!), default is undefined
4443
// #define ENABLE_RANDOMVARIABLE_STATS
@@ -739,14 +738,14 @@ RandomVariable round(RandomVariable x, const RandomVariable& y) {
739738
if (!y.deterministic_)
740739
x.expand();
741740
if (x.deterministic()){
742-
double factor = std::pow(10, y.constantData_);
743-
x.constantData_ = std::round(x.constantData_*factor)/factor;
741+
QuantLib::Rounding rnd(y.constantData_, QuantLib::Rounding::Closest, 5);
742+
x.constantData_ = rnd(x.constantData_);
744743
}
745744
else {
746745
resumeCalcStats();
747746
for (Size i = 0; i < x.size(); ++i) {
748-
double factor = std::pow(10, y[i]);
749-
x.data_[i] = std::round(x.data_[i]*factor)/factor;
747+
QuantLib::Rounding rnd(y.constantData_, QuantLib::Rounding::Closest, 5);
748+
x.data_[i] = rnd(x.constantData_);
750749
}
751750
stopCalcStats(x.size());
752751
}

QuantExt/qle/math/randomvariable_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ std::vector<RandomVariableGrad> getRandomVariableGradients(const Size size, cons
310310

311311
// Frac = 19
312312
grads.push_back([](const std::vector<const RandomVariable*>& args, const RandomVariable* v,
313-
const Size node) -> std::vector<RandomVariable> { return {QuantExt::frac(*args[0])}; });
313+
const Size node) -> std::vector<RandomVariable> { return {QuantExt::frac(RandomVariable(1))}; });
314314

315315
// Round = 20
316316
grads.push_back([](const std::vector<const RandomVariable*>& args, const RandomVariable* v,
317-
const Size node) -> std::vector<RandomVariable> { return {QuantExt::round(*args[0], *args[1])}; });
317+
const Size node) -> std::vector<RandomVariable> { return {QuantExt::round(RandomVariable(0),RandomVariable(0))}; });
318318

319319
return grads;
320320
}
@@ -386,10 +386,10 @@ std::vector<RandomVariableOpNodeRequirements> getRandomVariableOpNodeRequirement
386386
res.push_back([](const std::size_t nArgs) { return std::make_pair(std::vector<bool>(nArgs, true), true); });
387387

388388
// Frac = 19
389-
res.push_back([](const std::size_t nArgs) { return std::make_pair(std::vector<bool>(nArgs, true), false); });
389+
res.push_back([](const std::size_t nArgs) { return std::make_pair(std::vector<bool>(nArgs, false), false); });
390390

391391
// Round = 20
392-
res.push_back([](const std::size_t nArgs) { return std::make_pair(std::vector<bool>(nArgs, true), false); });
392+
res.push_back([](const std::size_t nArgs) { return std::make_pair(std::vector<bool>(nArgs, false), false); });
393393

394394
return res;
395395
}

0 commit comments

Comments
 (0)