aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
commitc0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch)
treef42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
parent344a3780b2e33f6ca763666c380202b18aab72a3 (diff)
Diffstat (limited to 'llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp')
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp54
1 files changed, 17 insertions, 37 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
index 7abb6fa8905c..2a4349e02f1b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
@@ -64,29 +64,21 @@ ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
// Recursively descend the def-use lists from V to find non-bitcast users of
// bitcasts of V.
static void findUses(Value *V, Function &F,
- SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
- SmallPtrSetImpl<Constant *> &ConstantBCs) {
- for (Use &U : V->uses()) {
- if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
- findUses(BC, F, Uses, ConstantBCs);
- else if (auto *A = dyn_cast<GlobalAlias>(U.getUser()))
- findUses(A, F, Uses, ConstantBCs);
- else if (U.get()->getType() != F.getType()) {
- CallBase *CB = dyn_cast<CallBase>(U.getUser());
- if (!CB)
- // Skip uses that aren't immediately called
- continue;
+ SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {
+ for (User *U : V->users()) {
+ if (auto *BC = dyn_cast<BitCastOperator>(U))
+ findUses(BC, F, Uses);
+ else if (auto *A = dyn_cast<GlobalAlias>(U))
+ findUses(A, F, Uses);
+ else if (auto *CB = dyn_cast<CallBase>(U)) {
Value *Callee = CB->getCalledOperand();
if (Callee != V)
// Skip calls where the function isn't the callee
continue;
- if (isa<Constant>(U.get())) {
- // Only add constant bitcasts to the list once; they get RAUW'd
- auto C = ConstantBCs.insert(cast<Constant>(U.get()));
- if (!C.second)
- continue;
- }
- Uses.push_back(std::make_pair(&U, &F));
+ if (CB->getFunctionType() == F.getValueType())
+ // Skip uses that are immediately called
+ continue;
+ Uses.push_back(std::make_pair(CB, &F));
}
}
}
@@ -238,8 +230,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {
Function *Main = nullptr;
CallInst *CallMain = nullptr;
- SmallVector<std::pair<Use *, Function *>, 0> Uses;
- SmallPtrSet<Constant *, 2> ConstantBCs;
+ SmallVector<std::pair<CallBase *, Function *>, 0> Uses;
// Collect all the places that need wrappers.
for (Function &F : M) {
@@ -247,7 +238,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {
// bitcast type difference for swiftself and swifterror.
if (F.getCallingConv() == CallingConv::Swift)
continue;
- findUses(&F, F, Uses, ConstantBCs);
+ findUses(&F, F, Uses);
// If we have a "main" function, and its type isn't
// "int main(int argc, char *argv[])", create an artificial call with it
@@ -268,8 +259,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {
Value *Casted =
ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
- Use *UseMain = &CallMain->getOperandUse(2);
- Uses.push_back(std::make_pair(UseMain, &F));
+ Uses.push_back(std::make_pair(CallMain, &F));
}
}
}
@@ -277,16 +267,9 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {
DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
for (auto &UseFunc : Uses) {
- Use *U = UseFunc.first;
+ CallBase *CB = UseFunc.first;
Function *F = UseFunc.second;
- auto *PTy = cast<PointerType>(U->get()->getType());
- auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
-
- // If the function is casted to something like i8* as a "generic pointer"
- // to be later casted to something else, we can't generate a wrapper for it.
- // Just ignore such casts for now.
- if (!Ty)
- continue;
+ FunctionType *Ty = CB->getFunctionType();
auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
if (Pair.second)
@@ -296,10 +279,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {
if (!Wrapper)
continue;
- if (isa<Constant>(U->get()))
- U->get()->replaceAllUsesWith(Wrapper);
- else
- U->set(Wrapper);
+ CB->setCalledOperand(Wrapper);
}
// If we created a wrapper for main, rename the wrapper so that it's the