aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp288
1 files changed, 237 insertions, 51 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index faf58a08976d..aefaf5af1750 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1,9 +1,8 @@
//===- InstCombineSelect.cpp ----------------------------------------------===//
//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
@@ -293,6 +292,8 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
return nullptr;
// If this is a cast from the same type, merge.
+ Value *Cond = SI.getCondition();
+ Type *CondTy = Cond->getType();
if (TI->getNumOperands() == 1 && TI->isCast()) {
Type *FIOpndTy = FI->getOperand(0)->getType();
if (TI->getOperand(0)->getType() != FIOpndTy)
@@ -300,7 +301,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// The select condition may be a vector. We may only change the operand
// type if the vector width remains the same (and matches the condition).
- Type *CondTy = SI.getCondition()->getType();
if (CondTy->isVectorTy()) {
if (!FIOpndTy->isVectorTy())
return nullptr;
@@ -327,12 +327,24 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// Fold this by inserting a select from the input values.
Value *NewSI =
- Builder.CreateSelect(SI.getCondition(), TI->getOperand(0),
- FI->getOperand(0), SI.getName() + ".v", &SI);
+ Builder.CreateSelect(Cond, TI->getOperand(0), FI->getOperand(0),
+ SI.getName() + ".v", &SI);
return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI,
TI->getType());
}
+ // Cond ? -X : -Y --> -(Cond ? X : Y)
+ Value *X, *Y;
+ if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) &&
+ (TI->hasOneUse() || FI->hasOneUse())) {
+ Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI);
+ // TODO: Remove the hack for the binop form when the unary op is optimized
+ // properly with all IR passes.
+ if (TI->getOpcode() != Instruction::FNeg)
+ return BinaryOperator::CreateFNegFMF(NewSel, cast<BinaryOperator>(TI));
+ return UnaryOperator::CreateFNeg(NewSel);
+ }
+
// Only handle binary operators (including two-operand getelementptr) with
// one-use here. As with the cast case above, it may be possible to relax the
// one-use constraint, but that needs be examined carefully since it may not
@@ -374,13 +386,12 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// If the select condition is a vector, the operands of the original select's
// operands also must be vectors. This may not be the case for getelementptr
// for example.
- if (SI.getCondition()->getType()->isVectorTy() &&
- (!OtherOpT->getType()->isVectorTy() ||
- !OtherOpF->getType()->isVectorTy()))
+ if (CondTy->isVectorTy() && (!OtherOpT->getType()->isVectorTy() ||
+ !OtherOpF->getType()->isVectorTy()))
return nullptr;
// If we reach here, they do have operations in common.
- Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF,
+ Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
SI.getName() + ".v", &SI);
Value *Op0 = MatchIsOpZero ? MatchOp : NewSI;
Value *Op1 = MatchIsOpZero ? NewSI : MatchOp;
@@ -521,6 +532,46 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
}
/// We want to turn:
+/// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1
+/// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0
+/// into:
+/// ashr (X, Y)
+static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
+ Value *FalseVal,
+ InstCombiner::BuilderTy &Builder) {
+ ICmpInst::Predicate Pred = IC->getPredicate();
+ Value *CmpLHS = IC->getOperand(0);
+ Value *CmpRHS = IC->getOperand(1);
+ if (!CmpRHS->getType()->isIntOrIntVectorTy())
+ return nullptr;
+
+ Value *X, *Y;
+ unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits();
+ if ((Pred != ICmpInst::ICMP_SGT ||
+ !match(CmpRHS,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) &&
+ (Pred != ICmpInst::ICMP_SLT ||
+ !match(CmpRHS,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0)))))
+ return nullptr;
+
+ // Canonicalize so that ashr is in FalseVal.
+ if (Pred == ICmpInst::ICMP_SLT)
+ std::swap(TrueVal, FalseVal);
+
+ if (match(TrueVal, m_LShr(m_Value(X), m_Value(Y))) &&
+ match(FalseVal, m_AShr(m_Specific(X), m_Specific(Y))) &&
+ match(CmpLHS, m_Specific(X))) {
+ const auto *Ashr = cast<Instruction>(FalseVal);
+ // if lshr is not exact and ashr is, this new ashr must not be exact.
+ bool IsExact = Ashr->isExact() && cast<Instruction>(TrueVal)->isExact();
+ return Builder.CreateAShr(X, Y, IC->getName(), IsExact);
+ }
+
+ return nullptr;
+}
+
+/// We want to turn:
/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
/// into:
/// (or (shl (and X, C1), C3), Y)
@@ -623,11 +674,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
return Builder.CreateOr(V, Y);
}
-/// Transform patterns such as: (a > b) ? a - b : 0
-/// into: ((a > b) ? a : b) - b)
-/// This produces a canonical max pattern that is more easily recognized by the
-/// backend and converted into saturated subtraction instructions if those
-/// exist.
+/// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b).
/// There are 8 commuted/swapped variants of this pattern.
/// TODO: Also support a - UMIN(a,b) patterns.
static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
@@ -669,11 +716,73 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
if (!TrueVal->hasOneUse())
return nullptr;
- // All checks passed, convert to canonical unsigned saturated subtraction
- // form: sub(max()).
- // (a > b) ? a - b : 0 -> ((a > b) ? a : b) - b)
- Value *Max = Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B);
- return IsNegative ? Builder.CreateSub(B, Max) : Builder.CreateSub(Max, B);
+ // (a > b) ? a - b : 0 -> usub.sat(a, b)
+ // (a > b) ? b - a : 0 -> -usub.sat(a, b)
+ Value *Result = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B);
+ if (IsNegative)
+ Result = Builder.CreateNeg(Result);
+ return Result;
+}
+
+static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ if (!Cmp->hasOneUse())
+ return nullptr;
+
+ // Match unsigned saturated add with constant.
+ Value *Cmp0 = Cmp->getOperand(0);
+ Value *Cmp1 = Cmp->getOperand(1);
+ ICmpInst::Predicate Pred = Cmp->getPredicate();
+ Value *X;
+ const APInt *C, *CmpC;
+ if (Pred == ICmpInst::ICMP_ULT &&
+ match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 &&
+ match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) {
+ // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C)
+ return Builder.CreateBinaryIntrinsic(
+ Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C));
+ }
+
+ // Match unsigned saturated add of 2 variables with an unnecessary 'not'.
+ // There are 8 commuted variants.
+ // Canonicalize -1 (saturated result) to true value of the select. Just
+ // swapping the compare operands is legal, because the selected value is the
+ // same in case of equality, so we can interchange u< and u<=.
+ if (match(FVal, m_AllOnes())) {
+ std::swap(TVal, FVal);
+ std::swap(Cmp0, Cmp1);
+ }
+ if (!match(TVal, m_AllOnes()))
+ return nullptr;
+
+ // Canonicalize predicate to 'ULT'.
+ if (Pred == ICmpInst::ICMP_UGT) {
+ Pred = ICmpInst::ICMP_ULT;
+ std::swap(Cmp0, Cmp1);
+ }
+ if (Pred != ICmpInst::ICMP_ULT)
+ return nullptr;
+
+ // Match unsigned saturated add of 2 variables with an unnecessary 'not'.
+ Value *Y;
+ if (match(Cmp0, m_Not(m_Value(X))) &&
+ match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) {
+ // (~X u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
+ // (~X u< Y) ? -1 : (Y + X) --> uadd.sat(X, Y)
+ return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y);
+ }
+ // The 'not' op may be included in the sum but not the compare.
+ X = Cmp0;
+ Y = Cmp1;
+ if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) {
+ // (X u< Y) ? -1 : (~X + Y) --> uadd.sat(~X, Y)
+ // (X u< Y) ? -1 : (Y + ~X) --> uadd.sat(Y, ~X)
+ BinaryOperator *BO = cast<BinaryOperator>(FVal);
+ return Builder.CreateBinaryIntrinsic(
+ Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
+ }
+
+ return nullptr;
}
/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single
@@ -1043,12 +1152,18 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
@@ -1496,6 +1611,43 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) {
return nullptr;
}
+static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X,
+ Value *Y,
+ InstCombiner::BuilderTy &Builder) {
+ assert(SelectPatternResult::isMinOrMax(SPF) && "Expected min/max pattern");
+ bool IsUnsigned = SPF == SelectPatternFlavor::SPF_UMIN ||
+ SPF == SelectPatternFlavor::SPF_UMAX;
+ // TODO: If InstSimplify could fold all cases where C2 <= C1, we could change
+ // the constant value check to an assert.
+ Value *A;
+ const APInt *C1, *C2;
+ if (IsUnsigned && match(X, m_NUWAdd(m_Value(A), m_APInt(C1))) &&
+ match(Y, m_APInt(C2)) && C2->uge(*C1) && X->hasNUses(2)) {
+ // umin (add nuw A, C1), C2 --> add nuw (umin A, C2 - C1), C1
+ // umax (add nuw A, C1), C2 --> add nuw (umax A, C2 - C1), C1
+ Value *NewMinMax = createMinMax(Builder, SPF, A,
+ ConstantInt::get(X->getType(), *C2 - *C1));
+ return BinaryOperator::CreateNUW(BinaryOperator::Add, NewMinMax,
+ ConstantInt::get(X->getType(), *C1));
+ }
+
+ if (!IsUnsigned && match(X, m_NSWAdd(m_Value(A), m_APInt(C1))) &&
+ match(Y, m_APInt(C2)) && X->hasNUses(2)) {
+ bool Overflow;
+ APInt Diff = C2->ssub_ov(*C1, Overflow);
+ if (!Overflow) {
+ // smin (add nsw A, C1), C2 --> add nsw (smin A, C2 - C1), C1
+ // smax (add nsw A, C1), C2 --> add nsw (smax A, C2 - C1), C1
+ Value *NewMinMax = createMinMax(Builder, SPF, A,
+ ConstantInt::get(X->getType(), Diff));
+ return BinaryOperator::CreateNSW(BinaryOperator::Add, NewMinMax,
+ ConstantInt::get(X->getType(), *C1));
+ }
+ }
+
+ return nullptr;
+}
+
/// Reduce a sequence of min/max with a common operand.
static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
Value *RHS,
@@ -1757,37 +1909,55 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
// NOTE: if we wanted to, this is where to detect MIN/MAX
}
+ }
- // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
- // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We
- // also require nnan because we do not want to unintentionally change the
- // sign of a NaN value.
- Value *X = FCI->getOperand(0);
- FCmpInst::Predicate Pred = FCI->getPredicate();
- if (match(FCI->getOperand(1), m_AnyZeroFP()) && FCI->hasNoNaNs()) {
- // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X)
- // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X)
- if ((X == FalseVal && Pred == FCmpInst::FCMP_OLE &&
- match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) ||
- (X == TrueVal && Pred == FCmpInst::FCMP_OGT &&
- match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) {
- Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI);
- return replaceInstUsesWith(SI, Fabs);
- }
- // With nsz:
- // (X < +/-0.0) ? -X : X --> fabs(X)
- // (X <= +/-0.0) ? -X : X --> fabs(X)
- // (X > +/-0.0) ? X : -X --> fabs(X)
- // (X >= +/-0.0) ? X : -X --> fabs(X)
- if (FCI->hasNoSignedZeros() &&
- ((X == FalseVal && match(TrueVal, m_FNeg(m_Specific(X))) &&
- (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) ||
- (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) &&
- (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) {
- Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI);
- return replaceInstUsesWith(SI, Fabs);
- }
- }
+ // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
+ // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We
+ // also require nnan because we do not want to unintentionally change the
+ // sign of a NaN value.
+ // FIXME: These folds should test/propagate FMF from the select, not the
+ // fsub or fneg.
+ // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X)
+ Instruction *FSub;
+ if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) &&
+ match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(FalseVal))) &&
+ match(TrueVal, m_Instruction(FSub)) && FSub->hasNoNaNs() &&
+ (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) {
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, FSub);
+ return replaceInstUsesWith(SI, Fabs);
+ }
+ // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X)
+ if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) &&
+ match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(TrueVal))) &&
+ match(FalseVal, m_Instruction(FSub)) && FSub->hasNoNaNs() &&
+ (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) {
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, FSub);
+ return replaceInstUsesWith(SI, Fabs);
+ }
+ // With nnan and nsz:
+ // (X < +/-0.0) ? -X : X --> fabs(X)
+ // (X <= +/-0.0) ? -X : X --> fabs(X)
+ Instruction *FNeg;
+ if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) &&
+ match(TrueVal, m_FNeg(m_Specific(FalseVal))) &&
+ match(TrueVal, m_Instruction(FNeg)) &&
+ FNeg->hasNoNaNs() && FNeg->hasNoSignedZeros() &&
+ (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE ||
+ Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE)) {
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, FNeg);
+ return replaceInstUsesWith(SI, Fabs);
+ }
+ // With nnan and nsz:
+ // (X > +/-0.0) ? X : -X --> fabs(X)
+ // (X >= +/-0.0) ? X : -X --> fabs(X)
+ if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) &&
+ match(FalseVal, m_FNeg(m_Specific(TrueVal))) &&
+ match(FalseVal, m_Instruction(FNeg)) &&
+ FNeg->hasNoNaNs() && FNeg->hasNoSignedZeros() &&
+ (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE ||
+ Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE)) {
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, FNeg);
+ return replaceInstUsesWith(SI, Fabs);
}
// See if we are selecting two values based on a comparison of the two values.
@@ -1895,11 +2065,27 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *I = moveNotAfterMinMax(RHS, LHS))
return I;
+ if (Instruction *I = moveAddAfterMinMax(SPF, LHS, RHS, Builder))
+ return I;
+
if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder))
return I;
}
}
+ // Canonicalize select of FP values where NaN and -0.0 are not valid as
+ // minnum/maxnum intrinsics.
+ if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
+ Value *X, *Y;
+ if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
+ return replaceInstUsesWith(
+ SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI));
+
+ if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y))))
+ return replaceInstUsesWith(
+ SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI));
+ }
+
// See if we can fold the select into a phi node if the condition is a select.
if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
// The true/false values have to be live in the PHI predecessor's blocks.