diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 181 |
1 files changed, 152 insertions, 29 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9fc871e49b30..05a624fde86b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -704,16 +704,24 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && "Unexpected isUnsigned predicate!"); - // Account for swapped form of subtraction: ((a > b) ? b - a : 0). + // Ensure the sub is of the form: + // (a > b) ? a - b : 0 -> usub.sat(a, b) + // (a > b) ? b - a : 0 -> -usub.sat(a, b) + // Checking for both a-b and a+(-b) as a constant. bool IsNegative = false; - if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) + const APInt *C; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) || + (match(A, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C))))) IsNegative = true; - else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) && + !(match(B, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C))))) return nullptr; - // If sub is used anywhere else, we wouldn't be able to eliminate it - // afterwards. - if (!TrueVal->hasOneUse()) + // If we are adding a negate and the sub and icmp are used anywhere else, we + // would end up with more instructions. + if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse()) return nullptr; // (a > b) ? a - b : 0 -> usub.sat(a, b) @@ -781,6 +789,13 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return Builder.CreateBinaryIntrinsic( Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); } + // The overflow may be detected via the add wrapping round. + if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && + match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { + // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) + // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y); + } return nullptr; } @@ -1725,6 +1740,128 @@ static Instruction *foldAddSubSelect(SelectInst &SI, return nullptr; } +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +/// Along with a number of patterns similar to: +/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + WithOverflowInst *II; + if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || + !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) + return nullptr; + + Value *X = II->getLHS(); + Value *Y = II->getRHS(); + + auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) { + Type *Ty = Limit->getType(); + + ICmpInst::Predicate Pred; + Value *TrueVal, *FalseVal, *Op; + const APInt *C; + if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)), + m_Value(TrueVal), m_Value(FalseVal)))) + return false; + + auto IsZeroOrOne = [](const APInt &C) { + return C.isNullValue() || C.isOneValue(); + }; + auto IsMinMax = [&](Value *Min, Value *Max) { + APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); + return match(Min, m_SpecificInt(MinVal)) && + match(Max, m_SpecificInt(MaxVal)); + }; + + if (Op != X && Op != Y) + return false; + + if (IsAdd) { + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(FalseVal, TrueVal)) + return true; + } else { + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + } + + return false; + }; + + Intrinsic::ID NewIntrinsicID; + if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && + match(TrueVal, m_AllOnes())) + // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y + NewIntrinsicID = Intrinsic::uadd_sat; + else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && + match(TrueVal, m_Zero())) + // X - Y overflows ? 0 : X - Y -> usub_sat X, Y + NewIntrinsicID = Intrinsic::usub_sat; + else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true)) + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + NewIntrinsicID = Intrinsic::sadd_sat; + else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false)) + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + NewIntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + Function *F = + Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); + return CallInst::Create(F, {X, Y}); +} + Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && @@ -2296,7 +2433,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { - if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { + Value *Cmp0 = FCI->getOperand(0), *Cmp1 = FCI->getOperand(1); + if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || + (Cmp0 == FalseVal && Cmp1 == TrueVal)) { // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -2305,30 +2444,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); IRBuilder<>::FastMathFlagGuard FMFG(Builder); + // FIXME: The FMF should propagate from the select, not the fcmp. Builder.setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal, - FCI->getName() + ".inv"); - - return SelectInst::Create(NewCond, FalseVal, TrueVal, - SI.getName() + ".p"); - } - - // NOTE: if we wanted to, this is where to detect MIN/MAX - } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ - // Canonicalize to use ordered comparisons by swapping the select - // operands. - // - // e.g. - // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y - if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { - FCmpInst::Predicate InvPred = FCI->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - Builder.setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal, + Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, FCI->getName() + ".inv"); - - return SelectInst::Create(NewCond, FalseVal, TrueVal, - SI.getName() + ".p"); + Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); + return replaceInstUsesWith(SI, NewSel); } // NOTE: if we wanted to, this is where to detect MIN/MAX @@ -2391,6 +2512,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; + if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) + return Add; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast<Instruction>(TrueVal); |