aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:04 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:11 +0000
commite3b557809604d036af6e00c60f012c2025b59a5e (patch)
tree8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
parent08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp246
1 files changed, 215 insertions, 31 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index b80c58183dd5..61e62adbe327 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -105,7 +105,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
// 2) Possibly more ExtractElements with the same index.
// 3) Another operand, which will feed back into the PHI.
Instruction *PHIUser = nullptr;
- for (auto U : PN->users()) {
+ for (auto *U : PN->users()) {
if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) {
if (EI.getIndexOperand() == EU->getIndexOperand())
Extracts.push_back(EU);
@@ -171,7 +171,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
}
}
- for (auto E : Extracts)
+ for (auto *E : Extracts)
replaceInstUsesWith(*E, scalarPHI);
return &EI;
@@ -187,13 +187,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
ElementCount NumElts =
cast<VectorType>(Ext.getVectorOperandType())->getElementCount();
Type *DestTy = Ext.getType();
+ unsigned DestWidth = DestTy->getPrimitiveSizeInBits();
bool IsBigEndian = DL.isBigEndian();
// If we are casting an integer to vector and extracting a portion, that is
// a shift-right and truncate.
- // TODO: Allow FP dest type by casting the trunc to FP?
- if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() &&
- isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) {
+ if (X->getType()->isIntegerTy()) {
assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) &&
"Expected fixed vector type for bitcast from scalar integer");
@@ -202,10 +201,18 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
// BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8
if (IsBigEndian)
ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC;
- unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits();
- if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) {
- Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset");
- return new TruncInst(Lshr, DestTy);
+ unsigned ShiftAmountC = ExtIndexC * DestWidth;
+ if (!ShiftAmountC ||
+ (isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) &&
+ Ext.getVectorOperand()->hasOneUse())) {
+ if (ShiftAmountC)
+ X = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset");
+ if (DestTy->isFloatingPointTy()) {
+ Type *DstIntTy = IntegerType::getIntNTy(X->getContext(), DestWidth);
+ Value *Trunc = Builder.CreateTrunc(X, DstIntTy);
+ return new BitCastInst(Trunc, DestTy);
+ }
+ return new TruncInst(X, DestTy);
}
}
@@ -278,7 +285,6 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
return nullptr;
unsigned SrcWidth = SrcTy->getScalarSizeInBits();
- unsigned DestWidth = DestTy->getPrimitiveSizeInBits();
unsigned ShAmt = Chunk * DestWidth;
// TODO: This limitation is more strict than necessary. We could sum the
@@ -393,6 +399,20 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
SQ.getWithInstruction(&EI)))
return replaceInstUsesWith(EI, V);
+ // extractelt (select %x, %vec1, %vec2), %const ->
+ // select %x, %vec1[%const], %vec2[%const]
+ // TODO: Support constant folding of multiple select operands:
+ // extractelt (select %x, %vec1, %vec2), (select %x, %c1, %c2)
+ // If the extractelement will for instance try to do out of bounds accesses
+ // because of the values of %c1 and/or %c2, the sequence could be optimized
+ // early. This is currently not possible because constant folding will reach
+ // an unreachable assertion if it doesn't find a constant operand.
+ if (SelectInst *SI = dyn_cast<SelectInst>(EI.getVectorOperand()))
+ if (SI->getCondition()->getType()->isIntegerTy() &&
+ isa<Constant>(EI.getIndexOperand()))
+ if (Instruction *R = FoldOpIntoSelect(EI, SI))
+ return R;
+
// If extracting a specified index from the vector, see if we can recursively
// find a previously computed scalar that was inserted into the vector.
auto *IndexC = dyn_cast<ConstantInt>(Index);
@@ -850,17 +870,16 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
if (NumAggElts > 2)
return nullptr;
- static constexpr auto NotFound = None;
+ static constexpr auto NotFound = std::nullopt;
static constexpr auto FoundMismatch = nullptr;
// Try to find a value of each element of an aggregate.
// FIXME: deal with more complex, not one-dimensional, aggregate types
- SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound);
+ SmallVector<std::optional<Instruction *>, 2> AggElts(NumAggElts, NotFound);
// Do we know values for each element of the aggregate?
auto KnowAllElts = [&AggElts]() {
- return all_of(AggElts,
- [](Optional<Instruction *> Elt) { return Elt != NotFound; });
+ return !llvm::is_contained(AggElts, NotFound);
};
int Depth = 0;
@@ -889,7 +908,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// Now, we may have already previously recorded the value for this element
// of an aggregate. If we did, that means the CurrIVI will later be
// overwritten with the already-recorded value. But if not, let's record it!
- Optional<Instruction *> &Elt = AggElts[Indices.front()];
+ std::optional<Instruction *> &Elt = AggElts[Indices.front()];
Elt = Elt.value_or(InsertedValue);
// FIXME: should we handle chain-terminating undef base operand?
@@ -919,7 +938,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
/// or different elements had different source aggregates.
FoundMismatch
};
- auto Describe = [](Optional<Value *> SourceAggregate) {
+ auto Describe = [](std::optional<Value *> SourceAggregate) {
if (SourceAggregate == NotFound)
return AggregateDescription::NotFound;
if (*SourceAggregate == FoundMismatch)
@@ -933,8 +952,8 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// If found, return the source aggregate from which the extraction was.
// If \p PredBB is provided, does PHI translation of an \p Elt first.
auto FindSourceAggregate =
- [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB,
- Optional<BasicBlock *> PredBB) -> Optional<Value *> {
+ [&](Instruction *Elt, unsigned EltIdx, std::optional<BasicBlock *> UseBB,
+ std::optional<BasicBlock *> PredBB) -> std::optional<Value *> {
// For now(?), only deal with, at most, a single level of PHI indirection.
if (UseBB && PredBB)
Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB));
@@ -961,9 +980,9 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// see if we can find appropriate source aggregate for each of the elements,
// and see it's the same aggregate for each element. If so, return it.
auto FindCommonSourceAggregate =
- [&](Optional<BasicBlock *> UseBB,
- Optional<BasicBlock *> PredBB) -> Optional<Value *> {
- Optional<Value *> SourceAggregate;
+ [&](std::optional<BasicBlock *> UseBB,
+ std::optional<BasicBlock *> PredBB) -> std::optional<Value *> {
+ std::optional<Value *> SourceAggregate;
for (auto I : enumerate(AggElts)) {
assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch &&
@@ -975,7 +994,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// For this element, is there a plausible source aggregate?
// FIXME: we could special-case undef element, IFF we know that in the
// source aggregate said element isn't poison.
- Optional<Value *> SourceAggregateForElement =
+ std::optional<Value *> SourceAggregateForElement =
FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB);
// Okay, what have we found? Does that correlate with previous findings?
@@ -1009,10 +1028,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
return *SourceAggregate;
};
- Optional<Value *> SourceAggregate;
+ std::optional<Value *> SourceAggregate;
// Can we find the source aggregate without looking at predecessors?
- SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None);
+ SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/std::nullopt,
+ /*PredBB=*/std::nullopt);
if (Describe(SourceAggregate) != AggregateDescription::NotFound) {
if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch)
return nullptr; // Conflicting source aggregates!
@@ -1029,7 +1049,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// they all should be defined in the same basic block.
BasicBlock *UseBB = nullptr;
- for (const Optional<Instruction *> &I : AggElts) {
+ for (const std::optional<Instruction *> &I : AggElts) {
BasicBlock *BB = (*I)->getParent();
// If it's the first instruction we've encountered, record the basic block.
if (!UseBB) {
@@ -1495,6 +1515,71 @@ static Instruction *narrowInsElt(InsertElementInst &InsElt,
return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType());
}
+/// If we are inserting 2 halves of a value into adjacent elements of a vector,
+/// try to convert to a single insert with appropriate bitcasts.
+static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt,
+ bool IsBigEndian,
+ InstCombiner::BuilderTy &Builder) {
+ Value *VecOp = InsElt.getOperand(0);
+ Value *ScalarOp = InsElt.getOperand(1);
+ Value *IndexOp = InsElt.getOperand(2);
+
+ // Pattern depends on endian because we expect lower index is inserted first.
+ // Big endian:
+ // inselt (inselt BaseVec, (trunc (lshr X, BW/2), Index0), (trunc X), Index1
+ // Little endian:
+ // inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1
+ // Note: It is not safe to do this transform with an arbitrary base vector
+ // because the bitcast of that vector to fewer/larger elements could
+ // allow poison to spill into an element that was not poison before.
+ // TODO: Detect smaller fractions of the scalar.
+ // TODO: One-use checks are conservative.
+ auto *VTy = dyn_cast<FixedVectorType>(InsElt.getType());
+ Value *Scalar0, *BaseVec;
+ uint64_t Index0, Index1;
+ if (!VTy || (VTy->getNumElements() & 1) ||
+ !match(IndexOp, m_ConstantInt(Index1)) ||
+ !match(VecOp, m_InsertElt(m_Value(BaseVec), m_Value(Scalar0),
+ m_ConstantInt(Index0))) ||
+ !match(BaseVec, m_Undef()))
+ return nullptr;
+
+ // The first insert must be to the index one less than this one, and
+ // the first insert must be to an even index.
+ if (Index0 + 1 != Index1 || Index0 & 1)
+ return nullptr;
+
+ // For big endian, the high half of the value should be inserted first.
+ // For little endian, the low half of the value should be inserted first.
+ Value *X;
+ uint64_t ShAmt;
+ if (IsBigEndian) {
+ if (!match(ScalarOp, m_Trunc(m_Value(X))) ||
+ !match(Scalar0, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt)))))
+ return nullptr;
+ } else {
+ if (!match(Scalar0, m_Trunc(m_Value(X))) ||
+ !match(ScalarOp, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt)))))
+ return nullptr;
+ }
+
+ Type *SrcTy = X->getType();
+ unsigned ScalarWidth = SrcTy->getScalarSizeInBits();
+ unsigned VecEltWidth = VTy->getScalarSizeInBits();
+ if (ScalarWidth != VecEltWidth * 2 || ShAmt != VecEltWidth)
+ return nullptr;
+
+ // Bitcast the base vector to a vector type with the source element type.
+ Type *CastTy = FixedVectorType::get(SrcTy, VTy->getNumElements() / 2);
+ Value *CastBaseVec = Builder.CreateBitCast(BaseVec, CastTy);
+
+ // Scale the insert index for a vector with half as many elements.
+ // bitcast (inselt (bitcast BaseVec), X, NewIndex)
+ uint64_t NewIndex = IsBigEndian ? Index1 / 2 : Index0 / 2;
+ Value *NewInsert = Builder.CreateInsertElement(CastBaseVec, X, NewIndex);
+ return new BitCastInst(NewInsert, VTy);
+}
+
Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
Value *VecOp = IE.getOperand(0);
Value *ScalarOp = IE.getOperand(1);
@@ -1505,10 +1590,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
return replaceInstUsesWith(IE, V);
// Canonicalize type of constant indices to i64 to simplify CSE
- if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp))
+ if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) {
if (auto *NewIdx = getPreferredVectorIndex(IndexC))
return replaceOperand(IE, 2, NewIdx);
+ Value *BaseVec, *OtherScalar;
+ uint64_t OtherIndexVal;
+ if (match(VecOp, m_OneUse(m_InsertElt(m_Value(BaseVec),
+ m_Value(OtherScalar),
+ m_ConstantInt(OtherIndexVal)))) &&
+ !isa<Constant>(OtherScalar) && OtherIndexVal > IndexC->getZExtValue()) {
+ Value *NewIns = Builder.CreateInsertElement(BaseVec, ScalarOp, IdxOp);
+ return InsertElementInst::Create(NewIns, OtherScalar,
+ Builder.getInt64(OtherIndexVal));
+ }
+ }
+
// If the scalar is bitcast and inserted into undef, do the insert in the
// source type followed by bitcast.
// TODO: Generalize for insert into any constant, not just undef?
@@ -1622,6 +1719,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
if (Instruction *Ext = narrowInsElt(IE, Builder))
return Ext;
+ if (Instruction *Ext = foldTruncInsEltPair(IE, DL.isBigEndian(), Builder))
+ return Ext;
+
return nullptr;
}
@@ -1653,7 +1753,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// from an undefined element in an operand.
if (llvm::is_contained(Mask, -1))
return false;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
@@ -1700,8 +1800,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// Verify that 'CI' does not occur twice in Mask. A single 'insertelement'
// can't put an element into multiple indices.
bool SeenOnce = false;
- for (int i = 0, e = Mask.size(); i != e; ++i) {
- if (Mask[i] == ElementNumber) {
+ for (int I : Mask) {
+ if (I == ElementNumber) {
if (SeenOnce)
return false;
SeenOnce = true;
@@ -1957,6 +2057,56 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
return {};
}
+/// A select shuffle of a select shuffle with a shared operand can be reduced
+/// to a single select shuffle. This is an obvious improvement in IR, and the
+/// backend is expected to lower select shuffles efficiently.
+static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) {
+ assert(Shuf.isSelect() && "Must have select-equivalent shuffle");
+
+ Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1);
+ SmallVector<int, 16> Mask;
+ Shuf.getShuffleMask(Mask);
+ unsigned NumElts = Mask.size();
+
+ // Canonicalize a select shuffle with common operand as Op1.
+ auto *ShufOp = dyn_cast<ShuffleVectorInst>(Op0);
+ if (ShufOp && ShufOp->isSelect() &&
+ (ShufOp->getOperand(0) == Op1 || ShufOp->getOperand(1) == Op1)) {
+ std::swap(Op0, Op1);
+ ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
+ }
+
+ ShufOp = dyn_cast<ShuffleVectorInst>(Op1);
+ if (!ShufOp || !ShufOp->isSelect() ||
+ (ShufOp->getOperand(0) != Op0 && ShufOp->getOperand(1) != Op0))
+ return nullptr;
+
+ Value *X = ShufOp->getOperand(0), *Y = ShufOp->getOperand(1);
+ SmallVector<int, 16> Mask1;
+ ShufOp->getShuffleMask(Mask1);
+ assert(Mask1.size() == NumElts && "Vector size changed with select shuffle");
+
+ // Canonicalize common operand (Op0) as X (first operand of first shuffle).
+ if (Y == Op0) {
+ std::swap(X, Y);
+ ShuffleVectorInst::commuteShuffleMask(Mask1, NumElts);
+ }
+
+ // If the mask chooses from X (operand 0), it stays the same.
+ // If the mask chooses from the earlier shuffle, the other mask value is
+ // transferred to the combined select shuffle:
+ // shuf X, (shuf X, Y, M1), M --> shuf X, Y, M'
+ SmallVector<int, 16> NewMask(NumElts);
+ for (unsigned i = 0; i != NumElts; ++i)
+ NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i];
+
+ // A select mask with undef elements might look like an identity mask.
+ assert((ShuffleVectorInst::isSelectMask(NewMask) ||
+ ShuffleVectorInst::isIdentityMask(NewMask)) &&
+ "Unexpected shuffle mask");
+ return new ShuffleVectorInst(X, Y, NewMask);
+}
+
static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) {
assert(Shuf.isSelect() && "Must have select-equivalent shuffle");
@@ -2061,6 +2211,9 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) {
return &Shuf;
}
+ if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf))
+ return I;
+
if (Instruction *I = foldSelectShuffleWith1Binop(Shuf))
return I;
@@ -2541,6 +2694,35 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
return new ShuffleVectorInst(X, Y, NewMask);
}
+// Splatting the first element of the result of a BinOp, where any of the
+// BinOp's operands are the result of a first element splat can be simplified to
+// splatting the first element of the result of the BinOp
+Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) {
+ if (!match(SVI.getOperand(1), m_Undef()) ||
+ !match(SVI.getShuffleMask(), m_ZeroMask()))
+ return nullptr;
+
+ Value *Op0 = SVI.getOperand(0);
+ Value *X, *Y;
+ if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()),
+ m_Value(Y))) &&
+ !match(Op0, m_BinOp(m_Value(X),
+ m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask()))))
+ return nullptr;
+ if (X->getType() != Y->getType())
+ return nullptr;
+
+ auto *BinOp = cast<BinaryOperator>(Op0);
+ if (!isSafeToSpeculativelyExecute(BinOp))
+ return nullptr;
+
+ Value *NewBO = Builder.CreateBinOp(BinOp->getOpcode(), X, Y);
+ if (auto NewBOI = dyn_cast<Instruction>(NewBO))
+ NewBOI->copyIRFlags(BinOp);
+
+ return new ShuffleVectorInst(NewBO, SVI.getShuffleMask());
+}
+
Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *LHS = SVI.getOperand(0);
Value *RHS = SVI.getOperand(1);
@@ -2549,7 +2731,9 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SVI.getType(), ShufQuery))
return replaceInstUsesWith(SVI, V);
- // Bail out for scalable vectors
+ if (Instruction *I = simplifyBinOpSplats(SVI))
+ return I;
+
if (isa<ScalableVectorType>(LHS->getType()))
return nullptr;
@@ -2694,7 +2878,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *V = LHS;
unsigned MaskElems = Mask.size();
auto *SrcTy = cast<FixedVectorType>(V->getType());
- unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize();
+ unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedValue();
unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType());
assert(SrcElemBitWidth && "vector elements must have a bitwidth");
unsigned SrcNumElems = SrcTy->getNumElements();