diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 555 |
1 files changed, 411 insertions, 144 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9f220ec003ec..aaf4ece3249a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -99,7 +99,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // transform. Bail out if we can not exclude that possibility. if (isa<FPMathOperator>(BO)) if (!BO->hasNoSignedZeros() && - !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI)) + !cannotBeNegativeZero(Y, 0, + IC.getSimplifyQuery().getWithInstruction(&Sel))) return nullptr; // BO = binop Y, X @@ -201,6 +202,14 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, const APInt &ValC = !TC.isZero() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); + bool ShouldNotVal = !TC.isZero(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + + // If we would need to create an 'and' + 'shift' + 'xor' to replace a 'select' + // + 'icmp', then this transformation would result in more instructions and + // potentially interfere with other folding. + if (CreateAnd && ShouldNotVal && ValZeros != AndZeros) + return nullptr; // Insert the 'and' instruction on the input to the truncate. if (CreateAnd) @@ -220,8 +229,6 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TC.isZero(); - ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -484,10 +491,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, } if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { auto *FGEP = cast<GetElementPtrInst>(FI); - Type *ElementType = TGEP->getResultElementType(); - return TGEP->isInBounds() && FGEP->isInBounds() - ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) - : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + Type *ElementType = TGEP->getSourceElementType(); + return GetElementPtrInst::Create( + ElementType, Op0, Op1, TGEP->getNoWrapFlags() & FGEP->getNoWrapFlags()); } llvm_unreachable("Expected BinaryOperator or GEP"); return nullptr; @@ -535,19 +541,29 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, - Swapped ? OOp : C, "", &SI); - if (isa<FPMathOperator>(&SI)) - cast<Instruction>(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(TVI); - BinaryOperator *BO = - BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; - } - return nullptr; + if (isa<Constant>(OOp) && + (!OOpIsAPInt || !isSelect01(C->getUniqueInteger(), *OOpC))) + return nullptr; + + // If the false value is a NaN then we have that the floating point math + // operation in the transformed code may not preserve the exact NaN + // bit-pattern -- e.g. `fadd sNaN, 0.0 -> qNaN`. + // This makes the transformation incorrect since the original program would + // have preserved the exact NaN bit-pattern. + // Avoid the folding if the false value might be a NaN. + if (isa<FPMathOperator>(&SI) && + !computeKnownFPClass(FalseVal, FMF, fcNan, &SI).isKnownNeverNaN()) + return nullptr; + + Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, + Swapped ? OOp : C, "", &SI); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; }; if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) @@ -1116,7 +1132,7 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -1158,6 +1174,9 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // Explicitly clear the 'is_zero_poison' flag. It's always valid to go from // true to false on this flag, so we can replace it for all users. II->setArgOperand(1, ConstantInt::getFalse(II->getContext())); + // A range annotation on the intrinsic may no longer be valid. + II->dropPoisonGeneratingAnnotations(); + IC.addToWorklist(II); return SelectArg; } @@ -1190,7 +1209,7 @@ static Value *canonicalizeSPF(ICmpInst &Cmp, Value *TrueVal, Value *FalseVal, match(RHS, m_NSWNeg(m_Specific(LHS))); Constant *IntMinIsPoisonC = ConstantInt::get(Type::getInt1Ty(Cmp.getContext()), IntMinIsPoison); - Instruction *Abs = + Value *Abs = IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); if (SPF == SelectPatternFlavor::SPF_NABS) @@ -1228,8 +1247,11 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New, if (Depth == 2) return false; + assert(!isa<Constant>(Old) && "Only replace non-constant values"); + auto *I = dyn_cast<Instruction>(V); - if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I)) + if (!I || !I->hasOneUse() || + !isSafeToSpeculativelyExecuteWithVariableReplaced(I)) return false; bool Changed = false; @@ -1274,22 +1296,36 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, Swapped = true; } - // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. - // Make sure Y cannot be undef though, as we might pick different values for - // undef in the icmp and in f(Y). Additionally, take care to avoid replacing - // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite - // replacement cycle. Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (TrueVal != CmpLHS && - isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ true)) - // Require either the replacement or the simplification result to be a - // constant to avoid infinite loops. - // FIXME: Make this check more precise. - if (isa<Constant>(CmpRHS) || isa<Constant>(V)) + auto ReplaceOldOpWithNewOp = [&](Value *OldOp, + Value *NewOp) -> Instruction * { + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that + // would lead to an infinite replacement cycle. + // If we will be able to evaluate f(Y) to a constant, we can allow undef, + // otherwise Y cannot be undef as we might pick different values for undef + // in the icmp and in f(Y). + if (TrueVal == OldOp) + return nullptr; + + if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ, + /* AllowRefinement=*/true)) { + // Need some guarantees about the new simplified op to ensure we don't inf + // loop. + // If we simplify to a constant, replace if we aren't creating new undef. + if (match(V, m_ImmConstant()) && + isGuaranteedNotToBeUndef(V, SQ.AC, &Sel, &DT)) return replaceOperand(Sel, Swapped ? 2 : 1, V); + // If NewOp is a constant and OldOp is not replace iff NewOp doesn't + // contain and undef elements. + if (match(NewOp, m_ImmConstant()) || NewOp == V) { + if (isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + return nullptr; + } + } + // Even if TrueVal does not simplify, we can directly replace a use of // CmpLHS with CmpRHS, as long as the instruction is not used anywhere // else and is safe to speculatively execute (we may end up executing it @@ -1297,17 +1333,18 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // undefined behavior). Only do this if CmpRHS is a constant, as // profitability is not clear for other cases. // FIXME: Support vectors. - if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && - !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) + if (OldOp == CmpLHS && match(NewOp, m_ImmConstant()) && + !match(OldOp, m_Constant()) && !Cmp.getType()->isVectorTy() && + isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT)) + if (replaceInInstruction(TrueVal, OldOp, NewOp)) return &Sel; - } - if (TrueVal != CmpRHS && - isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ true)) - if (isa<Constant>(CmpLHS) || isa<Constant>(V)) - return replaceOperand(Sel, Swapped ? 2 : 1, V); + return nullptr; + }; + + if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS)) + return R; + if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS)) + return R; auto *FalseInst = dyn_cast<Instruction>(FalseVal); if (!FalseInst) @@ -1329,7 +1366,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, /* AllowRefinement */ false, &DropFlags) == TrueVal) { for (Instruction *I : DropFlags) { - I->dropPoisonGeneratingFlagsAndMetadata(); + I->dropPoisonGeneratingAnnotations(); Worklist.add(I); } @@ -1354,7 +1391,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // Also ULT predicate can also be UGT iff C0 != -1 (+invert result) // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, + InstCombiner &IC) { Value *X = Sel0.getTrueValue(); Value *Sel1 = Sel0.getFalseValue(); @@ -1482,14 +1520,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, std::swap(ThresholdLowIncl, ThresholdHighExcl); // The fold has a precondition 1: C2 s>= ThresholdLow - auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, - ThresholdLowIncl); - if (!match(Precond1, m_One())) + auto *Precond1 = ConstantFoldCompareInstOperands( + ICmpInst::Predicate::ICMP_SGE, C2, ThresholdLowIncl, IC.getDataLayout()); + if (!Precond1 || !match(Precond1, m_One())) return nullptr; // The fold has a precondition 2: C2 s<= ThresholdHigh - auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, - ThresholdHighExcl); - if (!match(Precond2, m_One())) + auto *Precond2 = ConstantFoldCompareInstOperands( + ICmpInst::Predicate::ICMP_SLE, C2, ThresholdHighExcl, IC.getDataLayout()); + if (!Precond2 || !match(Precond2, m_One())) return nullptr; // If we are matching from a truncated input, we need to sext the @@ -1500,7 +1538,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(ReplacementLow, m_ImmConstant(LowC)) || !match(ReplacementHigh, m_ImmConstant(HighC))) return nullptr; - const DataLayout &DL = Sel0.getModule()->getDataLayout(); + const DataLayout &DL = Sel0.getDataLayout(); ReplacementLow = ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL); ReplacementHigh = @@ -1610,7 +1648,7 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, return nullptr; const APInt *CmpC; - if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC))) + if (!match(Cmp->getOperand(1), m_APIntAllowPoison(CmpC))) return nullptr; // (X u< 2) ? -X : -1 --> sext (X != 0) @@ -1676,6 +1714,109 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, return nullptr; } +static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI, + InstCombinerImpl &IC) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isEquality(Pred)) + return nullptr; + + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + + if (Pred == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Transform (X == C) ? X : Y -> (X == C) ? C : Y + // specific handling for Bitwise operation. + // x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y) + // x|y -> (x&y) | (x^y) or (x&y) ^ (x^y) + // x^y -> (x|y) ^ (x&y) or (x|y) & ~(x&y) + Value *X, *Y; + if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) || + !match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y)))) + return nullptr; + + const unsigned AndOps = Instruction::And, OrOps = Instruction::Or, + XorOps = Instruction::Xor, NoOps = 0; + enum NotMask { None = 0, NotInner, NotRHS }; + + auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc, + unsigned NotMask) { + auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y)); + if (OuterOpc == NoOps) + return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner); + + if (NotMask == NotInner) { + return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner), + m_Specific(CmpRHS))); + } else if (NotMask == NotRHS) { + return match(FalseVal, m_c_BinOp(OuterOpc, matchInner, + m_NotForbidPoison(m_Specific(CmpRHS)))); + } else { + return match(FalseVal, + m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS))); + } + }; + + // (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y or (X^Y)^ C : X^Y + // (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y + if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) { + if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) { + // (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C + // (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C + if (matchFalseVal(OrOps, XorOps, None) || + matchFalseVal(XorOps, XorOps, None)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) { + // (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C + // (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C + if (matchFalseVal(XorOps, OrOps, None) || + matchFalseVal(AndOps, OrOps, NotRHS)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + // (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y or ~(X^Y)&C : X^Y + // (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y or ~(X&Y)&C : X&Y + if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) { + if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) { + // (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C -> (X^Y)^C + // (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C + if (matchFalseVal(XorOps, XorOps, None) || + matchFalseVal(AndOps, XorOps, NotInner)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) { + // (X|Y)==C ? X^Y : (X&Y)^C -> (X&Y)^C : (X&Y)^C -> (X&Y)^C + // (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C + if (matchFalseVal(XorOps, AndOps, None) || + matchFalseVal(AndOps, AndOps, NotInner)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + // (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y + // (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y or (X&Y)^ C : X&Y + if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) { + if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) { + // (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C + // (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C + if (matchFalseVal(XorOps, OrOps, None) || + matchFalseVal(AndOps, OrOps, NotRHS)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) { + // (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C + // (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C + if (matchFalseVal(OrOps, AndOps, None) || + matchFalseVal(XorOps, AndOps, None)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + return nullptr; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -1689,7 +1830,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); - if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) + if (Value *V = canonicalizeClampLike(SI, *ICI, Builder, *this)) return replaceInstUsesWith(SI, V); if (Instruction *NewSel = @@ -1718,6 +1859,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, } } + if (Instruction *NewSel = foldSelectICmpEq(SI, ICI, *this)) + return NewSel; + // Canonicalize a signbit condition to use zero constant by swapping: // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV // To avoid conflicts (infinite loops) with other canonicalizations, this is @@ -1803,7 +1947,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); - if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, *this)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) @@ -2223,20 +2367,20 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, /// operand, the result of the select will always be equal to its false value. /// For example: /// -/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst -/// %1 = extractvalue { i64, i1 } %0, 1 -/// %2 = extractvalue { i64, i1 } %0, 0 -/// %3 = select i1 %1, i64 %compare, i64 %2 -/// ret i64 %3 +/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %val = extractvalue { i64, i1 } %cmpxchg, 0 +/// %success = extractvalue { i64, i1 } %cmpxchg, 1 +/// %sel = select i1 %success, i64 %compare, i64 %val +/// ret i64 %sel /// -/// The returned value of the cmpxchg instruction (%2) is the original value -/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// The returned value of the cmpxchg instruction (%val) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %val /// must have been equal to %compare. Thus, the result of the select is always -/// equal to %2, and the code can be simplified to: +/// equal to %val, and the code can be simplified to: /// -/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst -/// %1 = extractvalue { i64, i1 } %0, 0 -/// ret i64 %1 +/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %val = extractvalue { i64, i1 } %cmpxchg, 0 +/// ret i64 %val /// static Value *foldSelectCmpXchg(SelectInst &SI) { // A helper that determines if V is an extractvalue instruction whose @@ -2369,14 +2513,11 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, Value *FVal = Sel.getFalseValue(); Type *SelType = Sel.getType(); - if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType()) - return nullptr; - // Match select ?, TC, FC where the constants are equal but negated. // TODO: Generalize to handle a negated variable operand? const APFloat *TC, *FC; - if (!match(TVal, m_APFloatAllowUndef(TC)) || - !match(FVal, m_APFloatAllowUndef(FC)) || + if (!match(TVal, m_APFloatAllowPoison(TC)) || + !match(FVal, m_APFloatAllowPoison(FC)) || !abs(*TC).bitwiseIsEqual(abs(*FC))) return nullptr; @@ -2386,9 +2527,9 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, const APInt *C; bool IsTrueIfSignSet; ICmpInst::Predicate Pred; - if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || - !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || - X->getType() != SelType) + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)), + m_APInt(C)))) || + !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) return nullptr; // If needed, negate the value that will be the sign argument of the copysign: @@ -2423,8 +2564,8 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { if (auto *I = dyn_cast<Instruction>(V)) I->copyIRFlags(&Sel); Module *M = Sel.getModule(); - Function *F = Intrinsic::getDeclaration( - M, Intrinsic::experimental_vector_reverse, V->getType()); + Function *F = + Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType()); return CallInst::Create(F, V); }; @@ -2587,7 +2728,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, bool TrueIfSigned = false; if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) && - IC.isSignBitCheck(Pred, *C, TrueIfSigned))) + isSignBitCheck(Pred, *C, TrueIfSigned))) return nullptr; // If the sign bit is not set, we have a SGE/SGT comparison, and the operands @@ -2606,7 +2747,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, // %cnd = icmp slt i32 %rem, 0 // %add = add i32 %rem, %n // %sel = select i1 %cnd, i32 %add, i32 %rem - if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) && + if (match(TrueVal, m_Add(m_Specific(RemRes), m_Value(Remainder))) && match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && FalseVal == RemRes) @@ -2650,46 +2791,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy return nullptr; } +/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI. +static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI, + Value *CondVal, + bool CondIsTrue, + const DataLayout &DL) { + Value *InnerCondVal = SI.getCondition(); + Value *InnerTrueVal = SI.getTrueValue(); + Value *InnerFalseVal = SI.getFalseValue(); + assert(CondVal->getType() == InnerCondVal->getType() && + "The type of inner condition must match with the outer."); + if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue)) + return *Implied ? InnerTrueVal : InnerFalseVal; + return nullptr; +} + Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, bool IsAnd) { - Value *CondVal = SI.getCondition(); - Value *A = SI.getTrueValue(); - Value *B = SI.getFalseValue(); - assert(Op->getType()->isIntOrIntVectorTy(1) && "Op must be either i1 or vector of i1."); - - std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); - if (!Res) + if (SI.getCondition()->getType() != Op->getType()) return nullptr; - - Value *Zero = Constant::getNullValue(A->getType()); - Value *One = Constant::getAllOnesValue(A->getType()); - - if (*Res == true) { - if (IsAnd) - // select op, (select cond, A, B), false => select op, A, false - // and op, (select cond, A, B) => select op, A, false - // if op = true implies condval = true. - return SelectInst::Create(Op, A, Zero); - else - // select op, true, (select cond, A, B) => select op, true, A - // or op, (select cond, A, B) => select op, true, A - // if op = false implies condval = true. - return SelectInst::Create(Op, One, A); - } else { - if (IsAnd) - // select op, (select cond, A, B), false => select op, B, false - // and op, (select cond, A, B) => select op, B, false - // if op = true implies condval = false. - return SelectInst::Create(Op, B, Zero); - else - // select op, true, (select cond, A, B) => select op, true, B - // or op, (select cond, A, B) => select op, true, B - // if op = false implies condval = false. - return SelectInst::Create(Op, One, B); - } + if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) + return SelectInst::Create(Op, + IsAnd ? V : ConstantInt::getTrue(Op->getType()), + IsAnd ? ConstantInt::getFalse(Op->getType()) : V); + return nullptr; } // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need @@ -2772,6 +2900,36 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } + // Match select with (icmp slt (bitcast X to int), 0) + // or (icmp sgt (bitcast X to int), -1) + + for (bool Swap : {false, true}) { + Value *TrueVal = SI.getTrueValue(); + Value *X = SI.getFalseValue(); + + if (Swap) + std::swap(TrueVal, X); + + CmpInst::Predicate Pred; + const APInt *C; + bool TrueIfSigned; + if (!match(CondVal, + m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) || + !isSignBitCheck(Pred, *C, TrueIfSigned)) + continue; + if (!match(TrueVal, m_FNeg(m_Specific(X)))) + return nullptr; + if (Swap == TrueIfSigned && !CondVal->hasOneUse() && !TrueVal->hasOneUse()) + return nullptr; + + // Fold (IsNeg ? -X : X) or (!IsNeg ? X : -X) to fabs(X) + // Fold (IsNeg ? X : -X) or (!IsNeg ? -X : X) to -fabs(X) + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + if (Swap != TrueIfSigned) + return IC.replaceInstUsesWith(SI, Fabs); + return UnaryOperator::CreateFNegFMF(Fabs, &SI); + } + return ChangedFMF ? &SI : nullptr; } @@ -2808,17 +2966,17 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, // FIXME: we could support non non-splats here. const APInt *LowBitMaskCst; - if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) + if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowPoison(LowBitMaskCst)))) return nullptr; // Match even if the AND and ADD are swapped. const APInt *BiasCst, *HighBitMaskCst; if (!match(XBiasedHighBits, - m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), - m_APIntAllowUndef(HighBitMaskCst))) && + m_And(m_Add(m_Specific(X), m_APIntAllowPoison(BiasCst)), + m_APIntAllowPoison(HighBitMaskCst))) && !match(XBiasedHighBits, - m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)), - m_APIntAllowUndef(BiasCst)))) + m_Add(m_And(m_Specific(X), m_APIntAllowPoison(HighBitMaskCst)), + m_APIntAllowPoison(BiasCst)))) return nullptr; if (!LowBitMaskCst->isMask()) @@ -2834,7 +2992,8 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, return nullptr; if (!XBiasedHighBits->hasOneUse()) { - if (*BiasCst == *LowBitMaskCst) + // We can't directly return XBiasedHighBits if it is more poisonous. + if (*BiasCst == *LowBitMaskCst && impliesPoison(XBiasedHighBits, X)) return XBiasedHighBits; return nullptr; } @@ -2856,6 +3015,32 @@ struct DecomposedSelect { }; } // namespace +/// Folds patterns like: +/// select c2 (select c1 a b) (select c1 b a) +/// into: +/// select (xor c1 c2) b a +static Instruction * +foldSelectOfSymmetricSelect(SelectInst &OuterSelVal, + InstCombiner::BuilderTy &Builder) { + + Value *OuterCond, *InnerCond, *InnerTrueVal, *InnerFalseVal; + if (!match( + &OuterSelVal, + m_Select(m_Value(OuterCond), + m_OneUse(m_Select(m_Value(InnerCond), m_Value(InnerTrueVal), + m_Value(InnerFalseVal))), + m_OneUse(m_Select(m_Deferred(InnerCond), + m_Deferred(InnerFalseVal), + m_Deferred(InnerTrueVal)))))) + return nullptr; + + if (OuterCond->getType() != InnerCond->getType()) + return nullptr; + + Value *Xor = Builder.CreateXor(InnerCond, OuterCond); + return SelectInst::Create(Xor, InnerFalseVal, InnerTrueVal); +} + /// Look for patterns like /// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false /// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f @@ -2960,6 +3145,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return BinaryOperator::CreateOr(CondVal, FalseVal); } + if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_One(), m_Value(B)))) && + impliesPoison(FalseVal, B)) { + // (A || B) || C --> A || (B | C) + return replaceInstUsesWith( + SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal))); + } + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, @@ -3001,6 +3193,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return BinaryOperator::CreateAnd(CondVal, TrueVal); } + if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_Value(B), m_Zero()))) && + impliesPoison(TrueVal, B)) { + // (A && B) && C --> A && (B & C) + return replaceInstUsesWith( + SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal))); + } + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, @@ -3115,11 +3314,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return replaceInstUsesWith(SI, Op1); } - if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) - if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, - /* IsAnd */ IsAnd)) - return I; - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, @@ -3201,7 +3395,8 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { // pattern. static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, const APInt *Cond1, Value *CtlzOp, - unsigned BitWidth) { + unsigned BitWidth, + bool &ShouldDropNUW) { // The challenge in recognizing std::bit_ceil(X) is that the operand is used // for the CTLZ proper and select condition, each possibly with some // operation like add and sub. @@ -3224,6 +3419,8 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, ConstantRange CR = ConstantRange::makeExactICmpRegion( CmpInst::getInversePredicate(Pred), *Cond1); + ShouldDropNUW = false; + // Match the operation that's used to compute CtlzOp from CommonAncestor. If // CtlzOp == CommonAncestor, return true as no operation is needed. If a // match is found, execute the operation on CR, update CR, and return true. @@ -3237,6 +3434,7 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, return true; } if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + ShouldDropNUW = true; CR = ConstantRange(*C).sub(CR); return true; } @@ -3306,14 +3504,20 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { Pred = CmpInst::getInversePredicate(Pred); } + bool ShouldDropNUW; + if (!match(FalseVal, m_One()) || !match(TrueVal, m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), m_Value(Ctlz)))))) || !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || - !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth, + ShouldDropNUW)) return nullptr; + if (ShouldDropNUW) + cast<Instruction>(CtlzOp)->setHasNoUnsignedWrap(false); + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth // is an integer constant. Masking with BitWidth-1 comes free on some @@ -3350,6 +3554,33 @@ static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0, return false; } +/// Check whether the KnownBits of a select arm may be affected by the +/// select condition. +static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected, + unsigned Depth) { + if (Depth == MaxAnalysisRecursionDepth) + return false; + + // Ignore the case where the select arm itself is affected. These cases + // are handled more efficiently by other optimizations. + if (Depth != 0 && Affected.contains(V)) + return true; + + if (auto *I = dyn_cast<Instruction>(V)) { + if (isa<PHINode>(I)) { + if (Depth == MaxAnalysisRecursionDepth - 1) + return false; + Depth = MaxAnalysisRecursionDepth - 2; + } + return any_of(I->operands(), [&](Value *Op) { + return Op->getType()->isIntOrIntVectorTy() && + hasAffectedValue(Op, Affected, Depth + 1); + }); + } + + return false; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3536,16 +3767,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *Idx = Gep->getOperand(1); if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) return nullptr; - Type *ElementType = Gep->getResultElementType(); + Type *ElementType = Gep->getSourceElementType(); Value *NewT = Idx; Value *NewF = Constant::getNullValue(Idx->getType()); if (Swap) std::swap(NewT, NewF); Value *NewSI = Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); - if (Gep->isInBounds()) - return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}); - return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); + return GetElementPtrInst::Create(ElementType, Ptr, NewSI, + Gep->getNoWrapFlags()); }; if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false)) @@ -3620,12 +3850,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { if (TrueSI->getCondition()->getType() == CondVal->getType()) { - // select(C, select(C, a, b), c) -> select(C, a, c) - if (TrueSI->getCondition() == CondVal) { - if (SI.getTrueValue() == TrueSI->getTrueValue()) - return nullptr; - return replaceOperand(SI, 1, TrueSI->getTrueValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *TrueSI, CondVal, /*CondIsTrue=*/true, DL)) + return replaceOperand(SI, 1, V); + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) // We choose this as normal form to enable folding on the And and // shortening paths for the values (this helps getUnderlyingObjects() for @@ -3640,12 +3870,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) { if (FalseSI->getCondition()->getType() == CondVal->getType()) { - // select(C, a, select(C, b, c)) -> select(C, a, c) - if (FalseSI->getCondition() == CondVal) { - if (SI.getFalseValue() == FalseSI->getFalseValue()) - return nullptr; - return replaceOperand(SI, 2, FalseSI->getFalseValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *FalseSI, CondVal, /*CondIsTrue=*/false, DL)) + return replaceOperand(SI, 2, V); + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition()); @@ -3786,6 +4016,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder)) + return I; + if (Instruction *I = foldNestedSelects(SI, Builder)) return I; @@ -3844,5 +4077,39 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + // select Cond, !X, X -> xor Cond, X + if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal)) + return BinaryOperator::CreateXor(CondVal, FalseVal); + + // For vectors, this transform is only safe if the simplification does not + // look through any lane-crossing operations. For now, limit to scalars only. + if (SelType->isIntegerTy() && + (!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) { + // Try to simplify select arms based on KnownBits implied by the condition. + CondContext CC(CondVal); + findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) { + CC.AffectedValues.insert(V); + }); + SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC); + if (!CC.AffectedValues.empty()) { + if (!isa<Constant>(TrueVal) && + hasAffectedValue(TrueVal, CC.AffectedValues, /*Depth=*/0)) { + KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q); + if (Known.isConstant()) + return replaceOperand(SI, 1, + ConstantInt::get(SelType, Known.getConstant())); + } + + CC.Invert = true; + if (!isa<Constant>(FalseVal) && + hasAffectedValue(FalseVal, CC.AffectedValues, /*Depth=*/0)) { + KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q); + if (Known.isConstant()) + return replaceOperand(SI, 2, + ConstantInt::get(SelType, Known.getConstant())); + } + } + } + return nullptr; } |