diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
| commit | b1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch) | |
| tree | 7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | |
| parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) | |
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 487 |
1 files changed, 354 insertions, 133 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index a84d35a6ea4e..c5977546828f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -18,6 +18,7 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" +#include "llvm/CodeGen/MachineModuleInfoImpls.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetRegisterInfo.h" @@ -472,6 +473,17 @@ TargetLowering::getPICJumpTableRelocBaseExpr(const MachineFunction *MF, return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx); } +SDValue TargetLowering::expandIndirectJTBranch(const SDLoc &dl, SDValue Value, + SDValue Addr, int JTI, + SelectionDAG &DAG) const { + SDValue Chain = Value; + // Jump table debug info is only needed if CodeView is enabled. + if (DAG.getTarget().getTargetTriple().isOSBinFormatCOFF()) { + Chain = DAG.getJumpTableDebugInfo(JTI, Chain, dl); + } + return DAG.getNode(ISD::BRIND, dl, MVT::Other, Chain, Addr); +} + bool TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const { const TargetMachine &TM = getTargetMachine(); @@ -554,8 +566,9 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, } /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free. -/// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be -/// generalized for targets with other types of implicit widening casts. +/// This uses isTruncateFree/isZExtFree and ANY_EXTEND for the widening cast, +/// but it could be generalized for targets with other types of implicit +/// widening casts. bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth, const APInt &DemandedBits, TargetLoweringOpt &TLO) const { @@ -1040,13 +1053,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG, // larger type size to do the transform. if (!TLI.isOperationLegalOrCustom(AVGOpc, VT)) return SDValue(); - - if (DAG.computeOverflowForAdd(IsSigned, Add.getOperand(0), - Add.getOperand(1)) == - SelectionDAG::OFK_Never && - (!Add2 || DAG.computeOverflowForAdd(IsSigned, Add2.getOperand(0), - Add2.getOperand(1)) == - SelectionDAG::OFK_Never)) + if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0), + Add.getOperand(1)) && + (!Add2 || DAG.willNotOverflowAdd(IsSigned, Add2.getOperand(0), + Add2.getOperand(1)))) NVT = VT; else return SDValue(); @@ -1155,6 +1165,18 @@ bool TargetLowering::SimplifyDemandedBits( // TODO: Call SimplifyDemandedBits for non-constant demanded elements. Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); return false; // Don't fall through, will infinitely loop. + case ISD::SPLAT_VECTOR: { + SDValue Scl = Op.getOperand(0); + APInt DemandedSclBits = DemandedBits.zextOrTrunc(Scl.getValueSizeInBits()); + KnownBits KnownScl; + if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1)) + return true; + + // Implicitly truncate the bits to match the official semantics of + // SPLAT_VECTOR. + Known = KnownScl.trunc(BitWidth); + break; + } case ISD::LOAD: { auto *LD = cast<LoadSDNode>(Op); if (getTargetConstantFromLoad(LD)) { @@ -1765,8 +1787,17 @@ bool TargetLowering::SimplifyDemandedBits( APInt InDemandedMask = DemandedBits.lshr(ShAmt); if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, - Depth + 1)) + Depth + 1)) { + SDNodeFlags Flags = Op.getNode()->getFlags(); + if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) { + // Disable the nsw and nuw flags. We can no longer guarantee that we + // won't wrap after simplification. + Flags.setNoSignedWrap(false); + Flags.setNoUnsignedWrap(false); + Op->setFlags(Flags); + } return true; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero <<= ShAmt; Known.One <<= ShAmt; @@ -1788,6 +1819,37 @@ bool TargetLowering::SimplifyDemandedBits( if ((ShAmt < DemandedBits.getActiveBits()) && ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; + + // Narrow shift to lower half - similar to ShrinkDemandedOp. + // (shl i64:x, K) -> (i64 zero_extend (shl (i32 (trunc i64:x)), K)) + // Only do this if we demand the upper half so the knownbits are correct. + unsigned HalfWidth = BitWidth / 2; + if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth && + DemandedBits.countLeadingOnes() >= HalfWidth) { + EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth); + if (isNarrowingProfitable(VT, HalfVT) && + isTypeDesirableForOp(ISD::SHL, HalfVT) && + isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) && + (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) { + // If we're demanding the upper bits at all, we must ensure + // that the upper bits of the shift result are known to be zero, + // which is equivalent to the narrow shift being NUW. + if (bool IsNUW = (Known.countMinLeadingZeros() >= HalfWidth)) { + bool IsNSW = Known.countMinSignBits() > HalfWidth; + SDNodeFlags Flags; + Flags.setNoSignedWrap(IsNSW); + Flags.setNoUnsignedWrap(IsNUW); + SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0); + SDValue NewShiftAmt = TLO.DAG.getShiftAmountConstant( + ShAmt, HalfVT, dl, TLO.LegalTypes()); + SDValue NewShift = TLO.DAG.getNode(ISD::SHL, dl, HalfVT, NewOp, + NewShiftAmt, Flags); + SDValue NewExt = + TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift); + return TLO.CombineTo(Op, NewExt); + } + } + } } else { // This is a variable shift, so we can't shift the demand mask by a known // amount. But if we are not demanding high bits, then we are not @@ -1870,15 +1932,15 @@ bool TargetLowering::SimplifyDemandedBits( // Narrow shift to lower half - similar to ShrinkDemandedOp. // (srl i64:x, K) -> (i64 zero_extend (srl (i32 (trunc i64:x)), K)) - if ((BitWidth % 2) == 0 && !VT.isVector() && - ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) || - TLO.DAG.MaskedValueIsZero( - Op0, APInt::getHighBitsSet(BitWidth, BitWidth / 2)))) { + if ((BitWidth % 2) == 0 && !VT.isVector()) { + APInt HiBits = APInt::getHighBitsSet(BitWidth, BitWidth / 2); EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2); if (isNarrowingProfitable(VT, HalfVT) && isTypeDesirableForOp(ISD::SRL, HalfVT) && isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) && - (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT))) { + (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) && + ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) || + TLO.DAG.MaskedValueIsZero(Op0, HiBits))) { SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0); SDValue NewShiftAmt = TLO.DAG.getShiftAmountConstant( ShAmt, HalfVT, dl, TLO.LegalTypes()); @@ -1945,6 +2007,35 @@ bool TargetLowering::SimplifyDemandedBits( if (ShAmt == 0) return TLO.CombineTo(Op, Op0); + // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target + // supports sext_inreg. + if (Op0.getOpcode() == ISD::SHL) { + if (const APInt *InnerSA = + TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) { + unsigned LowBits = BitWidth - ShAmt; + EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits); + if (VT.isVector()) + ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT, + VT.getVectorElementCount()); + + if (*InnerSA == ShAmt) { + if (!TLO.LegalOperations() || + getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal) + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT, + Op0.getOperand(0), + TLO.DAG.getValueType(ExtVT))); + + // Even if we can't convert to sext_inreg, we might be able to + // remove this shift pair if the input is already sign extended. + unsigned NumSignBits = + TLO.DAG.ComputeNumSignBits(Op0.getOperand(0), DemandedElts); + if (NumSignBits > ShAmt) + return TLO.CombineTo(Op, Op0.getOperand(0)); + } + } + } + APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -2106,30 +2197,57 @@ bool TargetLowering::SimplifyDemandedBits( } break; } - case ISD::UMIN: { - // Check if one arg is always less than (or equal) to the other arg. - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1); - KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1); - Known = KnownBits::umin(Known0, Known1); - if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1)) - return TLO.CombineTo(Op, *IsULE ? Op0 : Op1); - if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1)) - return TLO.CombineTo(Op, *IsULT ? Op0 : Op1); - break; - } + case ISD::SMIN: + case ISD::SMAX: + case ISD::UMIN: case ISD::UMAX: { - // Check if one arg is always greater than (or equal) to the other arg. + unsigned Opc = Op.getOpcode(); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + + // If we're only demanding signbits, then we can simplify to OR/AND node. + unsigned BitOp = + (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND; + unsigned NumSignBits = + std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1), + TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1)); + unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero(); + if (NumSignBits >= NumDemandedUpperBits) + return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1)); + + // Check if one arg is always less/greater than (or equal) to the other arg. KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1); KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1); - Known = KnownBits::umax(Known0, Known1); - if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1)) - return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1); - if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1)) - return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1); + switch (Opc) { + case ISD::SMIN: + if (std::optional<bool> IsSLE = KnownBits::sle(Known0, Known1)) + return TLO.CombineTo(Op, *IsSLE ? Op0 : Op1); + if (std::optional<bool> IsSLT = KnownBits::slt(Known0, Known1)) + return TLO.CombineTo(Op, *IsSLT ? Op0 : Op1); + Known = KnownBits::smin(Known0, Known1); + break; + case ISD::SMAX: + if (std::optional<bool> IsSGE = KnownBits::sge(Known0, Known1)) + return TLO.CombineTo(Op, *IsSGE ? Op0 : Op1); + if (std::optional<bool> IsSGT = KnownBits::sgt(Known0, Known1)) + return TLO.CombineTo(Op, *IsSGT ? Op0 : Op1); + Known = KnownBits::smax(Known0, Known1); + break; + case ISD::UMIN: + if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1)) + return TLO.CombineTo(Op, *IsULE ? Op0 : Op1); + if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1)) + return TLO.CombineTo(Op, *IsULT ? Op0 : Op1); + Known = KnownBits::umin(Known0, Known1); + break; + case ISD::UMAX: + if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1)) + return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1); + if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1)) + return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1); + Known = KnownBits::umax(Known0, Known1); + break; + } break; } case ISD::BITREVERSE: { @@ -2285,11 +2403,17 @@ bool TargetLowering::SimplifyDemandedBits( return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src)); } + SDNodeFlags Flags = Op->getFlags(); APInt InDemandedBits = DemandedBits.trunc(InBits); APInt InDemandedElts = DemandedElts.zext(InElts); if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO, - Depth + 1)) + Depth + 1)) { + if (Flags.hasNonNeg()) { + Flags.setNonNeg(false); + Op->setFlags(Flags); + } return true; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); assert(Known.getBitWidth() == InBits && "Src width has changed?"); Known = Known.zext(BitWidth); @@ -2653,7 +2777,7 @@ bool TargetLowering::SimplifyDemandedBits( // neg x with only low bit demanded is simply x. if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() && - isa<ConstantSDNode>(Op0) && cast<ConstantSDNode>(Op0)->isZero()) + isNullConstant(Op0)) return TLO.CombineTo(Op, Op1); // Attempt to avoid multi-use ops if we don't need anything from them. @@ -2913,8 +3037,9 @@ bool TargetLowering::SimplifyDemandedVectorElts( SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts, TLO.DAG, Depth + 1); if (NewOp0 || NewOp1) { - SDValue NewOp = TLO.DAG.getNode( - Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0, NewOp1 ? NewOp1 : Op1); + SDValue NewOp = + TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0, + NewOp1 ? NewOp1 : Op1, Op->getFlags()); return TLO.CombineTo(Op, NewOp); } return false; @@ -3823,8 +3948,12 @@ SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, return SDValue(); } + // TODO: We should invert (X & Y) eq/ne 0 -> (X & Y) ne/eq Y if + // `isXAndYEqZeroPreferableToXAndYEqY` is false. This is a bit difficult as + // its liable to create and infinite loop. SDValue Zero = DAG.getConstant(0, DL, OpVT); - if (DAG.isKnownToBeAPowerOfTwo(Y)) { + if (isXAndYEqZeroPreferableToXAndYEqY(Cond, OpVT) && + DAG.isKnownToBeAPowerOfTwo(Y)) { // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set. // Note that where Y is variable and is known to have at most one bit set // (for example, if it is Z & 1) we cannot do this; the expressions are not @@ -3843,8 +3972,7 @@ SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, // Bail out if the compare operand that we want to turn into a zero is // already a zero (otherwise, infinite loop). - auto *YConst = dyn_cast<ConstantSDNode>(Y); - if (YConst && YConst->isZero()) + if (isNullConstant(Y)) return SDValue(); // Transform this into: ~X & Y == 0. @@ -4088,8 +4216,8 @@ static SDValue simplifySetCCWithCTPOP(const TargetLowering &TLI, EVT VT, // (ctpop x) u< 2 -> (x & x-1) == 0 // (ctpop x) u> 1 -> (x & x-1) != 0 if (Cond == ISD::SETULT || Cond == ISD::SETUGT) { - // Keep the CTPOP if it is a legal vector op. - if (CTVT.isVector() && TLI.isOperationLegal(ISD::CTPOP, CTVT)) + // Keep the CTPOP if it is a cheap vector op. + if (CTVT.isVector() && TLI.isCtpopFast(CTVT)) return SDValue(); unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond); @@ -4110,28 +4238,32 @@ static SDValue simplifySetCCWithCTPOP(const TargetLowering &TLI, EVT VT, return DAG.getSetCC(dl, VT, Result, DAG.getConstant(0, dl, CTVT), CC); } - // Expand a power-of-2 comparison based on ctpop: - // (ctpop x) == 1 --> (x != 0) && ((x & x-1) == 0) - // (ctpop x) != 1 --> (x == 0) || ((x & x-1) != 0) + // Expand a power-of-2 comparison based on ctpop if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && C1 == 1) { - // Keep the CTPOP if it is legal. - if (TLI.isOperationLegal(ISD::CTPOP, CTVT)) + // Keep the CTPOP if it is cheap. + if (TLI.isCtpopFast(CTVT)) return SDValue(); SDValue Zero = DAG.getConstant(0, dl, CTVT); SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT); assert(CTVT.isInteger()); - ISD::CondCode InvCond = ISD::getSetCCInverse(Cond, CTVT); SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne); - SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add); - SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond); + // Its not uncommon for known-never-zero X to exist in (ctpop X) eq/ne 1, so - // check before the emit a potentially unnecessary op. - if (DAG.isKnownNeverZero(CTOp)) + // check before emitting a potentially unnecessary op. + if (DAG.isKnownNeverZero(CTOp)) { + // (ctpop x) == 1 --> (x & x-1) == 0 + // (ctpop x) != 1 --> (x & x-1) != 0 + SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add); + SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond); return RHS; - SDValue LHS = DAG.getSetCC(dl, VT, CTOp, Zero, InvCond); - unsigned LogicOpcode = Cond == ISD::SETEQ ? ISD::AND : ISD::OR; - return DAG.getNode(LogicOpcode, dl, VT, LHS, RHS); + } + + // (ctpop x) == 1 --> (x ^ x-1) > x-1 + // (ctpop x) != 1 --> (x ^ x-1) <= x-1 + SDValue Xor = DAG.getNode(ISD::XOR, dl, CTVT, CTOp, Add); + ISD::CondCode CmpCond = Cond == ISD::SETEQ ? ISD::SETUGT : ISD::SETULE; + return DAG.getSetCC(dl, VT, Xor, Add, CmpCond); } return SDValue(); @@ -4477,8 +4609,8 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, shouldReduceLoadWidth(Lod, ISD::NON_EXTLOAD, newVT)) { SDValue Ptr = Lod->getBasePtr(); if (bestOffset != 0) - Ptr = - DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(bestOffset), dl); + Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(bestOffset), + dl); SDValue NewLoad = DAG.getLoad(newVT, dl, Lod->getChain(), Ptr, Lod->getPointerInfo().getWithOffset(bestOffset), @@ -4983,6 +5115,21 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, } } + // setueq/setoeq X, (fabs Inf) -> is_fpclass X, fcInf + if (isOperationLegalOrCustom(ISD::IS_FPCLASS, N0.getValueType()) && + !isFPImmLegal(CFP->getValueAPF(), CFP->getValueType(0))) { + bool IsFabs = N0.getOpcode() == ISD::FABS; + SDValue Op = IsFabs ? N0.getOperand(0) : N0; + if ((Cond == ISD::SETOEQ || Cond == ISD::SETUEQ) && CFP->isInfinity()) { + FPClassTest Flag = CFP->isNegative() ? (IsFabs ? fcNone : fcNegInf) + : (IsFabs ? fcInf : fcPosInf); + if (Cond == ISD::SETUEQ) + Flag |= fcNan; + return DAG.getNode(ISD::IS_FPCLASS, dl, VT, Op, + DAG.getTargetConstant(Flag, dl, MVT::i32)); + } + } + // If the condition is not legal, see if we can find an equivalent one // which is legal. if (!isCondCodeLegal(Cond, N0.getSimpleValueType())) { @@ -5037,7 +5184,8 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, if (isBitwiseNot(N1)) return DAG.getSetCC(dl, VT, N1.getOperand(0), N0.getOperand(0), Cond); - if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) { + if (DAG.isConstantIntBuildVectorOrConstantInt(N1) && + !DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(0))) { SDValue Not = DAG.getNOT(dl, N1, OpVT); return DAG.getSetCC(dl, VT, Not, N0.getOperand(0), Cond); } @@ -5297,11 +5445,12 @@ SDValue TargetLowering::LowerAsmOutputForConstraint( /// Lower the specified operand into the Ops vector. /// If it is invalid, don't add anything to Ops. void TargetLowering::LowerAsmOperandForConstraint(SDValue Op, - std::string &Constraint, + StringRef Constraint, std::vector<SDValue> &Ops, SelectionDAG &DAG) const { - if (Constraint.length() > 1) return; + if (Constraint.size() > 1) + return; char ConstraintLetter = Constraint[0]; switch (ConstraintLetter) { @@ -5620,20 +5769,27 @@ TargetLowering::ParseConstraints(const DataLayout &DL, return ConstraintOperands; } -/// Return an integer indicating how general CT is. -static unsigned getConstraintGenerality(TargetLowering::ConstraintType CT) { +/// Return a number indicating our preference for chosing a type of constraint +/// over another, for the purpose of sorting them. Immediates are almost always +/// preferrable (when they can be emitted). A higher return value means a +/// stronger preference for one constraint type relative to another. +/// FIXME: We should prefer registers over memory but doing so may lead to +/// unrecoverable register exhaustion later. +/// https://github.com/llvm/llvm-project/issues/20571 +static unsigned getConstraintPiority(TargetLowering::ConstraintType CT) { switch (CT) { case TargetLowering::C_Immediate: case TargetLowering::C_Other: - case TargetLowering::C_Unknown: - return 0; - case TargetLowering::C_Register: - return 1; - case TargetLowering::C_RegisterClass: - return 2; + return 4; case TargetLowering::C_Memory: case TargetLowering::C_Address: return 3; + case TargetLowering::C_RegisterClass: + return 2; + case TargetLowering::C_Register: + return 1; + case TargetLowering::C_Unknown: + return 0; } llvm_unreachable("Invalid constraint type"); } @@ -5713,11 +5869,15 @@ TargetLowering::ConstraintWeight /// If there are multiple different constraints that we could pick for this /// operand (e.g. "imr") try to pick the 'best' one. -/// This is somewhat tricky: constraints fall into four classes: -/// Other -> immediates and magic values +/// This is somewhat tricky: constraints (TargetLowering::ConstraintType) fall +/// into seven classes: /// Register -> one specific register /// RegisterClass -> a group of regs /// Memory -> memory +/// Address -> a symbolic memory reference +/// Immediate -> immediate values +/// Other -> magic values (such as "Flag Output Operands") +/// Unknown -> something we don't recognize yet and can't handle /// Ideally, we would pick the most specific constraint possible: if we have /// something that fits into a register, we would pick it. The problem here /// is that if we have something that could either be in a register or in @@ -5731,18 +5891,13 @@ TargetLowering::ConstraintWeight /// 2) Otherwise, pick the most general constraint present. This prefers /// 'm' over 'r', for example. /// -static void ChooseConstraint(TargetLowering::AsmOperandInfo &OpInfo, - const TargetLowering &TLI, - SDValue Op, SelectionDAG *DAG) { - assert(OpInfo.Codes.size() > 1 && "Doesn't have multiple constraint options"); - unsigned BestIdx = 0; - TargetLowering::ConstraintType BestType = TargetLowering::C_Unknown; - int BestGenerality = -1; +TargetLowering::ConstraintGroup TargetLowering::getConstraintPreferences( + TargetLowering::AsmOperandInfo &OpInfo) const { + ConstraintGroup Ret; - // Loop over the options, keeping track of the most general one. - for (unsigned i = 0, e = OpInfo.Codes.size(); i != e; ++i) { - TargetLowering::ConstraintType CType = - TLI.getConstraintType(OpInfo.Codes[i]); + Ret.reserve(OpInfo.Codes.size()); + for (StringRef Code : OpInfo.Codes) { + TargetLowering::ConstraintType CType = getConstraintType(Code); // Indirect 'other' or 'immediate' constraints are not allowed. if (OpInfo.isIndirect && !(CType == TargetLowering::C_Memory || @@ -5750,40 +5905,38 @@ static void ChooseConstraint(TargetLowering::AsmOperandInfo &OpInfo, CType == TargetLowering::C_RegisterClass)) continue; - // If this is an 'other' or 'immediate' constraint, see if the operand is - // valid for it. For example, on X86 we might have an 'rI' constraint. If - // the operand is an integer in the range [0..31] we want to use I (saving a - // load of a register), otherwise we must use 'r'. - if ((CType == TargetLowering::C_Other || - CType == TargetLowering::C_Immediate) && Op.getNode()) { - assert(OpInfo.Codes[i].size() == 1 && - "Unhandled multi-letter 'other' constraint"); - std::vector<SDValue> ResultOps; - TLI.LowerAsmOperandForConstraint(Op, OpInfo.Codes[i], - ResultOps, *DAG); - if (!ResultOps.empty()) { - BestType = CType; - BestIdx = i; - break; - } - } - // Things with matching constraints can only be registers, per gcc // documentation. This mainly affects "g" constraints. if (CType == TargetLowering::C_Memory && OpInfo.hasMatchingInput()) continue; - // This constraint letter is more general than the previous one, use it. - int Generality = getConstraintGenerality(CType); - if (Generality > BestGenerality) { - BestType = CType; - BestIdx = i; - BestGenerality = Generality; - } + Ret.emplace_back(Code, CType); } - OpInfo.ConstraintCode = OpInfo.Codes[BestIdx]; - OpInfo.ConstraintType = BestType; + std::stable_sort( + Ret.begin(), Ret.end(), [](ConstraintPair a, ConstraintPair b) { + return getConstraintPiority(a.second) > getConstraintPiority(b.second); + }); + + return Ret; +} + +/// If we have an immediate, see if we can lower it. Return true if we can, +/// false otherwise. +static bool lowerImmediateIfPossible(TargetLowering::ConstraintPair &P, + SDValue Op, SelectionDAG *DAG, + const TargetLowering &TLI) { + + assert((P.second == TargetLowering::C_Other || + P.second == TargetLowering::C_Immediate) && + "need immediate or other"); + + if (!Op.getNode()) + return false; + + std::vector<SDValue> ResultOps; + TLI.LowerAsmOperandForConstraint(Op, P.first, ResultOps, *DAG); + return !ResultOps.empty(); } /// Determines the constraint code and constraint type to use for the specific @@ -5798,7 +5951,26 @@ void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo, OpInfo.ConstraintCode = OpInfo.Codes[0]; OpInfo.ConstraintType = getConstraintType(OpInfo.ConstraintCode); } else { - ChooseConstraint(OpInfo, *this, Op, DAG); + ConstraintGroup G = getConstraintPreferences(OpInfo); + if (G.empty()) + return; + + unsigned BestIdx = 0; + for (const unsigned E = G.size(); + BestIdx < E && (G[BestIdx].second == TargetLowering::C_Other || + G[BestIdx].second == TargetLowering::C_Immediate); + ++BestIdx) { + if (lowerImmediateIfPossible(G[BestIdx], Op, DAG, *this)) + break; + // If we're out of constraints, just pick the first one. + if (BestIdx + 1 == E) { + BestIdx = 0; + break; + } + } + + OpInfo.ConstraintCode = G[BestIdx].first; + OpInfo.ConstraintType = G[BestIdx].second; } // 'X' matches anything. @@ -5914,6 +6086,49 @@ TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, return SDValue(); } +/// Build sdiv by power-of-2 with conditional move instructions +/// Ref: "Hacker's Delight" by Henry Warren 10-1 +/// If conditional move/branch is preferred, we lower sdiv x, +/-2**k into: +/// bgez x, label +/// add x, x, 2**k-1 +/// label: +/// sra res, x, k +/// neg res, res (when the divisor is negative) +SDValue TargetLowering::buildSDIVPow2WithCMov( + SDNode *N, const APInt &Divisor, SelectionDAG &DAG, + SmallVectorImpl<SDNode *> &Created) const { + unsigned Lg2 = Divisor.countr_zero(); + EVT VT = N->getValueType(0); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + SDValue Zero = DAG.getConstant(0, DL, VT); + APInt Lg2Mask = APInt::getLowBitsSet(VT.getSizeInBits(), Lg2); + SDValue Pow2MinusOne = DAG.getConstant(Lg2Mask, DL, VT); + + // If N0 is negative, we need to add (Pow2 - 1) to it before shifting right. + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue Cmp = DAG.getSetCC(DL, CCVT, N0, Zero, ISD::SETLT); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Pow2MinusOne); + SDValue CMov = DAG.getNode(ISD::SELECT, DL, VT, Cmp, Add, N0); + + Created.push_back(Cmp.getNode()); + Created.push_back(Add.getNode()); + Created.push_back(CMov.getNode()); + + // Divide by pow2. + SDValue SRA = + DAG.getNode(ISD::SRA, DL, VT, CMov, DAG.getConstant(Lg2, DL, VT)); + + // If we're dividing by a positive value, we're done. Otherwise, we must + // negate the result. + if (Divisor.isNonNegative()) + return SRA; + + Created.push_back(SRA.getNode()); + return DAG.getNode(ISD::SUB, DL, VT, Zero, SRA); +} + /// Given an ISD::SDIV node expressing a divide by constant, /// return a DAG expression to select that will generate the same value by /// multiplying by a magic number. @@ -6016,7 +6231,7 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG, // Multiply the numerator (operand 0) by the magic value. // FIXME: We should support doing a MUL in a wider type. auto GetMULHS = [&](SDValue X, SDValue Y) { - // If the type isn't legal, use a wider mul of the the type calculated + // If the type isn't legal, use a wider mul of the type calculated // earlier. if (!isTypeLegal(VT)) { X = DAG.getNode(ISD::SIGN_EXTEND, dl, MulVT, X); @@ -6203,7 +6418,7 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, // FIXME: We should support doing a MUL in a wider type. auto GetMULHU = [&](SDValue X, SDValue Y) { - // If the type isn't legal, use a wider mul of the the type calculated + // If the type isn't legal, use a wider mul of the type calculated // earlier. if (!isTypeLegal(VT)) { X = DAG.getNode(ISD::ZERO_EXTEND, dl, MulVT, X); @@ -9131,7 +9346,7 @@ TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, SrcEltVT, LD->getOriginalAlign(), LD->getMemOperand()->getFlags(), LD->getAAInfo()); - BasePTR = DAG.getObjectPtrOffset(SL, BasePTR, TypeSize::Fixed(Stride)); + BasePTR = DAG.getObjectPtrOffset(SL, BasePTR, TypeSize::getFixed(Stride)); Vals.push_back(ScalarLoad.getValue(0)); LoadChains.push_back(ScalarLoad.getValue(1)); @@ -9206,7 +9421,7 @@ SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, DAG.getVectorIdxConstant(Idx, SL)); SDValue Ptr = - DAG.getObjectPtrOffset(SL, BasePtr, TypeSize::Fixed(Idx * Stride)); + DAG.getObjectPtrOffset(SL, BasePtr, TypeSize::getFixed(Idx * Stride)); // This scalar TruncStore may be illegal, but we legalize it later. SDValue Store = DAG.getTruncStore( @@ -9342,7 +9557,7 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const { NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(), LD->getAAInfo()); - Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize)); + Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize)); Hi = DAG.getExtLoad(HiExtType, dl, VT, Chain, Ptr, LD->getPointerInfo().getWithOffset(IncrementSize), NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(), @@ -9352,7 +9567,7 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const { NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(), LD->getAAInfo()); - Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize)); + Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize)); Lo = DAG.getExtLoad(ISD::ZEXTLOAD, dl, VT, Chain, Ptr, LD->getPointerInfo().getWithOffset(IncrementSize), NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(), @@ -9477,6 +9692,14 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, SDValue ShiftAmount = DAG.getConstant( NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout())); SDValue Lo = Val; + // If Val is a constant, replace the upper bits with 0. The SRL will constant + // fold and not use the upper bits. A smaller constant may be easier to + // materialize. + if (auto *C = dyn_cast<ConstantSDNode>(Lo); C && !C->isOpaque()) + Lo = DAG.getNode( + ISD::AND, dl, VT, Lo, + DAG.getConstant(APInt::getLowBitsSet(VT.getSizeInBits(), NumBits), dl, + VT)); SDValue Hi = DAG.getNode(ISD::SRL, dl, VT, Val, ShiftAmount); // Store the two parts @@ -9486,7 +9709,7 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, Ptr, ST->getPointerInfo(), NewStoredVT, Alignment, ST->getMemOperand()->getFlags()); - Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize)); + Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize)); Store2 = DAG.getTruncStore( Chain, dl, DAG.getDataLayout().isLittleEndian() ? Hi : Lo, Ptr, ST->getPointerInfo().getWithOffset(IncrementSize), NewStoredVT, Alignment, @@ -9618,7 +9841,7 @@ SDValue TargetLowering::LowerToTLSEmulatedModel(const GlobalAddressSDNode *GA, // Access to address of TLS varialbe xyz is lowered to a function call: // __emutls_get_address( address of global variable named "__emutls_v.xyz" ) EVT PtrVT = getPointerTy(DAG.getDataLayout()); - PointerType *VoidPtrType = Type::getInt8PtrTy(*DAG.getContext()); + PointerType *VoidPtrType = PointerType::get(*DAG.getContext(), 0); SDLoc dl(GA); ArgListTy Args; @@ -9657,20 +9880,18 @@ SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op, return SDValue(); ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); SDLoc dl(Op); - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { - if (C->isZero() && CC == ISD::SETEQ) { - EVT VT = Op.getOperand(0).getValueType(); - SDValue Zext = Op.getOperand(0); - if (VT.bitsLT(MVT::i32)) { - VT = MVT::i32; - Zext = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Op.getOperand(0)); - } - unsigned Log2b = Log2_32(VT.getSizeInBits()); - SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Zext); - SDValue Scc = DAG.getNode(ISD::SRL, dl, VT, Clz, - DAG.getConstant(Log2b, dl, MVT::i32)); - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Scc); + if (isNullConstant(Op.getOperand(1)) && CC == ISD::SETEQ) { + EVT VT = Op.getOperand(0).getValueType(); + SDValue Zext = Op.getOperand(0); + if (VT.bitsLT(MVT::i32)) { + VT = MVT::i32; + Zext = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Op.getOperand(0)); } + unsigned Log2b = Log2_32(VT.getSizeInBits()); + SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Zext); + SDValue Scc = DAG.getNode(ISD::SRL, dl, VT, Clz, + DAG.getConstant(Log2b, dl, MVT::i32)); + return DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Scc); } return SDValue(); } @@ -10489,9 +10710,9 @@ SDValue TargetLowering::expandFP_TO_INT_SAT(SDNode *Node, MaxInt = APInt::getMaxValue(SatWidth).zext(DstWidth); } - // We cannot risk emitting FP_TO_XINT nodes with a source VT of f16, as + // We cannot risk emitting FP_TO_XINT nodes with a source VT of [b]f16, as // libcall emission cannot handle this. Large result types will fail. - if (SrcVT == MVT::f16) { + if (SrcVT == MVT::f16 || SrcVT == MVT::bf16) { Src = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, Src); SrcVT = Src.getValueType(); } |
