diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 607 |
1 files changed, 414 insertions, 193 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 6f26f7f5cd19..4867808478a3 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -47,93 +47,51 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -static SelectPatternFlavor -getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return SPF_SMAX; - case SPF_UMIN: - return SPF_UMAX; - case SPF_SMAX: - return SPF_SMIN; - case SPF_UMAX: - return SPF_UMIN; - } -} - -static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, - bool Ordered=false) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return ICmpInst::ICMP_SLT; - case SPF_UMIN: - return ICmpInst::ICMP_ULT; - case SPF_SMAX: - return ICmpInst::ICMP_SGT; - case SPF_UMAX: - return ICmpInst::ICMP_UGT; - case SPF_FMINNUM: - return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; - case SPF_FMAXNUM: - return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; - } -} - -static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, - SelectPatternFlavor SPF, Value *A, - Value *B) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); - assert(CmpInst::isIntPredicate(Pred)); +static Value *createMinMax(InstCombiner::BuilderTy &Builder, + SelectPatternFlavor SPF, Value *A, Value *B) { + CmpInst::Predicate Pred = getMinMaxPred(SPF); + assert(CmpInst::isIntPredicate(Pred) && "Expected integer predicate"); return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. /// This folds: -/// select (icmp eq (and X, C1)), C2, C3 -/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// select (icmp eq (and X, C1)), TC, FC +/// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. /// To something like: -/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC /// Or: -/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 -/// With some variations depending if C3 is larger than C2, or the shift +/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC +/// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, - APInt TrueVal, APInt FalseVal, +static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, InstCombiner::BuilderTy &Builder) { - assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + const APInt *SelTC, *SelFC; + if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || + !match(Sel.getFalseValue(), m_APInt(SelFC))) + return nullptr; // If this is a vector select, we need a vector compare. - if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) return nullptr; Value *V; APInt AndMask; bool CreateAnd = false; - ICmpInst::Predicate Pred = IC->getPredicate(); + ICmpInst::Predicate Pred = Cmp->getPredicate(); if (ICmpInst::isEquality(Pred)) { - if (!match(IC->getOperand(1), m_Zero())) + if (!match(Cmp->getOperand(1), m_Zero())) return nullptr; - V = IC->getOperand(0); - + V = Cmp->getOperand(0); const APInt *AndRHS; if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) return nullptr; AndMask = *AndRHS; - } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), Pred, V, AndMask)) { assert(ICmpInst::isEquality(Pred) && "Not equality test?"); - if (!AndMask.isPowerOf2()) return nullptr; @@ -142,39 +100,58 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, return nullptr; } - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else + // In general, when both constants are non-zero, we would need an offset to + // replace the select. This would require more instructions than we started + // with. But there's one special-case that we handle here because it can + // simplify/reduce the instructions. + APInt TC = *SelTC; + APInt FC = *SelFC; + if (!TC.isNullValue() && !FC.isNullValue()) { + // If the select constants differ by exactly one bit and that's the same + // bit that is masked and checked by the select condition, the select can + // be replaced by bitwise logic to set/clear one bit of the constant result. + if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask) return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; + if (CreateAnd) { + // If we have to create an 'and', then we must kill the cmp to not + // increase the instruction count. + if (!Cmp->hasOneUse()) + return nullptr; + V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask)); + } + bool ExtraBitInTC = TC.ugt(FC); + if (Pred == ICmpInst::ICMP_EQ) { + // If the masked bit in V is clear, clear or set the bit in the result: + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC + Constant *C = ConstantInt::get(SelType, TC); + return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C); + } + if (Pred == ICmpInst::ICMP_NE) { + // If the masked bit in V is set, set or clear the bit in the result: + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC + Constant *C = ConstantInt::get(SelType, FC); + return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C); + } + llvm_unreachable("Only expecting equality predicates"); } - // Make sure one of the select arms is a power of 2. - if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + // Make sure one of the select arms is a power-of-2. + if (!TC.isPowerOf2() && !FC.isPowerOf2()) return nullptr; // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + const APInt &ValC = !TC.isNullValue() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); - if (CreateAnd) { - // Insert the AND instruction on the input to the truncate. + // Insert the 'and' instruction on the input to the truncate. + if (CreateAnd) V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); - } - // If types don't match we can still convert the select by introducing a zext + // If types don't match, we can still convert the select by introducing a zext // or a trunc of the 'and'. if (ValZeros > AndZeros) { V = Builder.CreateZExtOrTrunc(V, SelType); @@ -182,19 +159,17 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, } else if (ValZeros < AndZeros) { V = Builder.CreateLShr(V, AndZeros - ValZeros); V = Builder.CreateZExtOrTrunc(V, SelType); - } else + } else { V = Builder.CreateZExtOrTrunc(V, SelType); + } // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); + bool ShouldNotVal = !TC.isNullValue(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); return V; } @@ -300,12 +275,13 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators with one-use here. As with the cast case - // above, it may be possible to relax the one-use constraint, but that needs - // be examined carefully since it may not reduce the total number of - // instructions. - BinaryOperator *BO = dyn_cast<BinaryOperator>(TI); - if (!BO || !TI->hasOneUse() || !FI->hasOneUse()) + // Only handle binary operators (including two-operand getelementptr) with + // one-use here. As with the cast case above, it may be possible to relax the + // one-use constraint, but that needs be examined carefully since it may not + // reduce the total number of instructions. + if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || + (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || + !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -342,7 +318,18 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; - return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + if (auto *BO = dyn_cast<BinaryOperator>(TI)) { + return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + } + if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { + auto *FGEP = cast<GetElementPtrInst>(FI); + Type *ElementType = TGEP->getResultElementType(); + return TGEP->isInBounds() && FGEP->isInBounds() + ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) + : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + } + llvm_unreachable("Expected BinaryOperator or GEP"); + return nullptr; } static bool isSelect01(const APInt &C1I, const APInt &C2I) { @@ -424,6 +411,47 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } /// We want to turn: +/// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1) +/// into: +/// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0) +/// Note: +/// Z may be 0 if lshr is missing. +/// Worst-case scenario is that we will replace 5 instructions with 5 different +/// instructions, but we got rid of select. +static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, + Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() && + Cmp->getPredicate() == ICmpInst::ICMP_EQ && + match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One()))) + return nullptr; + + // The TrueVal has general form of: and %B, 1 + Value *B; + if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One())))) + return nullptr; + + // Where %B may be optionally shifted: lshr %X, %Z. + Value *X, *Z; + const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); + if (!HasShift) + X = B; + + Value *Y; + if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y)))) + return nullptr; + + // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0 + // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0 + Constant *One = ConstantInt::get(SelType, 1); + Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One; + Value *FullMask = Builder.CreateOr(Y, MaskB); + Value *MaskedX = Builder.CreateAnd(X, FullMask); + Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX); + return new ZExtInst(ICmpNeZero, SelType); +} + +/// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: /// (or (shl (and X, C1), C3), Y) @@ -526,6 +554,59 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, return Builder.CreateOr(V, Y); } +/// Transform patterns such as: (a > b) ? a - b : 0 +/// into: ((a > b) ? a : b) - b) +/// This produces a canonical max pattern that is more easily recognized by the +/// backend and converted into saturated subtraction instructions if those +/// exist. +/// There are 8 commuted/swapped variants of this pattern. +/// TODO: Also support a - UMIN(a,b) patterns. +static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, + const Value *TrueVal, + const Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + + // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + if (match(TrueVal, m_Zero())) { + Pred = ICmpInst::getInversePredicate(Pred); + std::swap(TrueVal, FalseVal); + } + if (!match(FalseVal, m_Zero())) + return nullptr; + + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { + // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 + std::swap(A, B); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected isUnsigned predicate!"); + + // Account for swapped form of subtraction: ((a > b) ? b - a : 0). + bool IsNegative = false; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) + IsNegative = true; + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) + return nullptr; + + // If sub is used anywhere else, we wouldn't be able to eliminate it + // afterwards. + if (!TrueVal->hasOneUse()) + return nullptr; + + // All checks passed, convert to canonical unsigned saturated subtraction + // form: sub(max()). + // (a > b) ? a - b : 0 -> ((a > b) ? a : b) - b) + Value *Max = Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); + return IsNegative ? Builder.CreateSub(B, Max) : Builder.CreateSub(Max, B); +} + /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single /// call to cttz/ctlz with flag 'is_zero_undef' cleared. /// @@ -687,23 +768,18 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, // Canonicalize the compare predicate based on whether we have min or max. Value *LHS, *RHS; - ICmpInst::Predicate NewPred; SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); - switch (SPR.Flavor) { - case SPF_SMIN: NewPred = ICmpInst::ICMP_SLT; break; - case SPF_UMIN: NewPred = ICmpInst::ICMP_ULT; break; - case SPF_SMAX: NewPred = ICmpInst::ICMP_SGT; break; - case SPF_UMAX: NewPred = ICmpInst::ICMP_UGT; break; - default: return nullptr; - } + if (!SelectPatternResult::isMinOrMax(SPR.Flavor)) + return nullptr; // Is this already canonical? + ICmpInst::Predicate CanonicalPred = getMinMaxPred(SPR.Flavor); if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && - Cmp.getPredicate() == NewPred) + Cmp.getPredicate() == CanonicalPred) return nullptr; // Create the canonical compare and plug it into the select. - Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); + Sel.setCondition(Builder.CreateICmp(CanonicalPred, LHS, RHS)); // If the select operands did not change, we're done. if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) @@ -718,6 +794,89 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } +/// There are many select variants for each of ABS/NABS. +/// In matchSelectPattern(), there are different compare constants, compare +/// predicates/operands and select operands. +/// In isKnownNegation(), there are different formats of negated operands. +/// Canonicalize all these variants to 1 pattern. +/// This makes CSE more likely. +static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) + return nullptr; + + // Choose a sign-bit check for the compare (likely simpler for codegen). + // ABS: (X <s 0) ? -X : X + // NABS: (X <s 0) ? X : -X + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; + if (SPF != SelectPatternFlavor::SPF_ABS && + SPF != SelectPatternFlavor::SPF_NABS) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + assert(isKnownNegation(TVal, FVal) && + "Unexpected result from matchSelectPattern"); + + // The compare may use the negated abs()/nabs() operand, or it may use + // negation in non-canonical form such as: sub A, B. + bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || + match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); + + bool CmpCanonicalized = !CmpUsesNegatedOp && + match(Cmp.getOperand(1), m_ZeroInt()) && + Cmp.getPredicate() == ICmpInst::ICMP_SLT; + bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS))); + + // Is this already canonical? + if (CmpCanonicalized && RHSCanonicalized) + return nullptr; + + // If RHS is used by other instructions except compare and select, don't + // canonicalize it to not increase the instruction count. + if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) + return nullptr; + + // Create the canonical compare: icmp slt LHS 0. + if (!CmpCanonicalized) { + Cmp.setPredicate(ICmpInst::ICMP_SLT); + Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); + if (CmpUsesNegatedOp) + Cmp.setOperand(0, LHS); + } + + // Create the canonical RHS: RHS = sub (0, LHS). + if (!RHSCanonicalized) { + assert(RHS->hasOneUse() && "RHS use number is not right"); + RHS = Builder.CreateNeg(LHS); + if (TVal == LHS) { + Sel.setFalseValue(RHS); + FVal = RHS; + } else { + Sel.setTrueValue(RHS); + TVal = RHS; + } + } + + // If the select operands do not change, we're done. + if (SPF == SelectPatternFlavor::SPF_NABS) { + if (TVal == LHS) + return &Sel; + assert(FVal == LHS && "Unexpected results from matchSelectPattern"); + } else { + if (FVal == LHS) + return &Sel; + assert(TVal == LHS && "Unexpected results from matchSelectPattern"); + } + + // We are swapping the select operands, so swap the metadata too. + Sel.setTrueValue(FVal); + Sel.setFalseValue(TVal); + Sel.swapProfMetadata(); + return &Sel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -727,59 +886,18 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; + if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) + return NewAbs; + bool Changed = adjustMinMax(SI, *ICI); + if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) + return replaceInstUsesWith(SI, V); + + // NOTE: if we wanted to, this is where to detect integer MIN/MAX ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - - // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 - // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 - // FIXME: Type and constness constraints could be lifted, but we have to - // watch code size carefully. We should consider xor instead of - // sub/add when we decide to do that. - // TODO: Merge this with foldSelectICmpAnd somehow. - if (CmpLHS->getType()->isIntOrIntVectorTy() && - CmpLHS->getType() == TrueVal->getType()) { - const APInt *C1, *C2; - if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *X; - APInt Mask; - if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { - if (Mask.isSignMask()) { - assert(X == CmpLHS && "Expected to use the compare input directly"); - assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); - - if (Pred == ICmpInst::ICMP_NE) - std::swap(C1, C2); - - // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); - - // Check if we can express the operation with a single or. - if (C2->isAllOnesValue()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - - Value *And = Builder.CreateAnd(AShr, *C2 - *C1); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, - ConstantInt::get(And->getType(), *C1))); - } - } - } - } - - { - const APInt *TrueValC, *FalseValC; - if (match(TrueVal, m_APInt(TrueValC)) && - match(FalseVal, m_APInt(FalseValC))) - if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, - *FalseValC, Builder)) - return replaceInstUsesWith(SI, V); - } - - // NOTE: if we wanted to, this is where to detect integer MIN/MAX - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y @@ -842,16 +960,22 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } + if (Instruction *V = + foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } - /// SI is a select whose condition is a PHI node (but the two may be in /// different blocks). See if the true/false values (V) are live in all of the /// predecessor blocks of the PHI. For example, cases like this can't be mapped: @@ -900,7 +1024,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) - if (SPF1 == SPF2) + if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a @@ -992,10 +1116,10 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (!NotC) NotC = Builder.CreateNot(C); - Value *NewInner = generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); - Value *NewOuter = Builder.CreateNot(generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); + Value *NewInner = createMinMax(Builder, getInverseMinMaxFlavor(SPF1), NotA, + NotB); + Value *NewOuter = Builder.CreateNot( + createMinMax(Builder, getInverseMinMaxFlavor(SPF2), NewInner, NotC)); return replaceInstUsesWith(Outer, NewOuter); } @@ -1075,6 +1199,11 @@ static Instruction *foldAddSubSelect(SelectInst &SI, } Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + Instruction *ExtInst; if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && !match(Sel.getFalseValue(), m_Instruction(ExtInst))) @@ -1084,20 +1213,18 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) return nullptr; - // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. + // If we are extending from a boolean type or if we can create a select that + // has the same size operands as its condition, try to narrow the select. Value *X = ExtInst->getOperand(0); Type *SmallType = X->getType(); - if (!SmallType->isIntOrIntVectorTy(1)) - return nullptr; - - Constant *C; - if (!match(Sel.getTrueValue(), m_Constant(C)) && - !match(Sel.getFalseValue(), m_Constant(C))) + Value *Cond = Sel.getCondition(); + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!SmallType->isIntOrIntVectorTy(1) && + (!Cmp || Cmp->getOperand(0)->getType() != SmallType)) return nullptr; // If the constant is the same after truncation to the smaller type and // extension to the original type, we can narrow the select. - Value *Cond = Sel.getCondition(); Type *SelType = Sel.getType(); Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); @@ -1289,6 +1416,63 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } +/// Reduce a sequence of min/max with a common operand. +static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, + Value *RHS, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); + // TODO: Allow FP min/max with nnan/nsz. + if (!LHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + Value *A, *B, *C, *D; + SelectPatternResult L = matchSelectPattern(LHS, A, B); + SelectPatternResult R = matchSelectPattern(RHS, C, D); + if (SPF != L.Flavor || L.Flavor != R.Flavor) + return nullptr; + + // Look for a common operand. The use checks are different than usual because + // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by + // the select. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (!LHS->hasNUsesOrMore(3) && RHS->hasNUsesOrMore(3)) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else if (!RHS->hasNUsesOrMore(3)) { + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + if (!MinMaxOp || !ThirdOp) + return nullptr; + + CmpInst::Predicate P = getMinMaxPred(SPF); + Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); + return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1489,7 +1673,37 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } - // NOTE: if we wanted to, this is where to detect ABS + + // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We + // also require nnan because we do not want to unintentionally change the + // sign of a NaN value. + Value *X = FCI->getOperand(0); + FCmpInst::Predicate Pred = FCI->getPredicate(); + if (match(FCI->getOperand(1), m_AnyZeroFP()) && FCI->hasNoNaNs()) { + // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) + // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) + if ((X == FalseVal && Pred == FCmpInst::FCMP_OLE && + match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || + (X == TrueVal && Pred == FCmpInst::FCMP_OGT && + match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + // With nsz: + // (X < +/-0.0) ? -X : X --> fabs(X) + // (X <= +/-0.0) ? -X : X --> fabs(X) + // (X > +/-0.0) ? X : -X --> fabs(X) + // (X >= +/-0.0) ? X : -X --> fabs(X) + if (FCI->hasNoSignedZeros() && + ((X == FalseVal && match(TrueVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || + (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + } } // See if we are selecting two values based on a comparison of the two values. @@ -1532,7 +1746,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (LHS->getType()->isFPOrFPVectorTy() && ((CmpLHS != LHS && CmpLHS != RHS) || (CmpRHS != LHS && CmpRHS != RHS)))) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); + CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; if (CmpInst::isIntPredicate(Pred)) { @@ -1551,6 +1765,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); return replaceInstUsesWith(SI, NewCast); } + + // MAX(~a, ~b) -> ~MIN(a, b) + // MIN(~a, ~b) -> ~MAX(a, b) + Value *A, *B; + if (match(LHS, m_Not(m_Value(A))) && match(RHS, m_Not(m_Value(B))) && + (LHS->getNumUses() <= 2 || RHS->getNumUses() <= 2)) { + CmpInst::Predicate InvertedPred = getInverseMinMaxPred(SPF); + Value *InvertedCmp = Builder.CreateICmp(InvertedPred, A, B); + Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); + return BinaryOperator::CreateNot(NewSel); + } + + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } if (SPF) { @@ -1570,28 +1798,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return R; } - // MAX(~a, ~b) -> ~MIN(a, b) - if ((SPF == SPF_SMAX || SPF == SPF_UMAX) && - IsFreeToInvert(LHS, LHS->hasNUses(2)) && - IsFreeToInvert(RHS, RHS->hasNUses(2))) { - // For this transform to be profitable, we need to eliminate at least two - // 'not' instructions if we're going to add one 'not' instruction. - int NumberOfNots = - (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) + - (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) + - (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - - if (NumberOfNots >= 2) { - Value *NewLHS = Builder.CreateNot(LHS); - Value *NewRHS = Builder.CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX ? Builder.CreateICmpSLT(NewLHS, NewRHS) - : Builder.CreateICmpULT(NewLHS, NewRHS); - Value *NewSI = - Builder.CreateNot(Builder.CreateSelect(NewCmp, NewLHS, NewRHS)); - return replaceInstUsesWith(SI, NewSI); - } - } - // TODO. // ABS(-X) -> ABS(X) } @@ -1643,11 +1849,25 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { + // The select might be preventing a division by 0. + switch (BO->getOpcode()) { + default: + return true; + case Instruction::SRem: + case Instruction::URem: + case Instruction::SDiv: + case Instruction::UDiv: + return false; + } + }; + // Try to simplify a binop sandwiched between 2 selects with the same // condition. // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && + canMergeSelectThroughBinop(TrueBO)) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { TrueBO->setOperand(0, TrueBOSI->getTrueValue()); @@ -1666,7 +1886,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && + canMergeSelectThroughBinop(FalseBO)) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { FalseBO->setOperand(0, FalseBOSI->getFalseValue()); |