diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-04 19:20:19 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-08 19:02:26 +0000 |
commit | 81ad626541db97eb356e2c1d4a20eb2a26a766ab (patch) | |
tree | 311b6a8987c32b1e1dcbab65c54cfac3fdb56175 /contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | |
parent | 5fff09660e06a66bed6482da9c70df328e16bbb6 (diff) | |
parent | 145449b1e420787bb99721a429341fa6be3adfb6 (diff) | |
download | src-81ad626541db97eb356e2c1d4a20eb2a26a766ab.tar.gz src-81ad626541db97eb356e2c1d4a20eb2a26a766ab.zip |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 202 |
1 files changed, 127 insertions, 75 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 3f064cfda712..9d4c01ac03e2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" @@ -154,6 +154,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 0 && !V->hasOneUse()) DemandedMask.setAllBits(); + // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care + // about the high bits of the operands. + auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { + unsigned NLZ = DemandedMask.countLeadingZeros(); + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + if (NLZ > 0) { + // Disable the nsw and nuw flags here: We can no longer guarantee that + // we won't wrap after simplification. Removing the nsw/nuw flags is + // legal here because the top bit is not demanded. + I->setHasNoSignedWrap(false); + I->setHasNoUnsignedWrap(false); + } + return true; + } + return false; + }; + switch (I->getOpcode()) { default: computeKnownBits(I, Known, Depth, CxtI); @@ -297,13 +320,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); - Constant *AndC = - ConstantInt::get(I->getType(), NewMask & AndRHS->getValue()); + Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); InsertNewInstWith(NewAnd, *I); - Constant *XorC = - ConstantInt::get(I->getType(), NewMask & XorRHS->getValue()); + Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); return InsertNewInstWith(NewXor, *I); } @@ -311,33 +332,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::Select: { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; - if (SPF == SPF_UMAX) { - // UMax(A, C) == A if ... - // The lowest non-zero bit of DemandMask is higher than the highest - // non-zero bit of C. - const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); - if (match(RHS, m_APInt(C)) && CTZ >= C->getActiveBits()) - return LHS; - } else if (SPF == SPF_UMIN) { - // UMin(A, C) == A if ... - // The lowest non-zero bit of DemandMask is higher than the highest - // non-one bit of C. - // This comes from using DeMorgans on the above umax example. - const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); - if (match(RHS, m_APInt(C)) && - CTZ >= C->getBitWidth() - C->countLeadingOnes()) - return LHS; - } - - // If this is a select as part of any other min/max pattern, don't simplify - // any further in case we break the structure. - if (SPF != SPF_UNKNOWN) - return nullptr; - if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1)) return I; @@ -393,12 +387,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { // The shift amount must be valid (not poison) in the narrow type, and // it must not be greater than the high bits demanded of the result. - if (C->ult(I->getType()->getScalarSizeInBits()) && + if (C->ult(VTy->getScalarSizeInBits()) && C->ule(DemandedMask.countLeadingZeros())) { // trunc (lshr X, C) --> lshr (trunc X), C IRBuilderBase::InsertPointGuard Guard(Builder); Builder.SetInsertPoint(I); - Value *Trunc = Builder.CreateTrunc(X, I->getType()); + Value *Trunc = Builder.CreateTrunc(X, VTy); return Builder.CreateLShr(Trunc, C->getZExtValue()); } } @@ -420,9 +414,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (!I->getOperand(0)->getType()->isIntOrIntVectorTy()) return nullptr; // vector->int or fp->int? - if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) { - if (VectorType *SrcVTy = - dyn_cast<VectorType>(I->getOperand(0)->getType())) { + if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { + if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { if (cast<FixedVectorType>(DstVTy)->getNumElements() != cast<FixedVectorType>(SrcVTy)->getNumElements()) // Don't touch a bitcast between vectors of different element counts. @@ -507,26 +500,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } LLVM_FALLTHROUGH; case Instruction::Sub: { - /// If the high-bits of an ADD/SUB are not demanded, then we do not care - /// about the high bits of the operands. - unsigned NLZ = DemandedMask.countLeadingZeros(); - // Right fill the mask of bits for this ADD/SUB to demand the most - // significant bit and all those below it. - APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || - SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || - ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { - if (NLZ > 0) { - // Disable the nsw and nuw flags here: We can no longer guarantee that - // we won't wrap after simplification. Removing the nsw/nuw flags is - // legal here because the top bit is not demanded. - BinaryOperator &BinOP = *cast<BinaryOperator>(I); - BinOP.setHasNoSignedWrap(false); - BinOP.setHasNoUnsignedWrap(false); - } + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) return I; - } // If we are known to be adding/subtracting zeros to every bit below // the highest demanded bit, we just return the other side. @@ -544,6 +520,36 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, NSW, LHSKnown, RHSKnown); break; } + case Instruction::Mul: { + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) + return I; + + if (DemandedMask.isPowerOf2()) { + // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. + // If we demand exactly one bit N and we have "X * (C' << N)" where C' is + // odd (has LSB set), then the left-shifted low bit of X is the answer. + unsigned CTZ = DemandedMask.countTrailingZeros(); + const APInt *C; + if (match(I->getOperand(1), m_APInt(C)) && + C->countTrailingZeros() == CTZ) { + Constant *ShiftC = ConstantInt::get(VTy, CTZ); + Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); + return InsertNewInstWith(Shl, *I); + } + } + // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: + // X * X is odd iff X is odd. + // 'Quadratic Reciprocity': X * X -> 0 for bit[1] + if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { + Constant *One = ConstantInt::get(VTy, 1); + Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); + return InsertNewInstWith(And1, *I); + } + + computeKnownBits(I, Known, Depth, CxtI); + break; + } case Instruction::Shl: { const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { @@ -554,7 +560,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask, Known)) return R; + // TODO: If we only want bits that already match the signbit then we don't + // need to shift. + + // If we can pre-shift a right-shifted constant to the left without + // losing any high bits amd we don't demand the low bits, then eliminate + // the left-shift: + // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + Value *X; + Constant *C; + if (DemandedMask.countTrailingZeros() >= ShiftAmt && + match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { + Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); + if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) { + Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); + return InsertNewInstWith(Lshr, *I); + } + } + APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); // If the shift is NUW/NSW, then it does demand the high bits. @@ -584,7 +609,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, else if (SignBitOne) Known.One.setSignBit(); if (Known.hasConflict()) - return UndefValue::get(I->getType()); + return UndefValue::get(VTy); } } else { // This is a variable shift, so we can't shift the demand mask by a known @@ -607,6 +632,34 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + // If we are just demanding the shifted sign bit and below, then this can + // be treated as an ASHR in disguise. + if (DemandedMask.countLeadingZeros() >= ShiftAmt) { + // If we only want bits that already match the signbit then we don't + // need to shift. + unsigned NumHiDemandedBits = + BitWidth - DemandedMask.countTrailingZeros(); + unsigned SignBits = + ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + + // If we can pre-shift a left-shifted constant to the right without + // losing any low bits (we already know we don't demand the high bits), + // then eliminate the right-shift: + // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X + Value *X; + Constant *C; + if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { + Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC); + if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) { + Instruction *Shl = BinaryOperator::CreateShl(NewC, X); + return InsertNewInstWith(Shl, *I); + } + } + } + // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); @@ -628,6 +681,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::AShr: { + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + + // If we only want bits that already match the signbit then we don't need + // to shift. + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless @@ -639,11 +700,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return InsertNewInstWith(NewVal, *I); } - // If the sign bit is the only bit demanded by this ashr, then there is no - // need to do it, the shift doesn't change the high bit. - if (DemandedMask.isSignMask()) - return I->getOperand(0); - const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); @@ -663,8 +719,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now plus sign bits. APInt HighBits(APInt::getHighBitsSet( @@ -713,13 +767,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::SRem: { - ConstantInt *Rem; - if (match(I->getOperand(1), m_ConstantInt(Rem))) { + const APInt *Rem; + if (match(I->getOperand(1), m_APInt(Rem))) { // X % -1 demands all the bits because we don't want to introduce // INT_MIN % -1 (== undef) by accident. - if (Rem->isMinusOne()) + if (Rem->isAllOnes()) break; - APInt RA = Rem->getValue().abs(); + APInt RA = Rem->abs(); if (RA.isPowerOf2()) { if (DemandedMask.ult(RA)) // srem won't affect demanded bits return I->getOperand(0); @@ -786,7 +840,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 && match(II->getArgOperand(0), m_Not(m_Value(X)))) { Function *Ctpop = Intrinsic::getDeclaration( - II->getModule(), Intrinsic::ctpop, II->getType()); + II->getModule(), Intrinsic::ctpop, VTy); return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I); } break; @@ -809,12 +863,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Instruction *NewVal; if (NLZ > NTZ) NewVal = BinaryOperator::CreateLShr( - II->getArgOperand(0), - ConstantInt::get(I->getType(), NLZ - NTZ)); + II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ)); else NewVal = BinaryOperator::CreateShl( - II->getArgOperand(0), - ConstantInt::get(I->getType(), NTZ - NLZ)); + II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); NewVal->takeName(I); return InsertNewInstWith(NewVal, *I); } @@ -872,7 +924,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Handle target specific intrinsics Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( *II, DemandedMask, Known, KnownBitsComputed); - if (V.hasValue()) + if (V) return V.getValue(); break; } @@ -1583,7 +1635,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, simplifyAndSetOp); - if (V.hasValue()) + if (V) return V.getValue(); break; } |