diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 848 |
1 files changed, 709 insertions, 139 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 40d861702e86..b3b8756ae9ba 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -24,9 +24,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" @@ -55,7 +53,6 @@ #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Compiler.h" @@ -144,11 +141,11 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { unsigned EltSize = N->getValueType(0).getVectorElementType().getSizeInBits(); if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) { - SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize); + SplatVal = Op0->getAPIntValue().trunc(EltSize); return true; } if (auto *Op0 = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) { - SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize); + SplatVal = Op0->getValueAPF().bitcastToAPInt().trunc(EltSize); return true; } } @@ -714,6 +711,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(LD->getMemoryVT().getRawBits()); ID.AddInteger(LD->getRawSubclassData()); ID.AddInteger(LD->getPointerInfo().getAddrSpace()); + ID.AddInteger(LD->getMemOperand()->getFlags()); break; } case ISD::STORE: { @@ -721,6 +719,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); break; } case ISD::VP_LOAD: { @@ -728,6 +727,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(ELD->getMemoryVT().getRawBits()); ID.AddInteger(ELD->getRawSubclassData()); ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + ID.AddInteger(ELD->getMemOperand()->getFlags()); break; } case ISD::VP_STORE: { @@ -735,6 +735,21 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(EST->getMemoryVT().getRawBits()); ID.AddInteger(EST->getRawSubclassData()); ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + ID.AddInteger(EST->getMemOperand()->getFlags()); + break; + } + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: { + const VPStridedLoadSDNode *SLD = cast<VPStridedLoadSDNode>(N); + ID.AddInteger(SLD->getMemoryVT().getRawBits()); + ID.AddInteger(SLD->getRawSubclassData()); + ID.AddInteger(SLD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: { + const VPStridedStoreSDNode *SST = cast<VPStridedStoreSDNode>(N); + ID.AddInteger(SST->getMemoryVT().getRawBits()); + ID.AddInteger(SST->getRawSubclassData()); + ID.AddInteger(SST->getPointerInfo().getAddrSpace()); break; } case ISD::VP_GATHER: { @@ -742,6 +757,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(EG->getMemoryVT().getRawBits()); ID.AddInteger(EG->getRawSubclassData()); ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + ID.AddInteger(EG->getMemOperand()->getFlags()); break; } case ISD::VP_SCATTER: { @@ -749,6 +765,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(ES->getMemoryVT().getRawBits()); ID.AddInteger(ES->getRawSubclassData()); ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + ID.AddInteger(ES->getMemOperand()->getFlags()); break; } case ISD::MLOAD: { @@ -756,6 +773,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(MLD->getMemoryVT().getRawBits()); ID.AddInteger(MLD->getRawSubclassData()); ID.AddInteger(MLD->getPointerInfo().getAddrSpace()); + ID.AddInteger(MLD->getMemOperand()->getFlags()); break; } case ISD::MSTORE: { @@ -763,6 +781,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(MST->getMemoryVT().getRawBits()); ID.AddInteger(MST->getRawSubclassData()); ID.AddInteger(MST->getPointerInfo().getAddrSpace()); + ID.AddInteger(MST->getMemOperand()->getFlags()); break; } case ISD::MGATHER: { @@ -770,6 +789,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(MG->getMemoryVT().getRawBits()); ID.AddInteger(MG->getRawSubclassData()); ID.AddInteger(MG->getPointerInfo().getAddrSpace()); + ID.AddInteger(MG->getMemOperand()->getFlags()); break; } case ISD::MSCATTER: { @@ -777,6 +797,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(MS->getMemoryVT().getRawBits()); ID.AddInteger(MS->getRawSubclassData()); ID.AddInteger(MS->getPointerInfo().getAddrSpace()); + ID.AddInteger(MS->getMemOperand()->getFlags()); break; } case ISD::ATOMIC_CMP_SWAP: @@ -799,11 +820,13 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(AT->getMemoryVT().getRawBits()); ID.AddInteger(AT->getRawSubclassData()); ID.AddInteger(AT->getPointerInfo().getAddrSpace()); + ID.AddInteger(AT->getMemOperand()->getFlags()); break; } case ISD::PREFETCH: { const MemSDNode *PF = cast<MemSDNode>(N); ID.AddInteger(PF->getPointerInfo().getAddrSpace()); + ID.AddInteger(PF->getMemOperand()->getFlags()); break; } case ISD::VECTOR_SHUFFLE: { @@ -821,11 +844,18 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(BA->getTargetFlags()); break; } + case ISD::AssertAlign: + ID.AddInteger(cast<AssertAlignSDNode>(N)->getAlign().value()); + break; } // end switch (N->getOpcode()) - // Target specific memory nodes could also have address spaces to check. - if (N->isTargetMemoryOpcode()) - ID.AddInteger(cast<MemSDNode>(N)->getPointerInfo().getAddrSpace()); + // Target specific memory nodes could also have address spaces and flags + // to check. + if (N->isTargetMemoryOpcode()) { + const MemSDNode *MN = cast<MemSDNode>(N); + ID.AddInteger(MN->getPointerInfo().getAddrSpace()); + ID.AddInteger(MN->getMemOperand()->getFlags()); + } } /// AddNodeIDNode - Generic routine for adding a nodes info to the NodeID @@ -1395,6 +1425,12 @@ SDValue SelectionDAG::getLogicalNOT(const SDLoc &DL, SDValue Val, EVT VT) { return getNode(ISD::XOR, DL, VT, Val, TrueValue); } +SDValue SelectionDAG::getVPLogicalNOT(const SDLoc &DL, SDValue Val, + SDValue Mask, SDValue EVL, EVT VT) { + SDValue TrueValue = getBoolConstant(true, DL, VT, VT); + return getNode(ISD::VP_XOR, DL, VT, Val, TrueValue, Mask, EVL); +} + SDValue SelectionDAG::getBoolConstant(bool V, const SDLoc &DL, EVT VT, EVT OpVT) { if (!V) @@ -2433,23 +2469,9 @@ SDValue SelectionDAG::GetDemandedBits(SDValue V, const APInt &DemandedBits) { if (VT.isScalableVector()) return SDValue(); - APInt DemandedElts = VT.isVector() - ? APInt::getAllOnes(VT.getVectorNumElements()) - : APInt(1, 1); - return GetDemandedBits(V, DemandedBits, DemandedElts); -} - -/// See if the specified operand can be simplified with the knowledge that only -/// the bits specified by DemandedBits are used in the elements specified by -/// DemandedElts. -/// TODO: really we should be making this into the DAG equivalent of -/// SimplifyMultipleUseDemandedBits and not generate any new nodes. -SDValue SelectionDAG::GetDemandedBits(SDValue V, const APInt &DemandedBits, - const APInt &DemandedElts) { switch (V.getOpcode()) { default: - return TLI->SimplifyMultipleUseDemandedBits(V, DemandedBits, DemandedElts, - *this); + return TLI->SimplifyMultipleUseDemandedBits(V, DemandedBits, *this); case ISD::Constant: { const APInt &CVal = cast<ConstantSDNode>(V)->getAPIntValue(); APInt NewVal = CVal & DemandedBits; @@ -2469,8 +2491,8 @@ SDValue SelectionDAG::GetDemandedBits(SDValue V, const APInt &DemandedBits, if (Amt >= DemandedBits.getBitWidth()) break; APInt SrcDemandedBits = DemandedBits << Amt; - if (SDValue SimplifyLHS = - GetDemandedBits(V.getOperand(0), SrcDemandedBits)) + if (SDValue SimplifyLHS = TLI->SimplifyMultipleUseDemandedBits( + V.getOperand(0), SrcDemandedBits, *this)) return getNode(ISD::SRL, SDLoc(V), V.getValueType(), SimplifyLHS, V.getOperand(1)); } @@ -2503,6 +2525,14 @@ bool SelectionDAG::MaskedValueIsZero(SDValue V, const APInt &Mask, return Mask.isSubsetOf(computeKnownBits(V, DemandedElts, Depth).Zero); } +/// MaskedVectorIsZero - Return true if 'Op' is known to be zero in +/// DemandedElts. We use this predicate to simplify operations downstream. +bool SelectionDAG::MaskedVectorIsZero(SDValue V, const APInt &DemandedElts, + unsigned Depth /* = 0 */) const { + APInt Mask = APInt::getAllOnes(V.getScalarValueSizeInBits()); + return Mask.isSubsetOf(computeKnownBits(V, DemandedElts, Depth).Zero); +} + /// MaskedValueIsAllOnes - Return true if '(Op & Mask) == Mask'. bool SelectionDAG::MaskedValueIsAllOnes(SDValue V, const APInt &Mask, unsigned Depth) const { @@ -2587,9 +2617,9 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, return true; } case ISD::VECTOR_SHUFFLE: { - // Check if this is a shuffle node doing a splat. - // TODO: Do we need to handle shuffle(splat, undef, mask)? - int SplatIndex = -1; + // Check if this is a shuffle node doing a splat or a shuffle of a splat. + APInt DemandedLHS = APInt::getNullValue(NumElts); + APInt DemandedRHS = APInt::getNullValue(NumElts); ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(V)->getMask(); for (int i = 0; i != (int)NumElts; ++i) { int M = Mask[i]; @@ -2599,11 +2629,30 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, } if (!DemandedElts[i]) continue; - if (0 <= SplatIndex && SplatIndex != M) - return false; - SplatIndex = M; + if (M < (int)NumElts) + DemandedLHS.setBit(M); + else + DemandedRHS.setBit(M - NumElts); } - return true; + + // If we aren't demanding either op, assume there's no splat. + // If we are demanding both ops, assume there's no splat. + if ((DemandedLHS.isZero() && DemandedRHS.isZero()) || + (!DemandedLHS.isZero() && !DemandedRHS.isZero())) + return false; + + // See if the demanded elts of the source op is a splat or we only demand + // one element, which should always be a splat. + // TODO: Handle source ops splats with undefs. + auto CheckSplatSrc = [&](SDValue Src, const APInt &SrcElts) { + APInt SrcUndefs; + return (SrcElts.countPopulation() == 1) || + (isSplatValue(Src, SrcElts, SrcUndefs, Depth + 1) && + (SrcElts & SrcUndefs).isZero()); + }; + if (!DemandedLHS.isZero()) + return CheckSplatSrc(V.getOperand(0), DemandedLHS); + return CheckSplatSrc(V.getOperand(1), DemandedRHS); } case ISD::EXTRACT_SUBVECTOR: { // Offset the demanded elts by the subvector index. @@ -2614,7 +2663,7 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, uint64_t Idx = V.getConstantOperandVal(1); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); APInt UndefSrcElts; - APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx); if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts, Depth + 1)) { UndefElts = UndefSrcElts.extractBits(NumElts, Idx); return true; @@ -2631,9 +2680,49 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, return false; unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); APInt UndefSrcElts; - APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts); + APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts); if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts, Depth + 1)) { - UndefElts = UndefSrcElts.truncOrSelf(NumElts); + UndefElts = UndefSrcElts.trunc(NumElts); + return true; + } + break; + } + case ISD::BITCAST: { + SDValue Src = V.getOperand(0); + EVT SrcVT = Src.getValueType(); + unsigned SrcBitWidth = SrcVT.getScalarSizeInBits(); + unsigned BitWidth = VT.getScalarSizeInBits(); + + // Ignore bitcasts from unsupported types. + // TODO: Add fp support? + if (!SrcVT.isVector() || !SrcVT.isInteger() || !VT.isInteger()) + break; + + // Bitcast 'small element' vector to 'large element' vector. + if ((BitWidth % SrcBitWidth) == 0) { + // See if each sub element is a splat. + unsigned Scale = BitWidth / SrcBitWidth; + unsigned NumSrcElts = SrcVT.getVectorNumElements(); + APInt ScaledDemandedElts = + APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); + for (unsigned I = 0; I != Scale; ++I) { + APInt SubUndefElts; + APInt SubDemandedElt = APInt::getOneBitSet(Scale, I); + APInt SubDemandedElts = APInt::getSplat(NumSrcElts, SubDemandedElt); + SubDemandedElts &= ScaledDemandedElts; + if (!isSplatValue(Src, SubDemandedElts, SubUndefElts, Depth + 1)) + return false; + + // Here we can't do "MatchAnyBits" operation merge for undef bits. + // Because some operation only use part value of the source. + // Take llvm.fshl.* for example: + // t1: v4i32 = Constant:i32<12>, undef:i32, Constant:i32<12>, undef:i32 + // t2: v2i64 = bitcast t1 + // t5: v2i64 = fshl t3, t4, t2 + // We can not convert t2 to {i64 undef, i64 undef} + UndefElts |= APIntOps::ScaleBitMask(SubUndefElts, NumElts, + /*MatchAllBits=*/true); + } return true; } break; @@ -2978,7 +3067,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, break; uint64_t Idx = Op.getConstantOperandVal(1); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx); Known = computeKnownBits(Src, DemandedSrcElts, Depth + 1); break; } @@ -3083,9 +3172,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); bool SelfMultiply = Op.getOperand(0) == Op.getOperand(1); // TODO: SelfMultiply can be poison, but not undef. - SelfMultiply &= isGuaranteedNotToBeUndefOrPoison( - Op.getOperand(0), DemandedElts, false, Depth + 1); + if (SelfMultiply) + SelfMultiply &= isGuaranteedNotToBeUndefOrPoison( + Op.getOperand(0), DemandedElts, false, Depth + 1); Known = KnownBits::mul(Known, Known2, SelfMultiply); + + // If the multiplication is known not to overflow, the product of a number + // with itself is non-negative. Only do this if we didn't already computed + // the opposite value for the sign bit. + if (Op->getFlags().hasNoSignedWrap() && + Op.getOperand(0) == Op.getOperand(1) && + !Known.isNegative()) + Known.makeNonNegative(); break; } case ISD::MULHU: { @@ -3128,6 +3226,16 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known = KnownBits::udiv(Known, Known2); break; } + case ISD::AVGCEILU: { + Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = Known.zext(BitWidth + 1); + Known2 = Known2.zext(BitWidth + 1); + KnownBits One = KnownBits::makeConstant(APInt(1, 1)); + Known = KnownBits::computeForAddCarry(Known, Known2, One); + Known = Known.extractBits(BitWidth, 1); + break; + } case ISD::SELECT: case ISD::VSELECT: Known = computeKnownBits(Op.getOperand(2), DemandedElts, Depth+1); @@ -3330,7 +3438,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, } case ISD::ZERO_EXTEND_VECTOR_INREG: { EVT InVT = Op.getOperand(0).getValueType(); - APInt InDemandedElts = DemandedElts.zextOrSelf(InVT.getVectorNumElements()); + APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements()); Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1); Known = Known.zext(BitWidth); break; @@ -3342,7 +3450,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, } case ISD::SIGN_EXTEND_VECTOR_INREG: { EVT InVT = Op.getOperand(0).getValueType(); - APInt InDemandedElts = DemandedElts.zextOrSelf(InVT.getVectorNumElements()); + APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements()); Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1); // If the sign bit is known to be zero or one, then sext will extend // it to the top bits, else it will just zext. @@ -3358,7 +3466,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, } case ISD::ANY_EXTEND_VECTOR_INREG: { EVT InVT = Op.getOperand(0).getValueType(); - APInt InDemandedElts = DemandedElts.zextOrSelf(InVT.getVectorNumElements()); + APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements()); Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1); Known = Known.anyext(BitWidth); break; @@ -3605,6 +3713,19 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known = KnownBits::smax(Known, Known2); else Known = KnownBits::smin(Known, Known2); + + // For SMAX, if CstLow is non-negative we know the result will be + // non-negative and thus all sign bits are 0. + // TODO: There's an equivalent of this for smin with negative constant for + // known ones. + if (IsMax && CstLow) { + const APInt &ValueLow = CstLow->getAPIntValue(); + if (ValueLow.isNonNegative()) { + unsigned SignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1); + Known.Zero.setHighBits(std::min(SignBits, ValueLow.getNumSignBits())); + } + } + break; } case ISD::FP_TO_UINT_SAT: { @@ -3905,7 +4026,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, case ISD::SIGN_EXTEND_VECTOR_INREG: { SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); - APInt DemandedSrcElts = DemandedElts.zextOrSelf(SrcVT.getVectorNumElements()); + APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements()); Tmp = VTBits - SrcVT.getScalarSizeInBits(); return ComputeNumSignBits(Src, DemandedSrcElts, Depth+1) + Tmp; } @@ -4192,7 +4313,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, break; uint64_t Idx = Op.getConstantOperandVal(1); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx); return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1); } case ISD::CONCAT_VECTORS: { @@ -4585,26 +4706,54 @@ bool SelectionDAG::isEqualTo(SDValue A, SDValue B) const { return false; } +// Only bits set in Mask must be negated, other bits may be arbitrary. +SDValue llvm::getBitwiseNotOperand(SDValue V, SDValue Mask, bool AllowUndefs) { + if (isBitwiseNot(V, AllowUndefs)) + return V.getOperand(0); + + // Handle any_extend (not (truncate X)) pattern, where Mask only sets + // bits in the non-extended part. + ConstantSDNode *MaskC = isConstOrConstSplat(Mask); + if (!MaskC || V.getOpcode() != ISD::ANY_EXTEND) + return SDValue(); + SDValue ExtArg = V.getOperand(0); + if (ExtArg.getScalarValueSizeInBits() >= + MaskC->getAPIntValue().getActiveBits() && + isBitwiseNot(ExtArg, AllowUndefs) && + ExtArg.getOperand(0).getOpcode() == ISD::TRUNCATE && + ExtArg.getOperand(0).getOperand(0).getValueType() == V.getValueType()) + return ExtArg.getOperand(0).getOperand(0); + return SDValue(); +} + +static bool haveNoCommonBitsSetCommutative(SDValue A, SDValue B) { + // Match masked merge pattern (X & ~M) op (Y & M) + // Including degenerate case (X & ~M) op M + auto MatchNoCommonBitsPattern = [&](SDValue Not, SDValue Mask, + SDValue Other) { + if (SDValue NotOperand = + getBitwiseNotOperand(Not, Mask, /* AllowUndefs */ true)) { + if (Other == NotOperand) + return true; + if (Other->getOpcode() == ISD::AND) + return NotOperand == Other->getOperand(0) || + NotOperand == Other->getOperand(1); + } + return false; + }; + if (A->getOpcode() == ISD::AND) + return MatchNoCommonBitsPattern(A->getOperand(0), A->getOperand(1), B) || + MatchNoCommonBitsPattern(A->getOperand(1), A->getOperand(0), B); + return false; +} + // FIXME: unify with llvm::haveNoCommonBitsSet. bool SelectionDAG::haveNoCommonBitsSet(SDValue A, SDValue B) const { assert(A.getValueType() == B.getValueType() && "Values must have the same type"); - // Match masked merge pattern (X & ~M) op (Y & M) - if (A->getOpcode() == ISD::AND && B->getOpcode() == ISD::AND) { - auto MatchNoCommonBitsPattern = [&](SDValue NotM, SDValue And) { - if (isBitwiseNot(NotM, true)) { - SDValue NotOperand = NotM->getOperand(0); - return NotOperand == And->getOperand(0) || - NotOperand == And->getOperand(1); - } - return false; - }; - if (MatchNoCommonBitsPattern(A->getOperand(0), B) || - MatchNoCommonBitsPattern(A->getOperand(1), B) || - MatchNoCommonBitsPattern(B->getOperand(0), A) || - MatchNoCommonBitsPattern(B->getOperand(1), A)) - return true; - } + if (haveNoCommonBitsSetCommutative(A, B) || + haveNoCommonBitsSetCommutative(B, A)) + return true; return KnownBits::haveNoCommonBitsSet(computeKnownBits(A), computeKnownBits(B)); } @@ -4833,9 +4982,11 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, case ISD::CTTZ_ZERO_UNDEF: return getConstant(Val.countTrailingZeros(), DL, VT, C->isTargetOpcode(), C->isOpaque()); - case ISD::FP16_TO_FP: { + case ISD::FP16_TO_FP: + case ISD::BF16_TO_FP: { bool Ignored; - APFloat FPV(APFloat::IEEEhalf(), + APFloat FPV(Opcode == ISD::FP16_TO_FP ? APFloat::IEEEhalf() + : APFloat::BFloat(), (Val.getBitWidth() == 16) ? Val : Val.trunc(16)); // This can return overflow, underflow, or inexact; we don't care. @@ -4909,11 +5060,13 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, if (VT == MVT::i64 && C->getValueType(0) == MVT::f64) return getConstant(V.bitcastToAPInt().getZExtValue(), DL, VT); break; - case ISD::FP_TO_FP16: { + case ISD::FP_TO_FP16: + case ISD::FP_TO_BF16: { bool Ignored; // This can return overflow, underflow, or inexact; we don't care. // FIXME need to be more flexible about rounding mode. - (void)V.convert(APFloat::IEEEhalf(), + (void)V.convert(Opcode == ISD::FP_TO_FP16 ? APFloat::IEEEhalf() + : APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored); return getConstant(V.bitcastToAPInt().getZExtValue(), DL, VT); } @@ -4965,6 +5118,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, break; case ISD::FREEZE: assert(VT == Operand.getValueType() && "Unexpected VT!"); + if (isGuaranteedNotToBeUndefOrPoison(Operand)) + return Operand; break; case ISD::TokenFactor: case ISD::MERGE_VALUES: @@ -5114,7 +5269,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, assert(VT.isInteger() && VT == Operand.getValueType() && "Invalid ABS!"); if (OpOpcode == ISD::UNDEF) - return getUNDEF(VT); + return getConstant(0, DL, VT); break; case ISD::BSWAP: assert(VT.isInteger() && VT == Operand.getValueType() && @@ -5182,6 +5337,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, if (Operand.getValueType().getScalarType() == MVT::i1) return getNOT(DL, Operand, Operand.getValueType()); break; + case ISD::VECREDUCE_ADD: + if (Operand.getValueType().getScalarType() == MVT::i1) + return getNode(ISD::VECREDUCE_XOR, DL, VT, Operand); + break; case ISD::VECREDUCE_SMIN: case ISD::VECREDUCE_UMAX: if (Operand.getValueType().getScalarType() == MVT::i1) @@ -5273,6 +5432,30 @@ static llvm::Optional<APInt> FoldValue(unsigned Opcode, const APInt &C1, APInt C2Ext = C2.zext(FullWidth); return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); } + case ISD::AVGFLOORS: { + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); + } + case ISD::AVGFLOORU: { + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); + } + case ISD::AVGCEILS: { + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); + } + case ISD::AVGCEILU: { + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); + } } return llvm::None; } @@ -5355,7 +5538,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, if (!FoldAttempt) return SDValue(); - SDValue Folded = getConstant(FoldAttempt.getValue(), DL, VT); + SDValue Folded = getConstant(*FoldAttempt, DL, VT); assert((!Folded || !VT.isVector()) && "Can't fold vectors ops with scalar operands"); return Folded; @@ -5400,7 +5583,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, Optional<APInt> Fold = FoldValue(Opcode, RawBits1[I], RawBits2[I]); if (!Fold) break; - RawBits.push_back(Fold.getValue()); + RawBits.push_back(*Fold); } if (RawBits.size() == NumElts.getFixedValue()) { // We have constant folded, but we need to cast this again back to @@ -5416,7 +5599,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, for (unsigned I = 0, E = DstBits.size(); I != E; ++I) { if (DstUndefs[I]) continue; - Ops[I] = getConstant(DstBits[I].sextOrSelf(BVEltBits), DL, BVEltVT); + Ops[I] = getConstant(DstBits[I].sext(BVEltBits), DL, BVEltVT); } return getBitcast(VT, getBuildVector(BVVT, DL, Ops)); } @@ -5455,9 +5638,14 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, !llvm::all_of(Ops, IsScalarOrSameVectorSize)) return SDValue(); - // If we are comparing vectors, then the result needs to be a i1 boolean - // that is then sign-extended back to the legal result type. + // If we are comparing vectors, then the result needs to be a i1 boolean that + // is then extended back to the legal result type depending on how booleans + // are represented. EVT SVT = (Opcode == ISD::SETCC ? MVT::i1 : VT.getScalarType()); + ISD::NodeType ExtendCode = + (Opcode == ISD::SETCC && SVT != VT.getScalarType()) + ? TargetLowering::getExtendForContent(TLI->getBooleanContents(VT)) + : ISD::SIGN_EXTEND; // Find legal integer scalar type for constant promotion and // ensure that its scalar size is at least as large as source. @@ -5515,7 +5703,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, // Legalize the (integer) scalar constant if necessary. if (LegalSVT != SVT) - ScalarResult = getNode(ISD::SIGN_EXTEND, DL, LegalSVT, ScalarResult); + ScalarResult = getNode(ExtendCode, DL, LegalSVT, ScalarResult); // Scalar folding only succeeded if the result is a constant or UNDEF. if (!ScalarResult.isUndef() && ScalarResult.getOpcode() != ISD::Constant && @@ -5639,20 +5827,34 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, return getNode(Opcode, DL, VT, N1, N2, Flags); } +void SelectionDAG::canonicalizeCommutativeBinop(unsigned Opcode, SDValue &N1, + SDValue &N2) const { + if (!TLI->isCommutativeBinOp(Opcode)) + return; + + // Canonicalize: + // binop(const, nonconst) -> binop(nonconst, const) + bool IsN1C = isConstantIntBuildVectorOrConstantInt(N1); + bool IsN2C = isConstantIntBuildVectorOrConstantInt(N2); + bool IsN1CFP = isConstantFPBuildVectorOrConstantFP(N1); + bool IsN2CFP = isConstantFPBuildVectorOrConstantFP(N2); + if ((IsN1C && !IsN2C) || (IsN1CFP && !IsN2CFP)) + std::swap(N1, N2); + + // Canonicalize: + // binop(splat(x), step_vector) -> binop(step_vector, splat(x)) + else if (N1.getOpcode() == ISD::SPLAT_VECTOR && + N2.getOpcode() == ISD::STEP_VECTOR) + std::swap(N1, N2); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags Flags) { assert(N1.getOpcode() != ISD::DELETED_NODE && N2.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!"); - // Canonicalize constant to RHS if commutative. - if (TLI->isCommutativeBinOp(Opcode)) { - bool IsN1C = isConstantIntBuildVectorOrConstantInt(N1); - bool IsN2C = isConstantIntBuildVectorOrConstantInt(N2); - bool IsN1CFP = isConstantFPBuildVectorOrConstantFP(N1); - bool IsN2CFP = isConstantFPBuildVectorOrConstantFP(N2); - if ((IsN1C && !IsN2C) || (IsN1CFP && !IsN2CFP)) - std::swap(N1, N2); - } + + canonicalizeCommutativeBinop(Opcode, N1, N2); auto *N1C = dyn_cast<ConstantSDNode>(N1); auto *N2C = dyn_cast<ConstantSDNode>(N2); @@ -5956,6 +6158,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, if (N1Op2C->getZExtValue() == N2C->getZExtValue()) { if (VT == N1.getOperand(1).getValueType()) return N1.getOperand(1); + if (VT.isFloatingPoint()) { + assert(VT.getSizeInBits() > N1.getOperand(1).getValueType().getSizeInBits()); + return getFPExtendOrRound(N1.getOperand(1), DL, VT); + } return getSExtOrTrunc(N1.getOperand(1), DL, VT); } return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, N1.getOperand(0), N2); @@ -6053,9 +6259,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, std::swap(N1, N2); } else { switch (Opcode) { - case ISD::SIGN_EXTEND_INREG: case ISD::SUB: return getUNDEF(VT); // fold op(undef, arg2) -> undef + case ISD::SIGN_EXTEND_INREG: case ISD::UDIV: case ISD::SDIV: case ISD::UREM: @@ -6544,7 +6750,7 @@ static SDValue getMemcpyLoadsAndStores(SelectionDAG &DAG, const SDLoc &dl, const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); if (!TRI->hasStackRealignment(MF)) while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign)) - NewAlign = NewAlign / 2; + NewAlign = NewAlign.previous(); if (NewAlign > Alignment) { // Give the stack frame object a larger alignment if needed. @@ -6792,17 +6998,18 @@ static SDValue getMemmoveLoadsAndStores(SelectionDAG &DAG, const SDLoc &dl, /// \param Size Number of bytes to write. /// \param Alignment Alignment of the destination in bytes. /// \param isVol True if destination is volatile. +/// \param AlwaysInline Makes sure no function call is generated. /// \param DstPtrInfo IR information on the memory pointer. /// \returns New head in the control flow, if lowering was successful, empty /// SDValue otherwise. /// /// The function tries to replace 'llvm.memset' intrinsic with several store /// operations and value calculation code. This is usually profitable for small -/// memory size. +/// memory size or when the semantic requires inlining. static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src, uint64_t Size, Align Alignment, bool isVol, - MachinePointerInfo DstPtrInfo, + bool AlwaysInline, MachinePointerInfo DstPtrInfo, const AAMDNodes &AAInfo) { // Turn a memset of undef to nop. // FIXME: We need to honor volatile even is Src is undef. @@ -6822,8 +7029,10 @@ static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl, DstAlignCanChange = true; bool IsZeroVal = isa<ConstantSDNode>(Src) && cast<ConstantSDNode>(Src)->isZero(); + unsigned Limit = AlwaysInline ? ~0 : TLI.getMaxStoresPerMemset(OptSize); + if (!TLI.findOptimalMemOpLowering( - MemOps, TLI.getMaxStoresPerMemset(OptSize), + MemOps, Limit, MemOp::Set(Size, DstAlignCanChange, Alignment, IsZeroVal, isVol), DstPtrInfo.getAddrSpace(), ~0u, MF.getFunction().getAttributes())) return SDValue(); @@ -6974,10 +7183,9 @@ SDValue SelectionDAG::getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst, } SDValue SelectionDAG::getAtomicMemcpy(SDValue Chain, const SDLoc &dl, - SDValue Dst, unsigned DstAlign, - SDValue Src, unsigned SrcAlign, - SDValue Size, Type *SizeTy, - unsigned ElemSz, bool isTailCall, + SDValue Dst, SDValue Src, SDValue Size, + Type *SizeTy, unsigned ElemSz, + bool isTailCall, MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) { // Emit a library call. @@ -7077,10 +7285,9 @@ SDValue SelectionDAG::getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst, } SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl, - SDValue Dst, unsigned DstAlign, - SDValue Src, unsigned SrcAlign, - SDValue Size, Type *SizeTy, - unsigned ElemSz, bool isTailCall, + SDValue Dst, SDValue Src, SDValue Size, + Type *SizeTy, unsigned ElemSz, + bool isTailCall, MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) { // Emit a library call. @@ -7119,7 +7326,7 @@ SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl, SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src, SDValue Size, Align Alignment, - bool isVol, bool isTailCall, + bool isVol, bool AlwaysInline, bool isTailCall, MachinePointerInfo DstPtrInfo, const AAMDNodes &AAInfo) { // Check to see if we should lower the memset to stores first. @@ -7132,7 +7339,7 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Result = getMemsetStores(*this, dl, Chain, Dst, Src, ConstantSize->getZExtValue(), Alignment, - isVol, DstPtrInfo, AAInfo); + isVol, false, DstPtrInfo, AAInfo); if (Result.getNode()) return Result; @@ -7142,45 +7349,75 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, // code. If the target chooses to do this, this is the next best. if (TSI) { SDValue Result = TSI->EmitTargetCodeForMemset( - *this, dl, Chain, Dst, Src, Size, Alignment, isVol, DstPtrInfo); + *this, dl, Chain, Dst, Src, Size, Alignment, isVol, AlwaysInline, DstPtrInfo); if (Result.getNode()) return Result; } + // If we really need inline code and the target declined to provide it, + // use a (potentially long) sequence of loads and stores. + if (AlwaysInline) { + assert(ConstantSize && "AlwaysInline requires a constant size!"); + SDValue Result = getMemsetStores(*this, dl, Chain, Dst, Src, + ConstantSize->getZExtValue(), Alignment, + isVol, true, DstPtrInfo, AAInfo); + assert(Result && + "getMemsetStores must return a valid sequence when AlwaysInline"); + return Result; + } + checkAddrSpaceIsValidForLibcall(TLI, DstPtrInfo.getAddrSpace()); // Emit a library call. - TargetLowering::ArgListTy Args; - TargetLowering::ArgListEntry Entry; - Entry.Node = Dst; Entry.Ty = Type::getInt8PtrTy(*getContext()); - Args.push_back(Entry); - Entry.Node = Src; - Entry.Ty = Src.getValueType().getTypeForEVT(*getContext()); - Args.push_back(Entry); - Entry.Node = Size; - Entry.Ty = getDataLayout().getIntPtrType(*getContext()); - Args.push_back(Entry); + auto &Ctx = *getContext(); + const auto& DL = getDataLayout(); - // FIXME: pass in SDLoc TargetLowering::CallLoweringInfo CLI(*this); - CLI.setDebugLoc(dl) - .setChain(Chain) - .setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMSET), - Dst.getValueType().getTypeForEVT(*getContext()), - getExternalSymbol(TLI->getLibcallName(RTLIB::MEMSET), - TLI->getPointerTy(getDataLayout())), - std::move(Args)) - .setDiscardResult() - .setTailCall(isTailCall); + // FIXME: pass in SDLoc + CLI.setDebugLoc(dl).setChain(Chain); + + ConstantSDNode *ConstantSrc = dyn_cast<ConstantSDNode>(Src); + const bool SrcIsZero = ConstantSrc && ConstantSrc->isZero(); + const char *BzeroName = getTargetLoweringInfo().getLibcallName(RTLIB::BZERO); + + // Helper function to create an Entry from Node and Type. + const auto CreateEntry = [](SDValue Node, Type *Ty) { + TargetLowering::ArgListEntry Entry; + Entry.Node = Node; + Entry.Ty = Ty; + return Entry; + }; - std::pair<SDValue,SDValue> CallResult = TLI->LowerCallTo(CLI); + // If zeroing out and bzero is present, use it. + if (SrcIsZero && BzeroName) { + TargetLowering::ArgListTy Args; + Args.push_back(CreateEntry(Dst, Type::getInt8PtrTy(Ctx))); + Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx))); + CLI.setLibCallee( + TLI->getLibcallCallingConv(RTLIB::BZERO), Type::getVoidTy(Ctx), + getExternalSymbol(BzeroName, TLI->getPointerTy(DL)), std::move(Args)); + } else { + TargetLowering::ArgListTy Args; + Args.push_back(CreateEntry(Dst, Type::getInt8PtrTy(Ctx))); + Args.push_back(CreateEntry(Src, Src.getValueType().getTypeForEVT(Ctx))); + Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx))); + CLI.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMSET), + Dst.getValueType().getTypeForEVT(Ctx), + getExternalSymbol(TLI->getLibcallName(RTLIB::MEMSET), + TLI->getPointerTy(DL)), + std::move(Args)); + } + + CLI.setDiscardResult().setTailCall(isTailCall); + + std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI); return CallResult.second; } SDValue SelectionDAG::getAtomicMemset(SDValue Chain, const SDLoc &dl, - SDValue Dst, unsigned DstAlign, - SDValue Value, SDValue Size, Type *SizeTy, - unsigned ElemSz, bool isTailCall, + SDValue Dst, SDValue Value, SDValue Size, + Type *SizeTy, unsigned ElemSz, + bool isTailCall, MachinePointerInfo DstPtrInfo) { // Emit a library call. TargetLowering::ArgListTy Args; @@ -7224,6 +7461,7 @@ SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT, ID.AddInteger(MemVT.getRawBits()); AddNodeIDNode(ID, Opcode, VTList, Ops); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void* IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<AtomicSDNode>(E)->refineAlignment(MMO); @@ -7336,6 +7574,7 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<MemIntrinsicSDNode>( Opcode, dl.getIROrder(), VTList, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<MemIntrinsicSDNode>(E)->refineAlignment(MMO); @@ -7508,6 +7747,7 @@ SDValue SelectionDAG::getLoad(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, ID.AddInteger(getSyntheticNodeSubclassData<LoadSDNode>( dl.getIROrder(), VTs, AM, ExtType, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<LoadSDNode>(E)->refineAlignment(MMO); @@ -7609,6 +7849,7 @@ SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val, ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>( dl.getIROrder(), VTs, ISD::UNINDEXED, false, VT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<StoreSDNode>(E)->refineAlignment(MMO); @@ -7675,6 +7916,7 @@ SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val, ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>( dl.getIROrder(), VTs, ISD::UNINDEXED, true, SVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<StoreSDNode>(E)->refineAlignment(MMO); @@ -7703,6 +7945,7 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl, ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) return SDValue(E, 0); @@ -7760,6 +8003,7 @@ SDValue SelectionDAG::getLoadVP(ISD::MemIndexedMode AM, ID.AddInteger(getSyntheticNodeSubclassData<VPLoadSDNode>( dl.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<VPLoadSDNode>(E)->refineAlignment(MMO); @@ -7852,6 +8096,7 @@ SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>( dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<VPStoreSDNode>(E)->refineAlignment(MMO); @@ -7922,6 +8167,7 @@ SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>( dl.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<VPStoreSDNode>(E)->refineAlignment(MMO); @@ -7952,6 +8198,7 @@ SDValue SelectionDAG::getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) return SDValue(E, 0); @@ -7968,6 +8215,259 @@ SDValue SelectionDAG::getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, return V; } +SDValue SelectionDAG::getStridedLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges, bool IsExpanding) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOLoad; + assert((MMOFlags & MachineMemOperand::MOStore) == 0); + // If we don't have a PtrInfo, infer the trivial frame index case to simplify + // clients. + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset); + + uint64_t Size = MemoryLocation::UnknownSize; + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, + Alignment, AAInfo, Ranges); + return getStridedLoadVP(AM, ExtType, VT, DL, Chain, Ptr, Offset, Stride, Mask, + EVL, MemVT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getStridedLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, EVT MemVT, MachineMemOperand *MMO, bool IsExpanding) { + bool Indexed = AM != ISD::UNINDEXED; + assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!"); + + SDValue Ops[] = {Chain, Ptr, Offset, Stride, Mask, EVL}; + SDVTList VTs = Indexed ? getVTList(VT, Ptr.getValueType(), MVT::Other) + : getVTList(VT, MVT::Other); + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData<VPStridedLoadSDNode>( + DL.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast<VPStridedLoadSDNode>(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = + newSDNode<VPStridedLoadSDNode>(DL.getIROrder(), DL.getDebugLoc(), VTs, AM, + ExtType, IsExpanding, MemVT, MMO); + createOperands(N, Ops); + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getStridedLoadVP( + EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr, + Undef, Stride, Mask, EVL, PtrInfo, VT, Alignment, + MMOFlags, AAInfo, Ranges, IsExpanding); +} + +SDValue SelectionDAG::getStridedLoadVP(EVT VT, const SDLoc &DL, SDValue Chain, + SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, + MachineMemOperand *MMO, + bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr, + Undef, Stride, Mask, EVL, VT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getExtStridedLoadVP( + ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain, + SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL, + MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef, + Stride, Mask, EVL, PtrInfo, MemVT, Alignment, + MMOFlags, AAInfo, nullptr, IsExpanding); +} + +SDValue SelectionDAG::getExtStridedLoadVP( + ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain, + SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL, EVT MemVT, + MachineMemOperand *MMO, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef, + Stride, Mask, EVL, MemVT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getIndexedStridedLoadVP(SDValue OrigLoad, const SDLoc &DL, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM) { + auto *SLD = cast<VPStridedLoadSDNode>(OrigLoad); + assert(SLD->getOffset().isUndef() && + "Strided load is already a indexed load!"); + // Don't propagate the invariant or dereferenceable flags. + auto MMOFlags = + SLD->getMemOperand()->getFlags() & + ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable); + return getStridedLoadVP( + AM, SLD->getExtensionType(), OrigLoad.getValueType(), DL, SLD->getChain(), + Base, Offset, SLD->getStride(), SLD->getMask(), SLD->getVectorLength(), + SLD->getPointerInfo(), SLD->getMemoryVT(), SLD->getAlign(), MMOFlags, + SLD->getAAInfo(), nullptr, SLD->isExpandingLoad()); +} + +SDValue SelectionDAG::getStridedStoreVP(SDValue Chain, const SDLoc &DL, + SDValue Val, SDValue Ptr, + SDValue Offset, SDValue Stride, + SDValue Mask, SDValue EVL, EVT MemVT, + MachineMemOperand *MMO, + ISD::MemIndexedMode AM, + bool IsTruncating, bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + bool Indexed = AM != ISD::UNINDEXED; + assert((Indexed || Offset.isUndef()) && "Unindexed vp_store with an offset!"); + SDVTList VTs = Indexed ? getVTList(Ptr.getValueType(), MVT::Other) + : getVTList(MVT::Other); + SDValue Ops[] = {Chain, Val, Ptr, Offset, Stride, Mask, EVL}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(MemVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData<VPStridedStoreSDNode>( + DL.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast<VPStridedStoreSDNode>(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode<VPStridedStoreSDNode>(DL.getIROrder(), DL.getDebugLoc(), + VTs, AM, IsTruncating, + IsCompressing, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getTruncStridedStoreVP( + SDValue Chain, const SDLoc &DL, SDValue Val, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, EVT SVT, + Align Alignment, MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOStore; + assert((MMOFlags & MachineMemOperand::MOLoad) == 0); + + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); + + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + PtrInfo, MMOFlags, MemoryLocation::UnknownSize, Alignment, AAInfo); + return getTruncStridedStoreVP(Chain, DL, Val, Ptr, Stride, Mask, EVL, SVT, + MMO, IsCompressing); +} + +SDValue SelectionDAG::getTruncStridedStoreVP(SDValue Chain, const SDLoc &DL, + SDValue Val, SDValue Ptr, + SDValue Stride, SDValue Mask, + SDValue EVL, EVT SVT, + MachineMemOperand *MMO, + bool IsCompressing) { + EVT VT = Val.getValueType(); + + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + if (VT == SVT) + return getStridedStoreVP(Chain, DL, Val, Ptr, getUNDEF(Ptr.getValueType()), + Stride, Mask, EVL, VT, MMO, ISD::UNINDEXED, + /*IsTruncating*/ false, IsCompressing); + + assert(SVT.getScalarType().bitsLT(VT.getScalarType()) && + "Should only be a truncating store, not extending!"); + assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!"); + assert(VT.isVector() == SVT.isVector() && + "Cannot use trunc store to convert to or from a vector!"); + assert((!VT.isVector() || + VT.getVectorElementCount() == SVT.getVectorElementCount()) && + "Cannot use trunc store to change the number of vector elements!"); + + SDVTList VTs = getVTList(MVT::Other); + SDValue Undef = getUNDEF(Ptr.getValueType()); + SDValue Ops[] = {Chain, Val, Ptr, Undef, Stride, Mask, EVL}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(SVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData<VPStridedStoreSDNode>( + DL.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast<VPStridedStoreSDNode>(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode<VPStridedStoreSDNode>(DL.getIROrder(), DL.getDebugLoc(), + VTs, ISD::UNINDEXED, true, + IsCompressing, SVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getIndexedStridedStoreVP(SDValue OrigStore, + const SDLoc &DL, SDValue Base, + SDValue Offset, + ISD::MemIndexedMode AM) { + auto *SST = cast<VPStridedStoreSDNode>(OrigStore); + assert(SST->getOffset().isUndef() && + "Strided store is already an indexed store!"); + SDVTList VTs = getVTList(Base.getValueType(), MVT::Other); + SDValue Ops[] = { + SST->getChain(), SST->getValue(), Base, Offset, SST->getStride(), + SST->getMask(), SST->getVectorLength()}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(SST->getMemoryVT().getRawBits()); + ID.AddInteger(SST->getRawSubclassData()); + ID.AddInteger(SST->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) + return SDValue(E, 0); + + auto *N = newSDNode<VPStridedStoreSDNode>( + DL.getIROrder(), DL.getDebugLoc(), VTs, AM, SST->isTruncatingStore(), + SST->isCompressingStore(), SST->getMemoryVT(), SST->getMemOperand()); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef<SDValue> Ops, MachineMemOperand *MMO, ISD::MemIndexType IndexType) { @@ -7979,6 +8479,7 @@ SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<VPGatherSDNode>( dl.getIROrder(), VTs, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<VPGatherSDNode>(E)->refineAlignment(MMO); @@ -8022,6 +8523,7 @@ SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<VPScatterSDNode>( dl.getIROrder(), VTs, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<VPScatterSDNode>(E)->refineAlignment(MMO); @@ -8071,6 +8573,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>( dl.getIROrder(), VTs, AM, ExtTy, isExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<MaskedLoadSDNode>(E)->refineAlignment(MMO); @@ -8118,6 +8621,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>( dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<MaskedStoreSDNode>(E)->refineAlignment(MMO); @@ -8159,13 +8663,13 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>( dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<MaskedGatherSDNode>(E)->refineAlignment(MMO); return SDValue(E, 0); } - IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]); auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, MemVT, MMO, IndexType, ExtTy); createOperands(N, Ops); @@ -8206,13 +8710,13 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl, ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>( dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast<MaskedScatterSDNode>(E)->refineAlignment(MMO); return SDValue(E, 0); } - IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]); auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, MemVT, MMO, IndexType, IsTrunc); createOperands(N, Ops); @@ -8410,6 +8914,41 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, assert(Ops[2].getValueType() == Ops[3].getValueType() && "LHS/RHS of comparison should match types!"); break; + case ISD::VP_ADD: + case ISD::VP_SUB: + // If it is VP_ADD/VP_SUB mask operation then turn it to VP_XOR + if (VT.isVector() && VT.getVectorElementType() == MVT::i1) + Opcode = ISD::VP_XOR; + break; + case ISD::VP_MUL: + // If it is VP_MUL mask operation then turn it to VP_AND + if (VT.isVector() && VT.getVectorElementType() == MVT::i1) + Opcode = ISD::VP_AND; + break; + case ISD::VP_REDUCE_MUL: + // If it is VP_REDUCE_MUL mask operation then turn it to VP_REDUCE_AND + if (VT == MVT::i1) + Opcode = ISD::VP_REDUCE_AND; + break; + case ISD::VP_REDUCE_ADD: + // If it is VP_REDUCE_ADD mask operation then turn it to VP_REDUCE_XOR + if (VT == MVT::i1) + Opcode = ISD::VP_REDUCE_XOR; + break; + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_UMIN: + // If it is VP_REDUCE_SMAX/VP_REDUCE_UMIN mask operation then turn it to + // VP_REDUCE_AND. + if (VT == MVT::i1) + Opcode = ISD::VP_REDUCE_AND; + break; + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_UMAX: + // If it is VP_REDUCE_SMIN/VP_REDUCE_UMAX mask operation then turn it to + // VP_REDUCE_OR. + if (VT == MVT::i1) + Opcode = ISD::VP_REDUCE_OR; + break; } // Memoize nodes. @@ -8456,7 +8995,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, ArrayRef<SDValue> Ops, const SDNodeFlags Flags) { if (VTList.NumVTs == 1) - return getNode(Opcode, DL, VTList.VTs[0], Ops); + return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags); #ifndef NDEBUG for (auto &Op : Ops) @@ -9669,19 +10208,36 @@ void SelectionDAG::ReplaceAllUsesOfValueWith(SDValue From, SDValue To){ namespace { - /// UseMemo - This class is used by SelectionDAG::ReplaceAllUsesOfValuesWith - /// to record information about a use. - struct UseMemo { - SDNode *User; - unsigned Index; - SDUse *Use; - }; +/// UseMemo - This class is used by SelectionDAG::ReplaceAllUsesOfValuesWith +/// to record information about a use. +struct UseMemo { + SDNode *User; + unsigned Index; + SDUse *Use; +}; - /// operator< - Sort Memos by User. - bool operator<(const UseMemo &L, const UseMemo &R) { - return (intptr_t)L.User < (intptr_t)R.User; +/// operator< - Sort Memos by User. +bool operator<(const UseMemo &L, const UseMemo &R) { + return (intptr_t)L.User < (intptr_t)R.User; +} + +/// RAUOVWUpdateListener - Helper for ReplaceAllUsesOfValuesWith - When the node +/// pointed to by a UseMemo is deleted, set the User to nullptr to indicate that +/// the node already has been taken care of recursively. +class RAUOVWUpdateListener : public SelectionDAG::DAGUpdateListener { + SmallVector<UseMemo, 4> &Uses; + + void NodeDeleted(SDNode *N, SDNode *E) override { + for (UseMemo &Memo : Uses) + if (Memo.User == N) + Memo.User = nullptr; } +public: + RAUOVWUpdateListener(SelectionDAG &d, SmallVector<UseMemo, 4> &uses) + : SelectionDAG::DAGUpdateListener(d), Uses(uses) {} +}; + } // end anonymous namespace bool SelectionDAG::calculateDivergence(SDNode *N) { @@ -9773,12 +10329,19 @@ void SelectionDAG::ReplaceAllUsesOfValuesWith(const SDValue *From, // Sort the uses, so that all the uses from a given User are together. llvm::sort(Uses); + RAUOVWUpdateListener Listener(*this, Uses); for (unsigned UseIndex = 0, UseIndexEnd = Uses.size(); UseIndex != UseIndexEnd; ) { // We know that this user uses some value of From. If it is the right // value, update it. SDNode *User = Uses[UseIndex].User; + // If the node has been deleted by recursive CSE updates when updating + // another node, then just skip this entry. + if (User == nullptr) { + ++UseIndex; + continue; + } // This node is about to morph, remove its old self from the CSE maps. RemoveNodeFromCSEMaps(User); @@ -9975,6 +10538,11 @@ bool llvm::isOneConstant(SDValue V) { return Const != nullptr && Const->isOne(); } +bool llvm::isMinSignedConstant(SDValue V) { + ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V); + return Const != nullptr && Const->isMinSignedValue(); +} + SDValue llvm::peekThroughBitcasts(SDValue V) { while (V.getOpcode() == ISD::BITCAST) V = V.getOperand(0); @@ -10105,10 +10673,9 @@ bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) { } bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) { - // TODO: may want to use peekThroughBitcast() here. - unsigned BitWidth = N.getScalarValueSizeInBits(); - ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); - return C && C->isOne() && C->getValueSizeInBits(0) == BitWidth; + ConstantSDNode *C = + isConstOrConstSplat(N, AllowUndefs, /*AllowTruncation*/ true); + return C && C->isOne(); } bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) { @@ -10957,9 +11524,8 @@ bool BuildVectorSDNode::getConstantRawBits( auto *CInt = dyn_cast<ConstantSDNode>(Op); auto *CFP = dyn_cast<ConstantFPSDNode>(Op); assert((CInt || CFP) && "Unknown constant"); - SrcBitElements[I] = - CInt ? CInt->getAPIntValue().truncOrSelf(SrcEltSizeInBits) - : CFP->getValueAPF().bitcastToAPInt(); + SrcBitElements[I] = CInt ? CInt->getAPIntValue().trunc(SrcEltSizeInBits) + : CFP->getValueAPF().bitcastToAPInt(); } // Recast to dst width. @@ -11078,6 +11644,10 @@ SDNode *SelectionDAG::isConstantFPBuildVectorOrConstantFP(SDValue N) const { if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode())) return N.getNode(); + if ((N.getOpcode() == ISD::SPLAT_VECTOR) && + isa<ConstantFPSDNode>(N.getOperand(0))) + return N.getNode(); + return nullptr; } |