summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
commitb60736ec1405bb0a8dd40989f67ef4c93da068ab (patch)
tree5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
parentcfca06d7963fa0909f90483b42a6d7d194d01e08 (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp489
1 files changed, 271 insertions, 218 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 17124f717af7..f26c194d31b9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -38,6 +38,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
#include <utility>
@@ -46,6 +47,11 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
+/// FIXME: Enabled by default until the pattern is supported well.
+static cl::opt<bool> EnableUnsafeSelectTransform(
+ "instcombine-unsafe-select-transform", cl::init(true),
+ cl::desc("Enable poison-unsafe select to and/or transform"));
+
static Value *createMinMax(InstCombiner::BuilderTy &Builder,
SelectPatternFlavor SPF, Value *A, Value *B) {
CmpInst::Predicate Pred = getMinMaxPred(SPF);
@@ -57,7 +63,7 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder,
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
const TargetLibraryInfo &TLI,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
// The select condition must be an equality compare with a constant operand.
Value *X;
Constant *C;
@@ -258,29 +264,9 @@ static unsigned getSelectFoldableOperands(BinaryOperator *I) {
}
}
-/// For the same transformation as the previous function, return the identity
-/// constant that goes into the select.
-static APInt getSelectFoldableConstant(BinaryOperator *I) {
- switch (I->getOpcode()) {
- default: llvm_unreachable("This cannot happen!");
- case Instruction::Add:
- case Instruction::Sub:
- case Instruction::Or:
- case Instruction::Xor:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- return APInt::getNullValue(I->getType()->getScalarSizeInBits());
- case Instruction::And:
- return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits());
- case Instruction::Mul:
- return APInt(I->getType()->getScalarSizeInBits(), 1);
- }
-}
-
/// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
-Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
- Instruction *FI) {
+Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
+ Instruction *FI) {
// Don't break up min/max patterns. The hasOneUse checks below prevent that
// for most cases, but vector min/max with bitcasts can be transformed. If the
// one-use restrictions are eased for other patterns, we still don't want to
@@ -302,10 +288,9 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// The select condition may be a vector. We may only change the operand
// type if the vector width remains the same (and matches the condition).
if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) {
- if (!FIOpndTy->isVectorTy())
- return nullptr;
- if (CondVTy->getNumElements() !=
- cast<VectorType>(FIOpndTy)->getNumElements())
+ if (!FIOpndTy->isVectorTy() ||
+ CondVTy->getElementCount() !=
+ cast<VectorType>(FIOpndTy)->getElementCount())
return nullptr;
// TODO: If the backend knew how to deal with casts better, we could
@@ -418,8 +403,8 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) {
/// Try to fold the select into one of the operands to allow further
/// optimization.
-Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
- Value *FalseVal) {
+Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
+ Value *FalseVal) {
// See the comment above GetSelectFoldableOperands for a description of the
// transformation we are doing here.
if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) {
@@ -433,14 +418,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- APInt CI = getSelectFoldableConstant(TVI);
+ Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(),
+ TVI->getType(), true);
Value *OOp = TVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
const APInt *OOpC;
bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
- Value *C = ConstantInt::get(OOp->getType(), CI);
+ if (!isa<Constant>(OOp) ||
+ (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C);
NewSel->takeName(TVI);
BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(),
@@ -464,14 +450,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- APInt CI = getSelectFoldableConstant(FVI);
+ Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(),
+ FVI->getType(), true);
Value *OOp = FVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
const APInt *OOpC;
bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
- Value *C = ConstantInt::get(OOp->getType(), CI);
+ if (!isa<Constant>(OOp) ||
+ (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp);
NewSel->takeName(FVI);
BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(),
@@ -782,25 +769,24 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
// There are 8 commuted variants.
- // Canonicalize -1 (saturated result) to true value of the select. Just
- // swapping the compare operands is legal, because the selected value is the
- // same in case of equality, so we can interchange u< and u<=.
+ // Canonicalize -1 (saturated result) to true value of the select.
if (match(FVal, m_AllOnes())) {
std::swap(TVal, FVal);
- std::swap(Cmp0, Cmp1);
+ Pred = CmpInst::getInversePredicate(Pred);
}
if (!match(TVal, m_AllOnes()))
return nullptr;
- // Canonicalize predicate to 'ULT'.
- if (Pred == ICmpInst::ICMP_UGT) {
- Pred = ICmpInst::ICMP_ULT;
+ // Canonicalize predicate to less-than or less-or-equal-than.
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
std::swap(Cmp0, Cmp1);
+ Pred = CmpInst::getSwappedPredicate(Pred);
}
- if (Pred != ICmpInst::ICMP_ULT)
+ if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE)
return nullptr;
// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
+ // Strictness of the comparison is irrelevant.
Value *Y;
if (match(Cmp0, m_Not(m_Value(X))) &&
match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) {
@@ -809,6 +795,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y);
}
// The 'not' op may be included in the sum but not the compare.
+ // Strictness of the comparison is irrelevant.
X = Cmp0;
Y = Cmp1;
if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) {
@@ -819,7 +806,9 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
}
// The overflow may be detected via the add wrapping round.
- if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
+ // This is only valid for strict comparison!
+ if (Pred == ICmpInst::ICMP_ULT &&
+ match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) {
// ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y)
// ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
@@ -1024,9 +1013,9 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2
/// Note: if C1 != C2, this will change the icmp constant to the existing
/// constant operand of the select.
-static Instruction *
-canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+static Instruction *canonicalizeMinMaxWithConstant(SelectInst &Sel,
+ ICmpInst &Cmp,
+ InstCombinerImpl &IC) {
if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
@@ -1063,105 +1052,29 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
return &Sel;
}
-/// There are many select variants for each of ABS/NABS.
-/// In matchSelectPattern(), there are different compare constants, compare
-/// predicates/operands and select operands.
-/// In isKnownNegation(), there are different formats of negated operands.
-/// Canonicalize all these variants to 1 pattern.
-/// This makes CSE more likely.
static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
- // Choose a sign-bit check for the compare (likely simpler for codegen).
- // ABS: (X <s 0) ? -X : X
- // NABS: (X <s 0) ? X : -X
Value *LHS, *RHS;
SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor;
if (SPF != SelectPatternFlavor::SPF_ABS &&
SPF != SelectPatternFlavor::SPF_NABS)
return nullptr;
- Value *TVal = Sel.getTrueValue();
- Value *FVal = Sel.getFalseValue();
- assert(isKnownNegation(TVal, FVal) &&
- "Unexpected result from matchSelectPattern");
-
- // The compare may use the negated abs()/nabs() operand, or it may use
- // negation in non-canonical form such as: sub A, B.
- bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) ||
- match(Cmp.getOperand(0), m_Neg(m_Specific(FVal)));
-
- bool CmpCanonicalized = !CmpUsesNegatedOp &&
- match(Cmp.getOperand(1), m_ZeroInt()) &&
- Cmp.getPredicate() == ICmpInst::ICMP_SLT;
- bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS)));
-
- // Is this already canonical?
- if (CmpCanonicalized && RHSCanonicalized)
- return nullptr;
+ // Note that NSW flag can only be propagated for normal, non-negated abs!
+ bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS &&
+ match(RHS, m_NSWNeg(m_Specific(LHS)));
+ Constant *IntMinIsPoisonC =
+ ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison);
+ Instruction *Abs =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC);
- // If RHS is not canonical but is used by other instructions, don't
- // canonicalize it and potentially increase the instruction count.
- if (!RHSCanonicalized)
- if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp)))
- return nullptr;
+ if (SPF == SelectPatternFlavor::SPF_NABS)
+ return BinaryOperator::CreateNeg(Abs); // Always without NSW flag!
- // Create the canonical compare: icmp slt LHS 0.
- if (!CmpCanonicalized) {
- Cmp.setPredicate(ICmpInst::ICMP_SLT);
- Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType()));
- if (CmpUsesNegatedOp)
- Cmp.setOperand(0, LHS);
- }
-
- // Create the canonical RHS: RHS = sub (0, LHS).
- if (!RHSCanonicalized) {
- assert(RHS->hasOneUse() && "RHS use number is not right");
- RHS = IC.Builder.CreateNeg(LHS);
- if (TVal == LHS) {
- // Replace false value.
- IC.replaceOperand(Sel, 2, RHS);
- FVal = RHS;
- } else {
- // Replace true value.
- IC.replaceOperand(Sel, 1, RHS);
- TVal = RHS;
- }
- }
-
- // If the select operands do not change, we're done.
- if (SPF == SelectPatternFlavor::SPF_NABS) {
- if (TVal == LHS)
- return &Sel;
- assert(FVal == LHS && "Unexpected results from matchSelectPattern");
- } else {
- if (FVal == LHS)
- return &Sel;
- assert(TVal == LHS && "Unexpected results from matchSelectPattern");
- }
-
- // We are swapping the select operands, so swap the metadata too.
- Sel.swapValues();
- Sel.swapProfMetadata();
- return &Sel;
-}
-
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
- const SimplifyQuery &Q) {
- // If this is a binary operator, try to simplify it with the replaced op
- // because we know Op and ReplaceOp are equivalant.
- // For example: V = X + 1, Op = X, ReplaceOp = 42
- // Simplifies as: add(42, 1) --> 43
- if (auto *BO = dyn_cast<BinaryOperator>(V)) {
- if (BO->getOperand(0) == Op)
- return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
- if (BO->getOperand(1) == Op)
- return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
- }
-
- return nullptr;
+ return IC.replaceInstUsesWith(Sel, Abs);
}
/// If we have a select with an equality comparison, then we know the value in
@@ -1180,30 +1093,97 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
///
/// We can't replace %sel with %add unless we strip away the flags.
/// TODO: Wrapping flags could be preserved in some cases with better analysis.
-static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
- const SimplifyQuery &Q) {
+Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
+ ICmpInst &Cmp) {
if (!Cmp.isEquality())
return nullptr;
// Canonicalize the pattern to ICMP_EQ by swapping the select operands.
Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
- if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
+ bool Swapped = false;
+ if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
std::swap(TrueVal, FalseVal);
+ Swapped = true;
+ }
+
+ // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
+ // Make sure Y cannot be undef though, as we might pick different values for
+ // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
+ // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
+ // replacement cycle.
+ Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
+ if (TrueVal != CmpLHS &&
+ isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
+ if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(Sel, Swapped ? 2 : 1, V);
+
+ // Even if TrueVal does not simplify, we can directly replace a use of
+ // CmpLHS with CmpRHS, as long as the instruction is not used anywhere
+ // else and is safe to speculatively execute (we may end up executing it
+ // with different operands, which should not cause side-effects or trigger
+ // undefined behavior). Only do this if CmpRHS is a constant, as
+ // profitability is not clear for other cases.
+ // FIXME: The replacement could be performed recursively.
+ if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()))
+ if (auto *I = dyn_cast<Instruction>(TrueVal))
+ if (I->hasOneUse() && isSafeToSpeculativelyExecute(I))
+ for (Use &U : I->operands())
+ if (U == CmpLHS) {
+ replaceUse(U, CmpRHS);
+ return &Sel;
+ }
+ }
+ if (TrueVal != CmpRHS &&
+ isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
+ if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(Sel, Swapped ? 2 : 1, V);
+
+ auto *FalseInst = dyn_cast<Instruction>(FalseVal);
+ if (!FalseInst)
+ return nullptr;
+
+ // InstSimplify already performed this fold if it was possible subject to
+ // current poison-generating flags. Try the transform again with
+ // poison-generating flags temporarily dropped.
+ bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
+ WasNUW = OBO->hasNoUnsignedWrap();
+ WasNSW = OBO->hasNoSignedWrap();
+ FalseInst->setHasNoUnsignedWrap(false);
+ FalseInst->setHasNoSignedWrap(false);
+ }
+ if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
+ WasExact = PEO->isExact();
+ FalseInst->setIsExact(false);
+ }
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
+ WasInBounds = GEP->isInBounds();
+ GEP->setIsInBounds(false);
+ }
// Try each equivalence substitution possibility.
// We have an 'EQ' comparison, so the select's false value will propagate.
// Example:
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
- // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
- Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
- if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
- simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
- simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
- simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
- if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
- FalseInst->dropPoisonGeneratingFlags();
- return FalseVal;
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
+ /* AllowRefinement */ false) == TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
+ /* AllowRefinement */ false) == TrueVal) {
+ return replaceInstUsesWith(Sel, FalseVal);
}
+
+ // Restore poison-generating flags if the transform did not apply.
+ if (WasNUW)
+ FalseInst->setHasNoUnsignedWrap();
+ if (WasNSW)
+ FalseInst->setHasNoSignedWrap();
+ if (WasExact)
+ FalseInst->setIsExact();
+ if (WasInBounds)
+ cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
+
return nullptr;
}
@@ -1253,7 +1233,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
APInt::getAllOnesValue(
C0->getType()->getScalarSizeInBits()))))
return nullptr; // Can't do, have all-ones element[s].
- C0 = AddOne(C0);
+ C0 = InstCombiner::AddOne(C0);
std::swap(X, Sel1);
break;
case ICmpInst::Predicate::ICMP_UGE:
@@ -1313,7 +1293,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
APInt::getSignedMaxValue(
C2->getType()->getScalarSizeInBits()))))
return nullptr; // Can't do, have signed max element[s].
- C2 = AddOne(C2);
+ C2 = InstCombiner::AddOne(C2);
LLVM_FALLTHROUGH;
case ICmpInst::Predicate::ICMP_SGE:
// Also non-canonical, but here we don't need to change C2,
@@ -1360,7 +1340,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
// and swap the hands of select.
static Instruction *
tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
ICmpInst::Predicate Pred;
Value *X;
Constant *C0;
@@ -1375,7 +1355,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
// If comparison predicate is non-canonical, then we certainly won't be able
// to make it canonical; canonicalizeCmpWithConstant() already tried.
- if (!isCanonicalPredicate(Pred))
+ if (!InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
// If the [input] type of comparison and select type are different, lets abort
@@ -1403,7 +1383,8 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
// Check the constant we'd have with flipped-strictness predicate.
- auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
+ auto FlippedStrictness =
+ InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
if (!FlippedStrictness)
return nullptr;
@@ -1426,10 +1407,10 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
-Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
- ICmpInst *ICI) {
- if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ))
- return replaceInstUsesWith(SI, V);
+Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
+ ICmpInst *ICI) {
+ if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
+ return NewSel;
if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this))
return NewSel;
@@ -1579,11 +1560,11 @@ static bool canSelectOperandBeMappingIntoPredBlock(const Value *V,
/// We have an SPF (e.g. a min or max) of an SPF of the form:
/// SPF2(SPF1(A, B), C)
-Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
- SelectPatternFlavor SPF1,
- Value *A, Value *B,
- Instruction &Outer,
- SelectPatternFlavor SPF2, Value *C) {
+Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner,
+ SelectPatternFlavor SPF1, Value *A,
+ Value *B, Instruction &Outer,
+ SelectPatternFlavor SPF2,
+ Value *C) {
if (Outer.getType() != Inner->getType())
return nullptr;
@@ -1900,7 +1881,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
return CallInst::Create(F, {X, Y});
}
-Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
+Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
Constant *C;
if (!match(Sel.getTrueValue(), m_Constant(C)) &&
!match(Sel.getFalseValue(), m_Constant(C)))
@@ -1966,10 +1947,11 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Constant *CondC;
- if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC)))
+ auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType());
+ if (!CondValTy || !match(CondVal, m_Constant(CondC)))
return nullptr;
- unsigned NumElts = cast<VectorType>(CondVal->getType())->getNumElements();
+ unsigned NumElts = CondValTy->getNumElements();
SmallVector<int, 16> Mask;
Mask.reserve(NumElts);
for (unsigned i = 0; i != NumElts; ++i) {
@@ -2001,8 +1983,8 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
/// to a vector select by splatting the condition. A splat may get folded with
/// other operations in IR and having all operands of a select be vector types
/// is likely better for vector codegen.
-static Instruction *canonicalizeScalarSelectOfVecs(
- SelectInst &Sel, InstCombiner &IC) {
+static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel,
+ InstCombinerImpl &IC) {
auto *Ty = dyn_cast<VectorType>(Sel.getType());
if (!Ty)
return nullptr;
@@ -2015,8 +1997,8 @@ static Instruction *canonicalizeScalarSelectOfVecs(
// select (extelt V, Index), T, F --> select (splat V, Index), T, F
// Splatting the extracted condition reduces code (we could directly create a
// splat shuffle of the source vector to eliminate the intermediate step).
- unsigned NumElts = Ty->getNumElements();
- return IC.replaceOperand(Sel, 0, IC.Builder.CreateVectorSplat(NumElts, Cond));
+ return IC.replaceOperand(
+ Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond));
}
/// Reuse bitcasted operands between a compare and select:
@@ -2172,7 +2154,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X,
}
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) {
+Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) {
Type *Ty = MinMax1.getType();
// We are looking for a tree of:
@@ -2293,34 +2275,42 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp);
}
-/// Try to reduce a rotate pattern that includes a compare and select into a
-/// funnel shift intrinsic. Example:
+/// Try to reduce a funnel/rotate pattern that includes a compare and select
+/// into a funnel shift intrinsic. Example:
/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b)))
/// --> call llvm.fshl.i32(a, a, b)
-static Instruction *foldSelectRotate(SelectInst &Sel) {
- // The false value of the select must be a rotate of the true value.
- Value *Or0, *Or1;
- if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c)))
+/// --> call llvm.fshl.i32(a, b, c)
+/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c)))
+/// --> call llvm.fshr.i32(a, b, c)
+static Instruction *foldSelectFunnelShift(SelectInst &Sel,
+ InstCombiner::BuilderTy &Builder) {
+ // This must be a power-of-2 type for a bitmasking transform to be valid.
+ unsigned Width = Sel.getType()->getScalarSizeInBits();
+ if (!isPowerOf2_32(Width))
return nullptr;
- Value *TVal = Sel.getTrueValue();
- Value *SA0, *SA1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1)))))
+ BinaryOperator *Or0, *Or1;
+ if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
return nullptr;
- auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
- auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
- if (ShiftOpcode0 == ShiftOpcode1)
+ Value *SV0, *SV1, *SA0, *SA1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0),
+ m_ZExtOrSelf(m_Value(SA0))))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1),
+ m_ZExtOrSelf(m_Value(SA1))))) ||
+ Or0->getOpcode() == Or1->getOpcode())
return nullptr;
- // We have one of these patterns so far:
- // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1))
- // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1))
- // This must be a power-of-2 rotate for a bitmasking transform to be valid.
- unsigned Width = Sel.getType()->getScalarSizeInBits();
- if (!isPowerOf2_32(Width))
- return nullptr;
+ // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(SV0, SV1);
+ std::swap(SA0, SA1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
// Check the shift amounts to see if they are an opposite pair.
Value *ShAmt;
@@ -2331,6 +2321,15 @@ static Instruction *foldSelectRotate(SelectInst &Sel) {
else
return nullptr;
+ // We should now have this pattern:
+ // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1))
+ // The false value of the select must be a funnel-shift of the true value:
+ // IsFShl -> TVal must be SV0 else TVal must be SV1.
+ bool IsFshl = (ShAmt == SA0);
+ Value *TVal = Sel.getTrueValue();
+ if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1))
+ return nullptr;
+
// Finally, see if the select is filtering out a shift-by-zero.
Value *Cond = Sel.getCondition();
ICmpInst::Predicate Pred;
@@ -2338,13 +2337,21 @@ static Instruction *foldSelectRotate(SelectInst &Sel) {
Pred != ICmpInst::ICMP_EQ)
return nullptr;
- // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way.
+ // If this is not a rotate then the select was blocking poison from the
+ // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
+ if (SV0 != SV1) {
+ if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1))
+ SV1 = Builder.CreateFreeze(SV1);
+ else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0))
+ SV0 = Builder.CreateFreeze(SV0);
+ }
+
+ // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way.
// Convert to funnel shift intrinsic.
- bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) ||
- (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl);
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType());
- return IntrinsicInst::Create(F, { TVal, TVal, ShAmt });
+ ShAmt = Builder.CreateZExt(ShAmt, Sel.getType());
+ return IntrinsicInst::Create(F, { SV0, SV1, ShAmt });
}
static Instruction *foldSelectToCopysign(SelectInst &Sel,
@@ -2368,7 +2375,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
- !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType)
+ !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
+ X->getType() != SelType)
return nullptr;
// If needed, negate the value that will be the sign argument of the copysign:
@@ -2389,7 +2397,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
return CopySign;
}
-Instruction *InstCombiner::foldVectorSelect(SelectInst &Sel) {
+Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType());
if (!VecTy)
return nullptr;
@@ -2469,6 +2477,10 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB,
} else
return nullptr;
+ // Make sure the branches are actually different.
+ if (TrueSucc == FalseSucc)
+ return nullptr;
+
// We want to replace select %cond, %a, %b with a phi that takes value %a
// for all incoming edges that are dominated by condition `%cond == true`,
// and value %b for edges dominated by condition `%cond == false`. If %a
@@ -2515,7 +2527,33 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
return nullptr;
}
-Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
+static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
+ FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
+ if (!FI)
+ return nullptr;
+
+ Value *Cond = FI->getOperand(0);
+ Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
+
+ // select (freeze(x == y)), x, y --> y
+ // select (freeze(x != y)), x, y --> x
+ // The freeze should be only used by this select. Otherwise, remaining uses of
+ // the freeze can observe a contradictory value.
+ // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1
+ // a = select c, x, y ;
+ // f(a, c) ; f(poison, 1) cannot happen, but if a is folded
+ // ; to y, this can happen.
+ CmpInst::Predicate Pred;
+ if (FI->hasOneUse() &&
+ match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
+ (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) {
+ return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal;
+ }
+
+ return nullptr;
+}
+
+Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
@@ -2551,38 +2589,45 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (SelType->isIntOrIntVectorTy(1) &&
TrueVal->getType() == CondVal->getType()) {
- if (match(TrueVal, m_One())) {
+ if (match(TrueVal, m_One()) &&
+ (EnableUnsafeSelectTransform || impliesPoison(FalseVal, CondVal))) {
// Change: A = select B, true, C --> A = or B, C
return BinaryOperator::CreateOr(CondVal, FalseVal);
}
- if (match(TrueVal, m_Zero())) {
- // Change: A = select B, false, C --> A = and !B, C
- Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return BinaryOperator::CreateAnd(NotCond, FalseVal);
- }
- if (match(FalseVal, m_Zero())) {
+ if (match(FalseVal, m_Zero()) &&
+ (EnableUnsafeSelectTransform || impliesPoison(TrueVal, CondVal))) {
// Change: A = select B, C, false --> A = and B, C
return BinaryOperator::CreateAnd(CondVal, TrueVal);
}
+
+ // select a, false, b -> select !a, b, false
+ if (match(TrueVal, m_Zero())) {
+ Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
+ return SelectInst::Create(NotCond, FalseVal,
+ ConstantInt::getFalse(SelType));
+ }
+ // select a, b, true -> select !a, true, b
if (match(FalseVal, m_One())) {
- // Change: A = select B, C, true --> A = or !B, C
Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return BinaryOperator::CreateOr(NotCond, TrueVal);
+ return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType),
+ TrueVal);
}
- // select a, a, b -> a | b
- // select a, b, a -> a & b
+ // select a, a, b -> select a, true, b
if (CondVal == TrueVal)
- return BinaryOperator::CreateOr(CondVal, FalseVal);
+ return replaceOperand(SI, 1, ConstantInt::getTrue(SelType));
+ // select a, b, a -> select a, b, false
if (CondVal == FalseVal)
- return BinaryOperator::CreateAnd(CondVal, TrueVal);
+ return replaceOperand(SI, 2, ConstantInt::getFalse(SelType));
- // select a, ~a, b -> (~a) & b
- // select a, b, ~a -> (~a) | b
+ // select a, !a, b -> select !a, b, false
if (match(TrueVal, m_Not(m_Specific(CondVal))))
- return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+ return SelectInst::Create(TrueVal, FalseVal,
+ ConstantInt::getFalse(SelType));
+ // select a, b, !a -> select !a, true, b
if (match(FalseVal, m_Not(m_Specific(CondVal))))
- return BinaryOperator::CreateOr(TrueVal, FalseVal);
+ return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType),
+ TrueVal);
}
// Selecting between two integer or vector splat integer constants?
@@ -2591,7 +2636,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
// select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0>
// because that may need 3 instructions to splat the condition value:
// extend, insertelement, shufflevector.
- if (SelType->isIntOrIntVectorTy() &&
+ //
+ // Do not handle i1 TrueVal and FalseVal otherwise would result in
+ // zext/sext i1 to i1.
+ if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) &&
CondVal->getType()->isVectorTy() == SelType->isVectorTy()) {
// select C, 1, 0 -> zext C to int
if (match(TrueVal, m_One()) && match(FalseVal, m_Zero()))
@@ -2838,8 +2886,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}
// select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
- // We choose this as normal form to enable folding on the And and shortening
- // paths for the values (this helps GetUnderlyingObjects() for example).
+ // We choose this as normal form to enable folding on the And and
+ // shortening paths for the values (this helps getUnderlyingObjects() for
+ // example).
if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition());
replaceOperand(SI, 0, And);
@@ -2922,7 +2971,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
Value *NotCond;
- if (match(CondVal, m_Not(m_Value(NotCond)))) {
+ if (match(CondVal, m_Not(m_Value(NotCond))) &&
+ !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) {
replaceOperand(SI, 0, NotCond);
SI.swapValues();
SI.swapProfMetadata();
@@ -2956,8 +3006,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this))
return Select;
- if (Instruction *Rot = foldSelectRotate(SI))
- return Rot;
+ if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder))
+ return Funnel;
if (Instruction *Copysign = foldSelectToCopysign(SI, Builder))
return Copysign;
@@ -2965,5 +3015,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *PN = foldSelectToPhi(SI, DT, Builder))
return replaceInstUsesWith(SI, PN);
+ if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder))
+ return replaceInstUsesWith(SI, Fr);
+
return nullptr;
}