aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp181
1 files changed, 152 insertions, 29 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 9fc871e49b30..05a624fde86b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -704,16 +704,24 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&
"Unexpected isUnsigned predicate!");
- // Account for swapped form of subtraction: ((a > b) ? b - a : 0).
+ // Ensure the sub is of the form:
+ // (a > b) ? a - b : 0 -> usub.sat(a, b)
+ // (a > b) ? b - a : 0 -> -usub.sat(a, b)
+ // Checking for both a-b and a+(-b) as a constant.
bool IsNegative = false;
- if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))))
+ const APInt *C;
+ if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) ||
+ (match(A, m_APInt(C)) &&
+ match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C)))))
IsNegative = true;
- else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))))
+ else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) &&
+ !(match(B, m_APInt(C)) &&
+ match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C)))))
return nullptr;
- // If sub is used anywhere else, we wouldn't be able to eliminate it
- // afterwards.
- if (!TrueVal->hasOneUse())
+ // If we are adding a negate and the sub and icmp are used anywhere else, we
+ // would end up with more instructions.
+ if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse())
return nullptr;
// (a > b) ? a - b : 0 -> usub.sat(a, b)
@@ -781,6 +789,13 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return Builder.CreateBinaryIntrinsic(
Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
}
+ // The overflow may be detected via the add wrapping round.
+ if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
+ match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) {
+ // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y)
+ // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
+ return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y);
+ }
return nullptr;
}
@@ -1725,6 +1740,128 @@ static Instruction *foldAddSubSelect(SelectInst &SI,
return nullptr;
}
+/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
+/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y
+/// Along with a number of patterns similar to:
+/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+static Instruction *
+foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
+ WithOverflowInst *II;
+ if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) ||
+ !match(FalseVal, m_ExtractValue<0>(m_Specific(II))))
+ return nullptr;
+
+ Value *X = II->getLHS();
+ Value *Y = II->getRHS();
+
+ auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) {
+ Type *Ty = Limit->getType();
+
+ ICmpInst::Predicate Pred;
+ Value *TrueVal, *FalseVal, *Op;
+ const APInt *C;
+ if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)),
+ m_Value(TrueVal), m_Value(FalseVal))))
+ return false;
+
+ auto IsZeroOrOne = [](const APInt &C) {
+ return C.isNullValue() || C.isOneValue();
+ };
+ auto IsMinMax = [&](Value *Min, Value *Max) {
+ APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
+ APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits());
+ return match(Min, m_SpecificInt(MinVal)) &&
+ match(Max, m_SpecificInt(MaxVal));
+ };
+
+ if (Op != X && Op != Y)
+ return false;
+
+ if (IsAdd) {
+ // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ } else {
+ // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ }
+
+ return false;
+ };
+
+ Intrinsic::ID NewIntrinsicID;
+ if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow &&
+ match(TrueVal, m_AllOnes()))
+ // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
+ NewIntrinsicID = Intrinsic::uadd_sat;
+ else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow &&
+ match(TrueVal, m_Zero()))
+ // X - Y overflows ? 0 : X - Y -> usub_sat X, Y
+ NewIntrinsicID = Intrinsic::usub_sat;
+ else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow &&
+ IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true))
+ // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ NewIntrinsicID = Intrinsic::sadd_sat;
+ else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow &&
+ IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false))
+ // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ NewIntrinsicID = Intrinsic::ssub_sat;
+ else
+ return nullptr;
+
+ Function *F =
+ Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType());
+ return CallInst::Create(F, {X, Y});
+}
+
Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
Constant *C;
if (!match(Sel.getTrueValue(), m_Constant(C)) &&
@@ -2296,7 +2433,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
// See if we are selecting two values based on a comparison of the two values.
if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) {
- if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) {
+ Value *Cmp0 = FCI->getOperand(0), *Cmp1 = FCI->getOperand(1);
+ if ((Cmp0 == TrueVal && Cmp1 == FalseVal) ||
+ (Cmp0 == FalseVal && Cmp1 == TrueVal)) {
// Canonicalize to use ordered comparisons by swapping the select
// operands.
//
@@ -2305,30 +2444,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
FCmpInst::Predicate InvPred = FCI->getInversePredicate();
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ // FIXME: The FMF should propagate from the select, not the fcmp.
Builder.setFastMathFlags(FCI->getFastMathFlags());
- Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal,
- FCI->getName() + ".inv");
-
- return SelectInst::Create(NewCond, FalseVal, TrueVal,
- SI.getName() + ".p");
- }
-
- // NOTE: if we wanted to, this is where to detect MIN/MAX
- } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){
- // Canonicalize to use ordered comparisons by swapping the select
- // operands.
- //
- // e.g.
- // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y
- if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
- FCmpInst::Predicate InvPred = FCI->getInversePredicate();
- IRBuilder<>::FastMathFlagGuard FMFG(Builder);
- Builder.setFastMathFlags(FCI->getFastMathFlags());
- Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal,
+ Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1,
FCI->getName() + ".inv");
-
- return SelectInst::Create(NewCond, FalseVal, TrueVal,
- SI.getName() + ".p");
+ Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal);
+ return replaceInstUsesWith(SI, NewSel);
}
// NOTE: if we wanted to, this is where to detect MIN/MAX
@@ -2391,6 +2512,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *Add = foldAddSubSelect(SI, Builder))
return Add;
+ if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
+ return Add;
// Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
auto *TI = dyn_cast<Instruction>(TrueVal);