Skip to content

Commit e7e6e33

Browse files
Refactor to use graph traversal
1 parent eb6c79d commit e7e6e33

File tree

1 file changed

+147
-141
lines changed

1 file changed

+147
-141
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 147 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -22,168 +22,172 @@
2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
2424
#include "pass.h"
25+
#include "support/hash.h"
2526
#include "support/unique_deferring_queue.h"
2627
#include "wasm.h"
2728

2829
namespace wasm {
2930

30-
struct GenerateGlobalEffects : public Pass {
31-
void run(Module* module) override {
32-
// First, we do a scan of each function to see what effects they have,
33-
// including which functions they call directly (so that we can compute
34-
// transitive effects later).
35-
36-
struct FuncInfo {
37-
// Effects in this function.
38-
std::optional<EffectAnalyzer> effects;
39-
40-
// Directly-called functions from this function.
41-
std::unordered_set<Name> calledFunctions;
42-
};
43-
44-
ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
45-
*module, [&](Function* func, FuncInfo& funcInfo) {
46-
if (func->imported()) {
47-
// Imports can do anything, so we need to assume the worst anyhow,
48-
// which is the same as not specifying any effects for them in the
49-
// map (which we do by not setting funcInfo.effects).
50-
return;
51-
}
52-
53-
// Gather the effects.
54-
funcInfo.effects.emplace(getPassOptions(), *module, func);
55-
56-
if (funcInfo.effects->calls) {
57-
// There are calls in this function, which we will analyze in detail.
58-
// Clear the |calls| field first, and we'll handle calls of all sorts
59-
// below.
60-
funcInfo.effects->calls = false;
61-
62-
// Clear throws as well, as we are "forgetting" calls right now, and
63-
// want to forget their throwing effect as well. If we see something
64-
// else that throws, below, then we'll note that there.
65-
funcInfo.effects->throws_ = false;
66-
67-
struct CallScanner
68-
: public PostWalker<CallScanner,
69-
UnifiedExpressionVisitor<CallScanner>> {
70-
Module& wasm;
71-
PassOptions& options;
72-
FuncInfo& funcInfo;
73-
74-
CallScanner(Module& wasm, PassOptions& options, FuncInfo& funcInfo)
75-
: wasm(wasm), options(options), funcInfo(funcInfo) {}
76-
77-
void visitExpression(Expression* curr) {
78-
ShallowEffectAnalyzer effects(options, wasm, curr);
79-
if (auto* call = curr->dynCast<Call>()) {
80-
// Note the direct call.
81-
funcInfo.calledFunctions.insert(call->target);
82-
} else if (effects.calls) {
83-
// This is an indirect call of some sort, so we must assume the
84-
// worst. To do so, clear the effects, which indicates nothing
85-
// is known (so anything is possible).
86-
// TODO: We could group effects by function type etc.
87-
funcInfo.effects.reset();
88-
} else {
89-
// No call here, but update throwing if we see it. (Only do so,
90-
// however, if we have effects; if we cleared it - see before -
91-
// then we assume the worst anyhow, and have nothing to update.)
92-
if (effects.throws_ && funcInfo.effects) {
93-
funcInfo.effects->throws_ = true;
94-
}
31+
namespace {
32+
33+
struct FuncInfo {
34+
// Effects in this function.
35+
std::optional<EffectAnalyzer> effects;
36+
37+
// Directly-called functions from this function.
38+
std::unordered_set<Name> calledFunctions;
39+
};
40+
41+
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
42+
const PassOptions& passOptions) {
43+
ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
44+
module, [&](Function* func, FuncInfo& funcInfo) {
45+
if (func->imported()) {
46+
// Imports can do anything, so we need to assume the worst anyhow,
47+
// which is the same as not specifying any effects for them in the
48+
// map (which we do by not setting funcInfo.effects).
49+
return;
50+
}
51+
52+
// Gather the effects.
53+
funcInfo.effects.emplace(passOptions, module, func);
54+
55+
if (funcInfo.effects->calls) {
56+
// There are calls in this function, which we will analyze in detail.
57+
// Clear the |calls| field first, and we'll handle calls of all sorts
58+
// below.
59+
funcInfo.effects->calls = false;
60+
61+
// Clear throws as well, as we are "forgetting" calls right now, and
62+
// want to forget their throwing effect as well. If we see something
63+
// else that throws, below, then we'll note that there.
64+
funcInfo.effects->throws_ = false;
65+
66+
struct CallScanner
67+
: public PostWalker<CallScanner,
68+
UnifiedExpressionVisitor<CallScanner>> {
69+
Module& wasm;
70+
const PassOptions& options;
71+
FuncInfo& funcInfo;
72+
73+
CallScanner(Module& wasm,
74+
const PassOptions& options,
75+
FuncInfo& funcInfo)
76+
: wasm(wasm), options(options), funcInfo(funcInfo) {}
77+
78+
void visitExpression(Expression* curr) {
79+
ShallowEffectAnalyzer effects(options, wasm, curr);
80+
if (auto* call = curr->dynCast<Call>()) {
81+
// Note the direct call.
82+
funcInfo.calledFunctions.insert(call->target);
83+
} else if (effects.calls) {
84+
// This is an indirect call of some sort, so we must assume the
85+
// worst. To do so, clear the effects, which indicates nothing
86+
// is known (so anything is possible).
87+
// TODO: We could group effects by function type etc.
88+
funcInfo.effects.reset();
89+
} else {
90+
// No call here, but update throwing if we see it. (Only do so,
91+
// however, if we have effects; if we cleared it - see before -
92+
// then we assume the worst anyhow, and have nothing to update.)
93+
if (effects.throws_ && funcInfo.effects) {
94+
funcInfo.effects->throws_ = true;
9595
}
9696
}
97-
};
98-
CallScanner scanner(*module, getPassOptions(), funcInfo);
99-
scanner.walkFunction(func);
100-
}
101-
});
102-
103-
// Compute the transitive closure of effects. To do so, first construct for
104-
// each function a list of the functions that it is called by (so we need to
105-
// propagate its effects to them), and then we'll construct the closure of
106-
// that.
107-
//
108-
// callers[foo] = [func that calls foo, another func that calls foo, ..]
109-
//
110-
std::unordered_map<Name, std::unordered_set<Name>> callers;
111-
112-
// Our work queue contains info about a new call pair: a call from a caller
113-
// to a called function, that is information we then apply and propagate.
114-
using CallPair = std::pair<Name, Name>; // { caller, called }
115-
UniqueDeferredQueue<CallPair> work;
116-
for (auto& [func, info] : analysis.map) {
117-
for (auto& called : info.calledFunctions) {
118-
work.push({func->name, called});
97+
}
98+
};
99+
CallScanner scanner(module, passOptions, funcInfo);
100+
scanner.walkFunction(func);
119101
}
102+
});
103+
104+
return std::move(analysis.map);
105+
}
106+
107+
// Propagate effects from callees to callers transitively
108+
// e.g. if A -> B -> C (A calls B which calls C)
109+
// Then B inherits effects from C and A inherits effects from both B and C.
110+
void propagateEffects(
111+
const Module& module,
112+
const std::unordered_map<Name, std::unordered_set<Name>>& in,
113+
std::map<Function*, FuncInfo>& funcInfos) {
114+
115+
std::unordered_set<std::pair<Name, Name>> processed;
116+
std::deque<std::pair<Name, Name>> work;
117+
118+
for (const auto& [callee, callers] : in) {
119+
for (const auto& caller : callers) {
120+
work.emplace_back(callee, caller);
121+
processed.emplace(callee, caller);
122+
}
123+
}
124+
125+
auto propagate = [&](Name callee, Name caller) {
126+
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
127+
const auto& calleeEffects =
128+
funcInfos.at(module.getFunction(callee)).effects;
129+
if (!callerEffects) {
130+
return;
131+
}
132+
133+
if (!calleeEffects) {
134+
callerEffects.reset();
135+
return;
120136
}
121137

122-
// Compute the transitive closure of the call graph, that is, fill out
123-
// |callers| so that it contains the list of all callers - even through a
124-
// chain - of each function.
125-
while (!work.empty()) {
126-
auto [caller, called] = work.pop();
127-
128-
// We must not already have an entry for this call (that would imply we
129-
// are doing wasted work).
130-
assert(!callers[called].contains(caller));
131-
132-
// Apply the new call information.
133-
callers[called].insert(caller);
134-
135-
// We just learned that |caller| calls |called|. It also calls
136-
// transitively, which we need to propagate to all places unaware of that
137-
// information yet.
138-
//
139-
// caller => called => called by called
140-
//
141-
auto& calledInfo = analysis.map[module->getFunction(called)];
142-
for (auto calledByCalled : calledInfo.calledFunctions) {
143-
if (!callers[calledByCalled].contains(caller)) {
144-
work.push({caller, calledByCalled});
145-
}
138+
callerEffects->mergeIn(*calleeEffects);
139+
};
140+
141+
while (!work.empty()) {
142+
auto [callee, caller] = work.back();
143+
work.pop_back();
144+
145+
if (callee == caller) {
146+
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
147+
if (callerEffects) {
148+
callerEffects->trap = true;
146149
}
147150
}
148151

149-
// Now that we have transitively propagated all static calls, apply that
150-
// information. First, apply infinite recursion: if a function can call
151-
// itself then it might recurse infinitely, which we consider an effect (a
152-
// trap).
153-
for (auto& [func, info] : analysis.map) {
154-
if (callers[func->name].contains(func->name)) {
155-
if (info.effects) {
156-
info.effects->trap = true;
157-
}
152+
// Even if nothing changed, we still need to keep traversing the callers
153+
// to look for a potential cycle which adds a trap affect on the above
154+
// lines.
155+
propagate(callee, caller);
156+
157+
const auto& callerCallers = in.find(caller);
158+
if (callerCallers == in.end()) {
159+
continue;
160+
}
161+
162+
for (const Name& callerCaller : callerCallers->second) {
163+
if (processed.contains({callee, callerCaller})) {
164+
continue;
158165
}
166+
167+
processed.emplace(callee, callerCaller);
168+
work.emplace_back(callee, callerCaller);
159169
}
170+
}
171+
}
172+
173+
struct GenerateGlobalEffects : public Pass {
174+
void run(Module* module) override {
175+
std::map<Function*, FuncInfo> funcInfos =
176+
analyzeFuncs(*module, getPassOptions());
160177

161-
// Next, apply function effects to their callers.
162-
for (auto& [func, info] : analysis.map) {
163-
auto& funcEffects = info.effects;
164-
165-
for (auto& caller : callers[func->name]) {
166-
auto& callerEffects = analysis.map[module->getFunction(caller)].effects;
167-
if (!callerEffects) {
168-
// Nothing is known for the caller, which is already the worst case.
169-
continue;
170-
}
171-
172-
if (!funcEffects) {
173-
// Nothing is known for the called function, which means nothing is
174-
// known for the caller either.
175-
callerEffects.reset();
176-
continue;
177-
}
178-
179-
// Add func's effects to the caller.
180-
callerEffects->mergeIn(*funcEffects);
178+
// callee : caller
179+
std::unordered_map<Name, std::unordered_set<Name>> callers;
180+
for (const auto& [func, info] : funcInfos) {
181+
for (const auto& callee : info.calledFunctions) {
182+
callers[callee].insert(func->name);
181183
}
182184
}
183185

186+
propagateEffects(*module, callers, funcInfos);
187+
184188
// Generate the final data, starting from a blank slate where nothing is
185189
// known.
186-
for (auto& [func, info] : analysis.map) {
190+
for (auto& [func, info] : funcInfos) {
187191
func->effects.reset();
188192
if (!info.effects) {
189193
continue;
@@ -202,6 +206,8 @@ struct DiscardGlobalEffects : public Pass {
202206
}
203207
};
204208

209+
} // namespace
210+
205211
Pass* createGenerateGlobalEffectsPass() { return new GenerateGlobalEffects(); }
206212

207213
Pass* createDiscardGlobalEffectsPass() { return new DiscardGlobalEffects(); }

0 commit comments

Comments
 (0)