diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp | |
parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 246 |
1 files changed, 215 insertions, 31 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b80c58183dd5..61e62adbe327 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -105,7 +105,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // 2) Possibly more ExtractElements with the same index. // 3) Another operand, which will feed back into the PHI. Instruction *PHIUser = nullptr; - for (auto U : PN->users()) { + for (auto *U : PN->users()) { if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { if (EI.getIndexOperand() == EU->getIndexOperand()) Extracts.push_back(EU); @@ -171,7 +171,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, } } - for (auto E : Extracts) + for (auto *E : Extracts) replaceInstUsesWith(*E, scalarPHI); return &EI; @@ -187,13 +187,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { ElementCount NumElts = cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); Type *DestTy = Ext.getType(); + unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); bool IsBigEndian = DL.isBigEndian(); // If we are casting an integer to vector and extracting a portion, that is // a shift-right and truncate. - // TODO: Allow FP dest type by casting the trunc to FP? - if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && - isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + if (X->getType()->isIntegerTy()) { assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && "Expected fixed vector type for bitcast from scalar integer"); @@ -202,10 +201,18 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 if (IsBigEndian) ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; - unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); - if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { - Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); - return new TruncInst(Lshr, DestTy); + unsigned ShiftAmountC = ExtIndexC * DestWidth; + if (!ShiftAmountC || + (isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) && + Ext.getVectorOperand()->hasOneUse())) { + if (ShiftAmountC) + X = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + if (DestTy->isFloatingPointTy()) { + Type *DstIntTy = IntegerType::getIntNTy(X->getContext(), DestWidth); + Value *Trunc = Builder.CreateTrunc(X, DstIntTy); + return new BitCastInst(Trunc, DestTy); + } + return new TruncInst(X, DestTy); } } @@ -278,7 +285,6 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { return nullptr; unsigned SrcWidth = SrcTy->getScalarSizeInBits(); - unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); unsigned ShAmt = Chunk * DestWidth; // TODO: This limitation is more strict than necessary. We could sum the @@ -393,6 +399,20 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); + // extractelt (select %x, %vec1, %vec2), %const -> + // select %x, %vec1[%const], %vec2[%const] + // TODO: Support constant folding of multiple select operands: + // extractelt (select %x, %vec1, %vec2), (select %x, %c1, %c2) + // If the extractelement will for instance try to do out of bounds accesses + // because of the values of %c1 and/or %c2, the sequence could be optimized + // early. This is currently not possible because constant folding will reach + // an unreachable assertion if it doesn't find a constant operand. + if (SelectInst *SI = dyn_cast<SelectInst>(EI.getVectorOperand())) + if (SI->getCondition()->getType()->isIntegerTy() && + isa<Constant>(EI.getIndexOperand())) + if (Instruction *R = FoldOpIntoSelect(EI, SI)) + return R; + // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast<ConstantInt>(Index); @@ -850,17 +870,16 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( if (NumAggElts > 2) return nullptr; - static constexpr auto NotFound = None; + static constexpr auto NotFound = std::nullopt; static constexpr auto FoundMismatch = nullptr; // Try to find a value of each element of an aggregate. // FIXME: deal with more complex, not one-dimensional, aggregate types - SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); + SmallVector<std::optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); // Do we know values for each element of the aggregate? auto KnowAllElts = [&AggElts]() { - return all_of(AggElts, - [](Optional<Instruction *> Elt) { return Elt != NotFound; }); + return !llvm::is_contained(AggElts, NotFound); }; int Depth = 0; @@ -889,7 +908,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // Now, we may have already previously recorded the value for this element // of an aggregate. If we did, that means the CurrIVI will later be // overwritten with the already-recorded value. But if not, let's record it! - Optional<Instruction *> &Elt = AggElts[Indices.front()]; + std::optional<Instruction *> &Elt = AggElts[Indices.front()]; Elt = Elt.value_or(InsertedValue); // FIXME: should we handle chain-terminating undef base operand? @@ -919,7 +938,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( /// or different elements had different source aggregates. FoundMismatch }; - auto Describe = [](Optional<Value *> SourceAggregate) { + auto Describe = [](std::optional<Value *> SourceAggregate) { if (SourceAggregate == NotFound) return AggregateDescription::NotFound; if (*SourceAggregate == FoundMismatch) @@ -933,8 +952,8 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // If found, return the source aggregate from which the extraction was. // If \p PredBB is provided, does PHI translation of an \p Elt first. auto FindSourceAggregate = - [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { + [&](Instruction *Elt, unsigned EltIdx, std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { // For now(?), only deal with, at most, a single level of PHI indirection. if (UseBB && PredBB) Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB)); @@ -961,9 +980,9 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // see if we can find appropriate source aggregate for each of the elements, // and see it's the same aggregate for each element. If so, return it. auto FindCommonSourceAggregate = - [&](Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { - Optional<Value *> SourceAggregate; + [&](std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { + std::optional<Value *> SourceAggregate; for (auto I : enumerate(AggElts)) { assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch && @@ -975,7 +994,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // For this element, is there a plausible source aggregate? // FIXME: we could special-case undef element, IFF we know that in the // source aggregate said element isn't poison. - Optional<Value *> SourceAggregateForElement = + std::optional<Value *> SourceAggregateForElement = FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB); // Okay, what have we found? Does that correlate with previous findings? @@ -1009,10 +1028,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( return *SourceAggregate; }; - Optional<Value *> SourceAggregate; + std::optional<Value *> SourceAggregate; // Can we find the source aggregate without looking at predecessors? - SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None); + SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/std::nullopt, + /*PredBB=*/std::nullopt); if (Describe(SourceAggregate) != AggregateDescription::NotFound) { if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch) return nullptr; // Conflicting source aggregates! @@ -1029,7 +1049,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // they all should be defined in the same basic block. BasicBlock *UseBB = nullptr; - for (const Optional<Instruction *> &I : AggElts) { + for (const std::optional<Instruction *> &I : AggElts) { BasicBlock *BB = (*I)->getParent(); // If it's the first instruction we've encountered, record the basic block. if (!UseBB) { @@ -1495,6 +1515,71 @@ static Instruction *narrowInsElt(InsertElementInst &InsElt, return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); } +/// If we are inserting 2 halves of a value into adjacent elements of a vector, +/// try to convert to a single insert with appropriate bitcasts. +static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt, + bool IsBigEndian, + InstCombiner::BuilderTy &Builder) { + Value *VecOp = InsElt.getOperand(0); + Value *ScalarOp = InsElt.getOperand(1); + Value *IndexOp = InsElt.getOperand(2); + + // Pattern depends on endian because we expect lower index is inserted first. + // Big endian: + // inselt (inselt BaseVec, (trunc (lshr X, BW/2), Index0), (trunc X), Index1 + // Little endian: + // inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1 + // Note: It is not safe to do this transform with an arbitrary base vector + // because the bitcast of that vector to fewer/larger elements could + // allow poison to spill into an element that was not poison before. + // TODO: Detect smaller fractions of the scalar. + // TODO: One-use checks are conservative. + auto *VTy = dyn_cast<FixedVectorType>(InsElt.getType()); + Value *Scalar0, *BaseVec; + uint64_t Index0, Index1; + if (!VTy || (VTy->getNumElements() & 1) || + !match(IndexOp, m_ConstantInt(Index1)) || + !match(VecOp, m_InsertElt(m_Value(BaseVec), m_Value(Scalar0), + m_ConstantInt(Index0))) || + !match(BaseVec, m_Undef())) + return nullptr; + + // The first insert must be to the index one less than this one, and + // the first insert must be to an even index. + if (Index0 + 1 != Index1 || Index0 & 1) + return nullptr; + + // For big endian, the high half of the value should be inserted first. + // For little endian, the low half of the value should be inserted first. + Value *X; + uint64_t ShAmt; + if (IsBigEndian) { + if (!match(ScalarOp, m_Trunc(m_Value(X))) || + !match(Scalar0, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } else { + if (!match(Scalar0, m_Trunc(m_Value(X))) || + !match(ScalarOp, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } + + Type *SrcTy = X->getType(); + unsigned ScalarWidth = SrcTy->getScalarSizeInBits(); + unsigned VecEltWidth = VTy->getScalarSizeInBits(); + if (ScalarWidth != VecEltWidth * 2 || ShAmt != VecEltWidth) + return nullptr; + + // Bitcast the base vector to a vector type with the source element type. + Type *CastTy = FixedVectorType::get(SrcTy, VTy->getNumElements() / 2); + Value *CastBaseVec = Builder.CreateBitCast(BaseVec, CastTy); + + // Scale the insert index for a vector with half as many elements. + // bitcast (inselt (bitcast BaseVec), X, NewIndex) + uint64_t NewIndex = IsBigEndian ? Index1 / 2 : Index0 / 2; + Value *NewInsert = Builder.CreateInsertElement(CastBaseVec, X, NewIndex); + return new BitCastInst(NewInsert, VTy); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1505,10 +1590,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { return replaceInstUsesWith(IE, V); // Canonicalize type of constant indices to i64 to simplify CSE - if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) + if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) { if (auto *NewIdx = getPreferredVectorIndex(IndexC)) return replaceOperand(IE, 2, NewIdx); + Value *BaseVec, *OtherScalar; + uint64_t OtherIndexVal; + if (match(VecOp, m_OneUse(m_InsertElt(m_Value(BaseVec), + m_Value(OtherScalar), + m_ConstantInt(OtherIndexVal)))) && + !isa<Constant>(OtherScalar) && OtherIndexVal > IndexC->getZExtValue()) { + Value *NewIns = Builder.CreateInsertElement(BaseVec, ScalarOp, IdxOp); + return InsertElementInst::Create(NewIns, OtherScalar, + Builder.getInt64(OtherIndexVal)); + } + } + // If the scalar is bitcast and inserted into undef, do the insert in the // source type followed by bitcast. // TODO: Generalize for insert into any constant, not just undef? @@ -1622,6 +1719,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Ext = narrowInsElt(IE, Builder)) return Ext; + if (Instruction *Ext = foldTruncInsEltPair(IE, DL.isBigEndian(), Builder)) + return Ext; + return nullptr; } @@ -1653,7 +1753,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // from an undefined element in an operand. if (llvm::is_contained(Mask, -1)) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -1700,8 +1800,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // Verify that 'CI' does not occur twice in Mask. A single 'insertelement' // can't put an element into multiple indices. bool SeenOnce = false; - for (int i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == ElementNumber) { + for (int I : Mask) { + if (I == ElementNumber) { if (SeenOnce) return false; SeenOnce = true; @@ -1957,6 +2057,56 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { return {}; } +/// A select shuffle of a select shuffle with a shared operand can be reduced +/// to a single select shuffle. This is an obvious improvement in IR, and the +/// backend is expected to lower select shuffles efficiently. +static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + SmallVector<int, 16> Mask; + Shuf.getShuffleMask(Mask); + unsigned NumElts = Mask.size(); + + // Canonicalize a select shuffle with common operand as Op1. + auto *ShufOp = dyn_cast<ShuffleVectorInst>(Op0); + if (ShufOp && ShufOp->isSelect() && + (ShufOp->getOperand(0) == Op1 || ShufOp->getOperand(1) == Op1)) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + } + + ShufOp = dyn_cast<ShuffleVectorInst>(Op1); + if (!ShufOp || !ShufOp->isSelect() || + (ShufOp->getOperand(0) != Op0 && ShufOp->getOperand(1) != Op0)) + return nullptr; + + Value *X = ShufOp->getOperand(0), *Y = ShufOp->getOperand(1); + SmallVector<int, 16> Mask1; + ShufOp->getShuffleMask(Mask1); + assert(Mask1.size() == NumElts && "Vector size changed with select shuffle"); + + // Canonicalize common operand (Op0) as X (first operand of first shuffle). + if (Y == Op0) { + std::swap(X, Y); + ShuffleVectorInst::commuteShuffleMask(Mask1, NumElts); + } + + // If the mask chooses from X (operand 0), it stays the same. + // If the mask chooses from the earlier shuffle, the other mask value is + // transferred to the combined select shuffle: + // shuf X, (shuf X, Y, M1), M --> shuf X, Y, M' + SmallVector<int, 16> NewMask(NumElts); + for (unsigned i = 0; i != NumElts; ++i) + NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i]; + + // A select mask with undef elements might look like an identity mask. + assert((ShuffleVectorInst::isSelectMask(NewMask) || + ShuffleVectorInst::isIdentityMask(NewMask)) && + "Unexpected shuffle mask"); + return new ShuffleVectorInst(X, Y, NewMask); +} + static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); @@ -2061,6 +2211,9 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { return &Shuf; } + if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf)) + return I; + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) return I; @@ -2541,6 +2694,35 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, NewMask); } +// Splatting the first element of the result of a BinOp, where any of the +// BinOp's operands are the result of a first element splat can be simplified to +// splatting the first element of the result of the BinOp +Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { + if (!match(SVI.getOperand(1), m_Undef()) || + !match(SVI.getShuffleMask(), m_ZeroMask())) + return nullptr; + + Value *Op0 = SVI.getOperand(0); + Value *X, *Y; + if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()), + m_Value(Y))) && + !match(Op0, m_BinOp(m_Value(X), + m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask())))) + return nullptr; + if (X->getType() != Y->getType()) + return nullptr; + + auto *BinOp = cast<BinaryOperator>(Op0); + if (!isSafeToSpeculativelyExecute(BinOp)) + return nullptr; + + Value *NewBO = Builder.CreateBinOp(BinOp->getOpcode(), X, Y); + if (auto NewBOI = dyn_cast<Instruction>(NewBO)) + NewBOI->copyIRFlags(BinOp); + + return new ShuffleVectorInst(NewBO, SVI.getShuffleMask()); +} + Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -2549,7 +2731,9 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); - // Bail out for scalable vectors + if (Instruction *I = simplifyBinOpSplats(SVI)) + return I; + if (isa<ScalableVectorType>(LHS->getType())) return nullptr; @@ -2694,7 +2878,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *V = LHS; unsigned MaskElems = Mask.size(); auto *SrcTy = cast<FixedVectorType>(V->getType()); - unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); + unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedValue(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); assert(SrcElemBitWidth && "vector elements must have a bitwidth"); unsigned SrcNumElems = SrcTy->getNumElements(); |