aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp555
1 files changed, 411 insertions, 144 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 9f220ec003ec..aaf4ece3249a 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -99,7 +99,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
// transform. Bail out if we can not exclude that possibility.
if (isa<FPMathOperator>(BO))
if (!BO->hasNoSignedZeros() &&
- !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI))
+ !cannotBeNegativeZero(Y, 0,
+ IC.getSimplifyQuery().getWithInstruction(&Sel)))
return nullptr;
// BO = binop Y, X
@@ -201,6 +202,14 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
const APInt &ValC = !TC.isZero() ? TC : FC;
unsigned ValZeros = ValC.logBase2();
unsigned AndZeros = AndMask.logBase2();
+ bool ShouldNotVal = !TC.isZero();
+ ShouldNotVal ^= Pred == ICmpInst::ICMP_NE;
+
+ // If we would need to create an 'and' + 'shift' + 'xor' to replace a 'select'
+ // + 'icmp', then this transformation would result in more instructions and
+ // potentially interfere with other folding.
+ if (CreateAnd && ShouldNotVal && ValZeros != AndZeros)
+ return nullptr;
// Insert the 'and' instruction on the input to the truncate.
if (CreateAnd)
@@ -220,8 +229,6 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
// Okay, now we know that everything is set up, we just don't know whether we
// have a icmp_ne or icmp_eq and whether the true or false val is the zero.
- bool ShouldNotVal = !TC.isZero();
- ShouldNotVal ^= Pred == ICmpInst::ICMP_NE;
if (ShouldNotVal)
V = Builder.CreateXor(V, ValC);
@@ -484,10 +491,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
}
if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) {
auto *FGEP = cast<GetElementPtrInst>(FI);
- Type *ElementType = TGEP->getResultElementType();
- return TGEP->isInBounds() && FGEP->isInBounds()
- ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1})
- : GetElementPtrInst::Create(ElementType, Op0, {Op1});
+ Type *ElementType = TGEP->getSourceElementType();
+ return GetElementPtrInst::Create(
+ ElementType, Op0, Op1, TGEP->getNoWrapFlags() & FGEP->getNoWrapFlags());
}
llvm_unreachable("Expected BinaryOperator or GEP");
return nullptr;
@@ -535,19 +541,29 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
// between 0, 1 and -1.
const APInt *OOpC;
bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) ||
- (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
- Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
- Swapped ? OOp : C, "", &SI);
- if (isa<FPMathOperator>(&SI))
- cast<Instruction>(NewSel)->setFastMathFlags(FMF);
- NewSel->takeName(TVI);
- BinaryOperator *BO =
- BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
- BO->copyIRFlags(TVI);
- return BO;
- }
- return nullptr;
+ if (isa<Constant>(OOp) &&
+ (!OOpIsAPInt || !isSelect01(C->getUniqueInteger(), *OOpC)))
+ return nullptr;
+
+ // If the false value is a NaN then we have that the floating point math
+ // operation in the transformed code may not preserve the exact NaN
+ // bit-pattern -- e.g. `fadd sNaN, 0.0 -> qNaN`.
+ // This makes the transformation incorrect since the original program would
+ // have preserved the exact NaN bit-pattern.
+ // Avoid the folding if the false value might be a NaN.
+ if (isa<FPMathOperator>(&SI) &&
+ !computeKnownFPClass(FalseVal, FMF, fcNan, &SI).isKnownNeverNaN())
+ return nullptr;
+
+ Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
+ Swapped ? OOp : C, "", &SI);
+ if (isa<FPMathOperator>(&SI))
+ cast<Instruction>(NewSel)->setFastMathFlags(FMF);
+ NewSel->takeName(TVI);
+ BinaryOperator *BO =
+ BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
+ BO->copyIRFlags(TVI);
+ return BO;
};
if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false))
@@ -1116,7 +1132,7 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal,
/// into:
/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
- InstCombiner::BuilderTy &Builder) {
+ InstCombinerImpl &IC) {
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
@@ -1158,6 +1174,9 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
// Explicitly clear the 'is_zero_poison' flag. It's always valid to go from
// true to false on this flag, so we can replace it for all users.
II->setArgOperand(1, ConstantInt::getFalse(II->getContext()));
+ // A range annotation on the intrinsic may no longer be valid.
+ II->dropPoisonGeneratingAnnotations();
+ IC.addToWorklist(II);
return SelectArg;
}
@@ -1190,7 +1209,7 @@ static Value *canonicalizeSPF(ICmpInst &Cmp, Value *TrueVal, Value *FalseVal,
match(RHS, m_NSWNeg(m_Specific(LHS)));
Constant *IntMinIsPoisonC =
ConstantInt::get(Type::getInt1Ty(Cmp.getContext()), IntMinIsPoison);
- Instruction *Abs =
+ Value *Abs =
IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC);
if (SPF == SelectPatternFlavor::SPF_NABS)
@@ -1228,8 +1247,11 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
if (Depth == 2)
return false;
+ assert(!isa<Constant>(Old) && "Only replace non-constant values");
+
auto *I = dyn_cast<Instruction>(V);
- if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I))
+ if (!I || !I->hasOneUse() ||
+ !isSafeToSpeculativelyExecuteWithVariableReplaced(I))
return false;
bool Changed = false;
@@ -1274,22 +1296,36 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
Swapped = true;
}
- // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
- // Make sure Y cannot be undef though, as we might pick different values for
- // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
- // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
- // replacement cycle.
Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
- if (TrueVal != CmpLHS &&
- isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
- if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
- /* AllowRefinement */ true))
- // Require either the replacement or the simplification result to be a
- // constant to avoid infinite loops.
- // FIXME: Make this check more precise.
- if (isa<Constant>(CmpRHS) || isa<Constant>(V))
+ auto ReplaceOldOpWithNewOp = [&](Value *OldOp,
+ Value *NewOp) -> Instruction * {
+ // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
+ // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that
+ // would lead to an infinite replacement cycle.
+ // If we will be able to evaluate f(Y) to a constant, we can allow undef,
+ // otherwise Y cannot be undef as we might pick different values for undef
+ // in the icmp and in f(Y).
+ if (TrueVal == OldOp)
+ return nullptr;
+
+ if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ,
+ /* AllowRefinement=*/true)) {
+ // Need some guarantees about the new simplified op to ensure we don't inf
+ // loop.
+ // If we simplify to a constant, replace if we aren't creating new undef.
+ if (match(V, m_ImmConstant()) &&
+ isGuaranteedNotToBeUndef(V, SQ.AC, &Sel, &DT))
return replaceOperand(Sel, Swapped ? 2 : 1, V);
+ // If NewOp is a constant and OldOp is not replace iff NewOp doesn't
+ // contain and undef elements.
+ if (match(NewOp, m_ImmConstant()) || NewOp == V) {
+ if (isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT))
+ return replaceOperand(Sel, Swapped ? 2 : 1, V);
+ return nullptr;
+ }
+ }
+
// Even if TrueVal does not simplify, we can directly replace a use of
// CmpLHS with CmpRHS, as long as the instruction is not used anywhere
// else and is safe to speculatively execute (we may end up executing it
@@ -1297,17 +1333,18 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// undefined behavior). Only do this if CmpRHS is a constant, as
// profitability is not clear for other cases.
// FIXME: Support vectors.
- if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
- !Cmp.getType()->isVectorTy())
- if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS))
+ if (OldOp == CmpLHS && match(NewOp, m_ImmConstant()) &&
+ !match(OldOp, m_Constant()) && !Cmp.getType()->isVectorTy() &&
+ isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT))
+ if (replaceInInstruction(TrueVal, OldOp, NewOp))
return &Sel;
- }
- if (TrueVal != CmpRHS &&
- isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
- if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
- /* AllowRefinement */ true))
- if (isa<Constant>(CmpLHS) || isa<Constant>(V))
- return replaceOperand(Sel, Swapped ? 2 : 1, V);
+ return nullptr;
+ };
+
+ if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS))
+ return R;
+ if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS))
+ return R;
auto *FalseInst = dyn_cast<Instruction>(FalseVal);
if (!FalseInst)
@@ -1329,7 +1366,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
/* AllowRefinement */ false,
&DropFlags) == TrueVal) {
for (Instruction *I : DropFlags) {
- I->dropPoisonGeneratingFlagsAndMetadata();
+ I->dropPoisonGeneratingAnnotations();
Worklist.add(I);
}
@@ -1354,7 +1391,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// Also ULT predicate can also be UGT iff C0 != -1 (+invert result)
// SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.)
static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
- InstCombiner::BuilderTy &Builder) {
+ InstCombiner::BuilderTy &Builder,
+ InstCombiner &IC) {
Value *X = Sel0.getTrueValue();
Value *Sel1 = Sel0.getFalseValue();
@@ -1482,14 +1520,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
std::swap(ThresholdLowIncl, ThresholdHighExcl);
// The fold has a precondition 1: C2 s>= ThresholdLow
- auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
- ThresholdLowIncl);
- if (!match(Precond1, m_One()))
+ auto *Precond1 = ConstantFoldCompareInstOperands(
+ ICmpInst::Predicate::ICMP_SGE, C2, ThresholdLowIncl, IC.getDataLayout());
+ if (!Precond1 || !match(Precond1, m_One()))
return nullptr;
// The fold has a precondition 2: C2 s<= ThresholdHigh
- auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
- ThresholdHighExcl);
- if (!match(Precond2, m_One()))
+ auto *Precond2 = ConstantFoldCompareInstOperands(
+ ICmpInst::Predicate::ICMP_SLE, C2, ThresholdHighExcl, IC.getDataLayout());
+ if (!Precond2 || !match(Precond2, m_One()))
return nullptr;
// If we are matching from a truncated input, we need to sext the
@@ -1500,7 +1538,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
if (!match(ReplacementLow, m_ImmConstant(LowC)) ||
!match(ReplacementHigh, m_ImmConstant(HighC)))
return nullptr;
- const DataLayout &DL = Sel0.getModule()->getDataLayout();
+ const DataLayout &DL = Sel0.getDataLayout();
ReplacementLow =
ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL);
ReplacementHigh =
@@ -1610,7 +1648,7 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal,
return nullptr;
const APInt *CmpC;
- if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC)))
+ if (!match(Cmp->getOperand(1), m_APIntAllowPoison(CmpC)))
return nullptr;
// (X u< 2) ? -X : -1 --> sext (X != 0)
@@ -1676,6 +1714,109 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
return nullptr;
}
+static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
+ InstCombinerImpl &IC) {
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ if (!ICmpInst::isEquality(Pred))
+ return nullptr;
+
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ Value *CmpLHS = ICI->getOperand(0);
+ Value *CmpRHS = ICI->getOperand(1);
+
+ if (Pred == ICmpInst::ICMP_NE)
+ std::swap(TrueVal, FalseVal);
+
+ // Transform (X == C) ? X : Y -> (X == C) ? C : Y
+ // specific handling for Bitwise operation.
+ // x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y)
+ // x|y -> (x&y) | (x^y) or (x&y) ^ (x^y)
+ // x^y -> (x|y) ^ (x&y) or (x|y) & ~(x&y)
+ Value *X, *Y;
+ if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) ||
+ !match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y))))
+ return nullptr;
+
+ const unsigned AndOps = Instruction::And, OrOps = Instruction::Or,
+ XorOps = Instruction::Xor, NoOps = 0;
+ enum NotMask { None = 0, NotInner, NotRHS };
+
+ auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc,
+ unsigned NotMask) {
+ auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y));
+ if (OuterOpc == NoOps)
+ return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);
+
+ if (NotMask == NotInner) {
+ return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
+ m_Specific(CmpRHS)));
+ } else if (NotMask == NotRHS) {
+ return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
+ m_NotForbidPoison(m_Specific(CmpRHS))));
+ } else {
+ return match(FalseVal,
+ m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
+ }
+ };
+
+ // (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y or (X^Y)^ C : X^Y
+ // (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
+ if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
+ if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
+ // (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C
+ // (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C
+ if (matchFalseVal(OrOps, XorOps, None) ||
+ matchFalseVal(XorOps, XorOps, None))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
+ // (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C
+ // (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C
+ if (matchFalseVal(XorOps, OrOps, None) ||
+ matchFalseVal(AndOps, OrOps, NotRHS))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ }
+ }
+
+ // (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y or ~(X^Y)&C : X^Y
+ // (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y or ~(X&Y)&C : X&Y
+ if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) {
+ if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) {
+ // (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C -> (X^Y)^C
+ // (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C
+ if (matchFalseVal(XorOps, XorOps, None) ||
+ matchFalseVal(AndOps, XorOps, NotInner))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
+ // (X|Y)==C ? X^Y : (X&Y)^C -> (X&Y)^C : (X&Y)^C -> (X&Y)^C
+ // (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C
+ if (matchFalseVal(XorOps, AndOps, None) ||
+ matchFalseVal(AndOps, AndOps, NotInner))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ }
+ }
+
+ // (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
+ // (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y or (X&Y)^ C : X&Y
+ if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) {
+ if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) {
+ // (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C
+ // (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C
+ if (matchFalseVal(XorOps, OrOps, None) ||
+ matchFalseVal(AndOps, OrOps, NotRHS))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ } else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
+ // (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C
+ // (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C
+ if (matchFalseVal(OrOps, AndOps, None) ||
+ matchFalseVal(XorOps, AndOps, None))
+ return IC.replaceInstUsesWith(SI, FalseVal);
+ }
+ }
+
+ return nullptr;
+}
+
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
@@ -1689,7 +1830,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
- if (Value *V = canonicalizeClampLike(SI, *ICI, Builder))
+ if (Value *V = canonicalizeClampLike(SI, *ICI, Builder, *this))
return replaceInstUsesWith(SI, V);
if (Instruction *NewSel =
@@ -1718,6 +1859,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
}
}
+ if (Instruction *NewSel = foldSelectICmpEq(SI, ICI, *this))
+ return NewSel;
+
// Canonicalize a signbit condition to use zero constant by swapping:
// (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV
// To avoid conflicts (infinite loops) with other canonicalizations, this is
@@ -1803,7 +1947,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
- if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
+ if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, *this))
return replaceInstUsesWith(SI, V);
if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder))
@@ -2223,20 +2367,20 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
/// operand, the result of the select will always be equal to its false value.
/// For example:
///
-/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
-/// %1 = extractvalue { i64, i1 } %0, 1
-/// %2 = extractvalue { i64, i1 } %0, 0
-/// %3 = select i1 %1, i64 %compare, i64 %2
-/// ret i64 %3
+/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+/// %val = extractvalue { i64, i1 } %cmpxchg, 0
+/// %success = extractvalue { i64, i1 } %cmpxchg, 1
+/// %sel = select i1 %success, i64 %compare, i64 %val
+/// ret i64 %sel
///
-/// The returned value of the cmpxchg instruction (%2) is the original value
-/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2
+/// The returned value of the cmpxchg instruction (%val) is the original value
+/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %val
/// must have been equal to %compare. Thus, the result of the select is always
-/// equal to %2, and the code can be simplified to:
+/// equal to %val, and the code can be simplified to:
///
-/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
-/// %1 = extractvalue { i64, i1 } %0, 0
-/// ret i64 %1
+/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+/// %val = extractvalue { i64, i1 } %cmpxchg, 0
+/// ret i64 %val
///
static Value *foldSelectCmpXchg(SelectInst &SI) {
// A helper that determines if V is an extractvalue instruction whose
@@ -2369,14 +2513,11 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
Value *FVal = Sel.getFalseValue();
Type *SelType = Sel.getType();
- if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType())
- return nullptr;
-
// Match select ?, TC, FC where the constants are equal but negated.
// TODO: Generalize to handle a negated variable operand?
const APFloat *TC, *FC;
- if (!match(TVal, m_APFloatAllowUndef(TC)) ||
- !match(FVal, m_APFloatAllowUndef(FC)) ||
+ if (!match(TVal, m_APFloatAllowPoison(TC)) ||
+ !match(FVal, m_APFloatAllowPoison(FC)) ||
!abs(*TC).bitwiseIsEqual(abs(*FC)))
return nullptr;
@@ -2386,9 +2527,9 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
const APInt *C;
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
- if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
- !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
- X->getType() != SelType)
+ if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+ m_APInt(C)))) ||
+ !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType)
return nullptr;
// If needed, negate the value that will be the sign argument of the copysign:
@@ -2423,8 +2564,8 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
if (auto *I = dyn_cast<Instruction>(V))
I->copyIRFlags(&Sel);
Module *M = Sel.getModule();
- Function *F = Intrinsic::getDeclaration(
- M, Intrinsic::experimental_vector_reverse, V->getType());
+ Function *F =
+ Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType());
return CallInst::Create(F, V);
};
@@ -2587,7 +2728,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
bool TrueIfSigned = false;
if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) &&
- IC.isSignBitCheck(Pred, *C, TrueIfSigned)))
+ isSignBitCheck(Pred, *C, TrueIfSigned)))
return nullptr;
// If the sign bit is not set, we have a SGE/SGT comparison, and the operands
@@ -2606,7 +2747,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
// %cnd = icmp slt i32 %rem, 0
// %add = add i32 %rem, %n
// %sel = select i1 %cnd, i32 %add, i32 %rem
- if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) &&
+ if (match(TrueVal, m_Add(m_Specific(RemRes), m_Value(Remainder))) &&
match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) &&
IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) &&
FalseVal == RemRes)
@@ -2650,46 +2791,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
return nullptr;
}
+/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI.
+static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI,
+ Value *CondVal,
+ bool CondIsTrue,
+ const DataLayout &DL) {
+ Value *InnerCondVal = SI.getCondition();
+ Value *InnerTrueVal = SI.getTrueValue();
+ Value *InnerFalseVal = SI.getFalseValue();
+ assert(CondVal->getType() == InnerCondVal->getType() &&
+ "The type of inner condition must match with the outer.");
+ if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue))
+ return *Implied ? InnerTrueVal : InnerFalseVal;
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
SelectInst &SI,
bool IsAnd) {
- Value *CondVal = SI.getCondition();
- Value *A = SI.getTrueValue();
- Value *B = SI.getFalseValue();
-
assert(Op->getType()->isIntOrIntVectorTy(1) &&
"Op must be either i1 or vector of i1.");
-
- std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
- if (!Res)
+ if (SI.getCondition()->getType() != Op->getType())
return nullptr;
-
- Value *Zero = Constant::getNullValue(A->getType());
- Value *One = Constant::getAllOnesValue(A->getType());
-
- if (*Res == true) {
- if (IsAnd)
- // select op, (select cond, A, B), false => select op, A, false
- // and op, (select cond, A, B) => select op, A, false
- // if op = true implies condval = true.
- return SelectInst::Create(Op, A, Zero);
- else
- // select op, true, (select cond, A, B) => select op, true, A
- // or op, (select cond, A, B) => select op, true, A
- // if op = false implies condval = true.
- return SelectInst::Create(Op, One, A);
- } else {
- if (IsAnd)
- // select op, (select cond, A, B), false => select op, B, false
- // and op, (select cond, A, B) => select op, B, false
- // if op = true implies condval = false.
- return SelectInst::Create(Op, B, Zero);
- else
- // select op, true, (select cond, A, B) => select op, true, B
- // or op, (select cond, A, B) => select op, true, B
- // if op = false implies condval = false.
- return SelectInst::Create(Op, One, B);
- }
+ if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL))
+ return SelectInst::Create(Op,
+ IsAnd ? V : ConstantInt::getTrue(Op->getType()),
+ IsAnd ? ConstantInt::getFalse(Op->getType()) : V);
+ return nullptr;
}
// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
@@ -2772,6 +2900,36 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
}
}
+ // Match select with (icmp slt (bitcast X to int), 0)
+ // or (icmp sgt (bitcast X to int), -1)
+
+ for (bool Swap : {false, true}) {
+ Value *TrueVal = SI.getTrueValue();
+ Value *X = SI.getFalseValue();
+
+ if (Swap)
+ std::swap(TrueVal, X);
+
+ CmpInst::Predicate Pred;
+ const APInt *C;
+ bool TrueIfSigned;
+ if (!match(CondVal,
+ m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
+ !isSignBitCheck(Pred, *C, TrueIfSigned))
+ continue;
+ if (!match(TrueVal, m_FNeg(m_Specific(X))))
+ return nullptr;
+ if (Swap == TrueIfSigned && !CondVal->hasOneUse() && !TrueVal->hasOneUse())
+ return nullptr;
+
+ // Fold (IsNeg ? -X : X) or (!IsNeg ? X : -X) to fabs(X)
+ // Fold (IsNeg ? X : -X) or (!IsNeg ? -X : X) to -fabs(X)
+ Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI);
+ if (Swap != TrueIfSigned)
+ return IC.replaceInstUsesWith(SI, Fabs);
+ return UnaryOperator::CreateFNegFMF(Fabs, &SI);
+ }
+
return ChangedFMF ? &SI : nullptr;
}
@@ -2808,17 +2966,17 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
// FIXME: we could support non non-splats here.
const APInt *LowBitMaskCst;
- if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst))))
+ if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowPoison(LowBitMaskCst))))
return nullptr;
// Match even if the AND and ADD are swapped.
const APInt *BiasCst, *HighBitMaskCst;
if (!match(XBiasedHighBits,
- m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)),
- m_APIntAllowUndef(HighBitMaskCst))) &&
+ m_And(m_Add(m_Specific(X), m_APIntAllowPoison(BiasCst)),
+ m_APIntAllowPoison(HighBitMaskCst))) &&
!match(XBiasedHighBits,
- m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)),
- m_APIntAllowUndef(BiasCst))))
+ m_Add(m_And(m_Specific(X), m_APIntAllowPoison(HighBitMaskCst)),
+ m_APIntAllowPoison(BiasCst))))
return nullptr;
if (!LowBitMaskCst->isMask())
@@ -2834,7 +2992,8 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
return nullptr;
if (!XBiasedHighBits->hasOneUse()) {
- if (*BiasCst == *LowBitMaskCst)
+ // We can't directly return XBiasedHighBits if it is more poisonous.
+ if (*BiasCst == *LowBitMaskCst && impliesPoison(XBiasedHighBits, X))
return XBiasedHighBits;
return nullptr;
}
@@ -2856,6 +3015,32 @@ struct DecomposedSelect {
};
} // namespace
+/// Folds patterns like:
+/// select c2 (select c1 a b) (select c1 b a)
+/// into:
+/// select (xor c1 c2) b a
+static Instruction *
+foldSelectOfSymmetricSelect(SelectInst &OuterSelVal,
+ InstCombiner::BuilderTy &Builder) {
+
+ Value *OuterCond, *InnerCond, *InnerTrueVal, *InnerFalseVal;
+ if (!match(
+ &OuterSelVal,
+ m_Select(m_Value(OuterCond),
+ m_OneUse(m_Select(m_Value(InnerCond), m_Value(InnerTrueVal),
+ m_Value(InnerFalseVal))),
+ m_OneUse(m_Select(m_Deferred(InnerCond),
+ m_Deferred(InnerFalseVal),
+ m_Deferred(InnerTrueVal))))))
+ return nullptr;
+
+ if (OuterCond->getType() != InnerCond->getType())
+ return nullptr;
+
+ Value *Xor = Builder.CreateXor(InnerCond, OuterCond);
+ return SelectInst::Create(Xor, InnerFalseVal, InnerTrueVal);
+}
+
/// Look for patterns like
/// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false
/// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f
@@ -2960,6 +3145,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return BinaryOperator::CreateOr(CondVal, FalseVal);
}
+ if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_One(), m_Value(B)))) &&
+ impliesPoison(FalseVal, B)) {
+ // (A || B) || C --> A || (B | C)
+ return replaceInstUsesWith(
+ SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal)));
+ }
+
if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
if (auto *RHS = dyn_cast<FCmpInst>(FalseVal))
if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false,
@@ -3001,6 +3193,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return BinaryOperator::CreateAnd(CondVal, TrueVal);
}
+ if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_Value(B), m_Zero()))) &&
+ impliesPoison(TrueVal, B)) {
+ // (A && B) && C --> A && (B & C)
+ return replaceInstUsesWith(
+ SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal)));
+ }
+
if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
if (auto *RHS = dyn_cast<FCmpInst>(TrueVal))
if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true,
@@ -3115,11 +3314,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return replaceInstUsesWith(SI, Op1);
}
- if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
- if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
- /* IsAnd */ IsAnd))
- return I;
-
if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
@@ -3201,7 +3395,8 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
// pattern.
static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
const APInt *Cond1, Value *CtlzOp,
- unsigned BitWidth) {
+ unsigned BitWidth,
+ bool &ShouldDropNUW) {
// The challenge in recognizing std::bit_ceil(X) is that the operand is used
// for the CTLZ proper and select condition, each possibly with some
// operation like add and sub.
@@ -3224,6 +3419,8 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
ConstantRange CR = ConstantRange::makeExactICmpRegion(
CmpInst::getInversePredicate(Pred), *Cond1);
+ ShouldDropNUW = false;
+
// Match the operation that's used to compute CtlzOp from CommonAncestor. If
// CtlzOp == CommonAncestor, return true as no operation is needed. If a
// match is found, execute the operation on CR, update CR, and return true.
@@ -3237,6 +3434,7 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
return true;
}
if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
+ ShouldDropNUW = true;
CR = ConstantRange(*C).sub(CR);
return true;
}
@@ -3306,14 +3504,20 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Pred = CmpInst::getInversePredicate(Pred);
}
+ bool ShouldDropNUW;
+
if (!match(FalseVal, m_One()) ||
!match(TrueVal,
m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
m_Value(Ctlz)))))) ||
!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
- !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
+ !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth,
+ ShouldDropNUW))
return nullptr;
+ if (ShouldDropNUW)
+ cast<Instruction>(CtlzOp)->setHasNoUnsignedWrap(false);
+
// Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a
// single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
// is an integer constant. Masking with BitWidth-1 comes free on some
@@ -3350,6 +3554,33 @@ static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0,
return false;
}
+/// Check whether the KnownBits of a select arm may be affected by the
+/// select condition.
+static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
+ unsigned Depth) {
+ if (Depth == MaxAnalysisRecursionDepth)
+ return false;
+
+ // Ignore the case where the select arm itself is affected. These cases
+ // are handled more efficiently by other optimizations.
+ if (Depth != 0 && Affected.contains(V))
+ return true;
+
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ if (isa<PHINode>(I)) {
+ if (Depth == MaxAnalysisRecursionDepth - 1)
+ return false;
+ Depth = MaxAnalysisRecursionDepth - 2;
+ }
+ return any_of(I->operands(), [&](Value *Op) {
+ return Op->getType()->isIntOrIntVectorTy() &&
+ hasAffectedValue(Op, Affected, Depth + 1);
+ });
+ }
+
+ return false;
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3536,16 +3767,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *Idx = Gep->getOperand(1);
if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType()))
return nullptr;
- Type *ElementType = Gep->getResultElementType();
+ Type *ElementType = Gep->getSourceElementType();
Value *NewT = Idx;
Value *NewF = Constant::getNullValue(Idx->getType());
if (Swap)
std::swap(NewT, NewF);
Value *NewSI =
Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI);
- if (Gep->isInBounds())
- return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI});
- return GetElementPtrInst::Create(ElementType, Ptr, {NewSI});
+ return GetElementPtrInst::Create(ElementType, Ptr, NewSI,
+ Gep->getNoWrapFlags());
};
if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal))
if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false))
@@ -3620,12 +3850,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
if (TrueSI->getCondition()->getType() == CondVal->getType()) {
- // select(C, select(C, a, b), c) -> select(C, a, c)
- if (TrueSI->getCondition() == CondVal) {
- if (SI.getTrueValue() == TrueSI->getTrueValue())
- return nullptr;
- return replaceOperand(SI, 1, TrueSI->getTrueValue());
- }
+ // Fold nested selects if the inner condition can be implied by the outer
+ // condition.
+ if (Value *V = simplifyNestedSelectsUsingImpliedCond(
+ *TrueSI, CondVal, /*CondIsTrue=*/true, DL))
+ return replaceOperand(SI, 1, V);
+
// select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
// We choose this as normal form to enable folding on the And and
// shortening paths for the values (this helps getUnderlyingObjects() for
@@ -3640,12 +3870,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
if (FalseSI->getCondition()->getType() == CondVal->getType()) {
- // select(C, a, select(C, b, c)) -> select(C, a, c)
- if (FalseSI->getCondition() == CondVal) {
- if (SI.getFalseValue() == FalseSI->getFalseValue())
- return nullptr;
- return replaceOperand(SI, 2, FalseSI->getFalseValue());
- }
+ // Fold nested selects if the inner condition can be implied by the outer
+ // condition.
+ if (Value *V = simplifyNestedSelectsUsingImpliedCond(
+ *FalseSI, CondVal, /*CondIsTrue=*/false, DL))
+ return replaceOperand(SI, 2, V);
+
// select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition());
@@ -3786,6 +4016,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder))
+ return I;
+
if (Instruction *I = foldNestedSelects(SI, Builder))
return I;
@@ -3844,5 +4077,39 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ // select Cond, !X, X -> xor Cond, X
+ if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal))
+ return BinaryOperator::CreateXor(CondVal, FalseVal);
+
+ // For vectors, this transform is only safe if the simplification does not
+ // look through any lane-crossing operations. For now, limit to scalars only.
+ if (SelType->isIntegerTy() &&
+ (!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
+ // Try to simplify select arms based on KnownBits implied by the condition.
+ CondContext CC(CondVal);
+ findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
+ CC.AffectedValues.insert(V);
+ });
+ SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
+ if (!CC.AffectedValues.empty()) {
+ if (!isa<Constant>(TrueVal) &&
+ hasAffectedValue(TrueVal, CC.AffectedValues, /*Depth=*/0)) {
+ KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q);
+ if (Known.isConstant())
+ return replaceOperand(SI, 1,
+ ConstantInt::get(SelType, Known.getConstant()));
+ }
+
+ CC.Invert = true;
+ if (!isa<Constant>(FalseVal) &&
+ hasAffectedValue(FalseVal, CC.AffectedValues, /*Depth=*/0)) {
+ KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q);
+ if (Known.isConstant())
+ return replaceOperand(SI, 2,
+ ConstantInt::get(SelType, Known.getConstant()));
+ }
+ }
+ }
+
return nullptr;
}