diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 422 |
1 files changed, 308 insertions, 114 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 5ccfb29b01a1..5a25f9857665 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -57,12 +57,14 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TypeMetadataUtils.h" -#include "llvm/IR/CallSite.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" @@ -83,11 +85,12 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" -#include "llvm/PassSupport.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/GlobPattern.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" @@ -115,12 +118,15 @@ static cl::opt<PassSummaryAction> ClSummaryAction( static cl::opt<std::string> ClReadSummary( "wholeprogramdevirt-read-summary", - cl::desc("Read summary from given YAML file before running pass"), + cl::desc( + "Read summary from given bitcode or YAML file before running pass"), cl::Hidden); static cl::opt<std::string> ClWriteSummary( "wholeprogramdevirt-write-summary", - cl::desc("Write summary to given YAML file after running pass"), + cl::desc("Write summary to given bitcode or YAML file after running pass. " + "Output file format is deduced from extension: *.bc means writing " + "bitcode, otherwise YAML"), cl::Hidden); static cl::opt<unsigned> @@ -134,6 +140,45 @@ static cl::opt<bool> cl::init(false), cl::ZeroOrMore, cl::desc("Print index-based devirtualization messages")); +/// Provide a way to force enable whole program visibility in tests. +/// This is needed to support legacy tests that don't contain +/// !vcall_visibility metadata (the mere presense of type tests +/// previously implied hidden visibility). +cl::opt<bool> + WholeProgramVisibility("whole-program-visibility", cl::init(false), + cl::Hidden, cl::ZeroOrMore, + cl::desc("Enable whole program visibility")); + +/// Provide a way to force disable whole program for debugging or workarounds, +/// when enabled via the linker. +cl::opt<bool> DisableWholeProgramVisibility( + "disable-whole-program-visibility", cl::init(false), cl::Hidden, + cl::ZeroOrMore, + cl::desc("Disable whole program visibility (overrides enabling options)")); + +/// Provide way to prevent certain function from being devirtualized +cl::list<std::string> + SkipFunctionNames("wholeprogramdevirt-skip", + cl::desc("Prevent function(s) from being devirtualized"), + cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated); + +namespace { +struct PatternList { + std::vector<GlobPattern> Patterns; + template <class T> void init(const T &StringList) { + for (const auto &S : StringList) + if (Expected<GlobPattern> Pat = GlobPattern::create(S)) + Patterns.push_back(std::move(*Pat)); + } + bool match(StringRef S) { + for (const GlobPattern &P : Patterns) + if (P.match(S)) + return true; + return false; + } +}; +} // namespace + // Find the minimum offset that we may store a value of size Size bits at. If // IsAfter is set, look for an offset before the object, otherwise look for an // offset after the object. @@ -308,20 +353,20 @@ namespace { // A virtual call site. VTable is the loaded virtual table pointer, and CS is // the indirect virtual call. struct VirtualCallSite { - Value *VTable; - CallSite CS; + Value *VTable = nullptr; + CallBase &CB; // If non-null, this field points to the associated unsafe use count stored in // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description // of that field for details. - unsigned *NumUnsafeUses; + unsigned *NumUnsafeUses = nullptr; void emitRemark(const StringRef OptName, const StringRef TargetName, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { - Function *F = CS.getCaller(); - DebugLoc DLoc = CS->getDebugLoc(); - BasicBlock *Block = CS.getParent(); + Function *F = CB.getCaller(); + DebugLoc DLoc = CB.getDebugLoc(); + BasicBlock *Block = CB.getParent(); using namespace ore; OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) @@ -336,12 +381,12 @@ struct VirtualCallSite { Value *New) { if (RemarksEnabled) emitRemark(OptName, TargetName, OREGetter); - CS->replaceAllUsesWith(New); - if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { - BranchInst::Create(II->getNormalDest(), CS.getInstruction()); + CB.replaceAllUsesWith(New); + if (auto *II = dyn_cast<InvokeInst>(&CB)) { + BranchInst::Create(II->getNormalDest(), &CB); II->getUnwindDest()->removePredecessor(II->getParent()); } - CS->eraseFromParent(); + CB.eraseFromParent(); // This use is no longer unsafe. if (NumUnsafeUses) --*NumUnsafeUses; @@ -414,18 +459,18 @@ struct VTableSlotInfo { // "this"), grouped by argument list. std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; - void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); private: - CallSiteInfo &findCallSiteInfo(CallSite CS); + CallSiteInfo &findCallSiteInfo(CallBase &CB); }; -CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { std::vector<uint64_t> Args; - auto *CI = dyn_cast<IntegerType>(CS.getType()); - if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) + auto *CBType = dyn_cast<IntegerType>(CB.getType()); + if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) return CSInfo; - for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { + for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) { auto *CI = dyn_cast<ConstantInt>(Arg); if (!CI || CI->getBitWidth() > 64) return CSInfo; @@ -434,11 +479,11 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { return ConstCSInfo[Args]; } -void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, +void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses) { - auto &CSI = findCallSiteInfo(CS); + auto &CSI = findCallSiteInfo(CB); CSI.AllCallSitesDevirted = false; - CSI.CallSites.push_back({VTable, CS, NumUnsafeUses}); + CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); } struct DevirtModule { @@ -454,6 +499,10 @@ struct DevirtModule { IntegerType *Int32Ty; IntegerType *Int64Ty; IntegerType *IntPtrTy; + /// Sizeless array type, used for imported vtables. This provides a signal + /// to analyzers that these imports may alias, as they do for example + /// when multiple unique return values occur in the same vtable. + ArrayType *Int8Arr0Ty; bool RemarksEnabled; function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; @@ -469,6 +518,7 @@ struct DevirtModule { // eliminate the type check by RAUWing the associated llvm.type.test call with // true. std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; + PatternList FunctionsToSkip; DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, @@ -482,13 +532,17 @@ struct DevirtModule { Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), + Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { assert(!(ExportSummary && ImportSummary)); + FunctionsToSkip.init(SkipFunctionNames); } bool areRemarksEnabled(); - void scanTypeTestUsers(Function *TypeTestFunc); + void + scanTypeTestUsers(Function *TypeTestFunc, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); void buildTypeIdentifierMap( @@ -592,12 +646,16 @@ struct DevirtIndex { MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; + PatternList FunctionsToSkip; + DevirtIndex( ModuleSummaryIndex &ExportSummary, std::set<GlobalValue::GUID> &ExportedGUIDs, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), - LocalWPDTargetsMap(LocalWPDTargetsMap) {} + LocalWPDTargetsMap(LocalWPDTargetsMap) { + FunctionsToSkip.init(SkipFunctionNames); + } bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, @@ -702,7 +760,49 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } +// Enable whole program visibility if enabled by client (e.g. linker) or +// internal option, and not force disabled. +static bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { + return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && + !DisableWholeProgramVisibility; +} + namespace llvm { + +/// If whole program visibility asserted, then upgrade all public vcall +/// visibility metadata on vtable definitions to linkage unit visibility in +/// Module IR (for regular or hybrid LTO). +void updateVCallVisibilityInModule(Module &M, + bool WholeProgramVisibilityEnabledInLTO) { + if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) + return; + for (GlobalVariable &GV : M.globals()) + // Add linkage unit visibility to any variable with type metadata, which are + // the vtable definitions. We won't have an existing vcall_visibility + // metadata on vtable definitions with public visibility. + if (GV.hasMetadata(LLVMContext::MD_type) && + GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic) + GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); +} + +/// If whole program visibility asserted, then upgrade all public vcall +/// visibility metadata on vtable definition summaries to linkage unit +/// visibility in Module summary index (for ThinLTO). +void updateVCallVisibilityInIndex(ModuleSummaryIndex &Index, + bool WholeProgramVisibilityEnabledInLTO) { + if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) + return; + for (auto &P : Index) { + for (auto &S : P.second.SummaryList) { + auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); + if (!GVar || GVar->vTableFuncs().empty() || + GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) + continue; + GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); + } + } +} + void runWholeProgramDevirtOnIndex( ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { @@ -737,11 +837,27 @@ void updateIndexWPDForExports( } // end namespace llvm +static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { + // Check that summary index contains regular LTO module when performing + // export to prevent occasional use of index from pure ThinLTO compilation + // (-fno-split-lto-module). This kind of summary index is passed to + // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. + const auto &ModPaths = Summary->modulePaths(); + if (ClSummaryAction != PassSummaryAction::Import && + ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) == + ModPaths.end()) + return createStringError( + errc::invalid_argument, + "combined summary should contain Regular LTO module"); + return ErrorSuccess(); +} + bool DevirtModule::runForTesting( Module &M, function_ref<AAResults &(Function &)> AARGetter, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, function_ref<DominatorTree &(Function &)> LookupDomTree) { - ModuleSummaryIndex Summary(/*HaveGVs=*/false); + std::unique_ptr<ModuleSummaryIndex> Summary = + std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false); // Handle the command-line summary arguments. This code is for testing // purposes only, so we handle errors directly. @@ -750,28 +866,41 @@ bool DevirtModule::runForTesting( ": "); auto ReadSummaryFile = ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); - - yaml::Input In(ReadSummaryFile->getBuffer()); - In >> Summary; - ExitOnErr(errorCodeToError(In.error())); + if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr = + getModuleSummaryIndex(*ReadSummaryFile)) { + Summary = std::move(*SummaryOrErr); + ExitOnErr(checkCombinedSummaryForTesting(Summary.get())); + } else { + // Try YAML if we've failed with bitcode. + consumeError(SummaryOrErr.takeError()); + yaml::Input In(ReadSummaryFile->getBuffer()); + In >> *Summary; + ExitOnErr(errorCodeToError(In.error())); + } } bool Changed = - DevirtModule( - M, AARGetter, OREGetter, LookupDomTree, - ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, - ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) + DevirtModule(M, AARGetter, OREGetter, LookupDomTree, + ClSummaryAction == PassSummaryAction::Export ? Summary.get() + : nullptr, + ClSummaryAction == PassSummaryAction::Import ? Summary.get() + : nullptr) .run(); if (!ClWriteSummary.empty()) { ExitOnError ExitOnErr( "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text); - ExitOnErr(errorCodeToError(EC)); - - yaml::Output Out(OS); - Out << Summary; + if (StringRef(ClWriteSummary).endswith(".bc")) { + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); + ExitOnErr(errorCodeToError(EC)); + WriteIndexToFile(*Summary, OS); + } else { + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text); + ExitOnErr(errorCodeToError(EC)); + yaml::Output Out(OS); + Out << *Summary; + } } return Changed; @@ -818,6 +947,12 @@ bool DevirtModule::tryFindVirtualCallTargets( if (!TM.Bits->GV->isConstant()) return false; + // We cannot perform whole program devirtualization analysis on a vtable + // with public LTO visibility. + if (TM.Bits->GV->getVCallVisibility() == + GlobalObject::VCallVisibilityPublic) + return false; + Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), TM.Offset + ByteOffset, M); if (!Ptr) @@ -827,6 +962,9 @@ bool DevirtModule::tryFindVirtualCallTargets( if (!Fn) return false; + if (FunctionsToSkip.match(Fn->getName())) + return false; + // We can disregard __cxa_pure_virtual as a possible call target, as // calls to pure virtuals are UB. if (Fn->getName() == "__cxa_pure_virtual") @@ -863,8 +1001,13 @@ bool DevirtIndex::tryFindVirtualCallTargets( return false; LocalFound = true; } - if (!GlobalValue::isAvailableExternallyLinkage(S->linkage())) + if (!GlobalValue::isAvailableExternallyLinkage(S->linkage())) { VS = cast<GlobalVarSummary>(S->getBaseObject()); + // We cannot perform whole program devirtualization analysis on a vtable + // with public LTO visibility. + if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) + return false; + } } if (!VS->isLive()) continue; @@ -887,8 +1030,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, if (RemarksEnabled) VCallSite.emitRemark("single-impl", TheFn->stripPointerCasts()->getName(), OREGetter); - VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( - TheFn, VCallSite.CS.getCalledValue()->getType())); + VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast( + TheFn, VCallSite.CB.getCalledOperand()->getType())); // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) --*VCallSite.NumUnsafeUses; @@ -979,7 +1122,7 @@ bool DevirtModule::trySingleImplDevirt( AddCalls(SlotInfo, TheFnVI); Res->TheKind = WholeProgramDevirtResolution::SingleImpl; - Res->SingleImplName = TheFn->getName(); + Res->SingleImplName = std::string(TheFn->getName()); return true; } @@ -1001,6 +1144,11 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, if (!Size) return false; + // Don't devirtualize function if we're told to skip it + // in -wholeprogramdevirt-skip. + if (FunctionsToSkip.match(TheFn.name())) + return false; + // If the summary list contains multiple summaries where at least one is // a local, give up, as we won't know which (possibly promoted) name to use. for (auto &S : TheFn.getSummaryList()) @@ -1028,10 +1176,10 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); else { LocalWPDTargetsMap[TheFn].push_back(SlotSummary); - Res->SingleImplName = TheFn.name(); + Res->SingleImplName = std::string(TheFn.name()); } } else - Res->SingleImplName = TheFn.name(); + Res->SingleImplName = std::string(TheFn.name()); // Name will be empty if this thin link driven off of serialized combined // index (e.g. llvm-lto). However, WPD is not supported/invoked for the @@ -1106,10 +1254,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, if (CSInfo.AllCallSitesDevirted) return; for (auto &&VCallSite : CSInfo.CallSites) { - CallSite CS = VCallSite.CS; + CallBase &CB = VCallSite.CB; // Jump tables are only profitable if the retpoline mitigation is enabled. - Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features"); + Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); if (FSAttr.hasAttribute(Attribute::None) || !FSAttr.getValueAsString().contains("+retpoline")) continue; @@ -1122,42 +1270,40 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // x86_64. std::vector<Type *> NewArgs; NewArgs.push_back(Int8PtrTy); - for (Type *T : CS.getFunctionType()->params()) + for (Type *T : CB.getFunctionType()->params()) NewArgs.push_back(T); FunctionType *NewFT = - FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, - CS.getFunctionType()->isVarArg()); + FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, + CB.getFunctionType()->isVarArg()); PointerType *NewFTPtr = PointerType::getUnqual(NewFT); - IRBuilder<> IRB(CS.getInstruction()); + IRBuilder<> IRB(&CB); std::vector<Value *> Args; Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); - for (unsigned I = 0; I != CS.getNumArgOperands(); ++I) - Args.push_back(CS.getArgOperand(I)); + Args.insert(Args.end(), CB.arg_begin(), CB.arg_end()); - CallSite NewCS; - if (CS.isCall()) + CallBase *NewCS = nullptr; + if (isa<CallInst>(CB)) NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); else - NewCS = IRB.CreateInvoke( - NewFT, IRB.CreateBitCast(JT, NewFTPtr), - cast<InvokeInst>(CS.getInstruction())->getNormalDest(), - cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args); - NewCS.setCallingConv(CS.getCallingConv()); + NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr), + cast<InvokeInst>(CB).getNormalDest(), + cast<InvokeInst>(CB).getUnwindDest(), Args); + NewCS->setCallingConv(CB.getCallingConv()); - AttributeList Attrs = CS.getAttributes(); + AttributeList Attrs = CB.getAttributes(); std::vector<AttributeSet> NewArgAttrs; NewArgAttrs.push_back(AttributeSet::get( M.getContext(), ArrayRef<Attribute>{Attribute::get( M.getContext(), Attribute::Nest)})); for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) NewArgAttrs.push_back(Attrs.getParamAttributes(I)); - NewCS.setAttributes( + NewCS->setAttributes( AttributeList::get(M.getContext(), Attrs.getFnAttributes(), Attrs.getRetAttributes(), NewArgAttrs)); - CS->replaceAllUsesWith(NewCS.getInstruction()); - CS->eraseFromParent(); + CB.replaceAllUsesWith(NewCS); + CB.eraseFromParent(); // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1208,7 +1354,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, for (auto Call : CSInfo.CallSites) Call.replaceAndErase( "uniform-ret-val", FnName, RemarksEnabled, OREGetter, - ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); + ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); CSInfo.markDevirt(); } @@ -1273,7 +1419,8 @@ void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name) { - Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); + Constant *C = + M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty); auto *GV = dyn_cast<GlobalVariable>(C); if (GV) GV->setVisibility(GlobalValue::HiddenVisibility); @@ -1313,11 +1460,11 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, Constant *UniqueMemberAddr) { for (auto &&Call : CSInfo.CallSites) { - IRBuilder<> B(Call.CS.getInstruction()); + IRBuilder<> B(&Call.CB); Value *Cmp = - B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr); - Cmp = B.CreateZExt(Cmp, Call.CS->getType()); + B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, + B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); + Cmp = B.CreateZExt(Cmp, Call.CB.getType()); Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, Cmp); } @@ -1381,8 +1528,8 @@ bool DevirtModule::tryUniqueRetValOpt( void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Constant *Byte, Constant *Bit) { for (auto Call : CSInfo.CallSites) { - auto *RetType = cast<IntegerType>(Call.CS.getType()); - IRBuilder<> B(Call.CS.getInstruction()); + auto *RetType = cast<IntegerType>(Call.CB.getType()); + IRBuilder<> B(&Call.CB); Value *Addr = B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); if (RetType->getBitWidth() == 1) { @@ -1507,10 +1654,8 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { // Align the before byte array to the global's minimum alignment so that we // don't break any alignment requirements on the global. - MaybeAlign Alignment(B.GV->getAlignment()); - if (!Alignment) - Alignment = - Align(M.getDataLayout().getABITypeAlignment(B.GV->getValueType())); + Align Alignment = M.getDataLayout().getValueOrABITypeAlignment( + B.GV->getAlign(), B.GV->getValueType()); B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); // Before was stored in reverse order; flip it now. @@ -1562,13 +1707,14 @@ bool DevirtModule::areRemarksEnabled() { return false; } -void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) { +void DevirtModule::scanTypeTestUsers( + Function *TypeTestFunc, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { // Find all virtual calls via a virtual table pointer %p under an assumption // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p // points to a member of the type identifier %md. Group calls by (type ID, // offset) pair (effectively the identity of the virtual function) and store // to CallSlots. - DenseSet<CallSite> SeenCallSites; for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); I != E;) { auto CI = dyn_cast<CallInst>(I->getUser()); @@ -1582,29 +1728,59 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) { auto &DT = LookupDomTree(*CI->getFunction()); findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); + Metadata *TypeId = + cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); // If we found any, add them to CallSlots. if (!Assumes.empty()) { - Metadata *TypeId = - cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); - for (DevirtCallSite Call : DevirtCalls) { - // Only add this CallSite if we haven't seen it before. The vtable - // pointer may have been CSE'd with pointers from other call sites, - // and we don't want to process call sites multiple times. We can't - // just skip the vtable Ptr if it has been seen before, however, since - // it may be shared by type tests that dominate different calls. - if (SeenCallSites.insert(Call.CS).second) - CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr); - } + for (DevirtCallSite Call : DevirtCalls) + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); } - // We no longer need the assumes or the type test. - for (auto Assume : Assumes) - Assume->eraseFromParent(); - // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we - // may use the vtable argument later. - if (CI->use_empty()) - CI->eraseFromParent(); + auto RemoveTypeTestAssumes = [&]() { + // We no longer need the assumes or the type test. + for (auto Assume : Assumes) + Assume->eraseFromParent(); + // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we + // may use the vtable argument later. + if (CI->use_empty()) + CI->eraseFromParent(); + }; + + // At this point we could remove all type test assume sequences, as they + // were originally inserted for WPD. However, we can keep these in the + // code stream for later analysis (e.g. to help drive more efficient ICP + // sequences). They will eventually be removed by a second LowerTypeTests + // invocation that cleans them up. In order to do this correctly, the first + // LowerTypeTests invocation needs to know that they have "Unknown" type + // test resolution, so that they aren't treated as Unsat and lowered to + // False, which will break any uses on assumes. Below we remove any type + // test assumes that will not be treated as Unknown by LTT. + + // The type test assumes will be treated by LTT as Unsat if the type id is + // not used on a global (in which case it has no entry in the TypeIdMap). + if (!TypeIdMap.count(TypeId)) + RemoveTypeTestAssumes(); + + // For ThinLTO importing, we need to remove the type test assumes if this is + // an MDString type id without a corresponding TypeIdSummary. Any + // non-MDString type ids are ignored and treated as Unknown by LTT, so their + // type test assumes can be kept. If the MDString type id is missing a + // TypeIdSummary (e.g. because there was no use on a vcall, preventing the + // exporting phase of WPD from analyzing it), then it would be treated as + // Unsat by LTT and we need to remove its type test assumes here. If not + // used on a vcall we don't need them for later optimization use in any + // case. + else if (ImportSummary && isa<MDString>(TypeId)) { + const TypeIdSummary *TidSummary = + ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString()); + if (!TidSummary) + RemoveTypeTestAssumes(); + else + // If one was created it should not be Unsat, because if we reached here + // the type id was used on a global. + assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat); + } } } @@ -1680,7 +1856,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (HasNonCallUses) ++NumUnsafeUses; for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, &NumUnsafeUses); } @@ -1796,8 +1972,13 @@ bool DevirtModule::run() { (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) return false; + // Rebuild type metadata into a map for easy lookup. + std::vector<VTableBits> Bits; + DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; + buildTypeIdentifierMap(Bits, TypeIdMap); + if (TypeTestFunc && AssumeFunc) - scanTypeTestUsers(TypeTestFunc); + scanTypeTestUsers(TypeTestFunc, TypeIdMap); if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); @@ -1808,15 +1989,17 @@ bool DevirtModule::run() { removeRedundantTypeTests(); + // We have lowered or deleted the type instrinsics, so we will no + // longer have enough information to reason about the liveness of virtual + // function pointers in GlobalDCE. + for (GlobalVariable &GV : M.globals()) + GV.eraseMetadata(LLVMContext::MD_vcall_visibility); + // The rest of the code is only necessary when exporting or during regular // LTO, so we are done. return true; } - // Rebuild type metadata into a map for easy lookup. - std::vector<VTableBits> Bits; - DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; - buildTypeIdentifierMap(Bits, TypeIdMap); if (TypeIdMap.empty()) return true; @@ -1873,14 +2056,22 @@ bool DevirtModule::run() { // function implementation at offset S.first.ByteOffset, and add to // TargetsForSlot. std::vector<VirtualCallTarget> TargetsForSlot; - if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], + WholeProgramDevirtResolution *Res = nullptr; + const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID]; + if (ExportSummary && isa<MDString>(S.first.TypeID) && + TypeMemberInfos.size()) + // For any type id used on a global's type metadata, create the type id + // summary resolution regardless of whether we can devirtualize, so that + // lower type tests knows the type id is not Unsat. If it was not used on + // a global's type metadata, the TypeIdMap entry set will be empty, and + // we don't want to create an entry (with the default Unknown type + // resolution), which can prevent detection of the Unsat. + Res = &ExportSummary + ->getOrInsertTypeIdSummary( + cast<MDString>(S.first.TypeID)->getString()) + .WPDRes[S.first.ByteOffset]; + if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, S.first.ByteOffset)) { - WholeProgramDevirtResolution *Res = nullptr; - if (ExportSummary && isa<MDString>(S.first.TypeID)) - Res = &ExportSummary - ->getOrInsertTypeIdSummary( - cast<MDString>(S.first.TypeID)->getString()) - .WPDRes[S.first.ByteOffset]; if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { DidVirtualConstProp |= @@ -1893,7 +2084,7 @@ bool DevirtModule::run() { if (RemarksEnabled) for (const auto &T : TargetsForSlot) if (T.WasDevirt) - DevirtTargets[T.Fn->getName()] = T.Fn; + DevirtTargets[std::string(T.Fn->getName())] = T.Fn; } // CFI-specific: if we are exporting and any llvm.type.checked.load @@ -1931,7 +2122,7 @@ bool DevirtModule::run() { for (VTableBits &B : Bits) rebuildGlobal(B); - // We have lowered or deleted the type checked load intrinsics, so we no + // We have lowered or deleted the type instrinsics, so we will no // longer have enough information to reason about the liveness of virtual // function pointers in GlobalDCE. for (GlobalVariable &GV : M.globals()) @@ -1994,11 +2185,14 @@ void DevirtIndex::run() { std::vector<ValueInfo> TargetsForSlot; auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); assert(TidSummary); + // Create the type id summary resolution regardlness of whether we can + // devirtualize, so that lower type tests knows the type id is used on + // a global and not Unsat. + WholeProgramDevirtResolution *Res = + &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) + .WPDRes[S.first.ByteOffset]; if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, S.first.ByteOffset)) { - WholeProgramDevirtResolution *Res = - &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) - .WPDRes[S.first.ByteOffset]; if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, DevirtTargets)) |