summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp76
1 files changed, 58 insertions, 18 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 15b51ae8a5ee..e357a9da8b12 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -55,7 +55,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
KnownBits Known(BitWidth);
- APInt DemandedMask(APInt::getAllOnesValue(BitWidth));
+ APInt DemandedMask(APInt::getAllOnes(BitWidth));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
0, &Inst);
@@ -124,7 +124,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
Known.resetAll();
- if (DemandedMask.isNullValue()) // Not demanding any bits from V.
+ if (DemandedMask.isZero()) // Not demanding any bits from V.
return UndefValue::get(VTy);
if (Depth == MaxAnalysisRecursionDepth)
@@ -274,8 +274,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// constant because that's a canonical 'not' op, and that is better for
// combining, SCEV, and codegen.
const APInt *C;
- if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) {
- if ((*C | ~DemandedMask).isAllOnesValue()) {
+ if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
+ if ((*C | ~DemandedMask).isAllOnes()) {
// Force bits to 1 to create a 'not' op.
I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
return I;
@@ -385,8 +385,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Known = KnownBits::commonBits(LHSKnown, RHSKnown);
break;
}
- case Instruction::ZExt:
case Instruction::Trunc: {
+ // If we do not demand the high bits of a right-shifted and truncated value,
+ // then we may be able to truncate it before the shift.
+ Value *X;
+ const APInt *C;
+ if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
+ // The shift amount must be valid (not poison) in the narrow type, and
+ // it must not be greater than the high bits demanded of the result.
+ if (C->ult(I->getType()->getScalarSizeInBits()) &&
+ C->ule(DemandedMask.countLeadingZeros())) {
+ // trunc (lshr X, C) --> lshr (trunc X), C
+ IRBuilderBase::InsertPointGuard Guard(Builder);
+ Builder.SetInsertPoint(I);
+ Value *Trunc = Builder.CreateTrunc(X, I->getType());
+ return Builder.CreateLShr(Trunc, C->getZExtValue());
+ }
+ }
+ }
+ LLVM_FALLTHROUGH;
+ case Instruction::ZExt: {
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
@@ -516,8 +534,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I->getOperand(0);
// We can't do this with the LHS for subtraction, unless we are only
// demanding the LSB.
- if ((I->getOpcode() == Instruction::Add ||
- DemandedFromOps.isOneValue()) &&
+ if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) &&
DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
@@ -615,7 +632,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// always convert this into a logical shr, even if the shift amount is
// variable. The low bit of the shift cannot be an input sign bit unless
// the shift amount is >= the size of the datatype, which is undefined.
- if (DemandedMask.isOneValue()) {
+ if (DemandedMask.isOne()) {
// Perform the logical shift right.
Instruction *NewVal = BinaryOperator::CreateLShr(
I->getOperand(0), I->getOperand(1), I->getName());
@@ -743,7 +760,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
case Instruction::URem: {
KnownBits Known2(BitWidth);
- APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+ APInt AllOnes = APInt::getAllOnes(BitWidth);
if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) ||
SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1))
return I;
@@ -829,6 +846,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownBitsComputed = true;
break;
}
+ case Intrinsic::umax: {
+ // UMax(A, C) == A if ...
+ // The lowest non-zero bit of DemandMask is higher than the highest
+ // non-zero bit of C.
+ const APInt *C;
+ unsigned CTZ = DemandedMask.countTrailingZeros();
+ if (match(II->getArgOperand(1), m_APInt(C)) &&
+ CTZ >= C->getActiveBits())
+ return II->getArgOperand(0);
+ break;
+ }
+ case Intrinsic::umin: {
+ // UMin(A, C) == A if ...
+ // The lowest non-zero bit of DemandMask is higher than the highest
+ // non-one bit of C.
+ // This comes from using DeMorgans on the above umax example.
+ const APInt *C;
+ unsigned CTZ = DemandedMask.countTrailingZeros();
+ if (match(II->getArgOperand(1), m_APInt(C)) &&
+ CTZ >= C->getBitWidth() - C->countLeadingOnes())
+ return II->getArgOperand(0);
+ break;
+ }
default: {
// Handle target specific intrinsics
Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
@@ -1021,8 +1061,8 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits(
Known.Zero.setLowBits(ShlAmt - 1);
Known.Zero &= DemandedMask;
- APInt BitMask1(APInt::getAllOnesValue(BitWidth));
- APInt BitMask2(APInt::getAllOnesValue(BitWidth));
+ APInt BitMask1(APInt::getAllOnes(BitWidth));
+ APInt BitMask2(APInt::getAllOnes(BitWidth));
bool isLshr = (Shr->getOpcode() == Instruction::LShr);
BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
@@ -1088,7 +1128,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
return nullptr;
unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
- APInt EltMask(APInt::getAllOnesValue(VWidth));
+ APInt EltMask(APInt::getAllOnes(VWidth));
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
if (match(V, m_Undef())) {
@@ -1097,7 +1137,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
return nullptr;
}
- if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison.
+ if (DemandedElts.isZero()) { // If nothing is demanded, provide poison.
UndefElts = EltMask;
return PoisonValue::get(V->getType());
}
@@ -1107,7 +1147,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
if (auto *C = dyn_cast<Constant>(V)) {
// Check if this is identity. If so, return 0 since we are not simplifying
// anything.
- if (DemandedElts.isAllOnesValue())
+ if (DemandedElts.isAllOnes())
return nullptr;
Type *EltTy = cast<VectorType>(V->getType())->getElementType();
@@ -1260,7 +1300,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// Handle trivial case of a splat. Only check the first element of LHS
// operand.
if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
- DemandedElts.isAllOnesValue()) {
+ DemandedElts.isAllOnes()) {
if (!match(I->getOperand(1), m_Undef())) {
I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType()));
MadeChange = true;
@@ -1515,8 +1555,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// Subtlety: If we load from a pointer, the pointer must be valid
// regardless of whether the element is demanded. Doing otherwise risks
// segfaults which didn't exist in the original program.
- APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)),
- DemandedPassThrough(DemandedElts);
+ APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
+ DemandedPassThrough(DemandedElts);
if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
for (unsigned i = 0; i < VWidth; i++) {
Constant *CElt = CV->getAggregateElement(i);
@@ -1568,7 +1608,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// If we've proven all of the lanes undef, return an undef value.
// TODO: Intersect w/demanded lanes
- if (UndefElts.isAllOnesValue())
+ if (UndefElts.isAllOnes())
return UndefValue::get(I->getType());;
return MadeChange ? I : nullptr;