diff options
Diffstat (limited to 'lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | lib/Analysis/InstructionSimplify.cpp | 713 |
1 files changed, 285 insertions, 428 deletions
diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index ccf907c144f0..e34bf6f4e43f 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -1,9 +1,8 @@ //===- InstructionSimplify.cpp - Fold instruction operands ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -34,6 +33,8 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" @@ -50,6 +51,9 @@ STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyFPUnOp(unsigned, Value *, const FastMathFlags &, + const SimplifyQuery &, unsigned); static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, @@ -655,32 +659,11 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, Type *IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); APInt Offset = APInt::getNullValue(IntPtrTy->getIntegerBitWidth()); - // Even though we don't look through PHI nodes, we could be called on an - // instruction in an unreachable block, which may be on a cycle. - SmallPtrSet<Value *, 4> Visited; - Visited.insert(V); - do { - if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - if ((!AllowNonInbounds && !GEP->isInBounds()) || - !GEP->accumulateConstantOffset(DL, Offset)) - break; - V = GEP->getPointerOperand(); - } else if (Operator::getOpcode(V) == Instruction::BitCast) { - V = cast<Operator>(V)->getOperand(0); - } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->isInterposable()) - break; - V = GA->getAliasee(); - } else { - if (auto CS = CallSite(V)) - if (Value *RV = CS.getReturnedArgOperand()) { - V = RV; - continue; - } - break; - } - assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!"); - } while (Visited.insert(V).second); + V = V->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds); + // As that strip may trace through `addrspacecast`, need to sext or trunc + // the offset calculated. + IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); + Offset = Offset.sextOrTrunc(IntPtrTy->getIntegerBitWidth()); Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); if (V->getType()->isVectorTy()) @@ -1841,6 +1824,16 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op1; } + // This is a similar pattern used for checking if a value is a power-of-2: + // (A - 1) & A --> 0 (if A is a power-of-2 or 0) + // A & (A - 1) --> 0 (if A is a power-of-2 or 0) + if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Constant::getNullValue(Op1->getType()); + if (match(Op1, m_Add(m_Specific(Op0), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Constant::getNullValue(Op0->getType()); + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true)) return V; @@ -2280,12 +2273,12 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, // come from a pointer that cannot overlap with dynamically-allocated // memory within the lifetime of the current function (allocas, byval // arguments, globals), then determine the comparison result here. - SmallVector<Value *, 8> LHSUObjs, RHSUObjs; + SmallVector<const Value *, 8> LHSUObjs, RHSUObjs; GetUnderlyingObjects(LHS, LHSUObjs, DL); GetUnderlyingObjects(RHS, RHSUObjs, DL); // Is the set of underlying objects all noalias calls? - auto IsNAC = [](ArrayRef<Value *> Objects) { + auto IsNAC = [](ArrayRef<const Value *> Objects) { return all_of(Objects, isNoAliasCall); }; @@ -2295,8 +2288,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, // live with the compared-to allocation). For globals, we exclude symbols // that might be resolve lazily to symbols in another dynamically-loaded // library (and, thus, could be malloc'ed by the implementation). - auto IsAllocDisjoint = [](ArrayRef<Value *> Objects) { - return all_of(Objects, [](Value *V) { + auto IsAllocDisjoint = [](ArrayRef<const Value *> Objects) { + return all_of(Objects, [](const Value *V) { if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) @@ -2472,228 +2465,6 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -/// Many binary operators with a constant operand have an easy-to-compute -/// range of outputs. This can be used to fold a comparison to always true or -/// always false. -static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper, - const InstrInfoQuery &IIQ) { - unsigned Width = Lower.getBitWidth(); - const APInt *C; - switch (BO.getOpcode()) { - case Instruction::Add: - if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { - // FIXME: If we have both nuw and nsw, we should reduce the range further. - if (IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(&BO))) { - // 'add nuw x, C' produces [C, UINT_MAX]. - Lower = *C; - } else if (IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(&BO))) { - if (C->isNegative()) { - // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. - Lower = APInt::getSignedMinValue(Width); - Upper = APInt::getSignedMaxValue(Width) + *C + 1; - } else { - // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. - Lower = APInt::getSignedMinValue(Width) + *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - } - } - } - break; - - case Instruction::And: - if (match(BO.getOperand(1), m_APInt(C))) - // 'and x, C' produces [0, C]. - Upper = *C + 1; - break; - - case Instruction::Or: - if (match(BO.getOperand(1), m_APInt(C))) - // 'or x, C' produces [C, UINT_MAX]. - Lower = *C; - break; - - case Instruction::AShr: - if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { - // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. - Lower = APInt::getSignedMinValue(Width).ashr(*C); - Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; - } else if (match(BO.getOperand(0), m_APInt(C))) { - unsigned ShiftAmount = Width - 1; - if (!C->isNullValue() && IIQ.isExact(&BO)) - ShiftAmount = C->countTrailingZeros(); - if (C->isNegative()) { - // 'ashr C, x' produces [C, C >> (Width-1)] - Lower = *C; - Upper = C->ashr(ShiftAmount) + 1; - } else { - // 'ashr C, x' produces [C >> (Width-1), C] - Lower = C->ashr(ShiftAmount); - Upper = *C + 1; - } - } - break; - - case Instruction::LShr: - if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { - // 'lshr x, C' produces [0, UINT_MAX >> C]. - Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; - } else if (match(BO.getOperand(0), m_APInt(C))) { - // 'lshr C, x' produces [C >> (Width-1), C]. - unsigned ShiftAmount = Width - 1; - if (!C->isNullValue() && IIQ.isExact(&BO)) - ShiftAmount = C->countTrailingZeros(); - Lower = C->lshr(ShiftAmount); - Upper = *C + 1; - } - break; - - case Instruction::Shl: - if (match(BO.getOperand(0), m_APInt(C))) { - if (IIQ.hasNoUnsignedWrap(&BO)) { - // 'shl nuw C, x' produces [C, C << CLZ(C)] - Lower = *C; - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? - if (C->isNegative()) { - // 'shl nsw C, x' produces [C << CLO(C)-1, C] - unsigned ShiftAmount = C->countLeadingOnes() - 1; - Lower = C->shl(ShiftAmount); - Upper = *C + 1; - } else { - // 'shl nsw C, x' produces [C, C << CLZ(C)-1] - unsigned ShiftAmount = C->countLeadingZeros() - 1; - Lower = *C; - Upper = C->shl(ShiftAmount) + 1; - } - } - } - break; - - case Instruction::SDiv: - if (match(BO.getOperand(1), m_APInt(C))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C->isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where C != -1 and C != 0 and C != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (C->countLeadingZeros() < Width - 1) { - // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] - // where C != -1 and C != 0 and C != 1 - Lower = IntMin.sdiv(*C); - Upper = IntMax.sdiv(*C); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(BO.getOperand(0), m_APInt(C))) { - if (C->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = *C; - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv C, x' produces [-|C|, |C|]. - Upper = C->abs() + 1; - Lower = (-Upper) + 1; - } - } - break; - - case Instruction::UDiv: - if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { - // 'udiv x, C' produces [0, UINT_MAX / C]. - Upper = APInt::getMaxValue(Width).udiv(*C) + 1; - } else if (match(BO.getOperand(0), m_APInt(C))) { - // 'udiv C, x' produces [0, C]. - Upper = *C + 1; - } - break; - - case Instruction::SRem: - if (match(BO.getOperand(1), m_APInt(C))) { - // 'srem x, C' produces (-|C|, |C|). - Upper = C->abs(); - Lower = (-Upper) + 1; - } - break; - - case Instruction::URem: - if (match(BO.getOperand(1), m_APInt(C))) - // 'urem x, C' produces [0, C). - Upper = *C; - break; - - default: - break; - } -} - -/// Some intrinsics with a constant operand have an easy-to-compute range of -/// outputs. This can be used to fold a comparison to always true or always -/// false. -static void setLimitsForIntrinsic(IntrinsicInst &II, APInt &Lower, - APInt &Upper) { - unsigned Width = Lower.getBitWidth(); - const APInt *C; - switch (II.getIntrinsicID()) { - case Intrinsic::uadd_sat: - // uadd.sat(x, C) produces [C, UINT_MAX]. - if (match(II.getOperand(0), m_APInt(C)) || - match(II.getOperand(1), m_APInt(C))) - Lower = *C; - break; - case Intrinsic::sadd_sat: - if (match(II.getOperand(0), m_APInt(C)) || - match(II.getOperand(1), m_APInt(C))) { - if (C->isNegative()) { - // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)]. - Lower = APInt::getSignedMinValue(Width); - Upper = APInt::getSignedMaxValue(Width) + *C + 1; - } else { - // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX]. - Lower = APInt::getSignedMinValue(Width) + *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - } - } - break; - case Intrinsic::usub_sat: - // usub.sat(C, x) produces [0, C]. - if (match(II.getOperand(0), m_APInt(C))) - Upper = *C + 1; - // usub.sat(x, C) produces [0, UINT_MAX - C]. - else if (match(II.getOperand(1), m_APInt(C))) - Upper = APInt::getMaxValue(Width) - *C + 1; - break; - case Intrinsic::ssub_sat: - if (match(II.getOperand(0), m_APInt(C))) { - if (C->isNegative()) { - // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)]. - Lower = APInt::getSignedMinValue(Width); - Upper = *C - APInt::getSignedMinValue(Width) + 1; - } else { - // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX]. - Lower = *C - APInt::getSignedMaxValue(Width); - Upper = APInt::getSignedMaxValue(Width) + 1; - } - } else if (match(II.getOperand(1), m_APInt(C))) { - if (C->isNegative()) { - // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]: - Lower = APInt::getSignedMinValue(Width) - *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - } else { - // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C]. - Lower = APInt::getSignedMinValue(Width); - Upper = APInt::getSignedMaxValue(Width) - *C + 1; - } - } - break; - default: - break; - } -} - static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const InstrInfoQuery &IIQ) { Type *ITy = GetCompareTy(RHS); // The return type. @@ -2721,22 +2492,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (RHS_CR.isFullSet()) return ConstantInt::getTrue(ITy); - // Find the range of possible values for binary operators. - unsigned Width = C->getBitWidth(); - APInt Lower = APInt(Width, 0); - APInt Upper = APInt(Width, 0); - if (auto *BO = dyn_cast<BinaryOperator>(LHS)) - setLimitsForBinOp(*BO, Lower, Upper, IIQ); - else if (auto *II = dyn_cast<IntrinsicInst>(LHS)) - setLimitsForIntrinsic(*II, Lower, Upper); - - ConstantRange LHS_CR = - Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); - - if (auto *I = dyn_cast<Instruction>(LHS)) - if (auto *Ranges = IIQ.getMetadata(I, LLVMContext::MD_range)) - LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); - + ConstantRange LHS_CR = computeConstantRange(LHS, IIQ.UseInstrInfo); if (!LHS_CR.isFullSet()) { if (RHS_CR.contains(LHS_CR)) return ConstantInt::getTrue(ITy); @@ -3062,44 +2818,6 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -static Value *simplifyICmpWithAbsNabs(CmpInst::Predicate Pred, Value *Op0, - Value *Op1) { - // We need a comparison with a constant. - const APInt *C; - if (!match(Op1, m_APInt(C))) - return nullptr; - - // matchSelectPattern returns the negation part of an abs pattern in SP1. - // If the negate has an NSW flag, abs(INT_MIN) is undefined. Without that - // constraint, we can't make a contiguous range for the result of abs. - ICmpInst::Predicate AbsPred = ICmpInst::BAD_ICMP_PREDICATE; - Value *SP0, *SP1; - SelectPatternFlavor SPF = matchSelectPattern(Op0, SP0, SP1).Flavor; - if (SPF == SelectPatternFlavor::SPF_ABS && - cast<Instruction>(SP1)->hasNoSignedWrap()) - // The result of abs(X) is >= 0 (with nsw). - AbsPred = ICmpInst::ICMP_SGE; - if (SPF == SelectPatternFlavor::SPF_NABS) - // The result of -abs(X) is <= 0. - AbsPred = ICmpInst::ICMP_SLE; - - if (AbsPred == ICmpInst::BAD_ICMP_PREDICATE) - return nullptr; - - // If there is no intersection between abs/nabs and the range of this icmp, - // the icmp must be false. If the abs/nabs range is a subset of the icmp - // range, the icmp must be true. - APInt Zero = APInt::getNullValue(C->getBitWidth()); - ConstantRange AbsRange = ConstantRange::makeExactICmpRegion(AbsPred, Zero); - ConstantRange CmpRange = ConstantRange::makeExactICmpRegion(Pred, *C); - if (AbsRange.intersectWith(CmpRange).isEmptySet()) - return getFalse(GetCompareTy(Op0)); - if (CmpRange.contains(AbsRange)) - return getTrue(GetCompareTy(Op0)); - - return nullptr; -} - /// Simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, @@ -3319,9 +3037,16 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, std::swap(LHS, RHS); Pred = CmpInst::getSwappedPredicate(Pred); } + assert(!isa<UndefValue>(LHS) && "Unexpected icmp undef,%X"); Type *ITy = GetCompareTy(LHS); // The return type. + // For EQ and NE, we can always pick a value for the undef to make the + // predicate pass or fail, so we can return undef. + // Matches behavior in llvm::ConstantFoldCompareInstruction. + if (isa<UndefValue>(RHS) && ICmpInst::isEquality(Pred)) + return UndefValue::get(ITy); + // icmp X, X -> true/false // icmp X, undef -> true/false because undef could be X. if (LHS == RHS || isa<UndefValue>(RHS)) @@ -3531,9 +3256,6 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) return V; - if (Value *V = simplifyICmpWithAbsNabs(Pred, LHS, RHS)) - return V; - // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) @@ -3647,6 +3369,8 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Handle fcmp with constant RHS. + // TODO: Use match with a specific FP value, so these work with vectors with + // undef lanes. const APFloat *C; if (match(RHS, m_APFloat(C))) { // Check whether the constant is an infinity. @@ -3675,28 +3399,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } } - if (C->isZero()) { - switch (Pred) { - case FCmpInst::FCMP_OGE: - if (FMF.noNaNs() && CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return getTrue(RetTy); - break; - case FCmpInst::FCMP_UGE: - if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return getTrue(RetTy); - break; - case FCmpInst::FCMP_ULT: - if (FMF.noNaNs() && CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return getFalse(RetTy); - break; - case FCmpInst::FCMP_OLT: - if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return getFalse(RetTy); - break; - default: - break; - } - } else if (C->isNegative()) { + if (C->isNegative() && !C->isNegZero()) { assert(!C->isNaN() && "Unexpected NaN constant!"); // TODO: We can catch more cases by using a range check rather than // relying on CannotBeOrderedLessThanZero. @@ -3719,6 +3422,67 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; } } + + // Check comparison of [minnum/maxnum with constant] with other constant. + const APFloat *C2; + if ((match(LHS, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_APFloat(C2))) && + C2->compare(*C) == APFloat::cmpLessThan) || + (match(LHS, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_APFloat(C2))) && + C2->compare(*C) == APFloat::cmpGreaterThan)) { + bool IsMaxNum = + cast<IntrinsicInst>(LHS)->getIntrinsicID() == Intrinsic::maxnum; + // The ordered relationship and minnum/maxnum guarantee that we do not + // have NaN constants, so ordered/unordered preds are handled the same. + switch (Pred) { + case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_UEQ: + // minnum(X, LesserC) == C --> false + // maxnum(X, GreaterC) == C --> false + return getFalse(RetTy); + case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: + // minnum(X, LesserC) != C --> true + // maxnum(X, GreaterC) != C --> true + return getTrue(RetTy); + case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: case FCmpInst::FCMP_UGT: + // minnum(X, LesserC) >= C --> false + // minnum(X, LesserC) > C --> false + // maxnum(X, GreaterC) >= C --> true + // maxnum(X, GreaterC) > C --> true + return ConstantInt::get(RetTy, IsMaxNum); + case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: case FCmpInst::FCMP_ULT: + // minnum(X, LesserC) <= C --> true + // minnum(X, LesserC) < C --> true + // maxnum(X, GreaterC) <= C --> false + // maxnum(X, GreaterC) < C --> false + return ConstantInt::get(RetTy, !IsMaxNum); + default: + // TRUE/FALSE/ORD/UNO should be handled before this. + llvm_unreachable("Unexpected fcmp predicate"); + } + } + } + + if (match(RHS, m_AnyZeroFP())) { + switch (Pred) { + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_ULT: + // Positive or zero X >= 0.0 --> true + // Positive or zero X < 0.0 --> false + if ((FMF.noNaNs() || isKnownNeverNaN(LHS, Q.TLI)) && + CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return Pred == FCmpInst::FCMP_OGE ? getTrue(RetTy) : getFalse(RetTy); + break; + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OLT: + // Positive or zero or nan X >= 0.0 --> true + // Positive or zero or nan X < 0.0 --> false + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy); + break; + default: + break; + } } // If the comparison is with the result of a select instruction, check whether @@ -3904,27 +3668,44 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, Pred == ICmpInst::ICMP_EQ)) return V; - // Test for zero-shift-guard-ops around funnel shifts. These are used to - // avoid UB from oversized shifts in raw IR rotate patterns, but the - // intrinsics do not have that problem. + // Test for a bogus zero-shift-guard-op around funnel-shift or rotate. Value *ShAmt; auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), m_Value(ShAmt)), m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), m_Value(ShAmt))); - // (ShAmt != 0) ? fshl(X, *, ShAmt) : X --> fshl(X, *, ShAmt) - // (ShAmt != 0) ? fshr(*, X, ShAmt) : X --> fshr(*, X, ShAmt) // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X - if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt) - return Pred == ICmpInst::ICMP_NE ? TrueVal : X; - - // (ShAmt == 0) ? X : fshl(X, *, ShAmt) --> fshl(X, *, ShAmt) - // (ShAmt == 0) ? X : fshr(*, X, ShAmt) --> fshr(*, X, ShAmt) + if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_EQ) + return X; // (ShAmt != 0) ? X : fshl(X, *, ShAmt) --> X // (ShAmt != 0) ? X : fshr(*, X, ShAmt) --> X - if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt) - return Pred == ICmpInst::ICMP_EQ ? FalseVal : X; + if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_NE) + return X; + + // Test for a zero-shift-guard-op around rotates. These are used to + // avoid UB from oversized shifts in raw IR rotate patterns, but the + // intrinsics do not have that problem. + // We do not allow this transform for the general funnel shift case because + // that would not preserve the poison safety of the original code. + auto isRotate = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), + m_Deferred(X), + m_Value(ShAmt)), + m_Intrinsic<Intrinsic::fshr>(m_Value(X), + m_Deferred(X), + m_Value(ShAmt))); + // (ShAmt != 0) ? fshl(X, X, ShAmt) : X --> fshl(X, X, ShAmt) + // (ShAmt != 0) ? fshr(X, X, ShAmt) : X --> fshr(X, X, ShAmt) + if (match(TrueVal, isRotate) && FalseVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_NE) + return TrueVal; + // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt) + // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt) + if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt && + Pred == ICmpInst::ICMP_EQ) + return FalseVal; } // Check for other compares that behave like bit test. @@ -4218,6 +3999,17 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, if (isa<UndefValue>(Idx)) return UndefValue::get(Vec->getType()); + // Inserting an undef scalar? Assume it is the same value as the existing + // vector element. + if (isa<UndefValue>(Val)) + return Vec; + + // If we are extracting a value from a vector, then inserting it into the same + // place, that's the input vector: + // insertelt Vec, (extractelt Vec, Idx), Idx --> Vec + if (match(Val, m_ExtractElement(m_Specific(Vec), m_Specific(Idx)))) + return Vec; + return nullptr; } @@ -4495,6 +4287,33 @@ Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } +static Constant *foldConstant(Instruction::UnaryOps Opcode, + Value *&Op, const SimplifyQuery &Q) { + if (auto *C = dyn_cast<Constant>(Op)) + return ConstantFoldUnaryOpOperand(Opcode, C, Q.DL); + return nullptr; +} + +/// Given the operand for an FNeg, see if we can fold the result. If not, this +/// returns null. +static Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldConstant(Instruction::FNeg, Op, Q)) + return C; + + Value *X; + // fneg (fneg X) ==> X + if (match(Op, m_FNeg(m_Value(X)))) + return X; + + return nullptr; +} + +Value *llvm::SimplifyFNegInst(Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::simplifyFNegInst(Op, FMF, Q, RecursionLimit); +} + static Constant *propagateNaN(Constant *In) { // If the input is a vector with undef elements, just return a default NaN. if (!In->isNaN()) @@ -4536,16 +4355,22 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; - // With nnan: (+/-0.0 - X) + X --> 0.0 (and commuted variant) + // With nnan: -X + X --> 0.0 (and commuted variant) // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. // Negative zeros are allowed because we always end up with positive zero: // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 - if (FMF.noNaNs() && (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || - match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) - return ConstantFP::getNullValue(Op0->getType()); + if (FMF.noNaNs()) { + if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || + match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) + return ConstantFP::getNullValue(Op0->getType()); + + if (match(Op0, m_FNeg(m_Specific(Op1))) || + match(Op1, m_FNeg(m_Specific(Op0)))) + return ConstantFP::getNullValue(Op0->getType()); + } // (X - Y) + Y --> X // Y + (X - Y) --> X @@ -4578,14 +4403,17 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, return Op0; // fsub -0.0, (fsub -0.0, X) ==> X + // fsub -0.0, (fneg X) ==> X Value *X; if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) + match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. + // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) + (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + match(Op1, m_FNeg(m_Value(X))))) return X; // fsub nnan x, x ==> 0.0 @@ -4722,6 +4550,42 @@ Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, //=== Helper functions for higher up the class hierarchy. +/// Given the operand for a UnaryOperator, see if we can fold the result. +/// If not, this returns null. +static Value *simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q, + unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::FNeg: + return simplifyFNegInst(Op, FastMathFlags(), Q, MaxRecurse); + default: + llvm_unreachable("Unexpected opcode"); + } +} + +/// Given the operand for a UnaryOperator, see if we can fold the result. +/// If not, this returns null. +/// In contrast to SimplifyUnOp, try to use FastMathFlag when folding the +/// result. In case we don't need FastMathFlags, simply fall to SimplifyUnOp. +static Value *simplifyFPUnOp(unsigned Opcode, Value *Op, + const FastMathFlags &FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + switch (Opcode) { + case Instruction::FNeg: + return simplifyFNegInst(Op, FMF, Q, MaxRecurse); + default: + return simplifyUnOp(Opcode, Op, Q, MaxRecurse); + } +} + +Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q) { + return ::simplifyUnOp(Opcode, Op, Q, RecursionLimit); +} + +Value *llvm::SimplifyFPUnOp(unsigned Opcode, Value *Op, FastMathFlags FMF, + const SimplifyQuery &Q) { + return ::simplifyFPUnOp(Opcode, Op, FMF, Q, RecursionLimit); +} + /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, @@ -4885,22 +4749,6 @@ static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); } -static bool maskIsAllZeroOrUndef(Value *Mask) { - auto *ConstMask = dyn_cast<Constant>(Mask); - if (!ConstMask) - return false; - if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask)) - return true; - for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; - ++I) { - if (auto *MaskElt = ConstMask->getAggregateElement(I)) - if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt)) - continue; - return false; - } - return true; -} - static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, const SimplifyQuery &Q) { // Idempotent functions return the same result when called repeatedly. @@ -4941,8 +4789,32 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, case Intrinsic::log2: // log2(exp2(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X)))) return X; + (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) || + match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), + m_Value(X))))) return X; + break; + case Intrinsic::log10: + // log10(pow(10.0, x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), + m_Value(X)))) return X; break; + case Intrinsic::floor: + case Intrinsic::trunc: + case Intrinsic::ceil: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::rint: { + // floor (sitofp x) -> sitofp x + // floor (uitofp x) -> uitofp x + // + // Converting from int always results in a finite integral number or + // infinity. For either of those inputs, these rounding functions always + // return the same value, so the rounding can be eliminated. + if (match(Op0, m_SIToFP(m_Value())) || match(Op0, m_UIToFP(m_Value()))) + return Op0; + break; + } default: break; } @@ -4960,16 +4832,19 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // X - X -> { 0, false } if (Op0 == Op1) return Constant::getNullValue(ReturnType); - // X - undef -> undef - // undef - X -> undef - if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) - return UndefValue::get(ReturnType); - break; + LLVM_FALLTHROUGH; case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: - // X + undef -> undef - if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) - return UndefValue::get(ReturnType); + // X - undef -> { undef, false } + // undef - X -> { undef, false } + // X + undef -> { undef, false } + // undef + x -> { undef, false } + if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) { + return ConstantStruct::get( + cast<StructType>(ReturnType), + {UndefValue::get(ReturnType->getStructElementType(0)), + Constant::getNullValue(ReturnType->getStructElementType(1))}); + } break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: @@ -5085,26 +4960,28 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return nullptr; } -template <typename IterTy> -static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, - const SimplifyQuery &Q) { +static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { + // Intrinsics with no operands have some kind of side effect. Don't simplify. - unsigned NumOperands = std::distance(ArgBegin, ArgEnd); - if (NumOperands == 0) + unsigned NumOperands = Call->getNumArgOperands(); + if (!NumOperands) return nullptr; + Function *F = cast<Function>(Call->getCalledFunction()); Intrinsic::ID IID = F->getIntrinsicID(); if (NumOperands == 1) - return simplifyUnaryIntrinsic(F, ArgBegin[0], Q); + return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q); if (NumOperands == 2) - return simplifyBinaryIntrinsic(F, ArgBegin[0], ArgBegin[1], Q); + return simplifyBinaryIntrinsic(F, Call->getArgOperand(0), + Call->getArgOperand(1), Q); // Handle intrinsics with 3 or more arguments. switch (IID) { - case Intrinsic::masked_load: { - Value *MaskArg = ArgBegin[2]; - Value *PassthruArg = ArgBegin[3]; + case Intrinsic::masked_load: + case Intrinsic::masked_gather: { + Value *MaskArg = Call->getArgOperand(2); + Value *PassthruArg = Call->getArgOperand(3); // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; @@ -5112,7 +4989,8 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, } case Intrinsic::fshl: case Intrinsic::fshr: { - Value *Op0 = ArgBegin[0], *Op1 = ArgBegin[1], *ShAmtArg = ArgBegin[2]; + Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1), + *ShAmtArg = Call->getArgOperand(2); // If both operands are undef, the result is undef. if (match(Op0, m_Undef()) && match(Op1, m_Undef())) @@ -5120,15 +4998,14 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // If shift amount is undef, assume it is zero. if (match(ShAmtArg, m_Undef())) - return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; + return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); const APInt *ShAmtC; if (match(ShAmtArg, m_APInt(ShAmtC))) { // If there's effectively no shift, return the 1st arg or 2nd arg. - // TODO: For vectors, we could check each element of a non-splat constant. APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); if (ShAmtC->urem(BitWidth).isNullValue()) - return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; + return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); } return nullptr; } @@ -5137,58 +5014,36 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, } } -template <typename IterTy> -static Value *SimplifyCall(ImmutableCallSite CS, Value *V, IterTy ArgBegin, - IterTy ArgEnd, const SimplifyQuery &Q, - unsigned MaxRecurse) { - Type *Ty = V->getType(); - if (PointerType *PTy = dyn_cast<PointerType>(Ty)) - Ty = PTy->getElementType(); - FunctionType *FTy = cast<FunctionType>(Ty); +Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { + Value *Callee = Call->getCalledValue(); // call undef -> undef // call null -> undef - if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V)) - return UndefValue::get(FTy->getReturnType()); + if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) + return UndefValue::get(Call->getType()); - Function *F = dyn_cast<Function>(V); + Function *F = dyn_cast<Function>(Callee); if (!F) return nullptr; if (F->isIntrinsic()) - if (Value *Ret = simplifyIntrinsic(F, ArgBegin, ArgEnd, Q)) + if (Value *Ret = simplifyIntrinsic(Call, Q)) return Ret; - if (!canConstantFoldCallTo(CS, F)) + if (!canConstantFoldCallTo(Call, F)) return nullptr; SmallVector<Constant *, 4> ConstantArgs; - ConstantArgs.reserve(ArgEnd - ArgBegin); - for (IterTy I = ArgBegin, E = ArgEnd; I != E; ++I) { - Constant *C = dyn_cast<Constant>(*I); + unsigned NumArgs = Call->getNumArgOperands(); + ConstantArgs.reserve(NumArgs); + for (auto &Arg : Call->args()) { + Constant *C = dyn_cast<Constant>(&Arg); if (!C) return nullptr; ConstantArgs.push_back(C); } - return ConstantFoldCall(CS, F, ConstantArgs, Q.TLI); -} - -Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, - User::op_iterator ArgBegin, User::op_iterator ArgEnd, - const SimplifyQuery &Q) { - return ::SimplifyCall(CS, V, ArgBegin, ArgEnd, Q, RecursionLimit); -} - -Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, - ArrayRef<Value *> Args, const SimplifyQuery &Q) { - return ::SimplifyCall(CS, V, Args.begin(), Args.end(), Q, RecursionLimit); -} - -Value *llvm::SimplifyCall(ImmutableCallSite ICS, const SimplifyQuery &Q) { - CallSite CS(const_cast<Instruction*>(ICS.getInstruction())); - return ::SimplifyCall(CS, CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), - Q, RecursionLimit); + return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } /// See if we can compute a simplified version of this instruction. @@ -5203,6 +5058,9 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, default: Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); break; + case Instruction::FNeg: + Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); + break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), I->getFastMathFlags(), Q); @@ -5327,8 +5185,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, Result = SimplifyPHINode(cast<PHINode>(I), Q); break; case Instruction::Call: { - CallSite CS(cast<CallInst>(I)); - Result = SimplifyCall(CS, Q); + Result = SimplifyCall(cast<CallInst>(I), Q); break; } #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: |