diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp | 144 |
1 files changed, 95 insertions, 49 deletions
diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 7478daa2a0a5..9b81afbb4b6c 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -50,7 +50,6 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -129,22 +128,6 @@ public: private: unsigned getPointerAddressSpace(Value *I); - unsigned getAlignment(LoadInst *LI) const { - unsigned Align = LI->getAlignment(); - if (Align != 0) - return Align; - - return DL.getABITypeAlignment(LI->getType()); - } - - unsigned getAlignment(StoreInst *SI) const { - unsigned Align = SI->getAlignment(); - if (Align != 0) - return Align; - - return DL.getABITypeAlignment(SI->getValueOperand()->getType()); - } - static const unsigned MaxDepth = 3; bool isConsecutiveAccess(Value *A, Value *B); @@ -447,20 +430,78 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; + auto CheckFlags = [](Instruction *I, bool Signed) { + BinaryOperator *BinOpI = cast<BinaryOperator>(I); + return (Signed && BinOpI->hasNoSignedWrap()) || + (!Signed && BinOpI->hasNoUnsignedWrap()); + }; + // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa<ConstantInt>(OpB->getOperand(1)) && - IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue())) { - if (Signed) - Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap(); - else - Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap(); + IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) && + CheckFlags(OpB, Signed)) + Safe = true; + + // Second attempt: If both OpA and OpB is an add with NSW/NUW and with + // the same LHS operand, we can guarantee that the transformation is safe + // if we can prove that OpA won't overflow when IdxDiff added to the RHS + // of OpA. + // For example: + // %tmp7 = add nsw i32 %tmp2, %v0 + // %tmp8 = sext i32 %tmp7 to i64 + // ... + // %tmp11 = add nsw i32 %v0, 1 + // %tmp12 = add nsw i32 %tmp2, %tmp11 + // %tmp13 = sext i32 %tmp12 to i64 + // + // Both %tmp7 and %tmp2 has the nsw flag and the first operand + // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow + // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the + // nsw flag. + OpA = dyn_cast<Instruction>(ValA); + if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && + OpB->getOpcode() == Instruction::Add && + OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) && + CheckFlags(OpB, Signed)) { + Value *RHSA = OpA->getOperand(1); + Value *RHSB = OpB->getOperand(1); + Instruction *OpRHSA = dyn_cast<Instruction>(RHSA); + Instruction *OpRHSB = dyn_cast<Instruction>(RHSB); + // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. + if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && + CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) { + int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. + if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && + CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) { + int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw c)` and + // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. + if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && + OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) && + CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) && + isa<ConstantInt>(OpRHSB->getOperand(1))) { + int64_t CstValA = + cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue(); + int64_t CstValB = + cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && + IdxDiff.getSExtValue() == (CstValB - CstValA)) + Safe = true; + } } unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); - // Second attempt: + // Third attempt: // If all set bits of IdxDiff or any higher order bit other than the sign bit // are known to be zero in ValA, we can add Diff to it while guaranteeing no // overflow of any sort. @@ -503,7 +544,6 @@ bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB, } void Vectorizer::reorder(Instruction *I) { - OrderedBasicBlock OBB(I->getParent()); SmallPtrSet<Instruction *, 16> InstructionsToMove; SmallVector<Instruction *, 16> Worklist; @@ -521,7 +561,7 @@ void Vectorizer::reorder(Instruction *I) { if (IM->getParent() != I->getParent()) continue; - if (!OBB.dominates(IM, I)) { + if (!IM->comesBefore(I)) { InstructionsToMove.insert(IM); Worklist.push_back(IM); } @@ -637,8 +677,6 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { } } - OrderedBasicBlock OBB(Chain[0]->getParent()); - // Loop until we find an instruction in ChainInstrs that we can't vectorize. unsigned ChainInstrIdx = 0; Instruction *BarrierMemoryInstr = nullptr; @@ -648,14 +686,14 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { // If a barrier memory instruction was found, chain instructions that follow // will not be added to the valid prefix. - if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, ChainInstr)) + if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr)) break; // Check (in BB order) if any instruction prevents ChainInstr from being // vectorized. Find and store the first such "conflicting" instruction. for (Instruction *MemInstr : MemoryInstrs) { // If a barrier memory instruction was found, do not check past it. - if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, MemInstr)) + if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr)) break; auto *MemLoad = dyn_cast<LoadInst>(MemInstr); @@ -674,12 +712,12 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { // vectorize it (the vectorized load is inserted at the location of the // first load in the chain). if (isa<StoreInst>(MemInstr) && ChainLoad && - (IsInvariantLoad(ChainLoad) || OBB.dominates(ChainLoad, MemInstr))) + (IsInvariantLoad(ChainLoad) || ChainLoad->comesBefore(MemInstr))) continue; // Same case, but in reverse. if (MemLoad && isa<StoreInst>(ChainInstr) && - (IsInvariantLoad(MemLoad) || OBB.dominates(MemLoad, ChainInstr))) + (IsInvariantLoad(MemLoad) || MemLoad->comesBefore(ChainInstr))) continue; if (!AA.isNoAlias(MemoryLocation::get(MemInstr), @@ -705,7 +743,7 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { // the basic block. if (IsLoadChain && BarrierMemoryInstr) { // The BarrierMemoryInstr is a store that precedes ChainInstr. - assert(OBB.dominates(BarrierMemoryInstr, ChainInstr)); + assert(BarrierMemoryInstr->comesBefore(ChainInstr)); break; } } @@ -961,7 +999,7 @@ bool Vectorizer::vectorizeStoreChain( unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); - unsigned Alignment = getAlignment(S0); + Align Alignment = S0->getAlign(); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -992,10 +1030,10 @@ bool Vectorizer::vectorizeStoreChain( VectorType *VecTy; VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy); if (VecStoreTy) - VecTy = VectorType::get(StoreTy->getScalarType(), - Chain.size() * VecStoreTy->getNumElements()); + VecTy = FixedVectorType::get(StoreTy->getScalarType(), + Chain.size() * VecStoreTy->getNumElements()); else - VecTy = VectorType::get(StoreTy, Chain.size()); + VecTy = FixedVectorType::get(StoreTy, Chain.size()); // If it's more than the max vector size or the target has a better // vector factor, break it into two pieces. @@ -1019,18 +1057,20 @@ bool Vectorizer::vectorizeStoreChain( InstructionsProcessed->insert(Chain.begin(), Chain.end()); // If the store is going to be misaligned, don't vectorize it. - if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + if (accessIsMisaligned(SzInBytes, AS, Alignment.value())) { if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeStoreChain(Chains.first, InstructionsProcessed) | vectorizeStoreChain(Chains.second, InstructionsProcessed); } - unsigned NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), - StackAdjustedAlignment, - DL, S0, nullptr, &DT); - if (NewAlign != 0) + Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), + Align(StackAdjustedAlignment), + DL, S0, nullptr, &DT); + if (NewAlign >= Alignment) Alignment = NewAlign; + else + return false; } if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { @@ -1112,7 +1152,7 @@ bool Vectorizer::vectorizeLoadChain( unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); - unsigned Alignment = getAlignment(L0); + Align Alignment = L0->getAlign(); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -1142,10 +1182,10 @@ bool Vectorizer::vectorizeLoadChain( VectorType *VecTy; VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy); if (VecLoadTy) - VecTy = VectorType::get(LoadTy->getScalarType(), - Chain.size() * VecLoadTy->getNumElements()); + VecTy = FixedVectorType::get(LoadTy->getScalarType(), + Chain.size() * VecLoadTy->getNumElements()); else - VecTy = VectorType::get(LoadTy, Chain.size()); + VecTy = FixedVectorType::get(LoadTy, Chain.size()); // If it's more than the max vector size or the target has a better // vector factor, break it into two pieces. @@ -1162,15 +1202,20 @@ bool Vectorizer::vectorizeLoadChain( InstructionsProcessed->insert(Chain.begin(), Chain.end()); // If the load is going to be misaligned, don't vectorize it. - if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + if (accessIsMisaligned(SzInBytes, AS, Alignment.value())) { if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeLoadChain(Chains.first, InstructionsProcessed) | vectorizeLoadChain(Chains.second, InstructionsProcessed); } - Alignment = getOrEnforceKnownAlignment( - L0->getPointerOperand(), StackAdjustedAlignment, DL, L0, nullptr, &DT); + Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), + Align(StackAdjustedAlignment), + DL, L0, nullptr, &DT); + if (NewAlign >= Alignment) + Alignment = NewAlign; + else + return false; } if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { @@ -1194,7 +1239,8 @@ bool Vectorizer::vectorizeLoadChain( Value *Bitcast = Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - LoadInst *LI = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment); + LoadInst *LI = + Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment)); propagateMetadata(LI, Chain); if (VecLoadTy) { |