diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 643 |
1 files changed, 416 insertions, 227 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index b5bbb09935e2..3a4283ae5406 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1,9 +1,8 @@ //===- InstCombineCompares.cpp --------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -704,7 +703,10 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, continue; if (auto *CI = dyn_cast<CastInst>(Val)) { - NewInsts[CI] = NewInsts[CI->getOperand(0)]; + // Don't get rid of the intermediate variable here; the store can grow + // the map which will invalidate the reference to the input value. + Value *V = NewInsts[CI->getOperand(0)]; + NewInsts[CI] = V; continue; } if (auto *GEP = dyn_cast<GEPOperator>(Val)) { @@ -1292,8 +1294,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // use the sadd_with_overflow intrinsic to efficiently compute both the // result and the overflow bit. Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::sadd_with_overflow, NewType); InstCombiner::BuilderTy &Builder = IC.Builder; @@ -1315,14 +1317,16 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } -// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +// Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); - Value *X = Cmp.getOperand(0); + if (!match(Cmp.getOperand(1), m_Zero())) + return nullptr; - if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (Pred == ICmpInst::ICMP_SGT) { Value *A, *B; - SelectPatternResult SPR = matchSelectPattern(X, A, B); + SelectPatternResult SPR = matchSelectPattern(Cmp.getOperand(0), A, B); if (SPR.Flavor == SPF_SMIN) { if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) return new ICmpInst(Pred, B, Cmp.getOperand(1)); @@ -1330,6 +1334,20 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { return new ICmpInst(Pred, A, Cmp.getOperand(1)); } } + + // Given: + // icmp eq/ne (urem %x, %y), 0 + // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': + // icmp eq/ne %x, 0 + Value *X, *Y; + if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) && + ICmpInst::isEquality(Pred)) { + KnownBits XKnown = computeKnownBits(X, 0, &Cmp); + KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); + if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + } + return nullptr; } @@ -1624,20 +1642,43 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1) { + bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; + // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 // TODO: We canonicalize to the longer form for scalars because we have // better analysis/folds for icmp, and codegen may be better with icmp. - if (Cmp.getPredicate() == CmpInst::ICMP_NE && Cmp.getType()->isVectorTy() && - C1.isNullValue() && match(And->getOperand(1), m_One())) + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && + match(And->getOperand(1), m_One())) return new TruncInst(And->getOperand(0), Cmp.getType()); const APInt *C2; - if (!match(And->getOperand(1), m_APInt(C2))) + Value *X; + if (!match(And, m_And(m_Value(X), m_APInt(C2)))) return nullptr; + // Don't perform the following transforms if the AND has multiple uses if (!And->hasOneUse()) return nullptr; + if (Cmp.isEquality() && C1.isNullValue()) { + // Restrict this fold to single-use 'and' (PR10267). + // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 + if (C2->isSignMask()) { + Constant *Zero = Constant::getNullValue(X->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, X, Zero); + } + + // Restrict this fold only for single-use 'and' (PR10267). + // ((%x & C) == 0) --> %x u< (-C) iff (-C) is power of two. + if ((~(*C2) + 1).isPowerOf2()) { + Constant *NegBOC = + ConstantExpr::getNeg(cast<Constant>(And->getOperand(1))); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, NegBOC); + } + } + // If the LHS is an 'and' of a truncate and we can widen the and/compare to // the input width without changing the value produced, eliminate the cast: // @@ -1772,13 +1813,22 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, ConstantInt::get(V->getType(), 1)); } - // X | C == C --> X <=u C - // X | C != C --> X >u C - // iff C+1 is a power of 2 (C is a bitmask of the low bits) - if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && - (C + 1).isPowerOf2()) { - Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; - return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); + Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); + if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) { + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + if ((C + 1).isPowerOf2()) { + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, OrOp0, OrOp1); + } + // More general: are all bits outside of a mask constant set or not set? + // X | C == C --> (X & ~C) == 0 + // X | C != C --> (X & ~C) != 0 + if (Or->hasOneUse()) { + Value *A = Builder.CreateAnd(OrOp0, ~C); + return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType())); + } } if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) @@ -1799,8 +1849,8 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // Are we using xors to bitwise check for a pair of (in)equalities? Convert to // a shorter form that has more potential to be folded even further. Value *X1, *X2, *X3, *X4; - if (match(Or->getOperand(0), m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && - match(Or->getOperand(1), m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { + if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && + match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); @@ -1994,6 +2044,27 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, And, Constant::getNullValue(ShType)); } + // Simplify 'shl' inequality test into 'and' equality test. + if (Cmp.isUnsigned() && Shl->hasOneUse()) { + // (X l<< C2) u<=/u> C1 iff C1+1 is power of two -> X & (~C1 l>> C2) ==/!= 0 + if ((C + 1).isPowerOf2() && + (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) { + Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + // (X l<< C2) u</u>= C1 iff C1 is power of two -> X & (-C1 l>> C2) ==/!= 0 + if (C.isPowerOf2() && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { + Value *And = + Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + } + // Transform (icmp pred iM (shl iM %v, N), C) // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. @@ -2313,6 +2384,16 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); + const APInt *C2; + APInt SubResult; + + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + if (match(X, m_APInt(C2)) && + ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || + (Cmp.isSigned() && Sub->hasNoSignedWrap())) && + !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) + return new ICmpInst(Cmp.getSwappedPredicate(), Y, + ConstantInt::get(Y->getType(), SubResult)); // The following transforms are only worth it if the only user of the subtract // is the icmp. @@ -2337,7 +2418,6 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } - const APInt *C2; if (!match(X, m_APInt(C2))) return nullptr; @@ -2482,20 +2562,76 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, // the entire original Cmp can be simplified to a false. Value *Cond = Builder.getFalse(); if (TrueWhenLessThan) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, + OrigLHS, OrigRHS)); if (TrueWhenEqual) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, + OrigLHS, OrigRHS)); if (TrueWhenGreaterThan) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, + OrigLHS, OrigRHS)); return replaceInstUsesWith(Cmp, Cond); } return nullptr; } -Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, - BitCastInst *Bitcast, - const APInt &C) { +static Instruction *foldICmpBitCast(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); + if (!Bitcast) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op1 = Cmp.getOperand(1); + Value *BCSrcOp = Bitcast->getOperand(0); + + // Make sure the bitcast doesn't change the number of vector elements. + if (Bitcast->getSrcTy()->getScalarSizeInBits() == + Bitcast->getDestTy()->getScalarSizeInBits()) { + // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. + Value *X; + if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { + // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 + // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 + // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && + match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + + // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 + if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) + return new ICmpInst(Pred, X, + ConstantInt::getAllOnesValue(X->getType())); + } + + // Zero-equality checks are preserved through unsigned floating-point casts: + // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 + if (match(BCSrcOp, m_UIToFP(m_Value(X)))) + if (Cmp.isEquality() && match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + } + + // Test to see if the operands of the icmp are casted versions of other + // values. If the ptr->ptr cast can be stripped off both arguments, do so. + if (Bitcast->getType()->isPointerTy() && + (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { + // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast + // so eliminate it as well. + if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) + Op1 = BC2->getOperand(0); + + Op1 = Builder.CreateBitCast(Op1, BCSrcOp->getType()); + return new ICmpInst(Pred, BCSrcOp, Op1); + } + // Folding: icmp <pred> iN X, C // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN // and C is a splat of a K-bit pattern @@ -2503,28 +2639,28 @@ Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, // Into: // %E = extractelement <M x iK> %vec, i32 C' // icmp <pred> iK %E, trunc(C) - if (!Bitcast->getType()->isIntegerTy() || + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || + !Bitcast->getType()->isIntegerTy() || !Bitcast->getSrcTy()->isIntOrIntVectorTy()) return nullptr; - Value *BCIOp = Bitcast->getOperand(0); - Value *Vec = nullptr; // 1st vector arg of the shufflevector - Constant *Mask = nullptr; // Mask arg of the shufflevector - if (match(BCIOp, + Value *Vec; + Constant *Mask; + if (match(BCSrcOp, m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { // Check whether every element of Mask is the same constant if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { - auto *VecTy = cast<VectorType>(BCIOp->getType()); + auto *VecTy = cast<VectorType>(BCSrcOp->getType()); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); - auto Pred = Cmp.getPredicate(); - if (C.isSplat(EltTy->getBitWidth())) { + if (C->isSplat(EltTy->getBitWidth())) { // Fold the icmp based on the value of C // If C is M copies of an iK sized bit pattern, // then: // => %E = extractelement <N x iK> %vec, i32 Elem // icmp <pred> iK %SplatVal, <pattern> Value *Extract = Builder.CreateExtractElement(Vec, Elem); - Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth())); + Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); return new ICmpInst(Pred, Extract, NewC); } } @@ -2606,13 +2742,9 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } - if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) { - if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C)) + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) return I; - } - - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) - return I; return nullptr; } @@ -2711,24 +2843,6 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, if (C == *BOC && C.isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); - - // Don't perform the following transforms if the AND has multiple uses - if (!BO->hasOneUse()) - break; - - // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->isSignMask()) { - Constant *Zero = Constant::getNullValue(BOp0->getType()); - auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(NewPred, BOp0, Zero); - } - - // ((X & ~7) == 0) --> X < 8 - if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { - Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); - auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(NewPred, BOp0, NegBOC); - } } break; } @@ -2756,14 +2870,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, return nullptr; } -/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. -Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - const APInt &C) { - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); - if (!II || !Cmp.isEquality()) - return nullptr; - - // Handle icmp {eq|ne} <intrinsic>, Constant. +/// Fold an equality icmp with LLVM intrinsic and constant operand. +Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { @@ -2823,6 +2933,65 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, return nullptr; } +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { + if (Cmp.isEquality()) + return foldICmpEqIntrinsicWithConstant(Cmp, II, C); + + Type *Ty = II->getType(); + unsigned BitWidth = C.getBitWidth(); + switch (II->getIntrinsicID()) { + case Intrinsic::ctlz: { + // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + + // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && + C.uge(1) && C.ule(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + break; + } + case Intrinsic::cttz: { + // Limit to one use to ensure we don't increase instruction count. + if (!II->hasOneUse()) + return nullptr; + + // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + + // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && + C.uge(1) && C.ule(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + break; + } + default: + break; + } + + return nullptr; +} + /// Handle icmp with constant (but not simple integer constant) RHS. Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2983,6 +3152,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x s> x & (-1 >> y) -> x s> (-1 >> y) if (X != I.getOperand(0)) // X must be on LHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SGT; break; case ICmpInst::Predicate::ICMP_SGE: @@ -3009,6 +3182,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x s<= x & (-1 >> y) -> x s<= (-1 >> y) if (X != I.getOperand(0)) // X must be on LHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SLE; break; default: @@ -3093,6 +3270,64 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I, return T1; } +// Given pattern: +// icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 +// we should move shifts to the same hand of 'and', i.e. rewrite as +// icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) +// We are only interested in opposite logical shifts here. +// If we can, we want to end up creating 'lshr' shift. +static Value * +foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, + InstCombiner::BuilderTy &Builder) { + if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) || + !I.getOperand(0)->hasOneUse()) + return nullptr; + + auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); + auto m_AnyLShr = m_LShr(m_Value(), m_Value()); + + // Look for an 'and' of two (opposite) logical shifts. + // Pick the single-use shift as XShift. + Value *XShift, *YShift; + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), + m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) + return nullptr; + + // If YShift is a single-use 'lshr', swap the shifts around. + if (match(YShift, m_OneUse(m_AnyLShr))) + std::swap(XShift, YShift); + + // The shifts must be in opposite directions. + Instruction::BinaryOps XShiftOpcode = + cast<BinaryOperator>(XShift)->getOpcode(); + if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) + return nullptr; // Do not care about same-direction shifts here. + + Value *X, *XShAmt, *Y, *YShAmt; + match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); + match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + + // Can we fold (XShAmt+YShAmt) ? + Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, + SQ.getWithInstruction(&I)); + if (!NewShAmt) + return nullptr; + // Is the new shift amount smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. The shift is the same that was for X. + Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr + ? Builder.CreateLShr(X, NewShAmt) + : Builder.CreateShl(X, NewShAmt); + Value *T1 = Builder.CreateAnd(T0, Y); + return Builder.CreateICmp(I.getPredicate(), T1, + Constant::getNullValue(X->getType())); +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -3448,6 +3683,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) return replaceInstUsesWith(I, V); + if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder)) + return replaceInstUsesWith(I, V); + return nullptr; } @@ -3688,6 +3926,30 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { match(Op1, m_BitReverse(m_Value(B))))) return new ICmpInst(Pred, A, B); + // Canonicalize checking for a power-of-2-or-zero value: + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return Pred == ICmpInst::ICMP_EQ + ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); + } + return nullptr; } @@ -3698,7 +3960,6 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); Type *DestTy = LHSCI->getType(); - Value *RHSCIOp; // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. @@ -3740,7 +4001,7 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { // Not an extension from the same type? - RHSCIOp = CI->getOperand(0); + Value *RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) return nullptr; @@ -3813,104 +4074,83 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, - Value *RHS, Instruction &OrigI, - Value *&Result, Constant *&Overflow) { +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + case Instruction::Sub: + return match(RHS, m_Zero()); + case Instruction::Mul: + return match(RHS, m_One()); + } +} + +OverflowResult InstCombiner::computeOverflow( + Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, Instruction *CxtI) const { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + if (IsSigned) + return computeOverflowForSignedAdd(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedAdd(LHS, RHS, CxtI); + case Instruction::Sub: + if (IsSigned) + return computeOverflowForSignedSub(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedSub(LHS, RHS, CxtI); + case Instruction::Mul: + if (IsSigned) + return computeOverflowForSignedMul(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedMul(LHS, RHS, CxtI); + } +} + +bool InstCombiner::OptimizeOverflowCheck( + Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, + Instruction &OrigI, Value *&Result, Constant *&Overflow) { if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS)) std::swap(LHS, RHS); - auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { - Result = OpResult; - Overflow = OverflowVal; - if (ReuseName) - Result->takeName(&OrigI); - return true; - }; - // If the overflow check was an add followed by a compare, the insertion point // may be pointing to the compare. We want to insert the new instructions // before the add in case there are uses of the add between the add and the // compare. Builder.SetInsertPoint(&OrigI); - switch (OCF) { - case OCF_INVALID: - llvm_unreachable("bad overflow check kind!"); - - case OCF_UNSIGNED_ADD: { - OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); - if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder.CreateNUWAdd(LHS, RHS), Builder.getFalse(), - true); - - if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder.CreateAdd(LHS, RHS), Builder.getTrue(), true); - - // Fall through uadd into sadd - LLVM_FALLTHROUGH; - } - case OCF_SIGNED_ADD: { - // X + 0 -> {X, false} - if (match(RHS, m_Zero())) - return SetResult(LHS, Builder.getFalse(), false); - - // We can strength reduce this signed add into a regular add if we can prove - // that it will never overflow. - if (OCF == OCF_SIGNED_ADD) - if (willNotOverflowSignedAdd(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWAdd(LHS, RHS), Builder.getFalse(), - true); - break; - } - - case OCF_UNSIGNED_SUB: - case OCF_SIGNED_SUB: { - // X - 0 -> {X, false} - if (match(RHS, m_Zero())) - return SetResult(LHS, Builder.getFalse(), false); - - if (OCF == OCF_SIGNED_SUB) { - if (willNotOverflowSignedSub(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWSub(LHS, RHS), Builder.getFalse(), - true); - } else { - if (willNotOverflowUnsignedSub(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNUWSub(LHS, RHS), Builder.getFalse(), - true); - } - break; - } - - case OCF_UNSIGNED_MUL: { - OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); - if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder.CreateNUWMul(LHS, RHS), Builder.getFalse(), - true); - if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder.CreateMul(LHS, RHS), Builder.getTrue(), true); - LLVM_FALLTHROUGH; + if (isNeutralValue(BinaryOp, RHS)) { + Result = LHS; + Overflow = Builder.getFalse(); + return true; } - case OCF_SIGNED_MUL: - // X * undef -> undef - if (isa<UndefValue>(RHS)) - return SetResult(RHS, UndefValue::get(Builder.getInt1Ty()), false); - - // X * 0 -> {0, false} - if (match(RHS, m_Zero())) - return SetResult(RHS, Builder.getFalse(), false); - - // X * 1 -> {X, false} - if (match(RHS, m_One())) - return SetResult(LHS, Builder.getFalse(), false); - if (OCF == OCF_SIGNED_MUL) - if (willNotOverflowSignedMul(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWMul(LHS, RHS), Builder.getFalse(), - true); - break; + switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) { + case OverflowResult::MayOverflow: + return false; + case OverflowResult::AlwaysOverflowsLow: + case OverflowResult::AlwaysOverflowsHigh: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = Builder.getTrue(); + return true; + case OverflowResult::NeverOverflows: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = Builder.getFalse(); + if (auto *Inst = dyn_cast<Instruction>(Result)) { + if (IsSigned) + Inst->setHasNoSignedWrap(); + else + Inst->setHasNoUnsignedWrap(); + } + return true; } - return false; + llvm_unreachable("Unexpected overflow result"); } /// Recognize and process idiom involving test for multiplication @@ -4084,8 +4324,8 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, MulA = Builder.CreateZExt(A, MulType); if (WidthB < MulWidth) MulB = Builder.CreateZExt(B, MulType); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::umul_with_overflow, MulType); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, MulType); CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); IC.Worklist.Add(MulInstr); @@ -4881,61 +5121,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return New; } - // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. - Value *X; - if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) { - // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 - // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 - // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 - // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 - if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || - Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && - match(Op1, m_Zero())) - return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - - // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 - if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); - - // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 - if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) - return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType())); - } - - // Zero-equality checks are preserved through unsigned floating-point casts: - // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 - // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 - if (match(Op0, m_BitCast(m_UIToFP(m_Value(X))))) - if (I.isEquality() && match(Op1, m_Zero())) - return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - - // Test to see if the operands of the icmp are casted versions of other - // values. If the ptr->ptr cast can be stripped off both arguments, we do so - // now. - if (BitCastInst *CI = dyn_cast<BitCastInst>(Op0)) { - if (Op0->getType()->isPointerTy() && - (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { - // We keep moving the cast from the left operand over to the right - // operand, where it can often be eliminated completely. - Op0 = CI->getOperand(0); - - // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast - // so eliminate it as well. - if (BitCastInst *CI2 = dyn_cast<BitCastInst>(Op1)) - Op1 = CI2->getOperand(0); - - // If Op1 is a constant, we can fold the cast into the constant. - if (Op0->getType() != Op1->getType()) { - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { - Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); - } else { - // Otherwise, cast the RHS right before the icmp - Op1 = Builder.CreateBitCast(Op1, Op0->getType()); - } - } - return new ICmpInst(I.getPredicate(), Op0, Op1); - } - } + if (Instruction *Res = foldICmpBitCast(I, Builder)) + return Res; if (isa<CastInst>(Op0)) { // Handle the special case of: icmp (cast bool to X), <cst> @@ -4984,8 +5171,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { isa<IntegerType>(A->getType())) { Value *Result; Constant *Overflow; - if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, - Overflow)) { + if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B, + *AddI, Result, Overflow)) { replaceInstUsesWith(*AddI, Result); return replaceInstUsesWith(I, Overflow); } @@ -5411,6 +5598,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' + Type *OpType = Op0->getType(); + assert(OpType == Op1->getType() && "fcmp with different-typed operands?"); if (Op0 == Op1) { switch (Pred) { default: break; @@ -5420,7 +5609,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_UNE: // True if unordered or not equal // Canonicalize these to be 'fcmp uno %X, 0.0'. I.setPredicate(FCmpInst::FCMP_UNO); - I.setOperand(1, Constant::getNullValue(Op0->getType())); + I.setOperand(1, Constant::getNullValue(OpType)); return &I; case FCmpInst::FCMP_ORD: // True if ordered (no nans) @@ -5429,7 +5618,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_OLE: // True if ordered and less than or equal // Canonicalize these to be 'fcmp ord %X, 0.0'. I.setPredicate(FCmpInst::FCMP_ORD); - I.setOperand(1, Constant::getNullValue(Op0->getType())); + I.setOperand(1, Constant::getNullValue(OpType)); return &I; } } @@ -5438,15 +5627,20 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { - I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); + I.setOperand(0, ConstantFP::getNullValue(OpType)); return &I; } if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { - I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); + I.setOperand(1, ConstantFP::getNullValue(OpType)); return &I; } } + // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -5465,7 +5659,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { - I.setOperand(1, ConstantFP::getNullValue(Op1->getType())); + I.setOperand(1, ConstantFP::getNullValue(OpType)); return &I; } @@ -5505,12 +5699,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Instruction *R = foldFabsWithFcmpZero(I)) return R; - Value *X, *Y; if (match(Op0, m_FNeg(m_Value(X)))) { - // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); - // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C Constant *C; if (match(Op1, m_Constant(C))) { |