aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp192
1 files changed, 123 insertions, 69 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 77d675422966..00eece9534b0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -168,7 +168,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
// about the high bits of the operands.
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
@@ -195,7 +195,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown & RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -224,7 +225,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown | RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -262,7 +264,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown ^ RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -381,7 +384,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
// Only known if known in both the LHS and RHS.
- Known = KnownBits::commonBits(LHSKnown, RHSKnown);
+ Known = LHSKnown.intersectWith(RHSKnown);
break;
}
case Instruction::Trunc: {
@@ -393,7 +396,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The shift amount must be valid (not poison) in the narrow type, and
// it must not be greater than the high bits demanded of the result.
if (C->ult(VTy->getScalarSizeInBits()) &&
- C->ule(DemandedMask.countLeadingZeros())) {
+ C->ule(DemandedMask.countl_zero())) {
// trunc (lshr X, C) --> lshr (trunc X), C
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
@@ -508,7 +511,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
@@ -517,7 +520,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If low order bits are not demanded and known to be zero in one operand,
// then we don't need to demand them from the other operand, since they
// can't cause overflow into any bits that are demanded in the result.
- unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
APInt DemandedFromLHS = DemandedFromOps;
DemandedFromLHS.clearLowBits(NTZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
@@ -539,7 +542,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
case Instruction::Sub: {
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
@@ -548,7 +551,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If low order bits are not demanded and are known to be zero in RHS,
// then we don't need to demand them from LHS, since they can't cause a
// borrow from any bits that are demanded in the result.
- unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
APInt DemandedFromLHS = DemandedFromOps;
DemandedFromLHS.clearLowBits(NTZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
@@ -578,10 +581,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
// If we demand exactly one bit N and we have "X * (C' << N)" where C' is
// odd (has LSB set), then the left-shifted low bit of X is the answer.
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
const APInt *C;
- if (match(I->getOperand(1), m_APInt(C)) &&
- C->countTrailingZeros() == CTZ) {
+ if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) {
Constant *ShiftC = ConstantInt::get(VTy, CTZ);
Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
return InsertNewInstWith(Shl, *I);
@@ -619,7 +621,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
Value *X;
Constant *C;
- if (DemandedMask.countTrailingZeros() >= ShiftAmt &&
+ if (DemandedMask.countr_zero() >= ShiftAmt &&
match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC);
@@ -642,29 +644,15 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
- bool SignBitZero = Known.Zero.isSignBitSet();
- bool SignBitOne = Known.One.isSignBitSet();
- Known.Zero <<= ShiftAmt;
- Known.One <<= ShiftAmt;
- // low bits known zero.
- if (ShiftAmt)
- Known.Zero.setLowBits(ShiftAmt);
-
- // If this shift has "nsw" keyword, then the result is either a poison
- // value or has the same sign bit as the first operand.
- if (IOp->hasNoSignedWrap()) {
- if (SignBitZero)
- Known.Zero.setSignBit();
- else if (SignBitOne)
- Known.One.setSignBit();
- if (Known.hasConflict())
- return UndefValue::get(VTy);
- }
+ Known = KnownBits::shl(Known,
+ KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
+ /* NUW */ IOp->hasNoUnsignedWrap(),
+ /* NSW */ IOp->hasNoSignedWrap());
} else {
// This is a variable shift, so we can't shift the demand mask by a known
// amount. But if we are not demanding high bits, then we are not
// demanding those bits from the pre-shifted operand either.
- if (unsigned CTLZ = DemandedMask.countLeadingZeros()) {
+ if (unsigned CTLZ = DemandedMask.countl_zero()) {
APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) {
// We can't guarantee that nsw/nuw hold after simplifying the operand.
@@ -683,11 +671,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If we are just demanding the shifted sign bit and below, then this can
// be treated as an ASHR in disguise.
- if (DemandedMask.countLeadingZeros() >= ShiftAmt) {
+ if (DemandedMask.countl_zero() >= ShiftAmt) {
// If we only want bits that already match the signbit then we don't
// need to shift.
- unsigned NumHiDemandedBits =
- BitWidth - DemandedMask.countTrailingZeros();
+ unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
unsigned SignBits =
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
if (SignBits >= NumHiDemandedBits)
@@ -734,7 +721,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If we only want bits that already match the signbit then we don't need
// to shift.
- unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros();
+ unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
if (SignBits >= NumHiDemandedBits)
return I->getOperand(0);
@@ -757,7 +744,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
// If any of the high bits are demanded, we should set the sign bit as
// demanded.
- if (DemandedMask.countLeadingZeros() <= ShiftAmt)
+ if (DemandedMask.countl_zero() <= ShiftAmt)
DemandedMaskIn.setSignBit();
// If the shift is exact, then it does demand the low bits (and knows that
@@ -797,7 +784,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
// TODO: Take the demanded mask of the result into account.
- unsigned RHSTrailingZeros = SA->countTrailingZeros();
+ unsigned RHSTrailingZeros = SA->countr_zero();
APInt DemandedMaskIn =
APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) {
@@ -807,9 +794,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
}
- // Increase high zero bits from the input.
- Known.Zero.setHighBits(std::min(
- BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros));
+ Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA),
+ cast<BinaryOperator>(I)->isExact());
} else {
computeKnownBits(I, Known, Depth, CxtI);
}
@@ -851,25 +837,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
}
- // The sign bit is the LHS's sign bit, except when the result of the
- // remainder is zero.
- if (DemandedMask.isSignBitSet()) {
- computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
- // If it's known zero, our sign bit is also zero.
- if (LHSKnown.isNonNegative())
- Known.makeNonNegative();
- }
+ computeKnownBits(I, Known, Depth, CxtI);
break;
}
case Instruction::URem: {
- KnownBits Known2(BitWidth);
APInt AllOnes = APInt::getAllOnes(BitWidth);
- if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) ||
- SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, AllOnes, LHSKnown, Depth + 1) ||
+ SimplifyDemandedBits(I, 1, AllOnes, RHSKnown, Depth + 1))
return I;
- unsigned Leaders = Known2.countMinLeadingZeros();
- Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
+ Known = KnownBits::urem(LHSKnown, RHSKnown);
break;
}
case Instruction::Call: {
@@ -897,8 +874,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
case Intrinsic::bswap: {
// If the only bits demanded come from one byte of the bswap result,
// just shift the input byte into position to eliminate the bswap.
- unsigned NLZ = DemandedMask.countLeadingZeros();
- unsigned NTZ = DemandedMask.countTrailingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
+ unsigned NTZ = DemandedMask.countr_zero();
// Round NTZ down to the next byte. If we have 11 trailing zeros, then
// we need all the bits down to bit 8. Likewise, round NLZ. If we
@@ -935,9 +912,28 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
- if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) ||
- SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
- return I;
+ if (I->getOperand(0) != I->getOperand(1)) {
+ if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
+ return I;
+ } else { // fshl is a rotate
+ // Avoid converting rotate into funnel shift.
+ // Only simplify if one operand is constant.
+ LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I);
+ if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) &&
+ !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) {
+ replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One));
+ return I;
+ }
+
+ RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I);
+ if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) &&
+ !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) {
+ replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One));
+ return I;
+ }
+ }
Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
@@ -951,7 +947,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The lowest non-zero bit of DemandMask is higher than the highest
// non-zero bit of C.
const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
if (match(II->getArgOperand(1), m_APInt(C)) &&
CTZ >= C->getActiveBits())
return II->getArgOperand(0);
@@ -963,9 +959,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// non-one bit of C.
// This comes from using DeMorgans on the above umax example.
const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
if (match(II->getArgOperand(1), m_APInt(C)) &&
- CTZ >= C->getBitWidth() - C->countLeadingOnes())
+ CTZ >= C->getBitWidth() - C->countl_one())
return II->getArgOperand(0);
break;
}
@@ -1014,6 +1010,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown & RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1033,6 +1030,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown | RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1054,6 +1052,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown ^ RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1071,7 +1070,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
break;
}
case Instruction::Add: {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
// If an operand adds zeros to every bit below the highest demanded bit,
@@ -1084,10 +1083,13 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
+ bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+ Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown);
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::Sub: {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
// If an operand subtracts zeros from every bit below the highest demanded
@@ -1096,6 +1098,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
+ bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+ Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown);
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::AShr: {
@@ -1541,7 +1547,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// Found constant vector with single element - convert to insertelement.
if (Op && Value) {
Instruction *New = InsertElementInst::Create(
- Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx),
+ Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx),
Shuffle->getName());
InsertNewInstWith(New, *Shuffle);
return New;
@@ -1552,7 +1558,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
SmallVector<int, 16> Elts;
for (unsigned i = 0; i < VWidth; ++i) {
if (UndefElts[i])
- Elts.push_back(UndefMaskElem);
+ Elts.push_back(PoisonMaskElem);
else
Elts.push_back(Shuffle->getMaskValue(i));
}
@@ -1653,7 +1659,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// corresponding input elements are undef.
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
- if (SubUndef.countPopulation() == Ratio)
+ if (SubUndef.popcount() == Ratio)
UndefElts.setBit(OutIdx);
}
} else {
@@ -1712,6 +1718,54 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// UB/poison potential, but that should be refined.
BinaryOperator *BO;
if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
+ Value *X = BO->getOperand(0);
+ Value *Y = BO->getOperand(1);
+
+ // Look for an equivalent binop except that one operand has been shuffled.
+ // If the demand for this binop only includes elements that are the same as
+ // the other binop, then we may be able to replace this binop with a use of
+ // the earlier one.
+ //
+ // Example:
+ // %other_bo = bo (shuf X, {0}), Y
+ // %this_extracted_bo = extelt (bo X, Y), 0
+ // -->
+ // %other_bo = bo (shuf X, {0}), Y
+ // %this_extracted_bo = extelt %other_bo, 0
+ //
+ // TODO: Handle demand of an arbitrary single element or more than one
+ // element instead of just element 0.
+ // TODO: Unlike general demanded elements transforms, this should be safe
+ // for any (div/rem/shift) opcode too.
+ if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() &&
+ BO->hasOneUse() ) {
+
+ auto findShufBO = [&](bool MatchShufAsOp0) -> User * {
+ // Try to use shuffle-of-operand in place of an operand:
+ // bo X, Y --> bo (shuf X), Y
+ // bo X, Y --> bo X, (shuf Y)
+ BinaryOperator::BinaryOps Opcode = BO->getOpcode();
+ Value *ShufOp = MatchShufAsOp0 ? X : Y;
+ Value *OtherOp = MatchShufAsOp0 ? Y : X;
+ for (User *U : OtherOp->users()) {
+ auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_ZeroMask());
+ if (BO->isCommutative()
+ ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
+ : MatchShufAsOp0
+ ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
+ : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf)))
+ if (DT.dominates(U, I))
+ return U;
+ }
+ return nullptr;
+ };
+
+ if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true))
+ return ShufBO;
+ if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false))
+ return ShufBO;
+ }
+
simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
simplifyAndSetOp(I, 1, DemandedElts, UndefElts2);
@@ -1723,7 +1777,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// If we've proven all of the lanes undef, return an undef value.
// TODO: Intersect w/demanded lanes
if (UndefElts.isAllOnes())
- return UndefValue::get(I->getType());;
+ return UndefValue::get(I->getType());
return MadeChange ? I : nullptr;
}