diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 323 |
1 files changed, 189 insertions, 134 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 796b4021d273..faf58a08976d 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -54,34 +54,62 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// Fold -/// %A = icmp eq/ne i8 %x, 0 -/// %B = op i8 %x, %z -/// %C = select i1 %A, i8 %B, i8 %y -/// To -/// %C = select i1 %A, i8 %z, i8 %y -/// OP: binop with an identity constant -/// TODO: support for non-commutative and FP opcodes -static Instruction *foldSelectBinOpIdentity(SelectInst &Sel) { - - Value *Cond = Sel.getCondition(); - Value *X, *Z; +/// Replace a select operand based on an equality comparison with the identity +/// constant of a binop. +static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, + const TargetLibraryInfo &TLI) { + // The select condition must be an equality compare with a constant operand. + Value *X; Constant *C; CmpInst::Predicate Pred; - if (!match(Cond, m_ICmp(Pred, m_Value(X), m_Constant(C))) || - !ICmpInst::isEquality(Pred)) + if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C)))) return nullptr; - bool IsEq = Pred == ICmpInst::ICMP_EQ; - auto *BO = - dyn_cast<BinaryOperator>(IsEq ? Sel.getTrueValue() : Sel.getFalseValue()); - // TODO: support for undefs - if (BO && match(BO, m_c_BinOp(m_Specific(X), m_Value(Z))) && - ConstantExpr::getBinOpIdentity(BO->getOpcode(), X->getType()) == C) { - Sel.setOperand(IsEq ? 1 : 2, Z); - return &Sel; + bool IsEq; + if (ICmpInst::isEquality(Pred)) + IsEq = Pred == ICmpInst::ICMP_EQ; + else if (Pred == FCmpInst::FCMP_OEQ) + IsEq = true; + else if (Pred == FCmpInst::FCMP_UNE) + IsEq = false; + else + return nullptr; + + // A select operand must be a binop. + BinaryOperator *BO; + if (!match(Sel.getOperand(IsEq ? 1 : 2), m_BinOp(BO))) + return nullptr; + + // The compare constant must be the identity constant for that binop. + // If this a floating-point compare with 0.0, any zero constant will do. + Type *Ty = BO->getType(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BO->getOpcode(), Ty, true); + if (IdC != C) { + if (!IdC || !CmpInst::isFPPredicate(Pred)) + return nullptr; + if (!match(IdC, m_AnyZeroFP()) || !match(C, m_AnyZeroFP())) + return nullptr; } - return nullptr; + + // Last, match the compare variable operand with a binop operand. + Value *Y; + if (!BO->isCommutative() && !match(BO, m_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + if (!match(BO, m_c_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + + // +0.0 compares equal to -0.0, and so it does not behave as required for this + // transform. Bail out if we can not exclude that possibility. + if (isa<FPMathOperator>(BO)) + if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + return nullptr; + + // BO = binop Y, X + // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO } + // => + // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y } + Sel.setOperand(IsEq ? 1 : 2, Y); + return &Sel; } /// This folds: @@ -343,13 +371,24 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return nullptr; } + // If the select condition is a vector, the operands of the original select's + // operands also must be vectors. This may not be the case for getelementptr + // for example. + if (SI.getCondition()->getType()->isVectorTy() && + (!OtherOpT->getType()->isVectorTy() || + !OtherOpF->getType()->isVectorTy())) + return nullptr; + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; if (auto *BO = dyn_cast<BinaryOperator>(TI)) { - return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + BinaryOperator *NewBO = BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + NewBO->copyIRFlags(TI); + NewBO->andIRFlags(FI); + return NewBO; } if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { auto *FGEP = cast<GetElementPtrInst>(FI); @@ -670,17 +709,18 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, match(Count, m_Trunc(m_Value(V)))) Count = V; + // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the + // input to the cttz/ctlz is used as LHS for the compare instruction. + if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) && + !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) + return nullptr; + + IntrinsicInst *II = cast<IntrinsicInst>(Count); + // Check if the value propagated on zero is a constant number equal to the // sizeof in bits of 'Count'. unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); - if (!match(ValueOnZero, m_SpecificInt(SizeOfInBits))) - return nullptr; - - // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the - // input to the cttz/ctlz is used as LHS for the compare instruction. - if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) || - match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) { - IntrinsicInst *II = cast<IntrinsicInst>(Count); + if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) { // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); @@ -688,6 +728,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } + // If the ValueOnZero is not the bitwidth, we can at least make use of the + // fact that the cttz/ctlz result will not be used if the input is zero, so + // it's okay to relax it to undef for that case. + if (II->hasOneUse() && !match(II->getArgOperand(1), m_One())) + II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); + return nullptr; } @@ -1054,11 +1100,13 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) + // TODO: This could be done in instsimplify. if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a + // TODO: This could be done in instsimplify. if ((SPF1 == SPF_SMIN && SPF2 == SPF_SMAX) || (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || @@ -1071,6 +1119,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (match(B, m_APInt(CB)) && match(C, m_APInt(CC))) { // MIN(MIN(A, 23), 97) -> MIN(A, 23) // MAX(MAX(A, 97), 23) -> MAX(A, 97) + // TODO: This could be done in instsimplify. if ((SPF1 == SPF_UMIN && CB->ule(*CC)) || (SPF1 == SPF_SMIN && CB->sle(*CC)) || (SPF1 == SPF_UMAX && CB->uge(*CC)) || @@ -1091,6 +1140,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) + // TODO: This could be done in instsimplify. if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { return replaceInstUsesWith(Outer, Inner); } @@ -1503,6 +1553,60 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } +/// Try to reduce a rotate pattern that includes a compare and select into a +/// funnel shift intrinsic. Example: +/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) +/// --> call llvm.fshl.i32(a, a, b) +static Instruction *foldSelectRotate(SelectInst &Sel) { + // The false value of the select must be a rotate of the true value. + Value *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // We have one of these patterns so far: + // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) + // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) + // This must be a power-of-2 rotate for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + // Check the shift amounts to see if they are an opposite pair. + Value *ShAmt; + if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) + ShAmt = SA0; + else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) + ShAmt = SA1; + else + return nullptr; + + // Finally, see if the select is filtering out a shift-by-zero. + Value *Cond = Sel.getCondition(); + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // Convert to funnel shift intrinsic. + bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) || + (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); + return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1617,31 +1721,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { - // Transform (X == Y) ? X : Y -> Y - if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, FalseVal); - } - // Transform (X une Y) ? X : Y -> X - if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, TrueVal); - } - // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -1660,31 +1739,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ - // Transform (X == Y) ? Y : X -> X - if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, FalseVal); - } - // Transform (X une Y) ? Y : X -> Y - if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, TrueVal); - } - // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -1717,7 +1771,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || (X == TrueVal && Pred == FCmpInst::FCMP_OGT && match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { - Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); return replaceInstUsesWith(SI, Fabs); } // With nsz: @@ -1730,7 +1784,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { - Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); return replaceInstUsesWith(SI, Fabs); } } @@ -1759,10 +1813,23 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; - Value *LHS, *RHS, *LHS2, *RHS2; + Value *LHS, *RHS; Instruction::CastOps CastOp; SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp); auto SPF = SPR.Flavor; + if (SPF) { + Value *LHS2, *RHS2; + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS), SPF2, LHS2, + RHS2, SI, SPF, RHS)) + return R; + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2, + RHS2, SI, SPF, LHS)) + return R; + // TODO. + // ABS(-X) -> ABS(X) + } if (SelectPatternResult::isMinOrMax(SPF)) { // Canonicalize so that @@ -1797,39 +1864,40 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // MAX(~a, ~b) -> ~MIN(a, b) + // MAX(~a, C) -> ~MIN(a, ~C) // MIN(~a, ~b) -> ~MAX(a, b) - Value *A, *B; - if (match(LHS, m_Not(m_Value(A))) && match(RHS, m_Not(m_Value(B))) && - (LHS->getNumUses() <= 2 || RHS->getNumUses() <= 2)) { - CmpInst::Predicate InvertedPred = getInverseMinMaxPred(SPF); - Value *InvertedCmp = Builder.CreateICmp(InvertedPred, A, B); - Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); - return BinaryOperator::CreateNot(NewSel); - } + // MIN(~a, C) -> ~MAX(a, ~C) + auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { + Value *A; + if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) && + !IsFreeToInvert(A, A->hasOneUse()) && + // Passing false to only consider m_Not and constants. + IsFreeToInvert(Y, false)) { + Value *B = Builder.CreateNot(Y); + Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF), + A, B); + // Copy the profile metadata. + if (MDNode *MD = SI.getMetadata(LLVMContext::MD_prof)) { + cast<SelectInst>(NewMinMax)->setMetadata(LLVMContext::MD_prof, MD); + // Swap the metadata if the operands are swapped. + if (X == SI.getFalseValue() && Y == SI.getTrueValue()) + cast<SelectInst>(NewMinMax)->swapProfMetadata(); + } - if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return BinaryOperator::CreateNot(NewMinMax); + } + + return nullptr; + }; + + if (Instruction *I = moveNotAfterMinMax(LHS, RHS)) + return I; + if (Instruction *I = moveNotAfterMinMax(RHS, LHS)) return I; - } - if (SPF) { - // MAX(MAX(a, b), a) -> MAX(a, b) - // MIN(MIN(a, b), a) -> MIN(a, b) - // MAX(MIN(a, b), a) -> a - // MIN(MAX(a, b), a) -> a - // ABS(ABS(a)) -> ABS(a) - // NABS(NABS(a)) -> NABS(a) - if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) - if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, - SI, SPF, RHS)) - return R; - if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) - if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, - SI, SPF, LHS)) - return R; + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } - - // TODO. - // ABS(-X) -> ABS(X) } // See if we can fold the select into a phi node if the condition is a select. @@ -1934,10 +2002,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (BinaryOperator::isNot(CondVal)) { - SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); + Value *NotCond; + if (match(CondVal, m_Not(m_Value(NotCond)))) { + SI.setOperand(0, NotCond); SI.setOperand(1, FalseVal); SI.setOperand(2, TrueVal); + SI.swapProfMetadata(); return &SI; } @@ -1952,24 +2022,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - // See if we can determine the result of this select based on a dominating - // condition. - BasicBlock *Parent = SI.getParent(); - if (BasicBlock *Dom = Parent->getSinglePredecessor()) { - auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); - if (PBI && PBI->isConditional() && - PBI->getSuccessor(0) != PBI->getSuccessor(1) && - (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { - bool CondIsTrue = PBI->getSuccessor(0) == Parent; - Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), SI.getCondition(), DL, CondIsTrue); - if (Implication) { - Value *V = *Implication ? TrueVal : FalseVal; - return replaceInstUsesWith(SI, V); - } - } - } - // If we can compute the condition, there's no need for a select. // Like the above fold, we are attempting to reduce compile-time cost by // putting this fold here with limitations rather than in InstSimplify. @@ -1991,8 +2043,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectCmpXchg(SI)) return Select; - if (Instruction *Select = foldSelectBinOpIdentity(SI)) + if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI)) return Select; + if (Instruction *Rot = foldSelectRotate(SI)) + return Rot; + return nullptr; } |