aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp126
1 files changed, 111 insertions, 15 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
index 34f6bb156327..7826ee4e0d2f 100644
--- a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23190,6 +23190,10 @@ static SDValue EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC,
bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
+ // We don't need to replace SQRT with RSQRT for half type.
+ if (VT.getScalarType() == MVT::f16)
+ return true;
+
// We never want to use both SQRT and RSQRT instructions for the same input.
if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op))
return false;
@@ -23228,11 +23232,15 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
UseOneConstNR = false;
// There is no FSQRT for 512-bits, but there is RSQRT14.
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
- return DAG.getNode(Opcode, DL, VT, Op);
+ SDValue Estimate = DAG.getNode(Opcode, DL, VT, Op);
+ if (RefinementSteps == 0 && !Reciprocal)
+ Estimate = DAG.getNode(ISD::FMUL, DL, VT, Op, Estimate);
+ return Estimate;
}
if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
Subtarget.hasFP16()) {
+ assert(Reciprocal && "Don't replace SQRT with RSQRT for half type");
if (RefinementSteps == ReciprocalEstimate::Unspecified)
RefinementSteps = 0;
@@ -45684,7 +45692,7 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG,
if (is64BitFP && !Subtarget.is64Bit()) {
// On a 32-bit target, we cannot bitcast the 64-bit float to a
// 64-bit integer, since that's not a legal type. Since
- // OnesOrZeroesF is all ones of all zeroes, we don't need all the
+ // OnesOrZeroesF is all ones or all zeroes, we don't need all the
// bits, but can do this little dance to extract the lowest 32 bits
// and work with those going forward.
SDValue Vector64 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64,
@@ -46581,6 +46589,59 @@ static SDValue combineOrCmpEqZeroToCtlzSrl(SDNode *N, SelectionDAG &DAG,
return Ret;
}
+static SDValue foldMaskedMergeImpl(SDValue And0_L, SDValue And0_R,
+ SDValue And1_L, SDValue And1_R, SDLoc DL,
+ SelectionDAG &DAG) {
+ if (!isBitwiseNot(And0_L, true) || !And0_L->hasOneUse())
+ return SDValue();
+ SDValue NotOp = And0_L->getOperand(0);
+ if (NotOp == And1_R)
+ std::swap(And1_R, And1_L);
+ if (NotOp != And1_L)
+ return SDValue();
+
+ // (~(NotOp) & And0_R) | (NotOp & And1_R)
+ // --> ((And0_R ^ And1_R) & NotOp) ^ And1_R
+ EVT VT = And1_L->getValueType(0);
+ SDValue Freeze_And0_R = DAG.getNode(ISD::FREEZE, SDLoc(), VT, And0_R);
+ SDValue Xor0 = DAG.getNode(ISD::XOR, DL, VT, And1_R, Freeze_And0_R);
+ SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor0, NotOp);
+ SDValue Xor1 = DAG.getNode(ISD::XOR, DL, VT, And, Freeze_And0_R);
+ return Xor1;
+}
+
+/// Fold "masked merge" expressions like `(m & x) | (~m & y)` into the
+/// equivalent `((x ^ y) & m) ^ y)` pattern.
+/// This is typically a better representation for targets without a fused
+/// "and-not" operation. This function is intended to be called from a
+/// `TargetLowering::PerformDAGCombine` callback on `ISD::OR` nodes.
+static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG) {
+ // Note that masked-merge variants using XOR or ADD expressions are
+ // normalized to OR by InstCombine so we only check for OR.
+ assert(Node->getOpcode() == ISD::OR && "Must be called with ISD::OR node");
+ SDValue N0 = Node->getOperand(0);
+ if (N0->getOpcode() != ISD::AND || !N0->hasOneUse())
+ return SDValue();
+ SDValue N1 = Node->getOperand(1);
+ if (N1->getOpcode() != ISD::AND || !N1->hasOneUse())
+ return SDValue();
+
+ SDLoc DL(Node);
+ SDValue N00 = N0->getOperand(0);
+ SDValue N01 = N0->getOperand(1);
+ SDValue N10 = N1->getOperand(0);
+ SDValue N11 = N1->getOperand(1);
+ if (SDValue Result = foldMaskedMergeImpl(N00, N01, N10, N11, DL, DAG))
+ return Result;
+ if (SDValue Result = foldMaskedMergeImpl(N01, N00, N10, N11, DL, DAG))
+ return Result;
+ if (SDValue Result = foldMaskedMergeImpl(N10, N11, N00, N01, DL, DAG))
+ return Result;
+ if (SDValue Result = foldMaskedMergeImpl(N11, N10, N00, N01, DL, DAG))
+ return Result;
+ return SDValue();
+}
+
static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -46674,6 +46735,11 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
return Res;
}
+ // We should fold "masked merge" patterns when `andn` is not available.
+ if (!Subtarget.hasBMI() && VT.isScalarInteger() && VT != MVT::i1)
+ if (SDValue R = foldMaskedMerge(N, DAG))
+ return R;
+
return SDValue();
}
@@ -48508,20 +48574,50 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
SDValue LHS = Src.getOperand(0).getOperand(0);
SDValue RHS = Src.getOperand(0).getOperand(1);
- unsigned ExtOpc = LHS.getOpcode();
- if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
- RHS.getOpcode() != ExtOpc)
- return SDValue();
-
- // Peek through the extends.
- LHS = LHS.getOperand(0);
- RHS = RHS.getOperand(0);
-
- // Ensure the input types match.
- if (LHS.getValueType() != VT || RHS.getValueType() != VT)
- return SDValue();
+ // Count leading sign/zero bits on both inputs - if there are enough then
+ // truncation back to vXi16 will be cheap - either as a pack/shuffle
+ // sequence or using AVX512 truncations. If the inputs are sext/zext then the
+ // truncations may actually be free by peeking through to the ext source.
+ auto IsSext = [&DAG](SDValue V) {
+ return DAG.ComputeMinSignedBits(V) <= 16;
+ };
+ auto IsZext = [&DAG](SDValue V) {
+ return DAG.computeKnownBits(V).countMaxActiveBits() <= 16;
+ };
- unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
+ bool IsSigned = IsSext(LHS) && IsSext(RHS);
+ bool IsUnsigned = IsZext(LHS) && IsZext(RHS);
+ if (!IsSigned && !IsUnsigned)
+ return SDValue();
+
+ // Check if both inputs are extensions, which will be removed by truncation.
+ bool IsTruncateFree = (LHS.getOpcode() == ISD::SIGN_EXTEND ||
+ LHS.getOpcode() == ISD::ZERO_EXTEND) &&
+ (RHS.getOpcode() == ISD::SIGN_EXTEND ||
+ RHS.getOpcode() == ISD::ZERO_EXTEND) &&
+ LHS.getOperand(0).getScalarValueSizeInBits() <= 16 &&
+ RHS.getOperand(0).getScalarValueSizeInBits() <= 16;
+
+ // For AVX2+ targets, with the upper bits known zero, we can perform MULHU on
+ // the (bitcasted) inputs directly, and then cheaply pack/truncate the result
+ // (upper elts will be zero). Don't attempt this with just AVX512F as MULHU
+ // will have to split anyway.
+ unsigned InSizeInBits = InVT.getSizeInBits();
+ if (IsUnsigned && !IsTruncateFree && Subtarget.hasInt256() &&
+ !(Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.is256BitVector()) &&
+ (InSizeInBits % 16) == 0) {
+ EVT BCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+ InVT.getSizeInBits() / 16);
+ SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
+ DAG.getBitcast(BCVT, RHS));
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+ }
+
+ // Truncate back to source type.
+ LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
+ RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
+
+ unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
return DAG.getNode(Opc, DL, VT, LHS, RHS);
}