diff options
Diffstat (limited to 'lib/Transforms/IPO/CrossDSOCFI.cpp')
-rw-r--r-- | lib/Transforms/IPO/CrossDSOCFI.cpp | 118 |
1 files changed, 58 insertions, 60 deletions
diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index 5bbb7513005c6..58731eaf6e30f 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/CrossDSOCFI.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Statistic.h" @@ -30,13 +30,14 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; #define DEBUG_TYPE "cross-dso-cfi" -STATISTIC(TypeIds, "Number of unique type identifiers"); +STATISTIC(NumTypeIds, "Number of unique type identifiers"); namespace { @@ -46,13 +47,10 @@ struct CrossDSOCFI : public ModulePass { initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry()); } - Module *M; MDNode *VeryLikelyWeights; - ConstantInt *extractBitSetTypeId(MDNode *MD); - void buildCFICheck(); - - bool doInitialization(Module &M) override; + ConstantInt *extractNumericTypeId(MDNode *MD); + void buildCFICheck(Module &M); bool runOnModule(Module &M) override; }; @@ -65,18 +63,10 @@ char CrossDSOCFI::ID = 0; ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; } -bool CrossDSOCFI::doInitialization(Module &Mod) { - M = &Mod; - VeryLikelyWeights = - MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1); - - return false; -} - -/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode. -ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { +/// Extracts a numeric type identifier from an MDNode containing type metadata. +ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { // This check excludes vtables for classes inside anonymous namespaces. - auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0)); + auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1)); if (!TM) return nullptr; auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); @@ -84,68 +74,63 @@ ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { // We are looking for i64 constants. if (C->getBitWidth() != 64) return nullptr; - // Sanity check. - auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1)); - // Can be null if a function was removed by an optimization. - if (FM) { - auto F = dyn_cast<Function>(FM->getValue()); - // But can never be a function declaration. - assert(!F || !F->isDeclaration()); - (void)F; // Suppress unused variable warning in the no-asserts build. - } return C; } /// buildCFICheck - emits __cfi_check for the current module. -void CrossDSOCFI::buildCFICheck() { +void CrossDSOCFI::buildCFICheck(Module &M) { // FIXME: verify that __cfi_check ends up near the end of the code section, - // but before the jump slots created in LowerBitSets. - llvm::DenseSet<uint64_t> BitSetIds; - NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets"); - - if (BitSetNM) - for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) - if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I))) - BitSetIds.insert(TypeId->getZExtValue()); - - LLVMContext &Ctx = M->getContext(); - Constant *C = M->getOrInsertFunction( - "__cfi_check", - FunctionType::get( - Type::getVoidTy(Ctx), - {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))}, - false)); + // but before the jump slots created in LowerTypeTests. + llvm::DenseSet<uint64_t> TypeIds; + SmallVector<MDNode *, 2> Types; + for (GlobalObject &GO : M.global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + // Sanity check. GO must not be a function declaration. + assert(!isa<Function>(&GO) || !cast<Function>(&GO)->isDeclaration()); + + if (ConstantInt *TypeId = extractNumericTypeId(Type)) + TypeIds.insert(TypeId->getZExtValue()); + } + } + + LLVMContext &Ctx = M.getContext(); + Constant *C = M.getOrInsertFunction( + "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), + Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr); Function *F = dyn_cast<Function>(C); F->setAlignment(4096); auto args = F->arg_begin(); - Argument &CallSiteTypeId = *(args++); + Value &CallSiteTypeId = *(args++); CallSiteTypeId.setName("CallSiteTypeId"); - Argument &Addr = *(args++); + Value &Addr = *(args++); Addr.setName("Addr"); + Value &CFICheckFailData = *(args++); + CFICheckFailData.setName("CFICheckFailData"); assert(args == F->arg_end()); BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); + BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); - BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F); - IRBuilder<> IRBTrap(TrapBB); - Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap); - llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn); - TrapCall->setDoesNotReturn(); - TrapCall->setDoesNotThrow(); - IRBTrap.CreateUnreachable(); + BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); + IRBuilder<> IRBFail(TrapBB); + Constant *CFICheckFailFn = M.getOrInsertFunction( + "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx), + Type::getInt8PtrTy(Ctx), nullptr); + IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); + IRBFail.CreateBr(ExitBB); - BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); IRBuilder<> IRBExit(ExitBB); IRBExit.CreateRetVoid(); IRBuilder<> IRB(BB); - SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size()); - for (uint64_t TypeId : BitSetIds) { + SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size()); + for (uint64_t TypeId : TypeIds) { ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); IRBuilder<> IRBTest(TestBB); - Function *BitsetTestFn = - Intrinsic::getDeclaration(M, Intrinsic::bitset_test); + Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test); Value *Test = IRBTest.CreateCall( BitsetTestFn, {&Addr, MetadataAsValue::get( @@ -154,13 +139,26 @@ void CrossDSOCFI::buildCFICheck() { BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); SI->addCase(CaseTypeId, TestBB); - ++TypeIds; + ++NumTypeIds; } } bool CrossDSOCFI::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + VeryLikelyWeights = + MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); if (M.getModuleFlag("Cross-DSO CFI") == nullptr) return false; - buildCFICheck(); + buildCFICheck(M); return true; } + +PreservedAnalyses CrossDSOCFIPass::run(Module &M, AnalysisManager<Module> &AM) { + CrossDSOCFI Impl; + bool Changed = Impl.runOnModule(M); + if (!Changed) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} |