aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp422
1 files changed, 308 insertions, 114 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 5ccfb29b01a1..5a25f9857665 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -57,12 +57,14 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Triple.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
-#include "llvm/IR/CallSite.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
@@ -83,11 +85,12 @@
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
-#include "llvm/PassSupport.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/GlobPattern.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
@@ -115,12 +118,15 @@ static cl::opt<PassSummaryAction> ClSummaryAction(
static cl::opt<std::string> ClReadSummary(
"wholeprogramdevirt-read-summary",
- cl::desc("Read summary from given YAML file before running pass"),
+ cl::desc(
+ "Read summary from given bitcode or YAML file before running pass"),
cl::Hidden);
static cl::opt<std::string> ClWriteSummary(
"wholeprogramdevirt-write-summary",
- cl::desc("Write summary to given YAML file after running pass"),
+ cl::desc("Write summary to given bitcode or YAML file after running pass. "
+ "Output file format is deduced from extension: *.bc means writing "
+ "bitcode, otherwise YAML"),
cl::Hidden);
static cl::opt<unsigned>
@@ -134,6 +140,45 @@ static cl::opt<bool>
cl::init(false), cl::ZeroOrMore,
cl::desc("Print index-based devirtualization messages"));
+/// Provide a way to force enable whole program visibility in tests.
+/// This is needed to support legacy tests that don't contain
+/// !vcall_visibility metadata (the mere presense of type tests
+/// previously implied hidden visibility).
+cl::opt<bool>
+ WholeProgramVisibility("whole-program-visibility", cl::init(false),
+ cl::Hidden, cl::ZeroOrMore,
+ cl::desc("Enable whole program visibility"));
+
+/// Provide a way to force disable whole program for debugging or workarounds,
+/// when enabled via the linker.
+cl::opt<bool> DisableWholeProgramVisibility(
+ "disable-whole-program-visibility", cl::init(false), cl::Hidden,
+ cl::ZeroOrMore,
+ cl::desc("Disable whole program visibility (overrides enabling options)"));
+
+/// Provide way to prevent certain function from being devirtualized
+cl::list<std::string>
+ SkipFunctionNames("wholeprogramdevirt-skip",
+ cl::desc("Prevent function(s) from being devirtualized"),
+ cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated);
+
+namespace {
+struct PatternList {
+ std::vector<GlobPattern> Patterns;
+ template <class T> void init(const T &StringList) {
+ for (const auto &S : StringList)
+ if (Expected<GlobPattern> Pat = GlobPattern::create(S))
+ Patterns.push_back(std::move(*Pat));
+ }
+ bool match(StringRef S) {
+ for (const GlobPattern &P : Patterns)
+ if (P.match(S))
+ return true;
+ return false;
+ }
+};
+} // namespace
+
// Find the minimum offset that we may store a value of size Size bits at. If
// IsAfter is set, look for an offset before the object, otherwise look for an
// offset after the object.
@@ -308,20 +353,20 @@ namespace {
// A virtual call site. VTable is the loaded virtual table pointer, and CS is
// the indirect virtual call.
struct VirtualCallSite {
- Value *VTable;
- CallSite CS;
+ Value *VTable = nullptr;
+ CallBase &CB;
// If non-null, this field points to the associated unsafe use count stored in
// the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
// of that field for details.
- unsigned *NumUnsafeUses;
+ unsigned *NumUnsafeUses = nullptr;
void
emitRemark(const StringRef OptName, const StringRef TargetName,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
- Function *F = CS.getCaller();
- DebugLoc DLoc = CS->getDebugLoc();
- BasicBlock *Block = CS.getParent();
+ Function *F = CB.getCaller();
+ DebugLoc DLoc = CB.getDebugLoc();
+ BasicBlock *Block = CB.getParent();
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
@@ -336,12 +381,12 @@ struct VirtualCallSite {
Value *New) {
if (RemarksEnabled)
emitRemark(OptName, TargetName, OREGetter);
- CS->replaceAllUsesWith(New);
- if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
- BranchInst::Create(II->getNormalDest(), CS.getInstruction());
+ CB.replaceAllUsesWith(New);
+ if (auto *II = dyn_cast<InvokeInst>(&CB)) {
+ BranchInst::Create(II->getNormalDest(), &CB);
II->getUnwindDest()->removePredecessor(II->getParent());
}
- CS->eraseFromParent();
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (NumUnsafeUses)
--*NumUnsafeUses;
@@ -414,18 +459,18 @@ struct VTableSlotInfo {
// "this"), grouped by argument list.
std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
- void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+ void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
private:
- CallSiteInfo &findCallSiteInfo(CallSite CS);
+ CallSiteInfo &findCallSiteInfo(CallBase &CB);
};
-CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
std::vector<uint64_t> Args;
- auto *CI = dyn_cast<IntegerType>(CS.getType());
- if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
+ auto *CBType = dyn_cast<IntegerType>(CB.getType());
+ if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
return CSInfo;
- for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+ for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64)
return CSInfo;
@@ -434,11 +479,11 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
return ConstCSInfo[Args];
}
-void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
unsigned *NumUnsafeUses) {
- auto &CSI = findCallSiteInfo(CS);
+ auto &CSI = findCallSiteInfo(CB);
CSI.AllCallSitesDevirted = false;
- CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
+ CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
}
struct DevirtModule {
@@ -454,6 +499,10 @@ struct DevirtModule {
IntegerType *Int32Ty;
IntegerType *Int64Ty;
IntegerType *IntPtrTy;
+ /// Sizeless array type, used for imported vtables. This provides a signal
+ /// to analyzers that these imports may alias, as they do for example
+ /// when multiple unique return values occur in the same vtable.
+ ArrayType *Int8Arr0Ty;
bool RemarksEnabled;
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
@@ -469,6 +518,7 @@ struct DevirtModule {
// eliminate the type check by RAUWing the associated llvm.type.test call with
// true.
std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
+ PatternList FunctionsToSkip;
DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
@@ -482,13 +532,17 @@ struct DevirtModule {
Int32Ty(Type::getInt32Ty(M.getContext())),
Int64Ty(Type::getInt64Ty(M.getContext())),
IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
+ Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)),
RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) {
assert(!(ExportSummary && ImportSummary));
+ FunctionsToSkip.init(SkipFunctionNames);
}
bool areRemarksEnabled();
- void scanTypeTestUsers(Function *TypeTestFunc);
+ void
+ scanTypeTestUsers(Function *TypeTestFunc,
+ DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
void buildTypeIdentifierMap(
@@ -592,12 +646,16 @@ struct DevirtIndex {
MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
+ PatternList FunctionsToSkip;
+
DevirtIndex(
ModuleSummaryIndex &ExportSummary,
std::set<GlobalValue::GUID> &ExportedGUIDs,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap)
: ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
- LocalWPDTargetsMap(LocalWPDTargetsMap) {}
+ LocalWPDTargetsMap(LocalWPDTargetsMap) {
+ FunctionsToSkip.init(SkipFunctionNames);
+ }
bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
const TypeIdCompatibleVtableInfo TIdInfo,
@@ -702,7 +760,49 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return PreservedAnalyses::none();
}
+// Enable whole program visibility if enabled by client (e.g. linker) or
+// internal option, and not force disabled.
+static bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
+ return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
+ !DisableWholeProgramVisibility;
+}
+
namespace llvm {
+
+/// If whole program visibility asserted, then upgrade all public vcall
+/// visibility metadata on vtable definitions to linkage unit visibility in
+/// Module IR (for regular or hybrid LTO).
+void updateVCallVisibilityInModule(Module &M,
+ bool WholeProgramVisibilityEnabledInLTO) {
+ if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
+ return;
+ for (GlobalVariable &GV : M.globals())
+ // Add linkage unit visibility to any variable with type metadata, which are
+ // the vtable definitions. We won't have an existing vcall_visibility
+ // metadata on vtable definitions with public visibility.
+ if (GV.hasMetadata(LLVMContext::MD_type) &&
+ GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
+ GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
+}
+
+/// If whole program visibility asserted, then upgrade all public vcall
+/// visibility metadata on vtable definition summaries to linkage unit
+/// visibility in Module summary index (for ThinLTO).
+void updateVCallVisibilityInIndex(ModuleSummaryIndex &Index,
+ bool WholeProgramVisibilityEnabledInLTO) {
+ if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
+ return;
+ for (auto &P : Index) {
+ for (auto &S : P.second.SummaryList) {
+ auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
+ if (!GVar || GVar->vTableFuncs().empty() ||
+ GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
+ continue;
+ GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
+ }
+ }
+}
+
void runWholeProgramDevirtOnIndex(
ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
@@ -737,11 +837,27 @@ void updateIndexWPDForExports(
} // end namespace llvm
+static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
+ // Check that summary index contains regular LTO module when performing
+ // export to prevent occasional use of index from pure ThinLTO compilation
+ // (-fno-split-lto-module). This kind of summary index is passed to
+ // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
+ const auto &ModPaths = Summary->modulePaths();
+ if (ClSummaryAction != PassSummaryAction::Import &&
+ ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) ==
+ ModPaths.end())
+ return createStringError(
+ errc::invalid_argument,
+ "combined summary should contain Regular LTO module");
+ return ErrorSuccess();
+}
+
bool DevirtModule::runForTesting(
Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
function_ref<DominatorTree &(Function &)> LookupDomTree) {
- ModuleSummaryIndex Summary(/*HaveGVs=*/false);
+ std::unique_ptr<ModuleSummaryIndex> Summary =
+ std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false);
// Handle the command-line summary arguments. This code is for testing
// purposes only, so we handle errors directly.
@@ -750,28 +866,41 @@ bool DevirtModule::runForTesting(
": ");
auto ReadSummaryFile =
ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
-
- yaml::Input In(ReadSummaryFile->getBuffer());
- In >> Summary;
- ExitOnErr(errorCodeToError(In.error()));
+ if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr =
+ getModuleSummaryIndex(*ReadSummaryFile)) {
+ Summary = std::move(*SummaryOrErr);
+ ExitOnErr(checkCombinedSummaryForTesting(Summary.get()));
+ } else {
+ // Try YAML if we've failed with bitcode.
+ consumeError(SummaryOrErr.takeError());
+ yaml::Input In(ReadSummaryFile->getBuffer());
+ In >> *Summary;
+ ExitOnErr(errorCodeToError(In.error()));
+ }
}
bool Changed =
- DevirtModule(
- M, AARGetter, OREGetter, LookupDomTree,
- ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
- ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
+ DevirtModule(M, AARGetter, OREGetter, LookupDomTree,
+ ClSummaryAction == PassSummaryAction::Export ? Summary.get()
+ : nullptr,
+ ClSummaryAction == PassSummaryAction::Import ? Summary.get()
+ : nullptr)
.run();
if (!ClWriteSummary.empty()) {
ExitOnError ExitOnErr(
"-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
std::error_code EC;
- raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text);
- ExitOnErr(errorCodeToError(EC));
-
- yaml::Output Out(OS);
- Out << Summary;
+ if (StringRef(ClWriteSummary).endswith(".bc")) {
+ raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
+ ExitOnErr(errorCodeToError(EC));
+ WriteIndexToFile(*Summary, OS);
+ } else {
+ raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text);
+ ExitOnErr(errorCodeToError(EC));
+ yaml::Output Out(OS);
+ Out << *Summary;
+ }
}
return Changed;
@@ -818,6 +947,12 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (!TM.Bits->GV->isConstant())
return false;
+ // We cannot perform whole program devirtualization analysis on a vtable
+ // with public LTO visibility.
+ if (TM.Bits->GV->getVCallVisibility() ==
+ GlobalObject::VCallVisibilityPublic)
+ return false;
+
Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
TM.Offset + ByteOffset, M);
if (!Ptr)
@@ -827,6 +962,9 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (!Fn)
return false;
+ if (FunctionsToSkip.match(Fn->getName()))
+ return false;
+
// We can disregard __cxa_pure_virtual as a possible call target, as
// calls to pure virtuals are UB.
if (Fn->getName() == "__cxa_pure_virtual")
@@ -863,8 +1001,13 @@ bool DevirtIndex::tryFindVirtualCallTargets(
return false;
LocalFound = true;
}
- if (!GlobalValue::isAvailableExternallyLinkage(S->linkage()))
+ if (!GlobalValue::isAvailableExternallyLinkage(S->linkage())) {
VS = cast<GlobalVarSummary>(S->getBaseObject());
+ // We cannot perform whole program devirtualization analysis on a vtable
+ // with public LTO visibility.
+ if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
+ return false;
+ }
}
if (!VS->isLive())
continue;
@@ -887,8 +1030,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
if (RemarksEnabled)
VCallSite.emitRemark("single-impl",
TheFn->stripPointerCasts()->getName(), OREGetter);
- VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
- TheFn, VCallSite.CS.getCalledValue()->getType()));
+ VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast(
+ TheFn, VCallSite.CB.getCalledOperand()->getType()));
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
@@ -979,7 +1122,7 @@ bool DevirtModule::trySingleImplDevirt(
AddCalls(SlotInfo, TheFnVI);
Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
- Res->SingleImplName = TheFn->getName();
+ Res->SingleImplName = std::string(TheFn->getName());
return true;
}
@@ -1001,6 +1144,11 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
if (!Size)
return false;
+ // Don't devirtualize function if we're told to skip it
+ // in -wholeprogramdevirt-skip.
+ if (FunctionsToSkip.match(TheFn.name()))
+ return false;
+
// If the summary list contains multiple summaries where at least one is
// a local, give up, as we won't know which (possibly promoted) name to use.
for (auto &S : TheFn.getSummaryList())
@@ -1028,10 +1176,10 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
TheFn.name(), ExportSummary.getModuleHash(S->modulePath()));
else {
LocalWPDTargetsMap[TheFn].push_back(SlotSummary);
- Res->SingleImplName = TheFn.name();
+ Res->SingleImplName = std::string(TheFn.name());
}
} else
- Res->SingleImplName = TheFn.name();
+ Res->SingleImplName = std::string(TheFn.name());
// Name will be empty if this thin link driven off of serialized combined
// index (e.g. llvm-lto). However, WPD is not supported/invoked for the
@@ -1106,10 +1254,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
if (CSInfo.AllCallSitesDevirted)
return;
for (auto &&VCallSite : CSInfo.CallSites) {
- CallSite CS = VCallSite.CS;
+ CallBase &CB = VCallSite.CB;
// Jump tables are only profitable if the retpoline mitigation is enabled.
- Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+ Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (FSAttr.hasAttribute(Attribute::None) ||
!FSAttr.getValueAsString().contains("+retpoline"))
continue;
@@ -1122,42 +1270,40 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// x86_64.
std::vector<Type *> NewArgs;
NewArgs.push_back(Int8PtrTy);
- for (Type *T : CS.getFunctionType()->params())
+ for (Type *T : CB.getFunctionType()->params())
NewArgs.push_back(T);
FunctionType *NewFT =
- FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
- CS.getFunctionType()->isVarArg());
+ FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
+ CB.getFunctionType()->isVarArg());
PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
- IRBuilder<> IRB(CS.getInstruction());
+ IRBuilder<> IRB(&CB);
std::vector<Value *> Args;
Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
- for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
- Args.push_back(CS.getArgOperand(I));
+ Args.insert(Args.end(), CB.arg_begin(), CB.arg_end());
- CallSite NewCS;
- if (CS.isCall())
+ CallBase *NewCS = nullptr;
+ if (isa<CallInst>(CB))
NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
else
- NewCS = IRB.CreateInvoke(
- NewFT, IRB.CreateBitCast(JT, NewFTPtr),
- cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
- cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
- NewCS.setCallingConv(CS.getCallingConv());
+ NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr),
+ cast<InvokeInst>(CB).getNormalDest(),
+ cast<InvokeInst>(CB).getUnwindDest(), Args);
+ NewCS->setCallingConv(CB.getCallingConv());
- AttributeList Attrs = CS.getAttributes();
+ AttributeList Attrs = CB.getAttributes();
std::vector<AttributeSet> NewArgAttrs;
NewArgAttrs.push_back(AttributeSet::get(
M.getContext(), ArrayRef<Attribute>{Attribute::get(
M.getContext(), Attribute::Nest)}));
for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
NewArgAttrs.push_back(Attrs.getParamAttributes(I));
- NewCS.setAttributes(
+ NewCS->setAttributes(
AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
Attrs.getRetAttributes(), NewArgAttrs));
- CS->replaceAllUsesWith(NewCS.getInstruction());
- CS->eraseFromParent();
+ CB.replaceAllUsesWith(NewCS);
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
@@ -1208,7 +1354,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
for (auto Call : CSInfo.CallSites)
Call.replaceAndErase(
"uniform-ret-val", FnName, RemarksEnabled, OREGetter,
- ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
+ ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
CSInfo.markDevirt();
}
@@ -1273,7 +1419,8 @@ void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
StringRef Name) {
- Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
+ Constant *C =
+ M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty);
auto *GV = dyn_cast<GlobalVariable>(C);
if (GV)
GV->setVisibility(GlobalValue::HiddenVisibility);
@@ -1313,11 +1460,11 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
bool IsOne,
Constant *UniqueMemberAddr) {
for (auto &&Call : CSInfo.CallSites) {
- IRBuilder<> B(Call.CS.getInstruction());
+ IRBuilder<> B(&Call.CB);
Value *Cmp =
- B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
- B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr);
- Cmp = B.CreateZExt(Cmp, Call.CS->getType());
+ B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
+ B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
+ Cmp = B.CreateZExt(Cmp, Call.CB.getType());
Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
Cmp);
}
@@ -1381,8 +1528,8 @@ bool DevirtModule::tryUniqueRetValOpt(
void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
Constant *Byte, Constant *Bit) {
for (auto Call : CSInfo.CallSites) {
- auto *RetType = cast<IntegerType>(Call.CS.getType());
- IRBuilder<> B(Call.CS.getInstruction());
+ auto *RetType = cast<IntegerType>(Call.CB.getType());
+ IRBuilder<> B(&Call.CB);
Value *Addr =
B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
if (RetType->getBitWidth() == 1) {
@@ -1507,10 +1654,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
// Align the before byte array to the global's minimum alignment so that we
// don't break any alignment requirements on the global.
- MaybeAlign Alignment(B.GV->getAlignment());
- if (!Alignment)
- Alignment =
- Align(M.getDataLayout().getABITypeAlignment(B.GV->getValueType()));
+ Align Alignment = M.getDataLayout().getValueOrABITypeAlignment(
+ B.GV->getAlign(), B.GV->getValueType());
B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment));
// Before was stored in reverse order; flip it now.
@@ -1562,13 +1707,14 @@ bool DevirtModule::areRemarksEnabled() {
return false;
}
-void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
+void DevirtModule::scanTypeTestUsers(
+ Function *TypeTestFunc,
+ DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
// Find all virtual calls via a virtual table pointer %p under an assumption
// of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
// points to a member of the type identifier %md. Group calls by (type ID,
// offset) pair (effectively the identity of the virtual function) and store
// to CallSlots.
- DenseSet<CallSite> SeenCallSites;
for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
I != E;) {
auto CI = dyn_cast<CallInst>(I->getUser());
@@ -1582,29 +1728,59 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
auto &DT = LookupDomTree(*CI->getFunction());
findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
+ Metadata *TypeId =
+ cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
// If we found any, add them to CallSlots.
if (!Assumes.empty()) {
- Metadata *TypeId =
- cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
- for (DevirtCallSite Call : DevirtCalls) {
- // Only add this CallSite if we haven't seen it before. The vtable
- // pointer may have been CSE'd with pointers from other call sites,
- // and we don't want to process call sites multiple times. We can't
- // just skip the vtable Ptr if it has been seen before, however, since
- // it may be shared by type tests that dominate different calls.
- if (SeenCallSites.insert(Call.CS).second)
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
- }
+ for (DevirtCallSite Call : DevirtCalls)
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
}
- // We no longer need the assumes or the type test.
- for (auto Assume : Assumes)
- Assume->eraseFromParent();
- // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
- // may use the vtable argument later.
- if (CI->use_empty())
- CI->eraseFromParent();
+ auto RemoveTypeTestAssumes = [&]() {
+ // We no longer need the assumes or the type test.
+ for (auto Assume : Assumes)
+ Assume->eraseFromParent();
+ // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
+ // may use the vtable argument later.
+ if (CI->use_empty())
+ CI->eraseFromParent();
+ };
+
+ // At this point we could remove all type test assume sequences, as they
+ // were originally inserted for WPD. However, we can keep these in the
+ // code stream for later analysis (e.g. to help drive more efficient ICP
+ // sequences). They will eventually be removed by a second LowerTypeTests
+ // invocation that cleans them up. In order to do this correctly, the first
+ // LowerTypeTests invocation needs to know that they have "Unknown" type
+ // test resolution, so that they aren't treated as Unsat and lowered to
+ // False, which will break any uses on assumes. Below we remove any type
+ // test assumes that will not be treated as Unknown by LTT.
+
+ // The type test assumes will be treated by LTT as Unsat if the type id is
+ // not used on a global (in which case it has no entry in the TypeIdMap).
+ if (!TypeIdMap.count(TypeId))
+ RemoveTypeTestAssumes();
+
+ // For ThinLTO importing, we need to remove the type test assumes if this is
+ // an MDString type id without a corresponding TypeIdSummary. Any
+ // non-MDString type ids are ignored and treated as Unknown by LTT, so their
+ // type test assumes can be kept. If the MDString type id is missing a
+ // TypeIdSummary (e.g. because there was no use on a vcall, preventing the
+ // exporting phase of WPD from analyzing it), then it would be treated as
+ // Unsat by LTT and we need to remove its type test assumes here. If not
+ // used on a vcall we don't need them for later optimization use in any
+ // case.
+ else if (ImportSummary && isa<MDString>(TypeId)) {
+ const TypeIdSummary *TidSummary =
+ ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString());
+ if (!TidSummary)
+ RemoveTypeTestAssumes();
+ else
+ // If one was created it should not be Unsat, because if we reached here
+ // the type id was used on a global.
+ assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat);
+ }
}
}
@@ -1680,7 +1856,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (HasNonCallUses)
++NumUnsafeUses;
for (DevirtCallSite Call : DevirtCalls) {
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
&NumUnsafeUses);
}
@@ -1796,8 +1972,13 @@ bool DevirtModule::run() {
(!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
return false;
+ // Rebuild type metadata into a map for easy lookup.
+ std::vector<VTableBits> Bits;
+ DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
+ buildTypeIdentifierMap(Bits, TypeIdMap);
+
if (TypeTestFunc && AssumeFunc)
- scanTypeTestUsers(TypeTestFunc);
+ scanTypeTestUsers(TypeTestFunc, TypeIdMap);
if (TypeCheckedLoadFunc)
scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
@@ -1808,15 +1989,17 @@ bool DevirtModule::run() {
removeRedundantTypeTests();
+ // We have lowered or deleted the type instrinsics, so we will no
+ // longer have enough information to reason about the liveness of virtual
+ // function pointers in GlobalDCE.
+ for (GlobalVariable &GV : M.globals())
+ GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
+
// The rest of the code is only necessary when exporting or during regular
// LTO, so we are done.
return true;
}
- // Rebuild type metadata into a map for easy lookup.
- std::vector<VTableBits> Bits;
- DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
- buildTypeIdentifierMap(Bits, TypeIdMap);
if (TypeIdMap.empty())
return true;
@@ -1873,14 +2056,22 @@ bool DevirtModule::run() {
// function implementation at offset S.first.ByteOffset, and add to
// TargetsForSlot.
std::vector<VirtualCallTarget> TargetsForSlot;
- if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
+ WholeProgramDevirtResolution *Res = nullptr;
+ const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID];
+ if (ExportSummary && isa<MDString>(S.first.TypeID) &&
+ TypeMemberInfos.size())
+ // For any type id used on a global's type metadata, create the type id
+ // summary resolution regardless of whether we can devirtualize, so that
+ // lower type tests knows the type id is not Unsat. If it was not used on
+ // a global's type metadata, the TypeIdMap entry set will be empty, and
+ // we don't want to create an entry (with the default Unknown type
+ // resolution), which can prevent detection of the Unsat.
+ Res = &ExportSummary
+ ->getOrInsertTypeIdSummary(
+ cast<MDString>(S.first.TypeID)->getString())
+ .WPDRes[S.first.ByteOffset];
+ if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos,
S.first.ByteOffset)) {
- WholeProgramDevirtResolution *Res = nullptr;
- if (ExportSummary && isa<MDString>(S.first.TypeID))
- Res = &ExportSummary
- ->getOrInsertTypeIdSummary(
- cast<MDString>(S.first.TypeID)->getString())
- .WPDRes[S.first.ByteOffset];
if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) {
DidVirtualConstProp |=
@@ -1893,7 +2084,7 @@ bool DevirtModule::run() {
if (RemarksEnabled)
for (const auto &T : TargetsForSlot)
if (T.WasDevirt)
- DevirtTargets[T.Fn->getName()] = T.Fn;
+ DevirtTargets[std::string(T.Fn->getName())] = T.Fn;
}
// CFI-specific: if we are exporting and any llvm.type.checked.load
@@ -1931,7 +2122,7 @@ bool DevirtModule::run() {
for (VTableBits &B : Bits)
rebuildGlobal(B);
- // We have lowered or deleted the type checked load intrinsics, so we no
+ // We have lowered or deleted the type instrinsics, so we will no
// longer have enough information to reason about the liveness of virtual
// function pointers in GlobalDCE.
for (GlobalVariable &GV : M.globals())
@@ -1994,11 +2185,14 @@ void DevirtIndex::run() {
std::vector<ValueInfo> TargetsForSlot;
auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID);
assert(TidSummary);
+ // Create the type id summary resolution regardlness of whether we can
+ // devirtualize, so that lower type tests knows the type id is used on
+ // a global and not Unsat.
+ WholeProgramDevirtResolution *Res =
+ &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID)
+ .WPDRes[S.first.ByteOffset];
if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary,
S.first.ByteOffset)) {
- WholeProgramDevirtResolution *Res =
- &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID)
- .WPDRes[S.first.ByteOffset];
if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res,
DevirtTargets))