summaryrefslogtreecommitdiff
path: root/lib/Transforms/IPO/CrossDSOCFI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/IPO/CrossDSOCFI.cpp')
-rw-r--r--lib/Transforms/IPO/CrossDSOCFI.cpp118
1 files changed, 58 insertions, 60 deletions
diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp
index 5bbb7513005c6..58731eaf6e30f 100644
--- a/lib/Transforms/IPO/CrossDSOCFI.cpp
+++ b/lib/Transforms/IPO/CrossDSOCFI.cpp
@@ -12,7 +12,7 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/CrossDSOCFI.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Statistic.h"
@@ -30,13 +30,14 @@
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
#define DEBUG_TYPE "cross-dso-cfi"
-STATISTIC(TypeIds, "Number of unique type identifiers");
+STATISTIC(NumTypeIds, "Number of unique type identifiers");
namespace {
@@ -46,13 +47,10 @@ struct CrossDSOCFI : public ModulePass {
initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
}
- Module *M;
MDNode *VeryLikelyWeights;
- ConstantInt *extractBitSetTypeId(MDNode *MD);
- void buildCFICheck();
-
- bool doInitialization(Module &M) override;
+ ConstantInt *extractNumericTypeId(MDNode *MD);
+ void buildCFICheck(Module &M);
bool runOnModule(Module &M) override;
};
@@ -65,18 +63,10 @@ char CrossDSOCFI::ID = 0;
ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
-bool CrossDSOCFI::doInitialization(Module &Mod) {
- M = &Mod;
- VeryLikelyWeights =
- MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1);
-
- return false;
-}
-
-/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode.
-ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
+/// Extracts a numeric type identifier from an MDNode containing type metadata.
+ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) {
// This check excludes vtables for classes inside anonymous namespaces.
- auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0));
+ auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1));
if (!TM)
return nullptr;
auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
@@ -84,68 +74,63 @@ ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
// We are looking for i64 constants.
if (C->getBitWidth() != 64) return nullptr;
- // Sanity check.
- auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1));
- // Can be null if a function was removed by an optimization.
- if (FM) {
- auto F = dyn_cast<Function>(FM->getValue());
- // But can never be a function declaration.
- assert(!F || !F->isDeclaration());
- (void)F; // Suppress unused variable warning in the no-asserts build.
- }
return C;
}
/// buildCFICheck - emits __cfi_check for the current module.
-void CrossDSOCFI::buildCFICheck() {
+void CrossDSOCFI::buildCFICheck(Module &M) {
// FIXME: verify that __cfi_check ends up near the end of the code section,
- // but before the jump slots created in LowerBitSets.
- llvm::DenseSet<uint64_t> BitSetIds;
- NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets");
-
- if (BitSetNM)
- for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I)
- if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I)))
- BitSetIds.insert(TypeId->getZExtValue());
-
- LLVMContext &Ctx = M->getContext();
- Constant *C = M->getOrInsertFunction(
- "__cfi_check",
- FunctionType::get(
- Type::getVoidTy(Ctx),
- {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))},
- false));
+ // but before the jump slots created in LowerTypeTests.
+ llvm::DenseSet<uint64_t> TypeIds;
+ SmallVector<MDNode *, 2> Types;
+ for (GlobalObject &GO : M.global_objects()) {
+ Types.clear();
+ GO.getMetadata(LLVMContext::MD_type, Types);
+ for (MDNode *Type : Types) {
+ // Sanity check. GO must not be a function declaration.
+ assert(!isa<Function>(&GO) || !cast<Function>(&GO)->isDeclaration());
+
+ if (ConstantInt *TypeId = extractNumericTypeId(Type))
+ TypeIds.insert(TypeId->getZExtValue());
+ }
+ }
+
+ LLVMContext &Ctx = M.getContext();
+ Constant *C = M.getOrInsertFunction(
+ "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),
+ Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr);
Function *F = dyn_cast<Function>(C);
F->setAlignment(4096);
auto args = F->arg_begin();
- Argument &CallSiteTypeId = *(args++);
+ Value &CallSiteTypeId = *(args++);
CallSiteTypeId.setName("CallSiteTypeId");
- Argument &Addr = *(args++);
+ Value &Addr = *(args++);
Addr.setName("Addr");
+ Value &CFICheckFailData = *(args++);
+ CFICheckFailData.setName("CFICheckFailData");
assert(args == F->arg_end());
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+ BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
- BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F);
- IRBuilder<> IRBTrap(TrapBB);
- Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap);
- llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn);
- TrapCall->setDoesNotReturn();
- TrapCall->setDoesNotThrow();
- IRBTrap.CreateUnreachable();
+ BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F);
+ IRBuilder<> IRBFail(TrapBB);
+ Constant *CFICheckFailFn = M.getOrInsertFunction(
+ "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx),
+ Type::getInt8PtrTy(Ctx), nullptr);
+ IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});
+ IRBFail.CreateBr(ExitBB);
- BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
IRBuilder<> IRBExit(ExitBB);
IRBExit.CreateRetVoid();
IRBuilder<> IRB(BB);
- SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size());
- for (uint64_t TypeId : BitSetIds) {
+ SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size());
+ for (uint64_t TypeId : TypeIds) {
ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
IRBuilder<> IRBTest(TestBB);
- Function *BitsetTestFn =
- Intrinsic::getDeclaration(M, Intrinsic::bitset_test);
+ Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
Value *Test = IRBTest.CreateCall(
BitsetTestFn, {&Addr, MetadataAsValue::get(
@@ -154,13 +139,26 @@ void CrossDSOCFI::buildCFICheck() {
BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
SI->addCase(CaseTypeId, TestBB);
- ++TypeIds;
+ ++NumTypeIds;
}
}
bool CrossDSOCFI::runOnModule(Module &M) {
+ if (skipModule(M))
+ return false;
+
+ VeryLikelyWeights =
+ MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1);
if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
return false;
- buildCFICheck();
+ buildCFICheck(M);
return true;
}
+
+PreservedAnalyses CrossDSOCFIPass::run(Module &M, AnalysisManager<Module> &AM) {
+ CrossDSOCFI Impl;
+ bool Changed = Impl.runOnModule(M);
+ if (!Changed)
+ return PreservedAnalyses::all();
+ return PreservedAnalyses::none();
+}