diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 448 |
1 files changed, 318 insertions, 130 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index e7d8208f94fd..661c50062223 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -98,7 +98,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // +0.0 compares equal to -0.0, and so it does not behave as required for this // transform. Bail out if we can not exclude that possibility. if (isa<FPMathOperator>(BO)) - if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + if (!BO->hasNoSignedZeros() && + !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI)) return nullptr; // BO = binop Y, X @@ -386,6 +387,32 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); } } + + // select c, (ldexp v, e0), (ldexp v, e1) -> ldexp v, (select c, e0, e1) + // select c, (ldexp v0, e), (ldexp v1, e) -> ldexp (select c, v0, v1), e + // + // select c, (ldexp v0, e0), (ldexp v1, e1) -> + // ldexp (select c, v0, v1), (select c, e0, e1) + if (TII->getIntrinsicID() == Intrinsic::ldexp) { + Value *LdexpVal0 = TII->getArgOperand(0); + Value *LdexpExp0 = TII->getArgOperand(1); + Value *LdexpVal1 = FII->getArgOperand(0); + Value *LdexpExp1 = FII->getArgOperand(1); + if (LdexpExp0->getType() == LdexpExp1->getType()) { + FPMathOperator *SelectFPOp = cast<FPMathOperator>(&SI); + FastMathFlags FMF = cast<FPMathOperator>(TII)->getFastMathFlags(); + FMF &= cast<FPMathOperator>(FII)->getFastMathFlags(); + FMF |= SelectFPOp->getFastMathFlags(); + + Value *SelectVal = Builder.CreateSelect(Cond, LdexpVal0, LdexpVal1); + Value *SelectExp = Builder.CreateSelect(Cond, LdexpExp0, LdexpExp1); + + CallInst *NewLdexp = Builder.CreateIntrinsic( + TII->getType(), Intrinsic::ldexp, {SelectVal, SelectExp}); + NewLdexp->setFastMathFlags(FMF); + return replaceInstUsesWith(SI, NewLdexp); + } + } } // icmp with a common operand also can have the common operand @@ -429,6 +456,21 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, !OtherOpF->getType()->isVectorTy())) return nullptr; + // If we are sinking div/rem after a select, we may need to freeze the + // condition because div/rem may induce immediate UB with a poison operand. + // For example, the following transform is not safe if Cond can ever be poison + // because we can replace poison with zero and then we have div-by-zero that + // didn't exist in the original code: + // Cond ? x/y : x/z --> x / (Cond ? y : z) + auto *BO = dyn_cast<BinaryOperator>(TI); + if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) { + // A udiv/urem with a common divisor is safe because UB can only occur with + // div-by-zero, and that would be present in the original code. + if (BO->getOpcode() == Instruction::SDiv || + BO->getOpcode() == Instruction::SRem || MatchIsOpZero) + Cond = Builder.CreateFreeze(Cond); + } + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, SI.getName() + ".v", &SI); @@ -461,7 +503,7 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) { /// optimization. Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { - // See the comment above GetSelectFoldableOperands for a description of the + // See the comment above getSelectFoldableOperands for a description of the // transformation we are doing here. auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, Value *FalseVal, @@ -496,7 +538,7 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, - Swapped ? OOp : C); + Swapped ? OOp : C, "", &SI); if (isa<FPMathOperator>(&SI)) cast<Instruction>(NewSel)->setFastMathFlags(FMF); NewSel->takeName(TVI); @@ -569,6 +611,44 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, } /// We want to turn: +/// (select (icmp eq (and X, C1), 0), 0, (shl [nsw/nuw] X, C2)); +/// iff C1 is a mask and the number of its leading zeros is equal to C2 +/// into: +/// shl X, C2 +static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred; + Value *AndVal; + if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero()))) + return nullptr; + + if (Pred == ICmpInst::ICMP_NE) { + Pred = ICmpInst::ICMP_EQ; + std::swap(TVal, FVal); + } + + Value *X; + const APInt *C2, *C1; + if (Pred != ICmpInst::ICMP_EQ || + !match(AndVal, m_And(m_Value(X), m_APInt(C1))) || + !match(TVal, m_Zero()) || !match(FVal, m_Shl(m_Specific(X), m_APInt(C2)))) + return nullptr; + + if (!C1->isMask() || + C1->countLeadingZeros() != static_cast<unsigned>(C2->getZExtValue())) + return nullptr; + + auto *FI = dyn_cast<Instruction>(FVal); + if (!FI) + return nullptr; + + FI->setHasNoSignedWrap(false); + FI->setHasNoUnsignedWrap(false); + return FVal; +} + +/// We want to turn: /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 /// into: @@ -935,10 +1015,53 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Try to match patterns with select and subtract as absolute difference. +static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + auto *TI = dyn_cast<Instruction>(TVal); + auto *FI = dyn_cast<Instruction>(FVal); + if (!TI || !FI) + return nullptr; + + // Normalize predicate to gt/lt rather than ge/le. + ICmpInst::Predicate Pred = Cmp->getStrictPredicate(); + Value *A = Cmp->getOperand(0); + Value *B = Cmp->getOperand(1); + + // Normalize "A - B" as the true value of the select. + if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) { + std::swap(FI, TI); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // With any pair of no-wrap subtracts: + // (A > B) ? (A - B) : (B - A) --> abs(A - B) + if (Pred == CmpInst::ICMP_SGT && + match(TI, m_Sub(m_Specific(A), m_Specific(B))) && + match(FI, m_Sub(m_Specific(B), m_Specific(A))) && + (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) && + (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) { + // The remaining subtract is not "nuw" any more. + // If there's one use of the subtract (no other use than the use we are + // about to replace), then we know that the sub is "nsw" in this context + // even if it was only "nuw" before. If there's another use, then we can't + // add "nsw" to the existing instruction because it may not be safe in the + // other user's context. + TI->setHasNoUnsignedWrap(false); + if (!TI->hasNoSignedWrap()) + TI->setHasNoSignedWrap(TI->hasOneUse()); + return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue()); + } + + return nullptr; +} + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); // x ? 31 - a : a; +// // or +// x ? 31 - a : 32; /// \code /// /// into: @@ -953,15 +1076,19 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, if (ICI->getPredicate() == ICmpInst::ICMP_NE) std::swap(TrueVal, FalseVal); + Value *Ctlz; if (!match(FalseVal, - m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + m_Xor(m_Value(Ctlz), m_SpecificInt(BitWidth - 1)))) return nullptr; - if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + if (!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + if (TrueVal != Ctlz && !match(TrueVal, m_SpecificInt(BitWidth))) return nullptr; Value *X = ICI->getOperand(0); - auto *II = cast<IntrinsicInst>(TrueVal); + auto *II = cast<IntrinsicInst>(Ctlz); if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) return nullptr; @@ -1038,99 +1165,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return nullptr; } -/// Return true if we find and adjust an icmp+select pattern where the compare -/// is with a constant that can be incremented or decremented to match the -/// minimum or maximum idiom. -static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { - ICmpInst::Predicate Pred = Cmp.getPredicate(); - Value *CmpLHS = Cmp.getOperand(0); - Value *CmpRHS = Cmp.getOperand(1); - Value *TrueVal = Sel.getTrueValue(); - Value *FalseVal = Sel.getFalseValue(); - - // We may move or edit the compare, so make sure the select is the only user. - const APInt *CmpC; - if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) - return false; - - // These transforms only work for selects of integers or vector selects of - // integer vectors. - Type *SelTy = Sel.getType(); - auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); - if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) - return false; - - Constant *AdjustedRHS; - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); - else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); - else - return false; - - // X > C ? X : C+1 --> X < C+1 ? C+1 : X - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { - ; // Nothing to do here. Values match without any sign/zero extension. - } - // Types do not match. Instead of calculating this with mixed types, promote - // all to the larger type. This enables scalar evolution to analyze this - // expression. - else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { - Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); - - // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X - // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X - // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X - // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X - if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = SextRHS; - } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && - SextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = SextRHS; - } else if (Cmp.isUnsigned()) { - Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); - // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X - // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X - // zext + signed compare cannot be changed: - // 0xff <s 0x00, but 0x00ff >s 0x0000 - if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = ZextRHS; - } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && - ZextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = ZextRHS; - } else { - return false; - } - } else { - return false; - } - } else { - return false; - } - - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - Cmp.setPredicate(Pred); - Cmp.setOperand(0, CmpLHS); - Cmp.setOperand(1, CmpRHS); - Sel.setOperand(1, TrueVal); - Sel.setOperand(2, FalseVal); - Sel.swapProfMetadata(); - - // Move the compare instruction right before the select instruction. Otherwise - // the sext/zext value may be defined after the compare instruction uses it. - Cmp.moveBefore(&Sel); - - return true; -} - static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, InstCombinerImpl &IC) { Value *LHS, *RHS; @@ -1182,8 +1216,8 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, return nullptr; } -static bool replaceInInstruction(Value *V, Value *Old, Value *New, - InstCombiner &IC, unsigned Depth = 0) { +bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New, + unsigned Depth) { // Conservatively limit replacement to two instructions upwards. if (Depth == 2) return false; @@ -1195,10 +1229,11 @@ static bool replaceInInstruction(Value *V, Value *Old, Value *New, bool Changed = false; for (Use &U : I->operands()) { if (U == Old) { - IC.replaceUse(U, New); + replaceUse(U, New); + Worklist.add(I); Changed = true; } else { - Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); + Changed |= replaceInInstruction(U, Old, New, Depth + 1); } } return Changed; @@ -1254,7 +1289,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // FIXME: Support vectors. if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) + if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) return &Sel; } if (TrueVal != CmpRHS && @@ -1593,13 +1628,32 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, return nullptr; } -static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) { +static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, + InstCombiner::BuilderTy &Builder) { const APInt *CmpC; Value *V; CmpInst::Predicate Pred; if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) return nullptr; + // Match clamp away from min/max value as a max/min operation. + Value *TVal = SI.getTrueValue(); + Value *FVal = SI.getFalseValue(); + if (Pred == ICmpInst::ICMP_EQ && V == FVal) { + // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1) + if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal); + // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1) + if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal); + // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1) + if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal); + // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1) + if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal); + } + BinaryOperator *BO; const APInt *C; CmpInst::Predicate CPred; @@ -1632,7 +1686,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) return NewSPF; - if (Value *V = foldSelectInstWithICmpConst(SI, ICI)) + if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) @@ -1642,18 +1696,17 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; - bool Changed = adjustMinMax(SI, *ICI); - if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX + bool Changed = false; Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { + if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y SI.setOperand(1, CmpRHS); @@ -1683,7 +1736,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring // decomposeBitTestICmp() might help. - { + if (TrueVal->getType()->isIntOrIntVectorTy()) { unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); @@ -1735,6 +1788,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) return V; @@ -1756,6 +1812,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -2418,7 +2477,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { // in the case of a shuffle with no undefined mask elements. ArrayRef<int> Mask; if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(TVal)->isSelect()) { if (X == FVal) { // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) @@ -2432,7 +2491,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { } } if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(FVal)->isSelect()) { if (X == TVal) { // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) @@ -2965,6 +3024,14 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); + // select a, (select ~a, true, b), false -> select a, b, false + if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(FalseVal, m_Zero())) + return replaceOperand(SI, 1, B); + // select a, true, (select ~a, b, false) -> select a, true, b + if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(TrueVal, m_One())) + return replaceOperand(SI, 2, B); // ~(A & B) & (A | B) --> A ^ B if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), @@ -3077,6 +3144,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return nullptr; } +// Return true if we can safely remove the select instruction for std::bit_ceil +// pattern. +static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, + const APInt *Cond1, Value *CtlzOp, + unsigned BitWidth) { + // The challenge in recognizing std::bit_ceil(X) is that the operand is used + // for the CTLZ proper and select condition, each possibly with some + // operation like add and sub. + // + // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the + // select instruction would select 1, which allows us to get rid of the select + // instruction. + // + // To see if we can do so, we do some symbolic execution with ConstantRange. + // Specifically, we compute the range of values that Cond0 could take when + // Cond == false. Then we successively transform the range until we obtain + // the range of values that CtlzOp could take. + // + // Conceptually, we follow the def-use chain backward from Cond0 while + // transforming the range for Cond0 until we meet the common ancestor of Cond0 + // and CtlzOp. Then we follow the def-use chain forward until we obtain the + // range for CtlzOp. That said, we only follow at most one ancestor from + // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. + + ConstantRange CR = ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), *Cond1); + + // Match the operation that's used to compute CtlzOp from CommonAncestor. If + // CtlzOp == CommonAncestor, return true as no operation is needed. If a + // match is found, execute the operation on CR, update CR, and return true. + // Otherwise, return false. + auto MatchForward = [&](Value *CommonAncestor) { + const APInt *C = nullptr; + if (CtlzOp == CommonAncestor) + return true; + if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) { + CR = CR.add(*C); + return true; + } + if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + CR = ConstantRange(*C).sub(CR); + return true; + } + if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) { + CR = CR.binaryNot(); + return true; + } + return false; + }; + + const APInt *C = nullptr; + Value *CommonAncestor; + if (MatchForward(Cond0)) { + // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated. + } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) { + CR = CR.sub(*C); + if (!MatchForward(CommonAncestor)) + return false; + // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated. + } else { + return false; + } + + // Return true if all the values in the range are either 0 or negative (if + // treated as signed). We do so by evaluating: + // + // CR - 1 u>= (1 << BitWidth) - 1. + APInt IntMax = APInt::getSignMask(BitWidth) - 1; + CR = CR.sub(APInt(BitWidth, 1)); + return CR.icmp(ICmpInst::ICMP_UGE, IntMax); +} + +// Transform the std::bit_ceil(X) pattern like: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %sub = sub i32 32, %ctlz +// %shl = shl i32 1, %sub +// %ugt = icmp ugt i32 %x, 1 +// %sel = select i1 %ugt, i32 %shl, i32 1 +// +// into: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %neg = sub i32 0, %ctlz +// %masked = and i32 %ctlz, 31 +// %shl = shl i32 1, %sub +// +// Note that the select is optimized away while the shift count is masked with +// 31. We handle some variations of the input operand like std::bit_ceil(X + +// 1). +static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + Value *FalseVal = SI.getFalseValue(); + Value *TrueVal = SI.getTrueValue(); + ICmpInst::Predicate Pred; + const APInt *Cond1; + Value *Cond0, *Ctlz, *CtlzOp; + if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1)))) + return nullptr; + + if (match(TrueVal, m_One())) { + std::swap(FalseVal, TrueVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + if (!match(FalseVal, m_One()) || + !match(TrueVal, + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz)))))) || + !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateNeg(Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3253,6 +3448,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { std::swap(NewT, NewF); Value *NewSI = Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); + if (Gep->isInBounds()) + return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}); return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); }; if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) @@ -3364,25 +3561,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { - // The select might be preventing a division by 0. - switch (BO->getOpcode()) { - default: - return true; - case Instruction::SRem: - case Instruction::URem: - case Instruction::SDiv: - case Instruction::UDiv: - return false; - } - }; - // Try to simplify a binop sandwiched between 2 selects with the same - // condition. + // condition. This is not valid for div/rem because the select might be + // preventing a division-by-zero. + // TODO: A div/rem restriction is conservative; use something like + // isSafeToSpeculativelyExecute(). // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && - canMergeSelectThroughBinop(TrueBO)) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); @@ -3401,8 +3587,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && - canMergeSelectThroughBinop(FalseBO)) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); @@ -3516,5 +3701,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } |