diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp | 126 |
1 files changed, 111 insertions, 15 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp index 34f6bb156327..7826ee4e0d2f 100644 --- a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23190,6 +23190,10 @@ static SDValue EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + // We don't need to replace SQRT with RSQRT for half type. + if (VT.getScalarType() == MVT::f16) + return true; + // We never want to use both SQRT and RSQRT instructions for the same input. if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op)) return false; @@ -23228,11 +23232,15 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op, UseOneConstNR = false; // There is no FSQRT for 512-bits, but there is RSQRT14. unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT; - return DAG.getNode(Opcode, DL, VT, Op); + SDValue Estimate = DAG.getNode(Opcode, DL, VT, Op); + if (RefinementSteps == 0 && !Reciprocal) + Estimate = DAG.getNode(ISD::FMUL, DL, VT, Op, Estimate); + return Estimate; } if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) && Subtarget.hasFP16()) { + assert(Reciprocal && "Don't replace SQRT with RSQRT for half type"); if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 0; @@ -45684,7 +45692,7 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, if (is64BitFP && !Subtarget.is64Bit()) { // On a 32-bit target, we cannot bitcast the 64-bit float to a // 64-bit integer, since that's not a legal type. Since - // OnesOrZeroesF is all ones of all zeroes, we don't need all the + // OnesOrZeroesF is all ones or all zeroes, we don't need all the // bits, but can do this little dance to extract the lowest 32 bits // and work with those going forward. SDValue Vector64 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64, @@ -46581,6 +46589,59 @@ static SDValue combineOrCmpEqZeroToCtlzSrl(SDNode *N, SelectionDAG &DAG, return Ret; } +static SDValue foldMaskedMergeImpl(SDValue And0_L, SDValue And0_R, + SDValue And1_L, SDValue And1_R, SDLoc DL, + SelectionDAG &DAG) { + if (!isBitwiseNot(And0_L, true) || !And0_L->hasOneUse()) + return SDValue(); + SDValue NotOp = And0_L->getOperand(0); + if (NotOp == And1_R) + std::swap(And1_R, And1_L); + if (NotOp != And1_L) + return SDValue(); + + // (~(NotOp) & And0_R) | (NotOp & And1_R) + // --> ((And0_R ^ And1_R) & NotOp) ^ And1_R + EVT VT = And1_L->getValueType(0); + SDValue Freeze_And0_R = DAG.getNode(ISD::FREEZE, SDLoc(), VT, And0_R); + SDValue Xor0 = DAG.getNode(ISD::XOR, DL, VT, And1_R, Freeze_And0_R); + SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor0, NotOp); + SDValue Xor1 = DAG.getNode(ISD::XOR, DL, VT, And, Freeze_And0_R); + return Xor1; +} + +/// Fold "masked merge" expressions like `(m & x) | (~m & y)` into the +/// equivalent `((x ^ y) & m) ^ y)` pattern. +/// This is typically a better representation for targets without a fused +/// "and-not" operation. This function is intended to be called from a +/// `TargetLowering::PerformDAGCombine` callback on `ISD::OR` nodes. +static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG) { + // Note that masked-merge variants using XOR or ADD expressions are + // normalized to OR by InstCombine so we only check for OR. + assert(Node->getOpcode() == ISD::OR && "Must be called with ISD::OR node"); + SDValue N0 = Node->getOperand(0); + if (N0->getOpcode() != ISD::AND || !N0->hasOneUse()) + return SDValue(); + SDValue N1 = Node->getOperand(1); + if (N1->getOpcode() != ISD::AND || !N1->hasOneUse()) + return SDValue(); + + SDLoc DL(Node); + SDValue N00 = N0->getOperand(0); + SDValue N01 = N0->getOperand(1); + SDValue N10 = N1->getOperand(0); + SDValue N11 = N1->getOperand(1); + if (SDValue Result = foldMaskedMergeImpl(N00, N01, N10, N11, DL, DAG)) + return Result; + if (SDValue Result = foldMaskedMergeImpl(N01, N00, N10, N11, DL, DAG)) + return Result; + if (SDValue Result = foldMaskedMergeImpl(N10, N11, N00, N01, DL, DAG)) + return Result; + if (SDValue Result = foldMaskedMergeImpl(N11, N10, N00, N01, DL, DAG)) + return Result; + return SDValue(); +} + static SDValue combineOr(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -46674,6 +46735,11 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, return Res; } + // We should fold "masked merge" patterns when `andn` is not available. + if (!Subtarget.hasBMI() && VT.isScalarInteger() && VT != MVT::i1) + if (SDValue R = foldMaskedMerge(N, DAG)) + return R; + return SDValue(); } @@ -48508,20 +48574,50 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, SDValue LHS = Src.getOperand(0).getOperand(0); SDValue RHS = Src.getOperand(0).getOperand(1); - unsigned ExtOpc = LHS.getOpcode(); - if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) || - RHS.getOpcode() != ExtOpc) - return SDValue(); - - // Peek through the extends. - LHS = LHS.getOperand(0); - RHS = RHS.getOperand(0); - - // Ensure the input types match. - if (LHS.getValueType() != VT || RHS.getValueType() != VT) - return SDValue(); + // Count leading sign/zero bits on both inputs - if there are enough then + // truncation back to vXi16 will be cheap - either as a pack/shuffle + // sequence or using AVX512 truncations. If the inputs are sext/zext then the + // truncations may actually be free by peeking through to the ext source. + auto IsSext = [&DAG](SDValue V) { + return DAG.ComputeMinSignedBits(V) <= 16; + }; + auto IsZext = [&DAG](SDValue V) { + return DAG.computeKnownBits(V).countMaxActiveBits() <= 16; + }; - unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU; + bool IsSigned = IsSext(LHS) && IsSext(RHS); + bool IsUnsigned = IsZext(LHS) && IsZext(RHS); + if (!IsSigned && !IsUnsigned) + return SDValue(); + + // Check if both inputs are extensions, which will be removed by truncation. + bool IsTruncateFree = (LHS.getOpcode() == ISD::SIGN_EXTEND || + LHS.getOpcode() == ISD::ZERO_EXTEND) && + (RHS.getOpcode() == ISD::SIGN_EXTEND || + RHS.getOpcode() == ISD::ZERO_EXTEND) && + LHS.getOperand(0).getScalarValueSizeInBits() <= 16 && + RHS.getOperand(0).getScalarValueSizeInBits() <= 16; + + // For AVX2+ targets, with the upper bits known zero, we can perform MULHU on + // the (bitcasted) inputs directly, and then cheaply pack/truncate the result + // (upper elts will be zero). Don't attempt this with just AVX512F as MULHU + // will have to split anyway. + unsigned InSizeInBits = InVT.getSizeInBits(); + if (IsUnsigned && !IsTruncateFree && Subtarget.hasInt256() && + !(Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.is256BitVector()) && + (InSizeInBits % 16) == 0) { + EVT BCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + InVT.getSizeInBits() / 16); + SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS), + DAG.getBitcast(BCVT, RHS)); + return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res)); + } + + // Truncate back to source type. + LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS); + RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS); + + unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU; return DAG.getNode(Opc, DL, VT, LHS, RHS); } |
