diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 70 |
1 files changed, 66 insertions, 4 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index a9f64feb600c..f38dc436722d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2566,9 +2566,6 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Type *Ty = Add->getType(); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (!Add->hasOneUse()) - return nullptr; - // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE // are canonicalized to SGT/SLT/UGT/ULT. @@ -2602,6 +2599,9 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); } + if (!Add->hasOneUse()) + return nullptr; + // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 @@ -3364,6 +3364,23 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, llvm_unreachable("All possible folds are handled."); } + // The mask value may be a vector constant that has undefined elements. But it + // may not be safe to propagate those undefs into the new compare, so replace + // those elements by copying an existing, defined, and safe scalar constant. + Type *OpTy = M->getType(); + auto *VecC = dyn_cast<Constant>(M); + if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) { + Constant *SafeReplacementConstant = nullptr; + for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) { + if (!isa<UndefValue>(VecC->getAggregateElement(i))) { + SafeReplacementConstant = VecC->getAggregateElement(i); + break; + } + } + assert(SafeReplacementConstant && "Failed to find undef replacement"); + M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); + } + return Builder.CreateICmp(DstPred, X, M); } @@ -4930,7 +4947,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Get scalar or pointer size. unsigned BitWidth = Ty->isIntOrIntVectorTy() ? Ty->getScalarSizeInBits() - : DL.getIndexTypeSizeInBits(Ty->getScalarType()); + : DL.getPointerTypeSizeInBits(Ty->getScalarType()); if (!BitWidth) return nullptr; @@ -5167,6 +5184,7 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); }; + Constant *SafeReplacementConstant = nullptr; if (auto *CI = dyn_cast<ConstantInt>(C)) { // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) @@ -5186,12 +5204,23 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, auto *CI = dyn_cast<ConstantInt>(Elt); if (!CI || !ConstantIsOk(CI)) return llvm::None; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; } } else { // ConstantExpr? return llvm::None; } + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + if (C->containsUndefElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); // Increment or decrement the constant. @@ -5374,6 +5403,36 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, return nullptr; } +// extract(uadd.with.overflow(A, B), 0) ult A +// -> extract(uadd.with.overflow(A, B), 1) +static Instruction *foldICmpOfUAddOv(ICmpInst &I) { + CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + Value *UAddOv; + Value *A, *B; + auto UAddOvResultPat = m_ExtractValue<0>( + m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B))); + if (match(Op0, UAddOvResultPat) && + ((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) || + (Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) && + (match(A, m_One()) || match(B, m_One()))) || + (Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) && + (match(A, m_AllOnes()) || match(B, m_AllOnes()))))) + // extract(uadd.with.overflow(A, B), 0) < A + // extract(uadd.with.overflow(A, 1), 0) == 0 + // extract(uadd.with.overflow(A, -1), 0) != -1 + UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand(); + else if (match(Op1, UAddOvResultPat) && + Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B)) + // A > extract(uadd.with.overflow(A, B), 0) + UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand(); + else + return nullptr; + + return ExtractValueInst::Create(UAddOv, 1); +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -5562,6 +5621,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpEquality(I)) return Res; + if (Instruction *Res = foldICmpOfUAddOv(I)) + return Res; + // The 'cmpxchg' instruction returns an aggregate containing the old value and // an i1 which indicates whether or not we successfully did the swap. // |