diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
| commit | b60736ec1405bb0a8dd40989f67ef4c93da068ab (patch) | |
| tree | 5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | |
| parent | cfca06d7963fa0909f90483b42a6d7d194d01e08 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 489 |
1 files changed, 271 insertions, 218 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 17124f717af7..f26c194d31b9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -38,6 +38,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <utility> @@ -46,6 +47,11 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +/// FIXME: Enabled by default until the pattern is supported well. +static cl::opt<bool> EnableUnsafeSelectTransform( + "instcombine-unsafe-select-transform", cl::init(true), + cl::desc("Enable poison-unsafe select to and/or transform")); + static Value *createMinMax(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { CmpInst::Predicate Pred = getMinMaxPred(SPF); @@ -57,7 +63,7 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, const TargetLibraryInfo &TLI, - InstCombiner &IC) { + InstCombinerImpl &IC) { // The select condition must be an equality compare with a constant operand. Value *X; Constant *C; @@ -258,29 +264,9 @@ static unsigned getSelectFoldableOperands(BinaryOperator *I) { } } -/// For the same transformation as the previous function, return the identity -/// constant that goes into the select. -static APInt getSelectFoldableConstant(BinaryOperator *I) { - switch (I->getOpcode()) { - default: llvm_unreachable("This cannot happen!"); - case Instruction::Add: - case Instruction::Sub: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - return APInt::getNullValue(I->getType()->getScalarSizeInBits()); - case Instruction::And: - return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits()); - case Instruction::Mul: - return APInt(I->getType()->getScalarSizeInBits(), 1); - } -} - /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. -Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, - Instruction *FI) { +Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI) { // Don't break up min/max patterns. The hasOneUse checks below prevent that // for most cases, but vector min/max with bitcasts can be transformed. If the // one-use restrictions are eased for other patterns, we still don't want to @@ -302,10 +288,9 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // The select condition may be a vector. We may only change the operand // type if the vector width remains the same (and matches the condition). if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { - if (!FIOpndTy->isVectorTy()) - return nullptr; - if (CondVTy->getNumElements() != - cast<VectorType>(FIOpndTy)->getNumElements()) + if (!FIOpndTy->isVectorTy() || + CondVTy->getElementCount() != + cast<VectorType>(FIOpndTy)->getElementCount()) return nullptr; // TODO: If the backend knew how to deal with casts better, we could @@ -418,8 +403,8 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) { /// Try to fold the select into one of the operands to allow further /// optimization. -Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, - Value *FalseVal) { +Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, + Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { @@ -433,14 +418,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - APInt CI = getSelectFoldableConstant(TVI); + Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(), + TVI->getType(), true); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { - Value *C = ConstantInt::get(OOp->getType(), CI); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), @@ -464,14 +450,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - APInt CI = getSelectFoldableConstant(FVI); + Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(), + FVI->getType(), true); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { - Value *C = ConstantInt::get(OOp->getType(), CI); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), @@ -782,25 +769,24 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, // Match unsigned saturated add of 2 variables with an unnecessary 'not'. // There are 8 commuted variants. - // Canonicalize -1 (saturated result) to true value of the select. Just - // swapping the compare operands is legal, because the selected value is the - // same in case of equality, so we can interchange u< and u<=. + // Canonicalize -1 (saturated result) to true value of the select. if (match(FVal, m_AllOnes())) { std::swap(TVal, FVal); - std::swap(Cmp0, Cmp1); + Pred = CmpInst::getInversePredicate(Pred); } if (!match(TVal, m_AllOnes())) return nullptr; - // Canonicalize predicate to 'ULT'. - if (Pred == ICmpInst::ICMP_UGT) { - Pred = ICmpInst::ICMP_ULT; + // Canonicalize predicate to less-than or less-or-equal-than. + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { std::swap(Cmp0, Cmp1); + Pred = CmpInst::getSwappedPredicate(Pred); } - if (Pred != ICmpInst::ICMP_ULT) + if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE) return nullptr; // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + // Strictness of the comparison is irrelevant. Value *Y; if (match(Cmp0, m_Not(m_Value(X))) && match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) { @@ -809,6 +795,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y); } // The 'not' op may be included in the sum but not the compare. + // Strictness of the comparison is irrelevant. X = Cmp0; Y = Cmp1; if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) { @@ -819,7 +806,9 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); } // The overflow may be detected via the add wrapping round. - if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && + // This is only valid for strict comparison! + if (Pred == ICmpInst::ICMP_ULT && + match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) @@ -1024,9 +1013,9 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { /// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 /// Note: if C1 != C2, this will change the icmp constant to the existing /// constant operand of the select. -static Instruction * -canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { +static Instruction *canonicalizeMinMaxWithConstant(SelectInst &Sel, + ICmpInst &Cmp, + InstCombinerImpl &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; @@ -1063,105 +1052,29 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } -/// There are many select variants for each of ABS/NABS. -/// In matchSelectPattern(), there are different compare constants, compare -/// predicates/operands and select operands. -/// In isKnownNegation(), there are different formats of negated operands. -/// Canonicalize all these variants to 1 pattern. -/// This makes CSE more likely. static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { + InstCombinerImpl &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; - // Choose a sign-bit check for the compare (likely simpler for codegen). - // ABS: (X <s 0) ? -X : X - // NABS: (X <s 0) ? X : -X Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; if (SPF != SelectPatternFlavor::SPF_ABS && SPF != SelectPatternFlavor::SPF_NABS) return nullptr; - Value *TVal = Sel.getTrueValue(); - Value *FVal = Sel.getFalseValue(); - assert(isKnownNegation(TVal, FVal) && - "Unexpected result from matchSelectPattern"); - - // The compare may use the negated abs()/nabs() operand, or it may use - // negation in non-canonical form such as: sub A, B. - bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || - match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); - - bool CmpCanonicalized = !CmpUsesNegatedOp && - match(Cmp.getOperand(1), m_ZeroInt()) && - Cmp.getPredicate() == ICmpInst::ICMP_SLT; - bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS))); - - // Is this already canonical? - if (CmpCanonicalized && RHSCanonicalized) - return nullptr; + // Note that NSW flag can only be propagated for normal, non-negated abs! + bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && + match(RHS, m_NSWNeg(m_Specific(LHS))); + Constant *IntMinIsPoisonC = + ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); + Instruction *Abs = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); - // If RHS is not canonical but is used by other instructions, don't - // canonicalize it and potentially increase the instruction count. - if (!RHSCanonicalized) - if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) - return nullptr; + if (SPF == SelectPatternFlavor::SPF_NABS) + return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! - // Create the canonical compare: icmp slt LHS 0. - if (!CmpCanonicalized) { - Cmp.setPredicate(ICmpInst::ICMP_SLT); - Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); - if (CmpUsesNegatedOp) - Cmp.setOperand(0, LHS); - } - - // Create the canonical RHS: RHS = sub (0, LHS). - if (!RHSCanonicalized) { - assert(RHS->hasOneUse() && "RHS use number is not right"); - RHS = IC.Builder.CreateNeg(LHS); - if (TVal == LHS) { - // Replace false value. - IC.replaceOperand(Sel, 2, RHS); - FVal = RHS; - } else { - // Replace true value. - IC.replaceOperand(Sel, 1, RHS); - TVal = RHS; - } - } - - // If the select operands do not change, we're done. - if (SPF == SelectPatternFlavor::SPF_NABS) { - if (TVal == LHS) - return &Sel; - assert(FVal == LHS && "Unexpected results from matchSelectPattern"); - } else { - if (FVal == LHS) - return &Sel; - assert(TVal == LHS && "Unexpected results from matchSelectPattern"); - } - - // We are swapping the select operands, so swap the metadata too. - Sel.swapValues(); - Sel.swapProfMetadata(); - return &Sel; -} - -static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, - const SimplifyQuery &Q) { - // If this is a binary operator, try to simplify it with the replaced op - // because we know Op and ReplaceOp are equivalant. - // For example: V = X + 1, Op = X, ReplaceOp = 42 - // Simplifies as: add(42, 1) --> 43 - if (auto *BO = dyn_cast<BinaryOperator>(V)) { - if (BO->getOperand(0) == Op) - return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q); - if (BO->getOperand(1) == Op) - return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q); - } - - return nullptr; + return IC.replaceInstUsesWith(Sel, Abs); } /// If we have a select with an equality comparison, then we know the value in @@ -1180,30 +1093,97 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, /// /// We can't replace %sel with %add unless we strip away the flags. /// TODO: Wrapping flags could be preserved in some cases with better analysis. -static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, - const SimplifyQuery &Q) { +Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, + ICmpInst &Cmp) { if (!Cmp.isEquality()) return nullptr; // Canonicalize the pattern to ICMP_EQ by swapping the select operands. Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); - if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + bool Swapped = false; + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) { std::swap(TrueVal, FalseVal); + Swapped = true; + } + + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Make sure Y cannot be undef though, as we might pick different values for + // undef in the icmp and in f(Y). Additionally, take care to avoid replacing + // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite + // replacement cycle. + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (TrueVal != CmpLHS && + isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + // Even if TrueVal does not simplify, we can directly replace a use of + // CmpLHS with CmpRHS, as long as the instruction is not used anywhere + // else and is safe to speculatively execute (we may end up executing it + // with different operands, which should not cause side-effects or trigger + // undefined behavior). Only do this if CmpRHS is a constant, as + // profitability is not clear for other cases. + // FIXME: The replacement could be performed recursively. + if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant())) + if (auto *I = dyn_cast<Instruction>(TrueVal)) + if (I->hasOneUse() && isSafeToSpeculativelyExecute(I)) + for (Use &U : I->operands()) + if (U == CmpLHS) { + replaceUse(U, CmpRHS); + return &Sel; + } + } + if (TrueVal != CmpRHS && + isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + auto *FalseInst = dyn_cast<Instruction>(FalseVal); + if (!FalseInst) + return nullptr; + + // InstSimplify already performed this fold if it was possible subject to + // current poison-generating flags. Try the transform again with + // poison-generating flags temporarily dropped. + bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { + WasNUW = OBO->hasNoUnsignedWrap(); + WasNSW = OBO->hasNoSignedWrap(); + FalseInst->setHasNoUnsignedWrap(false); + FalseInst->setHasNoSignedWrap(false); + } + if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { + WasExact = PEO->isExact(); + FalseInst->setIsExact(false); + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { + WasInBounds = GEP->isInBounds(); + GEP->setIsInBounds(false); + } // Try each equivalence substitution possibility. // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43 - Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal || - simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal || - simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) { - if (auto *FalseInst = dyn_cast<Instruction>(FalseVal)) - FalseInst->dropPoisonGeneratingFlags(); - return FalseVal; + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ false) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ false) == TrueVal) { + return replaceInstUsesWith(Sel, FalseVal); } + + // Restore poison-generating flags if the transform did not apply. + if (WasNUW) + FalseInst->setHasNoUnsignedWrap(); + if (WasNSW) + FalseInst->setHasNoSignedWrap(); + if (WasExact) + FalseInst->setIsExact(); + if (WasInBounds) + cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); + return nullptr; } @@ -1253,7 +1233,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, APInt::getAllOnesValue( C0->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have all-ones element[s]. - C0 = AddOne(C0); + C0 = InstCombiner::AddOne(C0); std::swap(X, Sel1); break; case ICmpInst::Predicate::ICMP_UGE: @@ -1313,7 +1293,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, APInt::getSignedMaxValue( C2->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have signed max element[s]. - C2 = AddOne(C2); + C2 = InstCombiner::AddOne(C2); LLVM_FALLTHROUGH; case ICmpInst::Predicate::ICMP_SGE: // Also non-canonical, but here we don't need to change C2, @@ -1360,7 +1340,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // and swap the hands of select. static Instruction * tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { + InstCombinerImpl &IC) { ICmpInst::Predicate Pred; Value *X; Constant *C0; @@ -1375,7 +1355,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, // If comparison predicate is non-canonical, then we certainly won't be able // to make it canonical; canonicalizeCmpWithConstant() already tried. - if (!isCanonicalPredicate(Pred)) + if (!InstCombiner::isCanonicalPredicate(Pred)) return nullptr; // If the [input] type of comparison and select type are different, lets abort @@ -1403,7 +1383,8 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return nullptr; // Check the constant we'd have with flipped-strictness predicate. - auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); if (!FlippedStrictness) return nullptr; @@ -1426,10 +1407,10 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, } /// Visit a SelectInst that has an ICmpInst as its first operand. -Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, - ICmpInst *ICI) { - if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) - return replaceInstUsesWith(SI, V); +Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) + return NewSel; if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) return NewSel; @@ -1579,11 +1560,11 @@ static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, /// We have an SPF (e.g. a min or max) of an SPF of the form: /// SPF2(SPF1(A, B), C) -Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, - SelectPatternFlavor SPF1, - Value *A, Value *B, - Instruction &Outer, - SelectPatternFlavor SPF2, Value *C) { +Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner, + SelectPatternFlavor SPF1, Value *A, + Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, + Value *C) { if (Outer.getType() != Inner->getType()) return nullptr; @@ -1900,7 +1881,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { return CallInst::Create(F, {X, Y}); } -Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { +Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && !match(Sel.getFalseValue(), m_Constant(C))) @@ -1966,10 +1947,11 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { Value *CondVal = SI.getCondition(); Constant *CondC; - if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC))) + auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType()); + if (!CondValTy || !match(CondVal, m_Constant(CondC))) return nullptr; - unsigned NumElts = cast<VectorType>(CondVal->getType())->getNumElements(); + unsigned NumElts = CondValTy->getNumElements(); SmallVector<int, 16> Mask; Mask.reserve(NumElts); for (unsigned i = 0; i != NumElts; ++i) { @@ -2001,8 +1983,8 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { /// to a vector select by splatting the condition. A splat may get folded with /// other operations in IR and having all operands of a select be vector types /// is likely better for vector codegen. -static Instruction *canonicalizeScalarSelectOfVecs( - SelectInst &Sel, InstCombiner &IC) { +static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel, + InstCombinerImpl &IC) { auto *Ty = dyn_cast<VectorType>(Sel.getType()); if (!Ty) return nullptr; @@ -2015,8 +1997,8 @@ static Instruction *canonicalizeScalarSelectOfVecs( // select (extelt V, Index), T, F --> select (splat V, Index), T, F // Splatting the extracted condition reduces code (we could directly create a // splat shuffle of the source vector to eliminate the intermediate step). - unsigned NumElts = Ty->getNumElements(); - return IC.replaceOperand(Sel, 0, IC.Builder.CreateVectorSplat(NumElts, Cond)); + return IC.replaceOperand( + Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond)); } /// Reuse bitcasted operands between a compare and select: @@ -2172,7 +2154,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, } /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. -Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) { +Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { Type *Ty = MinMax1.getType(); // We are looking for a tree of: @@ -2293,34 +2275,42 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } -/// Try to reduce a rotate pattern that includes a compare and select into a -/// funnel shift intrinsic. Example: +/// Try to reduce a funnel/rotate pattern that includes a compare and select +/// into a funnel shift intrinsic. Example: /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) /// --> call llvm.fshl.i32(a, a, b) -static Instruction *foldSelectRotate(SelectInst &Sel) { - // The false value of the select must be a rotate of the true value. - Value *Or0, *Or1; - if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) +/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c))) +/// --> call llvm.fshl.i32(a, b, c) +/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c))) +/// --> call llvm.fshr.i32(a, b, c) +static Instruction *foldSelectFunnelShift(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // This must be a power-of-2 type for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) return nullptr; - Value *TVal = Sel.getTrueValue(); - Value *SA0, *SA1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + BinaryOperator *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) return nullptr; - auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); - auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); - if (ShiftOpcode0 == ShiftOpcode1) + Value *SV0, *SV1, *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0), + m_ZExtOrSelf(m_Value(SA0))))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1), + m_ZExtOrSelf(m_Value(SA1))))) || + Or0->getOpcode() == Or1->getOpcode()) return nullptr; - // We have one of these patterns so far: - // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) - // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) - // This must be a power-of-2 rotate for a bitmasking transform to be valid. - unsigned Width = Sel.getType()->getScalarSizeInBits(); - if (!isPowerOf2_32(Width)) - return nullptr; + // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(SV0, SV1); + std::swap(SA0, SA1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); // Check the shift amounts to see if they are an opposite pair. Value *ShAmt; @@ -2331,6 +2321,15 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { else return nullptr; + // We should now have this pattern: + // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1)) + // The false value of the select must be a funnel-shift of the true value: + // IsFShl -> TVal must be SV0 else TVal must be SV1. + bool IsFshl = (ShAmt == SA0); + Value *TVal = Sel.getTrueValue(); + if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1)) + return nullptr; + // Finally, see if the select is filtering out a shift-by-zero. Value *Cond = Sel.getCondition(); ICmpInst::Predicate Pred; @@ -2338,13 +2337,21 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { Pred != ICmpInst::ICMP_EQ) return nullptr; - // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // If this is not a rotate then the select was blocking poison from the + // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. + if (SV0 != SV1) { + if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1)) + SV1 = Builder.CreateFreeze(SV1); + else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0)) + SV0 = Builder.CreateFreeze(SV0); + } + + // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way. // Convert to funnel shift intrinsic. - bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) || - (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); - return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); + ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); + return IntrinsicInst::Create(F, { SV0, SV1, ShAmt }); } static Instruction *foldSelectToCopysign(SelectInst &Sel, @@ -2368,7 +2375,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, bool IsTrueIfSignSet; ICmpInst::Predicate Pred; if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || - !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) + !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || + X->getType() != SelType) return nullptr; // If needed, negate the value that will be the sign argument of the copysign: @@ -2389,7 +2397,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, return CopySign; } -Instruction *InstCombiner::foldVectorSelect(SelectInst &Sel) { +Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); if (!VecTy) return nullptr; @@ -2469,6 +2477,10 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, } else return nullptr; + // Make sure the branches are actually different. + if (TrueSucc == FalseSucc) + return nullptr; + // We want to replace select %cond, %a, %b with a phi that takes value %a // for all incoming edges that are dominated by condition `%cond == true`, // and value %b for edges dominated by condition `%cond == false`. If %a @@ -2515,7 +2527,33 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, return nullptr; } -Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { +static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); + if (!FI) + return nullptr; + + Value *Cond = FI->getOperand(0); + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + + // select (freeze(x == y)), x, y --> y + // select (freeze(x != y)), x, y --> x + // The freeze should be only used by this select. Otherwise, remaining uses of + // the freeze can observe a contradictory value. + // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1 + // a = select c, x, y ; + // f(a, c) ; f(poison, 1) cannot happen, but if a is folded + // ; to y, this can happen. + CmpInst::Predicate Pred; + if (FI->hasOneUse() && + match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) && + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) { + return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); @@ -2551,38 +2589,45 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (SelType->isIntOrIntVectorTy(1) && TrueVal->getType() == CondVal->getType()) { - if (match(TrueVal, m_One())) { + if (match(TrueVal, m_One()) && + (EnableUnsafeSelectTransform || impliesPoison(FalseVal, CondVal))) { // Change: A = select B, true, C --> A = or B, C return BinaryOperator::CreateOr(CondVal, FalseVal); } - if (match(TrueVal, m_Zero())) { - // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateAnd(NotCond, FalseVal); - } - if (match(FalseVal, m_Zero())) { + if (match(FalseVal, m_Zero()) && + (EnableUnsafeSelectTransform || impliesPoison(TrueVal, CondVal))) { // Change: A = select B, C, false --> A = and B, C return BinaryOperator::CreateAnd(CondVal, TrueVal); } + + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Zero())) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, + ConstantInt::getFalse(SelType)); + } + // select a, b, true -> select !a, true, b if (match(FalseVal, m_One())) { - // Change: A = select B, C, true --> A = or !B, C Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateOr(NotCond, TrueVal); + return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType), + TrueVal); } - // select a, a, b -> a | b - // select a, b, a -> a & b + // select a, a, b -> select a, true, b if (CondVal == TrueVal) - return BinaryOperator::CreateOr(CondVal, FalseVal); + return replaceOperand(SI, 1, ConstantInt::getTrue(SelType)); + // select a, b, a -> select a, b, false if (CondVal == FalseVal) - return BinaryOperator::CreateAnd(CondVal, TrueVal); + return replaceOperand(SI, 2, ConstantInt::getFalse(SelType)); - // select a, ~a, b -> (~a) & b - // select a, b, ~a -> (~a) | b + // select a, !a, b -> select !a, b, false if (match(TrueVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateAnd(TrueVal, FalseVal); + return SelectInst::Create(TrueVal, FalseVal, + ConstantInt::getFalse(SelType)); + // select a, b, !a -> select !a, true, b if (match(FalseVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateOr(TrueVal, FalseVal); + return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType), + TrueVal); } // Selecting between two integer or vector splat integer constants? @@ -2591,7 +2636,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> // because that may need 3 instructions to splat the condition value: // extend, insertelement, shufflevector. - if (SelType->isIntOrIntVectorTy() && + // + // Do not handle i1 TrueVal and FalseVal otherwise would result in + // zext/sext i1 to i1. + if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) && CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { // select C, 1, 0 -> zext C to int if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) @@ -2838,8 +2886,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return replaceOperand(SI, 1, TrueSI->getTrueValue()); } // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) - // We choose this as normal form to enable folding on the And and shortening - // paths for the values (this helps GetUnderlyingObjects() for example). + // We choose this as normal form to enable folding on the And and + // shortening paths for the values (this helps getUnderlyingObjects() for + // example). if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition()); replaceOperand(SI, 0, And); @@ -2922,7 +2971,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } Value *NotCond; - if (match(CondVal, m_Not(m_Value(NotCond)))) { + if (match(CondVal, m_Not(m_Value(NotCond))) && + !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) { replaceOperand(SI, 0, NotCond); SI.swapValues(); SI.swapProfMetadata(); @@ -2956,8 +3006,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) return Select; - if (Instruction *Rot = foldSelectRotate(SI)) - return Rot; + if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder)) + return Funnel; if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) return Copysign; @@ -2965,5 +3015,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *PN = foldSelectToPhi(SI, DT, Builder)) return replaceInstUsesWith(SI, PN); + if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder)) + return replaceInstUsesWith(SI, Fr); + return nullptr; } |
