diff options
Diffstat (limited to 'contrib/llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | contrib/llvm/lib/Target/X86/X86ISelLowering.cpp | 465 | 
1 files changed, 250 insertions, 215 deletions
| diff --git a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp index a72f4daa5e11..5ac5d0348f8a 100644 --- a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -461,7 +461,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,      setOperationAction(ISD::SRL_PARTS, VT, Custom);    } -  if (Subtarget.hasSSE1()) +  if (Subtarget.hasSSEPrefetch() || Subtarget.has3DNow())      setOperationAction(ISD::PREFETCH      , MVT::Other, Legal);    setOperationAction(ISD::ATOMIC_FENCE  , MVT::Other, Custom); @@ -1622,16 +1622,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,      setLibcallName(RTLIB::MUL_I128, nullptr);    } -  // Combine sin / cos into one node or libcall if possible. -  if (Subtarget.hasSinCos()) { -    setLibcallName(RTLIB::SINCOS_F32, "sincosf"); -    setLibcallName(RTLIB::SINCOS_F64, "sincos"); -    if (Subtarget.isTargetDarwin()) { -      // For MacOSX, we don't want the normal expansion of a libcall to sincos. -      // We want to issue a libcall to __sincos_stret to avoid memory traffic. -      setOperationAction(ISD::FSINCOS, MVT::f64, Custom); -      setOperationAction(ISD::FSINCOS, MVT::f32, Custom); -    } +  // Combine sin / cos into _sincos_stret if it is available. +  if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr && +      getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) { +    setOperationAction(ISD::FSINCOS, MVT::f64, Custom); +    setOperationAction(ISD::FSINCOS, MVT::f32, Custom);    }    if (Subtarget.isTargetWin64()) { @@ -7480,9 +7475,9 @@ static bool isAddSub(const BuildVectorSDNode *BV,  }  /// Returns true if is possible to fold MUL and an idiom that has already been -/// recognized as ADDSUB(\p Opnd0, \p Opnd1) into FMADDSUB(x, y, \p Opnd1). -/// If (and only if) true is returned, the operands of FMADDSUB are written to -/// parameters \p Opnd0, \p Opnd1, \p Opnd2. +/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into  +/// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the +/// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2.  ///  /// Prior to calling this function it should be known that there is some  /// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation @@ -7505,12 +7500,12 @@ static bool isAddSub(const BuildVectorSDNode *BV,  /// recognized ADDSUB idiom with ADDSUB operation is that such replacement  /// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit  /// FMADDSUB is. -static bool isFMAddSub(const X86Subtarget &Subtarget, SelectionDAG &DAG, -                       SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2, -                       unsigned ExpectedUses) { +static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget, +                                 SelectionDAG &DAG, +                                 SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2, +                                 unsigned ExpectedUses) {    if (Opnd0.getOpcode() != ISD::FMUL || -      !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || -      !Subtarget.hasAnyFMA()) +      !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || !Subtarget.hasAnyFMA())      return false;    // FIXME: These checks must match the similar ones in @@ -7547,7 +7542,7 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV,    SDValue Opnd2;    // TODO: According to coverage reports, the FMADDSUB transform is not    // triggered by any tests. -  if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) +  if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts))      return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);    // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -11958,6 +11953,19 @@ static int canLowerByDroppingEvenElements(ArrayRef<int> Mask,    return 0;  } +static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, +                                           ArrayRef<int> Mask, SDValue V1, +                                           SDValue V2, SelectionDAG &DAG) { +  MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); +  MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); + +  SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); +  if (V2.isUndef()) +    return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); + +  return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); +} +  /// \brief Generic lowering of v16i8 shuffles.  ///  /// This is a hybrid strategy to lower v16i8 vectors. It first attempts to @@ -12148,6 +12156,10 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,        if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack(                DL, MVT::v16i8, V1, V2, Mask, DAG))          return Unpack; + +      // If we have VBMI we can use one VPERM instead of multiple PSHUFBs. +      if (Subtarget.hasVBMI() && Subtarget.hasVLX()) +        return lowerVectorShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, DAG);      }      return PSHUFB; @@ -13048,19 +13060,6 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT,                       DAG.getConstant(Immediate, DL, MVT::i8));  } -static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, -                                           ArrayRef<int> Mask, SDValue V1, -                                           SDValue V2, SelectionDAG &DAG) { -  MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); -  MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); - -  SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); -  if (V2.isUndef()) -    return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); - -  return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); -} -  /// \brief Handle lowering of 4-lane 64-bit floating point shuffles.  ///  /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2 @@ -13615,6 +13614,10 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,            DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG))      return PSHUFB; +  // AVX512VBMIVL can lower to VPERMB. +  if (Subtarget.hasVBMI() && Subtarget.hasVLX()) +    return lowerVectorShuffleWithPERMV(DL, MVT::v32i8, Mask, V1, V2, DAG); +    // Try to simplify this by merging 128-bit lanes to enable a lane-based    // shuffle.    if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( @@ -14077,6 +14080,10 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,                                                  Zeroable, Subtarget, DAG))      return Blend; +  if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( +          DL, MVT::v32i16, Mask, V1, V2, Zeroable, Subtarget, DAG)) +    return PSHUFB; +    return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG);  } @@ -14212,7 +14219,9 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,      ExtVT = MVT::v4i32;      break;    case MVT::v8i1: -    ExtVT = MVT::v8i64; // Take 512-bit type, more shuffles on KNL +    // Take 512-bit type, more shuffles on KNL. If we have VLX use a 256-bit +    // shuffle. +    ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64;      break;    case MVT::v16i1:      ExtVT = MVT::v16i32; @@ -14569,11 +14578,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,      unsigned NumElts = VecVT.getVectorNumElements();      // Extending v8i1/v16i1 to 512-bit get better performance on KNL      // than extending to 128/256bit. -    unsigned VecSize = (NumElts <= 4 ? 128 : 512); -    MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize / NumElts), NumElts); -    SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVT, Vec); -    SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, -                              ExtVT.getVectorElementType(), Ext, Idx); +    MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; +    MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); +    SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec); +    SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ExtEltVT, Ext, Idx);      return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);    } @@ -14768,12 +14776,11 @@ static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG,      // Non constant index. Extend source and destination,      // insert element and then truncate the result.      unsigned NumElts = VecVT.getVectorNumElements(); -    unsigned VecSize = (NumElts <= 4 ? 128 : 512); -    MVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize/NumElts), NumElts); -    MVT ExtEltVT = ExtVecVT.getVectorElementType(); +    MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; +    MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);      SDValue ExtOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtVecVT, -      DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVecVT, Vec), -      DAG.getNode(ISD::ZERO_EXTEND, dl, ExtEltVT, Elt), Idx); +      DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec), +      DAG.getNode(ISD::SIGN_EXTEND, dl, ExtEltVT, Elt), Idx);      return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp);    } @@ -16287,21 +16294,6 @@ static  SDValue LowerZERO_EXTEND_Mask(SDValue Op,    return SelectedVal;  } -static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, -                               SelectionDAG &DAG) { -  SDValue In = Op->getOperand(0); -  MVT InVT = In.getSimpleValueType(); - -  if (InVT.getVectorElementType() == MVT::i1) -    return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG); - -  if (Subtarget.hasFp256()) -    if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) -      return Res; - -  return SDValue(); -} -  static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget,                                  SelectionDAG &DAG) {    SDValue In = Op.getOperand(0); @@ -16440,7 +16432,8 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,      assert((InVT.is256BitVector() || InVT.is128BitVector()) &&             "Unexpected vector type.");      unsigned NumElts = InVT.getVectorNumElements(); -    MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); +    MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts); +    MVT ExtVT = MVT::getVectorVT(EltVT, NumElts);      In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In);      InVT = ExtVT;      ShiftInx = InVT.getScalarSizeInBits() - 1; @@ -18446,6 +18439,21 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op,    return V;  } +static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, +                               SelectionDAG &DAG) { +  SDValue In = Op->getOperand(0); +  MVT InVT = In.getSimpleValueType(); + +  if (InVT.getVectorElementType() == MVT::i1) +    return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); + +  if (Subtarget.hasFp256()) +    if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) +      return Res; + +  return SDValue(); +} +  // Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG.  // For sign extend this needs to handle all vector sizes and SSE4.1 and  // non-SSE4.1 targets. For zero extend this should only handle inputs of @@ -21128,7 +21136,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,    // ADC/ADCX/SBB    case ADX: {      SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32); -    SDVTList VTs = DAG.getVTList(Op.getOperand(3)->getValueType(0), MVT::i32); +    SDVTList VTs = DAG.getVTList(Op.getOperand(3).getValueType(), MVT::i32);      SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(2),                                  DAG.getConstant(-1, dl, MVT::i8));      SDValue Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(3), @@ -22231,6 +22239,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,                           DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask));      } +    assert(VT == MVT::v16i8 && "Unexpected VT"); +      SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A);      SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B);      SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB); @@ -22989,12 +22999,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,        (Subtarget.hasAVX512() && VT == MVT::v16i16) ||        (Subtarget.hasAVX512() && VT == MVT::v16i8) ||        (Subtarget.hasBWI() && VT == MVT::v32i8)) { -    MVT EvtSVT = (VT == MVT::v32i8 ? MVT::i16 : MVT::i32); +    assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) && +           "Unexpected vector type"); +    MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;      MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements());      unsigned ExtOpc =          Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;      R = DAG.getNode(ExtOpc, dl, ExtVT, R); -    Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt); +    Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt);      return DAG.getNode(ISD::TRUNCATE, dl, VT,                         DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt));    } @@ -24101,8 +24113,9 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget &Subtarget,    // Only optimize x86_64 for now. i386 is a bit messy. For f32,    // the small struct {f32, f32} is returned in (eax, edx). For f64,    // the results are returned via SRet in memory. -  const char *LibcallName =  isF64 ? "__sincos_stret" : "__sincosf_stret";    const TargetLowering &TLI = DAG.getTargetLoweringInfo(); +  RTLIB::Libcall LC = isF64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32; +  const char *LibcallName = TLI.getLibcallName(LC);    SDValue Callee =        DAG.getExternalSymbol(LibcallName, TLI.getPointerTy(DAG.getDataLayout())); @@ -24928,7 +24941,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,    case ISD::BITCAST: {      assert(Subtarget.hasSSE2() && "Requires at least SSE2!");      EVT DstVT = N->getValueType(0); -    EVT SrcVT = N->getOperand(0)->getValueType(0); +    EVT SrcVT = N->getOperand(0).getValueType();      if (SrcVT != MVT::f64 ||          (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8)) @@ -28407,8 +28420,6 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,    // TODO - attempt to narrow Mask back to writemask size.    bool IsEVEXShuffle =        RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128); -  if (IsEVEXShuffle && (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits)) -    return SDValue();    // TODO - handle 128/256-bit lane shuffles of 512-bit vectors. @@ -28491,11 +28502,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,      if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain,                                  V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, -                                ShuffleVT)) { +                                ShuffleVT) && +        (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {        if (Depth == 1 && Root.getOpcode() == Shuffle)          return SDValue(); // Nothing to do! -      if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) -        return SDValue(); // AVX512 Writemask clash.        Res = DAG.getBitcast(ShuffleSrcVT, V1);        DCI.AddToWorklist(Res.getNode());        Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); @@ -28505,11 +28515,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,      if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,                                         AllowIntDomain, Subtarget, Shuffle, -                                       ShuffleVT, PermuteImm)) { +                                       ShuffleVT, PermuteImm) && +        (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {        if (Depth == 1 && Root.getOpcode() == Shuffle)          return SDValue(); // Nothing to do! -      if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) -        return SDValue(); // AVX512 Writemask clash.        Res = DAG.getBitcast(ShuffleVT, V1);        DCI.AddToWorklist(Res.getNode());        Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, @@ -28520,12 +28529,11 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,    }    if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, -                               V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, -                               ShuffleVT, UnaryShuffle)) { +                               V1, V2, DL, DAG, Subtarget, Shuffle, +                               ShuffleSrcVT, ShuffleVT, UnaryShuffle) && +      (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {      if (Depth == 1 && Root.getOpcode() == Shuffle)        return SDValue(); // Nothing to do! -    if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) -      return SDValue(); // AVX512 Writemask clash.      V1 = DAG.getBitcast(ShuffleSrcVT, V1);      DCI.AddToWorklist(V1.getNode());      V2 = DAG.getBitcast(ShuffleSrcVT, V2); @@ -28538,11 +28546,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,    if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,                                        AllowIntDomain, V1, V2, DL, DAG,                                        Subtarget, Shuffle, ShuffleVT, -                                      PermuteImm)) { +                                      PermuteImm) && +      (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {      if (Depth == 1 && Root.getOpcode() == Shuffle)        return SDValue(); // Nothing to do! -    if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) -      return SDValue(); // AVX512 Writemask clash.      V1 = DAG.getBitcast(ShuffleVT, V1);      DCI.AddToWorklist(V1.getNode());      V2 = DAG.getBitcast(ShuffleVT, V2); @@ -28594,8 +28601,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,      return SDValue();    // Depth threshold above which we can efficiently use variable mask shuffles. -  // TODO This should probably be target specific. -  bool AllowVariableMask = (Depth >= 3) || HasVariableMask; +  int VariableShuffleDepth = Subtarget.hasFastVariableShuffle() ? 2 : 3; +  bool AllowVariableMask = (Depth >= VariableShuffleDepth) || HasVariableMask;    bool MaskContainsZeros =        any_of(Mask, [](int M) { return M == SM_SentinelZero; }); @@ -29698,17 +29705,18 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,    return SDValue();  } -/// Returns true iff the shuffle node \p N can be replaced with ADDSUB -/// operation. If true is returned then the operands of ADDSUB operation +/// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD) +/// operation. If true is returned then the operands of ADDSUB(SUBADD) operation  /// are written to the parameters \p Opnd0 and \p Opnd1.  /// -/// We combine shuffle to ADDSUB directly on the abstract vector shuffle nodes +/// We combine shuffle to ADDSUB(SUBADD) directly on the abstract vector shuffle nodes  /// so it is easier to generically match. We also insert dummy vector shuffle  /// nodes for the operands which explicitly discard the lanes which are unused  /// by this operation to try to flow through the rest of the combiner  /// the fact that they're unused. -static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, -                     SDValue &Opnd0, SDValue &Opnd1) { +static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, +                             SDValue &Opnd0, SDValue &Opnd1,  +                             bool matchSubAdd = false) {    EVT VT = N->getValueType(0);    if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && @@ -29728,12 +29736,15 @@ static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget,    SDValue V1 = N->getOperand(0);    SDValue V2 = N->getOperand(1); -  // We require the first shuffle operand to be the FSUB node, and the second to -  // be the FADD node. -  if (V1.getOpcode() == ISD::FADD && V2.getOpcode() == ISD::FSUB) { +  unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB; +  unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD; + +  // We require the first shuffle operand to be the ExpectedOpcode node, +  // and the second to be the NextExpectedOpcode node. +  if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) {      ShuffleVectorSDNode::commuteMask(Mask);      std::swap(V1, V2); -  } else if (V1.getOpcode() != ISD::FSUB || V2.getOpcode() != ISD::FADD) +  } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode)      return false;    // If there are other uses of these operations we can't fold them. @@ -29767,7 +29778,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,                                                  const X86Subtarget &Subtarget,                                                  SelectionDAG &DAG) {    SDValue Opnd0, Opnd1; -  if (!isAddSub(N, Subtarget, Opnd0, Opnd1)) +  if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1))      return SDValue();    EVT VT = N->getValueType(0); @@ -29775,7 +29786,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,    // Try to generate X86ISD::FMADDSUB node here.    SDValue Opnd2; -  if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) +  if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))      return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);    // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -29787,6 +29798,26 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,    return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);  } +/// \brief Try to combine a shuffle into a target-specific +/// mul-sub-add node. +static SDValue combineShuffleToFMSubAdd(SDNode *N, +                                        const X86Subtarget &Subtarget, +                                        SelectionDAG &DAG) { +  SDValue Opnd0, Opnd1; +  if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true)) +    return SDValue(); + +  EVT VT = N->getValueType(0); +  SDLoc DL(N); + +  // Try to generate X86ISD::FMSUBADD node here. +  SDValue Opnd2; +  if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) +    return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2); + +  return SDValue(); +} +  // We are looking for a shuffle where both sources are concatenated with undef  // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so  // if we can express this as a single-source shuffle, that's preferable. @@ -29873,11 +29904,14 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,    EVT VT = N->getValueType(0);    const TargetLowering &TLI = DAG.getTargetLoweringInfo();    // If we have legalized the vector types, look for blends of FADD and FSUB -  // nodes that we can fuse into an ADDSUB node. +  // nodes that we can fuse into an ADDSUB, FMADDSUB, or FMSUBADD node.    if (TLI.isTypeLegal(VT)) {      if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG))        return AddSub; +    if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG)) +      return FMSubAdd; +      if (SDValue HAddSub = foldShuffleOfHorizOp(N))        return HAddSub;    } @@ -30181,7 +30215,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,      // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))      // sign-extend to a 256-bit operation to avoid truncation.      if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && -        N0->getOperand(0)->getValueType(0).is256BitVector()) { +        N0->getOperand(0).getValueType().is256BitVector()) {        SExtVT = MVT::v4i64;        FPCastVT = MVT::v4f64;      } @@ -30194,8 +30228,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,      // 256-bit because the shuffle is cheaper than sign extending the result of      // the compare.      if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && -        (N0->getOperand(0)->getValueType(0).is256BitVector() || -         N0->getOperand(0)->getValueType(0).is512BitVector())) { +        (N0->getOperand(0).getValueType().is256BitVector() || +         N0->getOperand(0).getValueType().is512BitVector())) {        SExtVT = MVT::v8i32;        FPCastVT = MVT::v8f32;      } @@ -30484,7 +30518,8 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,    return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1);  } -// Attempt to replace an min/max v8i16 horizontal reduction with PHMINPOSUW. +// Attempt to replace an min/max v8i16/v16i8 horizontal reduction with +// PHMINPOSUW.  static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,                                               const X86Subtarget &Subtarget) {    // Bail without SSE41. @@ -30492,7 +30527,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,      return SDValue();    EVT ExtractVT = Extract->getValueType(0); -  if (ExtractVT != MVT::i16) +  if (ExtractVT != MVT::i16 && ExtractVT != MVT::i8)      return SDValue();    // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. @@ -30504,7 +30539,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,    EVT SrcVT = Src.getValueType();    EVT SrcSVT = SrcVT.getScalarType(); -  if (SrcSVT != MVT::i16 || (SrcVT.getSizeInBits() % 128) != 0) +  if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0)      return SDValue();    SDLoc DL(Extract); @@ -30520,22 +30555,39 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,      SDValue Hi = extractSubVector(MinPos, NumSubElts, DAG, DL, SubSizeInBits);      MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);    } -  assert(SrcVT == MVT::v8i16 && "Unexpected value type"); +  assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) || +          (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) && +         "Unexpected value type");    // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask    // to flip the value accordingly.    SDValue Mask; +  unsigned MaskEltsBits = ExtractVT.getSizeInBits();    if (BinOp == ISD::SMAX) -    Mask = DAG.getConstant(APInt::getSignedMaxValue(16), DL, SrcVT); +    Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);    else if (BinOp == ISD::SMIN) -    Mask = DAG.getConstant(APInt::getSignedMinValue(16), DL, SrcVT); +    Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);    else if (BinOp == ISD::UMAX) -    Mask = DAG.getConstant(APInt::getAllOnesValue(16), DL, SrcVT); +    Mask = DAG.getConstant(APInt::getAllOnesValue(MaskEltsBits), DL, SrcVT);    if (Mask)      MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); -  MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, SrcVT, MinPos); +  // For v16i8 cases we need to perform UMIN on pairs of byte elements, +  // shuffling each upper element down and insert zeros. This means that the +  // v16i8 UMIN will leave the upper element as zero, performing zero-extension +  // ready for the PHMINPOS. +  if (ExtractVT == MVT::i8) { +    SDValue Upper = DAG.getVectorShuffle( +        SrcVT, DL, MinPos, getZeroVector(MVT::v16i8, Subtarget, DAG, DL), +        {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); +    MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper); +  } + +  // Perform the PHMINPOS on a v8i16 vector, +  MinPos = DAG.getBitcast(MVT::v8i16, MinPos); +  MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos); +  MinPos = DAG.getBitcast(SrcVT, MinPos);    if (Mask)      MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); @@ -30851,7 +30903,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,    if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget))      return Cmp; -  // Attempt to replace min/max v8i16 reductions with PHMINPOSUW. +  // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.    if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget))      return MinMax; @@ -32555,7 +32607,7 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) {    // 1. MOVs can write to a register that differs from source    // 2. MOVs accept memory operands -  if (!VT.isInteger() || VT.isVector() || N1.getOpcode() != ISD::Constant || +  if (VT.isVector() || N1.getOpcode() != ISD::Constant ||        N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||        N0.getOperand(1).getOpcode() != ISD::Constant)      return SDValue(); @@ -32569,11 +32621,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) {    if (SarConst.isNegative())      return SDValue(); -  for (MVT SVT : MVT::integer_valuetypes()) { +  for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {      unsigned ShiftSize = SVT.getSizeInBits();      // skipping types without corresponding sext/zext and      // ShlConst that is not one of [56,48,32,24,16] -    if (ShiftSize < 8 || ShiftSize > 64 || ShlConst != Size - ShiftSize) +    if (ShiftSize >= Size || ShlConst != Size - ShiftSize)        continue;      SDLoc DL(N);      SDValue NN = @@ -32626,37 +32678,6 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) {    return SDValue();  } -/// \brief Returns a vector of 0s if the node in input is a vector logical -/// shift by a constant amount which is known to be bigger than or equal -/// to the vector element size in bits. -static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG, -                                      const X86Subtarget &Subtarget) { -  EVT VT = N->getValueType(0); - -  if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 && -      (!Subtarget.hasInt256() || -       (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16))) -    return SDValue(); - -  SDValue Amt = N->getOperand(1); -  SDLoc DL(N); -  if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt)) -    if (auto *AmtSplat = AmtBV->getConstantSplatNode()) { -      const APInt &ShiftAmt = AmtSplat->getAPIntValue(); -      unsigned MaxAmount = -        VT.getSimpleVT().getScalarSizeInBits(); - -      // SSE2/AVX2 logical shifts always return a vector of 0s -      // if the shift amount is bigger than or equal to -      // the element size. The constant shift amount will be -      // encoded as a 8-bit immediate. -      if (ShiftAmt.trunc(8).uge(MaxAmount)) -        return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, DL); -    } - -  return SDValue(); -} -  static SDValue combineShift(SDNode* N, SelectionDAG &DAG,                              TargetLowering::DAGCombinerInfo &DCI,                              const X86Subtarget &Subtarget) { @@ -32672,11 +32693,6 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG,      if (SDValue V = combineShiftRightLogical(N, DAG))        return V; -  // Try to fold this logical shift into a zero vector. -  if (N->getOpcode() != ISD::SRA) -    if (SDValue V = performShiftToAllZeros(N, DAG, Subtarget)) -      return V; -    return SDValue();  } @@ -32996,21 +33012,20 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {  // register. In most cases we actually compare or select YMM-sized registers  // and mixing the two types creates horrible code. This method optimizes  // some of the transition sequences. +// Even with AVX-512 this is still useful for removing casts around logical +// operations on vXi1 mask types.  static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,                                   TargetLowering::DAGCombinerInfo &DCI,                                   const X86Subtarget &Subtarget) {    EVT VT = N->getValueType(0); -  if (!VT.is256BitVector()) -    return SDValue(); +  assert(VT.isVector() && "Expected vector type");    assert((N->getOpcode() == ISD::ANY_EXTEND ||            N->getOpcode() == ISD::ZERO_EXTEND ||            N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");    SDValue Narrow = N->getOperand(0); -  EVT NarrowVT = Narrow->getValueType(0); -  if (!NarrowVT.is128BitVector()) -    return SDValue(); +  EVT NarrowVT = Narrow.getValueType();    if (Narrow->getOpcode() != ISD::XOR &&        Narrow->getOpcode() != ISD::AND && @@ -33026,12 +33041,12 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,      return SDValue();    // The type of the truncated inputs. -  EVT WideVT = N0->getOperand(0)->getValueType(0); -  if (WideVT != VT) +  if (N0->getOperand(0).getValueType() != VT)      return SDValue();    // The right side has to be a 'trunc' or a constant vector. -  bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE; +  bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE && +                  N1.getOperand(0).getValueType() == VT;    ConstantSDNode *RHSConstSplat = nullptr;    if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1))      RHSConstSplat = RHSBV->getConstantSplatNode(); @@ -33040,37 +33055,31 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,    const TargetLowering &TLI = DAG.getTargetLoweringInfo(); -  if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), WideVT)) +  if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))      return SDValue();    // Set N0 and N1 to hold the inputs to the new wide operation.    N0 = N0->getOperand(0);    if (RHSConstSplat) { -    N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getVectorElementType(), +    N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT.getVectorElementType(),                       SDValue(RHSConstSplat, 0)); -    N1 = DAG.getSplatBuildVector(WideVT, DL, N1); +    N1 = DAG.getSplatBuildVector(VT, DL, N1);    } else if (RHSTrunc) {      N1 = N1->getOperand(0);    }    // Generate the wide operation. -  SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, WideVT, N0, N1); +  SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);    unsigned Opcode = N->getOpcode();    switch (Opcode) { +  default: llvm_unreachable("Unexpected opcode");    case ISD::ANY_EXTEND:      return Op; -  case ISD::ZERO_EXTEND: { -    unsigned InBits = NarrowVT.getScalarSizeInBits(); -    APInt Mask = APInt::getAllOnesValue(InBits); -    Mask = Mask.zext(VT.getScalarSizeInBits()); -    return DAG.getNode(ISD::AND, DL, VT, -                       Op, DAG.getConstant(Mask, DL, VT)); -  } +  case ISD::ZERO_EXTEND: +    return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType());    case ISD::SIGN_EXTEND:      return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT,                         Op, DAG.getValueType(NarrowVT)); -  default: -    llvm_unreachable("Unexpected opcode");    }  } @@ -33882,16 +33891,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,    if (!Subtarget.hasSSE2())      return SDValue(); -  if (Subtarget.hasBWI()) { -    if (VT.getSizeInBits() > 512) -      return SDValue(); -  } else if (Subtarget.hasAVX2()) { -    if (VT.getSizeInBits() > 256) -      return SDValue(); -  } else { -    if (VT.getSizeInBits() > 128) -      return SDValue(); -  }    // Detect the following pattern:    // @@ -33903,7 +33902,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,    //   %6 = trunc <N x i32> %5 to <N x i8>    //    // In AVX512, the last instruction can also be a trunc store. -    if (In.getOpcode() != ISD::SRL)      return SDValue(); @@ -33924,6 +33922,35 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,      return true;    }; +  // Split vectors to legal target size and apply AVG. +  auto LowerToAVG = [&](SDValue Op0, SDValue Op1) { +    unsigned NumSubs = 1; +    if (Subtarget.hasBWI()) { +      if (VT.getSizeInBits() > 512) +        NumSubs = VT.getSizeInBits() / 512; +    } else if (Subtarget.hasAVX2()) { +      if (VT.getSizeInBits() > 256) +        NumSubs = VT.getSizeInBits() / 256; +    } else { +      if (VT.getSizeInBits() > 128) +        NumSubs = VT.getSizeInBits() / 128; +    } + +    if (NumSubs == 1) +      return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); + +    SmallVector<SDValue, 4> Subs; +    EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), +                                 VT.getVectorNumElements() / NumSubs); +    for (unsigned i = 0; i != NumSubs; ++i) { +      unsigned Idx = i * SubVT.getVectorNumElements(); +      SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); +      SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); +      Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS)); +    } +    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); +  }; +    // Check if each element of the vector is left-shifted by one.    auto LHS = In.getOperand(0);    auto RHS = In.getOperand(1); @@ -33947,8 +33974,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,      SDValue VecOnes = DAG.getConstant(1, DL, InVT);      Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);      Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); -    return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), -                       Operands[1]); +    return LowerToAVG(Operands[0].getOperand(0), Operands[1]);    }    if (Operands[0].getOpcode() == ISD::ADD) @@ -33972,8 +33998,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,          return SDValue();      // The pattern is detected, emit X86ISD::AVG instruction. -    return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), -                       Operands[1].getOperand(0)); +    return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0));    }    return SDValue(); @@ -35872,14 +35897,8 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,    if (SDValue NewCMov = combineToExtendCMOV(N, DAG))      return NewCMov; -  if (!DCI.isBeforeLegalizeOps()) { -    if (InVT == MVT::i1) { -      SDValue Zero = DAG.getConstant(0, DL, VT); -      SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); -      return DAG.getSelect(DL, VT, N0, AllOnes, Zero); -    } +  if (!DCI.isBeforeLegalizeOps())      return SDValue(); -  }    if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR &&        isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) { @@ -35897,7 +35916,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,    if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))      return V; -  if (Subtarget.hasAVX() && VT.is256BitVector()) +  if (VT.isVector())      if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))        return R; @@ -36089,7 +36108,7 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,    if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))      return V; -  if (VT.is256BitVector()) +  if (VT.isVector())      if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))        return R; @@ -36244,39 +36263,54 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,                                      const X86Subtarget &Subtarget) {    SDLoc DL(N); -  // Pre-shrink oversized index elements to avoid triggering scalarization. -  if (DCI.isBeforeLegalize()) { +  if (DCI.isBeforeLegalizeOps()) {      SDValue Index = N->getOperand(4); -    if (Index.getScalarValueSizeInBits() > 64) { -      EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, +    // Remove any sign extends from 32 or smaller to larger than 32. +    // Only do this before LegalizeOps in case we need the sign extend for +    // legalization. +    if (Index.getOpcode() == ISD::SIGN_EXTEND) { +      if (Index.getScalarValueSizeInBits() > 32 && +          Index.getOperand(0).getScalarValueSizeInBits() <= 32) { +        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); +        NewOps[4] = Index.getOperand(0); +        DAG.UpdateNodeOperands(N, NewOps); +        // The original sign extend has less users, add back to worklist in case +        // it needs to be removed +        DCI.AddToWorklist(Index.getNode()); +        DCI.AddToWorklist(N); +        return SDValue(N, 0); +      } +    } + +    // Make sure the index is either i32 or i64 +    unsigned ScalarSize = Index.getScalarValueSizeInBits(); +    if (ScalarSize != 32 && ScalarSize != 64) { +      MVT EltVT = ScalarSize > 32 ? MVT::i64 : MVT::i32; +      EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,                                     Index.getValueType().getVectorNumElements()); -      SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index); +      Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); -      NewOps[4] = Trunc; +      NewOps[4] = Index;        DAG.UpdateNodeOperands(N, NewOps);        DCI.AddToWorklist(N);        return SDValue(N, 0);      } -  } -  // Try to remove sign extends from i32 to i64 on the index. -  // Only do this before legalize in case we are relying on it for -  // legalization. -  // TODO: We should maybe remove any sign extend once we learn how to sign -  // extend narrow index during lowering. -  if (DCI.isBeforeLegalizeOps()) { -    SDValue Index = N->getOperand(4); -    if (Index.getScalarValueSizeInBits() == 64 && -        Index.getOpcode() == ISD::SIGN_EXTEND && +    // Try to remove zero extends from 32->64 if we know the sign bit of +    // the input is zero. +    if (Index.getOpcode() == ISD::ZERO_EXTEND && +        Index.getScalarValueSizeInBits() == 64 &&          Index.getOperand(0).getScalarValueSizeInBits() == 32) { -      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); -      NewOps[4] = Index.getOperand(0); -      DAG.UpdateNodeOperands(N, NewOps); -      // The original sign extend has less users, add back to worklist in case -      // it needs to be removed. -      DCI.AddToWorklist(Index.getNode()); -      DCI.AddToWorklist(N); -      return SDValue(N, 0); +      if (DAG.SignBitIsZero(Index.getOperand(0))) { +        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); +        NewOps[4] = Index.getOperand(0); +        DAG.UpdateNodeOperands(N, NewOps); +        // The original zero extend has less users, add back to worklist in case +        // it needs to be removed +        DCI.AddToWorklist(Index.getNode()); +        DCI.AddToWorklist(N); +        return SDValue(N, 0); +      }      }    } @@ -36288,6 +36322,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());      NewOps[2] = Mask.getOperand(0);      DAG.UpdateNodeOperands(N, NewOps); +    return SDValue(N, 0);    }    // With AVX2 we only demand the upper bit of the mask. @@ -36356,7 +36391,7 @@ static SDValue combineVectorCompareAndMaskUnaryOp(SDNode *N,    EVT VT = N->getValueType(0);    if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||        N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC || -      VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits()) +      VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits())      return SDValue();    // Now check that the other operand of the AND is a constant. We could | 
