aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp144
1 files changed, 95 insertions, 49 deletions
diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 7478daa2a0a5..9b81afbb4b6c 100644
--- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -50,7 +50,6 @@
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
-#include "llvm/Analysis/OrderedBasicBlock.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -129,22 +128,6 @@ public:
private:
unsigned getPointerAddressSpace(Value *I);
- unsigned getAlignment(LoadInst *LI) const {
- unsigned Align = LI->getAlignment();
- if (Align != 0)
- return Align;
-
- return DL.getABITypeAlignment(LI->getType());
- }
-
- unsigned getAlignment(StoreInst *SI) const {
- unsigned Align = SI->getAlignment();
- if (Align != 0)
- return Align;
-
- return DL.getABITypeAlignment(SI->getValueOperand()->getType());
- }
-
static const unsigned MaxDepth = 3;
bool isConsecutiveAccess(Value *A, Value *B);
@@ -447,20 +430,78 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
// Now we need to prove that adding IdxDiff to ValA won't overflow.
bool Safe = false;
+ auto CheckFlags = [](Instruction *I, bool Signed) {
+ BinaryOperator *BinOpI = cast<BinaryOperator>(I);
+ return (Signed && BinOpI->hasNoSignedWrap()) ||
+ (!Signed && BinOpI->hasNoUnsignedWrap());
+ };
+
// First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
// ValA, we're okay.
if (OpB->getOpcode() == Instruction::Add &&
isa<ConstantInt>(OpB->getOperand(1)) &&
- IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue())) {
- if (Signed)
- Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap();
- else
- Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap();
+ IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
+ CheckFlags(OpB, Signed))
+ Safe = true;
+
+ // Second attempt: If both OpA and OpB is an add with NSW/NUW and with
+ // the same LHS operand, we can guarantee that the transformation is safe
+ // if we can prove that OpA won't overflow when IdxDiff added to the RHS
+ // of OpA.
+ // For example:
+ // %tmp7 = add nsw i32 %tmp2, %v0
+ // %tmp8 = sext i32 %tmp7 to i64
+ // ...
+ // %tmp11 = add nsw i32 %v0, 1
+ // %tmp12 = add nsw i32 %tmp2, %tmp11
+ // %tmp13 = sext i32 %tmp12 to i64
+ //
+ // Both %tmp7 and %tmp2 has the nsw flag and the first operand
+ // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
+ // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
+ // nsw flag.
+ OpA = dyn_cast<Instruction>(ValA);
+ if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
+ OpB->getOpcode() == Instruction::Add &&
+ OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) &&
+ CheckFlags(OpB, Signed)) {
+ Value *RHSA = OpA->getOperand(1);
+ Value *RHSB = OpB->getOperand(1);
+ Instruction *OpRHSA = dyn_cast<Instruction>(RHSA);
+ Instruction *OpRHSB = dyn_cast<Instruction>(RHSB);
+ // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
+ if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add &&
+ CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) {
+ int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+ if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal)
+ Safe = true;
+ }
+ // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
+ if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add &&
+ CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) {
+ int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+ if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal)
+ Safe = true;
+ }
+ // Match `x +nsw/nuw (y +nsw/nuw c)` and
+ // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
+ if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add &&
+ OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) &&
+ CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) &&
+ isa<ConstantInt>(OpRHSB->getOperand(1))) {
+ int64_t CstValA =
+ cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+ int64_t CstValB =
+ cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+ if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) &&
+ IdxDiff.getSExtValue() == (CstValB - CstValA))
+ Safe = true;
+ }
}
unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
- // Second attempt:
+ // Third attempt:
// If all set bits of IdxDiff or any higher order bit other than the sign bit
// are known to be zero in ValA, we can add Diff to it while guaranteeing no
// overflow of any sort.
@@ -503,7 +544,6 @@ bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB,
}
void Vectorizer::reorder(Instruction *I) {
- OrderedBasicBlock OBB(I->getParent());
SmallPtrSet<Instruction *, 16> InstructionsToMove;
SmallVector<Instruction *, 16> Worklist;
@@ -521,7 +561,7 @@ void Vectorizer::reorder(Instruction *I) {
if (IM->getParent() != I->getParent())
continue;
- if (!OBB.dominates(IM, I)) {
+ if (!IM->comesBefore(I)) {
InstructionsToMove.insert(IM);
Worklist.push_back(IM);
}
@@ -637,8 +677,6 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
}
}
- OrderedBasicBlock OBB(Chain[0]->getParent());
-
// Loop until we find an instruction in ChainInstrs that we can't vectorize.
unsigned ChainInstrIdx = 0;
Instruction *BarrierMemoryInstr = nullptr;
@@ -648,14 +686,14 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
// If a barrier memory instruction was found, chain instructions that follow
// will not be added to the valid prefix.
- if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, ChainInstr))
+ if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr))
break;
// Check (in BB order) if any instruction prevents ChainInstr from being
// vectorized. Find and store the first such "conflicting" instruction.
for (Instruction *MemInstr : MemoryInstrs) {
// If a barrier memory instruction was found, do not check past it.
- if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, MemInstr))
+ if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr))
break;
auto *MemLoad = dyn_cast<LoadInst>(MemInstr);
@@ -674,12 +712,12 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
// vectorize it (the vectorized load is inserted at the location of the
// first load in the chain).
if (isa<StoreInst>(MemInstr) && ChainLoad &&
- (IsInvariantLoad(ChainLoad) || OBB.dominates(ChainLoad, MemInstr)))
+ (IsInvariantLoad(ChainLoad) || ChainLoad->comesBefore(MemInstr)))
continue;
// Same case, but in reverse.
if (MemLoad && isa<StoreInst>(ChainInstr) &&
- (IsInvariantLoad(MemLoad) || OBB.dominates(MemLoad, ChainInstr)))
+ (IsInvariantLoad(MemLoad) || MemLoad->comesBefore(ChainInstr)))
continue;
if (!AA.isNoAlias(MemoryLocation::get(MemInstr),
@@ -705,7 +743,7 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
// the basic block.
if (IsLoadChain && BarrierMemoryInstr) {
// The BarrierMemoryInstr is a store that precedes ChainInstr.
- assert(OBB.dominates(BarrierMemoryInstr, ChainInstr));
+ assert(BarrierMemoryInstr->comesBefore(ChainInstr));
break;
}
}
@@ -961,7 +999,7 @@ bool Vectorizer::vectorizeStoreChain(
unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
unsigned VF = VecRegSize / Sz;
unsigned ChainSize = Chain.size();
- unsigned Alignment = getAlignment(S0);
+ Align Alignment = S0->getAlign();
if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
InstructionsProcessed->insert(Chain.begin(), Chain.end());
@@ -992,10 +1030,10 @@ bool Vectorizer::vectorizeStoreChain(
VectorType *VecTy;
VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy);
if (VecStoreTy)
- VecTy = VectorType::get(StoreTy->getScalarType(),
- Chain.size() * VecStoreTy->getNumElements());
+ VecTy = FixedVectorType::get(StoreTy->getScalarType(),
+ Chain.size() * VecStoreTy->getNumElements());
else
- VecTy = VectorType::get(StoreTy, Chain.size());
+ VecTy = FixedVectorType::get(StoreTy, Chain.size());
// If it's more than the max vector size or the target has a better
// vector factor, break it into two pieces.
@@ -1019,18 +1057,20 @@ bool Vectorizer::vectorizeStoreChain(
InstructionsProcessed->insert(Chain.begin(), Chain.end());
// If the store is going to be misaligned, don't vectorize it.
- if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
+ if (accessIsMisaligned(SzInBytes, AS, Alignment.value())) {
if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
auto Chains = splitOddVectorElts(Chain, Sz);
return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
vectorizeStoreChain(Chains.second, InstructionsProcessed);
}
- unsigned NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
- StackAdjustedAlignment,
- DL, S0, nullptr, &DT);
- if (NewAlign != 0)
+ Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
+ Align(StackAdjustedAlignment),
+ DL, S0, nullptr, &DT);
+ if (NewAlign >= Alignment)
Alignment = NewAlign;
+ else
+ return false;
}
if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
@@ -1112,7 +1152,7 @@ bool Vectorizer::vectorizeLoadChain(
unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
unsigned VF = VecRegSize / Sz;
unsigned ChainSize = Chain.size();
- unsigned Alignment = getAlignment(L0);
+ Align Alignment = L0->getAlign();
if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
InstructionsProcessed->insert(Chain.begin(), Chain.end());
@@ -1142,10 +1182,10 @@ bool Vectorizer::vectorizeLoadChain(
VectorType *VecTy;
VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy);
if (VecLoadTy)
- VecTy = VectorType::get(LoadTy->getScalarType(),
- Chain.size() * VecLoadTy->getNumElements());
+ VecTy = FixedVectorType::get(LoadTy->getScalarType(),
+ Chain.size() * VecLoadTy->getNumElements());
else
- VecTy = VectorType::get(LoadTy, Chain.size());
+ VecTy = FixedVectorType::get(LoadTy, Chain.size());
// If it's more than the max vector size or the target has a better
// vector factor, break it into two pieces.
@@ -1162,15 +1202,20 @@ bool Vectorizer::vectorizeLoadChain(
InstructionsProcessed->insert(Chain.begin(), Chain.end());
// If the load is going to be misaligned, don't vectorize it.
- if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
+ if (accessIsMisaligned(SzInBytes, AS, Alignment.value())) {
if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
auto Chains = splitOddVectorElts(Chain, Sz);
return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
vectorizeLoadChain(Chains.second, InstructionsProcessed);
}
- Alignment = getOrEnforceKnownAlignment(
- L0->getPointerOperand(), StackAdjustedAlignment, DL, L0, nullptr, &DT);
+ Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(),
+ Align(StackAdjustedAlignment),
+ DL, L0, nullptr, &DT);
+ if (NewAlign >= Alignment)
+ Alignment = NewAlign;
+ else
+ return false;
}
if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
@@ -1194,7 +1239,8 @@ bool Vectorizer::vectorizeLoadChain(
Value *Bitcast =
Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
- LoadInst *LI = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment);
+ LoadInst *LI =
+ Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment));
propagateMetadata(LI, Chain);
if (VecLoadTy) {