diff options
Diffstat (limited to 'lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp')
| -rw-r--r-- | lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp | 79 | 
1 files changed, 45 insertions, 34 deletions
| diff --git a/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp b/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp index 1a416520f97d0..b7fc65401fc48 100644 --- a/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp +++ b/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp @@ -1,9 +1,8 @@  //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//  // -//                     The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  //  //===----------------------------------------------------------------------===//  /// @@ -36,11 +35,6 @@ using namespace llvm;  #define DEBUG_TYPE "wasm-fix-function-bitcasts" -static cl::opt<bool> -    TemporaryWorkarounds("wasm-temporary-workarounds", -                         cl::desc("Apply certain temporary workarounds"), -                         cl::init(true), cl::Hidden); -  namespace {  class FixFunctionBitcasts final : public ModulePass {    StringRef getPassName() const override { @@ -70,12 +64,12 @@ ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {  // Recursively descend the def-use lists from V to find non-bitcast users of  // bitcasts of V. -static void FindUses(Value *V, Function &F, +static void findUses(Value *V, Function &F,                       SmallVectorImpl<std::pair<Use *, Function *>> &Uses,                       SmallPtrSetImpl<Constant *> &ConstantBCs) {    for (Use &U : V->uses()) { -    if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser())) -      FindUses(BC, F, Uses, ConstantBCs); +    if (auto *BC = dyn_cast<BitCastOperator>(U.getUser())) +      findUses(BC, F, Uses, ConstantBCs);      else if (U.get()->getType() != F.getType()) {        CallSite CS(U.getUser());        if (!CS) @@ -87,8 +81,8 @@ static void FindUses(Value *V, Function &F,          continue;        if (isa<Constant>(U.get())) {          // Only add constant bitcasts to the list once; they get RAUW'd -        auto c = ConstantBCs.insert(cast<Constant>(U.get())); -        if (!c.second) +        auto C = ConstantBCs.insert(cast<Constant>(U.get())); +        if (!C.second)            continue;        }        Uses.push_back(std::make_pair(&U, &F)); @@ -119,7 +113,7 @@ static void FindUses(Value *V, Function &F,  // For bitcasts that involve struct types we don't know at this stage if they  // would be equivalent at the wasm level and so we can't know if we need to  // generate a wrapper. -static Function *CreateWrapper(Function *F, FunctionType *Ty) { +static Function *createWrapper(Function *F, FunctionType *Ty) {    Module *M = F->getParent();    Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage, @@ -157,11 +151,11 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) {          BB->getInstList().push_back(PtrCast);          Args.push_back(PtrCast);        } else if (ArgType->isStructTy() || ParamType->isStructTy()) { -        LLVM_DEBUG(dbgs() << "CreateWrapper: struct param type in bitcast: " +        LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "                            << F->getName() << "\n");          WrapperNeeded = false;        } else { -        LLVM_DEBUG(dbgs() << "CreateWrapper: arg type mismatch calling: " +        LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "                            << F->getName() << "\n");          LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "                            << *ParamType << " Got: " << *ArgType << "\n"); @@ -197,11 +191,11 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) {        BB->getInstList().push_back(Cast);        ReturnInst::Create(M->getContext(), Cast, BB);      } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) { -      LLVM_DEBUG(dbgs() << "CreateWrapper: struct return type in bitcast: " +      LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "                          << F->getName() << "\n");        WrapperNeeded = false;      } else { -      LLVM_DEBUG(dbgs() << "CreateWrapper: return type mismatch calling: " +      LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "                          << F->getName() << "\n");        LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType                          << " Got: " << *RtnType << "\n"); @@ -218,15 +212,26 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) {      new UnreachableInst(M->getContext(), BB);      Wrapper->setName(F->getName() + "_bitcast_invalid");    } else if (!WrapperNeeded) { -    LLVM_DEBUG(dbgs() << "CreateWrapper: no wrapper needed: " << F->getName() +    LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()                        << "\n");      Wrapper->eraseFromParent();      return nullptr;    } -  LLVM_DEBUG(dbgs() << "CreateWrapper: " << F->getName() << "\n"); +  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");    return Wrapper;  } +// Test whether a main function with type FuncTy should be rewritten to have +// type MainTy. +static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) { +  // Only fix the main function if it's the standard zero-arg form. That way, +  // the standard cases will work as expected, and users will see signature +  // mismatches from the linker for non-standard cases. +  return FuncTy->getReturnType() == MainTy->getReturnType() && +         FuncTy->getNumParams() == 0 && +         !FuncTy->isVarArg(); +} +  bool FixFunctionBitcasts::runOnModule(Module &M) {    LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n"); @@ -237,27 +242,27 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {    // Collect all the places that need wrappers.    for (Function &F : M) { -    FindUses(&F, F, Uses, ConstantBCs); +    findUses(&F, F, Uses, ConstantBCs);      // If we have a "main" function, and its type isn't      // "int main(int argc, char *argv[])", create an artificial call with it      // bitcasted to that type so that we generate a wrapper for it, so that      // the C runtime can call it. -    if (!TemporaryWorkarounds && !F.isDeclaration() && F.getName() == "main") { +    if (F.getName() == "main") {        Main = &F;        LLVMContext &C = M.getContext();        Type *MainArgTys[] = {Type::getInt32Ty(C),                              PointerType::get(Type::getInt8PtrTy(C), 0)};        FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,                                                 /*isVarArg=*/false); -      if (F.getFunctionType() != MainTy) { +      if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {          LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "                            << *F.getFunctionType() << "\n");          Value *Args[] = {UndefValue::get(MainArgTys[0]),                           UndefValue::get(MainArgTys[1])};          Value *Casted =              ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0)); -        CallMain = CallInst::Create(Casted, Args, "call_main"); +        CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");          Use *UseMain = &CallMain->getOperandUse(2);          Uses.push_back(std::make_pair(UseMain, &F));        } @@ -269,8 +274,8 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {    for (auto &UseFunc : Uses) {      Use *U = UseFunc.first;      Function *F = UseFunc.second; -    PointerType *PTy = cast<PointerType>(U->get()->getType()); -    FunctionType *Ty = dyn_cast<FunctionType>(PTy->getElementType()); +    auto *PTy = cast<PointerType>(U->get()->getType()); +    auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());      // If the function is casted to something like i8* as a "generic pointer"      // to be later casted to something else, we can't generate a wrapper for it. @@ -280,7 +285,7 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {      auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));      if (Pair.second) -      Pair.first->second = CreateWrapper(F, Ty); +      Pair.first->second = createWrapper(F, Ty);      Function *Wrapper = Pair.first->second;      if (!Wrapper) @@ -296,14 +301,20 @@ bool FixFunctionBitcasts::runOnModule(Module &M) {    // one that gets called from startup.    if (CallMain) {      Main->setName("__original_main"); -    Function *MainWrapper = +    auto *MainWrapper =          cast<Function>(CallMain->getCalledValue()->stripPointerCasts()); -    MainWrapper->setName("main"); -    MainWrapper->setLinkage(Main->getLinkage()); -    MainWrapper->setVisibility(Main->getVisibility()); -    Main->setLinkage(Function::PrivateLinkage); -    Main->setVisibility(Function::DefaultVisibility);      delete CallMain; +    if (Main->isDeclaration()) { +      // The wrapper is not needed in this case as we don't need to export +      // it to anyone else. +      MainWrapper->eraseFromParent(); +    } else { +      // Otherwise give the wrapper the same linkage as the original main +      // function, so that it can be called from the same places. +      MainWrapper->setName("main"); +      MainWrapper->setLinkage(Main->getLinkage()); +      MainWrapper->setVisibility(Main->getVisibility()); +    }    }    return true; | 
