diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
| commit | b1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch) | |
| tree | 7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | |
| parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
| -rw-r--r-- | llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 129 |
1 files changed, 93 insertions, 36 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index d33258642365..85afc020dbf8 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,7 +58,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -369,8 +368,6 @@ template <> struct DenseMapInfo<VTableSlotSummary> { } // end namespace llvm -namespace { - // Returns true if the function must be unreachable based on ValueInfo. // // In particular, identifies a function as unreachable in the following @@ -378,7 +375,7 @@ namespace { // 1) All summaries are live. // 2) All function summaries indicate it's unreachable // 3) There is no non-function with the same GUID (which is rare) -bool mustBeUnreachableFunction(ValueInfo TheFnVI) { +static bool mustBeUnreachableFunction(ValueInfo TheFnVI) { if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { // Returns false if ValueInfo is absent, or the summary list is empty // (e.g., function declarations). @@ -403,6 +400,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { return true; } +namespace { // A virtual call site. VTable is the loaded virtual table pointer, and CS is // the indirect virtual call. struct VirtualCallSite { @@ -590,7 +588,7 @@ struct DevirtModule { : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), ExportSummary(ExportSummary), ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), - Int8PtrTy(Type::getInt8PtrTy(M.getContext())), + Int8PtrTy(PointerType::getUnqual(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), @@ -776,20 +774,59 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } -namespace llvm { // Enable whole program visibility if enabled by client (e.g. linker) or // internal option, and not force disabled. -bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { +bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && !DisableWholeProgramVisibility; } +static bool +typeIDVisibleToRegularObj(StringRef TypeID, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + // TypeID for member function pointer type is an internal construct + // and won't exist in IsVisibleToRegularObj. The full TypeID + // will be present and participate in invalidation. + if (TypeID.ends_with(".virtual")) + return false; + + // TypeID that doesn't start with Itanium mangling (_ZTS) will be + // non-externally visible types which cannot interact with + // external native files. See CodeGenModule::CreateMetadataIdentifierImpl. + if (!TypeID.consume_front("_ZTS")) + return false; + + // TypeID is keyed off the type name symbol (_ZTS). However, the native + // object may not contain this symbol if it does not contain a key + // function for the base type and thus only contains a reference to the + // type info (_ZTI). To catch this case we query using the type info + // symbol corresponding to the TypeID. + std::string typeInfo = ("_ZTI" + TypeID).str(); + return IsVisibleToRegularObj(typeInfo); +} + +static bool +skipUpdateDueToValidation(GlobalVariable &GV, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + SmallVector<MDNode *, 2> Types; + GV.getMetadata(LLVMContext::MD_type, Types); + + for (auto Type : Types) + if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get())) + return typeIDVisibleToRegularObj(TypeID->getString(), + IsVisibleToRegularObj); + + return false; +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definitions to linkage unit visibility in /// Module IR (for regular or hybrid LTO). -void updateVCallVisibilityInModule( +void llvm::updateVCallVisibilityInModule( Module &M, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + bool ValidateAllVtablesHaveTypeInfos, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (GlobalVariable &GV : M.globals()) { @@ -800,13 +837,19 @@ void updateVCallVisibilityInModule( GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && // Don't upgrade the visibility for symbols exported to the dynamic // linker, as we have no information on their eventual use. - !DynamicExportSymbols.count(GV.getGUID())) + !DynamicExportSymbols.count(GV.getGUID()) && + // With validation enabled, we want to exclude symbols visible to + // regular objects. Local symbols will be in this group due to the + // current implementation but those with VCallVisibilityTranslationUnit + // will have already been marked in clang so are unaffected. + !(ValidateAllVtablesHaveTypeInfos && + skipUpdateDueToValidation(GV, IsVisibleToRegularObj))) GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); } } -void updatePublicTypeTestCalls(Module &M, - bool WholeProgramVisibilityEnabledInLTO) { +void llvm::updatePublicTypeTestCalls(Module &M, + bool WholeProgramVisibilityEnabledInLTO) { Function *PublicTypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::public_type_test)); if (!PublicTypeTestFunc) @@ -832,12 +875,26 @@ void updatePublicTypeTestCalls(Module &M, } } +/// Based on typeID string, get all associated vtable GUIDS that are +/// visible to regular objects. +void llvm::getVisibleToRegularObjVtableGUIDs( + ModuleSummaryIndex &Index, + DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + for (const auto &typeID : Index.typeIdCompatibleVtableMap()) { + if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj)) + for (const TypeIdOffsetVtableInfo &P : typeID.second) + VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID()); + } +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definition summaries to linkage unit /// visibility in Module summary index (for ThinLTO). -void updateVCallVisibilityInIndex( +void llvm::updateVCallVisibilityInIndex( ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (auto &P : Index) { @@ -850,18 +907,24 @@ void updateVCallVisibilityInIndex( if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } } -void runWholeProgramDevirtOnIndex( +void llvm::runWholeProgramDevirtOnIndex( ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); } -void updateIndexWPDForExports( +void llvm::updateIndexWPDForExports( ModuleSummaryIndex &Summary, function_ref<bool(StringRef, ValueInfo)> isExported, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { @@ -887,8 +950,6 @@ void updateIndexWPDForExports( } } -} // end namespace llvm - static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { // Check that summary index contains regular LTO module when performing // export to prevent occasional use of index from pure ThinLTO compilation @@ -942,7 +1003,7 @@ bool DevirtModule::runForTesting( ExitOnError ExitOnErr( "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - if (StringRef(ClWriteSummary).endswith(".bc")) { + if (StringRef(ClWriteSummary).ends_with(".bc")) { raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); ExitOnErr(errorCodeToError(EC)); writeIndexToFile(*Summary, OS); @@ -1045,8 +1106,8 @@ bool DevirtModule::tryFindVirtualCallTargets( } bool DevirtIndex::tryFindVirtualCallTargets( - std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, - uint64_t ByteOffset) { + std::vector<ValueInfo> &TargetsForSlot, + const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) { for (const TypeIdOffsetVtableInfo &P : TIdInfo) { // Find a representative copy of the vtable initializer. // We can have multiple available_externally, linkonce_odr and weak_odr @@ -1203,7 +1264,8 @@ static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { // to better ensure we have the opportunity to inline them. bool IsExported = false; auto &S = Callee.getSummaryList()[0]; - CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); + CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false, + /* RelBF = */ 0); auto AddCalls = [&](CallSiteInfo &CSInfo) { for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { FS->addCall({Callee, CI}); @@ -1437,7 +1499,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IRBuilder<> IRB(&CB); std::vector<Value *> Args; - Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); + Args.push_back(VCallSite.VTable); llvm::append_range(Args, CB.args()); CallBase *NewCS = nullptr; @@ -1471,10 +1533,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. - std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { - CBs.first->replaceAllUsesWith(CBs.second); - CBs.first->eraseFromParent(); - }); + for (auto &[Old, New] : CallBases) { + Old->replaceAllUsesWith(New); + Old->eraseFromParent(); + } }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -1648,8 +1710,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, } Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { - Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); - return ConstantExpr::getGetElementPtr(Int8Ty, C, + return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV, ConstantInt::get(Int64Ty, M->Offset)); } @@ -1708,8 +1769,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, continue; auto *RetType = cast<IntegerType>(Call.CB.getType()); IRBuilder<> B(&Call.CB); - Value *Addr = - B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); + Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); if (RetType->getBitWidth() == 1) { Value *Bits = B.CreateLoad(Int8Ty, Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); @@ -2007,17 +2067,14 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (TypeCheckedLoadFunc->getIntrinsicID() == Intrinsic::type_checked_load_relative) { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); - LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int32Ty, GEP); LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); } else { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = - LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); - LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP); } for (Instruction *LoadedPtr : LoadedPtrs) { |
