aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2022-07-04 19:20:19 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-02-08 19:02:26 +0000
commit81ad626541db97eb356e2c1d4a20eb2a26a766ab (patch)
tree311b6a8987c32b1e1dcbab65c54cfac3fdb56175 /contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
parent5fff09660e06a66bed6482da9c70df328e16bbb6 (diff)
parent145449b1e420787bb99721a429341fa6be3adfb6 (diff)
downloadsrc-81ad626541db97eb356e2c1d4a20eb2a26a766ab.tar.gz
src-81ad626541db97eb356e2c1d4a20eb2a26a766ab.zip
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp202
1 files changed, 127 insertions, 75 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 3f064cfda712..9d4c01ac03e2 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -12,8 +12,8 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
@@ -154,6 +154,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 0 && !V->hasOneUse())
DemandedMask.setAllBits();
+ // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
+ // about the high bits of the operands.
+ auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
+ unsigned NLZ = DemandedMask.countLeadingZeros();
+ // Right fill the mask of bits for the operands to demand the most
+ // significant bit and all those below it.
+ DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
+ if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
+ ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
+ if (NLZ > 0) {
+ // Disable the nsw and nuw flags here: We can no longer guarantee that
+ // we won't wrap after simplification. Removing the nsw/nuw flags is
+ // legal here because the top bit is not demanded.
+ I->setHasNoSignedWrap(false);
+ I->setHasNoUnsignedWrap(false);
+ }
+ return true;
+ }
+ return false;
+ };
+
switch (I->getOpcode()) {
default:
computeKnownBits(I, Known, Depth, CxtI);
@@ -297,13 +320,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
(LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
- Constant *AndC =
- ConstantInt::get(I->getType(), NewMask & AndRHS->getValue());
+ Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue());
Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
InsertNewInstWith(NewAnd, *I);
- Constant *XorC =
- ConstantInt::get(I->getType(), NewMask & XorRHS->getValue());
+ Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue());
Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
return InsertNewInstWith(NewXor, *I);
}
@@ -311,33 +332,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
break;
}
case Instruction::Select: {
- Value *LHS, *RHS;
- SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor;
- if (SPF == SPF_UMAX) {
- // UMax(A, C) == A if ...
- // The lowest non-zero bit of DemandMask is higher than the highest
- // non-zero bit of C.
- const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
- if (match(RHS, m_APInt(C)) && CTZ >= C->getActiveBits())
- return LHS;
- } else if (SPF == SPF_UMIN) {
- // UMin(A, C) == A if ...
- // The lowest non-zero bit of DemandMask is higher than the highest
- // non-one bit of C.
- // This comes from using DeMorgans on the above umax example.
- const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
- if (match(RHS, m_APInt(C)) &&
- CTZ >= C->getBitWidth() - C->countLeadingOnes())
- return LHS;
- }
-
- // If this is a select as part of any other min/max pattern, don't simplify
- // any further in case we break the structure.
- if (SPF != SPF_UNKNOWN)
- return nullptr;
-
if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1))
return I;
@@ -393,12 +387,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
// The shift amount must be valid (not poison) in the narrow type, and
// it must not be greater than the high bits demanded of the result.
- if (C->ult(I->getType()->getScalarSizeInBits()) &&
+ if (C->ult(VTy->getScalarSizeInBits()) &&
C->ule(DemandedMask.countLeadingZeros())) {
// trunc (lshr X, C) --> lshr (trunc X), C
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
- Value *Trunc = Builder.CreateTrunc(X, I->getType());
+ Value *Trunc = Builder.CreateTrunc(X, VTy);
return Builder.CreateLShr(Trunc, C->getZExtValue());
}
}
@@ -420,9 +414,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (!I->getOperand(0)->getType()->isIntOrIntVectorTy())
return nullptr; // vector->int or fp->int?
- if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) {
- if (VectorType *SrcVTy =
- dyn_cast<VectorType>(I->getOperand(0)->getType())) {
+ if (auto *DstVTy = dyn_cast<VectorType>(VTy)) {
+ if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) {
if (cast<FixedVectorType>(DstVTy)->getNumElements() !=
cast<FixedVectorType>(SrcVTy)->getNumElements())
// Don't touch a bitcast between vectors of different element counts.
@@ -507,26 +500,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
LLVM_FALLTHROUGH;
case Instruction::Sub: {
- /// If the high-bits of an ADD/SUB are not demanded, then we do not care
- /// about the high bits of the operands.
- unsigned NLZ = DemandedMask.countLeadingZeros();
- // Right fill the mask of bits for this ADD/SUB to demand the most
- // significant bit and all those below it.
- APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
- if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
- SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
- ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
- SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
- if (NLZ > 0) {
- // Disable the nsw and nuw flags here: We can no longer guarantee that
- // we won't wrap after simplification. Removing the nsw/nuw flags is
- // legal here because the top bit is not demanded.
- BinaryOperator &BinOP = *cast<BinaryOperator>(I);
- BinOP.setHasNoSignedWrap(false);
- BinOP.setHasNoUnsignedWrap(false);
- }
+ APInt DemandedFromOps;
+ if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
- }
// If we are known to be adding/subtracting zeros to every bit below
// the highest demanded bit, we just return the other side.
@@ -544,6 +520,36 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
NSW, LHSKnown, RHSKnown);
break;
}
+ case Instruction::Mul: {
+ APInt DemandedFromOps;
+ if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
+ return I;
+
+ if (DemandedMask.isPowerOf2()) {
+ // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
+ // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
+ // odd (has LSB set), then the left-shifted low bit of X is the answer.
+ unsigned CTZ = DemandedMask.countTrailingZeros();
+ const APInt *C;
+ if (match(I->getOperand(1), m_APInt(C)) &&
+ C->countTrailingZeros() == CTZ) {
+ Constant *ShiftC = ConstantInt::get(VTy, CTZ);
+ Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
+ return InsertNewInstWith(Shl, *I);
+ }
+ }
+ // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
+ // X * X is odd iff X is odd.
+ // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
+ if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) {
+ Constant *One = ConstantInt::get(VTy, 1);
+ Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One);
+ return InsertNewInstWith(And1, *I);
+ }
+
+ computeKnownBits(I, Known, Depth, CxtI);
+ break;
+ }
case Instruction::Shl: {
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
@@ -554,7 +560,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask, Known))
return R;
+ // TODO: If we only want bits that already match the signbit then we don't
+ // need to shift.
+
+ // If we can pre-shift a right-shifted constant to the left without
+ // losing any high bits amd we don't demand the low bits, then eliminate
+ // the left-shift:
+ // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
+ Value *X;
+ Constant *C;
+ if (DemandedMask.countTrailingZeros() >= ShiftAmt &&
+ match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
+ Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+ Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC);
+ if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) {
+ Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
+ return InsertNewInstWith(Lshr, *I);
+ }
+ }
+
APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
// If the shift is NUW/NSW, then it does demand the high bits.
@@ -584,7 +609,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
else if (SignBitOne)
Known.One.setSignBit();
if (Known.hasConflict())
- return UndefValue::get(I->getType());
+ return UndefValue::get(VTy);
}
} else {
// This is a variable shift, so we can't shift the demand mask by a known
@@ -607,6 +632,34 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (match(I->getOperand(1), m_APInt(SA))) {
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
+ // If we are just demanding the shifted sign bit and below, then this can
+ // be treated as an ASHR in disguise.
+ if (DemandedMask.countLeadingZeros() >= ShiftAmt) {
+ // If we only want bits that already match the signbit then we don't
+ // need to shift.
+ unsigned NumHiDemandedBits =
+ BitWidth - DemandedMask.countTrailingZeros();
+ unsigned SignBits =
+ ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+ if (SignBits >= NumHiDemandedBits)
+ return I->getOperand(0);
+
+ // If we can pre-shift a left-shifted constant to the right without
+ // losing any low bits (we already know we don't demand the high bits),
+ // then eliminate the right-shift:
+ // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X
+ Value *X;
+ Constant *C;
+ if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) {
+ Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+ Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC);
+ if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) {
+ Instruction *Shl = BinaryOperator::CreateShl(NewC, X);
+ return InsertNewInstWith(Shl, *I);
+ }
+ }
+ }
+
// Unsigned shift right.
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
@@ -628,6 +681,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
break;
}
case Instruction::AShr: {
+ unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+
+ // If we only want bits that already match the signbit then we don't need
+ // to shift.
+ unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros();
+ if (SignBits >= NumHiDemandedBits)
+ return I->getOperand(0);
+
// If this is an arithmetic shift right and only the low-bit is set, we can
// always convert this into a logical shr, even if the shift amount is
// variable. The low bit of the shift cannot be an input sign bit unless
@@ -639,11 +700,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return InsertNewInstWith(NewVal, *I);
}
- // If the sign bit is the only bit demanded by this ashr, then there is no
- // need to do it, the shift doesn't change the high bit.
- if (DemandedMask.isSignMask())
- return I->getOperand(0);
-
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
@@ -663,8 +719,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
return I;
- unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
-
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
// Compute the new bits that are at the top now plus sign bits.
APInt HighBits(APInt::getHighBitsSet(
@@ -713,13 +767,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
break;
}
case Instruction::SRem: {
- ConstantInt *Rem;
- if (match(I->getOperand(1), m_ConstantInt(Rem))) {
+ const APInt *Rem;
+ if (match(I->getOperand(1), m_APInt(Rem))) {
// X % -1 demands all the bits because we don't want to introduce
// INT_MIN % -1 (== undef) by accident.
- if (Rem->isMinusOne())
+ if (Rem->isAllOnes())
break;
- APInt RA = Rem->getValue().abs();
+ APInt RA = Rem->abs();
if (RA.isPowerOf2()) {
if (DemandedMask.ult(RA)) // srem won't affect demanded bits
return I->getOperand(0);
@@ -786,7 +840,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 &&
match(II->getArgOperand(0), m_Not(m_Value(X)))) {
Function *Ctpop = Intrinsic::getDeclaration(
- II->getModule(), Intrinsic::ctpop, II->getType());
+ II->getModule(), Intrinsic::ctpop, VTy);
return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I);
}
break;
@@ -809,12 +863,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Instruction *NewVal;
if (NLZ > NTZ)
NewVal = BinaryOperator::CreateLShr(
- II->getArgOperand(0),
- ConstantInt::get(I->getType(), NLZ - NTZ));
+ II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ));
else
NewVal = BinaryOperator::CreateShl(
- II->getArgOperand(0),
- ConstantInt::get(I->getType(), NTZ - NLZ));
+ II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ));
NewVal->takeName(I);
return InsertNewInstWith(NewVal, *I);
}
@@ -872,7 +924,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Handle target specific intrinsics
Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
*II, DemandedMask, Known, KnownBitsComputed);
- if (V.hasValue())
+ if (V)
return V.getValue();
break;
}
@@ -1583,7 +1635,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
*II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
simplifyAndSetOp);
- if (V.hasValue())
+ if (V)
return V.getValue();
break;
}