diff options
Diffstat (limited to 'lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp')
-rw-r--r-- | lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp | 79 |
1 files changed, 45 insertions, 34 deletions
diff --git a/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp b/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp index 1a416520f97d..b7fc65401fc4 100644 --- a/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp +++ b/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp @@ -1,9 +1,8 @@ //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// @@ -36,11 +35,6 @@ using namespace llvm; #define DEBUG_TYPE "wasm-fix-function-bitcasts" -static cl::opt<bool> - TemporaryWorkarounds("wasm-temporary-workarounds", - cl::desc("Apply certain temporary workarounds"), - cl::init(true), cl::Hidden); - namespace { class FixFunctionBitcasts final : public ModulePass { StringRef getPassName() const override { @@ -70,12 +64,12 @@ ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() { // Recursively descend the def-use lists from V to find non-bitcast users of // bitcasts of V. -static void FindUses(Value *V, Function &F, +static void findUses(Value *V, Function &F, SmallVectorImpl<std::pair<Use *, Function *>> &Uses, SmallPtrSetImpl<Constant *> &ConstantBCs) { for (Use &U : V->uses()) { - if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser())) - FindUses(BC, F, Uses, ConstantBCs); + if (auto *BC = dyn_cast<BitCastOperator>(U.getUser())) + findUses(BC, F, Uses, ConstantBCs); else if (U.get()->getType() != F.getType()) { CallSite CS(U.getUser()); if (!CS) @@ -87,8 +81,8 @@ static void FindUses(Value *V, Function &F, continue; if (isa<Constant>(U.get())) { // Only add constant bitcasts to the list once; they get RAUW'd - auto c = ConstantBCs.insert(cast<Constant>(U.get())); - if (!c.second) + auto C = ConstantBCs.insert(cast<Constant>(U.get())); + if (!C.second) continue; } Uses.push_back(std::make_pair(&U, &F)); @@ -119,7 +113,7 @@ static void FindUses(Value *V, Function &F, // For bitcasts that involve struct types we don't know at this stage if they // would be equivalent at the wasm level and so we can't know if we need to // generate a wrapper. -static Function *CreateWrapper(Function *F, FunctionType *Ty) { +static Function *createWrapper(Function *F, FunctionType *Ty) { Module *M = F->getParent(); Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage, @@ -157,11 +151,11 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) { BB->getInstList().push_back(PtrCast); Args.push_back(PtrCast); } else if (ArgType->isStructTy() || ParamType->isStructTy()) { - LLVM_DEBUG(dbgs() << "CreateWrapper: struct param type in bitcast: " + LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: " << F->getName() << "\n"); WrapperNeeded = false; } else { - LLVM_DEBUG(dbgs() << "CreateWrapper: arg type mismatch calling: " + LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: " << F->getName() << "\n"); LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: " << *ParamType << " Got: " << *ArgType << "\n"); @@ -197,11 +191,11 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) { BB->getInstList().push_back(Cast); ReturnInst::Create(M->getContext(), Cast, BB); } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) { - LLVM_DEBUG(dbgs() << "CreateWrapper: struct return type in bitcast: " + LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: " << F->getName() << "\n"); WrapperNeeded = false; } else { - LLVM_DEBUG(dbgs() << "CreateWrapper: return type mismatch calling: " + LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: " << F->getName() << "\n"); LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType << " Got: " << *RtnType << "\n"); @@ -218,15 +212,26 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) { new UnreachableInst(M->getContext(), BB); Wrapper->setName(F->getName() + "_bitcast_invalid"); } else if (!WrapperNeeded) { - LLVM_DEBUG(dbgs() << "CreateWrapper: no wrapper needed: " << F->getName() + LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName() << "\n"); Wrapper->eraseFromParent(); return nullptr; } - LLVM_DEBUG(dbgs() << "CreateWrapper: " << F->getName() << "\n"); + LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n"); return Wrapper; } +// Test whether a main function with type FuncTy should be rewritten to have +// type MainTy. +static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) { + // Only fix the main function if it's the standard zero-arg form. That way, + // the standard cases will work as expected, and users will see signature + // mismatches from the linker for non-standard cases. + return FuncTy->getReturnType() == MainTy->getReturnType() && + FuncTy->getNumParams() == 0 && + !FuncTy->isVarArg(); +} + bool FixFunctionBitcasts::runOnModule(Module &M) { LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n"); @@ -237,27 +242,27 @@ bool FixFunctionBitcasts::runOnModule(Module &M) { // Collect all the places that need wrappers. for (Function &F : M) { - FindUses(&F, F, Uses, ConstantBCs); + findUses(&F, F, Uses, ConstantBCs); // If we have a "main" function, and its type isn't // "int main(int argc, char *argv[])", create an artificial call with it // bitcasted to that type so that we generate a wrapper for it, so that // the C runtime can call it. - if (!TemporaryWorkarounds && !F.isDeclaration() && F.getName() == "main") { + if (F.getName() == "main") { Main = &F; LLVMContext &C = M.getContext(); Type *MainArgTys[] = {Type::getInt32Ty(C), PointerType::get(Type::getInt8PtrTy(C), 0)}; FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys, /*isVarArg=*/false); - if (F.getFunctionType() != MainTy) { + if (shouldFixMainFunction(F.getFunctionType(), MainTy)) { LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: " << *F.getFunctionType() << "\n"); Value *Args[] = {UndefValue::get(MainArgTys[0]), UndefValue::get(MainArgTys[1])}; Value *Casted = ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0)); - CallMain = CallInst::Create(Casted, Args, "call_main"); + CallMain = CallInst::Create(MainTy, Casted, Args, "call_main"); Use *UseMain = &CallMain->getOperandUse(2); Uses.push_back(std::make_pair(UseMain, &F)); } @@ -269,8 +274,8 @@ bool FixFunctionBitcasts::runOnModule(Module &M) { for (auto &UseFunc : Uses) { Use *U = UseFunc.first; Function *F = UseFunc.second; - PointerType *PTy = cast<PointerType>(U->get()->getType()); - FunctionType *Ty = dyn_cast<FunctionType>(PTy->getElementType()); + auto *PTy = cast<PointerType>(U->get()->getType()); + auto *Ty = dyn_cast<FunctionType>(PTy->getElementType()); // If the function is casted to something like i8* as a "generic pointer" // to be later casted to something else, we can't generate a wrapper for it. @@ -280,7 +285,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) { auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr)); if (Pair.second) - Pair.first->second = CreateWrapper(F, Ty); + Pair.first->second = createWrapper(F, Ty); Function *Wrapper = Pair.first->second; if (!Wrapper) @@ -296,14 +301,20 @@ bool FixFunctionBitcasts::runOnModule(Module &M) { // one that gets called from startup. if (CallMain) { Main->setName("__original_main"); - Function *MainWrapper = + auto *MainWrapper = cast<Function>(CallMain->getCalledValue()->stripPointerCasts()); - MainWrapper->setName("main"); - MainWrapper->setLinkage(Main->getLinkage()); - MainWrapper->setVisibility(Main->getVisibility()); - Main->setLinkage(Function::PrivateLinkage); - Main->setVisibility(Function::DefaultVisibility); delete CallMain; + if (Main->isDeclaration()) { + // The wrapper is not needed in this case as we don't need to export + // it to anyone else. + MainWrapper->eraseFromParent(); + } else { + // Otherwise give the wrapper the same linkage as the original main + // function, so that it can be called from the same places. + MainWrapper->setName("main"); + MainWrapper->setLinkage(Main->getLinkage()); + MainWrapper->setVisibility(Main->getVisibility()); + } } return true; |