aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp162
1 files changed, 124 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 487a0a4a97f7..d33258642365 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -58,7 +58,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -84,9 +83,6 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSummaryIndexYAML.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/PassRegistry.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Errc.h"
@@ -94,6 +90,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/GlobPattern.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -259,7 +256,7 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
if (I < B.size())
BitsUsed |= B[I];
if (BitsUsed != 0xff)
- return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed));
+ return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed));
}
} else {
// Find a free (Size/8) byte region in each member of Used.
@@ -313,9 +310,10 @@ void wholeprogramdevirt::setAfterReturnValues(
}
}
-VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
+VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM)
: Fn(Fn), TM(TM),
- IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
+ IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()),
+ WasDevirt(false) {}
namespace {
@@ -379,6 +377,7 @@ namespace {
// conditions
// 1) All summaries are live.
// 2) All function summaries indicate it's unreachable
+// 3) There is no non-function with the same GUID (which is rare)
bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
// Returns false if ValueInfo is absent, or the summary list is empty
@@ -391,12 +390,13 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
// In general either all summaries should be live or all should be dead.
if (!Summary->isLive())
return false;
- if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) {
+ if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) {
if (!FS->fflags().MustBeUnreachable)
return false;
}
- // Do nothing if a non-function has the same GUID (which is rare).
- // This is correct since non-function summaries are not relevant.
+ // Be conservative if a non-function has the same GUID (which is rare).
+ else
+ return false;
}
// All function summaries are live and all of them agree that the function is
// unreachble.
@@ -567,6 +567,10 @@ struct DevirtModule {
// optimize a call more than once.
SmallPtrSet<CallBase *, 8> OptimizedCalls;
+ // Store calls that had their ptrauth bundle removed. They are to be deleted
+ // at the end of the optimization.
+ SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved;
+
// This map keeps track of the number of "unsafe" uses of a loaded function
// pointer. The key is the associated llvm.type.test intrinsic call generated
// by this pass. An unsafe use is one that calls the loaded function pointer
@@ -761,7 +765,7 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return FAM.getResult<DominatorTreeAnalysis>(F);
};
if (UseCommandLine) {
- if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
+ if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
@@ -892,8 +896,7 @@ static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
// DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
const auto &ModPaths = Summary->modulePaths();
if (ClSummaryAction != PassSummaryAction::Import &&
- ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) ==
- ModPaths.end())
+ !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName()))
return createStringError(
errc::invalid_argument,
"combined summary should contain Regular LTO module");
@@ -958,7 +961,7 @@ void DevirtModule::buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
DenseMap<GlobalVariable *, VTableBits *> GVToBits;
- Bits.reserve(M.getGlobalList().size());
+ Bits.reserve(M.global_size());
SmallVector<MDNode *, 2> Types;
for (GlobalVariable &GV : M.globals()) {
Types.clear();
@@ -1003,11 +1006,17 @@ bool DevirtModule::tryFindVirtualCallTargets(
return false;
Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
- TM.Offset + ByteOffset, M);
+ TM.Offset + ByteOffset, M, TM.Bits->GV);
if (!Ptr)
return false;
- auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
+ auto C = Ptr->stripPointerCasts();
+ // Make sure this is a function or alias to a function.
+ auto Fn = dyn_cast<Function>(C);
+ auto A = dyn_cast<GlobalAlias>(C);
+ if (!Fn && A)
+ Fn = dyn_cast<Function>(A->getAliasee());
+
if (!Fn)
return false;
@@ -1024,7 +1033,11 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (mustBeUnreachableFunction(Fn, ExportSummary))
continue;
- TargetsForSlot.push_back({Fn, &TM});
+ // Save the symbol used in the vtable to use as the devirtualization
+ // target.
+ auto GV = dyn_cast<GlobalValue>(C);
+ assert(GV);
+ TargetsForSlot.push_back({GV, &TM});
}
// Give up if we couldn't find any targets.
@@ -1156,6 +1169,14 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
// !callees metadata.
CB.setMetadata(LLVMContext::MD_prof, nullptr);
CB.setMetadata(LLVMContext::MD_callees, nullptr);
+ if (CB.getCalledOperand() &&
+ CB.getOperandBundle(LLVMContext::OB_ptrauth)) {
+ auto *NewCS =
+ CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB);
+ CB.replaceAllUsesWith(NewCS);
+ // Schedule for deletion at the end of pass run.
+ CallsWithPtrAuthBundleRemoved.push_back(&CB);
+ }
}
// This use is no longer unsafe.
@@ -1205,7 +1226,7 @@ bool DevirtModule::trySingleImplDevirt(
WholeProgramDevirtResolution *Res) {
// See if the program contains a single implementation of this virtual
// function.
- Function *TheFn = TargetsForSlot[0].Fn;
+ auto *TheFn = TargetsForSlot[0].Fn;
for (auto &&Target : TargetsForSlot)
if (TheFn != Target.Fn)
return false;
@@ -1379,9 +1400,20 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
IsExported = true;
if (CSInfo.AllCallSitesDevirted)
return;
+
+ std::map<CallBase *, CallBase *> CallBases;
for (auto &&VCallSite : CSInfo.CallSites) {
CallBase &CB = VCallSite.CB;
+ if (CallBases.find(&CB) != CallBases.end()) {
+ // When finding devirtualizable calls, it's possible to find the same
+ // vtable passed to multiple llvm.type.test or llvm.type.checked.load
+ // calls, which can cause duplicate call sites to be recorded in
+ // [Const]CallSites. If we've already found one of these
+ // call instances, just ignore it. It will be replaced later.
+ continue;
+ }
+
// Jump tables are only profitable if the retpoline mitigation is enabled.
Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (!FSAttr.isValid() ||
@@ -1428,8 +1460,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
AttributeList::get(M.getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), NewArgAttrs));
- CB.replaceAllUsesWith(NewCS);
- CB.eraseFromParent();
+ CallBases[&CB] = NewCS;
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
@@ -1439,6 +1470,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// retpoline mitigation, which would mean that they are lowered to
// llvm.type.test and therefore require an llvm.type.test resolution for the
// type identifier.
+
+ std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) {
+ CBs.first->replaceAllUsesWith(CBs.second);
+ CBs.first->eraseFromParent();
+ });
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -1451,23 +1487,30 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs(
// Evaluate each function and store the result in each target's RetVal
// field.
for (VirtualCallTarget &Target : TargetsForSlot) {
- if (Target.Fn->arg_size() != Args.size() + 1)
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(Target.Fn);
+ if (!Fn)
+ return false;
+
+ if (Fn->arg_size() != Args.size() + 1)
return false;
Evaluator Eval(M.getDataLayout(), nullptr);
SmallVector<Constant *, 2> EvalArgs;
EvalArgs.push_back(
- Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
+ Constant::getNullValue(Fn->getFunctionType()->getParamType(0)));
for (unsigned I = 0; I != Args.size(); ++I) {
- auto *ArgTy = dyn_cast<IntegerType>(
- Target.Fn->getFunctionType()->getParamType(I + 1));
+ auto *ArgTy =
+ dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1));
if (!ArgTy)
return false;
EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
}
Constant *RetVal;
- if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
+ if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) ||
!isa<ConstantInt>(RetVal))
return false;
Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
@@ -1675,8 +1718,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
OREGetter, IsBitSet);
} else {
- Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
- Value *Val = B.CreateLoad(RetType, ValAddr);
+ Value *Val = B.CreateLoad(RetType, Addr);
NumVirtConstProp++;
Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
OREGetter, Val);
@@ -1688,8 +1730,14 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
bool DevirtModule::tryVirtualConstProp(
MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res, VTableSlot Slot) {
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn);
+ if (!Fn)
+ return false;
// This only works if the function returns an integer.
- auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
+ auto RetType = dyn_cast<IntegerType>(Fn->getReturnType());
if (!RetType)
return false;
unsigned BitWidth = RetType->getBitWidth();
@@ -1707,11 +1755,18 @@ bool DevirtModule::tryVirtualConstProp(
// inline all implementations of the virtual function into each call site,
// rather than using function attributes to perform local optimization.
for (VirtualCallTarget &Target : TargetsForSlot) {
- if (Target.Fn->isDeclaration() ||
- !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn))
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(Target.Fn);
+ if (!Fn)
+ return false;
+
+ if (Fn->isDeclaration() ||
+ !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn))
.doesNotAccessMemory() ||
- Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
- Target.Fn->getReturnType() != RetType)
+ Fn->arg_empty() || !Fn->arg_begin()->use_empty() ||
+ Fn->getReturnType() != RetType)
return false;
}
@@ -1947,9 +2002,23 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
// This helps avoid unnecessary spills.
IRBuilder<> LoadB(
(LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
- Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
- Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+
+ Value *LoadedValue = nullptr;
+ if (TypeCheckedLoadFunc->getIntrinsicID() ==
+ Intrinsic::type_checked_load_relative) {
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty));
+ LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr);
+ LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy);
+ GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy);
+ LoadedValue = LoadB.CreateAdd(GEP, LoadedValue);
+ LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy);
+ } else {
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr =
+ LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
+ LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+ }
for (Instruction *LoadedPtr : LoadedPtrs) {
LoadedPtr->replaceAllUsesWith(LoadedValue);
@@ -2130,6 +2199,8 @@ bool DevirtModule::run() {
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
Function *TypeCheckedLoadFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+ Function *TypeCheckedLoadRelativeFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative));
Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
// Normally if there are no users of the devirtualization intrinsics in the
@@ -2138,7 +2209,9 @@ bool DevirtModule::run() {
if (!ExportSummary &&
(!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
AssumeFunc->use_empty()) &&
- (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
+ (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
+ (!TypeCheckedLoadRelativeFunc ||
+ TypeCheckedLoadRelativeFunc->use_empty()))
return false;
// Rebuild type metadata into a map for easy lookup.
@@ -2152,6 +2225,9 @@ bool DevirtModule::run() {
if (TypeCheckedLoadFunc)
scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
+ if (TypeCheckedLoadRelativeFunc)
+ scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc);
+
if (ImportSummary) {
for (auto &S : CallSlots)
importResolution(S.first, S.second);
@@ -2219,7 +2295,7 @@ bool DevirtModule::run() {
// For each (type, offset) pair:
bool DidVirtualConstProp = false;
- std::map<std::string, Function*> DevirtTargets;
+ std::map<std::string, GlobalValue *> DevirtTargets;
for (auto &S : CallSlots) {
// Search each of the members of the type identifier for the virtual
// function implementation at offset S.first.ByteOffset, and add to
@@ -2274,7 +2350,14 @@ bool DevirtModule::run() {
if (RemarksEnabled) {
// Generate remarks for each devirtualized function.
for (const auto &DT : DevirtTargets) {
- Function *F = DT.second;
+ GlobalValue *GV = DT.second;
+ auto F = dyn_cast<Function>(GV);
+ if (!F) {
+ auto A = dyn_cast<GlobalAlias>(GV);
+ assert(A && isa<Function>(A->getAliasee()));
+ F = dyn_cast<Function>(A->getAliasee());
+ assert(F);
+ }
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
@@ -2299,6 +2382,9 @@ bool DevirtModule::run() {
for (GlobalVariable &GV : M.globals())
GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
+ for (auto *CI : CallsWithPtrAuthBundleRemoved)
+ CI->eraseFromParent();
+
return true;
}