aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp354
1 files changed, 220 insertions, 134 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 661c50062223..2dda46986f0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -689,34 +689,40 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
}
/// We want to turn:
-/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
+/// (select (icmp eq (and X, C1), 0), Y, (BinOp Y, C2))
/// into:
-/// (or (shl (and X, C1), C3), Y)
+/// IF C2 u>= C1
+/// (BinOp Y, (shl (and X, C1), C3))
+/// ELSE
+/// (BinOp Y, (lshr (and X, C1), C3))
/// iff:
+/// 0 on the RHS is the identity value (i.e add, xor, shl, etc...)
/// C1 and C2 are both powers of 2
/// where:
-/// C3 = Log(C2) - Log(C1)
+/// IF C2 u>= C1
+/// C3 = Log(C2) - Log(C1)
+/// ELSE
+/// C3 = Log(C1) - Log(C2)
///
/// This transform handles cases where:
/// 1. The icmp predicate is inverted
/// 2. The select operands are reversed
/// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
+static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
Value *FalseVal,
InstCombiner::BuilderTy &Builder) {
// Only handle integer compares. Also, if this is a vector select, we need a
// vector compare.
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
- TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
+ TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
return nullptr;
Value *CmpLHS = IC->getOperand(0);
Value *CmpRHS = IC->getOperand(1);
- Value *V;
unsigned C1Log;
- bool IsEqualZero;
bool NeedAnd = false;
+ CmpInst::Predicate Pred = IC->getPredicate();
if (IC->isEquality()) {
if (!match(CmpRHS, m_Zero()))
return nullptr;
@@ -725,49 +731,49 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
return nullptr;
- V = CmpLHS;
C1Log = C1->logBase2();
- IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ;
- } else if (IC->getPredicate() == ICmpInst::ICMP_SLT ||
- IC->getPredicate() == ICmpInst::ICMP_SGT) {
- // We also need to recognize (icmp slt (trunc (X)), 0) and
- // (icmp sgt (trunc (X)), -1).
- IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT;
- if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) ||
- (!IsEqualZero && !match(CmpRHS, m_Zero())))
- return nullptr;
-
- if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V)))))
+ } else {
+ APInt C1;
+ if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) ||
+ !C1.isPowerOf2())
return nullptr;
- C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1;
+ C1Log = C1.logBase2();
NeedAnd = true;
- } else {
- return nullptr;
}
+ Value *Y, *V = CmpLHS;
+ BinaryOperator *BinOp;
const APInt *C2;
- bool OrOnTrueVal = false;
- bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)));
- if (!OrOnFalseVal)
- OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)));
-
- if (!OrOnFalseVal && !OrOnTrueVal)
+ bool NeedXor;
+ if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
+ Y = TrueVal;
+ BinOp = cast<BinaryOperator>(FalseVal);
+ NeedXor = Pred == ICmpInst::ICMP_NE;
+ } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
+ Y = FalseVal;
+ BinOp = cast<BinaryOperator>(TrueVal);
+ NeedXor = Pred == ICmpInst::ICMP_EQ;
+ } else {
return nullptr;
+ }
- Value *Y = OrOnFalseVal ? TrueVal : FalseVal;
+ // Check that 0 on RHS is identity value for this binop.
+ auto *IdentityC =
+ ConstantExpr::getBinOpIdentity(BinOp->getOpcode(), BinOp->getType(),
+ /*AllowRHSConstant*/ true);
+ if (IdentityC == nullptr || !IdentityC->isNullValue())
+ return nullptr;
unsigned C2Log = C2->logBase2();
- bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);
bool NeedShift = C1Log != C2Log;
bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() !=
V->getType()->getScalarSizeInBits();
// Make sure we don't create more instructions than we save.
- Value *Or = OrOnFalseVal ? FalseVal : TrueVal;
- if ((NeedShift + NeedXor + NeedZExtTrunc) >
- (IC->hasOneUse() + Or->hasOneUse()))
+ if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
+ (IC->hasOneUse() + BinOp->hasOneUse()))
return nullptr;
if (NeedAnd) {
@@ -788,7 +794,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
if (NeedXor)
V = Builder.CreateXor(V, *C2);
- return Builder.CreateOr(V, Y);
+ return Builder.CreateBinOp(BinOp->getOpcode(), Y, V);
}
/// Canonicalize a set or clear of a masked set of constant bits to
@@ -870,7 +876,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
auto *FalseValI = cast<Instruction>(FalseVal);
auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"),
- *FalseValI);
+ FalseValI->getIterator());
IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY);
return IC.replaceInstUsesWith(SI, FalseValI);
}
@@ -1303,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
return nullptr;
// InstSimplify already performed this fold if it was possible subject to
- // current poison-generating flags. Try the transform again with
- // poison-generating flags temporarily dropped.
- bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
- if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
- WasNUW = OBO->hasNoUnsignedWrap();
- WasNSW = OBO->hasNoSignedWrap();
- FalseInst->setHasNoUnsignedWrap(false);
- FalseInst->setHasNoSignedWrap(false);
- }
- if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
- WasExact = PEO->isExact();
- FalseInst->setIsExact(false);
- }
- if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
- WasInBounds = GEP->isInBounds();
- GEP->setIsInBounds(false);
- }
+ // current poison-generating flags. Check whether dropping poison-generating
+ // flags enables the transform.
// Try each equivalence substitution possibility.
// We have an 'EQ' comparison, so the select's false value will propagate.
// Example:
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
+ SmallVector<Instruction *> DropFlags;
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
- /* AllowRefinement */ false) == TrueVal ||
+ /* AllowRefinement */ false,
+ &DropFlags) == TrueVal ||
simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
- /* AllowRefinement */ false) == TrueVal) {
+ /* AllowRefinement */ false,
+ &DropFlags) == TrueVal) {
+ for (Instruction *I : DropFlags) {
+ I->dropPoisonGeneratingFlagsAndMetadata();
+ Worklist.add(I);
+ }
+
return replaceInstUsesWith(Sel, FalseVal);
}
- // Restore poison-generating flags if the transform did not apply.
- if (WasNUW)
- FalseInst->setHasNoUnsignedWrap();
- if (WasNSW)
- FalseInst->setHasNoSignedWrap();
- if (WasExact)
- FalseInst->setIsExact();
- if (WasInBounds)
- cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
-
return nullptr;
}
@@ -1506,8 +1495,13 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
if (!match(ReplacementLow, m_ImmConstant(LowC)) ||
!match(ReplacementHigh, m_ImmConstant(HighC)))
return nullptr;
- ReplacementLow = ConstantExpr::getSExt(LowC, X->getType());
- ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType());
+ const DataLayout &DL = Sel0.getModule()->getDataLayout();
+ ReplacementLow =
+ ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL);
+ ReplacementHigh =
+ ConstantFoldCastOperand(Instruction::SExt, HighC, X->getType(), DL);
+ assert(ReplacementLow && ReplacementHigh &&
+ "Constant folding of ImmConstant cannot fail");
}
// All good, finally emit the new pattern.
@@ -1797,7 +1791,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
return V;
- if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
+ if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
@@ -2094,9 +2088,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
// If the constant is the same after truncation to the smaller type and
// extension to the original type, we can narrow the select.
Type *SelType = Sel.getType();
- Constant *TruncC = ConstantExpr::getTrunc(C, SmallType);
- Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType);
- if (ExtC == C && ExtInst->hasOneUse()) {
+ Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
+ if (TruncC && ExtInst->hasOneUse()) {
Value *TruncCVal = cast<Value>(TruncC);
if (ExtInst == Sel.getFalseValue())
std::swap(X, TruncCVal);
@@ -2107,23 +2100,6 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType);
}
- // If one arm of the select is the extend of the condition, replace that arm
- // with the extension of the appropriate known bool value.
- if (Cond == X) {
- if (ExtInst == Sel.getTrueValue()) {
- // select X, (sext X), C --> select X, -1, C
- // select X, (zext X), C --> select X, 1, C
- Constant *One = ConstantInt::getTrue(SmallType);
- Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType);
- return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel);
- } else {
- // select X, C, (sext X) --> select X, C, 0
- // select X, C, (zext X) --> select X, C, 0
- Constant *Zero = ConstantInt::getNullValue(SelType);
- return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel);
- }
- }
-
return nullptr;
}
@@ -2561,7 +2537,7 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB,
return nullptr;
}
- Builder.SetInsertPoint(&*BB->begin());
+ Builder.SetInsertPoint(BB, BB->begin());
auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size());
for (auto *Pred : predecessors(BB))
PN->addIncoming(Inputs[Pred], Pred);
@@ -2584,6 +2560,61 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
return nullptr;
}
+/// Tries to reduce a pattern that arises when calculating the remainder of the
+/// Euclidean division. When the divisor is a power of two and is guaranteed not
+/// to be negative, a signed remainder can be folded with a bitwise and.
+///
+/// (x % n) < 0 ? (x % n) + n : (x % n)
+/// -> x & (n - 1)
+static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
+ IRBuilderBase &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
+ ICmpInst::Predicate Pred;
+ Value *Op, *RemRes, *Remainder;
+ const APInt *C;
+ bool TrueIfSigned = false;
+
+ if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) &&
+ IC.isSignBitCheck(Pred, *C, TrueIfSigned)))
+ return nullptr;
+
+ // If the sign bit is not set, we have a SGE/SGT comparison, and the operands
+ // of the select are inverted.
+ if (!TrueIfSigned)
+ std::swap(TrueVal, FalseVal);
+
+ auto FoldToBitwiseAnd = [&](Value *Remainder) -> Instruction * {
+ Value *Add = Builder.CreateAdd(
+ Remainder, Constant::getAllOnesValue(RemRes->getType()));
+ return BinaryOperator::CreateAnd(Op, Add);
+ };
+
+ // Match the general case:
+ // %rem = srem i32 %x, %n
+ // %cnd = icmp slt i32 %rem, 0
+ // %add = add i32 %rem, %n
+ // %sel = select i1 %cnd, i32 %add, i32 %rem
+ if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) &&
+ match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) &&
+ IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) &&
+ FalseVal == RemRes)
+ return FoldToBitwiseAnd(Remainder);
+
+ // Match the case where the one arm has been replaced by constant 1:
+ // %rem = srem i32 %n, 2
+ // %cnd = icmp slt i32 %rem, 0
+ // %sel = select i1 %cnd, i32 1, i32 %rem
+ if (match(TrueVal, m_One()) &&
+ match(RemRes, m_SRem(m_Value(Op), m_SpecificInt(2))) &&
+ FalseVal == RemRes)
+ return FoldToBitwiseAnd(ConstantInt::get(RemRes->getType(), 2));
+
+ return nullptr;
+}
+
static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
if (!FI)
@@ -2860,8 +2891,15 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
Value *AltCond = nullptr;
- auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) {
- return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond)));
+ auto matchOuterCond = [OuterSel, IsAndVariant, &AltCond](auto m_InnerCond) {
+ // An unsimplified select condition can match both LogicalAnd and LogicalOr
+ // (select true, true, false). Since below we assume that LogicalAnd implies
+ // InnerSel match the FVal and vice versa for LogicalOr, we can't match the
+ // alternative pattern here.
+ return IsAndVariant ? match(OuterSel.Cond,
+ m_c_LogicalAnd(m_InnerCond, m_Value(AltCond)))
+ : match(OuterSel.Cond,
+ m_c_LogicalOr(m_InnerCond, m_Value(AltCond)));
};
// Finally, match the condition that was driving the outermost `select`,
@@ -3024,31 +3062,37 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
return replaceOperand(SI, 0, A);
- // select a, (select ~a, true, b), false -> select a, b, false
- if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) &&
- match(FalseVal, m_Zero()))
- return replaceOperand(SI, 1, B);
- // select a, true, (select ~a, b, false) -> select a, true, b
- if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) &&
- match(TrueVal, m_One()))
- return replaceOperand(SI, 2, B);
// ~(A & B) & (A | B) --> A ^ B
if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
m_c_LogicalOr(m_Deferred(A), m_Deferred(B)))))
return BinaryOperator::CreateXor(A, B);
- // select (~a | c), a, b -> and a, (or c, freeze(b))
- if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) &&
- CondVal->hasOneUse()) {
- FalseVal = Builder.CreateFreeze(FalseVal);
- return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal));
+ // select (~a | c), a, b -> select a, (select c, true, b), false
+ if (match(CondVal,
+ m_OneUse(m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))))) {
+ Value *OrV = Builder.CreateSelect(C, One, FalseVal);
+ return SelectInst::Create(TrueVal, OrV, Zero);
+ }
+ // select (c & b), a, b -> select b, (select ~c, true, a), false
+ if (match(CondVal, m_OneUse(m_c_And(m_Value(C), m_Specific(FalseVal))))) {
+ if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) {
+ Value *OrV = Builder.CreateSelect(NotC, One, TrueVal);
+ return SelectInst::Create(FalseVal, OrV, Zero);
+ }
+ }
+ // select (a | c), a, b -> select a, true, (select ~c, b, false)
+ if (match(CondVal, m_OneUse(m_c_Or(m_Specific(TrueVal), m_Value(C))))) {
+ if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) {
+ Value *AndV = Builder.CreateSelect(NotC, FalseVal, Zero);
+ return SelectInst::Create(TrueVal, One, AndV);
+ }
}
- // select (~c & b), a, b -> and b, (or freeze(a), c)
- if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) &&
- CondVal->hasOneUse()) {
- TrueVal = Builder.CreateFreeze(TrueVal);
- return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
+ // select (c & ~b), a, b -> select b, true, (select c, a, false)
+ if (match(CondVal,
+ m_OneUse(m_c_And(m_Value(C), m_Not(m_Specific(FalseVal)))))) {
+ Value *AndV = Builder.CreateSelect(C, TrueVal, Zero);
+ return SelectInst::Create(FalseVal, One, AndV);
}
if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
@@ -3057,7 +3101,7 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
Value *Op1 = IsAnd ? TrueVal : FalseVal;
if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) {
auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr");
- InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser()));
+ InsertNewInstBefore(FI, cast<Instruction>(Y->getUser())->getIterator());
replaceUse(*Y, FI);
return replaceInstUsesWith(SI, Op1);
}
@@ -3272,6 +3316,31 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Masked);
}
+bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
+ const Instruction *CtxI) const {
+ KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
+
+ return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity() &&
+ (FMF.noSignedZeros() || Known.signBitIsZeroOrNaN());
+}
+
+static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0,
+ Value *Cmp1, Value *TrueVal,
+ Value *FalseVal, Instruction &CtxI,
+ bool SelectIsNSZ) {
+ Value *MulRHS;
+ if (match(Cmp1, m_PosZeroFP()) &&
+ match(TrueVal, m_c_FMul(m_Specific(Cmp0), m_Value(MulRHS)))) {
+ FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags();
+ // nsz must be on the select, it must be ignored on the multiply. We
+ // need nnan and ninf on the multiply for the other value.
+ FMF.setNoSignedZeros(SelectIsNSZ);
+ return IC.fmulByZeroIsZero(MulRHS, FMF, &CtxI);
+ }
+
+ return false;
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3303,28 +3372,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
ConstantInt::getFalse(CondType), SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 2, S);
-
- // Handle patterns involving sext/zext + not explicitly,
- // as simplifyWithOpReplaced() only looks past one instruction.
- Value *NotCond;
-
- // select a, sext(!a), b -> select !a, b, 0
- // select a, zext(!a), b -> select !a, b, 0
- if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, FalseVal,
- Constant::getNullValue(SelType));
-
- // select a, b, zext(!a) -> select !a, 1, b
- if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal);
-
- // select a, b, sext(!a) -> select !a, -1, b
- if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType),
- TrueVal);
}
if (Instruction *R = foldSelectOfBools(SI))
@@ -3362,7 +3409,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ auto *SIFPOp = dyn_cast<FPMathOperator>(&SI);
+
if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) {
+ FCmpInst::Predicate Pred = FCmp->getPredicate();
Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1);
// Are we selecting a value based on a comparison of the two values?
if ((Cmp0 == TrueVal && Cmp1 == FalseVal) ||
@@ -3372,7 +3422,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
//
// e.g.
// (X ugt Y) ? X : Y -> (X ole Y) ? Y : X
- if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) {
+ if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) {
FCmpInst::Predicate InvPred = FCmp->getInversePredicate();
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
// FIXME: The FMF should propagate from the select, not the fcmp.
@@ -3383,14 +3433,47 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceInstUsesWith(SI, NewSel);
}
}
+
+ if (SIFPOp) {
+ // Fold out scale-if-equals-zero pattern.
+ //
+ // This pattern appears in code with denormal range checks after it's
+ // assumed denormals are treated as zero. This drops a canonicalization.
+
+ // TODO: Could relax the signed zero logic. We just need to know the sign
+ // of the result matches (fmul x, y has the same sign as x).
+ //
+ // TODO: Handle always-canonicalizing variant that selects some value or 1
+ // scaling factor in the fmul visitor.
+
+ // TODO: Handle ldexp too
+
+ Value *MatchCmp0 = nullptr;
+ Value *MatchCmp1 = nullptr;
+
+ // (select (fcmp [ou]eq x, 0.0), (fmul x, K), x => x
+ // (select (fcmp [ou]ne x, 0.0), x, (fmul x, K) => x
+ if (Pred == CmpInst::FCMP_OEQ || Pred == CmpInst::FCMP_UEQ) {
+ MatchCmp0 = FalseVal;
+ MatchCmp1 = TrueVal;
+ } else if (Pred == CmpInst::FCMP_ONE || Pred == CmpInst::FCMP_UNE) {
+ MatchCmp0 = TrueVal;
+ MatchCmp1 = FalseVal;
+ }
+
+ if (Cmp0 == MatchCmp0 &&
+ matchFMulByZeroIfResultEqZero(*this, Cmp0, Cmp1, MatchCmp1, MatchCmp0,
+ SI, SIFPOp->hasNoSignedZeros()))
+ return replaceInstUsesWith(SI, Cmp0);
+ }
}
- if (isa<FPMathOperator>(SI)) {
+ if (SIFPOp) {
// TODO: Try to forward-propagate FMF from select arms to the select.
// Canonicalize select of FP values where NaN and -0.0 are not valid as
// minnum/maxnum intrinsics.
- if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
+ if (SIFPOp->hasNoNaNs() && SIFPOp->hasNoSignedZeros()) {
Value *X, *Y;
if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
return replaceInstUsesWith(
@@ -3430,6 +3513,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldSelectExtConst(SI))
return I;
+ if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
+ return I;
+
// Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0))
// Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx))
auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base,