diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 76 |
1 files changed, 58 insertions, 18 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 15b51ae8a5ee..e357a9da8b12 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -55,7 +55,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); + APInt DemandedMask(APInt::getAllOnes(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); @@ -124,7 +124,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } Known.resetAll(); - if (DemandedMask.isNullValue()) // Not demanding any bits from V. + if (DemandedMask.isZero()) // Not demanding any bits from V. return UndefValue::get(VTy); if (Depth == MaxAnalysisRecursionDepth) @@ -274,8 +274,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // constant because that's a canonical 'not' op, and that is better for // combining, SCEV, and codegen. const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) { - if ((*C | ~DemandedMask).isAllOnesValue()) { + if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { + if ((*C | ~DemandedMask).isAllOnes()) { // Force bits to 1 to create a 'not' op. I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); return I; @@ -385,8 +385,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } - case Instruction::ZExt: case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(I->getType()->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, I->getType()); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); @@ -516,8 +534,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I->getOperand(0); // We can't do this with the LHS for subtraction, unless we are only // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || - DemandedFromOps.isOneValue()) && + if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); @@ -615,7 +632,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless // the shift amount is >= the size of the datatype, which is undefined. - if (DemandedMask.isOneValue()) { + if (DemandedMask.isOne()) { // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); @@ -743,7 +760,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } case Instruction::URem: { KnownBits Known2(BitWidth); - APInt AllOnes = APInt::getAllOnesValue(BitWidth); + APInt AllOnes = APInt::getAllOnes(BitWidth); if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) return I; @@ -829,6 +846,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBitsComputed = true; break; } + case Intrinsic::umax: { + // UMax(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-zero bit of C. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getActiveBits()) + return II->getArgOperand(0); + break; + } + case Intrinsic::umin: { + // UMin(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-one bit of C. + // This comes from using DeMorgans on the above umax example. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getBitWidth() - C->countLeadingOnes()) + return II->getArgOperand(0); + break; + } default: { // Handle target specific intrinsics Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( @@ -1021,8 +1061,8 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits( Known.Zero.setLowBits(ShlAmt - 1); Known.Zero &= DemandedMask; - APInt BitMask1(APInt::getAllOnesValue(BitWidth)); - APInt BitMask2(APInt::getAllOnesValue(BitWidth)); + APInt BitMask1(APInt::getAllOnes(BitWidth)); + APInt BitMask2(APInt::getAllOnes(BitWidth)); bool isLshr = (Shr->getOpcode() == Instruction::LShr); BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : @@ -1088,7 +1128,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); - APInt EltMask(APInt::getAllOnesValue(VWidth)); + APInt EltMask(APInt::getAllOnes(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); if (match(V, m_Undef())) { @@ -1097,7 +1137,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; } - if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison. + if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. UndefElts = EltMask; return PoisonValue::get(V->getType()); } @@ -1107,7 +1147,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (auto *C = dyn_cast<Constant>(V)) { // Check if this is identity. If so, return 0 since we are not simplifying // anything. - if (DemandedElts.isAllOnesValue()) + if (DemandedElts.isAllOnes()) return nullptr; Type *EltTy = cast<VectorType>(V->getType())->getElementType(); @@ -1260,7 +1300,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Handle trivial case of a splat. Only check the first element of LHS // operand. if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && - DemandedElts.isAllOnesValue()) { + DemandedElts.isAllOnes()) { if (!match(I->getOperand(1), m_Undef())) { I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); MadeChange = true; @@ -1515,8 +1555,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Subtlety: If we load from a pointer, the pointer must be valid // regardless of whether the element is demanded. Doing otherwise risks // segfaults which didn't exist in the original program. - APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)), - DemandedPassThrough(DemandedElts); + APInt DemandedPtrs(APInt::getAllOnes(VWidth)), + DemandedPassThrough(DemandedElts); if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) for (unsigned i = 0; i < VWidth; i++) { Constant *CElt = CV->getAggregateElement(i); @@ -1568,7 +1608,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // If we've proven all of the lanes undef, return an undef value. // TODO: Intersect w/demanded lanes - if (UndefElts.isAllOnesValue()) + if (UndefElts.isAllOnes()) return UndefValue::get(I->getType());; return MadeChange ? I : nullptr; |
