diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 181 | 
1 files changed, 152 insertions, 29 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9fc871e49b30..05a624fde86b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -704,16 +704,24 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,    assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&           "Unexpected isUnsigned predicate!"); -  // Account for swapped form of subtraction: ((a > b) ? b - a : 0). +  // Ensure the sub is of the form: +  //  (a > b) ? a - b : 0 -> usub.sat(a, b) +  //  (a > b) ? b - a : 0 -> -usub.sat(a, b) +  // Checking for both a-b and a+(-b) as a constant.    bool IsNegative = false; -  if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) +  const APInt *C; +  if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) || +      (match(A, m_APInt(C)) && +       match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C)))))      IsNegative = true; -  else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) +  else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) && +           !(match(B, m_APInt(C)) && +             match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C)))))      return nullptr; -  // If sub is used anywhere else, we wouldn't be able to eliminate it -  // afterwards. -  if (!TrueVal->hasOneUse()) +  // If we are adding a negate and the sub and icmp are used anywhere else, we +  // would end up with more instructions. +  if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse())      return nullptr;    // (a > b) ? a - b : 0 -> usub.sat(a, b) @@ -781,6 +789,13 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,      return Builder.CreateBinaryIntrinsic(          Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));    } +  // The overflow may be detected via the add wrapping round. +  if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && +      match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { +    // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) +    // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) +    return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y); +  }    return nullptr;  } @@ -1725,6 +1740,128 @@ static Instruction *foldAddSubSelect(SelectInst &SI,    return nullptr;  } +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +/// Along with a number of patterns similar to: +/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { +  Value *CondVal = SI.getCondition(); +  Value *TrueVal = SI.getTrueValue(); +  Value *FalseVal = SI.getFalseValue(); + +  WithOverflowInst *II; +  if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || +      !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) +    return nullptr; + +  Value *X = II->getLHS(); +  Value *Y = II->getRHS(); + +  auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) { +    Type *Ty = Limit->getType(); + +    ICmpInst::Predicate Pred; +    Value *TrueVal, *FalseVal, *Op; +    const APInt *C; +    if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)), +                               m_Value(TrueVal), m_Value(FalseVal)))) +      return false; + +    auto IsZeroOrOne = [](const APInt &C) { +      return C.isNullValue() || C.isOneValue(); +    }; +    auto IsMinMax = [&](Value *Min, Value *Max) { +      APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); +      APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); +      return match(Min, m_SpecificInt(MinVal)) && +             match(Max, m_SpecificInt(MaxVal)); +    }; + +    if (Op != X && Op != Y) +      return false; + +    if (IsAdd) { +      // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +      if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && +          IsMinMax(TrueVal, FalseVal)) +        return true; +      // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +      // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +      if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && +          IsMinMax(FalseVal, TrueVal)) +        return true; +    } else { +      // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +      // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +      if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) && +          IsMinMax(TrueVal, FalseVal)) +        return true; +      // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +      // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +      if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) && +          IsMinMax(FalseVal, TrueVal)) +        return true; +      // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +      // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +      if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && +          IsMinMax(FalseVal, TrueVal)) +        return true; +      // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +      // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +      if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && +          IsMinMax(TrueVal, FalseVal)) +        return true; +    } + +    return false; +  }; + +  Intrinsic::ID NewIntrinsicID; +  if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && +      match(TrueVal, m_AllOnes())) +    // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +    NewIntrinsicID = Intrinsic::uadd_sat; +  else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && +           match(TrueVal, m_Zero())) +    // X - Y overflows ? 0 : X - Y -> usub_sat X, Y +    NewIntrinsicID = Intrinsic::usub_sat; +  else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow && +           IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true)) +    // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +    // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y +    NewIntrinsicID = Intrinsic::sadd_sat; +  else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow && +           IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false)) +    // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +    // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y +    NewIntrinsicID = Intrinsic::ssub_sat; +  else +    return nullptr; + +  Function *F = +      Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); +  return CallInst::Create(F, {X, Y}); +} +  Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {    Constant *C;    if (!match(Sel.getTrueValue(), m_Constant(C)) && @@ -2296,7 +2433,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {    // See if we are selecting two values based on a comparison of the two values.    if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { -    if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { +    Value *Cmp0 = FCI->getOperand(0), *Cmp1 = FCI->getOperand(1); +    if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || +        (Cmp0 == FalseVal && Cmp1 == TrueVal)) {        // Canonicalize to use ordered comparisons by swapping the select        // operands.        // @@ -2305,30 +2444,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {        if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {          FCmpInst::Predicate InvPred = FCI->getInversePredicate();          IRBuilder<>::FastMathFlagGuard FMFG(Builder); +        // FIXME: The FMF should propagate from the select, not the fcmp.          Builder.setFastMathFlags(FCI->getFastMathFlags()); -        Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal, -                                            FCI->getName() + ".inv"); - -        return SelectInst::Create(NewCond, FalseVal, TrueVal, -                                  SI.getName() + ".p"); -      } - -      // NOTE: if we wanted to, this is where to detect MIN/MAX -    } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ -      // Canonicalize to use ordered comparisons by swapping the select -      // operands. -      // -      // e.g. -      // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y -      if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { -        FCmpInst::Predicate InvPred = FCI->getInversePredicate(); -        IRBuilder<>::FastMathFlagGuard FMFG(Builder); -        Builder.setFastMathFlags(FCI->getFastMathFlags()); -        Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal, +        Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1,                                              FCI->getName() + ".inv"); - -        return SelectInst::Create(NewCond, FalseVal, TrueVal, -                                  SI.getName() + ".p"); +        Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); +        return replaceInstUsesWith(SI, NewSel);        }        // NOTE: if we wanted to, this is where to detect MIN/MAX @@ -2391,6 +2512,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {    if (Instruction *Add = foldAddSubSelect(SI, Builder))      return Add; +  if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) +    return Add;    // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))    auto *TI = dyn_cast<Instruction>(TrueVal);  | 
