diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 162 |
1 files changed, 124 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 487a0a4a97f7..d33258642365 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,7 +58,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -84,9 +83,6 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndexYAML.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Errc.h" @@ -94,6 +90,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/GlobPattern.h" #include "llvm/Support/MathExtras.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -259,7 +256,7 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, if (I < B.size()) BitsUsed |= B[I]; if (BitsUsed != 0xff) - return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed)); + return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed)); } } else { // Find a free (Size/8) byte region in each member of Used. @@ -313,9 +310,10 @@ void wholeprogramdevirt::setAfterReturnValues( } } -VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) +VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM) : Fn(Fn), TM(TM), - IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} + IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), + WasDevirt(false) {} namespace { @@ -379,6 +377,7 @@ namespace { // conditions // 1) All summaries are live. // 2) All function summaries indicate it's unreachable +// 3) There is no non-function with the same GUID (which is rare) bool mustBeUnreachableFunction(ValueInfo TheFnVI) { if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { // Returns false if ValueInfo is absent, or the summary list is empty @@ -391,12 +390,13 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { // In general either all summaries should be live or all should be dead. if (!Summary->isLive()) return false; - if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { + if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) { if (!FS->fflags().MustBeUnreachable) return false; } - // Do nothing if a non-function has the same GUID (which is rare). - // This is correct since non-function summaries are not relevant. + // Be conservative if a non-function has the same GUID (which is rare). + else + return false; } // All function summaries are live and all of them agree that the function is // unreachble. @@ -567,6 +567,10 @@ struct DevirtModule { // optimize a call more than once. SmallPtrSet<CallBase *, 8> OptimizedCalls; + // Store calls that had their ptrauth bundle removed. They are to be deleted + // at the end of the optimization. + SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved; + // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated // by this pass. An unsafe use is one that calls the loaded function pointer @@ -761,7 +765,7 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return FAM.getResult<DominatorTreeAnalysis>(F); }; if (UseCommandLine) { - if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) + if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -892,8 +896,7 @@ static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. const auto &ModPaths = Summary->modulePaths(); if (ClSummaryAction != PassSummaryAction::Import && - ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) == - ModPaths.end()) + !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName())) return createStringError( errc::invalid_argument, "combined summary should contain Regular LTO module"); @@ -958,7 +961,7 @@ void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { DenseMap<GlobalVariable *, VTableBits *> GVToBits; - Bits.reserve(M.getGlobalList().size()); + Bits.reserve(M.global_size()); SmallVector<MDNode *, 2> Types; for (GlobalVariable &GV : M.globals()) { Types.clear(); @@ -1003,11 +1006,17 @@ bool DevirtModule::tryFindVirtualCallTargets( return false; Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), - TM.Offset + ByteOffset, M); + TM.Offset + ByteOffset, M, TM.Bits->GV); if (!Ptr) return false; - auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); + auto C = Ptr->stripPointerCasts(); + // Make sure this is a function or alias to a function. + auto Fn = dyn_cast<Function>(C); + auto A = dyn_cast<GlobalAlias>(C); + if (!Fn && A) + Fn = dyn_cast<Function>(A->getAliasee()); + if (!Fn) return false; @@ -1024,7 +1033,11 @@ bool DevirtModule::tryFindVirtualCallTargets( if (mustBeUnreachableFunction(Fn, ExportSummary)) continue; - TargetsForSlot.push_back({Fn, &TM}); + // Save the symbol used in the vtable to use as the devirtualization + // target. + auto GV = dyn_cast<GlobalValue>(C); + assert(GV); + TargetsForSlot.push_back({GV, &TM}); } // Give up if we couldn't find any targets. @@ -1156,6 +1169,14 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, // !callees metadata. CB.setMetadata(LLVMContext::MD_prof, nullptr); CB.setMetadata(LLVMContext::MD_callees, nullptr); + if (CB.getCalledOperand() && + CB.getOperandBundle(LLVMContext::OB_ptrauth)) { + auto *NewCS = + CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB); + CB.replaceAllUsesWith(NewCS); + // Schedule for deletion at the end of pass run. + CallsWithPtrAuthBundleRemoved.push_back(&CB); + } } // This use is no longer unsafe. @@ -1205,7 +1226,7 @@ bool DevirtModule::trySingleImplDevirt( WholeProgramDevirtResolution *Res) { // See if the program contains a single implementation of this virtual // function. - Function *TheFn = TargetsForSlot[0].Fn; + auto *TheFn = TargetsForSlot[0].Fn; for (auto &&Target : TargetsForSlot) if (TheFn != Target.Fn) return false; @@ -1379,9 +1400,20 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IsExported = true; if (CSInfo.AllCallSitesDevirted) return; + + std::map<CallBase *, CallBase *> CallBases; for (auto &&VCallSite : CSInfo.CallSites) { CallBase &CB = VCallSite.CB; + if (CallBases.find(&CB) != CallBases.end()) { + // When finding devirtualizable calls, it's possible to find the same + // vtable passed to multiple llvm.type.test or llvm.type.checked.load + // calls, which can cause duplicate call sites to be recorded in + // [Const]CallSites. If we've already found one of these + // call instances, just ignore it. It will be replaced later. + continue; + } + // Jump tables are only profitable if the retpoline mitigation is enabled. Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); if (!FSAttr.isValid() || @@ -1428,8 +1460,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, AttributeList::get(M.getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), NewArgAttrs)); - CB.replaceAllUsesWith(NewCS); - CB.eraseFromParent(); + CallBases[&CB] = NewCS; // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1439,6 +1470,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // retpoline mitigation, which would mean that they are lowered to // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. + + std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { + CBs.first->replaceAllUsesWith(CBs.second); + CBs.first->eraseFromParent(); + }); }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -1451,23 +1487,30 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( // Evaluate each function and store the result in each target's RetVal // field. for (VirtualCallTarget &Target : TargetsForSlot) { - if (Target.Fn->arg_size() != Args.size() + 1) + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(Target.Fn); + if (!Fn) + return false; + + if (Fn->arg_size() != Args.size() + 1) return false; Evaluator Eval(M.getDataLayout(), nullptr); SmallVector<Constant *, 2> EvalArgs; EvalArgs.push_back( - Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); + Constant::getNullValue(Fn->getFunctionType()->getParamType(0))); for (unsigned I = 0; I != Args.size(); ++I) { - auto *ArgTy = dyn_cast<IntegerType>( - Target.Fn->getFunctionType()->getParamType(I + 1)); + auto *ArgTy = + dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1)); if (!ArgTy) return false; EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); } Constant *RetVal; - if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || + if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) || !isa<ConstantInt>(RetVal)) return false; Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); @@ -1675,8 +1718,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, OREGetter, IsBitSet); } else { - Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); - Value *Val = B.CreateLoad(RetType, ValAddr); + Value *Val = B.CreateLoad(RetType, Addr); NumVirtConstProp++; Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, OREGetter, Val); @@ -1688,8 +1730,14 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, bool DevirtModule::tryVirtualConstProp( MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res, VTableSlot Slot) { + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn); + if (!Fn) + return false; // This only works if the function returns an integer. - auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); + auto RetType = dyn_cast<IntegerType>(Fn->getReturnType()); if (!RetType) return false; unsigned BitWidth = RetType->getBitWidth(); @@ -1707,11 +1755,18 @@ bool DevirtModule::tryVirtualConstProp( // inline all implementations of the virtual function into each call site, // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { - if (Target.Fn->isDeclaration() || - !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(Target.Fn); + if (!Fn) + return false; + + if (Fn->isDeclaration() || + !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn)) .doesNotAccessMemory() || - Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || - Target.Fn->getReturnType() != RetType) + Fn->arg_empty() || !Fn->arg_begin()->use_empty() || + Fn->getReturnType() != RetType) return false; } @@ -1947,9 +2002,23 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { // This helps avoid unnecessary spills. IRBuilder<> LoadB( (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); - Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); - Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + Value *LoadedValue = nullptr; + if (TypeCheckedLoadFunc->getIntrinsicID() == + Intrinsic::type_checked_load_relative) { + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); + LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); + LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); + GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); + LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); + LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); + } else { + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = + LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + } for (Instruction *LoadedPtr : LoadedPtrs) { LoadedPtr->replaceAllUsesWith(LoadedValue); @@ -2130,6 +2199,8 @@ bool DevirtModule::run() { M.getFunction(Intrinsic::getName(Intrinsic::type_test)); Function *TypeCheckedLoadFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *TypeCheckedLoadRelativeFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); // Normally if there are no users of the devirtualization intrinsics in the @@ -2138,7 +2209,9 @@ bool DevirtModule::run() { if (!ExportSummary && (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || AssumeFunc->use_empty()) && - (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && + (!TypeCheckedLoadRelativeFunc || + TypeCheckedLoadRelativeFunc->use_empty())) return false; // Rebuild type metadata into a map for easy lookup. @@ -2152,6 +2225,9 @@ bool DevirtModule::run() { if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + if (TypeCheckedLoadRelativeFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc); + if (ImportSummary) { for (auto &S : CallSlots) importResolution(S.first, S.second); @@ -2219,7 +2295,7 @@ bool DevirtModule::run() { // For each (type, offset) pair: bool DidVirtualConstProp = false; - std::map<std::string, Function*> DevirtTargets; + std::map<std::string, GlobalValue *> DevirtTargets; for (auto &S : CallSlots) { // Search each of the members of the type identifier for the virtual // function implementation at offset S.first.ByteOffset, and add to @@ -2274,7 +2350,14 @@ bool DevirtModule::run() { if (RemarksEnabled) { // Generate remarks for each devirtualized function. for (const auto &DT : DevirtTargets) { - Function *F = DT.second; + GlobalValue *GV = DT.second; + auto F = dyn_cast<Function>(GV); + if (!F) { + auto A = dyn_cast<GlobalAlias>(GV); + assert(A && isa<Function>(A->getAliasee())); + F = dyn_cast<Function>(A->getAliasee()); + assert(F); + } using namespace ore; OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) @@ -2299,6 +2382,9 @@ bool DevirtModule::run() { for (GlobalVariable &GV : M.globals()) GV.eraseMetadata(LLVMContext::MD_vcall_visibility); + for (auto *CI : CallsWithPtrAuthBundleRemoved) + CI->eraseFromParent(); + return true; } |