aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2020-07-26 19:36:28 +0000
committerDimitry Andric <dim@FreeBSD.org>2020-07-26 19:36:28 +0000
commitcfca06d7963fa0909f90483b42a6d7d194d01e08 (patch)
tree209fb2a2d68f8f277793fc8df46c753d31bc853b /llvm/lib/Transforms/Scalar
parent706b4fc47bbc608932d3b491ae19a3b9cde9497b (diff)
Notes
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
-rw-r--r--llvm/lib/Transforms/Scalar/ADCE.cpp22
-rw-r--r--llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp247
-rw-r--r--llvm/lib/Transforms/Scalar/BDCE.cpp26
-rw-r--r--llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp156
-rw-r--r--llvm/lib/Transforms/Scalar/ConstantHoisting.cpp20
-rw-r--r--llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp47
-rw-r--r--llvm/lib/Transforms/Scalar/DCE.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp1230
-rw-r--r--llvm/lib/Transforms/Scalar/DivRemPairs.cpp27
-rw-r--r--llvm/lib/Transforms/Scalar/EarlyCSE.cpp244
-rw-r--r--llvm/lib/Transforms/Scalar/Float2Int.cpp11
-rw-r--r--llvm/lib/Transforms/Scalar/GVN.cpp101
-rw-r--r--llvm/lib/Transforms/Scalar/GVNHoist.cpp14
-rw-r--r--llvm/lib/Transforms/Scalar/GVNSink.cpp15
-rw-r--r--llvm/lib/Transforms/Scalar/IndVarSimplify.cpp595
-rw-r--r--llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp120
-rw-r--r--llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp206
-rw-r--r--llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/JumpThreading.cpp451
-rw-r--r--llvm/lib/Transforms/Scalar/LICM.cpp273
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp209
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDeletion.cpp44
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDistribute.cpp21
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFuse.cpp109
-rw-r--r--llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp139
-rw-r--r--llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/LoopInterchange.cpp65
-rw-r--r--llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp29
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPassManager.cpp13
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp57
-rw-r--r--llvm/lib/Transforms/Scalar/LoopRerollPass.cpp10
-rw-r--r--llvm/lib/Transforms/Scalar/LoopRotation.cpp13
-rw-r--r--llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp10
-rw-r--r--llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp206
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp96
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp181
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnswitch.cpp328
-rw-r--r--llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp1
-rw-r--r--llvm/lib/Transforms/Scalar/LowerAtomic.cpp13
-rw-r--r--llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp16
-rw-r--r--llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp77
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp1531
-rw-r--r--llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp136
-rw-r--r--llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp10
-rw-r--r--llvm/lib/Transforms/Scalar/NaryReassociate.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/NewGVN.cpp18
-rw-r--r--llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp5
-rw-r--r--llvm/lib/Transforms/Scalar/Reassociate.cpp40
-rw-r--r--llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp193
-rw-r--r--llvm/lib/Transforms/Scalar/SCCP.cpp1367
-rw-r--r--llvm/lib/Transforms/Scalar/SROA.cpp562
-rw-r--r--llvm/lib/Transforms/Scalar/Scalarizer.cpp250
-rw-r--r--llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp62
-rw-r--r--llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp58
-rw-r--r--llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp23
-rw-r--r--llvm/lib/Transforms/Scalar/Sink.cpp4
-rw-r--r--llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp18
-rw-r--r--llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp40
-rw-r--r--llvm/lib/Transforms/Scalar/StructurizeCFG.cpp246
-rw-r--r--llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp585
-rw-r--r--llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp1
61 files changed, 6579 insertions, 4020 deletions
diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp
index cc3d3bf7cdbf..c3709b9afffb 100644
--- a/llvm/lib/Transforms/Scalar/ADCE.cpp
+++ b/llvm/lib/Transforms/Scalar/ADCE.cpp
@@ -182,7 +182,7 @@ class AggressiveDeadCodeElimination {
/// Identify connected sections of the control flow graph which have
/// dead terminators and rewrite the control flow graph to remove them.
- void updateDeadRegions();
+ bool updateDeadRegions();
/// Set the BlockInfo::PostOrder field based on a post-order
/// numbering of the reverse control flow graph.
@@ -505,7 +505,7 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() {
//===----------------------------------------------------------------------===//
bool AggressiveDeadCodeElimination::removeDeadInstructions() {
// Updates control and dataflow around dead blocks
- updateDeadRegions();
+ bool RegionsUpdated = updateDeadRegions();
LLVM_DEBUG({
for (Instruction &I : instructions(F)) {
@@ -556,11 +556,11 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() {
I->eraseFromParent();
}
- return !Worklist.empty();
+ return !Worklist.empty() || RegionsUpdated;
}
// A dead region is the set of dead blocks with a common live post-dominator.
-void AggressiveDeadCodeElimination::updateDeadRegions() {
+bool AggressiveDeadCodeElimination::updateDeadRegions() {
LLVM_DEBUG({
dbgs() << "final dead terminator blocks: " << '\n';
for (auto *BB : BlocksWithDeadTerminators)
@@ -570,6 +570,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {
// Don't compute the post ordering unless we needed it.
bool HavePostOrder = false;
+ bool Changed = false;
for (auto *BB : BlocksWithDeadTerminators) {
auto &Info = BlockInfo[BB];
@@ -624,7 +625,10 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {
.applyUpdates(DeletedEdges);
NumBranchesRemoved += 1;
+ Changed = true;
}
+
+ return Changed;
}
// reverse top-sort order
@@ -685,10 +689,14 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) {
return PreservedAnalyses::all();
PreservedAnalyses PA;
- PA.preserveSet<CFGAnalyses>();
+ // TODO: We could track if we have actually done CFG changes.
+ if (!RemoveControlFlowFlag)
+ PA.preserveSet<CFGAnalyses>();
+ else {
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<PostDominatorTreeAnalysis>();
+ }
PA.preserve<GlobalsAA>();
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<PostDominatorTreeAnalysis>();
return PA;
}
diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index 06deaf3c4f9a..bccf94fc217f 100644
--- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -15,6 +15,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
#define AA_NAME "alignment-from-assumptions"
#define DEBUG_TYPE AA_NAME
@@ -30,6 +31,7 @@
#include "llvm/IR/Constant.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
@@ -90,9 +92,9 @@ FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
// to a constant. Using SCEV to compute alignment handles the case where
// DiffSCEV is a recurrence with constant start such that the aligned offset
// is constant. e.g. {16,+,32} % 32 -> 16.
-static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
- const SCEV *AlignSCEV,
- ScalarEvolution *SE) {
+static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
+ const SCEV *AlignSCEV,
+ ScalarEvolution *SE) {
// DiffUnits = Diff % int64_t(Alignment)
const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
@@ -107,26 +109,30 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
// displaced pointer has the same alignment as the aligned pointer, so
// return the alignment value.
if (!DiffUnits)
- return (unsigned)
- cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue();
+ return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
// If the displacement is not an exact multiple, but the remainder is a
// constant, then return this remainder (but only if it is a power of 2).
uint64_t DiffUnitsAbs = std::abs(DiffUnits);
if (isPowerOf2_64(DiffUnitsAbs))
- return (unsigned) DiffUnitsAbs;
+ return Align(DiffUnitsAbs);
}
- return 0;
+ return None;
}
// There is an address given by an offset OffSCEV from AASCEV which has an
// alignment AlignSCEV. Use that information, if possible, to compute a new
// alignment for Ptr.
-static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
- const SCEV *OffSCEV, Value *Ptr,
- ScalarEvolution *SE) {
+static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
+ const SCEV *OffSCEV, Value *Ptr,
+ ScalarEvolution *SE) {
const SCEV *PtrSCEV = SE->getSCEV(Ptr);
+ // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
+ // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
+ // may disagree. Trunc/extend so they agree.
+ PtrSCEV = SE->getTruncateOrZeroExtend(
+ PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
// On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
@@ -141,13 +147,12 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
<< *AlignSCEV << " and offset " << *OffSCEV
<< " using diff " << *DiffSCEV << "\n");
- unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE);
- LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n");
+ if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
+ LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
+ return *NewAlignment;
+ }
- if (NewAlignment) {
- return NewAlignment;
- } else if (const SCEVAddRecExpr *DiffARSCEV =
- dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
+ if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
// The relative offset to the alignment assumption did not yield a constant,
// but we should try harder: if we assume that a is 32-byte aligned, then in
// for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
@@ -165,134 +170,67 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
// first iteration, and also the alignment using the per-iteration delta.
// If these are the same, then use that answer. Otherwise, use the smaller
// one, but only if it divides the larger one.
- NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
- unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
-
- LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n");
- LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n");
-
- if (!NewAlignment || !NewIncAlignment) {
- return 0;
- } else if (NewAlignment > NewIncAlignment) {
- if (NewAlignment % NewIncAlignment == 0) {
- LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment
- << "\n");
- return NewIncAlignment;
- }
- } else if (NewIncAlignment > NewAlignment) {
- if (NewIncAlignment % NewAlignment == 0) {
- LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
- << "\n");
- return NewAlignment;
- }
- } else if (NewIncAlignment == NewAlignment) {
- LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
+ MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
+ MaybeAlign NewIncAlignment =
+ getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
+
+ LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
+ << "\n");
+ LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
+ << "\n");
+
+ if (!NewAlignment || !NewIncAlignment)
+ return Align(1);
+
+ const Align NewAlign = *NewAlignment;
+ const Align NewIncAlign = *NewIncAlignment;
+ if (NewAlign > NewIncAlign) {
+ LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
+ << DebugStr(NewIncAlign) << "\n");
+ return NewIncAlign;
+ }
+ if (NewIncAlign > NewAlign) {
+ LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
<< "\n");
- return NewAlignment;
+ return NewAlign;
}
+ assert(NewIncAlign == NewAlign);
+ LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
+ << "\n");
+ return NewAlign;
}
- return 0;
+ return Align(1);
}
bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
+ unsigned Idx,
Value *&AAPtr,
const SCEV *&AlignSCEV,
const SCEV *&OffSCEV) {
- // An alignment assume must be a statement about the least-significant
- // bits of the pointer being zero, possibly with some offset.
- ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
- if (!ICI)
- return false;
-
- // This must be an expression of the form: x & m == 0.
- if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
- return false;
-
- // Swap things around so that the RHS is 0.
- Value *CmpLHS = ICI->getOperand(0);
- Value *CmpRHS = ICI->getOperand(1);
- const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
- const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
- if (CmpLHSSCEV->isZero())
- std::swap(CmpLHS, CmpRHS);
- else if (!CmpRHSSCEV->isZero())
+ Type *Int64Ty = Type::getInt64Ty(I->getContext());
+ OperandBundleUse AlignOB = I->getOperandBundleAt(Idx);
+ if (AlignOB.getTagName() != "align")
return false;
-
- BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
- if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
- return false;
-
- // Swap things around so that the right operand of the and is a constant
- // (the mask); we cannot deal with variable masks.
- Value *AndLHS = CmpBO->getOperand(0);
- Value *AndRHS = CmpBO->getOperand(1);
- const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
- const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
- if (isa<SCEVConstant>(AndLHSSCEV)) {
- std::swap(AndLHS, AndRHS);
- std::swap(AndLHSSCEV, AndRHSSCEV);
- }
-
- const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
- if (!MaskSCEV)
- return false;
-
- // The mask must have some trailing ones (otherwise the condition is
- // trivial and tells us nothing about the alignment of the left operand).
- unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
- if (!TrailingOnes)
- return false;
-
- // Cap the alignment at the maximum with which LLVM can deal (and make sure
- // we don't overflow the shift).
- uint64_t Alignment;
- TrailingOnes = std::min(TrailingOnes,
- unsigned(sizeof(unsigned) * CHAR_BIT - 1));
- Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
-
- Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
- AlignSCEV = SE->getConstant(Int64Ty, Alignment);
-
- // The LHS might be a ptrtoint instruction, or it might be the pointer
- // with an offset.
- AAPtr = nullptr;
- OffSCEV = nullptr;
- if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
- AAPtr = PToI->getPointerOperand();
+ assert(AlignOB.Inputs.size() >= 2);
+ AAPtr = AlignOB.Inputs[0].get();
+ // TODO: Consider accumulating the offset to the base.
+ AAPtr = AAPtr->stripPointerCastsSameRepresentation();
+ AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get());
+ AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
+ if (AlignOB.Inputs.size() == 3)
+ OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());
+ else
OffSCEV = SE->getZero(Int64Ty);
- } else if (const SCEVAddExpr* AndLHSAddSCEV =
- dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
- // Try to find the ptrtoint; subtract it and the rest is the offset.
- for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
- JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
- if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
- if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
- AAPtr = PToI->getPointerOperand();
- OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
- break;
- }
- }
-
- if (!AAPtr)
- return false;
-
- // Sign extend the offset to 64 bits (so that it is like all of the other
- // expressions).
- unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
- if (OffSCEVBits < 64)
- OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
- else if (OffSCEVBits > 64)
- return false;
-
- AAPtr = AAPtr->stripPointerCasts();
+ OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
return true;
}
-bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
+bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
+ unsigned Idx) {
Value *AAPtr;
const SCEV *AlignSCEV, *OffSCEV;
- if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
+ if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))
return false;
// Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
@@ -310,35 +248,38 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
continue;
if (Instruction *K = dyn_cast<Instruction>(J))
- if (isValidAssumeForContext(ACall, K, DT))
WorkList.push_back(K);
}
while (!WorkList.empty()) {
Instruction *J = WorkList.pop_back_val();
-
if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
- unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
- LI->getPointerOperand(), SE);
-
- if (NewAlignment > LI->getAlignment()) {
- LI->setAlignment(MaybeAlign(NewAlignment));
+ if (!isValidAssumeForContext(ACall, J, DT))
+ continue;
+ Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
+ LI->getPointerOperand(), SE);
+ if (NewAlignment > LI->getAlign()) {
+ LI->setAlignment(NewAlignment);
++NumLoadAlignChanged;
}
} else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
- unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
- SI->getPointerOperand(), SE);
-
- if (NewAlignment > SI->getAlignment()) {
- SI->setAlignment(MaybeAlign(NewAlignment));
+ if (!isValidAssumeForContext(ACall, J, DT))
+ continue;
+ Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
+ SI->getPointerOperand(), SE);
+ if (NewAlignment > SI->getAlign()) {
+ SI->setAlignment(NewAlignment);
++NumStoreAlignChanged;
}
} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
- unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
- MI->getDest(), SE);
-
- LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";);
- if (NewDestAlignment > MI->getDestAlignment()) {
+ if (!isValidAssumeForContext(ACall, J, DT))
+ continue;
+ Align NewDestAlignment =
+ getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
+
+ LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
+ << "\n";);
+ if (NewDestAlignment > *MI->getDestAlign()) {
MI->setDestAlignment(NewDestAlignment);
++NumMemIntAlignChanged;
}
@@ -346,12 +287,13 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
// For memory transfers, there is also a source alignment that
// can be set.
if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
- unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
- MTI->getSource(), SE);
+ Align NewSrcAlignment =
+ getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
- LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";);
+ LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
+ << "\n";);
- if (NewSrcAlignment > MTI->getSourceAlignment()) {
+ if (NewSrcAlignment > *MTI->getSourceAlign()) {
MTI->setSourceAlignment(NewSrcAlignment);
++NumMemIntAlignChanged;
}
@@ -363,7 +305,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
Visited.insert(J);
for (User *UJ : J->users()) {
Instruction *K = cast<Instruction>(UJ);
- if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
+ if (!Visited.count(K))
WorkList.push_back(K);
}
}
@@ -390,8 +332,11 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
bool Changed = false;
for (auto &AssumeVH : AC.assumptions())
- if (AssumeVH)
- Changed |= processAssumption(cast<CallInst>(AssumeVH));
+ if (AssumeVH) {
+ CallInst *Call = cast<CallInst>(AssumeVH);
+ for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
+ Changed |= processAssumption(Call, Idx);
+ }
return Changed;
}
diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp
index 0fa38fa80b17..767c7656dcfa 100644
--- a/llvm/lib/Transforms/Scalar/BDCE.cpp
+++ b/llvm/lib/Transforms/Scalar/BDCE.cpp
@@ -9,7 +9,8 @@
// This file implements the Bit-Tracking Dead Code Elimination pass. Some
// instructions (shifts, some ands, ors, etc.) kill some of their input bits.
// We track these dead bits and remove instructions that compute only these
-// dead bits.
+// dead bits. We also simplify sext that generates unused extension bits,
+// converting it to a zext.
//
//===----------------------------------------------------------------------===//
@@ -19,6 +20,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
@@ -33,6 +35,8 @@ using namespace llvm;
STATISTIC(NumRemoved, "Number of instructions removed (unused)");
STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)");
+STATISTIC(NumSExt2ZExt,
+ "Number of sign extension instructions converted to zero extension");
/// If an instruction is trivialized (dead), then the chain of users of that
/// instruction may need to be cleared of assumptions that can no longer be
@@ -102,13 +106,31 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) {
(I.getType()->isIntOrIntVectorTy() &&
DB.getDemandedBits(&I).isNullValue() &&
wouldInstructionBeTriviallyDead(&I))) {
- salvageDebugInfoOrMarkUndef(I);
+ salvageDebugInfo(I);
Worklist.push_back(&I);
I.dropAllReferences();
Changed = true;
continue;
}
+ // Convert SExt into ZExt if none of the extension bits is required
+ if (SExtInst *SE = dyn_cast<SExtInst>(&I)) {
+ APInt Demanded = DB.getDemandedBits(SE);
+ const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits();
+ auto *const DstTy = SE->getDestTy();
+ const uint32_t DestBitSize = DstTy->getScalarSizeInBits();
+ if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) {
+ clearAssumptionsOfUsers(SE, DB);
+ IRBuilder<> Builder(SE);
+ I.replaceAllUsesWith(
+ Builder.CreateZExt(SE->getOperand(0), DstTy, SE->getName()));
+ Worklist.push_back(SE);
+ Changed = true;
+ NumSExt2ZExt++;
+ continue;
+ }
+ }
+
for (Use &U : I.operands()) {
// DemandedBits only detects dead integer uses.
if (!U->getType()->isIntOrIntVectorTy())
diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index e34c011b1c87..b26bd1114bd4 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -85,37 +85,36 @@ static cl::opt<unsigned>
"their cost is below DuplicationThreshold"),
cl::init(5));
-static void addNonNullAttribute(CallSite CS, Value *Op) {
+static void addNonNullAttribute(CallBase &CB, Value *Op) {
unsigned ArgNo = 0;
- for (auto &I : CS.args()) {
+ for (auto &I : CB.args()) {
if (&*I == Op)
- CS.addParamAttr(ArgNo, Attribute::NonNull);
+ CB.addParamAttr(ArgNo, Attribute::NonNull);
++ArgNo;
}
}
-static void setConstantInArgument(CallSite CS, Value *Op,
+static void setConstantInArgument(CallBase &CB, Value *Op,
Constant *ConstValue) {
unsigned ArgNo = 0;
- for (auto &I : CS.args()) {
+ for (auto &I : CB.args()) {
if (&*I == Op) {
// It is possible we have already added the non-null attribute to the
// parameter by using an earlier constraining condition.
- CS.removeParamAttr(ArgNo, Attribute::NonNull);
- CS.setArgument(ArgNo, ConstValue);
+ CB.removeParamAttr(ArgNo, Attribute::NonNull);
+ CB.setArgOperand(ArgNo, ConstValue);
}
++ArgNo;
}
}
-static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
+static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallBase &CB) {
assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand.");
Value *Op0 = Cmp->getOperand(0);
unsigned ArgNo = 0;
- for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E;
- ++I, ++ArgNo) {
+ for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I, ++ArgNo) {
// Don't consider constant or arguments that are already known non-null.
- if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull))
+ if (isa<Constant>(*I) || CB.paramHasAttr(ArgNo, Attribute::NonNull))
continue;
if (*I == Op0)
@@ -128,8 +127,8 @@ typedef std::pair<ICmpInst *, unsigned> ConditionTy;
typedef SmallVector<ConditionTy, 2> ConditionsTy;
/// If From has a conditional jump to To, add the condition to Conditions,
-/// if it is relevant to any argument at CS.
-static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To,
+/// if it is relevant to any argument at CB.
+static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,
ConditionsTy &Conditions) {
auto *BI = dyn_cast<BranchInst>(From->getTerminator());
if (!BI || !BI->isConditional())
@@ -142,38 +141,38 @@ static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To,
ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
- if (isCondRelevantToAnyCallArgument(Cmp, CS))
+ if (isCondRelevantToAnyCallArgument(Cmp, CB))
Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
? Pred
: Cmp->getInversePredicate()});
}
-/// Record ICmp conditions relevant to any argument in CS following Pred's
+/// Record ICmp conditions relevant to any argument in CB following Pred's
/// single predecessors. If there are conflicting conditions along a path, like
/// x == 1 and x == 0, the first condition will be used. We stop once we reach
/// an edge to StopAt.
-static void recordConditions(CallSite CS, BasicBlock *Pred,
+static void recordConditions(CallBase &CB, BasicBlock *Pred,
ConditionsTy &Conditions, BasicBlock *StopAt) {
BasicBlock *From = Pred;
BasicBlock *To = Pred;
SmallPtrSet<BasicBlock *, 4> Visited;
while (To != StopAt && !Visited.count(From->getSinglePredecessor()) &&
(From = From->getSinglePredecessor())) {
- recordCondition(CS, From, To, Conditions);
+ recordCondition(CB, From, To, Conditions);
Visited.insert(From);
To = From;
}
}
-static void addConditions(CallSite CS, const ConditionsTy &Conditions) {
+static void addConditions(CallBase &CB, const ConditionsTy &Conditions) {
for (auto &Cond : Conditions) {
Value *Arg = Cond.first->getOperand(0);
Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));
if (Cond.second == ICmpInst::ICMP_EQ)
- setConstantInArgument(CS, Arg, ConstVal);
+ setConstantInArgument(CB, Arg, ConstVal);
else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
assert(Cond.second == ICmpInst::ICMP_NE);
- addNonNullAttribute(CS, Arg);
+ addNonNullAttribute(CB, Arg);
}
}
}
@@ -184,17 +183,16 @@ static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
return Preds;
}
-static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) {
- if (CS.isConvergent() || CS.cannotDuplicate())
+static bool canSplitCallSite(CallBase &CB, TargetTransformInfo &TTI) {
+ if (CB.isConvergent() || CB.cannotDuplicate())
return false;
// FIXME: As of now we handle only CallInst. InvokeInst could be handled
// without too much effort.
- Instruction *Instr = CS.getInstruction();
- if (!isa<CallInst>(Instr))
+ if (!isa<CallInst>(CB))
return false;
- BasicBlock *CallSiteBB = Instr->getParent();
+ BasicBlock *CallSiteBB = CB.getParent();
// Need 2 predecessors and cannot split an edge from an IndirectBrInst.
SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB));
if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) ||
@@ -212,7 +210,7 @@ static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) {
// corresponding uses will be updated.
unsigned Cost = 0;
for (auto &InstBeforeCall :
- llvm::make_range(CallSiteBB->begin(), Instr->getIterator())) {
+ llvm::make_range(CallSiteBB->begin(), CB.getIterator())) {
Cost += TTI.getInstructionCost(&InstBeforeCall,
TargetTransformInfo::TCK_CodeSize);
if (Cost >= DuplicationThreshold)
@@ -304,24 +302,23 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI,
/// predecessors, new call-sites with more constrained arguments will be
/// created in createCallSitesOnPredicatedArgument().
static void splitCallSite(
- CallSite CS,
+ CallBase &CB,
const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds,
DomTreeUpdater &DTU) {
- Instruction *Instr = CS.getInstruction();
- BasicBlock *TailBB = Instr->getParent();
- bool IsMustTailCall = CS.isMustTailCall();
+ BasicBlock *TailBB = CB.getParent();
+ bool IsMustTailCall = CB.isMustTailCall();
PHINode *CallPN = nullptr;
// `musttail` calls must be followed by optional `bitcast`, and `ret`. The
// split blocks will be terminated right after that so there're no users for
// this phi in a `TailBB`.
- if (!IsMustTailCall && !Instr->use_empty()) {
- CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call");
- CallPN->setDebugLoc(Instr->getDebugLoc());
+ if (!IsMustTailCall && !CB.use_empty()) {
+ CallPN = PHINode::Create(CB.getType(), Preds.size(), "phi.call");
+ CallPN->setDebugLoc(CB.getDebugLoc());
}
- LLVM_DEBUG(dbgs() << "split call-site : " << *Instr << " into \n");
+ LLVM_DEBUG(dbgs() << "split call-site : " << CB << " into \n");
assert(Preds.size() == 2 && "The ValueToValueMaps array has size 2.");
// ValueToValueMapTy is neither copy nor moveable, so we use a simple array
@@ -330,21 +327,20 @@ static void splitCallSite(
for (unsigned i = 0; i < Preds.size(); i++) {
BasicBlock *PredBB = Preds[i].first;
BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween(
- TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i],
+ TailBB, PredBB, &*std::next(CB.getIterator()), ValueToValueMaps[i],
DTU);
assert(SplitBlock && "Unexpected new basic block split.");
- Instruction *NewCI =
- &*std::prev(SplitBlock->getTerminator()->getIterator());
- CallSite NewCS(NewCI);
- addConditions(NewCS, Preds[i].second);
+ auto *NewCI =
+ cast<CallBase>(&*std::prev(SplitBlock->getTerminator()->getIterator()));
+ addConditions(*NewCI, Preds[i].second);
// Handle PHIs used as arguments in the call-site.
for (PHINode &PN : TailBB->phis()) {
unsigned ArgNo = 0;
- for (auto &CI : CS.args()) {
+ for (auto &CI : CB.args()) {
if (&*CI == &PN) {
- NewCS.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock));
+ NewCI->setArgOperand(ArgNo, PN.getIncomingValueForBlock(SplitBlock));
}
++ArgNo;
}
@@ -356,7 +352,7 @@ static void splitCallSite(
// Clone and place bitcast and return instructions before `TI`
if (IsMustTailCall)
- copyMustTailReturn(SplitBlock, Instr, NewCI);
+ copyMustTailReturn(SplitBlock, &CB, NewCI);
}
NumCallSiteSplit++;
@@ -383,7 +379,7 @@ static void splitCallSite(
// Replace users of the original call with a PHI mering call-sites split.
if (CallPN) {
CallPN->insertBefore(OriginalBegin);
- Instr->replaceAllUsesWith(CallPN);
+ CB.replaceAllUsesWith(CallPN);
}
// Remove instructions moved to split blocks from TailBB, from the duplicated
@@ -393,7 +389,7 @@ static void splitCallSite(
// instruction, so we do not end up deleting them. By using reverse-order, we
// do not introduce unnecessary PHI nodes for def-use chains from the call
// instruction to the beginning of the block.
- auto I = Instr->getReverseIterator();
+ auto I = CB.getReverseIterator();
while (I != TailBB->rend()) {
Instruction *CurrentI = &*I++;
if (!CurrentI->use_empty()) {
@@ -418,28 +414,25 @@ static void splitCallSite(
// Return true if the call-site has an argument which is a PHI with only
// constant incoming values.
-static bool isPredicatedOnPHI(CallSite CS) {
- Instruction *Instr = CS.getInstruction();
- BasicBlock *Parent = Instr->getParent();
- if (Instr != Parent->getFirstNonPHIOrDbg())
+static bool isPredicatedOnPHI(CallBase &CB) {
+ BasicBlock *Parent = CB.getParent();
+ if (&CB != Parent->getFirstNonPHIOrDbg())
return false;
- for (auto &BI : *Parent) {
- if (PHINode *PN = dyn_cast<PHINode>(&BI)) {
- for (auto &I : CS.args())
- if (&*I == PN) {
- assert(PN->getNumIncomingValues() == 2 &&
- "Unexpected number of incoming values");
- if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1))
- return false;
- if (PN->getIncomingValue(0) == PN->getIncomingValue(1))
- continue;
- if (isa<Constant>(PN->getIncomingValue(0)) &&
- isa<Constant>(PN->getIncomingValue(1)))
- return true;
- }
+ for (auto &PN : Parent->phis()) {
+ for (auto &Arg : CB.args()) {
+ if (&*Arg != &PN)
+ continue;
+ assert(PN.getNumIncomingValues() == 2 &&
+ "Unexpected number of incoming values");
+ if (PN.getIncomingBlock(0) == PN.getIncomingBlock(1))
+ return false;
+ if (PN.getIncomingValue(0) == PN.getIncomingValue(1))
+ continue;
+ if (isa<Constant>(PN.getIncomingValue(0)) &&
+ isa<Constant>(PN.getIncomingValue(1)))
+ return true;
}
- break;
}
return false;
}
@@ -448,20 +441,20 @@ using PredsWithCondsTy = SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2>;
// Check if any of the arguments in CS are predicated on a PHI node and return
// the set of predecessors we should use for splitting.
-static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallSite CS) {
- if (!isPredicatedOnPHI(CS))
+static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallBase &CB) {
+ if (!isPredicatedOnPHI(CB))
return {};
- auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
+ auto Preds = getTwoPredecessors(CB.getParent());
return {{Preds[0], {}}, {Preds[1], {}}};
}
// Checks if any of the arguments in CS are predicated in a predecessor and
// returns a list of predecessors with the conditions that hold on their edges
// to CS.
-static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS,
+static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallBase &CB,
DomTreeUpdater &DTU) {
- auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
+ auto Preds = getTwoPredecessors(CB.getParent());
if (Preds[0] == Preds[1])
return {};
@@ -470,16 +463,16 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS,
// that node will be the same for all paths to the call site and splitting
// is not beneficial.
assert(DTU.hasDomTree() && "We need a DTU with a valid DT!");
- auto *CSDTNode = DTU.getDomTree().getNode(CS.getInstruction()->getParent());
+ auto *CSDTNode = DTU.getDomTree().getNode(CB.getParent());
BasicBlock *StopAt = CSDTNode ? CSDTNode->getIDom()->getBlock() : nullptr;
SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS;
for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) {
ConditionsTy Conditions;
// Record condition on edge BB(CS) <- Pred
- recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions);
+ recordCondition(CB, Pred, CB.getParent(), Conditions);
// Record conditions following Pred's single predecessors.
- recordConditions(CS, Pred, Conditions, StopAt);
+ recordConditions(CB, Pred, Conditions, StopAt);
PredsCS.push_back({Pred, Conditions});
}
@@ -491,19 +484,19 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS,
return PredsCS;
}
-static bool tryToSplitCallSite(CallSite CS, TargetTransformInfo &TTI,
+static bool tryToSplitCallSite(CallBase &CB, TargetTransformInfo &TTI,
DomTreeUpdater &DTU) {
// Check if we can split the call site.
- if (!CS.arg_size() || !canSplitCallSite(CS, TTI))
+ if (!CB.arg_size() || !canSplitCallSite(CB, TTI))
return false;
- auto PredsWithConds = shouldSplitOnPredicatedArgument(CS, DTU);
+ auto PredsWithConds = shouldSplitOnPredicatedArgument(CB, DTU);
if (PredsWithConds.empty())
- PredsWithConds = shouldSplitOnPHIPredicatedArgument(CS);
+ PredsWithConds = shouldSplitOnPHIPredicatedArgument(CB);
if (PredsWithConds.empty())
return false;
- splitCallSite(CS, PredsWithConds, DTU);
+ splitCallSite(CB, PredsWithConds, DTU);
return true;
}
@@ -521,20 +514,19 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI,
// case, IE will be invalidated and we also have to check the current
// terminator.
while (II != IE && &*II != BB.getTerminator()) {
- Instruction *I = &*II++;
- CallSite CS(cast<Value>(I));
- if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI))
+ CallBase *CB = dyn_cast<CallBase>(&*II++);
+ if (!CB || isa<IntrinsicInst>(CB) || isInstructionTriviallyDead(CB, &TLI))
continue;
- Function *Callee = CS.getCalledFunction();
+ Function *Callee = CB->getCalledFunction();
if (!Callee || Callee->isDeclaration())
continue;
// Successful musttail call-site splits result in erased CI and erased BB.
// Check if such path is possible before attempting the splitting.
- bool IsMustTail = CS.isMustTailCall();
+ bool IsMustTail = CB->isMustTailCall();
- Changed |= tryToSplitCallSite(CS, TTI, DTU);
+ Changed |= tryToSplitCallSite(*CB, TTI, DTU);
// There're no interesting instructions after this. The call site
// itself might have been erased on splitting.
diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 5bfece010bec..7c14b69d658d 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -250,7 +250,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
Orders.push_back(Entry);
while (Idx != Orders.size()) {
BasicBlock *Node = Orders[Idx++];
- for (auto ChildDomNode : DT.getNode(Node)->getChildren()) {
+ for (auto ChildDomNode : DT.getNode(Node)->children()) {
if (Candidates.count(ChildDomNode->getBlock()))
Orders.push_back(ChildDomNode->getBlock());
}
@@ -363,10 +363,12 @@ void ConstantHoistingPass::collectConstantCandidates(
// instruction and operand index.
if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst))
Cost = TTI->getIntImmCostIntrin(IntrInst->getIntrinsicID(), Idx,
- ConstInt->getValue(), ConstInt->getType());
+ ConstInt->getValue(), ConstInt->getType(),
+ TargetTransformInfo::TCK_SizeAndLatency);
else
Cost = TTI->getIntImmCostInst(Inst->getOpcode(), Idx, ConstInt->getValue(),
- ConstInt->getType());
+ ConstInt->getType(),
+ TargetTransformInfo::TCK_SizeAndLatency);
// Ignore cheap integer constants.
if (Cost > TargetTransformInfo::TCC_Basic) {
@@ -416,7 +418,8 @@ void ConstantHoistingPass::collectConstantCandidates(
// usually lowered to a load from constant pool. Such operation is unlikely
// to be cheaper than compute it by <Base + Offset>, which can be lowered to
// an ADD instruction or folded into Load/Store instruction.
- int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy);
+ int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy,
+ TargetTransformInfo::TCK_SizeAndLatency);
ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV];
ConstCandMapType::iterator Itr;
bool Inserted;
@@ -491,7 +494,7 @@ void ConstantHoistingPass::collectConstantCandidates(
// take constant variables is lower than `TargetTransformInfo::TCC_Basic`.
// So it's safe for us to collect constant candidates from all
// IntrinsicInsts.
- if (canReplaceOperandWithVariable(Inst, Idx) || isa<IntrinsicInst>(Inst)) {
+ if (canReplaceOperandWithVariable(Inst, Idx)) {
collectConstantCandidates(ConstCandMap, Inst, Idx);
}
} // end of for all operands
@@ -582,7 +585,8 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S,
for (auto User : ConstCand->Uses) {
unsigned Opcode = User.Inst->getOpcode();
unsigned OpndIdx = User.OpndIdx;
- Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty);
+ Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty,
+ TargetTransformInfo::TCK_SizeAndLatency);
LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n");
for (auto C2 = S; C2 != E; ++C2) {
@@ -975,8 +979,8 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F,
auto BFI = ConstHoistWithBlockFrequency
? &AM.getResult<BlockFrequencyAnalysis>(F)
: nullptr;
- auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager();
- auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
+ auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+ auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock(), PSI))
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 3435bc7f5eaa..cd2f4ca36f3b 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -22,7 +22,6 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
@@ -125,7 +124,7 @@ Pass *llvm::createCorrelatedValuePropagationPass() {
static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
if (S->getType()->isVectorTy()) return false;
- if (isa<Constant>(S->getOperand(0))) return false;
+ if (isa<Constant>(S->getCondition())) return false;
Constant *C = LVI->getConstant(S->getCondition(), S->getParent(), S);
if (!C) return false;
@@ -133,11 +132,7 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
ConstantInt *CI = dyn_cast<ConstantInt>(C);
if (!CI) return false;
- Value *ReplaceWith = S->getTrueValue();
- Value *Other = S->getFalseValue();
- if (!CI->isOne()) std::swap(ReplaceWith, Other);
- if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType());
-
+ Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue();
S->replaceAllUsesWith(ReplaceWith);
S->eraseFromParent();
@@ -310,9 +305,10 @@ static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) {
// the comparison is testing local values. While LVI can sometimes reason
// about such cases, it's not its primary purpose. We do make sure to do
// the block local query for uses from terminator instructions, but that's
- // handled in the code for each terminator.
+ // handled in the code for each terminator. As an exception, we allow phi
+ // nodes, for which LVI can thread the condition into predecessors.
auto *I = dyn_cast<Instruction>(Op0);
- if (I && I->getParent() == Cmp->getParent())
+ if (I && I->getParent() == Cmp->getParent() && !isa<PHINode>(I))
return false;
LazyValueInfo::Tristate Result =
@@ -535,18 +531,18 @@ static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {
}
/// Infer nonnull attributes for the arguments at the specified callsite.
-static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
+static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
SmallVector<unsigned, 4> ArgNos;
unsigned ArgNo = 0;
- if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) {
+ if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) {
if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
processOverflowIntrinsic(WO, LVI);
return true;
}
}
- if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) {
+ if (auto *SI = dyn_cast<SaturatingInst>(&CB)) {
if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
processSaturatingInst(SI, LVI);
return true;
@@ -559,8 +555,8 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
// desireable since it may allow further optimization of that value (e.g. via
// single use rules in instcombine). Since deopt uses tend to,
// idiomatically, appear along rare conditional paths, it's reasonable likely
- // we may have a conditional fact with which LVI can fold.
- if (auto DeoptBundle = CS.getOperandBundle(LLVMContext::OB_deopt)) {
+ // we may have a conditional fact with which LVI can fold.
+ if (auto DeoptBundle = CB.getOperandBundle(LLVMContext::OB_deopt)) {
bool Progress = false;
for (const Use &ConstU : DeoptBundle->Inputs) {
Use &U = const_cast<Use&>(ConstU);
@@ -568,7 +564,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
if (V->getType()->isVectorTy()) continue;
if (isa<Constant>(V)) continue;
- Constant *C = LVI->getConstant(V, CS.getParent(), CS.getInstruction());
+ Constant *C = LVI->getConstant(V, CB.getParent(), &CB);
if (!C) continue;
U.set(C);
Progress = true;
@@ -577,30 +573,30 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
return true;
}
- for (Value *V : CS.args()) {
+ for (Value *V : CB.args()) {
PointerType *Type = dyn_cast<PointerType>(V->getType());
// Try to mark pointer typed parameters as non-null. We skip the
// relatively expensive analysis for constants which are obviously either
// null or non-null to start with.
- if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&
+ if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) &&
!isa<Constant>(V) &&
LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,
ConstantPointerNull::get(Type),
- CS.getInstruction()) == LazyValueInfo::False)
+ &CB) == LazyValueInfo::False)
ArgNos.push_back(ArgNo);
ArgNo++;
}
- assert(ArgNo == CS.arg_size() && "sanity check");
+ assert(ArgNo == CB.arg_size() && "sanity check");
if (ArgNos.empty())
return false;
- AttributeList AS = CS.getAttributes();
- LLVMContext &Ctx = CS.getInstruction()->getContext();
+ AttributeList AS = CB.getAttributes();
+ LLVMContext &Ctx = CB.getContext();
AS = AS.addParamAttribute(Ctx, ArgNos,
Attribute::get(Ctx, Attribute::NonNull));
- CS.setAttributes(AS);
+ CB.setAttributes(AS);
return true;
}
@@ -793,7 +789,10 @@ static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) {
if (!RHS || !RHS->getValue().isMask())
return false;
- ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp);
+ // We can only replace the AND with LHS based on range info if the range does
+ // not include undef.
+ ConstantRange LRange =
+ LVI->getConstantRange(LHS, BB, BinOp, /*UndefAllowed=*/false);
if (!LRange.getUnsignedMax().ule(RHS->getValue()))
return false;
@@ -856,7 +855,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
break;
case Instruction::Call:
case Instruction::Invoke:
- BBChanged |= processCallSite(CallSite(II), LVI);
+ BBChanged |= processCallSite(cast<CallBase>(*II), LVI);
break;
case Instruction::SRem:
BBChanged |= processSRem(cast<BinaryOperator>(II), LVI);
diff --git a/llvm/lib/Transforms/Scalar/DCE.cpp b/llvm/lib/Transforms/Scalar/DCE.cpp
index a4b0c8df98f6..28947482e303 100644
--- a/llvm/lib/Transforms/Scalar/DCE.cpp
+++ b/llvm/lib/Transforms/Scalar/DCE.cpp
@@ -25,6 +25,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/DebugCounter.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
@@ -127,6 +128,7 @@ static bool DCEInstruction(Instruction *I,
return false;
salvageDebugInfo(*I);
+ salvageKnowledge(I);
// Null out all of the instruction's operands to see if any operand becomes
// dead as we go.
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 1ba4aab999e1..e58db03225ee 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -29,17 +30,19 @@
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
-#include "llvm/Analysis/OrderedBasicBlock.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
@@ -48,16 +51,19 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugCounter.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
#include <cassert>
@@ -68,14 +74,23 @@
#include <utility>
using namespace llvm;
+using namespace PatternMatch;
#define DEBUG_TYPE "dse"
+STATISTIC(NumRemainingStores, "Number of stores remaining after DSE");
STATISTIC(NumRedundantStores, "Number of redundant stores deleted");
STATISTIC(NumFastStores, "Number of stores deleted");
STATISTIC(NumFastOther, "Number of other instrs removed");
STATISTIC(NumCompletePartials, "Number of stores dead by later partials");
STATISTIC(NumModifiedStores, "Number of stores modified");
+STATISTIC(NumNoopStores, "Number of noop stores deleted");
+STATISTIC(NumCFGChecks, "Number of stores modified");
+STATISTIC(NumCFGTries, "Number of stores modified");
+STATISTIC(NumCFGSuccess, "Number of stores modified");
+
+DEBUG_COUNTER(MemorySSACounter, "dse-memoryssa",
+ "Controls which MemoryDefs are eliminated.");
static cl::opt<bool>
EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking",
@@ -87,6 +102,25 @@ EnablePartialStoreMerging("enable-dse-partial-store-merging",
cl::init(true), cl::Hidden,
cl::desc("Enable partial store merging in DSE"));
+static cl::opt<bool>
+ EnableMemorySSA("enable-dse-memoryssa", cl::init(false), cl::Hidden,
+ cl::desc("Use the new MemorySSA-backed DSE."));
+
+static cl::opt<unsigned>
+ MemorySSAScanLimit("dse-memoryssa-scanlimit", cl::init(100), cl::Hidden,
+ cl::desc("The number of memory instructions to scan for "
+ "dead store elimination (default = 100)"));
+
+static cl::opt<unsigned> MemorySSADefsPerBlockLimit(
+ "dse-memoryssa-defs-per-block-limit", cl::init(5000), cl::Hidden,
+ cl::desc("The number of MemoryDefs we consider as candidates to eliminated "
+ "other stores per basic block (default = 5000)"));
+
+static cl::opt<unsigned> MemorySSAPathCheckLimit(
+ "dse-memoryssa-path-check-limit", cl::init(50), cl::Hidden,
+ cl::desc("The maximum number of blocks to check when trying to prove that "
+ "all paths to an exit go through a killing block (default = 50)"));
+
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
@@ -100,7 +134,7 @@ using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>;
static void
deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,
MemoryDependenceResults &MD, const TargetLibraryInfo &TLI,
- InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB,
+ InstOverlapIntervalsTy &IOL,
MapVector<Instruction *, bool> &ThrowableInst,
SmallSetVector<const Value *, 16> *ValueSet = nullptr) {
SmallVector<Instruction*, 32> NowDeadInsts;
@@ -123,6 +157,7 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,
// Try to preserve debug information attached to the dead instruction.
salvageDebugInfo(*DeadInst);
+ salvageKnowledge(DeadInst);
// This instruction is dead, zap it, in stages. Start by removing it from
// MemDep, which needs to know the operands and needs it to be in the
@@ -143,7 +178,6 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,
if (ValueSet) ValueSet->remove(DeadInst);
IOL.erase(DeadInst);
- OBB.eraseInstruction(DeadInst);
if (NewIter == DeadInst->getIterator())
NewIter = DeadInst->eraseFromParent();
@@ -177,19 +211,17 @@ static bool hasAnalyzableMemoryWrite(Instruction *I,
return true;
}
}
- if (auto CS = CallSite(I)) {
- if (Function *F = CS.getCalledFunction()) {
- LibFunc LF;
- if (TLI.getLibFunc(*F, LF) && TLI.has(LF)) {
- switch (LF) {
- case LibFunc_strcpy:
- case LibFunc_strncpy:
- case LibFunc_strcat:
- case LibFunc_strncat:
- return true;
- default:
- return false;
- }
+ if (auto *CB = dyn_cast<CallBase>(I)) {
+ LibFunc LF;
+ if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) {
+ switch (LF) {
+ case LibFunc_strcpy:
+ case LibFunc_strncpy:
+ case LibFunc_strcat:
+ case LibFunc_strncat:
+ return true;
+ default:
+ return false;
}
}
}
@@ -222,10 +254,10 @@ static MemoryLocation getLocForWrite(Instruction *Inst) {
}
}
}
- if (auto CS = CallSite(Inst))
+ if (auto *CB = dyn_cast<CallBase>(Inst))
// All the supported TLI functions so far happen to have dest as their
// first argument.
- return MemoryLocation(CS.getArgument(0));
+ return MemoryLocation(CB->getArgOperand(0));
return MemoryLocation();
}
@@ -272,8 +304,8 @@ static bool isRemovable(Instruction *I) {
}
// note: only get here for calls with analyzable writes - i.e. libcalls
- if (auto CS = CallSite(I))
- return CS.getInstruction()->use_empty();
+ if (auto *CB = dyn_cast<CallBase>(I))
+ return CB->use_empty();
return false;
}
@@ -597,51 +629,82 @@ static bool isPossibleSelfRead(Instruction *Inst,
/// instruction.
static bool memoryIsNotModifiedBetween(Instruction *FirstI,
Instruction *SecondI,
- AliasAnalysis *AA) {
- SmallVector<BasicBlock *, 16> WorkList;
- SmallPtrSet<BasicBlock *, 8> Visited;
+ AliasAnalysis *AA,
+ const DataLayout &DL,
+ DominatorTree *DT) {
+ // Do a backwards scan through the CFG from SecondI to FirstI. Look for
+ // instructions which can modify the memory location accessed by SecondI.
+ //
+ // While doing the walk keep track of the address to check. It might be
+ // different in different basic blocks due to PHI translation.
+ using BlockAddressPair = std::pair<BasicBlock *, PHITransAddr>;
+ SmallVector<BlockAddressPair, 16> WorkList;
+ // Keep track of the address we visited each block with. Bail out if we
+ // visit a block with different addresses.
+ DenseMap<BasicBlock *, Value *> Visited;
+
BasicBlock::iterator FirstBBI(FirstI);
++FirstBBI;
BasicBlock::iterator SecondBBI(SecondI);
BasicBlock *FirstBB = FirstI->getParent();
BasicBlock *SecondBB = SecondI->getParent();
MemoryLocation MemLoc = MemoryLocation::get(SecondI);
+ auto *MemLocPtr = const_cast<Value *>(MemLoc.Ptr);
- // Start checking the store-block.
- WorkList.push_back(SecondBB);
+ // Start checking the SecondBB.
+ WorkList.push_back(
+ std::make_pair(SecondBB, PHITransAddr(MemLocPtr, DL, nullptr)));
bool isFirstBlock = true;
- // Check all blocks going backward until we reach the load-block.
+ // Check all blocks going backward until we reach the FirstBB.
while (!WorkList.empty()) {
- BasicBlock *B = WorkList.pop_back_val();
+ BlockAddressPair Current = WorkList.pop_back_val();
+ BasicBlock *B = Current.first;
+ PHITransAddr &Addr = Current.second;
+ Value *Ptr = Addr.getAddr();
- // Ignore instructions before LI if this is the FirstBB.
+ // Ignore instructions before FirstI if this is the FirstBB.
BasicBlock::iterator BI = (B == FirstBB ? FirstBBI : B->begin());
BasicBlock::iterator EI;
if (isFirstBlock) {
- // Ignore instructions after SI if this is the first visit of SecondBB.
+ // Ignore instructions after SecondI if this is the first visit of SecondBB.
assert(B == SecondBB && "first block is not the store block");
EI = SecondBBI;
isFirstBlock = false;
} else {
// It's not SecondBB or (in case of a loop) the second visit of SecondBB.
- // In this case we also have to look at instructions after SI.
+ // In this case we also have to look at instructions after SecondI.
EI = B->end();
}
for (; BI != EI; ++BI) {
Instruction *I = &*BI;
if (I->mayWriteToMemory() && I != SecondI)
- if (isModSet(AA->getModRefInfo(I, MemLoc)))
+ if (isModSet(AA->getModRefInfo(I, MemLoc.getWithNewPtr(Ptr))))
return false;
}
if (B != FirstBB) {
assert(B != &FirstBB->getParent()->getEntryBlock() &&
"Should not hit the entry block because SI must be dominated by LI");
for (auto PredI = pred_begin(B), PE = pred_end(B); PredI != PE; ++PredI) {
- if (!Visited.insert(*PredI).second)
+ PHITransAddr PredAddr = Addr;
+ if (PredAddr.NeedsPHITranslationFromBlock(B)) {
+ if (!PredAddr.IsPotentiallyPHITranslatable())
+ return false;
+ if (PredAddr.PHITranslateValue(B, *PredI, DT, false))
+ return false;
+ }
+ Value *TranslatedPtr = PredAddr.getAddr();
+ auto Inserted = Visited.insert(std::make_pair(*PredI, TranslatedPtr));
+ if (!Inserted.second) {
+ // We already visited this block before. If it was with a different
+ // address - bail out!
+ if (TranslatedPtr != Inserted.first->second)
+ return false;
+ // ... otherwise just skip it.
continue;
- WorkList.push_back(*PredI);
+ }
+ WorkList.push_back(std::make_pair(*PredI, PredAddr));
}
}
}
@@ -669,7 +732,7 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks,
static bool handleFree(CallInst *F, AliasAnalysis *AA,
MemoryDependenceResults *MD, DominatorTree *DT,
const TargetLibraryInfo *TLI,
- InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB,
+ InstOverlapIntervalsTy &IOL,
MapVector<Instruction *, bool> &ThrowableInst) {
bool MadeChange = false;
@@ -704,7 +767,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA,
// DCE instructions only used to calculate that store.
BasicBlock::iterator BBI(Dependency);
- deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, OBB,
+ deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL,
ThrowableInst);
++NumFastStores;
MadeChange = true;
@@ -762,7 +825,7 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc,
static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
MemoryDependenceResults *MD,
const TargetLibraryInfo *TLI,
- InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB,
+ InstOverlapIntervalsTy &IOL,
MapVector<Instruction *, bool> &ThrowableInst) {
bool MadeChange = false;
@@ -785,7 +848,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
// Treat byval or inalloca arguments the same, stores to them are dead at the
// end of the function.
for (Argument &AI : BB.getParent()->args())
- if (AI.hasByValOrInAllocaAttr())
+ if (AI.hasPassPointeeByValueAttr())
DeadStackObjects.insert(&AI);
const DataLayout &DL = BB.getModule()->getDataLayout();
@@ -824,7 +887,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
<< '\n');
// DCE instructions only used to calculate that store.
- deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst,
+ deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, ThrowableInst,
&DeadStackObjects);
++NumFastStores;
MadeChange = true;
@@ -836,7 +899,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
if (isInstructionTriviallyDead(&*BBI, TLI)) {
LLVM_DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: "
<< *&*BBI << '\n');
- deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst,
+ deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, ThrowableInst,
&DeadStackObjects);
++NumFastOther;
MadeChange = true;
@@ -1043,8 +1106,8 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,
const DataLayout &DL,
const TargetLibraryInfo *TLI,
InstOverlapIntervalsTy &IOL,
- OrderedBasicBlock &OBB,
- MapVector<Instruction *, bool> &ThrowableInst) {
+ MapVector<Instruction *, bool> &ThrowableInst,
+ DominatorTree *DT) {
// Must be a store instruction.
StoreInst *SI = dyn_cast<StoreInst>(Inst);
if (!SI)
@@ -1054,13 +1117,14 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,
// then the store can be removed.
if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) {
if (SI->getPointerOperand() == DepLoad->getPointerOperand() &&
- isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) {
+ isRemovable(SI) &&
+ memoryIsNotModifiedBetween(DepLoad, SI, AA, DL, DT)) {
LLVM_DEBUG(
dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: "
<< *DepLoad << "\n STORE: " << *SI << '\n');
- deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst);
+ deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst);
++NumRedundantStores;
return true;
}
@@ -1073,12 +1137,12 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,
dyn_cast<Instruction>(GetUnderlyingObject(SI->getPointerOperand(), DL));
if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) &&
- memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) {
+ memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA, DL, DT)) {
LLVM_DEBUG(
dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: "
<< *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n');
- deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst);
+ deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst);
++NumRedundantStores;
return true;
}
@@ -1086,13 +1150,58 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,
return false;
}
+static Constant *
+tryToMergePartialOverlappingStores(StoreInst *Earlier, StoreInst *Later,
+ int64_t InstWriteOffset,
+ int64_t DepWriteOffset, const DataLayout &DL,
+ AliasAnalysis *AA, DominatorTree *DT) {
+
+ if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) &&
+ DL.typeSizeEqualsStoreSize(Earlier->getValueOperand()->getType()) &&
+ Later && isa<ConstantInt>(Later->getValueOperand()) &&
+ DL.typeSizeEqualsStoreSize(Later->getValueOperand()->getType()) &&
+ memoryIsNotModifiedBetween(Earlier, Later, AA, DL, DT)) {
+ // If the store we find is:
+ // a) partially overwritten by the store to 'Loc'
+ // b) the later store is fully contained in the earlier one and
+ // c) they both have a constant value
+ // d) none of the two stores need padding
+ // Merge the two stores, replacing the earlier store's value with a
+ // merge of both values.
+ // TODO: Deal with other constant types (vectors, etc), and probably
+ // some mem intrinsics (if needed)
+
+ APInt EarlierValue =
+ cast<ConstantInt>(Earlier->getValueOperand())->getValue();
+ APInt LaterValue = cast<ConstantInt>(Later->getValueOperand())->getValue();
+ unsigned LaterBits = LaterValue.getBitWidth();
+ assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth());
+ LaterValue = LaterValue.zext(EarlierValue.getBitWidth());
+
+ // Offset of the smaller store inside the larger store
+ unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8;
+ unsigned LShiftAmount = DL.isBigEndian() ? EarlierValue.getBitWidth() -
+ BitOffsetDiff - LaterBits
+ : BitOffsetDiff;
+ APInt Mask = APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount,
+ LShiftAmount + LaterBits);
+ // Clear the bits we'll be replacing, then OR with the smaller
+ // store, shifted appropriately.
+ APInt Merged = (EarlierValue & ~Mask) | (LaterValue << LShiftAmount);
+ LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *Earlier
+ << "\n Later: " << *Later
+ << "\n Merged Value: " << Merged << '\n');
+ return ConstantInt::get(Earlier->getValueOperand()->getType(), Merged);
+ }
+ return nullptr;
+}
+
static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
MemoryDependenceResults *MD, DominatorTree *DT,
const TargetLibraryInfo *TLI) {
const DataLayout &DL = BB.getModule()->getDataLayout();
bool MadeChange = false;
- OrderedBasicBlock OBB(&BB);
MapVector<Instruction *, bool> ThrowableInst;
// A map of interval maps representing partially-overwritten value parts.
@@ -1102,7 +1211,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) {
// Handle 'free' calls specially.
if (CallInst *F = isFreeCall(&*BBI, TLI)) {
- MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, OBB, ThrowableInst);
+ MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, ThrowableInst);
// Increment BBI after handleFree has potentially deleted instructions.
// This ensures we maintain a valid iterator.
++BBI;
@@ -1121,14 +1230,14 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
continue;
// eliminateNoopStore will update in iterator, if necessary.
- if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, OBB,
- ThrowableInst)) {
+ if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL,
+ ThrowableInst, DT)) {
MadeChange = true;
continue;
}
// If we find something that writes memory, get its memory dependence.
- MemDepResult InstDep = MD->getDependency(Inst, &OBB);
+ MemDepResult InstDep = MD->getDependency(Inst);
// Ignore any store where we can't find a local dependence.
// FIXME: cross-block DSE would be fun. :)
@@ -1179,7 +1288,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
// If the underlying object is a non-escaping memory allocation, any store
// to it is dead along the unwind edge. Otherwise, we need to preserve
// the store.
- if (LastThrowing && OBB.dominates(DepWrite, LastThrowing)) {
+ if (LastThrowing && DepWrite->comesBefore(LastThrowing)) {
const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL);
bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying);
if (!IsStoreDeadOnUnwind) {
@@ -1210,13 +1319,13 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
<< "\n KILLER: " << *Inst << '\n');
// Delete the store and now-dead instructions that feed it.
- deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB,
+ deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL,
ThrowableInst);
++NumFastStores;
MadeChange = true;
// We erased DepWrite; start over.
- InstDep = MD->getDependency(Inst, &OBB);
+ InstDep = MD->getDependency(Inst);
continue;
} else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) ||
((OR == OW_Begin &&
@@ -1234,53 +1343,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
OR == OW_PartialEarlierWithFullLater) {
auto *Earlier = dyn_cast<StoreInst>(DepWrite);
auto *Later = dyn_cast<StoreInst>(Inst);
- if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) &&
- DL.typeSizeEqualsStoreSize(
- Earlier->getValueOperand()->getType()) &&
- Later && isa<ConstantInt>(Later->getValueOperand()) &&
- DL.typeSizeEqualsStoreSize(
- Later->getValueOperand()->getType()) &&
- memoryIsNotModifiedBetween(Earlier, Later, AA)) {
- // If the store we find is:
- // a) partially overwritten by the store to 'Loc'
- // b) the later store is fully contained in the earlier one and
- // c) they both have a constant value
- // d) none of the two stores need padding
- // Merge the two stores, replacing the earlier store's value with a
- // merge of both values.
- // TODO: Deal with other constant types (vectors, etc), and probably
- // some mem intrinsics (if needed)
-
- APInt EarlierValue =
- cast<ConstantInt>(Earlier->getValueOperand())->getValue();
- APInt LaterValue =
- cast<ConstantInt>(Later->getValueOperand())->getValue();
- unsigned LaterBits = LaterValue.getBitWidth();
- assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth());
- LaterValue = LaterValue.zext(EarlierValue.getBitWidth());
-
- // Offset of the smaller store inside the larger store
- unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8;
- unsigned LShiftAmount =
- DL.isBigEndian()
- ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits
- : BitOffsetDiff;
- APInt Mask =
- APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount,
- LShiftAmount + LaterBits);
- // Clear the bits we'll be replacing, then OR with the smaller
- // store, shifted appropriately.
- APInt Merged =
- (EarlierValue & ~Mask) | (LaterValue << LShiftAmount);
- LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite
- << "\n Later: " << *Inst
- << "\n Merged Value: " << Merged << '\n');
-
+ if (Constant *C = tryToMergePartialOverlappingStores(
+ Earlier, Later, InstWriteOffset, DepWriteOffset, DL, AA,
+ DT)) {
auto *SI = new StoreInst(
- ConstantInt::get(Earlier->getValueOperand()->getType(), Merged),
- Earlier->getPointerOperand(), false,
- MaybeAlign(Earlier->getAlignment()), Earlier->getOrdering(),
- Earlier->getSyncScopeID(), DepWrite);
+ C, Earlier->getPointerOperand(), false, Earlier->getAlign(),
+ Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite);
unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,
LLVMContext::MD_alias_scope,
@@ -1289,13 +1357,10 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
SI->copyMetadata(*DepWrite, MDToKeep);
++NumModifiedStores;
- // Remove earlier, wider, store
- OBB.replaceInstruction(DepWrite, SI);
-
// Delete the old stores and now-dead instructions that feed them.
- deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, OBB,
+ deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL,
ThrowableInst);
- deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB,
+ deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL,
ThrowableInst);
MadeChange = true;
@@ -1331,7 +1396,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
// If this block ends in a return, unwind, or unreachable, all allocas are
// dead at its end, which means stores to them are also dead.
if (BB.getTerminator()->getNumSuccessors() == 0)
- MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, OBB, ThrowableInst);
+ MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, ThrowableInst);
return MadeChange;
}
@@ -1349,22 +1414,913 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis *AA,
return MadeChange;
}
+namespace {
+//=============================================================================
+// MemorySSA backed dead store elimination.
+//
+// The code below implements dead store elimination using MemorySSA. It uses
+// the following general approach: given a MemoryDef, walk upwards to find
+// clobbering MemoryDefs that may be killed by the starting def. Then check
+// that there are no uses that may read the location of the original MemoryDef
+// in between both MemoryDefs. A bit more concretely:
+//
+// For all MemoryDefs StartDef:
+// 1. Get the next dominating clobbering MemoryDef (DomAccess) by walking
+// upwards.
+// 2. Check that there are no reads between DomAccess and the StartDef by
+// checking all uses starting at DomAccess and walking until we see StartDef.
+// 3. For each found DomDef, check that:
+// 1. There are no barrier instructions between DomDef and StartDef (like
+// throws or stores with ordering constraints).
+// 2. StartDef is executed whenever DomDef is executed.
+// 3. StartDef completely overwrites DomDef.
+// 4. Erase DomDef from the function and MemorySSA.
+
+// Returns true if \p M is an intrisnic that does not read or write memory.
+bool isNoopIntrinsic(MemoryUseOrDef *M) {
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(M->getMemoryInst())) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ case Intrinsic::invariant_end:
+ case Intrinsic::launder_invariant_group:
+ case Intrinsic::assume:
+ return true;
+ case Intrinsic::dbg_addr:
+ case Intrinsic::dbg_declare:
+ case Intrinsic::dbg_label:
+ case Intrinsic::dbg_value:
+ llvm_unreachable("Intrinsic should not be modeled in MemorySSA");
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
+// Check if we can ignore \p D for DSE.
+bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
+ Instruction *DI = D->getMemoryInst();
+ // Calls that only access inaccessible memory cannot read or write any memory
+ // locations we consider for elimination.
+ if (auto *CB = dyn_cast<CallBase>(DI))
+ if (CB->onlyAccessesInaccessibleMemory())
+ return true;
+
+ // We can eliminate stores to locations not visible to the caller across
+ // throwing instructions.
+ if (DI->mayThrow() && !DefVisibleToCaller)
+ return true;
+
+ // We can remove the dead stores, irrespective of the fence and its ordering
+ // (release/acquire/seq_cst). Fences only constraints the ordering of
+ // already visible stores, it does not make a store visible to other
+ // threads. So, skipping over a fence does not change a store from being
+ // dead.
+ if (isa<FenceInst>(DI))
+ return true;
+
+ // Skip intrinsics that do not really read or modify memory.
+ if (isNoopIntrinsic(D))
+ return true;
+
+ return false;
+}
+
+struct DSEState {
+ Function &F;
+ AliasAnalysis &AA;
+ MemorySSA &MSSA;
+ DominatorTree &DT;
+ PostDominatorTree &PDT;
+ const TargetLibraryInfo &TLI;
+
+ // All MemoryDefs that potentially could kill other MemDefs.
+ SmallVector<MemoryDef *, 64> MemDefs;
+ // Any that should be skipped as they are already deleted
+ SmallPtrSet<MemoryAccess *, 4> SkipStores;
+ // Keep track of all of the objects that are invisible to the caller before
+ // the function returns.
+ SmallPtrSet<const Value *, 16> InvisibleToCallerBeforeRet;
+ // Keep track of all of the objects that are invisible to the caller after
+ // the function returns.
+ SmallPtrSet<const Value *, 16> InvisibleToCallerAfterRet;
+ // Keep track of blocks with throwing instructions not modeled in MemorySSA.
+ SmallPtrSet<BasicBlock *, 16> ThrowingBlocks;
+ // Post-order numbers for each basic block. Used to figure out if memory
+ // accesses are executed before another access.
+ DenseMap<BasicBlock *, unsigned> PostOrderNumbers;
+
+ /// Keep track of instructions (partly) overlapping with killing MemoryDefs per
+ /// basic block.
+ DenseMap<BasicBlock *, InstOverlapIntervalsTy> IOLs;
+
+ DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT,
+ PostDominatorTree &PDT, const TargetLibraryInfo &TLI)
+ : F(F), AA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI) {}
+
+ static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
+ DominatorTree &DT, PostDominatorTree &PDT,
+ const TargetLibraryInfo &TLI) {
+ DSEState State(F, AA, MSSA, DT, PDT, TLI);
+ // Collect blocks with throwing instructions not modeled in MemorySSA and
+ // alloc-like objects.
+ unsigned PO = 0;
+ for (BasicBlock *BB : post_order(&F)) {
+ State.PostOrderNumbers[BB] = PO++;
+ for (Instruction &I : *BB) {
+ MemoryAccess *MA = MSSA.getMemoryAccess(&I);
+ if (I.mayThrow() && !MA)
+ State.ThrowingBlocks.insert(I.getParent());
+
+ auto *MD = dyn_cast_or_null<MemoryDef>(MA);
+ if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit &&
+ (State.getLocForWriteEx(&I) || State.isMemTerminatorInst(&I)))
+ State.MemDefs.push_back(MD);
+
+ // Track whether alloca and alloca-like objects are visible in the
+ // caller before and after the function returns. Alloca objects are
+ // invalid in the caller, so they are neither visible before or after
+ // the function returns.
+ if (isa<AllocaInst>(&I)) {
+ State.InvisibleToCallerBeforeRet.insert(&I);
+ State.InvisibleToCallerAfterRet.insert(&I);
+ }
+
+ // For alloca-like objects we need to check if they are captured before
+ // the function returns and if the return might capture the object.
+ if (isAllocLikeFn(&I, &TLI)) {
+ bool CapturesBeforeRet = PointerMayBeCaptured(&I, false, true);
+ if (!CapturesBeforeRet) {
+ State.InvisibleToCallerBeforeRet.insert(&I);
+ if (!PointerMayBeCaptured(&I, true, false))
+ State.InvisibleToCallerAfterRet.insert(&I);
+ }
+ }
+ }
+ }
+
+ // Treat byval or inalloca arguments the same as Allocas, stores to them are
+ // dead at the end of the function.
+ for (Argument &AI : F.args())
+ if (AI.hasPassPointeeByValueAttr()) {
+ // For byval, the caller doesn't know the address of the allocation.
+ if (AI.hasByValAttr())
+ State.InvisibleToCallerBeforeRet.insert(&AI);
+ State.InvisibleToCallerAfterRet.insert(&AI);
+ }
+
+ return State;
+ }
+
+ Optional<MemoryLocation> getLocForWriteEx(Instruction *I) const {
+ if (!I->mayWriteToMemory())
+ return None;
+
+ if (auto *MTI = dyn_cast<AnyMemIntrinsic>(I))
+ return {MemoryLocation::getForDest(MTI)};
+
+ if (auto *CB = dyn_cast<CallBase>(I)) {
+ LibFunc LF;
+ if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) {
+ switch (LF) {
+ case LibFunc_strcpy:
+ case LibFunc_strncpy:
+ case LibFunc_strcat:
+ case LibFunc_strncat:
+ return {MemoryLocation(CB->getArgOperand(0))};
+ default:
+ break;
+ }
+ }
+ return None;
+ }
+
+ return MemoryLocation::getOrNone(I);
+ }
+
+ /// Returns true if \p Use completely overwrites \p DefLoc.
+ bool isCompleteOverwrite(MemoryLocation DefLoc, Instruction *UseInst) const {
+ // UseInst has a MemoryDef associated in MemorySSA. It's possible for a
+ // MemoryDef to not write to memory, e.g. a volatile load is modeled as a
+ // MemoryDef.
+ if (!UseInst->mayWriteToMemory())
+ return false;
+
+ if (auto *CB = dyn_cast<CallBase>(UseInst))
+ if (CB->onlyAccessesInaccessibleMemory())
+ return false;
+
+ int64_t InstWriteOffset, DepWriteOffset;
+ auto CC = getLocForWriteEx(UseInst);
+ InstOverlapIntervalsTy IOL;
+
+ const DataLayout &DL = F.getParent()->getDataLayout();
+
+ return CC &&
+ isOverwrite(*CC, DefLoc, DL, TLI, DepWriteOffset, InstWriteOffset,
+ UseInst, IOL, AA, &F) == OW_Complete;
+ }
+
+ /// Returns true if \p Def is not read before returning from the function.
+ bool isWriteAtEndOfFunction(MemoryDef *Def) {
+ LLVM_DEBUG(dbgs() << " Check if def " << *Def << " ("
+ << *Def->getMemoryInst()
+ << ") is at the end the function \n");
+
+ auto MaybeLoc = getLocForWriteEx(Def->getMemoryInst());
+ if (!MaybeLoc) {
+ LLVM_DEBUG(dbgs() << " ... could not get location for write.\n");
+ return false;
+ }
+
+ SmallVector<MemoryAccess *, 4> WorkList;
+ SmallPtrSet<MemoryAccess *, 8> Visited;
+ auto PushMemUses = [&WorkList, &Visited](MemoryAccess *Acc) {
+ if (!Visited.insert(Acc).second)
+ return;
+ for (Use &U : Acc->uses())
+ WorkList.push_back(cast<MemoryAccess>(U.getUser()));
+ };
+ PushMemUses(Def);
+ for (unsigned I = 0; I < WorkList.size(); I++) {
+ if (WorkList.size() >= MemorySSAScanLimit) {
+ LLVM_DEBUG(dbgs() << " ... hit exploration limit.\n");
+ return false;
+ }
+
+ MemoryAccess *UseAccess = WorkList[I];
+ if (isa<MemoryPhi>(UseAccess)) {
+ PushMemUses(UseAccess);
+ continue;
+ }
+
+ // TODO: Checking for aliasing is expensive. Consider reducing the amount
+ // of times this is called and/or caching it.
+ Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst();
+ if (isReadClobber(*MaybeLoc, UseInst)) {
+ LLVM_DEBUG(dbgs() << " ... hit read clobber " << *UseInst << ".\n");
+ return false;
+ }
+
+ if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess))
+ PushMemUses(UseDef);
+ }
+ return true;
+ }
+
+ /// If \p I is a memory terminator like llvm.lifetime.end or free, return a
+ /// pair with the MemoryLocation terminated by \p I and a boolean flag
+ /// indicating whether \p I is a free-like call.
+ Optional<std::pair<MemoryLocation, bool>>
+ getLocForTerminator(Instruction *I) const {
+ uint64_t Len;
+ Value *Ptr;
+ if (match(I, m_Intrinsic<Intrinsic::lifetime_end>(m_ConstantInt(Len),
+ m_Value(Ptr))))
+ return {std::make_pair(MemoryLocation(Ptr, Len), false)};
+
+ if (auto *CB = dyn_cast<CallBase>(I)) {
+ if (isFreeCall(I, &TLI))
+ return {std::make_pair(MemoryLocation(CB->getArgOperand(0)), true)};
+ }
+
+ return None;
+ }
+
+ /// Returns true if \p I is a memory terminator instruction like
+ /// llvm.lifetime.end or free.
+ bool isMemTerminatorInst(Instruction *I) const {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+ return (II && II->getIntrinsicID() == Intrinsic::lifetime_end) ||
+ isFreeCall(I, &TLI);
+ }
+
+ /// Returns true if \p MaybeTerm is a memory terminator for the same
+ /// underlying object as \p DefLoc.
+ bool isMemTerminator(MemoryLocation DefLoc, Instruction *MaybeTerm) const {
+ Optional<std::pair<MemoryLocation, bool>> MaybeTermLoc =
+ getLocForTerminator(MaybeTerm);
+
+ if (!MaybeTermLoc)
+ return false;
+
+ // If the terminator is a free-like call, all accesses to the underlying
+ // object can be considered terminated.
+ if (MaybeTermLoc->second) {
+ DataLayout DL = MaybeTerm->getParent()->getModule()->getDataLayout();
+ DefLoc = MemoryLocation(GetUnderlyingObject(DefLoc.Ptr, DL));
+ }
+ return AA.isMustAlias(MaybeTermLoc->first, DefLoc);
+ }
+
+ // Returns true if \p Use may read from \p DefLoc.
+ bool isReadClobber(MemoryLocation DefLoc, Instruction *UseInst) const {
+ if (!UseInst->mayReadFromMemory())
+ return false;
+
+ if (auto *CB = dyn_cast<CallBase>(UseInst))
+ if (CB->onlyAccessesInaccessibleMemory())
+ return false;
+
+ ModRefInfo MR = AA.getModRefInfo(UseInst, DefLoc);
+ // If necessary, perform additional analysis.
+ if (isRefSet(MR))
+ MR = AA.callCapturesBefore(UseInst, DefLoc, &DT);
+ return isRefSet(MR);
+ }
+
+ // Find a MemoryDef writing to \p DefLoc and dominating \p Current, with no
+ // read access between them or on any other path to a function exit block if
+ // \p DefLoc is not accessible after the function returns. If there is no such
+ // MemoryDef, return None. The returned value may not (completely) overwrite
+ // \p DefLoc. Currently we bail out when we encounter an aliasing MemoryUse
+ // (read).
+ Optional<MemoryAccess *>
+ getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *Current,
+ MemoryLocation DefLoc, bool DefVisibleToCallerBeforeRet,
+ bool DefVisibleToCallerAfterRet, int &ScanLimit) const {
+ MemoryAccess *DomAccess;
+ bool StepAgain;
+ LLVM_DEBUG(dbgs() << " trying to get dominating access for " << *Current
+ << "\n");
+ // Find the next clobbering Mod access for DefLoc, starting at Current.
+ do {
+ StepAgain = false;
+ // Reached TOP.
+ if (MSSA.isLiveOnEntryDef(Current))
+ return None;
+
+ if (isa<MemoryPhi>(Current)) {
+ DomAccess = Current;
+ break;
+ }
+ MemoryUseOrDef *CurrentUD = cast<MemoryUseOrDef>(Current);
+ // Look for access that clobber DefLoc.
+ DomAccess = MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(CurrentUD,
+ DefLoc);
+ if (MSSA.isLiveOnEntryDef(DomAccess))
+ return None;
+
+ if (isa<MemoryPhi>(DomAccess))
+ break;
+
+ // Check if we can skip DomDef for DSE.
+ MemoryDef *DomDef = dyn_cast<MemoryDef>(DomAccess);
+ if (DomDef && canSkipDef(DomDef, DefVisibleToCallerBeforeRet)) {
+ StepAgain = true;
+ Current = DomDef->getDefiningAccess();
+ }
+
+ } while (StepAgain);
+
+ // Accesses to objects accessible after the function returns can only be
+ // eliminated if the access is killed along all paths to the exit. Collect
+ // the blocks with killing (=completely overwriting MemoryDefs) and check if
+ // they cover all paths from DomAccess to any function exit.
+ SmallPtrSet<BasicBlock *, 16> KillingBlocks = {KillingDef->getBlock()};
+ LLVM_DEBUG({
+ dbgs() << " Checking for reads of " << *DomAccess;
+ if (isa<MemoryDef>(DomAccess))
+ dbgs() << " (" << *cast<MemoryDef>(DomAccess)->getMemoryInst() << ")\n";
+ else
+ dbgs() << ")\n";
+ });
+
+ SmallSetVector<MemoryAccess *, 32> WorkList;
+ auto PushMemUses = [&WorkList](MemoryAccess *Acc) {
+ for (Use &U : Acc->uses())
+ WorkList.insert(cast<MemoryAccess>(U.getUser()));
+ };
+ PushMemUses(DomAccess);
+
+ // Check if DomDef may be read.
+ for (unsigned I = 0; I < WorkList.size(); I++) {
+ MemoryAccess *UseAccess = WorkList[I];
+
+ LLVM_DEBUG(dbgs() << " " << *UseAccess);
+ if (--ScanLimit == 0) {
+ LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n");
+ return None;
+ }
+
+ if (isa<MemoryPhi>(UseAccess)) {
+ LLVM_DEBUG(dbgs() << "\n ... adding PHI uses\n");
+ PushMemUses(UseAccess);
+ continue;
+ }
+
+ Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst();
+ LLVM_DEBUG(dbgs() << " (" << *UseInst << ")\n");
+
+ if (isNoopIntrinsic(cast<MemoryUseOrDef>(UseAccess))) {
+ LLVM_DEBUG(dbgs() << " ... adding uses of intrinsic\n");
+ PushMemUses(UseAccess);
+ continue;
+ }
+
+ // A memory terminator kills all preceeding MemoryDefs and all succeeding
+ // MemoryAccesses. We do not have to check it's users.
+ if (isMemTerminator(DefLoc, UseInst))
+ continue;
+
+ // Uses which may read the original MemoryDef mean we cannot eliminate the
+ // original MD. Stop walk.
+ if (isReadClobber(DefLoc, UseInst)) {
+ LLVM_DEBUG(dbgs() << " ... found read clobber\n");
+ return None;
+ }
+
+ // For the KillingDef and DomAccess we only have to check if it reads the
+ // memory location.
+ // TODO: It would probably be better to check for self-reads before
+ // calling the function.
+ if (KillingDef == UseAccess || DomAccess == UseAccess) {
+ LLVM_DEBUG(dbgs() << " ... skipping killing def/dom access\n");
+ continue;
+ }
+
+ // Check all uses for MemoryDefs, except for defs completely overwriting
+ // the original location. Otherwise we have to check uses of *all*
+ // MemoryDefs we discover, including non-aliasing ones. Otherwise we might
+ // miss cases like the following
+ // 1 = Def(LoE) ; <----- DomDef stores [0,1]
+ // 2 = Def(1) ; (2, 1) = NoAlias, stores [2,3]
+ // Use(2) ; MayAlias 2 *and* 1, loads [0, 3].
+ // (The Use points to the *first* Def it may alias)
+ // 3 = Def(1) ; <---- Current (3, 2) = NoAlias, (3,1) = MayAlias,
+ // stores [0,1]
+ if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) {
+ if (isCompleteOverwrite(DefLoc, UseInst)) {
+ if (DefVisibleToCallerAfterRet && UseAccess != DomAccess) {
+ BasicBlock *MaybeKillingBlock = UseInst->getParent();
+ if (PostOrderNumbers.find(MaybeKillingBlock)->second <
+ PostOrderNumbers.find(DomAccess->getBlock())->second) {
+
+ LLVM_DEBUG(dbgs() << " ... found killing block "
+ << MaybeKillingBlock->getName() << "\n");
+ KillingBlocks.insert(MaybeKillingBlock);
+ }
+ }
+ } else
+ PushMemUses(UseDef);
+ }
+ }
+
+ // For accesses to locations visible after the function returns, make sure
+ // that the location is killed (=overwritten) along all paths from DomAccess
+ // to the exit.
+ if (DefVisibleToCallerAfterRet) {
+ assert(!KillingBlocks.empty() &&
+ "Expected at least a single killing block");
+ // Find the common post-dominator of all killing blocks.
+ BasicBlock *CommonPred = *KillingBlocks.begin();
+ for (auto I = std::next(KillingBlocks.begin()), E = KillingBlocks.end();
+ I != E; I++) {
+ if (!CommonPred)
+ break;
+ CommonPred = PDT.findNearestCommonDominator(CommonPred, *I);
+ }
+
+ // If CommonPred is in the set of killing blocks, just check if it
+ // post-dominates DomAccess.
+ if (KillingBlocks.count(CommonPred)) {
+ if (PDT.dominates(CommonPred, DomAccess->getBlock()))
+ return {DomAccess};
+ return None;
+ }
+
+ // If the common post-dominator does not post-dominate DomAccess, there
+ // is a path from DomAccess to an exit not going through a killing block.
+ if (PDT.dominates(CommonPred, DomAccess->getBlock())) {
+ SetVector<BasicBlock *> WorkList;
+
+ // DomAccess's post-order number provides an upper bound of the blocks
+ // on a path starting at DomAccess.
+ unsigned UpperBound =
+ PostOrderNumbers.find(DomAccess->getBlock())->second;
+
+ // If CommonPred is null, there are multiple exits from the function.
+ // They all have to be added to the worklist.
+ if (CommonPred)
+ WorkList.insert(CommonPred);
+ else
+ for (BasicBlock *R : PDT.roots())
+ WorkList.insert(R);
+
+ NumCFGTries++;
+ // Check if all paths starting from an exit node go through one of the
+ // killing blocks before reaching DomAccess.
+ for (unsigned I = 0; I < WorkList.size(); I++) {
+ NumCFGChecks++;
+ BasicBlock *Current = WorkList[I];
+ if (KillingBlocks.count(Current))
+ continue;
+ if (Current == DomAccess->getBlock())
+ return None;
+
+ // DomAccess is reachable from the entry, so we don't have to explore
+ // unreachable blocks further.
+ if (!DT.isReachableFromEntry(Current))
+ continue;
+
+ unsigned CPO = PostOrderNumbers.find(Current)->second;
+ // Current block is not on a path starting at DomAccess.
+ if (CPO > UpperBound)
+ continue;
+ for (BasicBlock *Pred : predecessors(Current))
+ WorkList.insert(Pred);
+
+ if (WorkList.size() >= MemorySSAPathCheckLimit)
+ return None;
+ }
+ NumCFGSuccess++;
+ return {DomAccess};
+ }
+ return None;
+ }
+
+ // No aliasing MemoryUses of DomAccess found, DomAccess is potentially dead.
+ return {DomAccess};
+ }
+
+ // Delete dead memory defs
+ void deleteDeadInstruction(Instruction *SI) {
+ MemorySSAUpdater Updater(&MSSA);
+ SmallVector<Instruction *, 32> NowDeadInsts;
+ NowDeadInsts.push_back(SI);
+ --NumFastOther;
+
+ while (!NowDeadInsts.empty()) {
+ Instruction *DeadInst = NowDeadInsts.pop_back_val();
+ ++NumFastOther;
+
+ // Try to preserve debug information attached to the dead instruction.
+ salvageDebugInfo(*DeadInst);
+ salvageKnowledge(DeadInst);
+
+ // Remove the Instruction from MSSA.
+ if (MemoryAccess *MA = MSSA.getMemoryAccess(DeadInst)) {
+ if (MemoryDef *MD = dyn_cast<MemoryDef>(MA)) {
+ SkipStores.insert(MD);
+ }
+ Updater.removeMemoryAccess(MA);
+ }
+
+ auto I = IOLs.find(DeadInst->getParent());
+ if (I != IOLs.end())
+ I->second.erase(DeadInst);
+ // Remove its operands
+ for (Use &O : DeadInst->operands())
+ if (Instruction *OpI = dyn_cast<Instruction>(O)) {
+ O = nullptr;
+ if (isInstructionTriviallyDead(OpI, &TLI))
+ NowDeadInsts.push_back(OpI);
+ }
+
+ DeadInst->eraseFromParent();
+ }
+ }
+
+ // Check for any extra throws between SI and NI that block DSE. This only
+ // checks extra maythrows (those that aren't MemoryDef's). MemoryDef that may
+ // throw are handled during the walk from one def to the next.
+ bool mayThrowBetween(Instruction *SI, Instruction *NI,
+ const Value *SILocUnd) const {
+ // First see if we can ignore it by using the fact that SI is an
+ // alloca/alloca like object that is not visible to the caller during
+ // execution of the function.
+ if (SILocUnd && InvisibleToCallerBeforeRet.count(SILocUnd))
+ return false;
+
+ if (SI->getParent() == NI->getParent())
+ return ThrowingBlocks.count(SI->getParent());
+ return !ThrowingBlocks.empty();
+ }
+
+ // Check if \p NI acts as a DSE barrier for \p SI. The following instructions
+ // act as barriers:
+ // * A memory instruction that may throw and \p SI accesses a non-stack
+ // object.
+ // * Atomic stores stronger that monotonic.
+ bool isDSEBarrier(const Value *SILocUnd, Instruction *NI) const {
+ // If NI may throw it acts as a barrier, unless we are to an alloca/alloca
+ // like object that does not escape.
+ if (NI->mayThrow() && !InvisibleToCallerBeforeRet.count(SILocUnd))
+ return true;
+
+ // If NI is an atomic load/store stronger than monotonic, do not try to
+ // eliminate/reorder it.
+ if (NI->isAtomic()) {
+ if (auto *LI = dyn_cast<LoadInst>(NI))
+ return isStrongerThanMonotonic(LI->getOrdering());
+ if (auto *SI = dyn_cast<StoreInst>(NI))
+ return isStrongerThanMonotonic(SI->getOrdering());
+ llvm_unreachable("other instructions should be skipped in MemorySSA");
+ }
+ return false;
+ }
+
+ /// Eliminate writes to objects that are not visible in the caller and are not
+ /// accessed before returning from the function.
+ bool eliminateDeadWritesAtEndOfFunction() {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ bool MadeChange = false;
+ LLVM_DEBUG(
+ dbgs()
+ << "Trying to eliminate MemoryDefs at the end of the function\n");
+ for (int I = MemDefs.size() - 1; I >= 0; I--) {
+ MemoryDef *Def = MemDefs[I];
+ if (SkipStores.find(Def) != SkipStores.end() ||
+ !isRemovable(Def->getMemoryInst()))
+ continue;
+
+ // TODO: Consider doing the underlying object check first, if it is
+ // beneficial compile-time wise.
+ if (isWriteAtEndOfFunction(Def)) {
+ Instruction *DefI = Def->getMemoryInst();
+ // See through pointer-to-pointer bitcasts
+ SmallVector<const Value *, 4> Pointers;
+ GetUnderlyingObjects(getLocForWriteEx(DefI)->Ptr, Pointers, DL);
+
+ LLVM_DEBUG(dbgs() << " ... MemoryDef is not accessed until the end "
+ "of the function\n");
+ bool CanKill = true;
+ for (const Value *Pointer : Pointers) {
+ if (!InvisibleToCallerAfterRet.count(Pointer)) {
+ CanKill = false;
+ break;
+ }
+ }
+
+ if (CanKill) {
+ deleteDeadInstruction(DefI);
+ ++NumFastStores;
+ MadeChange = true;
+ }
+ }
+ }
+ return MadeChange;
+ }
+
+ /// \returns true if \p Def is a no-op store, either because it
+ /// directly stores back a loaded value or stores zero to a calloced object.
+ bool storeIsNoop(MemoryDef *Def, MemoryLocation DefLoc, const Value *DefUO) {
+ StoreInst *Store = dyn_cast<StoreInst>(Def->getMemoryInst());
+ if (!Store)
+ return false;
+
+ if (auto *LoadI = dyn_cast<LoadInst>(Store->getOperand(0))) {
+ if (LoadI->getPointerOperand() == Store->getOperand(1)) {
+ auto *LoadAccess = MSSA.getMemoryAccess(LoadI)->getDefiningAccess();
+ // If both accesses share the same defining access, no instructions
+ // between them can modify the memory location.
+ return LoadAccess == Def->getDefiningAccess();
+ }
+ }
+
+ Constant *StoredConstant = dyn_cast<Constant>(Store->getOperand(0));
+ if (StoredConstant && StoredConstant->isNullValue()) {
+ auto *DefUOInst = dyn_cast<Instruction>(DefUO);
+ if (DefUOInst && isCallocLikeFn(DefUOInst, &TLI)) {
+ auto *UnderlyingDef = cast<MemoryDef>(MSSA.getMemoryAccess(DefUOInst));
+ // If UnderlyingDef is the clobbering access of Def, no instructions
+ // between them can modify the memory location.
+ auto *ClobberDef =
+ MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def);
+ return UnderlyingDef == ClobberDef;
+ }
+ }
+ return false;
+ }
+};
+
+bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
+ MemorySSA &MSSA, DominatorTree &DT,
+ PostDominatorTree &PDT,
+ const TargetLibraryInfo &TLI) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ bool MadeChange = false;
+
+ DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI);
+ // For each store:
+ for (unsigned I = 0; I < State.MemDefs.size(); I++) {
+ MemoryDef *KillingDef = State.MemDefs[I];
+ if (State.SkipStores.count(KillingDef))
+ continue;
+ Instruction *SI = KillingDef->getMemoryInst();
+
+ auto MaybeSILoc = State.getLocForWriteEx(SI);
+ if (State.isMemTerminatorInst(SI))
+ MaybeSILoc = State.getLocForTerminator(SI).map(
+ [](const std::pair<MemoryLocation, bool> &P) { return P.first; });
+ else
+ MaybeSILoc = State.getLocForWriteEx(SI);
+
+ if (!MaybeSILoc) {
+ LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
+ << *SI << "\n");
+ continue;
+ }
+ MemoryLocation SILoc = *MaybeSILoc;
+ assert(SILoc.Ptr && "SILoc should not be null");
+ const Value *SILocUnd = GetUnderlyingObject(SILoc.Ptr, DL);
+
+ // Check if the store is a no-op.
+ if (isRemovable(SI) && State.storeIsNoop(KillingDef, SILoc, SILocUnd)) {
+ LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *SI << '\n');
+ State.deleteDeadInstruction(SI);
+ NumNoopStores++;
+ MadeChange = true;
+ continue;
+ }
+
+ Instruction *DefObj =
+ const_cast<Instruction *>(dyn_cast<Instruction>(SILocUnd));
+ bool DefVisibleToCallerBeforeRet =
+ !State.InvisibleToCallerBeforeRet.count(SILocUnd);
+ bool DefVisibleToCallerAfterRet =
+ !State.InvisibleToCallerAfterRet.count(SILocUnd);
+ if (DefObj && isAllocLikeFn(DefObj, &TLI)) {
+ if (DefVisibleToCallerBeforeRet)
+ DefVisibleToCallerBeforeRet =
+ PointerMayBeCapturedBefore(DefObj, false, true, SI, &DT);
+ }
+
+ MemoryAccess *Current = KillingDef;
+ LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
+ << *KillingDef << " (" << *SI << ")\n");
+
+ int ScanLimit = MemorySSAScanLimit;
+ // Worklist of MemoryAccesses that may be killed by KillingDef.
+ SetVector<MemoryAccess *> ToCheck;
+ ToCheck.insert(KillingDef->getDefiningAccess());
+
+ // Check if MemoryAccesses in the worklist are killed by KillingDef.
+ for (unsigned I = 0; I < ToCheck.size(); I++) {
+ Current = ToCheck[I];
+ if (State.SkipStores.count(Current))
+ continue;
+
+ Optional<MemoryAccess *> Next = State.getDomMemoryDef(
+ KillingDef, Current, SILoc, DefVisibleToCallerBeforeRet,
+ DefVisibleToCallerAfterRet, ScanLimit);
+
+ if (!Next) {
+ LLVM_DEBUG(dbgs() << " finished walk\n");
+ continue;
+ }
+
+ MemoryAccess *DomAccess = *Next;
+ LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DomAccess);
+ if (isa<MemoryPhi>(DomAccess)) {
+ LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n");
+ for (Value *V : cast<MemoryPhi>(DomAccess)->incoming_values()) {
+ MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+ BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+ BasicBlock *PhiBlock = DomAccess->getBlock();
+
+ // We only consider incoming MemoryAccesses that come before the
+ // MemoryPhi. Otherwise we could discover candidates that do not
+ // strictly dominate our starting def.
+ if (State.PostOrderNumbers[IncomingBlock] >
+ State.PostOrderNumbers[PhiBlock])
+ ToCheck.insert(IncomingAccess);
+ }
+ continue;
+ }
+ MemoryDef *NextDef = dyn_cast<MemoryDef>(DomAccess);
+ Instruction *NI = NextDef->getMemoryInst();
+ LLVM_DEBUG(dbgs() << " (" << *NI << ")\n");
+
+ // Before we try to remove anything, check for any extra throwing
+ // instructions that block us from DSEing
+ if (State.mayThrowBetween(SI, NI, SILocUnd)) {
+ LLVM_DEBUG(dbgs() << " ... skip, may throw!\n");
+ break;
+ }
+
+ // Check for anything that looks like it will be a barrier to further
+ // removal
+ if (State.isDSEBarrier(SILocUnd, NI)) {
+ LLVM_DEBUG(dbgs() << " ... skip, barrier\n");
+ continue;
+ }
+
+ ToCheck.insert(NextDef->getDefiningAccess());
+
+ if (!hasAnalyzableMemoryWrite(NI, TLI)) {
+ LLVM_DEBUG(dbgs() << " ... skip, cannot analyze def\n");
+ continue;
+ }
+
+ if (!isRemovable(NI)) {
+ LLVM_DEBUG(dbgs() << " ... skip, cannot remove def\n");
+ continue;
+ }
+
+ if (!DebugCounter::shouldExecute(MemorySSACounter))
+ continue;
+
+ MemoryLocation NILoc = *State.getLocForWriteEx(NI);
+
+ if (State.isMemTerminatorInst(SI)) {
+ const Value *NIUnd = GetUnderlyingObject(NILoc.Ptr, DL);
+ if (!SILocUnd || SILocUnd != NIUnd)
+ continue;
+ LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI
+ << "\n KILLER: " << *SI << '\n');
+ State.deleteDeadInstruction(NI);
+ ++NumFastStores;
+ MadeChange = true;
+ } else {
+ // Check if NI overwrites SI.
+ int64_t InstWriteOffset, DepWriteOffset;
+ auto Iter = State.IOLs.insert(
+ std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+ NI->getParent(), InstOverlapIntervalsTy()));
+ auto &IOL = Iter.first->second;
+ OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset,
+ InstWriteOffset, NI, IOL, AA, &F);
+
+ if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+ auto *Earlier = dyn_cast<StoreInst>(NI);
+ auto *Later = dyn_cast<StoreInst>(SI);
+ if (Constant *Merged = tryToMergePartialOverlappingStores(
+ Earlier, Later, InstWriteOffset, DepWriteOffset, DL, &AA,
+ &DT)) {
+
+ // Update stored value of earlier store to merged constant.
+ Earlier->setOperand(0, Merged);
+ ++NumModifiedStores;
+ MadeChange = true;
+
+ // Remove later store and remove any outstanding overlap intervals
+ // for the updated store.
+ State.deleteDeadInstruction(Later);
+ auto I = State.IOLs.find(Earlier->getParent());
+ if (I != State.IOLs.end())
+ I->second.erase(Earlier);
+ break;
+ }
+ }
+
+ if (OR == OW_Complete) {
+ LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI
+ << "\n KILLER: " << *SI << '\n');
+ State.deleteDeadInstruction(NI);
+ ++NumFastStores;
+ MadeChange = true;
+ }
+ }
+ }
+ }
+
+ if (EnablePartialOverwriteTracking)
+ for (auto &KV : State.IOLs)
+ MadeChange |= removePartiallyOverlappedStores(&AA, DL, KV.second);
+
+ MadeChange |= State.eliminateDeadWritesAtEndOfFunction();
+ return MadeChange;
+}
+} // end anonymous namespace
+
//===----------------------------------------------------------------------===//
// DSE Pass
//===----------------------------------------------------------------------===//
PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
- AliasAnalysis *AA = &AM.getResult<AAManager>(F);
- DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
- MemoryDependenceResults *MD = &AM.getResult<MemoryDependenceAnalysis>(F);
- const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
+ AliasAnalysis &AA = AM.getResult<AAManager>(F);
+ const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
+ DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+
+ bool Changed = false;
+ if (EnableMemorySSA) {
+ MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
+ PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
- if (!eliminateDeadStores(F, AA, MD, DT, TLI))
+ Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI);
+ } else {
+ MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F);
+
+ Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI);
+ }
+
+#ifdef LLVM_ENABLE_STATS
+ if (AreStatisticsEnabled())
+ for (auto &I : instructions(F))
+ NumRemainingStores += isa<StoreInst>(&I);
+#endif
+
+ if (!Changed)
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
- PA.preserve<MemoryDependenceAnalysis>();
+ if (EnableMemorySSA)
+ PA.preserve<MemorySSAAnalysis>();
+ else
+ PA.preserve<MemoryDependenceAnalysis>();
return PA;
}
@@ -1383,25 +2339,51 @@ public:
if (skipFunction(F))
return false;
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- MemoryDependenceResults *MD =
- &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
- const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ const TargetLibraryInfo &TLI =
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+
+ bool Changed = false;
+ if (EnableMemorySSA) {
+ MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
+ PostDominatorTree &PDT =
+ getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
+
+ Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI);
+ } else {
+ MemoryDependenceResults &MD =
+ getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
+
+ Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI);
+ }
- return eliminateDeadStores(F, AA, MD, DT, TLI);
+#ifdef LLVM_ENABLE_STATS
+ if (AreStatisticsEnabled())
+ for (auto &I : instructions(F))
+ NumRemainingStores += isa<StoreInst>(&I);
+#endif
+
+ return Changed;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
- AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<MemoryDependenceWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<MemoryDependenceWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+
+ if (EnableMemorySSA) {
+ AU.addRequired<PostDominatorTreeWrapperPass>();
+ AU.addRequired<MemorySSAWrapperPass>();
+ AU.addPreserved<PostDominatorTreeWrapperPass>();
+ AU.addPreserved<MemorySSAWrapperPass>();
+ } else {
+ AU.addRequired<MemoryDependenceWrapperPass>();
+ AU.addPreserved<MemoryDependenceWrapperPass>();
+ }
}
};
@@ -1412,8 +2394,10 @@ char DSELegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false,
false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false,
diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
index 132dfc8f6da1..d44a5979a8b2 100644
--- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
+++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
@@ -17,6 +17,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/PatternMatch.h"
@@ -71,6 +72,7 @@ static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) {
return M;
}
+namespace {
/// A thin wrapper to store two values that we matched as div-rem pair.
/// We want this extra indirection to avoid dealing with RAUW'ing the map keys.
struct DivRemPairWorklistEntry {
@@ -111,6 +113,7 @@ struct DivRemPairWorklistEntry {
}
}
};
+} // namespace
using DivRemWorklistTy = SmallVector<DivRemPairWorklistEntry, 4>;
/// Find matching pairs of integer div/rem ops (they have the same numerator,
@@ -218,6 +221,7 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
NumRecomposed++;
// Note that we have left ((X / Y) * Y) around.
// If it had other uses we could rewrite it as X - X % Y
+ Changed = true;
}
assert((!E.isRemExpanded() || !HasDivRemOp) &&
@@ -301,6 +305,29 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
Mul->insertAfter(RemInst);
Sub->insertAfter(Mul);
+ // If X can be undef, X should be frozen first.
+ // For example, let's assume that Y = 1 & X = undef:
+ // %div = sdiv undef, 1 // %div = undef
+ // %rem = srem undef, 1 // %rem = 0
+ // =>
+ // %div = sdiv undef, 1 // %div = undef
+ // %mul = mul %div, 1 // %mul = undef
+ // %rem = sub %x, %mul // %rem = undef - undef = undef
+ // If X is not frozen, %rem becomes undef after transformation.
+ // TODO: We need a undef-specific checking function in ValueTracking
+ if (!isGuaranteedNotToBeUndefOrPoison(X, DivInst, &DT)) {
+ auto *FrX = new FreezeInst(X, X->getName() + ".frozen", DivInst);
+ DivInst->setOperand(0, FrX);
+ Sub->setOperand(0, FrX);
+ }
+ // Same for Y. If X = 1 and Y = (undef | 1), %rem in src is either 1 or 0,
+ // but %rem in tgt can be one of many integer values.
+ if (!isGuaranteedNotToBeUndefOrPoison(Y, DivInst, &DT)) {
+ auto *FrY = new FreezeInst(Y, Y->getName() + ".frozen", DivInst);
+ DivInst->setOperand(1, FrY);
+ Mul->setOperand(1, FrY);
+ }
+
// Now kill the explicit remainder. We have replaced it with:
// (sub X, (mul (div X, Y), Y)
Sub->setName(RemInst->getName() + ".decomposed");
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index 40c1ba88354f..ddfc8555b0a0 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -41,6 +41,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Statepoint.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
@@ -54,6 +55,7 @@
#include "llvm/Support/RecyclingAllocator.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/GuardUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
@@ -114,7 +116,7 @@ struct SimpleValue {
isa<CmpInst>(Inst) || isa<SelectInst>(Inst) ||
isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) ||
- isa<InsertValueInst>(Inst);
+ isa<InsertValueInst>(Inst) || isa<FreezeInst>(Inst);
}
};
@@ -152,13 +154,50 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
std::swap(A, B);
}
- // Set flavor if we find a match, or set it to unknown otherwise; in
- // either case, return true to indicate that this is a select we can
- // process.
- if (auto *CmpI = dyn_cast<ICmpInst>(Cond))
- Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor;
- else
- Flavor = SPF_UNKNOWN;
+ // Match canonical forms of abs/nabs/min/max. We are not using ValueTracking's
+ // more powerful matchSelectPattern() because it may rely on instruction flags
+ // such as "nsw". That would be incompatible with the current hashing
+ // mechanism that may remove flags to increase the likelihood of CSE.
+
+ // These are the canonical forms of abs(X) and nabs(X) created by instcombine:
+ // %N = sub i32 0, %X
+ // %C = icmp slt i32 %X, 0
+ // %ABS = select i1 %C, i32 %N, i32 %X
+ //
+ // %N = sub i32 0, %X
+ // %C = icmp slt i32 %X, 0
+ // %NABS = select i1 %C, i32 %X, i32 %N
+ Flavor = SPF_UNKNOWN;
+ CmpInst::Predicate Pred;
+ if (match(Cond, m_ICmp(Pred, m_Specific(B), m_ZeroInt())) &&
+ Pred == ICmpInst::ICMP_SLT && match(A, m_Neg(m_Specific(B)))) {
+ // ABS: B < 0 ? -B : B
+ Flavor = SPF_ABS;
+ return true;
+ }
+ if (match(Cond, m_ICmp(Pred, m_Specific(A), m_ZeroInt())) &&
+ Pred == ICmpInst::ICMP_SLT && match(B, m_Neg(m_Specific(A)))) {
+ // NABS: A < 0 ? A : -A
+ Flavor = SPF_NABS;
+ return true;
+ }
+
+ if (!match(Cond, m_ICmp(Pred, m_Specific(A), m_Specific(B)))) {
+ // Check for commuted variants of min/max by swapping predicate.
+ // If we do not match the standard or commuted patterns, this is not a
+ // recognized form of min/max, but it is still a select, so return true.
+ if (!match(Cond, m_ICmp(Pred, m_Specific(B), m_Specific(A))))
+ return true;
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ switch (Pred) {
+ case CmpInst::ICMP_UGT: Flavor = SPF_UMAX; break;
+ case CmpInst::ICMP_ULT: Flavor = SPF_UMIN; break;
+ case CmpInst::ICMP_SGT: Flavor = SPF_SMAX; break;
+ case CmpInst::ICMP_SLT: Flavor = SPF_SMIN; break;
+ default: break;
+ }
return true;
}
@@ -231,6 +270,9 @@ static unsigned getHashValueImpl(SimpleValue Val) {
if (CastInst *CI = dyn_cast<CastInst>(Inst))
return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0));
+ if (FreezeInst *FI = dyn_cast<FreezeInst>(Inst))
+ return hash_combine(FI->getOpcode(), FI->getOperand(0));
+
if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(Inst))
return hash_combine(EVI->getOpcode(), EVI->getOperand(0),
hash_combine_range(EVI->idx_begin(), EVI->idx_end()));
@@ -242,7 +284,8 @@ static unsigned getHashValueImpl(SimpleValue Val) {
assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) ||
isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
- isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst)) &&
+ isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst) ||
+ isa<FreezeInst>(Inst)) &&
"Invalid/unknown instruction");
// Mix in the opcode.
@@ -414,6 +457,14 @@ template <> struct DenseMapInfo<CallValue> {
unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) {
Instruction *Inst = Val.Inst;
+
+ // gc.relocate is 'special' call: its second and third operands are
+ // not real values, but indices into statepoint's argument list.
+ // Get values they point to.
+ if (const GCRelocateInst *GCR = dyn_cast<GCRelocateInst>(Inst))
+ return hash_combine(GCR->getOpcode(), GCR->getOperand(0),
+ GCR->getBasePtr(), GCR->getDerivedPtr());
+
// Hash all of the operands as pointers and mix in the opcode.
return hash_combine(
Inst->getOpcode(),
@@ -424,6 +475,14 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {
Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
if (LHS.isSentinel() || RHS.isSentinel())
return LHSI == RHSI;
+
+ // See comment above in `getHashValue()`.
+ if (const GCRelocateInst *GCR1 = dyn_cast<GCRelocateInst>(LHSI))
+ if (const GCRelocateInst *GCR2 = dyn_cast<GCRelocateInst>(RHSI))
+ return GCR1->getOperand(0) == GCR2->getOperand(0) &&
+ GCR1->getBasePtr() == GCR2->getBasePtr() &&
+ GCR1->getDerivedPtr() == GCR2->getDerivedPtr();
+
return LHSI->isIdenticalTo(RHSI);
}
@@ -561,8 +620,8 @@ private:
public:
StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,
InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls,
- unsigned cg, DomTreeNode *n, DomTreeNode::iterator child,
- DomTreeNode::iterator end)
+ unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child,
+ DomTreeNode::const_iterator end)
: CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child),
EndIter(end),
Scopes(AvailableValues, AvailableLoads, AvailableInvariants,
@@ -576,7 +635,7 @@ private:
unsigned childGeneration() { return ChildGeneration; }
void childGeneration(unsigned generation) { ChildGeneration = generation; }
DomTreeNode *node() { return Node; }
- DomTreeNode::iterator childIter() { return ChildIter; }
+ DomTreeNode::const_iterator childIter() { return ChildIter; }
DomTreeNode *nextChild() {
DomTreeNode *child = *ChildIter;
@@ -584,7 +643,7 @@ private:
return child;
}
- DomTreeNode::iterator end() { return EndIter; }
+ DomTreeNode::const_iterator end() { return EndIter; }
bool isProcessed() { return Processed; }
void process() { Processed = true; }
@@ -592,8 +651,8 @@ private:
unsigned CurrentGeneration;
unsigned ChildGeneration;
DomTreeNode *Node;
- DomTreeNode::iterator ChildIter;
- DomTreeNode::iterator EndIter;
+ DomTreeNode::const_iterator ChildIter;
+ DomTreeNode::const_iterator EndIter;
NodeScope Scopes;
bool Processed = false;
};
@@ -716,7 +775,7 @@ private:
bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration,
Instruction *EarlierInst, Instruction *LaterInst);
- void removeMSSA(Instruction *Inst) {
+ void removeMSSA(Instruction &Inst) {
if (!MSSA)
return;
if (VerifyMemorySSA)
@@ -727,7 +786,7 @@ private:
// is handled by MemorySSA when passing OptimizePhis = true to
// removeMemoryAccess. The non-optimized MemoryUse case is lazily updated
// by MemorySSA's getClobberingMemoryAccess.
- MSSAUpdater->removeMemoryAccess(Inst, true);
+ MSSAUpdater->removeMemoryAccess(&Inst, true);
}
};
@@ -897,20 +956,19 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// See if any instructions in the block can be eliminated. If so, do it. If
// not, add them to AvailableValues.
- for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
- Instruction *Inst = &*I++;
-
+ for (Instruction &Inst : make_early_inc_range(BB->getInstList())) {
// Dead instructions should just be removed.
- if (isInstructionTriviallyDead(Inst, &TLI)) {
- LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n');
+ if (isInstructionTriviallyDead(&Inst, &TLI)) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << Inst << '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
- salvageDebugInfoOrMarkUndef(*Inst);
+ salvageKnowledge(&Inst, &AC);
+ salvageDebugInfo(Inst);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
++NumSimplify;
continue;
@@ -920,21 +978,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// they're marked as such to ensure preservation of control dependencies),
// and this pass will not bother with its removal. However, we should mark
// its condition as true for all dominated blocks.
- if (match(Inst, m_Intrinsic<Intrinsic::assume>())) {
+ if (match(&Inst, m_Intrinsic<Intrinsic::assume>())) {
auto *CondI =
- dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0));
+ dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0));
if (CondI && SimpleValue::canHandle(CondI)) {
- LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst
+ LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << Inst
<< '\n');
AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext()));
} else
- LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n');
+ LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << Inst << '\n');
continue;
}
// Skip sideeffect intrinsics, for the same reason as assume intrinsics.
- if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) {
- LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n');
+ if (match(&Inst, m_Intrinsic<Intrinsic::sideeffect>())) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << Inst << '\n');
continue;
}
@@ -951,21 +1009,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// store 40, i8* p
// We can DSE the store to 30, since the store 40 to invariant location p
// causes undefined behaviour.
- if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) {
+ if (match(&Inst, m_Intrinsic<Intrinsic::invariant_start>())) {
// If there are any uses, the scope might end.
- if (!Inst->use_empty())
+ if (!Inst.use_empty())
continue;
- auto *CI = cast<CallInst>(Inst);
- MemoryLocation MemLoc = MemoryLocation::getForArgument(CI, 1, TLI);
+ MemoryLocation MemLoc =
+ MemoryLocation::getForArgument(&cast<CallInst>(Inst), 1, TLI);
// Don't start a scope if we already have a better one pushed
if (!AvailableInvariants.count(MemLoc))
AvailableInvariants.insert(MemLoc, CurrentGeneration);
continue;
}
- if (isGuard(Inst)) {
+ if (isGuard(&Inst)) {
if (auto *CondI =
- dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) {
+ dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0))) {
if (SimpleValue::canHandle(CondI)) {
// Do we already know the actual value of this condition?
if (auto *KnownCond = AvailableValues.lookup(CondI)) {
@@ -973,14 +1031,15 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (isa<ConstantInt>(KnownCond) &&
cast<ConstantInt>(KnownCond)->isOne()) {
LLVM_DEBUG(dbgs()
- << "EarlyCSE removing guard: " << *Inst << '\n');
+ << "EarlyCSE removing guard: " << Inst << '\n');
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
continue;
} else
// Use the known value if it wasn't true.
- cast<CallInst>(Inst)->setArgOperand(0, KnownCond);
+ cast<CallInst>(Inst).setArgOperand(0, KnownCond);
}
// The condition we're on guarding here is true for all dominated
// locations.
@@ -997,20 +1056,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// If the instruction can be simplified (e.g. X+0 = X) then replace it with
// its simpler value.
- if (Value *V = SimplifyInstruction(Inst, SQ)) {
- LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V
+ if (Value *V = SimplifyInstruction(&Inst, SQ)) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << Inst << " to: " << *V
<< '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
} else {
bool Killed = false;
- if (!Inst->use_empty()) {
- Inst->replaceAllUsesWith(V);
+ if (!Inst.use_empty()) {
+ Inst.replaceAllUsesWith(V);
Changed = true;
}
- if (isInstructionTriviallyDead(Inst, &TLI)) {
+ if (isInstructionTriviallyDead(&Inst, &TLI)) {
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
Killed = true;
}
@@ -1022,31 +1082,32 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
// If this is a simple instruction that we can value number, process it.
- if (SimpleValue::canHandle(Inst)) {
+ if (SimpleValue::canHandle(&Inst)) {
// See if the instruction has an available value. If so, use it.
- if (Value *V = AvailableValues.lookup(Inst)) {
- LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V
+ if (Value *V = AvailableValues.lookup(&Inst)) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << Inst << " to: " << *V
<< '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
if (auto *I = dyn_cast<Instruction>(V))
- I->andIRFlags(Inst);
- Inst->replaceAllUsesWith(V);
+ I->andIRFlags(&Inst);
+ Inst.replaceAllUsesWith(V);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
++NumCSE;
continue;
}
// Otherwise, just remember that this value is available.
- AvailableValues.insert(Inst, Inst);
+ AvailableValues.insert(&Inst, &Inst);
continue;
}
- ParseMemoryInst MemInst(Inst, TTI);
+ ParseMemoryInst MemInst(&Inst, TTI);
// If this is a non-volatile load, process it.
if (MemInst.isValid() && MemInst.isLoad()) {
// (conservatively) we can't peak past the ordering implied by this
@@ -1062,7 +1123,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// We conservatively treat the invariant_load as that moment. If we
// pass a invariant load after already establishing a scope, don't
// restart it since we want to preserve the earliest point seen.
- auto MemLoc = MemoryLocation::get(Inst);
+ auto MemLoc = MemoryLocation::get(&Inst);
if (!AvailableInvariants.count(MemLoc))
AvailableInvariants.insert(MemLoc, CurrentGeneration);
}
@@ -1081,21 +1142,22 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
!MemInst.isVolatile() && MemInst.isUnordered() &&
// We can't replace an atomic load with one which isn't also atomic.
InVal.IsAtomic >= MemInst.isAtomic() &&
- (isOperatingOnInvariantMemAt(Inst, InVal.Generation) ||
+ (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) ||
isSameMemGeneration(InVal.Generation, CurrentGeneration,
- InVal.DefInst, Inst))) {
- Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType());
+ InVal.DefInst, &Inst))) {
+ Value *Op = getOrCreateResult(InVal.DefInst, Inst.getType());
if (Op != nullptr) {
- LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst
+ LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst
<< " to: " << *InVal.DefInst << '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
- if (!Inst->use_empty())
- Inst->replaceAllUsesWith(Op);
+ if (!Inst.use_empty())
+ Inst.replaceAllUsesWith(Op);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
++NumCSELoad;
continue;
@@ -1103,10 +1165,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
// Otherwise, remember that we have this instruction.
- AvailableLoads.insert(
- MemInst.getPointerOperand(),
- LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(),
- MemInst.isAtomic()));
+ AvailableLoads.insert(MemInst.getPointerOperand(),
+ LoadValue(&Inst, CurrentGeneration,
+ MemInst.getMatchingId(),
+ MemInst.isAtomic()));
LastStore = nullptr;
continue;
}
@@ -1117,36 +1179,36 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// may override this (e.g. so that a store intrinsic does not read from
// memory, and thus will be treated the same as a regular store for
// commoning purposes).
- if ((Inst->mayReadFromMemory() || Inst->mayThrow()) &&
+ if ((Inst.mayReadFromMemory() || Inst.mayThrow()) &&
!(MemInst.isValid() && !MemInst.mayReadFromMemory()))
LastStore = nullptr;
// If this is a read-only call, process it.
- if (CallValue::canHandle(Inst)) {
+ if (CallValue::canHandle(&Inst)) {
// If we have an available version of this call, and if it is the right
// generation, replace this instruction.
- std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst);
+ std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(&Inst);
if (InVal.first != nullptr &&
isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first,
- Inst)) {
- LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst
+ &Inst)) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << Inst
<< " to: " << *InVal.first << '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
- if (!Inst->use_empty())
- Inst->replaceAllUsesWith(InVal.first);
+ if (!Inst.use_empty())
+ Inst.replaceAllUsesWith(InVal.first);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
++NumCSECall;
continue;
}
// Otherwise, remember that we have this instruction.
- AvailableCalls.insert(
- Inst, std::pair<Instruction *, unsigned>(Inst, CurrentGeneration));
+ AvailableCalls.insert(&Inst, std::make_pair(&Inst, CurrentGeneration));
continue;
}
@@ -1155,9 +1217,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// result, we don't need to consider it as writing to memory and don't need
// to advance the generation. We do need to prevent DSE across the fence,
// but that's handled above.
- if (FenceInst *FI = dyn_cast<FenceInst>(Inst))
+ if (auto *FI = dyn_cast<FenceInst>(&Inst))
if (FI->getOrdering() == AtomicOrdering::Release) {
- assert(Inst->mayReadFromMemory() && "relied on to prevent DSE above");
+ assert(Inst.mayReadFromMemory() && "relied on to prevent DSE above");
continue;
}
@@ -1169,13 +1231,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (MemInst.isValid() && MemInst.isStore()) {
LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand());
if (InVal.DefInst &&
- InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) &&
+ InVal.DefInst == getOrCreateResult(&Inst, InVal.DefInst->getType()) &&
InVal.MatchingId == MemInst.getMatchingId() &&
// We don't yet handle removing stores with ordering of any kind.
!MemInst.isVolatile() && MemInst.isUnordered() &&
- (isOperatingOnInvariantMemAt(Inst, InVal.Generation) ||
+ (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) ||
isSameMemGeneration(InVal.Generation, CurrentGeneration,
- InVal.DefInst, Inst))) {
+ InVal.DefInst, &Inst))) {
// It is okay to have a LastStore to a different pointer here if MemorySSA
// tells us that the load and store are from the same memory generation.
// In that case, LastStore should keep its present value since we're
@@ -1185,13 +1247,14 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
MemInst.getPointerOperand() ||
MSSA) &&
"can't have an intervening store if not using MemorySSA!");
- LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n');
+ LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << Inst << '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
- Inst->eraseFromParent();
+ Inst.eraseFromParent();
Changed = true;
++NumDSE;
// We can avoid incrementing the generation count since we were able
@@ -1203,7 +1266,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// Okay, this isn't something we can CSE at all. Check to see if it is
// something that could modify memory. If so, our available memory values
// cannot be used so bump the generation count.
- if (Inst->mayWriteToMemory()) {
+ if (Inst.mayWriteToMemory()) {
++CurrentGeneration;
if (MemInst.isValid() && MemInst.isStore()) {
@@ -1221,11 +1284,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
"Violated invariant");
if (LastStoreMemInst.isMatchingMemLoc(MemInst)) {
LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore
- << " due to: " << *Inst << '\n');
+ << " due to: " << Inst << '\n');
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
} else {
- removeMSSA(LastStore);
+ salvageKnowledge(&Inst, &AC);
+ removeMSSA(*LastStore);
LastStore->eraseFromParent();
Changed = true;
++NumDSE;
@@ -1240,10 +1304,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// version of the pointer. It is safe to forward from volatile stores
// to non-volatile loads, so we don't have to check for volatility of
// the store.
- AvailableLoads.insert(
- MemInst.getPointerOperand(),
- LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(),
- MemInst.isAtomic()));
+ AvailableLoads.insert(MemInst.getPointerOperand(),
+ LoadValue(&Inst, CurrentGeneration,
+ MemInst.getMatchingId(),
+ MemInst.isAtomic()));
// Remember that this was the last unordered store we saw for DSE. We
// don't yet handle DSE on ordered or volatile stores since we don't
@@ -1253,7 +1317,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// it's not clear this is a profitable transform. Another option would
// be to merge the ordering with that of the post dominating store.
if (MemInst.isUnordered() && !MemInst.isVolatile())
- LastStore = Inst;
+ LastStore = &Inst;
else
LastStore = nullptr;
}
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index af223cc837f2..83f4c402ed4d 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -120,8 +120,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
// Find the roots - instructions that convert from the FP domain to
// integer domain.
-void Float2IntPass::findRoots(Function &F, const DominatorTree &DT,
- SmallPtrSet<Instruction*,8> &Roots) {
+void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
for (BasicBlock &BB : F) {
// Unreachable code can take on strange forms that we are not prepared to
// handle. For example, an instruction may have itself as an operand.
@@ -184,7 +183,7 @@ ConstantRange Float2IntPass::validateRange(ConstantRange R) {
// Breadth-first walk of the use-def graph; determine the set of nodes
// we care about and eagerly determine if some of them are poisonous.
-void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) {
+void Float2IntPass::walkBackwards() {
std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
while (!Worklist.empty()) {
Instruction *I = Worklist.back();
@@ -327,7 +326,7 @@ void Float2IntPass::walkForwards() {
APFloat NewF = F;
auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
- if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) {
+ if (Res != APFloat::opOK || NewF != F) {
seen(I, badRange());
Abort = true;
break;
@@ -525,9 +524,9 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
Ctx = &F.getParent()->getContext();
- findRoots(F, DT, Roots);
+ findRoots(F, DT);
- walkBackwards(Roots);
+ walkBackwards();
walkForwards();
bool Modified = validateAndTransform();
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 1e6aab14e7b4..b16f8591b5a4 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CFG.h"
@@ -42,7 +43,6 @@
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
@@ -72,6 +72,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
@@ -97,10 +98,11 @@ STATISTIC(NumGVNSimpl, "Number of instructions simplified");
STATISTIC(NumGVNEqProp, "Number of equalities propagated");
STATISTIC(NumPRELoad, "Number of loads PRE'd");
-static cl::opt<bool> EnablePRE("enable-pre",
- cl::init(true), cl::Hidden);
-static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true));
-static cl::opt<bool> EnableMemDep("enable-gvn-memdep", cl::init(true));
+static cl::opt<bool> GVNEnablePRE("enable-pre", cl::init(true), cl::Hidden);
+static cl::opt<bool> GVNEnableLoadPRE("enable-load-pre", cl::init(true));
+static cl::opt<bool> GVNEnableLoadInLoopPRE("enable-load-in-loop-pre",
+ cl::init(true));
+static cl::opt<bool> GVNEnableMemDep("enable-gvn-memdep", cl::init(true));
// Maximum allowed recursion depth.
static cl::opt<uint32_t>
@@ -113,8 +115,8 @@ static cl::opt<uint32_t> MaxNumDeps(
struct llvm::GVN::Expression {
uint32_t opcode;
- Type *type = nullptr;
bool commutative = false;
+ Type *type = nullptr;
SmallVector<uint32_t, 4> varargs;
Expression(uint32_t o = ~2U) : opcode(o) {}
@@ -288,7 +290,7 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) {
e.commutative = true;
}
- if (CmpInst *C = dyn_cast<CmpInst>(I)) {
+ if (auto *C = dyn_cast<CmpInst>(I)) {
// Sort the operand value numbers so x<y and y>x get the same value number.
CmpInst::Predicate Predicate = C->getPredicate();
if (e.varargs[0] > e.varargs[1]) {
@@ -297,10 +299,11 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) {
}
e.opcode = (C->getOpcode() << 8) | Predicate;
e.commutative = true;
- } else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) {
- for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end();
- II != IE; ++II)
- e.varargs.push_back(*II);
+ } else if (auto *E = dyn_cast<InsertValueInst>(I)) {
+ e.varargs.append(E->idx_begin(), E->idx_end());
+ } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
+ ArrayRef<int> ShuffleMask = SVI->getShuffleMask();
+ e.varargs.append(ShuffleMask.begin(), ShuffleMask.end());
}
return e;
@@ -530,6 +533,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) {
case Instruction::AddrSpaceCast:
case Instruction::BitCast:
case Instruction::Select:
+ case Instruction::Freeze:
case Instruction::ExtractElement:
case Instruction::InsertElement:
case Instruction::ShuffleVector:
@@ -610,6 +614,22 @@ void GVN::ValueTable::verifyRemoved(const Value *V) const {
// GVN Pass
//===----------------------------------------------------------------------===//
+bool GVN::isPREEnabled() const {
+ return Options.AllowPRE.getValueOr(GVNEnablePRE);
+}
+
+bool GVN::isLoadPREEnabled() const {
+ return Options.AllowLoadPRE.getValueOr(GVNEnableLoadPRE);
+}
+
+bool GVN::isLoadInLoopPREEnabled() const {
+ return Options.AllowLoadInLoopPRE.getValueOr(GVNEnableLoadInLoopPRE);
+}
+
+bool GVN::isMemDepEnabled() const {
+ return Options.AllowMemDep.getValueOr(GVNEnableMemDep);
+}
+
PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {
// FIXME: The order of evaluation of these 'getResult' calls is very
// significant! Re-ordering these variables will cause GVN when run alone to
@@ -619,10 +639,11 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
- auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F);
+ auto *MemDep =
+ isMemDepEnabled() ? &AM.getResult<MemoryDependenceAnalysis>(F) : nullptr;
auto *LI = AM.getCachedResult<LoopAnalysis>(F);
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep, LI, &ORE);
+ bool Changed = runImpl(F, AC, DT, TLI, AA, MemDep, LI, &ORE);
if (!Changed)
return PreservedAnalyses::all();
PreservedAnalyses PA;
@@ -927,6 +948,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// Loading the allocation -> undef.
if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) ||
+ isAlignedAllocLikeFn(DepInst, TLI) ||
// Loading immediately after lifetime begin -> undef.
isLifetimeStart(DepInst)) {
Res = AvailableValue::get(UndefValue::get(LI->getType()));
@@ -1245,7 +1267,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock,
auto *NewLoad = new LoadInst(
LI->getType(), LoadPtr, LI->getName() + ".pre", LI->isVolatile(),
- MaybeAlign(LI->getAlignment()), LI->getOrdering(), LI->getSyncScopeID(),
+ LI->getAlign(), LI->getOrdering(), LI->getSyncScopeID(),
UnavailablePred->getTerminator());
NewLoad->setDebugLoc(LI->getDebugLoc());
@@ -1383,7 +1405,10 @@ bool GVN::processNonLocalLoad(LoadInst *LI) {
}
// Step 4: Eliminate partial redundancy.
- if (!EnablePRE || !EnableLoadPRE)
+ if (!isPREEnabled() || !isLoadPREEnabled())
+ return false;
+ if (!isLoadInLoopPREEnabled() && this->LI &&
+ this->LI->getLoopFor(LI->getParent()))
return false;
return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks);
@@ -1428,7 +1453,7 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
Value *LHS = Cmp->getOperand(0);
Value *RHS = Cmp->getOperand(1);
// If we can prove either side non-zero, then equality must imply
- // equivalence.
+ // equivalence.
// FIXME: We should do this optimization if 'no signed zeros' is
// applicable via an instruction-level fast-math-flag or some other
// indicator that relaxed FP semantics are being used.
@@ -1465,7 +1490,8 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
Constant::getNullValue(Int8Ty->getPointerTo()),
IntrinsicI);
}
- markInstructionForDeletion(IntrinsicI);
+ if (isAssumeWithEmptyBundle(*IntrinsicI))
+ markInstructionForDeletion(IntrinsicI);
return false;
} else if (isa<Constant>(V)) {
// If it's not false, and constant, it must evaluate to true. This means our
@@ -1493,10 +1519,10 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
// If we find an equality fact, canonicalize all dominated uses in this block
// to one of the two values. We heuristically choice the "oldest" of the
// two where age is determined by value number. (Note that propagateEquality
- // above handles the cross block case.)
- //
+ // above handles the cross block case.)
+ //
// Key case to cover are:
- // 1)
+ // 1)
// %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen
// call void @llvm.assume(i1 %cmp)
// ret float %0 ; will change it to ret float 3.000000e+00
@@ -1537,7 +1563,7 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
<< *CmpLHS << " with "
<< *CmpRHS << " in block "
<< IntrinsicI->getParent()->getName() << "\n");
-
+
// Setup the replacement map - this handles uses within the same block
if (hasUsersIn(CmpLHS, IntrinsicI->getParent()))
@@ -1710,7 +1736,8 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred,
// instead of value numbers. Those index numbers should not be
// translated.
if ((i > 1 && Exp.opcode == Instruction::InsertValue) ||
- (i > 0 && Exp.opcode == Instruction::ExtractValue))
+ (i > 0 && Exp.opcode == Instruction::ExtractValue) ||
+ (i > 1 && Exp.opcode == Instruction::ShuffleVector))
continue;
Exp.varargs[i] = phiTranslate(Pred, PhiBlock, Exp.varargs[i], Gvn);
}
@@ -1802,7 +1829,7 @@ void GVN::assignBlockRPONumber(Function &F) {
bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const {
bool Changed = false;
for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) {
- Value *Operand = Instr->getOperand(OpNum);
+ Value *Operand = Instr->getOperand(OpNum);
auto it = ReplaceOperandsWithMap.find(Operand);
if (it != ReplaceOperandsWithMap.end()) {
LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with "
@@ -1922,7 +1949,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root,
// If "A == B" is known true, or "A != B" is known false, then replace
// A with B everywhere in the scope. For floating point operations, we
- // have to be careful since equality does not always imply equivalance.
+ // have to be careful since equality does not always imply equivalance.
if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
(isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
Worklist.push_back(std::make_pair(Op0, Op1));
@@ -2117,7 +2144,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
TLI = &RunTLI;
VN.setAliasAnalysis(&RunAA);
MD = RunMD;
- ImplicitControlFlowTracking ImplicitCFT(DT);
+ ImplicitControlFlowTracking ImplicitCFT;
ICF = &ImplicitCFT;
this->LI = LI;
VN.setMemDep(MD);
@@ -2148,7 +2175,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
++Iteration;
}
- if (EnablePRE) {
+ if (isPREEnabled()) {
// Fabricate val-num for dead-code in order to suppress assertion in
// performPRE().
assignValNumForDeadCode();
@@ -2206,6 +2233,7 @@ bool GVN::processBlock(BasicBlock *BB) {
for (auto *I : InstrsToErase) {
assert(I->getParent() == BB && "Removing instruction from wrong block?");
LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n');
+ salvageKnowledge(I, AC);
salvageDebugInfo(*I);
if (MD) MD->removeInstruction(I);
LLVM_DEBUG(verifyRemoved(I));
@@ -2478,8 +2506,11 @@ bool GVN::performPRE(Function &F) {
/// Split the critical edge connecting the given two blocks, and return
/// the block inserted to the critical edge.
BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) {
- BasicBlock *BB =
- SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT, LI));
+ // GVN does not require loop-simplify, do not try to preserve it if it is not
+ // possible.
+ BasicBlock *BB = SplitCriticalEdge(
+ Pred, Succ,
+ CriticalEdgeSplittingOptions(DT, LI).unsetPreserveLoopSimplify());
if (MD)
MD->invalidateCachedPredecessors();
InvalidBlockRPONumbers = true;
@@ -2682,8 +2713,8 @@ class llvm::gvn::GVNLegacyPass : public FunctionPass {
public:
static char ID; // Pass identification, replacement for typeid
- explicit GVNLegacyPass(bool NoMemDepAnalysis = !EnableMemDep)
- : FunctionPass(ID), NoMemDepAnalysis(NoMemDepAnalysis) {
+ explicit GVNLegacyPass(bool NoMemDepAnalysis = !GVNEnableMemDep)
+ : FunctionPass(ID), Impl(GVNOptions().setMemDep(!NoMemDepAnalysis)) {
initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry());
}
@@ -2698,9 +2729,9 @@ public:
getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
getAnalysis<AAResultsWrapperPass>().getAAResults(),
- NoMemDepAnalysis
- ? nullptr
- : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(),
+ Impl.isMemDepEnabled()
+ ? &getAnalysis<MemoryDependenceWrapperPass>().getMemDep()
+ : nullptr,
LIWP ? &LIWP->getLoopInfo() : nullptr,
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE());
}
@@ -2710,7 +2741,7 @@ public:
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
- if (!NoMemDepAnalysis)
+ if (Impl.isMemDepEnabled())
AU.addRequired<MemoryDependenceWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
@@ -2718,12 +2749,10 @@ public:
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreservedID(LoopSimplifyID);
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
private:
- bool NoMemDepAnalysis;
GVN Impl;
};
diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
index e1796f6bf05a..9c4cdf2feb56 100644
--- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
@@ -890,18 +890,16 @@ private:
void updateAlignment(Instruction *I, Instruction *Repl) {
if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) {
- ReplacementLoad->setAlignment(MaybeAlign(std::min(
- ReplacementLoad->getAlignment(), cast<LoadInst>(I)->getAlignment())));
+ ReplacementLoad->setAlignment(
+ std::min(ReplacementLoad->getAlign(), cast<LoadInst>(I)->getAlign()));
++NumLoadsRemoved;
} else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) {
- ReplacementStore->setAlignment(
- MaybeAlign(std::min(ReplacementStore->getAlignment(),
- cast<StoreInst>(I)->getAlignment())));
+ ReplacementStore->setAlignment(std::min(ReplacementStore->getAlign(),
+ cast<StoreInst>(I)->getAlign()));
++NumStoresRemoved;
} else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) {
- ReplacementAlloca->setAlignment(
- MaybeAlign(std::max(ReplacementAlloca->getAlignment(),
- cast<AllocaInst>(I)->getAlignment())));
+ ReplacementAlloca->setAlignment(std::max(
+ ReplacementAlloca->getAlign(), cast<AllocaInst>(I)->getAlign()));
} else if (isa<CallInst>(Repl)) {
++NumCallsRemoved;
}
diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp
index 6d0a4975e266..dfb4b7e038ba 100644
--- a/llvm/lib/Transforms/Scalar/GVNSink.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp
@@ -350,6 +350,7 @@ using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
class InstructionUseExpr : public GVNExpression::BasicExpression {
unsigned MemoryUseOrder = -1;
bool Volatile = false;
+ ArrayRef<int> ShuffleMask;
public:
InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
@@ -359,6 +360,9 @@ public:
setOpcode(I->getOpcode());
setType(I->getType());
+ if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I))
+ ShuffleMask = SVI->getShuffleMask().copy(A);
+
for (auto &U : I->uses())
op_push_back(U.getUser());
llvm::sort(op_begin(), op_end());
@@ -369,12 +373,12 @@ public:
hash_code getHashValue() const override {
return hash_combine(GVNExpression::BasicExpression::getHashValue(),
- MemoryUseOrder, Volatile);
+ MemoryUseOrder, Volatile, ShuffleMask);
}
template <typename Function> hash_code getHashValue(Function MapFn) {
- hash_code H =
- hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile);
+ hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile,
+ ShuffleMask);
for (auto *V : operands())
H = hash_combine(H, MapFn(V));
return H;
@@ -475,6 +479,7 @@ public:
case Instruction::PtrToInt:
case Instruction::IntToPtr:
case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
case Instruction::Select:
case Instruction::ExtractElement:
case Instruction::InsertElement:
@@ -576,7 +581,7 @@ public:
private:
ValueTable VN;
- bool isInstructionBlacklisted(Instruction *I) {
+ bool shouldAvoidSinkingInstruction(Instruction *I) {
// These instructions may change or break semantics if moved.
if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
I->getType()->isTokenTy())
@@ -668,7 +673,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking(
NewInsts.push_back(I);
}
for (auto *I : NewInsts)
- if (isInstructionBlacklisted(I))
+ if (shouldAvoidSinkingInstruction(I))
return None;
// If we've restricted the incoming blocks, restrict all needed PHIs also
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index d8d7acae5c9f..0f36c3f772e6 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -38,8 +38,9 @@
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -81,6 +82,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/SimplifyIndVar.h"
#include <cassert>
#include <cstdint>
@@ -100,10 +102,10 @@ STATISTIC(NumElimIV , "Number of congruent IVs eliminated");
// implement a strong expression equivalence checker in SCEV. Until then, we
// use the verify-indvars flag, which may assert in some cases.
static cl::opt<bool> VerifyIndvars(
- "verify-indvars", cl::Hidden,
- cl::desc("Verify the ScalarEvolution result after running indvars"));
-
-enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, NoHardUse, AlwaysRepl };
+ "verify-indvars", cl::Hidden,
+ cl::desc("Verify the ScalarEvolution result after running indvars. Has no "
+ "effect in release builds. (Note: this adds additional SCEV "
+ "queries potentially changing the analysis result)"));
static cl::opt<ReplaceExitVal> ReplaceExitValue(
"replexitval", cl::Hidden, cl::init(OnlyCheapRepl),
@@ -140,11 +142,10 @@ class IndVarSimplify {
const DataLayout &DL;
TargetLibraryInfo *TLI;
const TargetTransformInfo *TTI;
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
SmallVector<WeakTrackingVH, 16> DeadInsts;
- bool isValidRewrite(Value *FromVal, Value *ToVal);
-
bool handleFloatingPointIV(Loop *L, PHINode *PH);
bool rewriteNonIntegerIVs(Loop *L);
@@ -155,10 +156,7 @@ class IndVarSimplify {
/// iterations of the loop run when that is unobservable.
bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
- bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet);
- bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter);
bool rewriteFirstIterationLoopExitValues(Loop *L);
- bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const;
bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
const SCEV *ExitCount,
@@ -169,66 +167,17 @@ class IndVarSimplify {
public:
IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
const DataLayout &DL, TargetLibraryInfo *TLI,
- TargetTransformInfo *TTI)
- : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {}
+ TargetTransformInfo *TTI, MemorySSA *MSSA)
+ : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {
+ if (MSSA)
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+ }
bool run(Loop *L);
};
} // end anonymous namespace
-/// Return true if the SCEV expansion generated by the rewriter can replace the
-/// original value. SCEV guarantees that it produces the same value, but the way
-/// it is produced may be illegal IR. Ideally, this function will only be
-/// called for verification.
-bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) {
- // If an SCEV expression subsumed multiple pointers, its expansion could
- // reassociate the GEP changing the base pointer. This is illegal because the
- // final address produced by a GEP chain must be inbounds relative to its
- // underlying object. Otherwise basic alias analysis, among other things,
- // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid
- // producing an expression involving multiple pointers. Until then, we must
- // bail out here.
- //
- // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject
- // because it understands lcssa phis while SCEV does not.
- Value *FromPtr = FromVal;
- Value *ToPtr = ToVal;
- if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) {
- FromPtr = GEP->getPointerOperand();
- }
- if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) {
- ToPtr = GEP->getPointerOperand();
- }
- if (FromPtr != FromVal || ToPtr != ToVal) {
- // Quickly check the common case
- if (FromPtr == ToPtr)
- return true;
-
- // SCEV may have rewritten an expression that produces the GEP's pointer
- // operand. That's ok as long as the pointer operand has the same base
- // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the
- // base of a recurrence. This handles the case in which SCEV expansion
- // converts a pointer type recurrence into a nonrecurrent pointer base
- // indexed by an integer recurrence.
-
- // If the GEP base pointer is a vector of pointers, abort.
- if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy())
- return false;
-
- const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr));
- const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr));
- if (FromBase == ToBase)
- return true;
-
- LLVM_DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " << *FromBase
- << " != " << *ToBase << "\n");
-
- return false;
- }
- return true;
-}
-
/// Determine the insertion point for this user. By default, insert immediately
/// before the user. SCEVExpander or LICM will hoist loop invariants out of the
/// loop. For PHI nodes, there may be multiple uses, so compute the nearest
@@ -477,11 +426,11 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {
// new comparison.
NewCompare->takeName(Compare);
Compare->replaceAllUsesWith(NewCompare);
- RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI);
+ RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get());
// Delete the old floating point increment.
Incr->replaceAllUsesWith(UndefValue::get(Incr->getType()));
- RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI);
+ RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI, MSSAU.get());
// If the FP induction variable still has uses, this is because something else
// in the loop uses its value. In order to canonicalize the induction
@@ -494,7 +443,7 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {
Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv",
&*PN->getParent()->getFirstInsertionPt());
PN->replaceAllUsesWith(Conv);
- RecursivelyDeleteTriviallyDeadInstructions(PN, TLI);
+ RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get());
}
return true;
}
@@ -522,222 +471,6 @@ bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) {
return Changed;
}
-namespace {
-
-// Collect information about PHI nodes which can be transformed in
-// rewriteLoopExitValues.
-struct RewritePhi {
- PHINode *PN;
-
- // Ith incoming value.
- unsigned Ith;
-
- // Exit value after expansion.
- Value *Val;
-
- // High Cost when expansion.
- bool HighCost;
-
- RewritePhi(PHINode *P, unsigned I, Value *V, bool H)
- : PN(P), Ith(I), Val(V), HighCost(H) {}
-};
-
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-// rewriteLoopExitValues - Optimize IV users outside the loop.
-// As a side effect, reduces the amount of IV processing within the loop.
-//===----------------------------------------------------------------------===//
-
-bool IndVarSimplify::hasHardUserWithinLoop(const Loop *L, const Instruction *I) const {
- SmallPtrSet<const Instruction *, 8> Visited;
- SmallVector<const Instruction *, 8> WorkList;
- Visited.insert(I);
- WorkList.push_back(I);
- while (!WorkList.empty()) {
- const Instruction *Curr = WorkList.pop_back_val();
- // This use is outside the loop, nothing to do.
- if (!L->contains(Curr))
- continue;
- // Do we assume it is a "hard" use which will not be eliminated easily?
- if (Curr->mayHaveSideEffects())
- return true;
- // Otherwise, add all its users to worklist.
- for (auto U : Curr->users()) {
- auto *UI = cast<Instruction>(U);
- if (Visited.insert(UI).second)
- WorkList.push_back(UI);
- }
- }
- return false;
-}
-
-/// Check to see if this loop has a computable loop-invariant execution count.
-/// If so, this means that we can compute the final value of any expressions
-/// that are recurrent in the loop, and substitute the exit values from the loop
-/// into any instructions outside of the loop that use the final values of the
-/// current expressions.
-///
-/// This is mostly redundant with the regular IndVarSimplify activities that
-/// happen later, except that it's more powerful in some cases, because it's
-/// able to brute-force evaluate arbitrary instructions as long as they have
-/// constant operands at the beginning of the loop.
-bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) {
- // Check a pre-condition.
- assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
- "Indvars did not preserve LCSSA!");
-
- SmallVector<BasicBlock*, 8> ExitBlocks;
- L->getUniqueExitBlocks(ExitBlocks);
-
- SmallVector<RewritePhi, 8> RewritePhiSet;
- // Find all values that are computed inside the loop, but used outside of it.
- // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan
- // the exit blocks of the loop to find them.
- for (BasicBlock *ExitBB : ExitBlocks) {
- // If there are no PHI nodes in this exit block, then no values defined
- // inside the loop are used on this path, skip it.
- PHINode *PN = dyn_cast<PHINode>(ExitBB->begin());
- if (!PN) continue;
-
- unsigned NumPreds = PN->getNumIncomingValues();
-
- // Iterate over all of the PHI nodes.
- BasicBlock::iterator BBI = ExitBB->begin();
- while ((PN = dyn_cast<PHINode>(BBI++))) {
- if (PN->use_empty())
- continue; // dead use, don't replace it
-
- if (!SE->isSCEVable(PN->getType()))
- continue;
-
- // It's necessary to tell ScalarEvolution about this explicitly so that
- // it can walk the def-use list and forget all SCEVs, as it may not be
- // watching the PHI itself. Once the new exit value is in place, there
- // may not be a def-use connection between the loop and every instruction
- // which got a SCEVAddRecExpr for that loop.
- SE->forgetValue(PN);
-
- // Iterate over all of the values in all the PHI nodes.
- for (unsigned i = 0; i != NumPreds; ++i) {
- // If the value being merged in is not integer or is not defined
- // in the loop, skip it.
- Value *InVal = PN->getIncomingValue(i);
- if (!isa<Instruction>(InVal))
- continue;
-
- // If this pred is for a subloop, not L itself, skip it.
- if (LI->getLoopFor(PN->getIncomingBlock(i)) != L)
- continue; // The Block is in a subloop, skip it.
-
- // Check that InVal is defined in the loop.
- Instruction *Inst = cast<Instruction>(InVal);
- if (!L->contains(Inst))
- continue;
-
- // Okay, this instruction has a user outside of the current loop
- // and varies predictably *inside* the loop. Evaluate the value it
- // contains when the loop exits, if possible. We prefer to start with
- // expressions which are true for all exits (so as to maximize
- // expression reuse by the SCEVExpander), but resort to per-exit
- // evaluation if that fails.
- const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop());
- if (isa<SCEVCouldNotCompute>(ExitValue) ||
- !SE->isLoopInvariant(ExitValue, L) ||
- !isSafeToExpand(ExitValue, *SE)) {
- // TODO: This should probably be sunk into SCEV in some way; maybe a
- // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
- // most SCEV expressions and other recurrence types (e.g. shift
- // recurrences). Is there existing code we can reuse?
- const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i));
- if (isa<SCEVCouldNotCompute>(ExitCount))
- continue;
- if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst)))
- if (AddRec->getLoop() == L)
- ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE);
- if (isa<SCEVCouldNotCompute>(ExitValue) ||
- !SE->isLoopInvariant(ExitValue, L) ||
- !isSafeToExpand(ExitValue, *SE))
- continue;
- }
-
- // Computing the value outside of the loop brings no benefit if it is
- // definitely used inside the loop in a way which can not be optimized
- // away. Avoid doing so unless we know we have a value which computes
- // the ExitValue already. TODO: This should be merged into SCEV
- // expander to leverage its knowledge of existing expressions.
- if (ReplaceExitValue != AlwaysRepl &&
- !isa<SCEVConstant>(ExitValue) && !isa<SCEVUnknown>(ExitValue) &&
- hasHardUserWithinLoop(L, Inst))
- continue;
-
- bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst);
- Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst);
-
- LLVM_DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal
- << '\n'
- << " LoopVal = " << *Inst << "\n");
-
- if (!isValidRewrite(Inst, ExitVal)) {
- DeadInsts.push_back(ExitVal);
- continue;
- }
-
-#ifndef NDEBUG
- // If we reuse an instruction from a loop which is neither L nor one of
- // its containing loops, we end up breaking LCSSA form for this loop by
- // creating a new use of its instruction.
- if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
- if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
- if (EVL != L)
- assert(EVL->contains(L) && "LCSSA breach detected!");
-#endif
-
- // Collect all the candidate PHINodes to be rewritten.
- RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost);
- }
- }
- }
-
- bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
-
- bool Changed = false;
- // Transformation.
- for (const RewritePhi &Phi : RewritePhiSet) {
- PHINode *PN = Phi.PN;
- Value *ExitVal = Phi.Val;
-
- // Only do the rewrite when the ExitValue can be expanded cheaply.
- // If LoopCanBeDel is true, rewrite exit value aggressively.
- if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) {
- DeadInsts.push_back(ExitVal);
- continue;
- }
-
- Changed = true;
- ++NumReplaced;
- Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith));
- PN->setIncomingValue(Phi.Ith, ExitVal);
-
- // If this instruction is dead now, delete it. Don't do it now to avoid
- // invalidating iterators.
- if (isInstructionTriviallyDead(Inst, TLI))
- DeadInsts.push_back(Inst);
-
- // Replace PN with ExitVal if that is legal and does not break LCSSA.
- if (PN->getNumIncomingValues() == 1 &&
- LI->replacementPreservesLCSSAForm(PN, ExitVal)) {
- PN->replaceAllUsesWith(ExitVal);
- PN->eraseFromParent();
- }
- }
-
- // The insertion point instruction may have been deleted; clear it out
- // so that the rewriter doesn't trip over it later.
- Rewriter.clearInsertPoint();
- return Changed;
-}
-
//===---------------------------------------------------------------------===//
// rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know
// they will exit at the first iteration.
@@ -813,61 +546,6 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) {
return MadeAnyChanges;
}
-/// Check whether it is possible to delete the loop after rewriting exit
-/// value. If it is possible, ignore ReplaceExitValue and do rewriting
-/// aggressively.
-bool IndVarSimplify::canLoopBeDeleted(
- Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) {
- BasicBlock *Preheader = L->getLoopPreheader();
- // If there is no preheader, the loop will not be deleted.
- if (!Preheader)
- return false;
-
- // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1.
- // We obviate multiple ExitingBlocks case for simplicity.
- // TODO: If we see testcase with multiple ExitingBlocks can be deleted
- // after exit value rewriting, we can enhance the logic here.
- SmallVector<BasicBlock *, 4> ExitingBlocks;
- L->getExitingBlocks(ExitingBlocks);
- SmallVector<BasicBlock *, 8> ExitBlocks;
- L->getUniqueExitBlocks(ExitBlocks);
- if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
- return false;
-
- BasicBlock *ExitBlock = ExitBlocks[0];
- BasicBlock::iterator BI = ExitBlock->begin();
- while (PHINode *P = dyn_cast<PHINode>(BI)) {
- Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]);
-
- // If the Incoming value of P is found in RewritePhiSet, we know it
- // could be rewritten to use a loop invariant value in transformation
- // phase later. Skip it in the loop invariant check below.
- bool found = false;
- for (const RewritePhi &Phi : RewritePhiSet) {
- unsigned i = Phi.Ith;
- if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
- found = true;
- break;
- }
- }
-
- Instruction *I;
- if (!found && (I = dyn_cast<Instruction>(Incoming)))
- if (!L->hasLoopInvariantOperands(I))
- return false;
-
- ++BI;
- }
-
- for (auto *BB : L->blocks())
- if (llvm::any_of(*BB, [](Instruction &I) {
- return I.mayHaveSideEffects();
- }))
- return false;
-
- return true;
-}
-
//===----------------------------------------------------------------------===//
// IV Widening - Extend the width of an IV to cover its widest uses.
//===----------------------------------------------------------------------===//
@@ -1060,8 +738,8 @@ protected:
Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter);
bool widenLoopCompare(NarrowIVDefUse DU);
- bool widenWithVariantLoadUse(NarrowIVDefUse DU);
- void widenWithVariantLoadUseCodegen(NarrowIVDefUse DU);
+ bool widenWithVariantUse(NarrowIVDefUse DU);
+ void widenWithVariantUseCodegen(NarrowIVDefUse DU);
void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef);
};
@@ -1399,20 +1077,27 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) {
return true;
}
-/// If the narrow use is an instruction whose two operands are the defining
-/// instruction of DU and a load instruction, then we have the following:
-/// if the load is hoisted outside the loop, then we do not reach this function
-/// as scalar evolution analysis works fine in widenIVUse with variables
-/// hoisted outside the loop and efficient code is subsequently generated by
-/// not emitting truncate instructions. But when the load is not hoisted
-/// (whether due to limitation in alias analysis or due to a true legality),
-/// then scalar evolution can not proceed with loop variant values and
-/// inefficient code is generated. This function handles the non-hoisted load
-/// special case by making the optimization generate the same type of code for
-/// hoisted and non-hoisted load (widen use and eliminate sign extend
-/// instruction). This special case is important especially when the induction
-/// variables are affecting addressing mode in code generation.
-bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) {
+// The widenIVUse avoids generating trunc by evaluating the use as AddRec, this
+// will not work when:
+// 1) SCEV traces back to an instruction inside the loop that SCEV can not
+// expand, eg. add %indvar, (load %addr)
+// 2) SCEV finds a loop variant, eg. add %indvar, %loopvariant
+// While SCEV fails to avoid trunc, we can still try to use instruction
+// combining approach to prove trunc is not required. This can be further
+// extended with other instruction combining checks, but for now we handle the
+// following case (sub can be "add" and "mul", "nsw + sext" can be "nus + zext")
+//
+// Src:
+// %c = sub nsw %b, %indvar
+// %d = sext %c to i64
+// Dst:
+// %indvar.ext1 = sext %indvar to i64
+// %m = sext %b to i64
+// %d = sub nsw i64 %m, %indvar.ext1
+// Therefore, as long as the result of add/sub/mul is extended to wide type, no
+// trunc is required regardless of how %b is generated. This pattern is common
+// when calculating address in 64 bit architecture
+bool WidenIV::widenWithVariantUse(NarrowIVDefUse DU) {
Instruction *NarrowUse = DU.NarrowUse;
Instruction *NarrowDef = DU.NarrowDef;
Instruction *WideDef = DU.WideDef;
@@ -1443,12 +1128,6 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) {
else
return false;
- // We are interested in the other operand being a load instruction.
- // But, we should look into relaxing this restriction later on.
- auto *I = dyn_cast<Instruction>(NarrowUse->getOperand(ExtendOperIdx));
- if (I && I->getOpcode() != Instruction::Load)
- return false;
-
// Verifying that Defining operand is an AddRec
const SCEV *Op1 = SE->getSCEV(WideDef);
const SCEVAddRecExpr *AddRecOp1 = dyn_cast<SCEVAddRecExpr>(Op1);
@@ -1480,9 +1159,9 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) {
return true;
}
-/// Special Case for widening with variant Loads (see
-/// WidenIV::widenWithVariantLoadUse). This is the code generation part.
-void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) {
+/// Special Case for widening with loop variant (see
+/// WidenIV::widenWithVariant). This is the code generation part.
+void WidenIV::widenWithVariantUseCodegen(NarrowIVDefUse DU) {
Instruction *NarrowUse = DU.NarrowUse;
Instruction *NarrowDef = DU.NarrowDef;
Instruction *WideDef = DU.WideDef;
@@ -1508,33 +1187,22 @@ void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) {
Builder.Insert(WideBO);
WideBO->copyIRFlags(NarrowBO);
- if (ExtKind == SignExtended)
- ExtendKindMap[NarrowUse] = SignExtended;
- else
- ExtendKindMap[NarrowUse] = ZeroExtended;
+ assert(ExtKind != Unknown && "Unknown ExtKind not handled");
- // Update the Use.
- if (ExtKind == SignExtended) {
- for (Use &U : NarrowUse->uses()) {
- SExtInst *User = dyn_cast<SExtInst>(U.getUser());
- if (User && User->getType() == WideType) {
- LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by "
- << *WideBO << "\n");
- ++NumElimExt;
- User->replaceAllUsesWith(WideBO);
- DeadInsts.emplace_back(User);
- }
- }
- } else { // ExtKind == ZeroExtended
- for (Use &U : NarrowUse->uses()) {
- ZExtInst *User = dyn_cast<ZExtInst>(U.getUser());
- if (User && User->getType() == WideType) {
- LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by "
- << *WideBO << "\n");
- ++NumElimExt;
- User->replaceAllUsesWith(WideBO);
- DeadInsts.emplace_back(User);
- }
+ ExtendKindMap[NarrowUse] = ExtKind;
+
+ for (Use &U : NarrowUse->uses()) {
+ Instruction *User = nullptr;
+ if (ExtKind == SignExtended)
+ User = dyn_cast<SExtInst>(U.getUser());
+ else
+ User = dyn_cast<ZExtInst>(U.getUser());
+ if (User && User->getType() == WideType) {
+ LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by "
+ << *WideBO << "\n");
+ ++NumElimExt;
+ User->replaceAllUsesWith(WideBO);
+ DeadInsts.emplace_back(User);
}
}
}
@@ -1641,8 +1309,8 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) {
// in WideAddRec.first does not indicate a polynomial induction expression.
// In that case, look at the operands of the use instruction to determine
// if we can still widen the use instead of truncating its operand.
- if (widenWithVariantLoadUse(DU)) {
- widenWithVariantLoadUseCodegen(DU);
+ if (widenWithVariantUse(DU)) {
+ widenWithVariantUseCodegen(DU);
return nullptr;
}
@@ -1992,8 +1660,8 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L,
// Information about sign/zero extensions of CurrIV.
IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT);
- Changed |=
- simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, Rewriter, &Visitor);
+ Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, Rewriter,
+ &Visitor);
if (Visitor.WI.WidestNativeType) {
WideIVs.push_back(Visitor.WI);
@@ -2017,7 +1685,7 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L,
/// Given an Value which is hoped to be part of an add recurance in the given
/// loop, return the associated Phi node if so. Otherwise, return null. Note
-/// that this is less general than SCEVs AddRec checking.
+/// that this is less general than SCEVs AddRec checking.
static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) {
Instruction *IncI = dyn_cast<Instruction>(IncV);
if (!IncI)
@@ -2079,7 +1747,7 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {
BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
if (L->isLoopInvariant(BI->getCondition()))
return false;
-
+
// Do LFTR to simplify the exit condition to an ICMP.
ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());
if (!Cond)
@@ -2122,9 +1790,9 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {
/// actually poison. This can be used to assess whether a new use of Root can
/// be added at a location which is control equivalent with OnPathTo (such as
/// immediately before it) without introducing UB which didn't previously
-/// exist. Note that a false result conveys no information.
+/// exist. Note that a false result conveys no information.
static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
- Instruction *OnPathTo,
+ Instruction *OnPathTo,
DominatorTree *DT) {
// Basic approach is to assume Root is poison, propagate poison forward
// through all users we can easily track, and then check whether any of those
@@ -2142,10 +1810,10 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
// If we know this must trigger UB on a path leading our target.
if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))
return true;
-
+
// If we can't analyze propagation through this instruction, just skip it
// and transitive users. Safe as false is a conservative result.
- if (!propagatesFullPoison(I) && I != Root)
+ if (!propagatesPoison(I) && I != Root)
continue;
if (KnownPoison.insert(I).second)
@@ -2154,7 +1822,7 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
}
// Might be non-UB, or might have a path we couldn't prove must execute on
- // way to exiting bb.
+ // way to exiting bb.
return false;
}
@@ -2221,7 +1889,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L,
ScalarEvolution *SE) {
assert(Phi->getParent() == L->getHeader());
assert(L->getLoopLatch());
-
+
if (!SE->isSCEVable(Phi->getType()))
return false;
@@ -2282,7 +1950,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
if (!hasConcreteDef(Phi)) {
// We explicitly allow unknown phis as long as they are already used by
// the loop exit test. This is legal since performing LFTR could not
- // increase the number of undef users.
+ // increase the number of undef users.
Value *IncPhi = Phi->getIncomingValueForBlock(LatchBlock);
if (!isLoopExitTestBasedOn(Phi, ExitingBB) &&
!isLoopExitTestBasedOn(IncPhi, ExitingBB))
@@ -2300,7 +1968,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
if (!Phi->getType()->isIntegerTy() &&
!mustExecuteUBIfPoisonOnPathTo(Phi, ExitingBB->getTerminator(), DT))
continue;
-
+
const SCEV *Init = AR->getStart();
if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) {
@@ -2506,14 +2174,14 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
// reasoning as from SimplifyIndvar::eliminateTrunc to see if we can extend
// the other side of the comparison instead. We still evaluate the limit
// in the narrower bitwidth, we just prefer a zext/sext outside the loop to
- // a truncate within in.
+ // a truncate within in.
bool Extended = false;
const SCEV *IV = SE->getSCEV(CmpIndVar);
const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar),
ExitCnt->getType());
const SCEV *ZExtTrunc =
SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType());
-
+
if (ZExtTrunc == IV) {
Extended = true;
ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(),
@@ -2531,7 +2199,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
if (Extended) {
bool Discard;
L->makeLoopInvariant(ExitCnt, Discard);
- } else
+ } else
CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(),
"lftr.wideiv");
}
@@ -2551,7 +2219,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
// update the branch to use the new comparison; in the common case this
// will make old comparison dead.
BI->setCondition(Cond);
- DeadInsts.push_back(OrigCond);
+ DeadInsts.emplace_back(OrigCond);
++NumLFTR;
return true;
@@ -2685,11 +2353,10 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
L->getExitingBlocks(ExitingBlocks);
// Remove all exits which aren't both rewriteable and analyzeable.
- auto NewEnd = llvm::remove_if(ExitingBlocks,
- [&](BasicBlock *ExitingBB) {
+ auto NewEnd = llvm::remove_if(ExitingBlocks, [&](BasicBlock *ExitingBB) {
// If our exitting block exits multiple loops, we can only rewrite the
// innermost one. Otherwise, we're changing how many times the innermost
- // loop runs before it exits.
+ // loop runs before it exits.
if (LI->getLoopFor(ExitingBB) != L)
return true;
@@ -2701,18 +2368,18 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
// If already constant, nothing to do.
if (isa<Constant>(BI->getCondition()))
return true;
-
+
const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
if (isa<SCEVCouldNotCompute>(ExitCount))
return true;
return false;
- });
+ });
ExitingBlocks.erase(NewEnd, ExitingBlocks.end());
if (ExitingBlocks.empty())
return false;
-
- // Get a symbolic upper bound on the loop backedge taken count.
+
+ // Get a symbolic upper bound on the loop backedge taken count.
const SCEV *MaxExitCount = getMaxBackedgeTakenCount(*SE, *DT, L);
if (isa<SCEVCouldNotCompute>(MaxExitCount))
return false;
@@ -2720,11 +2387,12 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
// Visit our exit blocks in order of dominance. We know from the fact that
// all exits (left) are analyzeable that the must be a total dominance order
// between them as each must dominate the latch. The visit order only
- // matters for the provably equal case.
+ // matters for the provably equal case.
llvm::sort(ExitingBlocks,
[&](BasicBlock *A, BasicBlock *B) {
// std::sort sorts in ascending order, so we want the inverse of
// the normal dominance relation.
+ if (A == B) return false;
if (DT->properlyDominates(A, B)) return true;
if (DT->properlyDominates(B, A)) return false;
llvm_unreachable("expected total dominance order!");
@@ -2734,7 +2402,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]));
}
#endif
-
+
auto FoldExit = [&](BasicBlock *ExitingBB, bool IsTaken) {
BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
@@ -2743,7 +2411,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
IsTaken ? ExitIfTrue : !ExitIfTrue);
BI->setCondition(NewCond);
if (OldCond->use_empty())
- DeadInsts.push_back(OldCond);
+ DeadInsts.emplace_back(OldCond);
};
bool Changed = false;
@@ -2751,7 +2419,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
for (BasicBlock *ExitingBB : ExitingBlocks) {
const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
assert(!isa<SCEVCouldNotCompute>(ExitCount) && "checked above");
-
+
// If we know we'd exit on the first iteration, rewrite the exit to
// reflect this. This does not imply the loop must exit through this
// exit; there may be an earlier one taken on the first iteration.
@@ -2769,13 +2437,13 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
if (!ExitCount->getType()->isIntegerTy() ||
!MaxExitCount->getType()->isIntegerTy())
continue;
-
+
Type *WiderType =
SE->getWiderType(MaxExitCount->getType(), ExitCount->getType());
ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType);
MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType);
assert(MaxExitCount->getType() == ExitCount->getType());
-
+
// Can we prove that some other exit must be taken strictly before this
// one?
if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT,
@@ -2788,7 +2456,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
// As we run, keep track of which exit counts we've encountered. If we
// find a duplicate, we've found an exit which would have exited on the
// exiting iteration, but (from the visit order) strictly follows another
- // which does the same and is thus dead.
+ // which does the same and is thus dead.
if (!DominatingExitCounts.insert(ExitCount).second) {
FoldExit(ExitingBB, false);
Changed = true;
@@ -2809,22 +2477,20 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
SmallVector<BasicBlock*, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
- bool Changed = false;
-
// Finally, see if we can rewrite our exit conditions into a loop invariant
- // form. If we have a read-only loop, and we can tell that we must exit down
+ // form. If we have a read-only loop, and we can tell that we must exit down
// a path which does not need any of the values computed within the loop, we
// can rewrite the loop to exit on the first iteration. Note that this
// doesn't either a) tell us the loop exits on the first iteration (unless
// *all* exits are predicateable) or b) tell us *which* exit might be taken.
// This transformation looks a lot like a restricted form of dead loop
// elimination, but restricted to read-only loops and without neccesssarily
- // needing to kill the loop entirely.
+ // needing to kill the loop entirely.
if (!LoopPredication)
- return Changed;
+ return false;
if (!SE->hasLoopInvariantBackedgeTakenCount(L))
- return Changed;
+ return false;
// Note: ExactBTC is the exact backedge taken count *iff* the loop exits
// through *explicit* control flow. We have to eliminate the possibility of
@@ -2833,16 +2499,16 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
if (isa<SCEVCouldNotCompute>(ExactBTC) ||
!SE->isLoopInvariant(ExactBTC, L) ||
!isSafeToExpand(ExactBTC, *SE))
- return Changed;
+ return false;
// If we end up with a pointer exit count, bail. It may be unsized.
if (!ExactBTC->getType()->isIntegerTy())
- return Changed;
+ return false;
auto BadExit = [&](BasicBlock *ExitingBB) {
// If our exiting block exits multiple loops, we can only rewrite the
// innermost one. Otherwise, we're changing how many times the innermost
- // loop runs before it exits.
+ // loop runs before it exits.
if (LI->getLoopFor(ExitingBB) != L)
return true;
@@ -2897,18 +2563,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
// is complicated and we choose not to for now.
for (unsigned i = 1; i < ExitingBlocks.size(); i++)
if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]))
- return Changed;
+ return false;
// Given our sorted total order, we know that exit[j] must be evaluated
// after all exit[i] such j > i.
for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++)
if (BadExit(ExitingBlocks[i])) {
- ExitingBlocks.resize(i);
+ ExitingBlocks.resize(i);
break;
}
if (ExitingBlocks.empty())
- return Changed;
+ return false;
// We rely on not being able to reach an exiting block on a later iteration
// then it's statically compute exit count. The implementaton of
@@ -2930,8 +2596,9 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
for (auto &I : *BB)
// TODO:isGuaranteedToTransfer
if (I.mayHaveSideEffects() || I.mayThrow())
- return Changed;
+ return false;
+ bool Changed = false;
// Finally, do the actual predication for all predicatable blocks. A couple
// of notes here:
// 1) We don't bother to constant fold dominated exits with identical exit
@@ -2970,7 +2637,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
Value *OldCond = BI->getCondition();
BI->setCondition(NewCond);
if (OldCond->use_empty())
- DeadInsts.push_back(OldCond);
+ DeadInsts.emplace_back(OldCond);
Changed = true;
}
@@ -2985,7 +2652,6 @@ bool IndVarSimplify::run(Loop *L) {
// We need (and expect!) the incoming loop to be in LCSSA.
assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
"LCSSA required to run indvars!");
- bool Changed = false;
// If LoopSimplify form is not available, stay out of trouble. Some notes:
// - LSR currently only supports LoopSimplify-form loops. Indvars'
@@ -3001,9 +2667,15 @@ bool IndVarSimplify::run(Loop *L) {
#ifndef NDEBUG
// Used below for a consistency check only
- const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
+ // Note: Since the result returned by ScalarEvolution may depend on the order
+ // in which previous results are added to its cache, the call to
+ // getBackedgeTakenCount() may change following SCEV queries.
+ const SCEV *BackedgeTakenCount;
+ if (VerifyIndvars)
+ BackedgeTakenCount = SE->getBackedgeTakenCount(L);
#endif
+ bool Changed = false;
// If there are any floating-point recurrences, attempt to
// transform them to use integer recurrences.
Changed |= rewriteNonIntegerIVs(L);
@@ -3027,8 +2699,13 @@ bool IndVarSimplify::run(Loop *L) {
// that are recurrent in the loop, and substitute the exit values from the
// loop into any instructions outside of the loop that use the final values
// of the current expressions.
- if (ReplaceExitValue != NeverRepl)
- Changed |= rewriteLoopExitValues(L, Rewriter);
+ if (ReplaceExitValue != NeverRepl) {
+ if (int Rewrites = rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT,
+ ReplaceExitValue, DeadInsts)) {
+ NumReplaced += Rewrites;
+ Changed = true;
+ }
+ }
// Eliminate redundant IV cycles.
NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts);
@@ -3039,7 +2716,7 @@ bool IndVarSimplify::run(Loop *L) {
// Given we've changed exit counts, notify SCEV
SE->forgetLoop(L);
}
-
+
// Try to form loop invariant tests for loop exits by changing how many
// iterations of the loop run when that is unobservable.
if (predicateLoopExits(L, Rewriter)) {
@@ -3049,8 +2726,11 @@ bool IndVarSimplify::run(Loop *L) {
}
// If we have a trip count expression, rewrite the loop's exit condition
- // using it.
+ // using it.
if (!DisableLFTR) {
+ BasicBlock *PreHeader = L->getLoopPreheader();
+ BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
+
SmallVector<BasicBlock*, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
for (BasicBlock *ExitingBB : ExitingBlocks) {
@@ -3060,10 +2740,10 @@ bool IndVarSimplify::run(Loop *L) {
// If our exitting block exits multiple loops, we can only rewrite the
// innermost one. Otherwise, we're changing how many times the innermost
- // loop runs before it exits.
+ // loop runs before it exits.
if (LI->getLoopFor(ExitingBB) != L)
continue;
-
+
if (!needsLFTR(L, ExitingBB))
continue;
@@ -3077,14 +2757,15 @@ bool IndVarSimplify::run(Loop *L) {
// until stable to handle cases like this better.
if (ExitCount->isZero())
continue;
-
+
PHINode *IndVar = FindLoopCounter(L, ExitingBB, ExitCount, SE, DT);
if (!IndVar)
continue;
-
+
// Avoid high cost expansions. Note: This heuristic is questionable in
- // that our definition of "high cost" is not exactly principled.
- if (Rewriter.isHighCostExpansion(ExitCount, L))
+ // that our definition of "high cost" is not exactly principled.
+ if (Rewriter.isHighCostExpansion(ExitCount, L, SCEVCheapExpansionBudget,
+ TTI, PreHeaderBR))
continue;
// Check preconditions for proper SCEVExpander operation. SCEV does not
@@ -3092,7 +2773,7 @@ bool IndVarSimplify::run(Loop *L) {
// any pass that uses the SCEVExpander must do it. This does not work
// well for loop passes because SCEVExpander makes assumptions about
// all loops, while LoopPassManager only forces the current loop to be
- // simplified.
+ // simplified.
//
// FIXME: SCEV expansion has no way to bail out, so the caller must
// explicitly check any assumptions made by SCEV. Brittle.
@@ -3113,7 +2794,8 @@ bool IndVarSimplify::run(Loop *L) {
while (!DeadInsts.empty())
if (Instruction *Inst =
dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val()))
- Changed |= RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI);
+ Changed |=
+ RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI, MSSAU.get());
// The Rewriter may not be used from this point on.
@@ -3127,7 +2809,7 @@ bool IndVarSimplify::run(Loop *L) {
Changed |= rewriteFirstIterationLoopExitValues(L);
// Clean up dead instructions.
- Changed |= DeleteDeadPHIs(L->getHeader(), TLI);
+ Changed |= DeleteDeadPHIs(L->getHeader(), TLI, MSSAU.get());
// Check a post-condition.
assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
@@ -3150,6 +2832,8 @@ bool IndVarSimplify::run(Loop *L) {
assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount,
NewBECount) && "indvars must preserve SCEV");
}
+ if (VerifyMemorySSA && MSSAU)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
#endif
return Changed;
@@ -3161,12 +2845,14 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
Function *F = L.getHeader()->getParent();
const DataLayout &DL = F->getParent()->getDataLayout();
- IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI);
+ IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA);
if (!IVS.run(&L))
return PreservedAnalyses::all();
auto PA = getLoopPassPreservedAnalyses();
PA.preserveSet<CFGAnalyses>();
+ if (AR.MSSA)
+ PA.preserve<MemorySSAAnalysis>();
return PA;
}
@@ -3191,13 +2877,18 @@ struct IndVarSimplifyLegacyPass : public LoopPass {
auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr;
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
+ auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ MemorySSA *MSSA = nullptr;
+ if (MSSAAnalysis)
+ MSSA = &MSSAAnalysis->getMSSA();
- IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI);
+ IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA);
return IVS.run(L);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
+ AU.addPreserved<MemorySSAWrapperPass>();
getLoopAnalysisUsage(AU);
}
};
diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
index 58469749600e..30e4822b6769 100644
--- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
@@ -47,6 +47,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/PriorityWorklist.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@@ -55,8 +56,8 @@
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
@@ -87,6 +88,7 @@
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <algorithm>
#include <cassert>
@@ -242,20 +244,25 @@ public:
bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);
};
-class IRCELegacyPass : public LoopPass {
+class IRCELegacyPass : public FunctionPass {
public:
static char ID;
- IRCELegacyPass() : LoopPass(ID) {
+ IRCELegacyPass() : FunctionPass(ID) {
initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry());
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<BranchProbabilityInfoWrapperPass>();
- getLoopAnalysisUsage(AU);
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ AU.addPreserved<ScalarEvolutionWrapperPass>();
}
- bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+ bool runOnFunction(Function &F) override;
};
} // end anonymous namespace
@@ -265,7 +272,9 @@ char IRCELegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce",
"Inductive range check elimination", false, false)
INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination",
false, false)
@@ -866,7 +875,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,
const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
const SCEV *Step = SE.getSCEV(StepCI);
- ConstantInt *One = ConstantInt::get(IndVarTy, 1);
+ const SCEV *FixedRightSCEV = nullptr;
+
+ // If RightValue resides within loop (but still being loop invariant),
+ // regenerate it as preheader.
+ if (auto *I = dyn_cast<Instruction>(RightValue))
+ if (L.contains(I->getParent()))
+ FixedRightSCEV = RightSCEV;
+
if (IsIncreasing) {
bool DecreasedRightValueByOne = false;
if (StepCI->isOne()) {
@@ -928,10 +944,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,
if (LatchBrExitIdx == 0) {
// We need to increase the right value unless we have already decreased
// it virtually when we replaced EQ with SGT.
- if (!DecreasedRightValueByOne) {
- IRBuilder<> B(Preheader->getTerminator());
- RightValue = B.CreateAdd(RightValue, One);
- }
+ if (!DecreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
} else {
assert(!DecreasedRightValueByOne &&
"Right value can be decreased only for LatchBrExitIdx == 0!");
@@ -995,10 +1010,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,
if (LatchBrExitIdx == 0) {
// We need to decrease the right value unless we have already increased
// it virtually when we replaced EQ with SLT.
- if (!IncreasedRightValueByOne) {
- IRBuilder<> B(Preheader->getTerminator());
- RightValue = B.CreateSub(RightValue, One);
- }
+ if (!IncreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
} else {
assert(!IncreasedRightValueByOne &&
"Right value can be increased only for LatchBrExitIdx == 0!");
@@ -1012,9 +1026,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,
assert(!L.contains(LatchExit) && "expected an exit block!");
const DataLayout &DL = Preheader->getModule()->getDataLayout();
- Value *IndVarStartV =
- SCEVExpander(SE, DL, "irce")
- .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator());
+ SCEVExpander Expander(SE, DL, "irce");
+ Instruction *Ins = Preheader->getTerminator();
+
+ if (FixedRightSCEV)
+ RightValue =
+ Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
+
+ Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
IndVarStartV->setName("indvar.start");
LoopStructure Result;
@@ -1747,27 +1766,41 @@ IntersectUnsignedRange(ScalarEvolution &SE,
return Ret;
}
-PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM,
- LoopStandardAnalysisResults &AR,
- LPMUpdater &U) {
- Function *F = L.getHeader()->getParent();
- const auto &FAM =
- AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
- auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
- InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI);
- auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) {
+PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) {
+ auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F);
+ LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
+
+ InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI);
+
+ bool Changed = false;
+
+ for (const auto &L : LI) {
+ Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
+ /*PreserveLCSSA=*/false);
+ Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
+ }
+
+ SmallPriorityWorklist<Loop *, 4> Worklist;
+ appendLoopsToWorklist(LI, Worklist);
+ auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) {
if (!IsSubloop)
- U.addSiblingLoops(NL);
+ appendLoopsToWorklist(*NL, Worklist);
};
- bool Changed = IRCE.run(&L, LPMAddNewLoop);
+
+ while (!Worklist.empty()) {
+ Loop *L = Worklist.pop_back_val();
+ Changed |= IRCE.run(L, LPMAddNewLoop);
+ }
+
if (!Changed)
return PreservedAnalyses::all();
-
return getLoopPassPreservedAnalyses();
}
-bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
- if (skipLoop(L))
+bool IRCELegacyPass::runOnFunction(Function &F) {
+ if (skipFunction(F))
return false;
ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
@@ -1776,10 +1809,27 @@ bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI);
- auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) {
- LPM.addLoop(*NL);
+
+ bool Changed = false;
+
+ for (const auto &L : LI) {
+ Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
+ /*PreserveLCSSA=*/false);
+ Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
+ }
+
+ SmallPriorityWorklist<Loop *, 4> Worklist;
+ appendLoopsToWorklist(LI, Worklist);
+ auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) {
+ if (!IsSubloop)
+ appendLoopsToWorklist(*NL, Worklist);
};
- return IRCE.run(L, LPMAddNewLoop);
+
+ while (!Worklist.empty()) {
+ Loop *L = Worklist.pop_back_val();
+ Changed |= IRCE.run(L, LPMAddNewLoop);
+ }
+ return Changed;
}
bool InductiveRangeCheckElimination::run(
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index dfb1b6bfb739..db9cc58bbfc4 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -96,7 +96,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -116,11 +115,13 @@
#include "llvm/IR/ValueHandle.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <cassert>
#include <iterator>
@@ -132,16 +133,23 @@
using namespace llvm;
+static cl::opt<bool> AssumeDefaultIsFlatAddressSpace(
+ "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden,
+ cl::desc("The default address space is assumed as the flat address space. "
+ "This is mainly for test purpose."));
+
static const unsigned UninitializedAddressSpace =
std::numeric_limits<unsigned>::max();
namespace {
using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>;
+using PostorderStackTy = llvm::SmallVector<PointerIntPair<Value *, 1, bool>, 4>;
/// InferAddressSpaces
class InferAddressSpaces : public FunctionPass {
const TargetTransformInfo *TTI = nullptr;
+ const DataLayout *DL = nullptr;
/// Target specific address space which uses of should be replaced if
/// possible.
@@ -174,6 +182,11 @@ private:
bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const;
+ Value *cloneInstructionWithNewAddressSpace(
+ Instruction *I, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace,
+ SmallVectorImpl<const Use *> *UndefUsesToFix) const;
+
// Changes the flat address expressions in function F to point to specific
// address spaces if InferredAddrSpace says so. Postorder is the postorder of
// all flat expressions in the use-def graph of function F.
@@ -182,15 +195,14 @@ private:
const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const;
void appendsFlatAddressExpressionToPostorderStack(
- Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack,
- DenseSet<Value *> &Visited) const;
+ Value *V, PostorderStackTy &PostorderStack,
+ DenseSet<Value *> &Visited) const;
bool rewriteIntrinsicOperands(IntrinsicInst *II,
Value *OldV, Value *NewV) const;
- void collectRewritableIntrinsicOperands(
- IntrinsicInst *II,
- std::vector<std::pair<Value *, bool>> &PostorderStack,
- DenseSet<Value *> &Visited) const;
+ void collectRewritableIntrinsicOperands(IntrinsicInst *II,
+ PostorderStackTy &PostorderStack,
+ DenseSet<Value *> &Visited) const;
std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const;
@@ -214,24 +226,65 @@ void initializeInferAddressSpacesPass(PassRegistry &);
INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
false, false)
+// Check whether that's no-op pointer bicast using a pair of
+// `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over
+// different address spaces.
+static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL,
+ const TargetTransformInfo *TTI) {
+ assert(I2P->getOpcode() == Instruction::IntToPtr);
+ auto *P2I = dyn_cast<Operator>(I2P->getOperand(0));
+ if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
+ return false;
+ // Check it's really safe to treat that pair of `ptrtoint`/`inttoptr` as a
+ // no-op cast. Besides checking both of them are no-op casts, as the
+ // reinterpreted pointer may be used in other pointer arithmetic, we also
+ // need to double-check that through the target-specific hook. That ensures
+ // the underlying target also agrees that's a no-op address space cast and
+ // pointer bits are preserved.
+ // The current IR spec doesn't have clear rules on address space casts,
+ // especially a clear definition for pointer bits in non-default address
+ // spaces. It would be undefined if that pointer is dereferenced after an
+ // invalid reinterpret cast. Also, due to the unclearness for the meaning of
+ // bits in non-default address spaces in the current spec, the pointer
+ // arithmetic may also be undefined after invalid pointer reinterpret cast.
+ // However, as we confirm through the target hooks that it's a no-op
+ // addrspacecast, it doesn't matter since the bits should be the same.
+ return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()),
+ I2P->getOperand(0)->getType(), I2P->getType(),
+ DL) &&
+ CastInst::isNoopCast(Instruction::CastOps(P2I->getOpcode()),
+ P2I->getOperand(0)->getType(), P2I->getType(),
+ DL) &&
+ TTI->isNoopAddrSpaceCast(
+ P2I->getOperand(0)->getType()->getPointerAddressSpace(),
+ I2P->getType()->getPointerAddressSpace());
+}
+
// Returns true if V is an address expression.
// TODO: Currently, we consider only phi, bitcast, addrspacecast, and
// getelementptr operators.
-static bool isAddressExpression(const Value &V) {
- if (!isa<Operator>(V))
+static bool isAddressExpression(const Value &V, const DataLayout &DL,
+ const TargetTransformInfo *TTI) {
+ const Operator *Op = dyn_cast<Operator>(&V);
+ if (!Op)
return false;
- const Operator &Op = cast<Operator>(V);
- switch (Op.getOpcode()) {
+ switch (Op->getOpcode()) {
case Instruction::PHI:
- assert(Op.getType()->isPointerTy());
+ assert(Op->getType()->isPointerTy());
return true;
case Instruction::BitCast:
case Instruction::AddrSpaceCast:
case Instruction::GetElementPtr:
return true;
case Instruction::Select:
- return Op.getType()->isPointerTy();
+ return Op->getType()->isPointerTy();
+ case Instruction::Call: {
+ const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V);
+ return II && II->getIntrinsicID() == Intrinsic::ptrmask;
+ }
+ case Instruction::IntToPtr:
+ return isNoopPtrIntCastPair(Op, DL, TTI);
default:
return false;
}
@@ -240,7 +293,9 @@ static bool isAddressExpression(const Value &V) {
// Returns the pointer operands of V.
//
// Precondition: V is an address expression.
-static SmallVector<Value *, 2> getPointerOperands(const Value &V) {
+static SmallVector<Value *, 2>
+getPointerOperands(const Value &V, const DataLayout &DL,
+ const TargetTransformInfo *TTI) {
const Operator &Op = cast<Operator>(V);
switch (Op.getOpcode()) {
case Instruction::PHI: {
@@ -254,12 +309,22 @@ static SmallVector<Value *, 2> getPointerOperands(const Value &V) {
return {Op.getOperand(0)};
case Instruction::Select:
return {Op.getOperand(1), Op.getOperand(2)};
+ case Instruction::Call: {
+ const IntrinsicInst &II = cast<IntrinsicInst>(Op);
+ assert(II.getIntrinsicID() == Intrinsic::ptrmask &&
+ "unexpected intrinsic call");
+ return {II.getArgOperand(0)};
+ }
+ case Instruction::IntToPtr: {
+ assert(isNoopPtrIntCastPair(&Op, DL, TTI));
+ auto *P2I = cast<Operator>(Op.getOperand(0));
+ return {P2I->getOperand(0)};
+ }
default:
llvm_unreachable("Unexpected instruction type.");
}
}
-// TODO: Move logic to TTI?
bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
Value *OldV,
Value *NewV) const {
@@ -275,16 +340,26 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
II->setCalledFunction(NewDecl);
return true;
}
- default:
- return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
+ case Intrinsic::ptrmask:
+ // This is handled as an address expression, not as a use memory operation.
+ return false;
+ default: {
+ Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
+ if (!Rewrite)
+ return false;
+ if (Rewrite != II)
+ II->replaceAllUsesWith(Rewrite);
+ return true;
+ }
}
}
void InferAddressSpaces::collectRewritableIntrinsicOperands(
- IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack,
+ IntrinsicInst *II, PostorderStackTy &PostorderStack,
DenseSet<Value *> &Visited) const {
auto IID = II->getIntrinsicID();
switch (IID) {
+ case Intrinsic::ptrmask:
case Intrinsic::objectsize:
appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
PostorderStack, Visited);
@@ -305,7 +380,7 @@ void InferAddressSpaces::collectRewritableIntrinsicOperands(
// If V is an unvisited flat address expression, appends V to PostorderStack
// and marks it as visited.
void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack(
- Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack,
+ Value *V, PostorderStackTy &PostorderStack,
DenseSet<Value *> &Visited) const {
assert(V->getType()->isPointerTy());
@@ -313,21 +388,21 @@ void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack(
// expressions.
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
// TODO: Look in non-address parts, like icmp operands.
- if (isAddressExpression(*CE) && Visited.insert(CE).second)
- PostorderStack.push_back(std::make_pair(CE, false));
+ if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
+ PostorderStack.emplace_back(CE, false);
return;
}
- if (isAddressExpression(*V) &&
+ if (isAddressExpression(*V, *DL, TTI) &&
V->getType()->getPointerAddressSpace() == FlatAddrSpace) {
if (Visited.insert(V).second) {
- PostorderStack.push_back(std::make_pair(V, false));
+ PostorderStack.emplace_back(V, false);
Operator *Op = cast<Operator>(V);
for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) {
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) {
- if (isAddressExpression(*CE) && Visited.insert(CE).second)
+ if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
PostorderStack.emplace_back(CE, false);
}
}
@@ -341,7 +416,7 @@ std::vector<WeakTrackingVH>
InferAddressSpaces::collectFlatAddressExpressions(Function &F) const {
// This function implements a non-recursive postorder traversal of a partial
// use-def graph of function F.
- std::vector<std::pair<Value *, bool>> PostorderStack;
+ PostorderStackTy PostorderStack;
// The set of visited expressions.
DenseSet<Value *> Visited;
@@ -383,23 +458,27 @@ InferAddressSpaces::collectFlatAddressExpressions(Function &F) const {
} else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) {
if (!ASC->getType()->isVectorTy())
PushPtrOperand(ASC->getPointerOperand());
+ } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
+ if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
+ PushPtrOperand(
+ cast<PtrToIntInst>(I2P->getOperand(0))->getPointerOperand());
}
}
std::vector<WeakTrackingVH> Postorder; // The resultant postorder.
while (!PostorderStack.empty()) {
- Value *TopVal = PostorderStack.back().first;
+ Value *TopVal = PostorderStack.back().getPointer();
// If the operands of the expression on the top are already explored,
// adds that expression to the resultant postorder.
- if (PostorderStack.back().second) {
+ if (PostorderStack.back().getInt()) {
if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace)
Postorder.push_back(TopVal);
PostorderStack.pop_back();
continue;
}
// Otherwise, adds its operands to the stack and explores them.
- PostorderStack.back().second = true;
- for (Value *PtrOperand : getPointerOperands(*TopVal)) {
+ PostorderStack.back().setInt(true);
+ for (Value *PtrOperand : getPointerOperands(*TopVal, *DL, TTI)) {
appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack,
Visited);
}
@@ -438,10 +517,13 @@ static Value *operandWithNewAddressSpaceOrCreateUndef(
// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
// from a pointer whose type already matches. Therefore, this function returns a
// Value* instead of an Instruction*.
-static Value *cloneInstructionWithNewAddressSpace(
+//
+// This may also return nullptr in the case the instruction could not be
+// rewritten.
+Value *InferAddressSpaces::cloneInstructionWithNewAddressSpace(
Instruction *I, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
- SmallVectorImpl<const Use *> *UndefUsesToFix) {
+ SmallVectorImpl<const Use *> *UndefUsesToFix) const {
Type *NewPtrType =
I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
@@ -456,6 +538,23 @@ static Value *cloneInstructionWithNewAddressSpace(
return Src;
}
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ // Technically the intrinsic ID is a pointer typed argument, so specially
+ // handle calls early.
+ assert(II->getIntrinsicID() == Intrinsic::ptrmask);
+ Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef(
+ II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace,
+ UndefUsesToFix);
+ Value *Rewrite =
+ TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
+ if (Rewrite) {
+ assert(Rewrite != II && "cannot modify this pointer operation in place");
+ return Rewrite;
+ }
+
+ return nullptr;
+ }
+
// Computes the converted pointer operands.
SmallVector<Value *, 4> NewPointerOperands;
for (const Use &OperandUse : I->operands()) {
@@ -492,6 +591,14 @@ static Value *cloneInstructionWithNewAddressSpace(
assert(I->getType()->isPointerTy());
return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
NewPointerOperands[2], "", nullptr, I);
+ case Instruction::IntToPtr: {
+ assert(isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI));
+ Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0);
+ assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+ if (Src->getType() != NewPtrType)
+ return new BitCastInst(Src, NewPtrType);
+ return Src;
+ }
default:
llvm_unreachable("Unexpected opcode");
}
@@ -501,8 +608,9 @@ static Value *cloneInstructionWithNewAddressSpace(
// constant expression `CE` with its operands replaced as specified in
// ValueWithNewAddrSpace.
static Value *cloneConstantExprWithNewAddressSpace(
- ConstantExpr *CE, unsigned NewAddrSpace,
- const ValueToValueMapTy &ValueWithNewAddrSpace) {
+ ConstantExpr *CE, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
+ const TargetTransformInfo *TTI) {
Type *TargetType =
CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
@@ -533,6 +641,13 @@ static Value *cloneConstantExprWithNewAddressSpace(
}
}
+ if (CE->getOpcode() == Instruction::IntToPtr) {
+ assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI));
+ Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
+ assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+ return ConstantExpr::getBitCast(Src, TargetType);
+ }
+
// Computes the operands of the new constant expression.
bool IsNew = false;
SmallVector<Constant *, 4> NewOperands;
@@ -550,7 +665,7 @@ static Value *cloneConstantExprWithNewAddressSpace(
}
if (auto CExpr = dyn_cast<ConstantExpr>(Operand))
if (Value *NewOperand = cloneConstantExprWithNewAddressSpace(
- CExpr, NewAddrSpace, ValueWithNewAddrSpace)) {
+ CExpr, NewAddrSpace, ValueWithNewAddrSpace, DL, TTI)) {
IsNew = true;
NewOperands.push_back(cast<Constant>(NewOperand));
continue;
@@ -585,13 +700,13 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace(
const ValueToValueMapTy &ValueWithNewAddrSpace,
SmallVectorImpl<const Use *> *UndefUsesToFix) const {
// All values in Postorder are flat address expressions.
- assert(isAddressExpression(*V) &&
+ assert(isAddressExpression(*V, *DL, TTI) &&
V->getType()->getPointerAddressSpace() == FlatAddrSpace);
if (Instruction *I = dyn_cast<Instruction>(V)) {
Value *NewV = cloneInstructionWithNewAddressSpace(
I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix);
- if (Instruction *NewI = dyn_cast<Instruction>(NewV)) {
+ if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) {
if (NewI->getParent() == nullptr) {
NewI->insertBefore(I);
NewI->takeName(I);
@@ -601,7 +716,7 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace(
}
return cloneConstantExprWithNewAddressSpace(
- cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace);
+ cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace, DL, TTI);
}
// Defines the join operation on the address space lattice (see the file header
@@ -625,6 +740,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) {
return false;
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ DL = &F.getParent()->getDataLayout();
+
+ if (AssumeDefaultIsFlatAddressSpace)
+ FlatAddrSpace = 0;
if (FlatAddrSpace == UninitializedAddressSpace) {
FlatAddrSpace = TTI->getFlatAddressSpace();
@@ -729,7 +848,7 @@ Optional<unsigned> InferAddressSpaces::updateAddressSpace(
else
NewAS = joinAddressSpaces(Src0AS, Src1AS);
} else {
- for (Value *PtrOperand : getPointerOperands(V)) {
+ for (Value *PtrOperand : getPointerOperands(V, *DL, TTI)) {
auto I = InferredAddrSpace.find(PtrOperand);
unsigned OperandAS = I != InferredAddrSpace.end() ?
I->second : PtrOperand->getType()->getPointerAddressSpace();
@@ -879,8 +998,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(
for (Value* V : Postorder) {
unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
- ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace(
- V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix);
+ Value *New = cloneValueWithNewAddressSpace(
+ V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix);
+ if (New)
+ ValueWithNewAddrSpace[V] = New;
}
}
@@ -890,7 +1011,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(
// Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace.
for (const Use *UndefUse : UndefUsesToFix) {
User *V = UndefUse->getUser();
- User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V));
+ User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V));
+ if (!NewV)
+ continue;
+
unsigned OperandNo = UndefUse->getOperandNo();
assert(isa<UndefValue>(NewV->getOperand(OperandNo)));
NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get()));
diff --git a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
index e8bbf2936da6..e87b622ab19f 100644
--- a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
+++ b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
@@ -40,7 +40,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ,
if (!SQ.DT->isReachableFromEntry(&BB))
continue;
- SmallVector<Instruction *, 8> DeadInstsInBB;
+ SmallVector<WeakTrackingVH, 8> DeadInstsInBB;
for (Instruction &I : BB) {
// The first time through the loop, ToSimplify is empty and we try to
// simplify all instructions. On later iterations, ToSimplify is not
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 98c2fcb3dae0..9d0500419a7f 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -13,6 +13,7 @@
#include "llvm/Transforms/Scalar/JumpThreading.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -170,7 +171,7 @@ FunctionPass *llvm::createJumpThreadingPass(int Threshold) {
}
JumpThreadingPass::JumpThreadingPass(int T) {
- BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T);
+ DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T);
}
// Update branch probability information according to conditional
@@ -213,11 +214,16 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (!CondBr)
return;
- BranchProbability BP;
uint64_t TrueWeight, FalseWeight;
if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
return;
+ if (TrueWeight + FalseWeight == 0)
+ // Zero branch_weights do not give a hint for getting branch probabilities.
+ // Technically it would result in division by zero denominator, which is
+ // TrueWeight + FalseWeight.
+ return;
+
// Returns the outgoing edge of the dominating predecessor block
// that leads to the PhiNode's incoming block:
auto GetPredOutEdge =
@@ -252,10 +258,11 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (!CI || !CI->getType()->isIntegerTy(1))
continue;
- BP = (CI->isOne() ? BranchProbability::getBranchProbability(
- TrueWeight, TrueWeight + FalseWeight)
- : BranchProbability::getBranchProbability(
- FalseWeight, TrueWeight + FalseWeight));
+ BranchProbability BP =
+ (CI->isOne() ? BranchProbability::getBranchProbability(
+ TrueWeight, TrueWeight + FalseWeight)
+ : BranchProbability::getBranchProbability(
+ FalseWeight, TrueWeight + FalseWeight));
auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB);
if (!PredOutEdge.first)
@@ -298,8 +305,6 @@ bool JumpThreading::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- // Get DT analysis before LVI. When LVI is initialized it conditionally adds
- // DT if it's available.
auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
@@ -316,7 +321,7 @@ bool JumpThreading::runOnFunction(Function &F) {
std::move(BFI), std::move(BPI));
if (PrintLVIAfterJumpThreading) {
dbgs() << "LVI for function '" << F.getName() << "':\n";
- LVI->printLVI(F, *DT, dbgs());
+ LVI->printLVI(F, DTU.getDomTree(), dbgs());
}
return Changed;
}
@@ -324,8 +329,6 @@ bool JumpThreading::runOnFunction(Function &F) {
PreservedAnalyses JumpThreadingPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- // Get DT analysis before LVI. When LVI is initialized it conditionally adds
- // DT if it's available.
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &LVI = AM.getResult<LazyValueAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
@@ -374,6 +377,15 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
BFI = std::move(BFI_);
}
+ // Reduce the number of instructions duplicated when optimizing strictly for
+ // size.
+ if (BBDuplicateThreshold.getNumOccurrences())
+ BBDupThreshold = BBDuplicateThreshold;
+ else if (F.hasFnAttribute(Attribute::MinSize))
+ BBDupThreshold = 3;
+ else
+ BBDupThreshold = DefaultBBDupThreshold;
+
// JumpThreading must not processes blocks unreachable from entry. It's a
// waste of compute time and can potentially lead to hangs.
SmallPtrSet<BasicBlock *, 16> Unreachable;
@@ -396,6 +408,12 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
continue;
while (ProcessBlock(&BB)) // Thread all of the branches we can over BB.
Changed = true;
+
+ // Jump threading may have introduced redundant debug values into BB
+ // which should be removed.
+ if (Changed)
+ RemoveRedundantDbgInstrs(&BB);
+
// Stop processing BB if it's the entry or is now deleted. The following
// routines attempt to eliminate BB and locating a suitable replacement
// for the entry is non-trivial.
@@ -418,26 +436,27 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
// ProcessBlock doesn't thread BBs with unconditional TIs. However, if BB
// is "almost empty", we attempt to merge BB with its sole successor.
auto *BI = dyn_cast<BranchInst>(BB.getTerminator());
- if (BI && BI->isUnconditional() &&
- // The terminator must be the only non-phi instruction in BB.
- BB.getFirstNonPHIOrDbg()->isTerminator() &&
- // Don't alter Loop headers and latches to ensure another pass can
- // detect and transform nested loops later.
- !LoopHeaders.count(&BB) && !LoopHeaders.count(BI->getSuccessor(0)) &&
- TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) {
- // BB is valid for cleanup here because we passed in DTU. F remains
- // BB's parent until a DTU->getDomTree() event.
- LVI->eraseBlock(&BB);
- Changed = true;
+ if (BI && BI->isUnconditional()) {
+ BasicBlock *Succ = BI->getSuccessor(0);
+ if (
+ // The terminator must be the only non-phi instruction in BB.
+ BB.getFirstNonPHIOrDbg()->isTerminator() &&
+ // Don't alter Loop headers and latches to ensure another pass can
+ // detect and transform nested loops later.
+ !LoopHeaders.count(&BB) && !LoopHeaders.count(Succ) &&
+ TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) {
+ RemoveRedundantDbgInstrs(Succ);
+ // BB is valid for cleanup here because we passed in DTU. F remains
+ // BB's parent until a DTU->getDomTree() event.
+ LVI->eraseBlock(&BB);
+ Changed = true;
+ }
}
}
EverChanged |= Changed;
} while (Changed);
LoopHeaders.clear();
- // Flush only the Dominator Tree.
- DTU->getDomTree();
- LVI->enableDT();
return EverChanged;
}
@@ -592,20 +611,19 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) {
/// This returns true if there were any known values.
bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
Value *V, BasicBlock *BB, PredValueInfo &Result,
- ConstantPreference Preference,
- DenseSet<std::pair<Value *, BasicBlock *>> &RecursionSet,
+ ConstantPreference Preference, DenseSet<Value *> &RecursionSet,
Instruction *CxtI) {
// This method walks up use-def chains recursively. Because of this, we could
// get into an infinite loop going around loops in the use-def chain. To
// prevent this, keep track of what (value, block) pairs we've already visited
// and terminate the search if we loop back to them
- if (!RecursionSet.insert(std::make_pair(V, BB)).second)
+ if (!RecursionSet.insert(V).second)
return false;
// If V is a constant, then it is known in all predecessors.
if (Constant *KC = getKnownConstant(V, Preference)) {
for (BasicBlock *Pred : predecessors(BB))
- Result.push_back(std::make_pair(KC, Pred));
+ Result.emplace_back(KC, Pred);
return !Result.empty();
}
@@ -627,17 +645,12 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
// able to handle value inequalities better, for example if the compare is
// "X < 4" and "X < 3" is known true but "X < 4" itself is not available.
// Perhaps getConstantOnEdge should be smart enough to do this?
-
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
for (BasicBlock *P : predecessors(BB)) {
// If the value is known by LazyValueInfo to be a constant in a
// predecessor, use that information to try to thread this block.
Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI);
if (Constant *KC = getKnownConstant(PredCst, Preference))
- Result.push_back(std::make_pair(KC, P));
+ Result.emplace_back(KC, P);
}
return !Result.empty();
@@ -645,20 +658,16 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
/// If I is a PHI node, then we know the incoming values for any constants.
if (PHINode *PN = dyn_cast<PHINode>(I)) {
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
Value *InVal = PN->getIncomingValue(i);
if (Constant *KC = getKnownConstant(InVal, Preference)) {
- Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i)));
+ Result.emplace_back(KC, PN->getIncomingBlock(i));
} else {
Constant *CI = LVI->getConstantOnEdge(InVal,
PN->getIncomingBlock(i),
BB, CxtI);
if (Constant *KC = getKnownConstant(CI, Preference))
- Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i)));
+ Result.emplace_back(KC, PN->getIncomingBlock(i));
}
}
@@ -757,7 +766,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI);
if (Constant *KC = getKnownConstant(Folded, WantInteger))
- Result.push_back(std::make_pair(KC, LHSVal.second));
+ Result.emplace_back(KC, LHSVal.second);
}
}
@@ -779,10 +788,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
const DataLayout &DL = PN->getModule()->getDataLayout();
// We can do this simplification if any comparisons fold to true or false.
// See if any do.
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
BasicBlock *PredBB = PN->getIncomingBlock(i);
Value *LHS, *RHS;
@@ -813,7 +818,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
}
if (Constant *KC = getKnownConstant(Res, WantInteger))
- Result.push_back(std::make_pair(KC, PredBB));
+ Result.emplace_back(KC, PredBB);
}
return !Result.empty();
@@ -826,10 +831,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
if (!isa<Instruction>(CmpLHS) ||
cast<Instruction>(CmpLHS)->getParent() != BB) {
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
for (BasicBlock *P : predecessors(BB)) {
// If the value is known by LazyValueInfo to be a constant in a
// predecessor, use that information to try to thread this block.
@@ -840,7 +841,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
continue;
Constant *ResC = ConstantInt::get(CmpType, Res);
- Result.push_back(std::make_pair(ResC, P));
+ Result.emplace_back(ResC, P);
}
return !Result.empty();
@@ -858,10 +859,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) {
if (!isa<Instruction>(AddLHS) ||
cast<Instruction>(AddLHS)->getParent() != BB) {
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
for (BasicBlock *P : predecessors(BB)) {
// If the value is known by LazyValueInfo to be a ConstantRange in
// a predecessor, use that information to try to thread this
@@ -883,7 +880,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
else
continue;
- Result.push_back(std::make_pair(ResC, P));
+ Result.emplace_back(ResC, P);
}
return !Result.empty();
@@ -901,7 +898,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
Constant *V = LHSVal.first;
Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);
if (Constant *KC = getKnownConstant(Folded, WantInteger))
- Result.push_back(std::make_pair(KC, LHSVal.second));
+ Result.emplace_back(KC, LHSVal.second);
}
return !Result.empty();
@@ -935,7 +932,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
// See if the select has a known constant value for this predecessor.
if (Constant *Val = KnownCond ? TrueVal : FalseVal)
- Result.push_back(std::make_pair(Val, C.second));
+ Result.emplace_back(Val, C.second);
}
return !Result.empty();
@@ -943,14 +940,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
}
// If all else fails, see if LVI can figure out a constant value for us.
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
Constant *CI = LVI->getConstant(V, BB, CxtI);
if (Constant *KC = getKnownConstant(CI, Preference)) {
for (BasicBlock *Pred : predecessors(BB))
- Result.push_back(std::make_pair(KC, Pred));
+ Result.emplace_back(KC, Pred);
}
return !Result.empty();
@@ -1106,10 +1099,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
// threading is concerned.
assert(CondBr->isConditional() && "Threading on unconditional terminator");
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
LazyValueInfo::Tristate Ret =
LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0),
CondConst, CondBr);
@@ -1363,7 +1352,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) {
// If so, this load is partially redundant. Remember this info so that we
// can create a PHI node.
- AvailablePreds.push_back(std::make_pair(PredBB, PredAvailable));
+ AvailablePreds.emplace_back(PredBB, PredAvailable);
}
// If the loaded value isn't available in any predecessor, it isn't partially
@@ -1430,14 +1419,14 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) {
"Can't handle critical edge here!");
LoadInst *NewVal = new LoadInst(
LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred),
- LoadI->getName() + ".pr", false, MaybeAlign(LoadI->getAlignment()),
+ LoadI->getName() + ".pr", false, LoadI->getAlign(),
LoadI->getOrdering(), LoadI->getSyncScopeID(),
UnavailablePred->getTerminator());
NewVal->setDebugLoc(LoadI->getDebugLoc());
if (AATags)
NewVal->setAAMetadata(AATags);
- AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal));
+ AvailablePreds.emplace_back(UnavailablePred, NewVal);
}
// Now we know that each predecessor of this block has a value in
@@ -1496,56 +1485,70 @@ FindMostPopularDest(BasicBlock *BB,
// explicitly choose to ignore 'undef' destinations. We prefer to thread
// blocks with known and real destinations to threading undef. We'll handle
// them later if interesting.
- DenseMap<BasicBlock*, unsigned> DestPopularity;
+ MapVector<BasicBlock *, unsigned> DestPopularity;
+
+ // Populate DestPopularity with the successors in the order they appear in the
+ // successor list. This way, we ensure determinism by iterating it in the
+ // same order in std::max_element below. We map nullptr to 0 so that we can
+ // return nullptr when PredToDestList contains nullptr only.
+ DestPopularity[nullptr] = 0;
+ for (auto *SuccBB : successors(BB))
+ DestPopularity[SuccBB] = 0;
+
for (const auto &PredToDest : PredToDestList)
if (PredToDest.second)
DestPopularity[PredToDest.second]++;
- if (DestPopularity.empty())
- return nullptr;
-
// Find the most popular dest.
- DenseMap<BasicBlock*, unsigned>::iterator DPI = DestPopularity.begin();
- BasicBlock *MostPopularDest = DPI->first;
- unsigned Popularity = DPI->second;
- SmallVector<BasicBlock*, 4> SamePopularity;
-
- for (++DPI; DPI != DestPopularity.end(); ++DPI) {
- // If the popularity of this entry isn't higher than the popularity we've
- // seen so far, ignore it.
- if (DPI->second < Popularity)
- ; // ignore.
- else if (DPI->second == Popularity) {
- // If it is the same as what we've seen so far, keep track of it.
- SamePopularity.push_back(DPI->first);
- } else {
- // If it is more popular, remember it.
- SamePopularity.clear();
- MostPopularDest = DPI->first;
- Popularity = DPI->second;
- }
+ using VT = decltype(DestPopularity)::value_type;
+ auto MostPopular = std::max_element(
+ DestPopularity.begin(), DestPopularity.end(),
+ [](const VT &L, const VT &R) { return L.second < R.second; });
+
+ // Okay, we have finally picked the most popular destination.
+ return MostPopular->first;
+}
+
+// Try to evaluate the value of V when the control flows from PredPredBB to
+// BB->getSinglePredecessor() and then on to BB.
+Constant *JumpThreadingPass::EvaluateOnPredecessorEdge(BasicBlock *BB,
+ BasicBlock *PredPredBB,
+ Value *V) {
+ BasicBlock *PredBB = BB->getSinglePredecessor();
+ assert(PredBB && "Expected a single predecessor");
+
+ if (Constant *Cst = dyn_cast<Constant>(V)) {
+ return Cst;
}
- // Okay, now we know the most popular destination. If there is more than one
- // destination, we need to determine one. This is arbitrary, but we need
- // to make a deterministic decision. Pick the first one that appears in the
- // successor list.
- if (!SamePopularity.empty()) {
- SamePopularity.push_back(MostPopularDest);
- Instruction *TI = BB->getTerminator();
- for (unsigned i = 0; ; ++i) {
- assert(i != TI->getNumSuccessors() && "Didn't find any successor!");
+ // Consult LVI if V is not an instruction in BB or PredBB.
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I || (I->getParent() != BB && I->getParent() != PredBB)) {
+ return LVI->getConstantOnEdge(V, PredPredBB, PredBB, nullptr);
+ }
- if (!is_contained(SamePopularity, TI->getSuccessor(i)))
- continue;
+ // Look into a PHI argument.
+ if (PHINode *PHI = dyn_cast<PHINode>(V)) {
+ if (PHI->getParent() == PredBB)
+ return dyn_cast<Constant>(PHI->getIncomingValueForBlock(PredPredBB));
+ return nullptr;
+ }
- MostPopularDest = TI->getSuccessor(i);
- break;
+ // If we have a CmpInst, try to fold it for each incoming edge into PredBB.
+ if (CmpInst *CondCmp = dyn_cast<CmpInst>(V)) {
+ if (CondCmp->getParent() == BB) {
+ Constant *Op0 =
+ EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0));
+ Constant *Op1 =
+ EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1));
+ if (Op0 && Op1) {
+ return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1);
+ }
}
+ return nullptr;
}
- // Okay, we have finally picked the most popular destination.
- return MostPopularDest;
+ return nullptr;
}
bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,
@@ -1557,8 +1560,12 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,
return false;
PredValueInfoTy PredValues;
- if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI))
- return false;
+ if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference,
+ CxtI)) {
+ // We don't have known values in predecessors. See if we can thread through
+ // BB and its sole predecessor.
+ return MaybeThreadThroughTwoBasicBlocks(BB, Cond);
+ }
assert(!PredValues.empty() &&
"ComputeValueKnownInPredecessors returned true with no values");
@@ -1624,7 +1631,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,
isa<CallBrInst>(Pred->getTerminator()))
continue;
- PredToDestList.push_back(std::make_pair(Pred, DestBB));
+ PredToDestList.emplace_back(Pred, DestBB);
}
// If all edges were unthreadable, we fail.
@@ -2015,6 +2022,205 @@ JumpThreadingPass::CloneInstructions(BasicBlock::iterator BI,
return ValueMapping;
}
+/// Attempt to thread through two successive basic blocks.
+bool JumpThreadingPass::MaybeThreadThroughTwoBasicBlocks(BasicBlock *BB,
+ Value *Cond) {
+ // Consider:
+ //
+ // PredBB:
+ // %var = phi i32* [ null, %bb1 ], [ @a, %bb2 ]
+ // %tobool = icmp eq i32 %cond, 0
+ // br i1 %tobool, label %BB, label ...
+ //
+ // BB:
+ // %cmp = icmp eq i32* %var, null
+ // br i1 %cmp, label ..., label ...
+ //
+ // We don't know the value of %var at BB even if we know which incoming edge
+ // we take to BB. However, once we duplicate PredBB for each of its incoming
+ // edges (say, PredBB1 and PredBB2), we know the value of %var in each copy of
+ // PredBB. Then we can thread edges PredBB1->BB and PredBB2->BB through BB.
+
+ // Require that BB end with a Branch for simplicity.
+ BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!CondBr)
+ return false;
+
+ // BB must have exactly one predecessor.
+ BasicBlock *PredBB = BB->getSinglePredecessor();
+ if (!PredBB)
+ return false;
+
+ // Require that PredBB end with a conditional Branch. If PredBB ends with an
+ // unconditional branch, we should be merging PredBB and BB instead. For
+ // simplicity, we don't deal with a switch.
+ BranchInst *PredBBBranch = dyn_cast<BranchInst>(PredBB->getTerminator());
+ if (!PredBBBranch || PredBBBranch->isUnconditional())
+ return false;
+
+ // If PredBB has exactly one incoming edge, we don't gain anything by copying
+ // PredBB.
+ if (PredBB->getSinglePredecessor())
+ return false;
+
+ // Don't thread through PredBB if it contains a successor edge to itself, in
+ // which case we would infinite loop. Suppose we are threading an edge from
+ // PredPredBB through PredBB and BB to SuccBB with PredBB containing a
+ // successor edge to itself. If we allowed jump threading in this case, we
+ // could duplicate PredBB and BB as, say, PredBB.thread and BB.thread. Since
+ // PredBB.thread has a successor edge to PredBB, we would immediately come up
+ // with another jump threading opportunity from PredBB.thread through PredBB
+ // and BB to SuccBB. This jump threading would repeatedly occur. That is, we
+ // would keep peeling one iteration from PredBB.
+ if (llvm::is_contained(successors(PredBB), PredBB))
+ return false;
+
+ // Don't thread across a loop header.
+ if (LoopHeaders.count(PredBB))
+ return false;
+
+ // Avoid complication with duplicating EH pads.
+ if (PredBB->isEHPad())
+ return false;
+
+ // Find a predecessor that we can thread. For simplicity, we only consider a
+ // successor edge out of BB to which we thread exactly one incoming edge into
+ // PredBB.
+ unsigned ZeroCount = 0;
+ unsigned OneCount = 0;
+ BasicBlock *ZeroPred = nullptr;
+ BasicBlock *OnePred = nullptr;
+ for (BasicBlock *P : predecessors(PredBB)) {
+ if (ConstantInt *CI = dyn_cast_or_null<ConstantInt>(
+ EvaluateOnPredecessorEdge(BB, P, Cond))) {
+ if (CI->isZero()) {
+ ZeroCount++;
+ ZeroPred = P;
+ } else if (CI->isOne()) {
+ OneCount++;
+ OnePred = P;
+ }
+ }
+ }
+
+ // Disregard complicated cases where we have to thread multiple edges.
+ BasicBlock *PredPredBB;
+ if (ZeroCount == 1) {
+ PredPredBB = ZeroPred;
+ } else if (OneCount == 1) {
+ PredPredBB = OnePred;
+ } else {
+ return false;
+ }
+
+ BasicBlock *SuccBB = CondBr->getSuccessor(PredPredBB == ZeroPred);
+
+ // If threading to the same block as we come from, we would infinite loop.
+ if (SuccBB == BB) {
+ LLVM_DEBUG(dbgs() << " Not threading across BB '" << BB->getName()
+ << "' - would thread to self!\n");
+ return false;
+ }
+
+ // If threading this would thread across a loop header, don't thread the edge.
+ // See the comments above FindLoopHeaders for justifications and caveats.
+ if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) {
+ LLVM_DEBUG({
+ bool BBIsHeader = LoopHeaders.count(BB);
+ bool SuccIsHeader = LoopHeaders.count(SuccBB);
+ dbgs() << " Not threading across "
+ << (BBIsHeader ? "loop header BB '" : "block BB '")
+ << BB->getName() << "' to dest "
+ << (SuccIsHeader ? "loop header BB '" : "block BB '")
+ << SuccBB->getName()
+ << "' - it might create an irreducible loop!\n";
+ });
+ return false;
+ }
+
+ // Compute the cost of duplicating BB and PredBB.
+ unsigned BBCost =
+ getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold);
+ unsigned PredBBCost = getJumpThreadDuplicationCost(
+ PredBB, PredBB->getTerminator(), BBDupThreshold);
+
+ // Give up if costs are too high. We need to check BBCost and PredBBCost
+ // individually before checking their sum because getJumpThreadDuplicationCost
+ // return (unsigned)~0 for those basic blocks that cannot be duplicated.
+ if (BBCost > BBDupThreshold || PredBBCost > BBDupThreshold ||
+ BBCost + PredBBCost > BBDupThreshold) {
+ LLVM_DEBUG(dbgs() << " Not threading BB '" << BB->getName()
+ << "' - Cost is too high: " << PredBBCost
+ << " for PredBB, " << BBCost << "for BB\n");
+ return false;
+ }
+
+ // Now we are ready to duplicate PredBB.
+ ThreadThroughTwoBasicBlocks(PredPredBB, PredBB, BB, SuccBB);
+ return true;
+}
+
+void JumpThreadingPass::ThreadThroughTwoBasicBlocks(BasicBlock *PredPredBB,
+ BasicBlock *PredBB,
+ BasicBlock *BB,
+ BasicBlock *SuccBB) {
+ LLVM_DEBUG(dbgs() << " Threading through '" << PredBB->getName() << "' and '"
+ << BB->getName() << "'\n");
+
+ BranchInst *CondBr = cast<BranchInst>(BB->getTerminator());
+ BranchInst *PredBBBranch = cast<BranchInst>(PredBB->getTerminator());
+
+ BasicBlock *NewBB =
+ BasicBlock::Create(PredBB->getContext(), PredBB->getName() + ".thread",
+ PredBB->getParent(), PredBB);
+ NewBB->moveAfter(PredBB);
+
+ // Set the block frequency of NewBB.
+ if (HasProfileData) {
+ auto NewBBFreq = BFI->getBlockFreq(PredPredBB) *
+ BPI->getEdgeProbability(PredPredBB, PredBB);
+ BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+ }
+
+ // We are going to have to map operands from the original BB block to the new
+ // copy of the block 'NewBB'. If there are PHI nodes in PredBB, evaluate them
+ // to account for entry from PredPredBB.
+ DenseMap<Instruction *, Value *> ValueMapping =
+ CloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB);
+
+ // Update the terminator of PredPredBB to jump to NewBB instead of PredBB.
+ // This eliminates predecessors from PredPredBB, which requires us to simplify
+ // any PHI nodes in PredBB.
+ Instruction *PredPredTerm = PredPredBB->getTerminator();
+ for (unsigned i = 0, e = PredPredTerm->getNumSuccessors(); i != e; ++i)
+ if (PredPredTerm->getSuccessor(i) == PredBB) {
+ PredBB->removePredecessor(PredPredBB, true);
+ PredPredTerm->setSuccessor(i, NewBB);
+ }
+
+ AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(0), PredBB, NewBB,
+ ValueMapping);
+ AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(1), PredBB, NewBB,
+ ValueMapping);
+
+ DTU->applyUpdatesPermissive(
+ {{DominatorTree::Insert, NewBB, CondBr->getSuccessor(0)},
+ {DominatorTree::Insert, NewBB, CondBr->getSuccessor(1)},
+ {DominatorTree::Insert, PredPredBB, NewBB},
+ {DominatorTree::Delete, PredPredBB, PredBB}});
+
+ UpdateSSA(PredBB, NewBB, ValueMapping);
+
+ // Clean up things like PHI nodes with single operands, dead instructions,
+ // etc.
+ SimplifyInstructionsInBlock(NewBB, TLI);
+ SimplifyInstructionsInBlock(PredBB, TLI);
+
+ SmallVector<BasicBlock *, 1> PredsToFactor;
+ PredsToFactor.push_back(NewBB);
+ ThreadEdge(BB, PredsToFactor, SuccBB);
+}
+
/// TryThreadEdge - Thread an edge if it's safe and profitable to do so.
bool JumpThreadingPass::TryThreadEdge(
BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &PredBBs,
@@ -2078,10 +2284,6 @@ void JumpThreadingPass::ThreadEdge(BasicBlock *BB,
<< "' to '" << SuccBB->getName()
<< ", across block:\n " << *BB << "\n");
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
LVI->threadEdge(PredBB, BB, SuccBB);
BasicBlock *NewBB = BasicBlock::Create(BB->getContext(),
@@ -2246,8 +2448,7 @@ void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
}
// Update edge probabilities in BPI.
- for (int I = 0, E = BBSuccProbs.size(); I < E; I++)
- BPI->setEdgeProbability(BB, I, BBSuccProbs[I]);
+ BPI->setEdgeProbability(BB, BBSuccProbs);
// Update the profile metadata as well.
//
@@ -2524,10 +2725,6 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) {
// Now check if one of the select values would allow us to constant fold the
// terminator in BB. We don't do the transform if both sides fold, those
// cases will be threaded in any case.
- if (DTU->hasPendingDomTreeUpdates())
- LVI->disableDT();
- else
- LVI->enableDT();
LazyValueInfo::Tristate LHSFolds =
LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1),
CondRHS, Pred, BB, CondCmp);
@@ -2565,6 +2762,16 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) {
/// select is not jump-threaded, it will be folded again in the later
/// optimizations.
bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) {
+ // This transform can introduce a UB (a conditional branch that depends on a
+ // poison value) that was not present in the original program. See
+ // @TryToUnfoldSelectInCurrBB test in test/Transforms/JumpThreading/select.ll.
+ // Disable this transform under MemorySanitizer.
+ // FIXME: either delete it or replace with a valid transform. This issue is
+ // not limited to MemorySanitizer (but has only been observed as an MSan false
+ // positive in practice so far).
+ if (BB->getParent()->hasFnAttribute(Attribute::SanitizeMemory))
+ return false;
+
// If threading this would thread across a loop header, don't thread the edge.
// See the comments above FindLoopHeaders for justifications and caveats.
if (LoopHeaders.count(BB))
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 8c33045c2380..1a22edaf8726 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -46,6 +46,7 @@
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/MustExecute.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
@@ -69,6 +70,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -151,11 +153,11 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst,
const Instruction *CtxI = nullptr);
static bool pointerInvalidatedByLoop(MemoryLocation MemLoc,
AliasSetTracker *CurAST, Loop *CurLoop,
- AliasAnalysis *AA);
+ AAResults *AA);
static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU,
Loop *CurLoop,
SinkAndHoistLICMFlags &Flags);
-static Instruction *CloneInstructionInExitBlock(
+static Instruction *cloneInstructionInExitBlock(
Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI,
const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU);
@@ -168,27 +170,24 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest,
namespace {
struct LoopInvariantCodeMotion {
- using ASTrackerMapTy = DenseMap<Loop *, std::unique_ptr<AliasSetTracker>>;
- bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT,
+ bool runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT,
TargetLibraryInfo *TLI, TargetTransformInfo *TTI,
ScalarEvolution *SE, MemorySSA *MSSA,
- OptimizationRemarkEmitter *ORE, bool DeleteAST);
+ OptimizationRemarkEmitter *ORE);
- ASTrackerMapTy &getLoopToAliasSetMap() { return LoopToAliasSetMap; }
LoopInvariantCodeMotion(unsigned LicmMssaOptCap,
unsigned LicmMssaNoAccForPromotionCap)
: LicmMssaOptCap(LicmMssaOptCap),
LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap) {}
private:
- ASTrackerMapTy LoopToAliasSetMap;
unsigned LicmMssaOptCap;
unsigned LicmMssaNoAccForPromotionCap;
std::unique_ptr<AliasSetTracker>
- collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA);
+ collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AAResults *AA);
std::unique_ptr<AliasSetTracker>
- collectAliasInfoForLoopWithMSSA(Loop *L, AliasAnalysis *AA,
+ collectAliasInfoForLoopWithMSSA(Loop *L, AAResults *AA,
MemorySSAUpdater *MSSAU);
};
@@ -202,13 +201,8 @@ struct LegacyLICMPass : public LoopPass {
}
bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L)) {
- // If we have run LICM on a previous loop but now we are skipping
- // (because we've hit the opt-bisect limit), we need to clear the
- // loop alias information.
- LICM.getLoopToAliasSetMap().clear();
+ if (skipLoop(L))
return false;
- }
auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
MemorySSA *MSSA = EnableMSSALoopDependency
@@ -226,7 +220,7 @@ struct LegacyLICMPass : public LoopPass {
*L->getHeader()->getParent()),
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
*L->getHeader()->getParent()),
- SE ? &SE->getSE() : nullptr, MSSA, &ORE, false);
+ SE ? &SE->getSE() : nullptr, MSSA, &ORE);
}
/// This transformation requires natural loop information & requires that
@@ -244,53 +238,21 @@ struct LegacyLICMPass : public LoopPass {
getLoopAnalysisUsage(AU);
}
- using llvm::Pass::doFinalization;
-
- bool doFinalization() override {
- auto &AliasSetMap = LICM.getLoopToAliasSetMap();
- // All loops in the AliasSetMap should be cleaned up already. The only case
- // where we fail to do so is if an outer loop gets deleted before LICM
- // visits it.
- assert(all_of(AliasSetMap,
- [](LoopInvariantCodeMotion::ASTrackerMapTy::value_type &KV) {
- return !KV.first->getParentLoop();
- }) &&
- "Didn't free loop alias sets");
- AliasSetMap.clear();
- return false;
- }
-
private:
LoopInvariantCodeMotion LICM;
-
- /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info.
- void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To,
- Loop *L) override;
-
- /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias
- /// set.
- void deleteAnalysisValue(Value *V, Loop *L) override;
-
- /// Simple Analysis hook. Delete loop L from alias set map.
- void deleteAnalysisLoop(Loop *L) override;
};
} // namespace
PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR, LPMUpdater &) {
- const auto &FAM =
- AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
- Function *F = L.getHeader()->getParent();
-
- auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F);
- // FIXME: This should probably be optional rather than required.
- if (!ORE)
- report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not "
- "cached at a higher level");
+ // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but ORE cannot be preserved (see comment before the pass definition).
+ OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap);
if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE,
- AR.MSSA, ORE, true))
+ AR.MSSA, &ORE))
return PreservedAnalyses::all();
auto PA = getLoopPassPreservedAnalyses();
@@ -322,13 +284,10 @@ Pass *llvm::createLICMPass(unsigned LicmMssaOptCap,
/// Hoist expressions out of the specified loop. Note, alias info for inner
/// loop is not preserved so it is not a good idea to run LICM multiple
/// times on one loop.
-/// We should delete AST for inner loops in the new pass manager to avoid
-/// memory leak.
-///
bool LoopInvariantCodeMotion::runOnLoop(
- Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT,
+ Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT,
TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE,
- MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST) {
+ MemorySSA *MSSA, OptimizationRemarkEmitter *ORE) {
bool Changed = false;
assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form.");
@@ -372,7 +331,7 @@ bool LoopInvariantCodeMotion::runOnLoop(
BasicBlock *Preheader = L->getLoopPreheader();
// Compute loop safety information.
- ICFLoopSafetyInfo SafetyInfo(DT);
+ ICFLoopSafetyInfo SafetyInfo;
SafetyInfo.computeLoopSafetyInfo(L);
// We want to visit all of the instructions in this loop... that are not parts
@@ -476,11 +435,6 @@ bool LoopInvariantCodeMotion::runOnLoop(
assert((!L->getParentLoop() || L->getParentLoop()->isLCSSAForm(*DT)) &&
"Parent loop not left in LCSSA form after LICM!");
- // If this loop is nested inside of another one, save the alias information
- // for when we process the outer loop.
- if (!MSSAU.get() && CurAST.get() && L->getParentLoop() && !DeleteAST)
- LoopToAliasSetMap[L] = std::move(CurAST);
-
if (MSSAU.get() && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -494,7 +448,7 @@ bool LoopInvariantCodeMotion::runOnLoop(
/// first order w.r.t the DominatorTree. This allows us to visit uses before
/// definitions, allowing us to sink a loop body in one pass without iteration.
///
-bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
+bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
DominatorTree *DT, TargetLibraryInfo *TLI,
TargetTransformInfo *TTI, Loop *CurLoop,
AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU,
@@ -529,6 +483,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
// used in the loop, instead, just delete it.
if (isInstructionTriviallyDead(&I, TLI)) {
LLVM_DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n');
+ salvageKnowledge(&I);
salvageDebugInfo(I);
++II;
eraseInstruction(I, *SafetyInfo, CurAST, MSSAU);
@@ -542,13 +497,14 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
// operands of the instruction are loop invariant.
//
bool FreeInLoop = false;
- if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) &&
+ if (!I.mayHaveSideEffects() &&
+ isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) &&
canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags,
- ORE) &&
- !I.mayHaveSideEffects()) {
+ ORE)) {
if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) {
if (!FreeInLoop) {
++II;
+ salvageDebugInfo(I);
eraseInstruction(I, *SafetyInfo, CurAST, MSSAU);
}
Changed = true;
@@ -790,47 +746,12 @@ public:
};
} // namespace
-
-/// Return true if we know how to rewrite all uses of the given alloca after
-/// hoisting it out of the loop. The main concerns are a) potential captures
-/// and b) invariant.start markers which don't capture, but are no longer
-/// valid w/o a corresponding invariant.end.
-static bool canRewriteUsesOfAlloca(AllocaInst &AI) {
- // TODO: This looks a lot like capture tracking, but we need to remove any
- // invariant starts if we extend the lifetime of the alloca by hoisting it.
- // We should probably refactor capture tracking into a form which allows us
- // to reuse the relevant bits and remove the duplicated logic here.
-
- SmallVector<Use *, 16> Worklist;
- for (Use &U : AI.uses())
- Worklist.push_back(&U);
-
- unsigned NumUsesExplored = 0;
- while (!Worklist.empty()) {
- Use *U = Worklist.pop_back_val();
- Instruction *I = cast<Instruction>(U->getUser());
- NumUsesExplored++;
- if (NumUsesExplored > DefaultMaxUsesToExplore)
- return false;
- // Non capturing, terminating uses
- if (isa<LoadInst>(I) ||
- (isa<StoreInst>(I) && U->getOperandNo() == 1))
- continue;
- // Non capturing, non-terminating
- if (!isa<BitCastInst>(I) && !isa<GetElementPtrInst>(I))
- return false;
- for (Use &U : I->uses())
- Worklist.push_back(&U);
- }
- return true;
-}
-
/// Walk the specified region of the CFG (defined by all blocks dominated by
/// the specified block, and that are in the current loop) in depth first
/// order w.r.t the DominatorTree. This allows us to visit definitions before
/// uses, allowing us to hoist a loop body in one pass without iteration.
///
-bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
+bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop,
AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU,
ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo,
@@ -901,9 +822,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
// Attempt to remove floating point division out of the loop by
// converting it to a reciprocal multiplication.
- if (I.getOpcode() == Instruction::FDiv &&
- CurLoop->isLoopInvariant(I.getOperand(1)) &&
- I.hasAllowReciprocal()) {
+ if (I.getOpcode() == Instruction::FDiv && I.hasAllowReciprocal() &&
+ CurLoop->isLoopInvariant(I.getOperand(1))) {
auto Divisor = I.getOperand(1);
auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0);
auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor);
@@ -945,16 +865,6 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
continue;
}
- if (isa<AllocaInst>(&I) &&
- SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) &&
- canRewriteUsesOfAlloca(cast<AllocaInst>(I))) {
- hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo,
- MSSAU, SE, ORE);
- HoistedInstructions.push_back(&I);
- Changed = true;
- continue;
- }
-
if (PHINode *PN = dyn_cast<PHINode>(&I)) {
if (CFH.canHoistPHI(PN)) {
// Redirect incoming blocks first to ensure that we create hoisted
@@ -1081,12 +991,12 @@ namespace {
bool isHoistableAndSinkableInst(Instruction &I) {
// Only these instructions are hoistable/sinkable.
return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) ||
- isa<FenceInst>(I) || isa<CastInst>(I) ||
- isa<UnaryOperator>(I) || isa<BinaryOperator>(I) ||
- isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) ||
+ isa<FenceInst>(I) || isa<CastInst>(I) || isa<UnaryOperator>(I) ||
+ isa<BinaryOperator>(I) || isa<SelectInst>(I) ||
+ isa<GetElementPtrInst>(I) || isa<CmpInst>(I) ||
isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) ||
isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) ||
- isa<InsertValueInst>(I));
+ isa<InsertValueInst>(I) || isa<FreezeInst>(I));
}
/// Return true if all of the alias sets within this AST are known not to
/// contain a Mod, or if MSSA knows thare are no MemoryDefs in the loop.
@@ -1198,11 +1108,11 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI);
if (Behavior == FMRB_DoesNotAccessMemory)
return true;
- if (AliasAnalysis::onlyReadsMemory(Behavior)) {
+ if (AAResults::onlyReadsMemory(Behavior)) {
// A readonly argmemonly function only reads from memory pointed to by
// it's arguments with arbitrary offsets. If we can prove there are no
// writes to this memory in the loop, we can hoist or sink.
- if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) {
+ if (AAResults::onlyAccessesArgPointees(Behavior)) {
// TODO: expand to writeable arguments
for (Value *Op : CI->arg_operands())
if (Op->getType()->isPointerTy()) {
@@ -1351,7 +1261,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,
const TargetTransformInfo *TTI) {
if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) {
- if (TTI->getUserCost(GEP) != TargetTransformInfo::TCC_Free)
+ if (TTI->getUserCost(GEP, TargetTransformInfo::TCK_SizeAndLatency) !=
+ TargetTransformInfo::TCC_Free)
return false;
// For a GEP, we cannot simply use getUserCost because currently it
// optimistically assume that a GEP will fold into addressing mode
@@ -1366,7 +1277,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,
}
return true;
} else
- return TTI->getUserCost(&I) == TargetTransformInfo::TCC_Free;
+ return TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency) ==
+ TargetTransformInfo::TCC_Free;
}
/// Return true if the only users of this instruction are outside of
@@ -1407,7 +1319,7 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop,
return true;
}
-static Instruction *CloneInstructionInExitBlock(
+static Instruction *cloneInstructionInExitBlock(
Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI,
const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU) {
Instruction *New;
@@ -1520,7 +1432,7 @@ static Instruction *sinkThroughTriviallyReplaceablePHI(
if (It != SunkCopies.end())
New = It->second;
else
- New = SunkCopies[ExitBlock] = CloneInstructionInExitBlock(
+ New = SunkCopies[ExitBlock] = cloneInstructionInExitBlock(
*I, *ExitBlock, *TPN, LI, SafetyInfo, MSSAU);
return New;
}
@@ -1537,7 +1449,8 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) {
return false;
for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {
BasicBlock *BBPred = *PI;
- if (isa<IndirectBrInst>(BBPred->getTerminator()))
+ if (isa<IndirectBrInst>(BBPred->getTerminator()) ||
+ isa<CallBrInst>(BBPred->getTerminator()))
return false;
}
return true;
@@ -1857,7 +1770,7 @@ public:
StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos);
if (UnorderedAtomic)
NewSI->setOrdering(AtomicOrdering::Unordered);
- NewSI->setAlignment(MaybeAlign(Alignment));
+ NewSI->setAlignment(Align(Alignment));
NewSI->setDebugLoc(DL);
if (AATags)
NewSI->setAAMetadata(AATags);
@@ -1981,7 +1894,7 @@ bool llvm::promoteLoopAccessesToScalars(
// We start with an alignment of one and try to find instructions that allow
// us to prove better alignment.
- unsigned Alignment = 1;
+ Align Alignment;
// Keep track of which types of access we see
bool SawUnorderedAtomic = false;
bool SawNotAtomic = false;
@@ -2029,10 +1942,7 @@ bool llvm::promoteLoopAccessesToScalars(
SawUnorderedAtomic |= Load->isAtomic();
SawNotAtomic |= !Load->isAtomic();
- unsigned InstAlignment = Load->getAlignment();
- if (!InstAlignment)
- InstAlignment =
- MDL.getABITypeAlignment(Load->getType());
+ Align InstAlignment = Load->getAlign();
// Note that proving a load safe to speculate requires proving
// sufficient alignment at the target location. Proving it guaranteed
@@ -2060,10 +1970,7 @@ bool llvm::promoteLoopAccessesToScalars(
// already know that promotion is safe, since it may have higher
// alignment than any other guaranteed stores, in which case we can
// raise the alignment on the promoted store.
- unsigned InstAlignment = Store->getAlignment();
- if (!InstAlignment)
- InstAlignment =
- MDL.getABITypeAlignment(Store->getValueOperand()->getType());
+ Align InstAlignment = Store->getAlign();
if (!DereferenceableInPH || !SafeToInsertStore ||
(InstAlignment > Alignment)) {
@@ -2090,8 +1997,7 @@ bool llvm::promoteLoopAccessesToScalars(
if (!DereferenceableInPH) {
DereferenceableInPH = isDereferenceableAndAlignedPointer(
Store->getPointerOperand(), Store->getValueOperand()->getType(),
- MaybeAlign(Store->getAlignment()), MDL,
- Preheader->getTerminator(), DT);
+ Store->getAlign(), MDL, Preheader->getTerminator(), DT);
}
} else
return false; // Not a load or store.
@@ -2156,18 +2062,19 @@ bool llvm::promoteLoopAccessesToScalars(
});
++NumPromoted;
- // Grab a debug location for the inserted loads/stores; given that the
- // inserted loads/stores have little relation to the original loads/stores,
- // this code just arbitrarily picks a location from one, since any debug
- // location is better than none.
- DebugLoc DL = LoopUses[0]->getDebugLoc();
+ // Look at all the loop uses, and try to merge their locations.
+ std::vector<const DILocation *> LoopUsesLocs;
+ for (auto U : LoopUses)
+ LoopUsesLocs.push_back(U->getDebugLoc().get());
+ auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs));
// We use the SSAUpdater interface to insert phi nodes as required.
SmallVector<PHINode *, 16> NewPHIs;
SSAUpdater SSA(&NewPHIs);
LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks,
InsertPts, MSSAInsertPts, PIC, *CurAST, MSSAU, *LI, DL,
- Alignment, SawUnorderedAtomic, AATags, *SafetyInfo);
+ Alignment.value(), SawUnorderedAtomic, AATags,
+ *SafetyInfo);
// Set up the preheader to have a definition of the value. It is the live-out
// value from the preheader that uses in the loop will use.
@@ -2176,8 +2083,8 @@ bool llvm::promoteLoopAccessesToScalars(
SomePtr->getName() + ".promoted", Preheader->getTerminator());
if (SawUnorderedAtomic)
PreheaderLoad->setOrdering(AtomicOrdering::Unordered);
- PreheaderLoad->setAlignment(MaybeAlign(Alignment));
- PreheaderLoad->setDebugLoc(DL);
+ PreheaderLoad->setAlignment(Alignment);
+ PreheaderLoad->setDebugLoc(DebugLoc());
if (AATags)
PreheaderLoad->setAAMetadata(AATags);
SSA.AddAvailableValue(Preheader, PreheaderLoad);
@@ -2206,41 +2113,13 @@ bool llvm::promoteLoopAccessesToScalars(
/// Returns an owning pointer to an alias set which incorporates aliasing info
/// from L and all subloops of L.
-/// FIXME: In new pass manager, there is no helper function to handle loop
-/// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed
-/// from scratch for every loop. Hook up with the helper functions when
-/// available in the new pass manager to avoid redundant computation.
std::unique_ptr<AliasSetTracker>
LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI,
- AliasAnalysis *AA) {
- std::unique_ptr<AliasSetTracker> CurAST;
- SmallVector<Loop *, 4> RecomputeLoops;
- for (Loop *InnerL : L->getSubLoops()) {
- auto MapI = LoopToAliasSetMap.find(InnerL);
- // If the AST for this inner loop is missing it may have been merged into
- // some other loop's AST and then that loop unrolled, and so we need to
- // recompute it.
- if (MapI == LoopToAliasSetMap.end()) {
- RecomputeLoops.push_back(InnerL);
- continue;
- }
- std::unique_ptr<AliasSetTracker> InnerAST = std::move(MapI->second);
+ AAResults *AA) {
+ auto CurAST = std::make_unique<AliasSetTracker>(*AA);
- if (CurAST) {
- // What if InnerLoop was modified by other passes ?
- // Once we've incorporated the inner loop's AST into ours, we don't need
- // the subloop's anymore.
- CurAST->add(*InnerAST);
- } else {
- CurAST = std::move(InnerAST);
- }
- LoopToAliasSetMap.erase(MapI);
- }
- if (!CurAST)
- CurAST = std::make_unique<AliasSetTracker>(*AA);
-
- // Add everything from the sub loops that are no longer directly available.
- for (Loop *InnerL : RecomputeLoops)
+ // Add everything from all the sub loops.
+ for (Loop *InnerL : L->getSubLoops())
for (BasicBlock *BB : InnerL->blocks())
CurAST->add(*BB);
@@ -2254,46 +2133,16 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI,
std::unique_ptr<AliasSetTracker>
LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA(
- Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) {
+ Loop *L, AAResults *AA, MemorySSAUpdater *MSSAU) {
auto *MSSA = MSSAU->getMemorySSA();
auto CurAST = std::make_unique<AliasSetTracker>(*AA, MSSA, L);
CurAST->addAllInstructionsInLoopUsingMSSA();
return CurAST;
}
-/// Simple analysis hook. Clone alias set info.
-///
-void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To,
- Loop *L) {
- auto ASTIt = LICM.getLoopToAliasSetMap().find(L);
- if (ASTIt == LICM.getLoopToAliasSetMap().end())
- return;
-
- ASTIt->second->copyValue(From, To);
-}
-
-/// Simple Analysis hook. Delete value V from alias set
-///
-void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) {
- auto ASTIt = LICM.getLoopToAliasSetMap().find(L);
- if (ASTIt == LICM.getLoopToAliasSetMap().end())
- return;
-
- ASTIt->second->deleteValue(V);
-}
-
-/// Simple Analysis hook. Delete value L from alias set map.
-///
-void LegacyLICMPass::deleteAnalysisLoop(Loop *L) {
- if (!LICM.getLoopToAliasSetMap().count(L))
- return;
-
- LICM.getLoopToAliasSetMap().erase(L);
-}
-
static bool pointerInvalidatedByLoop(MemoryLocation MemLoc,
AliasSetTracker *CurAST, Loop *CurLoop,
- AliasAnalysis *AA) {
+ AAResults *AA) {
// First check to see if any of the basic blocks in CurLoop invalidate *V.
bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod();
diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
index ab65f56d088f..687e14d6d7d2 100644
--- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
@@ -21,7 +21,6 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CFG.h"
@@ -32,6 +31,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
using namespace llvm;
@@ -61,10 +61,10 @@ namespace {
/// Loop prefetch implementation class.
class LoopDataPrefetch {
public:
- LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE,
- const TargetTransformInfo *TTI,
+ LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
+ ScalarEvolution *SE, const TargetTransformInfo *TTI,
OptimizationRemarkEmitter *ORE)
- : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
+ : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
bool run();
@@ -73,12 +73,16 @@ private:
/// Check if the stride of the accesses is large enough to
/// warrant a prefetch.
- bool isStrideLargeEnough(const SCEVAddRecExpr *AR);
+ bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
- unsigned getMinPrefetchStride() {
+ unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+ unsigned NumStridedMemAccesses,
+ unsigned NumPrefetches,
+ bool HasCall) {
if (MinPrefetchStride.getNumOccurrences() > 0)
return MinPrefetchStride;
- return TTI->getMinPrefetchStride();
+ return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+ NumPrefetches, HasCall);
}
unsigned getPrefetchDistance() {
@@ -93,7 +97,14 @@ private:
return TTI->getMaxPrefetchIterationsAhead();
}
+ bool doPrefetchWrites() {
+ if (PrefetchWrites.getNumOccurrences() > 0)
+ return PrefetchWrites;
+ return TTI->enableWritePrefetching();
+ }
+
AssumptionCache *AC;
+ DominatorTree *DT;
LoopInfo *LI;
ScalarEvolution *SE;
const TargetTransformInfo *TTI;
@@ -110,6 +121,7 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addPreserved<LoopInfoWrapperPass>();
@@ -138,8 +150,8 @@ FunctionPass *llvm::createLoopDataPrefetchPass() {
return new LoopDataPrefetchLegacyPass();
}
-bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
- unsigned TargetMinStride = getMinPrefetchStride();
+bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
+ unsigned TargetMinStride) {
// No need to check if any stride goes.
if (TargetMinStride <= 1)
return true;
@@ -156,6 +168,7 @@ bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
FunctionAnalysisManager &AM) {
+ DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
@@ -163,7 +176,7 @@ PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
&AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
- LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
+ LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
bool Changed = LDP.run();
if (Changed) {
@@ -180,6 +193,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
+ DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
AssumptionCache *AC =
@@ -189,7 +203,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
const TargetTransformInfo *TTI =
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
+ LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
return LDP.run();
}
@@ -210,6 +224,49 @@ bool LoopDataPrefetch::run() {
return MadeChange;
}
+/// A record for a potential prefetch made during the initial scan of the
+/// loop. This is used to let a single prefetch target multiple memory accesses.
+struct Prefetch {
+ /// The address formula for this prefetch as returned by ScalarEvolution.
+ const SCEVAddRecExpr *LSCEVAddRec;
+ /// The point of insertion for the prefetch instruction.
+ Instruction *InsertPt;
+ /// True if targeting a write memory access.
+ bool Writes;
+ /// The (first seen) prefetched instruction.
+ Instruction *MemI;
+
+ /// Constructor to create a new Prefetch for \p I.
+ Prefetch(const SCEVAddRecExpr *L, Instruction *I)
+ : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
+ addInstruction(I);
+ };
+
+ /// Add the instruction \param I to this prefetch. If it's not the first
+ /// one, 'InsertPt' and 'Writes' will be updated as required.
+ /// \param PtrDiff the known constant address difference to the first added
+ /// instruction.
+ void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
+ int64_t PtrDiff = 0) {
+ if (!InsertPt) {
+ MemI = I;
+ InsertPt = I;
+ Writes = isa<StoreInst>(I);
+ } else {
+ BasicBlock *PrefBB = InsertPt->getParent();
+ BasicBlock *InsBB = I->getParent();
+ if (PrefBB != InsBB) {
+ BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
+ if (DomBB != PrefBB)
+ InsertPt = DomBB->getTerminator();
+ }
+
+ if (isa<StoreInst>(I) && PtrDiff == 0)
+ Writes = true;
+ }
+ }
+};
+
bool LoopDataPrefetch::runOnLoop(Loop *L) {
bool MadeChange = false;
@@ -222,15 +279,22 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
// Calculate the number of iterations ahead to prefetch
CodeMetrics Metrics;
+ bool HasCall = false;
for (const auto BB : L->blocks()) {
// If the loop already has prefetches, then assume that the user knows
// what they are doing and don't add any more.
- for (auto &I : *BB)
- if (CallInst *CI = dyn_cast<CallInst>(&I))
- if (Function *F = CI->getCalledFunction())
+ for (auto &I : *BB) {
+ if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
+ if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
if (F->getIntrinsicID() == Intrinsic::prefetch)
return MadeChange;
-
+ if (TTI->isLoweredToCall(F))
+ HasCall = true;
+ } else { // indirect call.
+ HasCall = true;
+ }
+ }
+ }
Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
}
unsigned LoopSize = Metrics.NumInsts;
@@ -244,12 +308,14 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
if (ItersAhead > getMaxPrefetchIterationsAhead())
return MadeChange;
- LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
- << " iterations ahead (loop size: " << LoopSize << ") in "
- << L->getHeader()->getParent()->getName() << ": " << *L);
+ unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
+ if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
+ return MadeChange;
- SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
- for (const auto BB : L->blocks()) {
+ unsigned NumMemAccesses = 0;
+ unsigned NumStridedMemAccesses = 0;
+ SmallVector<Prefetch, 16> Prefetches;
+ for (const auto BB : L->blocks())
for (auto &I : *BB) {
Value *PtrValue;
Instruction *MemI;
@@ -258,7 +324,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
MemI = LMemI;
PtrValue = LMemI->getPointerOperand();
} else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
- if (!PrefetchWrites) continue;
+ if (!doPrefetchWrites()) continue;
MemI = SMemI;
PtrValue = SMemI->getPointerOperand();
} else continue;
@@ -266,7 +332,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
if (PtrAddrSpace)
continue;
-
+ NumMemAccesses++;
if (L->isLoopInvariant(PtrValue))
continue;
@@ -274,62 +340,79 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
if (!LSCEVAddRec)
continue;
+ NumStridedMemAccesses++;
- // Check if the stride of the accesses is large enough to warrant a
- // prefetch.
- if (!isStrideLargeEnough(LSCEVAddRec))
- continue;
-
- // We don't want to double prefetch individual cache lines. If this load
- // is known to be within one cache line of some other load that has
- // already been prefetched, then don't prefetch this one as well.
+ // We don't want to double prefetch individual cache lines. If this
+ // access is known to be within one cache line of some other one that
+ // has already been prefetched, then don't prefetch this one as well.
bool DupPref = false;
- for (const auto &PrefLoad : PrefLoads) {
- const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second);
+ for (auto &Pref : Prefetches) {
+ const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
if (const SCEVConstant *ConstPtrDiff =
dyn_cast<SCEVConstant>(PtrDiff)) {
int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
if (PD < (int64_t) TTI->getCacheLineSize()) {
+ Pref.addInstruction(MemI, DT, PD);
DupPref = true;
break;
}
}
}
- if (DupPref)
- continue;
+ if (!DupPref)
+ Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
+ }
- const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
- SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
- LSCEVAddRec->getStepRecurrence(*SE)));
- if (!isSafeToExpand(NextLSCEV, *SE))
- continue;
+ unsigned TargetMinStride =
+ getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+ Prefetches.size(), HasCall);
- PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
-
- Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
- SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr");
- Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
-
- IRBuilder<> Builder(MemI);
- Module *M = BB->getParent()->getParent();
- Type *I32 = Type::getInt32Ty(BB->getContext());
- Function *PrefetchFunc = Intrinsic::getDeclaration(
- M, Intrinsic::prefetch, PrefPtrValue->getType());
- Builder.CreateCall(
- PrefetchFunc,
- {PrefPtrValue,
- ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
- ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
- ++NumPrefetches;
- LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV
- << "\n");
- ORE->emit([&]() {
- return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI)
- << "prefetched memory access";
+ LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
+ << " iterations ahead (loop size: " << LoopSize << ") in "
+ << L->getHeader()->getParent()->getName() << ": " << *L);
+ LLVM_DEBUG(dbgs() << "Loop has: "
+ << NumMemAccesses << " memory accesses, "
+ << NumStridedMemAccesses << " strided memory accesses, "
+ << Prefetches.size() << " potential prefetch(es), "
+ << "a minimum stride of " << TargetMinStride << ", "
+ << (HasCall ? "calls" : "no calls") << ".\n");
+
+ for (auto &P : Prefetches) {
+ // Check if the stride of the accesses is large enough to warrant a
+ // prefetch.
+ if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
+ continue;
+
+ const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
+ SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
+ P.LSCEVAddRec->getStepRecurrence(*SE)));
+ if (!isSafeToExpand(NextLSCEV, *SE))
+ continue;
+
+ BasicBlock *BB = P.InsertPt->getParent();
+ Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
+ SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
+ Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
+
+ IRBuilder<> Builder(P.InsertPt);
+ Module *M = BB->getParent()->getParent();
+ Type *I32 = Type::getInt32Ty(BB->getContext());
+ Function *PrefetchFunc = Intrinsic::getDeclaration(
+ M, Intrinsic::prefetch, PrefPtrValue->getType());
+ Builder.CreateCall(
+ PrefetchFunc,
+ {PrefPtrValue,
+ ConstantInt::get(I32, P.Writes),
+ ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
+ ++NumPrefetches;
+ LLVM_DEBUG(dbgs() << " Access: "
+ << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
+ << ", SCEV: " << *P.LSCEVAddRec << "\n");
+ ORE->emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
+ << "prefetched memory access";
});
- MadeChange = true;
- }
+ MadeChange = true;
}
return MadeChange;
diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
index 2451572d6171..be209d34be42 100644
--- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
@@ -18,6 +18,8 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
@@ -134,7 +136,9 @@ static bool isLoopNeverExecuted(Loop *L) {
/// is unable to delete it due to hoisting trivially loop invariant
/// instructions out of the loop.
static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,
- ScalarEvolution &SE, LoopInfo &LI) {
+ ScalarEvolution &SE, LoopInfo &LI,
+ MemorySSA *MSSA,
+ OptimizationRemarkEmitter &ORE) {
assert(L->isLCSSAForm(DT) && "Expected LCSSA!");
// We can only remove the loop if there is a preheader that we can branch from
@@ -164,7 +168,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,
std::fill(P.incoming_values().begin(), P.incoming_values().end(),
UndefValue::get(P.getType()));
}
- deleteDeadLoop(L, &DT, &SE, &LI);
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "NeverExecutes", L->getStartLoc(),
+ L->getHeader())
+ << "Loop deleted because it never executes";
+ });
+ deleteDeadLoop(L, &DT, &SE, &LI, MSSA);
++NumDeleted;
return LoopDeletionResult::Deleted;
}
@@ -200,7 +209,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,
}
LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!");
- deleteDeadLoop(L, &DT, &SE, &LI);
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "Invariant", L->getStartLoc(),
+ L->getHeader())
+ << "Loop deleted because it is invariant";
+ });
+ deleteDeadLoop(L, &DT, &SE, &LI, MSSA);
++NumDeleted;
return LoopDeletionResult::Deleted;
@@ -212,15 +226,22 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM,
LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: ");
LLVM_DEBUG(L.dump());
- std::string LoopName = L.getName();
- auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI);
+ std::string LoopName = std::string(L.getName());
+ // For the new PM, we can't use OptimizationRemarkEmitter as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but ORE cannot be preserved (see comment before the pass definition).
+ OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
+ auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, AR.MSSA, ORE);
if (Result == LoopDeletionResult::Unmodified)
return PreservedAnalyses::all();
if (Result == LoopDeletionResult::Deleted)
Updater.markLoopAsDeleted(L, LoopName);
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ if (AR.MSSA)
+ PA.preserve<MemorySSAAnalysis>();
+ return PA;
}
namespace {
@@ -235,6 +256,7 @@ public:
bool runOnLoop(Loop *L, LPPassManager &) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addPreserved<MemorySSAWrapperPass>();
getLoopAnalysisUsage(AU);
}
};
@@ -255,11 +277,19 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ MemorySSA *MSSA = nullptr;
+ if (MSSAAnalysis)
+ MSSA = &MSSAAnalysis->getMSSA();
+ // For the old PM, we can't use OptimizationRemarkEmitter as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but ORE cannot be preserved (see comment before the pass definition).
+ OptimizationRemarkEmitter ORE(L->getHeader()->getParent());
LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: ");
LLVM_DEBUG(L->dump());
- LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI);
+ LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI, MSSA, ORE);
if (Result == LoopDeletionResult::Deleted)
LPM.markLoopAsDeleted(*L);
diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
index 8e04e6e0ffe8..7867a5468891 100644
--- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -789,12 +789,6 @@ public:
// instructions to partitions.
Partitions.setupPartitionIdOnInstructions();
- // To keep things simple have an empty preheader before we version or clone
- // the loop. (Also split if this has no predecessor, i.e. entry, because we
- // rely on PH having a predecessor.)
- if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
- SplitBlock(PH, PH->getTerminator(), DT, LI);
-
// If we need run-time checks, version the loop now.
auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI);
const auto *RtPtrChecking = LAI->getRuntimePointerChecking();
@@ -807,6 +801,12 @@ public:
"may not insert runtime check with convergent operation");
}
+ // To keep things simple have an empty preheader before we version or clone
+ // the loop. (Also split if this has no predecessor, i.e. entry, because we
+ // rely on PH having a predecessor.)
+ if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
+ SplitBlock(PH, PH->getTerminator(), DT, LI);
+
if (!Pred.isAlwaysTrue() || !Checks.empty()) {
assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning");
@@ -903,15 +903,14 @@ private:
/// \p PtrToPartition contains the partition number for pointers. Partition
/// number -1 means that the pointer is used in multiple partitions. In this
/// case we can't safely omit the check.
- SmallVector<RuntimePointerChecking::PointerCheck, 4>
- includeOnlyCrossPartitionChecks(
- const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks,
+ SmallVector<RuntimePointerCheck, 4> includeOnlyCrossPartitionChecks(
+ const SmallVectorImpl<RuntimePointerCheck> &AllChecks,
const SmallVectorImpl<int> &PtrToPartition,
const RuntimePointerChecking *RtPtrChecking) {
- SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+ SmallVector<RuntimePointerCheck, 4> Checks;
copy_if(AllChecks, std::back_inserter(Checks),
- [&](const RuntimePointerChecking::PointerCheck &Check) {
+ [&](const RuntimePointerCheck &Check) {
for (unsigned PtrIdx1 : Check.first->Members)
for (unsigned PtrIdx2 : Check.second->Members)
// Only include this check if there is a pair of pointers
diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
index e1738f08eb23..20edc8699d79 100644
--- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
@@ -86,11 +86,15 @@ STATISTIC(UnknownTripCount, "Loop has unknown trip count");
STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop");
STATISTIC(NonEqualTripCount, "Loop trip counts are not the same");
STATISTIC(NonAdjacent, "Loops are not adjacent");
-STATISTIC(NonEmptyPreheader, "Loop has a non-empty preheader");
+STATISTIC(
+ NonEmptyPreheader,
+ "Loop has a non-empty preheader with instructions that cannot be moved");
STATISTIC(FusionNotBeneficial, "Fusion is not beneficial");
STATISTIC(NonIdenticalGuards, "Candidates have different guards");
-STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block");
-STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block");
+STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block with "
+ "instructions that cannot be moved");
+STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block with "
+ "instructions that cannot be moved");
STATISTIC(NotRotated, "Candidate is not rotated");
enum FusionDependenceAnalysisChoice {
@@ -738,33 +742,40 @@ private:
continue;
}
- // The following three checks look for empty blocks in FC0 and FC1. If
- // any of these blocks are non-empty, we do not fuse. This is done
- // because we currently do not have the safety checks to determine if
- // it is safe to move the blocks past other blocks in the loop. Once
- // these checks are added, these conditions can be relaxed.
- if (!isEmptyPreheader(*FC1)) {
- LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty "
- "preheader. Not fusing.\n");
+ if (!isSafeToMoveBefore(*FC1->Preheader,
+ *FC0->Preheader->getTerminator(), DT, &PDT,
+ &DI)) {
+ LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe "
+ "instructions in preheader. Not fusing.\n");
reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
NonEmptyPreheader);
continue;
}
- if (FC0->GuardBranch && !isEmptyExitBlock(*FC0)) {
- LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty exit "
- "block. Not fusing.\n");
- reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
- NonEmptyExitBlock);
- continue;
- }
+ if (FC0->GuardBranch) {
+ assert(FC1->GuardBranch && "Expecting valid FC1 guard branch");
+
+ if (!isSafeToMoveBefore(*FC0->ExitBlock,
+ *FC1->ExitBlock->getFirstNonPHIOrDbg(), DT,
+ &PDT, &DI)) {
+ LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe "
+ "instructions in exit block. Not fusing.\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEmptyExitBlock);
+ continue;
+ }
- if (FC1->GuardBranch && !isEmptyGuardBlock(*FC1)) {
- LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty guard "
- "block. Not fusing.\n");
- reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
- NonEmptyGuardBlock);
- continue;
+ if (!isSafeToMoveBefore(
+ *FC1->GuardBranch->getParent(),
+ *FC0->GuardBranch->getParent()->getTerminator(), DT, &PDT,
+ &DI)) {
+ LLVM_DEBUG(dbgs()
+ << "Fusion candidate contains unsafe "
+ "instructions in guard block. Not fusing.\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEmptyGuardBlock);
+ continue;
+ }
}
// Check the dependencies across the loops and do not fuse if it would
@@ -1075,38 +1086,6 @@ private:
return (FC1.GuardBranch->getSuccessor(1) == FC1.Preheader);
}
- /// Check that the guard for \p FC *only* contains the cmp/branch for the
- /// guard.
- /// Once we are able to handle intervening code, any code in the guard block
- /// for FC1 will need to be treated as intervening code and checked whether
- /// it can safely move around the loops.
- bool isEmptyGuardBlock(const FusionCandidate &FC) const {
- assert(FC.GuardBranch && "Expecting a fusion candidate with guard branch.");
- if (auto *CmpInst = dyn_cast<Instruction>(FC.GuardBranch->getCondition())) {
- auto *GuardBlock = FC.GuardBranch->getParent();
- // If the generation of the cmp value is in GuardBlock, then the size of
- // the guard block should be 2 (cmp + branch). If the generation of the
- // cmp value is in a different block, then the size of the guard block
- // should only be 1.
- if (CmpInst->getParent() == GuardBlock)
- return GuardBlock->size() == 2;
- else
- return GuardBlock->size() == 1;
- }
-
- return false;
- }
-
- bool isEmptyPreheader(const FusionCandidate &FC) const {
- assert(FC.Preheader && "Expecting a valid preheader");
- return FC.Preheader->size() == 1;
- }
-
- bool isEmptyExitBlock(const FusionCandidate &FC) const {
- assert(FC.ExitBlock && "Expecting a valid exit block");
- return FC.ExitBlock->size() == 1;
- }
-
/// Simplify the condition of the latch branch of \p FC to true, when both of
/// its successors are the same.
void simplifyLatchBranch(const FusionCandidate &FC) const {
@@ -1123,7 +1102,7 @@ private:
/// Move instructions from FC0.Latch to FC1.Latch. If FC0.Latch has an unique
/// successor, then merge FC0.Latch with its unique successor.
void mergeLatch(const FusionCandidate &FC0, const FusionCandidate &FC1) {
- moveInstsBottomUp(*FC0.Latch, *FC1.Latch, DT, PDT, DI);
+ moveInstructionsToTheBeginning(*FC0.Latch, *FC1.Latch, DT, PDT, DI);
if (BasicBlock *Succ = FC0.Latch->getUniqueSuccessor()) {
MergeBlockIntoPredecessor(Succ, &DTU, &LI);
DTU.flush();
@@ -1166,6 +1145,10 @@ private:
LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump();
dbgs() << "Fusion Candidate 1: \n"; FC1.dump(););
+ // Move instructions from the preheader of FC1 to the end of the preheader
+ // of FC0.
+ moveInstructionsToTheEnd(*FC1.Preheader, *FC0.Preheader, DT, PDT, DI);
+
// Fusing guarded loops is handled slightly differently than non-guarded
// loops and has been broken out into a separate method instead of trying to
// intersperse the logic within a single method.
@@ -1382,6 +1365,14 @@ private:
BasicBlock *FC0NonLoopBlock = FC0.getNonLoopBlock();
BasicBlock *FC1NonLoopBlock = FC1.getNonLoopBlock();
+ // Move instructions from the exit block of FC0 to the beginning of the exit
+ // block of FC1.
+ moveInstructionsToTheBeginning(*FC0.ExitBlock, *FC1.ExitBlock, DT, PDT, DI);
+
+ // Move instructions from the guard block of FC1 to the end of the guard
+ // block of FC0.
+ moveInstructionsToTheEnd(*FC1GuardBlock, *FC0GuardBlock, DT, PDT, DI);
+
assert(FC0NonLoopBlock == FC1GuardBlock && "Loops are not adjacent");
SmallVector<DominatorTree::UpdateType, 8> TreeUpdates;
@@ -1394,6 +1385,7 @@ private:
// Thus, one path from the guard goes to the preheader for FC0 (and thus
// executes the new fused loop) and the other path goes to the NonLoopBlock
// for FC1 (where FC1 guard would have gone if FC1 was not executed).
+ FC1NonLoopBlock->replacePhiUsesWith(FC1GuardBlock, FC0GuardBlock);
FC0.GuardBranch->replaceUsesOfWith(FC0NonLoopBlock, FC1NonLoopBlock);
FC0.ExitBlock->getTerminator()->replaceUsesOfWith(FC1GuardBlock,
FC1.Header);
@@ -1545,7 +1537,10 @@ private:
// Update DT/PDT
DTU.applyUpdates(TreeUpdates);
+ LI.removeBlock(FC1GuardBlock);
LI.removeBlock(FC1.Preheader);
+ LI.removeBlock(FC0.ExitBlock);
+ DTU.deleteBB(FC1GuardBlock);
DTU.deleteBB(FC1.Preheader);
DTU.deleteBB(FC0.ExitBlock);
DTU.flush();
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index b77843d7cd71..3cb4df12e9b0 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -51,9 +51,11 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/MustExecute.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -91,6 +93,7 @@
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -123,15 +126,19 @@ class LoopIdiomRecognize {
const DataLayout *DL;
OptimizationRemarkEmitter &ORE;
bool ApplyCodeSizeHeuristics;
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
public:
explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT,
LoopInfo *LI, ScalarEvolution *SE,
TargetLibraryInfo *TLI,
- const TargetTransformInfo *TTI,
+ const TargetTransformInfo *TTI, MemorySSA *MSSA,
const DataLayout *DL,
OptimizationRemarkEmitter &ORE)
- : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {}
+ : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {
+ if (MSSA)
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+ }
bool runOnLoop(Loop *L);
@@ -224,13 +231,17 @@ public:
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
*L->getHeader()->getParent());
const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout();
+ auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ MemorySSA *MSSA = nullptr;
+ if (MSSAAnalysis)
+ MSSA = &MSSAAnalysis->getMSSA();
// For the old PM, we can't use OptimizationRemarkEmitter as an analysis
// pass. Function analyses need to be preserved across loop transformations
// but ORE cannot be preserved (see comment before the pass definition).
OptimizationRemarkEmitter ORE(L->getHeader()->getParent());
- LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE);
+ LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, MSSA, DL, ORE);
return LIR.runOnLoop(L);
}
@@ -239,6 +250,7 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.addPreserved<MemorySSAWrapperPass>();
getLoopAnalysisUsage(AU);
}
};
@@ -252,23 +264,20 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
LPMUpdater &) {
const auto *DL = &L.getHeader()->getModule()->getDataLayout();
- const auto &FAM =
- AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
- Function *F = L.getHeader()->getParent();
-
- auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F);
- // FIXME: This should probably be optional rather than required.
- if (!ORE)
- report_fatal_error(
- "LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached "
- "at a higher level");
+ // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but ORE cannot be preserved (see comment before the pass definition).
+ OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
- LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL,
- *ORE);
+ LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI,
+ AR.MSSA, DL, ORE);
if (!LIR.runOnLoop(&L))
return PreservedAnalyses::all();
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ if (AR.MSSA)
+ PA.preserve<MemorySSAAnalysis>();
+ return PA;
}
INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom",
@@ -339,14 +348,14 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
<< "] Countable Loop %" << CurLoop->getHeader()->getName()
<< "\n");
- bool MadeChange = false;
-
// The following transforms hoist stores/memsets into the loop pre-header.
- // Give up if the loop has instructions may throw.
+ // Give up if the loop has instructions that may throw.
SimpleLoopSafetyInfo SafetyInfo;
SafetyInfo.computeLoopSafetyInfo(CurLoop);
if (SafetyInfo.anyBlockMayThrow())
- return MadeChange;
+ return false;
+
+ bool MadeChange = false;
// Scan all the blocks in the loop that are not in subloops.
for (auto *BB : CurLoop->getBlocks()) {
@@ -968,11 +977,17 @@ bool LoopIdiomRecognize::processLoopStridedStore(
Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy);
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
}
+ NewCall->setDebugLoc(TheStore->getDebugLoc());
+
+ if (MSSAU) {
+ MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
+ NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator);
+ MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true);
+ }
LLVM_DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n"
<< " from store to: " << *Ev << " at: " << *TheStore
<< "\n");
- NewCall->setDebugLoc(TheStore->getDebugLoc());
ORE.emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStridedStore",
@@ -984,12 +999,40 @@ bool LoopIdiomRecognize::processLoopStridedStore(
// Okay, the memset has been formed. Zap the original store and anything that
// feeds into it.
- for (auto *I : Stores)
+ for (auto *I : Stores) {
+ if (MSSAU)
+ MSSAU->removeMemoryAccess(I, true);
deleteDeadInstruction(I);
+ }
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
++NumMemSet;
return true;
}
+class ExpandedValuesCleaner {
+ SCEVExpander &Expander;
+ TargetLibraryInfo *TLI;
+ SmallVector<Value *, 4> ExpandedValues;
+ bool Commit = false;
+
+public:
+ ExpandedValuesCleaner(SCEVExpander &Expander, TargetLibraryInfo *TLI)
+ : Expander(Expander), TLI(TLI) {}
+
+ void add(Value *V) { ExpandedValues.push_back(V); }
+
+ void commit() { Commit = true; }
+
+ ~ExpandedValuesCleaner() {
+ if (!Commit) {
+ Expander.clear();
+ for (auto *V : ExpandedValues)
+ RecursivelyDeleteTriviallyDeadInstructions(V, TLI);
+ }
+ }
+};
+
/// If the stored value is a strided load in the same loop with the same stride
/// this may be transformable into a memcpy. This kicks in for stuff like
/// for (i) A[i] = B[i];
@@ -1020,6 +1063,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
IRBuilder<> Builder(Preheader->getTerminator());
SCEVExpander Expander(*SE, *DL, "loop-idiom");
+ ExpandedValuesCleaner EVC(Expander, TLI);
+
const SCEV *StrStart = StoreEv->getStart();
unsigned StrAS = SI->getPointerAddressSpace();
Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS));
@@ -1036,16 +1081,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
// checking everything.
Value *StoreBasePtr = Expander.expandCodeFor(
StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator());
+ EVC.add(StoreBasePtr);
SmallPtrSet<Instruction *, 1> Stores;
Stores.insert(SI);
if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
- StoreSize, *AA, Stores)) {
- Expander.clear();
- // If we generated new code for the base pointer, clean up.
- RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI);
+ StoreSize, *AA, Stores))
return false;
- }
const SCEV *LdStart = LoadEv->getStart();
unsigned LdAS = LI->getPointerAddressSpace();
@@ -1058,15 +1100,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
// mutated by the loop.
Value *LoadBasePtr = Expander.expandCodeFor(
LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator());
+ EVC.add(LoadBasePtr);
if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
- StoreSize, *AA, Stores)) {
- Expander.clear();
- // If we generated new code for the base pointer, clean up.
- RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI);
- RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI);
+ StoreSize, *AA, Stores))
return false;
- }
if (avoidLIRForMultiBlockLoop())
return false;
@@ -1078,6 +1116,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
Value *NumBytes =
Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
+ EVC.add(NumBytes);
CallInst *NewCall = nullptr;
// Check whether to generate an unordered atomic memcpy:
@@ -1089,8 +1128,9 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
else {
// We cannot allow unaligned ops for unordered load/store, so reject
// anything where the alignment isn't at least the element size.
- unsigned Align = std::min(SI->getAlignment(), LI->getAlignment());
- if (Align < StoreSize)
+ const Align StoreAlign = SI->getAlign();
+ const Align LoadAlign = LI->getAlign();
+ if (StoreAlign < StoreSize || LoadAlign < StoreSize)
return false;
// If the element.atomic memcpy is not lowered into explicit
@@ -1104,11 +1144,17 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
// Note that unordered atomic loads/stores are *required* by the spec to
// have an alignment but non-atomic loads/stores may not.
NewCall = Builder.CreateElementUnorderedAtomicMemCpy(
- StoreBasePtr, SI->getAlignment(), LoadBasePtr, LI->getAlignment(),
- NumBytes, StoreSize);
+ StoreBasePtr, StoreAlign, LoadBasePtr, LoadAlign, NumBytes,
+ StoreSize);
}
NewCall->setDebugLoc(SI->getDebugLoc());
+ if (MSSAU) {
+ MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
+ NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator);
+ MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true);
+ }
+
LLVM_DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n"
<< " from load ptr=" << *LoadEv << " at: " << *LI << "\n"
<< " from store ptr=" << *StoreEv << " at: " << *SI
@@ -1124,8 +1170,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
// Okay, the memcpy has been formed. Zap the original store and anything that
// feeds into it.
+ if (MSSAU)
+ MSSAU->removeMemoryAccess(SI, true);
deleteDeadInstruction(SI);
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
++NumMemCpy;
+ EVC.commit();
return true;
}
@@ -1502,18 +1553,20 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() {
// %inc = add nsw %i.0, 1
// br i1 %tobool
- const Value *Args[] =
- {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext())
- : ConstantInt::getFalse(InitX->getContext())};
+ const Value *Args[] = {
+ InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext())
+ : ConstantInt::getFalse(InitX->getContext())};
// @llvm.dbg doesn't count as they have no semantic effect.
auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
uint32_t HeaderSize =
std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
+ IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args);
+ int Cost =
+ TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency);
if (HeaderSize != IdiomCanonicalSize &&
- TTI->getIntrinsicCost(IntrinID, InitX->getType(), Args) >
- TargetTransformInfo::TCC_Basic)
+ Cost > TargetTransformInfo::TCC_Basic)
return false;
transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX,
diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
index 901204181a7c..3153a8721193 100644
--- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
@@ -68,7 +68,7 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI,
// While simplifying we may discover dead code or cause code to become dead.
// Keep track of all such instructions and we will delete them at the end.
- SmallVector<Instruction *, 8> DeadInsts;
+ SmallVector<WeakTrackingVH, 8> DeadInsts;
// First we want to create an RPO traversal of the loop body. By processing in
// RPO we can ensure that definitions are processed prior to uses (for non PHI
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index 6ce2d06058cf..7787c0bccd4c 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -412,7 +412,6 @@ public:
private:
bool adjustLoopLinks();
- void adjustLoopPreheaders();
bool adjustLoopBranches();
Loop *OuterLoop;
@@ -580,6 +579,12 @@ struct LoopInterchange : public LoopPass {
LIT.transform();
LLVM_DEBUG(dbgs() << "Loops interchanged.\n");
LoopsInterchanged++;
+
+ assert(InnerLoop->isLCSSAForm(*DT) &&
+ "Inner loop not left in LCSSA form after loop interchange!");
+ assert(OuterLoop->isLCSSAForm(*DT) &&
+ "Outer loop not left in LCSSA form after loop interchange!");
+
return true;
}
};
@@ -689,7 +694,7 @@ bool LoopInterchangeLegality::findInductionAndReductions(
// PHIs in inner loops need to be part of a reduction in the outer loop,
// discovered when checking the PHIs of the outer loop earlier.
if (!InnerLoop) {
- if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) {
+ if (!OuterInnerReductions.count(&PHI)) {
LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions "
"across the outer loop.\n");
return false;
@@ -903,8 +908,8 @@ areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL,
return false;
if (any_of(PHI.users(), [&Reductions, OuterL](User *U) {
PHINode *PN = dyn_cast<PHINode>(U);
- return !PN || (Reductions.find(PN) == Reductions.end() &&
- OuterL->contains(PN->getParent()));
+ return !PN ||
+ (!Reductions.count(PN) && OuterL->contains(PN->getParent()));
})) {
return false;
}
@@ -1319,6 +1324,23 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) {
FromBB->getTerminator()->getIterator());
}
+/// Swap instructions between \p BB1 and \p BB2 but keep terminators intact.
+static void swapBBContents(BasicBlock *BB1, BasicBlock *BB2) {
+ // Save all non-terminator instructions of BB1 into TempInstrs and unlink them
+ // from BB1 afterwards.
+ auto Iter = map_range(*BB1, [](Instruction &I) { return &I; });
+ SmallVector<Instruction *, 4> TempInstrs(Iter.begin(), std::prev(Iter.end()));
+ for (Instruction *I : TempInstrs)
+ I->removeFromParent();
+
+ // Move instructions from BB2 to BB1.
+ moveBBContents(BB2, BB1->getTerminator());
+
+ // Move instructions from TempInstrs to BB2.
+ for (Instruction *I : TempInstrs)
+ I->insertBefore(BB2->getTerminator());
+}
+
// Update BI to jump to NewBB instead of OldBB. Records updates to the
// dominator tree in DTUpdates. If \p MustUpdateOnce is true, assert that
// \p OldBB is exactly once in BI's successor list.
@@ -1560,13 +1582,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() {
// outer loop and all the remains to do is and updating the incoming blocks.
for (PHINode *PHI : OuterLoopPHIs) {
PHI->moveBefore(InnerLoopHeader->getFirstNonPHI());
- assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() &&
- "Expected a reduction PHI node");
+ assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node");
}
for (PHINode *PHI : InnerLoopPHIs) {
PHI->moveBefore(OuterLoopHeader->getFirstNonPHI());
- assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() &&
- "Expected a reduction PHI node");
+ assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node");
}
// Update the incoming blocks for moved PHI nodes.
@@ -1578,30 +1598,17 @@ bool LoopInterchangeTransform::adjustLoopBranches() {
return true;
}
-void LoopInterchangeTransform::adjustLoopPreheaders() {
- // We have interchanged the preheaders so we need to interchange the data in
- // the preheader as well.
- // This is because the content of inner preheader was previously executed
- // inside the outer loop.
- BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader();
- BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader();
- BasicBlock *OuterLoopHeader = OuterLoop->getHeader();
- BranchInst *InnerTermBI =
- cast<BranchInst>(InnerLoopPreHeader->getTerminator());
-
- // These instructions should now be executed inside the loop.
- // Move instruction into a new block after outer header.
- moveBBContents(InnerLoopPreHeader, OuterLoopHeader->getTerminator());
- // These instructions were not executed previously in the loop so move them to
- // the older inner loop preheader.
- moveBBContents(OuterLoopPreHeader, InnerTermBI);
-}
-
bool LoopInterchangeTransform::adjustLoopLinks() {
// Adjust all branches in the inner and outer loop.
bool Changed = adjustLoopBranches();
- if (Changed)
- adjustLoopPreheaders();
+ if (Changed) {
+ // We have interchanged the preheaders so we need to interchange the data in
+ // the preheaders as well. This is because the content of the inner
+ // preheader was previously executed inside the outer loop.
+ BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader();
+ BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader();
+ swapBBContents(OuterLoopPreHeader, InnerLoopPreHeader);
+ }
return Changed;
}
diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 4e1b4e87ebc9..4412b3079461 100644
--- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -38,7 +38,6 @@
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -58,6 +57,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include <algorithm>
#include <cassert>
@@ -377,7 +377,7 @@ public:
/// Determine the pointer alias checks to prove that there are no
/// intervening stores.
- SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks(
+ SmallVector<RuntimePointerCheck, 4> collectMemchecks(
const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =
@@ -391,10 +391,10 @@ public:
std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr));
const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
- SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+ SmallVector<RuntimePointerCheck, 4> Checks;
copy_if(AllChecks, std::back_inserter(Checks),
- [&](const RuntimePointerChecking::PointerCheck &Check) {
+ [&](const RuntimePointerCheck &Check) {
for (auto PtrIdx1 : Check.first->Members)
for (auto PtrIdx2 : Check.second->Members)
if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
@@ -432,12 +432,12 @@ public:
Value *Ptr = Cand.Load->getPointerOperand();
auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(Ptr));
auto *PH = L->getLoopPreheader();
+ assert(PH && "Preheader should exist!");
Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(),
PH->getTerminator());
Value *Initial = new LoadInst(
Cand.Load->getType(), InitialPtr, "load_initial",
- /* isVolatile */ false, MaybeAlign(Cand.Load->getAlignment()),
- PH->getTerminator());
+ /* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator());
PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded",
&L->getHeader()->front());
@@ -520,8 +520,7 @@ public:
// Check intervening may-alias stores. These need runtime checks for alias
// disambiguation.
- SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks =
- collectMemchecks(Candidates);
+ SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);
// Too many checks are likely to outweigh the benefits of forwarding.
if (Checks.size() > Candidates.size() * CheckPerElim) {
@@ -535,6 +534,11 @@ public:
return false;
}
+ if (!L->isLoopSimplifyForm()) {
+ LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form");
+ return false;
+ }
+
if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) {
if (LAI.hasConvergentOp()) {
LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with "
@@ -554,11 +558,6 @@ public:
return false;
}
- if (!L->isLoopSimplifyForm()) {
- LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form");
- return false;
- }
-
// Point of no-return, start the transformation. First, version the loop
// if necessary.
@@ -697,8 +696,8 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F,
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
- auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager();
- auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
+ auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+ auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
auto *BFI = (PSI && PSI->hasProfileSummary()) ?
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
MemorySSA *MSSA = EnableMSSALoopDependency
diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
index f3bfbd3564ab..98889a9df116 100644
--- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Support/TimeProfiler.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Analysis/LoopInfo.h"
@@ -33,15 +34,19 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
// instrumenting callbacks for the passes later.
PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR);
for (auto &Pass : Passes) {
- if (DebugLogging)
- dbgs() << "Running pass: " << Pass->name() << " on " << L;
-
// Check the PassInstrumentation's BeforePass callbacks before running the
// pass, skip its execution completely if asked to (callback returns false).
if (!PI.runBeforePass<Loop>(*Pass, L))
continue;
- PreservedAnalyses PassPA = Pass->run(L, AM, AR, U);
+ if (DebugLogging)
+ dbgs() << "Running pass: " << Pass->name() << " on " << L;
+
+ PreservedAnalyses PassPA;
+ {
+ TimeTraceScope TimeScope(Pass->name(), L.getName());
+ PassPA = Pass->run(L, AM, AR, U);
+ }
// do not pass deleted Loop into the instrumentation
if (U.skipCurrentLoop())
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 1a42f6b23443..edde22d6708f 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -184,7 +184,6 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
@@ -199,6 +198,7 @@
#include "llvm/Transforms/Utils/GuardUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#define DEBUG_TYPE "loop-predication"
@@ -268,7 +268,7 @@ class LoopPredication {
/// Return an insertion point suitable for inserting a safe to speculate
/// instruction whose only user will be 'User' which has operands 'Ops'. A
/// trivial result would be the at the User itself, but we try to return a
- /// loop invariant location if possible.
+ /// loop invariant location if possible.
Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
/// Same as above, *except* that this uses the SCEV definition of invariant
/// which is that an expression *can be made* invariant via SCEVExpander.
@@ -278,7 +278,7 @@ class LoopPredication {
/// Return true if the value is known to produce a single fixed value across
/// all iterations on which it executes. Note that this does not imply
- /// speculation safety. That must be established seperately.
+ /// speculation safety. That must be established separately.
bool isLoopInvariantValue(const SCEV* S);
Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
@@ -342,7 +342,7 @@ public:
};
char LoopPredicationLegacyPass::ID = 0;
-} // end namespace llvm
+} // end namespace
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
"Loop predication", false, false)
@@ -358,11 +358,12 @@ Pass *llvm::createLoopPredicationPass() {
PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
- const auto &FAM =
- AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
Function *F = L.getHeader()->getParent();
- auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
- LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, BPI);
+ // For the new PM, we also can't use BranchProbabilityInfo as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but BPI is not preserved, hence a newly built one is needed.
+ BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI);
+ LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI);
if (!LP.runOnLoop(&L))
return PreservedAnalyses::all();
@@ -397,7 +398,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) {
}
Value *LoopPredication::expandCheck(SCEVExpander &Expander,
- Instruction *Guard,
+ Instruction *Guard,
ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS) {
Type *Ty = LHS->getType();
@@ -521,7 +522,7 @@ Instruction *LoopPredication::findInsertPt(Instruction *Use,
return Preheader->getTerminator();
}
-bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
+bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
// Handling expressions which produce invariant results, but *haven't* yet
// been removed from the loop serves two important purposes.
// 1) Most importantly, it resolves a pass ordering cycle which would
@@ -534,12 +535,12 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
// much more obviously in the IR. Otherwise, the cost modeling for other
// transforms would end up needing to duplicate all of this logic to model a
// check which becomes predictable based on a modeled peel or unswitch.
- //
+ //
// The cost of doing so in the worst case is an extra fill from the stack in
// the loop to materialize the loop invariant test value instead of checking
// against the original IV which is presumable in a register inside the loop.
// Such cases are presumably rare, and hint at missing oppurtunities for
- // other passes.
+ // other passes.
if (SE->isLoopInvariant(S, L))
// Note: This the SCEV variant, so the original Value* may be within the
@@ -547,7 +548,7 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
return true;
// Handle a particular important case which SCEV doesn't yet know about which
- // shows up in range checks on arrays with immutable lengths.
+ // shows up in range checks on arrays with immutable lengths.
// TODO: This should be sunk inside SCEV.
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
@@ -574,7 +575,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
const SCEV *LatchLimit = LatchCheck.Limit;
// Subtlety: We need all the values to be *invariant* across all iterations,
// but we only need to check expansion safety for those which *aren't*
- // already guaranteed to dominate the guard.
+ // already guaranteed to dominate the guard.
if (!isLoopInvariantValue(GuardStart) ||
!isLoopInvariantValue(GuardLimit) ||
!isLoopInvariantValue(LatchStart) ||
@@ -598,7 +599,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
-
+
auto *LimitCheck =
expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
@@ -617,7 +618,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
const SCEV *LatchLimit = LatchCheck.Limit;
// Subtlety: We need all the values to be *invariant* across all iterations,
// but we only need to check expansion safety for those which *aren't*
- // already guaranteed to dominate the guard.
+ // already guaranteed to dominate the guard.
if (!isLoopInvariantValue(GuardStart) ||
!isLoopInvariantValue(GuardLimit) ||
!isLoopInvariantValue(LatchStart) ||
@@ -658,7 +659,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
static void normalizePredicate(ScalarEvolution *SE, Loop *L,
LoopICmp& RC) {
// LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
- // ULT/UGE form for ease of handling by our caller.
+ // ULT/UGE form for ease of handling by our caller.
if (ICmpInst::isEquality(RC.Pred) &&
RC.IV->getStepRecurrence(*SE)->isOne() &&
SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
@@ -1020,17 +1021,6 @@ static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE,
return SE.getUMinFromMismatchedTypes(ExitCounts);
}
-/// Return true if we can be fairly sure that executing block BB will probably
-/// lead to executing an __llvm_deoptimize. This is a profitability heuristic,
-/// not a legality constraint.
-static bool isVeryLikelyToDeopt(BasicBlock *BB) {
- while (BB->getUniqueSuccessor())
- // Will skip side effects, that's okay
- BB = BB->getUniqueSuccessor();
-
- return BB->getTerminatingDeoptimizeCall();
-}
-
/// This implements an analogous, but entirely distinct transform from the main
/// loop predication transform. This one is phrased in terms of using a
/// widenable branch *outside* the loop to allow us to simplify loop exits in a
@@ -1054,7 +1044,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
// inserting a branch on the value which can be either poison or undef. In
// this case, the branch can legally go either way; we just need to avoid
// introducing UB. This is achieved through the use of the freeze
- // instruction.
+ // instruction.
SmallVector<BasicBlock *, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
@@ -1082,7 +1072,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
// analyzeable after dropping widenability.
{
bool Invalidate = false;
-
+
for (auto *ExitingBB : ExitingBlocks) {
if (LI->getLoopFor(ExitingBB) != L)
continue;
@@ -1150,10 +1140,13 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
- if (!isVeryLikelyToDeopt(ExitBB))
- // Profitability: indicator of rarely/never taken exit
+ if (!ExitBB->getPostdominatingDeoptimizeCall())
continue;
+ /// Here we can be fairly sure that executing this exit will most likely
+ /// lead to executing llvm.experimental.deoptimize.
+ /// This is a profitability heuristic, not a legality constraint.
+
// If we found a widenable exit condition, do two things:
// 1) fold the widened exit test into the widenable condition
// 2) fold the branch to untaken - avoids infinite looping
diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
index da13a342ae12..3542d0a4ee73 100644
--- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
@@ -24,7 +24,6 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -55,6 +54,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -880,6 +880,12 @@ bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) {
if (DRS.Roots.empty())
return false;
+ // If the value of the base instruction is used outside the loop, we cannot
+ // reroll the loop. Check for other root instructions is unnecessary because
+ // they don't match any base instructions if their values are used outside.
+ if (hasUsesOutsideLoop(DRS.BaseInst, L))
+ return false;
+
// Consider a DAGRootSet with N-1 roots (so N different values including
// BaseInst).
// Define d = Roots[0] - BaseInst, which should be the same as
@@ -1126,7 +1132,7 @@ static bool isIgnorableInst(const Instruction *I) {
case Intrinsic::annotation:
case Intrinsic::ptr_annotation:
case Intrinsic::var_annotation:
- // TODO: the following intrinsics may also be whitelisted:
+ // TODO: the following intrinsics may also be allowed:
// lifetime_start, lifetime_end, invariant_start, invariant_end
return true;
}
diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp
index 0868e742f4ee..f92566ba77ce 100644
--- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp
@@ -81,10 +81,8 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetTransformInfoWrapperPass>();
- if (EnableMSSALoopDependency) {
- AU.addRequired<MemorySSAWrapperPass>();
+ if (EnableMSSALoopDependency)
AU.addPreserved<MemorySSAWrapperPass>();
- }
getLoopAnalysisUsage(AU);
}
@@ -101,15 +99,18 @@ public:
const SimplifyQuery SQ = getBestSimplifyQuery(*this, F);
Optional<MemorySSAUpdater> MSSAU;
if (EnableMSSALoopDependency) {
- MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
- MSSAU = MemorySSAUpdater(MSSA);
+ // Not requiring MemorySSA and getting it only if available will split
+ // the loop pass pipeline when LoopRotate is being run first.
+ auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ if (MSSAA)
+ MSSAU = MemorySSAUpdater(&MSSAA->getMSSA());
}
return LoopRotation(L, LI, TTI, AC, &DT, &SE,
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ,
false, MaxHeaderSize, false);
}
};
-}
+} // end namespace
char LoopRotateLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops",
diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
index b27e65e0adb7..031e5b9c1d2c 100644
--- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
@@ -23,6 +23,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopIterator.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
@@ -30,6 +31,7 @@
#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Scalar.h"
@@ -673,13 +675,13 @@ static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT,
static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI,
ScalarEvolution &SE, MemorySSAUpdater *MSSAU,
- bool &isLoopDeleted) {
+ bool &IsLoopDeleted) {
bool Changed = false;
// Constant-fold terminators with known constant conditions.
- Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, isLoopDeleted);
+ Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, IsLoopDeleted);
- if (isLoopDeleted)
+ if (IsLoopDeleted)
return true;
// Eliminate unconditional branches by merging blocks into their predecessors.
@@ -752,7 +754,7 @@ public:
getLoopAnalysisUsage(AU);
}
};
-}
+} // end namespace
char LoopSimplifyCFGLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg",
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index e9f368628a08..cf02ef1e83f3 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -65,12 +65,14 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/IVUsers.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -109,6 +111,7 @@
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
@@ -807,9 +810,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI,
switch (II->getIntrinsicID()) {
case Intrinsic::memset:
case Intrinsic::prefetch:
+ case Intrinsic::masked_load:
if (II->getArgOperand(0) == OperandVal)
isAddress = true;
break;
+ case Intrinsic::masked_store:
+ if (II->getArgOperand(1) == OperandVal)
+ isAddress = true;
+ break;
case Intrinsic::memmove:
case Intrinsic::memcpy:
if (II->getArgOperand(0) == OperandVal ||
@@ -859,6 +867,15 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
AccessTy.AddrSpace = OperandVal->getType()->getPointerAddressSpace();
AccessTy.MemTy = OperandVal->getType();
break;
+ case Intrinsic::masked_load:
+ AccessTy.AddrSpace =
+ II->getArgOperand(0)->getType()->getPointerAddressSpace();
+ break;
+ case Intrinsic::masked_store:
+ AccessTy.MemTy = II->getOperand(0)->getType();
+ AccessTy.AddrSpace =
+ II->getArgOperand(1)->getType()->getPointerAddressSpace();
+ break;
default: {
MemIntrinsicInfo IntrInfo;
if (TTI.getTgtMemIntrinsic(II, IntrInfo) && IntrInfo.PtrVal) {
@@ -962,33 +979,6 @@ static bool isHighCostExpansion(const SCEV *S,
return true;
}
-/// If any of the instructions in the specified set are trivially dead, delete
-/// them and see if this makes any of their operands subsequently dead.
-static bool
-DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
- bool Changed = false;
-
- while (!DeadInsts.empty()) {
- Value *V = DeadInsts.pop_back_val();
- Instruction *I = dyn_cast_or_null<Instruction>(V);
-
- if (!I || !isInstructionTriviallyDead(I))
- continue;
-
- for (Use &O : I->operands())
- if (Instruction *U = dyn_cast<Instruction>(O)) {
- O = nullptr;
- if (U->use_empty())
- DeadInsts.emplace_back(U);
- }
-
- I->eraseFromParent();
- Changed = true;
- }
-
- return Changed;
-}
-
namespace {
class LSRUse;
@@ -1242,7 +1232,7 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,
// for now LSR only handles innermost loops).
if (AR->getLoop() != L) {
// If the AddRec exists, consider it's register free and leave it alone.
- if (isExistingPhi(AR, *SE))
+ if (isExistingPhi(AR, *SE) && !TTI->shouldFavorPostInc())
return;
// It is bad to allow LSR for current loop to add induction variables
@@ -1913,9 +1903,10 @@ class LSRInstance {
DominatorTree &DT;
LoopInfo &LI;
AssumptionCache &AC;
- TargetLibraryInfo &LibInfo;
+ TargetLibraryInfo &TLI;
const TargetTransformInfo &TTI;
Loop *const L;
+ MemorySSAUpdater *MSSAU;
bool FavorBackedgeIndex = false;
bool Changed = false;
@@ -2018,6 +2009,7 @@ class LSRInstance {
void NarrowSearchSpaceByCollapsingUnrolledCode();
void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();
void NarrowSearchSpaceByFilterFormulaWithSameScaledReg();
+ void NarrowSearchSpaceByFilterPostInc();
void NarrowSearchSpaceByDeletingCostlyFormulas();
void NarrowSearchSpaceByPickingWinnerRegs();
void NarrowSearchSpaceUsingHeuristics();
@@ -2053,7 +2045,7 @@ class LSRInstance {
public:
LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT,
LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC,
- TargetLibraryInfo &LibInfo);
+ TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU);
bool getChanged() const { return Changed; }
@@ -2830,9 +2822,10 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr,
/// increments can be computed in fewer registers when chained.
///
/// TODO: Consider IVInc free if it's already used in another chains.
-static bool
-isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users,
- ScalarEvolution &SE) {
+static bool isProfitableChain(IVChain &Chain,
+ SmallPtrSetImpl<Instruction *> &Users,
+ ScalarEvolution &SE,
+ const TargetTransformInfo &TTI) {
if (StressIVChain)
return true;
@@ -2861,7 +2854,14 @@ isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users,
unsigned NumConstIncrements = 0;
unsigned NumVarIncrements = 0;
unsigned NumReusedIncrements = 0;
+
+ if (TTI.isProfitableLSRChainElement(Chain.Incs[0].UserInst))
+ return true;
+
for (const IVInc &Inc : Chain) {
+ if (TTI.isProfitableLSRChainElement(Inc.UserInst))
+ return true;
+
if (Inc.IncExpr->isZero())
continue;
@@ -3092,7 +3092,7 @@ void LSRInstance::CollectChains() {
for (unsigned UsersIdx = 0, NChains = IVChainVec.size();
UsersIdx < NChains; ++UsersIdx) {
if (!isProfitableChain(IVChainVec[UsersIdx],
- ChainUsersVec[UsersIdx].FarUsers, SE))
+ ChainUsersVec[UsersIdx].FarUsers, SE, TTI))
continue;
// Preserve the chain at UsesIdx.
if (ChainIdx != UsersIdx)
@@ -3212,7 +3212,8 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,
IVOper = Builder.CreateTruncOrBitCast(IVOper, OperTy, "lsr.chain");
}
Inc.UserInst->replaceUsesOfWith(Inc.IVOperand, IVOper);
- DeadInsts.emplace_back(Inc.IVOperand);
+ if (auto *OperandIsInstr = dyn_cast<Instruction>(Inc.IVOperand))
+ DeadInsts.emplace_back(OperandIsInstr);
}
// If LSR created a new, wider phi, we may also replace its postinc. We only
// do this if we also found a wide value for the head of the chain.
@@ -3240,7 +3241,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,
void LSRInstance::CollectFixupsAndInitialFormulae() {
BranchInst *ExitBranch = nullptr;
- bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &LibInfo);
+ bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &TLI);
for (const IVStrideUse &U : IU) {
Instruction *UserInst = U.getUser();
@@ -3553,9 +3554,6 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI,
const SCEV *LoopStep = AR->getStepRecurrence(SE);
if (!isa<SCEVConstant>(LoopStep))
return false;
- if (LU.AccessTy.getType()->getScalarSizeInBits() !=
- LoopStep->getType()->getScalarSizeInBits())
- return false;
// Check if a post-indexed load/store can be used.
if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) ||
TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) {
@@ -4673,6 +4671,54 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() {
});
}
+/// If we are over the complexity limit, filter out any post-inc prefering
+/// variables to only post-inc values.
+void LSRInstance::NarrowSearchSpaceByFilterPostInc() {
+ if (!TTI.shouldFavorPostInc())
+ return;
+ if (EstimateSearchSpaceComplexity() < ComplexityLimit)
+ return;
+
+ LLVM_DEBUG(dbgs() << "The search space is too complex.\n"
+ "Narrowing the search space by choosing the lowest "
+ "register Formula for PostInc Uses.\n");
+
+ for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) {
+ LSRUse &LU = Uses[LUIdx];
+
+ if (LU.Kind != LSRUse::Address)
+ continue;
+ if (!TTI.isIndexedLoadLegal(TTI.MIM_PostInc, LU.AccessTy.getType()) &&
+ !TTI.isIndexedStoreLegal(TTI.MIM_PostInc, LU.AccessTy.getType()))
+ continue;
+
+ size_t MinRegs = std::numeric_limits<size_t>::max();
+ for (const Formula &F : LU.Formulae)
+ MinRegs = std::min(F.getNumRegs(), MinRegs);
+
+ bool Any = false;
+ for (size_t FIdx = 0, NumForms = LU.Formulae.size(); FIdx != NumForms;
+ ++FIdx) {
+ Formula &F = LU.Formulae[FIdx];
+ if (F.getNumRegs() > MinRegs) {
+ LLVM_DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs());
+ dbgs() << "\n");
+ LU.DeleteFormula(F);
+ --FIdx;
+ --NumForms;
+ Any = true;
+ }
+ }
+ if (Any)
+ LU.RecomputeRegs(LUIdx, RegUses);
+
+ if (EstimateSearchSpaceComplexity() < ComplexityLimit)
+ break;
+ }
+
+ LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs()));
+}
+
/// The function delete formulas with high registers number expectation.
/// Assuming we don't know the value of each formula (already delete
/// all inefficient), generate probability of not selecting for each
@@ -4883,6 +4929,7 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() {
NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();
if (FilterSameScaledReg)
NarrowSearchSpaceByFilterFormulaWithSameScaledReg();
+ NarrowSearchSpaceByFilterPostInc();
if (LSRExpNarrow)
NarrowSearchSpaceByDeletingCostlyFormulas();
else
@@ -4923,19 +4970,24 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution,
// Ignore formulae which may not be ideal in terms of register reuse of
// ReqRegs. The formula should use all required registers before
// introducing new ones.
- int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size());
- for (const SCEV *Reg : ReqRegs) {
- if ((F.ScaledReg && F.ScaledReg == Reg) ||
- is_contained(F.BaseRegs, Reg)) {
- --NumReqRegsToFind;
- if (NumReqRegsToFind == 0)
- break;
+ // This can sometimes (notably when trying to favour postinc) lead to
+ // sub-optimial decisions. There it is best left to the cost modelling to
+ // get correct.
+ if (!TTI.shouldFavorPostInc() || LU.Kind != LSRUse::Address) {
+ int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size());
+ for (const SCEV *Reg : ReqRegs) {
+ if ((F.ScaledReg && F.ScaledReg == Reg) ||
+ is_contained(F.BaseRegs, Reg)) {
+ --NumReqRegsToFind;
+ if (NumReqRegsToFind == 0)
+ break;
+ }
+ }
+ if (NumReqRegsToFind != 0) {
+ // If none of the formulae satisfied the required registers, then we could
+ // clear ReqRegs and try again. Currently, we simply give up in this case.
+ continue;
}
- }
- if (NumReqRegsToFind != 0) {
- // If none of the formulae satisfied the required registers, then we could
- // clear ReqRegs and try again. Currently, we simply give up in this case.
- continue;
}
// Evaluate the cost of the current formula. If it's already worse than
@@ -5268,7 +5320,8 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF,
// form, update the ICmp's other operand.
if (LU.Kind == LSRUse::ICmpZero) {
ICmpInst *CI = cast<ICmpInst>(LF.UserInst);
- DeadInsts.emplace_back(CI->getOperand(1));
+ if (auto *OperandIsInstr = dyn_cast<Instruction>(CI->getOperand(1)))
+ DeadInsts.emplace_back(OperandIsInstr);
assert(!F.BaseGV && "ICmp does not support folding a global value and "
"a scale at the same time!");
if (F.Scale == -1) {
@@ -5449,7 +5502,8 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF,
LF.UserInst->replaceUsesOfWith(LF.OperandValToReplace, FullV);
}
- DeadInsts.emplace_back(LF.OperandValToReplace);
+ if (auto *OperandIsInstr = dyn_cast<Instruction>(LF.OperandValToReplace))
+ DeadInsts.emplace_back(OperandIsInstr);
}
/// Rewrite all the fixup locations with new values, following the chosen
@@ -5490,16 +5544,17 @@ void LSRInstance::ImplementSolution(
// instructions.
Rewriter.clear();
- Changed |= DeleteTriviallyDeadInstructions(DeadInsts);
+ Changed |= RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts,
+ &TLI, MSSAU);
}
LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE,
DominatorTree &DT, LoopInfo &LI,
const TargetTransformInfo &TTI, AssumptionCache &AC,
- TargetLibraryInfo &LibInfo)
- : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), LibInfo(LibInfo), TTI(TTI), L(L),
- FavorBackedgeIndex(EnableBackedgeIndexing &&
- TTI.shouldFavorBackedgeIndex(L)) {
+ TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU)
+ : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), TLI(TLI), TTI(TTI), L(L),
+ MSSAU(MSSAU), FavorBackedgeIndex(EnableBackedgeIndexing &&
+ TTI.shouldFavorBackedgeIndex(L)) {
// If LoopSimplify form is not available, stay out of trouble.
if (!L->isLoopSimplifyForm())
return;
@@ -5702,21 +5757,26 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<IVUsersWrapperPass>();
AU.addPreserved<IVUsersWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.addPreserved<MemorySSAWrapperPass>();
}
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
DominatorTree &DT, LoopInfo &LI,
const TargetTransformInfo &TTI,
- AssumptionCache &AC,
- TargetLibraryInfo &LibInfo) {
+ AssumptionCache &AC, TargetLibraryInfo &TLI,
+ MemorySSA *MSSA) {
bool Changed = false;
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
+ if (MSSA)
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
// Run the main LSR transformation.
- Changed |= LSRInstance(L, IU, SE, DT, LI, TTI, AC, LibInfo).getChanged();
+ Changed |=
+ LSRInstance(L, IU, SE, DT, LI, TTI, AC, TLI, MSSAU.get()).getChanged();
// Remove any extra phis created by processing inner loops.
- Changed |= DeleteDeadPHIs(L->getHeader());
+ Changed |= DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
if (EnablePhiElim && L->isLoopSimplifyForm()) {
SmallVector<WeakTrackingVH, 16> DeadInsts;
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
@@ -5727,8 +5787,9 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI);
if (numFolded) {
Changed = true;
- DeleteTriviallyDeadInstructions(DeadInsts);
- DeleteDeadPHIs(L->getHeader());
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI,
+ MSSAU.get());
+ DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
}
}
return Changed;
@@ -5746,19 +5807,26 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
*L->getHeader()->getParent());
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
*L->getHeader()->getParent());
- auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
*L->getHeader()->getParent());
- return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo);
+ auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ MemorySSA *MSSA = nullptr;
+ if (MSSAAnalysis)
+ MSSA = &MSSAAnalysis->getMSSA();
+ return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, TLI, MSSA);
}
PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &) {
if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE,
- AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI))
+ AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI, AR.MSSA))
return PreservedAnalyses::all();
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ if (AR.MSSA)
+ PA.preserve<MemorySSAAnalysis>();
+ return PA;
}
char LoopStrengthReduce::ID = 0;
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
index 92ad8dafa5ab..285cba6ee205 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
@@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
-#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/PriorityWorklist.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -20,37 +22,36 @@
#include "llvm/Analysis/DependenceAnalysis.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CFG.h"
-#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
-#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
-#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/PassRegistry.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Scalar/LoopPassManager.h"
-#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
-#include <algorithm>
#include <cassert>
#include <cstdint>
-#include <string>
+#include <vector>
+
+namespace llvm {
+class Instruction;
+class Value;
+} // namespace llvm
using namespace llvm;
@@ -91,7 +92,7 @@ static cl::opt<unsigned> PragmaUnrollAndJamThreshold(
// Returns the loop hint metadata node with the given name (for example,
// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is
// returned.
-static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {
+static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) {
if (MDNode *LoopID = L->getLoopID())
return GetUnrollMetadata(LoopID, Name);
return nullptr;
@@ -99,14 +100,14 @@ static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {
// Returns true if the loop has any metadata starting with Prefix. For example a
// Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata.
-static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
+static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
if (MDNode *LoopID = L->getLoopID()) {
// First operand should refer to the loop id itself.
assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
- for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
- MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
+ for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) {
+ MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(I));
if (!MD)
continue;
@@ -122,14 +123,14 @@ static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
}
// Returns true if the loop has an unroll_and_jam(enable) pragma.
-static bool HasUnrollAndJamEnablePragma(const Loop *L) {
- return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable");
+static bool hasUnrollAndJamEnablePragma(const Loop *L) {
+ return getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable");
}
// If loop has an unroll_and_jam_count pragma return the (necessarily
// positive) value from the pragma. Otherwise return 0.
-static unsigned UnrollAndJamCountPragmaValue(const Loop *L) {
- MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count");
+static unsigned unrollAndJamCountPragmaValue(const Loop *L) {
+ MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count");
if (MD) {
assert(MD->getNumOperands() == 2 &&
"Unroll count hint metadata should have two operands.");
@@ -157,7 +158,8 @@ static bool computeUnrollAndJamCount(
const SmallPtrSetImpl<const Value *> &EphValues,
OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount,
- unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) {
+ unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP,
+ TargetTransformInfo::PeelingPreferences &PP) {
// First up use computeUnrollCount from the loop unroller to get a count
// for unrolling the outer loop, plus any loops requiring explicit
// unrolling we leave to the unroller. This uses UP.Threshold /
@@ -167,7 +169,8 @@ static bool computeUnrollAndJamCount(
bool UseUpperBound = false;
bool ExplicitUnroll = computeUnrollCount(
L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount,
- /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, UseUpperBound);
+ /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP,
+ UseUpperBound);
if (ExplicitUnroll || UseUpperBound) {
// If the user explicitly set the loop as unrolled, dont UnJ it. Leave it
// for the unroller instead.
@@ -190,7 +193,7 @@ static bool computeUnrollAndJamCount(
}
// Check for unroll_and_jam pragmas
- unsigned PragmaCount = UnrollAndJamCountPragmaValue(L);
+ unsigned PragmaCount = unrollAndJamCountPragmaValue(L);
if (PragmaCount > 0) {
UP.Count = PragmaCount;
UP.Runtime = true;
@@ -202,7 +205,7 @@ static bool computeUnrollAndJamCount(
return true;
}
- bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L);
+ bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L);
bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount;
bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount;
@@ -279,24 +282,11 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
ScalarEvolution &SE, const TargetTransformInfo &TTI,
AssumptionCache &AC, DependenceInfo &DI,
OptimizationRemarkEmitter &ORE, int OptLevel) {
- // Quick checks of the correct loop form
- if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
- return LoopUnrollResult::Unmodified;
- Loop *SubLoop = L->getSubLoops()[0];
- if (!SubLoop->isLoopSimplifyForm())
- return LoopUnrollResult::Unmodified;
-
- BasicBlock *Latch = L->getLoopLatch();
- BasicBlock *Exit = L->getExitingBlock();
- BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
- BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
-
- if (Latch != Exit || SubLoopLatch != SubLoopExit)
- return LoopUnrollResult::Unmodified;
-
TargetTransformInfo::UnrollingPreferences UP =
gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None,
- None, None, None, None, None, None, None);
+ None, None, None, None, None);
+ TargetTransformInfo::PeelingPreferences PP =
+ gatherPeelingPreferences(L, SE, TTI, None, None);
if (AllowUnrollAndJam.getNumOccurrences() > 0)
UP.UnrollAndJam = AllowUnrollAndJam;
if (UnrollAndJamThreshold.getNumOccurrences() > 0)
@@ -317,13 +307,13 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
// the unroller, so long as it does not explicitly have unroll_and_jam
// metadata. This means #pragma nounroll will disable unroll and jam as well
// as unrolling
- if (HasAnyUnrollPragma(L, "llvm.loop.unroll.") &&
- !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) {
+ if (hasAnyUnrollPragma(L, "llvm.loop.unroll.") &&
+ !hasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) {
LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n");
return LoopUnrollResult::Unmodified;
}
- if (!isSafeToUnrollAndJam(L, SE, DT, DI)) {
+ if (!isSafeToUnrollAndJam(L, SE, DT, DI, *LI)) {
LLVM_DEBUG(dbgs() << " Disabled due to not being safe.\n");
return LoopUnrollResult::Unmodified;
}
@@ -334,6 +324,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
bool Convergent;
SmallPtrSet<const Value *, 32> EphValues;
CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
+ Loop *SubLoop = L->getSubLoops()[0];
unsigned InnerLoopSize =
ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable,
Convergent, TTI, EphValues, UP.BEInsns);
@@ -371,6 +362,8 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
SubLoop->setLoopID(NewInnerEpilogueLoopID.getValue());
// Find trip count and trip multiple
+ BasicBlock *Latch = L->getLoopLatch();
+ BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch);
unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch);
unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch);
@@ -378,7 +371,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
// Decide if, and by how much, to unroll
bool IsCountSetExplicitly = computeUnrollAndJamCount(
L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount,
- OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP);
+ OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP);
if (UP.Count <= 1)
return LoopUnrollResult::Unmodified;
// Unroll factor (Count) must be less or equal to TripCount.
@@ -388,7 +381,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
Loop *EpilogueOuterLoop = nullptr;
LoopUnrollResult UnrollResult = UnrollAndJamLoop(
L, UP.Count, OuterTripCount, OuterTripMultiple, UP.UnrollRemainder, LI,
- &SE, &DT, &AC, &ORE, &EpilogueOuterLoop);
+ &SE, &DT, &AC, &TTI, &ORE, &EpilogueOuterLoop);
// Assign new loop attributes.
if (EpilogueOuterLoop) {
@@ -435,22 +428,23 @@ static bool tryToUnrollAndJamLoop(Function &F, DominatorTree &DT, LoopInfo &LI,
int OptLevel) {
bool DidSomething = false;
- // The loop unroll and jam pass requires loops to be in simplified form, and also needs LCSSA.
- // Since simplification may add new inner loops, it has to run before the
- // legality and profitability checks. This means running the loop unroll and jam pass
- // will simplify all loops, regardless of whether anything end up being
- // unroll and jammed.
+ // The loop unroll and jam pass requires loops to be in simplified form, and
+ // also needs LCSSA. Since simplification may add new inner loops, it has to
+ // run before the legality and profitability checks. This means running the
+ // loop unroll and jam pass will simplify all loops, regardless of whether
+ // anything end up being unroll and jammed.
for (auto &L : LI) {
DidSomething |=
simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */);
DidSomething |= formLCSSARecursively(*L, DT, &LI, &SE);
}
+ // Add the loop nests in the reverse order of LoopInfo. See method
+ // declaration.
SmallPriorityWorklist<Loop *, 4> Worklist;
- internal::appendLoopsToWorklist(reverse(LI), Worklist);
+ appendLoopsToWorklist(LI, Worklist);
while (!Worklist.empty()) {
Loop *L = Worklist.pop_back_val();
- formLCSSA(*L, DT, &LI, &SE);
LoopUnrollResult Result =
tryToUnrollAndJamLoop(L, DT, &LI, SE, TTI, AC, DI, ORE, OptLevel);
if (Result != LoopUnrollResult::Unmodified)
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
index 4c2b079c6bb5..87f40bb7ba85 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -154,6 +154,10 @@ static cl::opt<bool>
cl::desc("Allows loops to be peeled when the dynamic "
"trip count is known to be low."));
+static cl::opt<bool> UnrollAllowLoopNestsPeeling(
+ "unroll-allow-loop-nests-peeling", cl::init(false), cl::Hidden,
+ cl::desc("Allows loop nests to be peeled."));
+
static cl::opt<bool> UnrollUnrollRemainder(
"unroll-remainder", cl::Hidden,
cl::desc("Allow the loop remainder to be unrolled."));
@@ -167,6 +171,16 @@ static cl::opt<bool> UnrollRevisitChildLoops(
"This shouldn't typically be needed as child loops (or their "
"clones) were already visited."));
+static cl::opt<unsigned> UnrollThresholdAggressive(
+ "unroll-threshold-aggressive", cl::init(300), cl::Hidden,
+ cl::desc("Threshold (max size of unrolled loop) to use in aggressive (O3) "
+ "optimizations"));
+static cl::opt<unsigned>
+ UnrollThresholdDefault("unroll-threshold-default", cl::init(150),
+ cl::Hidden,
+ cl::desc("Default threshold (max size of unrolled "
+ "loop), used in all but O3 optimizations"));
+
/// A magic value for use with the Threshold parameter to indicate
/// that the loop unroll should be performed regardless of how much
/// code expansion would result.
@@ -179,19 +193,17 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel,
Optional<unsigned> UserThreshold, Optional<unsigned> UserCount,
Optional<bool> UserAllowPartial, Optional<bool> UserRuntime,
- Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling,
- Optional<bool> UserAllowProfileBasedPeeling,
- Optional<unsigned> UserFullUnrollMaxCount) {
+ Optional<bool> UserUpperBound, Optional<unsigned> UserFullUnrollMaxCount) {
TargetTransformInfo::UnrollingPreferences UP;
// Set up the defaults
- UP.Threshold = OptLevel > 2 ? 300 : 150;
+ UP.Threshold =
+ OptLevel > 2 ? UnrollThresholdAggressive : UnrollThresholdDefault;
UP.MaxPercentThresholdBoost = 400;
UP.OptSizeThreshold = 0;
UP.PartialThreshold = 150;
UP.PartialOptSizeThreshold = 0;
UP.Count = 0;
- UP.PeelCount = 0;
UP.DefaultUnrollRuntimeCount = 8;
UP.MaxCount = std::numeric_limits<unsigned>::max();
UP.FullUnrollMaxCount = std::numeric_limits<unsigned>::max();
@@ -203,10 +215,9 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.AllowExpensiveTripCount = false;
UP.Force = false;
UP.UpperBound = false;
- UP.AllowPeeling = true;
UP.UnrollAndJam = false;
- UP.PeelProfiledIterations = true;
UP.UnrollAndJamInnerLoopThreshold = 60;
+ UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze;
// Override with any target specific settings
TTI.getUnrollingPreferences(L, SE, UP);
@@ -232,8 +243,6 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.MaxCount = UnrollMaxCount;
if (UnrollFullMaxCount.getNumOccurrences() > 0)
UP.FullUnrollMaxCount = UnrollFullMaxCount;
- if (UnrollPeelCount.getNumOccurrences() > 0)
- UP.PeelCount = UnrollPeelCount;
if (UnrollAllowPartial.getNumOccurrences() > 0)
UP.Partial = UnrollAllowPartial;
if (UnrollAllowRemainder.getNumOccurrences() > 0)
@@ -242,10 +251,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.Runtime = UnrollRuntime;
if (UnrollMaxUpperBound == 0)
UP.UpperBound = false;
- if (UnrollAllowPeeling.getNumOccurrences() > 0)
- UP.AllowPeeling = UnrollAllowPeeling;
if (UnrollUnrollRemainder.getNumOccurrences() > 0)
UP.UnrollRemainder = UnrollUnrollRemainder;
+ if (UnrollMaxIterationsCountToAnalyze.getNumOccurrences() > 0)
+ UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze;
// Apply user values provided by argument
if (UserThreshold.hasValue()) {
@@ -260,16 +269,45 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.Runtime = *UserRuntime;
if (UserUpperBound.hasValue())
UP.UpperBound = *UserUpperBound;
- if (UserAllowPeeling.hasValue())
- UP.AllowPeeling = *UserAllowPeeling;
- if (UserAllowProfileBasedPeeling.hasValue())
- UP.PeelProfiledIterations = *UserAllowProfileBasedPeeling;
if (UserFullUnrollMaxCount.hasValue())
UP.FullUnrollMaxCount = *UserFullUnrollMaxCount;
return UP;
}
+TargetTransformInfo::PeelingPreferences
+llvm::gatherPeelingPreferences(Loop *L, ScalarEvolution &SE,
+ const TargetTransformInfo &TTI,
+ Optional<bool> UserAllowPeeling,
+ Optional<bool> UserAllowProfileBasedPeeling) {
+ TargetTransformInfo::PeelingPreferences PP;
+
+ // Default values
+ PP.PeelCount = 0;
+ PP.AllowPeeling = true;
+ PP.AllowLoopNestsPeeling = false;
+ PP.PeelProfiledIterations = true;
+
+ // Get Target Specifc Values
+ TTI.getPeelingPreferences(L, SE, PP);
+
+ // User Specified Values using cl::opt
+ if (UnrollPeelCount.getNumOccurrences() > 0)
+ PP.PeelCount = UnrollPeelCount;
+ if (UnrollAllowPeeling.getNumOccurrences() > 0)
+ PP.AllowPeeling = UnrollAllowPeeling;
+ if (UnrollAllowLoopNestsPeeling.getNumOccurrences() > 0)
+ PP.AllowLoopNestsPeeling = UnrollAllowLoopNestsPeeling;
+
+ // User Specifed values provided by argument
+ if (UserAllowPeeling.hasValue())
+ PP.AllowPeeling = *UserAllowPeeling;
+ if (UserAllowProfileBasedPeeling.hasValue())
+ PP.PeelProfiledIterations = *UserAllowProfileBasedPeeling;
+
+ return PP;
+}
+
namespace {
/// A struct to densely store the state of an instruction after unrolling at
@@ -335,11 +373,12 @@ struct EstimatedUnrollCost {
static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(
const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE,
const SmallPtrSetImpl<const Value *> &EphValues,
- const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize) {
+ const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize,
+ unsigned MaxIterationsCountToAnalyze) {
// We want to be able to scale offsets by the trip count and add more offsets
// to them without checking for overflows, and we already don't want to
// analyze *massive* trip counts, so we force the max to be reasonably small.
- assert(UnrollMaxIterationsCountToAnalyze <
+ assert(MaxIterationsCountToAnalyze <
(unsigned)(std::numeric_limits<int>::max() / 2) &&
"The unroll iterations max is too large!");
@@ -349,8 +388,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(
return None;
// Don't simulate loops with a big or unknown tripcount
- if (!UnrollMaxIterationsCountToAnalyze || !TripCount ||
- TripCount > UnrollMaxIterationsCountToAnalyze)
+ if (!TripCount || TripCount > MaxIterationsCountToAnalyze)
return None;
SmallSetVector<BasicBlock *, 16> BBWorklist;
@@ -428,7 +466,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(
// First accumulate the cost of this instruction.
if (!Cost.IsFree) {
- UnrolledCost += TTI.getUserCost(I);
+ UnrolledCost += TTI.getUserCost(I, TargetTransformInfo::TCK_CodeSize);
LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration "
<< Iteration << "): ");
LLVM_DEBUG(I->dump());
@@ -521,7 +559,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(
// Track this instruction's expected baseline cost when executing the
// rolled loop form.
- RolledDynamicCost += TTI.getUserCost(&I);
+ RolledDynamicCost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize);
// Visit the instruction to analyze its loop cost after unrolling,
// and if the visitor returns true, mark the instruction as free after
@@ -665,32 +703,32 @@ unsigned llvm::ApproximateLoopSize(
// Returns the loop hint metadata node with the given name (for example,
// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is
// returned.
-static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {
+static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) {
if (MDNode *LoopID = L->getLoopID())
return GetUnrollMetadata(LoopID, Name);
return nullptr;
}
// Returns true if the loop has an unroll(full) pragma.
-static bool HasUnrollFullPragma(const Loop *L) {
- return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.full");
+static bool hasUnrollFullPragma(const Loop *L) {
+ return getUnrollMetadataForLoop(L, "llvm.loop.unroll.full");
}
// Returns true if the loop has an unroll(enable) pragma. This metadata is used
// for both "#pragma unroll" and "#pragma clang loop unroll(enable)" directives.
-static bool HasUnrollEnablePragma(const Loop *L) {
- return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.enable");
+static bool hasUnrollEnablePragma(const Loop *L) {
+ return getUnrollMetadataForLoop(L, "llvm.loop.unroll.enable");
}
// Returns true if the loop has an runtime unroll(disable) pragma.
-static bool HasRuntimeUnrollDisablePragma(const Loop *L) {
- return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable");
+static bool hasRuntimeUnrollDisablePragma(const Loop *L) {
+ return getUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable");
}
// If loop has an unroll_count pragma return the (necessarily
// positive) value from the pragma. Otherwise return 0.
-static unsigned UnrollCountPragmaValue(const Loop *L) {
- MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll.count");
+static unsigned unrollCountPragmaValue(const Loop *L) {
+ MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll.count");
if (MD) {
assert(MD->getNumOperands() == 2 &&
"Unroll count hint metadata should have two operands.");
@@ -740,7 +778,8 @@ bool llvm::computeUnrollCount(
ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues,
OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount,
bool MaxOrZero, unsigned &TripMultiple, unsigned LoopSize,
- TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) {
+ TargetTransformInfo::UnrollingPreferences &UP,
+ TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) {
// Check for explicit Count.
// 1st priority is unroll count set by "unroll-count" option.
@@ -754,7 +793,7 @@ bool llvm::computeUnrollCount(
}
// 2nd priority is unroll count set by pragma.
- unsigned PragmaCount = UnrollCountPragmaValue(L);
+ unsigned PragmaCount = unrollCountPragmaValue(L);
if (PragmaCount > 0) {
UP.Count = PragmaCount;
UP.Runtime = true;
@@ -764,14 +803,14 @@ bool llvm::computeUnrollCount(
getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold)
return true;
}
- bool PragmaFullUnroll = HasUnrollFullPragma(L);
+ bool PragmaFullUnroll = hasUnrollFullPragma(L);
if (PragmaFullUnroll && TripCount != 0) {
UP.Count = TripCount;
if (getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold)
return false;
}
- bool PragmaEnableUnroll = HasUnrollEnablePragma(L);
+ bool PragmaEnableUnroll = hasUnrollEnablePragma(L);
bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll ||
PragmaEnableUnroll || UserUnrollCount;
@@ -827,7 +866,8 @@ bool llvm::computeUnrollCount(
// To check that, run additional analysis on the loop.
if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost(
L, FullUnrollTripCount, DT, SE, EphValues, TTI,
- UP.Threshold * UP.MaxPercentThresholdBoost / 100)) {
+ UP.Threshold * UP.MaxPercentThresholdBoost / 100,
+ UP.MaxIterationsCountToAnalyze)) {
unsigned Boost =
getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost);
if (Cost->UnrolledCost < UP.Threshold * Boost / 100) {
@@ -841,8 +881,8 @@ bool llvm::computeUnrollCount(
}
// 4th priority is loop peeling.
- computePeelCount(L, LoopSize, UP, TripCount, SE);
- if (UP.PeelCount) {
+ computePeelCount(L, LoopSize, UP, PP, TripCount, SE);
+ if (PP.PeelCount) {
UP.Runtime = false;
UP.Count = 1;
return ExplicitUnroll;
@@ -925,7 +965,7 @@ bool llvm::computeUnrollCount(
// 6th priority is runtime unrolling.
// Don't unroll a runtime trip count loop when it is disabled.
- if (HasRuntimeUnrollDisablePragma(L)) {
+ if (hasRuntimeUnrollDisablePragma(L)) {
UP.Count = 0;
return false;
}
@@ -1045,8 +1085,9 @@ static LoopUnrollResult tryToUnrollLoop(
TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount,
ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound,
- ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling,
ProvidedFullUnrollMaxCount);
+ TargetTransformInfo::PeelingPreferences PP = gatherPeelingPreferences(
+ L, SE, TTI, ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling);
// Exit early if unrolling is disabled. For OptForSize, we pick the loop size
// as threshold later on.
@@ -1120,7 +1161,7 @@ static LoopUnrollResult tryToUnrollLoop(
bool UseUpperBound = false;
bool IsCountSetExplicitly = computeUnrollCount(
L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero,
- TripMultiple, LoopSize, UP, UseUpperBound);
+ TripMultiple, LoopSize, UP, PP, UseUpperBound);
if (!UP.Count)
return LoopUnrollResult::Unmodified;
// Unroll factor (Count) must be less or equal to TripCount.
@@ -1135,9 +1176,9 @@ static LoopUnrollResult tryToUnrollLoop(
LoopUnrollResult UnrollResult = UnrollLoop(
L,
{UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount,
- UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder,
+ UseUpperBound, MaxOrZero, TripMultiple, PP.PeelCount, UP.UnrollRemainder,
ForgetAllSCEV},
- LI, &SE, &DT, &AC, &ORE, PreserveLCSSA, &RemainderLoop);
+ LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop);
if (UnrollResult == LoopUnrollResult::Unmodified)
return LoopUnrollResult::Unmodified;
@@ -1167,7 +1208,7 @@ static LoopUnrollResult tryToUnrollLoop(
// If the loop was peeled, we already "used up" the profile information
// we had, so we don't want to unroll or peel again.
if (UnrollResult != LoopUnrollResult::FullyUnrolled &&
- (IsCountSetExplicitly || (UP.PeelProfiledIterations && UP.PeelCount)))
+ (IsCountSetExplicitly || (PP.PeelProfiledIterations && PP.PeelCount)))
L->setLoopAlreadyUnrolled();
return UnrollResult;
@@ -1296,16 +1337,10 @@ Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced,
PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &Updater) {
- const auto &FAM =
- AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
- Function *F = L.getHeader()->getParent();
-
- auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F);
- // FIXME: This should probably be optional rather than required.
- if (!ORE)
- report_fatal_error(
- "LoopFullUnrollPass: OptimizationRemarkEmitterAnalysis not "
- "cached at a higher level");
+ // For the new PM, we can't use OptimizationRemarkEmitter as an analysis
+ // pass. Function analyses need to be preserved across loop transformations
+ // but ORE cannot be preserved (see comment before the pass definition).
+ OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
// Keep track of the previous loop structure so we can identify new loops
// created by unrolling.
@@ -1316,9 +1351,9 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
else
OldLoops.insert(AR.LI.begin(), AR.LI.end());
- std::string LoopName = L.getName();
+ std::string LoopName = std::string(L.getName());
- bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE,
+ bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE,
/*BFI*/ nullptr, /*PSI*/ nullptr,
/*PreserveLCSSA*/ true, OptLevel,
OnlyWhenForced, ForgetSCEV, /*Count*/ None,
@@ -1384,30 +1419,6 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
return getLoopPassPreservedAnalyses();
}
-template <typename RangeT>
-static SmallVector<Loop *, 8> appendLoopsToWorklist(RangeT &&Loops) {
- SmallVector<Loop *, 8> Worklist;
- // We use an internal worklist to build up the preorder traversal without
- // recursion.
- SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist;
-
- for (Loop *RootL : Loops) {
- assert(PreOrderLoops.empty() && "Must start with an empty preorder walk.");
- assert(PreOrderWorklist.empty() &&
- "Must start with an empty preorder walk worklist.");
- PreOrderWorklist.push_back(RootL);
- do {
- Loop *L = PreOrderWorklist.pop_back_val();
- PreOrderWorklist.append(L->begin(), L->end());
- PreOrderLoops.push_back(L);
- } while (!PreOrderWorklist.empty());
-
- Worklist.append(PreOrderLoops.begin(), PreOrderLoops.end());
- PreOrderLoops.clear();
- }
- return Worklist;
-}
-
PreservedAnalyses LoopUnrollPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
@@ -1421,10 +1432,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F))
LAM = &LAMProxy->getManager();
- const ModuleAnalysisManager &MAM =
- AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager();
+ auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
ProfileSummaryInfo *PSI =
- MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
+ MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
auto *BFI = (PSI && PSI->hasProfileSummary()) ?
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
@@ -1441,7 +1451,10 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
}
- SmallVector<Loop *, 8> Worklist = appendLoopsToWorklist(LI);
+ // Add the loop nests in the reverse order of LoopInfo. See method
+ // declaration.
+ SmallPriorityWorklist<Loop *, 4> Worklist;
+ appendLoopsToWorklist(LI, Worklist);
while (!Worklist.empty()) {
// Because the LoopInfo stores the loops in RPO, we walk the worklist
@@ -1459,7 +1472,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
Optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling;
if (PSI && PSI->hasHugeWorkingSetSize())
LocalAllowPeeling = false;
- std::string LoopName = L.getName();
+ std::string LoopName = std::string(L.getName());
// The API here is quite complex to call and we allow to select some
// flavors of unrolling during construction time (by setting UnrollOpts).
LoopUnrollResult Result = tryToUnrollLoop(
diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
index 915e053704b2..645a89bbd0ff 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
@@ -38,11 +38,11 @@
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/MustExecute.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
@@ -158,7 +158,7 @@ namespace {
// Returns true if another unswitching could be done within the cost
// threshold.
- bool CostAllowsUnswitching();
+ bool costAllowsUnswitching();
// Clone all loop-unswitch related loop properties.
// Redistribute unswitching quotas.
@@ -173,20 +173,20 @@ namespace {
AssumptionCache *AC;
// Used to check if second loop needs processing after
- // RewriteLoopBodyWithConditionConstant rewrites first loop.
+ // rewriteLoopBodyWithConditionConstant rewrites first loop.
std::vector<Loop*> LoopProcessWorklist;
LUAnalysisCache BranchesInfo;
bool OptimizeForSize;
- bool redoLoop = false;
+ bool RedoLoop = false;
- Loop *currentLoop = nullptr;
+ Loop *CurrentLoop = nullptr;
DominatorTree *DT = nullptr;
MemorySSA *MSSA = nullptr;
std::unique_ptr<MemorySSAUpdater> MSSAU;
- BasicBlock *loopHeader = nullptr;
- BasicBlock *loopPreheader = nullptr;
+ BasicBlock *LoopHeader = nullptr;
+ BasicBlock *LoopPreheader = nullptr;
bool SanitizeMemory;
SimpleLoopSafetyInfo SafetyInfo;
@@ -198,15 +198,15 @@ namespace {
// NewBlocks contained cloned copy of basic blocks from LoopBlocks.
std::vector<BasicBlock*> NewBlocks;
- bool hasBranchDivergence;
+ bool HasBranchDivergence;
public:
static char ID; // Pass ID, replacement for typeid
- explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false)
+ explicit LoopUnswitch(bool Os = false, bool HasBranchDivergence = false)
: LoopPass(ID), OptimizeForSize(Os),
- hasBranchDivergence(hasBranchDivergence) {
- initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());
+ HasBranchDivergence(HasBranchDivergence) {
+ initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());
}
bool runOnLoop(Loop *L, LPPassManager &LPM) override;
@@ -223,48 +223,46 @@ namespace {
AU.addRequired<MemorySSAWrapperPass>();
AU.addPreserved<MemorySSAWrapperPass>();
}
- if (hasBranchDivergence)
+ if (HasBranchDivergence)
AU.addRequired<LegacyDivergenceAnalysis>();
getLoopAnalysisUsage(AU);
}
private:
- void releaseMemory() override {
- BranchesInfo.forgetLoop(currentLoop);
- }
+ void releaseMemory() override { BranchesInfo.forgetLoop(CurrentLoop); }
void initLoopData() {
- loopHeader = currentLoop->getHeader();
- loopPreheader = currentLoop->getLoopPreheader();
+ LoopHeader = CurrentLoop->getHeader();
+ LoopPreheader = CurrentLoop->getLoopPreheader();
}
/// Split all of the edges from inside the loop to their exit blocks.
/// Update the appropriate Phi nodes as we do so.
- void SplitExitEdges(Loop *L,
+ void splitExitEdges(Loop *L,
const SmallVectorImpl<BasicBlock *> &ExitBlocks);
- bool TryTrivialLoopUnswitch(bool &Changed);
+ bool tryTrivialLoopUnswitch(bool &Changed);
- bool UnswitchIfProfitable(Value *LoopCond, Constant *Val,
+ bool unswitchIfProfitable(Value *LoopCond, Constant *Val,
Instruction *TI = nullptr);
- void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
+ void unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
BasicBlock *ExitBlock, Instruction *TI);
- void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L,
+ void unswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L,
Instruction *TI);
- void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
- Constant *Val, bool isEqual);
+ void rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
+ Constant *Val, bool IsEqual);
- void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
+ void emitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
BasicBlock *TrueDest,
BasicBlock *FalseDest,
BranchInst *OldBranch, Instruction *TI);
- void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L);
+ void simplifyCode(std::vector<Instruction *> &Worklist, Loop *L);
/// Given that the Invariant is not equal to Val. Simplify instructions
/// in the loop.
- Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant,
+ Value *simplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant,
Constant *Val);
};
@@ -347,7 +345,7 @@ bool LUAnalysisCache::isUnswitched(const SwitchInst *SI, const Value *V) {
return (*CurLoopInstructions)[SI].count(V);
}
-bool LUAnalysisCache::CostAllowsUnswitching() {
+bool LUAnalysisCache::costAllowsUnswitching() {
return CurrentLoopProperties->CanBeUnswitchedCount > 0;
}
@@ -396,8 +394,8 @@ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",
false, false)
-Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) {
- return new LoopUnswitch(Os, hasBranchDivergence);
+Pass *llvm::createLoopUnswitchPass(bool Os, bool HasBranchDivergence) {
+ return new LoopUnswitch(Os, HasBranchDivergence);
}
/// Operator chain lattice.
@@ -411,15 +409,15 @@ enum OperatorChain {
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
/// an invariant piece, return the invariant. Otherwise, return null.
//
-/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
-/// mixed operator chain, as we can not reliably find a value which will simplify
-/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
-/// to simplify the chain.
+/// NOTE: findLIVLoopCondition will not return a partial LIV by walking up a
+/// mixed operator chain, as we can not reliably find a value which will
+/// simplify the operator chain. If the chain is AND-only or OR-only, we can use
+/// 0 or ~0 to simplify the chain.
///
/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to
/// simplify the condition itself to a loop variant condition, but at the
/// cost of creating an entirely new loop.
-static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
+static Value *findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
OperatorChain &ParentChain,
DenseMap<Value *, Value *> &Cache,
MemorySSAUpdater *MSSAU) {
@@ -479,7 +477,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
// If either the left or right side is invariant, we can unswitch on this,
// which will cause the branch to go away in one loop and the condition to
// simplify in the other one.
- if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
+ if (Value *LHS = findLIVLoopCondition(BO->getOperand(0), L, Changed,
ParentChain, Cache, MSSAU)) {
Cache[Cond] = LHS;
return LHS;
@@ -487,7 +485,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
// We did not manage to find a partial LIV in operand(0). Backtrack and try
// operand(1).
ParentChain = NewChain;
- if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
+ if (Value *RHS = findLIVLoopCondition(BO->getOperand(1), L, Changed,
ParentChain, Cache, MSSAU)) {
Cache[Cond] = RHS;
return RHS;
@@ -503,11 +501,11 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
/// an invariant piece, return the invariant along with the operator chain type.
/// Otherwise, return null.
static std::pair<Value *, OperatorChain>
-FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
+findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
MemorySSAUpdater *MSSAU) {
DenseMap<Value *, Value *> Cache;
OperatorChain OpChain = OC_OpChainNone;
- Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU);
+ Value *FCond = findLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU);
// In case we do find a LIV, it can not be obtained by walking up a mixed
// operator chain.
@@ -516,22 +514,22 @@ FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
return {FCond, OpChain};
}
-bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
+bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) {
if (skipLoop(L))
return false;
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
*L->getHeader()->getParent());
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LPM = &LPM_Ref;
+ LPM = &LPMRef;
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
if (EnableMSSALoopDependency) {
MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
assert(DT && "Cannot update MemorySSA without a valid DomTree.");
}
- currentLoop = L;
- Function *F = currentLoop->getHeader()->getParent();
+ CurrentLoop = L;
+ Function *F = CurrentLoop->getHeader()->getParent();
SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory);
if (SanitizeMemory)
@@ -542,12 +540,12 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
bool Changed = false;
do {
- assert(currentLoop->isLCSSAForm(*DT));
+ assert(CurrentLoop->isLCSSAForm(*DT));
if (MSSA && VerifyMemorySSA)
MSSA->verifyMemorySSA();
- redoLoop = false;
+ RedoLoop = false;
Changed |= processCurrentLoop();
- } while(redoLoop);
+ } while (RedoLoop);
if (MSSA && VerifyMemorySSA)
MSSA->verifyMemorySSA();
@@ -560,7 +558,7 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) {
auto *Node = DT->getNode(BB)->getIDom();
BasicBlock *DomBB = Node->getBlock();
- while (currentLoop->contains(DomBB)) {
+ while (CurrentLoop->contains(DomBB)) {
BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator());
Node = DT->getNode(DomBB)->getIDom();
@@ -591,7 +589,7 @@ bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) {
/// causing problems. Detail could be found in PR31652. Note if the
/// func returns true, it is unsafe. But if it is false, it doesn't mean
/// it is necessarily safe.
-static bool EqualityPropUnSafe(Value &LoopCond) {
+static bool equalityPropUnSafe(Value &LoopCond) {
ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond);
if (!CI || !CI->isEquality())
return false;
@@ -601,7 +599,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) {
if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS))
return true;
- auto hasUndefInPHI = [](PHINode &PN) {
+ auto HasUndefInPHI = [](PHINode &PN) {
for (Value *Opd : PN.incoming_values()) {
if (isa<UndefValue>(Opd))
return true;
@@ -610,10 +608,10 @@ static bool EqualityPropUnSafe(Value &LoopCond) {
};
PHINode *LPHI = dyn_cast<PHINode>(LHS);
PHINode *RPHI = dyn_cast<PHINode>(RHS);
- if ((LPHI && hasUndefInPHI(*LPHI)) || (RPHI && hasUndefInPHI(*RPHI)))
+ if ((LPHI && HasUndefInPHI(*LPHI)) || (RPHI && HasUndefInPHI(*RPHI)))
return true;
- auto hasUndefInSelect = [](SelectInst &SI) {
+ auto HasUndefInSelect = [](SelectInst &SI) {
if (isa<UndefValue>(SI.getTrueValue()) ||
isa<UndefValue>(SI.getFalseValue()))
return true;
@@ -621,7 +619,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) {
};
SelectInst *LSI = dyn_cast<SelectInst>(LHS);
SelectInst *RSI = dyn_cast<SelectInst>(RHS);
- if ((LSI && hasUndefInSelect(*LSI)) || (RSI && hasUndefInSelect(*RSI)))
+ if ((LSI && HasUndefInSelect(*LSI)) || (RSI && HasUndefInSelect(*RSI)))
return true;
return false;
}
@@ -633,35 +631,36 @@ bool LoopUnswitch::processCurrentLoop() {
initLoopData();
// If LoopSimplify was unable to form a preheader, don't do any unswitching.
- if (!loopPreheader)
+ if (!LoopPreheader)
return false;
// Loops with indirectbr cannot be cloned.
- if (!currentLoop->isSafeToClone())
+ if (!CurrentLoop->isSafeToClone())
return false;
// Without dedicated exits, splitting the exit edge may fail.
- if (!currentLoop->hasDedicatedExits())
+ if (!CurrentLoop->hasDedicatedExits())
return false;
- LLVMContext &Context = loopHeader->getContext();
+ LLVMContext &Context = LoopHeader->getContext();
// Analyze loop cost, and stop unswitching if loop content can not be duplicated.
if (!BranchesInfo.countLoop(
- currentLoop, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
- *currentLoop->getHeader()->getParent()),
+ CurrentLoop,
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+ *CurrentLoop->getHeader()->getParent()),
AC))
return false;
// Try trivial unswitch first before loop over other basic blocks in the loop.
- if (TryTrivialLoopUnswitch(Changed)) {
+ if (tryTrivialLoopUnswitch(Changed)) {
return true;
}
// Do not do non-trivial unswitch while optimizing for size.
// FIXME: Use Function::hasOptSize().
if (OptimizeForSize ||
- loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize))
+ LoopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize))
return false;
// Run through the instructions in the loop, keeping track of three things:
@@ -680,11 +679,12 @@ bool LoopUnswitch::processCurrentLoop() {
SmallVector<IntrinsicInst *, 4> Guards;
- for (const auto BB : currentLoop->blocks()) {
+ for (const auto BB : CurrentLoop->blocks()) {
for (auto &I : *BB) {
- auto CS = CallSite(&I);
- if (!CS) continue;
- if (CS.isConvergent())
+ auto *CB = dyn_cast<CallBase>(&I);
+ if (!CB)
+ continue;
+ if (CB->isConvergent())
return false;
if (auto *II = dyn_cast<InvokeInst>(&I))
if (!II->getUnwindDest()->canSplitPredecessors())
@@ -696,11 +696,11 @@ bool LoopUnswitch::processCurrentLoop() {
}
for (IntrinsicInst *Guard : Guards) {
- Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop,
+ Value *LoopCond = findLIVLoopCondition(Guard->getOperand(0), CurrentLoop,
Changed, MSSAU.get())
.first;
if (LoopCond &&
- UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
+ unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
// NB! Unswitching (if successful) could have erased some of the
// instructions in Guards leaving dangling pointers there. This is fine
// because we're returning now, and won't look at Guards again.
@@ -712,8 +712,9 @@ bool LoopUnswitch::processCurrentLoop() {
// Loop over all of the basic blocks in the loop. If we find an interior
// block that is branching on a loop-invariant condition, we can unswitch this
// loop.
- for (Loop::block_iterator I = currentLoop->block_begin(),
- E = currentLoop->block_end(); I != E; ++I) {
+ for (Loop::block_iterator I = CurrentLoop->block_begin(),
+ E = CurrentLoop->block_end();
+ I != E; ++I) {
Instruction *TI = (*I)->getTerminator();
// Unswitching on a potentially uninitialized predicate is not
@@ -723,7 +724,7 @@ bool LoopUnswitch::processCurrentLoop() {
// This is a workaround for the discrepancy between LLVM IR and MSan
// semantics. See PR28054 for more details.
if (SanitizeMemory &&
- !SafetyInfo.isGuaranteedToExecute(*TI, DT, currentLoop))
+ !SafetyInfo.isGuaranteedToExecute(*TI, DT, CurrentLoop))
continue;
if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
@@ -738,11 +739,11 @@ bool LoopUnswitch::processCurrentLoop() {
if (BI->isConditional()) {
// See if this, or some part of it, is loop invariant. If so, we can
// unswitch on it if we desire.
- Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop,
+ Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop,
Changed, MSSAU.get())
.first;
- if (LoopCond && !EqualityPropUnSafe(*LoopCond) &&
- UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
+ if (LoopCond && !equalityPropUnSafe(*LoopCond) &&
+ unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
++NumBranches;
return true;
}
@@ -752,7 +753,7 @@ bool LoopUnswitch::processCurrentLoop() {
Value *LoopCond;
OperatorChain OpChain;
std::tie(LoopCond, OpChain) =
- FindLIVLoopCondition(SC, currentLoop, Changed, MSSAU.get());
+ findLIVLoopCondition(SC, CurrentLoop, Changed, MSSAU.get());
unsigned NumCases = SI->getNumCases();
if (LoopCond && NumCases) {
@@ -796,7 +797,7 @@ bool LoopUnswitch::processCurrentLoop() {
if (!UnswitchVal)
continue;
- if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {
+ if (unswitchIfProfitable(LoopCond, UnswitchVal)) {
++NumSwitches;
// In case of a full LIV, UnswitchVal is the value we unswitched out.
// In case of a partial LIV, we only unswitch when its an AND-chain
@@ -812,11 +813,11 @@ bool LoopUnswitch::processCurrentLoop() {
for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end();
BBI != E; ++BBI)
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
- Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop,
+ Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop,
Changed, MSSAU.get())
.first;
- if (LoopCond && UnswitchIfProfitable(LoopCond,
- ConstantInt::getTrue(Context))) {
+ if (LoopCond &&
+ unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
++NumSelects;
return true;
}
@@ -875,62 +876,38 @@ static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) {
return nullptr;
}
-/// We have found that we can unswitch currentLoop when LoopCond == Val to
+/// We have found that we can unswitch CurrentLoop when LoopCond == Val to
/// simplify the loop. If we decide that this is profitable,
/// unswitch the loop, reprocess the pieces, then return true.
-bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,
+bool LoopUnswitch::unswitchIfProfitable(Value *LoopCond, Constant *Val,
Instruction *TI) {
// Check to see if it would be profitable to unswitch current loop.
- if (!BranchesInfo.CostAllowsUnswitching()) {
+ if (!BranchesInfo.costAllowsUnswitching()) {
LLVM_DEBUG(dbgs() << "NOT unswitching loop %"
- << currentLoop->getHeader()->getName()
+ << CurrentLoop->getHeader()->getName()
<< " at non-trivial condition '" << *Val
<< "' == " << *LoopCond << "\n"
<< ". Cost too high.\n");
return false;
}
- if (hasBranchDivergence &&
+ if (HasBranchDivergence &&
getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) {
LLVM_DEBUG(dbgs() << "NOT unswitching loop %"
- << currentLoop->getHeader()->getName()
+ << CurrentLoop->getHeader()->getName()
<< " at non-trivial condition '" << *Val
<< "' == " << *LoopCond << "\n"
<< ". Condition is divergent.\n");
return false;
}
- UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI);
+ unswitchNontrivialCondition(LoopCond, Val, CurrentLoop, TI);
return true;
}
-/// Recursively clone the specified loop and all of its children,
-/// mapping the blocks with the specified map.
-static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
- LoopInfo *LI, LPPassManager *LPM) {
- Loop &New = *LI->AllocateLoop();
- if (PL)
- PL->addChildLoop(&New);
- else
- LI->addTopLevelLoop(&New);
- LPM->addLoop(New);
-
- // Add all of the blocks in L to the new loop.
- for (Loop::block_iterator I = L->block_begin(), E = L->block_end();
- I != E; ++I)
- if (LI->getLoopFor(*I) == L)
- New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI);
-
- // Add all of the subloops to the new loop.
- for (Loop *I : *L)
- CloneLoop(I, &New, VM, LI, LPM);
-
- return &New;
-}
-
/// Emit a conditional branch on two values if LIC == Val, branch to TrueDst,
/// otherwise branch to FalseDest. Insert the code immediately before OldBranch
/// and remove (but not erase!) it from the function.
-void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
+void LoopUnswitch::emitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
BasicBlock *TrueDest,
BasicBlock *FalseDest,
BranchInst *OldBranch,
@@ -997,11 +974,11 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
/// that doesn't execute its body has no side-effects), unswitch it. This
/// doesn't involve any code duplication, just moving the conditional branch
/// outside of the loop and updating loop info.
-void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
+void LoopUnswitch::unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
BasicBlock *ExitBlock,
Instruction *TI) {
LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %"
- << loopHeader->getName() << " [" << L->getBlocks().size()
+ << LoopHeader->getName() << " [" << L->getBlocks().size()
<< " blocks] in Function "
<< L->getHeader()->getParent()->getName()
<< " on cond: " << *Val << " == " << *Cond << "\n");
@@ -1011,9 +988,9 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
SEWP->getSE().forgetTopmostLoop(L);
// First step, split the preheader, so that we know that there is a safe place
- // to insert the conditional branch. We will change loopPreheader to have a
+ // to insert the conditional branch. We will change LoopPreheader to have a
// conditional branch on Cond.
- BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get());
+ BasicBlock *NewPH = SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get());
// Now that we have a place to insert the conditional branch, create a place
// to branch to: this is the exit block out of the loop that we should
@@ -1029,22 +1006,21 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
// Okay, now we have a position to branch from and a position to branch to,
// insert the new conditional branch.
- auto *OldBranch = dyn_cast<BranchInst>(loopPreheader->getTerminator());
+ auto *OldBranch = dyn_cast<BranchInst>(LoopPreheader->getTerminator());
assert(OldBranch && "Failed to split the preheader");
- EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI);
- LPM->deleteSimpleAnalysisValue(OldBranch, L);
+ emitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI);
- // EmitPreheaderBranchOnCondition removed the OldBranch from the function.
+ // emitPreheaderBranchOnCondition removed the OldBranch from the function.
// Delete it, as it is no longer needed.
delete OldBranch;
// We need to reprocess this loop, it could be unswitched again.
- redoLoop = true;
+ RedoLoop = true;
// Now that we know that the loop is never entered when this condition is a
// particular value, rewrite the loop with this info. We know that this will
// at least eliminate the old branch.
- RewriteLoopBodyWithConditionConstant(L, Cond, Val, false);
+ rewriteLoopBodyWithConditionConstant(L, Cond, Val, /*IsEqual=*/false);
++NumTrivial;
}
@@ -1055,8 +1031,8 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,
/// produces no code duplications (equivalently, it produces a simpler loop and
/// a new empty loop, which gets deleted). Therefore always unswitch trivial
/// condition.
-bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
- BasicBlock *CurrentBB = currentLoop->getHeader();
+bool LoopUnswitch::tryTrivialLoopUnswitch(bool &Changed) {
+ BasicBlock *CurrentBB = CurrentLoop->getHeader();
Instruction *CurrentTerm = CurrentBB->getTerminator();
LLVMContext &Context = CurrentBB->getContext();
@@ -1081,7 +1057,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
// we can not reach any trivial condition candidates (unfoldable
// branch instructions or switch instructions) and no unswitch
// can happen. Exit and return false.
- if (!currentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second)
+ if (!CurrentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second)
return false;
// Check if this loop will execute any side-effecting instructions (e.g.
@@ -1128,7 +1104,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
if (!BI->isConditional())
return false;
- Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop,
+ Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop,
Changed, MSSAU.get())
.first;
@@ -1141,11 +1117,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
// exit through a unique exit block without having any
// side-effects. If so, determine the value of Cond that causes
// it to do this.
- if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop,
- BI->getSuccessor(0)))) {
+ if ((LoopExitBB =
+ isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(0)))) {
CondVal = ConstantInt::getTrue(Context);
- } else if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop,
- BI->getSuccessor(1)))) {
+ } else if ((LoopExitBB =
+ isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(1)))) {
CondVal = ConstantInt::getFalse(Context);
}
@@ -1154,16 +1130,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin()))
return false; // Can't handle this.
- if (EqualityPropUnSafe(*LoopCond))
+ if (equalityPropUnSafe(*LoopCond))
return false;
- UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
+ unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB,
CurrentTerm);
++NumBranches;
return true;
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
// If this isn't switching on an invariant condition, we can't unswitch it.
- Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop,
+ Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop,
Changed, MSSAU.get())
.first;
@@ -1181,7 +1157,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
for (auto Case : SI->cases()) {
BasicBlock *LoopExitCandidate;
if ((LoopExitCandidate =
- isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) {
+ isTrivialLoopExitBlock(CurrentLoop, Case.getCaseSuccessor()))) {
// Okay, we found a trivial case, remember the value that is trivial.
ConstantInt *CaseVal = Case.getCaseValue();
@@ -1200,7 +1176,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin()))
return false; // Can't handle this.
- UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
+ unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB,
nullptr);
// We are only unswitching full LIV.
@@ -1213,11 +1189,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
/// Split all of the edges from inside the loop to their exit blocks.
/// Update the appropriate Phi nodes as we do so.
-void LoopUnswitch::SplitExitEdges(Loop *L,
- const SmallVectorImpl<BasicBlock *> &ExitBlocks){
+void LoopUnswitch::splitExitEdges(
+ Loop *L, const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
- for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) {
- BasicBlock *ExitBlock = ExitBlocks[i];
+ for (unsigned I = 0, E = ExitBlocks.size(); I != E; ++I) {
+ BasicBlock *ExitBlock = ExitBlocks[I];
SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock),
pred_end(ExitBlock));
@@ -1231,11 +1207,11 @@ void LoopUnswitch::SplitExitEdges(Loop *L,
/// We determined that the loop is profitable to unswitch when LIC equal Val.
/// Split it into loop versions and test the condition outside of either loop.
/// Return the loops created as Out1/Out2.
-void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
+void LoopUnswitch::unswitchNontrivialCondition(Value *LIC, Constant *Val,
Loop *L, Instruction *TI) {
- Function *F = loopHeader->getParent();
+ Function *F = LoopHeader->getParent();
LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %"
- << loopHeader->getName() << " [" << L->getBlocks().size()
+ << LoopHeader->getName() << " [" << L->getBlocks().size()
<< " blocks] in Function " << F->getName() << " when '"
<< *Val << "' == " << *LIC << "\n");
@@ -1253,7 +1229,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
// First step, split the preheader and exit blocks, and add these blocks to
// the LoopBlocks list.
BasicBlock *NewPreheader =
- SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get());
+ SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get());
LoopBlocks.push_back(NewPreheader);
// We want the loop to come after the preheader, but before the exit blocks.
@@ -1264,7 +1240,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
// Split all of the edges from inside the loop to their exit blocks. Update
// the appropriate Phi nodes as we do so.
- SplitExitEdges(L, ExitBlocks);
+ splitExitEdges(L, ExitBlocks);
// The exit blocks may have been changed due to edge splitting, recompute.
ExitBlocks.clear();
@@ -1278,12 +1254,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
// the instructions and blocks.
NewBlocks.reserve(LoopBlocks.size());
ValueToValueMapTy VMap;
- for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) {
- BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], VMap, ".us", F);
+ for (unsigned I = 0, E = LoopBlocks.size(); I != E; ++I) {
+ BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[I], VMap, ".us", F);
NewBlocks.push_back(NewBB);
- VMap[LoopBlocks[i]] = NewBB; // Keep the BB mapping.
- LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], NewBB, L);
+ VMap[LoopBlocks[I]] = NewBB; // Keep the BB mapping.
}
// Splice the newly inserted blocks into the function right before the
@@ -1293,7 +1268,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
NewBlocks[0]->getIterator(), F->end());
// Now we create the new Loop object for the versioned loop.
- Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM);
+ Loop *NewLoop = cloneLoop(L, L->getParentLoop(), VMap, LI, LPM);
// Recalculate unswitching quota, inherit simplified switches info for NewBB,
// Probably clone more loop-unswitch related loop properties.
@@ -1306,10 +1281,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI);
}
- for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) {
- BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[i]]);
+ for (unsigned EBI = 0, EBE = ExitBlocks.size(); EBI != EBE; ++EBI) {
+ BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[EBI]]);
// The new exit block should be in the same loop as the old one.
- if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i]))
+ if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[EBI]))
ExitBBLoop->addBasicBlockToLoop(NewExit, *LI);
assert(NewExit->getTerminator()->getNumSuccessors() == 1 &&
@@ -1319,7 +1294,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
// If the successor of the exit block had PHI nodes, add an entry for
// NewExit.
for (PHINode &PN : ExitSucc->phis()) {
- Value *V = PN.getIncomingValueForBlock(ExitBlocks[i]);
+ Value *V = PN.getIncomingValueForBlock(ExitBlocks[EBI]);
ValueToValueMapTy::iterator It = VMap.find(V);
if (It != VMap.end()) V = It->second;
PN.addIncoming(V, NewExit);
@@ -1340,8 +1315,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
}
// Rewrite the code to refer to itself.
- for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) {
- for (Instruction &I : *NewBlocks[i]) {
+ for (unsigned NBI = 0, NBE = NewBlocks.size(); NBI != NBE; ++NBI) {
+ for (Instruction &I : *NewBlocks[NBI]) {
RemapInstruction(&I, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
if (auto *II = dyn_cast<IntrinsicInst>(&I))
@@ -1351,7 +1326,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
}
// Rewrite the original preheader to select between versions of the loop.
- BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator());
+ BranchInst *OldBR = cast<BranchInst>(LoopPreheader->getTerminator());
assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] &&
"Preheader splitting did not work correctly!");
@@ -1364,9 +1339,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
}
// Emit the new branch that selects between the two versions of this loop.
- EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR,
+ emitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR,
TI);
- LPM->deleteSimpleAnalysisValue(OldBR, L);
if (MSSAU) {
// Update MemoryPhis in Exit blocks.
MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMap, *DT);
@@ -1375,11 +1349,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
}
// The OldBr was replaced by a new one and removed (but not erased) by
- // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it.
+ // emitPreheaderBranchOnCondition. It is no longer needed, so delete it.
delete OldBR;
LoopProcessWorklist.push_back(NewLoop);
- redoLoop = true;
+ RedoLoop = true;
// Keep a WeakTrackingVH holding onto LIC. If the first call to
// RewriteLoopBody
@@ -1390,22 +1364,23 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
// Now we rewrite the original code to know that the condition is true and the
// new code to know that the condition is false.
- RewriteLoopBodyWithConditionConstant(L, LIC, Val, false);
+ rewriteLoopBodyWithConditionConstant(L, LIC, Val, /*IsEqual=*/false);
// It's possible that simplifying one loop could cause the other to be
// changed to another value or a constant. If its a constant, don't simplify
// it.
if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop &&
LICHandle && !isa<Constant>(LICHandle))
- RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true);
+ rewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val,
+ /*IsEqual=*/true);
if (MSSA && VerifyMemorySSA)
MSSA->verifyMemorySSA();
}
/// Remove all instances of I from the worklist vector specified.
-static void RemoveFromWorklist(Instruction *I,
- std::vector<Instruction*> &Worklist) {
+static void removeFromWorklist(Instruction *I,
+ std::vector<Instruction *> &Worklist) {
Worklist.erase(std::remove(Worklist.begin(), Worklist.end(), I),
Worklist.end());
@@ -1413,7 +1388,7 @@ static void RemoveFromWorklist(Instruction *I,
/// When we find that I really equals V, remove I from the
/// program, replacing all uses with V and update the worklist.
-static void ReplaceUsesOfWith(Instruction *I, Value *V,
+static void replaceUsesOfWith(Instruction *I, Value *V,
std::vector<Instruction *> &Worklist, Loop *L,
LPPassManager *LPM, MemorySSAUpdater *MSSAU) {
LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n");
@@ -1426,8 +1401,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V,
// Add users to the worklist which may be simplified now.
for (User *U : I->users())
Worklist.push_back(cast<Instruction>(U));
- LPM->deleteSimpleAnalysisValue(I, L);
- RemoveFromWorklist(I, Worklist);
+ removeFromWorklist(I, Worklist);
I->replaceAllUsesWith(V);
if (!I->mayHaveSideEffects()) {
if (MSSAU)
@@ -1440,7 +1414,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V,
/// We know either that the value LIC has the value specified by Val in the
/// specified loop, or we know it does NOT have that value.
/// Rewrite any uses of LIC or of properties correlated to it.
-void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
+void LoopUnswitch::rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
Constant *Val,
bool IsEqual) {
assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?");
@@ -1478,7 +1452,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
for (Instruction *UI : Worklist)
UI->replaceUsesOfWith(LIC, Replacement);
- SimplifyCode(Worklist, L);
+ simplifyCode(Worklist, L);
return;
}
@@ -1492,7 +1466,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
// At this point, we know LIC is definitely not Val. Try to use some simple
// logic to simplify the user w.r.t. to the context.
- if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) {
+ if (Value *Replacement = simplifyInstructionWithNotEqual(UI, LIC, Val)) {
if (LI->replacementPreservesLCSSAForm(UI, Replacement)) {
// This in-loop instruction has been simplified w.r.t. its context,
// i.e. LIC != Val, make sure we propagate its replacement value to
@@ -1506,7 +1480,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
}
}
- // This is a LIC user, push it into the worklist so that SimplifyCode can
+ // This is a LIC user, push it into the worklist so that simplifyCode can
// attempt to simplify it.
Worklist.push_back(UI);
@@ -1568,7 +1542,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
DT->addNewBlock(Abort, NewSISucc);
}
- SimplifyCode(Worklist, L);
+ simplifyCode(Worklist, L);
}
/// Now that we have simplified some instructions in the loop, walk over it and
@@ -1579,7 +1553,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
/// FIXME: When the loop optimizer is more mature, separate this out to a new
/// pass.
///
-void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
+void LoopUnswitch::simplifyCode(std::vector<Instruction *> &Worklist, Loop *L) {
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
while (!Worklist.empty()) {
Instruction *I = Worklist.back();
@@ -1593,8 +1567,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i)))
Worklist.push_back(Use);
- LPM->deleteSimpleAnalysisValue(I, L);
- RemoveFromWorklist(I, Worklist);
+ removeFromWorklist(I, Worklist);
if (MSSAU)
MSSAU->removeMemoryAccess(I);
I->eraseFromParent();
@@ -1607,7 +1580,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
// 'false'. TODO: update the domtree properly so we can pass it here.
if (Value *V = SimplifyInstruction(I, DL))
if (LI->replacementPreservesLCSSAForm(I, V)) {
- ReplaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get());
+ replaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get());
continue;
}
@@ -1624,9 +1597,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
assert(SinglePred == Pred && "CFG broken");
// Make the LPM and Worklist updates specific to LoopUnswitch.
- LPM->deleteSimpleAnalysisValue(BI, L);
- RemoveFromWorklist(BI, Worklist);
- LPM->deleteSimpleAnalysisValue(Succ, L);
+ removeFromWorklist(BI, Worklist);
auto SuccIt = Succ->begin();
while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) {
for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It)
@@ -1634,8 +1605,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
Worklist.push_back(Use);
for (User *U : PN->users())
Worklist.push_back(cast<Instruction>(U));
- LPM->deleteSimpleAnalysisValue(PN, L);
- RemoveFromWorklist(PN, Worklist);
+ removeFromWorklist(PN, Worklist);
++NumSimplify;
}
// Merge the block and make the remaining analyses updates.
@@ -1652,7 +1622,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
/// Simple simplifications we can do given the information that Cond is
/// definitely not equal to Val.
-Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst,
+Value *LoopUnswitch::simplifyInstructionWithNotEqual(Instruction *Inst,
Value *Invariant,
Constant *Val) {
// icmp eq cond, val -> false
diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
index 7b9af527d444..06b684ef1e70 100644
--- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
@@ -69,7 +69,6 @@
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instruction.h"
diff --git a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp
index ab7b85e89e7b..d1f67b355b19 100644
--- a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp
@@ -117,18 +117,17 @@ static bool LowerStoreInst(StoreInst *SI) {
static bool runOnBasicBlock(BasicBlock &BB) {
bool Changed = false;
- for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE;) {
- Instruction *Inst = &*DI++;
- if (FenceInst *FI = dyn_cast<FenceInst>(Inst))
+ for (Instruction &Inst : make_early_inc_range(BB)) {
+ if (FenceInst *FI = dyn_cast<FenceInst>(&Inst))
Changed |= LowerFenceInst(FI);
- else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst))
+ else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(&Inst))
Changed |= LowerAtomicCmpXchgInst(CXI);
- else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst))
+ else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&Inst))
Changed |= LowerAtomicRMWInst(RMWI);
- else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
+ else if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) {
if (LI->isAtomic())
LowerLoadInst(LI);
- } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(&Inst)) {
if (SI->isAtomic())
LowerStoreInst(SI);
}
diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
index 21c6c32e8e02..fddf28c281fc 100644
--- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
@@ -13,7 +13,9 @@
#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -135,8 +137,12 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI) {
PreservedAnalyses
LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) {
- if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F)))
- return PreservedAnalyses::none();
+ if (lowerConstantIntrinsics(F,
+ AM.getCachedResult<TargetLibraryAnalysis>(F))) {
+ PreservedAnalyses PA;
+ PA.preserve<GlobalsAA>();
+ return PA;
+ }
return PreservedAnalyses::all();
}
@@ -145,7 +151,7 @@ namespace {
/// Legacy pass for lowering is.constant intrinsics out of the IR.
///
/// When this pass is run over a function it converts is.constant intrinsics
-/// into 'true' or 'false'. This is completements the normal constand folding
+/// into 'true' or 'false'. This complements the normal constant folding
/// to 'true' as part of Instruction Simplify passes.
class LowerConstantIntrinsics : public FunctionPass {
public:
@@ -159,6 +165,10 @@ public:
const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
return lowerConstantIntrinsics(F, TLI);
}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addPreserved<GlobalsAAWrapperPass>();
+ }
};
} // namespace
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 53671c7bc3d1..0fe7dd9cfb39 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -55,13 +55,35 @@ static cl::opt<uint32_t> UnlikelyBranchWeight(
"unlikely-branch-weight", cl::Hidden, cl::init(1),
cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
+static std::tuple<uint32_t, uint32_t>
+getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
+ if (IntrinsicID == Intrinsic::expect) {
+ // __builtin_expect
+ return std::make_tuple(LikelyBranchWeight.getValue(),
+ UnlikelyBranchWeight.getValue());
+ } else {
+ // __builtin_expect_with_probability
+ assert(CI->getNumOperands() >= 3 &&
+ "expect with probability must have 3 arguments");
+ ConstantFP *Confidence = dyn_cast<ConstantFP>(CI->getArgOperand(2));
+ double TrueProb = Confidence->getValueAPF().convertToDouble();
+ assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
+ "probability value must be in the range [0.0, 1.0]");
+ double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
+ uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
+ uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
+ return std::make_tuple(LikelyBW, UnlikelyBW);
+ }
+}
+
static bool handleSwitchExpect(SwitchInst &SI) {
CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
if (!CI)
return false;
Function *Fn = CI->getCalledFunction();
- if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
+ if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
+ Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
return false;
Value *ArgValue = CI->getArgOperand(0);
@@ -71,15 +93,19 @@ static bool handleSwitchExpect(SwitchInst &SI) {
SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
unsigned n = SI.getNumCases(); // +1 for default case.
- SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight);
+ uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+ std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
+ getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
+
+ SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
- Weights[Index] = LikelyBranchWeight;
+ Weights[Index] = LikelyBranchWeightVal;
- SI.setMetadata(
- LLVMContext::MD_misexpect,
- MDBuilder(CI->getContext())
- .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight));
+ SI.setMetadata(LLVMContext::MD_misexpect,
+ MDBuilder(CI->getContext())
+ .createMisExpect(Index, LikelyBranchWeightVal,
+ UnlikelyBranchWeightVal));
SI.setCondition(ArgValue);
misexpect::checkFrontendInstrumentation(SI);
@@ -223,15 +249,18 @@ static void handlePhiDef(CallInst *Expect) {
return true;
return false;
};
+ uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+ std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
+ Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
- BI->setMetadata(
- LLVMContext::MD_prof,
- MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight));
+ BI->setMetadata(LLVMContext::MD_prof,
+ MDB.createBranchWeights(LikelyBranchWeightVal,
+ UnlikelyBranchWeightVal));
else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
- BI->setMetadata(
- LLVMContext::MD_prof,
- MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight));
+ BI->setMetadata(LLVMContext::MD_prof,
+ MDB.createBranchWeights(UnlikelyBranchWeightVal,
+ LikelyBranchWeightVal));
}
}
@@ -277,7 +306,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
}
Function *Fn = CI->getCalledFunction();
- if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
+ if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
+ Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
return false;
Value *ArgValue = CI->getArgOperand(0);
@@ -289,13 +319,21 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
MDNode *Node;
MDNode *ExpNode;
+ uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+ std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
+ getBranchWeight(Fn->getIntrinsicID(), CI, 2);
+
if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
(Predicate == CmpInst::ICMP_EQ)) {
- Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight);
- ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight);
+ Node =
+ MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
+ ExpNode =
+ MDB.createMisExpect(0, LikelyBranchWeightVal, UnlikelyBranchWeightVal);
} else {
- Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight);
- ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight);
+ Node =
+ MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
+ ExpNode =
+ MDB.createMisExpect(1, LikelyBranchWeightVal, UnlikelyBranchWeightVal);
}
BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode);
@@ -347,7 +385,8 @@ static bool lowerExpectIntrinsic(Function &F) {
}
Function *Fn = CI->getCalledFunction();
- if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) {
+ if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
+ Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
// Before erasing the llvm.expect, walk backward to find
// phi that define llvm.expect's first arg, and
// infer branch probability:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 0ff6ee8bcfcc..90314b17b5e2 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -9,8 +9,11 @@
// Lower matrix intrinsics to vector operations.
//
// TODO:
-// * Implement multiply & add fusion
-// * Add remark, summarizing the available matrix optimization opportunities.
+// * Improve fusion:
+// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
+// transposed.
+// * Improve cost-modeling, e.g. choose different number of rows/columns
+// columns for tiles, consider cost of copies on alias.
//
//===----------------------------------------------------------------------===//
@@ -18,10 +21,15 @@
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -29,30 +37,69 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "lower-matrix-intrinsics"
-static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape",
- cl::init(true));
-
+static cl::opt<bool> EnableShapePropagation(
+ "matrix-propagate-shape", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
+ "instructions."));
+
+static cl::opt<bool>
+ FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable fusing matrix instructions."));
+// TODO: Allow and use non-square tiles.
+static cl::opt<unsigned> TileSize(
+ "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
+ cl::desc(
+ "Tile size for matrix instruction fusion using square-shaped tiles."));
+static cl::opt<bool> ForceFusion(
+ "force-fuse-matrix", cl::init(false), cl::Hidden,
+ cl::desc("Force matrix instruction fusion even if not profitable."));
static cl::opt<bool> AllowContractEnabled(
"matrix-allow-contract", cl::init(false), cl::Hidden,
cl::desc("Allow the use of FMAs if available and profitable. This may "
"result in different results, due to less rounding error."));
+enum class MatrixLayoutTy { ColumnMajor, RowMajor };
+
+static cl::opt<MatrixLayoutTy> MatrixLayout(
+ "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
+ cl::desc("Sets the default matrix layout"),
+ cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
+ "Use column-major layout"),
+ clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
+ "Use row-major layout")));
+
+/// Helper function to either return Scope, if it is a subprogram or the
+/// attached subprogram for a local scope.
+static DISubprogram *getSubprogram(DIScope *Scope) {
+ if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
+ return Subprogram;
+ return cast<DILocalScope>(Scope)->getSubprogram();
+}
+
namespace {
-// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
-// the start address of column \p Col with type (\p EltType x \p NumRows)
-// assuming \p Stride elements between start two consecutive columns.
-// \p Stride must be >= \p NumRows.
+// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
+// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
+// assuming \p Stride elements between start two consecutive vectors.
+// \p Stride must be >= \p NumElements.
+// For column-major matrixes, the function computes the address of a column
+// vectors and \p NumElements must be set to the number of elements in a column
+// (= number of rows of the matrix). For row-major matrixes, the function
+// computes the address of a row vector and \p NumElements must be set to the
+// number of elements in a column (= number of columns of the matrix).
//
-// Consider a 4x4 matrix like below
+// Consider a 4x4 matrix in column-mjaor layout like below
//
// 0 1 2 3
// 0 v_0_0 v_0_1 v_0_2 v_0_3
@@ -62,14 +109,14 @@ namespace {
// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
// we need a pointer to the first element of the submatrix as base pointer.
-// Then we can use computeColumnAddr to compute the addresses for the columns
+// Then we can use computeVectorAddr to compute the addresses for the columns
// of the sub-matrix.
//
-// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
+// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
// -> just returns Base
-// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
+// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
// -> returns Base + (1 * 4)
-// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
+// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
// -> returns Base + (2 * 4)
//
// The graphic below illustrates the number of elements in a column (marked
@@ -82,30 +129,30 @@ namespace {
// v_2_0 |v_2_1 |v_2_2 |v_2_3
// v_3_0 {v_3_1 {v_3_2 v_3_3
//
-Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
- unsigned NumRows, Type *EltType,
+Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
+ unsigned NumElements, Type *EltType,
IRBuilder<> &Builder) {
assert((!isa<ConstantInt>(Stride) ||
- cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
- "Stride must be >= the number of rows.");
+ cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
+ "Stride must be >= the number of elements in the result vector.");
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- // Compute the start of the column with index Col as Col * Stride.
- Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
+ // Compute the start of the vector with index VecIdx as VecIdx * Stride.
+ Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
- // Get pointer to the start of the selected column. Skip GEP creation,
- // if we select column 0.
- if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
- ColumnStart = BasePtr;
+ // Get pointer to the start of the selected vector. Skip GEP creation,
+ // if we select vector 0.
+ if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
+ VecStart = BasePtr;
else
- ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
+ VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
- // Cast elementwise column start pointer to a pointer to a column
- // (EltType x NumRows)*.
- Type *ColumnType = VectorType::get(EltType, NumRows);
- Type *ColumnPtrType = PointerType::get(ColumnType, AS);
- return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
+ // Cast elementwise vector start pointer to a pointer to a vector
+ // (EltType x NumElements)*.
+ auto *VecType = FixedVectorType::get(EltType, NumElements);
+ Type *VecPtrType = PointerType::get(VecType, AS);
+ return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
}
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -113,15 +160,16 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
/// Currently, the lowering for each matrix intrinsic is done as follows:
/// 1. Propagate the shape information from intrinsics to connected
/// instructions.
-/// 2. Lower instructions with shape information.
+/// 2. Lower instructions with shape information (assuming column-major layout).
+/// The lowering works similarly using row-major layout.
/// 2.1. Get column vectors for each argument. If we already lowered the
/// definition of an argument, use the produced column vectors directly.
/// If not, split the operand vector containing an embedded matrix into
/// a set of column vectors,
-/// 2.2. Lower the instruction in terms of columnwise operations, which yields
-/// a set of column vectors containing result matrix. Note that we lower
-/// all instructions that have shape information. Besides the intrinsics,
-/// this includes stores for example.
+/// 2.2. Lower the instruction in terms of column major operations, which
+/// yields a set of column vectors containing result matrix. Note that we
+/// lower all instructions that have shape information. Besides the
+/// intrinsics, this includes stores for example.
/// 2.3. Update uses of the lowered instruction. If we have shape information
/// for a user, there is nothing to do, as we will look up the result
/// column matrix when lowering the user. For other uses, we embed the
@@ -134,42 +182,157 @@ class LowerMatrixIntrinsics {
Function &Func;
const DataLayout &DL;
const TargetTransformInfo &TTI;
+ AliasAnalysis &AA;
+ DominatorTree &DT;
+ LoopInfo &LI;
+ OptimizationRemarkEmitter &ORE;
+
+ /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
+ struct OpInfoTy {
+ /// Number of stores emitted to generate this matrix.
+ unsigned NumStores = 0;
+ /// Number of loads emitted to generate this matrix.
+ unsigned NumLoads = 0;
+ /// Number of compute operations emitted to generate this matrix.
+ unsigned NumComputeOps = 0;
+
+ OpInfoTy &operator+=(const OpInfoTy &RHS) {
+ NumStores += RHS.NumStores;
+ NumLoads += RHS.NumLoads;
+ NumComputeOps += RHS.NumComputeOps;
+ return *this;
+ }
+ };
+
+ /// Wrapper class representing a matrix as a set of vectors, either in row or
+ /// column major layout. All vectors must have the same vector type.
+ class MatrixTy {
+ SmallVector<Value *, 16> Vectors;
+
+ OpInfoTy OpInfo;
- /// Wrapper class representing a matrix as a set of column vectors.
- /// All column vectors must have the same vector type.
- class ColumnMatrixTy {
- SmallVector<Value *, 16> Columns;
+ bool IsColumnMajor = true;
public:
- ColumnMatrixTy() : Columns() {}
- ColumnMatrixTy(ArrayRef<Value *> Cols)
- : Columns(Cols.begin(), Cols.end()) {}
+ MatrixTy()
+ : Vectors(),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(ArrayRef<Value *> Vectors)
+ : Vectors(Vectors.begin(), Vectors.end()),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
+ : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
+
+ unsigned D = isColumnMajor() ? NumColumns : NumRows;
+ for (unsigned J = 0; J < D; ++J)
+ addVector(UndefValue::get(FixedVectorType::get(
+ EltTy, isColumnMajor() ? NumRows : NumColumns)));
+ }
+
+ Value *getVector(unsigned i) const { return Vectors[i]; }
+ Value *getColumn(unsigned i) const {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return Vectors[i];
+ }
+ Value *getRow(unsigned i) const {
+ assert(!isColumnMajor() && "only supported for row-major matrixes");
+ return Vectors[i];
+ }
- Value *getColumn(unsigned i) const { return Columns[i]; }
+ void setVector(unsigned i, Value *V) { Vectors[i] = V; }
- void setColumn(unsigned i, Value *V) { Columns[i] = V; }
+ Type *getElementType() { return getVectorTy()->getElementType(); }
- size_t getNumColumns() const { return Columns.size(); }
- size_t getNumRows() const {
- assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
- return cast<VectorType>(Columns[0]->getType())->getNumElements();
+ unsigned getNumVectors() const {
+ if (isColumnMajor())
+ return getNumColumns();
+ return getNumRows();
}
- const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
+ unsigned getNumColumns() const {
+ if (isColumnMajor())
+ return Vectors.size();
+ else {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ }
+ }
+ unsigned getNumRows() const {
+ if (isColumnMajor()) {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ } else
+ return Vectors.size();
+ }
- SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
+ void addVector(Value *V) { Vectors.push_back(V); }
+ VectorType *getColumnTy() {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return getVectorTy();
+ }
- void addColumn(Value *V) { Columns.push_back(V); }
+ VectorType *getVectorTy() {
+ return cast<VectorType>(Vectors[0]->getType());
+ }
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
- return make_range(Columns.begin(), Columns.end());
+ assert(isColumnMajor() &&
+ "columns() only supported for column-major matrixes");
+ return make_range(Vectors.begin(), Vectors.end());
}
- /// Embed the columns of the matrix into a flat vector by concatenating
+ iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
+ return make_range(Vectors.begin(), Vectors.end());
+ }
+
+ /// Embed the vectors of the matrix into a flat vector by concatenating
/// them.
Value *embedInVector(IRBuilder<> &Builder) const {
- return Columns.size() == 1 ? Columns[0]
- : concatenateVectors(Builder, Columns);
+ return Vectors.size() == 1 ? Vectors[0]
+ : concatenateVectors(Builder, Vectors);
+ }
+
+ MatrixTy &addNumLoads(unsigned N) {
+ OpInfo.NumLoads += N;
+ return *this;
+ }
+
+ void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
+
+ MatrixTy &addNumStores(unsigned N) {
+ OpInfo.NumStores += N;
+ return *this;
+ }
+
+ MatrixTy &addNumComputeOps(unsigned N) {
+ OpInfo.NumComputeOps += N;
+ return *this;
+ }
+
+ unsigned getNumStores() const { return OpInfo.NumStores; }
+ unsigned getNumLoads() const { return OpInfo.NumLoads; }
+ unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
+
+ const OpInfoTy &getOpInfo() const { return OpInfo; }
+
+ bool isColumnMajor() const { return IsColumnMajor; }
+
+ unsigned getStride() const {
+ if (isColumnMajor())
+ return getNumRows();
+ return getNumColumns();
+ }
+
+ /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
+ /// matrix is column-major, the result vector is extracted from a column
+ /// vector, otherwise from a row vector.
+ Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
+ IRBuilder<> &Builder) const {
+ Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
+ Value *Undef = UndefValue::get(Vec->getType());
+ return Builder.CreateShuffleVector(
+ Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
+ "block");
}
};
@@ -177,12 +340,15 @@ class LowerMatrixIntrinsics {
unsigned NumRows;
unsigned NumColumns;
+ bool IsColumnMajor;
+
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns) {}
+ : NumRows(NumRows), NumColumns(NumColumns),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
ShapeInfo(Value *NumRows, Value *NumColumns)
- : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
- NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
bool operator==(const ShapeInfo &other) {
return NumRows == other.NumRows && NumColumns == other.NumColumns;
@@ -195,12 +361,24 @@ class LowerMatrixIntrinsics {
assert(NumRows == 0 || NumColumns != 0);
return NumRows != 0;
}
+
+ unsigned getStride() const {
+ if (IsColumnMajor)
+ return NumRows;
+ return NumColumns;
+ }
+
+ unsigned getNumVectors() const {
+ if (IsColumnMajor)
+ return NumColumns;
+ return NumRows;
+ }
};
/// Maps instructions to their shape information. The shape information
/// describes the shape to be used while lowering. This matches the shape of
/// the result value of the instruction, with the only exceptions being store
- /// instructions and the matrix_columnwise_store intrinsics. For those, the
+ /// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
/// using shape information as well.
DenseMap<Value *, ShapeInfo> ShapeMap;
@@ -211,31 +389,49 @@ class LowerMatrixIntrinsics {
SmallVector<Instruction *, 16> ToRemove;
/// Map from instructions to their produced column matrix.
- DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
+ MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
public:
- LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
- : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
+ LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
+ AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI,
+ OptimizationRemarkEmitter &ORE)
+ : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
+ LI(LI), ORE(ORE) {}
+
+ unsigned getNumOps(Type *VT) {
+ assert(isa<VectorType>(VT) && "Expected vector type");
+ return getNumOps(VT->getScalarType(),
+ cast<FixedVectorType>(VT)->getNumElements());
+ }
- /// Return the set of column vectors that a matrix value is lowered to.
+ //
+ /// Return the estimated number of vector ops required for an operation on
+ /// \p VT * N.
+ unsigned getNumOps(Type *ST, unsigned N) {
+ return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
+ double(TTI.getRegisterBitWidth(true)));
+ }
+
+ /// Return the set of vectors that a matrix value is lowered to.
///
- /// If we lowered \p MatrixVal, just return the cache result column matrix.
- /// Otherwie split the flat vector \p MatrixVal containing a matrix with
- /// shape \p SI into column vectors.
- ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
- IRBuilder<> Builder) {
+ /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
+ /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
+ /// into vectors.
+ MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
+ IRBuilder<> &Builder) {
VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
assert(VType && "MatrixVal must be a vector type");
- assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
+ assert(cast<FixedVectorType>(VType)->getNumElements() ==
+ SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements");
// Check if we lowered MatrixVal using shape information. In that case,
- // return the existing column matrix, if it matches the requested shape
+ // return the existing matrix, if it matches the requested shape
// information. If there is a mis-match, embed the result in a flat
// vector and split it later.
auto Found = Inst2ColumnMatrix.find(MatrixVal);
if (Found != Inst2ColumnMatrix.end()) {
- ColumnMatrixTy &M = Found->second;
+ MatrixTy &M = Found->second;
// Return the found matrix, if its shape matches the requested shape
// information
if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
@@ -247,10 +443,12 @@ public:
// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
Value *Undef = UndefValue::get(VType);
- for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
- MaskStart += SI.NumRows) {
- Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
- Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
+ for (unsigned MaskStart = 0;
+ MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+ MaskStart += SI.getStride()) {
+ Value *V = Builder.CreateShuffleVector(
+ MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0),
+ "split");
SplitVecs.push_back(V);
}
@@ -308,8 +506,8 @@ public:
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_columnwise_load:
- case Intrinsic::matrix_columnwise_store:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
@@ -348,13 +546,13 @@ public:
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
// Flip dimensions.
Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(M), m_Value(N)))) {
+ m_Value(), m_Value(M), m_Value(N)))) {
Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst,
- m_Intrinsic<Intrinsic::matrix_columnwise_load>(
- m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_Value(), m_Value(), m_Value(M),
+ m_Value(N)))) {
Propagate = setShapeInfo(Inst, {M, N});
} else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
auto OpShape = ShapeMap.find(MatrixA);
@@ -426,14 +624,14 @@ public:
// Flip dimensions.
if (setShapeInfo(MatrixA, {M, N}))
pushInstruction(MatrixA, WorkList);
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
m_Value(M), m_Value(N)))) {
if (setShapeInfo(MatrixA, {M, N})) {
pushInstruction(MatrixA, WorkList);
}
} else if (isa<LoadInst>(V) ||
- match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
+ match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
// Nothing to do, no matrix input.
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
@@ -472,8 +670,8 @@ public:
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_columnwise_load:
- case Intrinsic::matrix_columnwise_store:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
WorkList.push_back(&Inst);
break;
default:
@@ -487,45 +685,57 @@ public:
}
}
- ReversePostOrderTraversal<Function *> RPOT(&Func);
bool Changed = false;
- for (auto *BB : RPOT) {
- for (Instruction &Inst : make_early_inc_range(*BB)) {
- IRBuilder<> Builder(&Inst);
-
- if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
- Changed |= VisitCallInst(CInst);
-
- Value *Op1;
- Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
- Changed |= VisitBinaryOperator(BinOp);
- if (match(&Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(&Inst, Op1, Builder);
- else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(&Inst, Op1, Op2, Builder);
+ SmallVector<CallInst *, 16> MaybeFusableInsts;
+ SmallVector<Instruction *, 16> MatrixInsts;
+
+ // First, collect all instructions with shape information and candidates for
+ // fusion (currently only matrix multiplies).
+ ReversePostOrderTraversal<Function *> RPOT(&Func);
+ for (auto *BB : RPOT)
+ for (Instruction &I : *BB) {
+ if (ShapeMap.find(&I) == ShapeMap.end())
+ continue;
+ if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
+ MaybeFusableInsts.push_back(cast<CallInst>(&I));
+ MatrixInsts.push_back(&I);
}
+
+ // Second, try to fuse candidates.
+ SmallPtrSet<Instruction *, 16> FusedInsts;
+ for (CallInst *CI : MaybeFusableInsts)
+ LowerMatrixMultiplyFused(CI, FusedInsts);
+ Changed = !FusedInsts.empty();
+
+ // Third, lower remaining instructions with shape information.
+ for (Instruction *Inst : MatrixInsts) {
+ if (FusedInsts.count(Inst))
+ continue;
+
+ IRBuilder<> Builder(Inst);
+
+ if (CallInst *CInst = dyn_cast<CallInst>(Inst))
+ Changed |= VisitCallInst(CInst);
+
+ Value *Op1;
+ Value *Op2;
+ if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+ Changed |= VisitBinaryOperator(BinOp);
+ if (match(Inst, m_Load(m_Value(Op1))))
+ Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
+ Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
}
+ RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
+ RemarkGen.emitRemarks();
+
for (Instruction *Inst : reverse(ToRemove))
Inst->eraseFromParent();
return Changed;
}
- LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
- IRBuilder<> Builder) {
- unsigned Align = DL.getABITypeAlignment(EltType);
- return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load");
- }
-
- StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
- Type *EltType, IRBuilder<> Builder) {
- unsigned Align = DL.getABITypeAlignment(EltType);
- return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
- }
-
-
/// Turns \p BasePtr into an elementwise pointer to \p EltType.
Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
@@ -545,11 +755,11 @@ public:
case Intrinsic::matrix_transpose:
LowerTranspose(Inst);
break;
- case Intrinsic::matrix_columnwise_load:
- LowerColumnwiseLoad(Inst);
+ case Intrinsic::matrix_column_major_load:
+ LowerColumnMajorLoad(Inst);
break;
- case Intrinsic::matrix_columnwise_store:
- LowerColumnwiseStore(Inst);
+ case Intrinsic::matrix_column_major_store:
+ LowerColumnMajorStore(Inst);
break;
default:
return false;
@@ -557,108 +767,200 @@ public:
return true;
}
- void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Inst->getType());
+ /// Compute the alignment for a column/row \p Idx with \p Stride between them.
+ /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
+ /// ConstantInt, reduce the initial alignment based on the byte offset. For
+ /// non-ConstantInt strides, return the common alignment of the initial
+ /// alignment and the element size in bytes.
+ Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
+ MaybeAlign A) const {
+ Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
+ if (Idx == 0)
+ return InitialAlign;
+
+ TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
+ if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
+ uint64_t StrideInBytes =
+ ConstStride->getZExtValue() * ElementSizeInBits / 8;
+ return commonAlignment(InitialAlign, Idx * StrideInBytes);
+ }
+ return commonAlignment(InitialAlign, ElementSizeInBits / 8);
+ }
+
+ /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- ColumnMatrixTy Result;
- // Distance between start of one column and the start of the next
- for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
- Value *GEP =
- computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
- VType->getElementType(), Builder);
- Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
- Result.addColumn(Column);
+ MatrixTy Result;
+ for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
+ Shape.getStride(), VType->getElementType(),
+ Builder);
+ Value *Vector = Builder.CreateAlignedLoad(
+ GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
+ IsVolatile, "col.load");
+
+ Result.addVector(Vector);
}
+ return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
+ }
- finalizeLowering(Inst, Result, Builder);
+ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
+ /// starting at \p MatrixPtr[I][J].
+ MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
+ ShapeInfo MatrixShape, Value *I, Value *J,
+ ShapeInfo ResultShape, Type *EltTy,
+ IRBuilder<> &Builder) {
+
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
+ ResultShape.NumColumns);
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ return loadMatrix(TileTy, TilePtr, Align,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+ ResultShape, Builder);
+ }
+
+ /// Lower a load instruction with shape information.
+ void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ finalizeLowering(Inst,
+ loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
+ Shape, Builder),
+ Builder);
}
- /// Lowers llvm.matrix.columnwise.load.
+ /// Lowers llvm.matrix.column.major.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
- void LowerColumnwiseLoad(CallInst *Inst) {
+ void LowerColumnMajorLoad(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
- LowerLoad(Inst, Ptr, Stride,
- {Inst->getArgOperand(2), Inst->getArgOperand(3)});
+ LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
+ {Inst->getArgOperand(3), Inst->getArgOperand(4)});
}
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Matrix->getType());
+ /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
+ /// MatrixPtr[I][J].
+ void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+ MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
+ Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
+ StoreVal.getNumColumns());
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+ }
+
+ /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
+ MaybeAlign MAlign, Value *Stride, bool IsVolatile,
+ IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- auto LM = getMatrix(Matrix, Shape, Builder);
- for (auto C : enumerate(LM.columns())) {
- Value *GEP =
- computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
- Shape.NumRows, VType->getElementType(), Builder);
- createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
+ for (auto Vec : enumerate(StoreVal.vectors())) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
+ Stride, StoreVal.getStride(),
+ VType->getElementType(), Builder);
+ Builder.CreateAlignedStore(Vec.value(), GEP,
+ getAlignForIndex(Vec.index(), Stride,
+ VType->getElementType(),
+ MAlign),
+ IsVolatile);
}
+ return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
+ StoreVal.getNumVectors());
+ }
- ToRemove.push_back(Inst);
+ /// Lower a store instruction with shape information.
+ void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ auto StoreVal = getMatrix(Matrix, Shape, Builder);
+ finalizeLowering(Inst,
+ storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
+ IsVolatile, Builder),
+ Builder);
}
- /// Lowers llvm.matrix.columnwise.store.
+ /// Lowers llvm.matrix.column.major.store.
///
/// The intrinsic store a matrix back memory using a stride between columns.
- void LowerColumnwiseStore(CallInst *Inst) {
+ void LowerColumnMajorStore(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
Value *Matrix = Inst->getArgOperand(0);
Value *Ptr = Inst->getArgOperand(1);
Value *Stride = Inst->getArgOperand(2);
- LowerStore(Inst, Matrix, Ptr, Stride,
- {Inst->getArgOperand(3), Inst->getArgOperand(4)});
- }
-
- /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
- /// the matrix \p LM represented as a vector of column vectors.
- Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
- unsigned NumElts, IRBuilder<> Builder) {
- Value *Col = LM.getColumn(J);
- Value *Undef = UndefValue::get(Col->getType());
- Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
- return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
+ LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
+ {Inst->getArgOperand(4), Inst->getArgOperand(5)});
}
// Set elements I..I+NumElts-1 to Block
Value *insertVector(Value *Col, unsigned I, Value *Block,
- IRBuilder<> Builder) {
+ IRBuilder<> &Builder) {
// First, bring Block to the same size as Col
unsigned BlockNumElts =
- cast<VectorType>(Block->getType())->getNumElements();
- unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
+ cast<FixedVectorType>(Block->getType())->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
assert(NumElts >= BlockNumElts && "Too few elements for current block");
- Value *ExtendMask =
- createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
Value *Undef = UndefValue::get(Block->getType());
- Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
+ Block = Builder.CreateShuffleVector(
+ Block, Undef,
+ createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
// If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
// 8, 4, 5, 6
- SmallVector<Constant *, 16> Mask;
+ SmallVector<int, 16> Mask;
unsigned i;
for (i = 0; i < I; i++)
- Mask.push_back(Builder.getInt32(i));
+ Mask.push_back(i);
- unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
+ unsigned VecNumElts =
+ cast<FixedVectorType>(Col->getType())->getNumElements();
for (; i < I + BlockNumElts; i++)
- Mask.push_back(Builder.getInt32(i - I + VecNumElts));
+ Mask.push_back(i - I + VecNumElts);
for (; i < VecNumElts; i++)
- Mask.push_back(Builder.getInt32(i));
-
- Value *MaskVal = ConstantVector::get(Mask);
+ Mask.push_back(i);
- return Builder.CreateShuffleVector(Col, Block, MaskVal);
+ return Builder.CreateShuffleVector(Col, Block, Mask);
}
Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
- IRBuilder<> &Builder, bool AllowContraction) {
-
+ IRBuilder<> &Builder, bool AllowContraction,
+ unsigned &NumComputeOps) {
+ NumComputeOps += getNumOps(A->getType());
if (!Sum)
return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
@@ -666,14 +968,16 @@ public:
if (AllowContraction) {
// Use fmuladd for floating point operations and let the backend decide
// if that's profitable.
- Value *FMulAdd = Intrinsic::getDeclaration(
+ Function *FMulAdd = Intrinsic::getDeclaration(
Func.getParent(), Intrinsic::fmuladd, A->getType());
return Builder.CreateCall(FMulAdd, {A, B, Sum});
}
+ NumComputeOps += getNumOps(A->getType());
Value *Mul = Builder.CreateFMul(A, B);
return Builder.CreateFAdd(Sum, Mul);
}
+ NumComputeOps += getNumOps(A->getType());
Value *Mul = Builder.CreateMul(A, B);
return Builder.CreateAdd(Sum, Mul);
}
@@ -683,7 +987,7 @@ public:
/// cached value when they are lowered. For other users, \p Matrix is
/// flattened and the uses are updated to use it. Also marks \p Inst for
/// deletion.
- void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
+ void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
IRBuilder<> &Builder) {
Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
@@ -699,6 +1003,294 @@ public:
}
}
+ /// Compute \p Result += \p A * \p B for input matrices with left-associating
+ /// addition.
+ void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
+ const MatrixTy &B, bool AllowContraction,
+ IRBuilder<> &Builder, bool isTiled) {
+ const unsigned VF = std::max<unsigned>(
+ TTI.getRegisterBitWidth(true) /
+ Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+ unsigned R = Result.getNumRows();
+ unsigned C = Result.getNumColumns();
+ unsigned M = A.getNumColumns();
+
+ bool IsFP = Result.getElementType()->isFloatingPointTy();
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
+ unsigned NumComputeOps = 0;
+ if (A.isColumnMajor()) {
+ // Multiply columns from the first operand with scalars from the second
+ // operand. Then move along the K axes and accumulate the columns. With
+ // this the adds can be vectorized without reassociation.
+ for (unsigned J = 0; J < C; ++J) {
+ unsigned BlockSize = VF;
+ // If Result is zero, we don't need to accumulate in the K==0 iteration.
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+
+ for (unsigned I = 0; I < R; I += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (I + BlockSize > R)
+ BlockSize /= 2;
+
+ Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
+ : nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *L = A.extractVector(I, K, BlockSize, Builder);
+ Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
+ Result.getElementType()->isFloatingPointTy(),
+ Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(J,
+ insertVector(Result.getVector(J), I, Sum, Builder));
+ }
+ }
+ } else {
+ // Multiply rows from the second operand with scalars from the first
+ // operand. Then move along the K axes and accumulate the rows. With this
+ // the adds can be vectorized without reassociation.
+ for (unsigned I = 0; I < R; ++I) {
+ unsigned BlockSize = VF;
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
+ for (unsigned J = 0; J < C; J += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (J + BlockSize > C)
+ BlockSize /= 2;
+
+ Value *Sum = nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *R = B.extractVector(K, J, BlockSize, Builder);
+ Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
+ IsFP, Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(I,
+ insertVector(Result.getVector(I), J, Sum, Builder));
+ }
+ }
+ }
+ Result.addNumComputeOps(NumComputeOps);
+ }
+
+ /// Ensure that the memory in \p Load does not alias \p Store by potentially
+ /// copying it to a new location. This new or otherwise the original location
+ /// is returned.
+ Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
+ CallInst *MatMul) {
+ MemoryLocation StoreLoc = MemoryLocation::get(Store);
+ MemoryLocation LoadLoc = MemoryLocation::get(Load);
+
+ AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc);
+
+ // If we can statically determine noalias we're good.
+ if (!LdAliased)
+ return Load->getPointerOperand();
+
+ // Create code to check if the memory locations of the Load and Store
+ // overlap and if they do, copy Load's operand to a new buffer.
+
+ // First, create new blocks for 2n part of the check and the copy.
+ BasicBlock *Check0 = MatMul->getParent();
+ // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
+ // DT. Manually collect dominator tree updates, to avoid unnecessary work,
+ // as we adjust Check0 and Check1's branches.
+ SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
+ for (BasicBlock *Succ : successors(Check0))
+ DTUpdates.push_back({DT.Delete, Check0, Succ});
+
+ BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ nullptr, "alias_cont");
+ BasicBlock *Copy =
+ SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy");
+ BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ nullptr, "no_alias");
+
+ // Check if the loaded memory location begins before the end of the store
+ // location. If the condition holds, they might overlap, otherwise they are
+ // guaranteed to not overlap.
+ IRBuilder<> Builder(MatMul);
+ Check0->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check0);
+ Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
+ Value *StoreBegin = Builder.CreatePtrToInt(
+ const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
+ Value *StoreEnd = Builder.CreateAdd(
+ StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
+ "store.end", true, true);
+ Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
+ IntPtrTy, "load.begin");
+ Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
+ Fusion);
+
+ // Check if the store begins before the end of the load location. If the
+ // condition holds, they alias, otherwise they are guaranteed to not
+ // overlap.
+ Check1->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check1, Check1->begin());
+ Value *LoadEnd = Builder.CreateAdd(
+ LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
+ "load.end", true, true);
+ Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
+ Fusion);
+
+ // Copy load operand to new alloca.
+ Builder.SetInsertPoint(Copy, Copy->begin());
+ AllocaInst *NewLd =
+ Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
+ Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
+ Load->getPointerOperand(), Load->getAlign(),
+ LoadLoc.Size.getValue());
+ Builder.SetInsertPoint(Fusion, Fusion->begin());
+ PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
+ PHI->addIncoming(Load->getPointerOperand(), Check0);
+ PHI->addIncoming(Load->getPointerOperand(), Check1);
+ PHI->addIncoming(NewLd, Copy);
+
+ // Adjust DT.
+ DTUpdates.push_back({DT.Insert, Check0, Check1});
+ DTUpdates.push_back({DT.Insert, Check0, Fusion});
+ DTUpdates.push_back({DT.Insert, Check1, Copy});
+ DTUpdates.push_back({DT.Insert, Check1, Fusion});
+ DT.applyUpdates(DTUpdates);
+ return PHI;
+ }
+
+ bool isFusionProfitable(CallInst *MatMul) {
+ if (ForceFusion)
+ return true;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ const unsigned VF =
+ std::max<unsigned>(TTI.getRegisterBitWidth(true) /
+ EltType->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+
+ // Cost model for tiling
+ //
+ // For tiling to be beneficial, we need reuse either along the R or
+ // the C axis. We vectorize along the R axis so that means at least
+ // 3 elements.
+ // TODO: Also consider cost of copying if operands alias.
+ if (R <= VF && C == 1)
+ return false;
+ // Then we need enough elements to exceed the number of vector
+ // registers we have. Note that this is an oversimplification since
+ // fusing also takes some extra loads which may exceed the number of
+ // reloads necessary.
+ unsigned Op0Regs = (R + VF - 1) / VF * M;
+ unsigned Op1Regs = (M + VF - 1) / VF * C;
+ return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
+ }
+
+ MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
+ MatrixTy Res;
+ auto *ColumType = FixedVectorType::get(EltType, R);
+ for (unsigned I = 0; I < C; ++I)
+ Res.addVector(ConstantAggregateZero::get(ColumType));
+ return Res;
+ }
+
+ void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
+ StoreInst *Store,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Tiling only supported for column-major matrixes at the moment!");
+ if (!isFusionProfitable(MatMul))
+ return;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
+ Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
+ Value *CPtr = Store->getPointerOperand();
+
+ bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
+ MatMul->hasAllowContract());
+ IRBuilder<> Builder(Store);
+ for (unsigned J = 0; J < C; J += TileSize)
+ for (unsigned I = 0; I < R; I += TileSize) {
+ const unsigned TileR = std::min(R - I, unsigned(TileSize));
+ const unsigned TileC = std::min(C - J, unsigned(TileSize));
+ MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
+
+ for (unsigned K = 0; K < M; K += TileSize) {
+ const unsigned TileM = std::min(M - K, unsigned(TileSize));
+ MatrixTy A =
+ loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
+ LShape, Builder.getInt64(I), Builder.getInt64(K),
+ {TileR, TileM}, EltType, Builder);
+ MatrixTy B =
+ loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
+ RShape, Builder.getInt64(K), Builder.getInt64(J),
+ {TileM, TileC}, EltType, Builder);
+ emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
+ }
+ storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
+ Builder.getInt64(I), Builder.getInt64(J), EltType, Builder);
+ }
+
+ // Mark eliminated instructions as fused and remove them.
+ FusedInsts.insert(Store);
+ FusedInsts.insert(MatMul);
+ Store->eraseFromParent();
+ MatMul->eraseFromParent();
+ if (LoadOp0->hasNUses(0)) {
+ FusedInsts.insert(LoadOp0);
+ LoadOp0->eraseFromParent();
+ }
+ if (LoadOp1->hasNUses(0)) {
+ FusedInsts.insert(LoadOp1);
+ LoadOp1->eraseFromParent();
+ }
+ }
+
+ /// Try to lower matrix multiply chains by fusing operations.
+ ///
+ /// Currently we only lower {ld, ld} -> matmul -> st chains.
+ //
+ /// No need to return a MatrixTy object for the result of the operation, since
+ /// the single store user will be lowered as part of this. Instructions that
+ /// are completely eliminated by fusion are added to \p FusedInsts.
+ void LowerMatrixMultiplyFused(CallInst *MatMul,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ if (!FuseMatrix || !MatMul->hasOneUse() ||
+ MatrixLayout != MatrixLayoutTy::ColumnMajor)
+ return;
+
+ auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
+ auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1));
+ auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
+ if (LoadOp0 && LoadOp1 && Store) {
+ // The store address must dominate the MatMul instruction, otherwise
+ // we create invalid IR.
+ // FIXME: See if we can hoist the store address computation.
+ auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
+ if (AddrI && (!DT.dominates(AddrI, MatMul)))
+ return;
+
+ emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
+ return;
+ }
+ }
+
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
@@ -706,97 +1298,80 @@ public:
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const ColumnMatrixTy &Lhs =
- getMatrix(MatMul->getArgOperand(0), LShape, Builder);
- const ColumnMatrixTy &Rhs =
- getMatrix(MatMul->getArgOperand(1), RShape, Builder);
+ const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
+ const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
const unsigned R = LShape.NumRows;
- const unsigned M = LShape.NumColumns;
const unsigned C = RShape.NumColumns;
- assert(M == RShape.NumRows);
+ assert(LShape.NumColumns == RShape.NumRows);
// Initialize the output
- ColumnMatrixTy Result;
- for (unsigned J = 0; J < C; ++J)
- Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
-
- const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
- EltType->getPrimitiveSizeInBits(),
- uint64_t(1));
+ MatrixTy Result(R, C, EltType);
bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
MatMul->hasAllowContract());
- // Multiply columns from the first operand with scalars from the second
- // operand. Then move along the K axes and accumulate the columns. With
- // this the adds can be vectorized without reassociation.
- for (unsigned J = 0; J < C; ++J) {
- unsigned BlockSize = VF;
- for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
- Value *Sum = nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
- Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
- Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
- Builder, AllowContract);
- }
- Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
- }
- }
+ emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
finalizeLowering(MatMul, Result, Builder);
}
/// Lowers llvm.matrix.transpose.
void LowerTranspose(CallInst *Inst) {
- ColumnMatrixTy Result;
+ MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
VectorType *VectorTy = cast<VectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
- ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
-
- for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
- // Build a single column vector for this row. First initialize it.
- Value *ResultColumn = UndefValue::get(
- VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
-
- // Go through the elements of this row and insert it into the resulting
- // column vector.
- for (auto C : enumerate(InputMatrix.columns())) {
- Value *Elt = Builder.CreateExtractElement(C.value(), Row);
- // We insert at index Column since that is the row index after the
- // transpose.
- ResultColumn =
- Builder.CreateInsertElement(ResultColumn, Elt, C.index());
+ MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
+
+ const unsigned NewNumVecs =
+ InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
+ const unsigned NewNumElts =
+ InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
+
+ for (unsigned I = 0; I < NewNumVecs; ++I) {
+ // Build a single result vector. First initialize it.
+ Value *ResultVector = UndefValue::get(
+ FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
+ // Go through the old elements and insert it into the resulting vector.
+ for (auto J : enumerate(InputMatrix.vectors())) {
+ Value *Elt = Builder.CreateExtractElement(J.value(), I);
+ // Row and column indices are transposed.
+ ResultVector =
+ Builder.CreateInsertElement(ResultVector, Elt, J.index());
}
- Result.addColumn(ResultColumn);
+ Result.addVector(ResultVector);
}
- finalizeLowering(Inst, Result, Builder);
+ // TODO: Improve estimate of operations needed for transposes. Currently we
+ // just count the insertelement/extractelement instructions, but do not
+ // account for later simplifications/combines.
+ finalizeLowering(
+ Inst,
+ Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
+ Builder);
}
/// Lower load instructions, if shape information is available.
- bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
+ bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
if (I == ShapeMap.end())
return false;
- LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
+ LowerLoad(Inst, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
return true;
}
- bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
+ bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
auto I = ShapeMap.find(StoredVal);
if (I == ShapeMap.end())
return false;
- LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
+ LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
return true;
}
@@ -812,12 +1387,15 @@ public:
IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;
- ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
- ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
+ MatrixTy Result;
+ MatrixTy A = getMatrix(Lhs, Shape, Builder);
+ MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
- // Add each column and store the result back into the opmapping
- ColumnMatrixTy Result;
- auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
+ // Helper to perform binary op on vectors.
+ auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
switch (Inst->getOpcode()) {
case Instruction::Add:
return Builder.CreateAdd(LHS, RHS);
@@ -835,20 +1413,462 @@ public:
llvm_unreachable("Unsupported binary operator for matrix");
}
};
- for (unsigned C = 0; C < Shape.NumColumns; ++C)
- Result.addColumn(
- BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
- finalizeLowering(Inst, Result, Builder);
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
+
+ finalizeLowering(Inst,
+ Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors()),
+ Builder);
return true;
}
+
+ /// Helper to linearize a matrix expression tree into a string. Currently
+ /// matrix expressions are linarized by starting at an expression leaf and
+ /// linearizing bottom up.
+ struct ExprLinearizer {
+ unsigned LengthToBreak = 100;
+ std::string Str;
+ raw_string_ostream Stream;
+ unsigned LineLength = 0;
+ const DataLayout &DL;
+
+ /// Mapping from instructions to matrixes. It is used to identify
+ /// matrix instructions.
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+
+ /// Mapping from values to the leaves of all expressions that the value is
+ /// part of.
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
+
+ /// Set of matrix expressions in the scope of a given DISubprogram.
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram;
+
+ /// Leaf node of the expression to linearize.
+ Value *Leaf;
+
+ /// Used to keep track of sub-expressions that get reused while linearizing
+ /// the expression. Re-used sub-expressions are marked as (reused).
+ SmallPtrSet<Value *, 8> ReusedExprs;
+
+ ExprLinearizer(const DataLayout &DL,
+ const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ Value *Leaf)
+ : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
+ ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
+
+ void indent(unsigned N) {
+ LineLength += N;
+ for (unsigned i = 0; i < N; i++)
+ Stream << " ";
+ }
+
+ void lineBreak() {
+ Stream << "\n";
+ LineLength = 0;
+ }
+
+ void maybeIndent(unsigned Indent) {
+ if (LineLength >= LengthToBreak)
+ lineBreak();
+
+ if (LineLength == 0)
+ indent(Indent);
+ }
+
+ void write(StringRef S) {
+ LineLength += S.size();
+ Stream << S;
+ }
+
+ Value *getUnderlyingObjectThroughLoads(Value *V) {
+ if (Value *Ptr = getPointerOperand(V))
+ return getUnderlyingObjectThroughLoads(Ptr);
+ else if (V->getType()->isPointerTy())
+ return GetUnderlyingObject(V, DL);
+ return V;
+ }
+
+ /// Returns true if \p V is a matrix value in the given subprogram.
+ bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
+
+ /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
+ /// \p SS.
+ void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
+ auto M = Inst2Matrix.find(V);
+ if (M == Inst2Matrix.end())
+ SS << "unknown";
+ else {
+ SS << M->second.getNumRows();
+ SS << "x";
+ SS << M->second.getNumColumns();
+ }
+ }
+
+ /// Write the called function name. Handles calls to llvm.matrix.*
+ /// specially: we write the name, followed by the dimensions of the input
+ /// matrixes, followed by the scalar type name.
+ void writeFnName(CallInst *CI) {
+ if (!CI->getCalledFunction())
+ write("<no called fn>");
+ else {
+ StringRef Name = CI->getCalledFunction()->getName();
+ if (!Name.startswith("llvm.matrix")) {
+ write(Name);
+ return;
+ }
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
+ write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
+ .drop_front(StringRef("llvm.matrix.").size()));
+ write(".");
+ std::string Tmp = "";
+ raw_string_ostream SS(Tmp);
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << ".";
+ prettyPrintMatrixType(II->getOperand(1), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_transpose:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_load:
+ prettyPrintMatrixType(II, SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_store:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getOperand(0)->getType()->getScalarType();
+ break;
+ default:
+ llvm_unreachable("Unhandled case");
+ }
+ SS.flush();
+ write(Tmp);
+ }
+ }
+
+ unsigned getNumShapeArgs(CallInst *CI) const {
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ return 3;
+ case Intrinsic::matrix_transpose:
+ return 2;
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
+ return 3;
+ default:
+ return 0;
+ }
+ }
+ return 0;
+ }
+
+ /// Special printing for values: for pointers, we print if they refer to an
+ /// (function) external address or a stack address, for other values we
+ /// either print the constant or "scalar"/"matrix" for other values.
+ void write(Value *V) {
+ V = getUnderlyingObjectThroughLoads(V);
+ if (V->getType()->isPointerTy()) {
+ if (isa<AllocaInst>(V)) {
+ Stream << "stack addr";
+ LineLength += StringRef("stack addr").size();
+ } else {
+ Stream << "addr";
+ LineLength += StringRef("addr").size();
+ }
+ if (!V->getName().empty()) {
+ Stream << " %" << V->getName() << "";
+ LineLength += V->getName().size() + 2;
+ }
+ return;
+ }
+
+ std::string Tmp;
+ raw_string_ostream TmpStream(Tmp);
+
+ if (auto *CI = dyn_cast<ConstantInt>(V))
+ TmpStream << CI->getValue();
+ else if (isa<Constant>(V))
+ TmpStream << "constant";
+ else {
+ if (isMatrix(V))
+ TmpStream << "matrix";
+ else
+ TmpStream << "scalar";
+ }
+ TmpStream.flush();
+ Tmp = std::string(StringRef(Tmp).trim());
+ LineLength += Tmp.size();
+ Stream << Tmp;
+ }
+
+ /// Linearize expression \p Expr starting at an indentation of \p Indent.
+ /// Expressions that are re-used multiple times are prefixed with (reused)
+ /// at the re-used root instruction.
+ void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
+ bool ParentShared) {
+ auto *I = cast<Instruction>(Expr);
+ maybeIndent(Indent);
+ SmallVector<Value *, 8> Ops;
+
+ // Is Expr shared with other expression leaves?
+ bool ExprShared = false;
+
+ // Deal with shared subtrees. Mark them as shared, if required.
+ if (!ParentShared) {
+ auto SI = Shared.find(Expr);
+ assert(SI != Shared.end() && SI->second.count(Leaf));
+
+ for (Value *S : SI->second) {
+ if (S == Leaf)
+ continue;
+ DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
+ write("shared with remark at line " + std::to_string(DL.getLine()) +
+ " column " + std::to_string(DL.getCol()) + " (");
+ }
+ ExprShared = SI->second.size() > 1;
+ }
+
+ bool Reused = !ReusedExprs.insert(Expr).second;
+ if (Reused && !ParentReused)
+ write("(reused) ");
+
+ if (auto *CI = dyn_cast<CallInst>(I)) {
+ writeFnName(CI);
+
+ Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
+ } else if (isa<BitCastInst>(Expr)) {
+ // Special case bitcasts, which are used to materialize matrixes from
+ // non-matrix ops.
+ write("matrix");
+ return;
+ } else {
+ Ops.append(I->value_op_begin(), I->value_op_end());
+ write(std::string(I->getOpcodeName()));
+ }
+
+ write(std::string("("));
+
+ unsigned NumOpsToBreak = 1;
+ if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
+ NumOpsToBreak = 2;
+
+ for (Value *Op : Ops) {
+ if (Ops.size() > NumOpsToBreak)
+ lineBreak();
+
+ maybeIndent(Indent + 1);
+ if (isMatrix(Op))
+ linearizeExpr(Op, Indent + 1, Reused, ExprShared);
+ else
+ write(Op);
+ if (Op != Ops.back())
+ write(", ");
+ }
+
+ write(")");
+ }
+
+ const std::string &getResult() {
+ Stream.flush();
+ return Str;
+ }
+ };
+
+ /// Generate remarks for matrix operations in a function. To generate remarks
+ /// for matrix expressions, the following approach is used:
+ /// 1. Use the inlined-at debug information to group matrix operations to the
+ /// DISubprograms they are contained in.
+ /// 2. Collect leaves of matrix expressions (done in
+ /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
+ // mapping. Leaves are lowered matrix instructions without other matrix
+ // users (like stores) in the current subprogram.
+ /// 3. For each leaf, create a remark containing a linearizied version of the
+ /// matrix expression. The expression is linearized by a recursive
+ /// bottom-up traversal of the matrix operands, starting at a leaf. Note
+ /// that multiple leaves can share sub-expressions. Shared subexpressions
+ /// are explicitly marked as shared().
+ struct RemarkGenerator {
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+ OptimizationRemarkEmitter &ORE;
+ Function &Func;
+ const DataLayout &DL;
+
+ RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ OptimizationRemarkEmitter &ORE, Function &Func)
+ : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
+ DL(Func.getParent()->getDataLayout()) {}
+
+ /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
+ /// instructions in Inst2Matrix returning void or without any users in
+ /// \p ExprsInSubprogram. Currently that should only include stores.
+ SmallVector<Value *, 4>
+ getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
+ SmallVector<Value *, 4> Leaves;
+ for (auto *Expr : ExprsInSubprogram)
+ if (Expr->getType()->isVoidTy() ||
+ !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
+ return ExprsInSubprogram.count(U);
+ }))
+ Leaves.push_back(Expr);
+ return Leaves;
+ }
+
+ /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
+ /// to all visited expressions in \p Shared. Limit the matrix operations to
+ /// the ones in \p ExprsInSubprogram.
+ void collectSharedInfo(Value *Leaf, Value *V,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
+
+ if (!ExprsInSubprogram.count(V))
+ return;
+
+ auto I = Shared.insert({V, {}});
+ I.first->second.insert(Leaf);
+
+ for (Value *Op : cast<Instruction>(V)->operand_values())
+ collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
+ return;
+ }
+
+ /// Calculate the number of exclusive and shared op counts for expression
+ /// starting at \p V. Expressions used multiple times are counted once.
+ /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
+ std::pair<OpInfoTy, OpInfoTy>
+ sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
+ if (!ExprsInSubprogram.count(Root))
+ return {};
+
+ // Already counted this expression. Stop.
+ if (!ReusedExprs.insert(Root).second)
+ return {};
+
+ OpInfoTy SharedCount;
+ OpInfoTy Count;
+
+ auto I = Shared.find(Root);
+ auto CM = Inst2Matrix.find(Root);
+ if (I->second.size() == 1)
+ Count = CM->second.getOpInfo();
+ else
+ SharedCount = CM->second.getOpInfo();
+
+ for (Value *Op : cast<Instruction>(Root)->operand_values()) {
+ auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
+ Count += C.first;
+ SharedCount += C.second;
+ }
+ return {Count, SharedCount};
+ }
+
+ void emitRemarks() {
+ if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
+ return;
+
+ // Map matrix operations to their containting subprograms, by traversing
+ // the inlinedAt chain. If the function does not have a DISubprogram, we
+ // only map them to the containing function.
+ MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
+ for (auto &KV : Inst2Matrix) {
+ if (Func.getSubprogram()) {
+ auto *I = cast<Instruction>(KV.first);
+ DILocation *Context = I->getDebugLoc();
+ while (Context) {
+ auto I =
+ Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
+ I.first->second.push_back(KV.first);
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+ } else {
+ auto I = Subprog2Exprs.insert({nullptr, {}});
+ I.first->second.push_back(KV.first);
+ }
+ }
+ for (auto &KV : Subprog2Exprs) {
+ SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
+ KV.second.end());
+ auto Leaves = getExpressionLeaves(ExprsInSubprogram);
+
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
+ for (Value *Leaf : Leaves)
+ collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
+
+ // Generate remarks for each leaf.
+ for (auto *L : Leaves) {
+
+ DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
+ DILocation *Context = cast<Instruction>(L)->getDebugLoc();
+ while (Context) {
+ if (getSubprogram(Context->getScope()) == KV.first) {
+ Loc = Context;
+ break;
+ }
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+
+ SmallPtrSet<Value *, 8> ReusedExprs;
+ OpInfoTy Counts, SharedCounts;
+ std::tie(Counts, SharedCounts) =
+ sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
+
+ OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
+ cast<Instruction>(L)->getParent());
+
+ Rem << "Lowered with ";
+ Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
+ << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
+ << ore::NV("NumComputeOps", Counts.NumComputeOps)
+ << " compute ops";
+
+ if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
+ SharedCounts.NumComputeOps > 0) {
+ Rem << ",\nadditionally "
+ << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
+ << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
+ << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
+ << " compute ops"
+ << " are shared with other expressions";
+ }
+
+ Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
+ ORE.emit(Rem);
+ }
+ }
+ }
+
+ std::string
+ linearize(Value *L,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ const DataLayout &DL) {
+ ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
+ Lin.linearizeExpr(L, 0, false, false);
+ return Lin.getResult();
+ }
+ };
};
} // namespace
PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
- LowerMatrixIntrinsics LMT(F, TTI);
+ auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &LI = AM.getResult<LoopAnalysis>(F);
+
+ LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
if (LMT.Visit()) {
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
@@ -869,15 +1889,24 @@ public:
}
bool runOnFunction(Function &F) override {
- auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LowerMatrixIntrinsics LMT(F, *TTI);
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
bool C = LMT.Visit();
return C;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
+ AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
}
};
} // namespace
@@ -886,6 +1915,10 @@ static const char pass_name[] = "Lower the matrix intrinsics";
char LowerMatrixIntrinsicsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
false, false)
+INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
false, false)
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index c24fa40860eb..4b4196edc12b 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -27,7 +27,6 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
@@ -173,8 +172,8 @@ public:
void addStore(int64_t OffsetFromFirst, StoreInst *SI) {
int64_t StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType());
- addRange(OffsetFromFirst, StoreSize,
- SI->getPointerOperand(), SI->getAlignment(), SI);
+ addRange(OffsetFromFirst, StoreSize, SI->getPointerOperand(),
+ SI->getAlign().value(), SI);
}
void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) {
@@ -387,13 +386,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
// Get the starting pointer of the block.
StartPtr = Range.StartPtr;
- // Determine alignment
- const Align Alignment = DL.getValueOrABITypeAlignment(
- MaybeAlign(Range.Alignment),
- cast<PointerType>(StartPtr->getType())->getElementType());
-
AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start,
- Alignment);
+ MaybeAlign(Range.Alignment));
LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
: Range.TheStores) dbgs()
<< *SI << '\n';
@@ -413,23 +407,6 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
return AMemSet;
}
-static Align findStoreAlignment(const DataLayout &DL, const StoreInst *SI) {
- return DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()),
- SI->getOperand(0)->getType());
-}
-
-static Align findLoadAlignment(const DataLayout &DL, const LoadInst *LI) {
- return DL.getValueOrABITypeAlignment(MaybeAlign(LI->getAlignment()),
- LI->getType());
-}
-
-static Align findCommonAlignment(const DataLayout &DL, const StoreInst *SI,
- const LoadInst *LI) {
- Align StoreAlign = findStoreAlignment(DL, SI);
- Align LoadAlign = findLoadAlignment(DL, LI);
- return commonAlignment(StoreAlign, LoadAlign);
-}
-
// This method try to lift a store instruction before position P.
// It will lift the store and its argument + that anything that
// may alias with these.
@@ -585,12 +562,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
Instruction *M;
if (UseMemMove)
M = Builder.CreateMemMove(
- SI->getPointerOperand(), findStoreAlignment(DL, SI),
- LI->getPointerOperand(), findLoadAlignment(DL, LI), Size);
+ SI->getPointerOperand(), SI->getAlign(),
+ LI->getPointerOperand(), LI->getAlign(), Size);
else
M = Builder.CreateMemCpy(
- SI->getPointerOperand(), findStoreAlignment(DL, SI),
- LI->getPointerOperand(), findLoadAlignment(DL, LI), Size);
+ SI->getPointerOperand(), SI->getAlign(),
+ LI->getPointerOperand(), LI->getAlign(), Size);
LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => "
<< *M << "\n");
@@ -642,7 +619,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
LI, SI->getPointerOperand()->stripPointerCasts(),
LI->getPointerOperand()->stripPointerCasts(),
DL.getTypeStoreSize(SI->getOperand(0)->getType()),
- findCommonAlignment(DL, SI, LI).value(), C);
+ commonAlignment(SI->getAlign(), LI->getAlign()), C);
if (changed) {
MD->removeInstruction(SI);
SI->eraseFromParent();
@@ -675,11 +652,9 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
auto *T = V->getType();
if (T->isAggregateType()) {
uint64_t Size = DL.getTypeStoreSize(T);
- const Align MA =
- DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()), T);
IRBuilder<> Builder(SI);
- auto *M =
- Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, MA);
+ auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size,
+ SI->getAlign());
LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n");
@@ -713,7 +688,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
/// the call write its result directly into the destination of the memcpy.
bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
Value *cpySrc, uint64_t cpyLen,
- unsigned cpyAlign, CallInst *C) {
+ Align cpyAlign, CallInst *C) {
// The general transformation to keep in mind is
//
// call @func(..., src, ...)
@@ -733,10 +708,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start)
return false;
- // Deliberately get the source and destination with bitcasts stripped away,
- // because we'll need to do type comparisons based on the underlying type.
- CallSite CS(C);
-
// Require that src be an alloca. This simplifies the reasoning considerably.
AllocaInst *srcAlloca = dyn_cast<AllocaInst>(cpySrc);
if (!srcAlloca)
@@ -795,9 +766,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
}
// Check that dest points to memory that is at least as aligned as src.
- unsigned srcAlign = srcAlloca->getAlignment();
- if (!srcAlign)
- srcAlign = DL.getABITypeAlignment(srcAlloca->getAllocatedType());
+ Align srcAlign = srcAlloca->getAlign();
bool isDestSufficientlyAligned = srcAlign <= cpyAlign;
// If dest is not aligned enough and we can't increase its alignment then
// bail out.
@@ -836,8 +805,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
// Check that src isn't captured by the called function since the
// transformation can cause aliasing issues in that case.
- for (unsigned i = 0, e = CS.arg_size(); i != e; ++i)
- if (CS.getArgument(i) == cpySrc && !CS.doesNotCapture(i))
+ for (unsigned ArgI = 0, E = C->arg_size(); ArgI != E; ++ArgI)
+ if (C->getArgOperand(ArgI) == cpySrc && !C->doesNotCapture(ArgI))
return false;
// Since we're changing the parameter to the callsite, we need to make sure
@@ -864,25 +833,26 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
if (cpySrc->getType()->getPointerAddressSpace() !=
cpyDest->getType()->getPointerAddressSpace())
return false;
- for (unsigned i = 0; i < CS.arg_size(); ++i)
- if (CS.getArgument(i)->stripPointerCasts() == cpySrc &&
+ for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI)
+ if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc &&
cpySrc->getType()->getPointerAddressSpace() !=
- CS.getArgument(i)->getType()->getPointerAddressSpace())
+ C->getArgOperand(ArgI)->getType()->getPointerAddressSpace())
return false;
// All the checks have passed, so do the transformation.
bool changedArgument = false;
- for (unsigned i = 0; i < CS.arg_size(); ++i)
- if (CS.getArgument(i)->stripPointerCasts() == cpySrc) {
+ for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI)
+ if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc) {
Value *Dest = cpySrc->getType() == cpyDest->getType() ? cpyDest
: CastInst::CreatePointerCast(cpyDest, cpySrc->getType(),
cpyDest->getName(), C);
changedArgument = true;
- if (CS.getArgument(i)->getType() == Dest->getType())
- CS.setArgument(i, Dest);
+ if (C->getArgOperand(ArgI)->getType() == Dest->getType())
+ C->setArgOperand(ArgI, Dest);
else
- CS.setArgument(i, CastInst::CreatePointerCast(Dest,
- CS.getArgument(i)->getType(), Dest->getName(), C));
+ C->setArgOperand(ArgI, CastInst::CreatePointerCast(
+ Dest, C->getArgOperand(ArgI)->getType(),
+ Dest->getName(), C));
}
if (!changedArgument)
@@ -891,7 +861,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
// If the destination wasn't sufficiently aligned then increase its alignment.
if (!isDestSufficientlyAligned) {
assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!");
- cast<AllocaInst>(cpyDest)->setAlignment(MaybeAlign(srcAlign));
+ cast<AllocaInst>(cpyDest)->setAlignment(srcAlign);
}
// Drop any cached information about the call, because we may have changed
@@ -1127,15 +1097,16 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
/// B to be a memcpy from X to Z (or potentially a memmove, depending on
/// circumstances). This allows later passes to remove the first memcpy
/// altogether.
-bool MemCpyOptPass::processMemCpy(MemCpyInst *M) {
+bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
// We can only optimize non-volatile memcpy's.
if (M->isVolatile()) return false;
// If the source and destination of the memcpy are the same, then zap it.
if (M->getSource() == M->getDest()) {
+ ++BBI;
MD->removeInstruction(M);
M->eraseFromParent();
- return false;
+ return true;
}
// If copying from a constant, try to turn the memcpy into a memset.
@@ -1176,10 +1147,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) {
if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) {
// FIXME: Can we pass in either of dest/src alignment here instead
// of conservatively taking the minimum?
- unsigned Align = MinAlign(M->getDestAlignment(), M->getSourceAlignment());
+ Align Alignment = std::min(M->getDestAlign().valueOrOne(),
+ M->getSourceAlign().valueOrOne());
if (performCallSlotOptzn(M, M->getDest(), M->getSource(),
- CopySize->getZExtValue(), Align,
- C)) {
+ CopySize->getZExtValue(), Alignment, C)) {
MD->removeInstruction(M);
M->eraseFromParent();
return true;
@@ -1247,15 +1218,15 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) {
}
/// This is called on every byval argument in call sites.
-bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {
- const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout();
+bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
+ const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout();
// Find out what feeds this byval argument.
- Value *ByValArg = CS.getArgument(ArgNo);
+ Value *ByValArg = CB.getArgOperand(ArgNo);
Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType();
uint64_t ByValSize = DL.getTypeAllocSize(ByValTy);
MemDepResult DepInfo = MD->getPointerDependencyFrom(
MemoryLocation(ByValArg, LocationSize::precise(ByValSize)), true,
- CS.getInstruction()->getIterator(), CS.getInstruction()->getParent());
+ CB.getIterator(), CB.getParent());
if (!DepInfo.isClobber())
return false;
@@ -1274,16 +1245,17 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {
// Get the alignment of the byval. If the call doesn't specify the alignment,
// then it is some target specific value that we can't know.
- unsigned ByValAlign = CS.getParamAlignment(ArgNo);
- if (ByValAlign == 0) return false;
+ MaybeAlign ByValAlign = CB.getParamAlign(ArgNo);
+ if (!ByValAlign) return false;
// If it is greater than the memcpy, then we check to see if we can force the
// source of the memcpy to the alignment we need. If we fail, we bail out.
AssumptionCache &AC = LookupAssumptionCache();
DominatorTree &DT = LookupDomTree();
- if (MDep->getSourceAlignment() < ByValAlign &&
- getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL,
- CS.getInstruction(), &AC, &DT) < ByValAlign)
+ MaybeAlign MemDepAlign = MDep->getSourceAlign();
+ if ((!MemDepAlign || *MemDepAlign < *ByValAlign) &&
+ getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, &CB, &AC,
+ &DT) < *ByValAlign)
return false;
// The address space of the memcpy source must match the byval argument
@@ -1302,21 +1274,25 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {
// not just the defining memcpy.
MemDepResult SourceDep = MD->getPointerDependencyFrom(
MemoryLocation::getForSource(MDep), false,
- CS.getInstruction()->getIterator(), MDep->getParent());
+ CB.getIterator(), MDep->getParent());
if (!SourceDep.isClobber() || SourceDep.getInst() != MDep)
return false;
Value *TmpCast = MDep->getSource();
- if (MDep->getSource()->getType() != ByValArg->getType())
- TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(),
- "tmpcast", CS.getInstruction());
+ if (MDep->getSource()->getType() != ByValArg->getType()) {
+ BitCastInst *TmpBitCast = new BitCastInst(MDep->getSource(), ByValArg->getType(),
+ "tmpcast", &CB);
+ // Set the tmpcast's DebugLoc to MDep's
+ TmpBitCast->setDebugLoc(MDep->getDebugLoc());
+ TmpCast = TmpBitCast;
+ }
LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n"
<< " " << *MDep << "\n"
- << " " << *CS.getInstruction() << "\n");
+ << " " << CB << "\n");
// Otherwise we're good! Update the byval argument.
- CS.setArgument(ArgNo, TmpCast);
+ CB.setArgOperand(ArgNo, TmpCast);
++NumMemCpyInstr;
return true;
}
@@ -1347,13 +1323,13 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {
else if (MemSetInst *M = dyn_cast<MemSetInst>(I))
RepeatInstruction = processMemSet(M, BI);
else if (MemCpyInst *M = dyn_cast<MemCpyInst>(I))
- RepeatInstruction = processMemCpy(M);
+ RepeatInstruction = processMemCpy(M, BI);
else if (MemMoveInst *M = dyn_cast<MemMoveInst>(I))
RepeatInstruction = processMemMove(M);
- else if (auto CS = CallSite(I)) {
- for (unsigned i = 0, e = CS.arg_size(); i != e; ++i)
- if (CS.isByValArgument(i))
- MadeChange |= processByValArgument(CS, i);
+ else if (auto *CB = dyn_cast<CallBase>(I)) {
+ for (unsigned i = 0, e = CB->arg_size(); i != e; ++i)
+ if (CB->isByValArgument(i))
+ MadeChange |= processByValArgument(*CB, i);
}
// Reprocess the instruction if desired.
diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
index 6b0d0202d9bb..69aa0cebe170 100644
--- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
+++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
@@ -354,15 +354,11 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) {
// optimization opportunities.
// This loop doesn't care about newly inserted/split blocks
// since they never will be diamond heads.
- for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) {
- BasicBlock *BB = &*FI++;
-
+ for (BasicBlock &BB : make_early_inc_range(F))
// Hoist equivalent loads and sink stores
// outside diamonds when possible
- if (isDiamondHead(BB)) {
- Changed |= mergeStores(BB);
- }
- }
+ if (isDiamondHead(&BB))
+ Changed |= mergeStores(&BB);
return Changed;
}
diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
index bba9082e31b2..4e010f8704d0 100644
--- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -213,7 +213,7 @@ bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_,
return Changed;
}
-// Whitelist the instruction types NaryReassociate handles for now.
+// Explicitly list the instruction types NaryReassociate handles for now.
static bool isPotentiallyNaryReassociable(Instruction *I) {
switch (I->getOpcode()) {
case Instruction::Add:
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 6a643480f312..0ed1773373a7 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -106,6 +106,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVNExpression.h"
+#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/PredicateInfo.h"
#include "llvm/Transforms/Utils/VNCoercion.h"
@@ -495,6 +496,7 @@ class NewGVN {
AliasAnalysis *AA = nullptr;
MemorySSA *MSSA = nullptr;
MemorySSAWalker *MSSAWalker = nullptr;
+ AssumptionCache *AC = nullptr;
const DataLayout &DL;
std::unique_ptr<PredicateInfo> PredInfo;
@@ -658,7 +660,7 @@ public:
NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC,
TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA,
const DataLayout &DL)
- : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL),
+ : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), AC(AC), DL(DL),
PredInfo(std::make_unique<PredicateInfo>(F, *DT, *AC)),
SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {}
@@ -898,7 +900,7 @@ bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const {
#ifndef NDEBUG
static std::string getBlockName(const BasicBlock *B) {
- return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr);
+ return DOTGraphTraits<DOTFuncInfo *>::getSimpleNodeLabel(B, nullptr);
}
#endif
@@ -1334,8 +1336,6 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,
// Give store and loads same opcode so they value number together.
E->setOpcode(0);
E->op_push_back(PointerOp);
- if (LI)
- E->setAlignment(MaybeAlign(LI->getAlignment()));
// TODO: Value number heap versions. We may be able to discover
// things alias analysis can't on it's own (IE that a store and a
@@ -1470,7 +1470,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr,
// undef value. This can happen when loading for a fresh allocation with no
// intervening stores, for example. Note that this is only true in the case
// that the result of the allocation is pointer equal to the load ptr.
- if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) {
+ if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) ||
+ isAlignedAllocLikeFn(DepInst, TLI)) {
return createConstantExpression(UndefValue::get(LoadType));
}
// If this load occurs either right after a lifetime begin,
@@ -2030,10 +2031,12 @@ NewGVN::performSymbolicEvaluation(Value *V,
case Instruction::Select:
case Instruction::ExtractElement:
case Instruction::InsertElement:
- case Instruction::ShuffleVector:
case Instruction::GetElementPtr:
E = createExpression(I);
break;
+ case Instruction::ShuffleVector:
+ // FIXME: Add support for shufflevector to createExpression.
+ return nullptr;
default:
return nullptr;
}
@@ -3433,7 +3436,7 @@ bool NewGVN::runGVN() {
// Sort dominator tree children arrays into RPO.
for (auto &B : RPOT) {
auto *Node = DT->getNode(B);
- if (Node->getChildren().size() > 1)
+ if (Node->getNumChildren() > 1)
llvm::sort(Node->begin(), Node->end(),
[&](const DomTreeNode *A, const DomTreeNode *B) {
return RPOOrdering[A] < RPOOrdering[B];
@@ -3693,6 +3696,7 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {
Inst.replaceAllUsesWith(UndefValue::get(Inst.getType()));
if (isa<LandingPadInst>(Inst))
continue;
+ salvageKnowledge(&Inst, AC);
Inst.eraseFromParent();
++NumGVNInstrDeleted;
diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
index 5c4a89977c38..4553b23532f2 100644
--- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
+++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
@@ -189,7 +189,8 @@ static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) {
return false;
}
- return !(isStatepoint(Call) || isGCRelocate(Call) || isGCResult(Call));
+ return !(isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) ||
+ isa<GCResultInst>(Call));
}
/// Returns true if this loop is known to contain a call safepoint which
@@ -650,7 +651,7 @@ InsertSafepointPoll(Instruction *InsertBefore,
// Do the actual inlining
InlineFunctionInfo IFI;
- bool InlineStatus = InlineFunction(PollCall, IFI);
+ bool InlineStatus = InlineFunction(*PollCall, IFI).isSuccess();
assert(InlineStatus && "inline must succeed");
(void)InlineStatus; // suppress warning in release-asserts
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 41940e980faa..ba7f367267fe 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -29,6 +29,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
@@ -254,15 +255,15 @@ static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name,
}
}
-static BinaryOperator *CreateNeg(Value *S1, const Twine &Name,
- Instruction *InsertBefore, Value *FlagsOp) {
+static Instruction *CreateNeg(Value *S1, const Twine &Name,
+ Instruction *InsertBefore, Value *FlagsOp) {
if (S1->getType()->isIntOrIntVectorTy())
return BinaryOperator::CreateNeg(S1, Name, InsertBefore);
- else {
- BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore);
- Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags());
- return Res;
- }
+
+ if (auto *FMFSource = dyn_cast<Instruction>(FlagsOp))
+ return UnaryOperator::CreateFNegFMF(S1, FMFSource, Name, InsertBefore);
+
+ return UnaryOperator::CreateFNeg(S1, Name, InsertBefore);
}
/// Replace 0-X with X*-1.
@@ -914,7 +915,7 @@ static Value *NegateValue(Value *V, Instruction *BI,
// Insert a 'neg' instruction that subtracts the value from zero to get the
// negation.
- BinaryOperator *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI);
+ Instruction *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI);
ToRedo.insert(NewNeg);
return NewNeg;
}
@@ -975,7 +976,8 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub,
/// this into a multiply by a constant to assist with further reassociation.
static BinaryOperator *ConvertShiftToMul(Instruction *Shl) {
Constant *MulCst = ConstantInt::get(Shl->getType(), 1);
- MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1)));
+ auto *SA = cast<ConstantInt>(Shl->getOperand(1));
+ MulCst = ConstantExpr::getShl(MulCst, SA);
BinaryOperator *Mul =
BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl);
@@ -988,10 +990,12 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) {
// We can safely preserve the nuw flag in all cases. It's also safe to turn a
// nuw nsw shl into a nuw nsw mul. However, nsw in isolation requires special
- // handling.
+ // handling. It can be preserved as long as we're not left shifting by
+ // bitwidth - 1.
bool NSW = cast<BinaryOperator>(Shl)->hasNoSignedWrap();
bool NUW = cast<BinaryOperator>(Shl)->hasNoUnsignedWrap();
- if (NSW && NUW)
+ unsigned BitWidth = Shl->getType()->getIntegerBitWidth();
+ if (NSW && (NUW || SA->getValue().ult(BitWidth - 1)))
Mul->setHasNoSignedWrap(true);
Mul->setHasNoUnsignedWrap(NUW);
return Mul;
@@ -1076,7 +1080,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
const APFloat &F1 = FC1->getValueAPF();
APFloat F2(FC2->getValueAPF());
F2.changeSign();
- if (F1.compare(F2) == APFloat::cmpEqual) {
+ if (F1 == F2) {
FoundFactor = NeedsNegate = true;
Factors.erase(Factors.begin() + i);
break;
@@ -1721,7 +1725,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
}
/// Build a tree of multiplies, computing the product of Ops.
-static Value *buildMultiplyTree(IRBuilder<> &Builder,
+static Value *buildMultiplyTree(IRBuilderBase &Builder,
SmallVectorImpl<Value*> &Ops) {
if (Ops.size() == 1)
return Ops.back();
@@ -1744,7 +1748,7 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder,
/// DAG of multiplies to compute the final product, and return that product
/// value.
Value *
-ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder,
+ReassociatePass::buildMinimalMultiplyDAG(IRBuilderBase &Builder,
SmallVectorImpl<Factor> &Factors) {
assert(Factors[0].Power);
SmallVector<Value *, 4> OuterProduct;
@@ -1899,7 +1903,7 @@ void ReassociatePass::RecursivelyEraseDeadInsts(Instruction *I,
ValueRankMap.erase(I);
Insts.remove(I);
RedoInsts.remove(I);
- llvm::salvageDebugInfoOrMarkUndef(*I);
+ llvm::salvageDebugInfo(*I);
I->eraseFromParent();
for (auto Op : Ops)
if (Instruction *OpInst = dyn_cast<Instruction>(Op))
@@ -1916,7 +1920,7 @@ void ReassociatePass::EraseInst(Instruction *I) {
// Erase the dead instruction.
ValueRankMap.erase(I);
RedoInsts.remove(I);
- llvm::salvageDebugInfoOrMarkUndef(*I);
+ llvm::salvageDebugInfo(*I);
I->eraseFromParent();
// Optimize its operands.
SmallPtrSet<Instruction *, 8> Visited; // Detect self-referential nodes.
@@ -2457,6 +2461,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
if (MadeChange) {
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
+ PA.preserve<AAManager>();
+ PA.preserve<BasicAA>();
PA.preserve<GlobalsAA>();
return PA;
}
@@ -2487,6 +2493,8 @@ namespace {
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
+ AU.addPreserved<AAResultsWrapperPass>();
+ AU.addPreserved<BasicAAWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
};
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index b242f100faff..dc2ad14ae61e 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -271,7 +271,7 @@ struct PartiallyConstructedSafepointRecord {
/// The *new* gc.statepoint instruction itself. This produces the token
/// that normal path gc.relocates and the gc.result are tied to.
- Instruction *StatepointToken;
+ GCStatepointInst *StatepointToken;
/// Instruction to which exceptional gc relocates are attached
/// Makes it easier to iterate through them during relocationViaAlloca.
@@ -381,14 +381,19 @@ static void analyzeParsePointLiveness(
dbgs() << " " << V->getName() << " " << *V << "\n";
}
if (PrintLiveSetSize) {
- dbgs() << "Safepoint For: " << Call->getCalledValue()->getName() << "\n";
+ dbgs() << "Safepoint For: " << Call->getCalledOperand()->getName() << "\n";
dbgs() << "Number live values: " << LiveSet.size() << "\n";
}
Result.LiveSet = LiveSet;
}
+// Returns true is V is a knownBaseResult.
static bool isKnownBaseResult(Value *V);
+// Returns true if V is a BaseResult that already exists in the IR, i.e. it is
+// not created by the findBasePointers algorithm.
+static bool isOriginalBaseResult(Value *V);
+
namespace {
/// A single base defining value - An immediate base defining value for an
@@ -633,15 +638,20 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
return Def;
}
+/// This value is a base pointer that is not generated by RS4GC, i.e. it already
+/// exists in the code.
+static bool isOriginalBaseResult(Value *V) {
+ // no recursion possible
+ return !isa<PHINode>(V) && !isa<SelectInst>(V) &&
+ !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
+ !isa<ShuffleVectorInst>(V);
+}
+
/// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
/// is it known to be a base pointer? Or do we need to continue searching.
static bool isKnownBaseResult(Value *V) {
- if (!isa<PHINode>(V) && !isa<SelectInst>(V) &&
- !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
- !isa<ShuffleVectorInst>(V)) {
- // no recursion possible
+ if (isOriginalBaseResult(V))
return true;
- }
if (isa<Instruction>(V) &&
cast<Instruction>(V)->getMetadata("is_base_value")) {
// This is a previously inserted base phi or select. We know
@@ -653,6 +663,12 @@ static bool isKnownBaseResult(Value *V) {
return false;
}
+// Returns true if First and Second values are both scalar or both vector.
+static bool areBothVectorOrScalar(Value *First, Value *Second) {
+ return isa<VectorType>(First->getType()) ==
+ isa<VectorType>(Second->getType());
+}
+
namespace {
/// Models the state of a single base defining value in the findBasePointer
@@ -762,7 +778,7 @@ static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {
static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
Value *Def = findBaseOrBDV(I, Cache);
- if (isKnownBaseResult(Def))
+ if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I))
return Def;
// Here's the rough algorithm:
@@ -810,13 +826,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
States.insert({Def, BDVState()});
while (!Worklist.empty()) {
Value *Current = Worklist.pop_back_val();
- assert(!isKnownBaseResult(Current) && "why did it get added?");
+ assert(!isOriginalBaseResult(Current) && "why did it get added?");
auto visitIncomingValue = [&](Value *InVal) {
Value *Base = findBaseOrBDV(InVal, Cache);
- if (isKnownBaseResult(Base))
+ if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal))
// Known bases won't need new instructions introduced and can be
- // ignored safely
+ // ignored safely. However, this can only be done when InVal and Base
+ // are both scalar or both vector. Otherwise, we need to find a
+ // correct BDV for InVal, by creating an entry in the lattice
+ // (States).
return;
assert(isExpectedBDVType(Base) && "the only non-base values "
"we see should be base defining values");
@@ -853,10 +872,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
// Return a phi state for a base defining value. We'll generate a new
// base state for known bases and expect to find a cached state otherwise.
- auto getStateForBDV = [&](Value *baseValue) {
- if (isKnownBaseResult(baseValue))
- return BDVState(baseValue);
- auto I = States.find(baseValue);
+ auto GetStateForBDV = [&](Value *BaseValue, Value *Input) {
+ if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input))
+ return BDVState(BaseValue);
+ auto I = States.find(BaseValue);
assert(I != States.end() && "lookup failed!");
return I->second;
};
@@ -873,13 +892,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
// much faster.
for (auto Pair : States) {
Value *BDV = Pair.first;
- assert(!isKnownBaseResult(BDV) && "why did it get added?");
+ // Only values that do not have known bases or those that have differing
+ // type (scalar versus vector) from a possible known base should be in the
+ // lattice.
+ assert((!isKnownBaseResult(BDV) ||
+ !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) &&
+ "why did it get added?");
// Given an input value for the current instruction, return a BDVState
// instance which represents the BDV of that value.
auto getStateForInput = [&](Value *V) mutable {
Value *BDV = findBaseOrBDV(V, Cache);
- return getStateForBDV(BDV);
+ return GetStateForBDV(BDV, V);
};
BDVState NewState;
@@ -926,20 +950,26 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
}
#endif
- // Insert Phis for all conflicts
- // TODO: adjust naming patterns to avoid this order of iteration dependency
+ // Handle all instructions that have a vector BDV, but the instruction itself
+ // is of scalar type.
for (auto Pair : States) {
Instruction *I = cast<Instruction>(Pair.first);
BDVState State = Pair.second;
- assert(!isKnownBaseResult(I) && "why did it get added?");
+ auto *BaseValue = State.getBaseValue();
+ // Only values that do not have known bases or those that have differing
+ // type (scalar versus vector) from a possible known base should be in the
+ // lattice.
+ assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) &&
+ "why did it get added?");
assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
+ if (!State.isBase() || !isa<VectorType>(BaseValue->getType()))
+ continue;
// extractelement instructions are a bit special in that we may need to
// insert an extract even when we know an exact base for the instruction.
// The problem is that we need to convert from a vector base to a scalar
// base for the particular indice we're interested in.
- if (State.isBase() && isa<ExtractElementInst>(I) &&
- isa<VectorType>(State.getBaseValue()->getType())) {
+ if (isa<ExtractElementInst>(I)) {
auto *EE = cast<ExtractElementInst>(I);
// TODO: In many cases, the new instruction is just EE itself. We should
// exploit this, but can't do it here since it would break the invariant
@@ -948,7 +978,27 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
States[I] = BDVState(BDVState::Base, BaseInst);
+ } else if (!isa<VectorType>(I->getType())) {
+ // We need to handle cases that have a vector base but the instruction is
+ // a scalar type (these could be phis or selects or any instruction that
+ // are of scalar type, but the base can be a vector type). We
+ // conservatively set this as conflict. Setting the base value for these
+ // conflicts is handled in the next loop which traverses States.
+ States[I] = BDVState(BDVState::Conflict);
}
+ }
+
+ // Insert Phis for all conflicts
+ // TODO: adjust naming patterns to avoid this order of iteration dependency
+ for (auto Pair : States) {
+ Instruction *I = cast<Instruction>(Pair.first);
+ BDVState State = Pair.second;
+ // Only values that do not have known bases or those that have differing
+ // type (scalar versus vector) from a possible known base should be in the
+ // lattice.
+ assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) &&
+ "why did it get added?");
+ assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
// Since we're joining a vector and scalar base, they can never be the
// same. As a result, we should always see insert element having reached
@@ -987,7 +1037,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
auto *SV = cast<ShuffleVectorInst>(I);
UndefValue *VecUndef = UndefValue::get(SV->getOperand(0)->getType());
std::string Name = suffixed_name_or(I, ".base", "base_sv");
- return new ShuffleVectorInst(VecUndef, VecUndef, SV->getOperand(2),
+ return new ShuffleVectorInst(VecUndef, VecUndef, SV->getShuffleMask(),
Name, SV);
}
};
@@ -1008,7 +1058,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) {
Value *BDV = findBaseOrBDV(Input, Cache);
Value *Base = nullptr;
- if (isKnownBaseResult(BDV)) {
+ if (isKnownBaseResult(BDV) && areBothVectorOrScalar(BDV, Input)) {
Base = BDV;
} else {
// Either conflict or base.
@@ -1029,7 +1079,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
Instruction *BDV = cast<Instruction>(Pair.first);
BDVState State = Pair.second;
- assert(!isKnownBaseResult(BDV) && "why did it get added?");
+ // Only values that do not have known bases or those that have differing
+ // type (scalar versus vector) from a possible known base should be in the
+ // lattice.
+ assert((!isKnownBaseResult(BDV) ||
+ !areBothVectorOrScalar(BDV, State.getBaseValue())) &&
+ "why did it get added?");
assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
if (!State.isConflict())
continue;
@@ -1119,7 +1174,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {
auto *BDV = Pair.first;
Value *Base = Pair.second.getBaseValue();
assert(BDV && Base);
- assert(!isKnownBaseResult(BDV) && "why did it get added?");
+ // Only values that do not have known bases or those that have differing
+ // type (scalar versus vector) from a possible known base should be in the
+ // lattice.
+ assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) &&
+ "why did it get added?");
LLVM_DEBUG(
dbgs() << "Updating base value cache"
@@ -1238,7 +1297,8 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent,
// Create new attribute set containing only attributes which can be transferred
// from original call to the safepoint.
-static AttributeList legalizeCallAttributes(AttributeList AL) {
+static AttributeList legalizeCallAttributes(LLVMContext &Ctx,
+ AttributeList AL) {
if (AL.isEmpty())
return AL;
@@ -1252,7 +1312,6 @@ static AttributeList legalizeCallAttributes(AttributeList AL) {
}
// Just skip parameter and return attributes for now
- LLVMContext &Ctx = AL.getContext();
return AttributeList::get(Ctx, AttributeList::FunctionIndex,
AttributeSet::get(Ctx, FnAttrs));
}
@@ -1261,16 +1320,14 @@ static AttributeList legalizeCallAttributes(AttributeList AL) {
/// statepoint.
/// Inputs:
/// liveVariables - list of variables to be relocated.
-/// liveStart - index of the first live variable.
/// basePtrs - base pointers.
/// statepointToken - statepoint instruction to which relocates should be
/// bound.
/// Builder - Llvm IR builder to be used to construct new calls.
static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
- const int LiveStart,
ArrayRef<Value *> BasePtrs,
Instruction *StatepointToken,
- IRBuilder<> Builder) {
+ IRBuilder<> &Builder) {
if (LiveVariables.empty())
return;
@@ -1295,7 +1352,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
auto AS = Ty->getScalarType()->getPointerAddressSpace();
Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS);
if (auto *VT = dyn_cast<VectorType>(Ty))
- NewTy = VectorType::get(NewTy, VT->getNumElements());
+ NewTy = FixedVectorType::get(NewTy,
+ cast<FixedVectorType>(VT)->getNumElements());
return Intrinsic::getDeclaration(M, Intrinsic::experimental_gc_relocate,
{NewTy});
};
@@ -1307,9 +1365,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
for (unsigned i = 0; i < LiveVariables.size(); i++) {
// Generate the gc.relocate call and save the result
- Value *BaseIdx =
- Builder.getInt32(LiveStart + FindIndex(LiveVariables, BasePtrs[i]));
- Value *LiveIdx = Builder.getInt32(LiveStart + i);
+ Value *BaseIdx = Builder.getInt32(FindIndex(LiveVariables, BasePtrs[i]));
+ Value *LiveIdx = Builder.getInt32(i);
Type *Ty = LiveVariables[i]->getType();
if (!TypeToDeclMap.count(Ty))
@@ -1431,12 +1488,14 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
uint32_t Flags = uint32_t(StatepointFlags::None);
ArrayRef<Use> CallArgs(Call->arg_begin(), Call->arg_end());
- ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(Call);
- ArrayRef<Use> TransitionArgs;
- if (auto TransitionBundle =
- Call->getOperandBundle(LLVMContext::OB_gc_transition)) {
+ Optional<ArrayRef<Use>> DeoptArgs;
+ if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_deopt))
+ DeoptArgs = Bundle->Inputs;
+ Optional<ArrayRef<Use>> TransitionArgs;
+ if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_gc_transition)) {
+ TransitionArgs = Bundle->Inputs;
+ // TODO: This flag no longer serves a purpose and can be removed later
Flags |= uint32_t(StatepointFlags::GCTransition);
- TransitionArgs = TransitionBundle->Inputs;
}
// Instead of lowering calls to @llvm.experimental.deoptimize as normal calls
@@ -1459,7 +1518,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
assert(DeoptLowering.equals("live-through") && "Unsupported value!");
}
- Value *CallTarget = Call->getCalledValue();
+ Value *CallTarget = Call->getCalledOperand();
if (Function *F = dyn_cast<Function>(CallTarget)) {
if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) {
// Calls to llvm.experimental.deoptimize are lowered to calls to the
@@ -1485,7 +1544,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
}
// Create the statepoint given all the arguments
- Instruction *Token = nullptr;
+ GCStatepointInst *Token = nullptr;
if (auto *CI = dyn_cast<CallInst>(Call)) {
CallInst *SPCall = Builder.CreateGCStatepointCall(
StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs,
@@ -1498,9 +1557,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
// function attributes. In case if we can handle this set of attributes -
// set up function attrs directly on statepoint and return attrs later for
// gc_result intrinsic.
- SPCall->setAttributes(legalizeCallAttributes(CI->getAttributes()));
+ SPCall->setAttributes(
+ legalizeCallAttributes(CI->getContext(), CI->getAttributes()));
- Token = SPCall;
+ Token = cast<GCStatepointInst>(SPCall);
// Put the following gc_result and gc_relocate calls immediately after the
// the old call (which we're about to delete)
@@ -1524,9 +1584,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
// function attributes. In case if we can handle this set of attributes -
// set up function attrs directly on statepoint and return attrs later for
// gc_result intrinsic.
- SPInvoke->setAttributes(legalizeCallAttributes(II->getAttributes()));
+ SPInvoke->setAttributes(
+ legalizeCallAttributes(II->getContext(), II->getAttributes()));
- Token = SPInvoke;
+ Token = cast<GCStatepointInst>(SPInvoke);
// Generate gc relocates in exceptional path
BasicBlock *UnwindBlock = II->getUnwindDest();
@@ -1541,9 +1602,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst();
Result.UnwindToken = ExceptionalToken;
- const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx();
- CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, ExceptionalToken,
- Builder);
+ CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder);
// Generate gc relocates and returns for normal block
BasicBlock *NormalDest = II->getNormalDest();
@@ -1589,8 +1648,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Result.StatepointToken = Token;
// Second, create a gc.relocate for every live variable
- const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx();
- CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, Token, Builder);
+ CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder);
}
// Replace an existing gc.statepoint with a new one and a set of gc.relocates
@@ -1651,8 +1709,8 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs,
cast<AllocaInst>(Alloca)->getAllocatedType(),
suffixed_name_or(Relocate, ".casted", ""));
- StoreInst *Store = new StoreInst(CastedRelocatedValue, Alloca);
- Store->insertAfter(cast<Instruction>(CastedRelocatedValue));
+ new StoreInst(CastedRelocatedValue, Alloca,
+ cast<Instruction>(CastedRelocatedValue)->getNextNode());
#ifndef NDEBUG
VisitedLiveValues.insert(OriginalValue);
@@ -1674,8 +1732,8 @@ static void insertRematerializationStores(
"Can not find alloca for rematerialized value");
Value *Alloca = AllocaMap[OriginalValue];
- StoreInst *Store = new StoreInst(RematerializedValue, Alloca);
- Store->insertAfter(RematerializedValue);
+ new StoreInst(RematerializedValue, Alloca,
+ RematerializedValue->getNextNode());
#ifndef NDEBUG
VisitedLiveValues.insert(OriginalValue);
@@ -1780,8 +1838,7 @@ static void relocationViaAlloca(
for (auto *AI : ToClobber) {
auto PT = cast<PointerType>(AI->getAllocatedType());
Constant *CPN = ConstantPointerNull::get(PT);
- StoreInst *Store = new StoreInst(CPN, AI);
- Store->insertBefore(IP);
+ new StoreInst(CPN, AI, IP);
}
};
@@ -1843,7 +1900,8 @@ static void relocationViaAlloca(
// Emit store for the initial gc value. Store must be inserted after load,
// otherwise store will be in alloca's use list and an extra load will be
// inserted before it.
- StoreInst *Store = new StoreInst(Def, Alloca);
+ StoreInst *Store = new StoreInst(Def, Alloca, /*volatile*/ false,
+ DL.getABITypeAlign(Def->getType()));
if (Instruction *Inst = dyn_cast<Instruction>(Def)) {
if (InvokeInst *Invoke = dyn_cast<InvokeInst>(Inst)) {
// InvokeInst is a terminator so the store need to be inserted into its
@@ -1966,7 +2024,9 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,
"non noop cast is found during rematerialization");
Type *SrcTy = CI->getOperand(0)->getType();
- Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI);
+ Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy,
+ TargetTransformInfo::TCK_SizeAndLatency,
+ CI);
} else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {
// Cost of the address calculation
@@ -2344,9 +2404,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// That Value* no longer exists and we need to use the new gc_result.
// Thankfully, the live set is embedded in the statepoint (and updated), so
// we just grab that.
- Statepoint Statepoint(Info.StatepointToken);
- Live.insert(Live.end(), Statepoint.gc_args_begin(),
- Statepoint.gc_args_end());
+ Live.insert(Live.end(), Info.StatepointToken->gc_args_begin(),
+ Info.StatepointToken->gc_args_end());
#ifndef NDEBUG
// Do some basic sanity checks on our liveness results before performing
// relocation. Relocation can and will turn mistakes in liveness results
@@ -2354,7 +2413,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// TODO: It would be nice to test consistency as well
assert(DT.isReachableFromEntry(Info.StatepointToken->getParent()) &&
"statepoint must be reachable or liveness is meaningless");
- for (Value *V : Statepoint.gc_args()) {
+ for (Value *V : Info.StatepointToken->gc_args()) {
if (!isa<Instruction>(V))
// Non-instruction values trivial dominate all possible uses
continue;
@@ -2523,7 +2582,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,
auto NeedsRewrite = [&TLI](Instruction &I) {
if (const auto *Call = dyn_cast<CallBase>(&I))
- return !callsGCLeafFunction(Call, TLI) && !isStatepoint(Call);
+ return !callsGCLeafFunction(Call, TLI) && !isa<GCStatepointInst>(Call);
return false;
};
@@ -2608,10 +2667,10 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,
unsigned VF = 0;
for (unsigned i = 0; i < I.getNumOperands(); i++)
- if (I.getOperand(i)->getType()->isVectorTy()) {
+ if (auto *OpndVTy = dyn_cast<VectorType>(I.getOperand(i)->getType())) {
assert(VF == 0 ||
- VF == I.getOperand(i)->getType()->getVectorNumElements());
- VF = I.getOperand(i)->getType()->getVectorNumElements();
+ VF == cast<FixedVectorType>(OpndVTy)->getNumElements());
+ VF = cast<FixedVectorType>(OpndVTy)->getNumElements();
}
// It's the vector to scalar traversal through the pointer operand which
diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp
index e696ea83a300..5ebd3b71fe78 100644
--- a/llvm/lib/Transforms/Scalar/SCCP.cpp
+++ b/llvm/lib/Transforms/Scalar/SCCP.cpp
@@ -27,12 +27,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
@@ -67,123 +68,44 @@ using namespace llvm;
STATISTIC(NumInstRemoved, "Number of instructions removed");
STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable");
+STATISTIC(NumInstReplaced,
+ "Number of instructions replaced with (simpler) instruction");
STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP");
STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP");
STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP");
-
+STATISTIC(
+ IPNumInstReplaced,
+ "Number of instructions replaced with (simpler) instruction by IPSCCP");
+
+// The maximum number of range extensions allowed for operations requiring
+// widening.
+static const unsigned MaxNumRangeExtensions = 10;
+
+/// Returns MergeOptions with MaxWidenSteps set to MaxNumRangeExtensions.
+static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() {
+ return ValueLatticeElement::MergeOptions().setMaxWidenSteps(
+ MaxNumRangeExtensions);
+}
namespace {
-/// LatticeVal class - This class represents the different lattice values that
-/// an LLVM value may occupy. It is a simple class with value semantics.
-///
-class LatticeVal {
- enum LatticeValueTy {
- /// unknown - This LLVM Value has no known value yet.
- unknown,
-
- /// constant - This LLVM Value has a specific constant value.
- constant,
-
- /// forcedconstant - This LLVM Value was thought to be undef until
- /// ResolvedUndefsIn. This is treated just like 'constant', but if merged
- /// with another (different) constant, it goes to overdefined, instead of
- /// asserting.
- forcedconstant,
-
- /// overdefined - This instruction is not known to be constant, and we know
- /// it has a value.
- overdefined
- };
-
- /// Val: This stores the current lattice value along with the Constant* for
- /// the constant if this is a 'constant' or 'forcedconstant' value.
- PointerIntPair<Constant *, 2, LatticeValueTy> Val;
-
- LatticeValueTy getLatticeValue() const {
- return Val.getInt();
- }
-
-public:
- LatticeVal() : Val(nullptr, unknown) {}
-
- bool isUnknown() const { return getLatticeValue() == unknown; }
-
- bool isConstant() const {
- return getLatticeValue() == constant || getLatticeValue() == forcedconstant;
- }
-
- bool isOverdefined() const { return getLatticeValue() == overdefined; }
-
- Constant *getConstant() const {
- assert(isConstant() && "Cannot get the constant of a non-constant!");
- return Val.getPointer();
- }
-
- /// markOverdefined - Return true if this is a change in status.
- bool markOverdefined() {
- if (isOverdefined())
- return false;
-
- Val.setInt(overdefined);
- return true;
- }
-
- /// markConstant - Return true if this is a change in status.
- bool markConstant(Constant *V) {
- if (getLatticeValue() == constant) { // Constant but not forcedconstant.
- assert(getConstant() == V && "Marking constant with different value");
- return false;
- }
-
- if (isUnknown()) {
- Val.setInt(constant);
- assert(V && "Marking constant with NULL");
- Val.setPointer(V);
- } else {
- assert(getLatticeValue() == forcedconstant &&
- "Cannot move from overdefined to constant!");
- // Stay at forcedconstant if the constant is the same.
- if (V == getConstant()) return false;
-
- // Otherwise, we go to overdefined. Assumptions made based on the
- // forced value are possibly wrong. Assuming this is another constant
- // could expose a contradiction.
- Val.setInt(overdefined);
- }
- return true;
- }
-
- /// getConstantInt - If this is a constant with a ConstantInt value, return it
- /// otherwise return null.
- ConstantInt *getConstantInt() const {
- if (isConstant())
- return dyn_cast<ConstantInt>(getConstant());
- return nullptr;
- }
-
- /// getBlockAddress - If this is a constant with a BlockAddress value, return
- /// it, otherwise return null.
- BlockAddress *getBlockAddress() const {
- if (isConstant())
- return dyn_cast<BlockAddress>(getConstant());
- return nullptr;
- }
-
- void markForcedConstant(Constant *V) {
- assert(isUnknown() && "Can't force a defined value!");
- Val.setInt(forcedconstant);
- Val.setPointer(V);
- }
+// Helper to check if \p LV is either a constant or a constant
+// range with a single element. This should cover exactly the same cases as the
+// old ValueLatticeElement::isConstant() and is intended to be used in the
+// transition to ValueLatticeElement.
+bool isConstant(const ValueLatticeElement &LV) {
+ return LV.isConstant() ||
+ (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
+}
- ValueLatticeElement toValueLattice() const {
- if (isOverdefined())
- return ValueLatticeElement::getOverdefined();
- if (isConstant())
- return ValueLatticeElement::get(getConstant());
- return ValueLatticeElement();
- }
-};
+// Helper to check if \p LV is either overdefined or a constant range with more
+// than a single element. This should cover exactly the same cases as the old
+// ValueLatticeElement::isOverdefined() and is intended to be used in the
+// transition to ValueLatticeElement.
+bool isOverdefined(const ValueLatticeElement &LV) {
+ return LV.isOverdefined() ||
+ (LV.isConstantRange() && !LV.getConstantRange().isSingleElement());
+}
//===----------------------------------------------------------------------===//
//
@@ -194,28 +116,28 @@ class SCCPSolver : public InstVisitor<SCCPSolver> {
const DataLayout &DL;
std::function<const TargetLibraryInfo &(Function &)> GetTLI;
SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable.
- DenseMap<Value *, LatticeVal> ValueState; // The state each value is in.
- // The state each parameter is in.
- DenseMap<Value *, ValueLatticeElement> ParamState;
+ DenseMap<Value *, ValueLatticeElement>
+ ValueState; // The state each value is in.
/// StructValueState - This maintains ValueState for values that have
/// StructType, for example for formal arguments, calls, insertelement, etc.
- DenseMap<std::pair<Value *, unsigned>, LatticeVal> StructValueState;
+ DenseMap<std::pair<Value *, unsigned>, ValueLatticeElement> StructValueState;
/// GlobalValue - If we are tracking any values for the contents of a global
/// variable, we keep a mapping from the constant accessor to the element of
/// the global, to the currently known value. If the value becomes
/// overdefined, it's entry is simply removed from this map.
- DenseMap<GlobalVariable *, LatticeVal> TrackedGlobals;
+ DenseMap<GlobalVariable *, ValueLatticeElement> TrackedGlobals;
/// TrackedRetVals - If we are tracking arguments into and the return
/// value out of a function, it will have an entry in this map, indicating
/// what the known return value for the function is.
- MapVector<Function *, LatticeVal> TrackedRetVals;
+ MapVector<Function *, ValueLatticeElement> TrackedRetVals;
/// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions
/// that return multiple values.
- MapVector<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals;
+ MapVector<std::pair<Function *, unsigned>, ValueLatticeElement>
+ TrackedMultipleRetVals;
/// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is
/// represented here for efficient lookup.
@@ -251,6 +173,8 @@ class SCCPSolver : public InstVisitor<SCCPSolver> {
DenseMap<Function *, AnalysisResultsForFn> AnalysisResults;
DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers;
+ LLVMContext &Ctx;
+
public:
void addAnalysis(Function &F, AnalysisResultsForFn A) {
AnalysisResults.insert({&F, std::move(A)});
@@ -270,8 +194,9 @@ public:
}
SCCPSolver(const DataLayout &DL,
- std::function<const TargetLibraryInfo &(Function &)> GetTLI)
- : DL(DL), GetTLI(std::move(GetTLI)) {}
+ std::function<const TargetLibraryInfo &(Function &)> GetTLI,
+ LLVMContext &Ctx)
+ : DL(DL), GetTLI(std::move(GetTLI)), Ctx(Ctx) {}
/// MarkBlockExecutable - This method can be used by clients to mark all of
/// the blocks that are known to be intrinsically live in the processed unit.
@@ -292,7 +217,7 @@ public:
void TrackValueOfGlobalVariable(GlobalVariable *GV) {
// We only track the contents of scalar globals.
if (GV->getValueType()->isSingleValueType()) {
- LatticeVal &IV = TrackedGlobals[GV];
+ ValueLatticeElement &IV = TrackedGlobals[GV];
if (!isa<UndefValue>(GV->getInitializer()))
IV.markConstant(GV->getInitializer());
}
@@ -306,10 +231,10 @@ public:
if (auto *STy = dyn_cast<StructType>(F->getReturnType())) {
MRVFunctionsTracked.insert(F);
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i)
- TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i),
- LatticeVal()));
+ TrackedMultipleRetVals.insert(
+ std::make_pair(std::make_pair(F, i), ValueLatticeElement()));
} else
- TrackedRetVals.insert(std::make_pair(F, LatticeVal()));
+ TrackedRetVals.insert(std::make_pair(F, ValueLatticeElement()));
}
/// AddMustTailCallee - If the SCCP solver finds that this function is called
@@ -352,8 +277,8 @@ public:
// block to the 'To' basic block is currently feasible.
bool isEdgeFeasible(BasicBlock *From, BasicBlock *To);
- std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const {
- std::vector<LatticeVal> StructValues;
+ std::vector<ValueLatticeElement> getStructLatticeValueFor(Value *V) const {
+ std::vector<ValueLatticeElement> StructValues;
auto *STy = dyn_cast<StructType>(V->getType());
assert(STy && "getStructLatticeValueFor() can be called only on structs");
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
@@ -364,23 +289,26 @@ public:
return StructValues;
}
- const LatticeVal &getLatticeValueFor(Value *V) const {
+ void removeLatticeValueFor(Value *V) { ValueState.erase(V); }
+
+ const ValueLatticeElement &getLatticeValueFor(Value *V) const {
assert(!V->getType()->isStructTy() &&
"Should use getStructLatticeValueFor");
- DenseMap<Value *, LatticeVal>::const_iterator I = ValueState.find(V);
+ DenseMap<Value *, ValueLatticeElement>::const_iterator I =
+ ValueState.find(V);
assert(I != ValueState.end() &&
"V not found in ValueState nor Paramstate map!");
return I->second;
}
/// getTrackedRetVals - Get the inferred return value map.
- const MapVector<Function*, LatticeVal> &getTrackedRetVals() {
+ const MapVector<Function *, ValueLatticeElement> &getTrackedRetVals() {
return TrackedRetVals;
}
/// getTrackedGlobals - Get and return the set of inferred initializers for
/// global variables.
- const DenseMap<GlobalVariable*, LatticeVal> &getTrackedGlobals() {
+ const DenseMap<GlobalVariable *, ValueLatticeElement> &getTrackedGlobals() {
return TrackedGlobals;
}
@@ -407,32 +335,59 @@ public:
}
// isStructLatticeConstant - Return true if all the lattice values
- // corresponding to elements of the structure are not overdefined,
+ // corresponding to elements of the structure are constants,
// false otherwise.
bool isStructLatticeConstant(Function *F, StructType *STy) {
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i));
assert(It != TrackedMultipleRetVals.end());
- LatticeVal LV = It->second;
- if (LV.isOverdefined())
+ ValueLatticeElement LV = It->second;
+ if (!isConstant(LV))
return false;
}
return true;
}
+ /// Helper to return a Constant if \p LV is either a constant or a constant
+ /// range with a single element.
+ Constant *getConstant(const ValueLatticeElement &LV) const {
+ if (LV.isConstant())
+ return LV.getConstant();
+
+ if (LV.isConstantRange()) {
+ auto &CR = LV.getConstantRange();
+ if (CR.getSingleElement())
+ return ConstantInt::get(Ctx, *CR.getSingleElement());
+ }
+ return nullptr;
+ }
+
private:
- // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined
- void pushToWorkList(LatticeVal &IV, Value *V) {
+ ConstantInt *getConstantInt(const ValueLatticeElement &IV) const {
+ return dyn_cast_or_null<ConstantInt>(getConstant(IV));
+ }
+
+ // pushToWorkList - Helper for markConstant/markOverdefined
+ void pushToWorkList(ValueLatticeElement &IV, Value *V) {
if (IV.isOverdefined())
return OverdefinedInstWorkList.push_back(V);
InstWorkList.push_back(V);
}
+ // Helper to push \p V to the worklist, after updating it to \p IV. Also
+ // prints a debug message with the updated value.
+ void pushToWorkListMsg(ValueLatticeElement &IV, Value *V) {
+ LLVM_DEBUG(dbgs() << "updated " << IV << ": " << *V << '\n');
+ pushToWorkList(IV, V);
+ }
+
// markConstant - Make a value be marked as "constant". If the value
// is not already a constant, add it to the instruction work list so that
// the users of the instruction are updated later.
- bool markConstant(LatticeVal &IV, Value *V, Constant *C) {
- if (!IV.markConstant(C)) return false;
+ bool markConstant(ValueLatticeElement &IV, Value *V, Constant *C,
+ bool MayIncludeUndef = false) {
+ if (!IV.markConstant(C, MayIncludeUndef))
+ return false;
LLVM_DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n');
pushToWorkList(IV, V);
return true;
@@ -443,18 +398,10 @@ private:
return markConstant(ValueState[V], V, C);
}
- void markForcedConstant(Value *V, Constant *C) {
- assert(!V->getType()->isStructTy() && "structs should use mergeInValue");
- LatticeVal &IV = ValueState[V];
- IV.markForcedConstant(C);
- LLVM_DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n');
- pushToWorkList(IV, V);
- }
-
// markOverdefined - Make a value be marked as "overdefined". If the
// value is not already overdefined, add it to the overdefined instruction
// work list so that the users of the instruction are updated later.
- bool markOverdefined(LatticeVal &IV, Value *V) {
+ bool markOverdefined(ValueLatticeElement &IV, Value *V) {
if (!IV.markOverdefined()) return false;
LLVM_DEBUG(dbgs() << "markOverdefined: ";
@@ -466,71 +413,59 @@ private:
return true;
}
- bool mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) {
- if (IV.isOverdefined() || MergeWithV.isUnknown())
- return false; // Noop.
- if (MergeWithV.isOverdefined())
- return markOverdefined(IV, V);
- if (IV.isUnknown())
- return markConstant(IV, V, MergeWithV.getConstant());
- if (IV.getConstant() != MergeWithV.getConstant())
- return markOverdefined(IV, V);
+ /// Merge \p MergeWithV into \p IV and push \p V to the worklist, if \p IV
+ /// changes.
+ bool mergeInValue(ValueLatticeElement &IV, Value *V,
+ ValueLatticeElement MergeWithV,
+ ValueLatticeElement::MergeOptions Opts = {
+ /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) {
+ if (IV.mergeIn(MergeWithV, Opts)) {
+ pushToWorkList(IV, V);
+ LLVM_DEBUG(dbgs() << "Merged " << MergeWithV << " into " << *V << " : "
+ << IV << "\n");
+ return true;
+ }
return false;
}
- bool mergeInValue(Value *V, LatticeVal MergeWithV) {
+ bool mergeInValue(Value *V, ValueLatticeElement MergeWithV,
+ ValueLatticeElement::MergeOptions Opts = {
+ /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) {
assert(!V->getType()->isStructTy() &&
"non-structs should use markConstant");
- return mergeInValue(ValueState[V], V, MergeWithV);
+ return mergeInValue(ValueState[V], V, MergeWithV, Opts);
}
- /// getValueState - Return the LatticeVal object that corresponds to the
- /// value. This function handles the case when the value hasn't been seen yet
- /// by properly seeding constants etc.
- LatticeVal &getValueState(Value *V) {
+ /// getValueState - Return the ValueLatticeElement object that corresponds to
+ /// the value. This function handles the case when the value hasn't been seen
+ /// yet by properly seeding constants etc.
+ ValueLatticeElement &getValueState(Value *V) {
assert(!V->getType()->isStructTy() && "Should use getStructValueState");
- std::pair<DenseMap<Value*, LatticeVal>::iterator, bool> I =
- ValueState.insert(std::make_pair(V, LatticeVal()));
- LatticeVal &LV = I.first->second;
+ auto I = ValueState.insert(std::make_pair(V, ValueLatticeElement()));
+ ValueLatticeElement &LV = I.first->second;
if (!I.second)
return LV; // Common case, already in the map.
- if (auto *C = dyn_cast<Constant>(V)) {
- // Undef values remain unknown.
- if (!isa<UndefValue>(V))
- LV.markConstant(C); // Constants are constant
- }
-
- // All others are underdefined by default.
- return LV;
- }
-
- ValueLatticeElement &getParamState(Value *V) {
- assert(!V->getType()->isStructTy() && "Should use getStructValueState");
-
- std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool>
- PI = ParamState.insert(std::make_pair(V, ValueLatticeElement()));
- ValueLatticeElement &LV = PI.first->second;
- if (PI.second)
- LV = getValueState(V).toValueLattice();
+ if (auto *C = dyn_cast<Constant>(V))
+ LV.markConstant(C); // Constants are constant
+ // All others are unknown by default.
return LV;
}
- /// getStructValueState - Return the LatticeVal object that corresponds to the
- /// value/field pair. This function handles the case when the value hasn't
- /// been seen yet by properly seeding constants etc.
- LatticeVal &getStructValueState(Value *V, unsigned i) {
+ /// getStructValueState - Return the ValueLatticeElement object that
+ /// corresponds to the value/field pair. This function handles the case when
+ /// the value hasn't been seen yet by properly seeding constants etc.
+ ValueLatticeElement &getStructValueState(Value *V, unsigned i) {
assert(V->getType()->isStructTy() && "Should use getValueState");
assert(i < cast<StructType>(V->getType())->getNumElements() &&
"Invalid element #");
- std::pair<DenseMap<std::pair<Value*, unsigned>, LatticeVal>::iterator,
- bool> I = StructValueState.insert(
- std::make_pair(std::make_pair(V, i), LatticeVal()));
- LatticeVal &LV = I.first->second;
+ auto I = StructValueState.insert(
+ std::make_pair(std::make_pair(V, i), ValueLatticeElement()));
+ ValueLatticeElement &LV = I.first->second;
if (!I.second)
return LV; // Common case, already in the map.
@@ -589,9 +524,20 @@ private:
// Mark I's users as changed, including AdditionalUsers.
void markUsersAsChanged(Value *I) {
- for (User *U : I->users())
- if (auto *UI = dyn_cast<Instruction>(U))
- OperandChangedState(UI);
+ // Functions include their arguments in the use-list. Changed function
+ // values mean that the result of the function changed. We only need to
+ // update the call sites with the new function result and do not have to
+ // propagate the call arguments.
+ if (isa<Function>(I)) {
+ for (User *U : I->users()) {
+ if (auto *CB = dyn_cast<CallBase>(U))
+ handleCallResult(*CB);
+ }
+ } else {
+ for (User *U : I->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ OperandChangedState(UI);
+ }
auto Iter = AdditionalUsers.find(I);
if (Iter != AdditionalUsers.end()) {
@@ -600,6 +546,9 @@ private:
OperandChangedState(UI);
}
}
+ void handleCallOverdefined(CallBase &CB);
+ void handleCallResult(CallBase &CB);
+ void handleCallArguments(CallBase &CB);
private:
friend class InstVisitor<SCCPSolver>;
@@ -634,20 +583,20 @@ private:
void visitGetElementPtrInst(GetElementPtrInst &I);
void visitCallInst (CallInst &I) {
- visitCallSite(&I);
+ visitCallBase(I);
}
void visitInvokeInst (InvokeInst &II) {
- visitCallSite(&II);
+ visitCallBase(II);
visitTerminator(II);
}
void visitCallBrInst (CallBrInst &CBI) {
- visitCallSite(&CBI);
+ visitCallBase(CBI);
visitTerminator(CBI);
}
- void visitCallSite (CallSite CS);
+ void visitCallBase (CallBase &CB);
void visitResumeInst (ResumeInst &I) { /*returns void*/ }
void visitUnreachableInst(UnreachableInst &I) { /*returns void*/ }
void visitFenceInst (FenceInst &I) { /*returns void*/ }
@@ -673,12 +622,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,
return;
}
- LatticeVal BCValue = getValueState(BI->getCondition());
- ConstantInt *CI = BCValue.getConstantInt();
+ ValueLatticeElement BCValue = getValueState(BI->getCondition());
+ ConstantInt *CI = getConstantInt(BCValue);
if (!CI) {
// Overdefined condition variables, and branches on unfoldable constant
// conditions, mean the branch could go either way.
- if (!BCValue.isUnknown())
+ if (!BCValue.isUnknownOrUndef())
Succs[0] = Succs[1] = true;
return;
}
@@ -699,12 +648,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,
Succs[0] = true;
return;
}
- LatticeVal SCValue = getValueState(SI->getCondition());
- ConstantInt *CI = SCValue.getConstantInt();
+ ValueLatticeElement SCValue = getValueState(SI->getCondition());
+ ConstantInt *CI = getConstantInt(SCValue);
if (!CI) { // Overdefined or unknown condition?
// All destinations are executable!
- if (!SCValue.isUnknown())
+ if (!SCValue.isUnknownOrUndef())
Succs.assign(TI.getNumSuccessors(), true);
return;
}
@@ -717,11 +666,11 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,
// the target as executable.
if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {
// Casts are folded by visitCastInst.
- LatticeVal IBRValue = getValueState(IBR->getAddress());
- BlockAddress *Addr = IBRValue.getBlockAddress();
+ ValueLatticeElement IBRValue = getValueState(IBR->getAddress());
+ BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue));
if (!Addr) { // Overdefined or unknown condition?
// All destinations are executable!
- if (!IBRValue.isUnknown())
+ if (!IBRValue.isUnknownOrUndef())
Succs.assign(TI.getNumSuccessors(), true);
return;
}
@@ -786,50 +735,43 @@ void SCCPSolver::visitPHINode(PHINode &PN) {
return (void)markOverdefined(&PN);
if (getValueState(&PN).isOverdefined())
- return; // Quick exit
+ return; // Quick exit
// Super-extra-high-degree PHI nodes are unlikely to ever be marked constant,
// and slow us down a lot. Just mark them overdefined.
if (PN.getNumIncomingValues() > 64)
return (void)markOverdefined(&PN);
+ unsigned NumActiveIncoming = 0;
+
// Look at all of the executable operands of the PHI node. If any of them
// are overdefined, the PHI becomes overdefined as well. If they are all
// constant, and they agree with each other, the PHI becomes the identical
- // constant. If they are constant and don't agree, the PHI is overdefined.
- // If there are no executable operands, the PHI remains unknown.
- Constant *OperandVal = nullptr;
+ // constant. If they are constant and don't agree, the PHI is a constant
+ // range. If there are no executable operands, the PHI remains unknown.
+ ValueLatticeElement PhiState = getValueState(&PN);
for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
- LatticeVal IV = getValueState(PN.getIncomingValue(i));
- if (IV.isUnknown()) continue; // Doesn't influence PHI node.
-
if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent()))
continue;
- if (IV.isOverdefined()) // PHI node becomes overdefined!
- return (void)markOverdefined(&PN);
-
- if (!OperandVal) { // Grab the first value.
- OperandVal = IV.getConstant();
- continue;
- }
-
- // There is already a reachable operand. If we conflict with it,
- // then the PHI node becomes overdefined. If we agree with it, we
- // can continue on.
-
- // Check to see if there are two different constants merging, if so, the PHI
- // node is overdefined.
- if (IV.getConstant() != OperandVal)
- return (void)markOverdefined(&PN);
- }
-
- // If we exited the loop, this means that the PHI node only has constant
- // arguments that agree with each other(and OperandVal is the constant) or
- // OperandVal is null because there are no defined incoming arguments. If
- // this is the case, the PHI remains unknown.
- if (OperandVal)
- markConstant(&PN, OperandVal); // Acquire operand value
+ ValueLatticeElement IV = getValueState(PN.getIncomingValue(i));
+ PhiState.mergeIn(IV);
+ NumActiveIncoming++;
+ if (PhiState.isOverdefined())
+ break;
+ }
+
+ // We allow up to 1 range extension per active incoming value and one
+ // additional extension. Note that we manually adjust the number of range
+ // extensions to match the number of active incoming values. This helps to
+ // limit multiple extensions caused by the same incoming value, if other
+ // incoming values are equal.
+ mergeInValue(&PN, PhiState,
+ ValueLatticeElement::MergeOptions().setMaxWidenSteps(
+ NumActiveIncoming + 1));
+ ValueLatticeElement &PhiStateRef = getValueState(&PN);
+ PhiStateRef.setNumRangeExtensions(
+ std::max(NumActiveIncoming, PhiStateRef.getNumRangeExtensions()));
}
void SCCPSolver::visitReturnInst(ReturnInst &I) {
@@ -840,8 +782,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) {
// If we are tracking the return value of this function, merge it in.
if (!TrackedRetVals.empty() && !ResultOp->getType()->isStructTy()) {
- MapVector<Function*, LatticeVal>::iterator TFRVI =
- TrackedRetVals.find(F);
+ auto TFRVI = TrackedRetVals.find(F);
if (TFRVI != TrackedRetVals.end()) {
mergeInValue(TFRVI->second, F, getValueState(ResultOp));
return;
@@ -871,18 +812,28 @@ void SCCPSolver::visitTerminator(Instruction &TI) {
}
void SCCPSolver::visitCastInst(CastInst &I) {
- LatticeVal OpSt = getValueState(I.getOperand(0));
- if (OpSt.isOverdefined()) // Inherit overdefinedness of operand
- markOverdefined(&I);
- else if (OpSt.isConstant()) {
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (ValueState[&I].isOverdefined())
+ return;
+
+ ValueLatticeElement OpSt = getValueState(I.getOperand(0));
+ if (Constant *OpC = getConstant(OpSt)) {
// Fold the constant as we build.
- Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpSt.getConstant(),
- I.getType(), DL);
+ Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);
if (isa<UndefValue>(C))
return;
// Propagate constant value
markConstant(&I, C);
- }
+ } else if (OpSt.isConstantRange() && I.getDestTy()->isIntegerTy()) {
+ auto &LV = getValueState(&I);
+ ConstantRange OpRange = OpSt.getConstantRange();
+ Type *DestTy = I.getDestTy();
+ ConstantRange Res =
+ OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy));
+ mergeInValue(LV, &I, ValueLatticeElement::getRange(Res));
+ } else if (!OpSt.isUnknownOrUndef())
+ markOverdefined(&I);
}
void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {
@@ -891,6 +842,11 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {
if (EVI.getType()->isStructTy())
return (void)markOverdefined(&EVI);
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (ValueState[&EVI].isOverdefined())
+ return (void)markOverdefined(&EVI);
+
// If this is extracting from more than one level of struct, we don't know.
if (EVI.getNumIndices() != 1)
return (void)markOverdefined(&EVI);
@@ -898,7 +854,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {
Value *AggVal = EVI.getAggregateOperand();
if (AggVal->getType()->isStructTy()) {
unsigned i = *EVI.idx_begin();
- LatticeVal EltVal = getStructValueState(AggVal, i);
+ ValueLatticeElement EltVal = getStructValueState(AggVal, i);
mergeInValue(getValueState(&EVI), &EVI, EltVal);
} else {
// Otherwise, must be extracting from an array.
@@ -911,6 +867,11 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {
if (!STy)
return (void)markOverdefined(&IVI);
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (isOverdefined(ValueState[&IVI]))
+ return (void)markOverdefined(&IVI);
+
// If this has more than one index, we can't handle it, drive all results to
// undef.
if (IVI.getNumIndices() != 1)
@@ -923,7 +884,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
// This passes through all values that aren't the inserted element.
if (i != Idx) {
- LatticeVal EltVal = getStructValueState(Aggr, i);
+ ValueLatticeElement EltVal = getStructValueState(Aggr, i);
mergeInValue(getStructValueState(&IVI, i), &IVI, EltVal);
continue;
}
@@ -933,7 +894,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {
// We don't track structs in structs.
markOverdefined(getStructValueState(&IVI, i), &IVI);
else {
- LatticeVal InVal = getValueState(Val);
+ ValueLatticeElement InVal = getValueState(Val);
mergeInValue(getStructValueState(&IVI, i), &IVI, InVal);
}
}
@@ -945,11 +906,16 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {
if (I.getType()->isStructTy())
return (void)markOverdefined(&I);
- LatticeVal CondValue = getValueState(I.getCondition());
- if (CondValue.isUnknown())
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (ValueState[&I].isOverdefined())
+ return (void)markOverdefined(&I);
+
+ ValueLatticeElement CondValue = getValueState(I.getCondition());
+ if (CondValue.isUnknownOrUndef())
return;
- if (ConstantInt *CondCB = CondValue.getConstantInt()) {
+ if (ConstantInt *CondCB = getConstantInt(CondValue)) {
Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue();
mergeInValue(&I, getValueState(OpVal));
return;
@@ -958,30 +924,27 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {
// Otherwise, the condition is overdefined or a constant we can't evaluate.
// See if we can produce something better than overdefined based on the T/F
// value.
- LatticeVal TVal = getValueState(I.getTrueValue());
- LatticeVal FVal = getValueState(I.getFalseValue());
-
- // select ?, C, C -> C.
- if (TVal.isConstant() && FVal.isConstant() &&
- TVal.getConstant() == FVal.getConstant())
- return (void)markConstant(&I, FVal.getConstant());
-
- if (TVal.isUnknown()) // select ?, undef, X -> X.
- return (void)mergeInValue(&I, FVal);
- if (FVal.isUnknown()) // select ?, X, undef -> X.
- return (void)mergeInValue(&I, TVal);
- markOverdefined(&I);
+ ValueLatticeElement TVal = getValueState(I.getTrueValue());
+ ValueLatticeElement FVal = getValueState(I.getFalseValue());
+
+ bool Changed = ValueState[&I].mergeIn(TVal);
+ Changed |= ValueState[&I].mergeIn(FVal);
+ if (Changed)
+ pushToWorkListMsg(ValueState[&I], &I);
}
// Handle Unary Operators.
void SCCPSolver::visitUnaryOperator(Instruction &I) {
- LatticeVal V0State = getValueState(I.getOperand(0));
+ ValueLatticeElement V0State = getValueState(I.getOperand(0));
- LatticeVal &IV = ValueState[&I];
- if (IV.isOverdefined()) return;
+ ValueLatticeElement &IV = ValueState[&I];
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (isOverdefined(IV))
+ return (void)markOverdefined(&I);
- if (V0State.isConstant()) {
- Constant *C = ConstantExpr::get(I.getOpcode(), V0State.getConstant());
+ if (isConstant(V0State)) {
+ Constant *C = ConstantExpr::get(I.getOpcode(), getConstant(V0State));
// op Y -> undef.
if (isa<UndefValue>(C))
@@ -990,7 +953,7 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) {
}
// If something is undef, wait for it to resolve.
- if (!V0State.isOverdefined())
+ if (!isOverdefined(V0State))
return;
markOverdefined(&I);
@@ -998,101 +961,90 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) {
// Handle Binary Operators.
void SCCPSolver::visitBinaryOperator(Instruction &I) {
- LatticeVal V1State = getValueState(I.getOperand(0));
- LatticeVal V2State = getValueState(I.getOperand(1));
+ ValueLatticeElement V1State = getValueState(I.getOperand(0));
+ ValueLatticeElement V2State = getValueState(I.getOperand(1));
- LatticeVal &IV = ValueState[&I];
- if (IV.isOverdefined()) return;
-
- if (V1State.isConstant() && V2State.isConstant()) {
- Constant *C = ConstantExpr::get(I.getOpcode(), V1State.getConstant(),
- V2State.getConstant());
- // X op Y -> undef.
- if (isa<UndefValue>(C))
- return;
- return (void)markConstant(IV, &I, C);
- }
+ ValueLatticeElement &IV = ValueState[&I];
+ if (IV.isOverdefined())
+ return;
// If something is undef, wait for it to resolve.
- if (!V1State.isOverdefined() && !V2State.isOverdefined())
+ if (V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef())
return;
- // Otherwise, one of our operands is overdefined. Try to produce something
- // better than overdefined with some tricks.
- // If this is 0 / Y, it doesn't matter that the second operand is
- // overdefined, and we can replace it with zero.
- if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv)
- if (V1State.isConstant() && V1State.getConstant()->isNullValue())
- return (void)markConstant(IV, &I, V1State.getConstant());
-
- // If this is:
- // -> AND/MUL with 0
- // -> OR with -1
- // it doesn't matter that the other operand is overdefined.
- if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul ||
- I.getOpcode() == Instruction::Or) {
- LatticeVal *NonOverdefVal = nullptr;
- if (!V1State.isOverdefined())
- NonOverdefVal = &V1State;
- else if (!V2State.isOverdefined())
- NonOverdefVal = &V2State;
-
- if (NonOverdefVal) {
- if (NonOverdefVal->isUnknown())
- return;
+ if (V1State.isOverdefined() && V2State.isOverdefined())
+ return (void)markOverdefined(&I);
- if (I.getOpcode() == Instruction::And ||
- I.getOpcode() == Instruction::Mul) {
- // X and 0 = 0
- // X * 0 = 0
- if (NonOverdefVal->getConstant()->isNullValue())
- return (void)markConstant(IV, &I, NonOverdefVal->getConstant());
- } else {
- // X or -1 = -1
- if (ConstantInt *CI = NonOverdefVal->getConstantInt())
- if (CI->isMinusOne())
- return (void)markConstant(IV, &I, NonOverdefVal->getConstant());
- }
+ // If either of the operands is a constant, try to fold it to a constant.
+ // TODO: Use information from notconstant better.
+ if ((V1State.isConstant() || V2State.isConstant())) {
+ Value *V1 = isConstant(V1State) ? getConstant(V1State) : I.getOperand(0);
+ Value *V2 = isConstant(V2State) ? getConstant(V2State) : I.getOperand(1);
+ Value *R = SimplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL));
+ auto *C = dyn_cast_or_null<Constant>(R);
+ if (C) {
+ // X op Y -> undef.
+ if (isa<UndefValue>(C))
+ return;
+ // Conservatively assume that the result may be based on operands that may
+ // be undef. Note that we use mergeInValue to combine the constant with
+ // the existing lattice value for I, as different constants might be found
+ // after one of the operands go to overdefined, e.g. due to one operand
+ // being a special floating value.
+ ValueLatticeElement NewV;
+ NewV.markConstant(C, /*MayIncludeUndef=*/true);
+ return (void)mergeInValue(&I, NewV);
}
}
- markOverdefined(&I);
+ // Only use ranges for binary operators on integers.
+ if (!I.getType()->isIntegerTy())
+ return markOverdefined(&I);
+
+ // Try to simplify to a constant range.
+ ConstantRange A = ConstantRange::getFull(I.getType()->getScalarSizeInBits());
+ ConstantRange B = ConstantRange::getFull(I.getType()->getScalarSizeInBits());
+ if (V1State.isConstantRange())
+ A = V1State.getConstantRange();
+ if (V2State.isConstantRange())
+ B = V2State.getConstantRange();
+
+ ConstantRange R = A.binaryOp(cast<BinaryOperator>(&I)->getOpcode(), B);
+ mergeInValue(&I, ValueLatticeElement::getRange(R));
+
+ // TODO: Currently we do not exploit special values that produce something
+ // better than overdefined with an overdefined operand for vector or floating
+ // point types, like and <4 x i32> overdefined, zeroinitializer.
}
// Handle ICmpInst instruction.
void SCCPSolver::visitCmpInst(CmpInst &I) {
// Do not cache this lookup, getValueState calls later in the function might
// invalidate the reference.
- if (ValueState[&I].isOverdefined()) return;
+ if (isOverdefined(ValueState[&I]))
+ return (void)markOverdefined(&I);
Value *Op1 = I.getOperand(0);
Value *Op2 = I.getOperand(1);
// For parameters, use ParamState which includes constant range info if
// available.
- auto V1Param = ParamState.find(Op1);
- ValueLatticeElement V1State = (V1Param != ParamState.end())
- ? V1Param->second
- : getValueState(Op1).toValueLattice();
-
- auto V2Param = ParamState.find(Op2);
- ValueLatticeElement V2State = V2Param != ParamState.end()
- ? V2Param->second
- : getValueState(Op2).toValueLattice();
+ auto V1State = getValueState(Op1);
+ auto V2State = getValueState(Op2);
Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State);
if (C) {
if (isa<UndefValue>(C))
return;
- LatticeVal CV;
+ ValueLatticeElement CV;
CV.markConstant(C);
mergeInValue(&I, CV);
return;
}
// If operands are still unknown, wait for it to resolve.
- if (!V1State.isOverdefined() && !V2State.isOverdefined() &&
- !ValueState[&I].isConstant())
+ if ((V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef()) &&
+ !isConstant(ValueState[&I]))
return;
markOverdefined(&I);
@@ -1101,21 +1053,26 @@ void SCCPSolver::visitCmpInst(CmpInst &I) {
// Handle getelementptr instructions. If all operands are constants then we
// can turn this into a getelementptr ConstantExpr.
void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) {
- if (ValueState[&I].isOverdefined()) return;
+ if (isOverdefined(ValueState[&I]))
+ return (void)markOverdefined(&I);
SmallVector<Constant*, 8> Operands;
Operands.reserve(I.getNumOperands());
for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) {
- LatticeVal State = getValueState(I.getOperand(i));
- if (State.isUnknown())
+ ValueLatticeElement State = getValueState(I.getOperand(i));
+ if (State.isUnknownOrUndef())
return; // Operands are not resolved yet.
- if (State.isOverdefined())
+ if (isOverdefined(State))
return (void)markOverdefined(&I);
- assert(State.isConstant() && "Unknown state!");
- Operands.push_back(State.getConstant());
+ if (Constant *C = getConstant(State)) {
+ Operands.push_back(C);
+ continue;
+ }
+
+ return (void)markOverdefined(&I);
}
Constant *Ptr = Operands[0];
@@ -1136,230 +1093,297 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) {
return;
GlobalVariable *GV = cast<GlobalVariable>(SI.getOperand(1));
- DenseMap<GlobalVariable*, LatticeVal>::iterator I = TrackedGlobals.find(GV);
- if (I == TrackedGlobals.end() || I->second.isOverdefined()) return;
+ auto I = TrackedGlobals.find(GV);
+ if (I == TrackedGlobals.end())
+ return;
// Get the value we are storing into the global, then merge it.
- mergeInValue(I->second, GV, getValueState(SI.getOperand(0)));
+ mergeInValue(I->second, GV, getValueState(SI.getOperand(0)),
+ ValueLatticeElement::MergeOptions().setCheckWiden(false));
if (I->second.isOverdefined())
TrackedGlobals.erase(I); // No need to keep tracking this!
}
+static ValueLatticeElement getValueFromMetadata(const Instruction *I) {
+ if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
+ if (I->getType()->isIntegerTy())
+ return ValueLatticeElement::getRange(
+ getConstantRangeFromMetadata(*Ranges));
+ // TODO: Also handle MD_nonnull.
+ return ValueLatticeElement::getOverdefined();
+}
+
// Handle load instructions. If the operand is a constant pointer to a constant
// global, we can replace the load with the loaded constant value!
void SCCPSolver::visitLoadInst(LoadInst &I) {
- // If this load is of a struct, just mark the result overdefined.
- if (I.getType()->isStructTy())
+ // If this load is of a struct or the load is volatile, just mark the result
+ // as overdefined.
+ if (I.getType()->isStructTy() || I.isVolatile())
return (void)markOverdefined(&I);
- LatticeVal PtrVal = getValueState(I.getOperand(0));
- if (PtrVal.isUnknown()) return; // The pointer is not resolved yet!
-
- LatticeVal &IV = ValueState[&I];
- if (IV.isOverdefined()) return;
+ // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (ValueState[&I].isOverdefined())
+ return (void)markOverdefined(&I);
- if (!PtrVal.isConstant() || I.isVolatile())
- return (void)markOverdefined(IV, &I);
+ ValueLatticeElement PtrVal = getValueState(I.getOperand(0));
+ if (PtrVal.isUnknownOrUndef())
+ return; // The pointer is not resolved yet!
- Constant *Ptr = PtrVal.getConstant();
+ ValueLatticeElement &IV = ValueState[&I];
- // load null is undefined.
- if (isa<ConstantPointerNull>(Ptr)) {
- if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace()))
- return (void)markOverdefined(IV, &I);
- else
- return;
- }
+ if (isConstant(PtrVal)) {
+ Constant *Ptr = getConstant(PtrVal);
- // Transform load (constant global) into the value loaded.
- if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) {
- if (!TrackedGlobals.empty()) {
- // If we are tracking this global, merge in the known value for it.
- DenseMap<GlobalVariable*, LatticeVal>::iterator It =
- TrackedGlobals.find(GV);
- if (It != TrackedGlobals.end()) {
- mergeInValue(IV, &I, It->second);
+ // load null is undefined.
+ if (isa<ConstantPointerNull>(Ptr)) {
+ if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace()))
+ return (void)markOverdefined(IV, &I);
+ else
return;
+ }
+
+ // Transform load (constant global) into the value loaded.
+ if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) {
+ if (!TrackedGlobals.empty()) {
+ // If we are tracking this global, merge in the known value for it.
+ auto It = TrackedGlobals.find(GV);
+ if (It != TrackedGlobals.end()) {
+ mergeInValue(IV, &I, It->second, getMaxWidenStepsOpts());
+ return;
+ }
}
}
- }
- // Transform load from a constant into a constant if possible.
- if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) {
- if (isa<UndefValue>(C))
- return;
- return (void)markConstant(IV, &I, C);
+ // Transform load from a constant into a constant if possible.
+ if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) {
+ if (isa<UndefValue>(C))
+ return;
+ return (void)markConstant(IV, &I, C);
+ }
}
- // Otherwise we cannot say for certain what value this load will produce.
- // Bail out.
- markOverdefined(IV, &I);
+ // Fall back to metadata.
+ mergeInValue(&I, getValueFromMetadata(&I));
}
-void SCCPSolver::visitCallSite(CallSite CS) {
- Function *F = CS.getCalledFunction();
- Instruction *I = CS.getInstruction();
+void SCCPSolver::visitCallBase(CallBase &CB) {
+ handleCallResult(CB);
+ handleCallArguments(CB);
+}
- if (auto *II = dyn_cast<IntrinsicInst>(I)) {
- if (II->getIntrinsicID() == Intrinsic::ssa_copy) {
- if (ValueState[I].isOverdefined())
+void SCCPSolver::handleCallOverdefined(CallBase &CB) {
+ Function *F = CB.getCalledFunction();
+
+ // Void return and not tracking callee, just bail.
+ if (CB.getType()->isVoidTy())
+ return;
+
+ // Always mark struct return as overdefined.
+ if (CB.getType()->isStructTy())
+ return (void)markOverdefined(&CB);
+
+ // Otherwise, if we have a single return value case, and if the function is
+ // a declaration, maybe we can constant fold it.
+ if (F && F->isDeclaration() && canConstantFoldCallTo(&CB, F)) {
+ SmallVector<Constant *, 8> Operands;
+ for (auto AI = CB.arg_begin(), E = CB.arg_end(); AI != E; ++AI) {
+ if (AI->get()->getType()->isStructTy())
+ return markOverdefined(&CB); // Can't handle struct args.
+ ValueLatticeElement State = getValueState(*AI);
+
+ if (State.isUnknownOrUndef())
+ return; // Operands are not resolved yet.
+ if (isOverdefined(State))
+ return (void)markOverdefined(&CB);
+ assert(isConstant(State) && "Unknown state!");
+ Operands.push_back(getConstant(State));
+ }
+
+ if (isOverdefined(getValueState(&CB)))
+ return (void)markOverdefined(&CB);
+
+ // If we can constant fold this, mark the result of the call as a
+ // constant.
+ if (Constant *C = ConstantFoldCall(&CB, F, Operands, &GetTLI(*F))) {
+ // call -> undef.
+ if (isa<UndefValue>(C))
return;
+ return (void)markConstant(&CB, C);
+ }
+ }
+
+ // Fall back to metadata.
+ mergeInValue(&CB, getValueFromMetadata(&CB));
+}
+
+void SCCPSolver::handleCallArguments(CallBase &CB) {
+ Function *F = CB.getCalledFunction();
+ // If this is a local function that doesn't have its address taken, mark its
+ // entry block executable and merge in the actual arguments to the call into
+ // the formal arguments of the function.
+ if (!TrackingIncomingArguments.empty() &&
+ TrackingIncomingArguments.count(F)) {
+ MarkBlockExecutable(&F->front());
+
+ // Propagate information from this call site into the callee.
+ auto CAI = CB.arg_begin();
+ for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E;
+ ++AI, ++CAI) {
+ // If this argument is byval, and if the function is not readonly, there
+ // will be an implicit copy formed of the input aggregate.
+ if (AI->hasByValAttr() && !F->onlyReadsMemory()) {
+ markOverdefined(&*AI);
+ continue;
+ }
+
+ if (auto *STy = dyn_cast<StructType>(AI->getType())) {
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ ValueLatticeElement CallArg = getStructValueState(*CAI, i);
+ mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg,
+ getMaxWidenStepsOpts());
+ }
+ } else
+ mergeInValue(&*AI, getValueState(*CAI), getMaxWidenStepsOpts());
+ }
+ }
+}
+
+void SCCPSolver::handleCallResult(CallBase &CB) {
+ Function *F = CB.getCalledFunction();
- auto *PI = getPredicateInfoFor(I);
- if (!PI)
+ if (auto *II = dyn_cast<IntrinsicInst>(&CB)) {
+ if (II->getIntrinsicID() == Intrinsic::ssa_copy) {
+ if (ValueState[&CB].isOverdefined())
return;
- Value *CopyOf = I->getOperand(0);
- auto *PBranch = dyn_cast<PredicateBranch>(PI);
- if (!PBranch) {
- mergeInValue(ValueState[I], I, getValueState(CopyOf));
+ Value *CopyOf = CB.getOperand(0);
+ ValueLatticeElement CopyOfVal = getValueState(CopyOf);
+ auto *PI = getPredicateInfoFor(&CB);
+ assert(PI && "Missing predicate info for ssa.copy");
+
+ CmpInst *Cmp;
+ bool TrueEdge;
+ if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
+ Cmp = dyn_cast<CmpInst>(PBranch->Condition);
+ TrueEdge = PBranch->TrueEdge;
+ } else if (auto *PAssume = dyn_cast<PredicateAssume>(PI)) {
+ Cmp = dyn_cast<CmpInst>(PAssume->Condition);
+ TrueEdge = true;
+ } else {
+ mergeInValue(ValueState[&CB], &CB, CopyOfVal);
return;
}
- Value *Cond = PBranch->Condition;
-
// Everything below relies on the condition being a comparison.
- auto *Cmp = dyn_cast<CmpInst>(Cond);
if (!Cmp) {
- mergeInValue(ValueState[I], I, getValueState(CopyOf));
+ mergeInValue(ValueState[&CB], &CB, CopyOfVal);
return;
}
+ Value *RenamedOp = PI->RenamedOp;
Value *CmpOp0 = Cmp->getOperand(0);
Value *CmpOp1 = Cmp->getOperand(1);
- if (CopyOf != CmpOp0 && CopyOf != CmpOp1) {
- mergeInValue(ValueState[I], I, getValueState(CopyOf));
+ // Bail out if neither of the operands matches RenamedOp.
+ if (CmpOp0 != RenamedOp && CmpOp1 != RenamedOp) {
+ mergeInValue(ValueState[&CB], &CB, getValueState(CopyOf));
return;
}
- if (CmpOp0 != CopyOf)
+ auto Pred = Cmp->getPredicate();
+ if (CmpOp1 == RenamedOp) {
std::swap(CmpOp0, CmpOp1);
+ Pred = Cmp->getSwappedPredicate();
+ }
- LatticeVal OriginalVal = getValueState(CopyOf);
- LatticeVal EqVal = getValueState(CmpOp1);
- LatticeVal &IV = ValueState[I];
- if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) {
- addAdditionalUser(CmpOp1, I);
- if (OriginalVal.isConstant())
- mergeInValue(IV, I, OriginalVal);
- else
- mergeInValue(IV, I, EqVal);
+ // Wait until CmpOp1 is resolved.
+ if (getValueState(CmpOp1).isUnknown()) {
+ addAdditionalUser(CmpOp1, &CB);
return;
}
- if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) {
- addAdditionalUser(CmpOp1, I);
- if (OriginalVal.isConstant())
- mergeInValue(IV, I, OriginalVal);
- else
- mergeInValue(IV, I, EqVal);
+
+ // The code below relies on PredicateInfo only inserting copies for the
+ // true branch when the branch condition is an AND and only inserting
+ // copies for the false branch when the branch condition is an OR. This
+ // ensures we can intersect the range from the condition with the range of
+ // CopyOf.
+ if (!TrueEdge)
+ Pred = CmpInst::getInversePredicate(Pred);
+
+ ValueLatticeElement CondVal = getValueState(CmpOp1);
+ ValueLatticeElement &IV = ValueState[&CB];
+ if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
+ auto ImposedCR =
+ ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType()));
+
+ // Get the range imposed by the condition.
+ if (CondVal.isConstantRange())
+ ImposedCR = ConstantRange::makeAllowedICmpRegion(
+ Pred, CondVal.getConstantRange());
+
+ // Combine range info for the original value with the new range from the
+ // condition.
+ auto CopyOfCR = CopyOfVal.isConstantRange()
+ ? CopyOfVal.getConstantRange()
+ : ConstantRange::getFull(
+ DL.getTypeSizeInBits(CopyOf->getType()));
+ auto NewCR = ImposedCR.intersectWith(CopyOfCR);
+ // If the existing information is != x, do not use the information from
+ // a chained predicate, as the != x information is more likely to be
+ // helpful in practice.
+ if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement())
+ NewCR = CopyOfCR;
+
+ addAdditionalUser(CmpOp1, &CB);
+ // TODO: Actually filp MayIncludeUndef for the created range to false,
+ // once most places in the optimizer respect the branches on
+ // undef/poison are UB rule. The reason why the new range cannot be
+ // undef is as follows below:
+ // The new range is based on a branch condition. That guarantees that
+ // neither of the compare operands can be undef in the branch targets,
+ // unless we have conditions that are always true/false (e.g. icmp ule
+ // i32, %a, i32_max). For the latter overdefined/empty range will be
+ // inferred, but the branch will get folded accordingly anyways.
+ mergeInValue(
+ IV, &CB,
+ ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef=*/true));
+ return;
+ } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) {
+ // For non-integer values or integer constant expressions, only
+ // propagate equal constants.
+ addAdditionalUser(CmpOp1, &CB);
+ mergeInValue(IV, &CB, CondVal);
return;
}
- return (void)mergeInValue(IV, I, getValueState(CopyOf));
+ return (void)mergeInValue(IV, &CB, CopyOfVal);
}
}
// The common case is that we aren't tracking the callee, either because we
// are not doing interprocedural analysis or the callee is indirect, or is
// external. Handle these cases first.
- if (!F || F->isDeclaration()) {
-CallOverdefined:
- // Void return and not tracking callee, just bail.
- if (I->getType()->isVoidTy()) return;
-
- // Otherwise, if we have a single return value case, and if the function is
- // a declaration, maybe we can constant fold it.
- if (F && F->isDeclaration() && !I->getType()->isStructTy() &&
- canConstantFoldCallTo(cast<CallBase>(CS.getInstruction()), F)) {
- SmallVector<Constant*, 8> Operands;
- for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end();
- AI != E; ++AI) {
- if (AI->get()->getType()->isStructTy())
- return markOverdefined(I); // Can't handle struct args.
- LatticeVal State = getValueState(*AI);
-
- if (State.isUnknown())
- return; // Operands are not resolved yet.
- if (State.isOverdefined())
- return (void)markOverdefined(I);
- assert(State.isConstant() && "Unknown state!");
- Operands.push_back(State.getConstant());
- }
-
- if (getValueState(I).isOverdefined())
- return;
-
- // If we can constant fold this, mark the result of the call as a
- // constant.
- if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F,
- Operands, &GetTLI(*F))) {
- // call -> undef.
- if (isa<UndefValue>(C))
- return;
- return (void)markConstant(I, C);
- }
- }
-
- // Otherwise, we don't know anything about this call, mark it overdefined.
- return (void)markOverdefined(I);
- }
-
- // If this is a local function that doesn't have its address taken, mark its
- // entry block executable and merge in the actual arguments to the call into
- // the formal arguments of the function.
- if (!TrackingIncomingArguments.empty() && TrackingIncomingArguments.count(F)){
- MarkBlockExecutable(&F->front());
-
- // Propagate information from this call site into the callee.
- CallSite::arg_iterator CAI = CS.arg_begin();
- for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end();
- AI != E; ++AI, ++CAI) {
- // If this argument is byval, and if the function is not readonly, there
- // will be an implicit copy formed of the input aggregate.
- if (AI->hasByValAttr() && !F->onlyReadsMemory()) {
- markOverdefined(&*AI);
- continue;
- }
-
- if (auto *STy = dyn_cast<StructType>(AI->getType())) {
- for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- LatticeVal CallArg = getStructValueState(*CAI, i);
- mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg);
- }
- } else {
- // Most other parts of the Solver still only use the simpler value
- // lattice, so we propagate changes for parameters to both lattices.
- LatticeVal ConcreteArgument = getValueState(*CAI);
- bool ParamChanged =
- getParamState(&*AI).mergeIn(ConcreteArgument.toValueLattice(), DL);
- bool ValueChanged = mergeInValue(&*AI, ConcreteArgument);
- // Add argument to work list, if the state of a parameter changes but
- // ValueState does not change (because it is already overdefined there),
- // We have to take changes in ParamState into account, as it is used
- // when evaluating Cmp instructions.
- if (!ValueChanged && ParamChanged)
- pushToWorkList(ValueState[&*AI], &*AI);
- }
- }
- }
+ if (!F || F->isDeclaration())
+ return handleCallOverdefined(CB);
// If this is a single/zero retval case, see if we're tracking the function.
if (auto *STy = dyn_cast<StructType>(F->getReturnType())) {
if (!MRVFunctionsTracked.count(F))
- goto CallOverdefined; // Not tracking this callee.
+ return handleCallOverdefined(CB); // Not tracking this callee.
// If we are tracking this callee, propagate the result of the function
// into this call site.
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i)
- mergeInValue(getStructValueState(I, i), I,
- TrackedMultipleRetVals[std::make_pair(F, i)]);
+ mergeInValue(getStructValueState(&CB, i), &CB,
+ TrackedMultipleRetVals[std::make_pair(F, i)],
+ getMaxWidenStepsOpts());
} else {
- MapVector<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F);
+ auto TFRVI = TrackedRetVals.find(F);
if (TFRVI == TrackedRetVals.end())
- goto CallOverdefined; // Not tracking this callee.
+ return handleCallOverdefined(CB); // Not tracking this callee.
// If so, propagate the return value of the callee into this call result.
- mergeInValue(I, TFRVI->second);
+ mergeInValue(&CB, TFRVI->second, getMaxWidenStepsOpts());
}
}
@@ -1429,10 +1453,8 @@ void SCCPSolver::Solve() {
/// constraints on the condition of the branch, as that would impact other users
/// of the value.
///
-/// This scan also checks for values that use undefs, whose results are actually
-/// defined. For example, 'zext i8 undef to i32' should produce all zeros
-/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero,
-/// even if X isn't defined.
+/// This scan also checks for values that use undefs. It conservatively marks
+/// them as overdefined.
bool SCCPSolver::ResolvedUndefsIn(Function &F) {
for (BasicBlock &BB : F) {
if (!BBExecutable.count(&BB))
@@ -1446,8 +1468,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
// Only a few things that can be structs matter for undef.
// Tracked calls must never be marked overdefined in ResolvedUndefsIn.
- if (CallSite CS = CallSite(&I))
- if (Function *F = CS.getCalledFunction())
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
if (MRVFunctionsTracked.count(F))
continue;
@@ -1455,19 +1477,18 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
// tracked as precisely as their operands.
if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
continue;
-
// Send the results of everything else to overdefined. We could be
// more precise than this but it isn't worth bothering.
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- LatticeVal &LV = getStructValueState(&I, i);
- if (LV.isUnknown())
+ ValueLatticeElement &LV = getStructValueState(&I, i);
+ if (LV.isUnknownOrUndef())
markOverdefined(LV, &I);
}
continue;
}
- LatticeVal &LV = getValueState(&I);
- if (!LV.isUnknown())
+ ValueLatticeElement &LV = getValueState(&I);
+ if (!LV.isUnknownOrUndef())
continue;
// There are two reasons a call can have an undef result
@@ -1475,195 +1496,20 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
// 2. It could be constant-foldable.
// Because of the way we solve return values, tracked calls must
// never be marked overdefined in ResolvedUndefsIn.
- if (CallSite CS = CallSite(&I)) {
- if (Function *F = CS.getCalledFunction())
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
if (TrackedRetVals.count(F))
continue;
- // If the call is constant-foldable, we mark it overdefined because
- // we do not know what return values are valid.
- markOverdefined(&I);
- return true;
- }
-
- // extractvalue is safe; check here because the argument is a struct.
- if (isa<ExtractValueInst>(I))
- continue;
-
- // Compute the operand LatticeVals, for convenience below.
- // Anything taking a struct is conservatively assumed to require
- // overdefined markings.
- if (I.getOperand(0)->getType()->isStructTy()) {
- markOverdefined(&I);
- return true;
- }
- LatticeVal Op0LV = getValueState(I.getOperand(0));
- LatticeVal Op1LV;
- if (I.getNumOperands() == 2) {
- if (I.getOperand(1)->getType()->isStructTy()) {
- markOverdefined(&I);
- return true;
- }
-
- Op1LV = getValueState(I.getOperand(1));
- }
- // If this is an instructions whose result is defined even if the input is
- // not fully defined, propagate the information.
- Type *ITy = I.getType();
- switch (I.getOpcode()) {
- case Instruction::Add:
- case Instruction::Sub:
- case Instruction::Trunc:
- case Instruction::FPTrunc:
- case Instruction::BitCast:
- break; // Any undef -> undef
- case Instruction::FSub:
- case Instruction::FAdd:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- // Floating-point binary operation: be conservative.
- if (Op0LV.isUnknown() && Op1LV.isUnknown())
- markForcedConstant(&I, Constant::getNullValue(ITy));
- else
- markOverdefined(&I);
- return true;
- case Instruction::FNeg:
- break; // fneg undef -> undef
- case Instruction::ZExt:
- case Instruction::SExt:
- case Instruction::FPToUI:
- case Instruction::FPToSI:
- case Instruction::FPExt:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
- case Instruction::SIToFP:
- case Instruction::UIToFP:
- // undef -> 0; some outputs are impossible
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- case Instruction::Mul:
- case Instruction::And:
- // Both operands undef -> undef
- if (Op0LV.isUnknown() && Op1LV.isUnknown())
- break;
- // undef * X -> 0. X could be zero.
- // undef & X -> 0. X could be zero.
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- case Instruction::Or:
- // Both operands undef -> undef
- if (Op0LV.isUnknown() && Op1LV.isUnknown())
- break;
- // undef | X -> -1. X could be -1.
- markForcedConstant(&I, Constant::getAllOnesValue(ITy));
- return true;
- case Instruction::Xor:
- // undef ^ undef -> 0; strictly speaking, this is not strictly
- // necessary, but we try to be nice to people who expect this
- // behavior in simple cases
- if (Op0LV.isUnknown() && Op1LV.isUnknown()) {
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- }
- // undef ^ X -> undef
- break;
- case Instruction::SDiv:
- case Instruction::UDiv:
- case Instruction::SRem:
- case Instruction::URem:
- // X / undef -> undef. No change.
- // X % undef -> undef. No change.
- if (Op1LV.isUnknown()) break;
-
- // X / 0 -> undef. No change.
- // X % 0 -> undef. No change.
- if (Op1LV.isConstant() && Op1LV.getConstant()->isZeroValue())
- break;
-
- // undef / X -> 0. X could be maxint.
- // undef % X -> 0. X could be 1.
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- case Instruction::AShr:
- // X >>a undef -> undef.
- if (Op1LV.isUnknown()) break;
-
- // Shifting by the bitwidth or more is undefined.
- if (Op1LV.isConstant()) {
- if (auto *ShiftAmt = Op1LV.getConstantInt())
- if (ShiftAmt->getLimitedValue() >=
- ShiftAmt->getType()->getScalarSizeInBits())
- break;
- }
-
- // undef >>a X -> 0
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- case Instruction::LShr:
- case Instruction::Shl:
- // X << undef -> undef.
- // X >> undef -> undef.
- if (Op1LV.isUnknown()) break;
-
- // Shifting by the bitwidth or more is undefined.
- if (Op1LV.isConstant()) {
- if (auto *ShiftAmt = Op1LV.getConstantInt())
- if (ShiftAmt->getLimitedValue() >=
- ShiftAmt->getType()->getScalarSizeInBits())
- break;
- }
-
- // undef << X -> 0
- // undef >> X -> 0
- markForcedConstant(&I, Constant::getNullValue(ITy));
- return true;
- case Instruction::Select:
- Op1LV = getValueState(I.getOperand(1));
- // undef ? X : Y -> X or Y. There could be commonality between X/Y.
- if (Op0LV.isUnknown()) {
- if (!Op1LV.isConstant()) // Pick the constant one if there is any.
- Op1LV = getValueState(I.getOperand(2));
- } else if (Op1LV.isUnknown()) {
- // c ? undef : undef -> undef. No change.
- Op1LV = getValueState(I.getOperand(2));
- if (Op1LV.isUnknown())
- break;
- // Otherwise, c ? undef : x -> x.
- } else {
- // Leave Op1LV as Operand(1)'s LatticeValue.
- }
-
- if (Op1LV.isConstant())
- markForcedConstant(&I, Op1LV.getConstant());
- else
- markOverdefined(&I);
- return true;
- case Instruction::Load:
+ if (isa<LoadInst>(I)) {
// A load here means one of two things: a load of undef from a global,
// a load from an unknown pointer. Either way, having it return undef
// is okay.
- break;
- case Instruction::ICmp:
- // X == undef -> undef. Other comparisons get more complicated.
- Op0LV = getValueState(I.getOperand(0));
- Op1LV = getValueState(I.getOperand(1));
-
- if ((Op0LV.isUnknown() || Op1LV.isUnknown()) &&
- cast<ICmpInst>(&I)->isEquality())
- break;
- markOverdefined(&I);
- return true;
- case Instruction::Call:
- case Instruction::Invoke:
- case Instruction::CallBr:
- llvm_unreachable("Call-like instructions should have be handled early");
- default:
- // If we don't know what should happen here, conservatively mark it
- // overdefined.
- markOverdefined(&I);
- return true;
+ continue;
}
+
+ markOverdefined(&I);
+ return true;
}
// Check to see if we have a branch or switch on an undefined value. If so
@@ -1672,7 +1518,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
Instruction *TI = BB.getTerminator();
if (auto *BI = dyn_cast<BranchInst>(TI)) {
if (!BI->isConditional()) continue;
- if (!getValueState(BI->getCondition()).isUnknown())
+ if (!getValueState(BI->getCondition()).isUnknownOrUndef())
continue;
// If the input to SCCP is actually branch on undef, fix the undef to
@@ -1700,7 +1546,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
if (IBR->getNumSuccessors() < 1)
continue;
- if (!getValueState(IBR->getAddress()).isUnknown())
+ if (!getValueState(IBR->getAddress()).isUnknownOrUndef())
continue;
// If the input to SCCP is actually branch on undef, fix the undef to
@@ -1724,7 +1570,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
}
if (auto *SI = dyn_cast<SwitchInst>(TI)) {
- if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown())
+ if (!SI->getNumCases() ||
+ !getValueState(SI->getCondition()).isUnknownOrUndef())
continue;
// If the input to SCCP is actually switch on undef, fix the undef to
@@ -1753,25 +1600,26 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
Constant *Const = nullptr;
if (V->getType()->isStructTy()) {
- std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V);
- if (llvm::any_of(IVs,
- [](const LatticeVal &LV) { return LV.isOverdefined(); }))
+ std::vector<ValueLatticeElement> IVs = Solver.getStructLatticeValueFor(V);
+ if (any_of(IVs,
+ [](const ValueLatticeElement &LV) { return isOverdefined(LV); }))
return false;
std::vector<Constant *> ConstVals;
auto *ST = cast<StructType>(V->getType());
for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
- LatticeVal V = IVs[i];
- ConstVals.push_back(V.isConstant()
- ? V.getConstant()
+ ValueLatticeElement V = IVs[i];
+ ConstVals.push_back(isConstant(V)
+ ? Solver.getConstant(V)
: UndefValue::get(ST->getElementType(i)));
}
Const = ConstantStruct::get(ST, ConstVals);
} else {
- const LatticeVal &IV = Solver.getLatticeValueFor(V);
- if (IV.isOverdefined())
+ const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
+ if (isOverdefined(IV))
return false;
- Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType());
+ Const =
+ isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
}
assert(Const && "Constant is nullptr here!");
@@ -1779,8 +1627,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
// unless the call itself can be removed
CallInst *CI = dyn_cast<CallInst>(V);
if (CI && CI->isMustTailCall() && !CI->isSafeToRemove()) {
- CallSite CS(CI);
- Function *F = CS.getCalledFunction();
+ Function *F = CI->getCalledFunction();
// Don't zap returns of the callee
if (F)
@@ -1798,13 +1645,49 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
return true;
}
+static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB,
+ SmallPtrSetImpl<Value *> &InsertedValues,
+ Statistic &InstRemovedStat,
+ Statistic &InstReplacedStat) {
+ bool MadeChanges = false;
+ for (Instruction &Inst : make_early_inc_range(BB)) {
+ if (Inst.getType()->isVoidTy())
+ continue;
+ if (tryToReplaceWithConstant(Solver, &Inst)) {
+ if (Inst.isSafeToRemove())
+ Inst.eraseFromParent();
+ // Hey, we just changed something!
+ MadeChanges = true;
+ ++InstRemovedStat;
+ } else if (isa<SExtInst>(&Inst)) {
+ Value *ExtOp = Inst.getOperand(0);
+ if (isa<Constant>(ExtOp) || InsertedValues.count(ExtOp))
+ continue;
+ const ValueLatticeElement &IV = Solver.getLatticeValueFor(ExtOp);
+ if (!IV.isConstantRange(/*UndefAllowed=*/false))
+ continue;
+ if (IV.getConstantRange().isAllNonNegative()) {
+ auto *ZExt = new ZExtInst(ExtOp, Inst.getType(), "", &Inst);
+ InsertedValues.insert(ZExt);
+ Inst.replaceAllUsesWith(ZExt);
+ Solver.removeLatticeValueFor(&Inst);
+ Inst.eraseFromParent();
+ InstReplacedStat++;
+ MadeChanges = true;
+ }
+ }
+ }
+ return MadeChanges;
+}
+
// runSCCP() - Run the Sparse Conditional Constant Propagation algorithm,
// and return true if the function was modified.
static bool runSCCP(Function &F, const DataLayout &DL,
const TargetLibraryInfo *TLI) {
LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n");
SCCPSolver Solver(
- DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; });
+ DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; },
+ F.getContext());
// Mark the first block of the function as being executable.
Solver.MarkBlockExecutable(&F.front());
@@ -1827,6 +1710,7 @@ static bool runSCCP(Function &F, const DataLayout &DL,
// delete their contents now. Note that we cannot actually delete the blocks,
// as we cannot modify the CFG of the function.
+ SmallPtrSet<Value *, 32> InsertedValues;
for (BasicBlock &BB : F) {
if (!Solver.isBlockExecutable(&BB)) {
LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB);
@@ -1838,21 +1722,8 @@ static bool runSCCP(Function &F, const DataLayout &DL,
continue;
}
- // Iterate over all of the instructions in a function, replacing them with
- // constants if we have found them to be of constant values.
- for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
- Instruction *Inst = &*BI++;
- if (Inst->getType()->isVoidTy() || Inst->isTerminator())
- continue;
-
- if (tryToReplaceWithConstant(Solver, Inst)) {
- if (isInstructionTriviallyDead(Inst))
- Inst->eraseFromParent();
- // Hey, we just changed something!
- MadeChanges = true;
- ++NumInstRemoved;
- }
- }
+ MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues,
+ NumInstRemoved, NumInstReplaced);
}
return MadeChanges;
@@ -1942,14 +1813,15 @@ static void findReturnsToZap(Function &F,
// uses (like blockaddresses) could stuck around, without being
// used in the underlying IR, meaning we do not have lattice
// values for them.
- if (!CallSite(U))
+ if (!isa<CallBase>(U))
return true;
if (U->getType()->isStructTy()) {
- return all_of(
- Solver.getStructLatticeValueFor(U),
- [](const LatticeVal &LV) { return !LV.isOverdefined(); });
+ return all_of(Solver.getStructLatticeValueFor(U),
+ [](const ValueLatticeElement &LV) {
+ return !isOverdefined(LV);
+ });
}
- return !Solver.getLatticeValueFor(U).isOverdefined();
+ return !isOverdefined(Solver.getLatticeValueFor(U));
}) &&
"We can only zap functions where all live users have a concrete value");
@@ -2006,7 +1878,7 @@ bool llvm::runIPSCCP(
Module &M, const DataLayout &DL,
std::function<const TargetLibraryInfo &(Function &)> GetTLI,
function_ref<AnalysisResultsForFn(Function &)> getAnalysis) {
- SCCPSolver Solver(DL, GetTLI);
+ SCCPSolver Solver(DL, GetTLI, M.getContext());
// Loop over all functions, marking arguments to those with their addresses
// taken or that are external as overdefined.
@@ -2080,30 +1952,21 @@ bool llvm::runIPSCCP(
}
}
- for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
- if (!Solver.isBlockExecutable(&*BB)) {
- LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << *BB);
+ SmallPtrSet<Value *, 32> InsertedValues;
+ for (BasicBlock &BB : F) {
+ if (!Solver.isBlockExecutable(&BB)) {
+ LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB);
++NumDeadBlocks;
MadeChanges = true;
- if (&*BB != &F.front())
- BlocksToErase.push_back(&*BB);
+ if (&BB != &F.front())
+ BlocksToErase.push_back(&BB);
continue;
}
- for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) {
- Instruction *Inst = &*BI++;
- if (Inst->getType()->isVoidTy())
- continue;
- if (tryToReplaceWithConstant(Solver, Inst)) {
- if (Inst->isSafeToRemove())
- Inst->eraseFromParent();
- // Hey, we just changed something!
- MadeChanges = true;
- ++IPNumInstRemoved;
- }
- }
+ MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues,
+ IPNumInstRemoved, IPNumInstReplaced);
}
DomTreeUpdater DTU = Solver.getDTU(F);
@@ -2189,10 +2052,9 @@ bool llvm::runIPSCCP(
// whether other functions are optimizable.
SmallVector<ReturnInst*, 8> ReturnsToZap;
- const MapVector<Function*, LatticeVal> &RV = Solver.getTrackedRetVals();
- for (const auto &I : RV) {
+ for (const auto &I : Solver.getTrackedRetVals()) {
Function *F = I.first;
- if (I.second.isOverdefined() || F->getReturnType()->isVoidTy())
+ if (isOverdefined(I.second) || F->getReturnType()->isVoidTy())
continue;
findReturnsToZap(*F, ReturnsToZap, Solver);
}
@@ -2213,17 +2075,16 @@ bool llvm::runIPSCCP(
// If we inferred constant or undef values for globals variables, we can
// delete the global and any stores that remain to it.
- const DenseMap<GlobalVariable*, LatticeVal> &TG = Solver.getTrackedGlobals();
- for (DenseMap<GlobalVariable*, LatticeVal>::const_iterator I = TG.begin(),
- E = TG.end(); I != E; ++I) {
- GlobalVariable *GV = I->first;
- assert(!I->second.isOverdefined() &&
- "Overdefined values should have been taken out of the map!");
+ for (auto &I : make_early_inc_range(Solver.getTrackedGlobals())) {
+ GlobalVariable *GV = I.first;
+ if (isOverdefined(I.second))
+ continue;
LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName()
<< "' is constant!\n");
while (!GV->use_empty()) {
StoreInst *SI = cast<StoreInst>(GV->user_back());
SI->eraseFromParent();
+ MadeChanges = true;
}
M.getGlobalList().erase(GV);
++IPNumGlobalConst;
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 89916e43fce2..89f324deef9f 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -94,11 +94,6 @@
#include <utility>
#include <vector>
-#ifndef NDEBUG
-// We only use this for a debug check.
-#include <random>
-#endif
-
using namespace llvm;
using namespace llvm::sroa;
@@ -115,11 +110,6 @@ STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion");
STATISTIC(NumDeleted, "Number of instructions deleted");
STATISTIC(NumVectorized, "Number of vectorized aggregates");
-/// Hidden option to enable randomly shuffling the slices to help uncover
-/// instability in their order.
-static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices",
- cl::init(false), cl::Hidden);
-
/// Hidden option to experiment with completely strict handling of inbounds
/// GEPs.
static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false),
@@ -129,7 +119,7 @@ namespace {
/// A custom IRBuilder inserter which prefixes all names, but only in
/// Assert builds.
-class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter {
+class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter {
std::string Prefix;
const Twine getNameWithPrefix(const Twine &Name) const {
@@ -139,9 +129,8 @@ class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter {
public:
void SetNamePrefix(const Twine &P) { Prefix = P.str(); }
-protected:
void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB,
- BasicBlock::iterator InsertPt) const {
+ BasicBlock::iterator InsertPt) const override {
IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB,
InsertPt);
}
@@ -663,7 +652,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
public:
SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)
: PtrUseVisitor<SliceBuilder>(DL),
- AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {}
+ AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize()),
+ AS(AS) {}
private:
void markAsDead(Instruction &I) {
@@ -752,8 +742,10 @@ private:
// For array or vector indices, scale the index by the size of the
// type.
APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth());
- GEPOffset += Index * APInt(Offset.getBitWidth(),
- DL.getTypeAllocSize(GTI.getIndexedType()));
+ GEPOffset +=
+ Index *
+ APInt(Offset.getBitWidth(),
+ DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize());
}
// If this index has computed an intermediate pointer which is not
@@ -788,7 +780,7 @@ private:
LI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
return PI.setAborted(&LI);
- uint64_t Size = DL.getTypeStoreSize(LI.getType());
+ uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedSize();
return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile());
}
@@ -803,7 +795,7 @@ private:
SI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
return PI.setAborted(&SI);
- uint64_t Size = DL.getTypeStoreSize(ValOp->getType());
+ uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedSize();
// If this memory access can be shown to *statically* extend outside the
// bounds of the allocation, it's behavior is undefined, so simply
@@ -1069,17 +1061,9 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
llvm::remove_if(Slices, [](const Slice &S) { return S.isDead(); }),
Slices.end());
-#ifndef NDEBUG
- if (SROARandomShuffleSlices) {
- std::mt19937 MT(static_cast<unsigned>(
- std::chrono::system_clock::now().time_since_epoch().count()));
- std::shuffle(Slices.begin(), Slices.end(), MT);
- }
-#endif
-
// Sort the uses. This arranges for the offsets to be in ascending order,
// and the sizes to be in descending order.
- llvm::sort(Slices);
+ std::stable_sort(Slices.begin(), Slices.end());
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -1200,7 +1184,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) {
// TODO: Allow recursive phi users.
// TODO: Allow stores.
BasicBlock *BB = PN.getParent();
- MaybeAlign MaxAlign;
+ Align MaxAlign;
uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType());
APInt MaxSize(APWidth, 0);
bool HaveLoad = false;
@@ -1221,8 +1205,8 @@ static bool isSafePHIToSpeculate(PHINode &PN) {
if (BBI->mayWriteToMemory())
return false;
- uint64_t Size = DL.getTypeStoreSize(LI->getType());
- MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment()));
+ uint64_t Size = DL.getTypeStoreSize(LI->getType()).getFixedSize();
+ MaxAlign = std::max(MaxAlign, LI->getAlign());
MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize;
HaveLoad = true;
}
@@ -1273,7 +1257,7 @@ static void speculatePHINodeLoads(PHINode &PN) {
// matter which one we get and if any differ.
AAMDNodes AATags;
SomeLoad->getAAMetadata(AATags);
- const MaybeAlign Align = MaybeAlign(SomeLoad->getAlignment());
+ Align Alignment = SomeLoad->getAlign();
// Rewrite all loads of the PN to use the new PHI.
while (!PN.use_empty()) {
@@ -1300,11 +1284,10 @@ static void speculatePHINodeLoads(PHINode &PN) {
Instruction *TI = Pred->getTerminator();
IRBuilderTy PredBuilder(TI);
- LoadInst *Load = PredBuilder.CreateLoad(
- LoadTy, InVal,
+ LoadInst *Load = PredBuilder.CreateAlignedLoad(
+ LoadTy, InVal, Alignment,
(PN.getName() + ".sroa.speculate.load." + Pred->getName()));
++NumLoadsSpeculated;
- Load->setAlignment(Align);
if (AATags)
Load->setAAMetadata(AATags);
NewPN->addIncoming(Load, Pred);
@@ -1342,10 +1325,10 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) {
// absolutely (e.g. allocas) or at this point because we can see other
// accesses to it.
if (!isSafeToLoadUnconditionally(TValue, LI->getType(),
- MaybeAlign(LI->getAlignment()), DL, LI))
+ LI->getAlign(), DL, LI))
return false;
if (!isSafeToLoadUnconditionally(FValue, LI->getType(),
- MaybeAlign(LI->getAlignment()), DL, LI))
+ LI->getAlign(), DL, LI))
return false;
}
@@ -1371,8 +1354,8 @@ static void speculateSelectInstLoads(SelectInst &SI) {
NumLoadsSpeculated += 2;
// Transfer alignment and AA info if present.
- TL->setAlignment(MaybeAlign(LI->getAlignment()));
- FL->setAlignment(MaybeAlign(LI->getAlignment()));
+ TL->setAlignment(LI->getAlign());
+ FL->setAlignment(LI->getAlign());
AAMDNodes Tags;
LI->getAAMetadata(Tags);
@@ -1479,14 +1462,15 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
// extremely poorly defined currently. The long-term goal is to remove GEPing
// over a vector from the IR completely.
if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) {
- unsigned ElementSizeInBits = DL.getTypeSizeInBits(VecTy->getScalarType());
+ unsigned ElementSizeInBits =
+ DL.getTypeSizeInBits(VecTy->getScalarType()).getFixedSize();
if (ElementSizeInBits % 8 != 0) {
// GEPs over non-multiple of 8 size vector elements are invalid.
return nullptr;
}
APInt ElementSize(Offset.getBitWidth(), ElementSizeInBits / 8);
APInt NumSkippedElements = Offset.sdiv(ElementSize);
- if (NumSkippedElements.ugt(VecTy->getNumElements()))
+ if (NumSkippedElements.ugt(cast<FixedVectorType>(VecTy)->getNumElements()))
return nullptr;
Offset -= NumSkippedElements * ElementSize;
Indices.push_back(IRB.getInt(NumSkippedElements));
@@ -1496,7 +1480,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {
Type *ElementTy = ArrTy->getElementType();
- APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy));
+ APInt ElementSize(Offset.getBitWidth(),
+ DL.getTypeAllocSize(ElementTy).getFixedSize());
APInt NumSkippedElements = Offset.sdiv(ElementSize);
if (NumSkippedElements.ugt(ArrTy->getNumElements()))
return nullptr;
@@ -1518,7 +1503,7 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
unsigned Index = SL->getElementContainingOffset(StructOffset);
Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index));
Type *ElementTy = STy->getElementType(Index);
- if (Offset.uge(DL.getTypeAllocSize(ElementTy)))
+ if (Offset.uge(DL.getTypeAllocSize(ElementTy).getFixedSize()))
return nullptr; // The offset points into alignment padding.
Indices.push_back(IRB.getInt32(Index));
@@ -1550,7 +1535,8 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL,
Type *ElementTy = Ty->getElementType();
if (!ElementTy->isSized())
return nullptr; // We can't GEP through an unsized element.
- APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy));
+ APInt ElementSize(Offset.getBitWidth(),
+ DL.getTypeAllocSize(ElementTy).getFixedSize());
if (ElementSize == 0)
return nullptr; // Zero-length arrays can't help us build a natural GEP.
APInt NumSkippedElements = Offset.sdiv(ElementSize);
@@ -1681,20 +1667,8 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr,
}
/// Compute the adjusted alignment for a load or store from an offset.
-static Align getAdjustedAlignment(Instruction *I, uint64_t Offset,
- const DataLayout &DL) {
- MaybeAlign Alignment;
- Type *Ty;
- if (auto *LI = dyn_cast<LoadInst>(I)) {
- Alignment = MaybeAlign(LI->getAlignment());
- Ty = LI->getType();
- } else if (auto *SI = dyn_cast<StoreInst>(I)) {
- Alignment = MaybeAlign(SI->getAlignment());
- Ty = SI->getValueOperand()->getType();
- } else {
- llvm_unreachable("Only loads and stores are allowed!");
- }
- return commonAlignment(DL.getValueOrABITypeAlignment(Alignment, Ty), Offset);
+static Align getAdjustedAlignment(Instruction *I, uint64_t Offset) {
+ return commonAlignment(getLoadStoreAlignment(I), Offset);
}
/// Test whether we can convert a value from the old to the new type.
@@ -1717,7 +1691,8 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {
return false;
}
- if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy))
+ if (DL.getTypeSizeInBits(NewTy).getFixedSize() !=
+ DL.getTypeSizeInBits(OldTy).getFixedSize())
return false;
if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType())
return false;
@@ -1728,8 +1703,15 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {
NewTy = NewTy->getScalarType();
if (NewTy->isPointerTy() || OldTy->isPointerTy()) {
if (NewTy->isPointerTy() && OldTy->isPointerTy()) {
- return cast<PointerType>(NewTy)->getPointerAddressSpace() ==
- cast<PointerType>(OldTy)->getPointerAddressSpace();
+ unsigned OldAS = OldTy->getPointerAddressSpace();
+ unsigned NewAS = NewTy->getPointerAddressSpace();
+ // Convert pointers if they are pointers from the same address space or
+ // different integral (not non-integral) address spaces with the same
+ // pointer size.
+ return OldAS == NewAS ||
+ (!DL.isNonIntegralAddressSpace(OldAS) &&
+ !DL.isNonIntegralAddressSpace(NewAS) &&
+ DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS));
}
// We can convert integers to integral pointers, but not to non-integral
@@ -1765,36 +1747,40 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) &&
"Integer types must be the exact same to convert.");
- // See if we need inttoptr for this type pair. A cast involving both scalars
- // and vectors requires and additional bitcast.
+ // See if we need inttoptr for this type pair. May require additional bitcast.
if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
// Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8*
- if (OldTy->isVectorTy() && !NewTy->isVectorTy())
- return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
- NewTy);
-
// Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*>
- if (!OldTy->isVectorTy() && NewTy->isVectorTy())
- return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
- NewTy);
-
- return IRB.CreateIntToPtr(V, NewTy);
+ // Expand <4 x i32> to <2 x i8*> --> <4 x i32> to <2 x i64> to <2 x i8*>
+ // Directly handle i64 to i8*
+ return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
+ NewTy);
}
- // See if we need ptrtoint for this type pair. A cast involving both scalars
- // and vectors requires and additional bitcast.
+ // See if we need ptrtoint for this type pair. May require additional bitcast.
if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) {
// Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128
- if (OldTy->isVectorTy() && !NewTy->isVectorTy())
- return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
- NewTy);
-
// Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32>
- if (!OldTy->isVectorTy() && NewTy->isVectorTy())
- return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
- NewTy);
+ // Expand <2 x i8*> to <4 x i32> --> <2 x i8*> to <2 x i64> to <4 x i32>
+ // Expand i8* to i64 --> i8* to i64 to i64
+ return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
+ NewTy);
+ }
- return IRB.CreatePtrToInt(V, NewTy);
+ if (OldTy->isPtrOrPtrVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
+ unsigned OldAS = OldTy->getPointerAddressSpace();
+ unsigned NewAS = NewTy->getPointerAddressSpace();
+ // To convert pointers with different address spaces (they are already
+ // checked convertible, i.e. they have the same pointer size), so far we
+ // cannot use `bitcast` (which has restrict on the same address space) or
+ // `addrspacecast` (which is not always no-op casting). Instead, use a pair
+ // of no-op `ptrtoint`/`inttoptr` casts through an integer with the same bit
+ // size.
+ if (OldAS != NewAS) {
+ assert(DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS));
+ return IRB.CreateIntToPtr(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
+ NewTy);
+ }
}
return IRB.CreateBitCast(V, NewTy);
@@ -1813,19 +1799,20 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S,
std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset();
uint64_t BeginIndex = BeginOffset / ElementSize;
if (BeginIndex * ElementSize != BeginOffset ||
- BeginIndex >= Ty->getNumElements())
+ BeginIndex >= cast<FixedVectorType>(Ty)->getNumElements())
return false;
uint64_t EndOffset =
std::min(S.endOffset(), P.endOffset()) - P.beginOffset();
uint64_t EndIndex = EndOffset / ElementSize;
- if (EndIndex * ElementSize != EndOffset || EndIndex > Ty->getNumElements())
+ if (EndIndex * ElementSize != EndOffset ||
+ EndIndex > cast<FixedVectorType>(Ty)->getNumElements())
return false;
assert(EndIndex > BeginIndex && "Empty vector!");
uint64_t NumElements = EndIndex - BeginIndex;
Type *SliceTy = (NumElements == 1)
? Ty->getElementType()
- : VectorType::get(Ty->getElementType(), NumElements);
+ : FixedVectorType::get(Ty->getElementType(), NumElements);
Type *SplitIntTy =
Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8);
@@ -1890,7 +1877,8 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Return if bitcast to vectors is different for total size in bits.
if (!CandidateTys.empty()) {
VectorType *V = CandidateTys[0];
- if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) {
+ if (DL.getTypeSizeInBits(VTy).getFixedSize() !=
+ DL.getTypeSizeInBits(V).getFixedSize()) {
CandidateTys.clear();
return;
}
@@ -1936,13 +1924,15 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// they're all integer vectors. We sort by ascending number of elements.
auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
(void)DL;
- assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) &&
+ assert(DL.getTypeSizeInBits(RHSTy).getFixedSize() ==
+ DL.getTypeSizeInBits(LHSTy).getFixedSize() &&
"Cannot have vector types of different sizes!");
assert(RHSTy->getElementType()->isIntegerTy() &&
"All non-integer types eliminated!");
assert(LHSTy->getElementType()->isIntegerTy() &&
"All non-integer types eliminated!");
- return RHSTy->getNumElements() < LHSTy->getNumElements();
+ return cast<FixedVectorType>(RHSTy)->getNumElements() <
+ cast<FixedVectorType>(LHSTy)->getNumElements();
};
llvm::sort(CandidateTys, RankVectorTypes);
CandidateTys.erase(
@@ -1964,13 +1954,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Try each vector type, and return the one which works.
auto CheckVectorTypeForPromotion = [&](VectorType *VTy) {
- uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType());
+ uint64_t ElementSize =
+ DL.getTypeSizeInBits(VTy->getElementType()).getFixedSize();
// While the definition of LLVM vectors is bitpacked, we don't support sizes
// that aren't byte sized.
if (ElementSize % 8)
return false;
- assert((DL.getTypeSizeInBits(VTy) % 8) == 0 &&
+ assert((DL.getTypeSizeInBits(VTy).getFixedSize() % 8) == 0 &&
"vector size not a multiple of element size?");
ElementSize /= 8;
@@ -2000,7 +1991,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
Type *AllocaTy,
const DataLayout &DL,
bool &WholeAllocaOp) {
- uint64_t Size = DL.getTypeStoreSize(AllocaTy);
+ uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedSize();
uint64_t RelBegin = S.beginOffset() - AllocBeginOffset;
uint64_t RelEnd = S.endOffset() - AllocBeginOffset;
@@ -2016,7 +2007,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
if (LI->isVolatile())
return false;
// We can't handle loads that extend past the allocated memory.
- if (DL.getTypeStoreSize(LI->getType()) > Size)
+ if (DL.getTypeStoreSize(LI->getType()).getFixedSize() > Size)
return false;
// So far, AllocaSliceRewriter does not support widening split slice tails
// in rewriteIntegerLoad.
@@ -2028,7 +2019,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size)
WholeAllocaOp = true;
if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) {
- if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy))
+ if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())
return false;
} else if (RelBegin != 0 || RelEnd != Size ||
!canConvertValue(DL, AllocaTy, LI->getType())) {
@@ -2041,7 +2032,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
if (SI->isVolatile())
return false;
// We can't handle stores that extend past the allocated memory.
- if (DL.getTypeStoreSize(ValueTy) > Size)
+ if (DL.getTypeStoreSize(ValueTy).getFixedSize() > Size)
return false;
// So far, AllocaSliceRewriter does not support widening split slice tails
// in rewriteIntegerStore.
@@ -2053,7 +2044,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size)
WholeAllocaOp = true;
if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) {
- if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy))
+ if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())
return false;
} else if (RelBegin != 0 || RelEnd != Size ||
!canConvertValue(DL, ValueTy, AllocaTy)) {
@@ -2084,13 +2075,13 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
/// promote the resulting alloca.
static bool isIntegerWideningViable(Partition &P, Type *AllocaTy,
const DataLayout &DL) {
- uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy);
+ uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedSize();
// Don't create integer types larger than the maximum bitwidth.
if (SizeInBits > IntegerType::MAX_INT_BITS)
return false;
// Don't try to handle allocas with bit-padding.
- if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy))
+ if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedSize())
return false;
// We need to ensure that an integer type with the appropriate bitwidth can
@@ -2129,11 +2120,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
const Twine &Name) {
LLVM_DEBUG(dbgs() << " start: " << *V << "\n");
IntegerType *IntTy = cast<IntegerType>(V->getType());
- assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) &&
+ assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <=
+ DL.getTypeStoreSize(IntTy).getFixedSize() &&
"Element extends past full value");
uint64_t ShAmt = 8 * Offset;
if (DL.isBigEndian())
- ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset);
+ ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() -
+ DL.getTypeStoreSize(Ty).getFixedSize() - Offset);
if (ShAmt) {
V = IRB.CreateLShr(V, ShAmt, Name + ".shift");
LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n");
@@ -2158,11 +2151,13 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,
V = IRB.CreateZExt(V, IntTy, Name + ".ext");
LLVM_DEBUG(dbgs() << " extended: " << *V << "\n");
}
- assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) &&
+ assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <=
+ DL.getTypeStoreSize(IntTy).getFixedSize() &&
"Element store outside of alloca store");
uint64_t ShAmt = 8 * Offset;
if (DL.isBigEndian())
- ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset);
+ ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() -
+ DL.getTypeStoreSize(Ty).getFixedSize() - Offset);
if (ShAmt) {
V = IRB.CreateShl(V, ShAmt, Name + ".shift");
LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n");
@@ -2180,7 +2175,7 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,
static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex,
unsigned EndIndex, const Twine &Name) {
- VectorType *VecTy = cast<VectorType>(V->getType());
+ auto *VecTy = cast<FixedVectorType>(V->getType());
unsigned NumElements = EndIndex - BeginIndex;
assert(NumElements <= VecTy->getNumElements() && "Too many elements!");
@@ -2194,12 +2189,12 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex,
return V;
}
- SmallVector<Constant *, 8> Mask;
+ SmallVector<int, 8> Mask;
Mask.reserve(NumElements);
for (unsigned i = BeginIndex; i != EndIndex; ++i)
- Mask.push_back(IRB.getInt32(i));
- V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()),
- ConstantVector::get(Mask), Name + ".extract");
+ Mask.push_back(i);
+ V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), Mask,
+ Name + ".extract");
LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n");
return V;
}
@@ -2218,21 +2213,23 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
return V;
}
- assert(Ty->getNumElements() <= VecTy->getNumElements() &&
+ assert(cast<FixedVectorType>(Ty)->getNumElements() <=
+ cast<FixedVectorType>(VecTy)->getNumElements() &&
"Too many elements!");
- if (Ty->getNumElements() == VecTy->getNumElements()) {
+ if (cast<FixedVectorType>(Ty)->getNumElements() ==
+ cast<FixedVectorType>(VecTy)->getNumElements()) {
assert(V->getType() == VecTy && "Vector type mismatch");
return V;
}
- unsigned EndIndex = BeginIndex + Ty->getNumElements();
+ unsigned EndIndex = BeginIndex + cast<FixedVectorType>(Ty)->getNumElements();
// When inserting a smaller vector into the larger to store, we first
// use a shuffle vector to widen it with undef elements, and then
// a second shuffle vector to select between the loaded vector and the
// incoming vector.
SmallVector<Constant *, 8> Mask;
- Mask.reserve(VecTy->getNumElements());
- for (unsigned i = 0; i != VecTy->getNumElements(); ++i)
+ Mask.reserve(cast<FixedVectorType>(VecTy)->getNumElements());
+ for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)
if (i >= BeginIndex && i < EndIndex)
Mask.push_back(IRB.getInt32(i - BeginIndex));
else
@@ -2242,7 +2239,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n");
Mask.clear();
- for (unsigned i = 0; i != VecTy->getNumElements(); ++i)
+ for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)
Mask.push_back(IRB.getInt1(i >= BeginIndex && i < EndIndex));
V = IRB.CreateSelect(ConstantVector::get(Mask), V, Old, Name + "blend");
@@ -2325,18 +2322,20 @@ public:
NewAllocaBeginOffset(NewAllocaBeginOffset),
NewAllocaEndOffset(NewAllocaEndOffset),
NewAllocaTy(NewAI.getAllocatedType()),
- IntTy(IsIntegerPromotable
- ? Type::getIntNTy(
- NewAI.getContext(),
- DL.getTypeSizeInBits(NewAI.getAllocatedType()))
- : nullptr),
+ IntTy(
+ IsIntegerPromotable
+ ? Type::getIntNTy(NewAI.getContext(),
+ DL.getTypeSizeInBits(NewAI.getAllocatedType())
+ .getFixedSize())
+ : nullptr),
VecTy(PromotableVecTy),
ElementTy(VecTy ? VecTy->getElementType() : nullptr),
- ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0),
+ ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8
+ : 0),
PHIUsers(PHIUsers), SelectUsers(SelectUsers),
IRB(NewAI.getContext(), ConstantFolder()) {
if (VecTy) {
- assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 &&
+ assert((DL.getTypeSizeInBits(ElementTy).getFixedSize() % 8) == 0 &&
"Only multiple-of-8 sized vector elements are viable");
++NumVectorized;
}
@@ -2368,7 +2367,8 @@ public:
Instruction *OldUserI = cast<Instruction>(OldUse->getUser());
IRB.SetInsertPoint(OldUserI);
IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc());
- IRB.SetNamePrefix(Twine(NewAI.getName()) + "." + Twine(BeginOffset) + ".");
+ IRB.getInserter().SetNamePrefix(
+ Twine(NewAI.getName()) + "." + Twine(BeginOffset) + ".");
CanSROA &= visit(cast<Instruction>(OldUse->getUser()));
if (VecTy || IntTy)
@@ -2429,14 +2429,9 @@ private:
///
/// You can optionally pass a type to this routine and if that type's ABI
/// alignment is itself suitable, this will return zero.
- MaybeAlign getSliceAlign(Type *Ty = nullptr) {
- const MaybeAlign NewAIAlign = DL.getValueOrABITypeAlignment(
- MaybeAlign(NewAI.getAlignment()), NewAI.getAllocatedType());
- const MaybeAlign Align =
- commonAlignment(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset);
- return (Ty && Align && Align->value() == DL.getABITypeAlignment(Ty))
- ? None
- : Align;
+ Align getSliceAlign() {
+ return commonAlignment(NewAI.getAlign(),
+ NewBeginOffset - NewAllocaBeginOffset);
}
unsigned getIndex(uint64_t Offset) {
@@ -2460,7 +2455,7 @@ private:
assert(EndIndex > BeginIndex && "Empty vector!");
Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "load");
+ NewAI.getAlign(), "load");
return extractVector(IRB, V, BeginIndex, EndIndex, "vec");
}
@@ -2468,7 +2463,7 @@ private:
assert(IntTy && "We cannot insert an integer to the alloca");
assert(!LI.isVolatile());
Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "load");
+ NewAI.getAlign(), "load");
V = convertValue(DL, IRB, V, IntTy);
assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");
uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
@@ -2500,7 +2495,8 @@ private:
Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)
: LI.getType();
- const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize;
+ const bool IsLoadPastEnd =
+ DL.getTypeStoreSize(TargetTy).getFixedSize() > SliceSize;
bool IsPtrAdjusted = false;
Value *V;
if (VecTy) {
@@ -2513,12 +2509,14 @@ private:
(IsLoadPastEnd && NewAllocaTy->isIntegerTy() &&
TargetTy->isIntegerTy()))) {
LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(),
- LI.isVolatile(), LI.getName());
+ NewAI.getAlign(), LI.isVolatile(),
+ LI.getName());
if (AATags)
NewLI->setAAMetadata(AATags);
if (LI.isVolatile())
NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
+ if (NewLI->isAtomic())
+ NewLI->setAlignment(LI.getAlign());
// Any !nonnull metadata or !range metadata on the old load is also valid
// on the new load. This is even true in some cases even when the loads
@@ -2549,9 +2547,9 @@ private:
}
} else {
Type *LTy = TargetTy->getPointerTo(AS);
- LoadInst *NewLI = IRB.CreateAlignedLoad(
- TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy),
- LI.isVolatile(), LI.getName());
+ LoadInst *NewLI =
+ IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy),
+ getSliceAlign(), LI.isVolatile(), LI.getName());
if (AATags)
NewLI->setAAMetadata(AATags);
if (LI.isVolatile())
@@ -2566,7 +2564,7 @@ private:
assert(!LI.isVolatile());
assert(LI.getType()->isIntegerTy() &&
"Only integer type loads and stores are split");
- assert(SliceSize < DL.getTypeStoreSize(LI.getType()) &&
+ assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedSize() &&
"Split load isn't smaller than original load");
assert(DL.typeSizeEqualsStoreSize(LI.getType()) &&
"Non-byte-multiple bit width");
@@ -2577,7 +2575,8 @@ private:
// the computed value, and then replace the placeholder with LI, leaving
// LI only used for this computation.
Value *Placeholder = new LoadInst(
- LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS)));
+ LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS)), "",
+ false, Align(1));
V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,
"insert");
LI.replaceAllUsesWith(V);
@@ -2600,19 +2599,20 @@ private:
unsigned EndIndex = getIndex(NewEndOffset);
assert(EndIndex > BeginIndex && "Empty vector!");
unsigned NumElements = EndIndex - BeginIndex;
- assert(NumElements <= VecTy->getNumElements() && "Too many elements!");
+ assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() &&
+ "Too many elements!");
Type *SliceTy = (NumElements == 1)
? ElementTy
- : VectorType::get(ElementTy, NumElements);
+ : FixedVectorType::get(ElementTy, NumElements);
if (V->getType() != SliceTy)
V = convertValue(DL, IRB, V, SliceTy);
// Mix in the existing elements.
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "load");
+ NewAI.getAlign(), "load");
V = insertVector(IRB, Old, V, BeginIndex, "vec");
}
- StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment());
+ StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());
if (AATags)
Store->setAAMetadata(AATags);
Pass.DeadInsts.insert(&SI);
@@ -2624,16 +2624,17 @@ private:
bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) {
assert(IntTy && "We cannot extract an integer from the alloca");
assert(!SI.isVolatile());
- if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) {
+ if (DL.getTypeSizeInBits(V->getType()).getFixedSize() !=
+ IntTy->getBitWidth()) {
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "oldload");
+ NewAI.getAlign(), "oldload");
Old = convertValue(DL, IRB, Old, IntTy);
assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");
uint64_t Offset = BeginOffset - NewAllocaBeginOffset;
V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert");
}
V = convertValue(DL, IRB, V, NewAllocaTy);
- StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment());
+ StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());
Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,
LLVMContext::MD_access_group});
if (AATags)
@@ -2659,7 +2660,7 @@ private:
if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets()))
Pass.PostPromotionWorklist.insert(AI);
- if (SliceSize < DL.getTypeStoreSize(V->getType())) {
+ if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedSize()) {
assert(!SI.isVolatile());
assert(V->getType()->isIntegerTy() &&
"Only integer type loads and stores are split");
@@ -2675,7 +2676,8 @@ private:
if (IntTy && V->getType()->isIntegerTy())
return rewriteIntegerStore(V, SI, AATags);
- const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize;
+ const bool IsStorePastEnd =
+ DL.getTypeStoreSize(V->getType()).getFixedSize() > SliceSize;
StoreInst *NewSI;
if (NewBeginOffset == NewAllocaBeginOffset &&
NewEndOffset == NewAllocaEndOffset &&
@@ -2695,13 +2697,13 @@ private:
}
V = convertValue(DL, IRB, V, NewAllocaTy);
- NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(),
- SI.isVolatile());
+ NewSI =
+ IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), SI.isVolatile());
} else {
unsigned AS = SI.getPointerAddressSpace();
Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS));
- NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()),
- SI.isVolatile());
+ NewSI =
+ IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile());
}
NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,
LLVMContext::MD_access_group});
@@ -2709,6 +2711,8 @@ private:
NewSI->setAAMetadata(AATags);
if (SI.isVolatile())
NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID());
+ if (NewSI->isAtomic())
+ NewSI->setAlignment(SI.getAlign());
Pass.DeadInsts.insert(&SI);
deleteIfTriviallyDead(OldOp);
@@ -2786,9 +2790,9 @@ private:
return false;
const auto Len = C->getZExtValue();
auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext());
- auto *SrcTy = VectorType::get(Int8Ty, Len);
+ auto *SrcTy = FixedVectorType::get(Int8Ty, Len);
return canConvertValue(DL, SrcTy, AllocaTy) &&
- DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy));
+ DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedSize());
}();
// If this doesn't map cleanly onto the alloca type, and that type isn't
@@ -2820,16 +2824,17 @@ private:
unsigned EndIndex = getIndex(NewEndOffset);
assert(EndIndex > BeginIndex && "Empty vector!");
unsigned NumElements = EndIndex - BeginIndex;
- assert(NumElements <= VecTy->getNumElements() && "Too many elements!");
+ assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() &&
+ "Too many elements!");
- Value *Splat =
- getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8);
+ Value *Splat = getIntegerSplat(
+ II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8);
Splat = convertValue(DL, IRB, Splat, ElementTy);
if (NumElements > 1)
Splat = getVectorSplat(Splat, NumElements);
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "oldload");
+ NewAI.getAlign(), "oldload");
V = insertVector(IRB, Old, Splat, BeginIndex, "vec");
} else if (IntTy) {
// If this is a memset on an alloca where we can widen stores, insert the
@@ -2842,7 +2847,7 @@ private:
if (IntTy && (BeginOffset != NewAllocaBeginOffset ||
EndOffset != NewAllocaBeginOffset)) {
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "oldload");
+ NewAI.getAlign(), "oldload");
Old = convertValue(DL, IRB, Old, IntTy);
uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
V = insertInteger(DL, IRB, Old, V, Offset, "insert");
@@ -2856,15 +2861,17 @@ private:
assert(NewBeginOffset == NewAllocaBeginOffset);
assert(NewEndOffset == NewAllocaEndOffset);
- V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8);
+ V = getIntegerSplat(II.getValue(),
+ DL.getTypeSizeInBits(ScalarTy).getFixedSize() / 8);
if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy))
- V = getVectorSplat(V, AllocaVecTy->getNumElements());
+ V = getVectorSplat(
+ V, cast<FixedVectorType>(AllocaVecTy)->getNumElements());
V = convertValue(DL, IRB, V, AllocaTy);
}
- StoreInst *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(),
- II.isVolatile());
+ StoreInst *New =
+ IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), II.isVolatile());
if (AATags)
New->setAAMetadata(AATags);
LLVM_DEBUG(dbgs() << " to: " << *New << "\n");
@@ -2919,7 +2926,8 @@ private:
bool EmitMemCpy =
!VecTy && !IntTy &&
(BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset ||
- SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) ||
+ SliceSize !=
+ DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedSize() ||
!NewAI.getAllocatedType()->isSingleValueType());
// If we're just going to emit a memcpy, the alloca hasn't changed, and the
@@ -2955,7 +2963,7 @@ private:
unsigned OffsetWidth = DL.getIndexSizeInBits(OtherAS);
APInt OtherOffset(OffsetWidth, NewBeginOffset - BeginOffset);
Align OtherAlign =
- assumeAligned(IsDest ? II.getSourceAlignment() : II.getDestAlignment());
+ (IsDest ? II.getSourceAlign() : II.getDestAlign()).valueOrOne();
OtherAlign =
commonAlignment(OtherAlign, OtherOffset.zextOrTrunc(64).getZExtValue());
@@ -3007,7 +3015,7 @@ private:
if (NumElements == 1)
OtherTy = VecTy->getElementType();
else
- OtherTy = VectorType::get(VecTy->getElementType(), NumElements);
+ OtherTy = FixedVectorType::get(VecTy->getElementType(), NumElements);
} else if (IntTy && !IsWholeAlloca) {
OtherTy = SubIntTy;
} else {
@@ -3028,11 +3036,11 @@ private:
Value *Src;
if (VecTy && !IsWholeAlloca && !IsDest) {
Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "load");
+ NewAI.getAlign(), "load");
Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec");
} else if (IntTy && !IsWholeAlloca && !IsDest) {
Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "load");
+ NewAI.getAlign(), "load");
Src = convertValue(DL, IRB, Src, IntTy);
uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract");
@@ -3046,11 +3054,11 @@ private:
if (VecTy && !IsWholeAlloca && IsDest) {
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "oldload");
+ NewAI.getAlign(), "oldload");
Src = insertVector(IRB, Old, Src, BeginIndex, "vec");
} else if (IntTy && !IsWholeAlloca && IsDest) {
Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
- NewAI.getAlignment(), "oldload");
+ NewAI.getAlign(), "oldload");
Old = convertValue(DL, IRB, Old, IntTy);
uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
Src = insertInteger(DL, IRB, Old, Src, Offset, "insert");
@@ -3115,17 +3123,12 @@ private:
Instruction *I = Uses.pop_back_val();
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
- MaybeAlign LoadAlign = DL.getValueOrABITypeAlignment(
- MaybeAlign(LI->getAlignment()), LI->getType());
- LI->setAlignment(std::min(LoadAlign, getSliceAlign()));
+ LI->setAlignment(std::min(LI->getAlign(), getSliceAlign()));
continue;
}
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
- Value *Op = SI->getOperand(0);
- MaybeAlign StoreAlign = DL.getValueOrABITypeAlignment(
- MaybeAlign(SI->getAlignment()), Op->getType());
- SI->setAlignment(std::min(StoreAlign, getSliceAlign()));
- continue;
+ SI->setAlignment(std::min(SI->getAlign(), getSliceAlign()));
+ continue;
}
assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) ||
@@ -3146,14 +3149,14 @@ private:
// as local as possible to the PHI. To do that, we re-use the location of
// the old pointer, which necessarily must be in the right position to
// dominate the PHI.
- IRBuilderTy PtrBuilder(IRB);
+ IRBuilderBase::InsertPointGuard Guard(IRB);
if (isa<PHINode>(OldPtr))
- PtrBuilder.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt());
+ IRB.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt());
else
- PtrBuilder.SetInsertPoint(OldPtr);
- PtrBuilder.SetCurrentDebugLocation(OldPtr->getDebugLoc());
+ IRB.SetInsertPoint(OldPtr);
+ IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc());
- Value *NewPtr = getNewAllocaSlicePtr(PtrBuilder, OldPtr->getType());
+ Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());
// Replace the operands which were using the old pointer.
std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr);
@@ -3357,7 +3360,7 @@ private:
Value *GEP =
IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");
LoadInst *Load =
- IRB.CreateAlignedLoad(Ty, GEP, Alignment.value(), Name + ".load");
+ IRB.CreateAlignedLoad(Ty, GEP, Alignment, Name + ".load");
if (AATags)
Load->setAAMetadata(AATags);
Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert");
@@ -3375,9 +3378,10 @@ private:
AAMDNodes AATags;
LI.getAAMetadata(AATags);
LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags,
- getAdjustedAlignment(&LI, 0, DL), DL);
+ getAdjustedAlignment(&LI, 0), DL);
Value *V = UndefValue::get(LI.getType());
Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca");
+ Visited.erase(&LI);
LI.replaceAllUsesWith(V);
LI.eraseFromParent();
return true;
@@ -3403,7 +3407,7 @@ private:
Value *InBoundsGEP =
IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");
StoreInst *Store =
- IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment.value());
+ IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment);
if (AATags)
Store->setAAMetadata(AATags);
LLVM_DEBUG(dbgs() << " to: " << *Store << "\n");
@@ -3422,8 +3426,9 @@ private:
AAMDNodes AATags;
SI.getAAMetadata(AATags);
StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags,
- getAdjustedAlignment(&SI, 0, DL), DL);
+ getAdjustedAlignment(&SI, 0), DL);
Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca");
+ Visited.erase(&SI);
SI.eraseFromParent();
return true;
}
@@ -3438,7 +3443,110 @@ private:
return false;
}
+ // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2)
+ bool foldGEPSelect(GetElementPtrInst &GEPI) {
+ if (!GEPI.hasAllConstantIndices())
+ return false;
+
+ SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand());
+
+ LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):"
+ << "\n original: " << *Sel
+ << "\n " << GEPI);
+
+ IRBuilderTy Builder(&GEPI);
+ SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end());
+ bool IsInBounds = GEPI.isInBounds();
+
+ Value *True = Sel->getTrueValue();
+ Value *NTrue =
+ IsInBounds
+ ? Builder.CreateInBoundsGEP(True, Index,
+ True->getName() + ".sroa.gep")
+ : Builder.CreateGEP(True, Index, True->getName() + ".sroa.gep");
+
+ Value *False = Sel->getFalseValue();
+
+ Value *NFalse =
+ IsInBounds
+ ? Builder.CreateInBoundsGEP(False, Index,
+ False->getName() + ".sroa.gep")
+ : Builder.CreateGEP(False, Index, False->getName() + ".sroa.gep");
+
+ Value *NSel = Builder.CreateSelect(Sel->getCondition(), NTrue, NFalse,
+ Sel->getName() + ".sroa.sel");
+ Visited.erase(&GEPI);
+ GEPI.replaceAllUsesWith(NSel);
+ GEPI.eraseFromParent();
+ Instruction *NSelI = cast<Instruction>(NSel);
+ Visited.insert(NSelI);
+ enqueueUsers(*NSelI);
+
+ LLVM_DEBUG(dbgs() << "\n to: " << *NTrue
+ << "\n " << *NFalse
+ << "\n " << *NSel << '\n');
+
+ return true;
+ }
+
+ // Fold gep (phi ptr1, ptr2) => phi gep(ptr1), gep(ptr2)
+ bool foldGEPPhi(GetElementPtrInst &GEPI) {
+ if (!GEPI.hasAllConstantIndices())
+ return false;
+
+ PHINode *PHI = cast<PHINode>(GEPI.getPointerOperand());
+ if (GEPI.getParent() != PHI->getParent() ||
+ llvm::any_of(PHI->incoming_values(), [](Value *In)
+ { Instruction *I = dyn_cast<Instruction>(In);
+ return !I || isa<GetElementPtrInst>(I) || isa<PHINode>(I) ||
+ succ_empty(I->getParent()) ||
+ !I->getParent()->isLegalToHoistInto();
+ }))
+ return false;
+
+ LLVM_DEBUG(dbgs() << " Rewriting gep(phi) -> phi(gep):"
+ << "\n original: " << *PHI
+ << "\n " << GEPI
+ << "\n to: ");
+
+ SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end());
+ bool IsInBounds = GEPI.isInBounds();
+ IRBuilderTy PHIBuilder(GEPI.getParent()->getFirstNonPHI());
+ PHINode *NewPN = PHIBuilder.CreatePHI(GEPI.getType(),
+ PHI->getNumIncomingValues(),
+ PHI->getName() + ".sroa.phi");
+ for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
+ Instruction *In = cast<Instruction>(PHI->getIncomingValue(I));
+
+ IRBuilderTy B(In->getParent(), std::next(In->getIterator()));
+ Value *NewVal = IsInBounds
+ ? B.CreateInBoundsGEP(In, Index, In->getName() + ".sroa.gep")
+ : B.CreateGEP(In, Index, In->getName() + ".sroa.gep");
+ NewPN->addIncoming(NewVal, PHI->getIncomingBlock(I));
+ }
+
+ Visited.erase(&GEPI);
+ GEPI.replaceAllUsesWith(NewPN);
+ GEPI.eraseFromParent();
+ Visited.insert(NewPN);
+ enqueueUsers(*NewPN);
+
+ LLVM_DEBUG(for (Value *In : NewPN->incoming_values())
+ dbgs() << "\n " << *In;
+ dbgs() << "\n " << *NewPN << '\n');
+
+ return true;
+ }
+
bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+ if (isa<SelectInst>(GEPI.getPointerOperand()) &&
+ foldGEPSelect(GEPI))
+ return true;
+
+ if (isa<PHINode>(GEPI.getPointerOperand()) &&
+ foldGEPPhi(GEPI))
+ return true;
+
enqueueUsers(GEPI);
return false;
}
@@ -3465,8 +3573,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
if (Ty->isSingleValueType())
return Ty;
- uint64_t AllocSize = DL.getTypeAllocSize(Ty);
- uint64_t TypeSize = DL.getTypeSizeInBits(Ty);
+ uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedSize();
+ uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedSize();
Type *InnerTy;
if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {
@@ -3479,8 +3587,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
return Ty;
}
- if (AllocSize > DL.getTypeAllocSize(InnerTy) ||
- TypeSize > DL.getTypeSizeInBits(InnerTy))
+ if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedSize() ||
+ TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedSize())
return Ty;
return stripAggregateTypeWrapping(DL, InnerTy);
@@ -3501,17 +3609,28 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
/// return a type if necessary.
static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
uint64_t Size) {
- if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size)
+ if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedSize() == Size)
return stripAggregateTypeWrapping(DL, Ty);
- if (Offset > DL.getTypeAllocSize(Ty) ||
- (DL.getTypeAllocSize(Ty) - Offset) < Size)
+ if (Offset > DL.getTypeAllocSize(Ty).getFixedSize() ||
+ (DL.getTypeAllocSize(Ty).getFixedSize() - Offset) < Size)
return nullptr;
- if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) {
- Type *ElementTy = SeqTy->getElementType();
- uint64_t ElementSize = DL.getTypeAllocSize(ElementTy);
+ if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) {
+ Type *ElementTy;
+ uint64_t TyNumElements;
+ if (auto *AT = dyn_cast<ArrayType>(Ty)) {
+ ElementTy = AT->getElementType();
+ TyNumElements = AT->getNumElements();
+ } else {
+ // FIXME: This isn't right for vectors with non-byte-sized or
+ // non-power-of-two sized elements.
+ auto *VT = cast<FixedVectorType>(Ty);
+ ElementTy = VT->getElementType();
+ TyNumElements = VT->getNumElements();
+ }
+ uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();
uint64_t NumSkippedElements = Offset / ElementSize;
- if (NumSkippedElements >= SeqTy->getNumElements())
+ if (NumSkippedElements >= TyNumElements)
return nullptr;
Offset -= NumSkippedElements * ElementSize;
@@ -3549,7 +3668,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
Offset -= SL->getElementOffset(Index);
Type *ElementTy = STy->getElementType(Index);
- uint64_t ElementSize = DL.getTypeAllocSize(ElementTy);
+ uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();
if (Offset >= ElementSize)
return nullptr; // The offset points into alignment padding.
@@ -3860,7 +3979,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
getAdjustedPtr(IRB, DL, BasePtr,
APInt(DL.getIndexSizeInBits(AS), PartOffset),
PartPtrTy, BasePtr->getName() + "."),
- getAdjustedAlignment(LI, PartOffset, DL).value(),
+ getAdjustedAlignment(LI, PartOffset),
/*IsVolatile*/ false, LI->getName());
PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,
LLVMContext::MD_access_group});
@@ -3918,7 +4037,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
getAdjustedPtr(IRB, DL, StoreBasePtr,
APInt(DL.getIndexSizeInBits(AS), PartOffset),
PartPtrTy, StoreBasePtr->getName() + "."),
- getAdjustedAlignment(SI, PartOffset, DL).value(),
+ getAdjustedAlignment(SI, PartOffset),
/*IsVolatile*/ false);
PStore->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,
LLVMContext::MD_access_group});
@@ -4003,7 +4122,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
getAdjustedPtr(IRB, DL, LoadBasePtr,
APInt(DL.getIndexSizeInBits(AS), PartOffset),
LoadPartPtrTy, LoadBasePtr->getName() + "."),
- getAdjustedAlignment(LI, PartOffset, DL).value(),
+ getAdjustedAlignment(LI, PartOffset),
/*IsVolatile*/ false, LI->getName());
}
@@ -4015,7 +4134,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
getAdjustedPtr(IRB, DL, StoreBasePtr,
APInt(DL.getIndexSizeInBits(AS), PartOffset),
StorePartPtrTy, StoreBasePtr->getName() + "."),
- getAdjustedAlignment(SI, PartOffset, DL).value(),
+ getAdjustedAlignment(SI, PartOffset),
/*IsVolatile*/ false);
// Now build a new slice for the alloca.
@@ -4117,7 +4236,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
Type *SliceTy = nullptr;
const DataLayout &DL = AI.getModule()->getDataLayout();
if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset()))
- if (DL.getTypeAllocSize(CommonUseTy) >= P.size())
+ if (DL.getTypeAllocSize(CommonUseTy).getFixedSize() >= P.size())
SliceTy = CommonUseTy;
if (!SliceTy)
if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(),
@@ -4129,7 +4248,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
SliceTy = Type::getIntNTy(*C, P.size() * 8);
if (!SliceTy)
SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size());
- assert(DL.getTypeAllocSize(SliceTy) >= P.size());
+ assert(DL.getTypeAllocSize(SliceTy).getFixedSize() >= P.size());
bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL);
@@ -4151,19 +4270,14 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
// FIXME: We might want to defer PHI speculation until after here.
// FIXME: return nullptr;
} else {
- // If alignment is unspecified we fallback on the one required by the ABI
- // for this type. We also make sure the alignment is compatible with
- // P.beginOffset().
- const Align Alignment = commonAlignment(
- DL.getValueOrABITypeAlignment(MaybeAlign(AI.getAlignment()),
- AI.getAllocatedType()),
- P.beginOffset());
+ // Make sure the alignment is compatible with P.beginOffset().
+ const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());
// If we will get at least this much alignment from the type alone, leave
// the alloca's alignment unconstrained.
- const bool IsUnconstrained = Alignment <= DL.getABITypeAlignment(SliceTy);
+ const bool IsUnconstrained = Alignment <= DL.getABITypeAlign(SliceTy);
NewAI = new AllocaInst(
SliceTy, AI.getType()->getAddressSpace(), nullptr,
- IsUnconstrained ? MaybeAlign() : Alignment,
+ IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment,
AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI);
// Copy the old AI debug location over to the new one.
NewAI->setDebugLoc(AI.getDebugLoc());
@@ -4270,7 +4384,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
// to be rewritten into a partition.
bool IsSorted = true;
- uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType());
+ uint64_t AllocaSize =
+ DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize();
const uint64_t MaxBitVectorSize = 1024;
if (AllocaSize <= MaxBitVectorSize) {
// If a byte boundary is included in any load or store, a slice starting or
@@ -4334,7 +4449,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
Changed = true;
if (NewAI != &AI) {
uint64_t SizeOfByte = 8;
- uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType());
+ uint64_t AllocaSize =
+ DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedSize();
// Don't include any padding.
uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte);
Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size));
@@ -4354,7 +4470,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
auto *Expr = DbgDeclares.front()->getExpression();
auto VarSize = Var->getSizeInBits();
DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false);
- uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType());
+ uint64_t AllocaSize =
+ DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedSize();
for (auto Fragment : Fragments) {
// Create a fragment expression describing the new partition or reuse AI's
// expression if there is only one partition.
@@ -4442,8 +4559,9 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
const DataLayout &DL = AI.getModule()->getDataLayout();
// Skip alloca forms that this analysis can't handle.
- if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() ||
- DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
+ auto *AT = AI.getAllocatedType();
+ if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) ||
+ DL.getTypeAllocSize(AT).getFixedSize() == 0)
return false;
bool Changed = false;
@@ -4563,8 +4681,14 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT,
BasicBlock &EntryBB = F.getEntryBlock();
for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());
I != E; ++I) {
- if (AllocaInst *AI = dyn_cast<AllocaInst>(I))
- Worklist.insert(AI);
+ if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
+ if (isa<ScalableVectorType>(AI->getAllocatedType())) {
+ if (isAllocaPromotable(AI))
+ PromotableAllocas.push_back(AI);
+ } else {
+ Worklist.insert(AI);
+ }
+ }
}
bool Changed = false;
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index c25c6c632b8f..851bd79cd6d8 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -22,8 +22,8 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/Dominators.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
@@ -41,6 +41,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstdint>
#include <iterator>
@@ -51,6 +52,11 @@ using namespace llvm;
#define DEBUG_TYPE "scalarizer"
+static cl::opt<bool> ScalarizeVariableInsertExtract(
+ "scalarize-variable-insert-extract", cl::init(true), cl::Hidden,
+ cl::desc("Allow the scalarizer pass to scalarize "
+ "insertelement/extractelement with variable index"));
+
// This is disabled by default because having separate loads and stores
// makes it more likely that the -combiner-alias-analysis limits will be
// reached.
@@ -156,8 +162,8 @@ struct VectorLayout {
VectorLayout() = default;
// Return the alignment of element I.
- uint64_t getElemAlign(unsigned I) {
- return MinAlign(VecAlign, I * ElemSize);
+ Align getElemAlign(unsigned I) {
+ return commonAlignment(VecAlign, I * ElemSize);
}
// The type of the vector.
@@ -167,7 +173,7 @@ struct VectorLayout {
Type *ElemTy = nullptr;
// The alignment of the vector.
- uint64_t VecAlign = 0;
+ Align VecAlign;
// The size of each element.
uint64_t ElemSize = 0;
@@ -192,6 +198,8 @@ public:
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
bool visitCastInst(CastInst &CI);
bool visitBitCastInst(BitCastInst &BCI);
+ bool visitInsertElementInst(InsertElementInst &IEI);
+ bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
bool visitPHINode(PHINode &PHI);
bool visitLoadInst(LoadInst &LI);
@@ -203,8 +211,8 @@ private:
void gather(Instruction *Op, const ValueVector &CV);
bool canTransferMetadata(unsigned Kind);
void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
- bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout,
- const DataLayout &DL);
+ Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
+ const DataLayout &DL);
bool finish();
template<typename T> bool splitUnary(Instruction &, const T &);
@@ -215,6 +223,8 @@ private:
ScatterMap Scattered;
GatherList Gathered;
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+
unsigned ParallelLoopAccessMDKind;
DominatorTree *DT;
@@ -252,7 +262,7 @@ Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
PtrTy = dyn_cast<PointerType>(Ty);
if (PtrTy)
Ty = PtrTy->getElementType();
- Size = Ty->getVectorNumElements();
+ Size = cast<FixedVectorType>(Ty)->getNumElements();
if (!CachePtr)
Tmp.resize(Size, nullptr);
else if (CachePtr->empty())
@@ -269,7 +279,7 @@ Value *Scatterer::operator[](unsigned I) {
return CV[I];
IRBuilder<> Builder(BB, BBI);
if (PtrTy) {
- Type *ElTy = PtrTy->getElementType()->getVectorElementType();
+ Type *ElTy = cast<VectorType>(PtrTy->getElementType())->getElementType();
if (!CV[0]) {
Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace());
CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0");
@@ -376,11 +386,6 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {
// so that we can avoid creating the gathered form if all uses of Op are
// replaced with uses of CV.
void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
- // Since we're not deleting Op yet, stub out its operands, so that it
- // doesn't make anything live unnecessarily.
- for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I)
- Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType()));
-
transferMetadataAndIRFlags(Op, CV);
// If we already have a scattered form of Op (created from ExtractElements
@@ -389,13 +394,13 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
if (!SV.empty()) {
for (unsigned I = 0, E = SV.size(); I != E; ++I) {
Value *V = SV[I];
- if (V == nullptr)
+ if (V == nullptr || SV[I] == CV[I])
continue;
Instruction *Old = cast<Instruction>(V);
CV[I]->takeName(Old);
Old->replaceAllUsesWith(CV[I]);
- Old->eraseFromParent();
+ PotentiallyDeadInstrs.emplace_back(Old);
}
}
SV = CV;
@@ -434,25 +439,22 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
}
// Try to fill in Layout from Ty, returning true on success. Alignment is
-// the alignment of the vector, or 0 if the ABI default should be used.
-bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment,
- VectorLayout &Layout, const DataLayout &DL) {
+// the alignment of the vector, or None if the ABI default should be used.
+Optional<VectorLayout>
+ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
+ const DataLayout &DL) {
+ VectorLayout Layout;
// Make sure we're dealing with a vector.
Layout.VecTy = dyn_cast<VectorType>(Ty);
if (!Layout.VecTy)
- return false;
-
+ return None;
// Check that we're dealing with full-byte elements.
Layout.ElemTy = Layout.VecTy->getElementType();
if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy))
- return false;
-
- if (Alignment)
- Layout.VecAlign = Alignment;
- else
- Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy);
+ return None;
+ Layout.VecAlign = Alignment;
Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
- return true;
+ return Layout;
}
// Scalarize one-operand instruction I, using Split(Builder, X, Name)
@@ -463,7 +465,7 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&I);
Scatterer Op = scatter(&I, I.getOperand(0));
assert(Op.size() == NumElems && "Mismatched unary operation");
@@ -483,17 +485,19 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&I);
- Scatterer Op0 = scatter(&I, I.getOperand(0));
- Scatterer Op1 = scatter(&I, I.getOperand(1));
- assert(Op0.size() == NumElems && "Mismatched binary operation");
- assert(Op1.size() == NumElems && "Mismatched binary operation");
+ Scatterer VOp0 = scatter(&I, I.getOperand(0));
+ Scatterer VOp1 = scatter(&I, I.getOperand(1));
+ assert(VOp0.size() == NumElems && "Mismatched binary operation");
+ assert(VOp1.size() == NumElems && "Mismatched binary operation");
ValueVector Res;
Res.resize(NumElems);
- for (unsigned Elem = 0; Elem < NumElems; ++Elem)
- Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem],
- I.getName() + ".i" + Twine(Elem));
+ for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
+ Value *Op0 = VOp0[Elem];
+ Value *Op1 = VOp1[Elem];
+ Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem));
+ }
gather(&I, Res);
return true;
}
@@ -524,7 +528,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
unsigned NumArgs = CI.getNumArgOperands();
ValueVector ScalarOperands(NumArgs);
@@ -574,26 +578,33 @@ bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&SI);
- Scatterer Op1 = scatter(&SI, SI.getOperand(1));
- Scatterer Op2 = scatter(&SI, SI.getOperand(2));
- assert(Op1.size() == NumElems && "Mismatched select");
- assert(Op2.size() == NumElems && "Mismatched select");
+ Scatterer VOp1 = scatter(&SI, SI.getOperand(1));
+ Scatterer VOp2 = scatter(&SI, SI.getOperand(2));
+ assert(VOp1.size() == NumElems && "Mismatched select");
+ assert(VOp2.size() == NumElems && "Mismatched select");
ValueVector Res;
Res.resize(NumElems);
if (SI.getOperand(0)->getType()->isVectorTy()) {
- Scatterer Op0 = scatter(&SI, SI.getOperand(0));
- assert(Op0.size() == NumElems && "Mismatched select");
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I],
+ Scatterer VOp0 = scatter(&SI, SI.getOperand(0));
+ assert(VOp0.size() == NumElems && "Mismatched select");
+ for (unsigned I = 0; I < NumElems; ++I) {
+ Value *Op0 = VOp0[I];
+ Value *Op1 = VOp1[I];
+ Value *Op2 = VOp2[I];
+ Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
SI.getName() + ".i" + Twine(I));
+ }
} else {
Value *Op0 = SI.getOperand(0);
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I],
+ for (unsigned I = 0; I < NumElems; ++I) {
+ Value *Op1 = VOp1[I];
+ Value *Op2 = VOp2[I];
+ Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
SI.getName() + ".i" + Twine(I));
+ }
}
gather(&SI, Res);
return true;
@@ -621,7 +632,7 @@ bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
return false;
IRBuilder<> Builder(&GEPI);
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
unsigned NumIndices = GEPI.getNumIndices();
// The base pointer might be scalar even if it's a vector GEP. In those cases,
@@ -666,7 +677,7 @@ bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&CI);
Scatterer Op0 = scatter(&CI, CI.getOperand(0));
assert(Op0.size() == NumElems && "Mismatched cast");
@@ -685,8 +696,8 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
if (!DstVT || !SrcVT)
return false;
- unsigned DstNumElems = DstVT->getNumElements();
- unsigned SrcNumElems = SrcVT->getNumElements();
+ unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements();
+ unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();
IRBuilder<> Builder(&BCI);
Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
ValueVector Res;
@@ -700,7 +711,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
// <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
// individual elements to the destination.
unsigned FanOut = DstNumElems / SrcNumElems;
- Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut);
+ auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut);
unsigned ResI = 0;
for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
Value *V = Op0[Op0I];
@@ -718,7 +729,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
} else {
// <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
unsigned FanIn = SrcNumElems / DstNumElems;
- Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn);
+ auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn);
unsigned Op0I = 0;
for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
Value *V = UndefValue::get(MidTy);
@@ -734,12 +745,79 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
return true;
}
+bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
+ VectorType *VT = dyn_cast<VectorType>(IEI.getType());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
+ IRBuilder<> Builder(&IEI);
+ Scatterer Op0 = scatter(&IEI, IEI.getOperand(0));
+ Value *NewElt = IEI.getOperand(1);
+ Value *InsIdx = IEI.getOperand(2);
+
+ ValueVector Res;
+ Res.resize(NumElems);
+
+ if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I];
+ } else {
+ if (!ScalarizeVariableInsertExtract)
+ return false;
+
+ for (unsigned I = 0; I < NumElems; ++I) {
+ Value *ShouldReplace =
+ Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
+ InsIdx->getName() + ".is." + Twine(I));
+ Value *OldElt = Op0[I];
+ Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
+ IEI.getName() + ".i" + Twine(I));
+ }
+ }
+
+ gather(&IEI, Res);
+ return true;
+}
+
+bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
+ VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType());
+ if (!VT)
+ return false;
+
+ unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements();
+ IRBuilder<> Builder(&EEI);
+ Scatterer Op0 = scatter(&EEI, EEI.getOperand(0));
+ Value *ExtIdx = EEI.getOperand(1);
+
+ if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
+ Value *Res = Op0[CI->getValue().getZExtValue()];
+ gather(&EEI, {Res});
+ return true;
+ }
+
+ if (!ScalarizeVariableInsertExtract)
+ return false;
+
+ Value *Res = UndefValue::get(VT->getElementType());
+ for (unsigned I = 0; I < NumSrcElems; ++I) {
+ Value *ShouldExtract =
+ Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
+ ExtIdx->getName() + ".is." + Twine(I));
+ Value *Elt = Op0[I];
+ Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
+ EEI.getName() + ".upto" + Twine(I));
+ }
+ gather(&EEI, {Res});
+ return true;
+}
+
bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
VectorType *VT = dyn_cast<VectorType>(SVI.getType());
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
ValueVector Res;
@@ -763,7 +841,7 @@ bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
if (!VT)
return false;
- unsigned NumElems = VT->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&PHI);
ValueVector Res;
Res.resize(NumElems);
@@ -789,20 +867,20 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
if (!LI.isSimple())
return false;
- VectorLayout Layout;
- if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout,
- LI.getModule()->getDataLayout()))
+ Optional<VectorLayout> Layout = getVectorLayout(
+ LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout());
+ if (!Layout)
return false;
- unsigned NumElems = Layout.VecTy->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
IRBuilder<> Builder(&LI);
Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
ValueVector Res;
Res.resize(NumElems);
for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreateAlignedLoad(Layout.VecTy->getElementType(), Ptr[I],
- Layout.getElemAlign(I),
+ Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I],
+ Align(Layout->getElemAlign(I)),
LI.getName() + ".i" + Twine(I));
gather(&LI, Res);
return true;
@@ -814,22 +892,23 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
if (!SI.isSimple())
return false;
- VectorLayout Layout;
Value *FullValue = SI.getValueOperand();
- if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout,
- SI.getModule()->getDataLayout()))
+ Optional<VectorLayout> Layout = getVectorLayout(
+ FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout());
+ if (!Layout)
return false;
- unsigned NumElems = Layout.VecTy->getNumElements();
+ unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
IRBuilder<> Builder(&SI);
- Scatterer Ptr = scatter(&SI, SI.getPointerOperand());
- Scatterer Val = scatter(&SI, FullValue);
+ Scatterer VPtr = scatter(&SI, SI.getPointerOperand());
+ Scatterer VVal = scatter(&SI, FullValue);
ValueVector Stores;
Stores.resize(NumElems);
for (unsigned I = 0; I < NumElems; ++I) {
- unsigned Align = Layout.getElemAlign(I);
- Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align);
+ Value *Val = VVal[I];
+ Value *Ptr = VPtr[I];
+ Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I));
}
transferMetadataAndIRFlags(&SI, Stores);
return true;
@@ -852,23 +931,32 @@ bool ScalarizerVisitor::finish() {
if (!Op->use_empty()) {
// The value is still needed, so recreate it using a series of
// InsertElements.
- Type *Ty = Op->getType();
- Value *Res = UndefValue::get(Ty);
- BasicBlock *BB = Op->getParent();
- unsigned Count = Ty->getVectorNumElements();
- IRBuilder<> Builder(Op);
- if (isa<PHINode>(Op))
- Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
- for (unsigned I = 0; I < Count; ++I)
- Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
- Op->getName() + ".upto" + Twine(I));
+ Value *Res = UndefValue::get(Op->getType());
+ if (auto *Ty = dyn_cast<VectorType>(Op->getType())) {
+ BasicBlock *BB = Op->getParent();
+ unsigned Count = cast<FixedVectorType>(Ty)->getNumElements();
+ IRBuilder<> Builder(Op);
+ if (isa<PHINode>(Op))
+ Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+ for (unsigned I = 0; I < Count; ++I)
+ Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
+ Op->getName() + ".upto" + Twine(I));
+ } else {
+ assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
+ Res = CV[0];
+ if (Op == Res)
+ continue;
+ }
Res->takeName(Op);
Op->replaceAllUsesWith(Res);
}
- Op->eraseFromParent();
+ PotentiallyDeadInstrs.emplace_back(Op);
}
Gathered.clear();
Scattered.clear();
+
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+
return true;
}
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 2a1a040bf83e..f1d2e3c1ecfa 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -431,8 +431,10 @@ private:
bool reuniteExts(Instruction *I);
/// Find the closest dominator of <Dominatee> that is equivalent to <Key>.
- Instruction *findClosestMatchingDominator(const SCEV *Key,
- Instruction *Dominatee);
+ Instruction *findClosestMatchingDominator(
+ const SCEV *Key, Instruction *Dominatee,
+ DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs);
+
/// Verify F is free of dead code.
void verifyNoDeadCode(Function &F);
@@ -456,7 +458,8 @@ private:
/// multiple GEPs with a single index.
bool LowerGEP;
- DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingExprs;
+ DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingAdds;
+ DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingSubs;
};
} // end anonymous namespace
@@ -519,7 +522,7 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
// sext(a + b) = sext(a) + sext(b)
// even if the addition is not marked nsw.
//
- // Leveraging this invarient, we can trace into an sext'ed inbound GEP
+ // Leveraging this invariant, we can trace into an sext'ed inbound GEP
// index if the constant offset is non-negative.
//
// Verified in @sext_add in split-gep.ll.
@@ -549,6 +552,9 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO,
bool SignExtended,
bool ZeroExtended) {
+ // Save off the current height of the chain, in case we need to restore it.
+ size_t ChainLength = UserChain.size();
+
// BO being non-negative does not shed light on whether its operands are
// non-negative. Clear the NonNegative flag here.
APInt ConstantOffset = find(BO->getOperand(0), SignExtended, ZeroExtended,
@@ -559,12 +565,22 @@ APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO,
// However, such cases are probably already handled by -instcombine,
// given this pass runs after the standard optimizations.
if (ConstantOffset != 0) return ConstantOffset;
+
+ // Reset the chain back to where it was when we started exploring this node,
+ // since visiting the LHS didn't pan out.
+ UserChain.resize(ChainLength);
+
ConstantOffset = find(BO->getOperand(1), SignExtended, ZeroExtended,
/* NonNegative */ false);
// If U is a sub operator, negate the constant offset found in the right
// operand.
if (BO->getOpcode() == Instruction::Sub)
ConstantOffset = -ConstantOffset;
+
+ // If RHS wasn't a suitable candidate either, reset the chain again.
+ if (ConstantOffset == 0)
+ UserChain.resize(ChainLength);
+
return ConstantOffset;
}
@@ -688,7 +704,7 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
}
BinaryOperator *BO = cast<BinaryOperator>(UserChain[ChainIndex]);
- assert(BO->getNumUses() <= 1 &&
+ assert((BO->use_empty() || BO->hasOneUse()) &&
"distributeExtsAndCloneChain clones each BinaryOperator in "
"UserChain, so no one should be used more than "
"once");
@@ -1141,7 +1157,8 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) {
}
Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator(
- const SCEV *Key, Instruction *Dominatee) {
+ const SCEV *Key, Instruction *Dominatee,
+ DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs) {
auto Pos = DominatingExprs.find(Key);
if (Pos == DominatingExprs.end())
return nullptr;
@@ -1169,12 +1186,23 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
// If Dom can't sign overflow and Dom dominates I, optimize I to sext(Dom).
// TODO: handle zext
Value *LHS = nullptr, *RHS = nullptr;
- if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS)))) ||
- match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {
+ if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {
if (LHS->getType() == RHS->getType()) {
const SCEV *Key =
SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
- if (auto *Dom = findClosestMatchingDominator(Key, I)) {
+ if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) {
+ Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I);
+ NewSExt->takeName(I);
+ I->replaceAllUsesWith(NewSExt);
+ RecursivelyDeleteTriviallyDeadInstructions(I);
+ return true;
+ }
+ }
+ } else if (match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {
+ if (LHS->getType() == RHS->getType()) {
+ const SCEV *Key =
+ SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
+ if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingSubs)) {
Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I);
NewSExt->takeName(I);
I->replaceAllUsesWith(NewSExt);
@@ -1185,12 +1213,17 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
}
// Add I to DominatingExprs if it's an add/sub that can't sign overflow.
- if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS))) ||
- match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) {
- if (programUndefinedIfFullPoison(I)) {
+ if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS)))) {
+ if (programUndefinedIfPoison(I)) {
+ const SCEV *Key =
+ SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
+ DominatingAdds[Key].push_back(I);
+ }
+ } else if (match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) {
+ if (programUndefinedIfPoison(I)) {
const SCEV *Key =
SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
- DominatingExprs[Key].push_back(I);
+ DominatingSubs[Key].push_back(I);
}
}
return false;
@@ -1198,7 +1231,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) {
bool Changed = false;
- DominatingExprs.clear();
+ DominatingAdds.clear();
+ DominatingSubs.clear();
for (const auto Node : depth_first(DT)) {
BasicBlock *BB = Node->getBlock();
for (auto I = BB->begin(); I != BB->end(); ) {
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index d7a34acb4318..6c6d6ca9cf65 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -26,7 +26,6 @@
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
-#include "llvm/Analysis/Utils/Local.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -36,6 +35,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
@@ -182,7 +182,7 @@ static void buildPartialUnswitchConditionalBranch(BasicBlock &BB,
BasicBlock &UnswitchedSucc,
BasicBlock &NormalSucc) {
IRBuilder<> IRB(&BB);
-
+
Value *Cond = Direction ? IRB.CreateOr(Invariants) :
IRB.CreateAnd(Invariants);
IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc,
@@ -598,19 +598,36 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
auto *ParentBB = SI.getParent();
+ // The same check must be used both for the default and the exit cases. We
+ // should never leave edges from the switch instruction to a basic block that
+ // we are unswitching, hence the condition used to determine the default case
+ // needs to also be used to populate ExitCaseIndices, which is then used to
+ // remove cases from the switch.
+ auto IsTriviallyUnswitchableExitBlock = [&](BasicBlock &BBToCheck) {
+ // BBToCheck is not an exit block if it is inside loop L.
+ if (L.contains(&BBToCheck))
+ return false;
+ // BBToCheck is not trivial to unswitch if its phis aren't loop invariant.
+ if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, BBToCheck))
+ return false;
+ // We do not unswitch a block that only has an unreachable statement, as
+ // it's possible this is a previously unswitched block. Only unswitch if
+ // either the terminator is not unreachable, or, if it is, it's not the only
+ // instruction in the block.
+ auto *TI = BBToCheck.getTerminator();
+ bool isUnreachable = isa<UnreachableInst>(TI);
+ return !isUnreachable ||
+ (isUnreachable && (BBToCheck.getFirstNonPHIOrDbg() != TI));
+ };
+
SmallVector<int, 4> ExitCaseIndices;
- for (auto Case : SI.cases()) {
- auto *SuccBB = Case.getCaseSuccessor();
- if (!L.contains(SuccBB) &&
- areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB))
+ for (auto Case : SI.cases())
+ if (IsTriviallyUnswitchableExitBlock(*Case.getCaseSuccessor()))
ExitCaseIndices.push_back(Case.getCaseIndex());
- }
BasicBlock *DefaultExitBB = nullptr;
SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight =
SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0);
- if (!L.contains(SI.getDefaultDest()) &&
- areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) &&
- !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) {
+ if (IsTriviallyUnswitchableExitBlock(*SI.getDefaultDest())) {
DefaultExitBB = SI.getDefaultDest();
} else if (ExitCaseIndices.empty())
return false;
@@ -1557,6 +1574,11 @@ static void deleteDeadBlocksFromLoop(Loop &L,
// Check that the dominator tree has already been updated.
assert(!DT.getNode(BB) && "Should already have cleared domtree!");
LI.changeLoopFor(BB, nullptr);
+ // Drop all uses of the instructions to make sure we won't have dangling
+ // uses in other blocks.
+ for (auto &I : *BB)
+ if (!I.use_empty())
+ I.replaceAllUsesWith(UndefValue::get(I.getType()));
BB->dropAllReferences();
}
@@ -2465,7 +2487,7 @@ turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
/// unswitch candidates, making adequate predictions instead of wild guesses.
/// That requires knowing not just the number of "remaining" candidates but
/// also costs of unswitching for each of these candidates.
-static int calculateUnswitchCostMultiplier(
+static int CalculateUnswitchCostMultiplier(
Instruction &TI, Loop &L, LoopInfo &LI, DominatorTree &DT,
ArrayRef<std::pair<Instruction *, TinyPtrVector<Value *>>>
UnswitchCandidates) {
@@ -2656,11 +2678,11 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB))
return false;
- if (auto CS = CallSite(&I))
- if (CS.isConvergent() || CS.cannotDuplicate())
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (CB->isConvergent() || CB->cannotDuplicate())
return false;
- Cost += TTI.getUserCost(&I);
+ Cost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize);
}
assert(Cost >= 0 && "Must not have negative costs!");
LoopCost += Cost;
@@ -2754,7 +2776,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
// exponential behavior of loop-unswitch.
if (EnableUnswitchCostMultiplier) {
int CostMultiplier =
- calculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates);
+ CalculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates);
assert(
(CostMultiplier > 0 && CostMultiplier <= UnswitchThreshold) &&
"cost multiplier needs to be in the range of 1..UnswitchThreshold");
@@ -2868,7 +2890,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
// Save the current loop name in a variable so that we can report it even
// after it has been deleted.
- std::string LoopName = L.getName();
+ std::string LoopName = std::string(L.getName());
auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
ArrayRef<Loop *> NewLoops) {
@@ -2983,10 +3005,6 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
if (MSSA && VerifyMemorySSA)
MSSA->verifyMemorySSA();
- // If anything was unswitched, also clear any cached information about this
- // loop.
- LPM.deleteSimpleAnalysisLoop(L);
-
// Historically this pass has had issues with the dominator tree so verify it
// in asserts builds.
assert(DT.verify(DominatorTree::VerificationLevel::Fast));
diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
index 623a8b711ed8..2e459c9a64d4 100644
--- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
+++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
@@ -104,6 +104,21 @@ static bool mergeEmptyReturnBlocks(Function &F) {
continue;
}
+ // Skip merging if this would result in a CallBr instruction with a
+ // duplicate destination. FIXME: See note in CodeGenPrepare.cpp.
+ bool SkipCallBr = false;
+ for (pred_iterator PI = pred_begin(&BB), E = pred_end(&BB);
+ PI != E && !SkipCallBr; ++PI) {
+ if (auto *CBI = dyn_cast<CallBrInst>((*PI)->getTerminator()))
+ for (unsigned i = 0, e = CBI->getNumSuccessors(); i != e; ++i)
+ if (RetBlock == CBI->getSuccessor(i)) {
+ SkipCallBr = true;
+ break;
+ }
+ }
+ if (SkipCallBr)
+ continue;
+
// Otherwise, we found a duplicate return block. Merge the two.
Changed = true;
@@ -266,6 +281,14 @@ struct CFGSimplifyPass : public FunctionPass {
return false;
Options.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ if (F.hasFnAttribute(Attribute::OptForFuzzing)) {
+ Options.setSimplifyCondBranch(false)
+ .setFoldTwoEntryPHINode(false);
+ } else {
+ Options.setSimplifyCondBranch(true)
+ .setFoldTwoEntryPHINode(true);
+ }
+
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return simplifyFunctionCFG(F, TTI, Options);
}
diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp
index 677d86f8c7b4..48f289c8f17d 100644
--- a/llvm/lib/Transforms/Scalar/Sink.cpp
+++ b/llvm/lib/Transforms/Scalar/Sink.cpp
@@ -166,8 +166,8 @@ static bool SinkInstruction(Instruction *Inst,
// dominated by one of the successors.
// Look at all the dominated blocks and see if we can sink it in one.
DomTreeNode *DTN = DT.getNode(Inst->getParent());
- for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end();
- I != E && SuccToSinkTo == nullptr; ++I) {
+ for (auto I = DTN->begin(), E = DTN->end(); I != E && SuccToSinkTo == nullptr;
+ ++I) {
BasicBlock *Candidate = (*I)->getBlock();
// A node always immediate-dominates its children on the dominator
// tree.
diff --git a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
index cd7bfb2f20dc..8258b92a716d 100644
--- a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
+++ b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
@@ -67,8 +67,8 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT,
return false;
}
- if (auto CS = ImmutableCallSite(UI)) {
- if (CS.isConvergent() || CS.cannotDuplicate()) {
+ if (const auto *CS = dyn_cast<CallBase>(UI)) {
+ if (CS->isConvergent() || CS->cannotDuplicate()) {
LLVM_DEBUG(dbgs() << " Unsafe: convergent "
"callsite cannot de duplicated: " << *UI << '\n');
return false;
@@ -232,7 +232,8 @@ static bool isSafeAndProfitableToSpeculateAroundPHI(
continue;
int &MatCost = InsertResult.first->second.MatCost;
- MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType());
+ MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType(),
+ TargetTransformInfo::TCK_SizeAndLatency);
NonFreeMat |= MatCost != TTI.TCC_Free;
}
if (!NonFreeMat) {
@@ -283,12 +284,15 @@ static bool isSafeAndProfitableToSpeculateAroundPHI(
int MatCost = IncomingConstantAndCostsAndCount.second.MatCost;
int &FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost;
if (IID)
- FoldedCost += TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(),
- IncomingC->getType());
+ FoldedCost +=
+ TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(),
+ IncomingC->getType(),
+ TargetTransformInfo::TCK_SizeAndLatency);
else
FoldedCost +=
TTI.getIntImmCostInst(UserI->getOpcode(), Idx,
- IncomingC->getValue(), IncomingC->getType());
+ IncomingC->getValue(), IncomingC->getType(),
+ TargetTransformInfo::TCK_SizeAndLatency);
// If we accumulate more folded cost for this incoming constant than
// materialized cost, then we'll regress any edge with this constant so
@@ -465,7 +469,7 @@ findProfitablePHIs(ArrayRef<PHINode *> PNs,
if (CostMapIt != SpecCostMap.end())
Cost += CostMapIt->second;
}
- Cost += TTI.getUserCost(I);
+ Cost += TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency);
bool Inserted = SpecCostMap.insert({I, Cost}).second;
(void)Inserted;
assert(Inserted && "Must not re-insert a cost during the DFS!");
diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
index c8d899bb4871..f82a2936c762 100644
--- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
+++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
@@ -65,6 +65,7 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/InitializePasses.h"
@@ -244,19 +245,35 @@ static unsigned ComputeSpeculationCost(const Instruction *I,
case Instruction::FNeg:
case Instruction::ICmp:
case Instruction::FCmp:
- return TTI.getUserCost(I);
+ return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency);
default:
- return UINT_MAX; // Disallow anything not whitelisted.
+ return UINT_MAX; // Disallow anything not explicitly listed.
}
}
bool SpeculativeExecutionPass::considerHoistingFromTo(
BasicBlock &FromBlock, BasicBlock &ToBlock) {
SmallPtrSet<const Instruction *, 8> NotHoisted;
- const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) {
- for (Value* V : U->operand_values()) {
- if (Instruction *I = dyn_cast<Instruction>(V)) {
+ const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](const User *U) {
+ // Debug variable has special operand to check it's not hoisted.
+ if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(U)) {
+ if (const auto *I =
+ dyn_cast_or_null<Instruction>(DVI->getVariableLocation()))
+ if (NotHoisted.count(I) == 0)
+ return true;
+ return false;
+ }
+
+ // Usially debug label instrinsic corresponds to label in LLVM IR. In these
+ // cases we should not move it here.
+ // TODO: Possible special processing needed to detect it is related to a
+ // hoisted instruction.
+ if (isa<DbgLabelInst>(U))
+ return false;
+
+ for (const Value *V : U->operand_values()) {
+ if (const Instruction *I = dyn_cast<Instruction>(V)) {
if (NotHoisted.count(I) > 0)
return false;
}
@@ -265,7 +282,8 @@ bool SpeculativeExecutionPass::considerHoistingFromTo(
};
unsigned TotalSpeculationCost = 0;
- for (auto& I : FromBlock) {
+ unsigned NotHoistedInstCount = 0;
+ for (const auto &I : FromBlock) {
const unsigned Cost = ComputeSpeculationCost(&I, *TTI);
if (Cost != UINT_MAX && isSafeToSpeculativelyExecute(&I) &&
AllPrecedingUsesFromBlockHoisted(&I)) {
@@ -273,15 +291,15 @@ bool SpeculativeExecutionPass::considerHoistingFromTo(
if (TotalSpeculationCost > SpecExecMaxSpeculationCost)
return false; // too much to hoist
} else {
- NotHoisted.insert(&I);
- if (NotHoisted.size() > SpecExecMaxNotHoisted)
+ // Debug info instrinsics should not be counted for threshold.
+ if (!isa<DbgInfoIntrinsic>(I))
+ NotHoistedInstCount++;
+ if (NotHoistedInstCount > SpecExecMaxNotHoisted)
return false; // too much left behind
+ NotHoisted.insert(&I);
}
}
- if (TotalSpeculationCost == 0)
- return false; // nothing to hoist
-
for (auto I = FromBlock.begin(); I != FromBlock.end();) {
// We have to increment I before moving Current as moving Current
// changes the list that I is iterating through.
diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
index 4ce4ce46f67a..c20e57b02c1a 100644
--- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
+++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
@@ -8,13 +8,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LegacyDivergenceAnalysis.h"
-#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/RegionInfo.h"
#include "llvm/Analysis/RegionIterator.h"
#include "llvm/Analysis/RegionPass.h"
@@ -34,6 +33,7 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
+#include "llvm/IR/ValueHandle.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
@@ -43,6 +43,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
#include <cassert>
@@ -88,6 +89,59 @@ using BBPredicates = DenseMap<BasicBlock *, Value *>;
using PredMap = DenseMap<BasicBlock *, BBPredicates>;
using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>;
+// A traits type that is intended to be used in graph algorithms. The graph
+// traits starts at an entry node, and traverses the RegionNodes that are in
+// the Nodes set.
+struct SubGraphTraits {
+ using NodeRef = std::pair<RegionNode *, SmallDenseSet<RegionNode *> *>;
+ using BaseSuccIterator = GraphTraits<RegionNode *>::ChildIteratorType;
+
+ // This wraps a set of Nodes into the iterator, so we know which edges to
+ // filter out.
+ class WrappedSuccIterator
+ : public iterator_adaptor_base<
+ WrappedSuccIterator, BaseSuccIterator,
+ typename std::iterator_traits<BaseSuccIterator>::iterator_category,
+ NodeRef, std::ptrdiff_t, NodeRef *, NodeRef> {
+ SmallDenseSet<RegionNode *> *Nodes;
+
+ public:
+ WrappedSuccIterator(BaseSuccIterator It, SmallDenseSet<RegionNode *> *Nodes)
+ : iterator_adaptor_base(It), Nodes(Nodes) {}
+
+ NodeRef operator*() const { return {*I, Nodes}; }
+ };
+
+ static bool filterAll(const NodeRef &N) { return true; }
+ static bool filterSet(const NodeRef &N) { return N.second->count(N.first); }
+
+ using ChildIteratorType =
+ filter_iterator<WrappedSuccIterator, bool (*)(const NodeRef &)>;
+
+ static NodeRef getEntryNode(Region *R) {
+ return {GraphTraits<Region *>::getEntryNode(R), nullptr};
+ }
+
+ static NodeRef getEntryNode(NodeRef N) { return N; }
+
+ static iterator_range<ChildIteratorType> children(const NodeRef &N) {
+ auto *filter = N.second ? &filterSet : &filterAll;
+ return make_filter_range(
+ make_range<WrappedSuccIterator>(
+ {GraphTraits<RegionNode *>::child_begin(N.first), N.second},
+ {GraphTraits<RegionNode *>::child_end(N.first), N.second}),
+ filter);
+ }
+
+ static ChildIteratorType child_begin(const NodeRef &N) {
+ return children(N).begin();
+ }
+
+ static ChildIteratorType child_end(const NodeRef &N) {
+ return children(N).end();
+ }
+};
+
/// Finds the nearest common dominator of a set of BasicBlocks.
///
/// For every BB you add to the set, you can specify whether we "remember" the
@@ -192,11 +246,11 @@ class StructurizeCFG : public RegionPass {
LegacyDivergenceAnalysis *DA;
DominatorTree *DT;
- LoopInfo *LI;
SmallVector<RegionNode *, 8> Order;
BBSet Visited;
+ SmallVector<WeakVH, 8> AffectedPhis;
BBPhiMap DeletedPhis;
BB2BBVecMap AddedPhis;
@@ -211,13 +265,8 @@ class StructurizeCFG : public RegionPass {
void orderNodes();
- Loop *getAdjustedLoop(RegionNode *RN);
- unsigned getAdjustedLoopDepth(RegionNode *RN);
-
void analyzeLoops(RegionNode *N);
- Value *invert(Value *Condition);
-
Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert);
void gatherPredicates(RegionNode *N);
@@ -232,6 +281,8 @@ class StructurizeCFG : public RegionPass {
void setPhiValues();
+ void simplifyAffectedPhis();
+
void killTerminator(BasicBlock *BB);
void changeExit(RegionNode *Node, BasicBlock *NewExit,
@@ -279,7 +330,6 @@ public:
AU.addRequired<LegacyDivergenceAnalysis>();
AU.addRequiredID(LowerSwitchID);
AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
RegionPass::getAnalysisUsage(AU);
@@ -311,75 +361,60 @@ bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) {
return false;
}
-/// Use the exit block to determine the loop if RN is a SubRegion.
-Loop *StructurizeCFG::getAdjustedLoop(RegionNode *RN) {
- if (RN->isSubRegion()) {
- Region *SubRegion = RN->getNodeAs<Region>();
- return LI->getLoopFor(SubRegion->getExit());
- }
-
- return LI->getLoopFor(RN->getEntry());
-}
-
-/// Use the exit block to determine the loop depth if RN is a SubRegion.
-unsigned StructurizeCFG::getAdjustedLoopDepth(RegionNode *RN) {
- if (RN->isSubRegion()) {
- Region *SubR = RN->getNodeAs<Region>();
- return LI->getLoopDepth(SubR->getExit());
- }
-
- return LI->getLoopDepth(RN->getEntry());
-}
-
-/// Build up the general order of nodes
+/// Build up the general order of nodes, by performing a topological sort of the
+/// parent region's nodes, while ensuring that there is no outer cycle node
+/// between any two inner cycle nodes.
void StructurizeCFG::orderNodes() {
- ReversePostOrderTraversal<Region*> RPOT(ParentRegion);
- SmallDenseMap<Loop*, unsigned, 8> LoopBlocks;
-
- // The reverse post-order traversal of the list gives us an ordering close
- // to what we want. The only problem with it is that sometimes backedges
- // for outer loops will be visited before backedges for inner loops.
- for (RegionNode *RN : RPOT) {
- Loop *Loop = getAdjustedLoop(RN);
- ++LoopBlocks[Loop];
- }
-
- unsigned CurrentLoopDepth = 0;
- Loop *CurrentLoop = nullptr;
- for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) {
- RegionNode *RN = cast<RegionNode>(*I);
- unsigned LoopDepth = getAdjustedLoopDepth(RN);
-
- if (is_contained(Order, *I))
- continue;
-
- if (LoopDepth < CurrentLoopDepth) {
- // Make sure we have visited all blocks in this loop before moving back to
- // the outer loop.
+ Order.resize(std::distance(GraphTraits<Region *>::nodes_begin(ParentRegion),
+ GraphTraits<Region *>::nodes_end(ParentRegion)));
+ if (Order.empty())
+ return;
- auto LoopI = I;
- while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) {
- LoopI++;
- if (getAdjustedLoop(cast<RegionNode>(*LoopI)) == CurrentLoop) {
- --BlockCount;
- Order.push_back(*LoopI);
- }
+ SmallDenseSet<RegionNode *> Nodes;
+ auto EntryNode = SubGraphTraits::getEntryNode(ParentRegion);
+
+ // A list of range indices of SCCs in Order, to be processed.
+ SmallVector<std::pair<unsigned, unsigned>, 8> WorkList;
+ unsigned I = 0, E = Order.size();
+ while (true) {
+ // Run through all the SCCs in the subgraph starting with Entry.
+ for (auto SCCI =
+ scc_iterator<SubGraphTraits::NodeRef, SubGraphTraits>::begin(
+ EntryNode);
+ !SCCI.isAtEnd(); ++SCCI) {
+ auto &SCC = *SCCI;
+
+ // An SCC up to the size of 2, can be reduced to an entry (the last node),
+ // and a possible additional node. Therefore, it is already in order, and
+ // there is no need to add it to the work-list.
+ unsigned Size = SCC.size();
+ if (Size > 2)
+ WorkList.emplace_back(I, I + Size);
+
+ // Add the SCC nodes to the Order array.
+ for (auto &N : SCC) {
+ assert(I < E && "SCC size mismatch!");
+ Order[I++] = N.first;
}
}
+ assert(I == E && "SCC size mismatch!");
- CurrentLoop = getAdjustedLoop(RN);
- if (CurrentLoop)
- LoopBlocks[CurrentLoop]--;
+ // If there are no more SCCs to order, then we are done.
+ if (WorkList.empty())
+ break;
- CurrentLoopDepth = LoopDepth;
- Order.push_back(*I);
- }
+ std::tie(I, E) = WorkList.pop_back_val();
+
+ // Collect the set of nodes in the SCC's subgraph. These are only the
+ // possible child nodes; we do not add the entry (last node) otherwise we
+ // will have the same exact SCC all over again.
+ Nodes.clear();
+ Nodes.insert(Order.begin() + I, Order.begin() + E - 1);
- // This pass originally used a post-order traversal and then operated on
- // the list in reverse. Now that we are using a reverse post-order traversal
- // rather than re-working the whole pass to operate on the list in order,
- // we just reverse the list and continue to operate on it in reverse.
- std::reverse(Order.begin(), Order.end());
+ // Update the entry node.
+ EntryNode.first = Order[E - 1];
+ EntryNode.second = &Nodes;
+ }
}
/// Determine the end of the loops
@@ -401,39 +436,6 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {
}
}
-/// Invert the given condition
-Value *StructurizeCFG::invert(Value *Condition) {
- // First: Check if it's a constant
- if (Constant *C = dyn_cast<Constant>(Condition))
- return ConstantExpr::getNot(C);
-
- // Second: If the condition is already inverted, return the original value
- Value *NotCondition;
- if (match(Condition, m_Not(m_Value(NotCondition))))
- return NotCondition;
-
- if (Instruction *Inst = dyn_cast<Instruction>(Condition)) {
- // Third: Check all the users for an invert
- BasicBlock *Parent = Inst->getParent();
- for (User *U : Condition->users())
- if (Instruction *I = dyn_cast<Instruction>(U))
- if (I->getParent() == Parent && match(I, m_Not(m_Specific(Condition))))
- return I;
-
- // Last option: Create a new instruction
- return BinaryOperator::CreateNot(Condition, "", Parent->getTerminator());
- }
-
- if (Argument *Arg = dyn_cast<Argument>(Condition)) {
- BasicBlock &EntryBlock = Arg->getParent()->getEntryBlock();
- return BinaryOperator::CreateNot(Condition,
- Arg->getName() + ".inv",
- EntryBlock.getTerminator());
- }
-
- llvm_unreachable("Unhandled condition to invert");
-}
-
/// Build the condition for one edge
Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,
bool Invert) {
@@ -442,7 +444,7 @@ Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,
Cond = Term->getCondition();
if (Idx != (unsigned)Invert)
- Cond = invert(Cond);
+ Cond = invertCondition(Cond);
}
return Cond;
}
@@ -520,8 +522,7 @@ void StructurizeCFG::collectInfos() {
for (RegionNode *RN : reverse(Order)) {
LLVM_DEBUG(dbgs() << "Visiting: "
<< (RN->isSubRegion() ? "SubRegion with entry: " : "")
- << RN->getEntry()->getName() << " Loop Depth: "
- << LI->getLoopDepth(RN->getEntry()) << "\n");
+ << RN->getEntry()->getName() << "\n");
// Analyze all the conditions leading to a node
gatherPredicates(RN);
@@ -585,9 +586,14 @@ void StructurizeCFG::insertConditions(bool Loops) {
void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) {
PhiMap &Map = DeletedPhis[To];
for (PHINode &Phi : To->phis()) {
+ bool Recorded = false;
while (Phi.getBasicBlockIndex(From) != -1) {
Value *Deleted = Phi.removeIncomingValue(From, false);
Map[&Phi].push_back(std::make_pair(From, Deleted));
+ if (!Recorded) {
+ AffectedPhis.push_back(&Phi);
+ Recorded = true;
+ }
}
}
}
@@ -632,28 +638,29 @@ void StructurizeCFG::setPhiValues() {
for (BasicBlock *FI : From)
Phi->setIncomingValueForBlock(FI, Updater.GetValueAtEndOfBlock(FI));
+ AffectedPhis.push_back(Phi);
}
DeletedPhis.erase(To);
}
assert(DeletedPhis.empty());
- // Simplify any phis inserted by the SSAUpdater if possible
+ AffectedPhis.append(InsertedPhis.begin(), InsertedPhis.end());
+}
+
+void StructurizeCFG::simplifyAffectedPhis() {
bool Changed;
do {
Changed = false;
-
SimplifyQuery Q(Func->getParent()->getDataLayout());
Q.DT = DT;
- for (size_t i = 0; i < InsertedPhis.size(); ++i) {
- PHINode *Phi = InsertedPhis[i];
- if (Value *V = SimplifyInstruction(Phi, Q)) {
- Phi->replaceAllUsesWith(V);
- Phi->eraseFromParent();
- InsertedPhis[i] = InsertedPhis.back();
- InsertedPhis.pop_back();
- i--;
- Changed = true;
+ for (WeakVH VH : AffectedPhis) {
+ if (auto Phi = dyn_cast_or_null<PHINode>(VH)) {
+ if (auto NewValue = SimplifyInstruction(Phi, Q)) {
+ Phi->replaceAllUsesWith(NewValue);
+ Phi->eraseFromParent();
+ Changed = true;
+ }
}
}
} while (Changed);
@@ -886,6 +893,7 @@ void StructurizeCFG::createFlow() {
BasicBlock *Exit = ParentRegion->getExit();
bool EntryDominatesExit = DT->dominates(ParentRegion->getEntry(), Exit);
+ AffectedPhis.clear();
DeletedPhis.clear();
AddedPhis.clear();
Conditions.clear();
@@ -1036,7 +1044,6 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) {
ParentRegion = R;
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
orderNodes();
collectInfos();
@@ -1044,6 +1051,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) {
insertConditions(false);
insertConditions(true);
setPhiValues();
+ simplifyAffectedPhis();
rebuildSSA();
// Cleanup
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 9f0ab9103d42..5bb1d54d7d12 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -64,7 +64,6 @@
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CFG.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
@@ -126,16 +125,16 @@ struct AllocaDerivedValueTracker {
switch (I->getOpcode()) {
case Instruction::Call:
case Instruction::Invoke: {
- CallSite CS(I);
+ auto &CB = cast<CallBase>(*I);
// If the alloca-derived argument is passed byval it is not an escape
// point, or a use of an alloca. Calling with byval copies the contents
// of the alloca into argument registers or stack slots, which exist
// beyond the lifetime of the current frame.
- if (CS.isArgOperand(U) && CS.isByValArgument(CS.getArgumentNo(U)))
+ if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
continue;
bool IsNocapture =
- CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U));
- callUsesLocalStack(CS, IsNocapture);
+ CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
+ callUsesLocalStack(CB, IsNocapture);
if (IsNocapture) {
// If the alloca-derived argument is passed in as nocapture, then it
// can't propagate to the call's return. That would be capturing.
@@ -168,17 +167,17 @@ struct AllocaDerivedValueTracker {
}
}
- void callUsesLocalStack(CallSite CS, bool IsNocapture) {
+ void callUsesLocalStack(CallBase &CB, bool IsNocapture) {
// Add it to the list of alloca users.
- AllocaUsers.insert(CS.getInstruction());
+ AllocaUsers.insert(&CB);
// If it's nocapture then it can't capture this alloca.
if (IsNocapture)
return;
// If it can write to memory, it can leak the alloca value.
- if (!CS.onlyReadsMemory())
- EscapePoints.insert(CS.getInstruction());
+ if (!CB.onlyReadsMemory())
+ EscapePoints.insert(&CB);
}
SmallPtrSet<Instruction *, 32> AllocaUsers;
@@ -342,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
const DataLayout &DL = L->getModule()->getDataLayout();
if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) ||
!isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(),
- MaybeAlign(L->getAlignment()), DL, L))
+ L->getAlign(), DL, L))
return false;
}
}
@@ -355,89 +354,23 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
return !is_contained(I->operands(), CI);
}
-/// Return true if the specified value is the same when the return would exit
-/// as it was when the initial iteration of the recursive function was executed.
-///
-/// We currently handle static constants and arguments that are not modified as
-/// part of the recursion.
-static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) {
- if (isa<Constant>(V)) return true; // Static constants are always dyn consts
-
- // Check to see if this is an immutable argument, if so, the value
- // will be available to initialize the accumulator.
- if (Argument *Arg = dyn_cast<Argument>(V)) {
- // Figure out which argument number this is...
- unsigned ArgNo = 0;
- Function *F = CI->getParent()->getParent();
- for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI)
- ++ArgNo;
-
- // If we are passing this argument into call as the corresponding
- // argument operand, then the argument is dynamically constant.
- // Otherwise, we cannot transform this function safely.
- if (CI->getArgOperand(ArgNo) == Arg)
- return true;
- }
-
- // Switch cases are always constant integers. If the value is being switched
- // on and the return is only reachable from one of its cases, it's
- // effectively constant.
- if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor())
- if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator()))
- if (SI->getCondition() == V)
- return SI->getDefaultDest() != RI->getParent();
-
- // Not a constant or immutable argument, we can't safely transform.
- return false;
-}
-
-/// Check to see if the function containing the specified tail call consistently
-/// returns the same runtime-constant value at all exit points except for
-/// IgnoreRI. If so, return the returned value.
-static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
- Function *F = CI->getParent()->getParent();
- Value *ReturnedValue = nullptr;
-
- for (BasicBlock &BBI : *F) {
- ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator());
- if (RI == nullptr || RI == IgnoreRI) continue;
-
- // We can only perform this transformation if the value returned is
- // evaluatable at the start of the initial invocation of the function,
- // instead of at the end of the evaluation.
- //
- Value *RetOp = RI->getOperand(0);
- if (!isDynamicConstant(RetOp, CI, RI))
- return nullptr;
-
- if (ReturnedValue && RetOp != ReturnedValue)
- return nullptr; // Cannot transform if differing values are returned.
- ReturnedValue = RetOp;
- }
- return ReturnedValue;
-}
+static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
+ if (!I->isAssociative() || !I->isCommutative())
+ return false;
-/// If the specified instruction can be transformed using accumulator recursion
-/// elimination, return the constant which is the start of the accumulator
-/// value. Otherwise return null.
-static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
- if (!I->isAssociative() || !I->isCommutative()) return nullptr;
assert(I->getNumOperands() == 2 &&
"Associative/commutative operations should have 2 args!");
// Exactly one operand should be the result of the call instruction.
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
(I->getOperand(0) != CI && I->getOperand(1) != CI))
- return nullptr;
+ return false;
// The only user of this instruction we allow is a single return instruction.
if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
- return nullptr;
+ return false;
- // Ok, now we have to check all of the other return instructions in this
- // function. If they return non-constants or differing values, then we cannot
- // transform the function safely.
- return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI);
+ return true;
}
static Instruction *firstNonDbg(BasicBlock::iterator I) {
@@ -446,11 +379,73 @@ static Instruction *firstNonDbg(BasicBlock::iterator I) {
return &*I;
}
-static CallInst *findTRECandidate(Instruction *TI,
- bool CannotTailCallElimCallsMarkedTail,
- const TargetTransformInfo *TTI) {
+namespace {
+class TailRecursionEliminator {
+ Function &F;
+ const TargetTransformInfo *TTI;
+ AliasAnalysis *AA;
+ OptimizationRemarkEmitter *ORE;
+ DomTreeUpdater &DTU;
+
+ // The below are shared state we want to have available when eliminating any
+ // calls in the function. There values should be populated by
+ // createTailRecurseLoopHeader the first time we find a call we can eliminate.
+ BasicBlock *HeaderBB = nullptr;
+ SmallVector<PHINode *, 8> ArgumentPHIs;
+ bool RemovableCallsMustBeMarkedTail = false;
+
+ // PHI node to store our return value.
+ PHINode *RetPN = nullptr;
+
+ // i1 PHI node to track if we have a valid return value stored in RetPN.
+ PHINode *RetKnownPN = nullptr;
+
+ // Vector of select instructions we insereted. These selects use RetKnownPN
+ // to either propagate RetPN or select a new return value.
+ SmallVector<SelectInst *, 8> RetSelects;
+
+ // The below are shared state needed when performing accumulator recursion.
+ // There values should be populated by insertAccumulator the first time we
+ // find an elimination that requires an accumulator.
+
+ // PHI node to store our current accumulated value.
+ PHINode *AccPN = nullptr;
+
+ // The instruction doing the accumulating.
+ Instruction *AccumulatorRecursionInstr = nullptr;
+
+ TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
+ AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
+ DomTreeUpdater &DTU)
+ : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+
+ CallInst *findTRECandidate(Instruction *TI,
+ bool CannotTailCallElimCallsMarkedTail);
+
+ void createTailRecurseLoopHeader(CallInst *CI);
+
+ void insertAccumulator(Instruction *AccRecInstr);
+
+ bool eliminateCall(CallInst *CI);
+
+ bool foldReturnAndProcessPred(ReturnInst *Ret,
+ bool CannotTailCallElimCallsMarkedTail);
+
+ bool processReturningBlock(ReturnInst *Ret,
+ bool CannotTailCallElimCallsMarkedTail);
+
+ void cleanupAndFinalize();
+
+public:
+ static bool eliminate(Function &F, const TargetTransformInfo *TTI,
+ AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
+ DomTreeUpdater &DTU);
+};
+} // namespace
+
+CallInst *TailRecursionEliminator::findTRECandidate(
+ Instruction *TI, bool CannotTailCallElimCallsMarkedTail) {
BasicBlock *BB = TI->getParent();
- Function *F = BB->getParent();
if (&BB->front() == TI) // Make sure there is something before the terminator.
return nullptr;
@@ -461,7 +456,7 @@ static CallInst *findTRECandidate(Instruction *TI,
BasicBlock::iterator BBI(TI);
while (true) {
CI = dyn_cast<CallInst>(BBI);
- if (CI && CI->getCalledFunction() == F)
+ if (CI && CI->getCalledFunction() == &F)
break;
if (BBI == BB->begin())
@@ -478,16 +473,14 @@ static CallInst *findTRECandidate(Instruction *TI,
// double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
// and disable this xform in this case, because the code generator will
// lower the call to fabs into inline code.
- if (BB == &F->getEntryBlock() &&
+ if (BB == &F.getEntryBlock() &&
firstNonDbg(BB->front().getIterator()) == CI &&
firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() &&
!TTI->isLoweredToCall(CI->getCalledFunction())) {
// A single-block function with just a call and a return. Check that
// the arguments match.
- CallSite::arg_iterator I = CallSite(CI).arg_begin(),
- E = CallSite(CI).arg_end();
- Function::arg_iterator FI = F->arg_begin(),
- FE = F->arg_end();
+ auto I = CI->arg_begin(), E = CI->arg_end();
+ Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end();
for (; I != E && FI != FE; ++I, ++FI)
if (*I != &*FI) break;
if (I == E && FI == FE)
@@ -497,27 +490,106 @@ static CallInst *findTRECandidate(Instruction *TI,
return CI;
}
-static bool eliminateRecursiveTailCall(
- CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry,
- bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs,
- AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
- // If we are introducing accumulator recursion to eliminate operations after
- // the call instruction that are both associative and commutative, the initial
- // value for the accumulator is placed in this variable. If this value is set
- // then we actually perform accumulator recursion elimination instead of
- // simple tail recursion elimination. If the operation is an LLVM instruction
- // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then
- // we are handling the case when the return instruction returns a constant C
- // which is different to the constant returned by other return instructions
- // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a
- // special case of accumulator recursion, the operation being "return C".
- Value *AccumulatorRecursionEliminationInitVal = nullptr;
- Instruction *AccumulatorRecursionInstr = nullptr;
+void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
+ HeaderBB = &F.getEntryBlock();
+ BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB);
+ NewEntry->takeName(HeaderBB);
+ HeaderBB->setName("tailrecurse");
+ BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry);
+ BI->setDebugLoc(CI->getDebugLoc());
+
+ // If this function has self recursive calls in the tail position where some
+ // are marked tail and some are not, only transform one flavor or another.
+ // We have to choose whether we move allocas in the entry block to the new
+ // entry block or not, so we can't make a good choice for both. We make this
+ // decision here based on whether the first call we found to remove is
+ // marked tail.
+ // NOTE: We could do slightly better here in the case that the function has
+ // no entry block allocas.
+ RemovableCallsMustBeMarkedTail = CI->isTailCall();
+
+ // If this tail call is marked 'tail' and if there are any allocas in the
+ // entry block, move them up to the new entry block.
+ if (RemovableCallsMustBeMarkedTail)
+ // Move all fixed sized allocas from HeaderBB to NewEntry.
+ for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(),
+ NEBI = NewEntry->begin();
+ OEBI != E;)
+ if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
+ if (isa<ConstantInt>(AI->getArraySize()))
+ AI->moveBefore(&*NEBI);
+
+ // Now that we have created a new block, which jumps to the entry
+ // block, insert a PHI node for each argument of the function.
+ // For now, we initialize each PHI to only have the real arguments
+ // which are passed in.
+ Instruction *InsertPos = &HeaderBB->front();
+ for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) {
+ PHINode *PN =
+ PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos);
+ I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
+ PN->addIncoming(&*I, NewEntry);
+ ArgumentPHIs.push_back(PN);
+ }
+
+ // If the function doen't return void, create the RetPN and RetKnownPN PHI
+ // nodes to track our return value. We initialize RetPN with undef and
+ // RetKnownPN with false since we can't know our return value at function
+ // entry.
+ Type *RetType = F.getReturnType();
+ if (!RetType->isVoidTy()) {
+ Type *BoolType = Type::getInt1Ty(F.getContext());
+ RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos);
+ RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos);
+
+ RetPN->addIncoming(UndefValue::get(RetType), NewEntry);
+ RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry);
+ }
+
+ // The entry block was changed from HeaderBB to NewEntry.
+ // The forward DominatorTree needs to be recalculated when the EntryBB is
+ // changed. In this corner-case we recalculate the entire tree.
+ DTU.recalculate(*NewEntry->getParent());
+}
+
+void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
+ assert(!AccPN && "Trying to insert multiple accumulators");
+
+ AccumulatorRecursionInstr = AccRecInstr;
+
+ // Start by inserting a new PHI node for the accumulator.
+ pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB);
+ AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1,
+ "accumulator.tr", &HeaderBB->front());
+
+ // Loop over all of the predecessors of the tail recursion block. For the
+ // real entry into the function we seed the PHI with the identity constant for
+ // the accumulation operation. For any other existing branches to this block
+ // (due to other tail recursions eliminated) the accumulator is not modified.
+ // Because we haven't added the branch in the current block to HeaderBB yet,
+ // it will not show up as a predecessor.
+ for (pred_iterator PI = PB; PI != PE; ++PI) {
+ BasicBlock *P = *PI;
+ if (P == &F.getEntryBlock()) {
+ Constant *Identity = ConstantExpr::getBinOpIdentity(
+ AccRecInstr->getOpcode(), AccRecInstr->getType());
+ AccPN->addIncoming(Identity, P);
+ } else {
+ AccPN->addIncoming(AccPN, P);
+ }
+ }
+
+ ++NumAccumAdded;
+}
+
+bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
+ ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator());
// Ok, we found a potential tail call. We can currently only transform the
// tail call if all of the instructions between the call and the return are
// movable to above the call itself, leaving the call next to the return.
// Check that this is the case now.
+ Instruction *AccRecInstr = nullptr;
BasicBlock::iterator BBI(CI);
for (++BBI; &*BBI != Ret; ++BBI) {
if (canMoveAboveCall(&*BBI, CI, AA))
@@ -526,39 +598,16 @@ static bool eliminateRecursiveTailCall(
// If we can't move the instruction above the call, it might be because it
// is an associative and commutative operation that could be transformed
// using accumulator recursion elimination. Check to see if this is the
- // case, and if so, remember the initial accumulator value for later.
- if ((AccumulatorRecursionEliminationInitVal =
- canTransformAccumulatorRecursion(&*BBI, CI))) {
- // Yes, this is accumulator recursion. Remember which instruction
- // accumulates.
- AccumulatorRecursionInstr = &*BBI;
- } else {
- return false; // Otherwise, we cannot eliminate the tail recursion!
- }
- }
+ // case, and if so, remember which instruction accumulates for later.
+ if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI))
+ return false; // We cannot eliminate the tail recursion!
- // We can only transform call/return pairs that either ignore the return value
- // of the call and return void, ignore the value of the call and return a
- // constant, return the value returned by the tail call, or that are being
- // accumulator recursion variable eliminated.
- if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI &&
- !isa<UndefValue>(Ret->getReturnValue()) &&
- AccumulatorRecursionEliminationInitVal == nullptr &&
- !getCommonReturnValue(nullptr, CI)) {
- // One case remains that we are able to handle: the current return
- // instruction returns a constant, and all other return instructions
- // return a different constant.
- if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret))
- return false; // Current return instruction does not return a constant.
- // Check that all other return instructions return a common constant. If
- // so, record it in AccumulatorRecursionEliminationInitVal.
- AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI);
- if (!AccumulatorRecursionEliminationInitVal)
- return false;
+ // Yes, this is accumulator recursion. Remember which instruction
+ // accumulates.
+ AccRecInstr = &*BBI;
}
BasicBlock *BB = Ret->getParent();
- Function *F = BB->getParent();
using namespace ore;
ORE->emit([&]() {
@@ -568,51 +617,10 @@ static bool eliminateRecursiveTailCall(
// OK! We can transform this tail call. If this is the first one found,
// create the new entry block, allowing us to branch back to the old entry.
- if (!OldEntry) {
- OldEntry = &F->getEntryBlock();
- BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry);
- NewEntry->takeName(OldEntry);
- OldEntry->setName("tailrecurse");
- BranchInst *BI = BranchInst::Create(OldEntry, NewEntry);
- BI->setDebugLoc(CI->getDebugLoc());
-
- // If this tail call is marked 'tail' and if there are any allocas in the
- // entry block, move them up to the new entry block.
- TailCallsAreMarkedTail = CI->isTailCall();
- if (TailCallsAreMarkedTail)
- // Move all fixed sized allocas from OldEntry to NewEntry.
- for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(),
- NEBI = NewEntry->begin(); OEBI != E; )
- if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
- if (isa<ConstantInt>(AI->getArraySize()))
- AI->moveBefore(&*NEBI);
-
- // Now that we have created a new block, which jumps to the entry
- // block, insert a PHI node for each argument of the function.
- // For now, we initialize each PHI to only have the real arguments
- // which are passed in.
- Instruction *InsertPos = &OldEntry->front();
- for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
- I != E; ++I) {
- PHINode *PN = PHINode::Create(I->getType(), 2,
- I->getName() + ".tr", InsertPos);
- I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
- PN->addIncoming(&*I, NewEntry);
- ArgumentPHIs.push_back(PN);
- }
- // The entry block was changed from OldEntry to NewEntry.
- // The forward DominatorTree needs to be recalculated when the EntryBB is
- // changed. In this corner-case we recalculate the entire tree.
- DTU.recalculate(*NewEntry->getParent());
- }
+ if (!HeaderBB)
+ createTailRecurseLoopHeader(CI);
- // If this function has self recursive calls in the tail position where some
- // are marked tail and some are not, only transform one flavor or another. We
- // have to choose whether we move allocas in the entry block to the new entry
- // block or not, so we can't make a good choice for both. NOTE: We could do
- // slightly better here in the case that the function has no entry block
- // allocas.
- if (TailCallsAreMarkedTail && !CI->isTailCall())
+ if (RemovableCallsMustBeMarkedTail && !CI->isTailCall())
return false;
// Ok, now that we know we have a pseudo-entry block WITH all of the
@@ -621,74 +629,53 @@ static bool eliminateRecursiveTailCall(
for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)
ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB);
- // If we are introducing an accumulator variable to eliminate the recursion,
- // do so now. Note that we _know_ that no subsequent tail recursion
- // eliminations will happen on this function because of the way the
- // accumulator recursion predicate is set up.
- //
- if (AccumulatorRecursionEliminationInitVal) {
- Instruction *AccRecInstr = AccumulatorRecursionInstr;
- // Start by inserting a new PHI node for the accumulator.
- pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry);
- PHINode *AccPN = PHINode::Create(
- AccumulatorRecursionEliminationInitVal->getType(),
- std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front());
-
- // Loop over all of the predecessors of the tail recursion block. For the
- // real entry into the function we seed the PHI with the initial value,
- // computed earlier. For any other existing branches to this block (due to
- // other tail recursions eliminated) the accumulator is not modified.
- // Because we haven't added the branch in the current block to OldEntry yet,
- // it will not show up as a predecessor.
- for (pred_iterator PI = PB; PI != PE; ++PI) {
- BasicBlock *P = *PI;
- if (P == &F->getEntryBlock())
- AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
- else
- AccPN->addIncoming(AccPN, P);
- }
+ if (AccRecInstr) {
+ insertAccumulator(AccRecInstr);
- if (AccRecInstr) {
- // Add an incoming argument for the current block, which is computed by
- // our associative and commutative accumulator instruction.
- AccPN->addIncoming(AccRecInstr, BB);
+ // Rewrite the accumulator recursion instruction so that it does not use
+ // the result of the call anymore, instead, use the PHI node we just
+ // inserted.
+ AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
+ }
- // Next, rewrite the accumulator recursion instruction so that it does not
- // use the result of the call anymore, instead, use the PHI node we just
- // inserted.
- AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
+ // Update our return value tracking
+ if (RetPN) {
+ if (Ret->getReturnValue() == CI || AccRecInstr) {
+ // Defer selecting a return value
+ RetPN->addIncoming(RetPN, BB);
+ RetKnownPN->addIncoming(RetKnownPN, BB);
} else {
- // Add an incoming argument for the current block, which is just the
- // constant returned by the current return instruction.
- AccPN->addIncoming(Ret->getReturnValue(), BB);
+ // We found a return value we want to use, insert a select instruction to
+ // select it if we don't already know what our return value will be and
+ // store the result in our return value PHI node.
+ SelectInst *SI = SelectInst::Create(
+ RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret);
+ RetSelects.push_back(SI);
+
+ RetPN->addIncoming(SI, BB);
+ RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB);
}
- // Finally, rewrite any return instructions in the program to return the PHI
- // node instead of the "initval" that they do currently. This loop will
- // actually rewrite the return value we are destroying, but that's ok.
- for (BasicBlock &BBI : *F)
- if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()))
- RI->setOperand(0, AccPN);
- ++NumAccumAdded;
+ if (AccPN)
+ AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);
}
// Now that all of the PHI nodes are in place, remove the call and
// ret instructions, replacing them with an unconditional branch.
- BranchInst *NewBI = BranchInst::Create(OldEntry, Ret);
+ BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret);
NewBI->setDebugLoc(CI->getDebugLoc());
BB->getInstList().erase(Ret); // Remove return.
BB->getInstList().erase(CI); // Remove call.
- DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}});
+ DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
return true;
}
-static bool foldReturnAndProcessPred(
- BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry,
- bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs,
- bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI,
- AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
+bool TailRecursionEliminator::foldReturnAndProcessPred(
+ ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
+ BasicBlock *BB = Ret->getParent();
+
bool Change = false;
// Make sure this block is a trivial return block.
@@ -711,10 +698,11 @@ static bool foldReturnAndProcessPred(
while (!UncondBranchPreds.empty()) {
BranchInst *BI = UncondBranchPreds.pop_back_val();
BasicBlock *Pred = BI->getParent();
- if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){
+ if (CallInst *CI =
+ findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) {
LLVM_DEBUG(dbgs() << "FOLDING: " << *BB
<< "INTO UNCOND BRANCH PRED: " << *Pred);
- ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);
+ FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);
// Cleanup: if all predecessors of BB have been eliminated by
// FoldReturnIntoUncondBranch, delete it. It is important to empty it,
@@ -723,8 +711,7 @@ static bool foldReturnAndProcessPred(
if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB))
DTU.deleteBB(BB);
- eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail,
- ArgumentPHIs, AA, ORE, DTU);
+ eliminateCall(CI);
++NumRetDuped;
Change = true;
}
@@ -733,23 +720,92 @@ static bool foldReturnAndProcessPred(
return Change;
}
-static bool processReturningBlock(
- ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail,
- SmallVectorImpl<PHINode *> &ArgumentPHIs,
- bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI,
- AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
- CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI);
+bool TailRecursionEliminator::processReturningBlock(
+ ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
+ CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
if (!CI)
return false;
- return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
- ArgumentPHIs, AA, ORE, DTU);
+ return eliminateCall(CI);
+}
+
+void TailRecursionEliminator::cleanupAndFinalize() {
+ // If we eliminated any tail recursions, it's possible that we inserted some
+ // silly PHI nodes which just merge an initial value (the incoming operand)
+ // with themselves. Check to see if we did and clean up our mess if so. This
+ // occurs when a function passes an argument straight through to its tail
+ // call.
+ for (PHINode *PN : ArgumentPHIs) {
+ // If the PHI Node is a dynamic constant, replace it with the value it is.
+ if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) {
+ PN->replaceAllUsesWith(PNV);
+ PN->eraseFromParent();
+ }
+ }
+
+ if (RetPN) {
+ if (RetSelects.empty()) {
+ // If we didn't insert any select instructions, then we know we didn't
+ // store a return value and we can remove the PHI nodes we inserted.
+ RetPN->dropAllReferences();
+ RetPN->eraseFromParent();
+
+ RetKnownPN->dropAllReferences();
+ RetKnownPN->eraseFromParent();
+
+ if (AccPN) {
+ // We need to insert a copy of our accumulator instruction before any
+ // return in the function, and return its result instead.
+ Instruction *AccRecInstr = AccumulatorRecursionInstr;
+ for (BasicBlock &BB : F) {
+ ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator());
+ if (!RI)
+ continue;
+
+ Instruction *AccRecInstrNew = AccRecInstr->clone();
+ AccRecInstrNew->setName("accumulator.ret.tr");
+ AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
+ RI->getOperand(0));
+ AccRecInstrNew->insertBefore(RI);
+ RI->setOperand(0, AccRecInstrNew);
+ }
+ }
+ } else {
+ // We need to insert a select instruction before any return left in the
+ // function to select our stored return value if we have one.
+ for (BasicBlock &BB : F) {
+ ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator());
+ if (!RI)
+ continue;
+
+ SelectInst *SI = SelectInst::Create(
+ RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI);
+ RetSelects.push_back(SI);
+ RI->setOperand(0, SI);
+ }
+
+ if (AccPN) {
+ // We need to insert a copy of our accumulator instruction before any
+ // of the selects we inserted, and select its result instead.
+ Instruction *AccRecInstr = AccumulatorRecursionInstr;
+ for (SelectInst *SI : RetSelects) {
+ Instruction *AccRecInstrNew = AccRecInstr->clone();
+ AccRecInstrNew->setName("accumulator.ret.tr");
+ AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN,
+ SI->getFalseValue());
+ AccRecInstrNew->insertBefore(SI);
+ SI->setFalseValue(AccRecInstrNew);
+ }
+ }
+ }
+ }
}
-static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
- AliasAnalysis *AA,
- OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU) {
+bool TailRecursionEliminator::eliminate(Function &F,
+ const TargetTransformInfo *TTI,
+ AliasAnalysis *AA,
+ OptimizationRemarkEmitter *ORE,
+ DomTreeUpdater &DTU) {
if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true")
return false;
@@ -762,17 +818,15 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
// If this function is a varargs function, we won't be able to PHI the args
// right, so don't even try to convert it...
if (F.getFunctionType()->isVarArg())
- return false;
-
- BasicBlock *OldEntry = nullptr;
- bool TailCallsAreMarkedTail = false;
- SmallVector<PHINode*, 8> ArgumentPHIs;
+ return MadeChange;
// If false, we cannot perform TRE on tail calls marked with the 'tail'
// attribute, because doing so would cause the stack size to increase (real
// TRE would deallocate variable sized allocas, TRE doesn't).
bool CanTRETailMarkedCall = canTRE(F);
+ TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+
// Change any tail recursive calls to loops.
//
// FIXME: The code generator produces really bad code when an 'escaping
@@ -782,29 +836,14 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) {
BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB.
if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
- bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
- ArgumentPHIs, !CanTRETailMarkedCall,
- TTI, AA, ORE, DTU);
+ bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall);
if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
- Change = foldReturnAndProcessPred(
- BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs,
- !CanTRETailMarkedCall, TTI, AA, ORE, DTU);
+ Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall);
MadeChange |= Change;
}
}
- // If we eliminated any tail recursions, it's possible that we inserted some
- // silly PHI nodes which just merge an initial value (the incoming operand)
- // with themselves. Check to see if we did and clean up our mess if so. This
- // occurs when a function passes an argument straight through to its tail
- // call.
- for (PHINode *PN : ArgumentPHIs) {
- // If the PHI Node is a dynamic constant, replace it with the value it is.
- if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) {
- PN->replaceAllUsesWith(PNV);
- PN->eraseFromParent();
- }
- }
+ TRE.cleanupAndFinalize();
return MadeChange;
}
@@ -838,7 +877,7 @@ struct TailCallElim : public FunctionPass {
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
- return eliminateTailRecursion(
+ return TailRecursionEliminator::eliminate(
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
@@ -871,7 +910,7 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
// UpdateStrategy based on some test results. It is feasible to switch the
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
- bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU);
+ bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
index c8461fdc1608..7c81e6352dec 100644
--- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
+++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/WarnMissedTransforms.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/LoopUtils.h"