summaryrefslogtreecommitdiff
path: root/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--lib/Target/X86/X86ISelLowering.cpp465
1 files changed, 250 insertions, 215 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index a72f4daa5e11..5ac5d0348f8a 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -461,7 +461,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SRL_PARTS, VT, Custom);
}
- if (Subtarget.hasSSE1())
+ if (Subtarget.hasSSEPrefetch() || Subtarget.has3DNow())
setOperationAction(ISD::PREFETCH , MVT::Other, Legal);
setOperationAction(ISD::ATOMIC_FENCE , MVT::Other, Custom);
@@ -1622,16 +1622,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setLibcallName(RTLIB::MUL_I128, nullptr);
}
- // Combine sin / cos into one node or libcall if possible.
- if (Subtarget.hasSinCos()) {
- setLibcallName(RTLIB::SINCOS_F32, "sincosf");
- setLibcallName(RTLIB::SINCOS_F64, "sincos");
- if (Subtarget.isTargetDarwin()) {
- // For MacOSX, we don't want the normal expansion of a libcall to sincos.
- // We want to issue a libcall to __sincos_stret to avoid memory traffic.
- setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
- setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
- }
+ // Combine sin / cos into _sincos_stret if it is available.
+ if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr &&
+ getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) {
+ setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
+ setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
}
if (Subtarget.isTargetWin64()) {
@@ -7480,9 +7475,9 @@ static bool isAddSub(const BuildVectorSDNode *BV,
}
/// Returns true if is possible to fold MUL and an idiom that has already been
-/// recognized as ADDSUB(\p Opnd0, \p Opnd1) into FMADDSUB(x, y, \p Opnd1).
-/// If (and only if) true is returned, the operands of FMADDSUB are written to
-/// parameters \p Opnd0, \p Opnd1, \p Opnd2.
+/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into
+/// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the
+/// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2.
///
/// Prior to calling this function it should be known that there is some
/// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation
@@ -7505,12 +7500,12 @@ static bool isAddSub(const BuildVectorSDNode *BV,
/// recognized ADDSUB idiom with ADDSUB operation is that such replacement
/// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit
/// FMADDSUB is.
-static bool isFMAddSub(const X86Subtarget &Subtarget, SelectionDAG &DAG,
- SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2,
- unsigned ExpectedUses) {
+static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget,
+ SelectionDAG &DAG,
+ SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2,
+ unsigned ExpectedUses) {
if (Opnd0.getOpcode() != ISD::FMUL ||
- !Opnd0->hasNUsesOfValue(ExpectedUses, 0) ||
- !Subtarget.hasAnyFMA())
+ !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || !Subtarget.hasAnyFMA())
return false;
// FIXME: These checks must match the similar ones in
@@ -7547,7 +7542,7 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV,
SDValue Opnd2;
// TODO: According to coverage reports, the FMADDSUB transform is not
// triggered by any tests.
- if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts))
+ if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts))
return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);
// Do not generate X86ISD::ADDSUB node for 512-bit types even though
@@ -11958,6 +11953,19 @@ static int canLowerByDroppingEvenElements(ArrayRef<int> Mask,
return 0;
}
+static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT,
+ ArrayRef<int> Mask, SDValue V1,
+ SDValue V2, SelectionDAG &DAG) {
+ MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits());
+ MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements());
+
+ SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true);
+ if (V2.isUndef())
+ return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1);
+
+ return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2);
+}
+
/// \brief Generic lowering of v16i8 shuffles.
///
/// This is a hybrid strategy to lower v16i8 vectors. It first attempts to
@@ -12148,6 +12156,10 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack(
DL, MVT::v16i8, V1, V2, Mask, DAG))
return Unpack;
+
+ // If we have VBMI we can use one VPERM instead of multiple PSHUFBs.
+ if (Subtarget.hasVBMI() && Subtarget.hasVLX())
+ return lowerVectorShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, DAG);
}
return PSHUFB;
@@ -13048,19 +13060,6 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT,
DAG.getConstant(Immediate, DL, MVT::i8));
}
-static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT,
- ArrayRef<int> Mask, SDValue V1,
- SDValue V2, SelectionDAG &DAG) {
- MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits());
- MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements());
-
- SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true);
- if (V2.isUndef())
- return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1);
-
- return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2);
-}
-
/// \brief Handle lowering of 4-lane 64-bit floating point shuffles.
///
/// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2
@@ -13615,6 +13614,10 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG))
return PSHUFB;
+ // AVX512VBMIVL can lower to VPERMB.
+ if (Subtarget.hasVBMI() && Subtarget.hasVLX())
+ return lowerVectorShuffleWithPERMV(DL, MVT::v32i8, Mask, V1, V2, DAG);
+
// Try to simplify this by merging 128-bit lanes to enable a lane-based
// shuffle.
if (SDValue Result = lowerVectorShuffleByMerging128BitLanes(
@@ -14077,6 +14080,10 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
Zeroable, Subtarget, DAG))
return Blend;
+ if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB(
+ DL, MVT::v32i16, Mask, V1, V2, Zeroable, Subtarget, DAG))
+ return PSHUFB;
+
return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG);
}
@@ -14212,7 +14219,9 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
ExtVT = MVT::v4i32;
break;
case MVT::v8i1:
- ExtVT = MVT::v8i64; // Take 512-bit type, more shuffles on KNL
+ // Take 512-bit type, more shuffles on KNL. If we have VLX use a 256-bit
+ // shuffle.
+ ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64;
break;
case MVT::v16i1:
ExtVT = MVT::v16i32;
@@ -14569,11 +14578,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
unsigned NumElts = VecVT.getVectorNumElements();
// Extending v8i1/v16i1 to 512-bit get better performance on KNL
// than extending to 128/256bit.
- unsigned VecSize = (NumElts <= 4 ? 128 : 512);
- MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize / NumElts), NumElts);
- SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVT, Vec);
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
- ExtVT.getVectorElementType(), Ext, Idx);
+ MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
+ MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
+ SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec);
+ SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ExtEltVT, Ext, Idx);
return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
}
@@ -14768,12 +14776,11 @@ static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG,
// Non constant index. Extend source and destination,
// insert element and then truncate the result.
unsigned NumElts = VecVT.getVectorNumElements();
- unsigned VecSize = (NumElts <= 4 ? 128 : 512);
- MVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize/NumElts), NumElts);
- MVT ExtEltVT = ExtVecVT.getVectorElementType();
+ MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
+ MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
SDValue ExtOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtVecVT,
- DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVecVT, Vec),
- DAG.getNode(ISD::ZERO_EXTEND, dl, ExtEltVT, Elt), Idx);
+ DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec),
+ DAG.getNode(ISD::SIGN_EXTEND, dl, ExtEltVT, Elt), Idx);
return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp);
}
@@ -16287,21 +16294,6 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op,
return SelectedVal;
}
-static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
- SelectionDAG &DAG) {
- SDValue In = Op->getOperand(0);
- MVT InVT = In.getSimpleValueType();
-
- if (InVT.getVectorElementType() == MVT::i1)
- return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG);
-
- if (Subtarget.hasFp256())
- if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget))
- return Res;
-
- return SDValue();
-}
-
static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
SDValue In = Op.getOperand(0);
@@ -16440,7 +16432,8 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,
assert((InVT.is256BitVector() || InVT.is128BitVector()) &&
"Unexpected vector type.");
unsigned NumElts = InVT.getVectorNumElements();
- MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts);
+ MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts);
+ MVT ExtVT = MVT::getVectorVT(EltVT, NumElts);
In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In);
InVT = ExtVT;
ShiftInx = InVT.getScalarSizeInBits() - 1;
@@ -18446,6 +18439,21 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op,
return V;
}
+static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ SDValue In = Op->getOperand(0);
+ MVT InVT = In.getSimpleValueType();
+
+ if (InVT.getVectorElementType() == MVT::i1)
+ return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG);
+
+ if (Subtarget.hasFp256())
+ if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget))
+ return Res;
+
+ return SDValue();
+}
+
// Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG.
// For sign extend this needs to handle all vector sizes and SSE4.1 and
// non-SSE4.1 targets. For zero extend this should only handle inputs of
@@ -21128,7 +21136,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
// ADC/ADCX/SBB
case ADX: {
SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32);
- SDVTList VTs = DAG.getVTList(Op.getOperand(3)->getValueType(0), MVT::i32);
+ SDVTList VTs = DAG.getVTList(Op.getOperand(3).getValueType(), MVT::i32);
SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(2),
DAG.getConstant(-1, dl, MVT::i8));
SDValue Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(3),
@@ -22231,6 +22239,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask));
}
+ assert(VT == MVT::v16i8 && "Unexpected VT");
+
SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A);
SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B);
SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB);
@@ -22989,12 +22999,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
(Subtarget.hasAVX512() && VT == MVT::v16i16) ||
(Subtarget.hasAVX512() && VT == MVT::v16i8) ||
(Subtarget.hasBWI() && VT == MVT::v32i8)) {
- MVT EvtSVT = (VT == MVT::v32i8 ? MVT::i16 : MVT::i32);
+ assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
+ "Unexpected vector type");
+ MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements());
unsigned ExtOpc =
Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
R = DAG.getNode(ExtOpc, dl, ExtVT, R);
- Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt);
+ Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt);
return DAG.getNode(ISD::TRUNCATE, dl, VT,
DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt));
}
@@ -24101,8 +24113,9 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget &Subtarget,
// Only optimize x86_64 for now. i386 is a bit messy. For f32,
// the small struct {f32, f32} is returned in (eax, edx). For f64,
// the results are returned via SRet in memory.
- const char *LibcallName = isF64 ? "__sincos_stret" : "__sincosf_stret";
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ RTLIB::Libcall LC = isF64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32;
+ const char *LibcallName = TLI.getLibcallName(LC);
SDValue Callee =
DAG.getExternalSymbol(LibcallName, TLI.getPointerTy(DAG.getDataLayout()));
@@ -24928,7 +24941,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
case ISD::BITCAST: {
assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
EVT DstVT = N->getValueType(0);
- EVT SrcVT = N->getOperand(0)->getValueType(0);
+ EVT SrcVT = N->getOperand(0).getValueType();
if (SrcVT != MVT::f64 ||
(DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8))
@@ -28407,8 +28420,6 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
// TODO - attempt to narrow Mask back to writemask size.
bool IsEVEXShuffle =
RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128);
- if (IsEVEXShuffle && (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits))
- return SDValue();
// TODO - handle 128/256-bit lane shuffles of 512-bit vectors.
@@ -28491,11 +28502,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain,
V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT,
- ShuffleVT)) {
+ ShuffleVT) &&
+ (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements()))
- return SDValue(); // AVX512 Writemask clash.
Res = DAG.getBitcast(ShuffleSrcVT, V1);
DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
@@ -28505,11 +28515,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,
AllowIntDomain, Subtarget, Shuffle,
- ShuffleVT, PermuteImm)) {
+ ShuffleVT, PermuteImm) &&
+ (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements()))
- return SDValue(); // AVX512 Writemask clash.
Res = DAG.getBitcast(ShuffleVT, V1);
DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res,
@@ -28520,12 +28529,11 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
}
if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain,
- V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT,
- ShuffleVT, UnaryShuffle)) {
+ V1, V2, DL, DAG, Subtarget, Shuffle,
+ ShuffleSrcVT, ShuffleVT, UnaryShuffle) &&
+ (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements()))
- return SDValue(); // AVX512 Writemask clash.
V1 = DAG.getBitcast(ShuffleSrcVT, V1);
DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(ShuffleSrcVT, V2);
@@ -28538,11 +28546,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,
AllowIntDomain, V1, V2, DL, DAG,
Subtarget, Shuffle, ShuffleVT,
- PermuteImm)) {
+ PermuteImm) &&
+ (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements()))
- return SDValue(); // AVX512 Writemask clash.
V1 = DAG.getBitcast(ShuffleVT, V1);
DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(ShuffleVT, V2);
@@ -28594,8 +28601,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
return SDValue();
// Depth threshold above which we can efficiently use variable mask shuffles.
- // TODO This should probably be target specific.
- bool AllowVariableMask = (Depth >= 3) || HasVariableMask;
+ int VariableShuffleDepth = Subtarget.hasFastVariableShuffle() ? 2 : 3;
+ bool AllowVariableMask = (Depth >= VariableShuffleDepth) || HasVariableMask;
bool MaskContainsZeros =
any_of(Mask, [](int M) { return M == SM_SentinelZero; });
@@ -29698,17 +29705,18 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
return SDValue();
}
-/// Returns true iff the shuffle node \p N can be replaced with ADDSUB
-/// operation. If true is returned then the operands of ADDSUB operation
+/// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD)
+/// operation. If true is returned then the operands of ADDSUB(SUBADD) operation
/// are written to the parameters \p Opnd0 and \p Opnd1.
///
-/// We combine shuffle to ADDSUB directly on the abstract vector shuffle nodes
+/// We combine shuffle to ADDSUB(SUBADD) directly on the abstract vector shuffle nodes
/// so it is easier to generically match. We also insert dummy vector shuffle
/// nodes for the operands which explicitly discard the lanes which are unused
/// by this operation to try to flow through the rest of the combiner
/// the fact that they're unused.
-static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget,
- SDValue &Opnd0, SDValue &Opnd1) {
+static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget,
+ SDValue &Opnd0, SDValue &Opnd1,
+ bool matchSubAdd = false) {
EVT VT = N->getValueType(0);
if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) &&
@@ -29728,12 +29736,15 @@ static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget,
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
- // We require the first shuffle operand to be the FSUB node, and the second to
- // be the FADD node.
- if (V1.getOpcode() == ISD::FADD && V2.getOpcode() == ISD::FSUB) {
+ unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB;
+ unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD;
+
+ // We require the first shuffle operand to be the ExpectedOpcode node,
+ // and the second to be the NextExpectedOpcode node.
+ if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) {
ShuffleVectorSDNode::commuteMask(Mask);
std::swap(V1, V2);
- } else if (V1.getOpcode() != ISD::FSUB || V2.getOpcode() != ISD::FADD)
+ } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode)
return false;
// If there are other uses of these operations we can't fold them.
@@ -29767,7 +29778,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
SDValue Opnd0, Opnd1;
- if (!isAddSub(N, Subtarget, Opnd0, Opnd1))
+ if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1))
return SDValue();
EVT VT = N->getValueType(0);
@@ -29775,7 +29786,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,
// Try to generate X86ISD::FMADDSUB node here.
SDValue Opnd2;
- if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))
+ if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))
return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);
// Do not generate X86ISD::ADDSUB node for 512-bit types even though
@@ -29787,6 +29798,26 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,
return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);
}
+/// \brief Try to combine a shuffle into a target-specific
+/// mul-sub-add node.
+static SDValue combineShuffleToFMSubAdd(SDNode *N,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ SDValue Opnd0, Opnd1;
+ if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true))
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // Try to generate X86ISD::FMSUBADD node here.
+ SDValue Opnd2;
+ if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))
+ return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2);
+
+ return SDValue();
+}
+
// We are looking for a shuffle where both sources are concatenated with undef
// and have a width that is half of the output's width. AVX2 has VPERMD/Q, so
// if we can express this as a single-source shuffle, that's preferable.
@@ -29873,11 +29904,14 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
EVT VT = N->getValueType(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
// If we have legalized the vector types, look for blends of FADD and FSUB
- // nodes that we can fuse into an ADDSUB node.
+ // nodes that we can fuse into an ADDSUB, FMADDSUB, or FMSUBADD node.
if (TLI.isTypeLegal(VT)) {
if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG))
return AddSub;
+ if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG))
+ return FMSubAdd;
+
if (SDValue HAddSub = foldShuffleOfHorizOp(N))
return HAddSub;
}
@@ -30181,7 +30215,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,
// For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))
// sign-extend to a 256-bit operation to avoid truncation.
if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() &&
- N0->getOperand(0)->getValueType(0).is256BitVector()) {
+ N0->getOperand(0).getValueType().is256BitVector()) {
SExtVT = MVT::v4i64;
FPCastVT = MVT::v4f64;
}
@@ -30194,8 +30228,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,
// 256-bit because the shuffle is cheaper than sign extending the result of
// the compare.
if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() &&
- (N0->getOperand(0)->getValueType(0).is256BitVector() ||
- N0->getOperand(0)->getValueType(0).is512BitVector())) {
+ (N0->getOperand(0).getValueType().is256BitVector() ||
+ N0->getOperand(0).getValueType().is512BitVector())) {
SExtVT = MVT::v8i32;
FPCastVT = MVT::v8f32;
}
@@ -30484,7 +30518,8 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1);
}
-// Attempt to replace an min/max v8i16 horizontal reduction with PHMINPOSUW.
+// Attempt to replace an min/max v8i16/v16i8 horizontal reduction with
+// PHMINPOSUW.
static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// Bail without SSE41.
@@ -30492,7 +30527,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
return SDValue();
EVT ExtractVT = Extract->getValueType(0);
- if (ExtractVT != MVT::i16)
+ if (ExtractVT != MVT::i16 && ExtractVT != MVT::i8)
return SDValue();
// Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns.
@@ -30504,7 +30539,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
EVT SrcVT = Src.getValueType();
EVT SrcSVT = SrcVT.getScalarType();
- if (SrcSVT != MVT::i16 || (SrcVT.getSizeInBits() % 128) != 0)
+ if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0)
return SDValue();
SDLoc DL(Extract);
@@ -30520,22 +30555,39 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
SDValue Hi = extractSubVector(MinPos, NumSubElts, DAG, DL, SubSizeInBits);
MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
}
- assert(SrcVT == MVT::v8i16 && "Unexpected value type");
+ assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) ||
+ (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) &&
+ "Unexpected value type");
// PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
// to flip the value accordingly.
SDValue Mask;
+ unsigned MaskEltsBits = ExtractVT.getSizeInBits();
if (BinOp == ISD::SMAX)
- Mask = DAG.getConstant(APInt::getSignedMaxValue(16), DL, SrcVT);
+ Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::SMIN)
- Mask = DAG.getConstant(APInt::getSignedMinValue(16), DL, SrcVT);
+ Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::UMAX)
- Mask = DAG.getConstant(APInt::getAllOnesValue(16), DL, SrcVT);
+ Mask = DAG.getConstant(APInt::getAllOnesValue(MaskEltsBits), DL, SrcVT);
if (Mask)
MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
- MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, SrcVT, MinPos);
+ // For v16i8 cases we need to perform UMIN on pairs of byte elements,
+ // shuffling each upper element down and insert zeros. This means that the
+ // v16i8 UMIN will leave the upper element as zero, performing zero-extension
+ // ready for the PHMINPOS.
+ if (ExtractVT == MVT::i8) {
+ SDValue Upper = DAG.getVectorShuffle(
+ SrcVT, DL, MinPos, getZeroVector(MVT::v16i8, Subtarget, DAG, DL),
+ {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
+ MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper);
+ }
+
+ // Perform the PHMINPOS on a v8i16 vector,
+ MinPos = DAG.getBitcast(MVT::v8i16, MinPos);
+ MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos);
+ MinPos = DAG.getBitcast(SrcVT, MinPos);
if (Mask)
MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
@@ -30851,7 +30903,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget))
return Cmp;
- // Attempt to replace min/max v8i16 reductions with PHMINPOSUW.
+ // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget))
return MinMax;
@@ -32555,7 +32607,7 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) {
// 1. MOVs can write to a register that differs from source
// 2. MOVs accept memory operands
- if (!VT.isInteger() || VT.isVector() || N1.getOpcode() != ISD::Constant ||
+ if (VT.isVector() || N1.getOpcode() != ISD::Constant ||
N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||
N0.getOperand(1).getOpcode() != ISD::Constant)
return SDValue();
@@ -32569,11 +32621,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) {
if (SarConst.isNegative())
return SDValue();
- for (MVT SVT : MVT::integer_valuetypes()) {
+ for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
unsigned ShiftSize = SVT.getSizeInBits();
// skipping types without corresponding sext/zext and
// ShlConst that is not one of [56,48,32,24,16]
- if (ShiftSize < 8 || ShiftSize > 64 || ShlConst != Size - ShiftSize)
+ if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
continue;
SDLoc DL(N);
SDValue NN =
@@ -32626,37 +32678,6 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
-/// \brief Returns a vector of 0s if the node in input is a vector logical
-/// shift by a constant amount which is known to be bigger than or equal
-/// to the vector element size in bits.
-static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
- EVT VT = N->getValueType(0);
-
- if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
- (!Subtarget.hasInt256() ||
- (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
- return SDValue();
-
- SDValue Amt = N->getOperand(1);
- SDLoc DL(N);
- if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt))
- if (auto *AmtSplat = AmtBV->getConstantSplatNode()) {
- const APInt &ShiftAmt = AmtSplat->getAPIntValue();
- unsigned MaxAmount =
- VT.getSimpleVT().getScalarSizeInBits();
-
- // SSE2/AVX2 logical shifts always return a vector of 0s
- // if the shift amount is bigger than or equal to
- // the element size. The constant shift amount will be
- // encoded as a 8-bit immediate.
- if (ShiftAmt.trunc(8).uge(MaxAmount))
- return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, DL);
- }
-
- return SDValue();
-}
-
static SDValue combineShift(SDNode* N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -32672,11 +32693,6 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG,
if (SDValue V = combineShiftRightLogical(N, DAG))
return V;
- // Try to fold this logical shift into a zero vector.
- if (N->getOpcode() != ISD::SRA)
- if (SDValue V = performShiftToAllZeros(N, DAG, Subtarget))
- return V;
-
return SDValue();
}
@@ -32996,21 +33012,20 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
// register. In most cases we actually compare or select YMM-sized registers
// and mixing the two types creates horrible code. This method optimizes
// some of the transition sequences.
+// Even with AVX-512 this is still useful for removing casts around logical
+// operations on vXi1 mask types.
static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
- if (!VT.is256BitVector())
- return SDValue();
+ assert(VT.isVector() && "Expected vector type");
assert((N->getOpcode() == ISD::ANY_EXTEND ||
N->getOpcode() == ISD::ZERO_EXTEND ||
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
SDValue Narrow = N->getOperand(0);
- EVT NarrowVT = Narrow->getValueType(0);
- if (!NarrowVT.is128BitVector())
- return SDValue();
+ EVT NarrowVT = Narrow.getValueType();
if (Narrow->getOpcode() != ISD::XOR &&
Narrow->getOpcode() != ISD::AND &&
@@ -33026,12 +33041,12 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
return SDValue();
// The type of the truncated inputs.
- EVT WideVT = N0->getOperand(0)->getValueType(0);
- if (WideVT != VT)
+ if (N0->getOperand(0).getValueType() != VT)
return SDValue();
// The right side has to be a 'trunc' or a constant vector.
- bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE;
+ bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
+ N1.getOperand(0).getValueType() == VT;
ConstantSDNode *RHSConstSplat = nullptr;
if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1))
RHSConstSplat = RHSBV->getConstantSplatNode();
@@ -33040,37 +33055,31 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), WideVT))
+ if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
return SDValue();
// Set N0 and N1 to hold the inputs to the new wide operation.
N0 = N0->getOperand(0);
if (RHSConstSplat) {
- N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getVectorElementType(),
+ N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT.getVectorElementType(),
SDValue(RHSConstSplat, 0));
- N1 = DAG.getSplatBuildVector(WideVT, DL, N1);
+ N1 = DAG.getSplatBuildVector(VT, DL, N1);
} else if (RHSTrunc) {
N1 = N1->getOperand(0);
}
// Generate the wide operation.
- SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, WideVT, N0, N1);
+ SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
unsigned Opcode = N->getOpcode();
switch (Opcode) {
+ default: llvm_unreachable("Unexpected opcode");
case ISD::ANY_EXTEND:
return Op;
- case ISD::ZERO_EXTEND: {
- unsigned InBits = NarrowVT.getScalarSizeInBits();
- APInt Mask = APInt::getAllOnesValue(InBits);
- Mask = Mask.zext(VT.getScalarSizeInBits());
- return DAG.getNode(ISD::AND, DL, VT,
- Op, DAG.getConstant(Mask, DL, VT));
- }
+ case ISD::ZERO_EXTEND:
+ return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType());
case ISD::SIGN_EXTEND:
return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT,
Op, DAG.getValueType(NarrowVT));
- default:
- llvm_unreachable("Unexpected opcode");
}
}
@@ -33882,16 +33891,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
if (!Subtarget.hasSSE2())
return SDValue();
- if (Subtarget.hasBWI()) {
- if (VT.getSizeInBits() > 512)
- return SDValue();
- } else if (Subtarget.hasAVX2()) {
- if (VT.getSizeInBits() > 256)
- return SDValue();
- } else {
- if (VT.getSizeInBits() > 128)
- return SDValue();
- }
// Detect the following pattern:
//
@@ -33903,7 +33902,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
// %6 = trunc <N x i32> %5 to <N x i8>
//
// In AVX512, the last instruction can also be a trunc store.
-
if (In.getOpcode() != ISD::SRL)
return SDValue();
@@ -33924,6 +33922,35 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
return true;
};
+ // Split vectors to legal target size and apply AVG.
+ auto LowerToAVG = [&](SDValue Op0, SDValue Op1) {
+ unsigned NumSubs = 1;
+ if (Subtarget.hasBWI()) {
+ if (VT.getSizeInBits() > 512)
+ NumSubs = VT.getSizeInBits() / 512;
+ } else if (Subtarget.hasAVX2()) {
+ if (VT.getSizeInBits() > 256)
+ NumSubs = VT.getSizeInBits() / 256;
+ } else {
+ if (VT.getSizeInBits() > 128)
+ NumSubs = VT.getSizeInBits() / 128;
+ }
+
+ if (NumSubs == 1)
+ return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1);
+
+ SmallVector<SDValue, 4> Subs;
+ EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
+ VT.getVectorNumElements() / NumSubs);
+ for (unsigned i = 0; i != NumSubs; ++i) {
+ unsigned Idx = i * SubVT.getVectorNumElements();
+ SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits());
+ SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits());
+ Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS));
+ }
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
+ };
+
// Check if each element of the vector is left-shifted by one.
auto LHS = In.getOperand(0);
auto RHS = In.getOperand(1);
@@ -33947,8 +33974,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
SDValue VecOnes = DAG.getConstant(1, DL, InVT);
Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
- return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0),
- Operands[1]);
+ return LowerToAVG(Operands[0].getOperand(0), Operands[1]);
}
if (Operands[0].getOpcode() == ISD::ADD)
@@ -33972,8 +33998,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();
// The pattern is detected, emit X86ISD::AVG instruction.
- return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0),
- Operands[1].getOperand(0));
+ return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0));
}
return SDValue();
@@ -35872,14 +35897,8 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
if (SDValue NewCMov = combineToExtendCMOV(N, DAG))
return NewCMov;
- if (!DCI.isBeforeLegalizeOps()) {
- if (InVT == MVT::i1) {
- SDValue Zero = DAG.getConstant(0, DL, VT);
- SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
- return DAG.getSelect(DL, VT, N0, AllOnes, Zero);
- }
+ if (!DCI.isBeforeLegalizeOps())
return SDValue();
- }
if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR &&
isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) {
@@ -35897,7 +35916,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))
return V;
- if (Subtarget.hasAVX() && VT.is256BitVector())
+ if (VT.isVector())
if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))
return R;
@@ -36089,7 +36108,7 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))
return V;
- if (VT.is256BitVector())
+ if (VT.isVector())
if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))
return R;
@@ -36244,39 +36263,54 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
SDLoc DL(N);
- // Pre-shrink oversized index elements to avoid triggering scalarization.
- if (DCI.isBeforeLegalize()) {
+ if (DCI.isBeforeLegalizeOps()) {
SDValue Index = N->getOperand(4);
- if (Index.getScalarValueSizeInBits() > 64) {
- EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), MVT::i64,
+ // Remove any sign extends from 32 or smaller to larger than 32.
+ // Only do this before LegalizeOps in case we need the sign extend for
+ // legalization.
+ if (Index.getOpcode() == ISD::SIGN_EXTEND) {
+ if (Index.getScalarValueSizeInBits() > 32 &&
+ Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
+ SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+ NewOps[4] = Index.getOperand(0);
+ DAG.UpdateNodeOperands(N, NewOps);
+ // The original sign extend has less users, add back to worklist in case
+ // it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ return SDValue(N, 0);
+ }
+ }
+
+ // Make sure the index is either i32 or i64
+ unsigned ScalarSize = Index.getScalarValueSizeInBits();
+ if (ScalarSize != 32 && ScalarSize != 64) {
+ MVT EltVT = ScalarSize > 32 ? MVT::i64 : MVT::i32;
+ EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
Index.getValueType().getVectorNumElements());
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index);
+ Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[4] = Trunc;
+ NewOps[4] = Index;
DAG.UpdateNodeOperands(N, NewOps);
DCI.AddToWorklist(N);
return SDValue(N, 0);
}
- }
- // Try to remove sign extends from i32 to i64 on the index.
- // Only do this before legalize in case we are relying on it for
- // legalization.
- // TODO: We should maybe remove any sign extend once we learn how to sign
- // extend narrow index during lowering.
- if (DCI.isBeforeLegalizeOps()) {
- SDValue Index = N->getOperand(4);
- if (Index.getScalarValueSizeInBits() == 64 &&
- Index.getOpcode() == ISD::SIGN_EXTEND &&
+ // Try to remove zero extends from 32->64 if we know the sign bit of
+ // the input is zero.
+ if (Index.getOpcode() == ISD::ZERO_EXTEND &&
+ Index.getScalarValueSizeInBits() == 64 &&
Index.getOperand(0).getScalarValueSizeInBits() == 32) {
- SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[4] = Index.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- // The original sign extend has less users, add back to worklist in case
- // it needs to be removed.
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ if (DAG.SignBitIsZero(Index.getOperand(0))) {
+ SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+ NewOps[4] = Index.getOperand(0);
+ DAG.UpdateNodeOperands(N, NewOps);
+ // The original zero extend has less users, add back to worklist in case
+ // it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ return SDValue(N, 0);
+ }
}
}
@@ -36288,6 +36322,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[2] = Mask.getOperand(0);
DAG.UpdateNodeOperands(N, NewOps);
+ return SDValue(N, 0);
}
// With AVX2 we only demand the upper bit of the mask.
@@ -36356,7 +36391,7 @@ static SDValue combineVectorCompareAndMaskUnaryOp(SDNode *N,
EVT VT = N->getValueType(0);
if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||
N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC ||
- VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits())
+ VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits())
return SDValue();
// Now check that the other operand of the AND is a constant. We could