Skip to content

Commit 036c02e

Browse files
committed
Add custom taint rules from sarif file parsing
1 parent bc922f5 commit 036c02e

6 files changed

Lines changed: 116 additions & 78 deletions

File tree

include/klee/Module/SarifReport.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,31 @@ template <typename T> struct adl_serializer<std::optional<T>> {
4343
} // namespace nlohmann
4444

4545
namespace klee {
46-
enum ReachWithError {
46+
enum ReachWithErrorType {
4747
DoubleFree = 0,
4848
UseAfterFree,
4949
MayBeNullPointerException, // void f(int *x) { *x = 42; } - should it error?
5050
MustBeNullPointerException, // MayBeNPE = yes, MustBeNPE = no
5151
NullCheckAfterDerefException,
5252
Reachable,
5353
None,
54+
MaybeTaint,
55+
};
56+
57+
struct ReachWithError {
58+
ReachWithErrorType type;
59+
std::optional<std::string> data;
5460

55-
TaintFormatString,
56-
TaintSensitiveData,
57-
TaintExecute,
61+
explicit ReachWithError(ReachWithErrorType type,
62+
std::optional<std::string> data = std::nullopt);
63+
64+
bool operator==(const ReachWithError &other) const;
65+
bool operator!=(const ReachWithError &other) const;
66+
bool operator<(const ReachWithError &other) const;
5867
};
5968

69+
using ReachWithErrors = std::vector<ReachWithError>;
70+
6071
const char *getErrorString(ReachWithError error);
6172
std::string getErrorsString(const std::vector<ReachWithError> &errors);
6273

lib/Core/ExecutionState.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ class ExecutionState {
421421

422422
ExprHashMap<llvm::Type *> gepExprBases;
423423

424-
mutable ReachWithError error = ReachWithError::None;
424+
mutable ReachWithError error = ReachWithError(ReachWithErrorType::None);
425425
std::atomic<HaltExecution::Reason> terminationReasonType{
426426
HaltExecution::NotHalt};
427427

lib/Core/Executor.cpp

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,8 @@ bool mustVisitForkBranches(ref<Target> target, KInstruction *instr) {
12941294
// fork branches here
12951295
if (auto reprErrorTarget = dyn_cast<ReproduceErrorTarget>(target)) {
12961296
return reprErrorTarget->isTheSameAsIn(instr) &&
1297-
reprErrorTarget->isThatError(
1298-
ReachWithError::NullCheckAfterDerefException);
1297+
reprErrorTarget->isThatError(ReachWithError(
1298+
ReachWithErrorType::NullCheckAfterDerefException));
12991299
}
13001300
return false;
13011301
}
@@ -2674,8 +2674,9 @@ void Executor::checkNullCheckAfterDeref(ref<Expr> cond, ExecutionState &state) {
26742674
if (eqPointerCheck && eqPointerCheck->left->isZero() &&
26752675
state.resolvedPointers.count(
26762676
makePointer(eqPointerCheck->right)->getBase())) {
2677-
reportStateOnTargetError(state,
2678-
ReachWithError::NullCheckAfterDerefException);
2677+
reportStateOnTargetError(
2678+
state,
2679+
ReachWithError(ReachWithErrorType::NullCheckAfterDerefException));
26792680
}
26802681
}
26812682

@@ -2687,10 +2688,11 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
26872688
auto target = kvp.first;
26882689
if (target->shouldFailOnThisTarget() &&
26892690
cast<ReproduceErrorTarget>(target)->isThatError(
2690-
ReachWithError::Reachable) &&
2691+
ReachWithError(ReachWithErrorType::Reachable)) &&
26912692
target->getBlock() == ki->parent &&
26922693
cast<ReproduceErrorTarget>(target)->isTheSameAsIn(ki)) {
2693-
terminateStateOnTargetError(state, ReachWithError::Reachable);
2694+
terminateStateOnTargetError(
2695+
state, ReachWithError(ReachWithErrorType::Reachable));
26942696
return;
26952697
}
26962698
}
@@ -4998,25 +5000,25 @@ void Executor::terminateStateOnTargetError(ExecutionState &state,
49985000
// Proceed with normal `terminateStateOnError` call
49995001
std::string messaget;
50005002
StateTerminationType terminationType;
5001-
switch (error) {
5002-
case ReachWithError::MayBeNullPointerException:
5003-
case ReachWithError::MustBeNullPointerException:
5003+
switch (error.type) {
5004+
case ReachWithErrorType::MayBeNullPointerException:
5005+
case ReachWithErrorType::MustBeNullPointerException:
50045006
messaget = "memory error: null pointer exception";
50055007
terminationType = StateTerminationType::Ptr;
50065008
break;
5007-
case ReachWithError::DoubleFree:
5009+
case ReachWithErrorType::DoubleFree:
50085010
messaget = "double free error";
50095011
terminationType = StateTerminationType::Ptr;
50105012
break;
5011-
case ReachWithError::UseAfterFree:
5013+
case ReachWithErrorType::UseAfterFree:
50125014
messaget = "use after free error";
50135015
terminationType = StateTerminationType::Ptr;
50145016
break;
5015-
case ReachWithError::Reachable:
5017+
case ReachWithErrorType::Reachable:
50165018
messaget = "";
50175019
terminationType = StateTerminationType::Reachable;
50185020
break;
5019-
case ReachWithError::None:
5021+
case ReachWithErrorType::None:
50205022
default:
50215023
messaget = "unspecified error";
50225024
terminationType = StateTerminationType::User;
@@ -5025,11 +5027,15 @@ void Executor::terminateStateOnTargetError(ExecutionState &state,
50255027
state, new ErrorEvent(locationOf(state), terminationType, messaget));
50265028
}
50275029

5028-
// TODO: add taint target errors to taint-annotations.json and change function
5029-
void Executor::terminateStateOnTargetTaintError(ExecutionState &state, size_t rule) {
5030-
const std::string &ruleStr = annotationsData.taintAnnotation.rules[rule];
5030+
void Executor::terminateStateOnTargetTaintError(ExecutionState &state,
5031+
size_t rule) {
5032+
if (rule >= annotationsData.taintAnnotation.rules.size()) {
5033+
terminateStateOnUserError(state, "Incorrect rule id");
5034+
}
50315035

5032-
// reportStateOnTargetError(state, rule);
5036+
const std::string &ruleStr = annotationsData.taintAnnotation.rules[rule];
5037+
reportStateOnTargetError(
5038+
state, ReachWithError(ReachWithErrorType::MaybeTaint, ruleStr));
50335039

50345040
terminateStateOnProgramError(state, ruleStr + " taint error",
50355041
StateTerminationType::Taint);
@@ -5500,8 +5506,8 @@ void Executor::executeFree(ExecutionState &state, ref<PointerExpr> address,
55005506
if (!resolveExact(*zeroPointer.second, address,
55015507
typeSystemManager->getUnknownType(), rl, "free") &&
55025508
guidanceKind == GuidanceKind::ErrorGuidance) {
5503-
terminateStateOnTargetError(*zeroPointer.second,
5504-
ReachWithError::DoubleFree);
5509+
terminateStateOnTargetError(
5510+
*zeroPointer.second, ReachWithError(ReachWithErrorType::DoubleFree));
55055511
return;
55065512
}
55075513

@@ -5566,9 +5572,9 @@ bool Executor::resolveExact(ExecutionState &estate, ref<Expr> address,
55665572
ExecutionState *bound = branches.first;
55675573
if (bound) {
55685574
auto error = isReadFromSymbolicArray(uniqueBase)
5569-
? ReachWithError::MayBeNullPointerException
5570-
: ReachWithError::MustBeNullPointerException;
5571-
terminateStateOnTargetError(*bound, error);
5575+
? ReachWithErrorType::MayBeNullPointerException
5576+
: ReachWithErrorType::MustBeNullPointerException;
5577+
terminateStateOnTargetError(*bound, ReachWithError(error));
55725578
}
55735579
if (!branches.second) {
55745580
address =
@@ -6248,9 +6254,9 @@ void Executor::executeMemoryOperation(
62486254
ExecutionState *bound = branches.first;
62496255
if (bound) {
62506256
auto error = (isReadFromSymbolicArray(base) && branches.second)
6251-
? ReachWithError::MayBeNullPointerException
6252-
: ReachWithError::MustBeNullPointerException;
6253-
terminateStateOnTargetError(*bound, error);
6257+
? ReachWithErrorType::MayBeNullPointerException
6258+
: ReachWithErrorType::MustBeNullPointerException;
6259+
terminateStateOnTargetError(*bound, ReachWithError(error));
62546260
}
62556261
if (!branches.second)
62566262
return;
@@ -6372,7 +6378,8 @@ void Executor::executeMemoryOperation(
63726378
solver->setTimeout(time::Span());
63736379

63746380
if (!success) {
6375-
terminateStateOnTargetError(*state, ReachWithError::UseAfterFree);
6381+
terminateStateOnTargetError(
6382+
*state, ReachWithError(ReachWithErrorType::UseAfterFree));
63766383
return;
63776384
}
63786385
}
@@ -6986,7 +6993,7 @@ void Executor::runFunctionAsMain(Function *f, int argc, char **argv,
69866993
auto kCallBlock = kfIt->second->entryKBlock;
69876994
forest = new TargetForest(kEntryFunction);
69886995
forest->add(ReproduceErrorTarget::create(
6989-
{ReachWithError::Reachable}, "",
6996+
{ReachWithError(ReachWithErrorType::Reachable)}, "",
69906997
ErrorLocation(kCallBlock->getFirstInstruction()), kCallBlock));
69916998
}
69926999
}
@@ -7461,9 +7468,9 @@ bool Executor::getSymbolicSolution(const ExecutionState &state, KTest &res) {
74617468
// we cannot be sure that an irreproducible state proves the presence of an
74627469
// error
74637470
if (uninitObjects.size() > 0 || state.symbolics.size() != symbolics.size()) {
7464-
state.error = ReachWithError::None;
7471+
state.error = ReachWithError(ReachWithErrorType::None);
74657472
} else if (FunctionCallReproduce != "" &&
7466-
state.error == ReachWithError::Reachable) {
7473+
state.error.type == ReachWithErrorType::Reachable) {
74677474
setHaltExecution(HaltExecution::ReachedTarget);
74687475
}
74697476

lib/Core/SpecialFunctionHandler.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,13 +1264,14 @@ void SpecialFunctionHandler::handleGetTaintRule(
12641264
return;
12651265
}
12661266

1267-
// ref<Expr> result = ConstantExpr::create(4, Expr::Int64);
1268-
// executor.bindLocal(target, state, result);
1267+
// TODO: now mock
1268+
ref<Expr> result = ConstantExpr::create(1, Expr::Int64);
1269+
executor.bindLocal(target, state, result);
12691270
}
12701271

1271-
void SpecialFunctionHandler::handleTaintHit(
1272-
klee::ExecutionState &state, klee::KInstruction *target,
1273-
std::vector<ref<Expr>> &arguments) {
1272+
void SpecialFunctionHandler::handleTaintHit(klee::ExecutionState &state,
1273+
klee::KInstruction *target,
1274+
std::vector<ref<Expr>> &arguments) {
12741275
if (arguments.size() != 1) {
12751276
executor.terminateStateOnUserError(
12761277
state, "Incorrect number of arguments to klee_taint_hit(size_t)");
@@ -1283,8 +1284,5 @@ void SpecialFunctionHandler::handleTaintHit(
12831284
executor.terminateStateOnUserError(
12841285
state, "Incorrect argument 0 to klee_taint_hit(size_t)");
12851286
}
1286-
1287-
klee_warning("!!!: %s\n", arguments[0]->toString().c_str());
1288-
12891287
executor.terminateStateOnTargetTaintError(state, rule);
12901288
}

lib/Core/TargetedExecutionManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ bool TargetedExecutionManager::reportTruePositive(ExecutionState &state,
552552

553553
atLeastOneReported = true;
554554
assert(!errorTarget->isReported);
555-
if (errorTarget->isThatError(ReachWithError::Reachable)) {
555+
if (errorTarget->isThatError(
556+
ReachWithError(ReachWithErrorType::Reachable))) {
556557
klee_warning("100.00%% %s Reachable at trace %s", getErrorString(error),
557558
errorTarget->getId().c_str());
558559
} else {

lib/Module/SarifReport.cpp

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -55,76 +55,78 @@ tryConvertRuleJson(const std::string &ruleId, const std::string &toolName,
5555
const std::optional<Message> &errorMessage) {
5656
if (toolName == "SecB") {
5757
if ("NullDereference" == ruleId) {
58-
return {ReachWithError::MustBeNullPointerException};
58+
return {ReachWithError(ReachWithErrorType::MustBeNullPointerException)};
5959
} else if ("CheckAfterDeref" == ruleId) {
60-
return {ReachWithError::NullCheckAfterDerefException};
60+
return {ReachWithError(ReachWithErrorType::NullCheckAfterDerefException)};
6161
} else if ("DoubleFree" == ruleId) {
62-
return {ReachWithError::DoubleFree};
62+
return {ReachWithError(ReachWithErrorType::DoubleFree)};
6363
} else if ("UseAfterFree" == ruleId) {
64-
return {ReachWithError::UseAfterFree};
64+
return {ReachWithError(ReachWithErrorType::UseAfterFree)};
6565
} else if ("Reached" == ruleId) {
66-
return {ReachWithError::Reachable};
66+
return {ReachWithError(ReachWithErrorType::Reachable)};
6767
} else {
68-
return {};
68+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
6969
}
7070
} else if (toolName == "clang") {
7171
if ("core.NullDereference" == ruleId) {
72-
return {ReachWithError::MayBeNullPointerException,
73-
ReachWithError::MustBeNullPointerException};
72+
return {ReachWithError(ReachWithErrorType::MayBeNullPointerException),
73+
ReachWithError(ReachWithErrorType::MustBeNullPointerException)};
7474
} else if ("unix.Malloc" == ruleId) {
7575
if (errorMessage.has_value()) {
7676
if (errorMessage->text == "Attempt to free released memory") {
77-
return {ReachWithError::DoubleFree};
77+
return {ReachWithError(ReachWithErrorType::DoubleFree)};
7878
} else if (errorMessage->text == "Use of memory after it is freed") {
79-
return {ReachWithError::UseAfterFree};
79+
return {ReachWithError(ReachWithErrorType::UseAfterFree)};
8080
} else {
8181
return {};
8282
}
8383
} else {
84-
return {ReachWithError::UseAfterFree, ReachWithError::DoubleFree};
84+
return {ReachWithError(ReachWithErrorType::UseAfterFree),
85+
ReachWithError(ReachWithErrorType::DoubleFree)};
8586
}
8687
} else if ("core.Reach" == ruleId) {
87-
return {ReachWithError::Reachable};
88+
return {ReachWithError(ReachWithErrorType::Reachable)};
8889
} else {
89-
return {};
90+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
9091
}
9192
} else if (toolName == "CppCheck") {
9293
if ("nullPointer" == ruleId || "ctunullpointer" == ruleId) {
93-
return {ReachWithError::MayBeNullPointerException,
94-
ReachWithError::MustBeNullPointerException}; // TODO: check it out
94+
return {
95+
ReachWithError(ReachWithErrorType::MayBeNullPointerException),
96+
ReachWithError(
97+
ReachWithErrorType::MustBeNullPointerException)}; // TODO: check
98+
// it out
9599
} else if ("doubleFree" == ruleId) {
96-
return {ReachWithError::DoubleFree};
100+
return {ReachWithError(ReachWithErrorType::DoubleFree)};
97101
} else {
98-
return {};
102+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
99103
}
100104
} else if (toolName == "Infer") {
101105
if ("NULL_DEREFERENCE" == ruleId || "NULLPTR_DEREFERENCE" == ruleId) {
102-
return {ReachWithError::MayBeNullPointerException,
103-
ReachWithError::MustBeNullPointerException}; // TODO: check it out
106+
return {
107+
ReachWithError(ReachWithErrorType::MayBeNullPointerException),
108+
ReachWithError(
109+
ReachWithErrorType::MustBeNullPointerException)}; // TODO: check
110+
// it out
104111
} else if ("USE_AFTER_DELETE" == ruleId || "USE_AFTER_FREE" == ruleId) {
105-
return {ReachWithError::UseAfterFree, ReachWithError::DoubleFree};
112+
return {ReachWithError(ReachWithErrorType::UseAfterFree),
113+
ReachWithError(ReachWithErrorType::DoubleFree)};
106114
} else {
107-
return {};
115+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
108116
}
109117
} else if (toolName == "Cooddy") {
110118
if ("NULL.DEREF" == ruleId || "NULL.UNTRUSTED.DEREF" == ruleId) {
111-
return {ReachWithError::MayBeNullPointerException,
112-
ReachWithError::MustBeNullPointerException};
119+
return {ReachWithError(ReachWithErrorType::MayBeNullPointerException),
120+
ReachWithError(ReachWithErrorType::MustBeNullPointerException)};
113121
} else if ("MEM.DOUBLE.FREE" == ruleId) {
114-
return {ReachWithError::DoubleFree};
122+
return {ReachWithError(ReachWithErrorType::DoubleFree)};
115123
} else if ("MEM.USE.FREE" == ruleId) {
116-
return {ReachWithError::UseAfterFree};
117-
} else if ("SV.STR.FMT.TAINT" == ruleId) {
118-
return {ReachWithError::TaintFormatString};
119-
} else if ("TAINT.SDE" == ruleId) {
120-
return {ReachWithError::TaintSensitiveData};
121-
} else if ("TAINT.STRING.CLI" == ruleId) {
122-
return {ReachWithError::TaintExecute};
124+
return {ReachWithError(ReachWithErrorType::UseAfterFree)};
123125
} else {
124-
return {};
126+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
125127
}
126128
} else {
127-
return {};
129+
return {ReachWithError(ReachWithErrorType::MaybeTaint, ruleId)};
128130
}
129131
}
130132

@@ -133,7 +135,7 @@ std::optional<Result> tryConvertResultJson(const ResultJson &resultJson,
133135
const std::string &id) {
134136
std::vector<ReachWithError> errors = {};
135137
if (!resultJson.ruleId.has_value()) {
136-
errors = {ReachWithError::Reachable};
138+
errors = {ReachWithError(ReachWithErrorType::Reachable)};
137139
} else {
138140
errors =
139141
tryConvertRuleJson(*resultJson.ruleId, toolName, resultJson.message);
@@ -190,8 +192,27 @@ static const char *ReachWithErrorNames[] = {
190192
"None",
191193
};
192194

195+
ReachWithError::ReachWithError(ReachWithErrorType type,
196+
std::optional<std::string> data)
197+
: type(type), data(std::move(data)) {}
198+
199+
bool ReachWithError::operator==(const ReachWithError &other) const {
200+
if (type == other.type && ReachWithErrorType::MaybeTaint == type) {
201+
return data == other.data;
202+
}
203+
return (type == other.type);
204+
}
205+
206+
bool ReachWithError::operator!=(const ReachWithError &other) const {
207+
return !(*this == other);
208+
}
209+
210+
bool ReachWithError::operator<(const ReachWithError &other) const {
211+
return type < other.type;
212+
}
213+
193214
const char *getErrorString(ReachWithError error) {
194-
return ReachWithErrorNames[error];
215+
return ReachWithErrorNames[error.type];
195216
}
196217

197218
std::string getErrorsString(const std::vector<ReachWithError> &errors) {

0 commit comments

Comments
 (0)