diff options
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | lib/Target/X86/X86ISelLowering.cpp | 8626 |
1 files changed, 5390 insertions, 3236 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 9edd799779c7..7dcdb7967058 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -103,7 +103,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, bool UseX87 = !Subtarget.useSoftFloat() && Subtarget.hasX87(); X86ScalarSSEf64 = Subtarget.hasSSE2(); X86ScalarSSEf32 = Subtarget.hasSSE1(); - MVT PtrVT = MVT::getIntegerVT(8 * TM.getPointerSize()); + MVT PtrVT = MVT::getIntegerVT(TM.getPointerSizeInBits(0)); // Set up the TargetLowering object. @@ -216,6 +216,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // We have an algorithm for SSE2, and we turn this into a 64-bit // FILD or VCVTUSI2SS/SD for other targets. setOperationAction(ISD::UINT_TO_FP , MVT::i32 , Custom); + } else { + setOperationAction(ISD::UINT_TO_FP , MVT::i32 , Expand); } // Promote i1/i8 SINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have @@ -235,7 +237,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } } else { setOperationAction(ISD::SINT_TO_FP , MVT::i16 , Promote); - setOperationAction(ISD::SINT_TO_FP , MVT::i32 , Promote); + setOperationAction(ISD::SINT_TO_FP , MVT::i32 , Expand); } // Promote i1/i8 FP_TO_SINT to larger FP_TO_SINTS's, as X86 doesn't have @@ -611,7 +613,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Long double always uses X87, except f128 in MMX. if (UseX87) { if (Subtarget.is64Bit() && Subtarget.hasMMX()) { - addRegisterClass(MVT::f128, &X86::FR128RegClass); + addRegisterClass(MVT::f128, &X86::VR128RegClass); ValueTypeActions.setTypeAction(MVT::f128, TypeSoftenFloat); setOperationAction(ISD::FABS , MVT::f128, Custom); setOperationAction(ISD::FNEG , MVT::f128, Custom); @@ -790,19 +792,33 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FABS, MVT::v2f64, Custom); setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom); - setOperationAction(ISD::SMAX, MVT::v8i16, Legal); - setOperationAction(ISD::UMAX, MVT::v16i8, Legal); - setOperationAction(ISD::SMIN, MVT::v8i16, Legal); - setOperationAction(ISD::UMIN, MVT::v16i8, Legal); + for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { + setOperationAction(ISD::SMAX, VT, VT == MVT::v8i16 ? Legal : Custom); + setOperationAction(ISD::SMIN, VT, VT == MVT::v8i16 ? Legal : Custom); + setOperationAction(ISD::UMAX, VT, VT == MVT::v16i8 ? Legal : Custom); + setOperationAction(ISD::UMIN, VT, VT == MVT::v16i8 ? Legal : Custom); + } setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v8i16, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i32, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f32, Custom); + // Provide custom widening for v2f32 setcc. This is really for VLX when + // setcc result type returns v2i1/v4i1 vector for v2f32/v4f32 leading to + // type legalization changing the result type to v4i1 during widening. + // It works fine for SSE2 and is probably faster so no need to qualify with + // VLX support. + setOperationAction(ISD::SETCC, MVT::v2i32, Custom); + for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); + + // The condition codes aren't legal in SSE/AVX and under AVX512 we use + // setcc all the way to isel and prefer SETGT in some isel patterns. + setCondCodeAction(ISD::SETLT, VT, Custom); + setCondCodeAction(ISD::SETLE, VT, Custom); } for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { @@ -874,6 +890,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::v2i32, Custom); setOperationAction(ISD::BITCAST, MVT::v4i16, Custom); setOperationAction(ISD::BITCAST, MVT::v8i8, Custom); + if (!Subtarget.hasAVX512()) + setOperationAction(ISD::BITCAST, MVT::v16i1, Custom); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v2i64, Custom); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i32, Custom); @@ -886,6 +904,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SRA, VT, Custom); } + + setOperationAction(ISD::ROTL, MVT::v4i32, Custom); + setOperationAction(ISD::ROTL, MVT::v8i16, Custom); + setOperationAction(ISD::ROTL, MVT::v16i8, Custom); } if (!Subtarget.useSoftFloat() && Subtarget.hasSSSE3()) { @@ -967,7 +989,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::BITREVERSE, VT, Custom); } - if (!Subtarget.useSoftFloat() && Subtarget.hasFp256()) { + if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) { bool HasInt256 = Subtarget.hasInt256(); addRegisterClass(MVT::v32i8, Subtarget.hasVLX() ? &X86::VR256XRegClass @@ -996,13 +1018,16 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted // even though v8i16 is a legal type. - setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Promote); - setOperationAction(ISD::FP_TO_UINT, MVT::v8i16, Promote); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i16, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i16, MVT::v8i32); setOperationAction(ISD::FP_TO_SINT, MVT::v8i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v8i32, Legal); setOperationAction(ISD::FP_ROUND, MVT::v4f32, Legal); + if (!Subtarget.hasAVX512()) + setOperationAction(ISD::BITCAST, MVT::v32i1, Custom); + for (MVT VT : MVT::fp_vector_valuetypes()) setLoadExtAction(ISD::EXTLOAD, VT, MVT::v4f32, Legal); @@ -1014,6 +1039,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRA, VT, Custom); } + setOperationAction(ISD::ROTL, MVT::v8i32, Custom); + setOperationAction(ISD::ROTL, MVT::v16i16, Custom); + setOperationAction(ISD::ROTL, MVT::v32i8, Custom); + setOperationAction(ISD::SELECT, MVT::v4f64, Custom); setOperationAction(ISD::SELECT, MVT::v4i64, Custom); setOperationAction(ISD::SELECT, MVT::v8f32, Custom); @@ -1034,6 +1063,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::CTLZ, VT, Custom); + + // The condition codes aren't legal in SSE/AVX and under AVX512 we use + // setcc all the way to isel and prefer SETGT in some isel patterns. + setCondCodeAction(ISD::SETLT, VT, Custom); + setCondCodeAction(ISD::SETLE, VT, Custom); } if (Subtarget.hasAnyFMA()) { @@ -1060,6 +1094,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MULHU, MVT::v32i8, Custom); setOperationAction(ISD::MULHS, MVT::v32i8, Custom); + setOperationAction(ISD::SMAX, MVT::v4i64, Custom); + setOperationAction(ISD::UMAX, MVT::v4i64, Custom); + setOperationAction(ISD::SMIN, MVT::v4i64, Custom); + setOperationAction(ISD::UMIN, MVT::v4i64, Custom); + for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) { setOperationAction(ISD::ABS, VT, HasInt256 ? Legal : Custom); setOperationAction(ISD::SMAX, VT, HasInt256 ? Legal : Custom); @@ -1137,13 +1176,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } } + // This block controls legalization of the mask vector sizes that are + // available with AVX512. 512-bit vectors are in a separate block controlled + // by useAVX512Regs. if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) { - addRegisterClass(MVT::v16i32, &X86::VR512RegClass); - addRegisterClass(MVT::v16f32, &X86::VR512RegClass); - addRegisterClass(MVT::v8i64, &X86::VR512RegClass); - addRegisterClass(MVT::v8f64, &X86::VR512RegClass); - addRegisterClass(MVT::v1i1, &X86::VK1RegClass); + addRegisterClass(MVT::v2i1, &X86::VK2RegClass); + addRegisterClass(MVT::v4i1, &X86::VK4RegClass); addRegisterClass(MVT::v8i1, &X86::VK8RegClass); addRegisterClass(MVT::v16i1, &X86::VK16RegClass); @@ -1151,24 +1190,34 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v1i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v16i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v16i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v8i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v4i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v2i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v2i1, Custom); - - // Extends of v16i1/v8i1 to 128-bit vectors. - setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v8i16, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v8i16, Custom); - - for (auto VT : { MVT::v8i1, MVT::v16i1 }) { + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v4i1, MVT::v4i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v4i1, MVT::v4i32); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i1, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i1, Custom); + + // There is no byte sized k-register load or store without AVX512DQ. + if (!Subtarget.hasDQI()) { + setOperationAction(ISD::LOAD, MVT::v1i1, Custom); + setOperationAction(ISD::LOAD, MVT::v2i1, Custom); + setOperationAction(ISD::LOAD, MVT::v4i1, Custom); + setOperationAction(ISD::LOAD, MVT::v8i1, Custom); + + setOperationAction(ISD::STORE, MVT::v1i1, Custom); + setOperationAction(ISD::STORE, MVT::v2i1, Custom); + setOperationAction(ISD::STORE, MVT::v4i1, Custom); + setOperationAction(ISD::STORE, MVT::v8i1, Custom); + } + + // Extends of v16i1/v8i1/v4i1/v2i1 to 128-bit vectors. + for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + } + + for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1 }) { setOperationAction(ISD::ADD, VT, Custom); setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); @@ -1184,11 +1233,24 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v2i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom); - for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1, - MVT::v16i1, MVT::v32i1, MVT::v64i1 }) - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); + for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + } + + // This block controls legalization for 512-bit operations with 32/64 bit + // elements. 512-bits can be disabled based on prefer-vector-width and + // required-vector-width function attributes. + if (!Subtarget.useSoftFloat() && Subtarget.useAVX512Regs()) { + addRegisterClass(MVT::v16i32, &X86::VR512RegClass); + addRegisterClass(MVT::v16f32, &X86::VR512RegClass); + addRegisterClass(MVT::v8i64, &X86::VR512RegClass); + addRegisterClass(MVT::v8f64, &X86::VR512RegClass); for (MVT VT : MVT::fp_vector_valuetypes()) setLoadExtAction(ISD::EXTLOAD, VT, MVT::v8f32, Legal); @@ -1201,16 +1263,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setLoadExtAction(ExtType, MVT::v8i64, MVT::v8i32, Legal); } - for (MVT VT : {MVT::v2i64, MVT::v4i32, MVT::v8i32, MVT::v4i64, MVT::v8i16, - MVT::v16i8, MVT::v16i16, MVT::v32i8, MVT::v16i32, - MVT::v8i64, MVT::v32i16, MVT::v64i8}) { - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - setLoadExtAction(ISD::SEXTLOAD, VT, MaskVT, Custom); - setLoadExtAction(ISD::ZEXTLOAD, VT, MaskVT, Custom); - setLoadExtAction(ISD::EXTLOAD, VT, MaskVT, Custom); - setTruncStoreAction(VT, MaskVT, Custom); - } - for (MVT VT : { MVT::v16f32, MVT::v8f64 }) { setOperationAction(ISD::FNEG, VT, Custom); setOperationAction(ISD::FABS, VT, Custom); @@ -1219,11 +1271,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } setOperationAction(ISD::FP_TO_SINT, MVT::v16i32, Legal); - setOperationAction(ISD::FP_TO_SINT, MVT::v16i16, Promote); - setOperationAction(ISD::FP_TO_SINT, MVT::v16i8, Promote); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i16, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i8, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i1, MVT::v16i32); setOperationAction(ISD::FP_TO_UINT, MVT::v16i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v16i8, Promote); - setOperationAction(ISD::FP_TO_UINT, MVT::v16i16, Promote); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i1, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i8, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i16, MVT::v16i32); setOperationAction(ISD::SINT_TO_FP, MVT::v16i32, Legal); setOperationAction(ISD::UINT_TO_FP, MVT::v16i32, Legal); @@ -1296,6 +1350,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::ROTL, VT, Custom); setOperationAction(ISD::ROTR, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + + // The condition codes aren't legal in SSE/AVX and under AVX512 we use + // setcc all the way to isel and prefer SETGT in some isel patterns. + setCondCodeAction(ISD::SETLT, VT, Custom); + setCondCodeAction(ISD::SETLE, VT, Custom); } // Need to promote to 64-bit even though we have 32-bit masked instructions @@ -1310,6 +1370,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); + + setOperationAction(ISD::MUL, MVT::v8i64, Legal); } if (Subtarget.hasCDI()) { @@ -1349,10 +1411,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationPromotedToType(ISD::LOAD, VT, MVT::v8i64); setOperationPromotedToType(ISD::SELECT, VT, MVT::v8i64); } + + // Need to custom split v32i16/v64i8 bitcasts. + if (!Subtarget.hasBWI()) { + setOperationAction(ISD::BITCAST, MVT::v32i16, Custom); + setOperationAction(ISD::BITCAST, MVT::v64i8, Custom); + } }// has AVX-512 - if (!Subtarget.useSoftFloat() && - (Subtarget.hasAVX512() || Subtarget.hasVLX())) { + // This block controls legalization for operations that don't have + // pre-AVX512 equivalents. Without VLX we use 512-bit operations for + // narrower widths. + if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) { // These operations are handled on non-VLX by artificially widening in // isel patterns. // TODO: Custom widen in lowering on non-VLX and drop the isel patterns? @@ -1376,6 +1446,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ROTR, VT, Custom); } + // Custom legalize 2x32 to get a little better code. + setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom); + setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom); + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) setOperationAction(ISD::MSCATTER, VT, Custom); @@ -1386,6 +1460,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UINT_TO_FP, VT, Legal); setOperationAction(ISD::FP_TO_SINT, VT, Legal); setOperationAction(ISD::FP_TO_UINT, VT, Legal); + + setOperationAction(ISD::MUL, VT, Legal); } } @@ -1402,10 +1478,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } } + // This block control legalization of v32i1/v64i1 which are available with + // AVX512BW. 512-bit v32i16 and v64i8 vector legalization is controlled with + // useBWIRegs. if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) { - addRegisterClass(MVT::v32i16, &X86::VR512RegClass); - addRegisterClass(MVT::v64i8, &X86::VR512RegClass); - addRegisterClass(MVT::v32i1, &X86::VK32RegClass); addRegisterClass(MVT::v64i1, &X86::VK64RegClass); @@ -1428,11 +1504,22 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i1, Custom); + for (auto VT : { MVT::v16i1, MVT::v32i1 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); // Extends from v32i1 masks to 256-bit vectors. setOperationAction(ISD::SIGN_EXTEND, MVT::v32i8, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v32i8, Custom); setOperationAction(ISD::ANY_EXTEND, MVT::v32i8, Custom); + } + + // This block controls legalization for v32i16 and v64i8. 512-bits can be + // disabled based on prefer-vector-width and required-vector-width function + // attributes. + if (!Subtarget.useSoftFloat() && Subtarget.useBWIRegs()) { + addRegisterClass(MVT::v32i16, &X86::VR512RegClass); + addRegisterClass(MVT::v64i8, &X86::VR512RegClass); + // Extends from v64i1 masks to 512-bit vectors. setOperationAction(ISD::SIGN_EXTEND, MVT::v64i8, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v64i8, Custom); @@ -1482,6 +1569,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UMAX, VT, Legal); setOperationAction(ISD::SMIN, VT, Legal); setOperationAction(ISD::UMIN, VT, Legal); + setOperationAction(ISD::SETCC, VT, Custom); setOperationPromotedToType(ISD::AND, VT, MVT::v8i64); setOperationPromotedToType(ISD::OR, VT, MVT::v8i64); @@ -1498,8 +1586,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } } - if (!Subtarget.useSoftFloat() && Subtarget.hasBWI() && - (Subtarget.hasAVX512() || Subtarget.hasVLX())) { + if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) { for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) { setOperationAction(ISD::MLOAD, VT, Subtarget.hasVLX() ? Legal : Custom); setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom); @@ -1516,39 +1603,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) { - addRegisterClass(MVT::v4i1, &X86::VK4RegClass); - addRegisterClass(MVT::v2i1, &X86::VK2RegClass); - - for (auto VT : { MVT::v2i1, MVT::v4i1 }) { - setOperationAction(ISD::ADD, VT, Custom); - setOperationAction(ISD::SUB, VT, Custom); - setOperationAction(ISD::MUL, VT, Custom); - setOperationAction(ISD::VSELECT, VT, Expand); - - setOperationAction(ISD::TRUNCATE, VT, Custom); - setOperationAction(ISD::SETCC, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::BUILD_VECTOR, VT, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); - } - - // TODO: v8i1 concat should be legal without VLX to support concats of - // v1i1, but we won't legalize it correctly currently without introducing - // a v4i1 concat in the middle. - setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); - - // Extends from v2i1/v4i1 masks to 128-bit vectors. - setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::ANY_EXTEND, MVT::v2i64, Custom); - setTruncStoreAction(MVT::v4i64, MVT::v4i8, Legal); setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal); setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal); @@ -1648,6 +1702,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // We have target-specific dag combine patterns for the following nodes: setTargetDAGCombine(ISD::VECTOR_SHUFFLE); + setTargetDAGCombine(ISD::SCALAR_TO_VECTOR); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::INSERT_SUBVECTOR); setTargetDAGCombine(ISD::EXTRACT_SUBVECTOR); @@ -1733,6 +1788,9 @@ SDValue X86TargetLowering::emitStackGuardXorFP(SelectionDAG &DAG, SDValue Val, TargetLoweringBase::LegalizeTypeAction X86TargetLowering::getPreferredVectorAction(EVT VT) const { + if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI()) + return TypeSplitVector; + if (ExperimentalVectorWideningLegalization && VT.getVectorNumElements() != 1 && VT.getVectorElementType().getSimpleVT() != MVT::i1) @@ -1741,6 +1799,20 @@ X86TargetLowering::getPreferredVectorAction(EVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, + EVT VT) const { + if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI()) + return MVT::v32i8; + return TargetLowering::getRegisterTypeForCallingConv(Context, VT); +} + +unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context, + EVT VT) const { + if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI()) + return 1; + return TargetLowering::getNumRegistersForCallingConv(Context, VT); +} + EVT X86TargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext& Context, EVT VT) const { @@ -1937,7 +2009,7 @@ void X86TargetLowering::markLibCallAttributes(MachineFunction *MF, unsigned CC, // Mark the first N int arguments as having reg for (unsigned Idx = 0; Idx < Args.size(); Idx++) { Type *T = Args[Idx].Ty; - if (T->isPointerTy() || T->isIntegerTy()) + if (T->isIntOrPtrTy()) if (MF->getDataLayout().getTypeAllocSize(T) <= 8) { unsigned numRegs = 1; if (MF->getDataLayout().getTypeAllocSize(T) > 4) @@ -2051,7 +2123,8 @@ Value *X86TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const { void X86TargetLowering::insertSSPDeclarations(Module &M) const { // MSVC CRT provides functionalities for stack protection. - if (Subtarget.getTargetTriple().isOSMSVCRT()) { + if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() || + Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) { // MSVC CRT has a global variable holding security cookie. M.getOrInsertGlobal("__security_cookie", Type::getInt8PtrTy(M.getContext())); @@ -2073,15 +2146,19 @@ void X86TargetLowering::insertSSPDeclarations(Module &M) const { Value *X86TargetLowering::getSDagStackGuard(const Module &M) const { // MSVC CRT has a global variable holding security cookie. - if (Subtarget.getTargetTriple().isOSMSVCRT()) + if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() || + Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) { return M.getGlobalVariable("__security_cookie"); + } return TargetLowering::getSDagStackGuard(M); } Value *X86TargetLowering::getSSPStackGuardCheck(const Module &M) const { // MSVC CRT has a function to validate security cookie. - if (Subtarget.getTargetTriple().isOSMSVCRT()) + if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() || + Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) { return M.getFunction("__security_check_cookie"); + } return TargetLowering::getSSPStackGuardCheck(M); } @@ -2140,6 +2217,10 @@ static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc, const SDLoc &Dl, SelectionDAG &DAG) { EVT ValVT = ValArg.getValueType(); + if (ValVT == MVT::v1i1) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, Dl, ValLoc, ValArg, + DAG.getIntPtrConstant(0, Dl)); + if ((ValVT == MVT::v8i1 && (ValLoc == MVT::i8 || ValLoc == MVT::i32)) || (ValVT == MVT::v16i1 && (ValLoc == MVT::i16 || ValLoc == MVT::i32))) { // Two stage lowering might be required @@ -2150,13 +2231,16 @@ static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc, if (ValLoc == MVT::i32) ValToCopy = DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValToCopy); return ValToCopy; - } else if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) || - (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) { + } + + if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) || + (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) { // One stage lowering is required // bitcast: v32i1 -> i32 / v64i1 -> i64 return DAG.getBitcast(ValLoc, ValArg); - } else - return DAG.getNode(ISD::SIGN_EXTEND, Dl, ValLoc, ValArg); + } + + return DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValArg); } /// Breaks v64i1 value into two registers and adds the new node to the DAG @@ -2474,10 +2558,10 @@ static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA, MachineFunction &MF = DAG.getMachineFunction(); const TargetRegisterClass *RC = &X86::GR32RegClass; - // Read a 32 bit value from the registers + // Read a 32 bit value from the registers. if (nullptr == InFlag) { // When no physical register is present, - // create an intermediate virtual register + // create an intermediate virtual register. Reg = MF.addLiveIn(VA.getLocReg(), RC); ArgValueLo = DAG.getCopyFromReg(Root, Dl, Reg, MVT::i32); Reg = MF.addLiveIn(NextVA.getLocReg(), RC); @@ -2493,13 +2577,13 @@ static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA, *InFlag = ArgValueHi.getValue(2); } - // Convert the i32 type into v32i1 type + // Convert the i32 type into v32i1 type. Lo = DAG.getBitcast(MVT::v32i1, ArgValueLo); - // Convert the i32 type into v32i1 type + // Convert the i32 type into v32i1 type. Hi = DAG.getBitcast(MVT::v32i1, ArgValueHi); - // Concatenate the two values together + // Concatenate the two values together. return DAG.getNode(ISD::CONCAT_VECTORS, Dl, MVT::v64i1, Lo, Hi); } @@ -2640,7 +2724,7 @@ enum StructReturnType { StackStructReturn }; static StructReturnType -callIsStructReturn(const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsMCU) { +callIsStructReturn(ArrayRef<ISD::OutputArg> Outs, bool IsMCU) { if (Outs.empty()) return NotStructReturn; @@ -2654,7 +2738,7 @@ callIsStructReturn(const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsMCU) { /// Determines whether a function uses struct return semantics. static StructReturnType -argsAreStructReturn(const SmallVectorImpl<ISD::InputArg> &Ins, bool IsMCU) { +argsAreStructReturn(ArrayRef<ISD::InputArg> Ins, bool IsMCU) { if (Ins.empty()) return NotStructReturn; @@ -2774,7 +2858,11 @@ X86TargetLowering::LowerMemArgument(SDValue Chain, CallingConv::ID CallConv, if (Flags.isByVal()) { unsigned Bytes = Flags.getByValSize(); if (Bytes == 0) Bytes = 1; // Don't create zero-sized stack objects. - int FI = MFI.CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable); + + // FIXME: For now, all byval parameter objects are marked as aliasing. This + // can be improved with deeper analysis. + int FI = MFI.CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable, + /*isAliased=*/true); // Adjust SP offset of interrupt parameter. if (CallConv == CallingConv::X86_INTR) { MFI.setObjectOffset(FI, Offset); @@ -2898,7 +2986,7 @@ static ArrayRef<MCPhysReg> get64BitArgumentXMMs(MachineFunction &MF, } #ifndef NDEBUG -static bool isSortedByValueNo(const SmallVectorImpl<CCValAssign> &ArgLocs) { +static bool isSortedByValueNo(ArrayRef<CCValAssign> ArgLocs) { return std::is_sorted(ArgLocs.begin(), ArgLocs.end(), [](const CCValAssign &A, const CCValAssign &B) -> bool { return A.getValNo() < B.getValNo(); @@ -2975,7 +3063,11 @@ SDValue X86TargetLowering::LowerFormalArguments( getv64i1Argument(VA, ArgLocs[++I], Chain, DAG, dl, Subtarget); } else { const TargetRegisterClass *RC; - if (RegVT == MVT::i32) + if (RegVT == MVT::i8) + RC = &X86::GR8RegClass; + else if (RegVT == MVT::i16) + RC = &X86::GR16RegClass; + else if (RegVT == MVT::i32) RC = &X86::GR32RegClass; else if (Is64Bit && RegVT == MVT::i64) RC = &X86::GR64RegClass; @@ -2986,7 +3078,7 @@ SDValue X86TargetLowering::LowerFormalArguments( else if (RegVT == MVT::f80) RC = &X86::RFP80RegClass; else if (RegVT == MVT::f128) - RC = &X86::FR128RegClass; + RC = &X86::VR128RegClass; else if (RegVT.is512BitVector()) RC = &X86::VR512RegClass; else if (RegVT.is256BitVector()) @@ -3361,6 +3453,11 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const Function *Fn = CI ? CI->getCalledFunction() : nullptr; bool HasNCSR = (CI && CI->hasFnAttr("no_caller_saved_registers")) || (Fn && Fn->hasFnAttribute("no_caller_saved_registers")); + const auto *II = dyn_cast_or_null<InvokeInst>(CLI.CS.getInstruction()); + bool HasNoCfCheck = + (CI && CI->doesNoCfCheck()) || (II && II->doesNoCfCheck()); + const Module *M = MF.getMMI().getModule(); + Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch"); if (CallConv == CallingConv::X86_INTR) report_fatal_error("X86 interrupts may not be called directly"); @@ -3743,6 +3840,14 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Callee = DAG.getTargetExternalSymbol( S->getSymbol(), getPointerTy(DAG.getDataLayout()), OpFlags); + + if (OpFlags == X86II::MO_GOTPCREL) { + Callee = DAG.getNode(X86ISD::WrapperRIP, dl, + getPointerTy(DAG.getDataLayout()), Callee); + Callee = DAG.getLoad( + getPointerTy(DAG.getDataLayout()), dl, DAG.getEntryNode(), Callee, + MachinePointerInfo::getGOT(DAG.getMachineFunction())); + } } else if (Subtarget.isTarget64BitILP32() && Callee->getValueType(0) == MVT::i32) { // Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI @@ -3804,9 +3909,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); // Allocate a new Reg Mask and copy Mask. - RegMask = MF.allocateRegisterMask(TRI->getNumRegs()); - unsigned RegMaskSize = (TRI->getNumRegs() + 31) / 32; - memcpy(RegMask, Mask, sizeof(uint32_t) * RegMaskSize); + RegMask = MF.allocateRegMask(); + unsigned RegMaskSize = MachineOperand::getRegMaskSize(TRI->getNumRegs()); + memcpy(RegMask, Mask, sizeof(RegMask[0]) * RegMaskSize); // Make sure all sub registers of the argument registers are reset // in the RegMask. @@ -3836,7 +3941,11 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, return DAG.getNode(X86ISD::TC_RETURN, dl, NodeTys, Ops); } - Chain = DAG.getNode(X86ISD::CALL, dl, NodeTys, Ops); + if (HasNoCfCheck && IsCFProtectionSupported) { + Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops); + } else { + Chain = DAG.getNode(X86ISD::CALL, dl, NodeTys, Ops); + } InFlag = Chain.getValue(1); // Create the CALLSEQ_END node. @@ -4260,8 +4369,6 @@ static bool isTargetShuffle(unsigned Opcode) { case X86ISD::VSRLDQ: case X86ISD::MOVLHPS: case X86ISD::MOVHLPS: - case X86ISD::MOVLPS: - case X86ISD::MOVLPD: case X86ISD::MOVSHDUP: case X86ISD::MOVSLDUP: case X86ISD::MOVDDUP: @@ -4273,12 +4380,12 @@ static bool isTargetShuffle(unsigned Opcode) { case X86ISD::VPERMILPI: case X86ISD::VPERMILPV: case X86ISD::VPERM2X128: + case X86ISD::SHUF128: case X86ISD::VPERMIL2: case X86ISD::VPERMI: case X86ISD::VPPERM: case X86ISD::VPERMV: case X86ISD::VPERMV3: - case X86ISD::VPERMIV3: case X86ISD::VZEXT_MOVL: return true; } @@ -4294,7 +4401,6 @@ static bool isTargetShuffleVariableMask(unsigned Opcode) { case X86ISD::VPPERM: case X86ISD::VPERMV: case X86ISD::VPERMV3: - case X86ISD::VPERMIV3: return true; // 'Faux' Target Shuffles. case ISD::AND: @@ -4371,7 +4477,7 @@ bool X86::isCalleePop(CallingConv::ID CallingConv, } } -/// \brief Return true if the condition is an unsigned comparison operation. +/// Return true if the condition is an unsigned comparison operation. static bool isX86CCUnsigned(unsigned X86CC) { switch (X86CC) { default: @@ -4518,20 +4624,6 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.offset = 0; switch (IntrData->Type) { - case EXPAND_FROM_MEM: { - Info.ptrVal = I.getArgOperand(0); - Info.memVT = MVT::getVT(I.getType()); - Info.align = 1; - Info.flags |= MachineMemOperand::MOLoad; - break; - } - case COMPRESS_TO_MEM: { - Info.ptrVal = I.getArgOperand(0); - Info.memVT = MVT::getVT(I.getArgOperand(1)->getType()); - Info.align = 1; - Info.flags |= MachineMemOperand::MOStore; - break; - } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: case TRUNCATE_TO_MEM_VI32: { @@ -4580,7 +4672,7 @@ bool X86TargetLowering::shouldReduceLoadWidth(SDNode *Load, return true; } -/// \brief Returns true if it is beneficial to convert a load of a constant +/// Returns true if it is beneficial to convert a load of a constant /// to just the constant itself. bool X86TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, Type *Ty) const { @@ -4625,6 +4717,14 @@ bool X86TargetLowering::isCheapToSpeculateCtlz() const { return Subtarget.hasLZCNT(); } +bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, + EVT BitcastVT) const { + if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1) + return false; + + return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT); +} + bool X86TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT, const SelectionDAG &DAG) const { // Do not merge to float value size (128 bytes) if no implicit @@ -4649,14 +4749,52 @@ bool X86TargetLowering::isMaskAndCmp0FoldingBeneficial( } bool X86TargetLowering::hasAndNotCompare(SDValue Y) const { + EVT VT = Y.getValueType(); + + if (VT.isVector()) + return false; + if (!Subtarget.hasBMI()) return false; // There are only 32-bit and 64-bit forms for 'andn'. - EVT VT = Y.getValueType(); if (VT != MVT::i32 && VT != MVT::i64) return false; + // A mask and compare against constant is ok for an 'andn' too + // even though the BMI instruction doesn't have an immediate form. + + return true; +} + +bool X86TargetLowering::hasAndNot(SDValue Y) const { + EVT VT = Y.getValueType(); + + if (!VT.isVector()) // x86 can't form 'andn' with an immediate. + return !isa<ConstantSDNode>(Y) && hasAndNotCompare(Y); + + // Vector. + + if (!Subtarget.hasSSE1() || VT.getSizeInBits() < 128) + return false; + + if (VT == MVT::v4i32) + return true; + + return Subtarget.hasSSE2(); +} + +bool X86TargetLowering::preferShiftsToClearExtremeBits(SDValue Y) const { + EVT VT = Y.getValueType(); + + // For vectors, we don't have a preference, but we probably want a mask. + if (VT.isVector()) + return false; + + // 64-bit shifts on 32-bit targets produce really bad bloated code. + if (VT == MVT::i64 && !Subtarget.is64Bit()) + return false; + return true; } @@ -4699,10 +4837,24 @@ static bool isUndefInRange(ArrayRef<int> Mask, unsigned Pos, unsigned Size) { return true; } +/// Return true if Val falls within the specified range (L, H]. +static bool isInRange(int Val, int Low, int Hi) { + return (Val >= Low && Val < Hi); +} + +/// Return true if the value of any element in Mask falls within the specified +/// range (L, H]. +static bool isAnyInRange(ArrayRef<int> Mask, int Low, int Hi) { + for (int M : Mask) + if (isInRange(M, Low, Hi)) + return true; + return false; +} + /// Return true if Val is undef or if its value falls within the /// specified range (L, H]. static bool isUndefOrInRange(int Val, int Low, int Hi) { - return (Val == SM_SentinelUndef) || (Val >= Low && Val < Hi); + return (Val == SM_SentinelUndef) || isInRange(Val, Low, Hi); } /// Return true if every element in Mask is undef or if its value @@ -4718,7 +4870,7 @@ static bool isUndefOrInRange(ArrayRef<int> Mask, /// Return true if Val is undef, zero or if its value falls within the /// specified range (L, H]. static bool isUndefOrZeroOrInRange(int Val, int Low, int Hi) { - return isUndefOrZero(Val) || (Val >= Low && Val < Hi); + return isUndefOrZero(Val) || isInRange(Val, Low, Hi); } /// Return true if every element in Mask is undef, zero or if its value @@ -4731,11 +4883,11 @@ static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) { } /// Return true if every element in Mask, beginning -/// from position Pos and ending in Pos+Size, falls within the specified -/// sequential range (Low, Low+Size]. or is undef. -static bool isSequentialOrUndefInRange(ArrayRef<int> Mask, - unsigned Pos, unsigned Size, int Low) { - for (unsigned i = Pos, e = Pos+Size; i != e; ++i, ++Low) +/// from position Pos and ending in Pos + Size, falls within the specified +/// sequence (Low, Low + Step, ..., Low + (Size - 1) * Step) or is undef. +static bool isSequentialOrUndefInRange(ArrayRef<int> Mask, unsigned Pos, + unsigned Size, int Low, int Step = 1) { + for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step) if (!isUndefOrEqual(Mask[i], Low)) return false; return true; @@ -4762,7 +4914,7 @@ static bool isUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos, return true; } -/// \brief Helper function to test whether a shuffle mask could be +/// Helper function to test whether a shuffle mask could be /// simplified by widening the elements being shuffled. /// /// Appends the mask for wider elements in WidenedMask if valid. Otherwise @@ -4821,6 +4973,24 @@ static bool canWidenShuffleElements(ArrayRef<int> Mask, return true; } +static bool canWidenShuffleElements(ArrayRef<int> Mask, + const APInt &Zeroable, + SmallVectorImpl<int> &WidenedMask) { + SmallVector<int, 32> TargetMask(Mask.begin(), Mask.end()); + for (int i = 0, Size = TargetMask.size(); i < Size; ++i) { + if (TargetMask[i] == SM_SentinelUndef) + continue; + if (Zeroable[i]) + TargetMask[i] = SM_SentinelZero; + } + return canWidenShuffleElements(TargetMask, WidenedMask); +} + +static bool canWidenShuffleElements(ArrayRef<int> Mask) { + SmallVector<int, 32> WidenedMask; + return canWidenShuffleElements(Mask, WidenedMask); +} + /// Returns true if Elt is a constant zero or a floating point constant +0.0. bool X86::isZeroNode(SDValue Elt) { return isNullConstant(Elt) || isNullFPConstant(Elt); @@ -4916,8 +5086,6 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, } else if (VT.getVectorElementType() == MVT::i1) { assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) && "Unexpected vector type"); - assert((Subtarget.hasVLX() || VT.getVectorNumElements() >= 8) && - "Unexpected vector type"); Vec = DAG.getConstant(0, dl, VT); } else { unsigned Num32BitElts = VT.getSizeInBits() / 32; @@ -5007,10 +5175,66 @@ static SDValue insert128BitVector(SDValue Result, SDValue Vec, unsigned IdxVal, return insertSubVector(Result, Vec, IdxVal, DAG, dl, 128); } -static SDValue insert256BitVector(SDValue Result, SDValue Vec, unsigned IdxVal, - SelectionDAG &DAG, const SDLoc &dl) { - assert(Vec.getValueType().is256BitVector() && "Unexpected vector size!"); - return insertSubVector(Result, Vec, IdxVal, DAG, dl, 256); +/// Widen a vector to a larger size with the same scalar type, with the new +/// elements either zero or undef. +static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements, + const X86Subtarget &Subtarget, SelectionDAG &DAG, + const SDLoc &dl) { + assert(Vec.getValueSizeInBits() < VT.getSizeInBits() && + Vec.getValueType().getScalarType() == VT.getScalarType() && + "Unsupported vector widening type"); + SDValue Res = ZeroNewElements ? getZeroVector(VT, Subtarget, DAG, dl) + : DAG.getUNDEF(VT); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, Res, Vec, + DAG.getIntPtrConstant(0, dl)); +} + +// Helper for splitting operands of an operation to legal target size and +// apply a function on each part. +// Useful for operations that are available on SSE2 in 128-bit, on AVX2 in +// 256-bit and on AVX512BW in 512-bit. The argument VT is the type used for +// deciding if/how to split Ops. Ops elements do *not* have to be of type VT. +// The argument Builder is a function that will be applied on each split part: +// SDValue Builder(SelectionDAG&G, SDLoc, ArrayRef<SDValue>) +template <typename F> +SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget, + const SDLoc &DL, EVT VT, ArrayRef<SDValue> Ops, + F Builder, bool CheckBWI = true) { + assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2"); + unsigned NumSubs = 1; + if ((CheckBWI && Subtarget.useBWIRegs()) || + (!CheckBWI && Subtarget.useAVX512Regs())) { + if (VT.getSizeInBits() > 512) { + NumSubs = VT.getSizeInBits() / 512; + assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size"); + } + } else if (Subtarget.hasAVX2()) { + if (VT.getSizeInBits() > 256) { + NumSubs = VT.getSizeInBits() / 256; + assert((VT.getSizeInBits() % 256) == 0 && "Illegal vector size"); + } + } else { + if (VT.getSizeInBits() > 128) { + NumSubs = VT.getSizeInBits() / 128; + assert((VT.getSizeInBits() % 128) == 0 && "Illegal vector size"); + } + } + + if (NumSubs == 1) + return Builder(DAG, DL, Ops); + + SmallVector<SDValue, 4> Subs; + for (unsigned i = 0; i != NumSubs; ++i) { + SmallVector<SDValue, 2> SubOps; + for (SDValue Op : Ops) { + EVT OpVT = Op.getValueType(); + unsigned NumSubElts = OpVT.getVectorNumElements() / NumSubs; + unsigned SizeSub = OpVT.getSizeInBits() / NumSubs; + SubOps.push_back(extractSubVector(Op, i * NumSubElts, DAG, DL, SizeSub)); + } + Subs.push_back(Builder(DAG, DL, SubOps)); + } + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); } // Return true if the instruction zeroes the unused upper part of the @@ -5019,13 +5243,9 @@ static bool isMaskedZeroUpperBitsvXi1(unsigned int Opcode) { switch (Opcode) { default: return false; - case X86ISD::TESTM: - case X86ISD::TESTNM: - case X86ISD::PCMPEQM: - case X86ISD::PCMPGTM: case X86ISD::CMPM: - case X86ISD::CMPMU: case X86ISD::CMPM_RND: + case ISD::SETCC: return true; } } @@ -5166,22 +5386,11 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx); } -/// Concat two 128-bit vectors into a 256 bit vector using VINSERTF128 -/// instructions. This is used because creating CONCAT_VECTOR nodes of -/// BUILD_VECTORS returns a larger BUILD_VECTOR while we're trying to lower -/// large BUILD_VECTORS. -static SDValue concat128BitVectors(SDValue V1, SDValue V2, EVT VT, - unsigned NumElems, SelectionDAG &DAG, - const SDLoc &dl) { - SDValue V = insert128BitVector(DAG.getUNDEF(VT), V1, 0, DAG, dl); - return insert128BitVector(V, V2, NumElems / 2, DAG, dl); -} - -static SDValue concat256BitVectors(SDValue V1, SDValue V2, EVT VT, - unsigned NumElems, SelectionDAG &DAG, - const SDLoc &dl) { - SDValue V = insert256BitVector(DAG.getUNDEF(VT), V1, 0, DAG, dl); - return insert256BitVector(V, V2, NumElems / 2, DAG, dl); +static SDValue concatSubVectors(SDValue V1, SDValue V2, EVT VT, + unsigned NumElems, SelectionDAG &DAG, + const SDLoc &dl, unsigned VectorWidth) { + SDValue V = insertSubVector(DAG.getUNDEF(VT), V1, 0, DAG, dl, VectorWidth); + return insertSubVector(V, V2, NumElems / 2, DAG, dl, VectorWidth); } /// Returns a vector of specified type with all bits set. @@ -5265,6 +5474,13 @@ static SDValue peekThroughOneUseBitcasts(SDValue V) { return V; } +// Peek through EXTRACT_SUBVECTORs - typically used for AVX1 256-bit intops. +static SDValue peekThroughEXTRACT_SUBVECTORs(SDValue V) { + while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) + V = V.getOperand(0); + return V; +} + static const Constant *getTargetConstantFromNode(SDValue Op) { Op = peekThroughBitcasts(Op); @@ -5389,6 +5605,12 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, SmallVector<APInt, 64> SrcEltBits(1, Cst->getAPIntValue()); return CastBitData(UndefSrcElts, SrcEltBits); } + if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) { + APInt UndefSrcElts = APInt::getNullValue(1); + APInt RawBits = Cst->getValueAPF().bitcastToAPInt(); + SmallVector<APInt, 64> SrcEltBits(1, RawBits); + return CastBitData(UndefSrcElts, SrcEltBits); + } // Extract constant bits from build vector. if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { @@ -5525,14 +5747,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodeBLENDMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodeBLENDMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::SHUFP: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodeSHUFPMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodeSHUFPMask(NumElems, VT.getScalarSizeInBits(), + cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::INSERTPS: @@ -5548,7 +5771,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, isa<ConstantSDNode>(N->getOperand(2))) { int BitLen = N->getConstantOperandVal(1); int BitIdx = N->getConstantOperandVal(2); - DecodeEXTRQIMask(VT, BitLen, BitIdx, Mask); + DecodeEXTRQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx, + Mask); IsUnary = true; } break; @@ -5559,20 +5783,21 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, isa<ConstantSDNode>(N->getOperand(3))) { int BitLen = N->getConstantOperandVal(2); int BitIdx = N->getConstantOperandVal(3); - DecodeINSERTQIMask(VT, BitLen, BitIdx, Mask); + DecodeINSERTQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx, + Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); } break; case X86ISD::UNPCKH: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - DecodeUNPCKHMask(VT, Mask); + DecodeUNPCKHMask(NumElems, VT.getScalarSizeInBits(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::UNPCKL: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - DecodeUNPCKLMask(VT, Mask); + DecodeUNPCKLMask(NumElems, VT.getScalarSizeInBits(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVHLPS: @@ -5592,7 +5817,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodePALIGNRMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePALIGNRMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); Ops.push_back(N->getOperand(1)); Ops.push_back(N->getOperand(0)); @@ -5601,38 +5827,43 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands() - 1); - DecodePSLLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePSLLDQMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = true; break; case X86ISD::VSRLDQ: assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands() - 1); - DecodePSRLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePSRLDQMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = true; break; case X86ISD::PSHUFD: case X86ISD::VPERMILPI: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodePSHUFMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePSHUFMask(NumElems, VT.getScalarSizeInBits(), + cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFHW: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodePSHUFHWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePSHUFHWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = true; break; case X86ISD::PSHUFLW: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodePSHUFLWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodePSHUFLWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = true; break; case X86ISD::VZEXT_MOVL: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - DecodeZeroMoveLowMask(VT, Mask); + DecodeZeroMoveLowMask(NumElems, Mask); IsUnary = true; break; case X86ISD::VBROADCAST: { @@ -5648,7 +5879,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, // came from an extract from the original width. If we found one, we // pushed it the Ops vector above. if (N0.getValueType() == VT || !Ops.empty()) { - DecodeVectorBroadcast(VT, Mask); + DecodeVectorBroadcast(NumElems, Mask); IsUnary = true; break; } @@ -5661,7 +5892,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, unsigned MaskEltSize = VT.getScalarSizeInBits(); SmallVector<uint64_t, 32> RawMask; if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) { - DecodeVPERMILPMask(VT, RawMask, Mask); + DecodeVPERMILPMask(NumElems, VT.getScalarSizeInBits(), RawMask, Mask); break; } if (auto *C = getTargetConstantFromNode(MaskNode)) { @@ -5690,41 +5921,47 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, case X86ISD::VPERMI: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodeVPERMMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodeVPERMMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::MOVSS: case X86ISD::MOVSD: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - DecodeScalarMoveMask(VT, /* IsLoad */ false, Mask); + DecodeScalarMoveMask(NumElems, /* IsLoad */ false, Mask); break; case X86ISD::VPERM2X128: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); - DecodeVPERM2X128Mask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + DecodeVPERM2X128Mask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); + IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); + break; + case X86ISD::SHUF128: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); + ImmN = N->getOperand(N->getNumOperands()-1); + decodeVSHUF64x2FamilyMask(NumElems, VT.getScalarSizeInBits(), + cast<ConstantSDNode>(ImmN)->getZExtValue(), + Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVSLDUP: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - DecodeMOVSLDUPMask(VT, Mask); + DecodeMOVSLDUPMask(NumElems, Mask); IsUnary = true; break; case X86ISD::MOVSHDUP: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - DecodeMOVSHDUPMask(VT, Mask); + DecodeMOVSHDUPMask(NumElems, Mask); IsUnary = true; break; case X86ISD::MOVDDUP: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - DecodeMOVDDUPMask(VT, Mask); + DecodeMOVDDUPMask(NumElems, Mask); IsUnary = true; break; - case X86ISD::MOVLPD: - case X86ISD::MOVLPS: - // Not yet implemented - return false; case X86ISD::VPERMIL2: { assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); @@ -5736,7 +5973,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, unsigned CtrlImm = CtrlOp->getZExtValue(); SmallVector<uint64_t, 32> RawMask; if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) { - DecodeVPERMIL2PMask(VT, CtrlImm, RawMask, Mask); + DecodeVPERMIL2PMask(NumElems, VT.getScalarSizeInBits(), CtrlImm, + RawMask, Mask); break; } if (auto *C = getTargetConstantFromNode(MaskNode)) { @@ -5795,21 +6033,6 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, } return false; } - case X86ISD::VPERMIV3: { - assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - assert(N->getOperand(2).getValueType() == VT && "Unexpected value type"); - IsUnary = IsFakeUnary = N->getOperand(1) == N->getOperand(2); - // Unlike most shuffle nodes, VPERMIV3's mask operand is the first one. - Ops.push_back(N->getOperand(1)); - Ops.push_back(N->getOperand(2)); - SDValue MaskNode = N->getOperand(0); - unsigned MaskEltSize = VT.getScalarSizeInBits(); - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPERMV3Mask(C, MaskEltSize, Mask); - break; - } - return false; - } default: llvm_unreachable("unknown target shuffle node"); } @@ -5927,7 +6150,7 @@ static bool setTargetShuffleZeroElements(SDValue N, // destination value type. static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, SmallVectorImpl<SDValue> &Ops, - SelectionDAG &DAG) { + const SelectionDAG &DAG) { Mask.clear(); Ops.clear(); @@ -5940,6 +6163,17 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, unsigned Opcode = N.getOpcode(); switch (Opcode) { + case ISD::VECTOR_SHUFFLE: { + // Don't treat ISD::VECTOR_SHUFFLE as a target shuffle so decode it here. + ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(N)->getMask(); + if (isUndefOrInRange(ShuffleMask, 0, 2 * NumElts)) { + Mask.append(ShuffleMask.begin(), ShuffleMask.end()); + Ops.push_back(N.getOperand(0)); + Ops.push_back(N.getOperand(1)); + return true; + } + return false; + } case ISD::AND: case X86ISD::ANDNP: { // Attempt to decode as a per-byte mask. @@ -6001,8 +6235,11 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, case X86ISD::PINSRW: { SDValue InVec = N.getOperand(0); SDValue InScl = N.getOperand(1); + SDValue InIndex = N.getOperand(2); + if (!isa<ConstantSDNode>(InIndex) || + cast<ConstantSDNode>(InIndex)->getAPIntValue().uge(NumElts)) + return false; uint64_t InIdx = N.getConstantOperandVal(2); - assert(InIdx < NumElts && "Illegal insertion index"); // Attempt to recognise a PINSR*(VEC, 0, Idx) shuffle pattern. if (X86::isZeroNode(InScl)) { @@ -6020,8 +6257,12 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, return false; SDValue ExVec = InScl.getOperand(0); + SDValue ExIndex = InScl.getOperand(1); + if (!isa<ConstantSDNode>(ExIndex) || + cast<ConstantSDNode>(ExIndex)->getAPIntValue().uge(NumElts)) + return false; uint64_t ExIdx = InScl.getConstantOperandVal(1); - assert(ExIdx < NumElts && "Illegal extraction index"); + Ops.push_back(InVec); Ops.push_back(ExVec); for (unsigned i = 0; i != NumElts; ++i) @@ -6097,7 +6338,8 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, MVT SrcVT = Src.getSimpleValueType(); if (NumSizeInBits != SrcVT.getSizeInBits()) break; - DecodeZeroExtendMask(SrcVT.getScalarType(), VT, Mask); + DecodeZeroExtendMask(SrcVT.getScalarSizeInBits(), VT.getScalarSizeInBits(), + VT.getVectorNumElements(), Mask); Ops.push_back(Src); return true; } @@ -6141,7 +6383,7 @@ static void resolveTargetShuffleInputsAndMask(SmallVectorImpl<SDValue> &Inputs, static bool resolveTargetShuffleInputs(SDValue Op, SmallVectorImpl<SDValue> &Inputs, SmallVectorImpl<int> &Mask, - SelectionDAG &DAG) { + const SelectionDAG &DAG) { if (!setTargetShuffleZeroElements(Op, Mask, Inputs)) if (!getFauxShuffleMask(Op, Mask, Inputs, DAG)) return false; @@ -6451,9 +6693,8 @@ static SDValue getVShift(bool isLeft, EVT VT, SDValue SrcOp, unsigned NumBits, MVT ShVT = MVT::v16i8; unsigned Opc = isLeft ? X86ISD::VSHLDQ : X86ISD::VSRLDQ; SrcOp = DAG.getBitcast(ShVT, SrcOp); - MVT ScalarShiftTy = TLI.getScalarShiftAmountTy(DAG.getDataLayout(), VT); assert(NumBits % 8 == 0 && "Only support byte sized shifts"); - SDValue ShiftVal = DAG.getConstant(NumBits/8, dl, ScalarShiftTy); + SDValue ShiftVal = DAG.getConstant(NumBits/8, dl, MVT::i8); return DAG.getBitcast(VT, DAG.getNode(Opc, dl, ShVT, SrcOp, ShiftVal)); } @@ -6805,17 +7046,13 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, BOperand = ZeroExtended.getOperand(0); else BOperand = Ld.getOperand(0).getOperand(0); - if (BOperand.getValueType().isVector() && - BOperand.getSimpleValueType().getVectorElementType() == MVT::i1) { - if ((EltType == MVT::i64 && (VT.getVectorElementType() == MVT::i8 || - NumElts == 8)) || // for broadcastmb2q - (EltType == MVT::i32 && (VT.getVectorElementType() == MVT::i16 || - NumElts == 16))) { // for broadcastmw2d - SDValue Brdcst = - DAG.getNode(X86ISD::VBROADCASTM, dl, - MVT::getVectorVT(EltType, NumElts), BOperand); - return DAG.getBitcast(VT, Brdcst); - } + MVT MaskVT = BOperand.getSimpleValueType(); + if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q + (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d + SDValue Brdcst = + DAG.getNode(X86ISD::VBROADCASTM, dl, + MVT::getVectorVT(EltType, NumElts), BOperand); + return DAG.getBitcast(VT, Brdcst); } } } @@ -6982,7 +7219,7 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, return SDValue(); } -/// \brief For an EXTRACT_VECTOR_ELT with a constant index return the real +/// For an EXTRACT_VECTOR_ELT with a constant index return the real /// underlying vector and index. /// /// Modifies \p ExtractedFromVec to the real vector and returns the real @@ -7195,7 +7432,7 @@ static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG, return DstVec; } -/// \brief Return true if \p N implements a horizontal binop and return the +/// Return true if \p N implements a horizontal binop and return the /// operands for the horizontal binop into V0 and V1. /// /// This is a helper function of LowerToHorizontalOp(). @@ -7292,7 +7529,7 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, return CanFold; } -/// \brief Emit a sequence of two 128-bit horizontal add/sub followed by +/// Emit a sequence of two 128-bit horizontal add/sub followed by /// a concat_vector. /// /// This is a helper function of LowerToHorizontalOp(). @@ -7360,18 +7597,18 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1, } /// Returns true iff \p BV builds a vector with the result equivalent to -/// the result of ADDSUB operation. -/// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1 operation -/// are written to the parameters \p Opnd0 and \p Opnd1. -static bool isAddSub(const BuildVectorSDNode *BV, - const X86Subtarget &Subtarget, SelectionDAG &DAG, - SDValue &Opnd0, SDValue &Opnd1, - unsigned &NumExtracts) { +/// the result of ADDSUB/SUBADD operation. +/// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1 +/// (SUBADD = Opnd0 -+ Opnd1) operation are written to the parameters +/// \p Opnd0 and \p Opnd1. +static bool isAddSubOrSubAdd(const BuildVectorSDNode *BV, + const X86Subtarget &Subtarget, SelectionDAG &DAG, + SDValue &Opnd0, SDValue &Opnd1, + unsigned &NumExtracts, + bool &IsSubAdd) { MVT VT = BV->getSimpleValueType(0); - if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && - (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) && - (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64))) + if (!Subtarget.hasSSE3() || !VT.isFloatingPoint()) return false; unsigned NumElts = VT.getVectorNumElements(); @@ -7381,26 +7618,20 @@ static bool isAddSub(const BuildVectorSDNode *BV, NumExtracts = 0; // Odd-numbered elements in the input build vector are obtained from - // adding two integer/float elements. + // adding/subtracting two integer/float elements. // Even-numbered elements in the input build vector are obtained from - // subtracting two integer/float elements. - unsigned ExpectedOpcode = ISD::FSUB; - unsigned NextExpectedOpcode = ISD::FADD; - bool AddFound = false; - bool SubFound = false; - + // subtracting/adding two integer/float elements. + unsigned Opc[2] {0, 0}; for (unsigned i = 0, e = NumElts; i != e; ++i) { SDValue Op = BV->getOperand(i); // Skip 'undef' values. unsigned Opcode = Op.getOpcode(); - if (Opcode == ISD::UNDEF) { - std::swap(ExpectedOpcode, NextExpectedOpcode); + if (Opcode == ISD::UNDEF) continue; - } // Early exit if we found an unexpected opcode. - if (Opcode != ExpectedOpcode) + if (Opcode != ISD::FADD && Opcode != ISD::FSUB) return false; SDValue Op0 = Op.getOperand(0); @@ -7420,11 +7651,11 @@ static bool isAddSub(const BuildVectorSDNode *BV, if (I0 != i) return false; - // We found a valid add/sub node. Update the information accordingly. - if (i & 1) - AddFound = true; - else - SubFound = true; + // We found a valid add/sub node, make sure its the same opcode as previous + // elements for this parity. + if (Opc[i % 2] != 0 && Opc[i % 2] != Opcode) + return false; + Opc[i % 2] = Opcode; // Update InVec0 and InVec1. if (InVec0.isUndef()) { @@ -7441,7 +7672,7 @@ static bool isAddSub(const BuildVectorSDNode *BV, // Make sure that operands in input to each add/sub node always // come from a same pair of vectors. if (InVec0 != Op0.getOperand(0)) { - if (ExpectedOpcode == ISD::FSUB) + if (Opcode == ISD::FSUB) return false; // FADD is commutable. Try to commute the operands @@ -7454,24 +7685,26 @@ static bool isAddSub(const BuildVectorSDNode *BV, if (InVec1 != Op1.getOperand(0)) return false; - // Update the pair of expected opcodes. - std::swap(ExpectedOpcode, NextExpectedOpcode); - // Increment the number of extractions done. ++NumExtracts; } - // Don't try to fold this build_vector into an ADDSUB if the inputs are undef. - if (!AddFound || !SubFound || InVec0.isUndef() || InVec1.isUndef()) + // Ensure we have found an opcode for both parities and that they are + // different. Don't try to fold this build_vector into an ADDSUB/SUBADD if the + // inputs are undef. + if (!Opc[0] || !Opc[1] || Opc[0] == Opc[1] || + InVec0.isUndef() || InVec1.isUndef()) return false; + IsSubAdd = Opc[0] == ISD::FADD; + Opnd0 = InVec0; Opnd1 = InVec1; return true; } /// Returns true if is possible to fold MUL and an idiom that has already been -/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into +/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into /// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the /// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2. /// @@ -7521,14 +7754,17 @@ static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget, return true; } -/// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' operation -/// accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB node. +/// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' or +/// 'fsubadd' operation accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB or +/// X86ISD::FMSUBADD node. static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue Opnd0, Opnd1; unsigned NumExtracts; - if (!isAddSub(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts)) + bool IsSubAdd; + if (!isAddSubOrSubAdd(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts, + IsSubAdd)) return SDValue(); MVT VT = BV->getSimpleValueType(0); @@ -7536,10 +7772,14 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, // Try to generate X86ISD::FMADDSUB node here. SDValue Opnd2; - // TODO: According to coverage reports, the FMADDSUB transform is not - // triggered by any tests. - if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) - return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) { + unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB; + return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2); + } + + // We only support ADDSUB. + if (IsSubAdd) + return SDValue(); // Do not generate X86ISD::ADDSUB node for 512-bit types even though // the ADDSUB idiom has been successfully recognized. There are no known @@ -7708,6 +7948,10 @@ static SDValue lowerBuildVectorToBitOp(BuildVectorSDNode *Op, case ISD::AND: case ISD::XOR: case ISD::OR: + // Don't do this if the buildvector is a splat - we'd replace one + // constant with an entire vector. + if (Op->getSplatValue()) + return SDValue(); if (!TLI.isOperationLegalOrPromote(Opcode, VT)) return SDValue(); break; @@ -7762,66 +8006,268 @@ static SDValue materializeVectorConstant(SDValue Op, SelectionDAG &DAG, return SDValue(); } -// Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be -// reasoned to be a permutation of a vector by indices in a non-constant vector. -// (build_vector (extract_elt V, (extract_elt I, 0)), -// (extract_elt V, (extract_elt I, 1)), -// ... -// -> -// (vpermv I, V) -// -// TODO: Handle undefs -// TODO: Utilize pshufb and zero mask blending to support more efficient -// construction of vectors with constant-0 elements. -// TODO: Use smaller-element vectors of same width, and "interpolate" the indices, -// when no native operation available. -static SDValue -LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - // Look for VPERMV and PSHUFB opportunities. - MVT VT = V.getSimpleValueType(); +/// Look for opportunities to create a VPERMV/VPERMILPV/PSHUFB variable permute +/// from a vector of source values and a vector of extraction indices. +/// The vectors might be manipulated to match the type of the permute op. +static SDValue createVariablePermute(MVT VT, SDValue SrcVec, SDValue IndicesVec, + SDLoc &DL, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + MVT ShuffleVT = VT; + EVT IndicesVT = EVT(VT).changeVectorElementTypeToInteger(); + unsigned NumElts = VT.getVectorNumElements(); + unsigned SizeInBits = VT.getSizeInBits(); + + // Adjust IndicesVec to match VT size. + assert(IndicesVec.getValueType().getVectorNumElements() >= NumElts && + "Illegal variable permute mask size"); + if (IndicesVec.getValueType().getVectorNumElements() > NumElts) + IndicesVec = extractSubVector(IndicesVec, 0, DAG, SDLoc(IndicesVec), + NumElts * VT.getScalarSizeInBits()); + IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT); + + // Handle SrcVec that don't match VT type. + if (SrcVec.getValueSizeInBits() != SizeInBits) { + if ((SrcVec.getValueSizeInBits() % SizeInBits) == 0) { + // Handle larger SrcVec by treating it as a larger permute. + unsigned Scale = SrcVec.getValueSizeInBits() / SizeInBits; + VT = MVT::getVectorVT(VT.getScalarType(), Scale * NumElts); + IndicesVT = EVT(VT).changeVectorElementTypeToInteger(); + IndicesVec = widenSubVector(IndicesVT.getSimpleVT(), IndicesVec, false, + Subtarget, DAG, SDLoc(IndicesVec)); + return extractSubVector( + createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget), 0, + DAG, DL, SizeInBits); + } else if (SrcVec.getValueSizeInBits() < SizeInBits) { + // Widen smaller SrcVec to match VT. + SrcVec = widenSubVector(VT, SrcVec, false, Subtarget, DAG, SDLoc(SrcVec)); + } else + return SDValue(); + } + + auto ScaleIndices = [&DAG](SDValue Idx, uint64_t Scale) { + assert(isPowerOf2_64(Scale) && "Illegal variable permute shuffle scale"); + EVT SrcVT = Idx.getValueType(); + unsigned NumDstBits = SrcVT.getScalarSizeInBits() / Scale; + uint64_t IndexScale = 0; + uint64_t IndexOffset = 0; + + // If we're scaling a smaller permute op, then we need to repeat the + // indices, scaling and offsetting them as well. + // e.g. v4i32 -> v16i8 (Scale = 4) + // IndexScale = v4i32 Splat(4 << 24 | 4 << 16 | 4 << 8 | 4) + // IndexOffset = v4i32 Splat(3 << 24 | 2 << 16 | 1 << 8 | 0) + for (uint64_t i = 0; i != Scale; ++i) { + IndexScale |= Scale << (i * NumDstBits); + IndexOffset |= i << (i * NumDstBits); + } + + Idx = DAG.getNode(ISD::MUL, SDLoc(Idx), SrcVT, Idx, + DAG.getConstant(IndexScale, SDLoc(Idx), SrcVT)); + Idx = DAG.getNode(ISD::ADD, SDLoc(Idx), SrcVT, Idx, + DAG.getConstant(IndexOffset, SDLoc(Idx), SrcVT)); + return Idx; + }; + + unsigned Opcode = 0; switch (VT.SimpleTy) { default: - return SDValue(); + break; case MVT::v16i8: - if (!Subtarget.hasSSE3()) - return SDValue(); + if (Subtarget.hasSSSE3()) + Opcode = X86ISD::PSHUFB; + break; + case MVT::v8i16: + if (Subtarget.hasVLX() && Subtarget.hasBWI()) + Opcode = X86ISD::VPERMV; + else if (Subtarget.hasSSSE3()) { + Opcode = X86ISD::PSHUFB; + ShuffleVT = MVT::v16i8; + } + break; + case MVT::v4f32: + case MVT::v4i32: + if (Subtarget.hasAVX()) { + Opcode = X86ISD::VPERMILPV; + ShuffleVT = MVT::v4f32; + } else if (Subtarget.hasSSSE3()) { + Opcode = X86ISD::PSHUFB; + ShuffleVT = MVT::v16i8; + } + break; + case MVT::v2f64: + case MVT::v2i64: + if (Subtarget.hasAVX()) { + // VPERMILPD selects using bit#1 of the index vector, so scale IndicesVec. + IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec); + Opcode = X86ISD::VPERMILPV; + ShuffleVT = MVT::v2f64; + } else if (Subtarget.hasSSE41()) { + // SSE41 can compare v2i64 - select between indices 0 and 1. + return DAG.getSelectCC( + DL, IndicesVec, + getZeroVector(IndicesVT.getSimpleVT(), Subtarget, DAG, DL), + DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {0, 0}), + DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {1, 1}), + ISD::CondCode::SETEQ); + } + break; + case MVT::v32i8: + if (Subtarget.hasVLX() && Subtarget.hasVBMI()) + Opcode = X86ISD::VPERMV; + else if (Subtarget.hasXOP()) { + SDValue LoSrc = extract128BitVector(SrcVec, 0, DAG, DL); + SDValue HiSrc = extract128BitVector(SrcVec, 16, DAG, DL); + SDValue LoIdx = extract128BitVector(IndicesVec, 0, DAG, DL); + SDValue HiIdx = extract128BitVector(IndicesVec, 16, DAG, DL); + return DAG.getNode( + ISD::CONCAT_VECTORS, DL, VT, + DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, LoIdx), + DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, HiIdx)); + } else if (Subtarget.hasAVX()) { + SDValue Lo = extract128BitVector(SrcVec, 0, DAG, DL); + SDValue Hi = extract128BitVector(SrcVec, 16, DAG, DL); + SDValue LoLo = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Lo); + SDValue HiHi = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Hi, Hi); + auto PSHUFBBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + // Permute Lo and Hi and then select based on index range. + // This works as SHUFB uses bits[3:0] to permute elements and we don't + // care about the bit[7] as its just an index vector. + SDValue Idx = Ops[2]; + EVT VT = Idx.getValueType(); + return DAG.getSelectCC(DL, Idx, DAG.getConstant(15, DL, VT), + DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[1], Idx), + DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[0], Idx), + ISD::CondCode::SETGT); + }; + SDValue Ops[] = {LoLo, HiHi, IndicesVec}; + return SplitOpsAndApply(DAG, Subtarget, DL, MVT::v32i8, Ops, + PSHUFBBuilder); + } + break; + case MVT::v16i16: + if (Subtarget.hasVLX() && Subtarget.hasBWI()) + Opcode = X86ISD::VPERMV; + else if (Subtarget.hasAVX()) { + // Scale to v32i8 and perform as v32i8. + IndicesVec = ScaleIndices(IndicesVec, 2); + return DAG.getBitcast( + VT, createVariablePermute( + MVT::v32i8, DAG.getBitcast(MVT::v32i8, SrcVec), + DAG.getBitcast(MVT::v32i8, IndicesVec), DL, DAG, Subtarget)); + } break; case MVT::v8f32: case MVT::v8i32: - if (!Subtarget.hasAVX2()) - return SDValue(); + if (Subtarget.hasAVX2()) + Opcode = X86ISD::VPERMV; + else if (Subtarget.hasAVX()) { + SrcVec = DAG.getBitcast(MVT::v8f32, SrcVec); + SDValue LoLo = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec, + {0, 1, 2, 3, 0, 1, 2, 3}); + SDValue HiHi = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec, + {4, 5, 6, 7, 4, 5, 6, 7}); + if (Subtarget.hasXOP()) + return DAG.getBitcast(VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v8f32, + LoLo, HiHi, IndicesVec, + DAG.getConstant(0, DL, MVT::i8))); + // Permute Lo and Hi and then select based on index range. + // This works as VPERMILPS only uses index bits[0:1] to permute elements. + SDValue Res = DAG.getSelectCC( + DL, IndicesVec, DAG.getConstant(3, DL, MVT::v8i32), + DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, HiHi, IndicesVec), + DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, LoLo, IndicesVec), + ISD::CondCode::SETGT); + return DAG.getBitcast(VT, Res); + } break; case MVT::v4i64: case MVT::v4f64: - if (!Subtarget.hasVLX()) - return SDValue(); + if (Subtarget.hasAVX512()) { + if (!Subtarget.hasVLX()) { + MVT WidenSrcVT = MVT::getVectorVT(VT.getScalarType(), 8); + SrcVec = widenSubVector(WidenSrcVT, SrcVec, false, Subtarget, DAG, + SDLoc(SrcVec)); + IndicesVec = widenSubVector(MVT::v8i64, IndicesVec, false, Subtarget, + DAG, SDLoc(IndicesVec)); + SDValue Res = createVariablePermute(WidenSrcVT, SrcVec, IndicesVec, DL, + DAG, Subtarget); + return extract256BitVector(Res, 0, DAG, DL); + } + Opcode = X86ISD::VPERMV; + } else if (Subtarget.hasAVX()) { + SrcVec = DAG.getBitcast(MVT::v4f64, SrcVec); + SDValue LoLo = + DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {0, 1, 0, 1}); + SDValue HiHi = + DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {2, 3, 2, 3}); + // VPERMIL2PD selects with bit#1 of the index vector, so scale IndicesVec. + IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec); + if (Subtarget.hasXOP()) + return DAG.getBitcast(VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v4f64, + LoLo, HiHi, IndicesVec, + DAG.getConstant(0, DL, MVT::i8))); + // Permute Lo and Hi and then select based on index range. + // This works as VPERMILPD only uses index bit[1] to permute elements. + SDValue Res = DAG.getSelectCC( + DL, IndicesVec, DAG.getConstant(2, DL, MVT::v4i64), + DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, HiHi, IndicesVec), + DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, LoLo, IndicesVec), + ISD::CondCode::SETGT); + return DAG.getBitcast(VT, Res); + } break; - case MVT::v16f32: - case MVT::v8f64: - case MVT::v16i32: - case MVT::v8i64: - if (!Subtarget.hasAVX512()) - return SDValue(); + case MVT::v64i8: + if (Subtarget.hasVBMI()) + Opcode = X86ISD::VPERMV; break; case MVT::v32i16: - if (!Subtarget.hasBWI()) - return SDValue(); - break; - case MVT::v8i16: - case MVT::v16i16: - if (!Subtarget.hasVLX() || !Subtarget.hasBWI()) - return SDValue(); - break; - case MVT::v64i8: - if (!Subtarget.hasVBMI()) - return SDValue(); + if (Subtarget.hasBWI()) + Opcode = X86ISD::VPERMV; break; - case MVT::v32i8: - if (!Subtarget.hasVLX() || !Subtarget.hasVBMI()) - return SDValue(); + case MVT::v16f32: + case MVT::v16i32: + case MVT::v8f64: + case MVT::v8i64: + if (Subtarget.hasAVX512()) + Opcode = X86ISD::VPERMV; break; } + if (!Opcode) + return SDValue(); + + assert((VT.getSizeInBits() == ShuffleVT.getSizeInBits()) && + (VT.getScalarSizeInBits() % ShuffleVT.getScalarSizeInBits()) == 0 && + "Illegal variable permute shuffle type"); + + uint64_t Scale = VT.getScalarSizeInBits() / ShuffleVT.getScalarSizeInBits(); + if (Scale > 1) + IndicesVec = ScaleIndices(IndicesVec, Scale); + + EVT ShuffleIdxVT = EVT(ShuffleVT).changeVectorElementTypeToInteger(); + IndicesVec = DAG.getBitcast(ShuffleIdxVT, IndicesVec); + + SrcVec = DAG.getBitcast(ShuffleVT, SrcVec); + SDValue Res = Opcode == X86ISD::VPERMV + ? DAG.getNode(Opcode, DL, ShuffleVT, IndicesVec, SrcVec) + : DAG.getNode(Opcode, DL, ShuffleVT, SrcVec, IndicesVec); + return DAG.getBitcast(VT, Res); +} + +// Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be +// reasoned to be a permutation of a vector by indices in a non-constant vector. +// (build_vector (extract_elt V, (extract_elt I, 0)), +// (extract_elt V, (extract_elt I, 1)), +// ... +// -> +// (vpermv I, V) +// +// TODO: Handle undefs +// TODO: Utilize pshufb and zero mask blending to support more efficient +// construction of vectors with constant-0 elements. +static SDValue +LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue SrcVec, IndicesVec; // Check for a match of the permute source vector and permute index elements. // This is done by checking that the i-th build_vector operand is of the form: @@ -7858,13 +8304,10 @@ LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG, if (!PermIdx || PermIdx->getZExtValue() != Idx) return SDValue(); } - MVT IndicesVT = VT; - if (VT.isFloatingPoint()) - IndicesVT = MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits()), - VT.getVectorNumElements()); - IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT); - return DAG.getNode(VT == MVT::v16i8 ? X86ISD::PSHUFB : X86ISD::VPERMV, - SDLoc(V), VT, IndicesVec, SrcVec); + + SDLoc DL(V); + MVT VT = V.getSimpleValueType(); + return createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget); } SDValue @@ -7872,7 +8315,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); - MVT ExtVT = VT.getVectorElementType(); + MVT EltVT = VT.getVectorElementType(); unsigned NumElems = Op.getNumOperands(); // Generate vectors for predicate vectors. @@ -7883,8 +8326,6 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { return VectorConstant; BuildVectorSDNode *BV = cast<BuildVectorSDNode>(Op.getNode()); - // TODO: Support FMSUBADD here if we ever get tests for the FMADDSUB - // transform here. if (SDValue AddSub = lowerToAddSubOrFMAddSub(BV, Subtarget, DAG)) return AddSub; if (SDValue HorizontalOp = LowerToHorizontalOp(BV, Subtarget, DAG)) @@ -7894,7 +8335,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { if (SDValue BitOp = lowerBuildVectorToBitOp(BV, DAG)) return BitOp; - unsigned EVTBits = ExtVT.getSizeInBits(); + unsigned EVTBits = EltVT.getSizeInBits(); unsigned NumZero = 0; unsigned NumNonZero = 0; @@ -7930,13 +8371,13 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // supported, we assume that we will fall back to a shuffle to get the scalar // blended with the constants. Insertion into a zero vector is handled as a // special-case somewhere below here. - LLVMContext &Context = *DAG.getContext(); if (NumConstants == NumElems - 1 && NumNonZero != 1 && (isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT) || isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) { // Create an all-constant vector. The variable element in the old // build vector is replaced by undef in the constant vector. Save the // variable scalar element and its index for use in the insertelement. + LLVMContext &Context = *DAG.getContext(); Type *EltType = Op.getValueType().getScalarType().getTypeForEVT(Context); SmallVector<Constant *, 16> ConstVecOps(NumElems, UndefValue::get(EltType)); SDValue VarElt; @@ -7975,27 +8416,6 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { unsigned Idx = countTrailingZeros(NonZeros); SDValue Item = Op.getOperand(Idx); - // If this is an insertion of an i64 value on x86-32, and if the top bits of - // the value are obviously zero, truncate the value to i32 and do the - // insertion that way. Only do this if the value is non-constant or if the - // value is a constant being inserted into element 0. It is cheaper to do - // a constant pool load than it is to do a movd + shuffle. - if (ExtVT == MVT::i64 && !Subtarget.is64Bit() && - (!IsAllConstants || Idx == 0)) { - if (DAG.MaskedValueIsZero(Item, APInt::getHighBitsSet(64, 32))) { - // Handle SSE only. - assert(VT == MVT::v2i64 && "Expected an SSE value type!"); - MVT VecVT = MVT::v4i32; - - // Truncate the value (which may itself be a constant) to i32, and - // convert it to a vector with movd (S2V+shuffle to zero extend). - Item = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Item); - Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Item); - return DAG.getBitcast(VT, getShuffleVectorZeroOrUndef( - Item, Idx * 2, true, Subtarget, DAG)); - } - } - // If we have a constant or non-constant insertion into the low element of // a vector, we can do this with SCALAR_TO_VECTOR + shuffle of zero into // the rest of the elements. This will be matched as movd/movq/movss/movsd @@ -8004,8 +8424,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { if (NumZero == 0) return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item); - if (ExtVT == MVT::i32 || ExtVT == MVT::f32 || ExtVT == MVT::f64 || - (ExtVT == MVT::i64 && Subtarget.is64Bit())) { + if (EltVT == MVT::i32 || EltVT == MVT::f32 || EltVT == MVT::f64 || + (EltVT == MVT::i64 && Subtarget.is64Bit())) { assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) && "Expected an SSE value type!"); @@ -8016,7 +8436,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // We can't directly insert an i8 or i16 into a vector, so zero extend // it to i32 first. - if (ExtVT == MVT::i16 || ExtVT == MVT::i8) { + if (EltVT == MVT::i16 || EltVT == MVT::i8) { Item = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Item); if (VT.getSizeInBits() >= 256) { MVT ShufVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits()/32); @@ -8088,17 +8508,43 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { return V; // See if we can use a vector load to get all of the elements. - if (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) { + { SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems); if (SDValue LD = EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false)) return LD; } + // If this is a splat of pairs of 32-bit elements, we can use a narrower + // build_vector and broadcast it. + // TODO: We could probably generalize this more. + if (Subtarget.hasAVX2() && EVTBits == 32 && Values.size() == 2) { + SDValue Ops[4] = { Op.getOperand(0), Op.getOperand(1), + DAG.getUNDEF(EltVT), DAG.getUNDEF(EltVT) }; + auto CanSplat = [](SDValue Op, unsigned NumElems, ArrayRef<SDValue> Ops) { + // Make sure all the even/odd operands match. + for (unsigned i = 2; i != NumElems; ++i) + if (Ops[i % 2] != Op.getOperand(i)) + return false; + return true; + }; + if (CanSplat(Op, NumElems, Ops)) { + MVT WideEltVT = VT.isFloatingPoint() ? MVT::f64 : MVT::i64; + MVT NarrowVT = MVT::getVectorVT(EltVT, 4); + // Create a new build vector and cast to v2i64/v2f64. + SDValue NewBV = DAG.getBitcast(MVT::getVectorVT(WideEltVT, 2), + DAG.getBuildVector(NarrowVT, dl, Ops)); + // Broadcast from v2i64/v2f64 and cast to final VT. + MVT BcastVT = MVT::getVectorVT(WideEltVT, NumElems/2); + return DAG.getBitcast(VT, DAG.getNode(X86ISD::VBROADCAST, dl, BcastVT, + NewBV)); + } + } + // For AVX-length vectors, build the individual 128-bit pieces and use // shuffles to put them in place. - if (VT.is256BitVector() || VT.is512BitVector()) { - EVT HVT = EVT::getVectorVT(Context, ExtVT, NumElems/2); + if (VT.getSizeInBits() > 128) { + MVT HVT = MVT::getVectorVT(EltVT, NumElems/2); // Build both the lower and upper subvector. SDValue Lower = @@ -8107,9 +8553,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { HVT, dl, Op->ops().slice(NumElems / 2, NumElems /2)); // Recreate the wider vector with the lower and upper part. - if (VT.is256BitVector()) - return concat128BitVectors(Lower, Upper, VT, NumElems, DAG, dl); - return concat256BitVectors(Lower, Upper, VT, NumElems, DAG, dl); + return concatSubVectors(Lower, Upper, VT, NumElems, DAG, dl, + VT.getSizeInBits() / 2); } // Let legalizer expand 2-wide build_vectors. @@ -8234,30 +8679,60 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // 256-bit AVX can use the vinsertf128 instruction // to create 256-bit vectors from two other 128-bit ones. -static SDValue LowerAVXCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) { +// TODO: Detect subvector broadcast here instead of DAG combine? +static SDValue LowerAVXCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc dl(Op); MVT ResVT = Op.getSimpleValueType(); assert((ResVT.is256BitVector() || ResVT.is512BitVector()) && "Value type must be 256-/512-bit wide"); - SDValue V1 = Op.getOperand(0); - SDValue V2 = Op.getOperand(1); - unsigned NumElems = ResVT.getVectorNumElements(); - if (ResVT.is256BitVector()) - return concat128BitVectors(V1, V2, ResVT, NumElems, DAG, dl); + unsigned NumOperands = Op.getNumOperands(); + unsigned NumZero = 0; + unsigned NumNonZero = 0; + unsigned NonZeros = 0; + for (unsigned i = 0; i != NumOperands; ++i) { + SDValue SubVec = Op.getOperand(i); + if (SubVec.isUndef()) + continue; + if (ISD::isBuildVectorAllZeros(SubVec.getNode())) + ++NumZero; + else { + assert(i < sizeof(NonZeros) * CHAR_BIT); // Ensure the shift is in range. + NonZeros |= 1 << i; + ++NumNonZero; + } + } - if (Op.getNumOperands() == 4) { + // If we have more than 2 non-zeros, build each half separately. + if (NumNonZero > 2) { MVT HalfVT = MVT::getVectorVT(ResVT.getVectorElementType(), ResVT.getVectorNumElements()/2); - SDValue V3 = Op.getOperand(2); - SDValue V4 = Op.getOperand(3); - return concat256BitVectors( - concat128BitVectors(V1, V2, HalfVT, NumElems / 2, DAG, dl), - concat128BitVectors(V3, V4, HalfVT, NumElems / 2, DAG, dl), ResVT, - NumElems, DAG, dl); + ArrayRef<SDUse> Ops = Op->ops(); + SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, + Ops.slice(0, NumOperands/2)); + SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, + Ops.slice(NumOperands/2)); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi); + } + + // Otherwise, build it up through insert_subvectors. + SDValue Vec = NumZero ? getZeroVector(ResVT, Subtarget, DAG, dl) + : DAG.getUNDEF(ResVT); + + MVT SubVT = Op.getOperand(0).getSimpleValueType(); + unsigned NumSubElems = SubVT.getVectorNumElements(); + for (unsigned i = 0; i != NumOperands; ++i) { + if ((NonZeros & (1 << i)) == 0) + continue; + + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec, + Op.getOperand(i), + DAG.getIntPtrConstant(i * NumSubElems, dl)); } - return concat256BitVectors(V1, V2, ResVT, NumElems, DAG, dl); + + return Vec; } // Return true if all the operands of the given CONCAT_VECTORS node are zeros @@ -8314,6 +8789,7 @@ static SDValue isTypePromotionOfi1ZeroUpBits(SDValue Op) { return SDValue(); } +// TODO: Merge this with LowerAVXCONCAT_VECTORS? static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG & DAG) { @@ -8328,12 +8804,8 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op, // of a node with instruction that zeroes all upper (irrelevant) bits of the // output register, mark it as legal and catch the pattern in instruction // selection to avoid emitting extra instructions (for zeroing upper bits). - if (SDValue Promoted = isTypePromotionOfi1ZeroUpBits(Op)) { - SDValue ZeroC = DAG.getIntPtrConstant(0, dl); - SDValue AllZeros = getZeroVector(ResVT, Subtarget, DAG, dl); - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, AllZeros, Promoted, - ZeroC); - } + if (SDValue Promoted = isTypePromotionOfi1ZeroUpBits(Op)) + return widenSubVector(ResVT, Promoted, true, Subtarget, DAG, dl); unsigned NumZero = 0; unsigned NumNonZero = 0; @@ -8404,7 +8876,7 @@ static SDValue LowerCONCAT_VECTORS(SDValue Op, // from two other 128-bit ones. // 512-bit vector may contain 2 256-bit vectors or 4 128-bit vectors - return LowerAVXCONCAT_VECTORS(Op, DAG); + return LowerAVXCONCAT_VECTORS(Op, DAG, Subtarget); } //===----------------------------------------------------------------------===// @@ -8418,7 +8890,7 @@ static SDValue LowerCONCAT_VECTORS(SDValue Op, // patterns. //===----------------------------------------------------------------------===// -/// \brief Tiny helper function to identify a no-op mask. +/// Tiny helper function to identify a no-op mask. /// /// This is a somewhat boring predicate function. It checks whether the mask /// array input, which is assumed to be a single-input shuffle mask of the kind @@ -8434,7 +8906,7 @@ static bool isNoopShuffleMask(ArrayRef<int> Mask) { return true; } -/// \brief Test whether there are elements crossing 128-bit lanes in this +/// Test whether there are elements crossing 128-bit lanes in this /// shuffle mask. /// /// X86 divides up its shuffles into in-lane and cross-lane shuffle operations @@ -8448,7 +8920,7 @@ static bool is128BitLaneCrossingShuffleMask(MVT VT, ArrayRef<int> Mask) { return false; } -/// \brief Test whether a shuffle mask is equivalent within each sub-lane. +/// Test whether a shuffle mask is equivalent within each sub-lane. /// /// This checks a shuffle mask to see if it is performing the same /// lane-relative shuffle in each sub-lane. This trivially implies @@ -8494,6 +8966,12 @@ is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask, return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask); } +static bool +is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask) { + SmallVector<int, 32> RepeatedMask; + return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask); +} + /// Test whether a shuffle mask is equivalent within each 256-bit lane. static bool is256BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask, @@ -8537,7 +9015,7 @@ static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits, MVT VT, return true; } -/// \brief Checks whether a shuffle mask is equivalent to an explicit list of +/// Checks whether a shuffle mask is equivalent to an explicit list of /// arguments. /// /// This is a fast way to test a shuffle mask against a fixed pattern: @@ -8634,7 +9112,7 @@ static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT) { return IsUnpackwdMask; } -/// \brief Get a 4-lane 8-bit shuffle immediate for a mask. +/// Get a 4-lane 8-bit shuffle immediate for a mask. /// /// This helper function produces an 8-bit shuffle immediate corresponding to /// the ubiquitous shuffle encoding scheme used in x86 instructions for @@ -8662,7 +9140,7 @@ static SDValue getV4X86ShuffleImm8ForMask(ArrayRef<int> Mask, const SDLoc &DL, return DAG.getConstant(getV4X86ShuffleImm(Mask), DL, MVT::i8); } -/// \brief Compute whether each element of a shuffle is zeroable. +/// Compute whether each element of a shuffle is zeroable. /// /// A "zeroable" vector shuffle element is one which can be lowered to zero. /// Either it is an undef element in the shuffle mask, the element of the input @@ -8859,8 +9337,8 @@ static SDValue lowerVectorShuffleToEXPAND(const SDLoc &DL, MVT VT, static bool matchVectorShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2, unsigned &UnpackOpcode, bool IsUnary, - ArrayRef<int> TargetMask, SDLoc &DL, - SelectionDAG &DAG, + ArrayRef<int> TargetMask, + const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget) { int NumElts = VT.getVectorNumElements(); @@ -8969,6 +9447,99 @@ static SDValue lowerVectorShuffleWithUNPCK(const SDLoc &DL, MVT VT, return SDValue(); } +static bool matchVectorShuffleAsVPMOV(ArrayRef<int> Mask, bool SwappedOps, + int Delta) { + int Size = (int)Mask.size(); + int Split = Size / Delta; + int TruncatedVectorStart = SwappedOps ? Size : 0; + + // Match for mask starting with e.g.: <8, 10, 12, 14,... or <0, 2, 4, 6,... + if (!isSequentialOrUndefInRange(Mask, 0, Split, TruncatedVectorStart, Delta)) + return false; + + // The rest of the mask should not refer to the truncated vector's elements. + if (isAnyInRange(Mask.slice(Split, Size - Split), TruncatedVectorStart, + TruncatedVectorStart + Size)) + return false; + + return true; +} + +// Try to lower trunc+vector_shuffle to a vpmovdb or a vpmovdw instruction. +// +// An example is the following: +// +// t0: ch = EntryToken +// t2: v4i64,ch = CopyFromReg t0, Register:v4i64 %0 +// t25: v4i32 = truncate t2 +// t41: v8i16 = bitcast t25 +// t21: v8i16 = BUILD_VECTOR undef:i16, undef:i16, undef:i16, undef:i16, +// Constant:i16<0>, Constant:i16<0>, Constant:i16<0>, Constant:i16<0> +// t51: v8i16 = vector_shuffle<0,2,4,6,12,13,14,15> t41, t21 +// t18: v2i64 = bitcast t51 +// +// Without avx512vl, this is lowered to: +// +// vpmovqd %zmm0, %ymm0 +// vpshufb {{.*#+}} xmm0 = +// xmm0[0,1,4,5,8,9,12,13],zero,zero,zero,zero,zero,zero,zero,zero +// +// But when avx512vl is available, one can just use a single vpmovdw +// instruction. +static SDValue lowerVectorShuffleWithVPMOV(const SDLoc &DL, ArrayRef<int> Mask, + MVT VT, SDValue V1, SDValue V2, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (VT != MVT::v16i8 && VT != MVT::v8i16) + return SDValue(); + + if (Mask.size() != VT.getVectorNumElements()) + return SDValue(); + + bool SwappedOps = false; + + if (!ISD::isBuildVectorAllZeros(V2.getNode())) { + if (!ISD::isBuildVectorAllZeros(V1.getNode())) + return SDValue(); + + std::swap(V1, V2); + SwappedOps = true; + } + + // Look for: + // + // bitcast (truncate <8 x i32> %vec to <8 x i16>) to <16 x i8> + // bitcast (truncate <4 x i64> %vec to <4 x i32>) to <8 x i16> + // + // and similar ones. + if (V1.getOpcode() != ISD::BITCAST) + return SDValue(); + if (V1.getOperand(0).getOpcode() != ISD::TRUNCATE) + return SDValue(); + + SDValue Src = V1.getOperand(0).getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + + // The vptrunc** instructions truncating 128 bit and 256 bit vectors + // are only available with avx512vl. + if (!SrcVT.is512BitVector() && !Subtarget.hasVLX()) + return SDValue(); + + // Down Convert Word to Byte is only available with avx512bw. The case with + // 256-bit output doesn't contain a shuffle and is therefore not handled here. + if (SrcVT.getVectorElementType() == MVT::i16 && VT == MVT::v16i8 && + !Subtarget.hasBWI()) + return SDValue(); + + // The first half/quarter of the mask should refer to every second/fourth + // element of the vector truncated and bitcasted. + if (!matchVectorShuffleAsVPMOV(Mask, SwappedOps, 2) && + !matchVectorShuffleAsVPMOV(Mask, SwappedOps, 4)) + return SDValue(); + + return DAG.getNode(X86ISD::VTRUNC, DL, VT, Src); +} + // X86 has dedicated pack instructions that can handle specific truncation // operations: PACKSS and PACKUS. static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, @@ -8984,15 +9555,6 @@ static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, auto MatchPACK = [&](SDValue N1, SDValue N2) { SDValue VV1 = DAG.getBitcast(PackVT, N1); SDValue VV2 = DAG.getBitcast(PackVT, N2); - if ((N1.isUndef() || DAG.ComputeNumSignBits(VV1) > BitSize) && - (N2.isUndef() || DAG.ComputeNumSignBits(VV2) > BitSize)) { - V1 = VV1; - V2 = VV2; - SrcVT = PackVT; - PackOpcode = X86ISD::PACKSS; - return true; - } - if (Subtarget.hasSSE41() || PackSVT == MVT::i16) { APInt ZeroMask = APInt::getHighBitsSet(BitSize * 2, BitSize); if ((N1.isUndef() || DAG.MaskedValueIsZero(VV1, ZeroMask)) && @@ -9004,7 +9566,14 @@ static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, return true; } } - + if ((N1.isUndef() || DAG.ComputeNumSignBits(VV1) > BitSize) && + (N2.isUndef() || DAG.ComputeNumSignBits(VV2) > BitSize)) { + V1 = VV1; + V2 = VV2; + SrcVT = PackVT; + PackOpcode = X86ISD::PACKSS; + return true; + } return false; }; @@ -9039,7 +9608,7 @@ static SDValue lowerVectorShuffleWithPACK(const SDLoc &DL, MVT VT, return SDValue(); } -/// \brief Try to emit a bitmask instruction for a shuffle. +/// Try to emit a bitmask instruction for a shuffle. /// /// This handles cases where we can model a blend exactly as a bitmask due to /// one of the inputs being zeroable. @@ -9072,7 +9641,7 @@ static SDValue lowerVectorShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1, return DAG.getNode(ISD::AND, DL, VT, V, VMask); } -/// \brief Try to emit a blend instruction for a shuffle using bit math. +/// Try to emit a blend instruction for a shuffle using bit math. /// /// This is used as a fallback approach when first class blend instructions are /// unavailable. Currently it is only suitable for integer vectors, but could @@ -9159,7 +9728,7 @@ static uint64_t scaleVectorShuffleBlendMask(uint64_t BlendMask, int Size, return ScaledMask; } -/// \brief Try to emit a blend instruction for a shuffle. +/// Try to emit a blend instruction for a shuffle. /// /// This doesn't do any checks for the availability of instructions for blending /// these values. It relies on the availability of the X86ISD::BLENDI pattern to @@ -9305,7 +9874,7 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, } } -/// \brief Try to lower as a blend of elements from two inputs followed by +/// Try to lower as a blend of elements from two inputs followed by /// a single-input permutation. /// /// This matches the pattern where we can blend elements from two inputs and @@ -9337,7 +9906,7 @@ static SDValue lowerVectorShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT, return DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), PermuteMask); } -/// \brief Generic routine to decompose a shuffle and blend into independent +/// Generic routine to decompose a shuffle and blend into independent /// blends and permutes. /// /// This matches the extremely common pattern for handling combined @@ -9378,7 +9947,7 @@ static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(const SDLoc &DL, return DAG.getVectorShuffle(VT, DL, V1, V2, BlendMask); } -/// \brief Try to lower a vector shuffle as a rotation. +/// Try to lower a vector shuffle as a rotation. /// /// This is used for support PALIGNR for SSSE3 or VALIGND/Q for AVX512. static int matchVectorShuffleAsRotate(SDValue &V1, SDValue &V2, @@ -9450,7 +10019,7 @@ static int matchVectorShuffleAsRotate(SDValue &V1, SDValue &V2, return Rotation; } -/// \brief Try to lower a vector shuffle as a byte rotation. +/// Try to lower a vector shuffle as a byte rotation. /// /// SSSE3 has a generic PALIGNR instruction in x86 that will do an arbitrary /// byte-rotation of the concatenation of two vectors; pre-SSSE3 can use @@ -9534,7 +10103,7 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, DAG.getNode(ISD::OR, DL, MVT::v16i8, LoShift, HiShift)); } -/// \brief Try to lower a vector shuffle as a dword/qword rotation. +/// Try to lower a vector shuffle as a dword/qword rotation. /// /// AVX512 has a VALIGND/VALIGNQ instructions that will do an arbitrary /// rotation of the concatenation of two vectors; This routine will @@ -9565,7 +10134,7 @@ static SDValue lowerVectorShuffleAsRotate(const SDLoc &DL, MVT VT, DAG.getConstant(Rotation, DL, MVT::i8)); } -/// \brief Try to lower a vector shuffle as a bit shift (shifts in zeros). +/// Try to lower a vector shuffle as a bit shift (shifts in zeros). /// /// Attempts to match a shuffle mask against the PSLL(W/D/Q/DQ) and /// PSRL(W/D/Q/DQ) SSE2 and AVX2 logical bit-shift instructions. The function @@ -9809,7 +10378,7 @@ static bool matchVectorShuffleAsINSERTQ(MVT VT, SDValue &V1, SDValue &V2, return false; } -/// \brief Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ. +/// Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ. static SDValue lowerVectorShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, const APInt &Zeroable, @@ -9829,7 +10398,7 @@ static SDValue lowerVectorShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1, return SDValue(); } -/// \brief Lower a vector shuffle as a zero or any extension. +/// Lower a vector shuffle as a zero or any extension. /// /// Given a specific number of elements, element bit width, and extension /// stride, produce either a zero or any extension based on the available @@ -9984,7 +10553,7 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend( return DAG.getBitcast(VT, InputV); } -/// \brief Try to lower a vector shuffle as a zero extension on any microarch. +/// Try to lower a vector shuffle as a zero extension on any microarch. /// /// This routine will try to do everything in its power to cleverly lower /// a shuffle which happens to match the pattern of a zero extend. It doesn't @@ -10112,7 +10681,7 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend( return SDValue(); } -/// \brief Try to get a scalar value for a specific element of a vector. +/// Try to get a scalar value for a specific element of a vector. /// /// Looks through BUILD_VECTOR and SCALAR_TO_VECTOR nodes to find a scalar. static SDValue getScalarValueForVectorElement(SDValue V, int Idx, @@ -10139,7 +10708,7 @@ static SDValue getScalarValueForVectorElement(SDValue V, int Idx, return SDValue(); } -/// \brief Helper to test for a load that can be folded with x86 shuffles. +/// Helper to test for a load that can be folded with x86 shuffles. /// /// This is particularly important because the set of instructions varies /// significantly based on whether the operand is a load or not. @@ -10148,7 +10717,7 @@ static bool isShuffleFoldableLoad(SDValue V) { return ISD::isNON_EXTLoad(V.getNode()); } -/// \brief Try to lower insertion of a single element into a zero vector. +/// Try to lower insertion of a single element into a zero vector. /// /// This is a common pattern that we have especially efficient patterns to lower /// across all subtarget feature sets. @@ -10239,9 +10808,7 @@ static SDValue lowerVectorShuffleAsElementInsertion( V2 = DAG.getBitcast(MVT::v16i8, V2); V2 = DAG.getNode( X86ISD::VSHLDQ, DL, MVT::v16i8, V2, - DAG.getConstant(V2Index * EltVT.getSizeInBits() / 8, DL, - DAG.getTargetLoweringInfo().getScalarShiftAmountTy( - DAG.getDataLayout(), VT))); + DAG.getConstant(V2Index * EltVT.getSizeInBits() / 8, DL, MVT::i8)); V2 = DAG.getBitcast(VT, V2); } } @@ -10295,13 +10862,13 @@ static SDValue lowerVectorShuffleAsTruncBroadcast(const SDLoc &DL, MVT VT, // vpbroadcast+vmovd+shr to vpshufb(m)+vmovd. if (const int OffsetIdx = BroadcastIdx % Scale) Scalar = DAG.getNode(ISD::SRL, DL, Scalar.getValueType(), Scalar, - DAG.getConstant(OffsetIdx * EltSize, DL, Scalar.getValueType())); + DAG.getConstant(OffsetIdx * EltSize, DL, MVT::i8)); return DAG.getNode(X86ISD::VBROADCAST, DL, VT, DAG.getNode(ISD::TRUNCATE, DL, EltVT, Scalar)); } -/// \brief Try to lower broadcast of a single element. +/// Try to lower broadcast of a single element. /// /// For convenience, this code also bundles all of the subtarget feature set /// filtering. While a little annoying to re-dispatch on type here, there isn't @@ -10626,7 +11193,7 @@ static SDValue lowerVectorShuffleAsInsertPS(const SDLoc &DL, SDValue V1, DAG.getConstant(InsertPSMask, DL, MVT::i8)); } -/// \brief Try to lower a shuffle as a permute of the inputs followed by an +/// Try to lower a shuffle as a permute of the inputs followed by an /// UNPCK instruction. /// /// This specifically targets cases where we end up with alternating between @@ -10738,7 +11305,7 @@ static SDValue lowerVectorShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT, return SDValue(); } -/// \brief Handle lowering of 2-lane 64-bit floating point shuffles. +/// Handle lowering of 2-lane 64-bit floating point shuffles. /// /// This is the basis function for the 2-lane 64-bit shuffles as we have full /// support for floating point shuffles but not integer shuffles. These @@ -10777,22 +11344,23 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Mask[1] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1, DAG.getConstant(SHUFPDMask, DL, MVT::i8)); } - assert(Mask[0] >= 0 && Mask[0] < 2 && "Non-canonicalized blend!"); - assert(Mask[1] >= 2 && "Non-canonicalized blend!"); + assert(Mask[0] >= 0 && "No undef lanes in multi-input v2 shuffles!"); + assert(Mask[1] >= 0 && "No undef lanes in multi-input v2 shuffles!"); + assert(Mask[0] < 2 && "We sort V1 to be the first input."); + assert(Mask[1] >= 2 && "We sort V2 to be the second input."); - // If we have a single input, insert that into V1 if we can do so cheaply. - if ((Mask[0] >= 2) + (Mask[1] >= 2) == 1) { - if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG)) - return Insertion; - // Try inverting the insertion since for v2 masks it is easy to do and we - // can't reliably sort the mask one way or the other. - int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2), - Mask[1] < 0 ? -1 : (Mask[1] ^ 2)}; - if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG)) - return Insertion; - } + // When loading a scalar and then shuffling it into a vector we can often do + // the insertion cheaply. + if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( + DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG)) + return Insertion; + // Try inverting the insertion since for v2 masks it is easy to do and we + // can't reliably sort the mask one way or the other. + int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2), + Mask[1] < 0 ? -1 : (Mask[1] ^ 2)}; + if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( + DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG)) + return Insertion; // Try to use one of the special instruction patterns to handle two common // blend patterns if a zero-blend above didn't work. @@ -10802,8 +11370,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // We can either use a special instruction to load over the low double or // to move just the low double. return DAG.getNode( - isShuffleFoldableLoad(V1S) ? X86ISD::MOVLPD : X86ISD::MOVSD, - DL, MVT::v2f64, V2, + X86ISD::MOVSD, DL, MVT::v2f64, V2, DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64, V1S)); if (Subtarget.hasSSE41()) @@ -10821,7 +11388,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DAG.getConstant(SHUFPDMask, DL, MVT::i8)); } -/// \brief Handle lowering of 2-lane 64-bit integer shuffles. +/// Handle lowering of 2-lane 64-bit integer shuffles. /// /// Tries to lower a 2-lane 64-bit shuffle using shuffle operations provided by /// the integer unit to minimize domain crossing penalties. However, for blends @@ -10918,7 +11485,7 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DAG.getVectorShuffle(MVT::v2f64, DL, V1, V2, Mask)); } -/// \brief Test whether this can be lowered with a single SHUFPS instruction. +/// Test whether this can be lowered with a single SHUFPS instruction. /// /// This is used to disable more specialized lowerings when the shufps lowering /// will happen to be efficient. @@ -10940,7 +11507,7 @@ static bool isSingleSHUFPSMask(ArrayRef<int> Mask) { return true; } -/// \brief Lower a vector shuffle using the SHUFPS instruction. +/// Lower a vector shuffle using the SHUFPS instruction. /// /// This is a helper routine dedicated to lowering vector shuffles using SHUFPS. /// It makes no assumptions about whether this is the *best* lowering, it simply @@ -11027,7 +11594,7 @@ static SDValue lowerVectorShuffleWithSHUFPS(const SDLoc &DL, MVT VT, getV4X86ShuffleImm8ForMask(NewMask, DL, DAG)); } -/// \brief Lower 4-lane 32-bit floating point shuffles. +/// Lower 4-lane 32-bit floating point shuffles. /// /// Uses instructions exclusively from the floating point unit to minimize /// domain crossing penalties, as these are sufficient to implement all v4f32 @@ -11123,7 +11690,7 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithSHUFPS(DL, MVT::v4f32, Mask, V1, V2, DAG); } -/// \brief Lower 4-lane i32 vector shuffles. +/// Lower 4-lane i32 vector shuffles. /// /// We try to handle these with integer-domain shuffles where we can, but for /// blends we use the floating point domain blend instructions. @@ -11235,7 +11802,7 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return DAG.getBitcast(MVT::v4i32, ShufPS); } -/// \brief Lowering of single-input v8i16 shuffles is the cornerstone of SSE2 +/// Lowering of single-input v8i16 shuffles is the cornerstone of SSE2 /// shuffle lowering, and the most complex part. /// /// The lowering strategy is to try to form pairs of input lanes which are @@ -11261,13 +11828,27 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( MutableArrayRef<int> LoMask = Mask.slice(0, 4); MutableArrayRef<int> HiMask = Mask.slice(4, 4); + // Attempt to directly match PSHUFLW or PSHUFHW. + if (isUndefOrInRange(LoMask, 0, 4) && + isSequentialOrUndefInRange(HiMask, 0, 4, 4)) { + return DAG.getNode(X86ISD::PSHUFLW, DL, VT, V, + getV4X86ShuffleImm8ForMask(LoMask, DL, DAG)); + } + if (isUndefOrInRange(HiMask, 4, 8) && + isSequentialOrUndefInRange(LoMask, 0, 4, 0)) { + for (int i = 0; i != 4; ++i) + HiMask[i] = (HiMask[i] < 0 ? HiMask[i] : (HiMask[i] - 4)); + return DAG.getNode(X86ISD::PSHUFHW, DL, VT, V, + getV4X86ShuffleImm8ForMask(HiMask, DL, DAG)); + } + SmallVector<int, 4> LoInputs; copy_if(LoMask, std::back_inserter(LoInputs), [](int M) { return M >= 0; }); - std::sort(LoInputs.begin(), LoInputs.end()); + array_pod_sort(LoInputs.begin(), LoInputs.end()); LoInputs.erase(std::unique(LoInputs.begin(), LoInputs.end()), LoInputs.end()); SmallVector<int, 4> HiInputs; copy_if(HiMask, std::back_inserter(HiInputs), [](int M) { return M >= 0; }); - std::sort(HiInputs.begin(), HiInputs.end()); + array_pod_sort(HiInputs.begin(), HiInputs.end()); HiInputs.erase(std::unique(HiInputs.begin(), HiInputs.end()), HiInputs.end()); int NumLToL = std::lower_bound(LoInputs.begin(), LoInputs.end(), 4) - LoInputs.begin(); @@ -11280,13 +11861,11 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( MutableArrayRef<int> HToLInputs(LoInputs.data() + NumLToL, NumHToL); MutableArrayRef<int> HToHInputs(HiInputs.data() + NumLToH, NumHToH); - // If we are splatting two values from one half - one to each half, then - // we can shuffle that half so each is splatted to a dword, then splat those - // to their respective halves. - auto SplatHalfs = [&](int LoInput, int HiInput, unsigned ShufWOp, - int DOffset) { - int PSHUFHalfMask[] = {LoInput % 4, LoInput % 4, HiInput % 4, HiInput % 4}; - int PSHUFDMask[] = {DOffset + 0, DOffset + 0, DOffset + 1, DOffset + 1}; + // If we are shuffling values from one half - check how many different DWORD + // pairs we need to create. If only 1 or 2 then we can perform this as a + // PSHUFLW/PSHUFHW + PSHUFD instead of the PSHUFD+PSHUFLW+PSHUFHW chain below. + auto ShuffleDWordPairs = [&](ArrayRef<int> PSHUFHalfMask, + ArrayRef<int> PSHUFDMask, unsigned ShufWOp) { V = DAG.getNode(ShufWOp, DL, VT, V, getV4X86ShuffleImm8ForMask(PSHUFHalfMask, DL, DAG)); V = DAG.getBitcast(PSHUFDVT, V); @@ -11295,10 +11874,48 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( return DAG.getBitcast(VT, V); }; - if (NumLToL == 1 && NumLToH == 1 && (NumHToL + NumHToH) == 0) - return SplatHalfs(LToLInputs[0], LToHInputs[0], X86ISD::PSHUFLW, 0); - if (NumHToL == 1 && NumHToH == 1 && (NumLToL + NumLToH) == 0) - return SplatHalfs(HToLInputs[0], HToHInputs[0], X86ISD::PSHUFHW, 2); + if ((NumHToL + NumHToH) == 0 || (NumLToL + NumLToH) == 0) { + int PSHUFDMask[4] = { -1, -1, -1, -1 }; + SmallVector<std::pair<int, int>, 4> DWordPairs; + int DOffset = ((NumHToL + NumHToH) == 0 ? 0 : 2); + + // Collect the different DWORD pairs. + for (int DWord = 0; DWord != 4; ++DWord) { + int M0 = Mask[2 * DWord + 0]; + int M1 = Mask[2 * DWord + 1]; + M0 = (M0 >= 0 ? M0 % 4 : M0); + M1 = (M1 >= 0 ? M1 % 4 : M1); + if (M0 < 0 && M1 < 0) + continue; + + bool Match = false; + for (int j = 0, e = DWordPairs.size(); j < e; ++j) { + auto &DWordPair = DWordPairs[j]; + if ((M0 < 0 || isUndefOrEqual(DWordPair.first, M0)) && + (M1 < 0 || isUndefOrEqual(DWordPair.second, M1))) { + DWordPair.first = (M0 >= 0 ? M0 : DWordPair.first); + DWordPair.second = (M1 >= 0 ? M1 : DWordPair.second); + PSHUFDMask[DWord] = DOffset + j; + Match = true; + break; + } + } + if (!Match) { + PSHUFDMask[DWord] = DOffset + DWordPairs.size(); + DWordPairs.push_back(std::make_pair(M0, M1)); + } + } + + if (DWordPairs.size() <= 2) { + DWordPairs.resize(2, std::make_pair(-1, -1)); + int PSHUFHalfMask[4] = {DWordPairs[0].first, DWordPairs[0].second, + DWordPairs[1].first, DWordPairs[1].second}; + if ((NumHToL + NumHToH) == 0) + return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFLW); + if ((NumLToL + NumLToH) == 0) + return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFHW); + } + } // Simplify the 1-into-3 and 3-into-1 cases with a single pshufd. For all // such inputs we can swap two of the dwords across the half mark and end up @@ -11750,7 +12367,7 @@ static SDValue lowerVectorShuffleAsBlendOfPSHUFBs( return DAG.getBitcast(VT, V); } -/// \brief Generic lowering of 8-lane i16 shuffles. +/// Generic lowering of 8-lane i16 shuffles. /// /// This handles both single-input shuffles and combined shuffle/blends with /// two inputs. The single input shuffles are immediately delegated to @@ -11883,7 +12500,7 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Mask, DAG); } -/// \brief Check whether a compaction lowering can be done by dropping even +/// Check whether a compaction lowering can be done by dropping even /// elements and compute how many times even elements must be dropped. /// /// This handles shuffles which take every Nth element where N is a power of @@ -11962,7 +12579,7 @@ static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); } -/// \brief Generic lowering of v16i8 shuffles. +/// Generic lowering of v16i8 shuffles. /// /// This is a hybrid strategy to lower v16i8 vectors. It first attempts to /// detect any complexity reducing interleaving. If that doesn't help, it uses @@ -12034,12 +12651,12 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, SmallVector<int, 4> LoInputs; copy_if(Mask, std::back_inserter(LoInputs), [](int M) { return M >= 0 && M < 8; }); - std::sort(LoInputs.begin(), LoInputs.end()); + array_pod_sort(LoInputs.begin(), LoInputs.end()); LoInputs.erase(std::unique(LoInputs.begin(), LoInputs.end()), LoInputs.end()); SmallVector<int, 4> HiInputs; copy_if(Mask, std::back_inserter(HiInputs), [](int M) { return M >= 8; }); - std::sort(HiInputs.begin(), HiInputs.end()); + array_pod_sort(HiInputs.begin(), HiInputs.end()); HiInputs.erase(std::unique(HiInputs.begin(), HiInputs.end()), HiInputs.end()); @@ -12262,7 +12879,7 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, LoV, HiV); } -/// \brief Dispatching routine to lower various 128-bit x86 vector shuffles. +/// Dispatching routine to lower various 128-bit x86 vector shuffles. /// /// This routine breaks down the specific type of 128-bit shuffle and /// dispatches to the lowering routines accordingly. @@ -12290,7 +12907,7 @@ static SDValue lower128BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } } -/// \brief Generic routine to split vector shuffle into half-sized shuffles. +/// Generic routine to split vector shuffle into half-sized shuffles. /// /// This routine just extracts two subvectors, shuffles them independently, and /// then concatenates them back together. This should work effectively with all @@ -12413,7 +13030,7 @@ static SDValue splitAndLowerVectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); } -/// \brief Either split a vector in halves or decompose the shuffles and the +/// Either split a vector in halves or decompose the shuffles and the /// blend. /// /// This is provided as a good fallback for many lowerings of non-single-input @@ -12471,7 +13088,7 @@ static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, return lowerVectorShuffleAsDecomposedShuffleBlend(DL, VT, V1, V2, Mask, DAG); } -/// \brief Lower a vector shuffle crossing multiple 128-bit lanes as +/// Lower a vector shuffle crossing multiple 128-bit lanes as /// a permutation and blend of those lanes. /// /// This essentially blends the out-of-lane inputs to each lane into the lane @@ -12529,7 +13146,7 @@ static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT, return DAG.getVectorShuffle(VT, DL, V1, Flipped, FlippedBlendMask); } -/// \brief Handle lowering 2-lane 128-bit shuffles. +/// Handle lowering 2-lane 128-bit shuffles. static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, const APInt &Zeroable, @@ -12540,9 +13157,22 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, return SDValue(); SmallVector<int, 4> WidenedMask; - if (!canWidenShuffleElements(Mask, WidenedMask)) + if (!canWidenShuffleElements(Mask, Zeroable, WidenedMask)) return SDValue(); + bool IsLowZero = (Zeroable & 0x3) == 0x3; + bool IsHighZero = (Zeroable & 0xc) == 0xc; + + // Try to use an insert into a zero vector. + if (WidenedMask[0] == 0 && IsHighZero) { + MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2); + SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1, + DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + getZeroVector(VT, Subtarget, DAG, DL), LoV, + DAG.getIntPtrConstant(0, DL)); + } + // TODO: If minimizing size and one of the inputs is a zero vector and the // the zero vector has only one use, we could use a VPERM2X128 to save the // instruction bytes needed to explicitly generate the zero vector. @@ -12552,9 +13182,6 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, Zeroable, Subtarget, DAG)) return Blend; - bool IsLowZero = (Zeroable & 0x3) == 0x3; - bool IsHighZero = (Zeroable & 0xc) == 0xc; - // If either input operand is a zero vector, use VPERM2X128 because its mask // allows us to replace the zero input with an implicit zero. if (!IsLowZero && !IsHighZero) { @@ -12566,14 +13193,12 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, // With AVX1, use vperm2f128 (below) to allow load folding. Otherwise, // this will likely become vinsertf128 which can't fold a 256-bit memop. if (!isa<LoadSDNode>(peekThroughBitcasts(V1))) { - MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), - VT.getVectorNumElements() / 2); - SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1, - DAG.getIntPtrConstant(0, DL)); - SDValue HiV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, - OnlyUsesV1 ? V1 : V2, - DAG.getIntPtrConstant(0, DL)); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV); + MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2); + SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, + OnlyUsesV1 ? V1 : V2, + DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec, + DAG.getIntPtrConstant(2, DL)); } } @@ -12601,7 +13226,8 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, // [6] - ignore // [7] - zero high half of destination - assert(WidenedMask[0] >= 0 && WidenedMask[1] >= 0 && "Undef half?"); + assert((WidenedMask[0] >= 0 || IsLowZero) && + (WidenedMask[1] >= 0 || IsHighZero) && "Undef half?"); unsigned PermMask = 0; PermMask |= IsLowZero ? 0x08 : (WidenedMask[0] << 0); @@ -12617,7 +13243,7 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, DAG.getConstant(PermMask, DL, MVT::i8)); } -/// \brief Lower a vector shuffle by first fixing the 128-bit lanes and then +/// Lower a vector shuffle by first fixing the 128-bit lanes and then /// shuffling each lane. /// /// This will only succeed when the result of fixing the 128-bit lanes results @@ -12820,7 +13446,7 @@ static SDValue lowerVectorShuffleWithUndefHalf(const SDLoc &DL, MVT VT, DAG.getIntPtrConstant(Offset, DL)); } -/// \brief Test whether the specified input (0 or 1) is in-place blended by the +/// Test whether the specified input (0 or 1) is in-place blended by the /// given mask. /// /// This returns true if the elements from a particular input are already in the @@ -13056,7 +13682,7 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, DAG.getConstant(Immediate, DL, MVT::i8)); } -/// \brief Handle lowering of 4-lane 64-bit floating point shuffles. +/// Handle lowering of 4-lane 64-bit floating point shuffles. /// /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2 /// isn't available. @@ -13098,7 +13724,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, getV4X86ShuffleImm8ForMask(Mask, DL, DAG)); // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) return V; @@ -13123,7 +13749,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return Op; // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) return V; @@ -13153,7 +13779,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v4f64, V1, V2, Mask, DAG); } -/// \brief Handle lowering of 4-lane 64-bit integer shuffles. +/// Handle lowering of 4-lane 64-bit integer shuffles. /// /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v4i64 shuffling.. @@ -13226,6 +13852,12 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v4i64, Mask, V1, V2, DAG)) return V; + // Try to create an in-lane repeating shuffle mask and then shuffle the + // results into the target lanes. + if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( + DL, MVT::v4i64, V1, V2, Mask, Subtarget, DAG)) + return V; + // Try to simplify this by merging 128-bit lanes to enable a lane-based // shuffle. However, if we have AVX2 and either inputs are already in place, // we will be able to shuffle even across lanes the other input in a single @@ -13241,7 +13873,7 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Mask, DAG); } -/// \brief Handle lowering of 8-lane 32-bit floating point shuffles. +/// Handle lowering of 8-lane 32-bit floating point shuffles. /// /// Also ends up handling lowering of 8-lane 32-bit integer shuffles when AVX2 /// isn't available. @@ -13291,7 +13923,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v8f32, V1, V2, Mask, Subtarget, DAG)) return V; @@ -13340,7 +13972,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask, DAG); } -/// \brief Handle lowering of 8-lane 32-bit integer shuffles. +/// Handle lowering of 8-lane 32-bit integer shuffles. /// /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v8i32 shuffling.. @@ -13453,7 +14085,7 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Mask, DAG); } -/// \brief Handle lowering of 16-lane 16-bit integer shuffles. +/// Handle lowering of 16-lane 16-bit integer shuffles. /// /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v16i16 shuffling.. @@ -13504,7 +14136,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return Rotate; // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v16i16, V1, V2, Mask, Subtarget, DAG)) return V; @@ -13544,7 +14176,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v16i16, V1, V2, Mask, DAG); } -/// \brief Handle lowering of 32-lane 8-bit integer shuffles. +/// Handle lowering of 32-lane 8-bit integer shuffles. /// /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v32i8 shuffling.. @@ -13595,7 +14227,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return Rotate; // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v32i8, V1, V2, Mask, Subtarget, DAG)) return V; @@ -13624,7 +14256,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v32i8, V1, V2, Mask, DAG); } -/// \brief High-level routine to lower various 256-bit x86 vector shuffles. +/// High-level routine to lower various 256-bit x86 vector shuffles. /// /// This routine either breaks down the specific type of a 256-bit x86 vector /// shuffle or splits it into two 128-bit shuffles and fuses the results back @@ -13694,10 +14326,13 @@ static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } } -/// \brief Try to lower a vector shuffle as a 128-bit shuffles. +/// Try to lower a vector shuffle as a 128-bit shuffles. static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, - ArrayRef<int> Mask, SDValue V1, - SDValue V2, SelectionDAG &DAG) { + ArrayRef<int> Mask, + const APInt &Zeroable, + SDValue V1, SDValue V2, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { assert(VT.getScalarSizeInBits() == 64 && "Unexpected element type size for 128bit shuffle."); @@ -13705,10 +14340,23 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, // function lowerV2X128VectorShuffle() is better solution. assert(VT.is512BitVector() && "Unexpected vector size for 512bit shuffle."); + // TODO - use Zeroable like we do for lowerV2X128VectorShuffle? SmallVector<int, 4> WidenedMask; if (!canWidenShuffleElements(Mask, WidenedMask)) return SDValue(); + // Try to use an insert into a zero vector. + if (WidenedMask[0] == 0 && (Zeroable & 0xf0) == 0xf0 && + (WidenedMask[1] == 1 || (Zeroable & 0x0c) == 0x0c)) { + unsigned NumElts = ((Zeroable & 0x0c) == 0x0c) ? 2 : 4; + MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1, + DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + getZeroVector(VT, Subtarget, DAG, DL), LoV, + DAG.getIntPtrConstant(0, DL)); + } + // Check for patterns which can be matched with a single insert of a 256-bit // subvector. bool OnlyUsesV1 = isShuffleEquivalent(V1, V2, Mask, @@ -13716,12 +14364,11 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, if (OnlyUsesV1 || isShuffleEquivalent(V1, V2, Mask, {0, 1, 2, 3, 8, 9, 10, 11})) { MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 4); - SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1, - DAG.getIntPtrConstant(0, DL)); - SDValue HiV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, - OnlyUsesV1 ? V1 : V2, + SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, + OnlyUsesV1 ? V1 : V2, DAG.getIntPtrConstant(0, DL)); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec, + DAG.getIntPtrConstant(4, DL)); } assert(WidenedMask.size() == 4); @@ -13756,7 +14403,7 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, return insert128BitVector(V1, Subvec, V2Index * 2, DAG, DL); } - // Try to lower to to vshuf64x2/vshuf32x4. + // Try to lower to vshuf64x2/vshuf32x4. SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)}; unsigned PermMask = 0; // Insure elements came from the same Op. @@ -13781,7 +14428,7 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, DAG.getConstant(PermMask, DL, MVT::i8)); } -/// \brief Handle lowering of 8-lane 64-bit floating point shuffles. +/// Handle lowering of 8-lane 64-bit floating point shuffles. static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -13814,7 +14461,8 @@ static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } if (SDValue Shuf128 = - lowerV4X128VectorShuffle(DL, MVT::v8f64, Mask, V1, V2, DAG)) + lowerV4X128VectorShuffle(DL, MVT::v8f64, Mask, Zeroable, V1, V2, + Subtarget, DAG)) return Shuf128; if (SDValue Unpck = @@ -13837,7 +14485,7 @@ static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v8f64, Mask, V1, V2, DAG); } -/// \brief Handle lowering of 16-lane 32-bit floating point shuffles. +/// Handle lowering of 16-lane 32-bit floating point shuffles. static SDValue lowerV16F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -13892,7 +14540,7 @@ static SDValue lowerV16F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v16f32, Mask, V1, V2, DAG); } -/// \brief Handle lowering of 8-lane 64-bit integer shuffles. +/// Handle lowering of 8-lane 64-bit integer shuffles. static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -13924,7 +14572,8 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } if (SDValue Shuf128 = - lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, V1, V2, DAG)) + lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, Zeroable, + V1, V2, Subtarget, DAG)) return Shuf128; // Try to use shift instructions. @@ -13957,7 +14606,7 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, DAG); } -/// \brief Handle lowering of 16-lane 32-bit integer shuffles. +/// Handle lowering of 16-lane 32-bit integer shuffles. static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -14028,7 +14677,7 @@ static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v16i32, Mask, V1, V2, DAG); } -/// \brief Handle lowering of 32-lane 16-bit integer shuffles. +/// Handle lowering of 32-lane 16-bit integer shuffles. static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -14083,7 +14732,7 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG); } -/// \brief Handle lowering of 64-lane 8-bit integer shuffles. +/// Handle lowering of 64-lane 8-bit integer shuffles. static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, SDValue V1, SDValue V2, @@ -14125,7 +14774,7 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleWithPERMV(DL, MVT::v64i8, Mask, V1, V2, DAG); // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG)) return V; @@ -14138,7 +14787,7 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return splitAndLowerVectorShuffle(DL, MVT::v64i8, V1, V2, Mask, DAG); } -/// \brief High-level routine to lower various 512-bit x86 vector shuffles. +/// High-level routine to lower various 512-bit x86 vector shuffles. /// /// This routine either breaks down the specific type of a 512-bit x86 vector /// shuffle or splits it into two 256-bit shuffles and fuses the results back @@ -14200,8 +14849,36 @@ static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // vector, shuffle and then truncate it back. static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, SDValue V1, SDValue V2, + const APInt &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + unsigned NumElts = Mask.size(); + + // Try to recognize shuffles that are just padding a subvector with zeros. + unsigned SubvecElts = 0; + for (int i = 0; i != (int)NumElts; ++i) { + if (Mask[i] >= 0 && Mask[i] != i) + break; + + ++SubvecElts; + } + assert(SubvecElts != NumElts && "Identity shuffle?"); + + // Clip to a power 2. + SubvecElts = PowerOf2Floor(SubvecElts); + + // Make sure the number of zeroable bits in the top at least covers the bits + // not covered by the subvector. + if (Zeroable.countLeadingOnes() >= (NumElts - SubvecElts)) { + MVT ExtractVT = MVT::getVectorVT(MVT::i1, SubvecElts); + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, + V1, DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + getZeroVector(VT, Subtarget, DAG, DL), + Extract, DAG.getIntPtrConstant(0, DL)); + } + + assert(Subtarget.hasAVX512() && "Cannot lower 512-bit vectors w/o basic ISA!"); MVT ExtVT; @@ -14220,38 +14897,31 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64; break; case MVT::v16i1: - ExtVT = MVT::v16i32; + // Take 512-bit type, unless we are avoiding 512-bit types and have the + // 256-bit operation available. + ExtVT = Subtarget.canExtendTo512DQ() ? MVT::v16i32 : MVT::v16i16; break; case MVT::v32i1: - ExtVT = MVT::v32i16; + // Take 512-bit type, unless we are avoiding 512-bit types and have the + // 256-bit operation available. + assert(Subtarget.hasBWI() && "Expected AVX512BW support"); + ExtVT = Subtarget.canExtendTo512BW() ? MVT::v32i16 : MVT::v32i8; break; case MVT::v64i1: ExtVT = MVT::v64i8; break; } - if (ISD::isBuildVectorAllZeros(V1.getNode())) - V1 = getZeroVector(ExtVT, Subtarget, DAG, DL); - else if (ISD::isBuildVectorAllOnes(V1.getNode())) - V1 = getOnesVector(ExtVT, DAG, DL); - else - V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V1); - - if (V2.isUndef()) - V2 = DAG.getUNDEF(ExtVT); - else if (ISD::isBuildVectorAllZeros(V2.getNode())) - V2 = getZeroVector(ExtVT, Subtarget, DAG, DL); - else if (ISD::isBuildVectorAllOnes(V2.getNode())) - V2 = getOnesVector(ExtVT, DAG, DL); - else - V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2); + V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V1); + V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2); SDValue Shuffle = DAG.getVectorShuffle(ExtVT, DL, V1, V2, Mask); // i1 was sign extended we can use X86ISD::CVT2MASK. int NumElems = VT.getVectorNumElements(); if ((Subtarget.hasBWI() && (NumElems >= 32)) || (Subtarget.hasDQI() && (NumElems < 32))) - return DAG.getNode(X86ISD::CVT2MASK, DL, VT, Shuffle); + return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, ExtVT), + Shuffle, ISD::SETGT); return DAG.getNode(ISD::TRUNCATE, DL, VT, Shuffle); } @@ -14320,7 +14990,7 @@ static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) { return false; } -/// \brief Top-level lowering for x86 vector shuffles. +/// Top-level lowering for x86 vector shuffles. /// /// This handles decomposition, canonicalization, and lowering of all x86 /// vector shuffles. Most of the specific lowering strategies are encapsulated @@ -14378,20 +15048,49 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, if (Zeroable.isAllOnesValue()) return getZeroVector(VT, Subtarget, DAG, DL); + bool V2IsZero = !V2IsUndef && ISD::isBuildVectorAllZeros(V2.getNode()); + + // Create an alternative mask with info about zeroable elements. + // Here we do not set undef elements as zeroable. + SmallVector<int, 64> ZeroableMask(Mask.begin(), Mask.end()); + if (V2IsZero) { + assert(!Zeroable.isNullValue() && "V2's non-undef elements are used?!"); + for (int i = 0; i != NumElements; ++i) + if (Mask[i] != SM_SentinelUndef && Zeroable[i]) + ZeroableMask[i] = SM_SentinelZero; + } + // Try to collapse shuffles into using a vector type with fewer elements but // wider element types. We cap this to not form integers or floating point // elements wider than 64 bits, but it might be interesting to form i128 // integers to handle flipping the low and high halves of AVX 256-bit vectors. SmallVector<int, 16> WidenedMask; if (VT.getScalarSizeInBits() < 64 && !Is1BitVector && - canWidenShuffleElements(Mask, WidenedMask)) { + canWidenShuffleElements(ZeroableMask, WidenedMask)) { MVT NewEltVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(VT.getScalarSizeInBits() * 2) : MVT::getIntegerVT(VT.getScalarSizeInBits() * 2); - MVT NewVT = MVT::getVectorVT(NewEltVT, VT.getVectorNumElements() / 2); + int NewNumElts = NumElements / 2; + MVT NewVT = MVT::getVectorVT(NewEltVT, NewNumElts); // Make sure that the new vector type is legal. For example, v2f64 isn't // legal on SSE1. if (DAG.getTargetLoweringInfo().isTypeLegal(NewVT)) { + if (V2IsZero) { + // Modify the new Mask to take all zeros from the all-zero vector. + // Choose indices that are blend-friendly. + bool UsedZeroVector = false; + assert(find(WidenedMask, SM_SentinelZero) != WidenedMask.end() && + "V2's non-undef elements are used?!"); + for (int i = 0; i != NewNumElts; ++i) + if (WidenedMask[i] == SM_SentinelZero) { + WidenedMask[i] = i + NewNumElts; + UsedZeroVector = true; + } + // Ensure all elements of V2 are zero - isBuildVectorAllZeros permits + // some elements to be undef. + if (UsedZeroVector) + V2 = getZeroVector(NewVT, Subtarget, DAG, DL); + } V1 = DAG.getBitcast(NewVT, V1); V2 = DAG.getBitcast(NewVT, V2); return DAG.getBitcast( @@ -14403,6 +15102,10 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, if (canonicalizeShuffleMaskWithCommute(Mask)) return DAG.getCommutedVectorShuffle(*SVOp); + if (SDValue V = + lowerVectorShuffleWithVPMOV(DL, Mask, VT, V1, V2, DAG, Subtarget)) + return V; + // For each vector width, delegate to a specialized lowering routine. if (VT.is128BitVector()) return lower128BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, @@ -14417,12 +15120,13 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, DAG); if (Is1BitVector) - return lower1BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG); + return lower1BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, + DAG); llvm_unreachable("Unimplemented!"); } -/// \brief Try to lower a VSELECT instruction to a vector shuffle. +/// Try to lower a VSELECT instruction to a vector shuffle. static SDValue lowerVSELECTtoVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -14441,9 +15145,12 @@ static SDValue lowerVSELECTtoVectorShuffle(SDValue Op, SmallVector<int, 32> Mask; for (int i = 0, Size = VT.getVectorNumElements(); i < Size; ++i) { SDValue CondElt = CondBV->getOperand(i); - Mask.push_back( - isa<ConstantSDNode>(CondElt) ? i + (isNullConstant(CondElt) ? Size : 0) - : -1); + int M = i; + // We can't map undef to undef here. They have different meanings. Treat + // as the same as zero. + if (CondElt.isUndef() || isNullConstant(CondElt)) + M += Size; + Mask.push_back(M); } return DAG.getVectorShuffle(VT, dl, LHS, RHS, Mask); } @@ -14483,9 +15190,11 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { assert(Cond.getValueType().getScalarSizeInBits() == VT.getScalarSizeInBits() && "Should have a size-matched integer condition!"); - // Build a mask by testing the condition against itself (tests for zero). + // Build a mask by testing the condition against zero. MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue Mask = DAG.getNode(X86ISD::TESTM, dl, MaskVT, Cond, Cond); + SDValue Mask = DAG.getSetCC(dl, MaskVT, Cond, + getZeroVector(VT, Subtarget, DAG, dl), + ISD::SETNE); // Now return a new VSELECT using the mask. return DAG.getSelect(dl, VT, Mask, Op.getOperand(1), Op.getOperand(2)); } @@ -14506,10 +15215,15 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { return SDValue(); case MVT::v8i16: - case MVT::v16i16: - // FIXME: We should custom lower this by fixing the condition and using i8 - // blends. - return SDValue(); + case MVT::v16i16: { + // Bitcast everything to the vXi8 type and use a vXi8 vselect. + MVT CastVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2); + SDValue Cond = DAG.getBitcast(CastVT, Op->getOperand(0)); + SDValue LHS = DAG.getBitcast(CastVT, Op->getOperand(1)); + SDValue RHS = DAG.getBitcast(CastVT, Op->getOperand(2)); + SDValue Select = DAG.getNode(ISD::VSELECT, dl, CastVT, Cond, LHS, RHS); + return DAG.getBitcast(VT, Select); + } } } @@ -14581,36 +15295,35 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); } - // Canonicalize result type to MVT::i32. - if (EltVT != MVT::i32) { - SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, - Vec, Idx); - return DAG.getAnyExtOrTrunc(Extract, dl, EltVT); - } - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - // Extracts from element 0 are always allowed. - if (IdxVal == 0) - return Op; - // If the kshift instructions of the correct width aren't natively supported // then we need to promote the vector to the native size to get the correct // zeroing behavior. - if ((!Subtarget.hasDQI() && (VecVT.getVectorNumElements() == 8)) || - (VecVT.getVectorNumElements() < 8)) { + if (VecVT.getVectorNumElements() < 16) { VecVT = MVT::v16i1; - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, - DAG.getUNDEF(VecVT), - Vec, + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, + DAG.getUNDEF(VecVT), Vec, DAG.getIntPtrConstant(0, dl)); } - // Use kshiftr instruction to move to the lower element. - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, Vec, - DAG.getIntPtrConstant(0, dl)); + // Extracts from element 0 are always allowed. + if (IdxVal != 0) { + // Use kshiftr instruction to move to the lower element. + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + } + + // Shrink to v16i1 since that's always legal. + if (VecVT.getVectorNumElements() > 16) { + VecVT = MVT::v16i1; + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Vec, + DAG.getIntPtrConstant(0, dl)); + } + + // Convert to a bitcast+aext/trunc. + MVT CastVT = MVT::getIntegerVT(VecVT.getVectorNumElements()); + return DAG.getAnyExtOrTrunc(DAG.getBitcast(CastVT, Vec), dl, EltVT); } SDValue @@ -14713,7 +15426,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, int ShiftVal = (IdxVal % 4) * 8; if (ShiftVal != 0) Res = DAG.getNode(ISD::SRL, dl, MVT::i32, Res, - DAG.getConstant(ShiftVal, dl, MVT::i32)); + DAG.getConstant(ShiftVal, dl, MVT::i8)); return DAG.getNode(ISD::TRUNCATE, dl, VT, Res); } @@ -14724,7 +15437,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, int ShiftVal = (IdxVal % 2) * 8; if (ShiftVal != 0) Res = DAG.getNode(ISD::SRL, dl, MVT::i16, Res, - DAG.getConstant(ShiftVal, dl, MVT::i16)); + DAG.getConstant(ShiftVal, dl, MVT::i8)); return DAG.getNode(ISD::TRUNCATE, dl, VT, Res); } @@ -14780,74 +15493,11 @@ static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp); } - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - unsigned NumElems = VecVT.getVectorNumElements(); - - // If the kshift instructions of the correct width aren't natively supported - // then we need to promote the vector to the native size to get the correct - // zeroing behavior. - if ((!Subtarget.hasDQI() && NumElems == 8) || (NumElems < 8)) { - // Need to promote to v16i1, do the insert, then extract back. - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, - DAG.getUNDEF(MVT::v16i1), Vec, - DAG.getIntPtrConstant(0, dl)); - Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v16i1, Vec, Elt, Idx); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Op, - DAG.getIntPtrConstant(0, dl)); - } - - SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Elt); + // Copy into a k-register, extract to v1i1 and insert_subvector. + SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Elt); - if (Vec.isUndef()) { - if (IdxVal) - EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - return EltInVec; - } - - // Insertion of one bit into first position - if (IdxVal == 0 ) { - // Clean top bits of vector. - EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec, - DAG.getConstant(NumElems - 1, dl, MVT::i8)); - EltInVec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, EltInVec, - DAG.getConstant(NumElems - 1, dl, MVT::i8)); - // Clean the first bit in source vector. - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(1 , dl, MVT::i8)); - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec, - DAG.getConstant(1, dl, MVT::i8)); - - return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); - } - // Insertion of one bit into last position - if (IdxVal == NumElems - 1) { - // Move the bit to the last position inside the vector. - EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - // Clean the last bit in the source vector. - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec, - DAG.getConstant(1, dl, MVT::i8)); - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(1 , dl, MVT::i8)); - - return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); - } - - // Move the current value of the bit to be replace to bit 0. - SDValue Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - // Xor with the new bit. - Merged = DAG.getNode(ISD::XOR, dl, VecVT, Merged, EltInVec); - // Shift to MSB, filling bottom bits with 0. - Merged = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Merged, - DAG.getConstant(NumElems - 1, dl, MVT::i8)); - // Shift to the final position, filling upper bits with 0. - Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Merged, - DAG.getConstant(NumElems - 1 - IdxVal, dl, MVT::i8)); - // Xor with original vector to cancel out the original bit value that's still - // present. - return DAG.getNode(ISD::XOR, dl, VecVT, Merged, Vec); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, Vec, EltInVec, + Op.getOperand(2)); } SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, @@ -15020,8 +15670,45 @@ static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, return insert1BitVector(Op, DAG, Subtarget); } +static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1 && + "Only vXi1 extract_subvectors need custom lowering"); + + SDLoc dl(Op); + SDValue Vec = Op.getOperand(0); + SDValue Idx = Op.getOperand(1); + + if (!isa<ConstantSDNode>(Idx)) + return SDValue(); + + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + if (IdxVal == 0) // the operation is legal + return Op; + + MVT VecVT = Vec.getSimpleValueType(); + unsigned NumElems = VecVT.getVectorNumElements(); + + // Extend to natively supported kshift. + MVT WideVecVT = VecVT; + if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) { + WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT, + DAG.getUNDEF(WideVecVT), Vec, + DAG.getIntPtrConstant(0, dl)); + } + + // Shift to the LSB. + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec, + DAG.getIntPtrConstant(0, dl)); +} + // Returns the appropriate wrapper opcode for a global reference. -unsigned X86TargetLowering::getGlobalWrapperKind(const GlobalValue *GV) const { +unsigned X86TargetLowering::getGlobalWrapperKind( + const GlobalValue *GV, const unsigned char OpFlags) const { // References to absolute symbols are never PC-relative. if (GV && GV->isAbsoluteSymbolRef()) return X86ISD::Wrapper; @@ -15031,6 +15718,10 @@ unsigned X86TargetLowering::getGlobalWrapperKind(const GlobalValue *GV) const { (M == CodeModel::Small || M == CodeModel::Kernel)) return X86ISD::WrapperRIP; + // GOTPCREL references must always use RIP. + if (OpFlags == X86II::MO_GOTPCREL) + return X86ISD::WrapperRIP; + return X86ISD::Wrapper; } @@ -15154,7 +15845,7 @@ SDValue X86TargetLowering::LowerGlobalAddress(const GlobalValue *GV, Result = DAG.getTargetGlobalAddress(GV, dl, PtrVT, 0, OpFlags); } - Result = DAG.getNode(getGlobalWrapperKind(GV), dl, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (isGlobalRelativeToPICBase(OpFlags)) { @@ -15336,7 +16027,7 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op); - if (DAG.getTarget().Options.EmulatedTLS) + if (DAG.getTarget().useEmulatedTLS()) return LowerToTLSEmulatedModel(GA, DAG); const GlobalValue *GV = GA->getGlobal(); @@ -15456,7 +16147,7 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { auto &DL = DAG.getDataLayout(); SDValue Scale = - DAG.getConstant(Log2_64_Ceil(DL.getPointerSize()), dl, PtrVT); + DAG.getConstant(Log2_64_Ceil(DL.getPointerSize()), dl, MVT::i8); IDX = DAG.getNode(ISD::SHL, dl, PtrVT, IDX, Scale); res = DAG.getNode(ISD::ADD, dl, PtrVT, ThreadPointer, IDX); @@ -15512,24 +16203,47 @@ static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) { // values for large shift amounts. SDValue AndNode = DAG.getNode(ISD::AND, dl, MVT::i8, ShAmt, DAG.getConstant(VTBits, dl, MVT::i8)); - SDValue Cond = DAG.getNode(X86ISD::CMP, dl, MVT::i32, - AndNode, DAG.getConstant(0, dl, MVT::i8)); + SDValue Cond = DAG.getSetCC(dl, MVT::i8, AndNode, + DAG.getConstant(0, dl, MVT::i8), ISD::SETNE); SDValue Hi, Lo; - SDValue CC = DAG.getConstant(X86::COND_NE, dl, MVT::i8); - SDValue Ops0[4] = { Tmp2, Tmp3, CC, Cond }; - SDValue Ops1[4] = { Tmp3, Tmp1, CC, Cond }; - if (Op.getOpcode() == ISD::SHL_PARTS) { - Hi = DAG.getNode(X86ISD::CMOV, dl, VT, Ops0); - Lo = DAG.getNode(X86ISD::CMOV, dl, VT, Ops1); + Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2); + Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3); } else { - Lo = DAG.getNode(X86ISD::CMOV, dl, VT, Ops0); - Hi = DAG.getNode(X86ISD::CMOV, dl, VT, Ops1); + Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2); + Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3); } - SDValue Ops[2] = { Lo, Hi }; - return DAG.getMergeValues(Ops, dl); + return DAG.getMergeValues({ Lo, Hi }, dl); +} + +// Try to use a packed vector operation to handle i64 on 32-bit targets when +// AVX512DQ is enabled. +static SDValue LowerI64IntToFP_AVX512DQ(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert((Op.getOpcode() == ISD::SINT_TO_FP || + Op.getOpcode() == ISD::UINT_TO_FP) && "Unexpected opcode!"); + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + MVT VT = Op.getSimpleValueType(); + + if (!Subtarget.hasDQI() || SrcVT != MVT::i64 || Subtarget.is64Bit() || + (VT != MVT::f32 && VT != MVT::f64)) + return SDValue(); + + // Pack the i64 into a vector, do the operation and extract. + + // Using 256-bit to ensure result is 128-bits for f32 case. + unsigned NumElts = Subtarget.hasVLX() ? 4 : 8; + MVT VecInVT = MVT::getVectorVT(MVT::i64, NumElts); + MVT VecVT = MVT::getVectorVT(VT, NumElts); + + SDLoc dl(Op); + SDValue InVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecInVT, Src); + SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, VecVT, InVec); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec, + DAG.getIntPtrConstant(0, dl)); } SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, @@ -15545,20 +16259,6 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, DAG.getUNDEF(SrcVT))); } - if (SrcVT.getVectorElementType() == MVT::i1) { - if (SrcVT == MVT::v2i1) { - // For v2i1, we need to widen to v4i1 first. - assert(VT == MVT::v2f64 && "Unexpected type"); - Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Src, - DAG.getUNDEF(MVT::v2i1)); - return DAG.getNode(X86ISD::CVTSI2P, dl, Op.getValueType(), - DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Src)); - } - - MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); - return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::SIGN_EXTEND, dl, IntegerVT, Src)); - } return SDValue(); } @@ -15567,15 +16267,17 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, // These are really Legal; return the operand so the caller accepts it as // Legal. - if (SrcVT == MVT::i32 && isScalarFPTypeInSSEReg(Op.getValueType())) + if (SrcVT == MVT::i32 && isScalarFPTypeInSSEReg(VT)) return Op; - if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(Op.getValueType()) && - Subtarget.is64Bit()) { + if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) && Subtarget.is64Bit()) { return Op; } + if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, DAG, Subtarget)) + return V; + SDValue ValueToStore = Op.getOperand(0); - if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(Op.getValueType()) && + if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) && !Subtarget.is64Bit()) // Bitcasting to f64 here allows us to do a single 64-bit store from // an SSE register, avoiding the store forwarding penalty that would come @@ -15760,7 +16462,8 @@ static SDValue LowerUINT_TO_FP_i32(SDValue Op, SelectionDAG &DAG, } static SDValue lowerUINT_TO_FP_v2i32(SDValue Op, SelectionDAG &DAG, - const X86Subtarget &Subtarget, SDLoc &DL) { + const X86Subtarget &Subtarget, + const SDLoc &DL) { if (Op.getSimpleValueType() != MVT::v2f64) return SDValue(); @@ -15894,21 +16597,6 @@ static SDValue lowerUINT_TO_FP_vec(SDValue Op, SelectionDAG &DAG, MVT SrcVT = N0.getSimpleValueType(); SDLoc dl(Op); - if (SrcVT.getVectorElementType() == MVT::i1) { - if (SrcVT == MVT::v2i1) { - // For v2i1, we need to widen to v4i1 first. - assert(Op.getValueType() == MVT::v2f64 && "Unexpected type"); - N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, N0, - DAG.getUNDEF(MVT::v2i1)); - return DAG.getNode(X86ISD::CVTUI2P, dl, MVT::v2f64, - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v4i32, N0)); - } - - MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); - return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::ZERO_EXTEND, dl, IntegerVT, N0)); - } - switch (SrcVT.SimpleTy) { default: llvm_unreachable("Custom UINT_TO_FP is not supported!"); @@ -15940,6 +16628,9 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op, return Op; } + if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, DAG, Subtarget)) + return V; + if (SrcVT == MVT::i64 && DstVT == MVT::f64 && X86ScalarSSEf64) return LowerUINT_TO_FP_i64(Op, DAG, Subtarget); if (SrcVT == MVT::i32 && X86ScalarSSEf64) @@ -16205,15 +16896,17 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, MVT InVT = In.getSimpleValueType(); SDLoc dl(Op); - if ((VT != MVT::v4i64 || InVT != MVT::v4i32) && - (VT != MVT::v8i32 || InVT != MVT::v8i16) && - (VT != MVT::v16i16 || InVT != MVT::v16i8) && - (VT != MVT::v8i64 || InVT != MVT::v8i32) && - (VT != MVT::v8i64 || InVT != MVT::v8i16) && - (VT != MVT::v16i32 || InVT != MVT::v16i16) && - (VT != MVT::v16i32 || InVT != MVT::v16i8) && - (VT != MVT::v32i16 || InVT != MVT::v32i8)) - return SDValue(); + assert(VT.isVector() && InVT.isVector() && "Expected vector type"); + assert(VT.getVectorNumElements() == VT.getVectorNumElements() && + "Expected same number of elements"); + assert((VT.getVectorElementType() == MVT::i16 || + VT.getVectorElementType() == MVT::i32 || + VT.getVectorElementType() == MVT::i64) && + "Unexpected element type"); + assert((InVT.getVectorElementType() == MVT::i8 || + InVT.getVectorElementType() == MVT::i16 || + InVT.getVectorElementType() == MVT::i32) && + "Unexpected element type"); if (Subtarget.hasInt256()) return DAG.getNode(X86ISD::VZEXT, dl, VT, In); @@ -16246,6 +16939,20 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } +// Helper to split and extend a v16i1 mask to v16i8 or v16i16. +static SDValue SplitAndExtendv16i1(unsigned ExtOpc, MVT VT, SDValue In, + const SDLoc &dl, SelectionDAG &DAG) { + assert((VT == MVT::v16i8 || VT == MVT::v16i16) && "Unexpected VT."); + SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In, + DAG.getIntPtrConstant(0, dl)); + SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In, + DAG.getIntPtrConstant(8, dl)); + Lo = DAG.getNode(ExtOpc, dl, MVT::v8i16, Lo); + Hi = DAG.getNode(ExtOpc, dl, MVT::v8i16, Hi); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i16, Lo, Hi); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Res); +} + static SDValue LowerZERO_EXTEND_Mask(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -16256,11 +16963,23 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op, SDLoc DL(Op); unsigned NumElts = VT.getVectorNumElements(); - // Extend VT if the scalar type is v8/v16 and BWI is not supported. + // For all vectors, but vXi8 we can just emit a sign_extend a shift. This + // avoids a constant pool load. + if (VT.getVectorElementType() != MVT::i8) { + SDValue Extend = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, In); + return DAG.getNode(ISD::SRL, DL, VT, Extend, + DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT)); + } + + // Extend VT if BWI is not supported. MVT ExtVT = VT; - if (!Subtarget.hasBWI() && - (VT.getVectorElementType().getSizeInBits() <= 16)) + if (!Subtarget.hasBWI()) { + // If v16i32 is to be avoided, we'll need to split and concatenate. + if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) + return SplitAndExtendv16i1(ISD::ZERO_EXTEND, VT, In, DL, DAG); + ExtVT = MVT::getVectorVT(MVT::i32, NumElts); + } // Widen to 512-bits if VLX is not supported. MVT WideVT = ExtVT; @@ -16278,9 +16997,9 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op, SDValue SelectedVal = DAG.getSelect(DL, WideVT, In, One, Zero); - // Truncate if we had to extend i16/i8 above. + // Truncate if we had to extend above. if (VT != ExtVT) { - WideVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + WideVT = MVT::getVectorVT(MVT::i8, NumElts); SelectedVal = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SelectedVal); } @@ -16300,14 +17019,8 @@ static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget, if (SVT.getVectorElementType() == MVT::i1) return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG); - if (Subtarget.hasFp256()) - if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) - return Res; - - assert(!Op.getSimpleValueType().is256BitVector() || !SVT.is128BitVector() || - Op.getSimpleValueType().getVectorNumElements() != - SVT.getVectorNumElements()); - return SDValue(); + assert(Subtarget.hasAVX() && "Expected AVX support"); + return LowerAVXExtend(Op, DAG, Subtarget); } /// Helper to recursively truncate vector elements in half with PACKSS/PACKUS. @@ -16321,8 +17034,8 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In, assert((Opcode == X86ISD::PACKSS || Opcode == X86ISD::PACKUS) && "Unexpected PACK opcode"); - // Requires SSE2 but AVX512 has fast truncate. - if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + // Requires SSE2 but AVX512 has fast vector truncate. + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512() || !DstVT.isVector()) return SDValue(); EVT SrcVT = In.getValueType(); @@ -16331,40 +17044,53 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In, if (SrcVT == DstVT) return In; - // We only support vector truncation to 128bits or greater from a - // 256bits or greater source. + // We only support vector truncation to 64bits or greater from a + // 128bits or greater source. unsigned DstSizeInBits = DstVT.getSizeInBits(); unsigned SrcSizeInBits = SrcVT.getSizeInBits(); - if ((DstSizeInBits % 128) != 0 || (SrcSizeInBits % 256) != 0) + if ((DstSizeInBits % 64) != 0 || (SrcSizeInBits % 128) != 0) return SDValue(); - LLVMContext &Ctx = *DAG.getContext(); unsigned NumElems = SrcVT.getVectorNumElements(); + if (!isPowerOf2_32(NumElems)) + return SDValue(); + + LLVMContext &Ctx = *DAG.getContext(); assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation"); assert(SrcSizeInBits > DstSizeInBits && "Illegal truncation"); EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2); - // Extract lower/upper subvectors. - unsigned NumSubElts = NumElems / 2; - SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2); - SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2); - // Pack to the largest type possible: // vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB. EVT InVT = MVT::i16, OutVT = MVT::i8; - if (DstVT.getScalarSizeInBits() > 8 && + if (SrcVT.getScalarSizeInBits() > 16 && (Opcode == X86ISD::PACKSS || Subtarget.hasSSE41())) { InVT = MVT::i32; OutVT = MVT::i16; } + // 128bit -> 64bit truncate - PACK 128-bit src in the lower subvector. + if (SrcVT.is128BitVector()) { + InVT = EVT::getVectorVT(Ctx, InVT, 128 / InVT.getSizeInBits()); + OutVT = EVT::getVectorVT(Ctx, OutVT, 128 / OutVT.getSizeInBits()); + In = DAG.getBitcast(InVT, In); + SDValue Res = DAG.getNode(Opcode, DL, OutVT, In, In); + Res = extractSubVector(Res, 0, DAG, DL, 64); + return DAG.getBitcast(DstVT, Res); + } + + // Extract lower/upper subvectors. + unsigned NumSubElts = NumElems / 2; + SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2); + SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2); + unsigned SubSizeInBits = SrcSizeInBits / 2; InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits()); OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits()); // 256bit -> 128bit truncate - PACK lower/upper 128-bit subvectors. - if (SrcVT.is256BitVector()) { + if (SrcVT.is256BitVector() && DstVT.is128BitVector()) { Lo = DAG.getBitcast(InVT, Lo); Hi = DAG.getBitcast(InVT, Hi); SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi); @@ -16393,7 +17119,7 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In, } // Recursively pack lower/upper subvectors, concat result and pack again. - assert(SrcSizeInBits >= 512 && "Expected 512-bit vector or greater"); + assert(SrcSizeInBits >= 256 && "Expected 256-bit vector or greater"); EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumSubElts); Lo = truncateVectorWithPACK(Opcode, PackedVT, Lo, DL, DAG, Subtarget); Hi = truncateVectorWithPACK(Opcode, PackedVT, Hi, DL, DAG, Subtarget); @@ -16418,18 +17144,49 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, if (InVT.getScalarSizeInBits() <= 16) { if (Subtarget.hasBWI()) { // legal, will go to VPMOVB2M, VPMOVW2M - // Shift packed bytes not supported natively, bitcast to word - MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16); - SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, ExtVT, - DAG.getBitcast(ExtVT, In), - DAG.getConstant(ShiftInx, DL, ExtVT)); - ShiftNode = DAG.getBitcast(InVT, ShiftNode); - return DAG.getNode(X86ISD::CVT2MASK, DL, VT, ShiftNode); + if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) { + // We need to shift to get the lsb into sign position. + // Shift packed bytes not supported natively, bitcast to word + MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16); + In = DAG.getNode(ISD::SHL, DL, ExtVT, + DAG.getBitcast(ExtVT, In), + DAG.getConstant(ShiftInx, DL, ExtVT)); + In = DAG.getBitcast(InVT, In); + } + return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT), + In, ISD::SETGT); } // Use TESTD/Q, extended vector to packed dword/qword. assert((InVT.is256BitVector() || InVT.is128BitVector()) && "Unexpected vector type."); unsigned NumElts = InVT.getVectorNumElements(); + assert((NumElts == 8 || NumElts == 16) && "Unexpected number of elements"); + // We need to change to a wider element type that we have support for. + // For 8 element vectors this is easy, we either extend to v8i32 or v8i64. + // For 16 element vectors we extend to v16i32 unless we are explicitly + // trying to avoid 512-bit vectors. If we are avoiding 512-bit vectors + // we need to split into two 8 element vectors which we can extend to v8i32, + // truncate and concat the results. There's an additional complication if + // the original type is v16i8. In that case we can't split the v16i8 so + // first we pre-extend it to v16i16 which we can split to v8i16, then extend + // to v8i32, truncate that to v8i1 and concat the two halves. + if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) { + if (InVT == MVT::v16i8) { + // First we need to sign extend up to 256-bits so we can split that. + InVT = MVT::v16i16; + In = DAG.getNode(ISD::SIGN_EXTEND, DL, InVT, In); + } + SDValue Lo = extract128BitVector(In, 0, DAG, DL); + SDValue Hi = extract128BitVector(In, 8, DAG, DL); + // We're split now, just emit two truncates and a concat. The two + // truncates will trigger legalization to come back to this function. + Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Lo); + Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + } + // We either have 8 elements or we're allowed to use 512-bit vectors. + // If we have VLX, we want to use the narrowest vector that can get the + // job done so we use vXi32. MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts); MVT ExtVT = MVT::getVectorVT(EltVT, NumElts); In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In); @@ -16437,9 +17194,17 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, ShiftInx = InVT.getScalarSizeInBits() - 1; } - SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, InVT, In, - DAG.getConstant(ShiftInx, DL, InVT)); - return DAG.getNode(X86ISD::TESTM, DL, VT, ShiftNode, ShiftNode); + if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) { + // We need to shift to get the lsb into sign position. + In = DAG.getNode(ISD::SHL, DL, InVT, In, + DAG.getConstant(ShiftInx, DL, InVT)); + } + // If we have DQI, emit a pattern that will be iseled as vpmovq2m/vpmovd2m. + if (Subtarget.hasDQI()) + return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT), + In, ISD::SETGT); + return DAG.getSetCC(DL, VT, In, getZeroVector(InVT, Subtarget, DAG, DL), + ISD::SETNE); } SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { @@ -16458,31 +17223,36 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { // vpmovqb/w/d, vpmovdb/w, vpmovwb if (Subtarget.hasAVX512()) { // word to byte only under BWI - if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) // v16i16 -> v16i8 - return DAG.getNode(X86ISD::VTRUNC, DL, VT, - getExtendInVec(X86ISD::VSEXT, DL, MVT::v16i32, In, DAG)); - return DAG.getNode(X86ISD::VTRUNC, DL, VT, In); + if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) { // v16i16 -> v16i8 + // Make sure we're allowed to promote 512-bits. + if (Subtarget.canExtendTo512DQ()) + return DAG.getNode(ISD::TRUNCATE, DL, VT, + DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In)); + } else { + return Op; + } } - // Truncate with PACKSS if we are truncating a vector with sign-bits that - // extend all the way to the packed/truncated value. - unsigned NumPackedBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16); - if ((InNumEltBits - NumPackedBits) < DAG.ComputeNumSignBits(In)) - if (SDValue V = - truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget)) - return V; + unsigned NumPackedSignBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16); + unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8; // Truncate with PACKUS if we are truncating a vector with leading zero bits // that extend all the way to the packed/truncated value. // Pre-SSE41 we can only use PACKUSWB. KnownBits Known; DAG.computeKnownBits(In, Known); - NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8; - if ((InNumEltBits - NumPackedBits) <= Known.countMinLeadingZeros()) + if ((InNumEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget)) return V; + // Truncate with PACKSS if we are truncating a vector with sign-bits that + // extend all the way to the packed/truncated value. + if ((InNumEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In)) + if (SDValue V = + truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget)) + return V; + if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) { // On AVX2, v4i64 -> v4i32 becomes VPERMD. if (Subtarget.hasInt256()) { @@ -16549,10 +17319,9 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { } // Handle truncation of V256 to V128 using shuffles. - if (!VT.is128BitVector() || !InVT.is256BitVector()) - return SDValue(); + assert(VT.is128BitVector() && InVT.is256BitVector() && "Unexpected types!"); - assert(Subtarget.hasFp256() && "256-bit vector without AVX!"); + assert(Subtarget.hasAVX() && "256-bit vector without AVX!"); unsigned NumElems = VT.getVectorNumElements(); MVT NVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems * 2); @@ -16572,9 +17341,29 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); if (VT.isVector()) { - assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); SDValue Src = Op.getOperand(0); SDLoc dl(Op); + + if (VT == MVT::v2i1 && Src.getSimpleValueType() == MVT::v2f64) { + MVT ResVT = MVT::v4i32; + MVT TruncVT = MVT::v4i1; + unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI; + if (!IsSigned && !Subtarget.hasVLX()) { + // Widen to 512-bits. + ResVT = MVT::v8i32; + TruncVT = MVT::v8i1; + Opc = ISD::FP_TO_UINT; + Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64, + DAG.getUNDEF(MVT::v8f64), + Src, DAG.getIntPtrConstant(0, dl)); + } + SDValue Res = DAG.getNode(Opc, dl, ResVT, Src); + Res = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Res); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i1, Res, + DAG.getIntPtrConstant(0, dl)); + } + + assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); if (VT == MVT::v2i64 && Src.getSimpleValueType() == MVT::v2f32) { return DAG.getNode(IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI, dl, VT, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, @@ -16771,8 +17560,16 @@ static SDValue LowerFGETSIGN(SDValue Op, SelectionDAG &DAG) { return Res; } +/// Helper for creating a X86ISD::SETCC node. +static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl, + SelectionDAG &DAG) { + return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, + DAG.getConstant(Cond, dl, MVT::i8), EFLAGS); +} + // Check whether an OR'd tree is PTEST-able. -static SDValue LowerVectorAllZeroTest(SDValue Op, const X86Subtarget &Subtarget, +static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree."); @@ -16859,10 +17656,12 @@ static SDValue LowerVectorAllZeroTest(SDValue Op, const X86Subtarget &Subtarget, VecIns.push_back(DAG.getNode(ISD::OR, DL, TestVT, LHS, RHS)); } - return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, VecIns.back(), VecIns.back()); + SDValue Res = DAG.getNode(X86ISD::PTEST, DL, MVT::i32, + VecIns.back(), VecIns.back()); + return getSETCC(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, Res, DL, DAG); } -/// \brief return true if \c Op has a use that doesn't just read flags. +/// return true if \c Op has a use that doesn't just read flags. static bool hasNonFlagsUse(SDValue Op) { for (SDNode::use_iterator UI = Op->use_begin(), UE = Op->use_end(); UI != UE; ++UI) { @@ -16881,33 +17680,10 @@ static bool hasNonFlagsUse(SDValue Op) { return false; } -// Emit KTEST instruction for bit vectors on AVX-512 -static SDValue EmitKTEST(SDValue Op, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - if (Op.getOpcode() == ISD::BITCAST) { - auto hasKTEST = [&](MVT VT) { - unsigned SizeInBits = VT.getSizeInBits(); - return (Subtarget.hasDQI() && (SizeInBits == 8 || SizeInBits == 16)) || - (Subtarget.hasBWI() && (SizeInBits == 32 || SizeInBits == 64)); - }; - SDValue Op0 = Op.getOperand(0); - MVT Op0VT = Op0.getValueType().getSimpleVT(); - if (Op0VT.isVector() && Op0VT.getVectorElementType() == MVT::i1 && - hasKTEST(Op0VT)) - return DAG.getNode(X86ISD::KTEST, SDLoc(Op), Op0VT, Op0, Op0); - } - return SDValue(); -} - /// Emit nodes that will be selected as "test Op0,Op0", or something /// equivalent. SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, SelectionDAG &DAG) const { - if (Op.getValueType() == MVT::i1) { - SDValue ExtOp = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Op); - return DAG.getNode(X86ISD::CMP, dl, MVT::i32, ExtOp, - DAG.getConstant(0, dl, MVT::i8)); - } // CF and OF aren't always set the way we want. Determine which // of these we need. bool NeedCF = false; @@ -16943,9 +17719,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, // doing a separate TEST. TEST always sets OF and CF to 0, so unless // we prove that the arithmetic won't overflow, we can't use OF or CF. if (Op.getResNo() != 0 || NeedOF || NeedCF) { - // Emit KTEST for bit vectors - if (auto Node = EmitKTEST(Op, DAG, Subtarget)) - return Node; // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); @@ -17119,14 +17892,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, case ISD::SUB: Opcode = X86ISD::SUB; break; case ISD::XOR: Opcode = X86ISD::XOR; break; case ISD::AND: Opcode = X86ISD::AND; break; - case ISD::OR: { - if (!NeedTruncation && ZeroCheck) { - if (SDValue EFLAGS = LowerVectorAllZeroTest(Op, Subtarget, DAG)) - return EFLAGS; - } - Opcode = X86ISD::OR; - break; - } + case ISD::OR: Opcode = X86ISD::OR; break; } NumOperands = 2; @@ -17168,16 +17934,13 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, if (TLI.isOperationLegal(WideVal.getOpcode(), WideVT)) { SDValue V0 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(0)); SDValue V1 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(1)); - Op = DAG.getNode(ConvertedOp, dl, VT, V0, V1); + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); + Op = DAG.getNode(ConvertedOp, dl, VTs, V0, V1); } } } if (Opcode == 0) { - // Emit KTEST for bit vectors - if (auto Node = EmitKTEST(Op, DAG, Subtarget)) - return Node; - // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); @@ -17186,7 +17949,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, SmallVector<SDValue, 4> Ops(Op->op_begin(), Op->op_begin() + NumOperands); SDValue New = DAG.getNode(Opcode, dl, VTs, Ops); - DAG.ReplaceAllUsesWith(Op, New); + DAG.ReplaceAllUsesOfValueWith(SDValue(Op.getNode(), 0), New); return SDValue(New.getNode(), 1); } @@ -17271,7 +18034,6 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op, EVT VT = Op.getValueType(); // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps. - // TODO: Add support for AVX512 (v16f32). // It is likely not profitable to do this for f64 because a double-precision // rsqrt estimate with refinement on x86 prior to FMA requires at least 16 // instructions: convert to single, rsqrtss, convert back to double, refine @@ -17282,12 +18044,15 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op, if ((VT == MVT::f32 && Subtarget.hasSSE1()) || (VT == MVT::v4f32 && Subtarget.hasSSE1() && Reciprocal) || (VT == MVT::v4f32 && Subtarget.hasSSE2() && !Reciprocal) || - (VT == MVT::v8f32 && Subtarget.hasAVX())) { + (VT == MVT::v8f32 && Subtarget.hasAVX()) || + (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) { if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 1; UseOneConstNR = false; - return DAG.getNode(X86ISD::FRSQRT, SDLoc(Op), VT, Op); + // There is no FSQRT for 512-bits, but there is RSQRT14. + unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT; + return DAG.getNode(Opcode, SDLoc(Op), VT, Op); } return SDValue(); } @@ -17300,7 +18065,6 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG, EVT VT = Op.getValueType(); // SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps. - // TODO: Add support for AVX512 (v16f32). // It is likely not profitable to do this for f64 because a double-precision // reciprocal estimate with refinement on x86 prior to FMA requires // 15 instructions: convert to single, rcpss, convert back to double, refine @@ -17309,7 +18073,8 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG, if ((VT == MVT::f32 && Subtarget.hasSSE1()) || (VT == MVT::v4f32 && Subtarget.hasSSE1()) || - (VT == MVT::v8f32 && Subtarget.hasAVX())) { + (VT == MVT::v8f32 && Subtarget.hasAVX()) || + (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) { // Enable estimate codegen with 1 refinement step for vector division. // Scalar division estimates are disabled because they break too much // real-world code. These defaults are intended to match GCC behavior. @@ -17319,7 +18084,9 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG, if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 1; - return DAG.getNode(X86ISD::FRCP, SDLoc(Op), VT, Op); + // There is no FSQRT for 512-bits, but there is RCP14. + unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP; + return DAG.getNode(Opcode, SDLoc(Op), VT, Op); } return SDValue(); } @@ -17334,13 +18101,6 @@ unsigned X86TargetLowering::combineRepeatedFPDivisors() const { return 2; } -/// Helper for creating a X86ISD::SETCC node. -static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl, - SelectionDAG &DAG) { - return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(Cond, dl, MVT::i8), EFLAGS); -} - /// Create a BT (Bit Test) node - Test bit \p BitNo in \p Src and set condition /// according to equal/not-equal condition code \p CC. static SDValue getBitTestCondition(SDValue Src, SDValue BitNo, ISD::CondCode CC, @@ -17408,12 +18168,15 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, if (AndRHSVal == 1 && AndLHS.getOpcode() == ISD::SRL) { LHS = AndLHS.getOperand(0); RHS = AndLHS.getOperand(1); - } - - // Use BT if the immediate can't be encoded in a TEST instruction. - if (!isUInt<32>(AndRHSVal) && isPowerOf2_64(AndRHSVal)) { - LHS = AndLHS; - RHS = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, LHS.getValueType()); + } else { + // Use BT if the immediate can't be encoded in a TEST instruction or we + // are optimizing for size and the immedaite won't fit in a byte. + bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); + if ((!isUInt<32>(AndRHSVal) || (OptForSize && !isUInt<8>(AndRHSVal))) && + isPowerOf2_64(AndRHSVal)) { + LHS = AndLHS; + RHS = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, LHS.getValueType()); + } } } @@ -17498,49 +18261,6 @@ static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) { DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC)); } -static SDValue LowerBoolVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - SDValue CC = Op.getOperand(2); - MVT VT = Op.getSimpleValueType(); - SDLoc dl(Op); - - assert(Op0.getSimpleValueType().getVectorElementType() == MVT::i1 && - "Unexpected type for boolean compare operation"); - ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get(); - SDValue NotOp0 = DAG.getNode(ISD::XOR, dl, VT, Op0, - DAG.getConstant(-1, dl, VT)); - SDValue NotOp1 = DAG.getNode(ISD::XOR, dl, VT, Op1, - DAG.getConstant(-1, dl, VT)); - switch (SetCCOpcode) { - default: llvm_unreachable("Unexpected SETCC condition"); - case ISD::SETEQ: - // (x == y) -> ~(x ^ y) - return DAG.getNode(ISD::XOR, dl, VT, - DAG.getNode(ISD::XOR, dl, VT, Op0, Op1), - DAG.getConstant(-1, dl, VT)); - case ISD::SETNE: - // (x != y) -> (x ^ y) - return DAG.getNode(ISD::XOR, dl, VT, Op0, Op1); - case ISD::SETUGT: - case ISD::SETGT: - // (x > y) -> (x & ~y) - return DAG.getNode(ISD::AND, dl, VT, Op0, NotOp1); - case ISD::SETULT: - case ISD::SETLT: - // (x < y) -> (~x & y) - return DAG.getNode(ISD::AND, dl, VT, NotOp0, Op1); - case ISD::SETULE: - case ISD::SETLE: - // (x <= y) -> (~x | y) - return DAG.getNode(ISD::OR, dl, VT, NotOp0, Op1); - case ISD::SETUGE: - case ISD::SETGE: - // (x >=y) -> (x | ~y) - return DAG.getNode(ISD::OR, dl, VT, Op0, NotOp1); - } -} - static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { SDValue Op0 = Op.getOperand(0); @@ -17553,48 +18273,24 @@ static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { "Cannot set masked compare for this operation"); ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get(); - unsigned Opc = 0; - bool Unsigned = false; - bool Swap = false; - unsigned SSECC; - switch (SetCCOpcode) { - default: llvm_unreachable("Unexpected SETCC condition"); - case ISD::SETNE: SSECC = 4; break; - case ISD::SETEQ: Opc = X86ISD::PCMPEQM; break; - case ISD::SETUGT: SSECC = 6; Unsigned = true; break; - case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH; - case ISD::SETGT: Opc = X86ISD::PCMPGTM; break; - case ISD::SETULT: SSECC = 1; Unsigned = true; break; - case ISD::SETUGE: SSECC = 5; Unsigned = true; break; //NLT - case ISD::SETGE: Swap = true; SSECC = 2; break; // LE + swap - case ISD::SETULE: Unsigned = true; LLVM_FALLTHROUGH; - case ISD::SETLE: SSECC = 2; break; - } - if (Swap) + // If this is a seteq make sure any build vectors of all zeros are on the RHS. + // This helps with vptestm matching. + // TODO: Should we just canonicalize the setcc during DAG combine? + if ((SetCCOpcode == ISD::SETEQ || SetCCOpcode == ISD::SETNE) && + ISD::isBuildVectorAllZeros(Op0.getNode())) std::swap(Op0, Op1); - // See if it is the case of CMP(EQ|NEQ,AND(A,B),ZERO) and change it to TESTM|NM. - if ((!Opc && SSECC == 4) || Opc == X86ISD::PCMPEQM) { - SDValue A = peekThroughBitcasts(Op0); - if ((A.getOpcode() == ISD::AND || A.getOpcode() == X86ISD::FAND) && - ISD::isBuildVectorAllZeros(Op1.getNode())) { - MVT VT0 = Op0.getSimpleValueType(); - SDValue RHS = DAG.getBitcast(VT0, A.getOperand(0)); - SDValue LHS = DAG.getBitcast(VT0, A.getOperand(1)); - return DAG.getNode(Opc == X86ISD::PCMPEQM ? X86ISD::TESTNM : X86ISD::TESTM, - dl, VT, RHS, LHS); - } + // Prefer SETGT over SETLT. + if (SetCCOpcode == ISD::SETLT) { + SetCCOpcode = ISD::getSetCCSwappedOperands(SetCCOpcode); + std::swap(Op0, Op1); } - if (Opc) - return DAG.getNode(Opc, dl, VT, Op0, Op1); - Opc = Unsigned ? X86ISD::CMPMU: X86ISD::CMPM; - return DAG.getNode(Opc, dl, VT, Op0, Op1, - DAG.getConstant(SSECC, dl, MVT::i8)); + return DAG.getSetCC(dl, VT, Op0, Op1, SetCCOpcode); } -/// \brief Try to turn a VSETULT into a VSETULE by modifying its second +/// Try to turn a VSETULT into a VSETULE by modifying its second /// operand \p Op1. If non-trivial (for example because it's not constant) /// return an empty value. static SDValue ChangeVSETULTtoVSETULE(const SDLoc &dl, SDValue Op1, @@ -17624,6 +18320,51 @@ static SDValue ChangeVSETULTtoVSETULE(const SDLoc &dl, SDValue Op1, return DAG.getBuildVector(VT, dl, ULTOp1); } +/// As another special case, use PSUBUS[BW] when it's profitable. E.g. for +/// Op0 u<= Op1: +/// t = psubus Op0, Op1 +/// pcmpeq t, <0..0> +static SDValue LowerVSETCCWithSUBUS(SDValue Op0, SDValue Op1, MVT VT, + ISD::CondCode Cond, const SDLoc &dl, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + MVT VET = VT.getVectorElementType(); + if (VET != MVT::i8 && VET != MVT::i16) + return SDValue(); + + switch (Cond) { + default: + return SDValue(); + case ISD::SETULT: { + // If the comparison is against a constant we can turn this into a + // setule. With psubus, setule does not require a swap. This is + // beneficial because the constant in the register is no longer + // destructed as the destination so it can be hoisted out of a loop. + // Only do this pre-AVX since vpcmp* is no longer destructive. + if (Subtarget.hasAVX()) + return SDValue(); + SDValue ULEOp1 = ChangeVSETULTtoVSETULE(dl, Op1, DAG); + if (!ULEOp1) + return SDValue(); + Op1 = ULEOp1; + break; + } + // Psubus is better than flip-sign because it requires no inversion. + case ISD::SETUGE: + std::swap(Op0, Op1); + break; + case ISD::SETULE: + break; + } + + SDValue Result = DAG.getNode(X86ISD::SUBUS, dl, VT, Op0, Op1); + return DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result, + getZeroVector(VT, Subtarget, DAG, dl)); +} + static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue Op0 = Op.getOperand(0); @@ -17697,23 +18438,10 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, assert(VT.getVectorNumElements() == VTOp0.getVectorNumElements() && "Invalid number of packed elements for source and destination!"); - if (VT.is128BitVector() && VTOp0.is256BitVector()) { - // On non-AVX512 targets, a vector of MVT::i1 is promoted by the type - // legalizer to a wider vector type. In the case of 'vsetcc' nodes, the - // legalizer firstly checks if the first operand in input to the setcc has - // a legal type. If so, then it promotes the return type to that same type. - // Otherwise, the return type is promoted to the 'next legal type' which, - // for a vector of MVT::i1 is always a 128-bit integer vector type. - // - // We reach this code only if the following two conditions are met: - // 1. Both return type and operand type have been promoted to wider types - // by the type legalizer. - // 2. The original operand type has been promoted to a 256-bit vector. - // - // Note that condition 2. only applies for AVX targets. - SDValue NewOp = DAG.getSetCC(dl, VTOp0, Op0, Op1, Cond); - return DAG.getZExtOrTrunc(NewOp, dl, VT); - } + // This is being called by type legalization because v2i32 is marked custom + // for result type legalization for v2f32. + if (VTOp0 == MVT::v2i32) + return SDValue(); // The non-AVX512 code below works under the assumption that source and // destination types are the same. @@ -17724,31 +18452,17 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, if (VT.is256BitVector() && !Subtarget.hasInt256()) return Lower256IntVSETCC(Op, DAG); - // Operands are boolean (vectors of i1) - MVT OpVT = Op1.getSimpleValueType(); - if (OpVT.getVectorElementType() == MVT::i1) - return LowerBoolVSETCC_AVX512(Op, DAG); - // The result is boolean, but operands are int/float if (VT.getVectorElementType() == MVT::i1) { // In AVX-512 architecture setcc returns mask with i1 elements, // But there is no compare instruction for i8 and i16 elements in KNL. - // In this case use SSE compare - bool UseAVX512Inst = - (OpVT.is512BitVector() || - OpVT.getScalarSizeInBits() >= 32 || - (Subtarget.hasBWI() && Subtarget.hasVLX())); - - if (UseAVX512Inst) - return LowerIntVSETCC_AVX512(Op, DAG); - - return DAG.getNode(ISD::TRUNCATE, dl, VT, - DAG.getNode(ISD::SETCC, dl, OpVT, Op0, Op1, CC)); + assert((VTOp0.getScalarSizeInBits() >= 32 || Subtarget.hasBWI()) && + "Unexpected operand type"); + return LowerIntVSETCC_AVX512(Op, DAG); } // Lower using XOP integer comparisons. - if ((VT == MVT::v16i8 || VT == MVT::v8i16 || - VT == MVT::v4i32 || VT == MVT::v2i64) && Subtarget.hasXOP()) { + if (VT.is128BitVector() && Subtarget.hasXOP()) { // Translate compare code to XOP PCOM compare mode. unsigned CmpMode = 0; switch (Cond) { @@ -17791,15 +18505,18 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, } } - // We are handling one of the integer comparisons here. Since SSE only has - // GT and EQ comparisons for integer, swapping operands and multiple - // operations may be required for some comparisons. - unsigned Opc = (Cond == ISD::SETEQ || Cond == ISD::SETNE) ? X86ISD::PCMPEQ - : X86ISD::PCMPGT; - bool Swap = Cond == ISD::SETLT || Cond == ISD::SETULT || - Cond == ISD::SETGE || Cond == ISD::SETUGE; - bool Invert = Cond == ISD::SETNE || - (Cond != ISD::SETEQ && ISD::isTrueWhenEqual(Cond)); + // If this is a SETNE against the signed minimum value, change it to SETGT. + // If this is a SETNE against the signed maximum value, change it to SETLT. + // which will be swapped to SETGT. + // Otherwise we use PCMPEQ+invert. + APInt ConstValue; + if (Cond == ISD::SETNE && + ISD::isConstantSplatVector(Op1.getNode(), ConstValue)) { + if (ConstValue.isMinSignedValue()) + Cond = ISD::SETGT; + else if (ConstValue.isMaxSignedValue()) + Cond = ISD::SETLT; + } // If both operands are known non-negative, then an unsigned compare is the // same as a signed compare and there's no need to flip signbits. @@ -17808,58 +18525,47 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, bool FlipSigns = ISD::isUnsignedIntSetCC(Cond) && !(DAG.SignBitIsZero(Op0) && DAG.SignBitIsZero(Op1)); - // Special case: Use min/max operations for SETULE/SETUGE - MVT VET = VT.getVectorElementType(); - bool HasMinMax = - (Subtarget.hasAVX512() && VET == MVT::i64) || - (Subtarget.hasSSE41() && (VET == MVT::i16 || VET == MVT::i32)) || - (Subtarget.hasSSE2() && (VET == MVT::i8)); - bool MinMax = false; - if (HasMinMax) { + // Special case: Use min/max operations for unsigned compares. We only want + // to do this for unsigned compares if we need to flip signs or if it allows + // use to avoid an invert. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (ISD::isUnsignedIntSetCC(Cond) && + (FlipSigns || ISD::isTrueWhenEqual(Cond)) && + TLI.isOperationLegal(ISD::UMIN, VT)) { + bool Invert = false; + unsigned Opc; switch (Cond) { - default: break; - case ISD::SETULE: Opc = ISD::UMIN; MinMax = true; break; - case ISD::SETUGE: Opc = ISD::UMAX; MinMax = true; break; + default: llvm_unreachable("Unexpected condition code"); + case ISD::SETUGT: Invert = true; LLVM_FALLTHROUGH; + case ISD::SETULE: Opc = ISD::UMIN; break; + case ISD::SETULT: Invert = true; LLVM_FALLTHROUGH; + case ISD::SETUGE: Opc = ISD::UMAX; break; } - if (MinMax) - Swap = Invert = FlipSigns = false; - } + SDValue Result = DAG.getNode(Opc, dl, VT, Op0, Op1); + Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Op0, Result); - bool HasSubus = Subtarget.hasSSE2() && (VET == MVT::i8 || VET == MVT::i16); - bool Subus = false; - if (!MinMax && HasSubus) { - // As another special case, use PSUBUS[BW] when it's profitable. E.g. for - // Op0 u<= Op1: - // t = psubus Op0, Op1 - // pcmpeq t, <0..0> - switch (Cond) { - default: break; - case ISD::SETULT: { - // If the comparison is against a constant we can turn this into a - // setule. With psubus, setule does not require a swap. This is - // beneficial because the constant in the register is no longer - // destructed as the destination so it can be hoisted out of a loop. - // Only do this pre-AVX since vpcmp* is no longer destructive. - if (Subtarget.hasAVX()) - break; - if (SDValue ULEOp1 = ChangeVSETULTtoVSETULE(dl, Op1, DAG)) { - Op1 = ULEOp1; - Subus = true; Invert = false; Swap = false; - } - break; - } - // Psubus is better than flip-sign because it requires no inversion. - case ISD::SETUGE: Subus = true; Invert = false; Swap = true; break; - case ISD::SETULE: Subus = true; Invert = false; Swap = false; break; - } + // If the logical-not of the result is required, perform that now. + if (Invert) + Result = DAG.getNOT(dl, Result, VT); - if (Subus) { - Opc = X86ISD::SUBUS; - FlipSigns = false; - } + return Result; } + // Try to use SUBUS and PCMPEQ. + if (SDValue V = LowerVSETCCWithSUBUS(Op0, Op1, VT, Cond, dl, Subtarget, DAG)) + return V; + + // We are handling one of the integer comparisons here. Since SSE only has + // GT and EQ comparisons for integer, swapping operands and multiple + // operations may be required for some comparisons. + unsigned Opc = (Cond == ISD::SETEQ || Cond == ISD::SETNE) ? X86ISD::PCMPEQ + : X86ISD::PCMPGT; + bool Swap = Cond == ISD::SETLT || Cond == ISD::SETULT || + Cond == ISD::SETGE || Cond == ISD::SETUGE; + bool Invert = Cond == ISD::SETNE || + (Cond != ISD::SETEQ && ISD::isTrueWhenEqual(Cond)); + if (Swap) std::swap(Op0, Op1); @@ -17947,14 +18653,47 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, if (Invert) Result = DAG.getNOT(dl, Result, VT); - if (MinMax) - Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Op0, Result); + return Result; +} - if (Subus) - Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result, - getZeroVector(VT, Subtarget, DAG, dl)); +// Try to select this as a KTEST+SETCC if possible. +static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Only support equality comparisons. + if (CC != ISD::SETEQ && CC != ISD::SETNE) + return SDValue(); - return Result; + // Must be a bitcast from vXi1. + if (Op0.getOpcode() != ISD::BITCAST) + return SDValue(); + + Op0 = Op0.getOperand(0); + MVT VT = Op0.getSimpleValueType(); + if (!(Subtarget.hasAVX512() && VT == MVT::v16i1) && + !(Subtarget.hasDQI() && VT == MVT::v8i1) && + !(Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1))) + return SDValue(); + + X86::CondCode X86CC; + if (isNullConstant(Op1)) { + X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE; + } else if (isAllOnesConstant(Op1)) { + // C flag is set for all ones. + X86CC = CC == ISD::SETEQ ? X86::COND_B : X86::COND_AE; + } else + return SDValue(); + + // If the input is an OR, we can combine it's operands into the KORTEST. + SDValue LHS = Op0; + SDValue RHS = Op0; + if (Op0.getOpcode() == ISD::OR && Op0.hasOneUse()) { + LHS = Op0.getOperand(0); + RHS = Op0.getOperand(1); + } + + SDValue KORTEST = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS); + return getSETCC(X86CC, KORTEST, dl, DAG); } SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { @@ -17979,6 +18718,18 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { return NewSetCC; } + // Try to use PTEST for a tree ORs equality compared with 0. + // TODO: We could do AND tree with all 1s as well by using the C flag. + if (Op0.getOpcode() == ISD::OR && isNullConstant(Op1) && + (CC == ISD::SETEQ || CC == ISD::SETNE)) { + if (SDValue NewSetCC = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG)) + return NewSetCC; + } + + // Try to lower using KTEST. + if (SDValue NewSetCC = EmitKTEST(Op0, Op1, CC, dl, DAG, Subtarget)) + return NewSetCC; + // Look for X == 0, X == 1, X != 0, or X != 1. We can simplify some forms of // these. if ((isOneConstant(Op1) || isNullConstant(Op1)) && @@ -18070,7 +18821,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // are available or VBLENDV if AVX is available. // Otherwise FP cmovs get lowered into a less efficient branch sequence later. if (Cond.getOpcode() == ISD::SETCC && - ((Subtarget.hasSSE2() && (VT == MVT::f32 || VT == MVT::f64)) || + ((Subtarget.hasSSE2() && VT == MVT::f64) || (Subtarget.hasSSE1() && VT == MVT::f32)) && VT == Cond.getOperand(0).getSimpleValueType() && Cond->hasOneUse()) { SDValue CondOp0 = Cond.getOperand(0), CondOp1 = Cond.getOperand(1); @@ -18132,6 +18883,18 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2); } + // For v64i1 without 64-bit support we need to split and rejoin. + if (VT == MVT::v64i1 && !Subtarget.is64Bit()) { + assert(Subtarget.hasBWI() && "Expected BWI to be legal"); + SDValue Op1Lo = extractSubVector(Op1, 0, DAG, DL, 32); + SDValue Op2Lo = extractSubVector(Op2, 0, DAG, DL, 32); + SDValue Op1Hi = extractSubVector(Op1, 32, DAG, DL, 32); + SDValue Op2Hi = extractSubVector(Op2, 32, DAG, DL, 32); + SDValue Lo = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Lo, Op2Lo); + SDValue Hi = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Hi, Op2Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + } + if (VT.isVector() && VT.getVectorElementType() == MVT::i1) { SDValue Op1Scalar; if (ISD::isBuildVectorOfConstantSDNodes(Op1.getNode())) @@ -18379,6 +19142,15 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } } + // Promote i16 cmovs if it won't prevent folding a load. + if (Op.getValueType() == MVT::i16 && !MayFoldLoad(Op1) && !MayFoldLoad(Op2)) { + Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1); + Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2); + SDValue Ops[] = { Op2, Op1, CC, Cond }; + SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, MVT::i32, Ops); + return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Cmov); + } + // X86ISD::CMOV means set the result (which is operand 1) to the RHS if // condition is true. SDValue Ops[] = { Op2, Op1, CC, Cond }; @@ -18399,8 +19171,13 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, // Extend VT if the scalar type is v8/v16 and BWI is not supported. MVT ExtVT = VT; - if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) + if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) { + // If v16i32 is to be avoided, we'll need to split and concatenate. + if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) + return SplitAndExtendv16i1(ISD::SIGN_EXTEND, VT, In, dl, DAG); + ExtVT = MVT::getVectorVT(MVT::i32, NumElts); + } // Widen to 512-bits if VLX is not supported. MVT WideVT = ExtVT; @@ -18416,7 +19193,7 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, MVT WideEltVT = WideVT.getVectorElementType(); if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) || (Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) { - V = getExtendInVec(X86ISD::VSEXT, dl, WideVT, In, DAG); + V = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, In); } else { SDValue NegOne = getOnesVector(WideVT, DAG, dl); SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, dl); @@ -18445,11 +19222,8 @@ static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, if (InVT.getVectorElementType() == MVT::i1) return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); - if (Subtarget.hasFp256()) - if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) - return Res; - - return SDValue(); + assert(Subtarget.hasAVX() && "Expected AVX support"); + return LowerAVXExtend(Op, DAG, Subtarget); } // Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG. @@ -18549,15 +19323,17 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, if (InVT.getVectorElementType() == MVT::i1) return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); - if ((VT != MVT::v4i64 || InVT != MVT::v4i32) && - (VT != MVT::v8i32 || InVT != MVT::v8i16) && - (VT != MVT::v16i16 || InVT != MVT::v16i8) && - (VT != MVT::v8i64 || InVT != MVT::v8i32) && - (VT != MVT::v8i64 || InVT != MVT::v8i16) && - (VT != MVT::v16i32 || InVT != MVT::v16i16) && - (VT != MVT::v16i32 || InVT != MVT::v16i8) && - (VT != MVT::v32i16 || InVT != MVT::v32i8)) - return SDValue(); + assert(VT.isVector() && InVT.isVector() && "Expected vector type"); + assert(VT.getVectorNumElements() == VT.getVectorNumElements() && + "Expected same number of elements"); + assert((VT.getVectorElementType() == MVT::i16 || + VT.getVectorElementType() == MVT::i32 || + VT.getVectorElementType() == MVT::i64) && + "Unexpected element type"); + assert((InVT.getVectorElementType() == MVT::i8 || + InVT.getVectorElementType() == MVT::i16 || + InVT.getVectorElementType() == MVT::i32) && + "Unexpected element type"); if (Subtarget.hasInt256()) return DAG.getNode(X86ISD::VSEXT, dl, VT, In); @@ -18595,164 +19371,29 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } -// Lower truncating store. We need a special lowering to vXi1 vectors -static SDValue LowerTruncatingStore(SDValue StOp, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - StoreSDNode *St = cast<StoreSDNode>(StOp.getNode()); +static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + StoreSDNode *St = cast<StoreSDNode>(Op.getNode()); SDLoc dl(St); - EVT MemVT = St->getMemoryVT(); - assert(St->isTruncatingStore() && "We only custom truncating store."); - assert(MemVT.isVector() && MemVT.getVectorElementType() == MVT::i1 && - "Expected truncstore of i1 vector"); - - SDValue Op = St->getValue(); - MVT OpVT = Op.getValueType().getSimpleVT(); - unsigned NumElts = OpVT.getVectorNumElements(); - if ((Subtarget.hasVLX() && Subtarget.hasBWI() && Subtarget.hasDQI()) || - NumElts == 16) { - // Truncate and store - everything is legal - Op = DAG.getNode(ISD::TRUNCATE, dl, MemVT, Op); - if (MemVT.getSizeInBits() < 8) - Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, - DAG.getUNDEF(MVT::v8i1), Op, - DAG.getIntPtrConstant(0, dl)); - return DAG.getStore(St->getChain(), dl, Op, St->getBasePtr(), - St->getMemOperand()); - } - - // A subset, assume that we have only AVX-512F - if (NumElts <= 8) { - if (NumElts < 8) { - // Extend to 8-elts vector - MVT ExtVT = MVT::getVectorVT(OpVT.getScalarType(), 8); - Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ExtVT, - DAG.getUNDEF(ExtVT), Op, DAG.getIntPtrConstant(0, dl)); - } - Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::v8i1, Op); - return DAG.getStore(St->getChain(), dl, Op, St->getBasePtr(), - St->getMemOperand()); - } - // v32i8 - assert(OpVT == MVT::v32i8 && "Unexpected operand type"); - // Divide the vector into 2 parts and store each part separately - SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, Op, - DAG.getIntPtrConstant(0, dl)); - Lo = DAG.getNode(ISD::TRUNCATE, dl, MVT::v16i1, Lo); - SDValue BasePtr = St->getBasePtr(); - SDValue StLo = DAG.getStore(St->getChain(), dl, Lo, BasePtr, - St->getMemOperand()); - SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, Op, - DAG.getIntPtrConstant(16, dl)); - Hi = DAG.getNode(ISD::TRUNCATE, dl, MVT::v16i1, Hi); - - SDValue BasePtrHi = - DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr, - DAG.getConstant(2, dl, BasePtr.getValueType())); - - SDValue StHi = DAG.getStore(St->getChain(), dl, Hi, - BasePtrHi, St->getMemOperand()); - return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, StLo, StHi); -} - -static SDValue LowerExtended1BitVectorLoad(SDValue Op, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - - LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode()); - SDLoc dl(Ld); - EVT MemVT = Ld->getMemoryVT(); - assert(MemVT.isVector() && MemVT.getScalarType() == MVT::i1 && - "Expected i1 vector load"); - unsigned ExtOpcode = Ld->getExtensionType() == ISD::ZEXTLOAD ? - ISD::ZERO_EXTEND : ISD::SIGN_EXTEND; - MVT VT = Op.getValueType().getSimpleVT(); - unsigned NumElts = VT.getVectorNumElements(); - - if ((Subtarget.hasBWI() && NumElts >= 32) || - (Subtarget.hasDQI() && NumElts < 16) || - NumElts == 16) { - // Load and extend - everything is legal - if (NumElts < 8) { - SDValue Load = DAG.getLoad(MVT::v8i1, dl, Ld->getChain(), - Ld->getBasePtr(), - Ld->getMemOperand()); - // Replace chain users with the new chain. - assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); - if (Subtarget.hasVLX()) { - // Extract to v4i1/v2i1. - SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, Load, - DAG.getIntPtrConstant(0, dl)); - // Finally, do a normal sign-extend to the desired register. - return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract); - } - - MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); - SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, Load); - - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec, - DAG.getIntPtrConstant(0, dl)); - } - SDValue Load = DAG.getLoad(MemVT, dl, Ld->getChain(), - Ld->getBasePtr(), - Ld->getMemOperand()); - // Replace chain users with the new chain. - assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); - - // Finally, do a normal sign-extend to the desired register. - return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Load); - } - - if (NumElts <= 8) { - // A subset, assume that we have only AVX-512F - SDValue Load = DAG.getLoad(MVT::i8, dl, Ld->getChain(), - Ld->getBasePtr(), - Ld->getMemOperand()); - // Replace chain users with the new chain. - assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); - - SDValue BitVec = DAG.getBitcast(MVT::v8i1, Load); - - if (NumElts == 8) - return DAG.getNode(ExtOpcode, dl, VT, BitVec); - - if (Subtarget.hasVLX()) { - // Extract to v4i1/v2i1. - SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, BitVec, - DAG.getIntPtrConstant(0, dl)); - // Finally, do a normal sign-extend to the desired register. - return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract); - } - - MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); - SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, BitVec); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec, - DAG.getIntPtrConstant(0, dl)); - } - - assert(VT == MVT::v32i8 && "Unexpected extload type"); - - SDValue BasePtr = Ld->getBasePtr(); - SDValue LoadLo = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), - Ld->getBasePtr(), - Ld->getMemOperand()); - - SDValue BasePtrHi = DAG.getMemBasePlusOffset(BasePtr, 2, dl); - - SDValue LoadHi = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), BasePtrHi, - Ld->getPointerInfo().getWithOffset(2), - MinAlign(Ld->getAlignment(), 2U), - Ld->getMemOperand()->getFlags()); - - SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, - LoadLo.getValue(1), LoadHi.getValue(1)); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewChain); + SDValue StoredVal = St->getValue(); + + // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads. + assert(StoredVal.getValueType().isVector() && + StoredVal.getValueType().getVectorElementType() == MVT::i1 && + StoredVal.getValueType().getVectorNumElements() <= 8 && + "Unexpected VT"); + assert(!St->isTruncatingStore() && "Expected non-truncating store"); + assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && + "Expected AVX512F without AVX512DQI"); + + StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, + DAG.getUNDEF(MVT::v8i1), StoredVal, + DAG.getIntPtrConstant(0, dl)); + StoredVal = DAG.getBitcast(MVT::i8, StoredVal); - SDValue Lo = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadLo); - SDValue Hi = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadHi); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v32i8, Lo, Hi); + return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + St->getPointerInfo(), St->getAlignment(), + St->getMemOperand()->getFlags()); } // Lower vector extended loads using a shuffle. If SSSE3 is not available we @@ -18762,21 +19403,40 @@ static SDValue LowerExtended1BitVectorLoad(SDValue Op, // FIXME: Is the expansion actually better than scalar code? It doesn't seem so. // TODO: It is possible to support ZExt by zeroing the undef values during // the shuffle phase or after the shuffle. -static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget, +static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT RegVT = Op.getSimpleValueType(); assert(RegVT.isVector() && "We only custom lower vector sext loads."); assert(RegVT.isInteger() && "We only custom lower integer vector sext loads."); - // Nothing useful we can do without SSE2 shuffles. - assert(Subtarget.hasSSE2() && "We only custom lower sext loads with SSE2."); - LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode()); SDLoc dl(Ld); EVT MemVT = Ld->getMemoryVT(); - if (MemVT.getScalarType() == MVT::i1) - return LowerExtended1BitVectorLoad(Op, Subtarget, DAG); + + // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads. + if (RegVT.isVector() && RegVT.getVectorElementType() == MVT::i1) { + assert(EVT(RegVT) == MemVT && "Expected non-extending load"); + assert(RegVT.getVectorNumElements() <= 8 && "Unexpected VT"); + assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && + "Expected AVX512F without AVX512DQI"); + + SDValue NewLd = DAG.getLoad(MVT::i8, dl, Ld->getChain(), Ld->getBasePtr(), + Ld->getPointerInfo(), Ld->getAlignment(), + Ld->getMemOperand()->getFlags()); + + // Replace chain users with the new chain. + assert(NewLd->getNumValues() == 2 && "Loads must carry a chain!"); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewLd.getValue(1)); + + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, RegVT, + DAG.getBitcast(MVT::v8i1, NewLd), + DAG.getIntPtrConstant(0, dl)); + return DAG.getMergeValues({Extract, NewLd.getValue(1)}, dl); + } + + // Nothing useful we can do without SSE2 shuffles. + assert(Subtarget.hasSSE2() && "We only custom lower sext loads with SSE2."); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned RegSz = RegVT.getSizeInBits(); @@ -19619,7 +20279,7 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt); } -/// \brief Return Mask with the necessary casting or extending +/// Return Mask with the necessary casting or extending /// for \p Mask according to \p MaskVT when lowering masking intrinsics static SDValue getMaskNode(SDValue Mask, MVT MaskVT, const X86Subtarget &Subtarget, SelectionDAG &DAG, @@ -19637,27 +20297,19 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT, } if (Mask.getSimpleValueType() == MVT::i64 && Subtarget.is32Bit()) { - if (MaskVT == MVT::v64i1) { - assert(Subtarget.hasBWI() && "Expected AVX512BW target!"); - // In case 32bit mode, bitcast i64 is illegal, extend/split it. - SDValue Lo, Hi; - Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask, - DAG.getConstant(0, dl, MVT::i32)); - Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask, - DAG.getConstant(1, dl, MVT::i32)); - - Lo = DAG.getBitcast(MVT::v32i1, Lo); - Hi = DAG.getBitcast(MVT::v32i1, Hi); - - return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi); - } else { - // MaskVT require < 64bit. Truncate mask (should succeed in any case), - // and bitcast. - MVT TruncVT = MVT::getIntegerVT(MaskVT.getSizeInBits()); - return DAG.getBitcast(MaskVT, - DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Mask)); - } - + assert(MaskVT == MVT::v64i1 && "Expected v64i1 mask!"); + assert(Subtarget.hasBWI() && "Expected AVX512BW target!"); + // In case 32bit mode, bitcast i64 is illegal, extend/split it. + SDValue Lo, Hi; + Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask, + DAG.getConstant(0, dl, MVT::i32)); + Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask, + DAG.getConstant(1, dl, MVT::i32)); + + Lo = DAG.getBitcast(MVT::v32i1, Lo); + Hi = DAG.getBitcast(MVT::v32i1, Hi); + + return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi); } else { MVT BitcastVT = MVT::getVectorVT(MVT::i1, Mask.getSimpleValueType().getSizeInBits()); @@ -19669,7 +20321,7 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT, } } -/// \brief Return (and \p Op, \p Mask) for compare instructions or +/// Return (and \p Op, \p Mask) for compare instructions or /// (vselect \p Mask, \p Op, \p PreservedSrc) for others along with the /// necessary casting or extending for \p Mask when lowering masking intrinsics static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, @@ -19690,11 +20342,10 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, default: break; case X86ISD::CMPM: case X86ISD::CMPM_RND: - case X86ISD::CMPMU: case X86ISD::VPSHUFBITQMB: - return DAG.getNode(ISD::AND, dl, VT, Op, VMask); case X86ISD::VFPCLASS: - return DAG.getNode(ISD::OR, dl, VT, Op, VMask); + return DAG.getNode(ISD::AND, dl, VT, Op, VMask); + case ISD::TRUNCATE: case X86ISD::VTRUNC: case X86ISD::VTRUNCS: case X86ISD::VTRUNCUS: @@ -19710,7 +20361,7 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, return DAG.getNode(OpcodeSelect, dl, VT, VMask, Op, PreservedSrc); } -/// \brief Creates an SDNode for a predicated scalar operation. +/// Creates an SDNode for a predicated scalar operation. /// \returns (X86vselect \p Mask, \p Op, \p PreservedSrc). /// The mask is coming as MVT::i8 and it should be transformed /// to MVT::v1i1 while lowering masking intrinsics. @@ -19729,12 +20380,12 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask, MVT VT = Op.getSimpleValueType(); SDLoc dl(Op); + assert(Mask.getValueType() == MVT::i8 && "Unexpect type"); SDValue IMask = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Mask); if (Op.getOpcode() == X86ISD::FSETCCM || - Op.getOpcode() == X86ISD::FSETCCM_RND) + Op.getOpcode() == X86ISD::FSETCCM_RND || + Op.getOpcode() == X86ISD::VFPCLASSS) return DAG.getNode(ISD::AND, dl, VT, Op, IMask); - if (Op.getOpcode() == X86ISD::VFPCLASSS) - return DAG.getNode(ISD::OR, dl, VT, Op, IMask); if (PreservedSrc.isUndef()) PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl); @@ -19819,14 +20470,67 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, const IntrinsicData* IntrData = getIntrinsicWithoutChain(IntNo); if (IntrData) { switch(IntrData->Type) { - case INTR_TYPE_1OP: + case INTR_TYPE_1OP: { + // We specify 2 possible opcodes for intrinsics with rounding modes. + // First, we check if the intrinsic may have non-default rounding mode, + // (IntrData->Opc1 != 0), then we check the rounding mode operand. + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(2); + if (!isRoundModeCurDirection(Rnd)) { + return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), + Op.getOperand(1), Rnd); + } + } return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1)); + } case INTR_TYPE_2OP: - return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), - Op.getOperand(2)); + case INTR_TYPE_2OP_IMM8: { + SDValue Src2 = Op.getOperand(2); + + if (IntrData->Type == INTR_TYPE_2OP_IMM8) + Src2 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src2); + + // We specify 2 possible opcodes for intrinsics with rounding modes. + // First, we check if the intrinsic may have non-default rounding mode, + // (IntrData->Opc1 != 0), then we check the rounding mode operand. + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(3); + if (!isRoundModeCurDirection(Rnd)) { + return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), + Op.getOperand(1), Src2, Rnd); + } + } + + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), + Op.getOperand(1), Src2); + } case INTR_TYPE_3OP: - return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), - Op.getOperand(2), Op.getOperand(3)); + case INTR_TYPE_3OP_IMM8: { + SDValue Src1 = Op.getOperand(1); + SDValue Src2 = Op.getOperand(2); + SDValue Src3 = Op.getOperand(3); + + if (IntrData->Type == INTR_TYPE_3OP_IMM8) + Src3 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src3); + + // We specify 2 possible opcodes for intrinsics with rounding modes. + // First, we check if the intrinsic may have non-default rounding mode, + // (IntrData->Opc1 != 0), then we check the rounding mode operand. + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(4); + if (!isRoundModeCurDirection(Rnd)) { + return DAG.getNode(IntrWithRoundingModeOpcode, + dl, Op.getValueType(), + Src1, Src2, Src3, Rnd); + } + } + + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), + Src1, Src2, Src3); + } case INTR_TYPE_4OP: return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3), Op.getOperand(4)); @@ -19927,16 +20631,12 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, RoundingMode, Sae), Mask, Src0, Subtarget, DAG); } - case INTR_TYPE_2OP_MASK: - case INTR_TYPE_2OP_IMM8_MASK: { + case INTR_TYPE_2OP_MASK: { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); SDValue PassThru = Op.getOperand(3); SDValue Mask = Op.getOperand(4); - if (IntrData->Type == INTR_TYPE_2OP_IMM8_MASK) - Src2 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src2); - // We specify 2 possible opcodes for intrinsics with rounding modes. // First, we check if the intrinsic may have non-default rounding mode, // (IntrData->Opc1 != 0), then we check the rounding mode operand. @@ -19991,26 +20691,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Src2, Src3), Mask, PassThru, Subtarget, DAG); } - case INTR_TYPE_3OP_MASK_RM: { - SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); - SDValue Imm = Op.getOperand(3); - SDValue PassThru = Op.getOperand(4); - SDValue Mask = Op.getOperand(5); - // We specify 2 possible modes for intrinsics, with/without rounding - // modes. - // First, we check if the intrinsic have rounding mode (7 operands), - // if not, we set rounding mode to "current". - SDValue Rnd; - if (Op.getNumOperands() == 7) - Rnd = Op.getOperand(6); - else - Rnd = DAG.getConstant(X86::STATIC_ROUNDING::CUR_DIRECTION, dl, MVT::i32); - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, - Src1, Src2, Imm, Rnd), - Mask, PassThru, Subtarget, DAG); - } - case INTR_TYPE_3OP_IMM8_MASK: case INTR_TYPE_3OP_MASK: { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); @@ -20018,9 +20698,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue PassThru = Op.getOperand(4); SDValue Mask = Op.getOperand(5); - if (IntrData->Type == INTR_TYPE_3OP_IMM8_MASK) - Src3 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src3); - // We specify 2 possible opcodes for intrinsics with rounding modes. // First, we check if the intrinsic may have non-default rounding mode, // (IntrData->Opc1 != 0), then we check the rounding mode operand. @@ -20038,41 +20715,13 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Src1, Src2, Src3), Mask, PassThru, Subtarget, DAG); } - case VPERM_2OP_MASK : { + case VPERM_2OP : { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); - SDValue PassThru = Op.getOperand(3); - SDValue Mask = Op.getOperand(4); - - // Swap Src1 and Src2 in the node creation - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1), - Mask, PassThru, Subtarget, DAG); - } - case VPERM_3OP_MASKZ: - case VPERM_3OP_MASK:{ - MVT VT = Op.getSimpleValueType(); - // Src2 is the PassThru - SDValue Src1 = Op.getOperand(1); - // PassThru needs to be the same type as the destination in order - // to pattern match correctly. - SDValue Src2 = DAG.getBitcast(VT, Op.getOperand(2)); - SDValue Src3 = Op.getOperand(3); - SDValue Mask = Op.getOperand(4); - SDValue PassThru = SDValue(); - - // set PassThru element - if (IntrData->Type == VPERM_3OP_MASKZ) - PassThru = getZeroVector(VT, Subtarget, DAG, dl); - else - PassThru = Src2; // Swap Src1 and Src2 in the node creation - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, - dl, Op.getValueType(), - Src2, Src1, Src3), - Mask, PassThru, Subtarget, DAG); + return DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1); } - case FMA_OP_MASK3: case FMA_OP_MASKZ: case FMA_OP_MASK: { SDValue Src1 = Op.getOperand(1); @@ -20085,8 +20734,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, // set PassThru element if (IntrData->Type == FMA_OP_MASKZ) PassThru = getZeroVector(VT, Subtarget, DAG, dl); - else if (IntrData->Type == FMA_OP_MASK3) - PassThru = Src3; else PassThru = Src1; @@ -20107,76 +20754,11 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Src1, Src2, Src3), Mask, PassThru, Subtarget, DAG); } - case FMA_OP_SCALAR_MASK: - case FMA_OP_SCALAR_MASK3: - case FMA_OP_SCALAR_MASKZ: { - SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); - SDValue Src3 = Op.getOperand(3); - SDValue Mask = Op.getOperand(4); - MVT VT = Op.getSimpleValueType(); - SDValue PassThru = SDValue(); - - // set PassThru element - if (IntrData->Type == FMA_OP_SCALAR_MASKZ) - PassThru = getZeroVector(VT, Subtarget, DAG, dl); - else if (IntrData->Type == FMA_OP_SCALAR_MASK3) - PassThru = Src3; - else - PassThru = Src1; - - unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; - if (IntrWithRoundingModeOpcode != 0) { - SDValue Rnd = Op.getOperand(5); - if (!isRoundModeCurDirection(Rnd)) - return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, - Op.getValueType(), Src1, Src2, - Src3, Rnd), - Mask, PassThru, Subtarget, DAG); - } - - return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, - Op.getValueType(), Src1, Src2, - Src3), - Mask, PassThru, Subtarget, DAG); - } - case IFMA_OP_MASKZ: - case IFMA_OP_MASK: { - SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); - SDValue Src3 = Op.getOperand(3); - SDValue Mask = Op.getOperand(4); - MVT VT = Op.getSimpleValueType(); - SDValue PassThru = Src1; - - // set PassThru element - if (IntrData->Type == IFMA_OP_MASKZ) - PassThru = getZeroVector(VT, Subtarget, DAG, dl); - - // Node we need to swizzle the operands to pass the multiply operands + case IFMA_OP: + // NOTE: We need to swizzle the operands to pass the multiply operands // first. - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, - dl, Op.getValueType(), - Src2, Src3, Src1), - Mask, PassThru, Subtarget, DAG); - } - case TERLOG_OP_MASK: - case TERLOG_OP_MASKZ: { - SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); - SDValue Src3 = Op.getOperand(3); - SDValue Src4 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Op.getOperand(4)); - SDValue Mask = Op.getOperand(5); - MVT VT = Op.getSimpleValueType(); - SDValue PassThru = Src1; - // Set PassThru element. - if (IntrData->Type == TERLOG_OP_MASKZ) - PassThru = getZeroVector(VT, Subtarget, DAG, dl); - - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, - Src1, Src2, Src3, Src4), - Mask, PassThru, Subtarget, DAG); - } + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); case CVTPD2PS: // ISD::FP_ROUND has a second argument that indicates if the truncation // does not change the value. Set it to 0 since it can change. @@ -20207,21 +20789,11 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Mask, PassThru, Subtarget, DAG); } case FPCLASS: { - // FPclass intrinsics with mask - SDValue Src1 = Op.getOperand(1); - MVT VT = Src1.getSimpleValueType(); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue Imm = Op.getOperand(2); - SDValue Mask = Op.getOperand(3); - MVT BitcastVT = MVT::getVectorVT(MVT::i1, - Mask.getSimpleValueType().getSizeInBits()); - SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm); - SDValue FPclassMask = getVectorMaskingNode(FPclass, Mask, SDValue(), - Subtarget, DAG); - SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, - DAG.getUNDEF(BitcastVT), FPclassMask, - DAG.getIntPtrConstant(0, dl)); - return DAG.getBitcast(Op.getValueType(), Res); + // FPclass intrinsics + SDValue Src1 = Op.getOperand(1); + MVT MaskVT = Op.getSimpleValueType(); + SDValue Imm = Op.getOperand(2); + return DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm); } case FPCLASSS: { SDValue Src1 = Op.getOperand(1); @@ -20230,17 +20802,20 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Imm); SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, SDValue(), Subtarget, DAG); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, FPclassMask, - DAG.getIntPtrConstant(0, dl)); - } - case CMP_MASK: - case CMP_MASK_CC: { + // Need to fill with zeros to ensure the bitcast will produce zeroes + // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that. + SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, + DAG.getConstant(0, dl, MVT::v8i1), + FPclassMask, DAG.getIntPtrConstant(0, dl)); + return DAG.getBitcast(MVT::i8, Ins); + } + case CMP_MASK: { // Comparison intrinsics with masks. // Example of transformation: // (i8 (int_x86_avx512_mask_pcmpeq_q_128 // (v2i64 %a), (v2i64 %b), (i8 %mask))) -> // (i8 (bitcast - // (v8i1 (insert_subvector undef, + // (v8i1 (insert_subvector zero, // (v2i1 (and (PCMPEQM %a, %b), // (extract_subvector // (v8i1 (bitcast %mask)), 0))), 0)))) @@ -20249,36 +20824,39 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue Mask = Op.getOperand((IntrData->Type == CMP_MASK_CC) ? 4 : 3); MVT BitcastVT = MVT::getVectorVT(MVT::i1, Mask.getSimpleValueType().getSizeInBits()); - SDValue Cmp; - if (IntrData->Type == CMP_MASK_CC) { - SDValue CC = Op.getOperand(3); - CC = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, CC); - // We specify 2 possible opcodes for intrinsics with rounding modes. - // First, we check if the intrinsic may have non-default rounding mode, - // (IntrData->Opc1 != 0), then we check the rounding mode operand. - if (IntrData->Opc1 != 0) { - SDValue Rnd = Op.getOperand(5); - if (!isRoundModeCurDirection(Rnd)) - Cmp = DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1), - Op.getOperand(2), CC, Rnd); - } - //default rounding mode - if(!Cmp.getNode()) - Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), - Op.getOperand(2), CC); - - } else { - assert(IntrData->Type == CMP_MASK && "Unexpected intrinsic type!"); - Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), - Op.getOperand(2)); - } + SDValue Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), + Op.getOperand(2)); SDValue CmpMask = getVectorMaskingNode(Cmp, Mask, SDValue(), Subtarget, DAG); + // Need to fill with zeros to ensure the bitcast will produce zeroes + // for the upper bits in the v2i1/v4i1 case. SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, - DAG.getUNDEF(BitcastVT), CmpMask, - DAG.getIntPtrConstant(0, dl)); + DAG.getConstant(0, dl, BitcastVT), + CmpMask, DAG.getIntPtrConstant(0, dl)); return DAG.getBitcast(Op.getValueType(), Res); } + + case CMP_MASK_CC: { + MVT MaskVT = Op.getSimpleValueType(); + SDValue Cmp; + SDValue CC = Op.getOperand(3); + CC = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, CC); + // We specify 2 possible opcodes for intrinsics with rounding modes. + // First, we check if the intrinsic may have non-default rounding mode, + // (IntrData->Opc1 != 0), then we check the rounding mode operand. + if (IntrData->Opc1 != 0) { + SDValue Rnd = Op.getOperand(4); + if (!isRoundModeCurDirection(Rnd)) + Cmp = DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1), + Op.getOperand(2), CC, Rnd); + } + //default rounding mode + if (!Cmp.getNode()) + Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), + Op.getOperand(2), CC); + + return Cmp; + } case CMP_MASK_SCALAR_CC: { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); @@ -20297,8 +20875,12 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue CmpMask = getScalarMaskingNode(Cmp, Mask, SDValue(), Subtarget, DAG); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, CmpMask, - DAG.getIntPtrConstant(0, dl)); + // Need to fill with zeros to ensure the bitcast will produce zeroes + // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that. + SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, + DAG.getConstant(0, dl, MVT::v8i1), + CmpMask, DAG.getIntPtrConstant(0, dl)); + return DAG.getBitcast(MVT::i8, Ins); } case COMI: { // Comparison intrinsics ISD::CondCode CC = (ISD::CondCode)IntrData->Opc1; @@ -20351,8 +20933,13 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, else FCmp = DAG.getNode(X86ISD::FSETCCM_RND, dl, MVT::v1i1, LHS, RHS, DAG.getConstant(CondVal, dl, MVT::i8), Sae); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, FCmp, - DAG.getIntPtrConstant(0, dl)); + // Need to fill with zeros to ensure the bitcast will produce zeroes + // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that. + SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, + DAG.getConstant(0, dl, MVT::v16i1), + FCmp, DAG.getIntPtrConstant(0, dl)); + return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, + DAG.getBitcast(MVT::i16, Ins)); } case VSHIFT: return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(), @@ -20369,22 +20956,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, DataToCompress), Mask, PassThru, Subtarget, DAG); } - case BROADCASTM: { - SDValue Mask = Op.getOperand(1); - MVT MaskVT = MVT::getVectorVT(MVT::i1, - Mask.getSimpleValueType().getSizeInBits()); - Mask = DAG.getBitcast(MaskVT, Mask); - return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Mask); - } - case MASK_BINOP: { - MVT VT = Op.getSimpleValueType(); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getSizeInBits()); - - SDValue Src1 = getMaskNode(Op.getOperand(1), MaskVT, Subtarget, DAG, dl); - SDValue Src2 = getMaskNode(Op.getOperand(2), MaskVT, Subtarget, DAG, dl); - SDValue Res = DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Src2); - return DAG.getBitcast(VT, Res); - } case FIXUPIMMS: case FIXUPIMMS_MASKZ: case FIXUPIMM: @@ -20414,18 +20985,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Src1, Src2, Src3, Imm, Rnd), Mask, Passthru, Subtarget, DAG); } - case CONVERT_TO_MASK: { - MVT SrcVT = Op.getOperand(1).getSimpleValueType(); - MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements()); - MVT BitcastVT = MVT::getVectorVT(MVT::i1, VT.getSizeInBits()); - - SDValue CvtMask = DAG.getNode(IntrData->Opc0, dl, MaskVT, - Op.getOperand(1)); - SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, - DAG.getUNDEF(BitcastVT), CvtMask, - DAG.getIntPtrConstant(0, dl)); - return DAG.getBitcast(Op.getValueType(), Res); - } case ROUNDP: { assert(IntrData->Opc0 == X86ISD::VRNDSCALE && "Unexpected opcode"); // Clear the upper bits of the rounding immediate so that the legacy @@ -20454,13 +21013,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, switch (IntNo) { default: return SDValue(); // Don't custom lower most intrinsics. - case Intrinsic::x86_avx2_permd: - case Intrinsic::x86_avx2_permps: - // Operands intentionally swapped. Mask is last operand to intrinsic, - // but second operand for node/instruction. - return DAG.getNode(X86ISD::VPERMV, dl, Op.getValueType(), - Op.getOperand(2), Op.getOperand(1)); - // ptest and testp intrinsics. The intrinsic these come from are designed to // return an integer value, not just an instruction so lower it to the ptest // or testp pattern and a setcc for the result. @@ -20528,43 +21080,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue SetCC = getSETCC(X86CC, Test, dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); } - case Intrinsic::x86_avx512_kortestz_w: - case Intrinsic::x86_avx512_kortestc_w: { - X86::CondCode X86CC = - (IntNo == Intrinsic::x86_avx512_kortestz_w) ? X86::COND_E : X86::COND_B; - SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1)); - SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2)); - SDValue Test = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS); - SDValue SetCC = getSETCC(X86CC, Test, dl, DAG); - return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); - } - - case Intrinsic::x86_avx512_knot_w: { - SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1)); - SDValue RHS = DAG.getConstant(1, dl, MVT::v16i1); - SDValue Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS, RHS); - return DAG.getBitcast(MVT::i16, Res); - } - - case Intrinsic::x86_avx512_kandn_w: { - SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1)); - // Invert LHS for the not. - LHS = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS, - DAG.getConstant(1, dl, MVT::v16i1)); - SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2)); - SDValue Res = DAG.getNode(ISD::AND, dl, MVT::v16i1, LHS, RHS); - return DAG.getBitcast(MVT::i16, Res); - } - - case Intrinsic::x86_avx512_kxnor_w: { - SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1)); - SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2)); - SDValue Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS, RHS); - // Invert result for the not. - Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, Res, - DAG.getConstant(1, dl, MVT::v16i1)); - return DAG.getBitcast(MVT::i16, Res); - } case Intrinsic::x86_sse42_pcmpistria128: case Intrinsic::x86_sse42_pcmpestria128: @@ -20581,50 +21096,50 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, switch (IntNo) { default: llvm_unreachable("Impossible intrinsic"); // Can't reach here. case Intrinsic::x86_sse42_pcmpistria128: - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; X86CC = X86::COND_A; break; case Intrinsic::x86_sse42_pcmpestria128: - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; X86CC = X86::COND_A; break; case Intrinsic::x86_sse42_pcmpistric128: - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; X86CC = X86::COND_B; break; case Intrinsic::x86_sse42_pcmpestric128: - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; X86CC = X86::COND_B; break; case Intrinsic::x86_sse42_pcmpistrio128: - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; X86CC = X86::COND_O; break; case Intrinsic::x86_sse42_pcmpestrio128: - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; X86CC = X86::COND_O; break; case Intrinsic::x86_sse42_pcmpistris128: - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; X86CC = X86::COND_S; break; case Intrinsic::x86_sse42_pcmpestris128: - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; X86CC = X86::COND_S; break; case Intrinsic::x86_sse42_pcmpistriz128: - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; X86CC = X86::COND_E; break; case Intrinsic::x86_sse42_pcmpestriz128: - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; X86CC = X86::COND_E; break; } SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end()); - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); - SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps); - SDValue SetCC = getSETCC(X86CC, SDValue(PCMP.getNode(), 1), dl, DAG); + SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32); + SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps).getValue(2); + SDValue SetCC = getSETCC(X86CC, PCMP, dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); } @@ -20632,15 +21147,28 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::x86_sse42_pcmpestri128: { unsigned Opcode; if (IntNo == Intrinsic::x86_sse42_pcmpistri128) - Opcode = X86ISD::PCMPISTRI; + Opcode = X86ISD::PCMPISTR; else - Opcode = X86ISD::PCMPESTRI; + Opcode = X86ISD::PCMPESTR; SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end()); - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); + SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32); return DAG.getNode(Opcode, dl, VTs, NewOps); } + case Intrinsic::x86_sse42_pcmpistrm128: + case Intrinsic::x86_sse42_pcmpestrm128: { + unsigned Opcode; + if (IntNo == Intrinsic::x86_sse42_pcmpistrm128) + Opcode = X86ISD::PCMPISTR; + else + Opcode = X86ISD::PCMPESTR; + + SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end()); + SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32); + return DAG.getNode(Opcode, dl, VTs, NewOps).getValue(1); + } + case Intrinsic::eh_sjlj_lsda: { MachineFunction &MF = DAG.getMachineFunction(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -20708,7 +21236,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, SDValue Segment = DAG.getRegister(0, MVT::i32); // If source is undef or we know it won't be used, use a zero vector // to break register dependency. - // TODO: use undef instead and let ExecutionDepsFix deal with it? + // TODO: use undef instead and let BreakFalseDeps deal with it? if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode())) Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl); SDValue Ops[] = {Src, Base, Scale, Index, Disp, Segment, Mask, Chain}; @@ -20736,7 +21264,7 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, SDValue Segment = DAG.getRegister(0, MVT::i32); // If source is undef or we know it won't be used, use a zero vector // to break register dependency. - // TODO: use undef instead and let ExecutionDepsFix deal with it? + // TODO: use undef instead and let BreakFalseDeps deal with it? if (Src.isUndef() || ISD::isBuildVectorAllOnes(VMask.getNode())) Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl); SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain}; @@ -21029,17 +21557,35 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, return SDValue(); } case Intrinsic::x86_lwpins32: - case Intrinsic::x86_lwpins64: { + case Intrinsic::x86_lwpins64: + case Intrinsic::x86_umwait: + case Intrinsic::x86_tpause: { SDLoc dl(Op); SDValue Chain = Op->getOperand(0); SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other); - SDValue LwpIns = - DAG.getNode(X86ISD::LWPINS, dl, VTs, Chain, Op->getOperand(2), + unsigned Opcode; + + switch (IntNo) { + default: llvm_unreachable("Impossible intrinsic"); + case Intrinsic::x86_umwait: + Opcode = X86ISD::UMWAIT; + break; + case Intrinsic::x86_tpause: + Opcode = X86ISD::TPAUSE; + break; + case Intrinsic::x86_lwpins32: + case Intrinsic::x86_lwpins64: + Opcode = X86ISD::LWPINS; + break; + } + + SDValue Operation = + DAG.getNode(Opcode, dl, VTs, Chain, Op->getOperand(2), Op->getOperand(3), Op->getOperand(4)); - SDValue SetCC = getSETCC(X86::COND_B, LwpIns.getValue(0), dl, DAG); + SDValue SetCC = getSETCC(X86::COND_B, Operation.getValue(0), dl, DAG); SDValue Result = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, SetCC); return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result, - LwpIns.getValue(1)); + Operation.getValue(1)); } } return SDValue(); @@ -21155,27 +21701,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue Results[] = { SetCC, Store }; return DAG.getMergeValues(Results, dl); } - case COMPRESS_TO_MEM: { - SDValue Mask = Op.getOperand(4); - SDValue DataToCompress = Op.getOperand(3); - SDValue Addr = Op.getOperand(2); - SDValue Chain = Op.getOperand(0); - MVT VT = DataToCompress.getSimpleValueType(); - - MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op); - assert(MemIntr && "Expected MemIntrinsicSDNode!"); - - if (isAllOnesConstant(Mask)) // return just a store - return DAG.getStore(Chain, dl, DataToCompress, Addr, - MemIntr->getMemOperand()); - - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - - return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT, - MemIntr->getMemOperand(), - false /* truncating */, true /* compressing */); - } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: case TRUNCATE_TO_MEM_VI32: { @@ -21219,28 +21744,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, llvm_unreachable("Unsupported truncstore intrinsic"); } } - - case EXPAND_FROM_MEM: { - SDValue Mask = Op.getOperand(4); - SDValue PassThru = Op.getOperand(3); - SDValue Addr = Op.getOperand(2); - SDValue Chain = Op.getOperand(0); - MVT VT = Op.getSimpleValueType(); - - MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op); - assert(MemIntr && "Expected MemIntrinsicSDNode!"); - - if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load. - return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand()); - if (X86::isZeroNode(Mask)) - return DAG.getUNDEF(VT); - - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT, - MemIntr->getMemOperand(), ISD::NON_EXTLOAD, - true /* expanding */); - } } } @@ -21657,14 +22160,16 @@ static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); unsigned NumElems = VT.getVectorNumElements(); unsigned SizeInBits = VT.getSizeInBits(); + MVT EltVT = VT.getVectorElementType(); + SDValue Src = Op.getOperand(0); + assert(EltVT == Src.getSimpleValueType().getVectorElementType() && + "Src and Op should have the same element type!"); // Extract the Lo/Hi vectors SDLoc dl(Op); - SDValue Src = Op.getOperand(0); SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2); SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2); - MVT EltVT = VT.getVectorElementType(); MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2); return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, DAG.getNode(Op.getOpcode(), dl, NewVT, Lo), @@ -21687,13 +22192,14 @@ static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) { return LowerVectorIntUnary(Op, DAG); } -/// \brief Lower a vector CTLZ using native supported vector CTLZ instruction. +/// Lower a vector CTLZ using native supported vector CTLZ instruction. // // i8/i16 vector implemented using dword LZCNT vector instruction // ( sub(trunc(lzcnt(zext32(x)))) ). In case zext32(x) is illegal, // split the vector, perform operation on it's Lo a Hi part and // concatenate the results. -static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG) { +static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { assert(Op.getOpcode() == ISD::CTLZ); SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); @@ -21704,7 +22210,8 @@ static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG) { "Unsupported element type"); // Split vector, it's Lo and Hi parts will be handled in next iteration. - if (16 < NumElems) + if (NumElems > 16 || + (NumElems == 16 && !Subtarget.canExtendTo512DQ())) return LowerVectorIntUnary(Op, DAG); MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems); @@ -21809,8 +22316,10 @@ static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); - if (Subtarget.hasCDI()) - return LowerVectorCTLZ_AVX512CDI(Op, DAG); + if (Subtarget.hasCDI() && + // vXi8 vectors need to be promoted to 512-bits for vXi32. + (Subtarget.canExtendTo512DQ() || VT.getVectorElementType() != MVT::i8)) + return LowerVectorCTLZ_AVX512CDI(Op, DAG, Subtarget); // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) @@ -21999,10 +22508,42 @@ static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) { } static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) { - assert(Op.getSimpleValueType().is256BitVector() && - Op.getSimpleValueType().isInteger() && - "Only handle AVX 256-bit vector integer operation"); - return Lower256IntArith(Op, DAG); + MVT VT = Op.getSimpleValueType(); + + // For AVX1 cases, split to use legal ops (everything but v4i64). + if (VT.getScalarType() != MVT::i64 && VT.is256BitVector()) + return Lower256IntArith(Op, DAG); + + SDLoc DL(Op); + unsigned Opcode = Op.getOpcode(); + SDValue N0 = Op.getOperand(0); + SDValue N1 = Op.getOperand(1); + + // For pre-SSE41, we can perform UMIN/UMAX v8i16 by flipping the signbit, + // using the SMIN/SMAX instructions and flipping the signbit back. + if (VT == MVT::v8i16) { + assert((Opcode == ISD::UMIN || Opcode == ISD::UMAX) && + "Unexpected MIN/MAX opcode"); + SDValue Sign = DAG.getConstant(APInt::getSignedMinValue(16), DL, VT); + N0 = DAG.getNode(ISD::XOR, DL, VT, N0, Sign); + N1 = DAG.getNode(ISD::XOR, DL, VT, N1, Sign); + Opcode = (Opcode == ISD::UMIN ? ISD::SMIN : ISD::SMAX); + SDValue Result = DAG.getNode(Opcode, DL, VT, N0, N1); + return DAG.getNode(ISD::XOR, DL, VT, Result, Sign); + } + + // Else, expand to a compare/select. + ISD::CondCode CC; + switch (Opcode) { + case ISD::SMIN: CC = ISD::CondCode::SETLT; break; + case ISD::SMAX: CC = ISD::CondCode::SETGT; break; + case ISD::UMIN: CC = ISD::CondCode::SETULT; break; + case ISD::UMAX: CC = ISD::CondCode::SETUGT; break; + default: llvm_unreachable("Unknown MINMAX opcode"); + } + + SDValue Cond = DAG.getSetCC(DL, VT, N0, N1, CC); + return DAG.getSelect(DL, VT, Cond, N0, N1); } static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, @@ -22048,40 +22589,26 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, MVT ExVT = MVT::v8i16; // Extract the lo parts and sign extend to i16 - SDValue ALo, BLo; - if (Subtarget.hasSSE41()) { - ALo = DAG.getSignExtendVectorInReg(A, dl, ExVT); - BLo = DAG.getSignExtendVectorInReg(B, dl, ExVT); - } else { - const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3, - -1, 4, -1, 5, -1, 6, -1, 7}; - ALo = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BLo = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); - ALo = DAG.getNode(ISD::SRA, dl, ExVT, ALo, DAG.getConstant(8, dl, ExVT)); - BLo = DAG.getNode(ISD::SRA, dl, ExVT, BLo, DAG.getConstant(8, dl, ExVT)); - } + // We're going to mask off the low byte of each result element of the + // pmullw, so it doesn't matter what's in the high byte of each 16-bit + // element. + const int LoShufMask[] = {0, -1, 1, -1, 2, -1, 3, -1, + 4, -1, 5, -1, 6, -1, 7, -1}; + SDValue ALo = DAG.getVectorShuffle(VT, dl, A, A, LoShufMask); + SDValue BLo = DAG.getVectorShuffle(VT, dl, B, B, LoShufMask); + ALo = DAG.getBitcast(ExVT, ALo); + BLo = DAG.getBitcast(ExVT, BLo); // Extract the hi parts and sign extend to i16 - SDValue AHi, BHi; - if (Subtarget.hasSSE41()) { - const int ShufMask[] = {8, 9, 10, 11, 12, 13, 14, 15, - -1, -1, -1, -1, -1, -1, -1, -1}; - AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = DAG.getSignExtendVectorInReg(AHi, dl, ExVT); - BHi = DAG.getSignExtendVectorInReg(BHi, dl, ExVT); - } else { - const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11, - -1, 12, -1, 13, -1, 14, -1, 15}; - AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); - AHi = DAG.getNode(ISD::SRA, dl, ExVT, AHi, DAG.getConstant(8, dl, ExVT)); - BHi = DAG.getNode(ISD::SRA, dl, ExVT, BHi, DAG.getConstant(8, dl, ExVT)); - } + // We're going to mask off the low byte of each result element of the + // pmullw, so it doesn't matter what's in the high byte of each 16-bit + // element. + const int HiShufMask[] = {8, -1, 9, -1, 10, -1, 11, -1, + 12, -1, 13, -1, 14, -1, 15, -1}; + SDValue AHi = DAG.getVectorShuffle(VT, dl, A, A, HiShufMask); + SDValue BHi = DAG.getVectorShuffle(VT, dl, B, B, HiShufMask); + AHi = DAG.getBitcast(ExVT, AHi); + BHi = DAG.getBitcast(ExVT, BHi); // Multiply, mask the lower 8bits of the lo/hi results and pack SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); @@ -22096,22 +22623,19 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, assert(Subtarget.hasSSE2() && !Subtarget.hasSSE41() && "Should not custom lower when pmulld is available!"); - // If the upper 17 bits of each element are zero then we can use PMADD. - APInt Mask17 = APInt::getHighBitsSet(32, 17); - if (DAG.MaskedValueIsZero(A, Mask17) && DAG.MaskedValueIsZero(B, Mask17)) - return DAG.getNode(X86ISD::VPMADDWD, dl, VT, - DAG.getBitcast(MVT::v8i16, A), - DAG.getBitcast(MVT::v8i16, B)); - // Extract the odd parts. static const int UnpackMask[] = { 1, -1, 3, -1 }; SDValue Aodds = DAG.getVectorShuffle(VT, dl, A, A, UnpackMask); SDValue Bodds = DAG.getVectorShuffle(VT, dl, B, B, UnpackMask); // Multiply the even parts. - SDValue Evens = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, A, B); + SDValue Evens = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, + DAG.getBitcast(MVT::v2i64, A), + DAG.getBitcast(MVT::v2i64, B)); // Now multiply odd parts. - SDValue Odds = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, Aodds, Bodds); + SDValue Odds = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, + DAG.getBitcast(MVT::v2i64, Aodds), + DAG.getBitcast(MVT::v2i64, Bodds)); Evens = DAG.getBitcast(VT, Evens); Odds = DAG.getBitcast(VT, Odds); @@ -22124,17 +22648,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, assert((VT == MVT::v2i64 || VT == MVT::v4i64 || VT == MVT::v8i64) && "Only know how to lower V2I64/V4I64/V8I64 multiply"); - - // 32-bit vector types used for MULDQ/MULUDQ. - MVT MulVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32); - - // MULDQ returns the 64-bit result of the signed multiplication of the lower - // 32-bits. We can lower with this if the sign bits stretch that far. - if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(A) > 32 && - DAG.ComputeNumSignBits(B) > 32) { - return DAG.getNode(X86ISD::PMULDQ, dl, VT, DAG.getBitcast(MulVT, A), - DAG.getBitcast(MulVT, B)); - } + assert(!Subtarget.hasDQI() && "DQI should use MULLQ"); // Ahi = psrlqi(a, 32); // Bhi = psrlqi(b, 32); @@ -22145,42 +22659,35 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // // Hi = psllqi(AloBhi + AhiBlo, 32); // return AloBlo + Hi; + KnownBits AKnown, BKnown; + DAG.computeKnownBits(A, AKnown); + DAG.computeKnownBits(B, BKnown); + APInt LowerBitsMask = APInt::getLowBitsSet(64, 32); - bool ALoIsZero = DAG.MaskedValueIsZero(A, LowerBitsMask); - bool BLoIsZero = DAG.MaskedValueIsZero(B, LowerBitsMask); + bool ALoIsZero = LowerBitsMask.isSubsetOf(AKnown.Zero); + bool BLoIsZero = LowerBitsMask.isSubsetOf(BKnown.Zero); APInt UpperBitsMask = APInt::getHighBitsSet(64, 32); - bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask); - bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask); - - // If DQI is supported we can use MULLQ, but MULUDQ is still better if the - // the high bits are known to be zero. - if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero)) - return Op; - - // Bit cast to 32-bit vectors for MULUDQ. - SDValue Alo = DAG.getBitcast(MulVT, A); - SDValue Blo = DAG.getBitcast(MulVT, B); + bool AHiIsZero = UpperBitsMask.isSubsetOf(AKnown.Zero); + bool BHiIsZero = UpperBitsMask.isSubsetOf(BKnown.Zero); SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl); // Only multiply lo/hi halves that aren't known to be zero. SDValue AloBlo = Zero; if (!ALoIsZero && !BLoIsZero) - AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Blo); + AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B); SDValue AloBhi = Zero; if (!ALoIsZero && !BHiIsZero) { SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG); - Bhi = DAG.getBitcast(MulVT, Bhi); - AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Bhi); + AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi); } SDValue AhiBlo = Zero; if (!AHiIsZero && !BLoIsZero) { SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG); - Ahi = DAG.getBitcast(MulVT, Ahi); - AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, Blo); + AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B); } SDValue Hi = DAG.getNode(ISD::ADD, dl, VT, AloBhi, AhiBlo); @@ -22226,7 +22733,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, SDValue Hi = DAG.getIntPtrConstant(NumElems / 2, dl); if (VT == MVT::v32i8) { - if (Subtarget.hasBWI()) { + if (Subtarget.canExtendTo512BW()) { SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v32i16, A); SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v32i16, B); SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v32i16, ExA, ExB); @@ -22277,13 +22784,14 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, assert(VT == MVT::v16i8 && "Pre-AVX2 support only supports v16i8 multiplication"); MVT ExVT = MVT::v8i16; - unsigned ExSSE41 = (ISD::MULHU == Opcode ? X86ISD::VZEXT : X86ISD::VSEXT); + unsigned ExSSE41 = ISD::MULHU == Opcode ? ISD::ZERO_EXTEND_VECTOR_INREG + : ISD::SIGN_EXTEND_VECTOR_INREG; // Extract the lo parts and zero/sign extend to i16. SDValue ALo, BLo; if (Subtarget.hasSSE41()) { - ALo = getExtendInVec(ExSSE41, dl, ExVT, A, DAG); - BLo = getExtendInVec(ExSSE41, dl, ExVT, B, DAG); + ALo = DAG.getNode(ExSSE41, dl, ExVT, A); + BLo = DAG.getNode(ExSSE41, dl, ExVT, B); } else { const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3, -1, 4, -1, 5, -1, 6, -1, 7}; @@ -22302,8 +22810,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, -1, -1, -1, -1, -1, -1, -1, -1}; AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = getExtendInVec(ExSSE41, dl, ExVT, AHi, DAG); - BHi = getExtendInVec(ExSSE41, dl, ExVT, BHi, DAG); + AHi = DAG.getNode(ExSSE41, dl, ExVT, AHi); + BHi = DAG.getNode(ExSSE41, dl, ExVT, BHi); } else { const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11, -1, 12, -1, 13, -1, 14, -1, 15}; @@ -22438,10 +22946,14 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget, (!IsSigned || !Subtarget.hasSSE41()) ? X86ISD::PMULUDQ : X86ISD::PMULDQ; // PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h> // => <2 x i64> <ae|cg> - SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, Op0, Op1)); + SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, + DAG.getBitcast(MulVT, Op0), + DAG.getBitcast(MulVT, Op1))); // PMULUDQ <4 x i32> <b|undef|d|undef>, <4 x i32> <f|undef|h|undef> // => <2 x i64> <bf|dh> - SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, Odd0, Odd1)); + SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, + DAG.getBitcast(MulVT, Odd0), + DAG.getBitcast(MulVT, Odd1))); // Shuffle it back into the right order. SmallVector<int, 16> HighMask(NumElts); @@ -22601,7 +23113,8 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl); if (VT.is512BitVector()) { assert(VT == MVT::v64i8 && "Unexpected element type!"); - SDValue CMP = DAG.getNode(X86ISD::PCMPGTM, dl, MVT::v64i1, Zeros, R); + SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, + ISD::SETGT); return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP); } return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R); @@ -22711,57 +23224,81 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, return SDValue(); } +// Determine if V is a splat value, and return the scalar. +static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, + SelectionDAG &DAG, const X86Subtarget &Subtarget, + unsigned Opcode) { + V = peekThroughEXTRACT_SUBVECTORs(V); + + // Check if this is a splat build_vector node. + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) { + SDValue SplatAmt = BV->getSplatValue(); + if (SplatAmt && SplatAmt.isUndef()) + return SDValue(); + return SplatAmt; + } + + // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns. + if (V.getOpcode() == ISD::SUB && + !SupportedVectorVarShift(VT, Subtarget, Opcode)) { + SDValue LHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(0)); + SDValue RHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(1)); + + // Ensure that the corresponding splat BV element is not UNDEF. + BitVector UndefElts; + BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS); + ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS); + if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) { + unsigned SplatIdx = (unsigned)SVN1->getSplatIndex(); + if (!UndefElts[SplatIdx]) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + VT.getVectorElementType(), V, + DAG.getIntPtrConstant(SplatIdx, dl)); + } + } + + // Check if this is a shuffle node doing a splat. + ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V); + if (!SVN || !SVN->isSplat()) + return SDValue(); + + unsigned SplatIdx = (unsigned)SVN->getSplatIndex(); + SDValue InVec = V.getOperand(0); + if (InVec.getOpcode() == ISD::BUILD_VECTOR) { + assert((SplatIdx < VT.getVectorNumElements()) && + "Unexpected shuffle index found!"); + return InVec.getOperand(SplatIdx); + } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) { + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2))) + if (C->getZExtValue() == SplatIdx) + return InVec.getOperand(1); + } + + // Avoid introducing an extract element from a shuffle. + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + VT.getVectorElementType(), InVec, + DAG.getIntPtrConstant(SplatIdx, dl)); +} + static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); SDLoc dl(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); + unsigned Opcode = Op.getOpcode(); - unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; - - unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; - - if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) { - SDValue BaseShAmt; - MVT EltVT = VT.getVectorElementType(); - - if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Amt)) { - // Check if this build_vector node is doing a splat. - // If so, then set BaseShAmt equal to the splat value. - BaseShAmt = BV->getSplatValue(); - if (BaseShAmt && BaseShAmt.isUndef()) - BaseShAmt = SDValue(); - } else { - if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR) - Amt = Amt.getOperand(0); + unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI : + (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; - ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(Amt); - if (SVN && SVN->isSplat()) { - unsigned SplatIdx = (unsigned)SVN->getSplatIndex(); - SDValue InVec = Amt.getOperand(0); - if (InVec.getOpcode() == ISD::BUILD_VECTOR) { - assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) && - "Unexpected shuffle index found!"); - BaseShAmt = InVec.getOperand(SplatIdx); - } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) { - if (ConstantSDNode *C = - dyn_cast<ConstantSDNode>(InVec.getOperand(2))) { - if (C->getZExtValue() == SplatIdx) - BaseShAmt = InVec.getOperand(1); - } - } + unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL : + (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; - if (!BaseShAmt) - // Avoid introducing an extract element from a shuffle. - BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, InVec, - DAG.getIntPtrConstant(SplatIdx, dl)); - } - } + Amt = peekThroughEXTRACT_SUBVECTORs(Amt); - if (BaseShAmt.getNode()) { + if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) { + if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG, Subtarget, Opcode)) { + MVT EltVT = VT.getVectorElementType(); assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!"); if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32)) BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, BaseShAmt); @@ -22793,6 +23330,70 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, return SDValue(); } +// Convert a shift/rotate left amount to a multiplication scale factor. +static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + MVT VT = Amt.getSimpleValueType(); + if (!(VT == MVT::v8i16 || VT == MVT::v4i32 || + (Subtarget.hasInt256() && VT == MVT::v16i16) || + (!Subtarget.hasAVX512() && VT == MVT::v16i8))) + return SDValue(); + + if (ISD::isBuildVectorOfConstantSDNodes(Amt.getNode())) { + SmallVector<SDValue, 8> Elts; + MVT SVT = VT.getVectorElementType(); + unsigned SVTBits = SVT.getSizeInBits(); + APInt One(SVTBits, 1); + unsigned NumElems = VT.getVectorNumElements(); + + for (unsigned i = 0; i != NumElems; ++i) { + SDValue Op = Amt->getOperand(i); + if (Op->isUndef()) { + Elts.push_back(Op); + continue; + } + + ConstantSDNode *ND = cast<ConstantSDNode>(Op); + APInt C(SVTBits, ND->getAPIntValue().getZExtValue()); + uint64_t ShAmt = C.getZExtValue(); + if (ShAmt >= SVTBits) { + Elts.push_back(DAG.getUNDEF(SVT)); + continue; + } + Elts.push_back(DAG.getConstant(One.shl(ShAmt), dl, SVT)); + } + return DAG.getBuildVector(VT, dl, Elts); + } + + // If the target doesn't support variable shifts, use either FP conversion + // or integer multiplication to avoid shifting each element individually. + if (VT == MVT::v4i32) { + Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT)); + Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, + DAG.getConstant(0x3f800000U, dl, VT)); + Amt = DAG.getBitcast(MVT::v4f32, Amt); + return DAG.getNode(ISD::FP_TO_SINT, dl, VT, Amt); + } + + // AVX2 can more effectively perform this as a zext/trunc to/from v8i32. + if (VT == MVT::v8i16 && !Subtarget.hasAVX2()) { + SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); + SDValue Lo = DAG.getBitcast(MVT::v4i32, getUnpackl(DAG, dl, VT, Amt, Z)); + SDValue Hi = DAG.getBitcast(MVT::v4i32, getUnpackh(DAG, dl, VT, Amt, Z)); + Lo = convertShiftLeftToScale(Lo, dl, Subtarget, DAG); + Hi = convertShiftLeftToScale(Hi, dl, Subtarget, DAG); + if (Subtarget.hasSSE41()) + return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi); + + return DAG.getVectorShuffle(VT, dl, DAG.getBitcast(VT, Lo), + DAG.getBitcast(VT, Hi), + {0, 2, 4, 6, 8, 10, 12, 14}); + } + + return SDValue(); +} + static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); @@ -22815,11 +23416,10 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // XOP has 128-bit variable logical/arithmetic shifts. // +ve/-ve Amt = shift left/right. - if (Subtarget.hasXOP() && - (VT == MVT::v2i64 || VT == MVT::v4i32 || - VT == MVT::v8i16 || VT == MVT::v16i8)) { + if (Subtarget.hasXOP() && (VT == MVT::v2i64 || VT == MVT::v4i32 || + VT == MVT::v8i16 || VT == MVT::v16i8)) { if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) { - SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl); + SDValue Zero = DAG.getConstant(0, dl, VT); Amt = DAG.getNode(ISD::SUB, dl, VT, Zero, Amt); } if (Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL) @@ -22852,51 +23452,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, return R; } - // If possible, lower this packed shift into a vector multiply instead of - // expanding it into a sequence of scalar shifts. - // Do this only if the vector shift count is a constant build_vector. - if (ConstantAmt && Op.getOpcode() == ISD::SHL && - (VT == MVT::v8i16 || VT == MVT::v4i32 || - (Subtarget.hasInt256() && VT == MVT::v16i16))) { - SmallVector<SDValue, 8> Elts; - MVT SVT = VT.getVectorElementType(); - unsigned SVTBits = SVT.getSizeInBits(); - APInt One(SVTBits, 1); - unsigned NumElems = VT.getVectorNumElements(); - - for (unsigned i=0; i !=NumElems; ++i) { - SDValue Op = Amt->getOperand(i); - if (Op->isUndef()) { - Elts.push_back(Op); - continue; - } - - ConstantSDNode *ND = cast<ConstantSDNode>(Op); - APInt C(SVTBits, ND->getAPIntValue().getZExtValue()); - uint64_t ShAmt = C.getZExtValue(); - if (ShAmt >= SVTBits) { - Elts.push_back(DAG.getUNDEF(SVT)); - continue; - } - Elts.push_back(DAG.getConstant(One.shl(ShAmt), dl, SVT)); - } - SDValue BV = DAG.getBuildVector(VT, dl, Elts); - return DAG.getNode(ISD::MUL, dl, VT, R, BV); - } - - // Lower SHL with variable shift amount. - if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) { - Op = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT)); - - Op = DAG.getNode(ISD::ADD, dl, VT, Op, - DAG.getConstant(0x3f800000U, dl, VT)); - Op = DAG.getBitcast(MVT::v4f32, Op); - Op = DAG.getNode(ISD::FP_TO_SINT, dl, VT, Op); - return DAG.getNode(ISD::MUL, dl, VT, Op, R); - } - // If possible, lower this shift as a sequence of two shifts by - // constant plus a MOVSS/MOVSD/PBLEND instead of scalarizing it. + // constant plus a BLENDing shuffle instead of scalarizing it. // Example: // (v4i32 (srl A, (build_vector < X, Y, Y, Y>))) // @@ -22904,67 +23461,54 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // (v4i32 (MOVSS (srl A, <Y,Y,Y,Y>), (srl A, <X,X,X,X>))) // // The advantage is that the two shifts from the example would be - // lowered as X86ISD::VSRLI nodes. This would be cheaper than scalarizing - // the vector shift into four scalar shifts plus four pairs of vector - // insert/extract. - if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32)) { - bool UseMOVSD = false; - bool CanBeSimplified; - // The splat value for the first packed shift (the 'X' from the example). - SDValue Amt1 = Amt->getOperand(0); - // The splat value for the second packed shift (the 'Y' from the example). - SDValue Amt2 = (VT == MVT::v4i32) ? Amt->getOperand(1) : Amt->getOperand(2); - - // See if it is possible to replace this node with a sequence of - // two shifts followed by a MOVSS/MOVSD/PBLEND. - if (VT == MVT::v4i32) { - // Check if it is legal to use a MOVSS. - CanBeSimplified = Amt2 == Amt->getOperand(2) && - Amt2 == Amt->getOperand(3); - if (!CanBeSimplified) { - // Otherwise, check if we can still simplify this node using a MOVSD. - CanBeSimplified = Amt1 == Amt->getOperand(1) && - Amt->getOperand(2) == Amt->getOperand(3); - UseMOVSD = true; - Amt2 = Amt->getOperand(2); + // lowered as X86ISD::VSRLI nodes in parallel before blending. + if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32 || + (VT == MVT::v16i16 && Subtarget.hasInt256()))) { + SDValue Amt1, Amt2; + unsigned NumElts = VT.getVectorNumElements(); + SmallVector<int, 8> ShuffleMask; + for (unsigned i = 0; i != NumElts; ++i) { + SDValue A = Amt->getOperand(i); + if (A.isUndef()) { + ShuffleMask.push_back(SM_SentinelUndef); + continue; } - } else { - // Do similar checks for the case where the machine value type - // is MVT::v8i16. - CanBeSimplified = Amt1 == Amt->getOperand(1); - for (unsigned i=3; i != 8 && CanBeSimplified; ++i) - CanBeSimplified = Amt2 == Amt->getOperand(i); - - if (!CanBeSimplified) { - UseMOVSD = true; - CanBeSimplified = true; - Amt2 = Amt->getOperand(4); - for (unsigned i=0; i != 4 && CanBeSimplified; ++i) - CanBeSimplified = Amt1 == Amt->getOperand(i); - for (unsigned j=4; j != 8 && CanBeSimplified; ++j) - CanBeSimplified = Amt2 == Amt->getOperand(j); + if (!Amt1 || Amt1 == A) { + ShuffleMask.push_back(i); + Amt1 = A; + continue; } + if (!Amt2 || Amt2 == A) { + ShuffleMask.push_back(i + NumElts); + Amt2 = A; + continue; + } + break; } - if (CanBeSimplified && isa<ConstantSDNode>(Amt1) && - isa<ConstantSDNode>(Amt2)) { - // Replace this node with two shifts followed by a MOVSS/MOVSD/PBLEND. + // Only perform this blend if we can perform it without loading a mask. + if (ShuffleMask.size() == NumElts && Amt1 && Amt2 && + isa<ConstantSDNode>(Amt1) && isa<ConstantSDNode>(Amt2) && + (VT != MVT::v16i16 || + is128BitLaneRepeatedShuffleMask(VT, ShuffleMask)) && + (VT == MVT::v4i32 || Subtarget.hasSSE41() || + Op.getOpcode() != ISD::SHL || canWidenShuffleElements(ShuffleMask))) { SDValue Splat1 = DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT); SDValue Shift1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat1); SDValue Splat2 = DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT); SDValue Shift2 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat2); - SDValue BitCast1 = DAG.getBitcast(MVT::v4i32, Shift1); - SDValue BitCast2 = DAG.getBitcast(MVT::v4i32, Shift2); - if (UseMOVSD) - return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1, - BitCast2, {0, 1, 6, 7})); - return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1, - BitCast2, {0, 5, 6, 7})); + return DAG.getVectorShuffle(VT, dl, Shift1, Shift2, ShuffleMask); } } + // If possible, lower this packed shift into a vector multiply instead of + // expanding it into a sequence of scalar shifts. + if (Op.getOpcode() == ISD::SHL) + if (SDValue Scale = convertShiftLeftToScale(Amt, dl, Subtarget, DAG)) + return DAG.getNode(ISD::MUL, dl, VT, R, Scale); + // v4i32 Non Uniform Shifts. // If the shift amount is constant we can shift each lane using the SSE2 // immediate shifts, else we need to zero-extend each lane to the lower i64 @@ -22994,31 +23538,56 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, break; } // The SSE2 shifts use the lower i64 as the same shift amount for - // all lanes and the upper i64 is ignored. These shuffle masks - // optimally zero-extend each lanes on SSE2/SSE41/AVX targets. - SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); - Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1}); - Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1}); - Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1}); - Amt3 = DAG.getVectorShuffle(VT, dl, Amt, Z, {3, 7, -1, -1}); + // all lanes and the upper i64 is ignored. On AVX we're better off + // just zero-extending, but for SSE just duplicating the top 16-bits is + // cheaper and has the same effect for out of range values. + if (Subtarget.hasAVX()) { + SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); + Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1}); + Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1}); + Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1}); + Amt3 = DAG.getVectorShuffle(VT, dl, Amt, Z, {3, 7, -1, -1}); + } else { + SDValue Amt01 = DAG.getBitcast(MVT::v8i16, Amt); + SDValue Amt23 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01, + {4, 5, 6, 7, -1, -1, -1, -1}); + Amt0 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01, + {0, 1, 1, 1, -1, -1, -1, -1}); + Amt1 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01, + {2, 3, 3, 3, -1, -1, -1, -1}); + Amt2 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt23, Amt23, + {0, 1, 1, 1, -1, -1, -1, -1}); + Amt3 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt23, Amt23, + {2, 3, 3, 3, -1, -1, -1, -1}); + } } - SDValue R0 = DAG.getNode(Opc, dl, VT, R, Amt0); - SDValue R1 = DAG.getNode(Opc, dl, VT, R, Amt1); - SDValue R2 = DAG.getNode(Opc, dl, VT, R, Amt2); - SDValue R3 = DAG.getNode(Opc, dl, VT, R, Amt3); - SDValue R02 = DAG.getVectorShuffle(VT, dl, R0, R2, {0, -1, 6, -1}); - SDValue R13 = DAG.getVectorShuffle(VT, dl, R1, R3, {-1, 1, -1, 7}); - return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7}); + SDValue R0 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt0)); + SDValue R1 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt1)); + SDValue R2 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt2)); + SDValue R3 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt3)); + + // Merge the shifted lane results optimally with/without PBLENDW. + // TODO - ideally shuffle combining would handle this. + if (Subtarget.hasSSE41()) { + SDValue R02 = DAG.getVectorShuffle(VT, dl, R0, R2, {0, -1, 6, -1}); + SDValue R13 = DAG.getVectorShuffle(VT, dl, R1, R3, {-1, 1, -1, 7}); + return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7}); + } + SDValue R01 = DAG.getVectorShuffle(VT, dl, R0, R1, {0, -1, -1, 5}); + SDValue R23 = DAG.getVectorShuffle(VT, dl, R2, R3, {2, -1, -1, 7}); + return DAG.getVectorShuffle(VT, dl, R01, R23, {0, 3, 4, 7}); } // It's worth extending once and using the vXi16/vXi32 shifts for smaller // types, but without AVX512 the extra overheads to get from vXi8 to vXi32 // make the existing SSE solution better. + // NOTE: We honor prefered vector width before promoting to 512-bits. if ((Subtarget.hasInt256() && VT == MVT::v8i16) || - (Subtarget.hasAVX512() && VT == MVT::v16i16) || - (Subtarget.hasAVX512() && VT == MVT::v16i8) || - (Subtarget.hasBWI() && VT == MVT::v32i8)) { + (Subtarget.canExtendTo512DQ() && VT == MVT::v16i16) || + (Subtarget.canExtendTo512DQ() && VT == MVT::v16i8) || + (Subtarget.canExtendTo512BW() && VT == MVT::v32i8) || + (Subtarget.hasBWI() && Subtarget.hasVLX() && VT == MVT::v16i8)) { assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) && "Unexpected vector type"); MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32; @@ -23046,7 +23615,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, V0 = DAG.getBitcast(VT, V0); V1 = DAG.getBitcast(VT, V1); Sel = DAG.getBitcast(VT, Sel); - Sel = DAG.getNode(X86ISD::CVT2MASK, dl, MaskVT, Sel); + Sel = DAG.getSetCC(dl, MaskVT, DAG.getConstant(0, dl, VT), Sel, + ISD::SETGT); return DAG.getBitcast(SelVT, DAG.getSelect(dl, VT, Sel, V0, V1)); } else if (Subtarget.hasSSE41()) { // On SSE41 targets we make use of the fact that VSELECT lowers @@ -23242,13 +23812,15 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); + assert(VT.isVector() && "Custom lowering only for vector rotates!"); + SDLoc DL(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); unsigned Opcode = Op.getOpcode(); unsigned EltSizeInBits = VT.getScalarSizeInBits(); - if (Subtarget.hasAVX512()) { + if (Subtarget.hasAVX512() && 32 <= EltSizeInBits) { // Attempt to rotate by immediate. APInt UndefElts; SmallVector<APInt, 16> EltBits; @@ -23267,31 +23839,178 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, return Op; } - assert(VT.isVector() && "Custom lowering only for vector rotates!"); - assert(Subtarget.hasXOP() && "XOP support required for vector rotates!"); assert((Opcode == ISD::ROTL) && "Only ROTL supported"); // XOP has 128-bit vector variable + immediate rotates. // +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL. + if (Subtarget.hasXOP()) { + // Split 256-bit integers. + if (VT.is256BitVector()) + return Lower256IntArith(Op, DAG); + assert(VT.is128BitVector() && "Only rotate 128-bit vectors!"); - // Split 256-bit integers. - if (VT.is256BitVector()) + // Attempt to rotate by immediate. + if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { + if (auto *RotateConst = BVAmt->getConstantSplatNode()) { + uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue(); + assert(RotateAmt < EltSizeInBits && "Rotation out of range"); + return DAG.getNode(X86ISD::VROTLI, DL, VT, R, + DAG.getConstant(RotateAmt, DL, MVT::i8)); + } + } + + // Use general rotate by variable (per-element). + return Op; + } + + // Split 256-bit integers on pre-AVX2 targets. + if (VT.is256BitVector() && !Subtarget.hasAVX2()) return Lower256IntArith(Op, DAG); - assert(VT.is128BitVector() && "Only rotate 128-bit vectors!"); + assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 || + ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) && + Subtarget.hasAVX2())) && + "Only vXi32/vXi16/vXi8 vector rotates supported"); - // Attempt to rotate by immediate. + // Rotate by an uniform constant - expand back to shifts. + // TODO - legalizers should be able to handle this. if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { if (auto *RotateConst = BVAmt->getConstantSplatNode()) { uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue(); assert(RotateAmt < EltSizeInBits && "Rotation out of range"); - return DAG.getNode(X86ISD::VROTLI, DL, VT, R, - DAG.getConstant(RotateAmt, DL, MVT::i8)); + if (RotateAmt == 0) + return R; + + SDValue AmtR = DAG.getConstant(EltSizeInBits - RotateAmt, DL, VT); + SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); + SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); + return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); } } - // Use general rotate by variable (per-element). - return Op; + // Rotate by splat - expand back to shifts. + // TODO - legalizers should be able to handle this. + if ((EltSizeInBits >= 16 || Subtarget.hasBWI()) && + IsSplatValue(VT, Amt, DL, DAG, Subtarget, Opcode)) { + SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); + AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); + SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); + SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); + return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); + } + + // v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by + // the amount bit. + if (EltSizeInBits == 8) { + if (Subtarget.hasBWI()) { + SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); + AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); + SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); + SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); + return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); + } + + MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); + + auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) { + if (Subtarget.hasSSE41()) { + // On SSE41 targets we make use of the fact that VSELECT lowers + // to PBLENDVB which selects bytes based just on the sign bit. + V0 = DAG.getBitcast(VT, V0); + V1 = DAG.getBitcast(VT, V1); + Sel = DAG.getBitcast(VT, Sel); + return DAG.getBitcast(SelVT, DAG.getSelect(DL, VT, Sel, V0, V1)); + } + // On pre-SSE41 targets we test for the sign bit by comparing to + // zero - a negative value will set all bits of the lanes to true + // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering. + SDValue Z = getZeroVector(SelVT, Subtarget, DAG, DL); + SDValue C = DAG.getNode(X86ISD::PCMPGT, DL, SelVT, Z, Sel); + return DAG.getSelect(DL, SelVT, C, V0, V1); + }; + + // Turn 'a' into a mask suitable for VSELECT: a = a << 5; + // We can safely do this using i16 shifts as we're only interested in + // the 3 lower bits of each byte. + Amt = DAG.getBitcast(ExtVT, Amt); + Amt = DAG.getNode(ISD::SHL, DL, ExtVT, Amt, DAG.getConstant(5, DL, ExtVT)); + Amt = DAG.getBitcast(VT, Amt); + + // r = VSELECT(r, rot(r, 4), a); + SDValue M; + M = DAG.getNode( + ISD::OR, DL, VT, + DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(4, DL, VT)), + DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(4, DL, VT))); + R = SignBitSelect(VT, Amt, M, R); + + // a += a + Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt); + + // r = VSELECT(r, rot(r, 2), a); + M = DAG.getNode( + ISD::OR, DL, VT, + DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(2, DL, VT)), + DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(6, DL, VT))); + R = SignBitSelect(VT, Amt, M, R); + + // a += a + Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt); + + // return VSELECT(r, rot(r, 1), a); + M = DAG.getNode( + ISD::OR, DL, VT, + DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(1, DL, VT)), + DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(7, DL, VT))); + return SignBitSelect(VT, Amt, M, R); + } + + bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode()); + bool LegalVarShifts = SupportedVectorVarShift(VT, Subtarget, ISD::SHL) && + SupportedVectorVarShift(VT, Subtarget, ISD::SRL); + + // Best to fallback for all supported variable shifts. + // AVX2 - best to fallback for non-constants as well. + // TODO - legalizers should be able to handle this. + if (LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) { + SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); + AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); + SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); + SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); + return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); + } + + // As with shifts, convert the rotation amount to a multiplication factor. + SDValue Scale = convertShiftLeftToScale(Amt, DL, Subtarget, DAG); + assert(Scale && "Failed to convert ROTL amount to scale"); + + // v8i16/v16i16: perform unsigned multiply hi/lo and OR the results. + if (EltSizeInBits == 16) { + SDValue Lo = DAG.getNode(ISD::MUL, DL, VT, R, Scale); + SDValue Hi = DAG.getNode(ISD::MULHU, DL, VT, R, Scale); + return DAG.getNode(ISD::OR, DL, VT, Lo, Hi); + } + + // v4i32: make use of the PMULUDQ instruction to multiply 2 lanes of v4i32 + // to v2i64 results at a time. The upper 32-bits contain the wrapped bits + // that can then be OR'd with the lower 32-bits. + assert(VT == MVT::v4i32 && "Only v4i32 vector rotate expected"); + static const int OddMask[] = {1, -1, 3, -1}; + SDValue R13 = DAG.getVectorShuffle(VT, DL, R, R, OddMask); + SDValue Scale13 = DAG.getVectorShuffle(VT, DL, Scale, Scale, OddMask); + + SDValue Res02 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64, + DAG.getBitcast(MVT::v2i64, R), + DAG.getBitcast(MVT::v2i64, Scale)); + SDValue Res13 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64, + DAG.getBitcast(MVT::v2i64, R13), + DAG.getBitcast(MVT::v2i64, Scale13)); + Res02 = DAG.getBitcast(VT, Res02); + Res13 = DAG.getBitcast(VT, Res13); + + return DAG.getNode(ISD::OR, DL, VT, + DAG.getVectorShuffle(VT, DL, Res02, Res13, {0, 4, 2, 6}), + DAG.getVectorShuffle(VT, DL, Res02, Res13, {1, 5, 3, 7})); } static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { @@ -23353,9 +24072,6 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { SDValue SetCC = getSETCC(X86::COND_O, SDValue(Sum.getNode(), 2), DL, DAG); - if (N->getValueType(1) == MVT::i1) - SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); } } @@ -23366,9 +24082,6 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { SDValue SetCC = getSETCC(Cond, SDValue(Sum.getNode(), 1), DL, DAG); - if (N->getValueType(1) == MVT::i1) - SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); } @@ -23572,11 +24285,68 @@ static SDValue LowerCMP_SWAP(SDValue Op, const X86Subtarget &Subtarget, return SDValue(); } +// Create MOVMSKB, taking into account whether we need to split for AVX1. +static SDValue getPMOVMSKB(const SDLoc &DL, SDValue V, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + MVT InVT = V.getSimpleValueType(); + + if (InVT == MVT::v32i8 && !Subtarget.hasInt256()) { + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(V, DL); + Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo); + Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi); + Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi, + DAG.getConstant(16, DL, MVT::i8)); + return DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi); + } + + return DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V); +} + static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - MVT SrcVT = Op.getOperand(0).getSimpleValueType(); + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); MVT DstVT = Op.getSimpleValueType(); + // Legalize (v64i1 (bitcast i64 (X))) by splitting the i64, bitcasting each + // half to v32i1 and concatenating the result. + if (SrcVT == MVT::i64 && DstVT == MVT::v64i1) { + assert(!Subtarget.is64Bit() && "Expected 32-bit mode"); + assert(Subtarget.hasBWI() && "Expected BWI target"); + SDLoc dl(Op); + SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, + DAG.getIntPtrConstant(0, dl)); + Lo = DAG.getBitcast(MVT::v32i1, Lo); + SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, + DAG.getIntPtrConstant(1, dl)); + Hi = DAG.getBitcast(MVT::v32i1, Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi); + } + + // Custom splitting for BWI types when AVX512F is available but BWI isn't. + if ((SrcVT == MVT::v32i16 || SrcVT == MVT::v64i8) && DstVT.isVector() && + DAG.getTargetLoweringInfo().isTypeLegal(DstVT)) { + SDLoc dl(Op); + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(Op.getOperand(0), dl); + EVT CastVT = MVT::getVectorVT(DstVT.getVectorElementType(), + DstVT.getVectorNumElements() / 2); + Lo = DAG.getBitcast(CastVT, Lo); + Hi = DAG.getBitcast(CastVT, Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, DstVT, Lo, Hi); + } + + // Use MOVMSK for vector to scalar conversion to prevent scalarization. + if ((SrcVT == MVT::v16i1 || SrcVT == MVT::v32i1) && DstVT.isScalarInteger()) { + assert(!Subtarget.hasAVX512() && "Should use K-registers with AVX512"); + MVT SExtVT = SrcVT == MVT::v16i1 ? MVT::v16i8 : MVT::v32i8; + SDLoc DL(Op); + SDValue V = DAG.getSExtOrTrunc(Src, DL, SExtVT); + V = getPMOVMSKB(DL, V, DAG, Subtarget); + return DAG.getZExtOrTrunc(V, DL, DstVT); + } + if (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 || SrcVT == MVT::i64) { assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); @@ -23584,7 +24354,6 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget, // This conversion needs to be expanded. return SDValue(); - SDValue Op0 = Op->getOperand(0); SmallVector<SDValue, 16> Elts; SDLoc dl(Op); unsigned NumElts; @@ -23596,14 +24365,14 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget, // Widen the vector in input in the case of MVT::v2i32. // Example: from MVT::v2i32 to MVT::v4i32. for (unsigned i = 0, e = NumElts; i != e; ++i) - Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, Op0, + Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, Src, DAG.getIntPtrConstant(i, dl))); } else { assert(SrcVT == MVT::i64 && !Subtarget.is64Bit() && "Unexpected source type in LowerBITCAST"); - Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Op0, + Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, DAG.getIntPtrConstant(0, dl))); - Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Op0, + Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, DAG.getIntPtrConstant(1, dl))); NumElts = 2; SVT = MVT::i32; @@ -23842,7 +24611,7 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, unsigned NumElems = VT.getVectorNumElements(); assert((VT.getVectorElementType() == MVT::i8 || VT.getVectorElementType() == MVT::i16) && "Unexpected type"); - if (NumElems <= 16) { + if (NumElems < 16 || (NumElems == 16 && Subtarget.canExtendTo512DQ())) { MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems); Op = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, Op0); Op = DAG.getNode(ISD::CTPOP, DL, NewVT, Op); @@ -24224,76 +24993,81 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op"); SDLoc dl(Op); + SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); SDValue BasePtr = N->getBasePtr(); - MVT MemVT = N->getMemoryVT().getSimpleVT(); + + if (VT == MVT::v2f32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, + DAG.getUNDEF(MVT::v2f32)); + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + return SDValue(); + } + + if (VT == MVT::v2i32) { + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, + DAG.getUNDEF(MVT::v2i32)); + // If the index is v2i64 and we have VLX we can use xmm for data and index. + if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) { + SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 1); + } + // Custom widen all the operands to avoid promotion. + EVT NewIndexVT = EVT::getVectorVT( + *DAG.getContext(), Index.getValueType().getVectorElementType(), 4); + Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, + DAG.getUNDEF(Index.getValueType())); + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getConstant(0, dl, MVT::v2i1)); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl, + Ops, N->getMemOperand()); + } + MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); - if (MemVT.getScalarSizeInBits() < VT.getScalarSizeInBits()) { - // The v2i32 value was promoted to v2i64. - // Now we "redo" the type legalizer's work and widen the original - // v2i32 value to v4i32. The original v2i32 is retrieved from v2i64 - // with a shuffle. - assert((MemVT == MVT::v2i32 && VT == MVT::v2i64) && - "Unexpected memory type"); - int ShuffleMask[] = {0, 2, -1, -1}; - Src = DAG.getVectorShuffle(MVT::v4i32, dl, DAG.getBitcast(MVT::v4i32, Src), - DAG.getUNDEF(MVT::v4i32), ShuffleMask); - // Now we have 4 elements instead of 2. - // Expand the index. - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), 4); - Index = ExtendToType(Index, NewIndexVT, DAG); - - // Expand the mask with zeroes - // Mask may be <2 x i64> or <2 x i1> at this moment - assert((MaskVT == MVT::v2i1 || MaskVT == MVT::v2i64) && - "Unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), 4); - Mask = ExtendToType(Mask, ExtMaskVT, DAG, true); - VT = MVT::v4i32; - } + // If the index is v2i32, we're being called by type legalization and we + // should just let the default handling take care of it. + if (IndexVT == MVT::v2i32) + return SDValue(); - unsigned NumElts = VT.getVectorNumElements(); + // If we don't have VLX and neither the passthru or index is 512-bits, we + // need to widen until one is. if (!Subtarget.hasVLX() && !VT.is512BitVector() && !Index.getSimpleValueType().is512BitVector()) { - // AVX512F supports only 512-bit vectors. Or data or index should - // be 512 bit wide. If now the both index and data are 256-bit, but - // the vector contains 8 elements, we just sign-extend the index - if (IndexVT == MVT::v8i32) - // Just extend index - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - else { - // The minimal number of elts in scatter is 8 - NumElts = 8; - // Index - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts); - // Use original index here, do not modify the index twice - Index = ExtendToType(N->getIndex(), NewIndexVT, DAG); - if (IndexVT.getScalarType() == MVT::i32) - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - - // Mask - // At this point we have promoted mask operand - assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts); - // Use the original mask here, do not modify the mask twice - Mask = ExtendToType(N->getMask(), ExtMaskVT, DAG, true); - - // The value that should be stored - MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); - Src = ExtendToType(Src, NewVT, DAG); - } - } - // If the mask is "wide" at this point - truncate it to i1 vector - MVT BitMaskVT = MVT::getVectorVT(MVT::i1, NumElts); - Mask = DAG.getNode(ISD::TRUNCATE, dl, BitMaskVT, Mask); - - // The mask is killed by scatter, add it to the values - SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other); - SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; + // Determine how much we need to widen by to get a 512-bit type. + unsigned Factor = std::min(512/VT.getSizeInBits(), + 512/IndexVT.getSizeInBits()); + unsigned NumElts = VT.getVectorNumElements() * Factor; + + VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts); + MaskVT = MVT::getVectorVT(MVT::i1, NumElts); + + Src = ExtendToType(Src, VT, DAG); + Index = ExtendToType(Index, IndexVT, DAG); + Mask = ExtendToType(Mask, MaskVT, DAG, true); + } + + SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other); + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); @@ -24315,11 +25089,6 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, assert((!N->isExpandingLoad() || ScalarVT.getSizeInBits() >= 32) && "Expanding masked load is supported for 32 and 64-bit types only!"); - // 4x32, 4x64 and 2x64 vectors of non-expanding loads are legal regardless of - // VLX. These types for exp-loads are handled here. - if (!N->isExpandingLoad() && VT.getVectorNumElements() <= 4) - return Op; - assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked load op."); @@ -24336,16 +25105,12 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, Src0 = ExtendToType(Src0, WideDataVT, DAG); // Mask element has to be i1. - MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); - assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && - "We handle 4x32, 4x64 and 2x64 vectors only in this case"); + assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && + "Unexpected mask type"); - MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); Mask = ExtendToType(Mask, WideMaskVT, DAG, true); - if (MaskEltTy != MVT::i1) - Mask = DAG.getNode(ISD::TRUNCATE, dl, - MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(), N->getBasePtr(), Mask, Src0, N->getMemoryVT(), N->getMemOperand(), @@ -24374,10 +25139,6 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, assert((!N->isCompressingStore() || ScalarVT.getSizeInBits() >= 32) && "Expanding masked load is supported for 32 and 64-bit types only!"); - // 4x32 and 2x64 vectors of non-compressing stores are legal regardless to VLX. - if (!N->isCompressingStore() && VT.getVectorNumElements() <= 4) - return Op; - assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked store op."); @@ -24392,17 +25153,13 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); // Mask element has to be i1. - MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); - assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && - "We handle 4x32, 4x64 and 2x64 vectors only in this case"); + assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && + "Unexpected mask type"); - MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); DataToStore = ExtendToType(DataToStore, WideDataVT, DAG); Mask = ExtendToType(Mask, WideMaskVT, DAG, true); - if (MaskEltTy != MVT::i1) - Mask = DAG.getNode(ISD::TRUNCATE, dl, - MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(), Mask, N->getMemoryVT(), N->getMemOperand(), N->isTruncatingStore(), N->isCompressingStore()); @@ -24422,63 +25179,40 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); - unsigned NumElts = VT.getVectorNumElements(); assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op"); // If the index is v2i32, we're being called by type legalization. if (IndexVT == MVT::v2i32) return SDValue(); + // If we don't have VLX and neither the passthru or index is 512-bits, we + // need to widen until one is. + MVT OrigVT = VT; if (Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && - !Index.getSimpleValueType().is512BitVector()) { - // AVX512F supports only 512-bit vectors. Or data or index should - // be 512 bit wide. If now the both index and data are 256-bit, but - // the vector contains 8 elements, we just sign-extend the index - if (NumElts == 8) { - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; - SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); - } - - // Minimal number of elements in Gather - NumElts = 8; - // Index - MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts); - Index = ExtendToType(Index, NewIndexVT, DAG); - if (IndexVT.getScalarType() == MVT::i32) - Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - - // Mask - MVT MaskBitVT = MVT::getVectorVT(MVT::i1, NumElts); - // At this point we have promoted mask operand - assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type"); - MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts); - Mask = ExtendToType(Mask, ExtMaskVT, DAG, true); - Mask = DAG.getNode(ISD::TRUNCATE, dl, MaskBitVT, Mask); - - // The pass-through value - MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); - Src0 = ExtendToType(Src0, NewVT, DAG); - - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; - SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(NewVT, MaskBitVT, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, - NewGather.getValue(0), - DAG.getIntPtrConstant(0, dl)); - SDValue RetOps[] = {Extract, NewGather.getValue(2)}; - return DAG.getMergeValues(RetOps, dl); + !IndexVT.is512BitVector()) { + // Determine how much we need to widen by to get a 512-bit type. + unsigned Factor = std::min(512/VT.getSizeInBits(), + 512/IndexVT.getSizeInBits()); + + unsigned NumElts = VT.getVectorNumElements() * Factor; + + VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts); + MaskVT = MVT::getVectorVT(MVT::i1, NumElts); + + Src0 = ExtendToType(Src0, VT, DAG); + Index = ExtendToType(Index, IndexVT, DAG); + Mask = ExtendToType(Mask, MaskVT, DAG, true); } - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, + N->getScale() }; SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); - return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT, + NewGather, DAG.getIntPtrConstant(0, dl)); + return DAG.getMergeValues({Extract, NewGather.getValue(2)}, dl); } SDValue X86TargetLowering::LowerGC_TRANSITION_START(SDValue Op, @@ -24545,6 +25279,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG); case ISD::INSERT_SUBVECTOR: return LowerINSERT_SUBVECTOR(Op, Subtarget,DAG); + case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op,Subtarget,DAG); case ISD::SCALAR_TO_VECTOR: return LowerSCALAR_TO_VECTOR(Op, Subtarget,DAG); case ISD::ConstantPool: return LowerConstantPool(Op, DAG); case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG); @@ -24566,7 +25301,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: return LowerFP_TO_INT(Op, DAG); case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); - case ISD::LOAD: return LowerExtendedLoad(Op, Subtarget, DAG); + case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG); + case ISD::STORE: return LowerStore(Op, Subtarget, DAG); case ISD::FABS: case ISD::FNEG: return LowerFABSorFNEG(Op, DAG); case ISD::FCOPYSIGN: return LowerFCOPYSIGN(Op, DAG); @@ -24635,7 +25371,6 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::GC_TRANSITION_START: return LowerGC_TRANSITION_START(Op, DAG); case ISD::GC_TRANSITION_END: return LowerGC_TRANSITION_END(Op, DAG); - case ISD::STORE: return LowerTruncatingStore(Op, Subtarget, DAG); } } @@ -24676,19 +25411,13 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); auto InVT = N->getValueType(0); - auto InVTSize = InVT.getSizeInBits(); - const unsigned RegSize = - (InVTSize > 128) ? ((InVTSize > 256) ? 512 : 256) : 128; - assert((Subtarget.hasBWI() || RegSize < 512) && - "512-bit vector requires AVX512BW"); - assert((Subtarget.hasAVX2() || RegSize < 256) && - "256-bit vector requires AVX2"); - - auto ElemVT = InVT.getVectorElementType(); - auto RegVT = EVT::getVectorVT(*DAG.getContext(), ElemVT, - RegSize / ElemVT.getSizeInBits()); - assert(RegSize % InVT.getSizeInBits() == 0); - unsigned NumConcat = RegSize / InVT.getSizeInBits(); + assert(InVT.getSizeInBits() < 128); + assert(128 % InVT.getSizeInBits() == 0); + unsigned NumConcat = 128 / InVT.getSizeInBits(); + + EVT RegVT = EVT::getVectorVT(*DAG.getContext(), + InVT.getVectorElementType(), + NumConcat * InVT.getVectorNumElements()); SmallVector<SDValue, 16> Ops(NumConcat, DAG.getUNDEF(InVT)); Ops[0] = N->getOperand(0); @@ -24697,12 +25426,32 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Ops); SDValue Res = DAG.getNode(X86ISD::AVG, dl, RegVT, InVec0, InVec1); - if (!ExperimentalVectorWideningLegalization) + if (getTypeAction(*DAG.getContext(), InVT) != TypeWidenVector) Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, InVT, Res, DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; } + case ISD::SETCC: { + // Widen v2i32 (setcc v2f32). This is really needed for AVX512VL when + // setCC result type is v2i1 because type legalzation will end up with + // a v4i1 setcc plus an extend. + assert(N->getValueType(0) == MVT::v2i32 && "Unexpected type"); + if (N->getOperand(0).getValueType() != MVT::v2f32) + return; + SDValue UNDEF = DAG.getUNDEF(MVT::v2f32); + SDValue LHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, + N->getOperand(0), UNDEF); + SDValue RHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, + N->getOperand(1), UNDEF); + SDValue Res = DAG.getNode(ISD::SETCC, dl, MVT::v4i32, LHS, RHS, + N->getOperand(2)); + if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); + return; + } // We might have generated v2f32 FMIN/FMAX operations. Widen them to v4f32. case X86ISD::FMINC: case X86ISD::FMIN: @@ -24731,12 +25480,14 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: { bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT; + EVT VT = N->getValueType(0); + SDValue Src = N->getOperand(0); + EVT SrcVT = Src.getValueType(); - if (N->getValueType(0) == MVT::v2i32) { + if (VT == MVT::v2i32) { assert((IsSigned || Subtarget.hasAVX512()) && "Can only handle signed conversion without AVX512"); assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); - SDValue Src = N->getOperand(0); if (Src.getValueType() == MVT::v2f64) { MVT ResVT = MVT::v4i32; unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI; @@ -24749,20 +25500,21 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Src, DAG.getIntPtrConstant(0, dl)); } SDValue Res = DAG.getNode(Opc, dl, ResVT, Src); - ResVT = ExperimentalVectorWideningLegalization ? MVT::v4i32 - : MVT::v2i32; + bool WidenType = getTypeAction(*DAG.getContext(), + MVT::v2i32) == TypeWidenVector; + ResVT = WidenType ? MVT::v4i32 : MVT::v2i32; Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Res, DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; } - if (Src.getValueType() == MVT::v2f32) { + if (SrcVT == MVT::v2f32) { SDValue Idx = DAG.getIntPtrConstant(0, dl); SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, DAG.getUNDEF(MVT::v2f32)); Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, MVT::v4i32, Res); - if (!ExperimentalVectorWideningLegalization) + if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); Results.push_back(Res); return; @@ -24773,11 +25525,30 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, return; } + if (Subtarget.hasDQI() && VT == MVT::i64 && + (SrcVT == MVT::f32 || SrcVT == MVT::f64)) { + assert(!Subtarget.is64Bit() && "i64 should be legal"); + unsigned NumElts = Subtarget.hasVLX() ? 4 : 8; + // Using a 256-bit input here to guarantee 128-bit input for f32 case. + // TODO: Use 128-bit vectors for f64 case? + // TODO: Use 128-bit vectors for f32 by using CVTTP2SI/CVTTP2UI. + MVT VecVT = MVT::getVectorVT(MVT::i64, NumElts); + MVT VecInVT = MVT::getVectorVT(SrcVT.getSimpleVT(), NumElts); + + SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl); + SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VecInVT, + DAG.getConstantFP(0.0, dl, VecInVT), Src, + ZeroIdx); + Res = DAG.getNode(N->getOpcode(), SDLoc(N), VecVT, Res); + Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Res, ZeroIdx); + Results.push_back(Res); + return; + } + std::pair<SDValue,SDValue> Vals = FP_TO_INTHelper(SDValue(N, 0), DAG, IsSigned, /*IsReplace=*/ true); SDValue FIST = Vals.first, StackSlot = Vals.second; if (FIST.getNode()) { - EVT VT = N->getValueType(0); // Return a load from the stack slot. if (StackSlot.getNode()) Results.push_back( @@ -24963,6 +25734,32 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, EVT DstVT = N->getValueType(0); EVT SrcVT = N->getOperand(0).getValueType(); + // If this is a bitcast from a v64i1 k-register to a i64 on a 32-bit target + // we can split using the k-register rather than memory. + if (SrcVT == MVT::v64i1 && DstVT == MVT::i64 && Subtarget.hasBWI()) { + assert(!Subtarget.is64Bit() && "Expected 32-bit mode"); + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0); + Lo = DAG.getBitcast(MVT::i32, Lo); + Hi = DAG.getBitcast(MVT::i32, Hi); + SDValue Res = DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi); + Results.push_back(Res); + return; + } + + // Custom splitting for BWI types when AVX512F is available but BWI isn't. + if ((DstVT == MVT::v32i16 || DstVT == MVT::v64i8) && + SrcVT.isVector() && isTypeLegal(SrcVT)) { + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0); + MVT CastVT = (DstVT == MVT::v32i16) ? MVT::v16i16 : MVT::v32i8; + Lo = DAG.getBitcast(CastVT, Lo); + Hi = DAG.getBitcast(CastVT, Hi); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, DstVT, Lo, Hi); + Results.push_back(Res); + return; + } + if (SrcVT != MVT::f64 || (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8)) return; @@ -24974,7 +25771,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, MVT::v2f64, N->getOperand(0)); SDValue ToVecInt = DAG.getBitcast(WiderVT, Expanded); - if (ExperimentalVectorWideningLegalization) { + if (getTypeAction(*DAG.getContext(), DstVT) == TypeWidenVector) { // If we are legalizing vectors by widening, we already have the desired // legal vector type, just return it. Results.push_back(ToVecInt); @@ -25009,7 +25806,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); @@ -25036,12 +25833,12 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); SDValue Chain = Res.getValue(2); - if (!ExperimentalVectorWideningLegalization) + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); @@ -25057,12 +25854,12 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, DAG.getConstant(0, dl, MVT::v2i1)); SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), Gather->getMemoryVT(), dl, Ops, Gather->getMemOperand()); SDValue Chain = Res.getValue(1); - if (!ExperimentalVectorWideningLegalization) + if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); @@ -25101,7 +25898,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::COMI: return "X86ISD::COMI"; case X86ISD::UCOMI: return "X86ISD::UCOMI"; case X86ISD::CMPM: return "X86ISD::CMPM"; - case X86ISD::CMPMU: return "X86ISD::CMPMU"; case X86ISD::CMPM_RND: return "X86ISD::CMPM_RND"; case X86ISD::SETCC: return "X86ISD::SETCC"; case X86ISD::SETCC_CARRY: return "X86ISD::SETCC_CARRY"; @@ -25192,7 +25988,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VFPROUND: return "X86ISD::VFPROUND"; case X86ISD::VFPROUND_RND: return "X86ISD::VFPROUND_RND"; case X86ISD::VFPROUNDS_RND: return "X86ISD::VFPROUNDS_RND"; - case X86ISD::CVT2MASK: return "X86ISD::CVT2MASK"; case X86ISD::VSHLDQ: return "X86ISD::VSHLDQ"; case X86ISD::VSRLDQ: return "X86ISD::VSRLDQ"; case X86ISD::VSHL: return "X86ISD::VSHL"; @@ -25208,8 +26003,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::CMPP: return "X86ISD::CMPP"; case X86ISD::PCMPEQ: return "X86ISD::PCMPEQ"; case X86ISD::PCMPGT: return "X86ISD::PCMPGT"; - case X86ISD::PCMPEQM: return "X86ISD::PCMPEQM"; - case X86ISD::PCMPGTM: return "X86ISD::PCMPGTM"; case X86ISD::PHMINPOS: return "X86ISD::PHMINPOS"; case X86ISD::ADD: return "X86ISD::ADD"; case X86ISD::SUB: return "X86ISD::SUB"; @@ -25226,14 +26019,14 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::OR: return "X86ISD::OR"; case X86ISD::XOR: return "X86ISD::XOR"; case X86ISD::AND: return "X86ISD::AND"; + case X86ISD::BEXTR: return "X86ISD::BEXTR"; case X86ISD::MUL_IMM: return "X86ISD::MUL_IMM"; case X86ISD::MOVMSK: return "X86ISD::MOVMSK"; case X86ISD::PTEST: return "X86ISD::PTEST"; case X86ISD::TESTP: return "X86ISD::TESTP"; - case X86ISD::TESTM: return "X86ISD::TESTM"; - case X86ISD::TESTNM: return "X86ISD::TESTNM"; case X86ISD::KORTEST: return "X86ISD::KORTEST"; case X86ISD::KTEST: return "X86ISD::KTEST"; + case X86ISD::KADD: return "X86ISD::KADD"; case X86ISD::KSHIFTL: return "X86ISD::KSHIFTL"; case X86ISD::KSHIFTR: return "X86ISD::KSHIFTR"; case X86ISD::PACKSS: return "X86ISD::PACKSS"; @@ -25251,8 +26044,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SHUF128: return "X86ISD::SHUF128"; case X86ISD::MOVLHPS: return "X86ISD::MOVLHPS"; case X86ISD::MOVHLPS: return "X86ISD::MOVHLPS"; - case X86ISD::MOVLPS: return "X86ISD::MOVLPS"; - case X86ISD::MOVLPD: return "X86ISD::MOVLPD"; case X86ISD::MOVDDUP: return "X86ISD::MOVDDUP"; case X86ISD::MOVSHDUP: return "X86ISD::MOVSHDUP"; case X86ISD::MOVSLDUP: return "X86ISD::MOVSLDUP"; @@ -25268,7 +26059,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VPERM2X128: return "X86ISD::VPERM2X128"; case X86ISD::VPERMV: return "X86ISD::VPERMV"; case X86ISD::VPERMV3: return "X86ISD::VPERMV3"; - case X86ISD::VPERMIV3: return "X86ISD::VPERMIV3"; case X86ISD::VPERMI: return "X86ISD::VPERMI"; case X86ISD::VPTERNLOG: return "X86ISD::VPTERNLOG"; case X86ISD::VFIXUPIMM: return "X86ISD::VFIXUPIMM"; @@ -25308,26 +26098,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FNMSUB_RND: return "X86ISD::FNMSUB_RND"; case X86ISD::FMADDSUB_RND: return "X86ISD::FMADDSUB_RND"; case X86ISD::FMSUBADD_RND: return "X86ISD::FMSUBADD_RND"; - case X86ISD::FMADDS1: return "X86ISD::FMADDS1"; - case X86ISD::FNMADDS1: return "X86ISD::FNMADDS1"; - case X86ISD::FMSUBS1: return "X86ISD::FMSUBS1"; - case X86ISD::FNMSUBS1: return "X86ISD::FNMSUBS1"; - case X86ISD::FMADDS1_RND: return "X86ISD::FMADDS1_RND"; - case X86ISD::FNMADDS1_RND: return "X86ISD::FNMADDS1_RND"; - case X86ISD::FMSUBS1_RND: return "X86ISD::FMSUBS1_RND"; - case X86ISD::FNMSUBS1_RND: return "X86ISD::FNMSUBS1_RND"; - case X86ISD::FMADDS3: return "X86ISD::FMADDS3"; - case X86ISD::FNMADDS3: return "X86ISD::FNMADDS3"; - case X86ISD::FMSUBS3: return "X86ISD::FMSUBS3"; - case X86ISD::FNMSUBS3: return "X86ISD::FNMSUBS3"; - case X86ISD::FMADDS3_RND: return "X86ISD::FMADDS3_RND"; - case X86ISD::FNMADDS3_RND: return "X86ISD::FNMADDS3_RND"; - case X86ISD::FMSUBS3_RND: return "X86ISD::FMSUBS3_RND"; - case X86ISD::FNMSUBS3_RND: return "X86ISD::FNMSUBS3_RND"; - case X86ISD::FMADD4S: return "X86ISD::FMADD4S"; - case X86ISD::FNMADD4S: return "X86ISD::FNMADD4S"; - case X86ISD::FMSUB4S: return "X86ISD::FMSUB4S"; - case X86ISD::FNMSUB4S: return "X86ISD::FNMSUB4S"; case X86ISD::VPMADD52H: return "X86ISD::VPMADD52H"; case X86ISD::VPMADD52L: return "X86ISD::VPMADD52L"; case X86ISD::VRNDSCALE: return "X86ISD::VRNDSCALE"; @@ -25342,8 +26112,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VGETMANT_RND: return "X86ISD::VGETMANT_RND"; case X86ISD::VGETMANTS: return "X86ISD::VGETMANTS"; case X86ISD::VGETMANTS_RND: return "X86ISD::VGETMANTS_RND"; - case X86ISD::PCMPESTRI: return "X86ISD::PCMPESTRI"; - case X86ISD::PCMPISTRI: return "X86ISD::PCMPISTRI"; + case X86ISD::PCMPESTR: return "X86ISD::PCMPESTR"; + case X86ISD::PCMPISTR: return "X86ISD::PCMPISTR"; case X86ISD::XTEST: return "X86ISD::XTEST"; case X86ISD::COMPRESS: return "X86ISD::COMPRESS"; case X86ISD::EXPAND: return "X86ISD::EXPAND"; @@ -25412,6 +26182,10 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::GF2P8MULB: return "X86ISD::GF2P8MULB"; case X86ISD::GF2P8AFFINEQB: return "X86ISD::GF2P8AFFINEQB"; case X86ISD::GF2P8AFFINEINVQB: return "X86ISD::GF2P8AFFINEINVQB"; + case X86ISD::NT_CALL: return "X86ISD::NT_CALL"; + case X86ISD::NT_BRIND: return "X86ISD::NT_BRIND"; + case X86ISD::UMWAIT: return "X86ISD::UMWAIT"; + case X86ISD::TPAUSE: return "X86ISD::TPAUSE"; } return nullptr; } @@ -25478,11 +26252,20 @@ bool X86TargetLowering::isVectorShiftByScalarCheap(Type *Ty) const { if (Bits == 8) return false; + // XOP has v16i8/v8i16/v4i32/v2i64 variable vector shifts. + if (Subtarget.hasXOP() && Ty->getPrimitiveSizeInBits() == 128 && + (Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64)) + return false; + // AVX2 has vpsllv[dq] instructions (and other shifts) that make variable // shifts just as cheap as scalar ones. if (Subtarget.hasAVX2() && (Bits == 32 || Bits == 64)) return false; + // AVX512BW has shifts such as vpsllvw. + if (Subtarget.hasBWI() && Bits == 16) + return false; + // Otherwise, it's significantly cheaper to shift by a scalar amount than by a // fully general vector. return true; @@ -25561,7 +26344,15 @@ bool X86TargetLowering::isZExtFree(SDValue Val, EVT VT2) const { return false; } -bool X86TargetLowering::isVectorLoadExtDesirable(SDValue) const { return true; } +bool X86TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { + EVT SrcVT = ExtVal.getOperand(0).getValueType(); + + // There is no extending load for vXi1. + if (SrcVT.getScalarType() == MVT::i1) + return false; + + return true; +} bool X86TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const { @@ -25610,13 +26401,27 @@ bool X86TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { return isTypeLegal(VT.getSimpleVT()); } -bool -X86TargetLowering::isVectorClearMaskLegal(const SmallVectorImpl<int> &Mask, - EVT VT) const { +bool X86TargetLowering::isVectorClearMaskLegal(ArrayRef<int> Mask, + EVT VT) const { + // Don't convert an 'and' into a shuffle that we don't directly support. + // vpblendw and vpshufb for 256-bit vectors are not available on AVX1. + if (!Subtarget.hasAVX2()) + if (VT == MVT::v32i8 || VT == MVT::v16i16) + return false; + // Just delegate to the generic legality, clear masks aren't special. return isShuffleMaskLegal(Mask, VT); } +bool X86TargetLowering::areJTsAllowed(const Function *Fn) const { + // If the subtarget is using retpolines, we need to not generate jump tables. + if (Subtarget.useRetpoline()) + return false; + + // Otherwise, fallback on the generic logic. + return TargetLowering::areJTsAllowed(Fn); +} + //===----------------------------------------------------------------------===// // X86 Scheduler Hooks //===----------------------------------------------------------------------===// @@ -25697,79 +26502,6 @@ static MachineBasicBlock *emitXBegin(MachineInstr &MI, MachineBasicBlock *MBB, return sinkMBB; } -// FIXME: When we get size specific XMM0 registers, i.e. XMM0_V16I8 -// or XMM0_V32I8 in AVX all of this code can be replaced with that -// in the .td file. -static MachineBasicBlock *emitPCMPSTRM(MachineInstr &MI, MachineBasicBlock *BB, - const TargetInstrInfo *TII) { - unsigned Opc; - switch (MI.getOpcode()) { - default: llvm_unreachable("illegal opcode!"); - case X86::PCMPISTRM128REG: Opc = X86::PCMPISTRM128rr; break; - case X86::VPCMPISTRM128REG: Opc = X86::VPCMPISTRM128rr; break; - case X86::PCMPISTRM128MEM: Opc = X86::PCMPISTRM128rm; break; - case X86::VPCMPISTRM128MEM: Opc = X86::VPCMPISTRM128rm; break; - case X86::PCMPESTRM128REG: Opc = X86::PCMPESTRM128rr; break; - case X86::VPCMPESTRM128REG: Opc = X86::VPCMPESTRM128rr; break; - case X86::PCMPESTRM128MEM: Opc = X86::PCMPESTRM128rm; break; - case X86::VPCMPESTRM128MEM: Opc = X86::VPCMPESTRM128rm; break; - } - - DebugLoc dl = MI.getDebugLoc(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, dl, TII->get(Opc)); - - unsigned NumArgs = MI.getNumOperands(); - for (unsigned i = 1; i < NumArgs; ++i) { - MachineOperand &Op = MI.getOperand(i); - if (!(Op.isReg() && Op.isImplicit())) - MIB.add(Op); - } - if (MI.hasOneMemOperand()) - MIB->setMemRefs(MI.memoperands_begin(), MI.memoperands_end()); - - BuildMI(*BB, MI, dl, TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg()) - .addReg(X86::XMM0); - - MI.eraseFromParent(); - return BB; -} - -// FIXME: Custom handling because TableGen doesn't support multiple implicit -// defs in an instruction pattern -static MachineBasicBlock *emitPCMPSTRI(MachineInstr &MI, MachineBasicBlock *BB, - const TargetInstrInfo *TII) { - unsigned Opc; - switch (MI.getOpcode()) { - default: llvm_unreachable("illegal opcode!"); - case X86::PCMPISTRIREG: Opc = X86::PCMPISTRIrr; break; - case X86::VPCMPISTRIREG: Opc = X86::VPCMPISTRIrr; break; - case X86::PCMPISTRIMEM: Opc = X86::PCMPISTRIrm; break; - case X86::VPCMPISTRIMEM: Opc = X86::VPCMPISTRIrm; break; - case X86::PCMPESTRIREG: Opc = X86::PCMPESTRIrr; break; - case X86::VPCMPESTRIREG: Opc = X86::VPCMPESTRIrr; break; - case X86::PCMPESTRIMEM: Opc = X86::PCMPESTRIrm; break; - case X86::VPCMPESTRIMEM: Opc = X86::VPCMPESTRIrm; break; - } - - DebugLoc dl = MI.getDebugLoc(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, dl, TII->get(Opc)); - - unsigned NumArgs = MI.getNumOperands(); // remove the results - for (unsigned i = 1; i < NumArgs; ++i) { - MachineOperand &Op = MI.getOperand(i); - if (!(Op.isReg() && Op.isImplicit())) - MIB.add(Op); - } - if (MI.hasOneMemOperand()) - MIB->setMemRefs(MI.memoperands_begin(), MI.memoperands_end()); - - BuildMI(*BB, MI, dl, TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg()) - .addReg(X86::ECX); - - MI.eraseFromParent(); - return BB; -} - static MachineBasicBlock *emitWRPKRU(MachineInstr &MI, MachineBasicBlock *BB, const X86Subtarget &Subtarget) { DebugLoc dl = MI.getDebugLoc(); @@ -26158,7 +26890,7 @@ MachineBasicBlock *X86TargetLowering::EmitVAStartSaveXMMRegsWithCustomInserter( !MI.getOperand(MI.getNumOperands() - 1).isReg() || MI.getOperand(MI.getNumOperands() - 1).getReg() == X86::EFLAGS) && "Expected last argument to be EFLAGS"); - unsigned MOVOpc = Subtarget.hasFp256() ? X86::VMOVAPSmr : X86::MOVAPSmr; + unsigned MOVOpc = Subtarget.hasAVX() ? X86::VMOVAPSmr : X86::MOVAPSmr; // In the XMM save block, save all the XMM argument registers. for (int i = 3, e = MI.getNumOperands() - 1; i != e; ++i) { int64_t Offset = (i - 3) * 16 + VarArgsFPOffset; @@ -26919,6 +27651,184 @@ X86TargetLowering::EmitLoweredTLSCall(MachineInstr &MI, return BB; } +static unsigned getOpcodeForRetpoline(unsigned RPOpc) { + switch (RPOpc) { + case X86::RETPOLINE_CALL32: + return X86::CALLpcrel32; + case X86::RETPOLINE_CALL64: + return X86::CALL64pcrel32; + case X86::RETPOLINE_TCRETURN32: + return X86::TCRETURNdi; + case X86::RETPOLINE_TCRETURN64: + return X86::TCRETURNdi64; + } + llvm_unreachable("not retpoline opcode"); +} + +static const char *getRetpolineSymbol(const X86Subtarget &Subtarget, + unsigned Reg) { + if (Subtarget.useRetpolineExternalThunk()) { + // When using an external thunk for retpolines, we pick names that match the + // names GCC happens to use as well. This helps simplify the implementation + // of the thunks for kernels where they have no easy ability to create + // aliases and are doing non-trivial configuration of the thunk's body. For + // example, the Linux kernel will do boot-time hot patching of the thunk + // bodies and cannot easily export aliases of these to loaded modules. + // + // Note that at any point in the future, we may need to change the semantics + // of how we implement retpolines and at that time will likely change the + // name of the called thunk. Essentially, there is no hard guarantee that + // LLVM will generate calls to specific thunks, we merely make a best-effort + // attempt to help out kernels and other systems where duplicating the + // thunks is costly. + switch (Reg) { + case X86::EAX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_eax"; + case X86::ECX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_ecx"; + case X86::EDX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_edx"; + case X86::EDI: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_edi"; + case X86::R11: + assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!"); + return "__x86_indirect_thunk_r11"; + } + llvm_unreachable("unexpected reg for retpoline"); + } + + // When targeting an internal COMDAT thunk use an LLVM-specific name. + switch (Reg) { + case X86::EAX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_eax"; + case X86::ECX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_ecx"; + case X86::EDX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_edx"; + case X86::EDI: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_edi"; + case X86::R11: + assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!"); + return "__llvm_retpoline_r11"; + } + llvm_unreachable("unexpected reg for retpoline"); +} + +MachineBasicBlock * +X86TargetLowering::EmitLoweredRetpoline(MachineInstr &MI, + MachineBasicBlock *BB) const { + // Copy the virtual register into the R11 physical register and + // call the retpoline thunk. + DebugLoc DL = MI.getDebugLoc(); + const X86InstrInfo *TII = Subtarget.getInstrInfo(); + unsigned CalleeVReg = MI.getOperand(0).getReg(); + unsigned Opc = getOpcodeForRetpoline(MI.getOpcode()); + + // Find an available scratch register to hold the callee. On 64-bit, we can + // just use R11, but we scan for uses anyway to ensure we don't generate + // incorrect code. On 32-bit, we use one of EAX, ECX, or EDX that isn't + // already a register use operand to the call to hold the callee. If none + // are available, use EDI instead. EDI is chosen because EBX is the PIC base + // register and ESI is the base pointer to realigned stack frames with VLAs. + SmallVector<unsigned, 3> AvailableRegs; + if (Subtarget.is64Bit()) + AvailableRegs.push_back(X86::R11); + else + AvailableRegs.append({X86::EAX, X86::ECX, X86::EDX, X86::EDI}); + + // Zero out any registers that are already used. + for (const auto &MO : MI.operands()) { + if (MO.isReg() && MO.isUse()) + for (unsigned &Reg : AvailableRegs) + if (Reg == MO.getReg()) + Reg = 0; + } + + // Choose the first remaining non-zero available register. + unsigned AvailableReg = 0; + for (unsigned MaybeReg : AvailableRegs) { + if (MaybeReg) { + AvailableReg = MaybeReg; + break; + } + } + if (!AvailableReg) + report_fatal_error("calling convention incompatible with retpoline, no " + "available registers"); + + const char *Symbol = getRetpolineSymbol(Subtarget, AvailableReg); + + BuildMI(*BB, MI, DL, TII->get(TargetOpcode::COPY), AvailableReg) + .addReg(CalleeVReg); + MI.getOperand(0).ChangeToES(Symbol); + MI.setDesc(TII->get(Opc)); + MachineInstrBuilder(*BB->getParent(), &MI) + .addReg(AvailableReg, RegState::Implicit | RegState::Kill); + return BB; +} + +/// SetJmp implies future control flow change upon calling the corresponding +/// LongJmp. +/// Instead of using the 'return' instruction, the long jump fixes the stack and +/// performs an indirect branch. To do so it uses the registers that were stored +/// in the jump buffer (when calling SetJmp). +/// In case the shadow stack is enabled we need to fix it as well, because some +/// return addresses will be skipped. +/// The function will save the SSP for future fixing in the function +/// emitLongJmpShadowStackFix. +/// \sa emitLongJmpShadowStackFix +/// \param [in] MI The temporary Machine Instruction for the builtin. +/// \param [in] MBB The Machine Basic Block that will be modified. +void X86TargetLowering::emitSetJmpShadowStackFix(MachineInstr &MI, + MachineBasicBlock *MBB) const { + DebugLoc DL = MI.getDebugLoc(); + MachineFunction *MF = MBB->getParent(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + MachineRegisterInfo &MRI = MF->getRegInfo(); + MachineInstrBuilder MIB; + + // Memory Reference. + MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); + MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + + // Initialize a register with zero. + MVT PVT = getPointerTy(MF->getDataLayout()); + const TargetRegisterClass *PtrRC = getRegClassFor(PVT); + unsigned ZReg = MRI.createVirtualRegister(PtrRC); + unsigned XorRROpc = (PVT == MVT::i64) ? X86::XOR64rr : X86::XOR32rr; + BuildMI(*MBB, MI, DL, TII->get(XorRROpc)) + .addDef(ZReg) + .addReg(ZReg, RegState::Undef) + .addReg(ZReg, RegState::Undef); + + // Read the current SSP Register value to the zeroed register. + unsigned SSPCopyReg = MRI.createVirtualRegister(PtrRC); + unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD; + BuildMI(*MBB, MI, DL, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg); + + // Write the SSP register value to offset 3 in input memory buffer. + unsigned PtrStoreOpc = (PVT == MVT::i64) ? X86::MOV64mr : X86::MOV32mr; + MIB = BuildMI(*MBB, MI, DL, TII->get(PtrStoreOpc)); + const int64_t SSPOffset = 3 * PVT.getStoreSize(); + const unsigned MemOpndSlot = 1; + for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { + if (i == X86::AddrDisp) + MIB.addDisp(MI.getOperand(MemOpndSlot + i), SSPOffset); + else + MIB.add(MI.getOperand(MemOpndSlot + i)); + } + MIB.addReg(SSPCopyReg); + MIB.setMemRefs(MMOBegin, MMOEnd); +} + MachineBasicBlock * X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, MachineBasicBlock *MBB) const { @@ -27028,6 +27938,11 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, else MIB.addMBB(restoreMBB); MIB.setMemRefs(MMOBegin, MMOEnd); + + if (MF->getMMI().getModule()->getModuleFlag("cf-protection-return")) { + emitSetJmpShadowStackFix(MI, thisMBB); + } + // Setup MIB = BuildMI(*thisMBB, MI, DL, TII->get(X86::EH_SjLj_Setup)) .addMBB(restoreMBB); @@ -27069,6 +27984,183 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, return sinkMBB; } +/// Fix the shadow stack using the previously saved SSP pointer. +/// \sa emitSetJmpShadowStackFix +/// \param [in] MI The temporary Machine Instruction for the builtin. +/// \param [in] MBB The Machine Basic Block that will be modified. +/// \return The sink MBB that will perform the future indirect branch. +MachineBasicBlock * +X86TargetLowering::emitLongJmpShadowStackFix(MachineInstr &MI, + MachineBasicBlock *MBB) const { + DebugLoc DL = MI.getDebugLoc(); + MachineFunction *MF = MBB->getParent(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + MachineRegisterInfo &MRI = MF->getRegInfo(); + + // Memory Reference + MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); + MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + + MVT PVT = getPointerTy(MF->getDataLayout()); + const TargetRegisterClass *PtrRC = getRegClassFor(PVT); + + // checkSspMBB: + // xor vreg1, vreg1 + // rdssp vreg1 + // test vreg1, vreg1 + // je sinkMBB # Jump if Shadow Stack is not supported + // fallMBB: + // mov buf+24/12(%rip), vreg2 + // sub vreg1, vreg2 + // jbe sinkMBB # No need to fix the Shadow Stack + // fixShadowMBB: + // shr 3/2, vreg2 + // incssp vreg2 # fix the SSP according to the lower 8 bits + // shr 8, vreg2 + // je sinkMBB + // fixShadowLoopPrepareMBB: + // shl vreg2 + // mov 128, vreg3 + // fixShadowLoopMBB: + // incssp vreg3 + // dec vreg2 + // jne fixShadowLoopMBB # Iterate until you finish fixing + // # the Shadow Stack + // sinkMBB: + + MachineFunction::iterator I = ++MBB->getIterator(); + const BasicBlock *BB = MBB->getBasicBlock(); + + MachineBasicBlock *checkSspMBB = MF->CreateMachineBasicBlock(BB); + MachineBasicBlock *fallMBB = MF->CreateMachineBasicBlock(BB); + MachineBasicBlock *fixShadowMBB = MF->CreateMachineBasicBlock(BB); + MachineBasicBlock *fixShadowLoopPrepareMBB = MF->CreateMachineBasicBlock(BB); + MachineBasicBlock *fixShadowLoopMBB = MF->CreateMachineBasicBlock(BB); + MachineBasicBlock *sinkMBB = MF->CreateMachineBasicBlock(BB); + MF->insert(I, checkSspMBB); + MF->insert(I, fallMBB); + MF->insert(I, fixShadowMBB); + MF->insert(I, fixShadowLoopPrepareMBB); + MF->insert(I, fixShadowLoopMBB); + MF->insert(I, sinkMBB); + + // Transfer the remainder of BB and its successor edges to sinkMBB. + sinkMBB->splice(sinkMBB->begin(), MBB, MachineBasicBlock::iterator(MI), + MBB->end()); + sinkMBB->transferSuccessorsAndUpdatePHIs(MBB); + + MBB->addSuccessor(checkSspMBB); + + // Initialize a register with zero. + unsigned ZReg = MRI.createVirtualRegister(PtrRC); + unsigned XorRROpc = (PVT == MVT::i64) ? X86::XOR64rr : X86::XOR32rr; + BuildMI(checkSspMBB, DL, TII->get(XorRROpc)) + .addDef(ZReg) + .addReg(ZReg, RegState::Undef) + .addReg(ZReg, RegState::Undef); + + // Read the current SSP Register value to the zeroed register. + unsigned SSPCopyReg = MRI.createVirtualRegister(PtrRC); + unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD; + BuildMI(checkSspMBB, DL, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg); + + // Check whether the result of the SSP register is zero and jump directly + // to the sink. + unsigned TestRROpc = (PVT == MVT::i64) ? X86::TEST64rr : X86::TEST32rr; + BuildMI(checkSspMBB, DL, TII->get(TestRROpc)) + .addReg(SSPCopyReg) + .addReg(SSPCopyReg); + BuildMI(checkSspMBB, DL, TII->get(X86::JE_1)).addMBB(sinkMBB); + checkSspMBB->addSuccessor(sinkMBB); + checkSspMBB->addSuccessor(fallMBB); + + // Reload the previously saved SSP register value. + unsigned PrevSSPReg = MRI.createVirtualRegister(PtrRC); + unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm; + const int64_t SPPOffset = 3 * PVT.getStoreSize(); + MachineInstrBuilder MIB = + BuildMI(fallMBB, DL, TII->get(PtrLoadOpc), PrevSSPReg); + for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { + if (i == X86::AddrDisp) + MIB.addDisp(MI.getOperand(i), SPPOffset); + else + MIB.add(MI.getOperand(i)); + } + MIB.setMemRefs(MMOBegin, MMOEnd); + + // Subtract the current SSP from the previous SSP. + unsigned SspSubReg = MRI.createVirtualRegister(PtrRC); + unsigned SubRROpc = (PVT == MVT::i64) ? X86::SUB64rr : X86::SUB32rr; + BuildMI(fallMBB, DL, TII->get(SubRROpc), SspSubReg) + .addReg(PrevSSPReg) + .addReg(SSPCopyReg); + + // Jump to sink in case PrevSSPReg <= SSPCopyReg. + BuildMI(fallMBB, DL, TII->get(X86::JBE_1)).addMBB(sinkMBB); + fallMBB->addSuccessor(sinkMBB); + fallMBB->addSuccessor(fixShadowMBB); + + // Shift right by 2/3 for 32/64 because incssp multiplies the argument by 4/8. + unsigned ShrRIOpc = (PVT == MVT::i64) ? X86::SHR64ri : X86::SHR32ri; + unsigned Offset = (PVT == MVT::i64) ? 3 : 2; + unsigned SspFirstShrReg = MRI.createVirtualRegister(PtrRC); + BuildMI(fixShadowMBB, DL, TII->get(ShrRIOpc), SspFirstShrReg) + .addReg(SspSubReg) + .addImm(Offset); + + // Increase SSP when looking only on the lower 8 bits of the delta. + unsigned IncsspOpc = (PVT == MVT::i64) ? X86::INCSSPQ : X86::INCSSPD; + BuildMI(fixShadowMBB, DL, TII->get(IncsspOpc)).addReg(SspFirstShrReg); + + // Reset the lower 8 bits. + unsigned SspSecondShrReg = MRI.createVirtualRegister(PtrRC); + BuildMI(fixShadowMBB, DL, TII->get(ShrRIOpc), SspSecondShrReg) + .addReg(SspFirstShrReg) + .addImm(8); + + // Jump if the result of the shift is zero. + BuildMI(fixShadowMBB, DL, TII->get(X86::JE_1)).addMBB(sinkMBB); + fixShadowMBB->addSuccessor(sinkMBB); + fixShadowMBB->addSuccessor(fixShadowLoopPrepareMBB); + + // Do a single shift left. + unsigned ShlR1Opc = (PVT == MVT::i64) ? X86::SHL64r1 : X86::SHL32r1; + unsigned SspAfterShlReg = MRI.createVirtualRegister(PtrRC); + BuildMI(fixShadowLoopPrepareMBB, DL, TII->get(ShlR1Opc), SspAfterShlReg) + .addReg(SspSecondShrReg); + + // Save the value 128 to a register (will be used next with incssp). + unsigned Value128InReg = MRI.createVirtualRegister(PtrRC); + unsigned MovRIOpc = (PVT == MVT::i64) ? X86::MOV64ri32 : X86::MOV32ri; + BuildMI(fixShadowLoopPrepareMBB, DL, TII->get(MovRIOpc), Value128InReg) + .addImm(128); + fixShadowLoopPrepareMBB->addSuccessor(fixShadowLoopMBB); + + // Since incssp only looks at the lower 8 bits, we might need to do several + // iterations of incssp until we finish fixing the shadow stack. + unsigned DecReg = MRI.createVirtualRegister(PtrRC); + unsigned CounterReg = MRI.createVirtualRegister(PtrRC); + BuildMI(fixShadowLoopMBB, DL, TII->get(X86::PHI), CounterReg) + .addReg(SspAfterShlReg) + .addMBB(fixShadowLoopPrepareMBB) + .addReg(DecReg) + .addMBB(fixShadowLoopMBB); + + // Every iteration we increase the SSP by 128. + BuildMI(fixShadowLoopMBB, DL, TII->get(IncsspOpc)).addReg(Value128InReg); + + // Every iteration we decrement the counter by 1. + unsigned DecROpc = (PVT == MVT::i64) ? X86::DEC64r : X86::DEC32r; + BuildMI(fixShadowLoopMBB, DL, TII->get(DecROpc), DecReg).addReg(CounterReg); + + // Jump if the counter is not zero yet. + BuildMI(fixShadowLoopMBB, DL, TII->get(X86::JNE_1)).addMBB(fixShadowLoopMBB); + fixShadowLoopMBB->addSuccessor(sinkMBB); + fixShadowLoopMBB->addSuccessor(fixShadowLoopMBB); + + return sinkMBB; +} + MachineBasicBlock * X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, MachineBasicBlock *MBB) const { @@ -27101,13 +28193,21 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm; unsigned IJmpOpc = (PVT == MVT::i64) ? X86::JMP64r : X86::JMP32r; + MachineBasicBlock *thisMBB = MBB; + + // When CET and shadow stack is enabled, we need to fix the Shadow Stack. + if (MF->getMMI().getModule()->getModuleFlag("cf-protection-return")) { + thisMBB = emitLongJmpShadowStackFix(MI, thisMBB); + } + // Reload FP - MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), FP); + MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), FP); for (unsigned i = 0; i < X86::AddrNumOperands; ++i) MIB.add(MI.getOperand(i)); MIB.setMemRefs(MMOBegin, MMOEnd); + // Reload IP - MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), Tmp); + MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), Tmp); for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { if (i == X86::AddrDisp) MIB.addDisp(MI.getOperand(i), LabelOffset); @@ -27115,8 +28215,9 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, MIB.add(MI.getOperand(i)); } MIB.setMemRefs(MMOBegin, MMOEnd); + // Reload SP - MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), SP); + MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), SP); for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { if (i == X86::AddrDisp) MIB.addDisp(MI.getOperand(i), SPOffset); @@ -27124,11 +28225,12 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, MIB.add(MI.getOperand(i)); } MIB.setMemRefs(MMOBegin, MMOEnd); + // Jump - BuildMI(*MBB, MI, DL, TII->get(IJmpOpc)).addReg(Tmp); + BuildMI(*thisMBB, MI, DL, TII->get(IJmpOpc)).addReg(Tmp); MI.eraseFromParent(); - return MBB; + return thisMBB; } void X86TargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI, @@ -27201,7 +28303,7 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, MCSymbol *Sym = nullptr; for (const auto &MI : MBB) { - if (MI.isDebugValue()) + if (MI.isDebugInstr()) continue; assert(MI.isEHLabel() && "expected EH_LABEL"); @@ -27419,21 +28521,16 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); - case X86::TAILJMPd64: - case X86::TAILJMPr64: - case X86::TAILJMPm64: - case X86::TAILJMPr64_REX: - case X86::TAILJMPm64_REX: - llvm_unreachable("TAILJMP64 would not be touched here."); - case X86::TCRETURNdi64: - case X86::TCRETURNri64: - case X86::TCRETURNmi64: - return BB; case X86::TLS_addr32: case X86::TLS_addr64: case X86::TLS_base_addr32: case X86::TLS_base_addr64: return EmitLoweredTLSAddr(MI, BB); + case X86::RETPOLINE_CALL32: + case X86::RETPOLINE_CALL64: + case X86::RETPOLINE_TCRETURN32: + case X86::RETPOLINE_TCRETURN64: + return EmitLoweredRetpoline(MI, BB); case X86::CATCHRET: return EmitLoweredCatchRet(MI, BB); case X86::CATCHPAD: @@ -27446,7 +28543,7 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, return EmitLoweredTLSCall(MI, BB); case X86::CMOV_FR32: case X86::CMOV_FR64: - case X86::CMOV_FR128: + case X86::CMOV_F128: case X86::CMOV_GR8: case X86::CMOV_GR16: case X86::CMOV_GR32: @@ -27474,11 +28571,16 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.getOpcode() == X86::RDFLAGS32 ? X86::PUSHF32 : X86::PUSHF64; unsigned Pop = MI.getOpcode() == X86::RDFLAGS32 ? X86::POP32r : X86::POP64r; MachineInstr *Push = BuildMI(*BB, MI, DL, TII->get(PushF)); - // Permit reads of the FLAGS register without it being defined. + // Permit reads of the EFLAGS and DF registers without them being defined. // This intrinsic exists to read external processor state in flags, such as // the trap flag, interrupt flag, and direction flag, none of which are // modeled by the backend. + assert(Push->getOperand(2).getReg() == X86::EFLAGS && + "Unexpected register in operand!"); Push->getOperand(2).setIsUndef(); + assert(Push->getOperand(3).getReg() == X86::DF && + "Unexpected register in operand!"); + Push->getOperand(3).setIsUndef(); BuildMI(*BB, MI, DL, TII->get(Pop), MI.getOperand(0).getReg()); MI.eraseFromParent(); // The pseudo is gone now. @@ -27561,32 +28663,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); // The pseudo instruction is gone now. return BB; } - // String/text processing lowering. - case X86::PCMPISTRM128REG: - case X86::VPCMPISTRM128REG: - case X86::PCMPISTRM128MEM: - case X86::VPCMPISTRM128MEM: - case X86::PCMPESTRM128REG: - case X86::VPCMPESTRM128REG: - case X86::PCMPESTRM128MEM: - case X86::VPCMPESTRM128MEM: - assert(Subtarget.hasSSE42() && - "Target must have SSE4.2 or AVX features enabled"); - return emitPCMPSTRM(MI, BB, Subtarget.getInstrInfo()); - - // String/text processing lowering. - case X86::PCMPISTRIREG: - case X86::VPCMPISTRIREG: - case X86::PCMPISTRIMEM: - case X86::VPCMPISTRIMEM: - case X86::PCMPESTRIREG: - case X86::VPCMPESTRIREG: - case X86::PCMPESTRIMEM: - case X86::VPCMPESTRIMEM: - assert(Subtarget.hasSSE42() && - "Target must have SSE4.2 or AVX features enabled"); - return emitPCMPSTRI(MI, BB, Subtarget.getInstrInfo()); - // Thread synchronization. case X86::MONITOR: return emitMonitor(MI, BB, Subtarget, X86::MONITORrrr); @@ -27633,8 +28709,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, return emitPatchPoint(MI, BB); case TargetOpcode::PATCHABLE_EVENT_CALL: - // Do nothing here, handle in xray instrumentation pass. - return BB; + return emitXRayCustomEvent(MI, BB); + + case TargetOpcode::PATCHABLE_TYPED_EVENT_CALL: + return emitXRayTypedEvent(MI, BB); case X86::LCMPXCHG8B: { const X86RegisterInfo *TRI = Subtarget.getRegisterInfo(); @@ -27702,6 +28780,65 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, // X86 Optimization Hooks //===----------------------------------------------------------------------===// +bool +X86TargetLowering::targetShrinkDemandedConstant(SDValue Op, + const APInt &Demanded, + TargetLoweringOpt &TLO) const { + // Only optimize Ands to prevent shrinking a constant that could be + // matched by movzx. + if (Op.getOpcode() != ISD::AND) + return false; + + EVT VT = Op.getValueType(); + + // Ignore vectors. + if (VT.isVector()) + return false; + + unsigned Size = VT.getSizeInBits(); + + // Make sure the RHS really is a constant. + ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1)); + if (!C) + return false; + + const APInt &Mask = C->getAPIntValue(); + + // Clear all non-demanded bits initially. + APInt ShrunkMask = Mask & Demanded; + + // Find the width of the shrunk mask. + unsigned Width = ShrunkMask.getActiveBits(); + + // If the mask is all 0s there's nothing to do here. + if (Width == 0) + return false; + + // Find the next power of 2 width, rounding up to a byte. + Width = PowerOf2Ceil(std::max(Width, 8U)); + // Truncate the width to size to handle illegal types. + Width = std::min(Width, Size); + + // Calculate a possible zero extend mask for this constant. + APInt ZeroExtendMask = APInt::getLowBitsSet(Size, Width); + + // If we aren't changing the mask, just return true to keep it and prevent + // the caller from optimizing. + if (ZeroExtendMask == Mask) + return true; + + // Make sure the new mask can be represented by a combination of mask bits + // and non-demanded bits. + if (!ZeroExtendMask.isSubsetOf(Mask | ~Demanded)) + return false; + + // Replace the constant with the zero extend mask. + SDLoc DL(Op); + SDValue NewC = TLO.DAG.getConstant(ZeroExtendMask, DL, VT); + SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC); + return TLO.CombineTo(Op, NewOp); +} + void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known, const APInt &DemandedElts, @@ -27763,6 +28900,19 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, } break; } + case X86ISD::PACKUS: { + // PACKUS is just a truncation if the upper half is zero. + // TODO: Add DemandedElts support. + KnownBits Known2; + DAG.computeKnownBits(Op.getOperand(0), Known, Depth + 1); + DAG.computeKnownBits(Op.getOperand(1), Known2, Depth + 1); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + if (Known.countMinLeadingZeros() < BitWidth) + Known.resetAll(); + Known = Known.trunc(BitWidth); + break; + } case X86ISD::VZEXT: { // TODO: Add DemandedElts support. SDValue N0 = Op.getOperand(0); @@ -27801,6 +28951,57 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.Zero.setBitsFrom(8); break; } + + // Handle target shuffles. + // TODO - use resolveTargetShuffleInputs once we can limit recursive depth. + if (isTargetShuffle(Opc)) { + bool IsUnary; + SmallVector<int, 64> Mask; + SmallVector<SDValue, 2> Ops; + if (getTargetShuffleMask(Op.getNode(), VT.getSimpleVT(), true, Ops, Mask, + IsUnary)) { + unsigned NumOps = Ops.size(); + unsigned NumElts = VT.getVectorNumElements(); + if (Mask.size() == NumElts) { + SmallVector<APInt, 2> DemandedOps(NumOps, APInt(NumElts, 0)); + Known.Zero.setAllBits(); Known.One.setAllBits(); + for (unsigned i = 0; i != NumElts; ++i) { + if (!DemandedElts[i]) + continue; + int M = Mask[i]; + if (M == SM_SentinelUndef) { + // For UNDEF elements, we don't know anything about the common state + // of the shuffle result. + Known.resetAll(); + break; + } else if (M == SM_SentinelZero) { + Known.One.clearAllBits(); + continue; + } + assert(0 <= M && (unsigned)M < (NumOps * NumElts) && + "Shuffle index out of range"); + + unsigned OpIdx = (unsigned)M / NumElts; + unsigned EltIdx = (unsigned)M % NumElts; + if (Ops[OpIdx].getValueType() != VT) { + // TODO - handle target shuffle ops with different value types. + Known.resetAll(); + break; + } + DemandedOps[OpIdx].setBit(EltIdx); + } + // Known bits are the values that are shared by every demanded element. + for (unsigned i = 0; i != NumOps && !Known.isUnknown(); ++i) { + if (!DemandedOps[i]) + continue; + KnownBits Known2; + DAG.computeKnownBits(Ops[i], Known2, DemandedOps[i], Depth + 1); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + } + } + } } unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( @@ -27917,12 +29118,21 @@ bool X86TargetLowering::isGAPlusOffset(SDNode *N, // TODO: Investigate sharing more of this with shuffle lowering. static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, bool AllowFloatDomain, bool AllowIntDomain, - SDValue &V1, SDLoc &DL, SelectionDAG &DAG, + SDValue &V1, const SDLoc &DL, + SelectionDAG &DAG, const X86Subtarget &Subtarget, unsigned &Shuffle, MVT &SrcVT, MVT &DstVT) { unsigned NumMaskElts = Mask.size(); unsigned MaskEltSize = MaskVT.getScalarSizeInBits(); + // Match against a VZEXT_MOVL vXi32 zero-extending instruction. + if (MaskEltSize == 32 && isUndefOrEqual(Mask[0], 0) && + isUndefOrZero(Mask[1]) && isUndefInRange(Mask, 2, NumMaskElts - 2)) { + Shuffle = X86ISD::VZEXT_MOVL; + SrcVT = DstVT = !Subtarget.hasSSE2() ? MVT::v4f32 : MaskVT; + return true; + } + // Match against a ZERO_EXTEND_VECTOR_INREG/VZEXT instruction. // TODO: Add 512-bit vector support (split AVX512F and AVX512BW). if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSE41()) || @@ -28165,7 +29375,7 @@ static bool matchUnaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, // TODO: Investigate sharing more of this with shuffle lowering. static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, bool AllowFloatDomain, bool AllowIntDomain, - SDValue &V1, SDValue &V2, SDLoc &DL, + SDValue &V1, SDValue &V2, const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget, unsigned &Shuffle, MVT &SrcVT, MVT &DstVT, @@ -28175,27 +29385,28 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, if (MaskVT.is128BitVector()) { if (isTargetShuffleEquivalent(Mask, {0, 0}) && AllowFloatDomain) { V2 = V1; - Shuffle = X86ISD::MOVLHPS; - SrcVT = DstVT = MVT::v4f32; + V1 = (SM_SentinelUndef == Mask[0] ? DAG.getUNDEF(MVT::v4f32) : V1); + Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKL : X86ISD::MOVLHPS; + SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1}) && AllowFloatDomain) { V2 = V1; - Shuffle = X86ISD::MOVHLPS; - SrcVT = DstVT = MVT::v4f32; + Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKH : X86ISD::MOVHLPS; + SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() && (AllowFloatDomain || !Subtarget.hasSSE41())) { std::swap(V1, V2); Shuffle = X86ISD::MOVSD; - SrcVT = DstVT = MaskVT; + SrcVT = DstVT = MVT::v2f64; return true; } if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) && (AllowFloatDomain || !Subtarget.hasSSE41())) { Shuffle = X86ISD::MOVSS; - SrcVT = DstVT = MaskVT; + SrcVT = DstVT = MVT::v4f32; return true; } } @@ -28228,15 +29439,11 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, return false; } -static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, - const APInt &Zeroable, - bool AllowFloatDomain, - bool AllowIntDomain, - SDValue &V1, SDValue &V2, SDLoc &DL, - SelectionDAG &DAG, - const X86Subtarget &Subtarget, - unsigned &Shuffle, MVT &ShuffleVT, - unsigned &PermuteImm) { +static bool matchBinaryPermuteVectorShuffle( + MVT MaskVT, ArrayRef<int> Mask, const APInt &Zeroable, + bool AllowFloatDomain, bool AllowIntDomain, SDValue &V1, SDValue &V2, + const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget, + unsigned &Shuffle, MVT &ShuffleVT, unsigned &PermuteImm) { unsigned NumMaskElts = Mask.size(); unsigned EltSizeInBits = MaskVT.getScalarSizeInBits(); @@ -28385,7 +29592,7 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, return false; } -/// \brief Combine an arbitrary chain of shuffles into a single instruction if +/// Combine an arbitrary chain of shuffles into a single instruction if /// possible. /// /// This is the leaf of the recursive combine below. When we have found some @@ -28397,7 +29604,6 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth, bool HasVariableMask, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!"); assert((Inputs.size() == 1 || Inputs.size() == 2) && @@ -28430,6 +29636,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, unsigned NumRootElts = RootVT.getVectorNumElements(); unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts; bool FloatDomain = VT1.isFloatingPoint() || VT2.isFloatingPoint() || + (RootVT.isFloatingPoint() && Depth >= 2) || (RootVT.is256BitVector() && !Subtarget.hasAVX2()); // Don't combine if we are a AVX512/EVEX target and the mask element size @@ -28458,11 +29665,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4); Res = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res, DAG.getUNDEF(ShuffleVT), DAG.getConstant(PermMask, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28520,16 +29725,15 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, } } + SDValue NewV1 = V1; // Save operand in case early exit happens. if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, - V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, - ShuffleVT) && + NewV1, DL, DAG, Subtarget, Shuffle, + ShuffleSrcVT, ShuffleVT) && (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - Res = DAG.getBitcast(ShuffleSrcVT, V1); - DCI.AddToWorklist(Res.getNode()); + Res = DAG.getBitcast(ShuffleSrcVT, NewV1); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28540,43 +29744,38 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! Res = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, DAG.getConstant(PermuteImm, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } } + SDValue NewV1 = V1; // Save operands in case early exit happens. + SDValue NewV2 = V2; if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, - V1, V2, DL, DAG, Subtarget, Shuffle, + NewV1, NewV2, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, ShuffleVT, UnaryShuffle) && (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - V1 = DAG.getBitcast(ShuffleSrcVT, V1); - DCI.AddToWorklist(V1.getNode()); - V2 = DAG.getBitcast(ShuffleSrcVT, V2); - DCI.AddToWorklist(V2.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2); - DCI.AddToWorklist(Res.getNode()); + NewV1 = DAG.getBitcast(ShuffleSrcVT, NewV1); + NewV2 = DAG.getBitcast(ShuffleSrcVT, NewV2); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2); return DAG.getBitcast(RootVT, Res); } - if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain, - AllowIntDomain, V1, V2, DL, DAG, - Subtarget, Shuffle, ShuffleVT, - PermuteImm) && + NewV1 = V1; // Save operands in case early exit happens. + NewV2 = V2; + if (matchBinaryPermuteVectorShuffle( + MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, NewV1, + NewV2, DL, DAG, Subtarget, Shuffle, ShuffleVT, PermuteImm) && (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - V1 = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(V1.getNode()); - V2 = DAG.getBitcast(ShuffleVT, V2); - DCI.AddToWorklist(V2.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2, + NewV1 = DAG.getBitcast(ShuffleVT, NewV1); + NewV2 = DAG.getBitcast(ShuffleVT, NewV2); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2, DAG.getConstant(PermuteImm, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28592,11 +29791,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (Depth == 1 && Root.getOpcode() == X86ISD::EXTRQI) return SDValue(); // Nothing to do! V1 = DAG.getBitcast(IntMaskVT, V1); - DCI.AddToWorklist(V1.getNode()); Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1, DAG.getConstant(BitLen, DL, MVT::i8), DAG.getConstant(BitIdx, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28604,13 +29801,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (Depth == 1 && Root.getOpcode() == X86ISD::INSERTQI) return SDValue(); // Nothing to do! V1 = DAG.getBitcast(IntMaskVT, V1); - DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(IntMaskVT, V2); - DCI.AddToWorklist(V2.getNode()); Res = DAG.getNode(X86ISD::INSERTQI, DL, IntMaskVT, V1, V2, DAG.getConstant(BitLen, DL, MVT::i8), DAG.getConstant(BitIdx, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } } @@ -28640,11 +29834,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, (Subtarget.hasVBMI() && MaskVT == MVT::v64i8) || (Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) { SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true); - DCI.AddToWorklist(VPermMask.getNode()); Res = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28667,13 +29858,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, Mask[i] = NumMaskElts + i; SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true); - DCI.AddToWorklist(VPermMask.getNode()); Res = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(Res.getNode()); SDValue Zero = getZeroVector(MaskVT, Subtarget, DAG, DL); - DCI.AddToWorklist(Zero.getNode()); Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, Res, VPermMask, Zero); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28690,13 +29877,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, (Subtarget.hasVBMI() && MaskVT == MVT::v64i8) || (Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) { SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true); - DCI.AddToWorklist(VPermMask.getNode()); V1 = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(MaskVT, V2); - DCI.AddToWorklist(V2.getNode()); Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, V1, VPermMask, V2); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } return SDValue(); @@ -28722,13 +29905,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, EltBits[i] = AllOnes; } SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL); - DCI.AddToWorklist(BitMask.getNode()); Res = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(Res.getNode()); unsigned AndOpcode = FloatDomain ? unsigned(X86ISD::FAND) : unsigned(ISD::AND); Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28745,11 +29925,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, VPermIdx.push_back(Idx); } SDValue VPermMask = DAG.getBuildVector(IntMaskVT, DL, VPermIdx); - DCI.AddToWorklist(VPermMask.getNode()); Res = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28781,14 +29958,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, VPerm2Idx.push_back(Index); } V1 = DAG.getBitcast(MaskVT, V1); - DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(MaskVT, V2); - DCI.AddToWorklist(V2.getNode()); SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, IntMaskVT, DAG, DL, true); - DCI.AddToWorklist(VPerm2MaskOp.getNode()); Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp, DAG.getConstant(M2ZImm, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28820,11 +29993,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); Res = DAG.getBitcast(ByteVT, V1); - DCI.AddToWorklist(Res.getNode()); SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask); - DCI.AddToWorklist(PSHUFBMaskOp.getNode()); Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28853,13 +30023,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, } MVT ByteVT = MVT::v16i8; V1 = DAG.getBitcast(ByteVT, V1); - DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(ByteVT, V2); - DCI.AddToWorklist(V2.getNode()); SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask); - DCI.AddToWorklist(VPPERMMaskOp.getNode()); Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp); - DCI.AddToWorklist(Res.getNode()); return DAG.getBitcast(RootVT, Res); } @@ -28870,11 +30036,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // Attempt to constant fold all of the constant source ops. // Returns true if the entire shuffle is folded to a constant. // TODO: Extend this to merge multiple constant Ops and update the mask. -static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, +static SDValue combineX86ShufflesConstants(ArrayRef<SDValue> Ops, ArrayRef<int> Mask, SDValue Root, bool HasVariableMask, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MVT VT = Root.getSimpleValueType(); @@ -28950,11 +30115,10 @@ static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, SDLoc DL(Root); SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL); - DCI.AddToWorklist(CstOp.getNode()); return DAG.getBitcast(VT, CstOp); } -/// \brief Fully generic combining of x86 shuffle instructions. +/// Fully generic combining of x86 shuffle instructions. /// /// This should be the last combine run over the x86 shuffle instructions. Once /// they have been fully optimized, this will recursively consider all chains @@ -28985,12 +30149,12 @@ static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, /// combining in this recursive walk. static SDValue combineX86ShufflesRecursively( ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root, - ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, int Depth, - bool HasVariableMask, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, unsigned Depth, + bool HasVariableMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // Bound the depth of our recursive combine because this is ultimately // quadratic in nature. - if (Depth > 8) + const unsigned MaxRecursionDepth = 8; + if (Depth > MaxRecursionDepth) return SDValue(); // Directly rip through bitcasts to find the underlying operand. @@ -29143,17 +30307,21 @@ static SDValue combineX86ShufflesRecursively( // See if we can recurse into each shuffle source op (if it's a target // shuffle). The source op should only be combined if it either has a // single use (i.e. current Op) or all its users have already been combined. - for (int i = 0, e = Ops.size(); i < e; ++i) - if (Ops[i].getNode()->hasOneUse() || - SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode())) - if (SDValue Res = combineX86ShufflesRecursively( - Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask, - DAG, DCI, Subtarget)) - return Res; + // Don't recurse if we already have more source ops than we can combine in + // the remaining recursion depth. + if (Ops.size() < (MaxRecursionDepth - Depth)) { + for (int i = 0, e = Ops.size(); i < e; ++i) + if (Ops[i].getNode()->hasOneUse() || + SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode())) + if (SDValue Res = combineX86ShufflesRecursively( + Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask, + DAG, Subtarget)) + return Res; + } // Attempt to constant fold all of the constant source ops. if (SDValue Cst = combineX86ShufflesConstants( - Ops, Mask, Root, HasVariableMask, DAG, DCI, Subtarget)) + Ops, Mask, Root, HasVariableMask, DAG, Subtarget)) return Cst; // We can only combine unary and binary shuffle mask cases. @@ -29179,10 +30347,10 @@ static SDValue combineX86ShufflesRecursively( // Finally, try to combine into a single shuffle instruction. return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG, - DCI, Subtarget); + Subtarget); } -/// \brief Get the PSHUF-style mask from PSHUF node. +/// Get the PSHUF-style mask from PSHUF node. /// /// This is a very minor wrapper around getTargetShuffleMask to easy forming v4 /// PSHUF-style masks that can be reused with such instructions. @@ -29225,7 +30393,7 @@ static SmallVector<int, 4> getPSHUFShuffleMask(SDValue N) { } } -/// \brief Search for a combinable shuffle across a chain ending in pshufd. +/// Search for a combinable shuffle across a chain ending in pshufd. /// /// We walk up the chain and look for a combinable shuffle, skipping over /// shuffles that we could hoist this shuffle's transformation past without @@ -29358,7 +30526,7 @@ combineRedundantDWordShuffle(SDValue N, MutableArrayRef<int> Mask, return V; } -/// \brief Search for a combinable shuffle across a chain ending in pshuflw or +/// Search for a combinable shuffle across a chain ending in pshuflw or /// pshufhw. /// /// We walk up the chain, skipping shuffles of the other half and looking @@ -29426,7 +30594,7 @@ static bool combineRedundantHalfShuffle(SDValue N, MutableArrayRef<int> Mask, return true; } -/// \brief Try to combine x86 target specific shuffles. +/// Try to combine x86 target specific shuffles. static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -29459,12 +30627,33 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, Hi = BC1.getOperand(Opcode == X86ISD::UNPCKH ? 1 : 0); } SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi); - DCI.AddToWorklist(Horiz.getNode()); return DAG.getBitcast(VT, Horiz); } } switch (Opcode) { + case X86ISD::VBROADCAST: { + // If broadcasting from another shuffle, attempt to simplify it. + // TODO - we really need a general SimplifyDemandedVectorElts mechanism. + SDValue Src = N.getOperand(0); + SDValue BC = peekThroughBitcasts(Src); + EVT SrcVT = Src.getValueType(); + EVT BCVT = BC.getValueType(); + if (isTargetShuffle(BC.getOpcode()) && + VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits() == 0) { + unsigned Scale = VT.getScalarSizeInBits() / BCVT.getScalarSizeInBits(); + SmallVector<int, 16> DemandedMask(BCVT.getVectorNumElements(), + SM_SentinelUndef); + for (unsigned i = 0; i != Scale; ++i) + DemandedMask[i] = i; + if (SDValue Res = combineX86ShufflesRecursively( + {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, Subtarget)) + return DAG.getNode(X86ISD::VBROADCAST, DL, VT, + DAG.getBitcast(SrcVT, Res)); + } + return SDValue(); + } case X86ISD::PSHUFD: case X86ISD::PSHUFLW: case X86ISD::PSHUFHW: @@ -29505,53 +30694,31 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, } return SDValue(); } - case X86ISD::BLENDI: { - SDValue V0 = N->getOperand(0); - SDValue V1 = N->getOperand(1); - assert(VT == V0.getSimpleValueType() && VT == V1.getSimpleValueType() && - "Unexpected input vector types"); - - // Canonicalize a v2f64 blend with a mask of 2 by swapping the vector - // operands and changing the mask to 1. This saves us a bunch of - // pattern-matching possibilities related to scalar math ops in SSE/AVX. - // x86InstrInfo knows how to commute this back after instruction selection - // if it would help register allocation. - - // TODO: If optimizing for size or a processor that doesn't suffer from - // partial register update stalls, this should be transformed into a MOVSD - // instruction because a MOVSD is 1-2 bytes smaller than a BLENDPD. - - if (VT == MVT::v2f64) - if (auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(2))) - if (Mask->getZExtValue() == 2 && !isShuffleFoldableLoad(V0)) { - SDValue NewMask = DAG.getConstant(1, DL, MVT::i8); - return DAG.getNode(X86ISD::BLENDI, DL, VT, V1, V0, NewMask); - } - - return SDValue(); - } case X86ISD::MOVSD: case X86ISD::MOVSS: { - SDValue V0 = peekThroughBitcasts(N->getOperand(0)); - SDValue V1 = peekThroughBitcasts(N->getOperand(1)); - bool isZero0 = ISD::isBuildVectorAllZeros(V0.getNode()); - bool isZero1 = ISD::isBuildVectorAllZeros(V1.getNode()); - if (isZero0 && isZero1) - return SDValue(); + SDValue N0 = N.getOperand(0); + SDValue N1 = N.getOperand(1); - // We often lower to MOVSD/MOVSS from integer as well as native float - // types; remove unnecessary domain-crossing bitcasts if we can to make it - // easier to combine shuffles later on. We've already accounted for the - // domain switching cost when we decided to lower with it. - bool isFloat = VT.isFloatingPoint(); - bool isFloat0 = V0.getSimpleValueType().isFloatingPoint(); - bool isFloat1 = V1.getSimpleValueType().isFloatingPoint(); - if ((isFloat != isFloat0 || isZero0) && (isFloat != isFloat1 || isZero1)) { - MVT NewVT = isFloat ? (X86ISD::MOVSD == Opcode ? MVT::v2i64 : MVT::v4i32) - : (X86ISD::MOVSD == Opcode ? MVT::v2f64 : MVT::v4f32); - V0 = DAG.getBitcast(NewVT, V0); - V1 = DAG.getBitcast(NewVT, V1); - return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, NewVT, V0, V1)); + // Canonicalize scalar FPOps: + // MOVS*(N0, OP(N0, N1)) --> MOVS*(N0, SCALAR_TO_VECTOR(OP(N0[0], N1[0]))) + // If commutable, allow OP(N1[0], N0[0]). + unsigned Opcode1 = N1.getOpcode(); + if (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL || Opcode1 == ISD::FSUB || + Opcode1 == ISD::FDIV) { + SDValue N10 = N1.getOperand(0); + SDValue N11 = N1.getOperand(1); + if (N10 == N0 || + (N11 == N0 && (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL))) { + if (N10 != N0) + std::swap(N10, N11); + MVT SVT = VT.getVectorElementType(); + SDValue ZeroIdx = DAG.getIntPtrConstant(0, DL); + N10 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N10, ZeroIdx); + N11 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N11, ZeroIdx); + SDValue Scl = DAG.getNode(Opcode1, DL, SVT, N10, N11); + SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); + return DAG.getNode(Opcode, DL, VT, N0, SclVec); + } } return SDValue(); @@ -29647,7 +30814,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, // Nuke no-op shuffles that show up after combining. if (isNoopShuffleMask(Mask)) - return DCI.CombineTo(N.getNode(), N.getOperand(0), /*AddTo*/ true); + return N.getOperand(0); // Look for simplifications involving one or two shuffle instructions. SDValue V = N.getOperand(0); @@ -29671,10 +30838,8 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, DMask[DOffset + 1] = DOffset + 0; MVT DVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2); V = DAG.getBitcast(DVT, V); - DCI.AddToWorklist(V.getNode()); V = DAG.getNode(X86ISD::PSHUFD, DL, DVT, V, getV4X86ShuffleImm8ForMask(DMask, DL, DAG)); - DCI.AddToWorklist(V.getNode()); return DAG.getBitcast(VT, V); } @@ -29705,7 +30870,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, makeArrayRef(MappedMask).equals({4, 4, 5, 5, 6, 6, 7, 7})) { // We can replace all three shuffles with an unpack. V = DAG.getBitcast(VT, D.getOperand(0)); - DCI.AddToWorklist(V.getNode()); return DAG.getNode(MappedMask[0] == 0 ? X86ISD::UNPCKL : X86ISD::UNPCKH, DL, VT, V, V); @@ -29725,6 +30889,37 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, return SDValue(); } +/// Checks if the shuffle mask takes subsequent elements +/// alternately from two vectors. +/// For example <0, 5, 2, 7> or <8, 1, 10, 3, 12, 5, 14, 7> are both correct. +static bool isAddSubOrSubAddMask(ArrayRef<int> Mask, bool &Op0Even) { + + int ParitySrc[2] = {-1, -1}; + unsigned Size = Mask.size(); + for (unsigned i = 0; i != Size; ++i) { + int M = Mask[i]; + if (M < 0) + continue; + + // Make sure we are using the matching element from the input. + if ((M % Size) != i) + return false; + + // Make sure we use the same input for all elements of the same parity. + int Src = M / Size; + if (ParitySrc[i % 2] >= 0 && ParitySrc[i % 2] != Src) + return false; + ParitySrc[i % 2] = Src; + } + + // Make sure each input is used. + if (ParitySrc[0] < 0 || ParitySrc[1] < 0 || ParitySrc[0] == ParitySrc[1]) + return false; + + Op0Even = ParitySrc[0] == 0; + return true; +} + /// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD) /// operation. If true is returned then the operands of ADDSUB(SUBADD) operation /// are written to the parameters \p Opnd0 and \p Opnd1. @@ -29735,13 +30930,13 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, /// by this operation to try to flow through the rest of the combiner /// the fact that they're unused. static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, - SDValue &Opnd0, SDValue &Opnd1, - bool matchSubAdd = false) { + SelectionDAG &DAG, SDValue &Opnd0, SDValue &Opnd1, + bool &IsSubAdd) { EVT VT = N->getValueType(0); - if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && - (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) && - (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64))) + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!Subtarget.hasSSE3() || !TLI.isTypeLegal(VT) || + !VT.getSimpleVT().isFloatingPoint()) return false; // We only handle target-independent shuffles. @@ -29750,21 +30945,13 @@ static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, if (N->getOpcode() != ISD::VECTOR_SHUFFLE) return false; - ArrayRef<int> OrigMask = cast<ShuffleVectorSDNode>(N)->getMask(); - SmallVector<int, 16> Mask(OrigMask.begin(), OrigMask.end()); - SDValue V1 = N->getOperand(0); SDValue V2 = N->getOperand(1); - unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB; - unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD; - - // We require the first shuffle operand to be the ExpectedOpcode node, - // and the second to be the NextExpectedOpcode node. - if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) { - ShuffleVectorSDNode::commuteMask(Mask); - std::swap(V1, V2); - } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode) + // Make sure we have an FADD and an FSUB. + if ((V1.getOpcode() != ISD::FADD && V1.getOpcode() != ISD::FSUB) || + (V2.getOpcode() != ISD::FADD && V2.getOpcode() != ISD::FSUB) || + V1.getOpcode() == V2.getOpcode()) return false; // If there are other uses of these operations we can't fold them. @@ -29773,41 +30960,101 @@ static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, // Ensure that both operations have the same operands. Note that we can // commute the FADD operands. - SDValue LHS = V1->getOperand(0), RHS = V1->getOperand(1); - if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) && - (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS)) - return false; + SDValue LHS, RHS; + if (V1.getOpcode() == ISD::FSUB) { + LHS = V1->getOperand(0); RHS = V1->getOperand(1); + if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) && + (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS)) + return false; + } else { + assert(V2.getOpcode() == ISD::FSUB && "Unexpected opcode"); + LHS = V2->getOperand(0); RHS = V2->getOperand(1); + if ((V1->getOperand(0) != LHS || V1->getOperand(1) != RHS) && + (V1->getOperand(0) != RHS || V1->getOperand(1) != LHS)) + return false; + } - // We're looking for blends between FADD and FSUB nodes. We insist on these - // nodes being lined up in a specific expected pattern. - if (!(isShuffleEquivalent(V1, V2, Mask, {0, 3}) || - isShuffleEquivalent(V1, V2, Mask, {0, 5, 2, 7}) || - isShuffleEquivalent(V1, V2, Mask, {0, 9, 2, 11, 4, 13, 6, 15}) || - isShuffleEquivalent(V1, V2, Mask, {0, 17, 2, 19, 4, 21, 6, 23, - 8, 25, 10, 27, 12, 29, 14, 31}))) + ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask(); + bool Op0Even; + if (!isAddSubOrSubAddMask(Mask, Op0Even)) return false; + // It's a subadd if the vector in the even parity is an FADD. + IsSubAdd = Op0Even ? V1->getOpcode() == ISD::FADD + : V2->getOpcode() == ISD::FADD; + Opnd0 = LHS; Opnd1 = RHS; return true; } -/// \brief Try to combine a shuffle into a target-specific add-sub or +/// Combine shuffle of two fma nodes into FMAddSub or FMSubAdd. +static SDValue combineShuffleToFMAddSub(SDNode *N, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + // We only handle target-independent shuffles. + // FIXME: It would be easy and harmless to use the target shuffle mask + // extraction tool to support more. + if (N->getOpcode() != ISD::VECTOR_SHUFFLE) + return SDValue(); + + MVT VT = N->getSimpleValueType(0); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!Subtarget.hasAnyFMA() || !TLI.isTypeLegal(VT)) + return SDValue(); + + // We're trying to match (shuffle fma(a, b, c), X86Fmsub(a, b, c). + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue FMAdd = Op0, FMSub = Op1; + if (FMSub.getOpcode() != X86ISD::FMSUB) + std::swap(FMAdd, FMSub); + + if (FMAdd.getOpcode() != ISD::FMA || FMSub.getOpcode() != X86ISD::FMSUB || + FMAdd.getOperand(0) != FMSub.getOperand(0) || !FMAdd.hasOneUse() || + FMAdd.getOperand(1) != FMSub.getOperand(1) || !FMSub.hasOneUse() || + FMAdd.getOperand(2) != FMSub.getOperand(2)) + return SDValue(); + + // Check for correct shuffle mask. + ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask(); + bool Op0Even; + if (!isAddSubOrSubAddMask(Mask, Op0Even)) + return SDValue(); + + // FMAddSub takes zeroth operand from FMSub node. + SDLoc DL(N); + bool IsSubAdd = Op0Even ? Op0 == FMAdd : Op1 == FMAdd; + unsigned Opcode = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB; + return DAG.getNode(Opcode, DL, VT, FMAdd.getOperand(0), FMAdd.getOperand(1), + FMAdd.getOperand(2)); +} + +/// Try to combine a shuffle into a target-specific add-sub or /// mul-add-sub node. static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + if (SDValue V = combineShuffleToFMAddSub(N, Subtarget, DAG)) + return V; + SDValue Opnd0, Opnd1; - if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1)) + bool IsSubAdd; + if (!isAddSubOrSubAdd(N, Subtarget, DAG, Opnd0, Opnd1, IsSubAdd)) return SDValue(); - EVT VT = N->getValueType(0); + MVT VT = N->getSimpleValueType(0); SDLoc DL(N); // Try to generate X86ISD::FMADDSUB node here. SDValue Opnd2; - if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) - return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) { + unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB; + return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2); + } + + if (IsSubAdd) + return SDValue(); // Do not generate X86ISD::ADDSUB node for 512-bit types even though // the ADDSUB idiom has been successfully recognized. There are no known @@ -29818,26 +31065,6 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } -/// \brief Try to combine a shuffle into a target-specific -/// mul-sub-add node. -static SDValue combineShuffleToFMSubAdd(SDNode *N, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDValue Opnd0, Opnd1; - if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true)) - return SDValue(); - - EVT VT = N->getValueType(0); - SDLoc DL(N); - - // Try to generate X86ISD::FMSUBADD node here. - SDValue Opnd2; - if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) - return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2); - - return SDValue(); -} - // We are looking for a shuffle where both sources are concatenated with undef // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so // if we can express this as a single-source shuffle, that's preferable. @@ -29897,8 +31124,8 @@ static SDValue foldShuffleOfHorizOp(SDNode *N) { // lanes of each operand as: // v4X32: A[0] + A[1] , A[2] + A[3] , B[0] + B[1] , B[2] + B[3] // ...similarly for v2f64 and v8i16. - // TODO: 256-bit is not the same because...x86. - if (HOp.getOperand(0) != HOp.getOperand(1) || HOp.getValueSizeInBits() != 128) + // TODO: Handle UNDEF operands. + if (HOp.getOperand(0) != HOp.getOperand(1)) return SDValue(); // When the operands of a horizontal math op are identical, the low half of @@ -29909,9 +31136,17 @@ static SDValue foldShuffleOfHorizOp(SDNode *N) { // TODO: Other mask possibilities like {1,1} and {1,0} could be added here, // but this should be tied to whatever horizontal op matching and shuffle // canonicalization are producing. - if (isTargetShuffleEquivalent(Mask, { 0, 0 }) || - isTargetShuffleEquivalent(Mask, { 0, 1, 0, 1 }) || - isTargetShuffleEquivalent(Mask, { 0, 1, 2, 3, 0, 1, 2, 3 })) + if (HOp.getValueSizeInBits() == 128 && + (isTargetShuffleEquivalent(Mask, {0, 0}) || + isTargetShuffleEquivalent(Mask, {0, 1, 0, 1}) || + isTargetShuffleEquivalent(Mask, {0, 1, 2, 3, 0, 1, 2, 3}))) + return HOp; + + if (HOp.getValueSizeInBits() == 256 && + (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}) || + isTargetShuffleEquivalent(Mask, {0, 1, 0, 1, 4, 5, 4, 5}) || + isTargetShuffleEquivalent( + Mask, {0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, 11, 8, 9, 10, 11}))) return HOp; return SDValue(); @@ -29929,9 +31164,6 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG)) return AddSub; - if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG)) - return FMSubAdd; - if (SDValue HAddSub = foldShuffleOfHorizOp(N)) return HAddSub; } @@ -30035,10 +31267,8 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, // a particular chain. if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; } return SDValue(); @@ -30155,53 +31385,6 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, SDValue N0 = BitCast.getOperand(0); EVT VecVT = N0->getValueType(0); - if (VT.isVector() && VecVT.isScalarInteger() && Subtarget.hasAVX512() && - N0->getOpcode() == ISD::OR) { - SDValue Op0 = N0->getOperand(0); - SDValue Op1 = N0->getOperand(1); - MVT TrunckVT; - MVT BitcastVT; - switch (VT.getSimpleVT().SimpleTy) { - default: - return SDValue(); - case MVT::v16i1: - TrunckVT = MVT::i8; - BitcastVT = MVT::v8i1; - break; - case MVT::v32i1: - TrunckVT = MVT::i16; - BitcastVT = MVT::v16i1; - break; - case MVT::v64i1: - TrunckVT = MVT::i32; - BitcastVT = MVT::v32i1; - break; - } - bool isArg0UndefRight = Op0->getOpcode() == ISD::SHL; - bool isArg0UndefLeft = - Op0->getOpcode() == ISD::ZERO_EXTEND || Op0->getOpcode() == ISD::AND; - bool isArg1UndefRight = Op1->getOpcode() == ISD::SHL; - bool isArg1UndefLeft = - Op1->getOpcode() == ISD::ZERO_EXTEND || Op1->getOpcode() == ISD::AND; - SDValue OpLeft; - SDValue OpRight; - if (isArg0UndefRight && isArg1UndefLeft) { - OpLeft = Op0; - OpRight = Op1; - } else if (isArg1UndefRight && isArg0UndefLeft) { - OpLeft = Op1; - OpRight = Op0; - } else - return SDValue(); - SDLoc DL(BitCast); - SDValue Shr = OpLeft->getOperand(0); - SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, TrunckVT, Shr); - SDValue Bitcast1 = DAG.getBitcast(BitcastVT, Trunc1); - SDValue Trunc2 = DAG.getNode(ISD::TRUNCATE, DL, TrunckVT, OpRight); - SDValue Bitcast2 = DAG.getBitcast(BitcastVT, Trunc2); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Bitcast1, Bitcast2); - } - if (!VT.isScalarInteger() || !VecVT.isSimple()) return SDValue(); @@ -30269,17 +31452,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, SDLoc DL(BitCast); SDValue V = DAG.getSExtOrTrunc(N0, DL, SExtVT); - if (SExtVT == MVT::v32i8 && !Subtarget.hasInt256()) { - // Handle pre-AVX2 cases by splitting to two v16i1's. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - MVT ShiftTy = TLI.getScalarShiftAmountTy(DAG.getDataLayout(), MVT::i32); - SDValue Lo = extract128BitVector(V, 0, DAG, DL); - SDValue Hi = extract128BitVector(V, 16, DAG, DL); - Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo); - Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi); - Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi, - DAG.getConstant(16, DL, ShiftTy)); - V = DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi); + if (SExtVT == MVT::v16i8 || SExtVT == MVT::v32i8) { + V = getPMOVMSKB(DL, V, DAG, Subtarget); return DAG.getZExtOrTrunc(V, DL, VT); } @@ -30296,6 +31470,153 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, return DAG.getZExtOrTrunc(V, DL, VT); } +// Convert a vXi1 constant build vector to the same width scalar integer. +static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) { + EVT SrcVT = Op.getValueType(); + assert(SrcVT.getVectorElementType() == MVT::i1 && + "Expected a vXi1 vector"); + assert(ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) && + "Expected a constant build vector"); + + APInt Imm(SrcVT.getVectorNumElements(), 0); + for (unsigned Idx = 0, e = Op.getNumOperands(); Idx < e; ++Idx) { + SDValue In = Op.getOperand(Idx); + if (!In.isUndef() && (cast<ConstantSDNode>(In)->getZExtValue() & 0x1)) + Imm.setBit(Idx); + } + EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), Imm.getBitWidth()); + return DAG.getConstant(Imm, SDLoc(Op), IntVT); +} + +static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + assert(N->getOpcode() == ISD::BITCAST && "Expected a bitcast"); + + if (!DCI.isBeforeLegalizeOps()) + return SDValue(); + + // Only do this if we have k-registers. + if (!Subtarget.hasAVX512()) + return SDValue(); + + EVT DstVT = N->getValueType(0); + SDValue Op = N->getOperand(0); + EVT SrcVT = Op.getValueType(); + + if (!Op.hasOneUse()) + return SDValue(); + + // Look for logic ops. + if (Op.getOpcode() != ISD::AND && + Op.getOpcode() != ISD::OR && + Op.getOpcode() != ISD::XOR) + return SDValue(); + + // Make sure we have a bitcast between mask registers and a scalar type. + if (!(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 && + DstVT.isScalarInteger()) && + !(DstVT.isVector() && DstVT.getVectorElementType() == MVT::i1 && + SrcVT.isScalarInteger())) + return SDValue(); + + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + if (LHS.hasOneUse() && LHS.getOpcode() == ISD::BITCAST && + LHS.getOperand(0).getValueType() == DstVT) + return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, LHS.getOperand(0), + DAG.getBitcast(DstVT, RHS)); + + if (RHS.hasOneUse() && RHS.getOpcode() == ISD::BITCAST && + RHS.getOperand(0).getValueType() == DstVT) + return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, + DAG.getBitcast(DstVT, LHS), RHS.getOperand(0)); + + // If the RHS is a vXi1 build vector, this is a good reason to flip too. + // Most of these have to move a constant from the scalar domain anyway. + if (ISD::isBuildVectorOfConstantSDNodes(RHS.getNode())) { + RHS = combinevXi1ConstantToInteger(RHS, DAG); + return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, + DAG.getBitcast(DstVT, LHS), RHS); + } + + return SDValue(); +} + +static SDValue createMMXBuildVector(SDValue N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc DL(N); + unsigned NumElts = N.getNumOperands(); + + auto *BV = cast<BuildVectorSDNode>(N); + SDValue Splat = BV->getSplatValue(); + + // Build MMX element from integer GPR or SSE float values. + auto CreateMMXElement = [&](SDValue V) { + if (V.isUndef()) + return DAG.getUNDEF(MVT::x86mmx); + if (V.getValueType().isFloatingPoint()) { + if (Subtarget.hasSSE1() && !isa<ConstantFPSDNode>(V)) { + V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, V); + V = DAG.getBitcast(MVT::v2i64, V); + return DAG.getNode(X86ISD::MOVDQ2Q, DL, MVT::x86mmx, V); + } + V = DAG.getBitcast(MVT::i32, V); + } else { + V = DAG.getAnyExtOrTrunc(V, DL, MVT::i32); + } + return DAG.getNode(X86ISD::MMX_MOVW2D, DL, MVT::x86mmx, V); + }; + + // Convert build vector ops to MMX data in the bottom elements. + SmallVector<SDValue, 8> Ops; + + // Broadcast - use (PUNPCKL+)PSHUFW to broadcast single element. + if (Splat) { + if (Splat.isUndef()) + return DAG.getUNDEF(MVT::x86mmx); + + Splat = CreateMMXElement(Splat); + + if (Subtarget.hasSSE1()) { + // Unpack v8i8 to splat i8 elements to lowest 16-bits. + if (NumElts == 8) + Splat = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx, + DAG.getConstant(Intrinsic::x86_mmx_punpcklbw, DL, MVT::i32), Splat, + Splat); + + // Use PSHUFW to repeat 16-bit elements. + unsigned ShufMask = (NumElts > 2 ? 0 : 0x44); + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx, + DAG.getConstant(Intrinsic::x86_sse_pshuf_w, DL, MVT::i32), Splat, + DAG.getConstant(ShufMask, DL, MVT::i8)); + } + Ops.append(NumElts, Splat); + } else { + for (unsigned i = 0; i != NumElts; ++i) + Ops.push_back(CreateMMXElement(N.getOperand(i))); + } + + // Use tree of PUNPCKLs to build up general MMX vector. + while (Ops.size() > 1) { + unsigned NumOps = Ops.size(); + unsigned IntrinOp = + (NumOps == 2 ? Intrinsic::x86_mmx_punpckldq + : (NumOps == 4 ? Intrinsic::x86_mmx_punpcklwd + : Intrinsic::x86_mmx_punpcklbw)); + SDValue Intrin = DAG.getConstant(IntrinOp, DL, MVT::i32); + for (unsigned i = 0; i != NumOps; i += 2) + Ops[i / 2] = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx, Intrin, + Ops[i], Ops[i + 1]); + Ops.resize(NumOps / 2); + } + + return Ops[0]; +} + static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -30309,42 +31630,124 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // (i16 movmsk (16i8 sext (v16i1 x))) // before the setcc result is scalarized on subtargets that don't have legal // vxi1 types. - if (DCI.isBeforeLegalize()) + if (DCI.isBeforeLegalize()) { if (SDValue V = combineBitcastvxi1(DAG, SDValue(N, 0), Subtarget)) return V; + + // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer + // type, widen both sides to avoid a trip through memory. + if ((VT == MVT::v4i1 || VT == MVT::v2i1) && SrcVT.isScalarInteger() && + Subtarget.hasAVX512()) { + SDLoc dl(N); + N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i8, N0); + N0 = DAG.getBitcast(MVT::v8i1, N0); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, N0, + DAG.getIntPtrConstant(0, dl)); + } + + // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer + // type, widen both sides to avoid a trip through memory. + if ((SrcVT == MVT::v4i1 || SrcVT == MVT::v2i1) && VT.isScalarInteger() && + Subtarget.hasAVX512()) { + SDLoc dl(N); + unsigned NumConcats = 8 / SrcVT.getVectorNumElements(); + SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT)); + Ops[0] = N0; + N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops); + N0 = DAG.getBitcast(MVT::i8, N0); + return DAG.getNode(ISD::TRUNCATE, dl, VT, N0); + } + } + // Since MMX types are special and don't usually play with other vector types, // it's better to handle them early to be sure we emit efficient code by // avoiding store-load conversions. + if (VT == MVT::x86mmx) { + // Detect MMX constant vectors. + APInt UndefElts; + SmallVector<APInt, 1> EltBits; + if (getTargetConstantBitsFromNode(N0, 64, UndefElts, EltBits)) { + SDLoc DL(N0); + // Handle zero-extension of i32 with MOVD. + if (EltBits[0].countLeadingZeros() >= 32) + return DAG.getNode(X86ISD::MMX_MOVW2D, DL, VT, + DAG.getConstant(EltBits[0].trunc(32), DL, MVT::i32)); + // Else, bitcast to a double. + // TODO - investigate supporting sext 32-bit immediates on x86_64. + APFloat F64(APFloat::IEEEdouble(), EltBits[0]); + return DAG.getBitcast(VT, DAG.getConstantFP(F64, DL, MVT::f64)); + } + + // Detect bitcasts to x86mmx low word. + if (N0.getOpcode() == ISD::BUILD_VECTOR && + (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8) && + N0.getOperand(0).getValueType() == SrcVT.getScalarType()) { + bool LowUndef = true, AllUndefOrZero = true; + for (unsigned i = 1, e = SrcVT.getVectorNumElements(); i != e; ++i) { + SDValue Op = N0.getOperand(i); + LowUndef &= Op.isUndef() || (i >= e/2); + AllUndefOrZero &= (Op.isUndef() || isNullConstant(Op)); + } + if (AllUndefOrZero) { + SDValue N00 = N0.getOperand(0); + SDLoc dl(N00); + N00 = LowUndef ? DAG.getAnyExtOrTrunc(N00, dl, MVT::i32) + : DAG.getZExtOrTrunc(N00, dl, MVT::i32); + return DAG.getNode(X86ISD::MMX_MOVW2D, dl, VT, N00); + } + } - // Detect bitcasts between i32 to x86mmx low word. - if (VT == MVT::x86mmx && N0.getOpcode() == ISD::BUILD_VECTOR && - SrcVT == MVT::v2i32 && isNullConstant(N0.getOperand(1))) { - SDValue N00 = N0->getOperand(0); - if (N00.getValueType() == MVT::i32) - return DAG.getNode(X86ISD::MMX_MOVW2D, SDLoc(N00), VT, N00); + // Detect bitcasts of 64-bit build vectors and convert to a + // MMX UNPCK/PSHUFW which takes MMX type inputs with the value in the + // lowest element. + if (N0.getOpcode() == ISD::BUILD_VECTOR && + (SrcVT == MVT::v2f32 || SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || + SrcVT == MVT::v8i8)) + return createMMXBuildVector(N0, DAG, Subtarget); + + // Detect bitcasts between element or subvector extraction to x86mmx. + if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) && + isNullConstant(N0.getOperand(1))) { + SDValue N00 = N0.getOperand(0); + if (N00.getValueType().is128BitVector()) + return DAG.getNode(X86ISD::MOVDQ2Q, SDLoc(N00), VT, + DAG.getBitcast(MVT::v2i64, N00)); + } + + // Detect bitcasts from FP_TO_SINT to x86mmx. + if (SrcVT == MVT::v2i32 && N0.getOpcode() == ISD::FP_TO_SINT) { + SDLoc DL(N0); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0, + DAG.getUNDEF(MVT::v2i32)); + return DAG.getNode(X86ISD::MOVDQ2Q, DL, VT, + DAG.getBitcast(MVT::v2i64, Res)); + } } - // Detect bitcasts between element or subvector extraction to x86mmx. - if (VT == MVT::x86mmx && - (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT || - N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) && - isNullConstant(N0.getOperand(1))) { - SDValue N00 = N0->getOperand(0); - if (N00.getValueType().is128BitVector()) - return DAG.getNode(X86ISD::MOVDQ2Q, SDLoc(N00), VT, - DAG.getBitcast(MVT::v2i64, N00)); + // Try to remove a bitcast of constant vXi1 vector. We have to legalize + // most of these to scalar anyway. + if (Subtarget.hasAVX512() && VT.isScalarInteger() && + SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 && + ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) { + return combinevXi1ConstantToInteger(N0, DAG); } - // Detect bitcasts from FP_TO_SINT to x86mmx. - if (VT == MVT::x86mmx && SrcVT == MVT::v2i32 && - N0.getOpcode() == ISD::FP_TO_SINT) { - SDLoc DL(N0); - SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0, - DAG.getUNDEF(MVT::v2i32)); - return DAG.getNode(X86ISD::MOVDQ2Q, DL, VT, - DAG.getBitcast(MVT::v2i64, Res)); + if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() && + VT.isVector() && VT.getVectorElementType() == MVT::i1 && + isa<ConstantSDNode>(N0)) { + auto *C = cast<ConstantSDNode>(N0); + if (C->isAllOnesValue()) + return DAG.getConstant(1, SDLoc(N0), VT); + if (C->isNullValue()) + return DAG.getConstant(0, SDLoc(N0), VT); } + // Try to remove bitcasts from input and output of mask arithmetic to + // remove GPR<->K-register crossings. + if (SDValue V = combineCastedMaskArithmetic(N, DAG, DCI, Subtarget)) + return V; + // Convert a bitcasted integer logic operation that has one bitcasted // floating-point operand into a floating-point logic operation. This may // create a load of a constant, but that is cheaper than materializing the @@ -30517,8 +31920,8 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, // Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs // to these zexts. static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, - const SDValue &Zext1, const SDLoc &DL) { - + const SDValue &Zext1, const SDLoc &DL, + const X86Subtarget &Subtarget) { // Find the appropriate width for the PSADBW. EVT InVT = Zext0.getOperand(0).getValueType(); unsigned RegSize = std::max(128u, InVT.getSizeInBits()); @@ -30533,9 +31936,15 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, Ops[0] = Zext1.getOperand(0); SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); - // Actually build the SAD + // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW. + auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + MVT VT = MVT::getVectorVT(MVT::i64, Ops[0].getValueSizeInBits() / 64); + return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops); + }; MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64); - return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); + return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 }, + PSADBWBuilder); } // Attempt to replace an min/max v8i16/v16i8 horizontal reduction with @@ -30702,12 +32111,12 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, return SDValue(); unsigned RegSize = 128; - if (Subtarget.hasBWI()) + if (Subtarget.useBWIRegs()) RegSize = 512; - else if (Subtarget.hasAVX2()) + else if (Subtarget.hasAVX()) RegSize = 256; - // We handle upto v16i* for SSE2 / v32i* for AVX2 / v64i* for AVX512. + // We handle upto v16i* for SSE2 / v32i* for AVX / v64i* for AVX512. // TODO: We should be able to handle larger vectors by splitting them before // feeding them into several SADs, and then reducing over those. if (RegSize / VT.getVectorNumElements() < 8) @@ -30742,7 +32151,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, // Create the SAD instruction. SDLoc DL(Extract); - SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL); + SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL, Subtarget); // If the original vector was wider than 8 elements, sum over the results // in the SAD vector. @@ -30791,6 +32200,11 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, if (SrcSVT == MVT::i1 || !isa<ConstantSDNode>(Idx)) return SDValue(); + // Handle extract(broadcast(scalar_value)), it doesn't matter what index is. + if (X86ISD::VBROADCAST == Src.getOpcode() && + Src.getOperand(0).getValueType() == VT) + return Src.getOperand(0); + // Resolve the target shuffle inputs and mask. SmallVector<int, 16> Mask; SmallVector<SDValue, 2> Ops; @@ -30908,8 +32322,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, isa<ConstantSDNode>(EltIdx) && isa<ConstantSDNode>(InputVector.getOperand(0))) { uint64_t ExtractedElt = N->getConstantOperandVal(1); - uint64_t InputValue = InputVector.getConstantOperandVal(0); - uint64_t Res = (InputValue >> ExtractedElt) & 1; + auto *InputC = cast<ConstantSDNode>(InputVector.getOperand(0)); + const APInt &InputValue = InputC->getAPIntValue(); + uint64_t Res = InputValue[ExtractedElt]; return DAG.getConstant(Res, dl, MVT::i1); } @@ -30927,102 +32342,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget)) return MinMax; - // Only operate on vectors of 4 elements, where the alternative shuffling - // gets to be more expensive. - if (SrcVT != MVT::v4i32) - return SDValue(); - - // Check whether every use of InputVector is an EXTRACT_VECTOR_ELT with a - // single use which is a sign-extend or zero-extend, and all elements are - // used. - SmallVector<SDNode *, 4> Uses; - unsigned ExtractedElements = 0; - for (SDNode::use_iterator UI = InputVector.getNode()->use_begin(), - UE = InputVector.getNode()->use_end(); UI != UE; ++UI) { - if (UI.getUse().getResNo() != InputVector.getResNo()) - return SDValue(); - - SDNode *Extract = *UI; - if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) - return SDValue(); - - if (Extract->getValueType(0) != MVT::i32) - return SDValue(); - if (!Extract->hasOneUse()) - return SDValue(); - if (Extract->use_begin()->getOpcode() != ISD::SIGN_EXTEND && - Extract->use_begin()->getOpcode() != ISD::ZERO_EXTEND) - return SDValue(); - if (!isa<ConstantSDNode>(Extract->getOperand(1))) - return SDValue(); - - // Record which element was extracted. - ExtractedElements |= 1 << Extract->getConstantOperandVal(1); - Uses.push_back(Extract); - } - - // If not all the elements were used, this may not be worthwhile. - if (ExtractedElements != 15) - return SDValue(); - - // Ok, we've now decided to do the transformation. - // If 64-bit shifts are legal, use the extract-shift sequence, - // otherwise bounce the vector off the cache. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SDValue Vals[4]; - - if (TLI.isOperationLegal(ISD::SRA, MVT::i64)) { - SDValue Cst = DAG.getBitcast(MVT::v2i64, InputVector); - auto &DL = DAG.getDataLayout(); - EVT VecIdxTy = DAG.getTargetLoweringInfo().getVectorIdxTy(DL); - SDValue BottomHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Cst, - DAG.getConstant(0, dl, VecIdxTy)); - SDValue TopHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Cst, - DAG.getConstant(1, dl, VecIdxTy)); - - SDValue ShAmt = DAG.getConstant( - 32, dl, DAG.getTargetLoweringInfo().getShiftAmountTy(MVT::i64, DL)); - Vals[0] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, BottomHalf); - Vals[1] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, - DAG.getNode(ISD::SRA, dl, MVT::i64, BottomHalf, ShAmt)); - Vals[2] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, TopHalf); - Vals[3] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, - DAG.getNode(ISD::SRA, dl, MVT::i64, TopHalf, ShAmt)); - } else { - // Store the value to a temporary stack slot. - SDValue StackPtr = DAG.CreateStackTemporary(SrcVT); - SDValue Ch = DAG.getStore(DAG.getEntryNode(), dl, InputVector, StackPtr, - MachinePointerInfo()); - - EVT ElementType = SrcVT.getVectorElementType(); - unsigned EltSize = ElementType.getSizeInBits() / 8; - - // Replace each use (extract) with a load of the appropriate element. - for (unsigned i = 0; i < 4; ++i) { - uint64_t Offset = EltSize * i; - auto PtrVT = TLI.getPointerTy(DAG.getDataLayout()); - SDValue OffsetVal = DAG.getConstant(Offset, dl, PtrVT); - - SDValue ScalarAddr = - DAG.getNode(ISD::ADD, dl, PtrVT, StackPtr, OffsetVal); - - // Load the scalar. - Vals[i] = - DAG.getLoad(ElementType, dl, Ch, ScalarAddr, MachinePointerInfo()); - } - } - - // Replace the extracts - for (SmallVectorImpl<SDNode *>::iterator UI = Uses.begin(), - UE = Uses.end(); UI != UE; ++UI) { - SDNode *Extract = *UI; - - uint64_t IdxVal = Extract->getConstantOperandVal(1); - DAG.ReplaceAllUsesOfValueWith(SDValue(Extract, 0), Vals[IdxVal]); - } - - // The replacement was made in place; return N so it won't be revisited. - return SDValue(N, 0); + return SDValue(); } /// If a vector select has an operand that is -1 or 0, try to simplify the @@ -31051,8 +32371,7 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, if (TValIsAllZeros && Subtarget.hasAVX512() && Cond.hasOneUse() && CondVT.getVectorElementType() == MVT::i1) { // Invert the cond to not(cond) : xor(op,allones)=not(op) - SDValue CondNew = DAG.getNode(ISD::XOR, DL, CondVT, Cond, - DAG.getAllOnesConstant(DL, CondVT)); + SDValue CondNew = DAG.getNOT(DL, Cond, CondVT); // Vselect cond, op1, op2 = Vselect not(cond), op2, op1 return DAG.getSelect(DL, VT, CondNew, RHS, LHS); } @@ -31191,68 +32510,77 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -// If this is a bitcasted op that can be represented as another type, push the -// the bitcast to the inputs. This allows more opportunities for pattern -// matching masked instructions. This is called when we know that the operation -// is used as one of the inputs of a vselect. -static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { - // Make sure we have a bitcast. - if (OrigOp.getOpcode() != ISD::BITCAST) - return false; - - SDValue Op = OrigOp.getOperand(0); - - // If the operation is used by anything other than the bitcast, we shouldn't - // do this combine as that would replicate the operation. - if (!Op.hasOneUse()) - return false; +/// If this is a *dynamic* select (non-constant condition) and we can match +/// this node with one of the variable blend instructions, restructure the +/// condition so that blends can use the high (sign) bit of each element. +static SDValue combineVSelectToShrunkBlend(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + SDValue Cond = N->getOperand(0); + if (N->getOpcode() != ISD::VSELECT || + ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) + return SDValue(); + + // Don't optimize before the condition has been transformed to a legal type + // and don't ever optimize vector selects that map to AVX512 mask-registers. + unsigned BitWidth = Cond.getScalarValueSizeInBits(); + if (BitWidth < 8 || BitWidth > 64) + return SDValue(); + + // We can only handle the cases where VSELECT is directly legal on the + // subtarget. We custom lower VSELECT nodes with constant conditions and + // this makes it hard to see whether a dynamic VSELECT will correctly + // lower, so we both check the operation's status and explicitly handle the + // cases where a *dynamic* blend will fail even though a constant-condition + // blend could be custom lowered. + // FIXME: We should find a better way to handle this class of problems. + // Potentially, we should combine constant-condition vselect nodes + // pre-legalization into shuffles and not mark as many types as custom + // lowered. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = N->getValueType(0); + if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT)) + return SDValue(); + // FIXME: We don't support i16-element blends currently. We could and + // should support them by making *all* the bits in the condition be set + // rather than just the high bit and using an i8-element blend. + if (VT.getVectorElementType() == MVT::i16) + return SDValue(); + // Dynamic blending was only available from SSE4.1 onward. + if (VT.is128BitVector() && !Subtarget.hasSSE41()) + return SDValue(); + // Byte blends are only available in AVX2 + if (VT == MVT::v32i8 && !Subtarget.hasAVX2()) + return SDValue(); + // There are no 512-bit blend instructions that use sign bits. + if (VT.is512BitVector()) + return SDValue(); - MVT VT = OrigOp.getSimpleValueType(); - MVT EltVT = VT.getVectorElementType(); - SDLoc DL(Op.getNode()); + // TODO: Add other opcodes eventually lowered into BLEND. + for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end(); + UI != UE; ++UI) + if (UI->getOpcode() != ISD::VSELECT || UI.getOperandNo() != 0) + return SDValue(); - auto BitcastAndCombineShuffle = [&](unsigned Opcode, SDValue Op0, SDValue Op1, - SDValue Op2) { - Op0 = DAG.getBitcast(VT, Op0); - DCI.AddToWorklist(Op0.getNode()); - Op1 = DAG.getBitcast(VT, Op1); - DCI.AddToWorklist(Op1.getNode()); - DCI.CombineTo(OrigOp.getNode(), - DAG.getNode(Opcode, DL, VT, Op0, Op1, Op2)); - return true; - }; + APInt DemandedMask(APInt::getSignMask(BitWidth)); + KnownBits Known; + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + if (!TLI.SimplifyDemandedBits(Cond, DemandedMask, Known, TLO, 0, true)) + return SDValue(); - unsigned Opcode = Op.getOpcode(); - switch (Opcode) { - case X86ISD::SHUF128: { - if (EltVT.getSizeInBits() != 32 && EltVT.getSizeInBits() != 64) - return false; - // Only change element size, not type. - if (VT.isInteger() != Op.getSimpleValueType().isInteger()) - return false; - return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1), - Op.getOperand(2)); - } - case X86ISD::SUBV_BROADCAST: { - unsigned EltSize = EltVT.getSizeInBits(); - if (EltSize != 32 && EltSize != 64) - return false; - // Only change element size, not type. - if (VT.isInteger() != Op.getSimpleValueType().isInteger()) - return false; - SDValue Op0 = Op.getOperand(0); - MVT Op0VT = MVT::getVectorVT(EltVT, - Op0.getSimpleValueType().getSizeInBits() / EltSize); - Op0 = DAG.getBitcast(Op0VT, Op.getOperand(0)); - DCI.AddToWorklist(Op0.getNode()); - DCI.CombineTo(OrigOp.getNode(), - DAG.getNode(Opcode, DL, VT, Op0)); - return true; + // If we changed the computation somewhere in the DAG, this change will + // affect all users of Cond. Update all the nodes so that we do not use + // the generic VSELECT anymore. Otherwise, we may perform wrong + // optimizations as we messed with the actual expectation for the vector + // boolean values. + for (SDNode *U : Cond->uses()) { + SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, SDLoc(U), U->getValueType(0), + Cond, U->getOperand(1), U->getOperand(2)); + DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB); } - } - - return false; + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); } /// Do target-specific dag combines on SELECT and VSELECT nodes. @@ -31268,6 +32596,23 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, EVT CondVT = Cond.getValueType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // Convert vselects with constant condition into shuffles. + if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()) && + DCI.isBeforeLegalizeOps()) { + SmallVector<int, 64> Mask(VT.getVectorNumElements(), -1); + for (int i = 0, Size = Mask.size(); i != Size; ++i) { + SDValue CondElt = Cond->getOperand(i); + Mask[i] = i; + // Arbitrarily choose from the 2nd operand if the select condition element + // is undef. + // TODO: Can we do better by matching patterns such as even/odd? + if (CondElt.isUndef() || isNullConstant(CondElt)) + Mask[i] += Size; + } + + return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask); + } + // If we have SSE[12] support, try to form min/max nodes. SSE min/max // instructions match the semantics of the common C idiom x<y?x:y but not // x<=y?x:y, because of how they handle negative zero (which can be @@ -31292,7 +32637,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // and negative zero incorrectly. if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) { if (!DAG.getTarget().Options.UnsafeFPMath && - !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS))) + !(DAG.isKnownNeverZeroFloat(LHS) || + DAG.isKnownNeverZeroFloat(RHS))) break; std::swap(LHS, RHS); } @@ -31302,7 +32648,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // Converting this to a min would handle comparisons between positive // and negative zero incorrectly. if (!DAG.getTarget().Options.UnsafeFPMath && - !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS)) + !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS)) break; Opcode = X86ISD::FMIN; break; @@ -31321,7 +32667,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // Converting this to a max would handle comparisons between positive // and negative zero incorrectly. if (!DAG.getTarget().Options.UnsafeFPMath && - !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS)) + !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS)) break; Opcode = X86ISD::FMAX; break; @@ -31331,7 +32677,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // and negative zero incorrectly. if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) { if (!DAG.getTarget().Options.UnsafeFPMath && - !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS))) + !(DAG.isKnownNeverZeroFloat(LHS) || + DAG.isKnownNeverZeroFloat(RHS))) break; std::swap(LHS, RHS); } @@ -31358,7 +32705,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // and negative zero incorrectly, and swapping the operands would // cause it to handle NaNs incorrectly. if (!DAG.getTarget().Options.UnsafeFPMath && - !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS))) { + !(DAG.isKnownNeverZeroFloat(LHS) || + DAG.isKnownNeverZeroFloat(RHS))) { if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) break; std::swap(LHS, RHS); @@ -31394,7 +32742,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // and negative zero incorrectly, and swapping the operands would // cause it to handle NaNs incorrectly. if (!DAG.getTarget().Options.UnsafeFPMath && - !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS)) { + !DAG.isKnownNeverZeroFloat(LHS) && + !DAG.isKnownNeverZeroFloat(RHS)) { if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) break; std::swap(LHS, RHS); @@ -31418,19 +32767,38 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS); } + // Some mask scalar intrinsics rely on checking if only one bit is set + // and implement it in C code like this: + // A[0] = (U & 1) ? A[0] : W[0]; + // This creates some redundant instructions that break pattern matching. + // fold (select (setcc (and (X, 1), 0, seteq), Y, Z)) -> select(and(X, 1),Z,Y) + if (Subtarget.hasAVX512() && N->getOpcode() == ISD::SELECT && + Cond.getOpcode() == ISD::SETCC && (VT == MVT::f32 || VT == MVT::f64)) { + ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); + SDValue AndNode = Cond.getOperand(0); + if (AndNode.getOpcode() == ISD::AND && CC == ISD::SETEQ && + isNullConstant(Cond.getOperand(1)) && + isOneConstant(AndNode.getOperand(1))) { + // LHS and RHS swapped due to + // setcc outputting 1 when AND resulted in 0 and vice versa. + AndNode = DAG.getZExtOrTrunc(AndNode, DL, MVT::i8); + return DAG.getNode(ISD::SELECT, DL, VT, AndNode, RHS, LHS); + } + } + // v16i8 (select v16i1, v16i8, v16i8) does not have a proper // lowering on KNL. In this case we convert it to // v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction. - // The same situation for all 128 and 256-bit vectors of i8 and i16. + // The same situation all vectors of i8 and i16 without BWI. + // Make sure we extend these even before type legalization gets a chance to + // split wide vectors. // Since SKX these selects have a proper lowering. - if (Subtarget.hasAVX512() && CondVT.isVector() && + if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && CondVT.isVector() && CondVT.getVectorElementType() == MVT::i1 && - (VT.is128BitVector() || VT.is256BitVector()) && + VT.getVectorNumElements() > 4 && (VT.getVectorElementType() == MVT::i8 || - VT.getVectorElementType() == MVT::i16) && - !(Subtarget.hasBWI() && Subtarget.hasVLX())) { + VT.getVectorElementType() == MVT::i16)) { Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond); - DCI.AddToWorklist(Cond.getNode()); return DAG.getNode(N->getOpcode(), DL, VT, Cond, LHS, RHS); } @@ -31476,7 +32844,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC && // psubus is available in SSE2 and AVX2 for i8 and i16 vectors. ((Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) || - (Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)))) { + (Subtarget.hasAVX() && (VT == MVT::v32i8 || VT == MVT::v16i16)))) { ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); // Check if one of the arms of the VSELECT is a zero vector. If it's on the @@ -31494,40 +32862,50 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, SDValue OpLHS = Other->getOperand(0), OpRHS = Other->getOperand(1); SDValue CondRHS = Cond->getOperand(1); + auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops); + }; + // Look for a general sub with unsigned saturation first. // x >= y ? x-y : 0 --> subus x, y // x > y ? x-y : 0 --> subus x, y if ((CC == ISD::SETUGE || CC == ISD::SETUGT) && Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS)) - return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS); + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, + SUBUSBuilder); if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) - if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) { - if (auto *CondRHSBV = dyn_cast<BuildVectorSDNode>(CondRHS)) - if (auto *CondRHSConst = CondRHSBV->getConstantSplatNode()) - // If the RHS is a constant we have to reverse the const - // canonicalization. - // x > C-1 ? x+-C : 0 --> subus x, C - if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD && - CondRHSConst->getAPIntValue() == - (-OpRHSConst->getAPIntValue() - 1)) - return DAG.getNode( - X86ISD::SUBUS, DL, VT, OpLHS, - DAG.getConstant(-OpRHSConst->getAPIntValue(), DL, VT)); + if (isa<BuildVectorSDNode>(CondRHS)) { + // If the RHS is a constant we have to reverse the const + // canonicalization. + // x > C-1 ? x+-C : 0 --> subus x, C + auto MatchSUBUS = [](ConstantSDNode *Op, ConstantSDNode *Cond) { + return Cond->getAPIntValue() == (-Op->getAPIntValue() - 1); + }; + if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD && + ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchSUBUS)) { + OpRHS = DAG.getNode(ISD::SUB, DL, VT, + DAG.getConstant(0, DL, VT), OpRHS); + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, + SUBUSBuilder); + } // Another special case: If C was a sign bit, the sub has been // canonicalized into a xor. // FIXME: Would it be better to use computeKnownBits to determine // whether it's safe to decanonicalize the xor? // x s< 0 ? x^C : 0 --> subus x, C - if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR && - ISD::isBuildVectorAllZeros(CondRHS.getNode()) && - OpRHSConst->getAPIntValue().isSignMask()) - // Note that we have to rebuild the RHS constant here to ensure we - // don't rely on particular values of undef lanes. - return DAG.getNode( - X86ISD::SUBUS, DL, VT, OpLHS, - DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT)); + if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) + if (CC == ISD::SETLT && Other.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllZeros(CondRHS.getNode()) && + OpRHSConst->getAPIntValue().isSignMask()) { + OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT); + // Note that we have to rebuild the RHS constant here to ensure we + // don't rely on particular values of undef lanes. + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, + SUBUSBuilder); + } } } } @@ -31535,99 +32913,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget)) return V; - // If this is a *dynamic* select (non-constant condition) and we can match - // this node with one of the variable blend instructions, restructure the - // condition so that blends can use the high (sign) bit of each element and - // use SimplifyDemandedBits to simplify the condition operand. - if (N->getOpcode() == ISD::VSELECT && DCI.isBeforeLegalizeOps() && - !DCI.isBeforeLegalize() && - !ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) { - unsigned BitWidth = Cond.getScalarValueSizeInBits(); - - // Don't optimize vector selects that map to mask-registers. - if (BitWidth == 1) - return SDValue(); - - // We can only handle the cases where VSELECT is directly legal on the - // subtarget. We custom lower VSELECT nodes with constant conditions and - // this makes it hard to see whether a dynamic VSELECT will correctly - // lower, so we both check the operation's status and explicitly handle the - // cases where a *dynamic* blend will fail even though a constant-condition - // blend could be custom lowered. - // FIXME: We should find a better way to handle this class of problems. - // Potentially, we should combine constant-condition vselect nodes - // pre-legalization into shuffles and not mark as many types as custom - // lowered. - if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT)) - return SDValue(); - // FIXME: We don't support i16-element blends currently. We could and - // should support them by making *all* the bits in the condition be set - // rather than just the high bit and using an i8-element blend. - if (VT.getVectorElementType() == MVT::i16) - return SDValue(); - // Dynamic blending was only available from SSE4.1 onward. - if (VT.is128BitVector() && !Subtarget.hasSSE41()) - return SDValue(); - // Byte blends are only available in AVX2 - if (VT == MVT::v32i8 && !Subtarget.hasAVX2()) - return SDValue(); - // There are no 512-bit blend instructions that use sign bits. - if (VT.is512BitVector()) - return SDValue(); - - assert(BitWidth >= 8 && BitWidth <= 64 && "Invalid mask size"); - APInt DemandedMask(APInt::getSignMask(BitWidth)); - KnownBits Known; - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); - if (TLI.ShrinkDemandedConstant(Cond, DemandedMask, TLO) || - TLI.SimplifyDemandedBits(Cond, DemandedMask, Known, TLO)) { - // If we changed the computation somewhere in the DAG, this change will - // affect all users of Cond. Make sure it is fine and update all the nodes - // so that we do not use the generic VSELECT anymore. Otherwise, we may - // perform wrong optimizations as we messed with the actual expectation - // for the vector boolean values. - if (Cond != TLO.Old) { - // Check all uses of the condition operand to check whether it will be - // consumed by non-BLEND instructions. Those may require that all bits - // are set properly. - for (SDNode *U : Cond->uses()) { - // TODO: Add other opcodes eventually lowered into BLEND. - if (U->getOpcode() != ISD::VSELECT) - return SDValue(); - } - - // Update all users of the condition before committing the change, so - // that the VSELECT optimizations that expect the correct vector boolean - // value will not be triggered. - for (SDNode *U : Cond->uses()) { - SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, SDLoc(U), - U->getValueType(0), Cond, U->getOperand(1), - U->getOperand(2)); - DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB); - } - DCI.CommitTargetLoweringOpt(TLO); - return SDValue(); - } - // Only Cond (rather than other nodes in the computation chain) was - // changed. Change the condition just for N to keep the opportunity to - // optimize all other users their own way. - SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, DL, VT, TLO.New, LHS, RHS); - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), SB); - return SDValue(); - } - } - - // Look for vselects with LHS/RHS being bitcasted from an operation that - // can be executed on another type. Push the bitcast to the inputs of - // the operation. This exposes opportunities for using masking instructions. - if (N->getOpcode() == ISD::VSELECT && DCI.isAfterLegalizeVectorOps() && - CondVT.getVectorElementType() == MVT::i1) { - if (combineBitcastForMaskedOp(LHS, DAG, DCI)) - return SDValue(N, 0); - if (combineBitcastForMaskedOp(RHS, DAG, DCI)) - return SDValue(N, 0); - } + if (SDValue V = combineVSelectToShrunkBlend(N, DAG, DCI, Subtarget)) + return V; // Custom action for SELECT MMX if (VT == MVT::x86mmx) { @@ -31969,17 +33256,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, X86::CondCode CC = (X86::CondCode)N->getConstantOperandVal(2); SDValue Cond = N->getOperand(3); - if (CC == X86::COND_E || CC == X86::COND_NE) { - switch (Cond.getOpcode()) { - default: break; - case X86ISD::BSR: - case X86ISD::BSF: - // If operand of BSR / BSF are proven never zero, then ZF cannot be set. - if (DAG.isKnownNeverZero(Cond.getOperand(0))) - return (CC == X86::COND_E) ? FalseOp : TrueOp; - } - } - // Try to simplify the EFLAGS and condition code operands. // We can't always do this as FCMOV only supports a subset of X86 cond. if (SDValue Flags = combineSetCCEFLAGS(Cond, CC, DAG, Subtarget)) { @@ -32149,6 +33425,36 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, } } + // Handle (CMOV C-1, (ADD (CTTZ X), C), (X != 0)) -> + // (ADD (CMOV (CTTZ X), -1, (X != 0)), C) or + // (CMOV (ADD (CTTZ X), C), C-1, (X == 0)) -> + // (ADD (CMOV C-1, (CTTZ X), (X == 0)), C) + if (CC == X86::COND_NE || CC == X86::COND_E) { + auto *Cnst = CC == X86::COND_E ? dyn_cast<ConstantSDNode>(TrueOp) + : dyn_cast<ConstantSDNode>(FalseOp); + SDValue Add = CC == X86::COND_E ? FalseOp : TrueOp; + + if (Cnst && Add.getOpcode() == ISD::ADD && Add.hasOneUse()) { + auto *AddOp1 = dyn_cast<ConstantSDNode>(Add.getOperand(1)); + SDValue AddOp2 = Add.getOperand(0); + if (AddOp1 && (AddOp2.getOpcode() == ISD::CTTZ_ZERO_UNDEF || + AddOp2.getOpcode() == ISD::CTTZ)) { + APInt Diff = Cnst->getAPIntValue() - AddOp1->getAPIntValue(); + if (CC == X86::COND_E) { + Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(), AddOp2, + DAG.getConstant(Diff, DL, Add.getValueType()), + DAG.getConstant(CC, DL, MVT::i8), Cond); + } else { + Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(), + DAG.getConstant(Diff, DL, Add.getValueType()), + AddOp2, DAG.getConstant(CC, DL, MVT::i8), Cond); + } + return DAG.getNode(X86ISD::ADD, DL, Add.getValueType(), Add, + SDValue(AddOp1, 0)); + } + } + } + return SDValue(); } @@ -32276,13 +33582,6 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, if ((NumElts % 2) != 0) return SDValue(); - // If the upper 17 bits of each element are zero then we can use PMADD. - APInt Mask17 = APInt::getHighBitsSet(32, 17); - if (VT == MVT::v4i32 && DAG.MaskedValueIsZero(N0, Mask17) && - DAG.MaskedValueIsZero(N1, Mask17)) - return DAG.getNode(X86ISD::VPMADDWD, DL, VT, DAG.getBitcast(MVT::v8i16, N0), - DAG.getBitcast(MVT::v8i16, N1)); - unsigned RegSize = 128; MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts); @@ -32378,7 +33677,7 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, } static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, - EVT VT, SDLoc DL) { + EVT VT, const SDLoc &DL) { auto combineMulShlAddOrSub = [&](int Mult, int Shift, bool isAdd) { SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), @@ -32390,10 +33689,11 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, return Result; }; - auto combineMulMulAddOrSub = [&](bool isAdd) { + auto combineMulMulAddOrSub = [&](int Mul1, int Mul2, bool isAdd) { SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), - DAG.getConstant(9, DL, VT)); - Result = DAG.getNode(ISD::MUL, DL, VT, Result, DAG.getConstant(3, DL, VT)); + DAG.getConstant(Mul1, DL, VT)); + Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, Result, + DAG.getConstant(Mul2, DL, VT)); Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result, N->getOperand(0)); return Result; @@ -32408,43 +33708,137 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, case 21: // mul x, 21 => add ((shl (mul x, 5), 2), x) return combineMulShlAddOrSub(5, 2, /*isAdd*/ true); + case 41: + // mul x, 41 => add ((shl (mul x, 5), 3), x) + return combineMulShlAddOrSub(5, 3, /*isAdd*/ true); case 22: // mul x, 22 => add (add ((shl (mul x, 5), 2), x), x) return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), combineMulShlAddOrSub(5, 2, /*isAdd*/ true)); case 19: - // mul x, 19 => sub ((shl (mul x, 5), 2), x) - return combineMulShlAddOrSub(5, 2, /*isAdd*/ false); + // mul x, 19 => add ((shl (mul x, 9), 1), x) + return combineMulShlAddOrSub(9, 1, /*isAdd*/ true); + case 37: + // mul x, 37 => add ((shl (mul x, 9), 2), x) + return combineMulShlAddOrSub(9, 2, /*isAdd*/ true); + case 73: + // mul x, 73 => add ((shl (mul x, 9), 3), x) + return combineMulShlAddOrSub(9, 3, /*isAdd*/ true); case 13: // mul x, 13 => add ((shl (mul x, 3), 2), x) return combineMulShlAddOrSub(3, 2, /*isAdd*/ true); case 23: - // mul x, 13 => sub ((shl (mul x, 3), 3), x) + // mul x, 23 => sub ((shl (mul x, 3), 3), x) return combineMulShlAddOrSub(3, 3, /*isAdd*/ false); - case 14: - // mul x, 14 => add (add ((shl (mul x, 3), 2), x), x) - return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), - combineMulShlAddOrSub(3, 2, /*isAdd*/ true)); case 26: - // mul x, 26 => sub ((mul (mul x, 9), 3), x) - return combineMulMulAddOrSub(/*isAdd*/ false); + // mul x, 26 => add ((mul (mul x, 5), 5), x) + return combineMulMulAddOrSub(5, 5, /*isAdd*/ true); case 28: // mul x, 28 => add ((mul (mul x, 9), 3), x) - return combineMulMulAddOrSub(/*isAdd*/ true); + return combineMulMulAddOrSub(9, 3, /*isAdd*/ true); case 29: // mul x, 29 => add (add ((mul (mul x, 9), 3), x), x) return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), - combineMulMulAddOrSub(/*isAdd*/ true)); - case 30: - // mul x, 30 => sub (sub ((shl x, 5), x), x) - return DAG.getNode( - ISD::SUB, DL, VT, - DAG.getNode(ISD::SUB, DL, VT, - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(5, DL, MVT::i8)), - N->getOperand(0)), - N->getOperand(0)); + combineMulMulAddOrSub(9, 3, /*isAdd*/ true)); + } + + // Another trick. If this is a power 2 + 2/4/8, we can use a shift followed + // by a single LEA. + // First check if this a sum of two power of 2s because that's easy. Then + // count how many zeros are up to the first bit. + // TODO: We can do this even without LEA at a cost of two shifts and an add. + if (isPowerOf2_64(MulAmt & (MulAmt - 1))) { + unsigned ScaleShift = countTrailingZeros(MulAmt); + if (ScaleShift >= 1 && ScaleShift < 4) { + unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); + SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(ShiftAmt, DL, MVT::i8)); + SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(ScaleShift, DL, MVT::i8)); + return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2); + } + } + + return SDValue(); +} + +// If the upper 17 bits of each element are zero then we can use PMADDWD, +// which is always at least as quick as PMULLD, expect on KNL. +static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (Subtarget.getProcFamily() == X86Subtarget::IntelKNL) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Only support vXi32 vectors. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32) + return SDValue(); + + // Make sure the vXi16 type is legal. This covers the AVX512 without BWI case. + MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements()); + if (!DAG.getTargetLoweringInfo().isTypeLegal(WVT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (!DAG.MaskedValueIsZero(N1, Mask17) || + !DAG.MaskedValueIsZero(N0, Mask17)) + return SDValue(); + + // Use SplitOpsAndApply to handle AVX splitting. + auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + { DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) }, + PMADDWDBuilder); +} + +static SDValue combineMulToPMULDQ(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Only support vXi64 vectors. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 || + !DAG.getTargetLoweringInfo().isTypeLegal(VT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // MULDQ returns the 64-bit result of the signed multiplication of the lower + // 32-bits. We can lower with this if the sign bits stretch that far. + if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(N0) > 32 && + DAG.ComputeNumSignBits(N1) > 32) { + auto PMULDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::PMULDQ, DL, Ops[0].getValueType(), Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, { N0, N1 }, + PMULDQBuilder, /*CheckBWI*/false); } + + // If the upper bits are zero we can use a single pmuludq. + APInt Mask = APInt::getHighBitsSet(64, 32); + if (DAG.MaskedValueIsZero(N0, Mask) && DAG.MaskedValueIsZero(N1, Mask)) { + auto PMULUDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::PMULUDQ, DL, Ops[0].getValueType(), Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, { N0, N1 }, + PMULUDQBuilder, /*CheckBWI*/false); + } + return SDValue(); } @@ -32454,6 +33848,13 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); + + if (SDValue V = combineMulToPMADDWD(N, DAG, Subtarget)) + return V; + + if (SDValue V = combineMulToPMULDQ(N, DAG, Subtarget)) + return V; + if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); @@ -32473,9 +33874,14 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, if (!C) return SDValue(); uint64_t MulAmt = C->getZExtValue(); - if (isPowerOf2_64(MulAmt) || MulAmt == 3 || MulAmt == 5 || MulAmt == 9) + if (isPowerOf2_64(MulAmt)) return SDValue(); + SDLoc DL(N); + if (MulAmt == 3 || MulAmt == 5 || MulAmt == 9) + return DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), + N->getOperand(1)); + uint64_t MulAmt1 = 0; uint64_t MulAmt2 = 0; if ((MulAmt % 9) == 0) { @@ -32489,7 +33895,6 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, MulAmt2 = MulAmt / 3; } - SDLoc DL(N); SDValue NewMul; if (MulAmt2 && (isPowerOf2_64(MulAmt2) || MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)){ @@ -32523,39 +33928,47 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, "Both cases that could cause potential overflows should have " "already been handled."); int64_t SignMulAmt = C->getSExtValue(); - if ((SignMulAmt != INT64_MIN) && (SignMulAmt != INT64_MAX) && - (SignMulAmt != -INT64_MAX)) { - int NumSign = SignMulAmt > 0 ? 1 : -1; - bool IsPowerOf2_64PlusOne = isPowerOf2_64(NumSign * SignMulAmt - 1); - bool IsPowerOf2_64MinusOne = isPowerOf2_64(NumSign * SignMulAmt + 1); - if (IsPowerOf2_64PlusOne) { - // (mul x, 2^N + 1) => (add (shl x, N), x) - NewMul = DAG.getNode( - ISD::ADD, DL, VT, N->getOperand(0), - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(Log2_64(NumSign * SignMulAmt - 1), DL, - MVT::i8))); - } else if (IsPowerOf2_64MinusOne) { - // (mul x, 2^N - 1) => (sub (shl x, N), x) - NewMul = DAG.getNode( - ISD::SUB, DL, VT, - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(Log2_64(NumSign * SignMulAmt + 1), DL, - MVT::i8)), - N->getOperand(0)); - } + assert(SignMulAmt != INT64_MIN && "Int min should have been handled!"); + uint64_t AbsMulAmt = SignMulAmt < 0 ? -SignMulAmt : SignMulAmt; + if (isPowerOf2_64(AbsMulAmt - 1)) { + // (mul x, 2^N + 1) => (add (shl x, N), x) + NewMul = DAG.getNode( + ISD::ADD, DL, VT, N->getOperand(0), + DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(Log2_64(AbsMulAmt - 1), DL, + MVT::i8))); // To negate, subtract the number from zero - if ((IsPowerOf2_64PlusOne || IsPowerOf2_64MinusOne) && NumSign == -1) - NewMul = - DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), NewMul); + if (SignMulAmt < 0) + NewMul = DAG.getNode(ISD::SUB, DL, VT, + DAG.getConstant(0, DL, VT), NewMul); + } else if (isPowerOf2_64(AbsMulAmt + 1)) { + // (mul x, 2^N - 1) => (sub (shl x, N), x) + NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(Log2_64(AbsMulAmt + 1), + DL, MVT::i8)); + // To negate, reverse the operands of the subtract. + if (SignMulAmt < 0) + NewMul = DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), NewMul); + else + NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0)); + } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt - 2)) { + // (mul x, 2^N + 2) => (add (add (shl x, N), x), x) + NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(Log2_64(AbsMulAmt - 2), + DL, MVT::i8)); + NewMul = DAG.getNode(ISD::ADD, DL, VT, NewMul, N->getOperand(0)); + NewMul = DAG.getNode(ISD::ADD, DL, VT, NewMul, N->getOperand(0)); + } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt + 2)) { + // (mul x, 2^N - 2) => (sub (sub (shl x, N), x), x) + NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), + DAG.getConstant(Log2_64(AbsMulAmt + 2), + DL, MVT::i8)); + NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0)); + NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0)); } } - if (NewMul) - // Do not add new nodes to DAG combiner worklist. - DCI.CombineTo(N, NewMul, false); - - return SDValue(); + return NewMul; } static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG) { @@ -32670,11 +34083,17 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) { +static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + // Only do this on the last DAG combine as it can interfere with other + // combines. + if (!DCI.isAfterLegalizeDAG()) + return SDValue(); + // Try to improve a sequence of srl (and X, C1), C2 by inverting the order. // TODO: This is a generic DAG combine that became an x86-only combine to // avoid shortcomings in other folds such as bswap, bit-test ('bt'), and @@ -32691,6 +34110,14 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) { // transform should reduce code size. It may also enable secondary transforms // from improved known-bits analysis or instruction selection. APInt MaskVal = AndC->getAPIntValue(); + + // If this can be matched by a zero extend, don't optimize. + if (MaskVal.isMask()) { + unsigned TO = MaskVal.countTrailingOnes(); + if (TO >= 8 && isPowerOf2_32(TO)) + return SDValue(); + } + APInt NewMaskVal = MaskVal.lshr(ShiftC->getAPIntValue()); unsigned OldMaskSize = MaskVal.getMinSignedBits(); unsigned NewMaskSize = NewMaskVal.getMinSignedBits(); @@ -32717,7 +34144,7 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG, return V; if (N->getOpcode() == ISD::SRL) - if (SDValue V = combineShiftRightLogical(N, DAG)) + if (SDValue V = combineShiftRightLogical(N, DAG, DCI)) return V; return SDValue(); @@ -32797,12 +34224,10 @@ static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG, // Attempt to combine as shuffle. SDValue Op(N, 0); - if (SDValue Res = combineX86ShufflesRecursively( - {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + if (SDValue Res = + combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; return SDValue(); } @@ -32861,10 +34286,8 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; } // Constant Folding. @@ -32900,12 +34323,10 @@ static SDValue combineVectorInsert(SDNode *N, SelectionDAG &DAG, // Attempt to combine PINSRB/PINSRW patterns to a shuffle. SDValue Op(N, 0); - if (SDValue Res = combineX86ShufflesRecursively( - {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + if (SDValue Res = + combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; return SDValue(); } @@ -32973,9 +34394,13 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, SDValue FSetCC = DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CMP00, CMP01, DAG.getConstant(x86cc, DL, MVT::i8)); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, - N->getSimpleValueType(0), FSetCC, - DAG.getIntPtrConstant(0, DL)); + // Need to fill with zeros to ensure the bitcast will produce zeroes + // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that. + SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v16i1, + DAG.getConstant(0, DL, MVT::v16i1), + FSetCC, DAG.getIntPtrConstant(0, DL)); + return DAG.getZExtOrTrunc(DAG.getBitcast(MVT::i16, Ins), DL, + N->getSimpleValueType(0)); } SDValue OnesOrZeroesF = DAG.getNode(X86ISD::FSETCC, DL, CMP00.getValueType(), CMP00, CMP01, @@ -33012,25 +34437,40 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Try to match (and (xor X, -1), Y) logic pattern for (andnp X, Y) combines. +static bool matchANDXORWithAllOnesAsANDNP(SDNode *N, SDValue &X, SDValue &Y) { + if (N->getOpcode() != ISD::AND) + return false; + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (N0.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) { + X = N0.getOperand(0); + Y = N1; + return true; + } + if (N1.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) { + X = N1.getOperand(0); + Y = N0; + return true; + } + + return false; +} + /// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y). static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::AND); EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); - if (VT != MVT::v2i64 && VT != MVT::v4i64 && VT != MVT::v8i64) return SDValue(); - if (N0.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) - return DAG.getNode(X86ISD::ANDNP, DL, VT, N0.getOperand(0), N1); - - if (N1.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) - return DAG.getNode(X86ISD::ANDNP, DL, VT, N1.getOperand(0), N0); + SDValue X, Y; + if (matchANDXORWithAllOnesAsANDNP(N, X, Y)) + return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y); return SDValue(); } @@ -33042,8 +34482,7 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { // Even with AVX-512 this is still useful for removing casts around logical // operations on vXi1 mask types. static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { + const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); assert(VT.isVector() && "Expected vector type"); @@ -33214,7 +34653,7 @@ static bool hasBZHI(const X86Subtarget &Subtarget, MVT VT) { // It's equivalent to performing bzhi (zero high bits) on the input, with the // same index of the load. static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { + const X86Subtarget &Subtarget) { MVT VT = Node->getSimpleValueType(0); SDLoc dl(Node); @@ -33269,15 +34708,16 @@ static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG, // <- (and (srl 0xFFFFFFFF, (sub 32, idx))) // that will be replaced with one bzhi instruction. SDValue Inp = (i == 0) ? Node->getOperand(1) : Node->getOperand(0); - SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, VT); + SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, MVT::i32); // Get the Node which indexes into the array. SDValue Index = getIndexFromUnindexedLoad(Ld); if (!Index) return SDValue(); - Index = DAG.getZExtOrTrunc(Index, dl, VT); + Index = DAG.getZExtOrTrunc(Index, dl, MVT::i32); - SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, SizeC, Index); + SDValue Sub = DAG.getNode(ISD::SUB, dl, MVT::i32, SizeC, Index); + Sub = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Sub); SDValue AllOnes = DAG.getAllOnesConstant(dl, VT); SDValue LShr = DAG.getNode(ISD::SRL, dl, VT, AllOnes, Sub); @@ -33303,6 +34743,20 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, DAG.getBitcast(MVT::v4f32, N->getOperand(1)))); } + // Use a 32-bit and+zext if upper bits known zero. + if (VT == MVT::i64 && Subtarget.is64Bit() && + !isa<ConstantSDNode>(N->getOperand(1))) { + APInt HiMask = APInt::getHighBitsSet(64, 32); + if (DAG.MaskedValueIsZero(N->getOperand(1), HiMask) || + DAG.MaskedValueIsZero(N->getOperand(0), HiMask)) { + SDLoc dl(N); + SDValue LHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N->getOperand(0)); + SDValue RHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N->getOperand(1)); + return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, + DAG.getNode(ISD::AND, dl, MVT::i32, LHS, RHS)); + } + } + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -33326,10 +34780,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; } // Attempt to combine a scalar bitmask AND with an extracted shuffle. @@ -33365,7 +34817,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (SDValue Shuffle = combineX86ShufflesRecursively( {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 2, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) + /*HasVarMask*/ false, DAG, Subtarget)) return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), VT, Shuffle, N->getOperand(0).getOperand(1)); } @@ -33374,6 +34826,38 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Try to match OR(AND(~MASK,X),AND(MASK,Y)) logic pattern. +static bool matchLogicBlend(SDNode *N, SDValue &X, SDValue &Y, SDValue &Mask) { + if (N->getOpcode() != ISD::OR) + return false; + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Canonicalize AND to LHS. + if (N1.getOpcode() == ISD::AND) + std::swap(N0, N1); + + // Attempt to match OR(AND(M,Y),ANDNP(M,X)). + if (N0.getOpcode() != ISD::AND || N1.getOpcode() != X86ISD::ANDNP) + return false; + + Mask = N1.getOperand(0); + X = N1.getOperand(1); + + // Check to see if the mask appeared in both the AND and ANDNP. + if (N0.getOperand(0) == Mask) + Y = N0.getOperand(1); + else if (N0.getOperand(1) == Mask) + Y = N0.getOperand(0); + else + return false; + + // TODO: Attempt to match against AND(XOR(-1,M),Y) as well, waiting for + // ANDNP combine allows other combines to happen that prevent matching. + return true; +} + // Try to fold: // (or (and (m, y), (pandn m, x))) // into: @@ -33386,33 +34870,13 @@ static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { assert(N->getOpcode() == ISD::OR && "Unexpected Opcode"); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); - if (!((VT.is128BitVector() && Subtarget.hasSSE2()) || (VT.is256BitVector() && Subtarget.hasInt256()))) return SDValue(); - // Canonicalize AND to LHS. - if (N1.getOpcode() == ISD::AND) - std::swap(N0, N1); - - // TODO: Attempt to match against AND(XOR(-1,X),Y) as well, waiting for - // ANDNP combine allows other combines to happen that prevent matching. - if (N0.getOpcode() != ISD::AND || N1.getOpcode() != X86ISD::ANDNP) - return SDValue(); - - SDValue Mask = N1.getOperand(0); - SDValue X = N1.getOperand(1); - SDValue Y; - if (N0.getOperand(0) == Mask) - Y = N0.getOperand(1); - if (N0.getOperand(1) == Mask) - Y = N0.getOperand(0); - - // Check to see if the mask appeared in both the AND and ANDNP. - if (!Y.getNode()) + SDValue X, Y, Mask; + if (!matchLogicBlend(N, X, Y, Mask)) return SDValue(); // Validate that X, Y, and Mask are bitcasts, and see through them. @@ -33509,7 +34973,7 @@ static SDValue lowerX86CmpEqZeroToCtlzSrl(SDValue Op, EVT ExtTy, // encoding of shr and lzcnt is more desirable. SDValue Trunc = DAG.getZExtOrTrunc(Clz, dl, MVT::i32); SDValue Scc = DAG.getNode(ISD::SRL, dl, MVT::i32, Trunc, - DAG.getConstant(Log2b, dl, VT)); + DAG.getConstant(Log2b, dl, MVT::i8)); return DAG.getZExtOrTrunc(Scc, dl, ExtTy); } @@ -33829,63 +35293,180 @@ static bool isSATValidOnAVX512Subtarget(EVT SrcVT, EVT DstVT, return false; // FIXME: Scalar type may be supported if we move it to vector register. - if (!SrcVT.isVector() || !SrcVT.isSimple() || SrcVT.getSizeInBits() > 512) + if (!SrcVT.isVector()) return false; EVT SrcElVT = SrcVT.getScalarType(); EVT DstElVT = DstVT.getScalarType(); - if (SrcElVT.getSizeInBits() < 16 || SrcElVT.getSizeInBits() > 64) - return false; - if (DstElVT.getSizeInBits() < 8 || DstElVT.getSizeInBits() > 32) + if (DstElVT != MVT::i8 && DstElVT != MVT::i16 && DstElVT != MVT::i32) return false; if (SrcVT.is512BitVector() || Subtarget.hasVLX()) return SrcElVT.getSizeInBits() >= 32 || Subtarget.hasBWI(); return false; } -/// Detect a pattern of truncation with saturation: -/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). +/// Detect patterns of truncation with unsigned saturation: +/// +/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). +/// Return the source value x to be truncated or SDValue() if the pattern was +/// not matched. +/// +/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type), +/// where C1 >= 0 and C2 is unsigned max of destination type. +/// +/// (truncate (smax (smin (x, C2), C1)) to dest_type) +/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2. +/// +/// These two patterns are equivalent to: +/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type) +/// So return the smax(x, C1) value to be truncated or SDValue() if the +/// pattern was not matched. +static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG, + const SDLoc &DL) { + EVT InVT = In.getValueType(); + + // Saturation with truncation. We truncate from InVT to VT. + assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() && + "Unexpected types for truncate operation"); + + // Match min/max and return limit value as a parameter. + auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue { + if (V.getOpcode() == Opcode && + ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit)) + return V.getOperand(0); + return SDValue(); + }; + + APInt C1, C2; + if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2)) + // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according + // the element size of the destination type. + if (C2.isMask(VT.getScalarSizeInBits())) + return UMin; + + if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2)) + if (MatchMinMax(SMin, ISD::SMAX, C1)) + if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits())) + return SMin; + + if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1)) + if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2)) + if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) && + C2.uge(C1)) { + return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1)); + } + + return SDValue(); +} + +/// Detect patterns of truncation with signed saturation: +/// (truncate (smin ((smax (x, signed_min_of_dest_type)), +/// signed_max_of_dest_type)) to dest_type) +/// or: +/// (truncate (smax ((smin (x, signed_max_of_dest_type)), +/// signed_min_of_dest_type)) to dest_type). +/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type]. /// Return the source value to be truncated or SDValue() if the pattern was not /// matched. -static SDValue detectUSatPattern(SDValue In, EVT VT) { - if (In.getOpcode() != ISD::UMIN) +static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) { + unsigned NumDstBits = VT.getScalarSizeInBits(); + unsigned NumSrcBits = In.getScalarValueSizeInBits(); + assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation"); + + auto MatchMinMax = [](SDValue V, unsigned Opcode, + const APInt &Limit) -> SDValue { + APInt C; + if (V.getOpcode() == Opcode && + ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit) + return V.getOperand(0); return SDValue(); + }; - //Saturation with truncation. We truncate from InVT to VT. - assert(In.getScalarValueSizeInBits() > VT.getScalarSizeInBits() && - "Unexpected types for truncate operation"); - - APInt C; - if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) { - // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according - // the element size of the destination type. - return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) : - SDValue(); + APInt SignedMax, SignedMin; + if (MatchPackUS) { + SignedMax = APInt::getAllOnesValue(NumDstBits).zext(NumSrcBits); + SignedMin = APInt(NumSrcBits, 0); + } else { + SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits); + SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits); } + + if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) + if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) + return SMax; + + if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) + if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) + return SMin; + return SDValue(); } +/// Detect a pattern of truncation with signed saturation. +/// The types should allow to use VPMOVSS* instruction on AVX512. +/// Return the source value to be truncated or SDValue() if the pattern was not +/// matched. +static SDValue detectAVX512SSatPattern(SDValue In, EVT VT, + const X86Subtarget &Subtarget, + const TargetLowering &TLI) { + if (!TLI.isTypeLegal(In.getValueType())) + return SDValue(); + if (!isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget)) + return SDValue(); + return detectSSatPattern(In, VT); +} + /// Detect a pattern of truncation with saturation: /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). /// The types should allow to use VPMOVUS* instruction on AVX512. /// Return the source value to be truncated or SDValue() if the pattern was not /// matched. -static SDValue detectAVX512USatPattern(SDValue In, EVT VT, - const X86Subtarget &Subtarget) { +static SDValue detectAVX512USatPattern(SDValue In, EVT VT, SelectionDAG &DAG, + const SDLoc &DL, + const X86Subtarget &Subtarget, + const TargetLowering &TLI) { + if (!TLI.isTypeLegal(In.getValueType())) + return SDValue(); if (!isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget)) return SDValue(); - return detectUSatPattern(In, VT); + return detectUSatPattern(In, VT, DAG, DL); } -static SDValue -combineTruncateWithUSat(SDValue In, EVT VT, SDLoc &DL, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + EVT SVT = VT.getScalarType(); + EVT InVT = In.getValueType(); + EVT InSVT = InVT.getScalarType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isTypeLegal(In.getValueType()) || !TLI.isTypeLegal(VT)) - return SDValue(); - if (auto USatVal = detectUSatPattern(In, VT)) - if (isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget)) + if (TLI.isTypeLegal(InVT) && TLI.isTypeLegal(VT) && + isSATValidOnAVX512Subtarget(InVT, VT, Subtarget)) { + if (auto SSatVal = detectSSatPattern(In, VT)) + return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal); + if (auto USatVal = detectUSatPattern(In, VT, DAG, DL)) return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal); + } + if (VT.isVector() && isPowerOf2_32(VT.getVectorNumElements()) && + (SVT == MVT::i8 || SVT == MVT::i16) && + (InSVT == MVT::i16 || InSVT == MVT::i32)) { + if (auto USatVal = detectSSatPattern(In, VT, true)) { + // vXi32 -> vXi8 must be performed as PACKUSWB(PACKSSDW,PACKSSDW). + if (SVT == MVT::i8 && InSVT == MVT::i32) { + EVT MidVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + VT.getVectorNumElements()); + SDValue Mid = truncateVectorWithPACK(X86ISD::PACKSS, MidVT, USatVal, DL, + DAG, Subtarget); + if (Mid) + return truncateVectorWithPACK(X86ISD::PACKUS, VT, Mid, DL, DAG, + Subtarget); + } else if (SVT == MVT::i8 || Subtarget.hasSSE41()) + return truncateVectorWithPACK(X86ISD::PACKUS, VT, USatVal, DL, DAG, + Subtarget); + } + if (auto SSatVal = detectSSatPattern(In, VT)) + return truncateVectorWithPACK(X86ISD::PACKSS, VT, SSatVal, DL, DAG, + Subtarget); + } return SDValue(); } @@ -33895,7 +35476,7 @@ combineTruncateWithUSat(SDValue In, EVT VT, SDLoc &DL, SelectionDAG &DAG, static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL) { - if (!VT.isVector() || !VT.isSimple()) + if (!VT.isVector()) return SDValue(); EVT InVT = In.getValueType(); unsigned NumElems = VT.getVectorNumElements(); @@ -33937,42 +35518,13 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op); if (!C) return false; - uint64_t Val = C->getZExtValue(); - if (Val < Min || Val > Max) + const APInt &Val = C->getAPIntValue(); + if (Val.ult(Min) || Val.ugt(Max)) return false; } return true; }; - // Split vectors to legal target size and apply AVG. - auto LowerToAVG = [&](SDValue Op0, SDValue Op1) { - unsigned NumSubs = 1; - if (Subtarget.hasBWI()) { - if (VT.getSizeInBits() > 512) - NumSubs = VT.getSizeInBits() / 512; - } else if (Subtarget.hasAVX2()) { - if (VT.getSizeInBits() > 256) - NumSubs = VT.getSizeInBits() / 256; - } else { - if (VT.getSizeInBits() > 128) - NumSubs = VT.getSizeInBits() / 128; - } - - if (NumSubs == 1) - return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); - - SmallVector<SDValue, 4> Subs; - EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), - VT.getVectorNumElements() / NumSubs); - for (unsigned i = 0; i != NumSubs; ++i) { - unsigned Idx = i * SubVT.getVectorNumElements(); - SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); - SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); - Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS)); - } - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); - }; - // Check if each element of the vector is left-shifted by one. auto LHS = In.getOperand(0); auto RHS = In.getOperand(1); @@ -33986,6 +35538,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, Operands[0] = LHS.getOperand(0); Operands[1] = LHS.getOperand(1); + auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops); + }; + // Take care of the case when one of the operands is a constant vector whose // element is in the range [1, 256]. if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) && @@ -33996,7 +35553,9 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, SDValue VecOnes = DAG.getConstant(1, DL, InVT); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); - return LowerToAVG(Operands[0].getOperand(0), Operands[1]); + return SplitOpsAndApply(DAG, Subtarget, DL, VT, + { Operands[0].getOperand(0), Operands[1] }, + AVGBuilder); } if (Operands[0].getOpcode() == ISD::ADD) @@ -34019,8 +35578,10 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, Operands[j].getOperand(0).getValueType() != VT) return SDValue(); - // The pattern is detected, emit X86ISD::AVG instruction. - return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0)); + // The pattern is detected, emit X86ISD::AVG instruction(s). + return SplitOpsAndApply(DAG, Subtarget, DL, VT, + { Operands[0].getOperand(0), + Operands[1].getOperand(0) }, AVGBuilder); } return SDValue(); @@ -34451,6 +36012,63 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, SDValue StoredVal = St->getOperand(1); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // If this is a store of a scalar_to_vector to v1i1, just use a scalar store. + // This will avoid a copy to k-register. + if (VT == MVT::v1i1 && VT == StVT && Subtarget.hasAVX512() && + StoredVal.getOpcode() == ISD::SCALAR_TO_VECTOR && + StoredVal.getOperand(0).getValueType() == MVT::i8) { + return DAG.getStore(St->getChain(), dl, StoredVal.getOperand(0), + St->getBasePtr(), St->getPointerInfo(), + St->getAlignment(), St->getMemOperand()->getFlags()); + } + + // Widen v2i1/v4i1 stores to v8i1. + if ((VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT && + Subtarget.hasAVX512()) { + unsigned NumConcats = 8 / VT.getVectorNumElements(); + SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(VT)); + Ops[0] = StoredVal; + StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops); + return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + St->getPointerInfo(), St->getAlignment(), + St->getMemOperand()->getFlags()); + } + + // Turn vXi1 stores of constants into a scalar store. + if ((VT == MVT::v8i1 || VT == MVT::v16i1 || VT == MVT::v32i1 || + VT == MVT::v64i1) && VT == StVT && TLI.isTypeLegal(VT) && + ISD::isBuildVectorOfConstantSDNodes(StoredVal.getNode())) { + // If its a v64i1 store without 64-bit support, we need two stores. + if (VT == MVT::v64i1 && !Subtarget.is64Bit()) { + SDValue Lo = DAG.getBuildVector(MVT::v32i1, dl, + StoredVal->ops().slice(0, 32)); + Lo = combinevXi1ConstantToInteger(Lo, DAG); + SDValue Hi = DAG.getBuildVector(MVT::v32i1, dl, + StoredVal->ops().slice(32, 32)); + Hi = combinevXi1ConstantToInteger(Hi, DAG); + + unsigned Alignment = St->getAlignment(); + + SDValue Ptr0 = St->getBasePtr(); + SDValue Ptr1 = DAG.getMemBasePlusOffset(Ptr0, 4, dl); + + SDValue Ch0 = + DAG.getStore(St->getChain(), dl, Lo, Ptr0, St->getPointerInfo(), + Alignment, St->getMemOperand()->getFlags()); + SDValue Ch1 = + DAG.getStore(St->getChain(), dl, Hi, Ptr1, + St->getPointerInfo().getWithOffset(4), + MinAlign(Alignment, 4U), + St->getMemOperand()->getFlags()); + return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Ch0, Ch1); + } + + StoredVal = combinevXi1ConstantToInteger(StoredVal, DAG); + return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + St->getPointerInfo(), St->getAlignment(), + St->getMemOperand()->getFlags()); + } + // If we are saving a concatenation of two XMM registers and 32-byte stores // are slow, such as on Sandy Bridge, perform two 16-byte stores. bool Fast; @@ -34493,13 +36111,19 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, St->getPointerInfo(), St->getAlignment(), St->getMemOperand()->getFlags()); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (SDValue Val = - detectAVX512USatPattern(St->getValue(), St->getMemoryVT(), Subtarget)) + detectAVX512SSatPattern(St->getValue(), St->getMemoryVT(), Subtarget, + TLI)) + return EmitTruncSStore(true /* Signed saturation */, St->getChain(), + dl, Val, St->getBasePtr(), + St->getMemoryVT(), St->getMemOperand(), DAG); + if (SDValue Val = detectAVX512USatPattern(St->getValue(), St->getMemoryVT(), + DAG, dl, Subtarget, TLI)) return EmitTruncSStore(false /* Unsigned saturation */, St->getChain(), dl, Val, St->getBasePtr(), St->getMemoryVT(), St->getMemOperand(), DAG); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned NumElems = VT.getVectorNumElements(); assert(StVT != VT && "Cannot truncate to the same type"); unsigned FromSz = VT.getScalarSizeInBits(); @@ -34812,7 +36436,7 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, // Try to synthesize horizontal add/sub from adds/subs of shuffles. if (((Subtarget.hasSSE3() && (VT == MVT::v4f32 || VT == MVT::v2f64)) || - (Subtarget.hasFp256() && (VT == MVT::v8f32 || VT == MVT::v4f64))) && + (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))) && isHorizontalBinOp(LHS, RHS, IsFadd)) { auto NewOpcode = IsFadd ? X86ISD::FHADD : X86ISD::FHSUB; return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); @@ -34825,7 +36449,7 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, /// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) ) static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget, - SDLoc &DL) { + const SDLoc &DL) { assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode"); SDValue Src = N->getOperand(0); unsigned Opcode = Src.getOpcode(); @@ -34898,7 +36522,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its // better to truncate if we have the chance. if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && - !Subtarget.hasDQI()) + !TLI.isOperationLegal(Opcode, SrcVT)) return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { @@ -34915,88 +36539,50 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// Truncate a group of v4i32 into v16i8/v8i16 using X86ISD::PACKUS. -static SDValue -combineVectorTruncationWithPACKUS(SDNode *N, SelectionDAG &DAG, - SmallVector<SDValue, 8> &Regs) { - assert(Regs.size() > 0 && (Regs[0].getValueType() == MVT::v4i32 || - Regs[0].getValueType() == MVT::v2i64)); +/// Truncate using ISD::AND mask and X86ISD::PACKUS. +static SDValue combineVectorTruncationWithPACKUS(SDNode *N, const SDLoc &DL, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + EVT InSVT = InVT.getVectorElementType(); EVT OutVT = N->getValueType(0); EVT OutSVT = OutVT.getVectorElementType(); - EVT InVT = Regs[0].getValueType(); - EVT InSVT = InVT.getVectorElementType(); - SDLoc DL(N); - // First, use mask to unset all bits that won't appear in the result. - assert((OutSVT == MVT::i8 || OutSVT == MVT::i16) && - "OutSVT can only be either i8 or i16."); + // Split a long vector into vectors of legal type and mask to unset all bits + // that won't appear in the result to prevent saturation. + // TODO - we should be doing this at the maximum legal size but this is + // causing regressions where we're concatenating back to max width just to + // perform the AND and then extracting back again..... + unsigned NumSubRegs = InVT.getSizeInBits() / 128; + unsigned NumSubRegElts = 128 / InSVT.getSizeInBits(); + EVT SubRegVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubRegElts); + SmallVector<SDValue, 8> SubVecs(NumSubRegs); + APInt Mask = APInt::getLowBitsSet(InSVT.getSizeInBits(), OutSVT.getSizeInBits()); - SDValue MaskVal = DAG.getConstant(Mask, DL, InVT); - for (auto &Reg : Regs) - Reg = DAG.getNode(ISD::AND, DL, InVT, MaskVal, Reg); - - MVT UnpackedVT, PackedVT; - if (OutSVT == MVT::i8) { - UnpackedVT = MVT::v8i16; - PackedVT = MVT::v16i8; - } else { - UnpackedVT = MVT::v4i32; - PackedVT = MVT::v8i16; - } - - // In each iteration, truncate the type by a half size. - auto RegNum = Regs.size(); - for (unsigned j = 1, e = InSVT.getSizeInBits() / OutSVT.getSizeInBits(); - j < e; j *= 2, RegNum /= 2) { - for (unsigned i = 0; i < RegNum; i++) - Regs[i] = DAG.getBitcast(UnpackedVT, Regs[i]); - for (unsigned i = 0; i < RegNum / 2; i++) - Regs[i] = DAG.getNode(X86ISD::PACKUS, DL, PackedVT, Regs[i * 2], - Regs[i * 2 + 1]); - } - - // If the type of the result is v8i8, we need do one more X86ISD::PACKUS, and - // then extract a subvector as the result since v8i8 is not a legal type. - if (OutVT == MVT::v8i8) { - Regs[0] = DAG.getNode(X86ISD::PACKUS, DL, PackedVT, Regs[0], Regs[0]); - Regs[0] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, Regs[0], - DAG.getIntPtrConstant(0, DL)); - return Regs[0]; - } else if (RegNum > 1) { - Regs.resize(RegNum); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Regs); - } else - return Regs[0]; -} - -/// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS. -static SDValue -combineVectorTruncationWithPACKSS(SDNode *N, const X86Subtarget &Subtarget, - SelectionDAG &DAG, - SmallVector<SDValue, 8> &Regs) { - assert(Regs.size() > 0 && Regs[0].getValueType() == MVT::v4i32); - EVT OutVT = N->getValueType(0); - SDLoc DL(N); + SDValue MaskVal = DAG.getConstant(Mask, DL, SubRegVT); - // Shift left by 16 bits, then arithmetic-shift right by 16 bits. - SDValue ShAmt = DAG.getConstant(16, DL, MVT::i32); - for (auto &Reg : Regs) { - Reg = getTargetVShiftNode(X86ISD::VSHLI, DL, MVT::v4i32, Reg, ShAmt, - Subtarget, DAG); - Reg = getTargetVShiftNode(X86ISD::VSRAI, DL, MVT::v4i32, Reg, ShAmt, - Subtarget, DAG); + for (unsigned i = 0; i < NumSubRegs; i++) { + SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubRegVT, In, + DAG.getIntPtrConstant(i * NumSubRegElts, DL)); + SubVecs[i] = DAG.getNode(ISD::AND, DL, SubRegVT, Sub, MaskVal); } + In = DAG.getNode(ISD::CONCAT_VECTORS, DL, InVT, SubVecs); - for (unsigned i = 0, e = Regs.size() / 2; i < e; i++) - Regs[i] = DAG.getNode(X86ISD::PACKSS, DL, MVT::v8i16, Regs[i * 2], - Regs[i * 2 + 1]); + return truncateVectorWithPACK(X86ISD::PACKUS, OutVT, In, DL, DAG, Subtarget); +} - if (Regs.size() > 2) { - Regs.resize(Regs.size() / 2); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Regs); - } else - return Regs[0]; +/// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS. +static SDValue combineVectorTruncationWithPACKSS(SDNode *N, const SDLoc &DL, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + EVT OutVT = N->getValueType(0); + In = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, InVT, In, + DAG.getValueType(OutVT)); + return truncateVectorWithPACK(X86ISD::PACKSS, OutVT, In, DL, DAG, Subtarget); } /// This function transforms truncation from vXi32/vXi64 to vXi8/vXi16 into @@ -35037,32 +36623,21 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG, return SDValue(); SDLoc DL(N); - - // Split a long vector into vectors of legal type. - unsigned RegNum = InVT.getSizeInBits() / 128; - SmallVector<SDValue, 8> SubVec(RegNum); - unsigned NumSubRegElts = 128 / InSVT.getSizeInBits(); - EVT SubRegVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubRegElts); - - for (unsigned i = 0; i < RegNum; i++) - SubVec[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubRegVT, In, - DAG.getIntPtrConstant(i * NumSubRegElts, DL)); - // SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS // for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to // truncate 2 x v4i32 to v8i16. if (Subtarget.hasSSE41() || OutSVT == MVT::i8) - return combineVectorTruncationWithPACKUS(N, DAG, SubVec); - else if (InSVT == MVT::i32) - return combineVectorTruncationWithPACKSS(N, Subtarget, DAG, SubVec); - else - return SDValue(); + return combineVectorTruncationWithPACKUS(N, DL, Subtarget, DAG); + if (InSVT == MVT::i32) + return combineVectorTruncationWithPACKSS(N, DL, Subtarget, DAG); + + return SDValue(); } /// This function transforms vector truncation of 'extended sign-bits' or /// 'extended zero-bits' values. /// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations. -static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, +static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // Requires SSE2 but AVX512 has fast truncate. @@ -35082,7 +36657,7 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, MVT InVT = In.getValueType().getSimpleVT(); MVT InSVT = InVT.getScalarType(); - // Check we have a truncation suited for PACKSS. + // Check we have a truncation suited for PACKSS/PACKUS. if (!VT.is128BitVector() && !VT.is256BitVector()) return SDValue(); if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32) @@ -35090,25 +36665,79 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64) return SDValue(); - // Use PACKSS if the input has sign-bits that extend all the way to the - // packed/truncated value. e.g. Comparison result, sext_in_reg, etc. - unsigned NumSignBits = DAG.ComputeNumSignBits(In); - unsigned NumPackedBits = std::min<unsigned>(SVT.getSizeInBits(), 16); - if (NumSignBits > (InSVT.getSizeInBits() - NumPackedBits)) - return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget); + unsigned NumPackedSignBits = std::min<unsigned>(SVT.getSizeInBits(), 16); + unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8; // Use PACKUS if the input has zero-bits that extend all the way to the // packed/truncated value. e.g. masks, zext_in_reg, etc. KnownBits Known; DAG.computeKnownBits(In, Known); unsigned NumLeadingZeroBits = Known.countMinLeadingZeros(); - NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8; - if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedBits)) + if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits)) return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget); + // Use PACKSS if the input has sign-bits that extend all the way to the + // packed/truncated value. e.g. Comparison result, sext_in_reg, etc. + unsigned NumSignBits = DAG.ComputeNumSignBits(In); + if (NumSignBits > (InSVT.getSizeInBits() - NumPackedSignBits)) + return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget); + return SDValue(); } +// Try to form a MULHU or MULHS node by looking for +// (trunc (srl (mul ext, ext), 16)) +// TODO: This is X86 specific because we want to be able to handle wide types +// before type legalization. But we can only do it if the vector will be +// legalized via widening/splitting. Type legalization can't handle promotion +// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG +// combiner. +static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, + SelectionDAG &DAG, const X86Subtarget &Subtarget) { + // First instruction should be a right shift of a multiply. + if (Src.getOpcode() != ISD::SRL || + Src.getOperand(0).getOpcode() != ISD::MUL) + return SDValue(); + + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Only handle vXi16 types that are at least 128-bits. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 || + VT.getVectorNumElements() < 8) + return SDValue(); + + // Input type should be vXi32. + EVT InVT = Src.getValueType(); + if (InVT.getVectorElementType() != MVT::i32) + return SDValue(); + + // Need a shift by 16. + APInt ShiftAmt; + if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) || + ShiftAmt != 16) + return SDValue(); + + SDValue LHS = Src.getOperand(0).getOperand(0); + SDValue RHS = Src.getOperand(0).getOperand(1); + + unsigned ExtOpc = LHS.getOpcode(); + if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) || + RHS.getOpcode() != ExtOpc) + return SDValue(); + + // Peek through the extends. + LHS = LHS.getOperand(0); + RHS = RHS.getOperand(0); + + // Ensure the input types match. + if (LHS.getValueType() != VT || RHS.getValueType() != VT) + return SDValue(); + + unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU; + return DAG.getNode(Opc, DL, VT, LHS, RHS); +} + static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); @@ -35123,10 +36752,14 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; - // Try to combine truncation with unsigned saturation. - if (SDValue Val = combineTruncateWithUSat(Src, VT, DL, DAG, Subtarget)) + // Try to combine truncation with signed/unsigned saturation. + if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) return Val; + // Try to combine PMULHUW/PMULHW for vXi16. + if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget)) + return V; + // The bitcast source is a direct mmx result. // Detect bitcasts between i32 to x86mmx if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) { @@ -35224,7 +36857,7 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, // If we're negating an FMA node, then we can adjust the // instruction to include the extra negation. unsigned NewOpcode = 0; - if (Arg.hasOneUse()) { + if (Arg.hasOneUse() && Subtarget.hasAnyFMA()) { switch (Arg.getOpcode()) { case ISD::FMA: NewOpcode = X86ISD::FNMSUB; break; case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; @@ -35320,6 +36953,39 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT VT = N->getValueType(0); + unsigned NumBits = VT.getSizeInBits(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + + // TODO - Constant Folding. + if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) { + // Reduce Cst1 to the bottom 16-bits. + // NOTE: SimplifyDemandedBits won't do this for constants. + const APInt &Val1 = Cst1->getAPIntValue(); + APInt MaskedVal1 = Val1 & 0xFFFF; + if (MaskedVal1 != Val1) + return DAG.getNode(X86ISD::BEXTR, SDLoc(N), VT, Op0, + DAG.getConstant(MaskedVal1, SDLoc(N), VT)); + } + + // Only bottom 16-bits of the control bits are required. + KnownBits Known; + APInt DemandedMask(APInt::getLowBitsSet(NumBits, 16)); + if (TLI.SimplifyDemandedBits(Op1, DemandedMask, Known, TLO)) { + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } + + return SDValue(); +} static bool isNullFPScalarOrVectorConst(SDValue V) { return isNullFPConstant(V) || ISD::isBuildVectorAllZeros(V.getNode()); @@ -35450,8 +37116,6 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, if (Subtarget.useSoftFloat()) return SDValue(); - // TODO: Check for global or instruction-level "nnan". In that case, we - // should be able to lower to FMAX/FMIN alone. // TODO: If an operand is already known to be a NaN or not a NaN, this // should be an optional swap and FMAX/FMIN. @@ -35461,14 +37125,21 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64)))) return SDValue(); - // This takes at least 3 instructions, so favor a library call when operating - // on a scalar and minimizing code size. - if (!VT.isVector() && DAG.getMachineFunction().getFunction().optForMinSize()) - return SDValue(); - SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDLoc DL(N); + auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN; + + // If we don't have to respect NaN inputs, this is a direct translation to x86 + // min/max instructions. + if (DAG.getTarget().Options.NoNaNsFPMath || N->getFlags().hasNoNaNs()) + return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags()); + + // If we have to respect NaN inputs, this takes at least 3 instructions. + // Favor a library call when operating on a scalar and minimizing code size. + if (!VT.isVector() && DAG.getMachineFunction().getFunction().optForMinSize()) + return SDValue(); + EVT SetCCType = DAG.getTargetLoweringInfo().getSetCCResultType( DAG.getDataLayout(), *DAG.getContext(), VT); @@ -35491,9 +37162,8 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, // use those instructions for fmaxnum by selecting away a NaN input. // If either operand is NaN, the 2nd source operand (Op0) is passed through. - auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN; SDValue MinOrMax = DAG.getNode(MinMaxOp, DL, VT, Op1, Op0); - SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType , Op0, Op0, ISD::SETUO); + SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType, Op0, Op0, ISD::SETUO); // If Op0 is a NaN, select Op1. Otherwise, select the max. If both operands // are NaN, the NaN value of Op1 is the result. @@ -35519,10 +37189,8 @@ static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG, SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, DCI, Subtarget)) { - DCI.CombineTo(N, Res); - return SDValue(); - } + /*HasVarMask*/ false, DAG, Subtarget)) + return Res; } return SDValue(); @@ -35542,12 +37210,54 @@ static SDValue combineBT(SDNode *N, SelectionDAG &DAG, return SDValue(); } -static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +// Try to combine sext_in_reg of a cmov of constants by extending the constants. +static SDValue combineSextInRegCmov(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); - if (!VT.isVector()) + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT ExtraVT = cast<VTSDNode>(N1)->getVT(); + + if (ExtraVT != MVT::i16) return SDValue(); + // Look through single use any_extends. + if (N0.getOpcode() == ISD::ANY_EXTEND && N0.hasOneUse()) + N0 = N0.getOperand(0); + + // See if we have a single use cmov. + if (N0.getOpcode() != X86ISD::CMOV || !N0.hasOneUse()) + return SDValue(); + + SDValue CMovOp0 = N0.getOperand(0); + SDValue CMovOp1 = N0.getOperand(1); + + // Make sure both operands are constants. + if (!isa<ConstantSDNode>(CMovOp0.getNode()) || + !isa<ConstantSDNode>(CMovOp1.getNode())) + return SDValue(); + + SDLoc DL(N); + + // If we looked through an any_extend above, add one to the constants. + if (N0.getValueType() != VT) { + CMovOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, VT, CMovOp0); + CMovOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, VT, CMovOp1); + } + + CMovOp0 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, CMovOp0, N1); + CMovOp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, CMovOp1, N1); + + return DAG.getNode(X86ISD::CMOV, DL, VT, CMovOp0, CMovOp1, + N0.getOperand(2), N0.getOperand(3)); +} + +static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (SDValue V = combineSextInRegCmov(N, DAG)) + return V; + + EVT VT = N->getValueType(0); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT ExtraVT = cast<VTSDNode>(N1)->getVT(); @@ -35686,7 +37396,7 @@ static SDValue getDivRem8(SDNode *N, SelectionDAG &DAG) { // promotion). static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) { SDValue CMovN = Extend->getOperand(0); - if (CMovN.getOpcode() != X86ISD::CMOV) + if (CMovN.getOpcode() != X86ISD::CMOV || !CMovN.hasOneUse()) return SDValue(); EVT TargetVT = Extend->getValueType(0); @@ -35697,20 +37407,36 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) { SDValue CMovOp0 = CMovN.getOperand(0); SDValue CMovOp1 = CMovN.getOperand(1); - bool DoPromoteCMOV = - (VT == MVT::i16 && (TargetVT == MVT::i32 || TargetVT == MVT::i64)) && - CMovN.hasOneUse() && - (isa<ConstantSDNode>(CMovOp0.getNode()) && - isa<ConstantSDNode>(CMovOp1.getNode())); + if (!isa<ConstantSDNode>(CMovOp0.getNode()) || + !isa<ConstantSDNode>(CMovOp1.getNode())) + return SDValue(); - if (!DoPromoteCMOV) + // Only extend to i32 or i64. + if (TargetVT != MVT::i32 && TargetVT != MVT::i64) return SDValue(); - CMovOp0 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp0); - CMovOp1 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp1); + // Only extend from i16 unless its a sign_extend from i32. Zext/aext from i32 + // are free. + if (VT != MVT::i16 && !(ExtendOpcode == ISD::SIGN_EXTEND && VT == MVT::i32)) + return SDValue(); - return DAG.getNode(X86ISD::CMOV, DL, TargetVT, CMovOp0, CMovOp1, - CMovN.getOperand(2), CMovN.getOperand(3)); + // If this a zero extend to i64, we should only extend to i32 and use a free + // zero extend to finish. + EVT ExtendVT = TargetVT; + if (TargetVT == MVT::i64 && ExtendOpcode != ISD::SIGN_EXTEND) + ExtendVT = MVT::i32; + + CMovOp0 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp0); + CMovOp1 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp1); + + SDValue Res = DAG.getNode(X86ISD::CMOV, DL, ExtendVT, CMovOp0, CMovOp1, + CMovN.getOperand(2), CMovN.getOperand(3)); + + // Finish extending if needed. + if (ExtendVT != TargetVT) + Res = DAG.getNode(ExtendOpcode, DL, TargetVT, Res); + + return Res; } // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). @@ -35866,7 +37592,7 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, // Also use this if we don't have SSE41 to allow the legalizer do its job. if (!Subtarget.hasSSE41() || VT.is128BitVector() || (VT.is256BitVector() && Subtarget.hasInt256()) || - (VT.is512BitVector() && Subtarget.hasAVX512())) { + (VT.is512BitVector() && Subtarget.useAVX512Regs())) { SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits()); return Opcode == ISD::SIGN_EXTEND ? DAG.getSignExtendVectorInReg(ExOp, DL, VT) @@ -35899,12 +37625,55 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, // On pre-AVX512 targets, split into 256-bit nodes of // ISD::*_EXTEND_VECTOR_INREG. - if (!Subtarget.hasAVX512() && !(VT.getSizeInBits() % 256)) + if (!Subtarget.useAVX512Regs() && !(VT.getSizeInBits() % 256)) return SplitAndExtendInReg(256); return SDValue(); } +// Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm +// result type. +static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc dl(N); + + // Only do this combine with AVX512 for vector extends. + if (!Subtarget.hasAVX512() || !VT.isVector() || N0->getOpcode() != ISD::SETCC) + return SDValue(); + + // Only combine legal element types. + EVT SVT = VT.getVectorElementType(); + if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32 && + SVT != MVT::i64 && SVT != MVT::f32 && SVT != MVT::f64) + return SDValue(); + + // We can only do this if the vector size in 256 bits or less. + unsigned Size = VT.getSizeInBits(); + if (Size > 256) + return SDValue(); + + // Don't fold if the condition code can't be handled by PCMPEQ/PCMPGT since + // that's the only integer compares with we have. + ISD::CondCode CC = cast<CondCodeSDNode>(N0->getOperand(2))->get(); + if (ISD::isUnsignedIntSetCC(CC)) + return SDValue(); + + // Only do this combine if the extension will be fully consumed by the setcc. + EVT N00VT = N0.getOperand(0).getValueType(); + EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger(); + if (Size != MatchingVecType.getSizeInBits()) + return SDValue(); + + SDValue Res = DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC); + + if (N->getOpcode() == ISD::ZERO_EXTEND) + Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType().getScalarType()); + + return Res; +} + static SDValue combineSext(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -35922,6 +37691,9 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (!DCI.isBeforeLegalizeOps()) return SDValue(); + if (SDValue V = combineExtSetcc(N, DAG, Subtarget)) + return V; + if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR && isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) { // Invert and sign-extend a boolean is the same as zero-extend and subtract @@ -35939,7 +37711,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, return V; if (VT.isVector()) - if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) + if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget)) return R; if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget)) @@ -35948,9 +37720,40 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) { + if (NegMul) { + switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); + case ISD::FMA: Opcode = X86ISD::FNMADD; break; + case X86ISD::FMADD_RND: Opcode = X86ISD::FNMADD_RND; break; + case X86ISD::FMSUB: Opcode = X86ISD::FNMSUB; break; + case X86ISD::FMSUB_RND: Opcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FNMADD: Opcode = ISD::FMA; break; + case X86ISD::FNMADD_RND: Opcode = X86ISD::FMADD_RND; break; + case X86ISD::FNMSUB: Opcode = X86ISD::FMSUB; break; + case X86ISD::FNMSUB_RND: Opcode = X86ISD::FMSUB_RND; break; + } + } + + if (NegAcc) { + switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); + case ISD::FMA: Opcode = X86ISD::FMSUB; break; + case X86ISD::FMADD_RND: Opcode = X86ISD::FMSUB_RND; break; + case X86ISD::FMSUB: Opcode = ISD::FMA; break; + case X86ISD::FMSUB_RND: Opcode = X86ISD::FMADD_RND; break; + case X86ISD::FNMADD: Opcode = X86ISD::FNMSUB; break; + case X86ISD::FNMADD_RND: Opcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FNMSUB: Opcode = X86ISD::FNMADD; break; + case X86ISD::FNMSUB_RND: Opcode = X86ISD::FNMADD_RND; break; + } + } + + return Opcode; +} + static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - // TODO: Handle FMSUB/FNMADD/FNMSUB as the starting opcode. SDLoc dl(N); EVT VT = N->getValueType(0); @@ -35966,96 +37769,41 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - auto invertIfNegative = [](SDValue &V) { + auto invertIfNegative = [&DAG](SDValue &V) { if (SDValue NegVal = isFNEG(V.getNode())) { - V = NegVal; + V = DAG.getBitcast(V.getValueType(), NegVal); return true; } + // Look through extract_vector_elts. If it comes from an FNEG, create a + // new extract from the FNEG input. + if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isa<ConstantSDNode>(V.getOperand(1)) && + cast<ConstantSDNode>(V.getOperand(1))->getZExtValue() == 0) { + if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) { + NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal); + V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(), + NegVal, V.getOperand(1)); + return true; + } + } + return false; }; // Do not convert the passthru input of scalar intrinsics. // FIXME: We could allow negations of the lower element only. - bool NegA = N->getOpcode() != X86ISD::FMADDS1 && - N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A); + bool NegA = invertIfNegative(A); bool NegB = invertIfNegative(B); - bool NegC = N->getOpcode() != X86ISD::FMADDS3 && - N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C); + bool NegC = invertIfNegative(C); - // Negative multiplication when NegA xor NegB - bool NegMul = (NegA != NegB); - bool HasNeg = NegA || NegB || NegC; + if (!NegA && !NegB && !NegC) + return SDValue(); - unsigned NewOpcode; - if (!NegMul) - NewOpcode = (!NegC) ? unsigned(ISD::FMA) : unsigned(X86ISD::FMSUB); - else - NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; - - // For FMA, we risk reconstructing the node we started with. - // In order to avoid this, we check for negation or opcode change. If - // one of the two happened, then it is a new node and we return it. - if (N->getOpcode() == ISD::FMA) { - if (HasNeg || NewOpcode != N->getOpcode()) - return DAG.getNode(NewOpcode, dl, VT, A, B, C); - return SDValue(); - } - - if (N->getOpcode() == X86ISD::FMADD_RND) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADD_RND; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; - } - } else if (N->getOpcode() == X86ISD::FMADDS1) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADDS1; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1; break; - } - } else if (N->getOpcode() == X86ISD::FMADDS3) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADDS3; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3; break; - } - } else if (N->getOpcode() == X86ISD::FMADDS1_RND) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADDS1_RND; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1_RND; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1_RND; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1_RND; break; - } - } else if (N->getOpcode() == X86ISD::FMADDS3_RND) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADDS3_RND; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3_RND; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3_RND; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3_RND; break; - } - } else if (N->getOpcode() == X86ISD::FMADD4S) { - switch (NewOpcode) { - case ISD::FMA: NewOpcode = X86ISD::FMADD4S; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB4S; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD4S; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB4S; break; - } - } else { - llvm_unreachable("Unexpected opcode!"); - } + unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC); - // Only return the node is the opcode was changed or one of the - // operand was negated. If not, we'll just recreate the same node. - if (HasNeg || NewOpcode != N->getOpcode()) { - if (N->getNumOperands() == 4) - return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); - return DAG.getNode(NewOpcode, dl, VT, A, B, C); - } - - return SDValue(); + if (N->getNumOperands() == 4) + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); + return DAG.getNode(NewOpcode, dl, VT, A, B, C); } // Combine FMADDSUB(A, B, FNEG(C)) -> FMSUBADD(A, B, C) @@ -36124,6 +37872,10 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, if (SDValue NewCMov = combineToExtendCMOV(N, DAG)) return NewCMov; + if (DCI.isBeforeLegalizeOps()) + if (SDValue V = combineExtSetcc(N, DAG, Subtarget)) + return V; + if (SDValue V = combineToExtendVectorInReg(N, DAG, DCI, Subtarget)) return V; @@ -36131,7 +37883,7 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, return V; if (VT.isVector()) - if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) + if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget)) return R; if (SDValue DivRem8 = getDivRem8(N, DAG)) @@ -36153,13 +37905,23 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get(); assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate"); - // We're looking for an oversized integer equality comparison, but ignore a - // comparison with zero because that gets special treatment in EmitTest(). + // We're looking for an oversized integer equality comparison. SDValue X = SetCC->getOperand(0); SDValue Y = SetCC->getOperand(1); EVT OpVT = X.getValueType(); unsigned OpSize = OpVT.getSizeInBits(); - if (!OpVT.isScalarInteger() || OpSize < 128 || isNullConstant(Y)) + if (!OpVT.isScalarInteger() || OpSize < 128) + return SDValue(); + + // Ignore a comparison with zero because that gets special treatment in + // EmitTest(). But make an exception for the special case of a pair of + // logically-combined vector-sized operands compared to zero. This pattern may + // be generated by the memcmp expansion pass with oversized integer compares + // (see PR33325). + bool IsOrXorXorCCZero = isNullConstant(Y) && X.getOpcode() == ISD::OR && + X.getOperand(0).getOpcode() == ISD::XOR && + X.getOperand(1).getOpcode() == ISD::XOR; + if (isNullConstant(Y) && !IsOrXorXorCCZero) return SDValue(); // Bail out if we know that this is not really just an oversized integer. @@ -36174,15 +37936,29 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, if ((OpSize == 128 && Subtarget.hasSSE2()) || (OpSize == 256 && Subtarget.hasAVX2())) { EVT VecVT = OpSize == 128 ? MVT::v16i8 : MVT::v32i8; - SDValue VecX = DAG.getBitcast(VecVT, X); - SDValue VecY = DAG.getBitcast(VecVT, Y); - + SDValue Cmp; + if (IsOrXorXorCCZero) { + // This is a bitwise-combined equality comparison of 2 pairs of vectors: + // setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne + // Use 2 vector equality compares and 'and' the results before doing a + // MOVMSK. + SDValue A = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(0)); + SDValue B = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(1)); + SDValue C = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(0)); + SDValue D = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(1)); + SDValue Cmp1 = DAG.getSetCC(DL, VecVT, A, B, ISD::SETEQ); + SDValue Cmp2 = DAG.getSetCC(DL, VecVT, C, D, ISD::SETEQ); + Cmp = DAG.getNode(ISD::AND, DL, VecVT, Cmp1, Cmp2); + } else { + SDValue VecX = DAG.getBitcast(VecVT, X); + SDValue VecY = DAG.getBitcast(VecVT, Y); + Cmp = DAG.getSetCC(DL, VecVT, VecX, VecY, ISD::SETEQ); + } // If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality. // setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq // setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne // setcc i256 X, Y, eq --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, eq // setcc i256 X, Y, ne --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, ne - SDValue Cmp = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, VecX, VecY); SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp); SDValue FFFFs = DAG.getConstant(OpSize == 128 ? 0xFFFF : 0xFFFFFFFF, DL, MVT::i32); @@ -36198,10 +37974,10 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); EVT VT = N->getValueType(0); + EVT OpVT = LHS.getValueType(); SDLoc DL(N); if (CC == ISD::SETNE || CC == ISD::SETEQ) { - EVT OpVT = LHS.getValueType(); // 0-x == y --> x+y == 0 // 0-x != y --> x+y != 0 if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) && @@ -36250,6 +38026,20 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, } } + // If we have AVX512, but not BWI and this is a vXi16/vXi8 setcc, just + // pre-promote its result type since vXi1 vectors don't get promoted + // during type legalization. + // NOTE: The element count check is to ignore operand types that need to + // go through type promotion to a 128-bit vector. + if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.isVector() && + VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() > 4 && + (OpVT.getVectorElementType() == MVT::i8 || + OpVT.getVectorElementType() == MVT::i16)) { + SDValue Setcc = DAG.getNode(ISD::SETCC, DL, OpVT, LHS, RHS, + N->getOperand(2)); + return DAG.getNode(ISD::TRUNCATE, DL, VT, Setcc); + } + // For an SSE1-only target, lower a comparison of v4f32 to X86ISD::CMPP early // to avoid scalarization via legalization because v4i32 is not a legal type. if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32 && @@ -36264,6 +38054,19 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, SDValue Src = N->getOperand(0); MVT SrcVT = Src.getSimpleValueType(); + // Perform constant folding. + if (ISD::isBuildVectorOfConstantSDNodes(Src.getNode())) { + assert(N->getValueType(0) == MVT::i32 && "Unexpected result type"); + APInt Imm(32, 0); + for (unsigned Idx = 0, e = Src.getNumOperands(); Idx < e; ++Idx) { + SDValue In = Src.getOperand(Idx); + if (!In.isUndef() && + cast<ConstantSDNode>(In)->getAPIntValue().isNegative()) + Imm.setBit(Idx); + } + return DAG.getConstant(Imm, SDLoc(N), N->getValueType(0)); + } + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), !DCI.isBeforeLegalizeOps()); @@ -36295,12 +38098,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, Index.getOperand(0).getScalarValueSizeInBits() <= 32) { SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); NewOps[4] = Index.getOperand(0); - DAG.UpdateNodeOperands(N, NewOps); - // The original sign extend has less users, add back to worklist in case - // it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - return SDValue(N, 0); + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) { + // The original sign extend has less users, add back to worklist in + // case it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + } + return SDValue(Res, 0); } } @@ -36313,9 +38118,10 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); NewOps[4] = Index; - DAG.UpdateNodeOperands(N, NewOps); - DCI.AddToWorklist(N); - return SDValue(N, 0); + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) + DCI.AddToWorklist(N); + return SDValue(Res, 0); } // Try to remove zero extends from 32->64 if we know the sign bit of @@ -36326,32 +38132,24 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, if (DAG.SignBitIsZero(Index.getOperand(0))) { SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); NewOps[4] = Index.getOperand(0); - DAG.UpdateNodeOperands(N, NewOps); - // The original zero extend has less users, add back to worklist in case - // it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - return SDValue(N, 0); + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) { + // The original sign extend has less users, add back to worklist in + // case it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + } + return SDValue(Res, 0); } } } - // Gather and Scatter instructions use k-registers for masks. The type of - // the masks is v*i1. So the mask will be truncated anyway. - // The SIGN_EXTEND_INREG my be dropped. - SDValue Mask = N->getOperand(2); - if (Subtarget.hasAVX512() && Mask.getOpcode() == ISD::SIGN_EXTEND_INREG) { - SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); - NewOps[2] = Mask.getOperand(0); - DAG.UpdateNodeOperands(N, NewOps); - return SDValue(N, 0); - } - // With AVX2 we only demand the upper bit of the mask. if (!Subtarget.hasAVX512()) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), !DCI.isBeforeLegalizeOps()); + SDValue Mask = N->getOperand(2); KnownBits Known; APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) { @@ -36448,11 +38246,11 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG, SDValue Op0 = N->getOperand(0); EVT VT = N->getValueType(0); EVT InVT = Op0.getValueType(); - EVT InSVT = InVT.getScalarType(); + // UINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32)) // UINT_TO_FP(vXi8) -> SINT_TO_FP(ZEXT(vXi8 to vXi32)) // UINT_TO_FP(vXi16) -> SINT_TO_FP(ZEXT(vXi16 to vXi32)) - if (InVT.isVector() && (InSVT == MVT::i8 || InSVT == MVT::i16)) { + if (InVT.isVector() && InVT.getScalarSizeInBits() < 32) { SDLoc dl(N); EVT DstVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, InVT.getVectorNumElements()); @@ -36482,14 +38280,11 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, SDValue Op0 = N->getOperand(0); EVT VT = N->getValueType(0); EVT InVT = Op0.getValueType(); - EVT InSVT = InVT.getScalarType(); // SINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32)) // SINT_TO_FP(vXi8) -> SINT_TO_FP(SEXT(vXi8 to vXi32)) // SINT_TO_FP(vXi16) -> SINT_TO_FP(SEXT(vXi16 to vXi32)) - if (InVT.isVector() && - (InSVT == MVT::i8 || InSVT == MVT::i16 || - (InSVT == MVT::i1 && !DAG.getTargetLoweringInfo().isTypeLegal(InVT)))) { + if (InVT.isVector() && InVT.getScalarSizeInBits() < 32) { SDLoc dl(N); EVT DstVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, InVT.getVectorNumElements()); @@ -36524,6 +38319,11 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, if (VT == MVT::f16 || VT == MVT::f128) return SDValue(); + // If we have AVX512DQ we can use packed conversion instructions unless + // the VT is f80. + if (Subtarget.hasDQI() && VT != MVT::f80) + return SDValue(); + if (!Ld->isVolatile() && !VT.isVector() && ISD::isNON_EXTLoad(Op0.getNode()) && Op0.hasOneUse() && !Subtarget.is64Bit() && LdVT == MVT::i64) { @@ -36778,15 +38578,9 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); - unsigned RegSize = 128; - if (Subtarget.hasBWI()) - RegSize = 512; - else if (Subtarget.hasAVX2()) - RegSize = 256; - unsigned VectorSize = VT.getVectorNumElements() * 16; // If the vector size is less than 128, or greater than the supported RegSize, // do not use PMADD. - if (VectorSize < 128 || VectorSize > RegSize) + if (VT.getVectorNumElements() < 8) return SDValue(); SDLoc DL(N); @@ -36800,7 +38594,13 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); // Madd vector size is half of the original vector size - SDValue Madd = DAG.getNode(X86ISD::VPMADDWD, DL, MAddVT, N0, N1); + auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); + }; + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); // Fill the rest of the output with 0 SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); @@ -36824,12 +38624,12 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, return SDValue(); unsigned RegSize = 128; - if (Subtarget.hasBWI()) + if (Subtarget.useBWIRegs()) RegSize = 512; - else if (Subtarget.hasAVX2()) + else if (Subtarget.hasAVX()) RegSize = 256; - // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // We only handle v16i32 for SSE2 / v32i32 for AVX / v64i32 for AVX512. // TODO: We should be able to handle larger vectors by splitting them before // feeding them into several SADs, and then reducing over those. if (VT.getSizeInBits() / 4 > RegSize) @@ -36855,7 +38655,7 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, // reduction. Note that the number of elements of the result of SAD is less // than the number of elements of its input. Therefore, we could only update // part of elements in the reduction vector. - SDValue Sad = createPSADBW(DAG, Op0, Op1, DL); + SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); // The output of PSADBW is a vector of i64. // We need to turn the vector of i64 into a vector of i32. @@ -36905,6 +38705,236 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(NewOpcode, SDLoc(N), VT, N->getOperand(0), AllOnesVec); } +static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1, + const SDLoc &DL, EVT VT, + const X86Subtarget &Subtarget) { + // Example of pattern we try to detect: + // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1)))) + //(add (build_vector (extract_elt t, 0), + // (extract_elt t, 2), + // (extract_elt t, 4), + // (extract_elt t, 6)), + // (build_vector (extract_elt t, 1), + // (extract_elt t, 3), + // (extract_elt t, 5), + // (extract_elt t, 7))) + + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (Op0.getOpcode() != ISD::BUILD_VECTOR || + Op1.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || + VT.getVectorNumElements() < 4 || + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + + // Check if one of Op0,Op1 is of the form: + // (build_vector (extract_elt Mul, 0), + // (extract_elt Mul, 2), + // (extract_elt Mul, 4), + // ... + // the other is of the form: + // (build_vector (extract_elt Mul, 1), + // (extract_elt Mul, 3), + // (extract_elt Mul, 5), + // ... + // and identify Mul. + SDValue Mul; + for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) { + SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i), + Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1); + // TODO: Be more tolerant to undefs. + if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1)); + auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1)); + auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1)); + auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1)); + if (!Const0L || !Const1L || !Const0H || !Const1H) + return SDValue(); + unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(), + Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue(); + // Commutativity of mul allows factors of a product to reorder. + if (Idx0L > Idx1L) + std::swap(Idx0L, Idx1L); + if (Idx0H > Idx1H) + std::swap(Idx0H, Idx1H); + // Commutativity of add allows pairs of factors to reorder. + if (Idx0L > Idx0H) { + std::swap(Idx0L, Idx0H); + std::swap(Idx1L, Idx1H); + } + if (Idx0L != 2 * i || Idx1L != 2 * i + 1 || Idx0H != 2 * i + 2 || + Idx1H != 2 * i + 3) + return SDValue(); + if (!Mul) { + // First time an extract_elt's source vector is visited. Must be a MUL + // with 2X number of vector elements than the BUILD_VECTOR. + // Both extracts must be from same MUL. + Mul = Op0L->getOperand(0); + if (Mul->getOpcode() != ISD::MUL || + Mul.getValueType().getVectorNumElements() != 2 * e) + return SDValue(); + } + // Check that the extract is from the same MUL previously seen. + if (Mul != Op0L->getOperand(0) || Mul != Op1L->getOperand(0) || + Mul != Op0H->getOperand(0) || Mul != Op1H->getOperand(0)) + return SDValue(); + } + + // Check if the Mul source can be safely shrunk. + ShrinkMode Mode; + if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); + + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i32 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + InVT.getVectorNumElements() / 2); + EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + InVT.getVectorNumElements()); + return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, + DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[0]), + DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[1])); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, + { Mul.getOperand(0), Mul.getOperand(1) }, + PMADDBuilder); +} + +// Attempt to turn this pattern into PMADDWD. +// (mul (add (zext (build_vector)), (zext (build_vector))), +// (add (zext (build_vector)), (zext (build_vector))) +static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + const SDLoc &DL, EVT VT, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + return SDValue(); + + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || + VT.getVectorNumElements() < 4 || + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + SDValue N10 = N1.getOperand(0); + SDValue N11 = N1.getOperand(1); + + // All inputs need to be sign extends. + // TODO: Support ZERO_EXTEND from known positive? + if (N00.getOpcode() != ISD::SIGN_EXTEND || + N01.getOpcode() != ISD::SIGN_EXTEND || + N10.getOpcode() != ISD::SIGN_EXTEND || + N11.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + + // Peek through the extends. + N00 = N00.getOperand(0); + N01 = N01.getOperand(0); + N10 = N10.getOperand(0); + N11 = N11.getOperand(0); + + // Must be extending from vXi16. + EVT InVT = N00.getValueType(); + if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT || + N10.getValueType() != InVT || N11.getValueType() != InVT) + return SDValue(); + + // All inputs should be build_vectors. + if (N00.getOpcode() != ISD::BUILD_VECTOR || + N01.getOpcode() != ISD::BUILD_VECTOR || + N10.getOpcode() != ISD::BUILD_VECTOR || + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the + // other vector. So we need to make sure for each element i, this operator + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue In0, In1; + for (unsigned i = 0; i != N00.getNumOperands(); ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); + SDValue N10Elt = N10.getOperand(i); + SDValue N11Elt = N11.getOperand(i); + // TODO: Be more tolerant to undefs. + if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1)); + auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1)); + auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1)); + auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1)); + if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) + return SDValue(); + unsigned IdxN00 = ConstN00Elt->getZExtValue(); + unsigned IdxN01 = ConstN01Elt->getZExtValue(); + unsigned IdxN10 = ConstN10Elt->getZExtValue(); + unsigned IdxN11 = ConstN11Elt->getZExtValue(); + // Add is commutative so indices can be reordered. + if (IdxN00 > IdxN10) { + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } + // N0 indices be the even elemtn. N1 indices must be the next odd element. + if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || + IdxN01 != 2 * i || IdxN11 != 2 * i + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); + SDValue N10In = N10Elt.getOperand(0); + SDValue N11In = N11Elt.getOperand(0); + // First time we find an input capture it. + if (!In0) { + In0 = N00In; + In1 = N01In; + } + // Mul is commutative so the input vectors can be in any order. + // Canonicalize to make the compares easier. + if (In0 != N00In) + std::swap(N00In, N01In); + if (In0 != N10In) + std::swap(N10In, N11In); + if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In) + return SDValue(); + } + + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i16 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + InVT.getVectorNumElements() / 2); + return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 }, + PMADDBuilder); +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags Flags = N->getFlags(); @@ -36918,11 +38948,22 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); + if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) + return MAdd; + if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) + return MAdd; + // Try to synthesize horizontal adds from adds of shuffles. - if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) || - (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) && - isHorizontalBinOp(Op0, Op1, true)) - return DAG.getNode(X86ISD::HADD, SDLoc(N), VT, Op0, Op1); + if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || + VT == MVT::v8i32) && + Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, true)) { + auto HADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::HADD, DL, Ops[0].getValueType(), Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1}, + HADDBuilder); + } if (SDValue V = combineIncDecVector(N, DAG)) return V; @@ -36936,20 +38977,19 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, SDValue Op1 = N->getOperand(1); EVT VT = N->getValueType(0); - // PSUBUS is supported, starting from SSE2, but special preprocessing - // for v8i32 requires umin, which appears in SSE41. + // PSUBUS is supported, starting from SSE2, but truncation for v8i32 + // is only worth it with SSSE3 (PSHUFB). if (!(Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) && - !(Subtarget.hasSSE41() && (VT == MVT::v8i32)) && - !(Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)) && - !(Subtarget.hasAVX512() && Subtarget.hasBWI() && - (VT == MVT::v64i8 || VT == MVT::v32i16 || VT == MVT::v16i32 || - VT == MVT::v8i64))) + !(Subtarget.hasSSSE3() && (VT == MVT::v8i32 || VT == MVT::v8i64)) && + !(Subtarget.hasAVX() && (VT == MVT::v32i8 || VT == MVT::v16i16)) && + !(Subtarget.useBWIRegs() && (VT == MVT::v64i8 || VT == MVT::v32i16 || + VT == MVT::v16i32 || VT == MVT::v8i64))) return SDValue(); SDValue SubusLHS, SubusRHS; // Try to find umax(a,b) - b or a - umin(a,b) patterns // they may be converted to subus(a,b). - // TODO: Need to add IR cannonicialization for this code. + // TODO: Need to add IR canonicalization for this code. if (Op0.getOpcode() == ISD::UMAX) { SubusRHS = Op1; SDValue MaxLHS = Op0.getOperand(0); @@ -36973,10 +39013,16 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, } else return SDValue(); + auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops); + }; + // PSUBUS doesn't support v8i32/v8i64/v16i32, but it can be enabled with // special preprocessing in some cases. if (VT != MVT::v8i32 && VT != MVT::v16i32 && VT != MVT::v8i64) - return DAG.getNode(X86ISD::SUBUS, SDLoc(N), VT, SubusLHS, SubusRHS); + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + { SubusLHS, SubusRHS }, SUBUSBuilder); // Special preprocessing case can be only applied // if the value was zero extended from 16 bit, @@ -37006,8 +39052,9 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, SDValue NewSubusLHS = DAG.getZExtOrTrunc(SubusLHS, SDLoc(SubusLHS), ShrinkedType); SDValue NewSubusRHS = DAG.getZExtOrTrunc(UMin, SDLoc(SubusRHS), ShrinkedType); - SDValue Psubus = DAG.getNode(X86ISD::SUBUS, SDLoc(N), ShrinkedType, - NewSubusLHS, NewSubusRHS); + SDValue Psubus = + SplitOpsAndApply(DAG, Subtarget, SDLoc(N), ShrinkedType, + { NewSubusLHS, NewSubusRHS }, SUBUSBuilder); // Zero extend the result, it may be used somewhere as 32 bit, // if not zext and following trunc will shrink. return DAG.getZExtOrTrunc(Psubus, SDLoc(N), ExtType); @@ -37038,10 +39085,16 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, // Try to synthesize horizontal subs from subs of shuffles. EVT VT = N->getValueType(0); - if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) || - (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) && - isHorizontalBinOp(Op0, Op1, false)) - return DAG.getNode(X86ISD::HSUB, SDLoc(N), VT, Op0, Op1); + if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || + VT == MVT::v8i32) && + Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, false)) { + auto HSUBBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(X86ISD::HSUB, DL, Ops[0].getValueType(), Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1}, + HSUBBuilder); + } if (SDValue V = combineIncDecVector(N, DAG)) return V; @@ -37145,28 +39198,6 @@ static SDValue combineVSZext(SDNode *N, SelectionDAG &DAG, return SDValue(); } -static SDValue combineTestM(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - - MVT VT = N->getSimpleValueType(0); - SDLoc DL(N); - - // TEST (AND a, b) ,(AND a, b) -> TEST a, b - if (Op0 == Op1 && Op1->getOpcode() == ISD::AND) - return DAG.getNode(X86ISD::TESTM, DL, VT, Op0->getOperand(0), - Op0->getOperand(1)); - - // TEST op0, BUILD_VECTOR(all_zero) -> BUILD_VECTOR(all_zero) - // TEST BUILD_VECTOR(all_zero), op1 -> BUILD_VECTOR(all_zero) - if (ISD::isBuildVectorAllZeros(Op0.getNode()) || - ISD::isBuildVectorAllZeros(Op1.getNode())) - return getZeroVector(VT, Subtarget, DAG, DL); - - return SDValue(); -} - static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MVT VT = N->getSimpleValueType(0); @@ -37190,9 +39221,7 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, MVT OpVT = N->getSimpleValueType(0); - // Early out for mask vectors. - if (OpVT.getVectorElementType() == MVT::i1) - return SDValue(); + bool IsI1Vector = OpVT.getVectorElementType() == MVT::i1; SDLoc dl(N); SDValue Vec = N->getOperand(0); @@ -37204,23 +39233,40 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, if (ISD::isBuildVectorAllZeros(Vec.getNode())) { // Inserting zeros into zeros is a nop. if (ISD::isBuildVectorAllZeros(SubVec.getNode())) - return Vec; + return getZeroVector(OpVT, Subtarget, DAG, dl); // If we're inserting into a zero vector and then into a larger zero vector, // just insert into the larger zero vector directly. if (SubVec.getOpcode() == ISD::INSERT_SUBVECTOR && ISD::isBuildVectorAllZeros(SubVec.getOperand(0).getNode())) { unsigned Idx2Val = SubVec.getConstantOperandVal(2); - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec, + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, + getZeroVector(OpVT, Subtarget, DAG, dl), SubVec.getOperand(1), DAG.getIntPtrConstant(IdxVal + Idx2Val, dl)); } + // If we're inserting into a zero vector and our input was extracted from an + // insert into a zero vector of the same type and the extraction was at + // least as large as the original insertion. Just insert the original + // subvector into a zero vector. + if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && IdxVal == 0 && + SubVec.getConstantOperandVal(1) == 0 && + SubVec.getOperand(0).getOpcode() == ISD::INSERT_SUBVECTOR) { + SDValue Ins = SubVec.getOperand(0); + if (Ins.getConstantOperandVal(2) == 0 && + ISD::isBuildVectorAllZeros(Ins.getOperand(0).getNode()) && + Ins.getOperand(1).getValueSizeInBits() <= SubVecVT.getSizeInBits()) + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, + getZeroVector(OpVT, Subtarget, DAG, dl), + Ins.getOperand(1), N->getOperand(2)); + } + // If we're inserting a bitcast into zeros, rewrite the insert and move the // bitcast to the other side. This helps with detecting zero extending // during isel. // TODO: Is this useful for other indices than 0? - if (SubVec.getOpcode() == ISD::BITCAST && IdxVal == 0) { + if (!IsI1Vector && SubVec.getOpcode() == ISD::BITCAST && IdxVal == 0) { MVT CastVT = SubVec.getOperand(0).getSimpleValueType(); unsigned NumElems = OpVT.getSizeInBits() / CastVT.getScalarSizeInBits(); MVT NewVT = MVT::getVectorVT(CastVT.getVectorElementType(), NumElems); @@ -37231,6 +39277,10 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, } } + // Stop here if this is an i1 vector. + if (IsI1Vector) + return SDValue(); + // If this is an insert of an extract, combine to a shuffle. Don't do this // if the insert or extract can be represented with a subregister operation. if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && @@ -37317,7 +39367,6 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, if (!Vec.getOperand(0).isUndef() && Vec.hasOneUse()) { Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, DAG.getUNDEF(OpVT), SubVec2, Vec.getOperand(2)); - DCI.AddToWorklist(Vec.getNode()); return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec, SubVec, N->getOperand(2)); @@ -37352,6 +39401,75 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, OpVT, SDLoc(N), InVec.getNode()->ops().slice(IdxVal, OpVT.getVectorNumElements())); + // If we're extracting the lowest subvector and we're the only user, + // we may be able to perform this with a smaller vector width. + if (IdxVal == 0 && InVec.hasOneUse()) { + unsigned InOpcode = InVec.getOpcode(); + if (OpVT == MVT::v2f64 && InVec.getValueType() == MVT::v4f64) { + // v2f64 CVTDQ2PD(v4i32). + if (InOpcode == ISD::SINT_TO_FP && + InVec.getOperand(0).getValueType() == MVT::v4i32) { + return DAG.getNode(X86ISD::CVTSI2P, SDLoc(N), OpVT, InVec.getOperand(0)); + } + // v2f64 CVTPS2PD(v4f32). + if (InOpcode == ISD::FP_EXTEND && + InVec.getOperand(0).getValueType() == MVT::v4f32) { + return DAG.getNode(X86ISD::VFPEXT, SDLoc(N), OpVT, InVec.getOperand(0)); + } + } + if ((InOpcode == X86ISD::VZEXT || InOpcode == X86ISD::VSEXT) && + OpVT.is128BitVector() && + InVec.getOperand(0).getSimpleValueType().is128BitVector()) { + unsigned ExtOp = InOpcode == X86ISD::VZEXT ? ISD::ZERO_EXTEND_VECTOR_INREG + : ISD::SIGN_EXTEND_VECTOR_INREG; + return DAG.getNode(ExtOp, SDLoc(N), OpVT, InVec.getOperand(0)); + } + } + + return SDValue(); +} + +static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + SDValue Src = N->getOperand(0); + + // If this is a scalar to vector to v1i1 from an AND with 1, bypass the and. + // This occurs frequently in our masked scalar intrinsic code and our + // floating point select lowering with AVX512. + // TODO: SimplifyDemandedBits instead? + if (VT == MVT::v1i1 && Src.getOpcode() == ISD::AND && Src.hasOneUse()) + if (auto *C = dyn_cast<ConstantSDNode>(Src.getOperand(1))) + if (C->getAPIntValue().isOneValue()) + return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), MVT::v1i1, + Src.getOperand(0)); + + return SDValue(); +} + +// Simplify PMULDQ and PMULUDQ operations. +static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + APInt DemandedMask(APInt::getLowBitsSet(64, 32)); + + // PMULQDQ/PMULUDQ only uses lower 32 bits from each vector element. + KnownBits LHSKnown; + if (TLI.SimplifyDemandedBits(LHS, DemandedMask, LHSKnown, TLO)) { + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } + + KnownBits RHSKnown; + if (TLI.SimplifyDemandedBits(RHS, DemandedMask, RHSKnown, TLO)) { + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } + return SDValue(); } @@ -37360,6 +39478,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, SelectionDAG &DAG = DCI.DAG; switch (N->getOpcode()) { default: break; + case ISD::SCALAR_TO_VECTOR: + return combineScalarToVector(N, DAG); case ISD::EXTRACT_VECTOR_ELT: case X86ISD::PEXTRW: case X86ISD::PEXTRB: @@ -37384,6 +39504,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget); case ISD::OR: return combineOr(N, DAG, DCI, Subtarget); case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget); + case X86ISD::BEXTR: return combineBEXTR(N, DAG, DCI, Subtarget); case ISD::LOAD: return combineLoad(N, DAG, DCI, Subtarget); case ISD::MLOAD: return combineMaskedLoad(N, DAG, DCI, Subtarget); case ISD::STORE: return combineStore(N, DAG, Subtarget); @@ -37449,20 +39570,21 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::VPERMI: case X86ISD::VPERMV: case X86ISD::VPERMV3: - case X86ISD::VPERMIV3: case X86ISD::VPERMIL2: case X86ISD::VPERMILPI: case X86ISD::VPERMILPV: case X86ISD::VPERM2X128: + case X86ISD::SHUF128: case X86ISD::VZEXT_MOVL: case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); case X86ISD::FMADD_RND: - case X86ISD::FMADDS1_RND: - case X86ISD::FMADDS3_RND: - case X86ISD::FMADDS1: - case X86ISD::FMADDS3: - case X86ISD::FMADD4S: - case ISD::FMA: return combineFMA(N, DAG, Subtarget); + case X86ISD::FMSUB: + case X86ISD::FMSUB_RND: + case X86ISD::FNMADD: + case X86ISD::FNMADD_RND: + case X86ISD::FNMSUB: + case X86ISD::FNMSUB_RND: + case ISD::FMA: return combineFMA(N, DAG, Subtarget); case X86ISD::FMADDSUB_RND: case X86ISD::FMSUBADD_RND: case X86ISD::FMADDSUB: @@ -37472,9 +39594,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::MSCATTER: case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget); - case X86ISD::TESTM: return combineTestM(N, DAG, Subtarget); case X86ISD::PCMPEQ: case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget); + case X86ISD::PMULDQ: + case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI); } return SDValue(); @@ -37487,6 +39610,11 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const { if (!isTypeLegal(VT)) return false; + + // There are no vXi8 shifts. + if (Opc == ISD::SHL && VT.isVector() && VT.getVectorElementType() == MVT::i8) + return false; + if (VT != MVT::i16) return true; @@ -37509,23 +39637,20 @@ bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const { } } -/// This function checks if any of the users of EFLAGS copies the EFLAGS. We -/// know that the code that lowers COPY of EFLAGS has to use the stack, and if -/// we don't adjust the stack we clobber the first frame index. -/// See X86InstrInfo::copyPhysReg. -static bool hasCopyImplyingStackAdjustment(const MachineFunction &MF) { - const MachineRegisterInfo &MRI = MF.getRegInfo(); - return any_of(MRI.reg_instructions(X86::EFLAGS), - [](const MachineInstr &RI) { return RI.isCopy(); }); -} - -void X86TargetLowering::finalizeLowering(MachineFunction &MF) const { - if (hasCopyImplyingStackAdjustment(MF)) { - MachineFrameInfo &MFI = MF.getFrameInfo(); - MFI.setHasCopyImplyingStackAdjustment(true); +SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc& dl, + SDValue Value, SDValue Addr, + SelectionDAG &DAG) const { + const Module *M = DAG.getMachineFunction().getMMI().getModule(); + Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch"); + if (IsCFProtectionSupported) { + // In case control-flow branch protection is enabled, we need to add + // notrack prefix to the indirect branch. + // In order to do that we create NT_BRIND SDNode. + // Upon ISEL, the pattern will convert it to jmp with NoTrack prefix. + return DAG.getNode(X86ISD::NT_BRIND, dl, MVT::Other, Value, Addr); } - TargetLoweringBase::finalizeLowering(MF); + return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, DAG); } /// This method query the target whether it is beneficial for dag combiner to @@ -37536,22 +39661,30 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { if (VT != MVT::i16) return false; - bool Promote = false; + auto IsFoldableRMW = [](SDValue Load, SDValue Op) { + if (!Op.hasOneUse()) + return false; + SDNode *User = *Op->use_begin(); + if (!ISD::isNormalStore(User)) + return false; + auto *Ld = cast<LoadSDNode>(Load); + auto *St = cast<StoreSDNode>(User); + return Ld->getBasePtr() == St->getBasePtr(); + }; + bool Commute = false; switch (Op.getOpcode()) { - default: break; + default: return false; case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: case ISD::ANY_EXTEND: - Promote = true; break; case ISD::SHL: case ISD::SRL: { SDValue N0 = Op.getOperand(0); // Look out for (store (shl (load), x)). - if (MayFoldLoad(N0) && MayFoldIntoStore(Op)) + if (MayFoldLoad(N0) && IsFoldableRMW(N0, Op)) return false; - Promote = true; break; } case ISD::ADD: @@ -37564,19 +39697,20 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { case ISD::SUB: { SDValue N0 = Op.getOperand(0); SDValue N1 = Op.getOperand(1); - if (!Commute && MayFoldLoad(N1)) - return false; // Avoid disabling potential load folding opportunities. - if (MayFoldLoad(N0) && (!isa<ConstantSDNode>(N1) || MayFoldIntoStore(Op))) + if (MayFoldLoad(N1) && + (!Commute || !isa<ConstantSDNode>(N0) || + (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N1, Op)))) return false; - if (MayFoldLoad(N1) && (!isa<ConstantSDNode>(N0) || MayFoldIntoStore(Op))) + if (MayFoldLoad(N0) && + ((Commute && !isa<ConstantSDNode>(N1)) || + (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N0, Op)))) return false; - Promote = true; } } PVT = MVT::i32; - return Promote; + return true; } bool X86TargetLowering:: @@ -37862,7 +39996,7 @@ TargetLowering::ConstraintWeight LLVM_FALLTHROUGH; case 'x': if (((type->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) || - ((type->getPrimitiveSizeInBits() == 256) && Subtarget.hasFp256())) + ((type->getPrimitiveSizeInBits() == 256) && Subtarget.hasAVX())) weight = CW_Register; break; case 'k': @@ -38353,6 +40487,25 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return Res; } + // Make sure it isn't a register that requires 64-bit mode. + if (!Subtarget.is64Bit() && + (isFRClass(*Res.second) || isGRClass(*Res.second)) && + TRI->getEncodingValue(Res.first) >= 8) { + // Register requires REX prefix, but we're in 32-bit mode. + Res.first = 0; + Res.second = nullptr; + return Res; + } + + // Make sure it isn't a register that requires AVX512. + if (!Subtarget.hasAVX512() && isFRClass(*Res.second) && + TRI->getEncodingValue(Res.first) & 0x10) { + // Register requires EVEX prefix. + Res.first = 0; + Res.second = nullptr; + return Res; + } + // Otherwise, check to see if this is a register class of the wrong value // type. For example, we want to map "{ax},i32" -> {eax}, we don't want it to // turn into {ax},{dx}. @@ -38421,7 +40574,7 @@ int X86TargetLowering::getScalingFactorCost(const DataLayout &DL, // will take 2 allocations in the out of order engine instead of 1 // for plain addressing mode, i.e. inst (reg1). // E.g., - // vaddps (%rsi,%drx), %ymm0, %ymm1 + // vaddps (%rsi,%rdx), %ymm0, %ymm1 // Requires two allocations (one for the load, one for the computation) // whereas: // vaddps (%rsi), %ymm0, %ymm1 @@ -38516,7 +40669,8 @@ StringRef X86TargetLowering::getStackProbeSymbolName(MachineFunction &MF) const // Generally, if we aren't on Windows, the platform ABI does not include // support for stack probes, so don't emit them. - if (!Subtarget.isOSWindows() || Subtarget.isTargetMachO()) + if (!Subtarget.isOSWindows() || Subtarget.isTargetMachO() || + MF.getFunction().hasFnAttribute("no-stack-arg-probe")) return ""; // We need a stack probe to conform to the Windows ABI. Choose the right |