summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/DeadStoreElimination.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/DeadStoreElimination.cpp')
-rw-r--r--lib/Transforms/Scalar/DeadStoreElimination.cpp159
1 files changed, 138 insertions, 21 deletions
diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 1ec38e56aa4c..e703014bb0e6 100644
--- a/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -16,31 +16,55 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
-#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Value.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstddef>
+#include <iterator>
#include <map>
+#include <utility>
+
using namespace llvm;
#define DEBUG_TYPE "dse"
@@ -49,18 +73,23 @@ STATISTIC(NumRedundantStores, "Number of redundant stores deleted");
STATISTIC(NumFastStores, "Number of stores deleted");
STATISTIC(NumFastOther , "Number of other instrs removed");
STATISTIC(NumCompletePartials, "Number of stores dead by later partials");
+STATISTIC(NumModifiedStores, "Number of stores modified");
static cl::opt<bool>
EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking",
cl::init(true), cl::Hidden,
cl::desc("Enable partial-overwrite tracking in DSE"));
+static cl::opt<bool>
+EnablePartialStoreMerging("enable-dse-partial-store-merging",
+ cl::init(true), cl::Hidden,
+ cl::desc("Enable partial store merging in DSE"));
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
-typedef std::map<int64_t, int64_t> OverlapIntervalsTy;
-typedef DenseMap<Instruction *, OverlapIntervalsTy> InstOverlapIntervalsTy;
+using OverlapIntervalsTy = std::map<int64_t, int64_t>;
+using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>;
/// Delete this instruction. Before we do, go through and zero out all the
/// operands of this instruction. If any of them become dead, delete them and
@@ -209,7 +238,6 @@ static bool isRemovable(Instruction *I) {
case Intrinsic::init_trampoline:
// Always safe to remove init_trampoline.
return true;
-
case Intrinsic::memset:
case Intrinsic::memmove:
case Intrinsic::memcpy:
@@ -224,7 +252,6 @@ static bool isRemovable(Instruction *I) {
return false;
}
-
/// Returns true if the end of this instruction can be safely shortened in
/// length.
static bool isShortenableAtTheEnd(Instruction *I) {
@@ -287,14 +314,24 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL,
}
namespace {
-enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown };
-}
+
+enum OverwriteResult {
+ OW_Begin,
+ OW_Complete,
+ OW_End,
+ OW_PartialEarlierWithFullLater,
+ OW_Unknown
+};
+
+} // end anonymous namespace
/// Return 'OW_Complete' if a store to the 'Later' location completely
/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the
/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the
-/// beginning of the 'Earlier' location is overwritten by 'Later', or
-/// 'OW_Unknown' if nothing can be determined.
+/// beginning of the 'Earlier' location is overwritten by 'Later'.
+/// 'OW_PartialEarlierWithFullLater' means that an earlier (big) store was
+/// overwritten by a latter (smaller) store which doesn't write outside the big
+/// store's memory locations. Returns 'OW_Unknown' if nothing can be determined.
static OverwriteResult isOverwrite(const MemoryLocation &Later,
const MemoryLocation &Earlier,
const DataLayout &DL,
@@ -427,6 +464,19 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
}
}
+ // Check for an earlier store which writes to all the memory locations that
+ // the later store writes to.
+ if (EnablePartialStoreMerging && LaterOff >= EarlierOff &&
+ int64_t(EarlierOff + Earlier.Size) > LaterOff &&
+ uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) {
+ DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff
+ << ", " << int64_t(EarlierOff + Earlier.Size)
+ << ") by a later store [" << LaterOff << ", "
+ << int64_t(LaterOff + Later.Size) << ")\n");
+ // TODO: Maybe come up with a better name?
+ return OW_PartialEarlierWithFullLater;
+ }
+
// Another interesting case is if the later store overwrites the end of the
// earlier store.
//
@@ -544,11 +594,9 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI,
}
for (; BI != EI; ++BI) {
Instruction *I = &*BI;
- if (I->mayWriteToMemory() && I != SecondI) {
- auto Res = AA->getModRefInfo(I, MemLoc);
- if (Res & MRI_Mod)
+ if (I->mayWriteToMemory() && I != SecondI)
+ if (isModSet(AA->getModRefInfo(I, MemLoc)))
return false;
- }
}
if (B != FirstBB) {
assert(B != &FirstBB->getParent()->getEntryBlock() &&
@@ -772,9 +820,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
// the call is live.
DeadStackObjects.remove_if([&](Value *I) {
// See if the call site touches the value.
- ModRefInfo A = AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI));
-
- return A == MRI_ModRef || A == MRI_Ref;
+ return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI)));
});
// If all of the allocas were clobbered by the call then we're not going
@@ -840,7 +886,7 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset,
if (!IsOverwriteEnd)
LaterOffset = int64_t(LaterOffset + LaterSize);
- if (!(llvm::isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) &&
+ if (!(isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) &&
!((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0))
return false;
@@ -1094,6 +1140,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
// If we find a write that is a) removable (i.e., non-volatile), b) is
// completely obliterated by the store to 'Loc', and c) which we know that
// 'Inst' doesn't load from, then we can remove it.
+ // Also try to merge two stores if a later one only touches memory written
+ // to by the earlier one.
if (isRemovable(DepWrite) &&
!isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) {
int64_t InstWriteOffset, DepWriteOffset;
@@ -1123,6 +1171,72 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
bool IsOverwriteEnd = (OR == OW_End);
MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize,
InstWriteOffset, LaterSize, IsOverwriteEnd);
+ } else if (EnablePartialStoreMerging &&
+ OR == OW_PartialEarlierWithFullLater) {
+ auto *Earlier = dyn_cast<StoreInst>(DepWrite);
+ auto *Later = dyn_cast<StoreInst>(Inst);
+ if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) &&
+ Later && isa<ConstantInt>(Later->getValueOperand())) {
+ // If the store we find is:
+ // a) partially overwritten by the store to 'Loc'
+ // b) the later store is fully contained in the earlier one and
+ // c) they both have a constant value
+ // Merge the two stores, replacing the earlier store's value with a
+ // merge of both values.
+ // TODO: Deal with other constant types (vectors, etc), and probably
+ // some mem intrinsics (if needed)
+
+ APInt EarlierValue =
+ cast<ConstantInt>(Earlier->getValueOperand())->getValue();
+ APInt LaterValue =
+ cast<ConstantInt>(Later->getValueOperand())->getValue();
+ unsigned LaterBits = LaterValue.getBitWidth();
+ assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth());
+ LaterValue = LaterValue.zext(EarlierValue.getBitWidth());
+
+ // Offset of the smaller store inside the larger store
+ unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8;
+ unsigned LShiftAmount =
+ DL.isBigEndian()
+ ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits
+ : BitOffsetDiff;
+ APInt Mask =
+ APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount,
+ LShiftAmount + LaterBits);
+ // Clear the bits we'll be replacing, then OR with the smaller
+ // store, shifted appropriately.
+ APInt Merged =
+ (EarlierValue & ~Mask) | (LaterValue << LShiftAmount);
+ DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite
+ << "\n Later: " << *Inst
+ << "\n Merged Value: " << Merged << '\n');
+
+ auto *SI = new StoreInst(
+ ConstantInt::get(Earlier->getValueOperand()->getType(), Merged),
+ Earlier->getPointerOperand(), false, Earlier->getAlignment(),
+ Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite);
+
+ unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,
+ LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias,
+ LLVMContext::MD_nontemporal};
+ SI->copyMetadata(*DepWrite, MDToKeep);
+ ++NumModifiedStores;
+
+ // Remove earlier, wider, store
+ size_t Idx = InstrOrdering.lookup(DepWrite);
+ InstrOrdering.erase(DepWrite);
+ InstrOrdering.insert(std::make_pair(SI, Idx));
+
+ // Delete the old stores and now-dead instructions that feed them.
+ deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, &InstrOrdering);
+ deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL,
+ &InstrOrdering);
+ MadeChange = true;
+
+ // We erased DepWrite and Inst (Loc); start over.
+ break;
+ }
}
}
@@ -1137,7 +1251,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
if (DepWrite == &BB.front()) break;
// Can't look past this instruction if it might read 'Loc'.
- if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref)
+ if (isRefSet(AA->getModRefInfo(DepWrite, Loc)))
break;
InstDep = MD->getPointerDependencyFrom(Loc, /*isLoad=*/ false,
@@ -1190,9 +1304,12 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
}
namespace {
+
/// A legacy pass for the legacy pass manager that wraps \c DSEPass.
class DSELegacyPass : public FunctionPass {
public:
+ static char ID; // Pass identification, replacement for typeid
+
DSELegacyPass() : FunctionPass(ID) {
initializeDSELegacyPassPass(*PassRegistry::getPassRegistry());
}
@@ -1221,12 +1338,12 @@ public:
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addPreserved<MemoryDependenceWrapperPass>();
}
-
- static char ID; // Pass identification, replacement for typeid
};
+
} // end anonymous namespace
char DSELegacyPass::ID = 0;
+
INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false,
false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)