aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
commitb1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch)
tree7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
parent7fa27ce4a07f19b07799a767fc29416f3b625afb (diff)
Diffstat (limited to 'llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp129
1 files changed, 93 insertions, 36 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index d33258642365..85afc020dbf8 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -58,7 +58,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
@@ -369,8 +368,6 @@ template <> struct DenseMapInfo<VTableSlotSummary> {
} // end namespace llvm
-namespace {
-
// Returns true if the function must be unreachable based on ValueInfo.
//
// In particular, identifies a function as unreachable in the following
@@ -378,7 +375,7 @@ namespace {
// 1) All summaries are live.
// 2) All function summaries indicate it's unreachable
// 3) There is no non-function with the same GUID (which is rare)
-bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
+static bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
// Returns false if ValueInfo is absent, or the summary list is empty
// (e.g., function declarations).
@@ -403,6 +400,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
return true;
}
+namespace {
// A virtual call site. VTable is the loaded virtual table pointer, and CS is
// the indirect virtual call.
struct VirtualCallSite {
@@ -590,7 +588,7 @@ struct DevirtModule {
: M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
ExportSummary(ExportSummary), ImportSummary(ImportSummary),
Int8Ty(Type::getInt8Ty(M.getContext())),
- Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
+ Int8PtrTy(PointerType::getUnqual(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())),
Int64Ty(Type::getInt64Ty(M.getContext())),
IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
@@ -776,20 +774,59 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return PreservedAnalyses::none();
}
-namespace llvm {
// Enable whole program visibility if enabled by client (e.g. linker) or
// internal option, and not force disabled.
-bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
+bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
!DisableWholeProgramVisibility;
}
+static bool
+typeIDVisibleToRegularObj(StringRef TypeID,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ // TypeID for member function pointer type is an internal construct
+ // and won't exist in IsVisibleToRegularObj. The full TypeID
+ // will be present and participate in invalidation.
+ if (TypeID.ends_with(".virtual"))
+ return false;
+
+ // TypeID that doesn't start with Itanium mangling (_ZTS) will be
+ // non-externally visible types which cannot interact with
+ // external native files. See CodeGenModule::CreateMetadataIdentifierImpl.
+ if (!TypeID.consume_front("_ZTS"))
+ return false;
+
+ // TypeID is keyed off the type name symbol (_ZTS). However, the native
+ // object may not contain this symbol if it does not contain a key
+ // function for the base type and thus only contains a reference to the
+ // type info (_ZTI). To catch this case we query using the type info
+ // symbol corresponding to the TypeID.
+ std::string typeInfo = ("_ZTI" + TypeID).str();
+ return IsVisibleToRegularObj(typeInfo);
+}
+
+static bool
+skipUpdateDueToValidation(GlobalVariable &GV,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ SmallVector<MDNode *, 2> Types;
+ GV.getMetadata(LLVMContext::MD_type, Types);
+
+ for (auto Type : Types)
+ if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get()))
+ return typeIDVisibleToRegularObj(TypeID->getString(),
+ IsVisibleToRegularObj);
+
+ return false;
+}
+
/// If whole program visibility asserted, then upgrade all public vcall
/// visibility metadata on vtable definitions to linkage unit visibility in
/// Module IR (for regular or hybrid LTO).
-void updateVCallVisibilityInModule(
+void llvm::updateVCallVisibilityInModule(
Module &M, bool WholeProgramVisibilityEnabledInLTO,
- const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
+ const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
+ bool ValidateAllVtablesHaveTypeInfos,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
return;
for (GlobalVariable &GV : M.globals()) {
@@ -800,13 +837,19 @@ void updateVCallVisibilityInModule(
GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic &&
// Don't upgrade the visibility for symbols exported to the dynamic
// linker, as we have no information on their eventual use.
- !DynamicExportSymbols.count(GV.getGUID()))
+ !DynamicExportSymbols.count(GV.getGUID()) &&
+ // With validation enabled, we want to exclude symbols visible to
+ // regular objects. Local symbols will be in this group due to the
+ // current implementation but those with VCallVisibilityTranslationUnit
+ // will have already been marked in clang so are unaffected.
+ !(ValidateAllVtablesHaveTypeInfos &&
+ skipUpdateDueToValidation(GV, IsVisibleToRegularObj)))
GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
}
}
-void updatePublicTypeTestCalls(Module &M,
- bool WholeProgramVisibilityEnabledInLTO) {
+void llvm::updatePublicTypeTestCalls(Module &M,
+ bool WholeProgramVisibilityEnabledInLTO) {
Function *PublicTypeTestFunc =
M.getFunction(Intrinsic::getName(Intrinsic::public_type_test));
if (!PublicTypeTestFunc)
@@ -832,12 +875,26 @@ void updatePublicTypeTestCalls(Module &M,
}
}
+/// Based on typeID string, get all associated vtable GUIDS that are
+/// visible to regular objects.
+void llvm::getVisibleToRegularObjVtableGUIDs(
+ ModuleSummaryIndex &Index,
+ DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ for (const auto &typeID : Index.typeIdCompatibleVtableMap()) {
+ if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj))
+ for (const TypeIdOffsetVtableInfo &P : typeID.second)
+ VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID());
+ }
+}
+
/// If whole program visibility asserted, then upgrade all public vcall
/// visibility metadata on vtable definition summaries to linkage unit
/// visibility in Module summary index (for ThinLTO).
-void updateVCallVisibilityInIndex(
+void llvm::updateVCallVisibilityInIndex(
ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
- const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
+ const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
+ const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) {
if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
return;
for (auto &P : Index) {
@@ -850,18 +907,24 @@ void updateVCallVisibilityInIndex(
if (!GVar ||
GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
continue;
+ // With validation enabled, we want to exclude symbols visible to regular
+ // objects. Local symbols will be in this group due to the current
+ // implementation but those with VCallVisibilityTranslationUnit will have
+ // already been marked in clang so are unaffected.
+ if (VisibleToRegularObjSymbols.count(P.first))
+ continue;
GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
}
}
}
-void runWholeProgramDevirtOnIndex(
+void llvm::runWholeProgramDevirtOnIndex(
ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
}
-void updateIndexWPDForExports(
+void llvm::updateIndexWPDForExports(
ModuleSummaryIndex &Summary,
function_ref<bool(StringRef, ValueInfo)> isExported,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
@@ -887,8 +950,6 @@ void updateIndexWPDForExports(
}
}
-} // end namespace llvm
-
static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
// Check that summary index contains regular LTO module when performing
// export to prevent occasional use of index from pure ThinLTO compilation
@@ -942,7 +1003,7 @@ bool DevirtModule::runForTesting(
ExitOnError ExitOnErr(
"-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
std::error_code EC;
- if (StringRef(ClWriteSummary).endswith(".bc")) {
+ if (StringRef(ClWriteSummary).ends_with(".bc")) {
raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
ExitOnErr(errorCodeToError(EC));
writeIndexToFile(*Summary, OS);
@@ -1045,8 +1106,8 @@ bool DevirtModule::tryFindVirtualCallTargets(
}
bool DevirtIndex::tryFindVirtualCallTargets(
- std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo,
- uint64_t ByteOffset) {
+ std::vector<ValueInfo> &TargetsForSlot,
+ const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) {
for (const TypeIdOffsetVtableInfo &P : TIdInfo) {
// Find a representative copy of the vtable initializer.
// We can have multiple available_externally, linkonce_odr and weak_odr
@@ -1203,7 +1264,8 @@ static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
// to better ensure we have the opportunity to inline them.
bool IsExported = false;
auto &S = Callee.getSummaryList()[0];
- CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0);
+ CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false,
+ /* RelBF = */ 0);
auto AddCalls = [&](CallSiteInfo &CSInfo) {
for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
FS->addCall({Callee, CI});
@@ -1437,7 +1499,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
IRBuilder<> IRB(&CB);
std::vector<Value *> Args;
- Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
+ Args.push_back(VCallSite.VTable);
llvm::append_range(Args, CB.args());
CallBase *NewCS = nullptr;
@@ -1471,10 +1533,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// llvm.type.test and therefore require an llvm.type.test resolution for the
// type identifier.
- std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) {
- CBs.first->replaceAllUsesWith(CBs.second);
- CBs.first->eraseFromParent();
- });
+ for (auto &[Old, New] : CallBases) {
+ Old->replaceAllUsesWith(New);
+ Old->eraseFromParent();
+ }
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -1648,8 +1710,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
}
Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
- Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
- return ConstantExpr::getGetElementPtr(Int8Ty, C,
+ return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV,
ConstantInt::get(Int64Ty, M->Offset));
}
@@ -1708,8 +1769,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
continue;
auto *RetType = cast<IntegerType>(Call.CB.getType());
IRBuilder<> B(&Call.CB);
- Value *Addr =
- B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
+ Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
if (RetType->getBitWidth() == 1) {
Value *Bits = B.CreateLoad(Int8Ty, Addr);
Value *BitsAndBit = B.CreateAnd(Bits, Bit);
@@ -2007,17 +2067,14 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (TypeCheckedLoadFunc->getIntrinsicID() ==
Intrinsic::type_checked_load_relative) {
Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty));
- LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr);
+ LoadedValue = LoadB.CreateLoad(Int32Ty, GEP);
LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy);
GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy);
LoadedValue = LoadB.CreateAdd(GEP, LoadedValue);
LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy);
} else {
Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr =
- LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
- LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+ LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP);
}
for (Instruction *LoadedPtr : LoadedPtrs) {