diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 290 |
1 files changed, 239 insertions, 51 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index d18bcd34620c..57b11e9414ba 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -31,10 +31,12 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" +#define DEBUG_TYPE "vector-combine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace llvm::PatternMatch; -#define DEBUG_TYPE "vector-combine" STATISTIC(NumVecLoad, "Number of vector loads formed"); STATISTIC(NumVecCmp, "Number of vector compares formed"); STATISTIC(NumVecBO, "Number of vector binops formed"); @@ -61,8 +63,10 @@ namespace { class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, + bool ScalarizationOnly) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), + ScalarizationOnly(ScalarizationOnly) {} bool run(); @@ -74,12 +78,18 @@ private: AAResults &AA; AssumptionCache &AC; + /// If true only perform scalarization combines and do not introduce new + /// vector operations. + bool ScalarizationOnly; + + InstructionWorklist Worklist; + bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, ExtractElementInst *Ext1, unsigned PreferredExtractIndex) const; bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex); void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, @@ -92,14 +102,27 @@ private: bool foldExtractedCmps(Instruction &I); bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); + bool foldShuffleOfBinops(Instruction &I); + + void replaceValue(Value &Old, Value &New) { + Old.replaceAllUsesWith(&New); + New.takeName(&Old); + if (auto *NewI = dyn_cast<Instruction>(&New)) { + Worklist.pushUsersToWorkList(*NewI); + Worklist.pushValue(NewI); + } + Worklist.pushValue(&Old); + } + + void eraseInstruction(Instruction &I) { + for (Value *Op : I.operands()) + Worklist.pushValue(Op); + Worklist.remove(&I); + I.eraseFromParent(); + } }; } // namespace -static void replaceValue(Value &Old, Value &New) { - Old.replaceAllUsesWith(&New); - New.takeName(&Old); -} - bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Match insert into fixed vector of scalar value. // TODO: Handle non-zero insert index. @@ -284,12 +307,13 @@ ExtractElementInst *VectorCombine::getShuffleExtract( /// \p ConvertToShuffle to that extract instruction. bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex) { assert(isa<ConstantInt>(Ext0->getOperand(1)) && isa<ConstantInt>(Ext1->getOperand(1)) && "Expected constant extract indexes"); + unsigned Opcode = I.getOpcode(); Type *ScalarTy = Ext0->getType(); auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); InstructionCost ScalarOpCost, VectorOpCost; @@ -302,10 +326,11 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, } else { assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && "Expected a compare"); - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, - CmpInst::makeCmpResultType(VecTy)); + CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } // Get cost estimates for the extract elements. These costs will factor into @@ -480,8 +505,7 @@ bool VectorCombine::foldExtractExtract(Instruction &I) { m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); ExtractElementInst *ExtractToChange; - if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange, - InsertIndex)) + if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) return false; if (ExtractToChange) { @@ -501,6 +525,8 @@ bool VectorCombine::foldExtractExtract(Instruction &I) { else foldExtExtBinop(Ext0, Ext1, I); + Worklist.push(Ext0); + Worklist.push(Ext1); return true; } @@ -623,8 +649,11 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { unsigned Opcode = I.getOpcode(); InstructionCost ScalarOpCost, VectorOpCost; if (IsCmp) { - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy); + CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } else { ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); @@ -724,7 +753,10 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { InstructionCost OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); - OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; + OldCost += + TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(), + CmpInst::makeCmpResultType(I0->getType()), Pred) * + 2; OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); // The proposed vector pattern is: @@ -733,7 +765,8 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); - InstructionCost NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); + InstructionCost NewCost = TTI.getCmpSelInstrCost( + CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred); SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); ShufMask[CheapIndex] = ExpensiveIndex; NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, @@ -774,18 +807,98 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin, }); } +/// Helper class to indicate whether a vector index can be safely scalarized and +/// if a freeze needs to be inserted. +class ScalarizationResult { + enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; + + StatusTy Status; + Value *ToFreeze; + + ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) + : Status(Status), ToFreeze(ToFreeze) {} + +public: + ScalarizationResult(const ScalarizationResult &Other) = default; + ~ScalarizationResult() { + assert(!ToFreeze && "freeze() not called with ToFreeze being set"); + } + + static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } + static ScalarizationResult safe() { return {StatusTy::Safe}; } + static ScalarizationResult safeWithFreeze(Value *ToFreeze) { + return {StatusTy::SafeWithFreeze, ToFreeze}; + } + + /// Returns true if the index can be scalarize without requiring a freeze. + bool isSafe() const { return Status == StatusTy::Safe; } + /// Returns true if the index cannot be scalarized. + bool isUnsafe() const { return Status == StatusTy::Unsafe; } + /// Returns true if the index can be scalarize, but requires inserting a + /// freeze. + bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } + + /// Reset the state of Unsafe and clear ToFreze if set. + void discard() { + ToFreeze = nullptr; + Status = StatusTy::Unsafe; + } + + /// Freeze the ToFreeze and update the use in \p User to use it. + void freeze(IRBuilder<> &Builder, Instruction &UserI) { + assert(isSafeWithFreeze() && + "should only be used when freezing is required"); + assert(is_contained(ToFreeze->users(), &UserI) && + "UserI must be a user of ToFreeze"); + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(cast<Instruction>(&UserI)); + Value *Frozen = + Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); + for (Use &U : make_early_inc_range((UserI.operands()))) + if (U.get() == ToFreeze) + U.set(Frozen); + + ToFreeze = nullptr; + } +}; + /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, - Instruction *CtxI, AssumptionCache &AC) { - if (auto *C = dyn_cast<ConstantInt>(Idx)) - return C->getValue().ult(VecTy->getNumElements()); +static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, + Value *Idx, Instruction *CtxI, + AssumptionCache &AC, + const DominatorTree &DT) { + if (auto *C = dyn_cast<ConstantInt>(Idx)) { + if (C->getValue().ult(VecTy->getNumElements())) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } - APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); - APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); + APInt Zero(IntWidth, 0); + APInt MaxElts(IntWidth, VecTy->getNumElements()); ConstantRange ValidIndices(Zero, MaxElts); - ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); - return ValidIndices.contains(IdxRange); + ConstantRange IdxRange(IntWidth, true); + + if (isGuaranteedNotToBePoison(Idx, &AC)) { + if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, &DT))) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } + + // If the index may be poison, check if we can insert a freeze before the + // range of the index is restricted. + Value *IdxBase; + ConstantInt *CI; + if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.binaryAnd(CI->getValue()); + } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.urem(CI->getValue()); + } + + if (ValidIndices.contains(IdxRange)) + return ScalarizationResult::safeWithFreeze(IdxBase); + return ScalarizationResult::unsafe(); } /// The memory operation on a vector of \p ScalarType had alignment of @@ -833,12 +946,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx, Load, AC) || - SrcAddr != SI->getPointerOperand()->stripPointerCasts() || + SrcAddr != SI->getPointerOperand()->stripPointerCasts()) + return false; + + auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); + if (ScalarizableIdx.isUnsafe() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) return false; + if (ScalarizableIdx.isSafeWithFreeze()) + ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); Value *GEP = Builder.CreateInBoundsGEP( SI->getValueOperand()->getType(), SI->getPointerOperand(), {ConstantInt::get(Idx->getType(), 0), Idx}); @@ -849,8 +967,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { DL); NSI->setAlignment(ScalarOpAlignment); replaceValue(I, *NSI); - // Need erasing the store manually. - I.eraseFromParent(); + eraseInstruction(I); return true; } @@ -860,11 +977,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - Value *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) + if (!match(&I, m_Load(m_Value(Ptr)))) return false; - auto *LI = cast<LoadInst>(I.getOperand(0)); + auto *LI = cast<LoadInst>(&I); const DataLayout &DL = I.getModule()->getDataLayout(); if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType())) return false; @@ -909,8 +1025,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { else if (LastCheckedInst->comesBefore(UI)) LastCheckedInst = UI; - if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC)) + auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); + if (!ScalarIdx.isSafe()) { + // TODO: Freeze index if it is safe to do so. + ScalarIdx.discard(); return false; + } auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); OriginalCost += @@ -946,6 +1066,60 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } +/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into +/// "binop (shuffle), (shuffle)". +bool VectorCombine::foldShuffleOfBinops(Instruction &I) { + auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); + if (!VecTy) + return false; + + BinaryOperator *B0, *B1; + ArrayRef<int> Mask; + if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)), + m_Mask(Mask))) || + B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy) + return false; + + // Try to replace a binop with a shuffle if the shuffle is not costly. + // The new shuffle will choose from a single, common operand, so it may be + // cheaper than the existing two-operand shuffle. + SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size()); + Instruction::BinaryOps Opcode = B0->getOpcode(); + InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy); + InstructionCost ShufCost = TTI.getShuffleCost( + TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask); + if (ShufCost > BinopCost) + return false; + + // If we have something like "add X, Y" and "add Z, X", swap ops to match. + Value *X = B0->getOperand(0), *Y = B0->getOperand(1); + Value *Z = B1->getOperand(0), *W = B1->getOperand(1); + if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W) + std::swap(X, Y); + + Value *Shuf0, *Shuf1; + if (X == Z) { + // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W) + Shuf0 = Builder.CreateShuffleVector(X, UnaryMask); + Shuf1 = Builder.CreateShuffleVector(Y, W, Mask); + } else if (Y == W) { + // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y) + Shuf0 = Builder.CreateShuffleVector(X, Z, Mask); + Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask); + } else { + return false; + } + + Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); + // Intersect flags from the old binops. + if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { + NewInst->copyIRFlags(B0); + NewInst->andIRFlags(B1); + } + replaceValue(I, *NewBO); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -957,29 +1131,43 @@ bool VectorCombine::run() { return false; bool MadeChange = false; + auto FoldInst = [this, &MadeChange](Instruction &I) { + Builder.SetInsertPoint(&I); + if (!ScalarizationOnly) { + MadeChange |= vectorizeLoadInsert(I); + MadeChange |= foldExtractExtract(I); + MadeChange |= foldBitcastShuf(I); + MadeChange |= foldExtractedCmps(I); + MadeChange |= foldShuffleOfBinops(I); + } + MadeChange |= scalarizeBinopOrCmp(I); + MadeChange |= scalarizeLoadExtract(I); + MadeChange |= foldSingleElementStore(I); + }; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; // Use early increment range so that we can erase instructions in loop. for (Instruction &I : make_early_inc_range(BB)) { - if (isa<DbgInfoIntrinsic>(I)) + if (I.isDebugOrPseudoInst()) continue; - Builder.SetInsertPoint(&I); - MadeChange |= vectorizeLoadInsert(I); - MadeChange |= foldExtractExtract(I); - MadeChange |= foldBitcastShuf(I); - MadeChange |= scalarizeBinopOrCmp(I); - MadeChange |= foldExtractedCmps(I); - MadeChange |= scalarizeLoadExtract(I); - MadeChange |= foldSingleElementStore(I); + FoldInst(I); } } - // We're done with transforms, so remove dead instructions. - if (MadeChange) - for (BasicBlock &BB : F) - SimplifyInstructionsInBlock(&BB); + while (!Worklist.isEmpty()) { + Instruction *I = Worklist.removeOne(); + if (!I) + continue; + + if (isInstructionTriviallyDead(I)) { + eraseInstruction(*I); + continue; + } + + FoldInst(*I); + } return MadeChange; } @@ -1014,7 +1202,7 @@ public: auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, false); return Combiner.run(); } }; @@ -1038,7 +1226,7 @@ PreservedAnalyses VectorCombinePass::run(Function &F, TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); AAResults &AA = FAM.getResult<AAManager>(F); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; |