diff options
Diffstat (limited to 'lib/Transforms/Scalar/DeadStoreElimination.cpp')
| -rw-r--r-- | lib/Transforms/Scalar/DeadStoreElimination.cpp | 159 | 
1 files changed, 138 insertions, 21 deletions
| diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 1ec38e56aa4c..e703014bb0e6 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -16,31 +16,55 @@  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Scalar/DeadStoreElimination.h" +#include "llvm/ADT/APInt.h"  #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h"  #include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/CaptureTracking.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/MemoryBuiltins.h"  #include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h"  #include "llvm/Pass.h" +#include "llvm/Support/Casting.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstddef> +#include <iterator>  #include <map> +#include <utility> +  using namespace llvm;  #define DEBUG_TYPE "dse" @@ -49,18 +73,23 @@ STATISTIC(NumRedundantStores, "Number of redundant stores deleted");  STATISTIC(NumFastStores, "Number of stores deleted");  STATISTIC(NumFastOther , "Number of other instrs removed");  STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); +STATISTIC(NumModifiedStores, "Number of stores modified");  static cl::opt<bool>  EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking",    cl::init(true), cl::Hidden,    cl::desc("Enable partial-overwrite tracking in DSE")); +static cl::opt<bool> +EnablePartialStoreMerging("enable-dse-partial-store-merging", +  cl::init(true), cl::Hidden, +  cl::desc("Enable partial store merging in DSE"));  //===----------------------------------------------------------------------===//  // Helper functions  //===----------------------------------------------------------------------===// -typedef std::map<int64_t, int64_t> OverlapIntervalsTy; -typedef DenseMap<Instruction *, OverlapIntervalsTy> InstOverlapIntervalsTy; +using OverlapIntervalsTy = std::map<int64_t, int64_t>; +using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>;  /// Delete this instruction.  Before we do, go through and zero out all the  /// operands of this instruction.  If any of them become dead, delete them and @@ -209,7 +238,6 @@ static bool isRemovable(Instruction *I) {      case Intrinsic::init_trampoline:        // Always safe to remove init_trampoline.        return true; -      case Intrinsic::memset:      case Intrinsic::memmove:      case Intrinsic::memcpy: @@ -224,7 +252,6 @@ static bool isRemovable(Instruction *I) {    return false;  } -  /// Returns true if the end of this instruction can be safely shortened in  /// length.  static bool isShortenableAtTheEnd(Instruction *I) { @@ -287,14 +314,24 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL,  }  namespace { -enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown }; -} + +enum OverwriteResult { +  OW_Begin, +  OW_Complete, +  OW_End, +  OW_PartialEarlierWithFullLater, +  OW_Unknown +}; + +} // end anonymous namespace  /// Return 'OW_Complete' if a store to the 'Later' location completely  /// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the  /// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the -/// beginning of the 'Earlier' location is overwritten by 'Later', or -/// 'OW_Unknown' if nothing can be determined. +/// beginning of the 'Earlier' location is overwritten by 'Later'. +/// 'OW_PartialEarlierWithFullLater' means that an earlier (big) store was +/// overwritten by a latter (smaller) store which doesn't write outside the big +/// store's memory locations. Returns 'OW_Unknown' if nothing can be determined.  static OverwriteResult isOverwrite(const MemoryLocation &Later,                                     const MemoryLocation &Earlier,                                     const DataLayout &DL, @@ -427,6 +464,19 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,      }    } +  // Check for an earlier store which writes to all the memory locations that +  // the later store writes to. +  if (EnablePartialStoreMerging && LaterOff >= EarlierOff && +      int64_t(EarlierOff + Earlier.Size) > LaterOff && +      uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) { +    DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff +                 << ", " << int64_t(EarlierOff + Earlier.Size) +                 << ") by a later store [" << LaterOff << ", " +                 << int64_t(LaterOff + Later.Size) << ")\n"); +    // TODO: Maybe come up with a better name? +    return OW_PartialEarlierWithFullLater; +  } +    // Another interesting case is if the later store overwrites the end of the    // earlier store.    // @@ -544,11 +594,9 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI,      }      for (; BI != EI; ++BI) {        Instruction *I = &*BI; -      if (I->mayWriteToMemory() && I != SecondI) { -        auto Res = AA->getModRefInfo(I, MemLoc); -        if (Res & MRI_Mod) +      if (I->mayWriteToMemory() && I != SecondI) +        if (isModSet(AA->getModRefInfo(I, MemLoc)))            return false; -      }      }      if (B != FirstBB) {        assert(B != &FirstBB->getParent()->getEntryBlock() && @@ -772,9 +820,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,        // the call is live.        DeadStackObjects.remove_if([&](Value *I) {          // See if the call site touches the value. -        ModRefInfo A = AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI)); - -        return A == MRI_ModRef || A == MRI_Ref; +        return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI)));        });        // If all of the allocas were clobbered by the call then we're not going @@ -840,7 +886,7 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset,    if (!IsOverwriteEnd)      LaterOffset = int64_t(LaterOffset + LaterSize); -  if (!(llvm::isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) && +  if (!(isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) &&        !((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0))      return false; @@ -1094,6 +1140,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,        // If we find a write that is a) removable (i.e., non-volatile), b) is        // completely obliterated by the store to 'Loc', and c) which we know that        // 'Inst' doesn't load from, then we can remove it. +      // Also try to merge two stores if a later one only touches memory written +      // to by the earlier one.        if (isRemovable(DepWrite) &&            !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) {          int64_t InstWriteOffset, DepWriteOffset; @@ -1123,6 +1171,72 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,            bool IsOverwriteEnd = (OR == OW_End);            MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize,                                      InstWriteOffset, LaterSize, IsOverwriteEnd); +        } else if (EnablePartialStoreMerging && +                   OR == OW_PartialEarlierWithFullLater) { +          auto *Earlier = dyn_cast<StoreInst>(DepWrite); +          auto *Later = dyn_cast<StoreInst>(Inst); +          if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && +              Later && isa<ConstantInt>(Later->getValueOperand())) { +            // If the store we find is: +            //   a) partially overwritten by the store to 'Loc' +            //   b) the later store is fully contained in the earlier one and +            //   c) they both have a constant value +            // Merge the two stores, replacing the earlier store's value with a +            // merge of both values. +            // TODO: Deal with other constant types (vectors, etc), and probably +            // some mem intrinsics (if needed) + +            APInt EarlierValue = +                cast<ConstantInt>(Earlier->getValueOperand())->getValue(); +            APInt LaterValue = +                cast<ConstantInt>(Later->getValueOperand())->getValue(); +            unsigned LaterBits = LaterValue.getBitWidth(); +            assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); +            LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + +            // Offset of the smaller store inside the larger store +            unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; +            unsigned LShiftAmount = +                DL.isBigEndian() +                    ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits +                    : BitOffsetDiff; +            APInt Mask = +                APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, +                                  LShiftAmount + LaterBits); +            // Clear the bits we'll be replacing, then OR with the smaller +            // store, shifted appropriately. +            APInt Merged = +                (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); +            DEBUG(dbgs() << "DSE: Merge Stores:\n  Earlier: " << *DepWrite +                         << "\n  Later: " << *Inst +                         << "\n  Merged Value: " << Merged << '\n'); + +            auto *SI = new StoreInst( +                ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), +                Earlier->getPointerOperand(), false, Earlier->getAlignment(), +                Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite); + +            unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa, +                                   LLVMContext::MD_alias_scope, +                                   LLVMContext::MD_noalias, +                                   LLVMContext::MD_nontemporal}; +            SI->copyMetadata(*DepWrite, MDToKeep); +            ++NumModifiedStores; + +            // Remove earlier, wider, store +            size_t Idx = InstrOrdering.lookup(DepWrite); +            InstrOrdering.erase(DepWrite); +            InstrOrdering.insert(std::make_pair(SI, Idx)); + +            // Delete the old stores and now-dead instructions that feed them. +            deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, &InstrOrdering); +            deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, +                                  &InstrOrdering); +            MadeChange = true; + +            // We erased DepWrite and Inst (Loc); start over. +            break; +          }          }        } @@ -1137,7 +1251,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,        if (DepWrite == &BB.front()) break;        // Can't look past this instruction if it might read 'Loc'. -      if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref) +      if (isRefSet(AA->getModRefInfo(DepWrite, Loc)))          break;        InstDep = MD->getPointerDependencyFrom(Loc, /*isLoad=*/ false, @@ -1190,9 +1304,12 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {  }  namespace { +  /// A legacy pass for the legacy pass manager that wraps \c DSEPass.  class DSELegacyPass : public FunctionPass {  public: +  static char ID; // Pass identification, replacement for typeid +    DSELegacyPass() : FunctionPass(ID) {      initializeDSELegacyPassPass(*PassRegistry::getPassRegistry());    } @@ -1221,12 +1338,12 @@ public:      AU.addPreserved<GlobalsAAWrapperPass>();      AU.addPreserved<MemoryDependenceWrapperPass>();    } - -  static char ID; // Pass identification, replacement for typeid  }; +  } // end anonymous namespace  char DSELegacyPass::ID = 0; +  INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false,                        false)  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) | 
