aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp448
1 files changed, 318 insertions, 130 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index e7d8208f94fd..661c50062223 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -98,7 +98,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
// +0.0 compares equal to -0.0, and so it does not behave as required for this
// transform. Bail out if we can not exclude that possibility.
if (isa<FPMathOperator>(BO))
- if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI))
+ if (!BO->hasNoSignedZeros() &&
+ !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI))
return nullptr;
// BO = binop Y, X
@@ -386,6 +387,32 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp});
}
}
+
+ // select c, (ldexp v, e0), (ldexp v, e1) -> ldexp v, (select c, e0, e1)
+ // select c, (ldexp v0, e), (ldexp v1, e) -> ldexp (select c, v0, v1), e
+ //
+ // select c, (ldexp v0, e0), (ldexp v1, e1) ->
+ // ldexp (select c, v0, v1), (select c, e0, e1)
+ if (TII->getIntrinsicID() == Intrinsic::ldexp) {
+ Value *LdexpVal0 = TII->getArgOperand(0);
+ Value *LdexpExp0 = TII->getArgOperand(1);
+ Value *LdexpVal1 = FII->getArgOperand(0);
+ Value *LdexpExp1 = FII->getArgOperand(1);
+ if (LdexpExp0->getType() == LdexpExp1->getType()) {
+ FPMathOperator *SelectFPOp = cast<FPMathOperator>(&SI);
+ FastMathFlags FMF = cast<FPMathOperator>(TII)->getFastMathFlags();
+ FMF &= cast<FPMathOperator>(FII)->getFastMathFlags();
+ FMF |= SelectFPOp->getFastMathFlags();
+
+ Value *SelectVal = Builder.CreateSelect(Cond, LdexpVal0, LdexpVal1);
+ Value *SelectExp = Builder.CreateSelect(Cond, LdexpExp0, LdexpExp1);
+
+ CallInst *NewLdexp = Builder.CreateIntrinsic(
+ TII->getType(), Intrinsic::ldexp, {SelectVal, SelectExp});
+ NewLdexp->setFastMathFlags(FMF);
+ return replaceInstUsesWith(SI, NewLdexp);
+ }
+ }
}
// icmp with a common operand also can have the common operand
@@ -429,6 +456,21 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
!OtherOpF->getType()->isVectorTy()))
return nullptr;
+ // If we are sinking div/rem after a select, we may need to freeze the
+ // condition because div/rem may induce immediate UB with a poison operand.
+ // For example, the following transform is not safe if Cond can ever be poison
+ // because we can replace poison with zero and then we have div-by-zero that
+ // didn't exist in the original code:
+ // Cond ? x/y : x/z --> x / (Cond ? y : z)
+ auto *BO = dyn_cast<BinaryOperator>(TI);
+ if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) {
+ // A udiv/urem with a common divisor is safe because UB can only occur with
+ // div-by-zero, and that would be present in the original code.
+ if (BO->getOpcode() == Instruction::SDiv ||
+ BO->getOpcode() == Instruction::SRem || MatchIsOpZero)
+ Cond = Builder.CreateFreeze(Cond);
+ }
+
// If we reach here, they do have operations in common.
Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
SI.getName() + ".v", &SI);
@@ -461,7 +503,7 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) {
/// optimization.
Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
Value *FalseVal) {
- // See the comment above GetSelectFoldableOperands for a description of the
+ // See the comment above getSelectFoldableOperands for a description of the
// transformation we are doing here.
auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal,
Value *FalseVal,
@@ -496,7 +538,7 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
if (!isa<Constant>(OOp) ||
(OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
- Swapped ? OOp : C);
+ Swapped ? OOp : C, "", &SI);
if (isa<FPMathOperator>(&SI))
cast<Instruction>(NewSel)->setFastMathFlags(FMF);
NewSel->takeName(TVI);
@@ -569,6 +611,44 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
}
/// We want to turn:
+/// (select (icmp eq (and X, C1), 0), 0, (shl [nsw/nuw] X, C2));
+/// iff C1 is a mask and the number of its leading zeros is equal to C2
+/// into:
+/// shl X, C2
+static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal,
+ Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ ICmpInst::Predicate Pred;
+ Value *AndVal;
+ if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero())))
+ return nullptr;
+
+ if (Pred == ICmpInst::ICMP_NE) {
+ Pred = ICmpInst::ICMP_EQ;
+ std::swap(TVal, FVal);
+ }
+
+ Value *X;
+ const APInt *C2, *C1;
+ if (Pred != ICmpInst::ICMP_EQ ||
+ !match(AndVal, m_And(m_Value(X), m_APInt(C1))) ||
+ !match(TVal, m_Zero()) || !match(FVal, m_Shl(m_Specific(X), m_APInt(C2))))
+ return nullptr;
+
+ if (!C1->isMask() ||
+ C1->countLeadingZeros() != static_cast<unsigned>(C2->getZExtValue()))
+ return nullptr;
+
+ auto *FI = dyn_cast<Instruction>(FVal);
+ if (!FI)
+ return nullptr;
+
+ FI->setHasNoSignedWrap(false);
+ FI->setHasNoUnsignedWrap(false);
+ return FVal;
+}
+
+/// We want to turn:
/// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1
/// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0
/// into:
@@ -935,10 +1015,53 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+/// Try to match patterns with select and subtract as absolute difference.
+static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ auto *TI = dyn_cast<Instruction>(TVal);
+ auto *FI = dyn_cast<Instruction>(FVal);
+ if (!TI || !FI)
+ return nullptr;
+
+ // Normalize predicate to gt/lt rather than ge/le.
+ ICmpInst::Predicate Pred = Cmp->getStrictPredicate();
+ Value *A = Cmp->getOperand(0);
+ Value *B = Cmp->getOperand(1);
+
+ // Normalize "A - B" as the true value of the select.
+ if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) {
+ std::swap(FI, TI);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ // With any pair of no-wrap subtracts:
+ // (A > B) ? (A - B) : (B - A) --> abs(A - B)
+ if (Pred == CmpInst::ICMP_SGT &&
+ match(TI, m_Sub(m_Specific(A), m_Specific(B))) &&
+ match(FI, m_Sub(m_Specific(B), m_Specific(A))) &&
+ (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) &&
+ (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) {
+ // The remaining subtract is not "nuw" any more.
+ // If there's one use of the subtract (no other use than the use we are
+ // about to replace), then we know that the sub is "nsw" in this context
+ // even if it was only "nuw" before. If there's another use, then we can't
+ // add "nsw" to the existing instruction because it may not be safe in the
+ // other user's context.
+ TI->setHasNoUnsignedWrap(false);
+ if (!TI->hasNoSignedWrap())
+ TI->setHasNoSignedWrap(TI->hasOneUse());
+ return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue());
+ }
+
+ return nullptr;
+}
+
/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
// x ? 31 - a : a;
+// // or
+// x ? 31 - a : 32;
/// \code
///
/// into:
@@ -953,15 +1076,19 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal,
if (ICI->getPredicate() == ICmpInst::ICMP_NE)
std::swap(TrueVal, FalseVal);
+ Value *Ctlz;
if (!match(FalseVal,
- m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1))))
+ m_Xor(m_Value(Ctlz), m_SpecificInt(BitWidth - 1))))
return nullptr;
- if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>()))
+ if (!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>()))
+ return nullptr;
+
+ if (TrueVal != Ctlz && !match(TrueVal, m_SpecificInt(BitWidth)))
return nullptr;
Value *X = ICI->getOperand(0);
- auto *II = cast<IntrinsicInst>(TrueVal);
+ auto *II = cast<IntrinsicInst>(Ctlz);
if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X)))))
return nullptr;
@@ -1038,99 +1165,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
return nullptr;
}
-/// Return true if we find and adjust an icmp+select pattern where the compare
-/// is with a constant that can be incremented or decremented to match the
-/// minimum or maximum idiom.
-static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
- ICmpInst::Predicate Pred = Cmp.getPredicate();
- Value *CmpLHS = Cmp.getOperand(0);
- Value *CmpRHS = Cmp.getOperand(1);
- Value *TrueVal = Sel.getTrueValue();
- Value *FalseVal = Sel.getFalseValue();
-
- // We may move or edit the compare, so make sure the select is the only user.
- const APInt *CmpC;
- if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC)))
- return false;
-
- // These transforms only work for selects of integers or vector selects of
- // integer vectors.
- Type *SelTy = Sel.getType();
- auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType());
- if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy())
- return false;
-
- Constant *AdjustedRHS;
- if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT)
- AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1);
- else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT)
- AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1);
- else
- return false;
-
- // X > C ? X : C+1 --> X < C+1 ? C+1 : X
- // X < C ? X : C-1 --> X > C-1 ? C-1 : X
- if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) ||
- (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) {
- ; // Nothing to do here. Values match without any sign/zero extension.
- }
- // Types do not match. Instead of calculating this with mixed types, promote
- // all to the larger type. This enables scalar evolution to analyze this
- // expression.
- else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) {
- Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy);
-
- // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X
- // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X
- // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X
- // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X
- if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) {
- CmpLHS = TrueVal;
- AdjustedRHS = SextRHS;
- } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) &&
- SextRHS == TrueVal) {
- CmpLHS = FalseVal;
- AdjustedRHS = SextRHS;
- } else if (Cmp.isUnsigned()) {
- Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy);
- // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X
- // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X
- // zext + signed compare cannot be changed:
- // 0xff <s 0x00, but 0x00ff >s 0x0000
- if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) {
- CmpLHS = TrueVal;
- AdjustedRHS = ZextRHS;
- } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) &&
- ZextRHS == TrueVal) {
- CmpLHS = FalseVal;
- AdjustedRHS = ZextRHS;
- } else {
- return false;
- }
- } else {
- return false;
- }
- } else {
- return false;
- }
-
- Pred = ICmpInst::getSwappedPredicate(Pred);
- CmpRHS = AdjustedRHS;
- std::swap(FalseVal, TrueVal);
- Cmp.setPredicate(Pred);
- Cmp.setOperand(0, CmpLHS);
- Cmp.setOperand(1, CmpRHS);
- Sel.setOperand(1, TrueVal);
- Sel.setOperand(2, FalseVal);
- Sel.swapProfMetadata();
-
- // Move the compare instruction right before the select instruction. Otherwise
- // the sext/zext value may be defined after the compare instruction uses it.
- Cmp.moveBefore(&Sel);
-
- return true;
-}
-
static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
InstCombinerImpl &IC) {
Value *LHS, *RHS;
@@ -1182,8 +1216,8 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
}
-static bool replaceInInstruction(Value *V, Value *Old, Value *New,
- InstCombiner &IC, unsigned Depth = 0) {
+bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
+ unsigned Depth) {
// Conservatively limit replacement to two instructions upwards.
if (Depth == 2)
return false;
@@ -1195,10 +1229,11 @@ static bool replaceInInstruction(Value *V, Value *Old, Value *New,
bool Changed = false;
for (Use &U : I->operands()) {
if (U == Old) {
- IC.replaceUse(U, New);
+ replaceUse(U, New);
+ Worklist.add(I);
Changed = true;
} else {
- Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1);
+ Changed |= replaceInInstruction(U, Old, New, Depth + 1);
}
}
return Changed;
@@ -1254,7 +1289,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// FIXME: Support vectors.
if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
!Cmp.getType()->isVectorTy())
- if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this))
+ if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS))
return &Sel;
}
if (TrueVal != CmpRHS &&
@@ -1593,13 +1628,32 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal,
return nullptr;
}
-static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) {
+static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
+ InstCombiner::BuilderTy &Builder) {
const APInt *CmpC;
Value *V;
CmpInst::Predicate Pred;
if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC))))
return nullptr;
+ // Match clamp away from min/max value as a max/min operation.
+ Value *TVal = SI.getTrueValue();
+ Value *FVal = SI.getFalseValue();
+ if (Pred == ICmpInst::ICMP_EQ && V == FVal) {
+ // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1)
+ if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal);
+ // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1)
+ if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal);
+ // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1)
+ if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal);
+ // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1)
+ if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal);
+ }
+
BinaryOperator *BO;
const APInt *C;
CmpInst::Predicate CPred;
@@ -1632,7 +1686,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this))
return NewSPF;
- if (Value *V = foldSelectInstWithICmpConst(SI, ICI))
+ if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = canonicalizeClampLike(SI, *ICI, Builder))
@@ -1642,18 +1696,17 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
return NewSel;
- bool Changed = adjustMinMax(SI, *ICI);
-
if (Value *V = foldSelectICmpAnd(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
+ bool Changed = false;
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
- if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
+ if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) {
if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
// Transform (X == C) ? X : Y -> (X == C) ? C : Y
SI.setOperand(1, CmpRHS);
@@ -1683,7 +1736,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
// FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring
// decomposeBitTestICmp() might help.
- {
+ if (TrueVal->getType()->isIntOrIntVectorTy()) {
unsigned BitWidth =
DL.getTypeSizeInBits(TrueVal->getType()->getScalarType());
APInt MinSignedValue = APInt::getSignedMinValue(BitWidth);
@@ -1735,6 +1788,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder))
return V;
+ if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder))
return V;
@@ -1756,6 +1812,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
@@ -2418,7 +2477,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
// in the case of a shuffle with no undefined mask elements.
ArrayRef<int> Mask;
if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
- !is_contained(Mask, UndefMaskElem) &&
+ !is_contained(Mask, PoisonMaskElem) &&
cast<ShuffleVectorInst>(TVal)->isSelect()) {
if (X == FVal) {
// select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X)
@@ -2432,7 +2491,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
}
}
if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
- !is_contained(Mask, UndefMaskElem) &&
+ !is_contained(Mask, PoisonMaskElem) &&
cast<ShuffleVectorInst>(FVal)->isSelect()) {
if (X == TVal) {
// select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y)
@@ -2965,6 +3024,14 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
return replaceOperand(SI, 0, A);
+ // select a, (select ~a, true, b), false -> select a, b, false
+ if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) &&
+ match(FalseVal, m_Zero()))
+ return replaceOperand(SI, 1, B);
+ // select a, true, (select ~a, b, false) -> select a, true, b
+ if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) &&
+ match(TrueVal, m_One()))
+ return replaceOperand(SI, 2, B);
// ~(A & B) & (A | B) --> A ^ B
if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
@@ -3077,6 +3144,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return nullptr;
}
+// Return true if we can safely remove the select instruction for std::bit_ceil
+// pattern.
+static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
+ const APInt *Cond1, Value *CtlzOp,
+ unsigned BitWidth) {
+ // The challenge in recognizing std::bit_ceil(X) is that the operand is used
+ // for the CTLZ proper and select condition, each possibly with some
+ // operation like add and sub.
+ //
+ // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the
+ // select instruction would select 1, which allows us to get rid of the select
+ // instruction.
+ //
+ // To see if we can do so, we do some symbolic execution with ConstantRange.
+ // Specifically, we compute the range of values that Cond0 could take when
+ // Cond == false. Then we successively transform the range until we obtain
+ // the range of values that CtlzOp could take.
+ //
+ // Conceptually, we follow the def-use chain backward from Cond0 while
+ // transforming the range for Cond0 until we meet the common ancestor of Cond0
+ // and CtlzOp. Then we follow the def-use chain forward until we obtain the
+ // range for CtlzOp. That said, we only follow at most one ancestor from
+ // Cond0. Likewise, we only follow at most one ancestor from CtrlOp.
+
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(
+ CmpInst::getInversePredicate(Pred), *Cond1);
+
+ // Match the operation that's used to compute CtlzOp from CommonAncestor. If
+ // CtlzOp == CommonAncestor, return true as no operation is needed. If a
+ // match is found, execute the operation on CR, update CR, and return true.
+ // Otherwise, return false.
+ auto MatchForward = [&](Value *CommonAncestor) {
+ const APInt *C = nullptr;
+ if (CtlzOp == CommonAncestor)
+ return true;
+ if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) {
+ CR = CR.add(*C);
+ return true;
+ }
+ if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
+ CR = ConstantRange(*C).sub(CR);
+ return true;
+ }
+ if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) {
+ CR = CR.binaryNot();
+ return true;
+ }
+ return false;
+ };
+
+ const APInt *C = nullptr;
+ Value *CommonAncestor;
+ if (MatchForward(Cond0)) {
+ // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) {
+ CR = CR.sub(*C);
+ if (!MatchForward(CommonAncestor))
+ return false;
+ // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else {
+ return false;
+ }
+
+ // Return true if all the values in the range are either 0 or negative (if
+ // treated as signed). We do so by evaluating:
+ //
+ // CR - 1 u>= (1 << BitWidth) - 1.
+ APInt IntMax = APInt::getSignMask(BitWidth) - 1;
+ CR = CR.sub(APInt(BitWidth, 1));
+ return CR.icmp(ICmpInst::ICMP_UGE, IntMax);
+}
+
+// Transform the std::bit_ceil(X) pattern like:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %sub = sub i32 32, %ctlz
+// %shl = shl i32 1, %sub
+// %ugt = icmp ugt i32 %x, 1
+// %sel = select i1 %ugt, i32 %shl, i32 1
+//
+// into:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %neg = sub i32 0, %ctlz
+// %masked = and i32 %ctlz, 31
+// %shl = shl i32 1, %sub
+//
+// Note that the select is optimized away while the shift count is masked with
+// 31. We handle some variations of the input operand like std::bit_ceil(X +
+// 1).
+static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
+ Type *SelType = SI.getType();
+ unsigned BitWidth = SelType->getScalarSizeInBits();
+
+ Value *FalseVal = SI.getFalseValue();
+ Value *TrueVal = SI.getTrueValue();
+ ICmpInst::Predicate Pred;
+ const APInt *Cond1;
+ Value *Cond0, *Ctlz, *CtlzOp;
+ if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
+ return nullptr;
+
+ if (match(TrueVal, m_One())) {
+ std::swap(FalseVal, TrueVal);
+ Pred = CmpInst::getInversePredicate(Pred);
+ }
+
+ if (!match(FalseVal, m_One()) ||
+ !match(TrueVal,
+ m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
+ m_Value(Ctlz)))))) ||
+ !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
+ !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
+ return nullptr;
+
+ // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a
+ // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
+ // is an integer constant. Masking with BitWidth-1 comes free on some
+ // hardware as part of the shift instruction.
+ Value *Neg = Builder.CreateNeg(Ctlz);
+ Value *Masked =
+ Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1));
+ return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1),
+ Masked);
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3253,6 +3448,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
std::swap(NewT, NewF);
Value *NewSI =
Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI);
+ if (Gep->isInBounds())
+ return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI});
return GetElementPtrInst::Create(ElementType, Ptr, {NewSI});
};
if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal))
@@ -3364,25 +3561,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
- auto canMergeSelectThroughBinop = [](BinaryOperator *BO) {
- // The select might be preventing a division by 0.
- switch (BO->getOpcode()) {
- default:
- return true;
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::SDiv:
- case Instruction::UDiv:
- return false;
- }
- };
-
// Try to simplify a binop sandwiched between 2 selects with the same
- // condition.
+ // condition. This is not valid for div/rem because the select might be
+ // preventing a division-by-zero.
+ // TODO: A div/rem restriction is conservative; use something like
+ // isSafeToSpeculativelyExecute().
// select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z)
BinaryOperator *TrueBO;
- if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) &&
- canMergeSelectThroughBinop(TrueBO)) {
+ if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) {
if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) {
if (TrueBOSI->getCondition() == CondVal) {
replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue());
@@ -3401,8 +3587,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
// select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W))
BinaryOperator *FalseBO;
- if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) &&
- canMergeSelectThroughBinop(FalseBO)) {
+ if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) {
if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) {
if (FalseBOSI->getCondition() == CondVal) {
replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue());
@@ -3516,5 +3701,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (sinkNotIntoOtherHandOfLogicalOp(SI))
return &SI;
+ if (Instruction *I = foldBitCeil(SI, Builder))
+ return I;
+
return nullptr;
}