diff options
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | lib/Target/X86/X86ISelLowering.cpp | 465 |
1 files changed, 250 insertions, 215 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index a72f4daa5e11..5ac5d0348f8a 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -461,7 +461,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRL_PARTS, VT, Custom); } - if (Subtarget.hasSSE1()) + if (Subtarget.hasSSEPrefetch() || Subtarget.has3DNow()) setOperationAction(ISD::PREFETCH , MVT::Other, Legal); setOperationAction(ISD::ATOMIC_FENCE , MVT::Other, Custom); @@ -1622,16 +1622,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setLibcallName(RTLIB::MUL_I128, nullptr); } - // Combine sin / cos into one node or libcall if possible. - if (Subtarget.hasSinCos()) { - setLibcallName(RTLIB::SINCOS_F32, "sincosf"); - setLibcallName(RTLIB::SINCOS_F64, "sincos"); - if (Subtarget.isTargetDarwin()) { - // For MacOSX, we don't want the normal expansion of a libcall to sincos. - // We want to issue a libcall to __sincos_stret to avoid memory traffic. - setOperationAction(ISD::FSINCOS, MVT::f64, Custom); - setOperationAction(ISD::FSINCOS, MVT::f32, Custom); - } + // Combine sin / cos into _sincos_stret if it is available. + if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr && + getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) { + setOperationAction(ISD::FSINCOS, MVT::f64, Custom); + setOperationAction(ISD::FSINCOS, MVT::f32, Custom); } if (Subtarget.isTargetWin64()) { @@ -7480,9 +7475,9 @@ static bool isAddSub(const BuildVectorSDNode *BV, } /// Returns true if is possible to fold MUL and an idiom that has already been -/// recognized as ADDSUB(\p Opnd0, \p Opnd1) into FMADDSUB(x, y, \p Opnd1). -/// If (and only if) true is returned, the operands of FMADDSUB are written to -/// parameters \p Opnd0, \p Opnd1, \p Opnd2. +/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into +/// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the +/// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2. /// /// Prior to calling this function it should be known that there is some /// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation @@ -7505,12 +7500,12 @@ static bool isAddSub(const BuildVectorSDNode *BV, /// recognized ADDSUB idiom with ADDSUB operation is that such replacement /// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit /// FMADDSUB is. -static bool isFMAddSub(const X86Subtarget &Subtarget, SelectionDAG &DAG, - SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2, - unsigned ExpectedUses) { +static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget, + SelectionDAG &DAG, + SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2, + unsigned ExpectedUses) { if (Opnd0.getOpcode() != ISD::FMUL || - !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || - !Subtarget.hasAnyFMA()) + !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || !Subtarget.hasAnyFMA()) return false; // FIXME: These checks must match the similar ones in @@ -7547,7 +7542,7 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, SDValue Opnd2; // TODO: According to coverage reports, the FMADDSUB transform is not // triggered by any tests. - if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -11958,6 +11953,19 @@ static int canLowerByDroppingEvenElements(ArrayRef<int> Mask, return 0; } +static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, + ArrayRef<int> Mask, SDValue V1, + SDValue V2, SelectionDAG &DAG) { + MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); + MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); + + SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); + if (V2.isUndef()) + return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); + + return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); +} + /// \brief Generic lowering of v16i8 shuffles. /// /// This is a hybrid strategy to lower v16i8 vectors. It first attempts to @@ -12148,6 +12156,10 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( DL, MVT::v16i8, V1, V2, Mask, DAG)) return Unpack; + + // If we have VBMI we can use one VPERM instead of multiple PSHUFBs. + if (Subtarget.hasVBMI() && Subtarget.hasVLX()) + return lowerVectorShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, DAG); } return PSHUFB; @@ -13048,19 +13060,6 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, DAG.getConstant(Immediate, DL, MVT::i8)); } -static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, - ArrayRef<int> Mask, SDValue V1, - SDValue V2, SelectionDAG &DAG) { - MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); - MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); - - SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); - if (V2.isUndef()) - return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); - - return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); -} - /// \brief Handle lowering of 4-lane 64-bit floating point shuffles. /// /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2 @@ -13615,6 +13614,10 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG)) return PSHUFB; + // AVX512VBMIVL can lower to VPERMB. + if (Subtarget.hasVBMI() && Subtarget.hasVLX()) + return lowerVectorShuffleWithPERMV(DL, MVT::v32i8, Mask, V1, V2, DAG); + // Try to simplify this by merging 128-bit lanes to enable a lane-based // shuffle. if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( @@ -14077,6 +14080,10 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Zeroable, Subtarget, DAG)) return Blend; + if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( + DL, MVT::v32i16, Mask, V1, V2, Zeroable, Subtarget, DAG)) + return PSHUFB; + return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG); } @@ -14212,7 +14219,9 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, ExtVT = MVT::v4i32; break; case MVT::v8i1: - ExtVT = MVT::v8i64; // Take 512-bit type, more shuffles on KNL + // Take 512-bit type, more shuffles on KNL. If we have VLX use a 256-bit + // shuffle. + ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64; break; case MVT::v16i1: ExtVT = MVT::v16i32; @@ -14569,11 +14578,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, unsigned NumElts = VecVT.getVectorNumElements(); // Extending v8i1/v16i1 to 512-bit get better performance on KNL // than extending to 128/256bit. - unsigned VecSize = (NumElts <= 4 ? 128 : 512); - MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize / NumElts), NumElts); - SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVT, Vec); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - ExtVT.getVectorElementType(), Ext, Idx); + MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; + MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); + SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ExtEltVT, Ext, Idx); return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); } @@ -14768,12 +14776,11 @@ static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG, // Non constant index. Extend source and destination, // insert element and then truncate the result. unsigned NumElts = VecVT.getVectorNumElements(); - unsigned VecSize = (NumElts <= 4 ? 128 : 512); - MVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize/NumElts), NumElts); - MVT ExtEltVT = ExtVecVT.getVectorElementType(); + MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; + MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); SDValue ExtOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtVecVT, - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVecVT, Vec), - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtEltVT, Elt), Idx); + DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec), + DAG.getNode(ISD::SIGN_EXTEND, dl, ExtEltVT, Elt), Idx); return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp); } @@ -16287,21 +16294,6 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op, return SelectedVal; } -static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDValue In = Op->getOperand(0); - MVT InVT = In.getSimpleValueType(); - - if (InVT.getVectorElementType() == MVT::i1) - return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG); - - if (Subtarget.hasFp256()) - if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) - return Res; - - return SDValue(); -} - static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue In = Op.getOperand(0); @@ -16440,7 +16432,8 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, assert((InVT.is256BitVector() || InVT.is128BitVector()) && "Unexpected vector type."); unsigned NumElts = InVT.getVectorNumElements(); - MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); + MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts); + MVT ExtVT = MVT::getVectorVT(EltVT, NumElts); In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In); InVT = ExtVT; ShiftInx = InVT.getScalarSizeInBits() - 1; @@ -18446,6 +18439,21 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, return V; } +static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue In = Op->getOperand(0); + MVT InVT = In.getSimpleValueType(); + + if (InVT.getVectorElementType() == MVT::i1) + return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); + + if (Subtarget.hasFp256()) + if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) + return Res; + + return SDValue(); +} + // Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG. // For sign extend this needs to handle all vector sizes and SSE4.1 and // non-SSE4.1 targets. For zero extend this should only handle inputs of @@ -21128,7 +21136,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, // ADC/ADCX/SBB case ADX: { SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32); - SDVTList VTs = DAG.getVTList(Op.getOperand(3)->getValueType(0), MVT::i32); + SDVTList VTs = DAG.getVTList(Op.getOperand(3).getValueType(), MVT::i32); SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(2), DAG.getConstant(-1, dl, MVT::i8)); SDValue Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(3), @@ -22231,6 +22239,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask)); } + assert(VT == MVT::v16i8 && "Unexpected VT"); + SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A); SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B); SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB); @@ -22989,12 +22999,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, (Subtarget.hasAVX512() && VT == MVT::v16i16) || (Subtarget.hasAVX512() && VT == MVT::v16i8) || (Subtarget.hasBWI() && VT == MVT::v32i8)) { - MVT EvtSVT = (VT == MVT::v32i8 ? MVT::i16 : MVT::i32); + assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) && + "Unexpected vector type"); + MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32; MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements()); unsigned ExtOpc = Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; R = DAG.getNode(ExtOpc, dl, ExtVT, R); - Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt); + Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt); return DAG.getNode(ISD::TRUNCATE, dl, VT, DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt)); } @@ -24101,8 +24113,9 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget &Subtarget, // Only optimize x86_64 for now. i386 is a bit messy. For f32, // the small struct {f32, f32} is returned in (eax, edx). For f64, // the results are returned via SRet in memory. - const char *LibcallName = isF64 ? "__sincos_stret" : "__sincosf_stret"; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + RTLIB::Libcall LC = isF64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32; + const char *LibcallName = TLI.getLibcallName(LC); SDValue Callee = DAG.getExternalSymbol(LibcallName, TLI.getPointerTy(DAG.getDataLayout())); @@ -24928,7 +24941,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, case ISD::BITCAST: { assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); EVT DstVT = N->getValueType(0); - EVT SrcVT = N->getOperand(0)->getValueType(0); + EVT SrcVT = N->getOperand(0).getValueType(); if (SrcVT != MVT::f64 || (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8)) @@ -28407,8 +28420,6 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // TODO - attempt to narrow Mask back to writemask size. bool IsEVEXShuffle = RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128); - if (IsEVEXShuffle && (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits)) - return SDValue(); // TODO - handle 128/256-bit lane shuffles of 512-bit vectors. @@ -28491,11 +28502,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, - ShuffleVT)) { + ShuffleVT) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return SDValue(); // AVX512 Writemask clash. Res = DAG.getBitcast(ShuffleSrcVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); @@ -28505,11 +28515,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, Subtarget, Shuffle, - ShuffleVT, PermuteImm)) { + ShuffleVT, PermuteImm) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return SDValue(); // AVX512 Writemask clash. Res = DAG.getBitcast(ShuffleVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, @@ -28520,12 +28529,11 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, } if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, - V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, - ShuffleVT, UnaryShuffle)) { + V1, V2, DL, DAG, Subtarget, Shuffle, + ShuffleSrcVT, ShuffleVT, UnaryShuffle) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return SDValue(); // AVX512 Writemask clash. V1 = DAG.getBitcast(ShuffleSrcVT, V1); DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(ShuffleSrcVT, V2); @@ -28538,11 +28546,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleVT, - PermuteImm)) { + PermuteImm) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return SDValue(); // AVX512 Writemask clash. V1 = DAG.getBitcast(ShuffleVT, V1); DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(ShuffleVT, V2); @@ -28594,8 +28601,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, return SDValue(); // Depth threshold above which we can efficiently use variable mask shuffles. - // TODO This should probably be target specific. - bool AllowVariableMask = (Depth >= 3) || HasVariableMask; + int VariableShuffleDepth = Subtarget.hasFastVariableShuffle() ? 2 : 3; + bool AllowVariableMask = (Depth >= VariableShuffleDepth) || HasVariableMask; bool MaskContainsZeros = any_of(Mask, [](int M) { return M == SM_SentinelZero; }); @@ -29698,17 +29705,18 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, return SDValue(); } -/// Returns true iff the shuffle node \p N can be replaced with ADDSUB -/// operation. If true is returned then the operands of ADDSUB operation +/// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD) +/// operation. If true is returned then the operands of ADDSUB(SUBADD) operation /// are written to the parameters \p Opnd0 and \p Opnd1. /// -/// We combine shuffle to ADDSUB directly on the abstract vector shuffle nodes +/// We combine shuffle to ADDSUB(SUBADD) directly on the abstract vector shuffle nodes /// so it is easier to generically match. We also insert dummy vector shuffle /// nodes for the operands which explicitly discard the lanes which are unused /// by this operation to try to flow through the rest of the combiner /// the fact that they're unused. -static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, - SDValue &Opnd0, SDValue &Opnd1) { +static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, + SDValue &Opnd0, SDValue &Opnd1, + bool matchSubAdd = false) { EVT VT = N->getValueType(0); if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && @@ -29728,12 +29736,15 @@ static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, SDValue V1 = N->getOperand(0); SDValue V2 = N->getOperand(1); - // We require the first shuffle operand to be the FSUB node, and the second to - // be the FADD node. - if (V1.getOpcode() == ISD::FADD && V2.getOpcode() == ISD::FSUB) { + unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB; + unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD; + + // We require the first shuffle operand to be the ExpectedOpcode node, + // and the second to be the NextExpectedOpcode node. + if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) { ShuffleVectorSDNode::commuteMask(Mask); std::swap(V1, V2); - } else if (V1.getOpcode() != ISD::FSUB || V2.getOpcode() != ISD::FADD) + } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode) return false; // If there are other uses of these operations we can't fold them. @@ -29767,7 +29778,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue Opnd0, Opnd1; - if (!isAddSub(N, Subtarget, Opnd0, Opnd1)) + if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1)) return SDValue(); EVT VT = N->getValueType(0); @@ -29775,7 +29786,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, // Try to generate X86ISD::FMADDSUB node here. SDValue Opnd2; - if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -29787,6 +29798,26 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } +/// \brief Try to combine a shuffle into a target-specific +/// mul-sub-add node. +static SDValue combineShuffleToFMSubAdd(SDNode *N, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue Opnd0, Opnd1; + if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true)) + return SDValue(); + + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // Try to generate X86ISD::FMSUBADD node here. + SDValue Opnd2; + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) + return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2); + + return SDValue(); +} + // We are looking for a shuffle where both sources are concatenated with undef // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so // if we can express this as a single-source shuffle, that's preferable. @@ -29873,11 +29904,14 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // If we have legalized the vector types, look for blends of FADD and FSUB - // nodes that we can fuse into an ADDSUB node. + // nodes that we can fuse into an ADDSUB, FMADDSUB, or FMSUBADD node. if (TLI.isTypeLegal(VT)) { if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG)) return AddSub; + if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG)) + return FMSubAdd; + if (SDValue HAddSub = foldShuffleOfHorizOp(N)) return HAddSub; } @@ -30181,7 +30215,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2)) // sign-extend to a 256-bit operation to avoid truncation. if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && - N0->getOperand(0)->getValueType(0).is256BitVector()) { + N0->getOperand(0).getValueType().is256BitVector()) { SExtVT = MVT::v4i64; FPCastVT = MVT::v4f64; } @@ -30194,8 +30228,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // 256-bit because the shuffle is cheaper than sign extending the result of // the compare. if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && - (N0->getOperand(0)->getValueType(0).is256BitVector() || - N0->getOperand(0)->getValueType(0).is512BitVector())) { + (N0->getOperand(0).getValueType().is256BitVector() || + N0->getOperand(0).getValueType().is512BitVector())) { SExtVT = MVT::v8i32; FPCastVT = MVT::v8f32; } @@ -30484,7 +30518,8 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); } -// Attempt to replace an min/max v8i16 horizontal reduction with PHMINPOSUW. +// Attempt to replace an min/max v8i16/v16i8 horizontal reduction with +// PHMINPOSUW. static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // Bail without SSE41. @@ -30492,7 +30527,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, return SDValue(); EVT ExtractVT = Extract->getValueType(0); - if (ExtractVT != MVT::i16) + if (ExtractVT != MVT::i16 && ExtractVT != MVT::i8) return SDValue(); // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. @@ -30504,7 +30539,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, EVT SrcVT = Src.getValueType(); EVT SrcSVT = SrcVT.getScalarType(); - if (SrcSVT != MVT::i16 || (SrcVT.getSizeInBits() % 128) != 0) + if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0) return SDValue(); SDLoc DL(Extract); @@ -30520,22 +30555,39 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, SDValue Hi = extractSubVector(MinPos, NumSubElts, DAG, DL, SubSizeInBits); MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi); } - assert(SrcVT == MVT::v8i16 && "Unexpected value type"); + assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) || + (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) && + "Unexpected value type"); // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask // to flip the value accordingly. SDValue Mask; + unsigned MaskEltsBits = ExtractVT.getSizeInBits(); if (BinOp == ISD::SMAX) - Mask = DAG.getConstant(APInt::getSignedMaxValue(16), DL, SrcVT); + Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT); else if (BinOp == ISD::SMIN) - Mask = DAG.getConstant(APInt::getSignedMinValue(16), DL, SrcVT); + Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT); else if (BinOp == ISD::UMAX) - Mask = DAG.getConstant(APInt::getAllOnesValue(16), DL, SrcVT); + Mask = DAG.getConstant(APInt::getAllOnesValue(MaskEltsBits), DL, SrcVT); if (Mask) MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); - MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, SrcVT, MinPos); + // For v16i8 cases we need to perform UMIN on pairs of byte elements, + // shuffling each upper element down and insert zeros. This means that the + // v16i8 UMIN will leave the upper element as zero, performing zero-extension + // ready for the PHMINPOS. + if (ExtractVT == MVT::i8) { + SDValue Upper = DAG.getVectorShuffle( + SrcVT, DL, MinPos, getZeroVector(MVT::v16i8, Subtarget, DAG, DL), + {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); + MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper); + } + + // Perform the PHMINPOS on a v8i16 vector, + MinPos = DAG.getBitcast(MVT::v8i16, MinPos); + MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos); + MinPos = DAG.getBitcast(SrcVT, MinPos); if (Mask) MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); @@ -30851,7 +30903,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget)) return Cmp; - // Attempt to replace min/max v8i16 reductions with PHMINPOSUW. + // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW. if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget)) return MinMax; @@ -32555,7 +32607,7 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) { // 1. MOVs can write to a register that differs from source // 2. MOVs accept memory operands - if (!VT.isInteger() || VT.isVector() || N1.getOpcode() != ISD::Constant || + if (VT.isVector() || N1.getOpcode() != ISD::Constant || N0.getOpcode() != ISD::SHL || !N0.hasOneUse() || N0.getOperand(1).getOpcode() != ISD::Constant) return SDValue(); @@ -32569,11 +32621,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) { if (SarConst.isNegative()) return SDValue(); - for (MVT SVT : MVT::integer_valuetypes()) { + for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) { unsigned ShiftSize = SVT.getSizeInBits(); // skipping types without corresponding sext/zext and // ShlConst that is not one of [56,48,32,24,16] - if (ShiftSize < 8 || ShiftSize > 64 || ShlConst != Size - ShiftSize) + if (ShiftSize >= Size || ShlConst != Size - ShiftSize) continue; SDLoc DL(N); SDValue NN = @@ -32626,37 +32678,6 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -/// \brief Returns a vector of 0s if the node in input is a vector logical -/// shift by a constant amount which is known to be bigger than or equal -/// to the vector element size in bits. -static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); - - if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 && - (!Subtarget.hasInt256() || - (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16))) - return SDValue(); - - SDValue Amt = N->getOperand(1); - SDLoc DL(N); - if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt)) - if (auto *AmtSplat = AmtBV->getConstantSplatNode()) { - const APInt &ShiftAmt = AmtSplat->getAPIntValue(); - unsigned MaxAmount = - VT.getSimpleVT().getScalarSizeInBits(); - - // SSE2/AVX2 logical shifts always return a vector of 0s - // if the shift amount is bigger than or equal to - // the element size. The constant shift amount will be - // encoded as a 8-bit immediate. - if (ShiftAmt.trunc(8).uge(MaxAmount)) - return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, DL); - } - - return SDValue(); -} - static SDValue combineShift(SDNode* N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -32672,11 +32693,6 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG, if (SDValue V = combineShiftRightLogical(N, DAG)) return V; - // Try to fold this logical shift into a zero vector. - if (N->getOpcode() != ISD::SRA) - if (SDValue V = performShiftToAllZeros(N, DAG, Subtarget)) - return V; - return SDValue(); } @@ -32996,21 +33012,20 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { // register. In most cases we actually compare or select YMM-sized registers // and mixing the two types creates horrible code. This method optimizes // some of the transition sequences. +// Even with AVX-512 this is still useful for removing casts around logical +// operations on vXi1 mask types. static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); - if (!VT.is256BitVector()) - return SDValue(); + assert(VT.isVector() && "Expected vector type"); assert((N->getOpcode() == ISD::ANY_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node"); SDValue Narrow = N->getOperand(0); - EVT NarrowVT = Narrow->getValueType(0); - if (!NarrowVT.is128BitVector()) - return SDValue(); + EVT NarrowVT = Narrow.getValueType(); if (Narrow->getOpcode() != ISD::XOR && Narrow->getOpcode() != ISD::AND && @@ -33026,12 +33041,12 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, return SDValue(); // The type of the truncated inputs. - EVT WideVT = N0->getOperand(0)->getValueType(0); - if (WideVT != VT) + if (N0->getOperand(0).getValueType() != VT) return SDValue(); // The right side has to be a 'trunc' or a constant vector. - bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE; + bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE && + N1.getOperand(0).getValueType() == VT; ConstantSDNode *RHSConstSplat = nullptr; if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1)) RHSConstSplat = RHSBV->getConstantSplatNode(); @@ -33040,37 +33055,31 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), WideVT)) + if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT)) return SDValue(); // Set N0 and N1 to hold the inputs to the new wide operation. N0 = N0->getOperand(0); if (RHSConstSplat) { - N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getVectorElementType(), + N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT.getVectorElementType(), SDValue(RHSConstSplat, 0)); - N1 = DAG.getSplatBuildVector(WideVT, DL, N1); + N1 = DAG.getSplatBuildVector(VT, DL, N1); } else if (RHSTrunc) { N1 = N1->getOperand(0); } // Generate the wide operation. - SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, WideVT, N0, N1); + SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1); unsigned Opcode = N->getOpcode(); switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); case ISD::ANY_EXTEND: return Op; - case ISD::ZERO_EXTEND: { - unsigned InBits = NarrowVT.getScalarSizeInBits(); - APInt Mask = APInt::getAllOnesValue(InBits); - Mask = Mask.zext(VT.getScalarSizeInBits()); - return DAG.getNode(ISD::AND, DL, VT, - Op, DAG.getConstant(Mask, DL, VT)); - } + case ISD::ZERO_EXTEND: + return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType()); case ISD::SIGN_EXTEND: return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op, DAG.getValueType(NarrowVT)); - default: - llvm_unreachable("Unexpected opcode"); } } @@ -33882,16 +33891,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - if (Subtarget.hasBWI()) { - if (VT.getSizeInBits() > 512) - return SDValue(); - } else if (Subtarget.hasAVX2()) { - if (VT.getSizeInBits() > 256) - return SDValue(); - } else { - if (VT.getSizeInBits() > 128) - return SDValue(); - } // Detect the following pattern: // @@ -33903,7 +33902,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, // %6 = trunc <N x i32> %5 to <N x i8> // // In AVX512, the last instruction can also be a trunc store. - if (In.getOpcode() != ISD::SRL) return SDValue(); @@ -33924,6 +33922,35 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, return true; }; + // Split vectors to legal target size and apply AVG. + auto LowerToAVG = [&](SDValue Op0, SDValue Op1) { + unsigned NumSubs = 1; + if (Subtarget.hasBWI()) { + if (VT.getSizeInBits() > 512) + NumSubs = VT.getSizeInBits() / 512; + } else if (Subtarget.hasAVX2()) { + if (VT.getSizeInBits() > 256) + NumSubs = VT.getSizeInBits() / 256; + } else { + if (VT.getSizeInBits() > 128) + NumSubs = VT.getSizeInBits() / 128; + } + + if (NumSubs == 1) + return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); + + SmallVector<SDValue, 4> Subs; + EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), + VT.getVectorNumElements() / NumSubs); + for (unsigned i = 0; i != NumSubs; ++i) { + unsigned Idx = i * SubVT.getVectorNumElements(); + SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); + SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); + Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS)); + } + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); + }; + // Check if each element of the vector is left-shifted by one. auto LHS = In.getOperand(0); auto RHS = In.getOperand(1); @@ -33947,8 +33974,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, SDValue VecOnes = DAG.getConstant(1, DL, InVT); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); - return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), - Operands[1]); + return LowerToAVG(Operands[0].getOperand(0), Operands[1]); } if (Operands[0].getOpcode() == ISD::ADD) @@ -33972,8 +33998,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, return SDValue(); // The pattern is detected, emit X86ISD::AVG instruction. - return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), - Operands[1].getOperand(0)); + return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0)); } return SDValue(); @@ -35872,14 +35897,8 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue NewCMov = combineToExtendCMOV(N, DAG)) return NewCMov; - if (!DCI.isBeforeLegalizeOps()) { - if (InVT == MVT::i1) { - SDValue Zero = DAG.getConstant(0, DL, VT); - SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); - return DAG.getSelect(DL, VT, N0, AllOnes, Zero); - } + if (!DCI.isBeforeLegalizeOps()) return SDValue(); - } if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR && isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) { @@ -35897,7 +35916,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) return V; - if (Subtarget.hasAVX() && VT.is256BitVector()) + if (VT.isVector()) if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; @@ -36089,7 +36108,7 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) return V; - if (VT.is256BitVector()) + if (VT.isVector()) if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; @@ -36244,39 +36263,54 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); - // Pre-shrink oversized index elements to avoid triggering scalarization. - if (DCI.isBeforeLegalize()) { + if (DCI.isBeforeLegalizeOps()) { SDValue Index = N->getOperand(4); - if (Index.getScalarValueSizeInBits() > 64) { - EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, + // Remove any sign extends from 32 or smaller to larger than 32. + // Only do this before LegalizeOps in case we need the sign extend for + // legalization. + if (Index.getOpcode() == ISD::SIGN_EXTEND) { + if (Index.getScalarValueSizeInBits() > 32 && + Index.getOperand(0).getScalarValueSizeInBits() <= 32) { + SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); + NewOps[4] = Index.getOperand(0); + DAG.UpdateNodeOperands(N, NewOps); + // The original sign extend has less users, add back to worklist in case + // it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + return SDValue(N, 0); + } + } + + // Make sure the index is either i32 or i64 + unsigned ScalarSize = Index.getScalarValueSizeInBits(); + if (ScalarSize != 32 && ScalarSize != 64) { + MVT EltVT = ScalarSize > 32 ? MVT::i64 : MVT::i32; + EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT, Index.getValueType().getVectorNumElements()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index); + Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Trunc; + NewOps[4] = Index; DAG.UpdateNodeOperands(N, NewOps); DCI.AddToWorklist(N); return SDValue(N, 0); } - } - // Try to remove sign extends from i32 to i64 on the index. - // Only do this before legalize in case we are relying on it for - // legalization. - // TODO: We should maybe remove any sign extend once we learn how to sign - // extend narrow index during lowering. - if (DCI.isBeforeLegalizeOps()) { - SDValue Index = N->getOperand(4); - if (Index.getScalarValueSizeInBits() == 64 && - Index.getOpcode() == ISD::SIGN_EXTEND && + // Try to remove zero extends from 32->64 if we know the sign bit of + // the input is zero. + if (Index.getOpcode() == ISD::ZERO_EXTEND && + Index.getScalarValueSizeInBits() == 64 && Index.getOperand(0).getScalarValueSizeInBits() == 32) { - SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index.getOperand(0); - DAG.UpdateNodeOperands(N, NewOps); - // The original sign extend has less users, add back to worklist in case - // it needs to be removed. - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - return SDValue(N, 0); + if (DAG.SignBitIsZero(Index.getOperand(0))) { + SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); + NewOps[4] = Index.getOperand(0); + DAG.UpdateNodeOperands(N, NewOps); + // The original zero extend has less users, add back to worklist in case + // it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + return SDValue(N, 0); + } } } @@ -36288,6 +36322,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); NewOps[2] = Mask.getOperand(0); DAG.UpdateNodeOperands(N, NewOps); + return SDValue(N, 0); } // With AVX2 we only demand the upper bit of the mask. @@ -36356,7 +36391,7 @@ static SDValue combineVectorCompareAndMaskUnaryOp(SDNode *N, EVT VT = N->getValueType(0); if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND || N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC || - VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits()) + VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits()) return SDValue(); // Now check that the other operand of the AND is a constant. We could |