summaryrefslogtreecommitdiff
path: root/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--lib/Target/X86/X86ISelLowering.cpp8626
1 files changed, 5390 insertions, 3236 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 9edd799779c7..7dcdb7967058 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -103,7 +103,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
bool UseX87 = !Subtarget.useSoftFloat() && Subtarget.hasX87();
X86ScalarSSEf64 = Subtarget.hasSSE2();
X86ScalarSSEf32 = Subtarget.hasSSE1();
- MVT PtrVT = MVT::getIntegerVT(8 * TM.getPointerSize());
+ MVT PtrVT = MVT::getIntegerVT(TM.getPointerSizeInBits(0));
// Set up the TargetLowering object.
@@ -216,6 +216,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// We have an algorithm for SSE2, and we turn this into a 64-bit
// FILD or VCVTUSI2SS/SD for other targets.
setOperationAction(ISD::UINT_TO_FP , MVT::i32 , Custom);
+ } else {
+ setOperationAction(ISD::UINT_TO_FP , MVT::i32 , Expand);
}
// Promote i1/i8 SINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have
@@ -235,7 +237,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
} else {
setOperationAction(ISD::SINT_TO_FP , MVT::i16 , Promote);
- setOperationAction(ISD::SINT_TO_FP , MVT::i32 , Promote);
+ setOperationAction(ISD::SINT_TO_FP , MVT::i32 , Expand);
}
// Promote i1/i8 FP_TO_SINT to larger FP_TO_SINTS's, as X86 doesn't have
@@ -611,7 +613,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// Long double always uses X87, except f128 in MMX.
if (UseX87) {
if (Subtarget.is64Bit() && Subtarget.hasMMX()) {
- addRegisterClass(MVT::f128, &X86::FR128RegClass);
+ addRegisterClass(MVT::f128, &X86::VR128RegClass);
ValueTypeActions.setTypeAction(MVT::f128, TypeSoftenFloat);
setOperationAction(ISD::FABS , MVT::f128, Custom);
setOperationAction(ISD::FNEG , MVT::f128, Custom);
@@ -790,19 +792,33 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FABS, MVT::v2f64, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom);
- setOperationAction(ISD::SMAX, MVT::v8i16, Legal);
- setOperationAction(ISD::UMAX, MVT::v16i8, Legal);
- setOperationAction(ISD::SMIN, MVT::v8i16, Legal);
- setOperationAction(ISD::UMIN, MVT::v16i8, Legal);
+ for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
+ setOperationAction(ISD::SMAX, VT, VT == MVT::v8i16 ? Legal : Custom);
+ setOperationAction(ISD::SMIN, VT, VT == MVT::v8i16 ? Legal : Custom);
+ setOperationAction(ISD::UMAX, VT, VT == MVT::v16i8 ? Legal : Custom);
+ setOperationAction(ISD::UMIN, VT, VT == MVT::v16i8 ? Legal : Custom);
+ }
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v8i16, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i32, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f32, Custom);
+ // Provide custom widening for v2f32 setcc. This is really for VLX when
+ // setcc result type returns v2i1/v4i1 vector for v2f32/v4f32 leading to
+ // type legalization changing the result type to v4i1 during widening.
+ // It works fine for SSE2 and is probably faster so no need to qualify with
+ // VLX support.
+ setOperationAction(ISD::SETCC, MVT::v2i32, Custom);
+
for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
setOperationAction(ISD::SETCC, VT, Custom);
setOperationAction(ISD::CTPOP, VT, Custom);
setOperationAction(ISD::CTTZ, VT, Custom);
+
+ // The condition codes aren't legal in SSE/AVX and under AVX512 we use
+ // setcc all the way to isel and prefer SETGT in some isel patterns.
+ setCondCodeAction(ISD::SETLT, VT, Custom);
+ setCondCodeAction(ISD::SETLE, VT, Custom);
}
for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
@@ -874,6 +890,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::BITCAST, MVT::v2i32, Custom);
setOperationAction(ISD::BITCAST, MVT::v4i16, Custom);
setOperationAction(ISD::BITCAST, MVT::v8i8, Custom);
+ if (!Subtarget.hasAVX512())
+ setOperationAction(ISD::BITCAST, MVT::v16i1, Custom);
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v2i64, Custom);
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i32, Custom);
@@ -886,6 +904,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
}
+
+ setOperationAction(ISD::ROTL, MVT::v4i32, Custom);
+ setOperationAction(ISD::ROTL, MVT::v8i16, Custom);
+ setOperationAction(ISD::ROTL, MVT::v16i8, Custom);
}
if (!Subtarget.useSoftFloat() && Subtarget.hasSSSE3()) {
@@ -967,7 +989,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::BITREVERSE, VT, Custom);
}
- if (!Subtarget.useSoftFloat() && Subtarget.hasFp256()) {
+ if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) {
bool HasInt256 = Subtarget.hasInt256();
addRegisterClass(MVT::v32i8, Subtarget.hasVLX() ? &X86::VR256XRegClass
@@ -996,13 +1018,16 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted
// even though v8i16 is a legal type.
- setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Promote);
- setOperationAction(ISD::FP_TO_UINT, MVT::v8i16, Promote);
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i16, MVT::v8i32);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i16, MVT::v8i32);
setOperationAction(ISD::FP_TO_SINT, MVT::v8i32, Legal);
setOperationAction(ISD::SINT_TO_FP, MVT::v8i32, Legal);
setOperationAction(ISD::FP_ROUND, MVT::v4f32, Legal);
+ if (!Subtarget.hasAVX512())
+ setOperationAction(ISD::BITCAST, MVT::v32i1, Custom);
+
for (MVT VT : MVT::fp_vector_valuetypes())
setLoadExtAction(ISD::EXTLOAD, VT, MVT::v4f32, Legal);
@@ -1014,6 +1039,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SRA, VT, Custom);
}
+ setOperationAction(ISD::ROTL, MVT::v8i32, Custom);
+ setOperationAction(ISD::ROTL, MVT::v16i16, Custom);
+ setOperationAction(ISD::ROTL, MVT::v32i8, Custom);
+
setOperationAction(ISD::SELECT, MVT::v4f64, Custom);
setOperationAction(ISD::SELECT, MVT::v4i64, Custom);
setOperationAction(ISD::SELECT, MVT::v8f32, Custom);
@@ -1034,6 +1063,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::CTPOP, VT, Custom);
setOperationAction(ISD::CTTZ, VT, Custom);
setOperationAction(ISD::CTLZ, VT, Custom);
+
+ // The condition codes aren't legal in SSE/AVX and under AVX512 we use
+ // setcc all the way to isel and prefer SETGT in some isel patterns.
+ setCondCodeAction(ISD::SETLT, VT, Custom);
+ setCondCodeAction(ISD::SETLE, VT, Custom);
}
if (Subtarget.hasAnyFMA()) {
@@ -1060,6 +1094,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHU, MVT::v32i8, Custom);
setOperationAction(ISD::MULHS, MVT::v32i8, Custom);
+ setOperationAction(ISD::SMAX, MVT::v4i64, Custom);
+ setOperationAction(ISD::UMAX, MVT::v4i64, Custom);
+ setOperationAction(ISD::SMIN, MVT::v4i64, Custom);
+ setOperationAction(ISD::UMIN, MVT::v4i64, Custom);
+
for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) {
setOperationAction(ISD::ABS, VT, HasInt256 ? Legal : Custom);
setOperationAction(ISD::SMAX, VT, HasInt256 ? Legal : Custom);
@@ -1137,13 +1176,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
}
+ // This block controls legalization of the mask vector sizes that are
+ // available with AVX512. 512-bit vectors are in a separate block controlled
+ // by useAVX512Regs.
if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
- addRegisterClass(MVT::v16i32, &X86::VR512RegClass);
- addRegisterClass(MVT::v16f32, &X86::VR512RegClass);
- addRegisterClass(MVT::v8i64, &X86::VR512RegClass);
- addRegisterClass(MVT::v8f64, &X86::VR512RegClass);
-
addRegisterClass(MVT::v1i1, &X86::VK1RegClass);
+ addRegisterClass(MVT::v2i1, &X86::VK2RegClass);
+ addRegisterClass(MVT::v4i1, &X86::VK4RegClass);
addRegisterClass(MVT::v8i1, &X86::VK8RegClass);
addRegisterClass(MVT::v16i1, &X86::VK16RegClass);
@@ -1151,24 +1190,34 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::v1i1, Custom);
- setOperationAction(ISD::SINT_TO_FP, MVT::v16i1, Custom);
- setOperationAction(ISD::UINT_TO_FP, MVT::v16i1, Custom);
- setOperationAction(ISD::SINT_TO_FP, MVT::v8i1, Custom);
- setOperationAction(ISD::UINT_TO_FP, MVT::v8i1, Custom);
- setOperationAction(ISD::SINT_TO_FP, MVT::v4i1, Custom);
- setOperationAction(ISD::UINT_TO_FP, MVT::v4i1, Custom);
- setOperationAction(ISD::SINT_TO_FP, MVT::v2i1, Custom);
- setOperationAction(ISD::UINT_TO_FP, MVT::v2i1, Custom);
-
- // Extends of v16i1/v8i1 to 128-bit vectors.
- setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom);
- setOperationAction(ISD::ZERO_EXTEND, MVT::v16i8, Custom);
- setOperationAction(ISD::ANY_EXTEND, MVT::v16i8, Custom);
- setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom);
- setOperationAction(ISD::ZERO_EXTEND, MVT::v8i16, Custom);
- setOperationAction(ISD::ANY_EXTEND, MVT::v8i16, Custom);
-
- for (auto VT : { MVT::v8i1, MVT::v16i1 }) {
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i1, MVT::v8i32);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i1, MVT::v8i32);
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v4i1, MVT::v4i32);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v4i1, MVT::v4i32);
+ setOperationAction(ISD::FP_TO_SINT, MVT::v2i1, Custom);
+ setOperationAction(ISD::FP_TO_UINT, MVT::v2i1, Custom);
+
+ // There is no byte sized k-register load or store without AVX512DQ.
+ if (!Subtarget.hasDQI()) {
+ setOperationAction(ISD::LOAD, MVT::v1i1, Custom);
+ setOperationAction(ISD::LOAD, MVT::v2i1, Custom);
+ setOperationAction(ISD::LOAD, MVT::v4i1, Custom);
+ setOperationAction(ISD::LOAD, MVT::v8i1, Custom);
+
+ setOperationAction(ISD::STORE, MVT::v1i1, Custom);
+ setOperationAction(ISD::STORE, MVT::v2i1, Custom);
+ setOperationAction(ISD::STORE, MVT::v4i1, Custom);
+ setOperationAction(ISD::STORE, MVT::v8i1, Custom);
+ }
+
+ // Extends of v16i1/v8i1/v4i1/v2i1 to 128-bit vectors.
+ for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
+ setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
+ setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
+ setOperationAction(ISD::ANY_EXTEND, VT, Custom);
+ }
+
+ for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1 }) {
setOperationAction(ISD::ADD, VT, Custom);
setOperationAction(ISD::SUB, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
@@ -1184,11 +1233,24 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i1, Custom);
+ setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom);
+ setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v2i1, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom);
- for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1,
- MVT::v16i1, MVT::v32i1, MVT::v64i1 })
- setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal);
+ for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
+ setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
+ }
+
+ // This block controls legalization for 512-bit operations with 32/64 bit
+ // elements. 512-bits can be disabled based on prefer-vector-width and
+ // required-vector-width function attributes.
+ if (!Subtarget.useSoftFloat() && Subtarget.useAVX512Regs()) {
+ addRegisterClass(MVT::v16i32, &X86::VR512RegClass);
+ addRegisterClass(MVT::v16f32, &X86::VR512RegClass);
+ addRegisterClass(MVT::v8i64, &X86::VR512RegClass);
+ addRegisterClass(MVT::v8f64, &X86::VR512RegClass);
for (MVT VT : MVT::fp_vector_valuetypes())
setLoadExtAction(ISD::EXTLOAD, VT, MVT::v8f32, Legal);
@@ -1201,16 +1263,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setLoadExtAction(ExtType, MVT::v8i64, MVT::v8i32, Legal);
}
- for (MVT VT : {MVT::v2i64, MVT::v4i32, MVT::v8i32, MVT::v4i64, MVT::v8i16,
- MVT::v16i8, MVT::v16i16, MVT::v32i8, MVT::v16i32,
- MVT::v8i64, MVT::v32i16, MVT::v64i8}) {
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
- setLoadExtAction(ISD::SEXTLOAD, VT, MaskVT, Custom);
- setLoadExtAction(ISD::ZEXTLOAD, VT, MaskVT, Custom);
- setLoadExtAction(ISD::EXTLOAD, VT, MaskVT, Custom);
- setTruncStoreAction(VT, MaskVT, Custom);
- }
-
for (MVT VT : { MVT::v16f32, MVT::v8f64 }) {
setOperationAction(ISD::FNEG, VT, Custom);
setOperationAction(ISD::FABS, VT, Custom);
@@ -1219,11 +1271,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
setOperationAction(ISD::FP_TO_SINT, MVT::v16i32, Legal);
- setOperationAction(ISD::FP_TO_SINT, MVT::v16i16, Promote);
- setOperationAction(ISD::FP_TO_SINT, MVT::v16i8, Promote);
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i16, MVT::v16i32);
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i8, MVT::v16i32);
+ setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i1, MVT::v16i32);
setOperationAction(ISD::FP_TO_UINT, MVT::v16i32, Legal);
- setOperationAction(ISD::FP_TO_UINT, MVT::v16i8, Promote);
- setOperationAction(ISD::FP_TO_UINT, MVT::v16i16, Promote);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i1, MVT::v16i32);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i8, MVT::v16i32);
+ setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i16, MVT::v16i32);
setOperationAction(ISD::SINT_TO_FP, MVT::v16i32, Legal);
setOperationAction(ISD::UINT_TO_FP, MVT::v16i32, Legal);
@@ -1296,6 +1350,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::CTTZ, VT, Custom);
setOperationAction(ISD::ROTL, VT, Custom);
setOperationAction(ISD::ROTR, VT, Custom);
+ setOperationAction(ISD::SETCC, VT, Custom);
+
+ // The condition codes aren't legal in SSE/AVX and under AVX512 we use
+ // setcc all the way to isel and prefer SETGT in some isel patterns.
+ setCondCodeAction(ISD::SETLT, VT, Custom);
+ setCondCodeAction(ISD::SETLE, VT, Custom);
}
// Need to promote to 64-bit even though we have 32-bit masked instructions
@@ -1310,6 +1370,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal);
setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal);
setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal);
+
+ setOperationAction(ISD::MUL, MVT::v8i64, Legal);
}
if (Subtarget.hasCDI()) {
@@ -1349,10 +1411,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationPromotedToType(ISD::LOAD, VT, MVT::v8i64);
setOperationPromotedToType(ISD::SELECT, VT, MVT::v8i64);
}
+
+ // Need to custom split v32i16/v64i8 bitcasts.
+ if (!Subtarget.hasBWI()) {
+ setOperationAction(ISD::BITCAST, MVT::v32i16, Custom);
+ setOperationAction(ISD::BITCAST, MVT::v64i8, Custom);
+ }
}// has AVX-512
- if (!Subtarget.useSoftFloat() &&
- (Subtarget.hasAVX512() || Subtarget.hasVLX())) {
+ // This block controls legalization for operations that don't have
+ // pre-AVX512 equivalents. Without VLX we use 512-bit operations for
+ // narrower widths.
+ if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
// These operations are handled on non-VLX by artificially widening in
// isel patterns.
// TODO: Custom widen in lowering on non-VLX and drop the isel patterns?
@@ -1376,6 +1446,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::ROTR, VT, Custom);
}
+ // Custom legalize 2x32 to get a little better code.
+ setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom);
+ setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom);
+
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 })
setOperationAction(ISD::MSCATTER, VT, Custom);
@@ -1386,6 +1460,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::UINT_TO_FP, VT, Legal);
setOperationAction(ISD::FP_TO_SINT, VT, Legal);
setOperationAction(ISD::FP_TO_UINT, VT, Legal);
+
+ setOperationAction(ISD::MUL, VT, Legal);
}
}
@@ -1402,10 +1478,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
}
+ // This block control legalization of v32i1/v64i1 which are available with
+ // AVX512BW. 512-bit v32i16 and v64i8 vector legalization is controlled with
+ // useBWIRegs.
if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) {
- addRegisterClass(MVT::v32i16, &X86::VR512RegClass);
- addRegisterClass(MVT::v64i8, &X86::VR512RegClass);
-
addRegisterClass(MVT::v32i1, &X86::VK32RegClass);
addRegisterClass(MVT::v64i1, &X86::VK64RegClass);
@@ -1428,11 +1504,22 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i1, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i1, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i1, Custom);
+ for (auto VT : { MVT::v16i1, MVT::v32i1 })
+ setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
// Extends from v32i1 masks to 256-bit vectors.
setOperationAction(ISD::SIGN_EXTEND, MVT::v32i8, Custom);
setOperationAction(ISD::ZERO_EXTEND, MVT::v32i8, Custom);
setOperationAction(ISD::ANY_EXTEND, MVT::v32i8, Custom);
+ }
+
+ // This block controls legalization for v32i16 and v64i8. 512-bits can be
+ // disabled based on prefer-vector-width and required-vector-width function
+ // attributes.
+ if (!Subtarget.useSoftFloat() && Subtarget.useBWIRegs()) {
+ addRegisterClass(MVT::v32i16, &X86::VR512RegClass);
+ addRegisterClass(MVT::v64i8, &X86::VR512RegClass);
+
// Extends from v64i1 masks to 512-bit vectors.
setOperationAction(ISD::SIGN_EXTEND, MVT::v64i8, Custom);
setOperationAction(ISD::ZERO_EXTEND, MVT::v64i8, Custom);
@@ -1482,6 +1569,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::UMAX, VT, Legal);
setOperationAction(ISD::SMIN, VT, Legal);
setOperationAction(ISD::UMIN, VT, Legal);
+ setOperationAction(ISD::SETCC, VT, Custom);
setOperationPromotedToType(ISD::AND, VT, MVT::v8i64);
setOperationPromotedToType(ISD::OR, VT, MVT::v8i64);
@@ -1498,8 +1586,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
}
- if (!Subtarget.useSoftFloat() && Subtarget.hasBWI() &&
- (Subtarget.hasAVX512() || Subtarget.hasVLX())) {
+ if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) {
for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) {
setOperationAction(ISD::MLOAD, VT, Subtarget.hasVLX() ? Legal : Custom);
setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom);
@@ -1516,39 +1603,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}
if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
- addRegisterClass(MVT::v4i1, &X86::VK4RegClass);
- addRegisterClass(MVT::v2i1, &X86::VK2RegClass);
-
- for (auto VT : { MVT::v2i1, MVT::v4i1 }) {
- setOperationAction(ISD::ADD, VT, Custom);
- setOperationAction(ISD::SUB, VT, Custom);
- setOperationAction(ISD::MUL, VT, Custom);
- setOperationAction(ISD::VSELECT, VT, Expand);
-
- setOperationAction(ISD::TRUNCATE, VT, Custom);
- setOperationAction(ISD::SETCC, VT, Custom);
- setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
- setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
- setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
- setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
- }
-
- // TODO: v8i1 concat should be legal without VLX to support concats of
- // v1i1, but we won't legalize it correctly currently without introducing
- // a v4i1 concat in the middle.
- setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom);
- setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom);
- setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom);
-
- // Extends from v2i1/v4i1 masks to 128-bit vectors.
- setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom);
- setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom);
- setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom);
- setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom);
- setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Custom);
- setOperationAction(ISD::ANY_EXTEND, MVT::v2i64, Custom);
-
setTruncStoreAction(MVT::v4i64, MVT::v4i8, Legal);
setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal);
setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal);
@@ -1648,6 +1702,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// We have target-specific dag combine patterns for the following nodes:
setTargetDAGCombine(ISD::VECTOR_SHUFFLE);
+ setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
setTargetDAGCombine(ISD::INSERT_SUBVECTOR);
setTargetDAGCombine(ISD::EXTRACT_SUBVECTOR);
@@ -1733,6 +1788,9 @@ SDValue X86TargetLowering::emitStackGuardXorFP(SelectionDAG &DAG, SDValue Val,
TargetLoweringBase::LegalizeTypeAction
X86TargetLowering::getPreferredVectorAction(EVT VT) const {
+ if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI())
+ return TypeSplitVector;
+
if (ExperimentalVectorWideningLegalization &&
VT.getVectorNumElements() != 1 &&
VT.getVectorElementType().getSimpleVT() != MVT::i1)
@@ -1741,6 +1799,20 @@ X86TargetLowering::getPreferredVectorAction(EVT VT) const {
return TargetLoweringBase::getPreferredVectorAction(VT);
}
+MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+ EVT VT) const {
+ if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI())
+ return MVT::v32i8;
+ return TargetLowering::getRegisterTypeForCallingConv(Context, VT);
+}
+
+unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
+ EVT VT) const {
+ if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI())
+ return 1;
+ return TargetLowering::getNumRegistersForCallingConv(Context, VT);
+}
+
EVT X86TargetLowering::getSetCCResultType(const DataLayout &DL,
LLVMContext& Context,
EVT VT) const {
@@ -1937,7 +2009,7 @@ void X86TargetLowering::markLibCallAttributes(MachineFunction *MF, unsigned CC,
// Mark the first N int arguments as having reg
for (unsigned Idx = 0; Idx < Args.size(); Idx++) {
Type *T = Args[Idx].Ty;
- if (T->isPointerTy() || T->isIntegerTy())
+ if (T->isIntOrPtrTy())
if (MF->getDataLayout().getTypeAllocSize(T) <= 8) {
unsigned numRegs = 1;
if (MF->getDataLayout().getTypeAllocSize(T) > 4)
@@ -2051,7 +2123,8 @@ Value *X86TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const {
void X86TargetLowering::insertSSPDeclarations(Module &M) const {
// MSVC CRT provides functionalities for stack protection.
- if (Subtarget.getTargetTriple().isOSMSVCRT()) {
+ if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() ||
+ Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) {
// MSVC CRT has a global variable holding security cookie.
M.getOrInsertGlobal("__security_cookie",
Type::getInt8PtrTy(M.getContext()));
@@ -2073,15 +2146,19 @@ void X86TargetLowering::insertSSPDeclarations(Module &M) const {
Value *X86TargetLowering::getSDagStackGuard(const Module &M) const {
// MSVC CRT has a global variable holding security cookie.
- if (Subtarget.getTargetTriple().isOSMSVCRT())
+ if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() ||
+ Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) {
return M.getGlobalVariable("__security_cookie");
+ }
return TargetLowering::getSDagStackGuard(M);
}
Value *X86TargetLowering::getSSPStackGuardCheck(const Module &M) const {
// MSVC CRT has a function to validate security cookie.
- if (Subtarget.getTargetTriple().isOSMSVCRT())
+ if (Subtarget.getTargetTriple().isWindowsMSVCEnvironment() ||
+ Subtarget.getTargetTriple().isWindowsItaniumEnvironment()) {
return M.getFunction("__security_check_cookie");
+ }
return TargetLowering::getSSPStackGuardCheck(M);
}
@@ -2140,6 +2217,10 @@ static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc,
const SDLoc &Dl, SelectionDAG &DAG) {
EVT ValVT = ValArg.getValueType();
+ if (ValVT == MVT::v1i1)
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, Dl, ValLoc, ValArg,
+ DAG.getIntPtrConstant(0, Dl));
+
if ((ValVT == MVT::v8i1 && (ValLoc == MVT::i8 || ValLoc == MVT::i32)) ||
(ValVT == MVT::v16i1 && (ValLoc == MVT::i16 || ValLoc == MVT::i32))) {
// Two stage lowering might be required
@@ -2150,13 +2231,16 @@ static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc,
if (ValLoc == MVT::i32)
ValToCopy = DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValToCopy);
return ValToCopy;
- } else if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) ||
- (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) {
+ }
+
+ if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) ||
+ (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) {
// One stage lowering is required
// bitcast: v32i1 -> i32 / v64i1 -> i64
return DAG.getBitcast(ValLoc, ValArg);
- } else
- return DAG.getNode(ISD::SIGN_EXTEND, Dl, ValLoc, ValArg);
+ }
+
+ return DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValArg);
}
/// Breaks v64i1 value into two registers and adds the new node to the DAG
@@ -2474,10 +2558,10 @@ static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA,
MachineFunction &MF = DAG.getMachineFunction();
const TargetRegisterClass *RC = &X86::GR32RegClass;
- // Read a 32 bit value from the registers
+ // Read a 32 bit value from the registers.
if (nullptr == InFlag) {
// When no physical register is present,
- // create an intermediate virtual register
+ // create an intermediate virtual register.
Reg = MF.addLiveIn(VA.getLocReg(), RC);
ArgValueLo = DAG.getCopyFromReg(Root, Dl, Reg, MVT::i32);
Reg = MF.addLiveIn(NextVA.getLocReg(), RC);
@@ -2493,13 +2577,13 @@ static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA,
*InFlag = ArgValueHi.getValue(2);
}
- // Convert the i32 type into v32i1 type
+ // Convert the i32 type into v32i1 type.
Lo = DAG.getBitcast(MVT::v32i1, ArgValueLo);
- // Convert the i32 type into v32i1 type
+ // Convert the i32 type into v32i1 type.
Hi = DAG.getBitcast(MVT::v32i1, ArgValueHi);
- // Concatenate the two values together
+ // Concatenate the two values together.
return DAG.getNode(ISD::CONCAT_VECTORS, Dl, MVT::v64i1, Lo, Hi);
}
@@ -2640,7 +2724,7 @@ enum StructReturnType {
StackStructReturn
};
static StructReturnType
-callIsStructReturn(const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsMCU) {
+callIsStructReturn(ArrayRef<ISD::OutputArg> Outs, bool IsMCU) {
if (Outs.empty())
return NotStructReturn;
@@ -2654,7 +2738,7 @@ callIsStructReturn(const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsMCU) {
/// Determines whether a function uses struct return semantics.
static StructReturnType
-argsAreStructReturn(const SmallVectorImpl<ISD::InputArg> &Ins, bool IsMCU) {
+argsAreStructReturn(ArrayRef<ISD::InputArg> Ins, bool IsMCU) {
if (Ins.empty())
return NotStructReturn;
@@ -2774,7 +2858,11 @@ X86TargetLowering::LowerMemArgument(SDValue Chain, CallingConv::ID CallConv,
if (Flags.isByVal()) {
unsigned Bytes = Flags.getByValSize();
if (Bytes == 0) Bytes = 1; // Don't create zero-sized stack objects.
- int FI = MFI.CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable);
+
+ // FIXME: For now, all byval parameter objects are marked as aliasing. This
+ // can be improved with deeper analysis.
+ int FI = MFI.CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable,
+ /*isAliased=*/true);
// Adjust SP offset of interrupt parameter.
if (CallConv == CallingConv::X86_INTR) {
MFI.setObjectOffset(FI, Offset);
@@ -2898,7 +2986,7 @@ static ArrayRef<MCPhysReg> get64BitArgumentXMMs(MachineFunction &MF,
}
#ifndef NDEBUG
-static bool isSortedByValueNo(const SmallVectorImpl<CCValAssign> &ArgLocs) {
+static bool isSortedByValueNo(ArrayRef<CCValAssign> ArgLocs) {
return std::is_sorted(ArgLocs.begin(), ArgLocs.end(),
[](const CCValAssign &A, const CCValAssign &B) -> bool {
return A.getValNo() < B.getValNo();
@@ -2975,7 +3063,11 @@ SDValue X86TargetLowering::LowerFormalArguments(
getv64i1Argument(VA, ArgLocs[++I], Chain, DAG, dl, Subtarget);
} else {
const TargetRegisterClass *RC;
- if (RegVT == MVT::i32)
+ if (RegVT == MVT::i8)
+ RC = &X86::GR8RegClass;
+ else if (RegVT == MVT::i16)
+ RC = &X86::GR16RegClass;
+ else if (RegVT == MVT::i32)
RC = &X86::GR32RegClass;
else if (Is64Bit && RegVT == MVT::i64)
RC = &X86::GR64RegClass;
@@ -2986,7 +3078,7 @@ SDValue X86TargetLowering::LowerFormalArguments(
else if (RegVT == MVT::f80)
RC = &X86::RFP80RegClass;
else if (RegVT == MVT::f128)
- RC = &X86::FR128RegClass;
+ RC = &X86::VR128RegClass;
else if (RegVT.is512BitVector())
RC = &X86::VR512RegClass;
else if (RegVT.is256BitVector())
@@ -3361,6 +3453,11 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const Function *Fn = CI ? CI->getCalledFunction() : nullptr;
bool HasNCSR = (CI && CI->hasFnAttr("no_caller_saved_registers")) ||
(Fn && Fn->hasFnAttribute("no_caller_saved_registers"));
+ const auto *II = dyn_cast_or_null<InvokeInst>(CLI.CS.getInstruction());
+ bool HasNoCfCheck =
+ (CI && CI->doesNoCfCheck()) || (II && II->doesNoCfCheck());
+ const Module *M = MF.getMMI().getModule();
+ Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch");
if (CallConv == CallingConv::X86_INTR)
report_fatal_error("X86 interrupts may not be called directly");
@@ -3743,6 +3840,14 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Callee = DAG.getTargetExternalSymbol(
S->getSymbol(), getPointerTy(DAG.getDataLayout()), OpFlags);
+
+ if (OpFlags == X86II::MO_GOTPCREL) {
+ Callee = DAG.getNode(X86ISD::WrapperRIP, dl,
+ getPointerTy(DAG.getDataLayout()), Callee);
+ Callee = DAG.getLoad(
+ getPointerTy(DAG.getDataLayout()), dl, DAG.getEntryNode(), Callee,
+ MachinePointerInfo::getGOT(DAG.getMachineFunction()));
+ }
} else if (Subtarget.isTarget64BitILP32() &&
Callee->getValueType(0) == MVT::i32) {
// Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
@@ -3804,9 +3909,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
// Allocate a new Reg Mask and copy Mask.
- RegMask = MF.allocateRegisterMask(TRI->getNumRegs());
- unsigned RegMaskSize = (TRI->getNumRegs() + 31) / 32;
- memcpy(RegMask, Mask, sizeof(uint32_t) * RegMaskSize);
+ RegMask = MF.allocateRegMask();
+ unsigned RegMaskSize = MachineOperand::getRegMaskSize(TRI->getNumRegs());
+ memcpy(RegMask, Mask, sizeof(RegMask[0]) * RegMaskSize);
// Make sure all sub registers of the argument registers are reset
// in the RegMask.
@@ -3836,7 +3941,11 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return DAG.getNode(X86ISD::TC_RETURN, dl, NodeTys, Ops);
}
- Chain = DAG.getNode(X86ISD::CALL, dl, NodeTys, Ops);
+ if (HasNoCfCheck && IsCFProtectionSupported) {
+ Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
+ } else {
+ Chain = DAG.getNode(X86ISD::CALL, dl, NodeTys, Ops);
+ }
InFlag = Chain.getValue(1);
// Create the CALLSEQ_END node.
@@ -4260,8 +4369,6 @@ static bool isTargetShuffle(unsigned Opcode) {
case X86ISD::VSRLDQ:
case X86ISD::MOVLHPS:
case X86ISD::MOVHLPS:
- case X86ISD::MOVLPS:
- case X86ISD::MOVLPD:
case X86ISD::MOVSHDUP:
case X86ISD::MOVSLDUP:
case X86ISD::MOVDDUP:
@@ -4273,12 +4380,12 @@ static bool isTargetShuffle(unsigned Opcode) {
case X86ISD::VPERMILPI:
case X86ISD::VPERMILPV:
case X86ISD::VPERM2X128:
+ case X86ISD::SHUF128:
case X86ISD::VPERMIL2:
case X86ISD::VPERMI:
case X86ISD::VPPERM:
case X86ISD::VPERMV:
case X86ISD::VPERMV3:
- case X86ISD::VPERMIV3:
case X86ISD::VZEXT_MOVL:
return true;
}
@@ -4294,7 +4401,6 @@ static bool isTargetShuffleVariableMask(unsigned Opcode) {
case X86ISD::VPPERM:
case X86ISD::VPERMV:
case X86ISD::VPERMV3:
- case X86ISD::VPERMIV3:
return true;
// 'Faux' Target Shuffles.
case ISD::AND:
@@ -4371,7 +4477,7 @@ bool X86::isCalleePop(CallingConv::ID CallingConv,
}
}
-/// \brief Return true if the condition is an unsigned comparison operation.
+/// Return true if the condition is an unsigned comparison operation.
static bool isX86CCUnsigned(unsigned X86CC) {
switch (X86CC) {
default:
@@ -4518,20 +4624,6 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
Info.offset = 0;
switch (IntrData->Type) {
- case EXPAND_FROM_MEM: {
- Info.ptrVal = I.getArgOperand(0);
- Info.memVT = MVT::getVT(I.getType());
- Info.align = 1;
- Info.flags |= MachineMemOperand::MOLoad;
- break;
- }
- case COMPRESS_TO_MEM: {
- Info.ptrVal = I.getArgOperand(0);
- Info.memVT = MVT::getVT(I.getArgOperand(1)->getType());
- Info.align = 1;
- Info.flags |= MachineMemOperand::MOStore;
- break;
- }
case TRUNCATE_TO_MEM_VI8:
case TRUNCATE_TO_MEM_VI16:
case TRUNCATE_TO_MEM_VI32: {
@@ -4580,7 +4672,7 @@ bool X86TargetLowering::shouldReduceLoadWidth(SDNode *Load,
return true;
}
-/// \brief Returns true if it is beneficial to convert a load of a constant
+/// Returns true if it is beneficial to convert a load of a constant
/// to just the constant itself.
bool X86TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
Type *Ty) const {
@@ -4625,6 +4717,14 @@ bool X86TargetLowering::isCheapToSpeculateCtlz() const {
return Subtarget.hasLZCNT();
}
+bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT,
+ EVT BitcastVT) const {
+ if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1)
+ return false;
+
+ return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT);
+}
+
bool X86TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT,
const SelectionDAG &DAG) const {
// Do not merge to float value size (128 bytes) if no implicit
@@ -4649,14 +4749,52 @@ bool X86TargetLowering::isMaskAndCmp0FoldingBeneficial(
}
bool X86TargetLowering::hasAndNotCompare(SDValue Y) const {
+ EVT VT = Y.getValueType();
+
+ if (VT.isVector())
+ return false;
+
if (!Subtarget.hasBMI())
return false;
// There are only 32-bit and 64-bit forms for 'andn'.
- EVT VT = Y.getValueType();
if (VT != MVT::i32 && VT != MVT::i64)
return false;
+ // A mask and compare against constant is ok for an 'andn' too
+ // even though the BMI instruction doesn't have an immediate form.
+
+ return true;
+}
+
+bool X86TargetLowering::hasAndNot(SDValue Y) const {
+ EVT VT = Y.getValueType();
+
+ if (!VT.isVector()) // x86 can't form 'andn' with an immediate.
+ return !isa<ConstantSDNode>(Y) && hasAndNotCompare(Y);
+
+ // Vector.
+
+ if (!Subtarget.hasSSE1() || VT.getSizeInBits() < 128)
+ return false;
+
+ if (VT == MVT::v4i32)
+ return true;
+
+ return Subtarget.hasSSE2();
+}
+
+bool X86TargetLowering::preferShiftsToClearExtremeBits(SDValue Y) const {
+ EVT VT = Y.getValueType();
+
+ // For vectors, we don't have a preference, but we probably want a mask.
+ if (VT.isVector())
+ return false;
+
+ // 64-bit shifts on 32-bit targets produce really bad bloated code.
+ if (VT == MVT::i64 && !Subtarget.is64Bit())
+ return false;
+
return true;
}
@@ -4699,10 +4837,24 @@ static bool isUndefInRange(ArrayRef<int> Mask, unsigned Pos, unsigned Size) {
return true;
}
+/// Return true if Val falls within the specified range (L, H].
+static bool isInRange(int Val, int Low, int Hi) {
+ return (Val >= Low && Val < Hi);
+}
+
+/// Return true if the value of any element in Mask falls within the specified
+/// range (L, H].
+static bool isAnyInRange(ArrayRef<int> Mask, int Low, int Hi) {
+ for (int M : Mask)
+ if (isInRange(M, Low, Hi))
+ return true;
+ return false;
+}
+
/// Return true if Val is undef or if its value falls within the
/// specified range (L, H].
static bool isUndefOrInRange(int Val, int Low, int Hi) {
- return (Val == SM_SentinelUndef) || (Val >= Low && Val < Hi);
+ return (Val == SM_SentinelUndef) || isInRange(Val, Low, Hi);
}
/// Return true if every element in Mask is undef or if its value
@@ -4718,7 +4870,7 @@ static bool isUndefOrInRange(ArrayRef<int> Mask,
/// Return true if Val is undef, zero or if its value falls within the
/// specified range (L, H].
static bool isUndefOrZeroOrInRange(int Val, int Low, int Hi) {
- return isUndefOrZero(Val) || (Val >= Low && Val < Hi);
+ return isUndefOrZero(Val) || isInRange(Val, Low, Hi);
}
/// Return true if every element in Mask is undef, zero or if its value
@@ -4731,11 +4883,11 @@ static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) {
}
/// Return true if every element in Mask, beginning
-/// from position Pos and ending in Pos+Size, falls within the specified
-/// sequential range (Low, Low+Size]. or is undef.
-static bool isSequentialOrUndefInRange(ArrayRef<int> Mask,
- unsigned Pos, unsigned Size, int Low) {
- for (unsigned i = Pos, e = Pos+Size; i != e; ++i, ++Low)
+/// from position Pos and ending in Pos + Size, falls within the specified
+/// sequence (Low, Low + Step, ..., Low + (Size - 1) * Step) or is undef.
+static bool isSequentialOrUndefInRange(ArrayRef<int> Mask, unsigned Pos,
+ unsigned Size, int Low, int Step = 1) {
+ for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
if (!isUndefOrEqual(Mask[i], Low))
return false;
return true;
@@ -4762,7 +4914,7 @@ static bool isUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos,
return true;
}
-/// \brief Helper function to test whether a shuffle mask could be
+/// Helper function to test whether a shuffle mask could be
/// simplified by widening the elements being shuffled.
///
/// Appends the mask for wider elements in WidenedMask if valid. Otherwise
@@ -4821,6 +4973,24 @@ static bool canWidenShuffleElements(ArrayRef<int> Mask,
return true;
}
+static bool canWidenShuffleElements(ArrayRef<int> Mask,
+ const APInt &Zeroable,
+ SmallVectorImpl<int> &WidenedMask) {
+ SmallVector<int, 32> TargetMask(Mask.begin(), Mask.end());
+ for (int i = 0, Size = TargetMask.size(); i < Size; ++i) {
+ if (TargetMask[i] == SM_SentinelUndef)
+ continue;
+ if (Zeroable[i])
+ TargetMask[i] = SM_SentinelZero;
+ }
+ return canWidenShuffleElements(TargetMask, WidenedMask);
+}
+
+static bool canWidenShuffleElements(ArrayRef<int> Mask) {
+ SmallVector<int, 32> WidenedMask;
+ return canWidenShuffleElements(Mask, WidenedMask);
+}
+
/// Returns true if Elt is a constant zero or a floating point constant +0.0.
bool X86::isZeroNode(SDValue Elt) {
return isNullConstant(Elt) || isNullFPConstant(Elt);
@@ -4916,8 +5086,6 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
} else if (VT.getVectorElementType() == MVT::i1) {
assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) &&
"Unexpected vector type");
- assert((Subtarget.hasVLX() || VT.getVectorNumElements() >= 8) &&
- "Unexpected vector type");
Vec = DAG.getConstant(0, dl, VT);
} else {
unsigned Num32BitElts = VT.getSizeInBits() / 32;
@@ -5007,10 +5175,66 @@ static SDValue insert128BitVector(SDValue Result, SDValue Vec, unsigned IdxVal,
return insertSubVector(Result, Vec, IdxVal, DAG, dl, 128);
}
-static SDValue insert256BitVector(SDValue Result, SDValue Vec, unsigned IdxVal,
- SelectionDAG &DAG, const SDLoc &dl) {
- assert(Vec.getValueType().is256BitVector() && "Unexpected vector size!");
- return insertSubVector(Result, Vec, IdxVal, DAG, dl, 256);
+/// Widen a vector to a larger size with the same scalar type, with the new
+/// elements either zero or undef.
+static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements,
+ const X86Subtarget &Subtarget, SelectionDAG &DAG,
+ const SDLoc &dl) {
+ assert(Vec.getValueSizeInBits() < VT.getSizeInBits() &&
+ Vec.getValueType().getScalarType() == VT.getScalarType() &&
+ "Unsupported vector widening type");
+ SDValue Res = ZeroNewElements ? getZeroVector(VT, Subtarget, DAG, dl)
+ : DAG.getUNDEF(VT);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, Res, Vec,
+ DAG.getIntPtrConstant(0, dl));
+}
+
+// Helper for splitting operands of an operation to legal target size and
+// apply a function on each part.
+// Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
+// 256-bit and on AVX512BW in 512-bit. The argument VT is the type used for
+// deciding if/how to split Ops. Ops elements do *not* have to be of type VT.
+// The argument Builder is a function that will be applied on each split part:
+// SDValue Builder(SelectionDAG&G, SDLoc, ArrayRef<SDValue>)
+template <typename F>
+SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
+ const SDLoc &DL, EVT VT, ArrayRef<SDValue> Ops,
+ F Builder, bool CheckBWI = true) {
+ assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2");
+ unsigned NumSubs = 1;
+ if ((CheckBWI && Subtarget.useBWIRegs()) ||
+ (!CheckBWI && Subtarget.useAVX512Regs())) {
+ if (VT.getSizeInBits() > 512) {
+ NumSubs = VT.getSizeInBits() / 512;
+ assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size");
+ }
+ } else if (Subtarget.hasAVX2()) {
+ if (VT.getSizeInBits() > 256) {
+ NumSubs = VT.getSizeInBits() / 256;
+ assert((VT.getSizeInBits() % 256) == 0 && "Illegal vector size");
+ }
+ } else {
+ if (VT.getSizeInBits() > 128) {
+ NumSubs = VT.getSizeInBits() / 128;
+ assert((VT.getSizeInBits() % 128) == 0 && "Illegal vector size");
+ }
+ }
+
+ if (NumSubs == 1)
+ return Builder(DAG, DL, Ops);
+
+ SmallVector<SDValue, 4> Subs;
+ for (unsigned i = 0; i != NumSubs; ++i) {
+ SmallVector<SDValue, 2> SubOps;
+ for (SDValue Op : Ops) {
+ EVT OpVT = Op.getValueType();
+ unsigned NumSubElts = OpVT.getVectorNumElements() / NumSubs;
+ unsigned SizeSub = OpVT.getSizeInBits() / NumSubs;
+ SubOps.push_back(extractSubVector(Op, i * NumSubElts, DAG, DL, SizeSub));
+ }
+ Subs.push_back(Builder(DAG, DL, SubOps));
+ }
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
}
// Return true if the instruction zeroes the unused upper part of the
@@ -5019,13 +5243,9 @@ static bool isMaskedZeroUpperBitsvXi1(unsigned int Opcode) {
switch (Opcode) {
default:
return false;
- case X86ISD::TESTM:
- case X86ISD::TESTNM:
- case X86ISD::PCMPEQM:
- case X86ISD::PCMPGTM:
case X86ISD::CMPM:
- case X86ISD::CMPMU:
case X86ISD::CMPM_RND:
+ case ISD::SETCC:
return true;
}
}
@@ -5166,22 +5386,11 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx);
}
-/// Concat two 128-bit vectors into a 256 bit vector using VINSERTF128
-/// instructions. This is used because creating CONCAT_VECTOR nodes of
-/// BUILD_VECTORS returns a larger BUILD_VECTOR while we're trying to lower
-/// large BUILD_VECTORS.
-static SDValue concat128BitVectors(SDValue V1, SDValue V2, EVT VT,
- unsigned NumElems, SelectionDAG &DAG,
- const SDLoc &dl) {
- SDValue V = insert128BitVector(DAG.getUNDEF(VT), V1, 0, DAG, dl);
- return insert128BitVector(V, V2, NumElems / 2, DAG, dl);
-}
-
-static SDValue concat256BitVectors(SDValue V1, SDValue V2, EVT VT,
- unsigned NumElems, SelectionDAG &DAG,
- const SDLoc &dl) {
- SDValue V = insert256BitVector(DAG.getUNDEF(VT), V1, 0, DAG, dl);
- return insert256BitVector(V, V2, NumElems / 2, DAG, dl);
+static SDValue concatSubVectors(SDValue V1, SDValue V2, EVT VT,
+ unsigned NumElems, SelectionDAG &DAG,
+ const SDLoc &dl, unsigned VectorWidth) {
+ SDValue V = insertSubVector(DAG.getUNDEF(VT), V1, 0, DAG, dl, VectorWidth);
+ return insertSubVector(V, V2, NumElems / 2, DAG, dl, VectorWidth);
}
/// Returns a vector of specified type with all bits set.
@@ -5265,6 +5474,13 @@ static SDValue peekThroughOneUseBitcasts(SDValue V) {
return V;
}
+// Peek through EXTRACT_SUBVECTORs - typically used for AVX1 256-bit intops.
+static SDValue peekThroughEXTRACT_SUBVECTORs(SDValue V) {
+ while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+ V = V.getOperand(0);
+ return V;
+}
+
static const Constant *getTargetConstantFromNode(SDValue Op) {
Op = peekThroughBitcasts(Op);
@@ -5389,6 +5605,12 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
SmallVector<APInt, 64> SrcEltBits(1, Cst->getAPIntValue());
return CastBitData(UndefSrcElts, SrcEltBits);
}
+ if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
+ APInt UndefSrcElts = APInt::getNullValue(1);
+ APInt RawBits = Cst->getValueAPF().bitcastToAPInt();
+ SmallVector<APInt, 64> SrcEltBits(1, RawBits);
+ return CastBitData(UndefSrcElts, SrcEltBits);
+ }
// Extract constant bits from build vector.
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) {
@@ -5525,14 +5747,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodeBLENDMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodeBLENDMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
break;
case X86ISD::SHUFP:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodeSHUFPMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodeSHUFPMask(NumElems, VT.getScalarSizeInBits(),
+ cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
break;
case X86ISD::INSERTPS:
@@ -5548,7 +5771,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
isa<ConstantSDNode>(N->getOperand(2))) {
int BitLen = N->getConstantOperandVal(1);
int BitIdx = N->getConstantOperandVal(2);
- DecodeEXTRQIMask(VT, BitLen, BitIdx, Mask);
+ DecodeEXTRQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx,
+ Mask);
IsUnary = true;
}
break;
@@ -5559,20 +5783,21 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
isa<ConstantSDNode>(N->getOperand(3))) {
int BitLen = N->getConstantOperandVal(2);
int BitIdx = N->getConstantOperandVal(3);
- DecodeINSERTQIMask(VT, BitLen, BitIdx, Mask);
+ DecodeINSERTQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx,
+ Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
}
break;
case X86ISD::UNPCKH:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
- DecodeUNPCKHMask(VT, Mask);
+ DecodeUNPCKHMask(NumElems, VT.getScalarSizeInBits(), Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
break;
case X86ISD::UNPCKL:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
- DecodeUNPCKLMask(VT, Mask);
+ DecodeUNPCKLMask(NumElems, VT.getScalarSizeInBits(), Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
break;
case X86ISD::MOVHLPS:
@@ -5592,7 +5817,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodePALIGNRMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePALIGNRMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
Ops.push_back(N->getOperand(1));
Ops.push_back(N->getOperand(0));
@@ -5601,38 +5827,43 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands() - 1);
- DecodePSLLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePSLLDQMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = true;
break;
case X86ISD::VSRLDQ:
assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands() - 1);
- DecodePSRLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePSRLDQMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = true;
break;
case X86ISD::PSHUFD:
case X86ISD::VPERMILPI:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodePSHUFMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePSHUFMask(NumElems, VT.getScalarSizeInBits(),
+ cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
IsUnary = true;
break;
case X86ISD::PSHUFHW:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodePSHUFHWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePSHUFHWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = true;
break;
case X86ISD::PSHUFLW:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodePSHUFLWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodePSHUFLWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = true;
break;
case X86ISD::VZEXT_MOVL:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
- DecodeZeroMoveLowMask(VT, Mask);
+ DecodeZeroMoveLowMask(NumElems, Mask);
IsUnary = true;
break;
case X86ISD::VBROADCAST: {
@@ -5648,7 +5879,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
// came from an extract from the original width. If we found one, we
// pushed it the Ops vector above.
if (N0.getValueType() == VT || !Ops.empty()) {
- DecodeVectorBroadcast(VT, Mask);
+ DecodeVectorBroadcast(NumElems, Mask);
IsUnary = true;
break;
}
@@ -5661,7 +5892,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
unsigned MaskEltSize = VT.getScalarSizeInBits();
SmallVector<uint64_t, 32> RawMask;
if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) {
- DecodeVPERMILPMask(VT, RawMask, Mask);
+ DecodeVPERMILPMask(NumElems, VT.getScalarSizeInBits(), RawMask, Mask);
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
@@ -5690,41 +5921,47 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
case X86ISD::VPERMI:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodeVPERMMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodeVPERMMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
IsUnary = true;
break;
case X86ISD::MOVSS:
case X86ISD::MOVSD:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
- DecodeScalarMoveMask(VT, /* IsLoad */ false, Mask);
+ DecodeScalarMoveMask(NumElems, /* IsLoad */ false, Mask);
break;
case X86ISD::VPERM2X128:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
ImmN = N->getOperand(N->getNumOperands()-1);
- DecodeVPERM2X128Mask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask);
+ DecodeVPERM2X128Mask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
+ IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
+ break;
+ case X86ISD::SHUF128:
+ assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
+ assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
+ ImmN = N->getOperand(N->getNumOperands()-1);
+ decodeVSHUF64x2FamilyMask(NumElems, VT.getScalarSizeInBits(),
+ cast<ConstantSDNode>(ImmN)->getZExtValue(),
+ Mask);
IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1);
break;
case X86ISD::MOVSLDUP:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
- DecodeMOVSLDUPMask(VT, Mask);
+ DecodeMOVSLDUPMask(NumElems, Mask);
IsUnary = true;
break;
case X86ISD::MOVSHDUP:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
- DecodeMOVSHDUPMask(VT, Mask);
+ DecodeMOVSHDUPMask(NumElems, Mask);
IsUnary = true;
break;
case X86ISD::MOVDDUP:
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
- DecodeMOVDDUPMask(VT, Mask);
+ DecodeMOVDDUPMask(NumElems, Mask);
IsUnary = true;
break;
- case X86ISD::MOVLPD:
- case X86ISD::MOVLPS:
- // Not yet implemented
- return false;
case X86ISD::VPERMIL2: {
assert(N->getOperand(0).getValueType() == VT && "Unexpected value type");
assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
@@ -5736,7 +5973,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
unsigned CtrlImm = CtrlOp->getZExtValue();
SmallVector<uint64_t, 32> RawMask;
if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) {
- DecodeVPERMIL2PMask(VT, CtrlImm, RawMask, Mask);
+ DecodeVPERMIL2PMask(NumElems, VT.getScalarSizeInBits(), CtrlImm,
+ RawMask, Mask);
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
@@ -5795,21 +6033,6 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
}
return false;
}
- case X86ISD::VPERMIV3: {
- assert(N->getOperand(1).getValueType() == VT && "Unexpected value type");
- assert(N->getOperand(2).getValueType() == VT && "Unexpected value type");
- IsUnary = IsFakeUnary = N->getOperand(1) == N->getOperand(2);
- // Unlike most shuffle nodes, VPERMIV3's mask operand is the first one.
- Ops.push_back(N->getOperand(1));
- Ops.push_back(N->getOperand(2));
- SDValue MaskNode = N->getOperand(0);
- unsigned MaskEltSize = VT.getScalarSizeInBits();
- if (auto *C = getTargetConstantFromNode(MaskNode)) {
- DecodeVPERMV3Mask(C, MaskEltSize, Mask);
- break;
- }
- return false;
- }
default: llvm_unreachable("unknown target shuffle node");
}
@@ -5927,7 +6150,7 @@ static bool setTargetShuffleZeroElements(SDValue N,
// destination value type.
static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
SmallVectorImpl<SDValue> &Ops,
- SelectionDAG &DAG) {
+ const SelectionDAG &DAG) {
Mask.clear();
Ops.clear();
@@ -5940,6 +6163,17 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
unsigned Opcode = N.getOpcode();
switch (Opcode) {
+ case ISD::VECTOR_SHUFFLE: {
+ // Don't treat ISD::VECTOR_SHUFFLE as a target shuffle so decode it here.
+ ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(N)->getMask();
+ if (isUndefOrInRange(ShuffleMask, 0, 2 * NumElts)) {
+ Mask.append(ShuffleMask.begin(), ShuffleMask.end());
+ Ops.push_back(N.getOperand(0));
+ Ops.push_back(N.getOperand(1));
+ return true;
+ }
+ return false;
+ }
case ISD::AND:
case X86ISD::ANDNP: {
// Attempt to decode as a per-byte mask.
@@ -6001,8 +6235,11 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
case X86ISD::PINSRW: {
SDValue InVec = N.getOperand(0);
SDValue InScl = N.getOperand(1);
+ SDValue InIndex = N.getOperand(2);
+ if (!isa<ConstantSDNode>(InIndex) ||
+ cast<ConstantSDNode>(InIndex)->getAPIntValue().uge(NumElts))
+ return false;
uint64_t InIdx = N.getConstantOperandVal(2);
- assert(InIdx < NumElts && "Illegal insertion index");
// Attempt to recognise a PINSR*(VEC, 0, Idx) shuffle pattern.
if (X86::isZeroNode(InScl)) {
@@ -6020,8 +6257,12 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
return false;
SDValue ExVec = InScl.getOperand(0);
+ SDValue ExIndex = InScl.getOperand(1);
+ if (!isa<ConstantSDNode>(ExIndex) ||
+ cast<ConstantSDNode>(ExIndex)->getAPIntValue().uge(NumElts))
+ return false;
uint64_t ExIdx = InScl.getConstantOperandVal(1);
- assert(ExIdx < NumElts && "Illegal extraction index");
+
Ops.push_back(InVec);
Ops.push_back(ExVec);
for (unsigned i = 0; i != NumElts; ++i)
@@ -6097,7 +6338,8 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
MVT SrcVT = Src.getSimpleValueType();
if (NumSizeInBits != SrcVT.getSizeInBits())
break;
- DecodeZeroExtendMask(SrcVT.getScalarType(), VT, Mask);
+ DecodeZeroExtendMask(SrcVT.getScalarSizeInBits(), VT.getScalarSizeInBits(),
+ VT.getVectorNumElements(), Mask);
Ops.push_back(Src);
return true;
}
@@ -6141,7 +6383,7 @@ static void resolveTargetShuffleInputsAndMask(SmallVectorImpl<SDValue> &Inputs,
static bool resolveTargetShuffleInputs(SDValue Op,
SmallVectorImpl<SDValue> &Inputs,
SmallVectorImpl<int> &Mask,
- SelectionDAG &DAG) {
+ const SelectionDAG &DAG) {
if (!setTargetShuffleZeroElements(Op, Mask, Inputs))
if (!getFauxShuffleMask(Op, Mask, Inputs, DAG))
return false;
@@ -6451,9 +6693,8 @@ static SDValue getVShift(bool isLeft, EVT VT, SDValue SrcOp, unsigned NumBits,
MVT ShVT = MVT::v16i8;
unsigned Opc = isLeft ? X86ISD::VSHLDQ : X86ISD::VSRLDQ;
SrcOp = DAG.getBitcast(ShVT, SrcOp);
- MVT ScalarShiftTy = TLI.getScalarShiftAmountTy(DAG.getDataLayout(), VT);
assert(NumBits % 8 == 0 && "Only support byte sized shifts");
- SDValue ShiftVal = DAG.getConstant(NumBits/8, dl, ScalarShiftTy);
+ SDValue ShiftVal = DAG.getConstant(NumBits/8, dl, MVT::i8);
return DAG.getBitcast(VT, DAG.getNode(Opc, dl, ShVT, SrcOp, ShiftVal));
}
@@ -6805,17 +7046,13 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
BOperand = ZeroExtended.getOperand(0);
else
BOperand = Ld.getOperand(0).getOperand(0);
- if (BOperand.getValueType().isVector() &&
- BOperand.getSimpleValueType().getVectorElementType() == MVT::i1) {
- if ((EltType == MVT::i64 && (VT.getVectorElementType() == MVT::i8 ||
- NumElts == 8)) || // for broadcastmb2q
- (EltType == MVT::i32 && (VT.getVectorElementType() == MVT::i16 ||
- NumElts == 16))) { // for broadcastmw2d
- SDValue Brdcst =
- DAG.getNode(X86ISD::VBROADCASTM, dl,
- MVT::getVectorVT(EltType, NumElts), BOperand);
- return DAG.getBitcast(VT, Brdcst);
- }
+ MVT MaskVT = BOperand.getSimpleValueType();
+ if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q
+ (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
+ SDValue Brdcst =
+ DAG.getNode(X86ISD::VBROADCASTM, dl,
+ MVT::getVectorVT(EltType, NumElts), BOperand);
+ return DAG.getBitcast(VT, Brdcst);
}
}
}
@@ -6982,7 +7219,7 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
return SDValue();
}
-/// \brief For an EXTRACT_VECTOR_ELT with a constant index return the real
+/// For an EXTRACT_VECTOR_ELT with a constant index return the real
/// underlying vector and index.
///
/// Modifies \p ExtractedFromVec to the real vector and returns the real
@@ -7195,7 +7432,7 @@ static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG,
return DstVec;
}
-/// \brief Return true if \p N implements a horizontal binop and return the
+/// Return true if \p N implements a horizontal binop and return the
/// operands for the horizontal binop into V0 and V1.
///
/// This is a helper function of LowerToHorizontalOp().
@@ -7292,7 +7529,7 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
return CanFold;
}
-/// \brief Emit a sequence of two 128-bit horizontal add/sub followed by
+/// Emit a sequence of two 128-bit horizontal add/sub followed by
/// a concat_vector.
///
/// This is a helper function of LowerToHorizontalOp().
@@ -7360,18 +7597,18 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1,
}
/// Returns true iff \p BV builds a vector with the result equivalent to
-/// the result of ADDSUB operation.
-/// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1 operation
-/// are written to the parameters \p Opnd0 and \p Opnd1.
-static bool isAddSub(const BuildVectorSDNode *BV,
- const X86Subtarget &Subtarget, SelectionDAG &DAG,
- SDValue &Opnd0, SDValue &Opnd1,
- unsigned &NumExtracts) {
+/// the result of ADDSUB/SUBADD operation.
+/// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1
+/// (SUBADD = Opnd0 -+ Opnd1) operation are written to the parameters
+/// \p Opnd0 and \p Opnd1.
+static bool isAddSubOrSubAdd(const BuildVectorSDNode *BV,
+ const X86Subtarget &Subtarget, SelectionDAG &DAG,
+ SDValue &Opnd0, SDValue &Opnd1,
+ unsigned &NumExtracts,
+ bool &IsSubAdd) {
MVT VT = BV->getSimpleValueType(0);
- if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) &&
- (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) &&
- (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64)))
+ if (!Subtarget.hasSSE3() || !VT.isFloatingPoint())
return false;
unsigned NumElts = VT.getVectorNumElements();
@@ -7381,26 +7618,20 @@ static bool isAddSub(const BuildVectorSDNode *BV,
NumExtracts = 0;
// Odd-numbered elements in the input build vector are obtained from
- // adding two integer/float elements.
+ // adding/subtracting two integer/float elements.
// Even-numbered elements in the input build vector are obtained from
- // subtracting two integer/float elements.
- unsigned ExpectedOpcode = ISD::FSUB;
- unsigned NextExpectedOpcode = ISD::FADD;
- bool AddFound = false;
- bool SubFound = false;
-
+ // subtracting/adding two integer/float elements.
+ unsigned Opc[2] {0, 0};
for (unsigned i = 0, e = NumElts; i != e; ++i) {
SDValue Op = BV->getOperand(i);
// Skip 'undef' values.
unsigned Opcode = Op.getOpcode();
- if (Opcode == ISD::UNDEF) {
- std::swap(ExpectedOpcode, NextExpectedOpcode);
+ if (Opcode == ISD::UNDEF)
continue;
- }
// Early exit if we found an unexpected opcode.
- if (Opcode != ExpectedOpcode)
+ if (Opcode != ISD::FADD && Opcode != ISD::FSUB)
return false;
SDValue Op0 = Op.getOperand(0);
@@ -7420,11 +7651,11 @@ static bool isAddSub(const BuildVectorSDNode *BV,
if (I0 != i)
return false;
- // We found a valid add/sub node. Update the information accordingly.
- if (i & 1)
- AddFound = true;
- else
- SubFound = true;
+ // We found a valid add/sub node, make sure its the same opcode as previous
+ // elements for this parity.
+ if (Opc[i % 2] != 0 && Opc[i % 2] != Opcode)
+ return false;
+ Opc[i % 2] = Opcode;
// Update InVec0 and InVec1.
if (InVec0.isUndef()) {
@@ -7441,7 +7672,7 @@ static bool isAddSub(const BuildVectorSDNode *BV,
// Make sure that operands in input to each add/sub node always
// come from a same pair of vectors.
if (InVec0 != Op0.getOperand(0)) {
- if (ExpectedOpcode == ISD::FSUB)
+ if (Opcode == ISD::FSUB)
return false;
// FADD is commutable. Try to commute the operands
@@ -7454,24 +7685,26 @@ static bool isAddSub(const BuildVectorSDNode *BV,
if (InVec1 != Op1.getOperand(0))
return false;
- // Update the pair of expected opcodes.
- std::swap(ExpectedOpcode, NextExpectedOpcode);
-
// Increment the number of extractions done.
++NumExtracts;
}
- // Don't try to fold this build_vector into an ADDSUB if the inputs are undef.
- if (!AddFound || !SubFound || InVec0.isUndef() || InVec1.isUndef())
+ // Ensure we have found an opcode for both parities and that they are
+ // different. Don't try to fold this build_vector into an ADDSUB/SUBADD if the
+ // inputs are undef.
+ if (!Opc[0] || !Opc[1] || Opc[0] == Opc[1] ||
+ InVec0.isUndef() || InVec1.isUndef())
return false;
+ IsSubAdd = Opc[0] == ISD::FADD;
+
Opnd0 = InVec0;
Opnd1 = InVec1;
return true;
}
/// Returns true if is possible to fold MUL and an idiom that has already been
-/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into
+/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into
/// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the
/// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2.
///
@@ -7521,14 +7754,17 @@ static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget,
return true;
}
-/// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' operation
-/// accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB node.
+/// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' or
+/// 'fsubadd' operation accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB or
+/// X86ISD::FMSUBADD node.
static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
SDValue Opnd0, Opnd1;
unsigned NumExtracts;
- if (!isAddSub(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts))
+ bool IsSubAdd;
+ if (!isAddSubOrSubAdd(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts,
+ IsSubAdd))
return SDValue();
MVT VT = BV->getSimpleValueType(0);
@@ -7536,10 +7772,14 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV,
// Try to generate X86ISD::FMADDSUB node here.
SDValue Opnd2;
- // TODO: According to coverage reports, the FMADDSUB transform is not
- // triggered by any tests.
- if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts))
- return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);
+ if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) {
+ unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
+ return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2);
+ }
+
+ // We only support ADDSUB.
+ if (IsSubAdd)
+ return SDValue();
// Do not generate X86ISD::ADDSUB node for 512-bit types even though
// the ADDSUB idiom has been successfully recognized. There are no known
@@ -7708,6 +7948,10 @@ static SDValue lowerBuildVectorToBitOp(BuildVectorSDNode *Op,
case ISD::AND:
case ISD::XOR:
case ISD::OR:
+ // Don't do this if the buildvector is a splat - we'd replace one
+ // constant with an entire vector.
+ if (Op->getSplatValue())
+ return SDValue();
if (!TLI.isOperationLegalOrPromote(Opcode, VT))
return SDValue();
break;
@@ -7762,66 +8006,268 @@ static SDValue materializeVectorConstant(SDValue Op, SelectionDAG &DAG,
return SDValue();
}
-// Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be
-// reasoned to be a permutation of a vector by indices in a non-constant vector.
-// (build_vector (extract_elt V, (extract_elt I, 0)),
-// (extract_elt V, (extract_elt I, 1)),
-// ...
-// ->
-// (vpermv I, V)
-//
-// TODO: Handle undefs
-// TODO: Utilize pshufb and zero mask blending to support more efficient
-// construction of vectors with constant-0 elements.
-// TODO: Use smaller-element vectors of same width, and "interpolate" the indices,
-// when no native operation available.
-static SDValue
-LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
- // Look for VPERMV and PSHUFB opportunities.
- MVT VT = V.getSimpleValueType();
+/// Look for opportunities to create a VPERMV/VPERMILPV/PSHUFB variable permute
+/// from a vector of source values and a vector of extraction indices.
+/// The vectors might be manipulated to match the type of the permute op.
+static SDValue createVariablePermute(MVT VT, SDValue SrcVec, SDValue IndicesVec,
+ SDLoc &DL, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ MVT ShuffleVT = VT;
+ EVT IndicesVT = EVT(VT).changeVectorElementTypeToInteger();
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned SizeInBits = VT.getSizeInBits();
+
+ // Adjust IndicesVec to match VT size.
+ assert(IndicesVec.getValueType().getVectorNumElements() >= NumElts &&
+ "Illegal variable permute mask size");
+ if (IndicesVec.getValueType().getVectorNumElements() > NumElts)
+ IndicesVec = extractSubVector(IndicesVec, 0, DAG, SDLoc(IndicesVec),
+ NumElts * VT.getScalarSizeInBits());
+ IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT);
+
+ // Handle SrcVec that don't match VT type.
+ if (SrcVec.getValueSizeInBits() != SizeInBits) {
+ if ((SrcVec.getValueSizeInBits() % SizeInBits) == 0) {
+ // Handle larger SrcVec by treating it as a larger permute.
+ unsigned Scale = SrcVec.getValueSizeInBits() / SizeInBits;
+ VT = MVT::getVectorVT(VT.getScalarType(), Scale * NumElts);
+ IndicesVT = EVT(VT).changeVectorElementTypeToInteger();
+ IndicesVec = widenSubVector(IndicesVT.getSimpleVT(), IndicesVec, false,
+ Subtarget, DAG, SDLoc(IndicesVec));
+ return extractSubVector(
+ createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget), 0,
+ DAG, DL, SizeInBits);
+ } else if (SrcVec.getValueSizeInBits() < SizeInBits) {
+ // Widen smaller SrcVec to match VT.
+ SrcVec = widenSubVector(VT, SrcVec, false, Subtarget, DAG, SDLoc(SrcVec));
+ } else
+ return SDValue();
+ }
+
+ auto ScaleIndices = [&DAG](SDValue Idx, uint64_t Scale) {
+ assert(isPowerOf2_64(Scale) && "Illegal variable permute shuffle scale");
+ EVT SrcVT = Idx.getValueType();
+ unsigned NumDstBits = SrcVT.getScalarSizeInBits() / Scale;
+ uint64_t IndexScale = 0;
+ uint64_t IndexOffset = 0;
+
+ // If we're scaling a smaller permute op, then we need to repeat the
+ // indices, scaling and offsetting them as well.
+ // e.g. v4i32 -> v16i8 (Scale = 4)
+ // IndexScale = v4i32 Splat(4 << 24 | 4 << 16 | 4 << 8 | 4)
+ // IndexOffset = v4i32 Splat(3 << 24 | 2 << 16 | 1 << 8 | 0)
+ for (uint64_t i = 0; i != Scale; ++i) {
+ IndexScale |= Scale << (i * NumDstBits);
+ IndexOffset |= i << (i * NumDstBits);
+ }
+
+ Idx = DAG.getNode(ISD::MUL, SDLoc(Idx), SrcVT, Idx,
+ DAG.getConstant(IndexScale, SDLoc(Idx), SrcVT));
+ Idx = DAG.getNode(ISD::ADD, SDLoc(Idx), SrcVT, Idx,
+ DAG.getConstant(IndexOffset, SDLoc(Idx), SrcVT));
+ return Idx;
+ };
+
+ unsigned Opcode = 0;
switch (VT.SimpleTy) {
default:
- return SDValue();
+ break;
case MVT::v16i8:
- if (!Subtarget.hasSSE3())
- return SDValue();
+ if (Subtarget.hasSSSE3())
+ Opcode = X86ISD::PSHUFB;
+ break;
+ case MVT::v8i16:
+ if (Subtarget.hasVLX() && Subtarget.hasBWI())
+ Opcode = X86ISD::VPERMV;
+ else if (Subtarget.hasSSSE3()) {
+ Opcode = X86ISD::PSHUFB;
+ ShuffleVT = MVT::v16i8;
+ }
+ break;
+ case MVT::v4f32:
+ case MVT::v4i32:
+ if (Subtarget.hasAVX()) {
+ Opcode = X86ISD::VPERMILPV;
+ ShuffleVT = MVT::v4f32;
+ } else if (Subtarget.hasSSSE3()) {
+ Opcode = X86ISD::PSHUFB;
+ ShuffleVT = MVT::v16i8;
+ }
+ break;
+ case MVT::v2f64:
+ case MVT::v2i64:
+ if (Subtarget.hasAVX()) {
+ // VPERMILPD selects using bit#1 of the index vector, so scale IndicesVec.
+ IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec);
+ Opcode = X86ISD::VPERMILPV;
+ ShuffleVT = MVT::v2f64;
+ } else if (Subtarget.hasSSE41()) {
+ // SSE41 can compare v2i64 - select between indices 0 and 1.
+ return DAG.getSelectCC(
+ DL, IndicesVec,
+ getZeroVector(IndicesVT.getSimpleVT(), Subtarget, DAG, DL),
+ DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {0, 0}),
+ DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {1, 1}),
+ ISD::CondCode::SETEQ);
+ }
+ break;
+ case MVT::v32i8:
+ if (Subtarget.hasVLX() && Subtarget.hasVBMI())
+ Opcode = X86ISD::VPERMV;
+ else if (Subtarget.hasXOP()) {
+ SDValue LoSrc = extract128BitVector(SrcVec, 0, DAG, DL);
+ SDValue HiSrc = extract128BitVector(SrcVec, 16, DAG, DL);
+ SDValue LoIdx = extract128BitVector(IndicesVec, 0, DAG, DL);
+ SDValue HiIdx = extract128BitVector(IndicesVec, 16, DAG, DL);
+ return DAG.getNode(
+ ISD::CONCAT_VECTORS, DL, VT,
+ DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, LoIdx),
+ DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, HiIdx));
+ } else if (Subtarget.hasAVX()) {
+ SDValue Lo = extract128BitVector(SrcVec, 0, DAG, DL);
+ SDValue Hi = extract128BitVector(SrcVec, 16, DAG, DL);
+ SDValue LoLo = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Lo);
+ SDValue HiHi = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Hi, Hi);
+ auto PSHUFBBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Permute Lo and Hi and then select based on index range.
+ // This works as SHUFB uses bits[3:0] to permute elements and we don't
+ // care about the bit[7] as its just an index vector.
+ SDValue Idx = Ops[2];
+ EVT VT = Idx.getValueType();
+ return DAG.getSelectCC(DL, Idx, DAG.getConstant(15, DL, VT),
+ DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[1], Idx),
+ DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[0], Idx),
+ ISD::CondCode::SETGT);
+ };
+ SDValue Ops[] = {LoLo, HiHi, IndicesVec};
+ return SplitOpsAndApply(DAG, Subtarget, DL, MVT::v32i8, Ops,
+ PSHUFBBuilder);
+ }
+ break;
+ case MVT::v16i16:
+ if (Subtarget.hasVLX() && Subtarget.hasBWI())
+ Opcode = X86ISD::VPERMV;
+ else if (Subtarget.hasAVX()) {
+ // Scale to v32i8 and perform as v32i8.
+ IndicesVec = ScaleIndices(IndicesVec, 2);
+ return DAG.getBitcast(
+ VT, createVariablePermute(
+ MVT::v32i8, DAG.getBitcast(MVT::v32i8, SrcVec),
+ DAG.getBitcast(MVT::v32i8, IndicesVec), DL, DAG, Subtarget));
+ }
break;
case MVT::v8f32:
case MVT::v8i32:
- if (!Subtarget.hasAVX2())
- return SDValue();
+ if (Subtarget.hasAVX2())
+ Opcode = X86ISD::VPERMV;
+ else if (Subtarget.hasAVX()) {
+ SrcVec = DAG.getBitcast(MVT::v8f32, SrcVec);
+ SDValue LoLo = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec,
+ {0, 1, 2, 3, 0, 1, 2, 3});
+ SDValue HiHi = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec,
+ {4, 5, 6, 7, 4, 5, 6, 7});
+ if (Subtarget.hasXOP())
+ return DAG.getBitcast(VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v8f32,
+ LoLo, HiHi, IndicesVec,
+ DAG.getConstant(0, DL, MVT::i8)));
+ // Permute Lo and Hi and then select based on index range.
+ // This works as VPERMILPS only uses index bits[0:1] to permute elements.
+ SDValue Res = DAG.getSelectCC(
+ DL, IndicesVec, DAG.getConstant(3, DL, MVT::v8i32),
+ DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, HiHi, IndicesVec),
+ DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, LoLo, IndicesVec),
+ ISD::CondCode::SETGT);
+ return DAG.getBitcast(VT, Res);
+ }
break;
case MVT::v4i64:
case MVT::v4f64:
- if (!Subtarget.hasVLX())
- return SDValue();
+ if (Subtarget.hasAVX512()) {
+ if (!Subtarget.hasVLX()) {
+ MVT WidenSrcVT = MVT::getVectorVT(VT.getScalarType(), 8);
+ SrcVec = widenSubVector(WidenSrcVT, SrcVec, false, Subtarget, DAG,
+ SDLoc(SrcVec));
+ IndicesVec = widenSubVector(MVT::v8i64, IndicesVec, false, Subtarget,
+ DAG, SDLoc(IndicesVec));
+ SDValue Res = createVariablePermute(WidenSrcVT, SrcVec, IndicesVec, DL,
+ DAG, Subtarget);
+ return extract256BitVector(Res, 0, DAG, DL);
+ }
+ Opcode = X86ISD::VPERMV;
+ } else if (Subtarget.hasAVX()) {
+ SrcVec = DAG.getBitcast(MVT::v4f64, SrcVec);
+ SDValue LoLo =
+ DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {0, 1, 0, 1});
+ SDValue HiHi =
+ DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {2, 3, 2, 3});
+ // VPERMIL2PD selects with bit#1 of the index vector, so scale IndicesVec.
+ IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec);
+ if (Subtarget.hasXOP())
+ return DAG.getBitcast(VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v4f64,
+ LoLo, HiHi, IndicesVec,
+ DAG.getConstant(0, DL, MVT::i8)));
+ // Permute Lo and Hi and then select based on index range.
+ // This works as VPERMILPD only uses index bit[1] to permute elements.
+ SDValue Res = DAG.getSelectCC(
+ DL, IndicesVec, DAG.getConstant(2, DL, MVT::v4i64),
+ DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, HiHi, IndicesVec),
+ DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, LoLo, IndicesVec),
+ ISD::CondCode::SETGT);
+ return DAG.getBitcast(VT, Res);
+ }
break;
- case MVT::v16f32:
- case MVT::v8f64:
- case MVT::v16i32:
- case MVT::v8i64:
- if (!Subtarget.hasAVX512())
- return SDValue();
+ case MVT::v64i8:
+ if (Subtarget.hasVBMI())
+ Opcode = X86ISD::VPERMV;
break;
case MVT::v32i16:
- if (!Subtarget.hasBWI())
- return SDValue();
- break;
- case MVT::v8i16:
- case MVT::v16i16:
- if (!Subtarget.hasVLX() || !Subtarget.hasBWI())
- return SDValue();
- break;
- case MVT::v64i8:
- if (!Subtarget.hasVBMI())
- return SDValue();
+ if (Subtarget.hasBWI())
+ Opcode = X86ISD::VPERMV;
break;
- case MVT::v32i8:
- if (!Subtarget.hasVLX() || !Subtarget.hasVBMI())
- return SDValue();
+ case MVT::v16f32:
+ case MVT::v16i32:
+ case MVT::v8f64:
+ case MVT::v8i64:
+ if (Subtarget.hasAVX512())
+ Opcode = X86ISD::VPERMV;
break;
}
+ if (!Opcode)
+ return SDValue();
+
+ assert((VT.getSizeInBits() == ShuffleVT.getSizeInBits()) &&
+ (VT.getScalarSizeInBits() % ShuffleVT.getScalarSizeInBits()) == 0 &&
+ "Illegal variable permute shuffle type");
+
+ uint64_t Scale = VT.getScalarSizeInBits() / ShuffleVT.getScalarSizeInBits();
+ if (Scale > 1)
+ IndicesVec = ScaleIndices(IndicesVec, Scale);
+
+ EVT ShuffleIdxVT = EVT(ShuffleVT).changeVectorElementTypeToInteger();
+ IndicesVec = DAG.getBitcast(ShuffleIdxVT, IndicesVec);
+
+ SrcVec = DAG.getBitcast(ShuffleVT, SrcVec);
+ SDValue Res = Opcode == X86ISD::VPERMV
+ ? DAG.getNode(Opcode, DL, ShuffleVT, IndicesVec, SrcVec)
+ : DAG.getNode(Opcode, DL, ShuffleVT, SrcVec, IndicesVec);
+ return DAG.getBitcast(VT, Res);
+}
+
+// Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be
+// reasoned to be a permutation of a vector by indices in a non-constant vector.
+// (build_vector (extract_elt V, (extract_elt I, 0)),
+// (extract_elt V, (extract_elt I, 1)),
+// ...
+// ->
+// (vpermv I, V)
+//
+// TODO: Handle undefs
+// TODO: Utilize pshufb and zero mask blending to support more efficient
+// construction of vectors with constant-0 elements.
+static SDValue
+LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
SDValue SrcVec, IndicesVec;
// Check for a match of the permute source vector and permute index elements.
// This is done by checking that the i-th build_vector operand is of the form:
@@ -7858,13 +8304,10 @@ LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG,
if (!PermIdx || PermIdx->getZExtValue() != Idx)
return SDValue();
}
- MVT IndicesVT = VT;
- if (VT.isFloatingPoint())
- IndicesVT = MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits()),
- VT.getVectorNumElements());
- IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT);
- return DAG.getNode(VT == MVT::v16i8 ? X86ISD::PSHUFB : X86ISD::VPERMV,
- SDLoc(V), VT, IndicesVec, SrcVec);
+
+ SDLoc DL(V);
+ MVT VT = V.getSimpleValueType();
+ return createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget);
}
SDValue
@@ -7872,7 +8315,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
MVT VT = Op.getSimpleValueType();
- MVT ExtVT = VT.getVectorElementType();
+ MVT EltVT = VT.getVectorElementType();
unsigned NumElems = Op.getNumOperands();
// Generate vectors for predicate vectors.
@@ -7883,8 +8326,6 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
return VectorConstant;
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(Op.getNode());
- // TODO: Support FMSUBADD here if we ever get tests for the FMADDSUB
- // transform here.
if (SDValue AddSub = lowerToAddSubOrFMAddSub(BV, Subtarget, DAG))
return AddSub;
if (SDValue HorizontalOp = LowerToHorizontalOp(BV, Subtarget, DAG))
@@ -7894,7 +8335,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
if (SDValue BitOp = lowerBuildVectorToBitOp(BV, DAG))
return BitOp;
- unsigned EVTBits = ExtVT.getSizeInBits();
+ unsigned EVTBits = EltVT.getSizeInBits();
unsigned NumZero = 0;
unsigned NumNonZero = 0;
@@ -7930,13 +8371,13 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
// supported, we assume that we will fall back to a shuffle to get the scalar
// blended with the constants. Insertion into a zero vector is handled as a
// special-case somewhere below here.
- LLVMContext &Context = *DAG.getContext();
if (NumConstants == NumElems - 1 && NumNonZero != 1 &&
(isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT) ||
isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) {
// Create an all-constant vector. The variable element in the old
// build vector is replaced by undef in the constant vector. Save the
// variable scalar element and its index for use in the insertelement.
+ LLVMContext &Context = *DAG.getContext();
Type *EltType = Op.getValueType().getScalarType().getTypeForEVT(Context);
SmallVector<Constant *, 16> ConstVecOps(NumElems, UndefValue::get(EltType));
SDValue VarElt;
@@ -7975,27 +8416,6 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
unsigned Idx = countTrailingZeros(NonZeros);
SDValue Item = Op.getOperand(Idx);
- // If this is an insertion of an i64 value on x86-32, and if the top bits of
- // the value are obviously zero, truncate the value to i32 and do the
- // insertion that way. Only do this if the value is non-constant or if the
- // value is a constant being inserted into element 0. It is cheaper to do
- // a constant pool load than it is to do a movd + shuffle.
- if (ExtVT == MVT::i64 && !Subtarget.is64Bit() &&
- (!IsAllConstants || Idx == 0)) {
- if (DAG.MaskedValueIsZero(Item, APInt::getHighBitsSet(64, 32))) {
- // Handle SSE only.
- assert(VT == MVT::v2i64 && "Expected an SSE value type!");
- MVT VecVT = MVT::v4i32;
-
- // Truncate the value (which may itself be a constant) to i32, and
- // convert it to a vector with movd (S2V+shuffle to zero extend).
- Item = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Item);
- Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Item);
- return DAG.getBitcast(VT, getShuffleVectorZeroOrUndef(
- Item, Idx * 2, true, Subtarget, DAG));
- }
- }
-
// If we have a constant or non-constant insertion into the low element of
// a vector, we can do this with SCALAR_TO_VECTOR + shuffle of zero into
// the rest of the elements. This will be matched as movd/movq/movss/movsd
@@ -8004,8 +8424,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
if (NumZero == 0)
return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item);
- if (ExtVT == MVT::i32 || ExtVT == MVT::f32 || ExtVT == MVT::f64 ||
- (ExtVT == MVT::i64 && Subtarget.is64Bit())) {
+ if (EltVT == MVT::i32 || EltVT == MVT::f32 || EltVT == MVT::f64 ||
+ (EltVT == MVT::i64 && Subtarget.is64Bit())) {
assert((VT.is128BitVector() || VT.is256BitVector() ||
VT.is512BitVector()) &&
"Expected an SSE value type!");
@@ -8016,7 +8436,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
// We can't directly insert an i8 or i16 into a vector, so zero extend
// it to i32 first.
- if (ExtVT == MVT::i16 || ExtVT == MVT::i8) {
+ if (EltVT == MVT::i16 || EltVT == MVT::i8) {
Item = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Item);
if (VT.getSizeInBits() >= 256) {
MVT ShufVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits()/32);
@@ -8088,17 +8508,43 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
return V;
// See if we can use a vector load to get all of the elements.
- if (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) {
+ {
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
if (SDValue LD =
EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
return LD;
}
+ // If this is a splat of pairs of 32-bit elements, we can use a narrower
+ // build_vector and broadcast it.
+ // TODO: We could probably generalize this more.
+ if (Subtarget.hasAVX2() && EVTBits == 32 && Values.size() == 2) {
+ SDValue Ops[4] = { Op.getOperand(0), Op.getOperand(1),
+ DAG.getUNDEF(EltVT), DAG.getUNDEF(EltVT) };
+ auto CanSplat = [](SDValue Op, unsigned NumElems, ArrayRef<SDValue> Ops) {
+ // Make sure all the even/odd operands match.
+ for (unsigned i = 2; i != NumElems; ++i)
+ if (Ops[i % 2] != Op.getOperand(i))
+ return false;
+ return true;
+ };
+ if (CanSplat(Op, NumElems, Ops)) {
+ MVT WideEltVT = VT.isFloatingPoint() ? MVT::f64 : MVT::i64;
+ MVT NarrowVT = MVT::getVectorVT(EltVT, 4);
+ // Create a new build vector and cast to v2i64/v2f64.
+ SDValue NewBV = DAG.getBitcast(MVT::getVectorVT(WideEltVT, 2),
+ DAG.getBuildVector(NarrowVT, dl, Ops));
+ // Broadcast from v2i64/v2f64 and cast to final VT.
+ MVT BcastVT = MVT::getVectorVT(WideEltVT, NumElems/2);
+ return DAG.getBitcast(VT, DAG.getNode(X86ISD::VBROADCAST, dl, BcastVT,
+ NewBV));
+ }
+ }
+
// For AVX-length vectors, build the individual 128-bit pieces and use
// shuffles to put them in place.
- if (VT.is256BitVector() || VT.is512BitVector()) {
- EVT HVT = EVT::getVectorVT(Context, ExtVT, NumElems/2);
+ if (VT.getSizeInBits() > 128) {
+ MVT HVT = MVT::getVectorVT(EltVT, NumElems/2);
// Build both the lower and upper subvector.
SDValue Lower =
@@ -8107,9 +8553,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
HVT, dl, Op->ops().slice(NumElems / 2, NumElems /2));
// Recreate the wider vector with the lower and upper part.
- if (VT.is256BitVector())
- return concat128BitVectors(Lower, Upper, VT, NumElems, DAG, dl);
- return concat256BitVectors(Lower, Upper, VT, NumElems, DAG, dl);
+ return concatSubVectors(Lower, Upper, VT, NumElems, DAG, dl,
+ VT.getSizeInBits() / 2);
}
// Let legalizer expand 2-wide build_vectors.
@@ -8234,30 +8679,60 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
// 256-bit AVX can use the vinsertf128 instruction
// to create 256-bit vectors from two other 128-bit ones.
-static SDValue LowerAVXCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) {
+// TODO: Detect subvector broadcast here instead of DAG combine?
+static SDValue LowerAVXCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
SDLoc dl(Op);
MVT ResVT = Op.getSimpleValueType();
assert((ResVT.is256BitVector() ||
ResVT.is512BitVector()) && "Value type must be 256-/512-bit wide");
- SDValue V1 = Op.getOperand(0);
- SDValue V2 = Op.getOperand(1);
- unsigned NumElems = ResVT.getVectorNumElements();
- if (ResVT.is256BitVector())
- return concat128BitVectors(V1, V2, ResVT, NumElems, DAG, dl);
+ unsigned NumOperands = Op.getNumOperands();
+ unsigned NumZero = 0;
+ unsigned NumNonZero = 0;
+ unsigned NonZeros = 0;
+ for (unsigned i = 0; i != NumOperands; ++i) {
+ SDValue SubVec = Op.getOperand(i);
+ if (SubVec.isUndef())
+ continue;
+ if (ISD::isBuildVectorAllZeros(SubVec.getNode()))
+ ++NumZero;
+ else {
+ assert(i < sizeof(NonZeros) * CHAR_BIT); // Ensure the shift is in range.
+ NonZeros |= 1 << i;
+ ++NumNonZero;
+ }
+ }
- if (Op.getNumOperands() == 4) {
+ // If we have more than 2 non-zeros, build each half separately.
+ if (NumNonZero > 2) {
MVT HalfVT = MVT::getVectorVT(ResVT.getVectorElementType(),
ResVT.getVectorNumElements()/2);
- SDValue V3 = Op.getOperand(2);
- SDValue V4 = Op.getOperand(3);
- return concat256BitVectors(
- concat128BitVectors(V1, V2, HalfVT, NumElems / 2, DAG, dl),
- concat128BitVectors(V3, V4, HalfVT, NumElems / 2, DAG, dl), ResVT,
- NumElems, DAG, dl);
+ ArrayRef<SDUse> Ops = Op->ops();
+ SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
+ Ops.slice(0, NumOperands/2));
+ SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
+ Ops.slice(NumOperands/2));
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
+ }
+
+ // Otherwise, build it up through insert_subvectors.
+ SDValue Vec = NumZero ? getZeroVector(ResVT, Subtarget, DAG, dl)
+ : DAG.getUNDEF(ResVT);
+
+ MVT SubVT = Op.getOperand(0).getSimpleValueType();
+ unsigned NumSubElems = SubVT.getVectorNumElements();
+ for (unsigned i = 0; i != NumOperands; ++i) {
+ if ((NonZeros & (1 << i)) == 0)
+ continue;
+
+ Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec,
+ Op.getOperand(i),
+ DAG.getIntPtrConstant(i * NumSubElems, dl));
}
- return concat256BitVectors(V1, V2, ResVT, NumElems, DAG, dl);
+
+ return Vec;
}
// Return true if all the operands of the given CONCAT_VECTORS node are zeros
@@ -8314,6 +8789,7 @@ static SDValue isTypePromotionOfi1ZeroUpBits(SDValue Op) {
return SDValue();
}
+// TODO: Merge this with LowerAVXCONCAT_VECTORS?
static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
const X86Subtarget &Subtarget,
SelectionDAG & DAG) {
@@ -8328,12 +8804,8 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
// of a node with instruction that zeroes all upper (irrelevant) bits of the
// output register, mark it as legal and catch the pattern in instruction
// selection to avoid emitting extra instructions (for zeroing upper bits).
- if (SDValue Promoted = isTypePromotionOfi1ZeroUpBits(Op)) {
- SDValue ZeroC = DAG.getIntPtrConstant(0, dl);
- SDValue AllZeros = getZeroVector(ResVT, Subtarget, DAG, dl);
- return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, AllZeros, Promoted,
- ZeroC);
- }
+ if (SDValue Promoted = isTypePromotionOfi1ZeroUpBits(Op))
+ return widenSubVector(ResVT, Promoted, true, Subtarget, DAG, dl);
unsigned NumZero = 0;
unsigned NumNonZero = 0;
@@ -8404,7 +8876,7 @@ static SDValue LowerCONCAT_VECTORS(SDValue Op,
// from two other 128-bit ones.
// 512-bit vector may contain 2 256-bit vectors or 4 128-bit vectors
- return LowerAVXCONCAT_VECTORS(Op, DAG);
+ return LowerAVXCONCAT_VECTORS(Op, DAG, Subtarget);
}
//===----------------------------------------------------------------------===//
@@ -8418,7 +8890,7 @@ static SDValue LowerCONCAT_VECTORS(SDValue Op,
// patterns.
//===----------------------------------------------------------------------===//
-/// \brief Tiny helper function to identify a no-op mask.
+/// Tiny helper function to identify a no-op mask.
///
/// This is a somewhat boring predicate function. It checks whether the mask
/// array input, which is assumed to be a single-input shuffle mask of the kind
@@ -8434,7 +8906,7 @@ static bool isNoopShuffleMask(ArrayRef<int> Mask) {
return true;
}
-/// \brief Test whether there are elements crossing 128-bit lanes in this
+/// Test whether there are elements crossing 128-bit lanes in this
/// shuffle mask.
///
/// X86 divides up its shuffles into in-lane and cross-lane shuffle operations
@@ -8448,7 +8920,7 @@ static bool is128BitLaneCrossingShuffleMask(MVT VT, ArrayRef<int> Mask) {
return false;
}
-/// \brief Test whether a shuffle mask is equivalent within each sub-lane.
+/// Test whether a shuffle mask is equivalent within each sub-lane.
///
/// This checks a shuffle mask to see if it is performing the same
/// lane-relative shuffle in each sub-lane. This trivially implies
@@ -8494,6 +8966,12 @@ is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask,
return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask);
}
+static bool
+is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask) {
+ SmallVector<int, 32> RepeatedMask;
+ return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask);
+}
+
/// Test whether a shuffle mask is equivalent within each 256-bit lane.
static bool
is256BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask,
@@ -8537,7 +9015,7 @@ static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits, MVT VT,
return true;
}
-/// \brief Checks whether a shuffle mask is equivalent to an explicit list of
+/// Checks whether a shuffle mask is equivalent to an explicit list of
/// arguments.
///
/// This is a fast way to test a shuffle mask against a fixed pattern:
@@ -8634,7 +9112,7 @@ static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT) {
return IsUnpackwdMask;
}
-/// \brief Get a 4-lane 8-bit shuffle immediate for a mask.
+/// Get a 4-lane 8-bit shuffle immediate for a mask.
///
/// This helper function produces an 8-bit shuffle immediate corresponding to
/// the ubiquitous shuffle encoding scheme used in x86 instructions for
@@ -8662,7 +9140,7 @@ static SDValue getV4X86ShuffleImm8ForMask(ArrayRef<int> Mask, const SDLoc &DL,
return DAG.getConstant(getV4X86ShuffleImm(Mask), DL, MVT::i8);
}
-/// \brief Compute whether each element of a shuffle is zeroable.
+/// Compute whether each element of a shuffle is zeroable.
///
/// A "zeroable" vector shuffle element is one which can be lowered to zero.
/// Either it is an undef element in the shuffle mask, the element of the input
@@ -8859,8 +9337,8 @@ static SDValue lowerVectorShuffleToEXPAND(const SDLoc &DL, MVT VT,
static bool matchVectorShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
unsigned &UnpackOpcode, bool IsUnary,
- ArrayRef<int> TargetMask, SDLoc &DL,
- SelectionDAG &DAG,
+ ArrayRef<int> TargetMask,
+ const SDLoc &DL, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
int NumElts = VT.getVectorNumElements();
@@ -8969,6 +9447,99 @@ static SDValue lowerVectorShuffleWithUNPCK(const SDLoc &DL, MVT VT,
return SDValue();
}
+static bool matchVectorShuffleAsVPMOV(ArrayRef<int> Mask, bool SwappedOps,
+ int Delta) {
+ int Size = (int)Mask.size();
+ int Split = Size / Delta;
+ int TruncatedVectorStart = SwappedOps ? Size : 0;
+
+ // Match for mask starting with e.g.: <8, 10, 12, 14,... or <0, 2, 4, 6,...
+ if (!isSequentialOrUndefInRange(Mask, 0, Split, TruncatedVectorStart, Delta))
+ return false;
+
+ // The rest of the mask should not refer to the truncated vector's elements.
+ if (isAnyInRange(Mask.slice(Split, Size - Split), TruncatedVectorStart,
+ TruncatedVectorStart + Size))
+ return false;
+
+ return true;
+}
+
+// Try to lower trunc+vector_shuffle to a vpmovdb or a vpmovdw instruction.
+//
+// An example is the following:
+//
+// t0: ch = EntryToken
+// t2: v4i64,ch = CopyFromReg t0, Register:v4i64 %0
+// t25: v4i32 = truncate t2
+// t41: v8i16 = bitcast t25
+// t21: v8i16 = BUILD_VECTOR undef:i16, undef:i16, undef:i16, undef:i16,
+// Constant:i16<0>, Constant:i16<0>, Constant:i16<0>, Constant:i16<0>
+// t51: v8i16 = vector_shuffle<0,2,4,6,12,13,14,15> t41, t21
+// t18: v2i64 = bitcast t51
+//
+// Without avx512vl, this is lowered to:
+//
+// vpmovqd %zmm0, %ymm0
+// vpshufb {{.*#+}} xmm0 =
+// xmm0[0,1,4,5,8,9,12,13],zero,zero,zero,zero,zero,zero,zero,zero
+//
+// But when avx512vl is available, one can just use a single vpmovdw
+// instruction.
+static SDValue lowerVectorShuffleWithVPMOV(const SDLoc &DL, ArrayRef<int> Mask,
+ MVT VT, SDValue V1, SDValue V2,
+ SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (VT != MVT::v16i8 && VT != MVT::v8i16)
+ return SDValue();
+
+ if (Mask.size() != VT.getVectorNumElements())
+ return SDValue();
+
+ bool SwappedOps = false;
+
+ if (!ISD::isBuildVectorAllZeros(V2.getNode())) {
+ if (!ISD::isBuildVectorAllZeros(V1.getNode()))
+ return SDValue();
+
+ std::swap(V1, V2);
+ SwappedOps = true;
+ }
+
+ // Look for:
+ //
+ // bitcast (truncate <8 x i32> %vec to <8 x i16>) to <16 x i8>
+ // bitcast (truncate <4 x i64> %vec to <4 x i32>) to <8 x i16>
+ //
+ // and similar ones.
+ if (V1.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ if (V1.getOperand(0).getOpcode() != ISD::TRUNCATE)
+ return SDValue();
+
+ SDValue Src = V1.getOperand(0).getOperand(0);
+ MVT SrcVT = Src.getSimpleValueType();
+
+ // The vptrunc** instructions truncating 128 bit and 256 bit vectors
+ // are only available with avx512vl.
+ if (!SrcVT.is512BitVector() && !Subtarget.hasVLX())
+ return SDValue();
+
+ // Down Convert Word to Byte is only available with avx512bw. The case with
+ // 256-bit output doesn't contain a shuffle and is therefore not handled here.
+ if (SrcVT.getVectorElementType() == MVT::i16 && VT == MVT::v16i8 &&
+ !Subtarget.hasBWI())
+ return SDValue();
+
+ // The first half/quarter of the mask should refer to every second/fourth
+ // element of the vector truncated and bitcasted.
+ if (!matchVectorShuffleAsVPMOV(Mask, SwappedOps, 2) &&
+ !matchVectorShuffleAsVPMOV(Mask, SwappedOps, 4))
+ return SDValue();
+
+ return DAG.getNode(X86ISD::VTRUNC, DL, VT, Src);
+}
+
// X86 has dedicated pack instructions that can handle specific truncation
// operations: PACKSS and PACKUS.
static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1,
@@ -8984,15 +9555,6 @@ static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1,
auto MatchPACK = [&](SDValue N1, SDValue N2) {
SDValue VV1 = DAG.getBitcast(PackVT, N1);
SDValue VV2 = DAG.getBitcast(PackVT, N2);
- if ((N1.isUndef() || DAG.ComputeNumSignBits(VV1) > BitSize) &&
- (N2.isUndef() || DAG.ComputeNumSignBits(VV2) > BitSize)) {
- V1 = VV1;
- V2 = VV2;
- SrcVT = PackVT;
- PackOpcode = X86ISD::PACKSS;
- return true;
- }
-
if (Subtarget.hasSSE41() || PackSVT == MVT::i16) {
APInt ZeroMask = APInt::getHighBitsSet(BitSize * 2, BitSize);
if ((N1.isUndef() || DAG.MaskedValueIsZero(VV1, ZeroMask)) &&
@@ -9004,7 +9566,14 @@ static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1,
return true;
}
}
-
+ if ((N1.isUndef() || DAG.ComputeNumSignBits(VV1) > BitSize) &&
+ (N2.isUndef() || DAG.ComputeNumSignBits(VV2) > BitSize)) {
+ V1 = VV1;
+ V2 = VV2;
+ SrcVT = PackVT;
+ PackOpcode = X86ISD::PACKSS;
+ return true;
+ }
return false;
};
@@ -9039,7 +9608,7 @@ static SDValue lowerVectorShuffleWithPACK(const SDLoc &DL, MVT VT,
return SDValue();
}
-/// \brief Try to emit a bitmask instruction for a shuffle.
+/// Try to emit a bitmask instruction for a shuffle.
///
/// This handles cases where we can model a blend exactly as a bitmask due to
/// one of the inputs being zeroable.
@@ -9072,7 +9641,7 @@ static SDValue lowerVectorShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1,
return DAG.getNode(ISD::AND, DL, VT, V, VMask);
}
-/// \brief Try to emit a blend instruction for a shuffle using bit math.
+/// Try to emit a blend instruction for a shuffle using bit math.
///
/// This is used as a fallback approach when first class blend instructions are
/// unavailable. Currently it is only suitable for integer vectors, but could
@@ -9159,7 +9728,7 @@ static uint64_t scaleVectorShuffleBlendMask(uint64_t BlendMask, int Size,
return ScaledMask;
}
-/// \brief Try to emit a blend instruction for a shuffle.
+/// Try to emit a blend instruction for a shuffle.
///
/// This doesn't do any checks for the availability of instructions for blending
/// these values. It relies on the availability of the X86ISD::BLENDI pattern to
@@ -9305,7 +9874,7 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1,
}
}
-/// \brief Try to lower as a blend of elements from two inputs followed by
+/// Try to lower as a blend of elements from two inputs followed by
/// a single-input permutation.
///
/// This matches the pattern where we can blend elements from two inputs and
@@ -9337,7 +9906,7 @@ static SDValue lowerVectorShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT,
return DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), PermuteMask);
}
-/// \brief Generic routine to decompose a shuffle and blend into independent
+/// Generic routine to decompose a shuffle and blend into independent
/// blends and permutes.
///
/// This matches the extremely common pattern for handling combined
@@ -9378,7 +9947,7 @@ static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(const SDLoc &DL,
return DAG.getVectorShuffle(VT, DL, V1, V2, BlendMask);
}
-/// \brief Try to lower a vector shuffle as a rotation.
+/// Try to lower a vector shuffle as a rotation.
///
/// This is used for support PALIGNR for SSSE3 or VALIGND/Q for AVX512.
static int matchVectorShuffleAsRotate(SDValue &V1, SDValue &V2,
@@ -9450,7 +10019,7 @@ static int matchVectorShuffleAsRotate(SDValue &V1, SDValue &V2,
return Rotation;
}
-/// \brief Try to lower a vector shuffle as a byte rotation.
+/// Try to lower a vector shuffle as a byte rotation.
///
/// SSSE3 has a generic PALIGNR instruction in x86 that will do an arbitrary
/// byte-rotation of the concatenation of two vectors; pre-SSSE3 can use
@@ -9534,7 +10103,7 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT,
DAG.getNode(ISD::OR, DL, MVT::v16i8, LoShift, HiShift));
}
-/// \brief Try to lower a vector shuffle as a dword/qword rotation.
+/// Try to lower a vector shuffle as a dword/qword rotation.
///
/// AVX512 has a VALIGND/VALIGNQ instructions that will do an arbitrary
/// rotation of the concatenation of two vectors; This routine will
@@ -9565,7 +10134,7 @@ static SDValue lowerVectorShuffleAsRotate(const SDLoc &DL, MVT VT,
DAG.getConstant(Rotation, DL, MVT::i8));
}
-/// \brief Try to lower a vector shuffle as a bit shift (shifts in zeros).
+/// Try to lower a vector shuffle as a bit shift (shifts in zeros).
///
/// Attempts to match a shuffle mask against the PSLL(W/D/Q/DQ) and
/// PSRL(W/D/Q/DQ) SSE2 and AVX2 logical bit-shift instructions. The function
@@ -9809,7 +10378,7 @@ static bool matchVectorShuffleAsINSERTQ(MVT VT, SDValue &V1, SDValue &V2,
return false;
}
-/// \brief Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ.
+/// Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ.
static SDValue lowerVectorShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1,
SDValue V2, ArrayRef<int> Mask,
const APInt &Zeroable,
@@ -9829,7 +10398,7 @@ static SDValue lowerVectorShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1,
return SDValue();
}
-/// \brief Lower a vector shuffle as a zero or any extension.
+/// Lower a vector shuffle as a zero or any extension.
///
/// Given a specific number of elements, element bit width, and extension
/// stride, produce either a zero or any extension based on the available
@@ -9984,7 +10553,7 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend(
return DAG.getBitcast(VT, InputV);
}
-/// \brief Try to lower a vector shuffle as a zero extension on any microarch.
+/// Try to lower a vector shuffle as a zero extension on any microarch.
///
/// This routine will try to do everything in its power to cleverly lower
/// a shuffle which happens to match the pattern of a zero extend. It doesn't
@@ -10112,7 +10681,7 @@ static SDValue lowerVectorShuffleAsZeroOrAnyExtend(
return SDValue();
}
-/// \brief Try to get a scalar value for a specific element of a vector.
+/// Try to get a scalar value for a specific element of a vector.
///
/// Looks through BUILD_VECTOR and SCALAR_TO_VECTOR nodes to find a scalar.
static SDValue getScalarValueForVectorElement(SDValue V, int Idx,
@@ -10139,7 +10708,7 @@ static SDValue getScalarValueForVectorElement(SDValue V, int Idx,
return SDValue();
}
-/// \brief Helper to test for a load that can be folded with x86 shuffles.
+/// Helper to test for a load that can be folded with x86 shuffles.
///
/// This is particularly important because the set of instructions varies
/// significantly based on whether the operand is a load or not.
@@ -10148,7 +10717,7 @@ static bool isShuffleFoldableLoad(SDValue V) {
return ISD::isNON_EXTLoad(V.getNode());
}
-/// \brief Try to lower insertion of a single element into a zero vector.
+/// Try to lower insertion of a single element into a zero vector.
///
/// This is a common pattern that we have especially efficient patterns to lower
/// across all subtarget feature sets.
@@ -10239,9 +10808,7 @@ static SDValue lowerVectorShuffleAsElementInsertion(
V2 = DAG.getBitcast(MVT::v16i8, V2);
V2 = DAG.getNode(
X86ISD::VSHLDQ, DL, MVT::v16i8, V2,
- DAG.getConstant(V2Index * EltVT.getSizeInBits() / 8, DL,
- DAG.getTargetLoweringInfo().getScalarShiftAmountTy(
- DAG.getDataLayout(), VT)));
+ DAG.getConstant(V2Index * EltVT.getSizeInBits() / 8, DL, MVT::i8));
V2 = DAG.getBitcast(VT, V2);
}
}
@@ -10295,13 +10862,13 @@ static SDValue lowerVectorShuffleAsTruncBroadcast(const SDLoc &DL, MVT VT,
// vpbroadcast+vmovd+shr to vpshufb(m)+vmovd.
if (const int OffsetIdx = BroadcastIdx % Scale)
Scalar = DAG.getNode(ISD::SRL, DL, Scalar.getValueType(), Scalar,
- DAG.getConstant(OffsetIdx * EltSize, DL, Scalar.getValueType()));
+ DAG.getConstant(OffsetIdx * EltSize, DL, MVT::i8));
return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
DAG.getNode(ISD::TRUNCATE, DL, EltVT, Scalar));
}
-/// \brief Try to lower broadcast of a single element.
+/// Try to lower broadcast of a single element.
///
/// For convenience, this code also bundles all of the subtarget feature set
/// filtering. While a little annoying to re-dispatch on type here, there isn't
@@ -10626,7 +11193,7 @@ static SDValue lowerVectorShuffleAsInsertPS(const SDLoc &DL, SDValue V1,
DAG.getConstant(InsertPSMask, DL, MVT::i8));
}
-/// \brief Try to lower a shuffle as a permute of the inputs followed by an
+/// Try to lower a shuffle as a permute of the inputs followed by an
/// UNPCK instruction.
///
/// This specifically targets cases where we end up with alternating between
@@ -10738,7 +11305,7 @@ static SDValue lowerVectorShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT,
return SDValue();
}
-/// \brief Handle lowering of 2-lane 64-bit floating point shuffles.
+/// Handle lowering of 2-lane 64-bit floating point shuffles.
///
/// This is the basis function for the 2-lane 64-bit shuffles as we have full
/// support for floating point shuffles but not integer shuffles. These
@@ -10777,22 +11344,23 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
Mask[1] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1,
DAG.getConstant(SHUFPDMask, DL, MVT::i8));
}
- assert(Mask[0] >= 0 && Mask[0] < 2 && "Non-canonicalized blend!");
- assert(Mask[1] >= 2 && "Non-canonicalized blend!");
+ assert(Mask[0] >= 0 && "No undef lanes in multi-input v2 shuffles!");
+ assert(Mask[1] >= 0 && "No undef lanes in multi-input v2 shuffles!");
+ assert(Mask[0] < 2 && "We sort V1 to be the first input.");
+ assert(Mask[1] >= 2 && "We sort V2 to be the second input.");
- // If we have a single input, insert that into V1 if we can do so cheaply.
- if ((Mask[0] >= 2) + (Mask[1] >= 2) == 1) {
- if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
- DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG))
- return Insertion;
- // Try inverting the insertion since for v2 masks it is easy to do and we
- // can't reliably sort the mask one way or the other.
- int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2),
- Mask[1] < 0 ? -1 : (Mask[1] ^ 2)};
- if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
- DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG))
- return Insertion;
- }
+ // When loading a scalar and then shuffling it into a vector we can often do
+ // the insertion cheaply.
+ if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
+ DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG))
+ return Insertion;
+ // Try inverting the insertion since for v2 masks it is easy to do and we
+ // can't reliably sort the mask one way or the other.
+ int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2),
+ Mask[1] < 0 ? -1 : (Mask[1] ^ 2)};
+ if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
+ DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG))
+ return Insertion;
// Try to use one of the special instruction patterns to handle two common
// blend patterns if a zero-blend above didn't work.
@@ -10802,8 +11370,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
// We can either use a special instruction to load over the low double or
// to move just the low double.
return DAG.getNode(
- isShuffleFoldableLoad(V1S) ? X86ISD::MOVLPD : X86ISD::MOVSD,
- DL, MVT::v2f64, V2,
+ X86ISD::MOVSD, DL, MVT::v2f64, V2,
DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64, V1S));
if (Subtarget.hasSSE41())
@@ -10821,7 +11388,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
DAG.getConstant(SHUFPDMask, DL, MVT::i8));
}
-/// \brief Handle lowering of 2-lane 64-bit integer shuffles.
+/// Handle lowering of 2-lane 64-bit integer shuffles.
///
/// Tries to lower a 2-lane 64-bit shuffle using shuffle operations provided by
/// the integer unit to minimize domain crossing penalties. However, for blends
@@ -10918,7 +11485,7 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
DAG.getVectorShuffle(MVT::v2f64, DL, V1, V2, Mask));
}
-/// \brief Test whether this can be lowered with a single SHUFPS instruction.
+/// Test whether this can be lowered with a single SHUFPS instruction.
///
/// This is used to disable more specialized lowerings when the shufps lowering
/// will happen to be efficient.
@@ -10940,7 +11507,7 @@ static bool isSingleSHUFPSMask(ArrayRef<int> Mask) {
return true;
}
-/// \brief Lower a vector shuffle using the SHUFPS instruction.
+/// Lower a vector shuffle using the SHUFPS instruction.
///
/// This is a helper routine dedicated to lowering vector shuffles using SHUFPS.
/// It makes no assumptions about whether this is the *best* lowering, it simply
@@ -11027,7 +11594,7 @@ static SDValue lowerVectorShuffleWithSHUFPS(const SDLoc &DL, MVT VT,
getV4X86ShuffleImm8ForMask(NewMask, DL, DAG));
}
-/// \brief Lower 4-lane 32-bit floating point shuffles.
+/// Lower 4-lane 32-bit floating point shuffles.
///
/// Uses instructions exclusively from the floating point unit to minimize
/// domain crossing penalties, as these are sufficient to implement all v4f32
@@ -11123,7 +11690,7 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithSHUFPS(DL, MVT::v4f32, Mask, V1, V2, DAG);
}
-/// \brief Lower 4-lane i32 vector shuffles.
+/// Lower 4-lane i32 vector shuffles.
///
/// We try to handle these with integer-domain shuffles where we can, but for
/// blends we use the floating point domain blend instructions.
@@ -11235,7 +11802,7 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return DAG.getBitcast(MVT::v4i32, ShufPS);
}
-/// \brief Lowering of single-input v8i16 shuffles is the cornerstone of SSE2
+/// Lowering of single-input v8i16 shuffles is the cornerstone of SSE2
/// shuffle lowering, and the most complex part.
///
/// The lowering strategy is to try to form pairs of input lanes which are
@@ -11261,13 +11828,27 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle(
MutableArrayRef<int> LoMask = Mask.slice(0, 4);
MutableArrayRef<int> HiMask = Mask.slice(4, 4);
+ // Attempt to directly match PSHUFLW or PSHUFHW.
+ if (isUndefOrInRange(LoMask, 0, 4) &&
+ isSequentialOrUndefInRange(HiMask, 0, 4, 4)) {
+ return DAG.getNode(X86ISD::PSHUFLW, DL, VT, V,
+ getV4X86ShuffleImm8ForMask(LoMask, DL, DAG));
+ }
+ if (isUndefOrInRange(HiMask, 4, 8) &&
+ isSequentialOrUndefInRange(LoMask, 0, 4, 0)) {
+ for (int i = 0; i != 4; ++i)
+ HiMask[i] = (HiMask[i] < 0 ? HiMask[i] : (HiMask[i] - 4));
+ return DAG.getNode(X86ISD::PSHUFHW, DL, VT, V,
+ getV4X86ShuffleImm8ForMask(HiMask, DL, DAG));
+ }
+
SmallVector<int, 4> LoInputs;
copy_if(LoMask, std::back_inserter(LoInputs), [](int M) { return M >= 0; });
- std::sort(LoInputs.begin(), LoInputs.end());
+ array_pod_sort(LoInputs.begin(), LoInputs.end());
LoInputs.erase(std::unique(LoInputs.begin(), LoInputs.end()), LoInputs.end());
SmallVector<int, 4> HiInputs;
copy_if(HiMask, std::back_inserter(HiInputs), [](int M) { return M >= 0; });
- std::sort(HiInputs.begin(), HiInputs.end());
+ array_pod_sort(HiInputs.begin(), HiInputs.end());
HiInputs.erase(std::unique(HiInputs.begin(), HiInputs.end()), HiInputs.end());
int NumLToL =
std::lower_bound(LoInputs.begin(), LoInputs.end(), 4) - LoInputs.begin();
@@ -11280,13 +11861,11 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle(
MutableArrayRef<int> HToLInputs(LoInputs.data() + NumLToL, NumHToL);
MutableArrayRef<int> HToHInputs(HiInputs.data() + NumLToH, NumHToH);
- // If we are splatting two values from one half - one to each half, then
- // we can shuffle that half so each is splatted to a dword, then splat those
- // to their respective halves.
- auto SplatHalfs = [&](int LoInput, int HiInput, unsigned ShufWOp,
- int DOffset) {
- int PSHUFHalfMask[] = {LoInput % 4, LoInput % 4, HiInput % 4, HiInput % 4};
- int PSHUFDMask[] = {DOffset + 0, DOffset + 0, DOffset + 1, DOffset + 1};
+ // If we are shuffling values from one half - check how many different DWORD
+ // pairs we need to create. If only 1 or 2 then we can perform this as a
+ // PSHUFLW/PSHUFHW + PSHUFD instead of the PSHUFD+PSHUFLW+PSHUFHW chain below.
+ auto ShuffleDWordPairs = [&](ArrayRef<int> PSHUFHalfMask,
+ ArrayRef<int> PSHUFDMask, unsigned ShufWOp) {
V = DAG.getNode(ShufWOp, DL, VT, V,
getV4X86ShuffleImm8ForMask(PSHUFHalfMask, DL, DAG));
V = DAG.getBitcast(PSHUFDVT, V);
@@ -11295,10 +11874,48 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle(
return DAG.getBitcast(VT, V);
};
- if (NumLToL == 1 && NumLToH == 1 && (NumHToL + NumHToH) == 0)
- return SplatHalfs(LToLInputs[0], LToHInputs[0], X86ISD::PSHUFLW, 0);
- if (NumHToL == 1 && NumHToH == 1 && (NumLToL + NumLToH) == 0)
- return SplatHalfs(HToLInputs[0], HToHInputs[0], X86ISD::PSHUFHW, 2);
+ if ((NumHToL + NumHToH) == 0 || (NumLToL + NumLToH) == 0) {
+ int PSHUFDMask[4] = { -1, -1, -1, -1 };
+ SmallVector<std::pair<int, int>, 4> DWordPairs;
+ int DOffset = ((NumHToL + NumHToH) == 0 ? 0 : 2);
+
+ // Collect the different DWORD pairs.
+ for (int DWord = 0; DWord != 4; ++DWord) {
+ int M0 = Mask[2 * DWord + 0];
+ int M1 = Mask[2 * DWord + 1];
+ M0 = (M0 >= 0 ? M0 % 4 : M0);
+ M1 = (M1 >= 0 ? M1 % 4 : M1);
+ if (M0 < 0 && M1 < 0)
+ continue;
+
+ bool Match = false;
+ for (int j = 0, e = DWordPairs.size(); j < e; ++j) {
+ auto &DWordPair = DWordPairs[j];
+ if ((M0 < 0 || isUndefOrEqual(DWordPair.first, M0)) &&
+ (M1 < 0 || isUndefOrEqual(DWordPair.second, M1))) {
+ DWordPair.first = (M0 >= 0 ? M0 : DWordPair.first);
+ DWordPair.second = (M1 >= 0 ? M1 : DWordPair.second);
+ PSHUFDMask[DWord] = DOffset + j;
+ Match = true;
+ break;
+ }
+ }
+ if (!Match) {
+ PSHUFDMask[DWord] = DOffset + DWordPairs.size();
+ DWordPairs.push_back(std::make_pair(M0, M1));
+ }
+ }
+
+ if (DWordPairs.size() <= 2) {
+ DWordPairs.resize(2, std::make_pair(-1, -1));
+ int PSHUFHalfMask[4] = {DWordPairs[0].first, DWordPairs[0].second,
+ DWordPairs[1].first, DWordPairs[1].second};
+ if ((NumHToL + NumHToH) == 0)
+ return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFLW);
+ if ((NumLToL + NumLToH) == 0)
+ return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFHW);
+ }
+ }
// Simplify the 1-into-3 and 3-into-1 cases with a single pshufd. For all
// such inputs we can swap two of the dwords across the half mark and end up
@@ -11750,7 +12367,7 @@ static SDValue lowerVectorShuffleAsBlendOfPSHUFBs(
return DAG.getBitcast(VT, V);
}
-/// \brief Generic lowering of 8-lane i16 shuffles.
+/// Generic lowering of 8-lane i16 shuffles.
///
/// This handles both single-input shuffles and combined shuffle/blends with
/// two inputs. The single input shuffles are immediately delegated to
@@ -11883,7 +12500,7 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
Mask, DAG);
}
-/// \brief Check whether a compaction lowering can be done by dropping even
+/// Check whether a compaction lowering can be done by dropping even
/// elements and compute how many times even elements must be dropped.
///
/// This handles shuffles which take every Nth element where N is a power of
@@ -11962,7 +12579,7 @@ static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT,
return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2);
}
-/// \brief Generic lowering of v16i8 shuffles.
+/// Generic lowering of v16i8 shuffles.
///
/// This is a hybrid strategy to lower v16i8 vectors. It first attempts to
/// detect any complexity reducing interleaving. If that doesn't help, it uses
@@ -12034,12 +12651,12 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
SmallVector<int, 4> LoInputs;
copy_if(Mask, std::back_inserter(LoInputs),
[](int M) { return M >= 0 && M < 8; });
- std::sort(LoInputs.begin(), LoInputs.end());
+ array_pod_sort(LoInputs.begin(), LoInputs.end());
LoInputs.erase(std::unique(LoInputs.begin(), LoInputs.end()),
LoInputs.end());
SmallVector<int, 4> HiInputs;
copy_if(Mask, std::back_inserter(HiInputs), [](int M) { return M >= 8; });
- std::sort(HiInputs.begin(), HiInputs.end());
+ array_pod_sort(HiInputs.begin(), HiInputs.end());
HiInputs.erase(std::unique(HiInputs.begin(), HiInputs.end()),
HiInputs.end());
@@ -12262,7 +12879,7 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, LoV, HiV);
}
-/// \brief Dispatching routine to lower various 128-bit x86 vector shuffles.
+/// Dispatching routine to lower various 128-bit x86 vector shuffles.
///
/// This routine breaks down the specific type of 128-bit shuffle and
/// dispatches to the lowering routines accordingly.
@@ -12290,7 +12907,7 @@ static SDValue lower128BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
}
}
-/// \brief Generic routine to split vector shuffle into half-sized shuffles.
+/// Generic routine to split vector shuffle into half-sized shuffles.
///
/// This routine just extracts two subvectors, shuffles them independently, and
/// then concatenates them back together. This should work effectively with all
@@ -12413,7 +13030,7 @@ static SDValue splitAndLowerVectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
}
-/// \brief Either split a vector in halves or decompose the shuffles and the
+/// Either split a vector in halves or decompose the shuffles and the
/// blend.
///
/// This is provided as a good fallback for many lowerings of non-single-input
@@ -12471,7 +13088,7 @@ static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT,
return lowerVectorShuffleAsDecomposedShuffleBlend(DL, VT, V1, V2, Mask, DAG);
}
-/// \brief Lower a vector shuffle crossing multiple 128-bit lanes as
+/// Lower a vector shuffle crossing multiple 128-bit lanes as
/// a permutation and blend of those lanes.
///
/// This essentially blends the out-of-lane inputs to each lane into the lane
@@ -12529,7 +13146,7 @@ static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT,
return DAG.getVectorShuffle(VT, DL, V1, Flipped, FlippedBlendMask);
}
-/// \brief Handle lowering 2-lane 128-bit shuffles.
+/// Handle lowering 2-lane 128-bit shuffles.
static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
SDValue V2, ArrayRef<int> Mask,
const APInt &Zeroable,
@@ -12540,9 +13157,22 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
return SDValue();
SmallVector<int, 4> WidenedMask;
- if (!canWidenShuffleElements(Mask, WidenedMask))
+ if (!canWidenShuffleElements(Mask, Zeroable, WidenedMask))
return SDValue();
+ bool IsLowZero = (Zeroable & 0x3) == 0x3;
+ bool IsHighZero = (Zeroable & 0xc) == 0xc;
+
+ // Try to use an insert into a zero vector.
+ if (WidenedMask[0] == 0 && IsHighZero) {
+ MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2);
+ SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
+ DAG.getIntPtrConstant(0, DL));
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
+ getZeroVector(VT, Subtarget, DAG, DL), LoV,
+ DAG.getIntPtrConstant(0, DL));
+ }
+
// TODO: If minimizing size and one of the inputs is a zero vector and the
// the zero vector has only one use, we could use a VPERM2X128 to save the
// instruction bytes needed to explicitly generate the zero vector.
@@ -12552,9 +13182,6 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
Zeroable, Subtarget, DAG))
return Blend;
- bool IsLowZero = (Zeroable & 0x3) == 0x3;
- bool IsHighZero = (Zeroable & 0xc) == 0xc;
-
// If either input operand is a zero vector, use VPERM2X128 because its mask
// allows us to replace the zero input with an implicit zero.
if (!IsLowZero && !IsHighZero) {
@@ -12566,14 +13193,12 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
// With AVX1, use vperm2f128 (below) to allow load folding. Otherwise,
// this will likely become vinsertf128 which can't fold a 256-bit memop.
if (!isa<LoadSDNode>(peekThroughBitcasts(V1))) {
- MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(),
- VT.getVectorNumElements() / 2);
- SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
- DAG.getIntPtrConstant(0, DL));
- SDValue HiV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT,
- OnlyUsesV1 ? V1 : V2,
- DAG.getIntPtrConstant(0, DL));
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV);
+ MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2);
+ SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT,
+ OnlyUsesV1 ? V1 : V2,
+ DAG.getIntPtrConstant(0, DL));
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec,
+ DAG.getIntPtrConstant(2, DL));
}
}
@@ -12601,7 +13226,8 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
// [6] - ignore
// [7] - zero high half of destination
- assert(WidenedMask[0] >= 0 && WidenedMask[1] >= 0 && "Undef half?");
+ assert((WidenedMask[0] >= 0 || IsLowZero) &&
+ (WidenedMask[1] >= 0 || IsHighZero) && "Undef half?");
unsigned PermMask = 0;
PermMask |= IsLowZero ? 0x08 : (WidenedMask[0] << 0);
@@ -12617,7 +13243,7 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1,
DAG.getConstant(PermMask, DL, MVT::i8));
}
-/// \brief Lower a vector shuffle by first fixing the 128-bit lanes and then
+/// Lower a vector shuffle by first fixing the 128-bit lanes and then
/// shuffling each lane.
///
/// This will only succeed when the result of fixing the 128-bit lanes results
@@ -12820,7 +13446,7 @@ static SDValue lowerVectorShuffleWithUndefHalf(const SDLoc &DL, MVT VT,
DAG.getIntPtrConstant(Offset, DL));
}
-/// \brief Test whether the specified input (0 or 1) is in-place blended by the
+/// Test whether the specified input (0 or 1) is in-place blended by the
/// given mask.
///
/// This returns true if the elements from a particular input are already in the
@@ -13056,7 +13682,7 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT,
DAG.getConstant(Immediate, DL, MVT::i8));
}
-/// \brief Handle lowering of 4-lane 64-bit floating point shuffles.
+/// Handle lowering of 4-lane 64-bit floating point shuffles.
///
/// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2
/// isn't available.
@@ -13098,7 +13724,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -13123,7 +13749,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return Op;
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -13153,7 +13779,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v4f64, V1, V2, Mask, DAG);
}
-/// \brief Handle lowering of 4-lane 64-bit integer shuffles.
+/// Handle lowering of 4-lane 64-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
/// instruction set for v4i64 shuffling..
@@ -13226,6 +13852,12 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
lowerVectorShuffleWithUNPCK(DL, MVT::v4i64, Mask, V1, V2, DAG))
return V;
+ // Try to create an in-lane repeating shuffle mask and then shuffle the
+ // results into the target lanes.
+ if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
+ DL, MVT::v4i64, V1, V2, Mask, Subtarget, DAG))
+ return V;
+
// Try to simplify this by merging 128-bit lanes to enable a lane-based
// shuffle. However, if we have AVX2 and either inputs are already in place,
// we will be able to shuffle even across lanes the other input in a single
@@ -13241,7 +13873,7 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
Mask, DAG);
}
-/// \brief Handle lowering of 8-lane 32-bit floating point shuffles.
+/// Handle lowering of 8-lane 32-bit floating point shuffles.
///
/// Also ends up handling lowering of 8-lane 32-bit integer shuffles when AVX2
/// isn't available.
@@ -13291,7 +13923,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
}
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v8f32, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -13340,7 +13972,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask, DAG);
}
-/// \brief Handle lowering of 8-lane 32-bit integer shuffles.
+/// Handle lowering of 8-lane 32-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
/// instruction set for v8i32 shuffling..
@@ -13453,7 +14085,7 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
Mask, DAG);
}
-/// \brief Handle lowering of 16-lane 16-bit integer shuffles.
+/// Handle lowering of 16-lane 16-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
/// instruction set for v16i16 shuffling..
@@ -13504,7 +14136,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return Rotate;
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v16i16, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -13544,7 +14176,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v16i16, V1, V2, Mask, DAG);
}
-/// \brief Handle lowering of 32-lane 8-bit integer shuffles.
+/// Handle lowering of 32-lane 8-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
/// instruction set for v32i8 shuffling..
@@ -13595,7 +14227,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return Rotate;
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v32i8, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -13624,7 +14256,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v32i8, V1, V2, Mask, DAG);
}
-/// \brief High-level routine to lower various 256-bit x86 vector shuffles.
+/// High-level routine to lower various 256-bit x86 vector shuffles.
///
/// This routine either breaks down the specific type of a 256-bit x86 vector
/// shuffle or splits it into two 128-bit shuffles and fuses the results back
@@ -13694,10 +14326,13 @@ static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
}
}
-/// \brief Try to lower a vector shuffle as a 128-bit shuffles.
+/// Try to lower a vector shuffle as a 128-bit shuffles.
static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT,
- ArrayRef<int> Mask, SDValue V1,
- SDValue V2, SelectionDAG &DAG) {
+ ArrayRef<int> Mask,
+ const APInt &Zeroable,
+ SDValue V1, SDValue V2,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
assert(VT.getScalarSizeInBits() == 64 &&
"Unexpected element type size for 128bit shuffle.");
@@ -13705,10 +14340,23 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT,
// function lowerV2X128VectorShuffle() is better solution.
assert(VT.is512BitVector() && "Unexpected vector size for 512bit shuffle.");
+ // TODO - use Zeroable like we do for lowerV2X128VectorShuffle?
SmallVector<int, 4> WidenedMask;
if (!canWidenShuffleElements(Mask, WidenedMask))
return SDValue();
+ // Try to use an insert into a zero vector.
+ if (WidenedMask[0] == 0 && (Zeroable & 0xf0) == 0xf0 &&
+ (WidenedMask[1] == 1 || (Zeroable & 0x0c) == 0x0c)) {
+ unsigned NumElts = ((Zeroable & 0x0c) == 0x0c) ? 2 : 4;
+ MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
+ SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
+ DAG.getIntPtrConstant(0, DL));
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
+ getZeroVector(VT, Subtarget, DAG, DL), LoV,
+ DAG.getIntPtrConstant(0, DL));
+ }
+
// Check for patterns which can be matched with a single insert of a 256-bit
// subvector.
bool OnlyUsesV1 = isShuffleEquivalent(V1, V2, Mask,
@@ -13716,12 +14364,11 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT,
if (OnlyUsesV1 || isShuffleEquivalent(V1, V2, Mask,
{0, 1, 2, 3, 8, 9, 10, 11})) {
MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 4);
- SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
- DAG.getIntPtrConstant(0, DL));
- SDValue HiV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT,
- OnlyUsesV1 ? V1 : V2,
+ SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT,
+ OnlyUsesV1 ? V1 : V2,
DAG.getIntPtrConstant(0, DL));
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec,
+ DAG.getIntPtrConstant(4, DL));
}
assert(WidenedMask.size() == 4);
@@ -13756,7 +14403,7 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT,
return insert128BitVector(V1, Subvec, V2Index * 2, DAG, DL);
}
- // Try to lower to to vshuf64x2/vshuf32x4.
+ // Try to lower to vshuf64x2/vshuf32x4.
SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)};
unsigned PermMask = 0;
// Insure elements came from the same Op.
@@ -13781,7 +14428,7 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT,
DAG.getConstant(PermMask, DL, MVT::i8));
}
-/// \brief Handle lowering of 8-lane 64-bit floating point shuffles.
+/// Handle lowering of 8-lane 64-bit floating point shuffles.
static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -13814,7 +14461,8 @@ static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
}
if (SDValue Shuf128 =
- lowerV4X128VectorShuffle(DL, MVT::v8f64, Mask, V1, V2, DAG))
+ lowerV4X128VectorShuffle(DL, MVT::v8f64, Mask, Zeroable, V1, V2,
+ Subtarget, DAG))
return Shuf128;
if (SDValue Unpck =
@@ -13837,7 +14485,7 @@ static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v8f64, Mask, V1, V2, DAG);
}
-/// \brief Handle lowering of 16-lane 32-bit floating point shuffles.
+/// Handle lowering of 16-lane 32-bit floating point shuffles.
static SDValue lowerV16F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -13892,7 +14540,7 @@ static SDValue lowerV16F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v16f32, Mask, V1, V2, DAG);
}
-/// \brief Handle lowering of 8-lane 64-bit integer shuffles.
+/// Handle lowering of 8-lane 64-bit integer shuffles.
static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -13924,7 +14572,8 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
}
if (SDValue Shuf128 =
- lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, V1, V2, DAG))
+ lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, Zeroable,
+ V1, V2, Subtarget, DAG))
return Shuf128;
// Try to use shift instructions.
@@ -13957,7 +14606,7 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, DAG);
}
-/// \brief Handle lowering of 16-lane 32-bit integer shuffles.
+/// Handle lowering of 16-lane 32-bit integer shuffles.
static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -14028,7 +14677,7 @@ static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v16i32, Mask, V1, V2, DAG);
}
-/// \brief Handle lowering of 32-lane 16-bit integer shuffles.
+/// Handle lowering of 32-lane 16-bit integer shuffles.
static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -14083,7 +14732,7 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG);
}
-/// \brief Handle lowering of 64-lane 8-bit integer shuffles.
+/// Handle lowering of 64-lane 8-bit integer shuffles.
static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
SDValue V1, SDValue V2,
@@ -14125,7 +14774,7 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerVectorShuffleWithPERMV(DL, MVT::v64i8, Mask, V1, V2, DAG);
// Try to create an in-lane repeating shuffle mask and then shuffle the
- // the results into the target lanes.
+ // results into the target lanes.
if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG))
return V;
@@ -14138,7 +14787,7 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return splitAndLowerVectorShuffle(DL, MVT::v64i8, V1, V2, Mask, DAG);
}
-/// \brief High-level routine to lower various 512-bit x86 vector shuffles.
+/// High-level routine to lower various 512-bit x86 vector shuffles.
///
/// This routine either breaks down the specific type of a 512-bit x86 vector
/// shuffle or splits it into two 256-bit shuffles and fuses the results back
@@ -14200,8 +14849,36 @@ static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
// vector, shuffle and then truncate it back.
static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
MVT VT, SDValue V1, SDValue V2,
+ const APInt &Zeroable,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
+ unsigned NumElts = Mask.size();
+
+ // Try to recognize shuffles that are just padding a subvector with zeros.
+ unsigned SubvecElts = 0;
+ for (int i = 0; i != (int)NumElts; ++i) {
+ if (Mask[i] >= 0 && Mask[i] != i)
+ break;
+
+ ++SubvecElts;
+ }
+ assert(SubvecElts != NumElts && "Identity shuffle?");
+
+ // Clip to a power 2.
+ SubvecElts = PowerOf2Floor(SubvecElts);
+
+ // Make sure the number of zeroable bits in the top at least covers the bits
+ // not covered by the subvector.
+ if (Zeroable.countLeadingOnes() >= (NumElts - SubvecElts)) {
+ MVT ExtractVT = MVT::getVectorVT(MVT::i1, SubvecElts);
+ SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT,
+ V1, DAG.getIntPtrConstant(0, DL));
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
+ getZeroVector(VT, Subtarget, DAG, DL),
+ Extract, DAG.getIntPtrConstant(0, DL));
+ }
+
+
assert(Subtarget.hasAVX512() &&
"Cannot lower 512-bit vectors w/o basic ISA!");
MVT ExtVT;
@@ -14220,38 +14897,31 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask,
ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64;
break;
case MVT::v16i1:
- ExtVT = MVT::v16i32;
+ // Take 512-bit type, unless we are avoiding 512-bit types and have the
+ // 256-bit operation available.
+ ExtVT = Subtarget.canExtendTo512DQ() ? MVT::v16i32 : MVT::v16i16;
break;
case MVT::v32i1:
- ExtVT = MVT::v32i16;
+ // Take 512-bit type, unless we are avoiding 512-bit types and have the
+ // 256-bit operation available.
+ assert(Subtarget.hasBWI() && "Expected AVX512BW support");
+ ExtVT = Subtarget.canExtendTo512BW() ? MVT::v32i16 : MVT::v32i8;
break;
case MVT::v64i1:
ExtVT = MVT::v64i8;
break;
}
- if (ISD::isBuildVectorAllZeros(V1.getNode()))
- V1 = getZeroVector(ExtVT, Subtarget, DAG, DL);
- else if (ISD::isBuildVectorAllOnes(V1.getNode()))
- V1 = getOnesVector(ExtVT, DAG, DL);
- else
- V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V1);
-
- if (V2.isUndef())
- V2 = DAG.getUNDEF(ExtVT);
- else if (ISD::isBuildVectorAllZeros(V2.getNode()))
- V2 = getZeroVector(ExtVT, Subtarget, DAG, DL);
- else if (ISD::isBuildVectorAllOnes(V2.getNode()))
- V2 = getOnesVector(ExtVT, DAG, DL);
- else
- V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2);
+ V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V1);
+ V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2);
SDValue Shuffle = DAG.getVectorShuffle(ExtVT, DL, V1, V2, Mask);
// i1 was sign extended we can use X86ISD::CVT2MASK.
int NumElems = VT.getVectorNumElements();
if ((Subtarget.hasBWI() && (NumElems >= 32)) ||
(Subtarget.hasDQI() && (NumElems < 32)))
- return DAG.getNode(X86ISD::CVT2MASK, DL, VT, Shuffle);
+ return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, ExtVT),
+ Shuffle, ISD::SETGT);
return DAG.getNode(ISD::TRUNCATE, DL, VT, Shuffle);
}
@@ -14320,7 +14990,7 @@ static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) {
return false;
}
-/// \brief Top-level lowering for x86 vector shuffles.
+/// Top-level lowering for x86 vector shuffles.
///
/// This handles decomposition, canonicalization, and lowering of all x86
/// vector shuffles. Most of the specific lowering strategies are encapsulated
@@ -14378,20 +15048,49 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget,
if (Zeroable.isAllOnesValue())
return getZeroVector(VT, Subtarget, DAG, DL);
+ bool V2IsZero = !V2IsUndef && ISD::isBuildVectorAllZeros(V2.getNode());
+
+ // Create an alternative mask with info about zeroable elements.
+ // Here we do not set undef elements as zeroable.
+ SmallVector<int, 64> ZeroableMask(Mask.begin(), Mask.end());
+ if (V2IsZero) {
+ assert(!Zeroable.isNullValue() && "V2's non-undef elements are used?!");
+ for (int i = 0; i != NumElements; ++i)
+ if (Mask[i] != SM_SentinelUndef && Zeroable[i])
+ ZeroableMask[i] = SM_SentinelZero;
+ }
+
// Try to collapse shuffles into using a vector type with fewer elements but
// wider element types. We cap this to not form integers or floating point
// elements wider than 64 bits, but it might be interesting to form i128
// integers to handle flipping the low and high halves of AVX 256-bit vectors.
SmallVector<int, 16> WidenedMask;
if (VT.getScalarSizeInBits() < 64 && !Is1BitVector &&
- canWidenShuffleElements(Mask, WidenedMask)) {
+ canWidenShuffleElements(ZeroableMask, WidenedMask)) {
MVT NewEltVT = VT.isFloatingPoint()
? MVT::getFloatingPointVT(VT.getScalarSizeInBits() * 2)
: MVT::getIntegerVT(VT.getScalarSizeInBits() * 2);
- MVT NewVT = MVT::getVectorVT(NewEltVT, VT.getVectorNumElements() / 2);
+ int NewNumElts = NumElements / 2;
+ MVT NewVT = MVT::getVectorVT(NewEltVT, NewNumElts);
// Make sure that the new vector type is legal. For example, v2f64 isn't
// legal on SSE1.
if (DAG.getTargetLoweringInfo().isTypeLegal(NewVT)) {
+ if (V2IsZero) {
+ // Modify the new Mask to take all zeros from the all-zero vector.
+ // Choose indices that are blend-friendly.
+ bool UsedZeroVector = false;
+ assert(find(WidenedMask, SM_SentinelZero) != WidenedMask.end() &&
+ "V2's non-undef elements are used?!");
+ for (int i = 0; i != NewNumElts; ++i)
+ if (WidenedMask[i] == SM_SentinelZero) {
+ WidenedMask[i] = i + NewNumElts;
+ UsedZeroVector = true;
+ }
+ // Ensure all elements of V2 are zero - isBuildVectorAllZeros permits
+ // some elements to be undef.
+ if (UsedZeroVector)
+ V2 = getZeroVector(NewVT, Subtarget, DAG, DL);
+ }
V1 = DAG.getBitcast(NewVT, V1);
V2 = DAG.getBitcast(NewVT, V2);
return DAG.getBitcast(
@@ -14403,6 +15102,10 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget,
if (canonicalizeShuffleMaskWithCommute(Mask))
return DAG.getCommutedVectorShuffle(*SVOp);
+ if (SDValue V =
+ lowerVectorShuffleWithVPMOV(DL, Mask, VT, V1, V2, DAG, Subtarget))
+ return V;
+
// For each vector width, delegate to a specialized lowering routine.
if (VT.is128BitVector())
return lower128BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget,
@@ -14417,12 +15120,13 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget,
DAG);
if (Is1BitVector)
- return lower1BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG);
+ return lower1BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget,
+ DAG);
llvm_unreachable("Unimplemented!");
}
-/// \brief Try to lower a VSELECT instruction to a vector shuffle.
+/// Try to lower a VSELECT instruction to a vector shuffle.
static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
@@ -14441,9 +15145,12 @@ static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
SmallVector<int, 32> Mask;
for (int i = 0, Size = VT.getVectorNumElements(); i < Size; ++i) {
SDValue CondElt = CondBV->getOperand(i);
- Mask.push_back(
- isa<ConstantSDNode>(CondElt) ? i + (isNullConstant(CondElt) ? Size : 0)
- : -1);
+ int M = i;
+ // We can't map undef to undef here. They have different meanings. Treat
+ // as the same as zero.
+ if (CondElt.isUndef() || isNullConstant(CondElt))
+ M += Size;
+ Mask.push_back(M);
}
return DAG.getVectorShuffle(VT, dl, LHS, RHS, Mask);
}
@@ -14483,9 +15190,11 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const {
assert(Cond.getValueType().getScalarSizeInBits() ==
VT.getScalarSizeInBits() &&
"Should have a size-matched integer condition!");
- // Build a mask by testing the condition against itself (tests for zero).
+ // Build a mask by testing the condition against zero.
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
- SDValue Mask = DAG.getNode(X86ISD::TESTM, dl, MaskVT, Cond, Cond);
+ SDValue Mask = DAG.getSetCC(dl, MaskVT, Cond,
+ getZeroVector(VT, Subtarget, DAG, dl),
+ ISD::SETNE);
// Now return a new VSELECT using the mask.
return DAG.getSelect(dl, VT, Mask, Op.getOperand(1), Op.getOperand(2));
}
@@ -14506,10 +15215,15 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const {
return SDValue();
case MVT::v8i16:
- case MVT::v16i16:
- // FIXME: We should custom lower this by fixing the condition and using i8
- // blends.
- return SDValue();
+ case MVT::v16i16: {
+ // Bitcast everything to the vXi8 type and use a vXi8 vselect.
+ MVT CastVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2);
+ SDValue Cond = DAG.getBitcast(CastVT, Op->getOperand(0));
+ SDValue LHS = DAG.getBitcast(CastVT, Op->getOperand(1));
+ SDValue RHS = DAG.getBitcast(CastVT, Op->getOperand(2));
+ SDValue Select = DAG.getNode(ISD::VSELECT, dl, CastVT, Cond, LHS, RHS);
+ return DAG.getBitcast(VT, Select);
+ }
}
}
@@ -14581,36 +15295,35 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
}
- // Canonicalize result type to MVT::i32.
- if (EltVT != MVT::i32) {
- SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
- Vec, Idx);
- return DAG.getAnyExtOrTrunc(Extract, dl, EltVT);
- }
-
unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue();
- // Extracts from element 0 are always allowed.
- if (IdxVal == 0)
- return Op;
-
// If the kshift instructions of the correct width aren't natively supported
// then we need to promote the vector to the native size to get the correct
// zeroing behavior.
- if ((!Subtarget.hasDQI() && (VecVT.getVectorNumElements() == 8)) ||
- (VecVT.getVectorNumElements() < 8)) {
+ if (VecVT.getVectorNumElements() < 16) {
VecVT = MVT::v16i1;
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT,
- DAG.getUNDEF(VecVT),
- Vec,
+ Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
+ DAG.getUNDEF(VecVT), Vec,
DAG.getIntPtrConstant(0, dl));
}
- // Use kshiftr instruction to move to the lower element.
- Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec,
- DAG.getConstant(IdxVal, dl, MVT::i8));
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, Vec,
- DAG.getIntPtrConstant(0, dl));
+ // Extracts from element 0 are always allowed.
+ if (IdxVal != 0) {
+ // Use kshiftr instruction to move to the lower element.
+ Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec,
+ DAG.getConstant(IdxVal, dl, MVT::i8));
+ }
+
+ // Shrink to v16i1 since that's always legal.
+ if (VecVT.getVectorNumElements() > 16) {
+ VecVT = MVT::v16i1;
+ Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Vec,
+ DAG.getIntPtrConstant(0, dl));
+ }
+
+ // Convert to a bitcast+aext/trunc.
+ MVT CastVT = MVT::getIntegerVT(VecVT.getVectorNumElements());
+ return DAG.getAnyExtOrTrunc(DAG.getBitcast(CastVT, Vec), dl, EltVT);
}
SDValue
@@ -14713,7 +15426,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
int ShiftVal = (IdxVal % 4) * 8;
if (ShiftVal != 0)
Res = DAG.getNode(ISD::SRL, dl, MVT::i32, Res,
- DAG.getConstant(ShiftVal, dl, MVT::i32));
+ DAG.getConstant(ShiftVal, dl, MVT::i8));
return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
}
@@ -14724,7 +15437,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
int ShiftVal = (IdxVal % 2) * 8;
if (ShiftVal != 0)
Res = DAG.getNode(ISD::SRL, dl, MVT::i16, Res,
- DAG.getConstant(ShiftVal, dl, MVT::i16));
+ DAG.getConstant(ShiftVal, dl, MVT::i8));
return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
}
@@ -14780,74 +15493,11 @@ static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp);
}
- unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue();
- unsigned NumElems = VecVT.getVectorNumElements();
-
- // If the kshift instructions of the correct width aren't natively supported
- // then we need to promote the vector to the native size to get the correct
- // zeroing behavior.
- if ((!Subtarget.hasDQI() && NumElems == 8) || (NumElems < 8)) {
- // Need to promote to v16i1, do the insert, then extract back.
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
- DAG.getUNDEF(MVT::v16i1), Vec,
- DAG.getIntPtrConstant(0, dl));
- Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v16i1, Vec, Elt, Idx);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Op,
- DAG.getIntPtrConstant(0, dl));
- }
-
- SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Elt);
+ // Copy into a k-register, extract to v1i1 and insert_subvector.
+ SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Elt);
- if (Vec.isUndef()) {
- if (IdxVal)
- EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec,
- DAG.getConstant(IdxVal, dl, MVT::i8));
- return EltInVec;
- }
-
- // Insertion of one bit into first position
- if (IdxVal == 0 ) {
- // Clean top bits of vector.
- EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec,
- DAG.getConstant(NumElems - 1, dl, MVT::i8));
- EltInVec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, EltInVec,
- DAG.getConstant(NumElems - 1, dl, MVT::i8));
- // Clean the first bit in source vector.
- Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec,
- DAG.getConstant(1 , dl, MVT::i8));
- Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec,
- DAG.getConstant(1, dl, MVT::i8));
-
- return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec);
- }
- // Insertion of one bit into last position
- if (IdxVal == NumElems - 1) {
- // Move the bit to the last position inside the vector.
- EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec,
- DAG.getConstant(IdxVal, dl, MVT::i8));
- // Clean the last bit in the source vector.
- Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec,
- DAG.getConstant(1, dl, MVT::i8));
- Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec,
- DAG.getConstant(1 , dl, MVT::i8));
-
- return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec);
- }
-
- // Move the current value of the bit to be replace to bit 0.
- SDValue Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec,
- DAG.getConstant(IdxVal, dl, MVT::i8));
- // Xor with the new bit.
- Merged = DAG.getNode(ISD::XOR, dl, VecVT, Merged, EltInVec);
- // Shift to MSB, filling bottom bits with 0.
- Merged = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Merged,
- DAG.getConstant(NumElems - 1, dl, MVT::i8));
- // Shift to the final position, filling upper bits with 0.
- Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Merged,
- DAG.getConstant(NumElems - 1 - IdxVal, dl, MVT::i8));
- // Xor with original vector to cancel out the original bit value that's still
- // present.
- return DAG.getNode(ISD::XOR, dl, VecVT, Merged, Vec);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, Vec, EltInVec,
+ Op.getOperand(2));
}
SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
@@ -15020,8 +15670,45 @@ static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
return insert1BitVector(Op, DAG, Subtarget);
}
+static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1 &&
+ "Only vXi1 extract_subvectors need custom lowering");
+
+ SDLoc dl(Op);
+ SDValue Vec = Op.getOperand(0);
+ SDValue Idx = Op.getOperand(1);
+
+ if (!isa<ConstantSDNode>(Idx))
+ return SDValue();
+
+ unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue();
+ if (IdxVal == 0) // the operation is legal
+ return Op;
+
+ MVT VecVT = Vec.getSimpleValueType();
+ unsigned NumElems = VecVT.getVectorNumElements();
+
+ // Extend to natively supported kshift.
+ MVT WideVecVT = VecVT;
+ if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
+ WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
+ Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
+ DAG.getUNDEF(WideVecVT), Vec,
+ DAG.getIntPtrConstant(0, dl));
+ }
+
+ // Shift to the LSB.
+ Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
+ DAG.getConstant(IdxVal, dl, MVT::i8));
+
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,
+ DAG.getIntPtrConstant(0, dl));
+}
+
// Returns the appropriate wrapper opcode for a global reference.
-unsigned X86TargetLowering::getGlobalWrapperKind(const GlobalValue *GV) const {
+unsigned X86TargetLowering::getGlobalWrapperKind(
+ const GlobalValue *GV, const unsigned char OpFlags) const {
// References to absolute symbols are never PC-relative.
if (GV && GV->isAbsoluteSymbolRef())
return X86ISD::Wrapper;
@@ -15031,6 +15718,10 @@ unsigned X86TargetLowering::getGlobalWrapperKind(const GlobalValue *GV) const {
(M == CodeModel::Small || M == CodeModel::Kernel))
return X86ISD::WrapperRIP;
+ // GOTPCREL references must always use RIP.
+ if (OpFlags == X86II::MO_GOTPCREL)
+ return X86ISD::WrapperRIP;
+
return X86ISD::Wrapper;
}
@@ -15154,7 +15845,7 @@ SDValue X86TargetLowering::LowerGlobalAddress(const GlobalValue *GV,
Result = DAG.getTargetGlobalAddress(GV, dl, PtrVT, 0, OpFlags);
}
- Result = DAG.getNode(getGlobalWrapperKind(GV), dl, PtrVT, Result);
+ Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
// With PIC, the address is actually $g + Offset.
if (isGlobalRelativeToPICBase(OpFlags)) {
@@ -15336,7 +16027,7 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const {
GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
- if (DAG.getTarget().Options.EmulatedTLS)
+ if (DAG.getTarget().useEmulatedTLS())
return LowerToTLSEmulatedModel(GA, DAG);
const GlobalValue *GV = GA->getGlobal();
@@ -15456,7 +16147,7 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const {
auto &DL = DAG.getDataLayout();
SDValue Scale =
- DAG.getConstant(Log2_64_Ceil(DL.getPointerSize()), dl, PtrVT);
+ DAG.getConstant(Log2_64_Ceil(DL.getPointerSize()), dl, MVT::i8);
IDX = DAG.getNode(ISD::SHL, dl, PtrVT, IDX, Scale);
res = DAG.getNode(ISD::ADD, dl, PtrVT, ThreadPointer, IDX);
@@ -15512,24 +16203,47 @@ static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) {
// values for large shift amounts.
SDValue AndNode = DAG.getNode(ISD::AND, dl, MVT::i8, ShAmt,
DAG.getConstant(VTBits, dl, MVT::i8));
- SDValue Cond = DAG.getNode(X86ISD::CMP, dl, MVT::i32,
- AndNode, DAG.getConstant(0, dl, MVT::i8));
+ SDValue Cond = DAG.getSetCC(dl, MVT::i8, AndNode,
+ DAG.getConstant(0, dl, MVT::i8), ISD::SETNE);
SDValue Hi, Lo;
- SDValue CC = DAG.getConstant(X86::COND_NE, dl, MVT::i8);
- SDValue Ops0[4] = { Tmp2, Tmp3, CC, Cond };
- SDValue Ops1[4] = { Tmp3, Tmp1, CC, Cond };
-
if (Op.getOpcode() == ISD::SHL_PARTS) {
- Hi = DAG.getNode(X86ISD::CMOV, dl, VT, Ops0);
- Lo = DAG.getNode(X86ISD::CMOV, dl, VT, Ops1);
+ Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
+ Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
} else {
- Lo = DAG.getNode(X86ISD::CMOV, dl, VT, Ops0);
- Hi = DAG.getNode(X86ISD::CMOV, dl, VT, Ops1);
+ Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
+ Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
}
- SDValue Ops[2] = { Lo, Hi };
- return DAG.getMergeValues(Ops, dl);
+ return DAG.getMergeValues({ Lo, Hi }, dl);
+}
+
+// Try to use a packed vector operation to handle i64 on 32-bit targets when
+// AVX512DQ is enabled.
+static SDValue LowerI64IntToFP_AVX512DQ(SDValue Op, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ assert((Op.getOpcode() == ISD::SINT_TO_FP ||
+ Op.getOpcode() == ISD::UINT_TO_FP) && "Unexpected opcode!");
+ SDValue Src = Op.getOperand(0);
+ MVT SrcVT = Src.getSimpleValueType();
+ MVT VT = Op.getSimpleValueType();
+
+ if (!Subtarget.hasDQI() || SrcVT != MVT::i64 || Subtarget.is64Bit() ||
+ (VT != MVT::f32 && VT != MVT::f64))
+ return SDValue();
+
+ // Pack the i64 into a vector, do the operation and extract.
+
+ // Using 256-bit to ensure result is 128-bits for f32 case.
+ unsigned NumElts = Subtarget.hasVLX() ? 4 : 8;
+ MVT VecInVT = MVT::getVectorVT(MVT::i64, NumElts);
+ MVT VecVT = MVT::getVectorVT(VT, NumElts);
+
+ SDLoc dl(Op);
+ SDValue InVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecInVT, Src);
+ SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, VecVT, InVec);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec,
+ DAG.getIntPtrConstant(0, dl));
}
SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
@@ -15545,20 +16259,6 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
DAG.getUNDEF(SrcVT)));
}
- if (SrcVT.getVectorElementType() == MVT::i1) {
- if (SrcVT == MVT::v2i1) {
- // For v2i1, we need to widen to v4i1 first.
- assert(VT == MVT::v2f64 && "Unexpected type");
- Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Src,
- DAG.getUNDEF(MVT::v2i1));
- return DAG.getNode(X86ISD::CVTSI2P, dl, Op.getValueType(),
- DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Src));
- }
-
- MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements());
- return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(),
- DAG.getNode(ISD::SIGN_EXTEND, dl, IntegerVT, Src));
- }
return SDValue();
}
@@ -15567,15 +16267,17 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
// These are really Legal; return the operand so the caller accepts it as
// Legal.
- if (SrcVT == MVT::i32 && isScalarFPTypeInSSEReg(Op.getValueType()))
+ if (SrcVT == MVT::i32 && isScalarFPTypeInSSEReg(VT))
return Op;
- if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(Op.getValueType()) &&
- Subtarget.is64Bit()) {
+ if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) && Subtarget.is64Bit()) {
return Op;
}
+ if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, DAG, Subtarget))
+ return V;
+
SDValue ValueToStore = Op.getOperand(0);
- if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(Op.getValueType()) &&
+ if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) &&
!Subtarget.is64Bit())
// Bitcasting to f64 here allows us to do a single 64-bit store from
// an SSE register, avoiding the store forwarding penalty that would come
@@ -15760,7 +16462,8 @@ static SDValue LowerUINT_TO_FP_i32(SDValue Op, SelectionDAG &DAG,
}
static SDValue lowerUINT_TO_FP_v2i32(SDValue Op, SelectionDAG &DAG,
- const X86Subtarget &Subtarget, SDLoc &DL) {
+ const X86Subtarget &Subtarget,
+ const SDLoc &DL) {
if (Op.getSimpleValueType() != MVT::v2f64)
return SDValue();
@@ -15894,21 +16597,6 @@ static SDValue lowerUINT_TO_FP_vec(SDValue Op, SelectionDAG &DAG,
MVT SrcVT = N0.getSimpleValueType();
SDLoc dl(Op);
- if (SrcVT.getVectorElementType() == MVT::i1) {
- if (SrcVT == MVT::v2i1) {
- // For v2i1, we need to widen to v4i1 first.
- assert(Op.getValueType() == MVT::v2f64 && "Unexpected type");
- N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, N0,
- DAG.getUNDEF(MVT::v2i1));
- return DAG.getNode(X86ISD::CVTUI2P, dl, MVT::v2f64,
- DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v4i32, N0));
- }
-
- MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements());
- return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(),
- DAG.getNode(ISD::ZERO_EXTEND, dl, IntegerVT, N0));
- }
-
switch (SrcVT.SimpleTy) {
default:
llvm_unreachable("Custom UINT_TO_FP is not supported!");
@@ -15940,6 +16628,9 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
return Op;
}
+ if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, DAG, Subtarget))
+ return V;
+
if (SrcVT == MVT::i64 && DstVT == MVT::f64 && X86ScalarSSEf64)
return LowerUINT_TO_FP_i64(Op, DAG, Subtarget);
if (SrcVT == MVT::i32 && X86ScalarSSEf64)
@@ -16205,15 +16896,17 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG,
MVT InVT = In.getSimpleValueType();
SDLoc dl(Op);
- if ((VT != MVT::v4i64 || InVT != MVT::v4i32) &&
- (VT != MVT::v8i32 || InVT != MVT::v8i16) &&
- (VT != MVT::v16i16 || InVT != MVT::v16i8) &&
- (VT != MVT::v8i64 || InVT != MVT::v8i32) &&
- (VT != MVT::v8i64 || InVT != MVT::v8i16) &&
- (VT != MVT::v16i32 || InVT != MVT::v16i16) &&
- (VT != MVT::v16i32 || InVT != MVT::v16i8) &&
- (VT != MVT::v32i16 || InVT != MVT::v32i8))
- return SDValue();
+ assert(VT.isVector() && InVT.isVector() && "Expected vector type");
+ assert(VT.getVectorNumElements() == VT.getVectorNumElements() &&
+ "Expected same number of elements");
+ assert((VT.getVectorElementType() == MVT::i16 ||
+ VT.getVectorElementType() == MVT::i32 ||
+ VT.getVectorElementType() == MVT::i64) &&
+ "Unexpected element type");
+ assert((InVT.getVectorElementType() == MVT::i8 ||
+ InVT.getVectorElementType() == MVT::i16 ||
+ InVT.getVectorElementType() == MVT::i32) &&
+ "Unexpected element type");
if (Subtarget.hasInt256())
return DAG.getNode(X86ISD::VZEXT, dl, VT, In);
@@ -16246,6 +16939,20 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
}
+// Helper to split and extend a v16i1 mask to v16i8 or v16i16.
+static SDValue SplitAndExtendv16i1(unsigned ExtOpc, MVT VT, SDValue In,
+ const SDLoc &dl, SelectionDAG &DAG) {
+ assert((VT == MVT::v16i8 || VT == MVT::v16i16) && "Unexpected VT.");
+ SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
+ DAG.getIntPtrConstant(0, dl));
+ SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
+ DAG.getIntPtrConstant(8, dl));
+ Lo = DAG.getNode(ExtOpc, dl, MVT::v8i16, Lo);
+ Hi = DAG.getNode(ExtOpc, dl, MVT::v8i16, Hi);
+ SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i16, Lo, Hi);
+ return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
+}
+
static SDValue LowerZERO_EXTEND_Mask(SDValue Op,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
@@ -16256,11 +16963,23 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op,
SDLoc DL(Op);
unsigned NumElts = VT.getVectorNumElements();
- // Extend VT if the scalar type is v8/v16 and BWI is not supported.
+ // For all vectors, but vXi8 we can just emit a sign_extend a shift. This
+ // avoids a constant pool load.
+ if (VT.getVectorElementType() != MVT::i8) {
+ SDValue Extend = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, In);
+ return DAG.getNode(ISD::SRL, DL, VT, Extend,
+ DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT));
+ }
+
+ // Extend VT if BWI is not supported.
MVT ExtVT = VT;
- if (!Subtarget.hasBWI() &&
- (VT.getVectorElementType().getSizeInBits() <= 16))
+ if (!Subtarget.hasBWI()) {
+ // If v16i32 is to be avoided, we'll need to split and concatenate.
+ if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
+ return SplitAndExtendv16i1(ISD::ZERO_EXTEND, VT, In, DL, DAG);
+
ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
+ }
// Widen to 512-bits if VLX is not supported.
MVT WideVT = ExtVT;
@@ -16278,9 +16997,9 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op,
SDValue SelectedVal = DAG.getSelect(DL, WideVT, In, One, Zero);
- // Truncate if we had to extend i16/i8 above.
+ // Truncate if we had to extend above.
if (VT != ExtVT) {
- WideVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
+ WideVT = MVT::getVectorVT(MVT::i8, NumElts);
SelectedVal = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SelectedVal);
}
@@ -16300,14 +17019,8 @@ static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
if (SVT.getVectorElementType() == MVT::i1)
return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG);
- if (Subtarget.hasFp256())
- if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget))
- return Res;
-
- assert(!Op.getSimpleValueType().is256BitVector() || !SVT.is128BitVector() ||
- Op.getSimpleValueType().getVectorNumElements() !=
- SVT.getVectorNumElements());
- return SDValue();
+ assert(Subtarget.hasAVX() && "Expected AVX support");
+ return LowerAVXExtend(Op, DAG, Subtarget);
}
/// Helper to recursively truncate vector elements in half with PACKSS/PACKUS.
@@ -16321,8 +17034,8 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
assert((Opcode == X86ISD::PACKSS || Opcode == X86ISD::PACKUS) &&
"Unexpected PACK opcode");
- // Requires SSE2 but AVX512 has fast truncate.
- if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
+ // Requires SSE2 but AVX512 has fast vector truncate.
+ if (!Subtarget.hasSSE2() || Subtarget.hasAVX512() || !DstVT.isVector())
return SDValue();
EVT SrcVT = In.getValueType();
@@ -16331,40 +17044,53 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
if (SrcVT == DstVT)
return In;
- // We only support vector truncation to 128bits or greater from a
- // 256bits or greater source.
+ // We only support vector truncation to 64bits or greater from a
+ // 128bits or greater source.
unsigned DstSizeInBits = DstVT.getSizeInBits();
unsigned SrcSizeInBits = SrcVT.getSizeInBits();
- if ((DstSizeInBits % 128) != 0 || (SrcSizeInBits % 256) != 0)
+ if ((DstSizeInBits % 64) != 0 || (SrcSizeInBits % 128) != 0)
return SDValue();
- LLVMContext &Ctx = *DAG.getContext();
unsigned NumElems = SrcVT.getVectorNumElements();
+ if (!isPowerOf2_32(NumElems))
+ return SDValue();
+
+ LLVMContext &Ctx = *DAG.getContext();
assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
assert(SrcSizeInBits > DstSizeInBits && "Illegal truncation");
EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
- // Extract lower/upper subvectors.
- unsigned NumSubElts = NumElems / 2;
- SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
- SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
-
// Pack to the largest type possible:
// vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB.
EVT InVT = MVT::i16, OutVT = MVT::i8;
- if (DstVT.getScalarSizeInBits() > 8 &&
+ if (SrcVT.getScalarSizeInBits() > 16 &&
(Opcode == X86ISD::PACKSS || Subtarget.hasSSE41())) {
InVT = MVT::i32;
OutVT = MVT::i16;
}
+ // 128bit -> 64bit truncate - PACK 128-bit src in the lower subvector.
+ if (SrcVT.is128BitVector()) {
+ InVT = EVT::getVectorVT(Ctx, InVT, 128 / InVT.getSizeInBits());
+ OutVT = EVT::getVectorVT(Ctx, OutVT, 128 / OutVT.getSizeInBits());
+ In = DAG.getBitcast(InVT, In);
+ SDValue Res = DAG.getNode(Opcode, DL, OutVT, In, In);
+ Res = extractSubVector(Res, 0, DAG, DL, 64);
+ return DAG.getBitcast(DstVT, Res);
+ }
+
+ // Extract lower/upper subvectors.
+ unsigned NumSubElts = NumElems / 2;
+ SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
+ SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
+
unsigned SubSizeInBits = SrcSizeInBits / 2;
InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
// 256bit -> 128bit truncate - PACK lower/upper 128-bit subvectors.
- if (SrcVT.is256BitVector()) {
+ if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
Lo = DAG.getBitcast(InVT, Lo);
Hi = DAG.getBitcast(InVT, Hi);
SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi);
@@ -16393,7 +17119,7 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
}
// Recursively pack lower/upper subvectors, concat result and pack again.
- assert(SrcSizeInBits >= 512 && "Expected 512-bit vector or greater");
+ assert(SrcSizeInBits >= 256 && "Expected 256-bit vector or greater");
EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumSubElts);
Lo = truncateVectorWithPACK(Opcode, PackedVT, Lo, DL, DAG, Subtarget);
Hi = truncateVectorWithPACK(Opcode, PackedVT, Hi, DL, DAG, Subtarget);
@@ -16418,18 +17144,49 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,
if (InVT.getScalarSizeInBits() <= 16) {
if (Subtarget.hasBWI()) {
// legal, will go to VPMOVB2M, VPMOVW2M
- // Shift packed bytes not supported natively, bitcast to word
- MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16);
- SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, ExtVT,
- DAG.getBitcast(ExtVT, In),
- DAG.getConstant(ShiftInx, DL, ExtVT));
- ShiftNode = DAG.getBitcast(InVT, ShiftNode);
- return DAG.getNode(X86ISD::CVT2MASK, DL, VT, ShiftNode);
+ if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) {
+ // We need to shift to get the lsb into sign position.
+ // Shift packed bytes not supported natively, bitcast to word
+ MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16);
+ In = DAG.getNode(ISD::SHL, DL, ExtVT,
+ DAG.getBitcast(ExtVT, In),
+ DAG.getConstant(ShiftInx, DL, ExtVT));
+ In = DAG.getBitcast(InVT, In);
+ }
+ return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT),
+ In, ISD::SETGT);
}
// Use TESTD/Q, extended vector to packed dword/qword.
assert((InVT.is256BitVector() || InVT.is128BitVector()) &&
"Unexpected vector type.");
unsigned NumElts = InVT.getVectorNumElements();
+ assert((NumElts == 8 || NumElts == 16) && "Unexpected number of elements");
+ // We need to change to a wider element type that we have support for.
+ // For 8 element vectors this is easy, we either extend to v8i32 or v8i64.
+ // For 16 element vectors we extend to v16i32 unless we are explicitly
+ // trying to avoid 512-bit vectors. If we are avoiding 512-bit vectors
+ // we need to split into two 8 element vectors which we can extend to v8i32,
+ // truncate and concat the results. There's an additional complication if
+ // the original type is v16i8. In that case we can't split the v16i8 so
+ // first we pre-extend it to v16i16 which we can split to v8i16, then extend
+ // to v8i32, truncate that to v8i1 and concat the two halves.
+ if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) {
+ if (InVT == MVT::v16i8) {
+ // First we need to sign extend up to 256-bits so we can split that.
+ InVT = MVT::v16i16;
+ In = DAG.getNode(ISD::SIGN_EXTEND, DL, InVT, In);
+ }
+ SDValue Lo = extract128BitVector(In, 0, DAG, DL);
+ SDValue Hi = extract128BitVector(In, 8, DAG, DL);
+ // We're split now, just emit two truncates and a concat. The two
+ // truncates will trigger legalization to come back to this function.
+ Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Lo);
+ Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+ }
+ // We either have 8 elements or we're allowed to use 512-bit vectors.
+ // If we have VLX, we want to use the narrowest vector that can get the
+ // job done so we use vXi32.
MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts);
MVT ExtVT = MVT::getVectorVT(EltVT, NumElts);
In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In);
@@ -16437,9 +17194,17 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,
ShiftInx = InVT.getScalarSizeInBits() - 1;
}
- SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, InVT, In,
- DAG.getConstant(ShiftInx, DL, InVT));
- return DAG.getNode(X86ISD::TESTM, DL, VT, ShiftNode, ShiftNode);
+ if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) {
+ // We need to shift to get the lsb into sign position.
+ In = DAG.getNode(ISD::SHL, DL, InVT, In,
+ DAG.getConstant(ShiftInx, DL, InVT));
+ }
+ // If we have DQI, emit a pattern that will be iseled as vpmovq2m/vpmovd2m.
+ if (Subtarget.hasDQI())
+ return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT),
+ In, ISD::SETGT);
+ return DAG.getSetCC(DL, VT, In, getZeroVector(InVT, Subtarget, DAG, DL),
+ ISD::SETNE);
}
SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
@@ -16458,31 +17223,36 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
// vpmovqb/w/d, vpmovdb/w, vpmovwb
if (Subtarget.hasAVX512()) {
// word to byte only under BWI
- if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) // v16i16 -> v16i8
- return DAG.getNode(X86ISD::VTRUNC, DL, VT,
- getExtendInVec(X86ISD::VSEXT, DL, MVT::v16i32, In, DAG));
- return DAG.getNode(X86ISD::VTRUNC, DL, VT, In);
+ if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) { // v16i16 -> v16i8
+ // Make sure we're allowed to promote 512-bits.
+ if (Subtarget.canExtendTo512DQ())
+ return DAG.getNode(ISD::TRUNCATE, DL, VT,
+ DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In));
+ } else {
+ return Op;
+ }
}
- // Truncate with PACKSS if we are truncating a vector with sign-bits that
- // extend all the way to the packed/truncated value.
- unsigned NumPackedBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16);
- if ((InNumEltBits - NumPackedBits) < DAG.ComputeNumSignBits(In))
- if (SDValue V =
- truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget))
- return V;
+ unsigned NumPackedSignBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16);
+ unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
// Truncate with PACKUS if we are truncating a vector with leading zero bits
// that extend all the way to the packed/truncated value.
// Pre-SSE41 we can only use PACKUSWB.
KnownBits Known;
DAG.computeKnownBits(In, Known);
- NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8;
- if ((InNumEltBits - NumPackedBits) <= Known.countMinLeadingZeros())
+ if ((InNumEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros())
if (SDValue V =
truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget))
return V;
+ // Truncate with PACKSS if we are truncating a vector with sign-bits that
+ // extend all the way to the packed/truncated value.
+ if ((InNumEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In))
+ if (SDValue V =
+ truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget))
+ return V;
+
if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) {
// On AVX2, v4i64 -> v4i32 becomes VPERMD.
if (Subtarget.hasInt256()) {
@@ -16549,10 +17319,9 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
}
// Handle truncation of V256 to V128 using shuffles.
- if (!VT.is128BitVector() || !InVT.is256BitVector())
- return SDValue();
+ assert(VT.is128BitVector() && InVT.is256BitVector() && "Unexpected types!");
- assert(Subtarget.hasFp256() && "256-bit vector without AVX!");
+ assert(Subtarget.hasAVX() && "256-bit vector without AVX!");
unsigned NumElems = VT.getVectorNumElements();
MVT NVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems * 2);
@@ -16572,9 +17341,29 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
if (VT.isVector()) {
- assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!");
SDValue Src = Op.getOperand(0);
SDLoc dl(Op);
+
+ if (VT == MVT::v2i1 && Src.getSimpleValueType() == MVT::v2f64) {
+ MVT ResVT = MVT::v4i32;
+ MVT TruncVT = MVT::v4i1;
+ unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
+ if (!IsSigned && !Subtarget.hasVLX()) {
+ // Widen to 512-bits.
+ ResVT = MVT::v8i32;
+ TruncVT = MVT::v8i1;
+ Opc = ISD::FP_TO_UINT;
+ Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64,
+ DAG.getUNDEF(MVT::v8f64),
+ Src, DAG.getIntPtrConstant(0, dl));
+ }
+ SDValue Res = DAG.getNode(Opc, dl, ResVT, Src);
+ Res = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Res);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i1, Res,
+ DAG.getIntPtrConstant(0, dl));
+ }
+
+ assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!");
if (VT == MVT::v2i64 && Src.getSimpleValueType() == MVT::v2f32) {
return DAG.getNode(IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI, dl, VT,
DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
@@ -16771,8 +17560,16 @@ static SDValue LowerFGETSIGN(SDValue Op, SelectionDAG &DAG) {
return Res;
}
+/// Helper for creating a X86ISD::SETCC node.
+static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
+ SelectionDAG &DAG) {
+ return DAG.getNode(X86ISD::SETCC, dl, MVT::i8,
+ DAG.getConstant(Cond, dl, MVT::i8), EFLAGS);
+}
+
// Check whether an OR'd tree is PTEST-able.
-static SDValue LowerVectorAllZeroTest(SDValue Op, const X86Subtarget &Subtarget,
+static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
+ const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree.");
@@ -16859,10 +17656,12 @@ static SDValue LowerVectorAllZeroTest(SDValue Op, const X86Subtarget &Subtarget,
VecIns.push_back(DAG.getNode(ISD::OR, DL, TestVT, LHS, RHS));
}
- return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, VecIns.back(), VecIns.back());
+ SDValue Res = DAG.getNode(X86ISD::PTEST, DL, MVT::i32,
+ VecIns.back(), VecIns.back());
+ return getSETCC(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, Res, DL, DAG);
}
-/// \brief return true if \c Op has a use that doesn't just read flags.
+/// return true if \c Op has a use that doesn't just read flags.
static bool hasNonFlagsUse(SDValue Op) {
for (SDNode::use_iterator UI = Op->use_begin(), UE = Op->use_end(); UI != UE;
++UI) {
@@ -16881,33 +17680,10 @@ static bool hasNonFlagsUse(SDValue Op) {
return false;
}
-// Emit KTEST instruction for bit vectors on AVX-512
-static SDValue EmitKTEST(SDValue Op, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
- if (Op.getOpcode() == ISD::BITCAST) {
- auto hasKTEST = [&](MVT VT) {
- unsigned SizeInBits = VT.getSizeInBits();
- return (Subtarget.hasDQI() && (SizeInBits == 8 || SizeInBits == 16)) ||
- (Subtarget.hasBWI() && (SizeInBits == 32 || SizeInBits == 64));
- };
- SDValue Op0 = Op.getOperand(0);
- MVT Op0VT = Op0.getValueType().getSimpleVT();
- if (Op0VT.isVector() && Op0VT.getVectorElementType() == MVT::i1 &&
- hasKTEST(Op0VT))
- return DAG.getNode(X86ISD::KTEST, SDLoc(Op), Op0VT, Op0, Op0);
- }
- return SDValue();
-}
-
/// Emit nodes that will be selected as "test Op0,Op0", or something
/// equivalent.
SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
SelectionDAG &DAG) const {
- if (Op.getValueType() == MVT::i1) {
- SDValue ExtOp = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Op);
- return DAG.getNode(X86ISD::CMP, dl, MVT::i32, ExtOp,
- DAG.getConstant(0, dl, MVT::i8));
- }
// CF and OF aren't always set the way we want. Determine which
// of these we need.
bool NeedCF = false;
@@ -16943,9 +17719,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
// doing a separate TEST. TEST always sets OF and CF to 0, so unless
// we prove that the arithmetic won't overflow, we can't use OF or CF.
if (Op.getResNo() != 0 || NeedOF || NeedCF) {
- // Emit KTEST for bit vectors
- if (auto Node = EmitKTEST(Op, DAG, Subtarget))
- return Node;
// Emit a CMP with 0, which is the TEST pattern.
return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
DAG.getConstant(0, dl, Op.getValueType()));
@@ -17119,14 +17892,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
case ISD::SUB: Opcode = X86ISD::SUB; break;
case ISD::XOR: Opcode = X86ISD::XOR; break;
case ISD::AND: Opcode = X86ISD::AND; break;
- case ISD::OR: {
- if (!NeedTruncation && ZeroCheck) {
- if (SDValue EFLAGS = LowerVectorAllZeroTest(Op, Subtarget, DAG))
- return EFLAGS;
- }
- Opcode = X86ISD::OR;
- break;
- }
+ case ISD::OR: Opcode = X86ISD::OR; break;
}
NumOperands = 2;
@@ -17168,16 +17934,13 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
if (TLI.isOperationLegal(WideVal.getOpcode(), WideVT)) {
SDValue V0 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(0));
SDValue V1 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(1));
- Op = DAG.getNode(ConvertedOp, dl, VT, V0, V1);
+ SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
+ Op = DAG.getNode(ConvertedOp, dl, VTs, V0, V1);
}
}
}
if (Opcode == 0) {
- // Emit KTEST for bit vectors
- if (auto Node = EmitKTEST(Op, DAG, Subtarget))
- return Node;
-
// Emit a CMP with 0, which is the TEST pattern.
return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
DAG.getConstant(0, dl, Op.getValueType()));
@@ -17186,7 +17949,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
SmallVector<SDValue, 4> Ops(Op->op_begin(), Op->op_begin() + NumOperands);
SDValue New = DAG.getNode(Opcode, dl, VTs, Ops);
- DAG.ReplaceAllUsesWith(Op, New);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Op.getNode(), 0), New);
return SDValue(New.getNode(), 1);
}
@@ -17271,7 +18034,6 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
EVT VT = Op.getValueType();
// SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
- // TODO: Add support for AVX512 (v16f32).
// It is likely not profitable to do this for f64 because a double-precision
// rsqrt estimate with refinement on x86 prior to FMA requires at least 16
// instructions: convert to single, rsqrtss, convert back to double, refine
@@ -17282,12 +18044,15 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
if ((VT == MVT::f32 && Subtarget.hasSSE1()) ||
(VT == MVT::v4f32 && Subtarget.hasSSE1() && Reciprocal) ||
(VT == MVT::v4f32 && Subtarget.hasSSE2() && !Reciprocal) ||
- (VT == MVT::v8f32 && Subtarget.hasAVX())) {
+ (VT == MVT::v8f32 && Subtarget.hasAVX()) ||
+ (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) {
if (RefinementSteps == ReciprocalEstimate::Unspecified)
RefinementSteps = 1;
UseOneConstNR = false;
- return DAG.getNode(X86ISD::FRSQRT, SDLoc(Op), VT, Op);
+ // There is no FSQRT for 512-bits, but there is RSQRT14.
+ unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
+ return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
}
return SDValue();
}
@@ -17300,7 +18065,6 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
EVT VT = Op.getValueType();
// SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
- // TODO: Add support for AVX512 (v16f32).
// It is likely not profitable to do this for f64 because a double-precision
// reciprocal estimate with refinement on x86 prior to FMA requires
// 15 instructions: convert to single, rcpss, convert back to double, refine
@@ -17309,7 +18073,8 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
if ((VT == MVT::f32 && Subtarget.hasSSE1()) ||
(VT == MVT::v4f32 && Subtarget.hasSSE1()) ||
- (VT == MVT::v8f32 && Subtarget.hasAVX())) {
+ (VT == MVT::v8f32 && Subtarget.hasAVX()) ||
+ (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) {
// Enable estimate codegen with 1 refinement step for vector division.
// Scalar division estimates are disabled because they break too much
// real-world code. These defaults are intended to match GCC behavior.
@@ -17319,7 +18084,9 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
if (RefinementSteps == ReciprocalEstimate::Unspecified)
RefinementSteps = 1;
- return DAG.getNode(X86ISD::FRCP, SDLoc(Op), VT, Op);
+ // There is no FSQRT for 512-bits, but there is RCP14.
+ unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
+ return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
}
return SDValue();
}
@@ -17334,13 +18101,6 @@ unsigned X86TargetLowering::combineRepeatedFPDivisors() const {
return 2;
}
-/// Helper for creating a X86ISD::SETCC node.
-static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
- SelectionDAG &DAG) {
- return DAG.getNode(X86ISD::SETCC, dl, MVT::i8,
- DAG.getConstant(Cond, dl, MVT::i8), EFLAGS);
-}
-
/// Create a BT (Bit Test) node - Test bit \p BitNo in \p Src and set condition
/// according to equal/not-equal condition code \p CC.
static SDValue getBitTestCondition(SDValue Src, SDValue BitNo, ISD::CondCode CC,
@@ -17408,12 +18168,15 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC,
if (AndRHSVal == 1 && AndLHS.getOpcode() == ISD::SRL) {
LHS = AndLHS.getOperand(0);
RHS = AndLHS.getOperand(1);
- }
-
- // Use BT if the immediate can't be encoded in a TEST instruction.
- if (!isUInt<32>(AndRHSVal) && isPowerOf2_64(AndRHSVal)) {
- LHS = AndLHS;
- RHS = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, LHS.getValueType());
+ } else {
+ // Use BT if the immediate can't be encoded in a TEST instruction or we
+ // are optimizing for size and the immedaite won't fit in a byte.
+ bool OptForSize = DAG.getMachineFunction().getFunction().optForSize();
+ if ((!isUInt<32>(AndRHSVal) || (OptForSize && !isUInt<8>(AndRHSVal))) &&
+ isPowerOf2_64(AndRHSVal)) {
+ LHS = AndLHS;
+ RHS = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, LHS.getValueType());
+ }
}
}
@@ -17498,49 +18261,6 @@ static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) {
DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC));
}
-static SDValue LowerBoolVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
- SDValue CC = Op.getOperand(2);
- MVT VT = Op.getSimpleValueType();
- SDLoc dl(Op);
-
- assert(Op0.getSimpleValueType().getVectorElementType() == MVT::i1 &&
- "Unexpected type for boolean compare operation");
- ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get();
- SDValue NotOp0 = DAG.getNode(ISD::XOR, dl, VT, Op0,
- DAG.getConstant(-1, dl, VT));
- SDValue NotOp1 = DAG.getNode(ISD::XOR, dl, VT, Op1,
- DAG.getConstant(-1, dl, VT));
- switch (SetCCOpcode) {
- default: llvm_unreachable("Unexpected SETCC condition");
- case ISD::SETEQ:
- // (x == y) -> ~(x ^ y)
- return DAG.getNode(ISD::XOR, dl, VT,
- DAG.getNode(ISD::XOR, dl, VT, Op0, Op1),
- DAG.getConstant(-1, dl, VT));
- case ISD::SETNE:
- // (x != y) -> (x ^ y)
- return DAG.getNode(ISD::XOR, dl, VT, Op0, Op1);
- case ISD::SETUGT:
- case ISD::SETGT:
- // (x > y) -> (x & ~y)
- return DAG.getNode(ISD::AND, dl, VT, Op0, NotOp1);
- case ISD::SETULT:
- case ISD::SETLT:
- // (x < y) -> (~x & y)
- return DAG.getNode(ISD::AND, dl, VT, NotOp0, Op1);
- case ISD::SETULE:
- case ISD::SETLE:
- // (x <= y) -> (~x | y)
- return DAG.getNode(ISD::OR, dl, VT, NotOp0, Op1);
- case ISD::SETUGE:
- case ISD::SETGE:
- // (x >=y) -> (x | ~y)
- return DAG.getNode(ISD::OR, dl, VT, Op0, NotOp1);
- }
-}
-
static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
SDValue Op0 = Op.getOperand(0);
@@ -17553,48 +18273,24 @@ static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
"Cannot set masked compare for this operation");
ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get();
- unsigned Opc = 0;
- bool Unsigned = false;
- bool Swap = false;
- unsigned SSECC;
- switch (SetCCOpcode) {
- default: llvm_unreachable("Unexpected SETCC condition");
- case ISD::SETNE: SSECC = 4; break;
- case ISD::SETEQ: Opc = X86ISD::PCMPEQM; break;
- case ISD::SETUGT: SSECC = 6; Unsigned = true; break;
- case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH;
- case ISD::SETGT: Opc = X86ISD::PCMPGTM; break;
- case ISD::SETULT: SSECC = 1; Unsigned = true; break;
- case ISD::SETUGE: SSECC = 5; Unsigned = true; break; //NLT
- case ISD::SETGE: Swap = true; SSECC = 2; break; // LE + swap
- case ISD::SETULE: Unsigned = true; LLVM_FALLTHROUGH;
- case ISD::SETLE: SSECC = 2; break;
- }
- if (Swap)
+ // If this is a seteq make sure any build vectors of all zeros are on the RHS.
+ // This helps with vptestm matching.
+ // TODO: Should we just canonicalize the setcc during DAG combine?
+ if ((SetCCOpcode == ISD::SETEQ || SetCCOpcode == ISD::SETNE) &&
+ ISD::isBuildVectorAllZeros(Op0.getNode()))
std::swap(Op0, Op1);
- // See if it is the case of CMP(EQ|NEQ,AND(A,B),ZERO) and change it to TESTM|NM.
- if ((!Opc && SSECC == 4) || Opc == X86ISD::PCMPEQM) {
- SDValue A = peekThroughBitcasts(Op0);
- if ((A.getOpcode() == ISD::AND || A.getOpcode() == X86ISD::FAND) &&
- ISD::isBuildVectorAllZeros(Op1.getNode())) {
- MVT VT0 = Op0.getSimpleValueType();
- SDValue RHS = DAG.getBitcast(VT0, A.getOperand(0));
- SDValue LHS = DAG.getBitcast(VT0, A.getOperand(1));
- return DAG.getNode(Opc == X86ISD::PCMPEQM ? X86ISD::TESTNM : X86ISD::TESTM,
- dl, VT, RHS, LHS);
- }
+ // Prefer SETGT over SETLT.
+ if (SetCCOpcode == ISD::SETLT) {
+ SetCCOpcode = ISD::getSetCCSwappedOperands(SetCCOpcode);
+ std::swap(Op0, Op1);
}
- if (Opc)
- return DAG.getNode(Opc, dl, VT, Op0, Op1);
- Opc = Unsigned ? X86ISD::CMPMU: X86ISD::CMPM;
- return DAG.getNode(Opc, dl, VT, Op0, Op1,
- DAG.getConstant(SSECC, dl, MVT::i8));
+ return DAG.getSetCC(dl, VT, Op0, Op1, SetCCOpcode);
}
-/// \brief Try to turn a VSETULT into a VSETULE by modifying its second
+/// Try to turn a VSETULT into a VSETULE by modifying its second
/// operand \p Op1. If non-trivial (for example because it's not constant)
/// return an empty value.
static SDValue ChangeVSETULTtoVSETULE(const SDLoc &dl, SDValue Op1,
@@ -17624,6 +18320,51 @@ static SDValue ChangeVSETULTtoVSETULE(const SDLoc &dl, SDValue Op1,
return DAG.getBuildVector(VT, dl, ULTOp1);
}
+/// As another special case, use PSUBUS[BW] when it's profitable. E.g. for
+/// Op0 u<= Op1:
+/// t = psubus Op0, Op1
+/// pcmpeq t, <0..0>
+static SDValue LowerVSETCCWithSUBUS(SDValue Op0, SDValue Op1, MVT VT,
+ ISD::CondCode Cond, const SDLoc &dl,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ MVT VET = VT.getVectorElementType();
+ if (VET != MVT::i8 && VET != MVT::i16)
+ return SDValue();
+
+ switch (Cond) {
+ default:
+ return SDValue();
+ case ISD::SETULT: {
+ // If the comparison is against a constant we can turn this into a
+ // setule. With psubus, setule does not require a swap. This is
+ // beneficial because the constant in the register is no longer
+ // destructed as the destination so it can be hoisted out of a loop.
+ // Only do this pre-AVX since vpcmp* is no longer destructive.
+ if (Subtarget.hasAVX())
+ return SDValue();
+ SDValue ULEOp1 = ChangeVSETULTtoVSETULE(dl, Op1, DAG);
+ if (!ULEOp1)
+ return SDValue();
+ Op1 = ULEOp1;
+ break;
+ }
+ // Psubus is better than flip-sign because it requires no inversion.
+ case ISD::SETUGE:
+ std::swap(Op0, Op1);
+ break;
+ case ISD::SETULE:
+ break;
+ }
+
+ SDValue Result = DAG.getNode(X86ISD::SUBUS, dl, VT, Op0, Op1);
+ return DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result,
+ getZeroVector(VT, Subtarget, DAG, dl));
+}
+
static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
SDValue Op0 = Op.getOperand(0);
@@ -17697,23 +18438,10 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
assert(VT.getVectorNumElements() == VTOp0.getVectorNumElements() &&
"Invalid number of packed elements for source and destination!");
- if (VT.is128BitVector() && VTOp0.is256BitVector()) {
- // On non-AVX512 targets, a vector of MVT::i1 is promoted by the type
- // legalizer to a wider vector type. In the case of 'vsetcc' nodes, the
- // legalizer firstly checks if the first operand in input to the setcc has
- // a legal type. If so, then it promotes the return type to that same type.
- // Otherwise, the return type is promoted to the 'next legal type' which,
- // for a vector of MVT::i1 is always a 128-bit integer vector type.
- //
- // We reach this code only if the following two conditions are met:
- // 1. Both return type and operand type have been promoted to wider types
- // by the type legalizer.
- // 2. The original operand type has been promoted to a 256-bit vector.
- //
- // Note that condition 2. only applies for AVX targets.
- SDValue NewOp = DAG.getSetCC(dl, VTOp0, Op0, Op1, Cond);
- return DAG.getZExtOrTrunc(NewOp, dl, VT);
- }
+ // This is being called by type legalization because v2i32 is marked custom
+ // for result type legalization for v2f32.
+ if (VTOp0 == MVT::v2i32)
+ return SDValue();
// The non-AVX512 code below works under the assumption that source and
// destination types are the same.
@@ -17724,31 +18452,17 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
if (VT.is256BitVector() && !Subtarget.hasInt256())
return Lower256IntVSETCC(Op, DAG);
- // Operands are boolean (vectors of i1)
- MVT OpVT = Op1.getSimpleValueType();
- if (OpVT.getVectorElementType() == MVT::i1)
- return LowerBoolVSETCC_AVX512(Op, DAG);
-
// The result is boolean, but operands are int/float
if (VT.getVectorElementType() == MVT::i1) {
// In AVX-512 architecture setcc returns mask with i1 elements,
// But there is no compare instruction for i8 and i16 elements in KNL.
- // In this case use SSE compare
- bool UseAVX512Inst =
- (OpVT.is512BitVector() ||
- OpVT.getScalarSizeInBits() >= 32 ||
- (Subtarget.hasBWI() && Subtarget.hasVLX()));
-
- if (UseAVX512Inst)
- return LowerIntVSETCC_AVX512(Op, DAG);
-
- return DAG.getNode(ISD::TRUNCATE, dl, VT,
- DAG.getNode(ISD::SETCC, dl, OpVT, Op0, Op1, CC));
+ assert((VTOp0.getScalarSizeInBits() >= 32 || Subtarget.hasBWI()) &&
+ "Unexpected operand type");
+ return LowerIntVSETCC_AVX512(Op, DAG);
}
// Lower using XOP integer comparisons.
- if ((VT == MVT::v16i8 || VT == MVT::v8i16 ||
- VT == MVT::v4i32 || VT == MVT::v2i64) && Subtarget.hasXOP()) {
+ if (VT.is128BitVector() && Subtarget.hasXOP()) {
// Translate compare code to XOP PCOM compare mode.
unsigned CmpMode = 0;
switch (Cond) {
@@ -17791,15 +18505,18 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
}
}
- // We are handling one of the integer comparisons here. Since SSE only has
- // GT and EQ comparisons for integer, swapping operands and multiple
- // operations may be required for some comparisons.
- unsigned Opc = (Cond == ISD::SETEQ || Cond == ISD::SETNE) ? X86ISD::PCMPEQ
- : X86ISD::PCMPGT;
- bool Swap = Cond == ISD::SETLT || Cond == ISD::SETULT ||
- Cond == ISD::SETGE || Cond == ISD::SETUGE;
- bool Invert = Cond == ISD::SETNE ||
- (Cond != ISD::SETEQ && ISD::isTrueWhenEqual(Cond));
+ // If this is a SETNE against the signed minimum value, change it to SETGT.
+ // If this is a SETNE against the signed maximum value, change it to SETLT.
+ // which will be swapped to SETGT.
+ // Otherwise we use PCMPEQ+invert.
+ APInt ConstValue;
+ if (Cond == ISD::SETNE &&
+ ISD::isConstantSplatVector(Op1.getNode(), ConstValue)) {
+ if (ConstValue.isMinSignedValue())
+ Cond = ISD::SETGT;
+ else if (ConstValue.isMaxSignedValue())
+ Cond = ISD::SETLT;
+ }
// If both operands are known non-negative, then an unsigned compare is the
// same as a signed compare and there's no need to flip signbits.
@@ -17808,58 +18525,47 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
bool FlipSigns = ISD::isUnsignedIntSetCC(Cond) &&
!(DAG.SignBitIsZero(Op0) && DAG.SignBitIsZero(Op1));
- // Special case: Use min/max operations for SETULE/SETUGE
- MVT VET = VT.getVectorElementType();
- bool HasMinMax =
- (Subtarget.hasAVX512() && VET == MVT::i64) ||
- (Subtarget.hasSSE41() && (VET == MVT::i16 || VET == MVT::i32)) ||
- (Subtarget.hasSSE2() && (VET == MVT::i8));
- bool MinMax = false;
- if (HasMinMax) {
+ // Special case: Use min/max operations for unsigned compares. We only want
+ // to do this for unsigned compares if we need to flip signs or if it allows
+ // use to avoid an invert.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (ISD::isUnsignedIntSetCC(Cond) &&
+ (FlipSigns || ISD::isTrueWhenEqual(Cond)) &&
+ TLI.isOperationLegal(ISD::UMIN, VT)) {
+ bool Invert = false;
+ unsigned Opc;
switch (Cond) {
- default: break;
- case ISD::SETULE: Opc = ISD::UMIN; MinMax = true; break;
- case ISD::SETUGE: Opc = ISD::UMAX; MinMax = true; break;
+ default: llvm_unreachable("Unexpected condition code");
+ case ISD::SETUGT: Invert = true; LLVM_FALLTHROUGH;
+ case ISD::SETULE: Opc = ISD::UMIN; break;
+ case ISD::SETULT: Invert = true; LLVM_FALLTHROUGH;
+ case ISD::SETUGE: Opc = ISD::UMAX; break;
}
- if (MinMax)
- Swap = Invert = FlipSigns = false;
- }
+ SDValue Result = DAG.getNode(Opc, dl, VT, Op0, Op1);
+ Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Op0, Result);
- bool HasSubus = Subtarget.hasSSE2() && (VET == MVT::i8 || VET == MVT::i16);
- bool Subus = false;
- if (!MinMax && HasSubus) {
- // As another special case, use PSUBUS[BW] when it's profitable. E.g. for
- // Op0 u<= Op1:
- // t = psubus Op0, Op1
- // pcmpeq t, <0..0>
- switch (Cond) {
- default: break;
- case ISD::SETULT: {
- // If the comparison is against a constant we can turn this into a
- // setule. With psubus, setule does not require a swap. This is
- // beneficial because the constant in the register is no longer
- // destructed as the destination so it can be hoisted out of a loop.
- // Only do this pre-AVX since vpcmp* is no longer destructive.
- if (Subtarget.hasAVX())
- break;
- if (SDValue ULEOp1 = ChangeVSETULTtoVSETULE(dl, Op1, DAG)) {
- Op1 = ULEOp1;
- Subus = true; Invert = false; Swap = false;
- }
- break;
- }
- // Psubus is better than flip-sign because it requires no inversion.
- case ISD::SETUGE: Subus = true; Invert = false; Swap = true; break;
- case ISD::SETULE: Subus = true; Invert = false; Swap = false; break;
- }
+ // If the logical-not of the result is required, perform that now.
+ if (Invert)
+ Result = DAG.getNOT(dl, Result, VT);
- if (Subus) {
- Opc = X86ISD::SUBUS;
- FlipSigns = false;
- }
+ return Result;
}
+ // Try to use SUBUS and PCMPEQ.
+ if (SDValue V = LowerVSETCCWithSUBUS(Op0, Op1, VT, Cond, dl, Subtarget, DAG))
+ return V;
+
+ // We are handling one of the integer comparisons here. Since SSE only has
+ // GT and EQ comparisons for integer, swapping operands and multiple
+ // operations may be required for some comparisons.
+ unsigned Opc = (Cond == ISD::SETEQ || Cond == ISD::SETNE) ? X86ISD::PCMPEQ
+ : X86ISD::PCMPGT;
+ bool Swap = Cond == ISD::SETLT || Cond == ISD::SETULT ||
+ Cond == ISD::SETGE || Cond == ISD::SETUGE;
+ bool Invert = Cond == ISD::SETNE ||
+ (Cond != ISD::SETEQ && ISD::isTrueWhenEqual(Cond));
+
if (Swap)
std::swap(Op0, Op1);
@@ -17947,14 +18653,47 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
if (Invert)
Result = DAG.getNOT(dl, Result, VT);
- if (MinMax)
- Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Op0, Result);
+ return Result;
+}
- if (Subus)
- Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result,
- getZeroVector(VT, Subtarget, DAG, dl));
+// Try to select this as a KTEST+SETCC if possible.
+static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC,
+ const SDLoc &dl, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ // Only support equality comparisons.
+ if (CC != ISD::SETEQ && CC != ISD::SETNE)
+ return SDValue();
- return Result;
+ // Must be a bitcast from vXi1.
+ if (Op0.getOpcode() != ISD::BITCAST)
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+ MVT VT = Op0.getSimpleValueType();
+ if (!(Subtarget.hasAVX512() && VT == MVT::v16i1) &&
+ !(Subtarget.hasDQI() && VT == MVT::v8i1) &&
+ !(Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1)))
+ return SDValue();
+
+ X86::CondCode X86CC;
+ if (isNullConstant(Op1)) {
+ X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE;
+ } else if (isAllOnesConstant(Op1)) {
+ // C flag is set for all ones.
+ X86CC = CC == ISD::SETEQ ? X86::COND_B : X86::COND_AE;
+ } else
+ return SDValue();
+
+ // If the input is an OR, we can combine it's operands into the KORTEST.
+ SDValue LHS = Op0;
+ SDValue RHS = Op0;
+ if (Op0.getOpcode() == ISD::OR && Op0.hasOneUse()) {
+ LHS = Op0.getOperand(0);
+ RHS = Op0.getOperand(1);
+ }
+
+ SDValue KORTEST = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS);
+ return getSETCC(X86CC, KORTEST, dl, DAG);
}
SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
@@ -17979,6 +18718,18 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
return NewSetCC;
}
+ // Try to use PTEST for a tree ORs equality compared with 0.
+ // TODO: We could do AND tree with all 1s as well by using the C flag.
+ if (Op0.getOpcode() == ISD::OR && isNullConstant(Op1) &&
+ (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ if (SDValue NewSetCC = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG))
+ return NewSetCC;
+ }
+
+ // Try to lower using KTEST.
+ if (SDValue NewSetCC = EmitKTEST(Op0, Op1, CC, dl, DAG, Subtarget))
+ return NewSetCC;
+
// Look for X == 0, X == 1, X != 0, or X != 1. We can simplify some forms of
// these.
if ((isOneConstant(Op1) || isNullConstant(Op1)) &&
@@ -18070,7 +18821,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
// are available or VBLENDV if AVX is available.
// Otherwise FP cmovs get lowered into a less efficient branch sequence later.
if (Cond.getOpcode() == ISD::SETCC &&
- ((Subtarget.hasSSE2() && (VT == MVT::f32 || VT == MVT::f64)) ||
+ ((Subtarget.hasSSE2() && VT == MVT::f64) ||
(Subtarget.hasSSE1() && VT == MVT::f32)) &&
VT == Cond.getOperand(0).getSimpleValueType() && Cond->hasOneUse()) {
SDValue CondOp0 = Cond.getOperand(0), CondOp1 = Cond.getOperand(1);
@@ -18132,6 +18883,18 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2);
}
+ // For v64i1 without 64-bit support we need to split and rejoin.
+ if (VT == MVT::v64i1 && !Subtarget.is64Bit()) {
+ assert(Subtarget.hasBWI() && "Expected BWI to be legal");
+ SDValue Op1Lo = extractSubVector(Op1, 0, DAG, DL, 32);
+ SDValue Op2Lo = extractSubVector(Op2, 0, DAG, DL, 32);
+ SDValue Op1Hi = extractSubVector(Op1, 32, DAG, DL, 32);
+ SDValue Op2Hi = extractSubVector(Op2, 32, DAG, DL, 32);
+ SDValue Lo = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Lo, Op2Lo);
+ SDValue Hi = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Hi, Op2Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+ }
+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1) {
SDValue Op1Scalar;
if (ISD::isBuildVectorOfConstantSDNodes(Op1.getNode()))
@@ -18379,6 +19142,15 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
}
}
+ // Promote i16 cmovs if it won't prevent folding a load.
+ if (Op.getValueType() == MVT::i16 && !MayFoldLoad(Op1) && !MayFoldLoad(Op2)) {
+ Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
+ Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
+ SDValue Ops[] = { Op2, Op1, CC, Cond };
+ SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, MVT::i32, Ops);
+ return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Cmov);
+ }
+
// X86ISD::CMOV means set the result (which is operand 1) to the RHS if
// condition is true.
SDValue Ops[] = { Op2, Op1, CC, Cond };
@@ -18399,8 +19171,13 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op,
// Extend VT if the scalar type is v8/v16 and BWI is not supported.
MVT ExtVT = VT;
- if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16)
+ if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) {
+ // If v16i32 is to be avoided, we'll need to split and concatenate.
+ if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
+ return SplitAndExtendv16i1(ISD::SIGN_EXTEND, VT, In, dl, DAG);
+
ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
+ }
// Widen to 512-bits if VLX is not supported.
MVT WideVT = ExtVT;
@@ -18416,7 +19193,7 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op,
MVT WideEltVT = WideVT.getVectorElementType();
if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) ||
(Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) {
- V = getExtendInVec(X86ISD::VSEXT, dl, WideVT, In, DAG);
+ V = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, In);
} else {
SDValue NegOne = getOnesVector(WideVT, DAG, dl);
SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, dl);
@@ -18445,11 +19222,8 @@ static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
if (InVT.getVectorElementType() == MVT::i1)
return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG);
- if (Subtarget.hasFp256())
- if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget))
- return Res;
-
- return SDValue();
+ assert(Subtarget.hasAVX() && "Expected AVX support");
+ return LowerAVXExtend(Op, DAG, Subtarget);
}
// Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG.
@@ -18549,15 +19323,17 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
if (InVT.getVectorElementType() == MVT::i1)
return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG);
- if ((VT != MVT::v4i64 || InVT != MVT::v4i32) &&
- (VT != MVT::v8i32 || InVT != MVT::v8i16) &&
- (VT != MVT::v16i16 || InVT != MVT::v16i8) &&
- (VT != MVT::v8i64 || InVT != MVT::v8i32) &&
- (VT != MVT::v8i64 || InVT != MVT::v8i16) &&
- (VT != MVT::v16i32 || InVT != MVT::v16i16) &&
- (VT != MVT::v16i32 || InVT != MVT::v16i8) &&
- (VT != MVT::v32i16 || InVT != MVT::v32i8))
- return SDValue();
+ assert(VT.isVector() && InVT.isVector() && "Expected vector type");
+ assert(VT.getVectorNumElements() == VT.getVectorNumElements() &&
+ "Expected same number of elements");
+ assert((VT.getVectorElementType() == MVT::i16 ||
+ VT.getVectorElementType() == MVT::i32 ||
+ VT.getVectorElementType() == MVT::i64) &&
+ "Unexpected element type");
+ assert((InVT.getVectorElementType() == MVT::i8 ||
+ InVT.getVectorElementType() == MVT::i16 ||
+ InVT.getVectorElementType() == MVT::i32) &&
+ "Unexpected element type");
if (Subtarget.hasInt256())
return DAG.getNode(X86ISD::VSEXT, dl, VT, In);
@@ -18595,164 +19371,29 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
}
-// Lower truncating store. We need a special lowering to vXi1 vectors
-static SDValue LowerTruncatingStore(SDValue StOp, const X86Subtarget &Subtarget,
- SelectionDAG &DAG) {
- StoreSDNode *St = cast<StoreSDNode>(StOp.getNode());
+static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ StoreSDNode *St = cast<StoreSDNode>(Op.getNode());
SDLoc dl(St);
- EVT MemVT = St->getMemoryVT();
- assert(St->isTruncatingStore() && "We only custom truncating store.");
- assert(MemVT.isVector() && MemVT.getVectorElementType() == MVT::i1 &&
- "Expected truncstore of i1 vector");
-
- SDValue Op = St->getValue();
- MVT OpVT = Op.getValueType().getSimpleVT();
- unsigned NumElts = OpVT.getVectorNumElements();
- if ((Subtarget.hasVLX() && Subtarget.hasBWI() && Subtarget.hasDQI()) ||
- NumElts == 16) {
- // Truncate and store - everything is legal
- Op = DAG.getNode(ISD::TRUNCATE, dl, MemVT, Op);
- if (MemVT.getSizeInBits() < 8)
- Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
- DAG.getUNDEF(MVT::v8i1), Op,
- DAG.getIntPtrConstant(0, dl));
- return DAG.getStore(St->getChain(), dl, Op, St->getBasePtr(),
- St->getMemOperand());
- }
-
- // A subset, assume that we have only AVX-512F
- if (NumElts <= 8) {
- if (NumElts < 8) {
- // Extend to 8-elts vector
- MVT ExtVT = MVT::getVectorVT(OpVT.getScalarType(), 8);
- Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ExtVT,
- DAG.getUNDEF(ExtVT), Op, DAG.getIntPtrConstant(0, dl));
- }
- Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::v8i1, Op);
- return DAG.getStore(St->getChain(), dl, Op, St->getBasePtr(),
- St->getMemOperand());
- }
- // v32i8
- assert(OpVT == MVT::v32i8 && "Unexpected operand type");
- // Divide the vector into 2 parts and store each part separately
- SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, Op,
- DAG.getIntPtrConstant(0, dl));
- Lo = DAG.getNode(ISD::TRUNCATE, dl, MVT::v16i1, Lo);
- SDValue BasePtr = St->getBasePtr();
- SDValue StLo = DAG.getStore(St->getChain(), dl, Lo, BasePtr,
- St->getMemOperand());
- SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, Op,
- DAG.getIntPtrConstant(16, dl));
- Hi = DAG.getNode(ISD::TRUNCATE, dl, MVT::v16i1, Hi);
-
- SDValue BasePtrHi =
- DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr,
- DAG.getConstant(2, dl, BasePtr.getValueType()));
-
- SDValue StHi = DAG.getStore(St->getChain(), dl, Hi,
- BasePtrHi, St->getMemOperand());
- return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, StLo, StHi);
-}
-
-static SDValue LowerExtended1BitVectorLoad(SDValue Op,
- const X86Subtarget &Subtarget,
- SelectionDAG &DAG) {
-
- LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode());
- SDLoc dl(Ld);
- EVT MemVT = Ld->getMemoryVT();
- assert(MemVT.isVector() && MemVT.getScalarType() == MVT::i1 &&
- "Expected i1 vector load");
- unsigned ExtOpcode = Ld->getExtensionType() == ISD::ZEXTLOAD ?
- ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
- MVT VT = Op.getValueType().getSimpleVT();
- unsigned NumElts = VT.getVectorNumElements();
-
- if ((Subtarget.hasBWI() && NumElts >= 32) ||
- (Subtarget.hasDQI() && NumElts < 16) ||
- NumElts == 16) {
- // Load and extend - everything is legal
- if (NumElts < 8) {
- SDValue Load = DAG.getLoad(MVT::v8i1, dl, Ld->getChain(),
- Ld->getBasePtr(),
- Ld->getMemOperand());
- // Replace chain users with the new chain.
- assert(Load->getNumValues() == 2 && "Loads must carry a chain!");
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1));
- if (Subtarget.hasVLX()) {
- // Extract to v4i1/v2i1.
- SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, Load,
- DAG.getIntPtrConstant(0, dl));
- // Finally, do a normal sign-extend to the desired register.
- return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract);
- }
-
- MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8);
- SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, Load);
-
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec,
- DAG.getIntPtrConstant(0, dl));
- }
- SDValue Load = DAG.getLoad(MemVT, dl, Ld->getChain(),
- Ld->getBasePtr(),
- Ld->getMemOperand());
- // Replace chain users with the new chain.
- assert(Load->getNumValues() == 2 && "Loads must carry a chain!");
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1));
-
- // Finally, do a normal sign-extend to the desired register.
- return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Load);
- }
-
- if (NumElts <= 8) {
- // A subset, assume that we have only AVX-512F
- SDValue Load = DAG.getLoad(MVT::i8, dl, Ld->getChain(),
- Ld->getBasePtr(),
- Ld->getMemOperand());
- // Replace chain users with the new chain.
- assert(Load->getNumValues() == 2 && "Loads must carry a chain!");
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1));
-
- SDValue BitVec = DAG.getBitcast(MVT::v8i1, Load);
-
- if (NumElts == 8)
- return DAG.getNode(ExtOpcode, dl, VT, BitVec);
-
- if (Subtarget.hasVLX()) {
- // Extract to v4i1/v2i1.
- SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, BitVec,
- DAG.getIntPtrConstant(0, dl));
- // Finally, do a normal sign-extend to the desired register.
- return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract);
- }
-
- MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8);
- SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, BitVec);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec,
- DAG.getIntPtrConstant(0, dl));
- }
-
- assert(VT == MVT::v32i8 && "Unexpected extload type");
-
- SDValue BasePtr = Ld->getBasePtr();
- SDValue LoadLo = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(),
- Ld->getBasePtr(),
- Ld->getMemOperand());
-
- SDValue BasePtrHi = DAG.getMemBasePlusOffset(BasePtr, 2, dl);
-
- SDValue LoadHi = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), BasePtrHi,
- Ld->getPointerInfo().getWithOffset(2),
- MinAlign(Ld->getAlignment(), 2U),
- Ld->getMemOperand()->getFlags());
-
- SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
- LoadLo.getValue(1), LoadHi.getValue(1));
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewChain);
+ SDValue StoredVal = St->getValue();
+
+ // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads.
+ assert(StoredVal.getValueType().isVector() &&
+ StoredVal.getValueType().getVectorElementType() == MVT::i1 &&
+ StoredVal.getValueType().getVectorNumElements() <= 8 &&
+ "Unexpected VT");
+ assert(!St->isTruncatingStore() && "Expected non-truncating store");
+ assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() &&
+ "Expected AVX512F without AVX512DQI");
+
+ StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
+ DAG.getUNDEF(MVT::v8i1), StoredVal,
+ DAG.getIntPtrConstant(0, dl));
+ StoredVal = DAG.getBitcast(MVT::i8, StoredVal);
- SDValue Lo = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadLo);
- SDValue Hi = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadHi);
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v32i8, Lo, Hi);
+ return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
+ St->getPointerInfo(), St->getAlignment(),
+ St->getMemOperand()->getFlags());
}
// Lower vector extended loads using a shuffle. If SSSE3 is not available we
@@ -18762,21 +19403,40 @@ static SDValue LowerExtended1BitVectorLoad(SDValue Op,
// FIXME: Is the expansion actually better than scalar code? It doesn't seem so.
// TODO: It is possible to support ZExt by zeroing the undef values during
// the shuffle phase or after the shuffle.
-static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget,
+static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT RegVT = Op.getSimpleValueType();
assert(RegVT.isVector() && "We only custom lower vector sext loads.");
assert(RegVT.isInteger() &&
"We only custom lower integer vector sext loads.");
- // Nothing useful we can do without SSE2 shuffles.
- assert(Subtarget.hasSSE2() && "We only custom lower sext loads with SSE2.");
-
LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode());
SDLoc dl(Ld);
EVT MemVT = Ld->getMemoryVT();
- if (MemVT.getScalarType() == MVT::i1)
- return LowerExtended1BitVectorLoad(Op, Subtarget, DAG);
+
+ // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads.
+ if (RegVT.isVector() && RegVT.getVectorElementType() == MVT::i1) {
+ assert(EVT(RegVT) == MemVT && "Expected non-extending load");
+ assert(RegVT.getVectorNumElements() <= 8 && "Unexpected VT");
+ assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() &&
+ "Expected AVX512F without AVX512DQI");
+
+ SDValue NewLd = DAG.getLoad(MVT::i8, dl, Ld->getChain(), Ld->getBasePtr(),
+ Ld->getPointerInfo(), Ld->getAlignment(),
+ Ld->getMemOperand()->getFlags());
+
+ // Replace chain users with the new chain.
+ assert(NewLd->getNumValues() == 2 && "Loads must carry a chain!");
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewLd.getValue(1));
+
+ SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, RegVT,
+ DAG.getBitcast(MVT::v8i1, NewLd),
+ DAG.getIntPtrConstant(0, dl));
+ return DAG.getMergeValues({Extract, NewLd.getValue(1)}, dl);
+ }
+
+ // Nothing useful we can do without SSE2 shuffles.
+ assert(Subtarget.hasSSE2() && "We only custom lower sext loads with SSE2.");
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned RegSz = RegVT.getSizeInBits();
@@ -19619,7 +20279,7 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt);
}
-/// \brief Return Mask with the necessary casting or extending
+/// Return Mask with the necessary casting or extending
/// for \p Mask according to \p MaskVT when lowering masking intrinsics
static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
const X86Subtarget &Subtarget, SelectionDAG &DAG,
@@ -19637,27 +20297,19 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
}
if (Mask.getSimpleValueType() == MVT::i64 && Subtarget.is32Bit()) {
- if (MaskVT == MVT::v64i1) {
- assert(Subtarget.hasBWI() && "Expected AVX512BW target!");
- // In case 32bit mode, bitcast i64 is illegal, extend/split it.
- SDValue Lo, Hi;
- Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
- DAG.getConstant(0, dl, MVT::i32));
- Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
- DAG.getConstant(1, dl, MVT::i32));
-
- Lo = DAG.getBitcast(MVT::v32i1, Lo);
- Hi = DAG.getBitcast(MVT::v32i1, Hi);
-
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
- } else {
- // MaskVT require < 64bit. Truncate mask (should succeed in any case),
- // and bitcast.
- MVT TruncVT = MVT::getIntegerVT(MaskVT.getSizeInBits());
- return DAG.getBitcast(MaskVT,
- DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Mask));
- }
-
+ assert(MaskVT == MVT::v64i1 && "Expected v64i1 mask!");
+ assert(Subtarget.hasBWI() && "Expected AVX512BW target!");
+ // In case 32bit mode, bitcast i64 is illegal, extend/split it.
+ SDValue Lo, Hi;
+ Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
+ DAG.getConstant(0, dl, MVT::i32));
+ Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
+ DAG.getConstant(1, dl, MVT::i32));
+
+ Lo = DAG.getBitcast(MVT::v32i1, Lo);
+ Hi = DAG.getBitcast(MVT::v32i1, Hi);
+
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
} else {
MVT BitcastVT = MVT::getVectorVT(MVT::i1,
Mask.getSimpleValueType().getSizeInBits());
@@ -19669,7 +20321,7 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
}
}
-/// \brief Return (and \p Op, \p Mask) for compare instructions or
+/// Return (and \p Op, \p Mask) for compare instructions or
/// (vselect \p Mask, \p Op, \p PreservedSrc) for others along with the
/// necessary casting or extending for \p Mask when lowering masking intrinsics
static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask,
@@ -19690,11 +20342,10 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask,
default: break;
case X86ISD::CMPM:
case X86ISD::CMPM_RND:
- case X86ISD::CMPMU:
case X86ISD::VPSHUFBITQMB:
- return DAG.getNode(ISD::AND, dl, VT, Op, VMask);
case X86ISD::VFPCLASS:
- return DAG.getNode(ISD::OR, dl, VT, Op, VMask);
+ return DAG.getNode(ISD::AND, dl, VT, Op, VMask);
+ case ISD::TRUNCATE:
case X86ISD::VTRUNC:
case X86ISD::VTRUNCS:
case X86ISD::VTRUNCUS:
@@ -19710,7 +20361,7 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask,
return DAG.getNode(OpcodeSelect, dl, VT, VMask, Op, PreservedSrc);
}
-/// \brief Creates an SDNode for a predicated scalar operation.
+/// Creates an SDNode for a predicated scalar operation.
/// \returns (X86vselect \p Mask, \p Op, \p PreservedSrc).
/// The mask is coming as MVT::i8 and it should be transformed
/// to MVT::v1i1 while lowering masking intrinsics.
@@ -19729,12 +20380,12 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
MVT VT = Op.getSimpleValueType();
SDLoc dl(Op);
+ assert(Mask.getValueType() == MVT::i8 && "Unexpect type");
SDValue IMask = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Mask);
if (Op.getOpcode() == X86ISD::FSETCCM ||
- Op.getOpcode() == X86ISD::FSETCCM_RND)
+ Op.getOpcode() == X86ISD::FSETCCM_RND ||
+ Op.getOpcode() == X86ISD::VFPCLASSS)
return DAG.getNode(ISD::AND, dl, VT, Op, IMask);
- if (Op.getOpcode() == X86ISD::VFPCLASSS)
- return DAG.getNode(ISD::OR, dl, VT, Op, IMask);
if (PreservedSrc.isUndef())
PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl);
@@ -19819,14 +20470,67 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
const IntrinsicData* IntrData = getIntrinsicWithoutChain(IntNo);
if (IntrData) {
switch(IntrData->Type) {
- case INTR_TYPE_1OP:
+ case INTR_TYPE_1OP: {
+ // We specify 2 possible opcodes for intrinsics with rounding modes.
+ // First, we check if the intrinsic may have non-default rounding mode,
+ // (IntrData->Opc1 != 0), then we check the rounding mode operand.
+ unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
+ if (IntrWithRoundingModeOpcode != 0) {
+ SDValue Rnd = Op.getOperand(2);
+ if (!isRoundModeCurDirection(Rnd)) {
+ return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
+ Op.getOperand(1), Rnd);
+ }
+ }
return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1));
+ }
case INTR_TYPE_2OP:
- return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1),
- Op.getOperand(2));
+ case INTR_TYPE_2OP_IMM8: {
+ SDValue Src2 = Op.getOperand(2);
+
+ if (IntrData->Type == INTR_TYPE_2OP_IMM8)
+ Src2 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src2);
+
+ // We specify 2 possible opcodes for intrinsics with rounding modes.
+ // First, we check if the intrinsic may have non-default rounding mode,
+ // (IntrData->Opc1 != 0), then we check the rounding mode operand.
+ unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
+ if (IntrWithRoundingModeOpcode != 0) {
+ SDValue Rnd = Op.getOperand(3);
+ if (!isRoundModeCurDirection(Rnd)) {
+ return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
+ Op.getOperand(1), Src2, Rnd);
+ }
+ }
+
+ return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
+ Op.getOperand(1), Src2);
+ }
case INTR_TYPE_3OP:
- return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1),
- Op.getOperand(2), Op.getOperand(3));
+ case INTR_TYPE_3OP_IMM8: {
+ SDValue Src1 = Op.getOperand(1);
+ SDValue Src2 = Op.getOperand(2);
+ SDValue Src3 = Op.getOperand(3);
+
+ if (IntrData->Type == INTR_TYPE_3OP_IMM8)
+ Src3 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src3);
+
+ // We specify 2 possible opcodes for intrinsics with rounding modes.
+ // First, we check if the intrinsic may have non-default rounding mode,
+ // (IntrData->Opc1 != 0), then we check the rounding mode operand.
+ unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
+ if (IntrWithRoundingModeOpcode != 0) {
+ SDValue Rnd = Op.getOperand(4);
+ if (!isRoundModeCurDirection(Rnd)) {
+ return DAG.getNode(IntrWithRoundingModeOpcode,
+ dl, Op.getValueType(),
+ Src1, Src2, Src3, Rnd);
+ }
+ }
+
+ return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
+ Src1, Src2, Src3);
+ }
case INTR_TYPE_4OP:
return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(4));
@@ -19927,16 +20631,12 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
RoundingMode, Sae),
Mask, Src0, Subtarget, DAG);
}
- case INTR_TYPE_2OP_MASK:
- case INTR_TYPE_2OP_IMM8_MASK: {
+ case INTR_TYPE_2OP_MASK: {
SDValue Src1 = Op.getOperand(1);
SDValue Src2 = Op.getOperand(2);
SDValue PassThru = Op.getOperand(3);
SDValue Mask = Op.getOperand(4);
- if (IntrData->Type == INTR_TYPE_2OP_IMM8_MASK)
- Src2 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src2);
-
// We specify 2 possible opcodes for intrinsics with rounding modes.
// First, we check if the intrinsic may have non-default rounding mode,
// (IntrData->Opc1 != 0), then we check the rounding mode operand.
@@ -19991,26 +20691,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Src2, Src3),
Mask, PassThru, Subtarget, DAG);
}
- case INTR_TYPE_3OP_MASK_RM: {
- SDValue Src1 = Op.getOperand(1);
- SDValue Src2 = Op.getOperand(2);
- SDValue Imm = Op.getOperand(3);
- SDValue PassThru = Op.getOperand(4);
- SDValue Mask = Op.getOperand(5);
- // We specify 2 possible modes for intrinsics, with/without rounding
- // modes.
- // First, we check if the intrinsic have rounding mode (7 operands),
- // if not, we set rounding mode to "current".
- SDValue Rnd;
- if (Op.getNumOperands() == 7)
- Rnd = Op.getOperand(6);
- else
- Rnd = DAG.getConstant(X86::STATIC_ROUNDING::CUR_DIRECTION, dl, MVT::i32);
- return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT,
- Src1, Src2, Imm, Rnd),
- Mask, PassThru, Subtarget, DAG);
- }
- case INTR_TYPE_3OP_IMM8_MASK:
case INTR_TYPE_3OP_MASK: {
SDValue Src1 = Op.getOperand(1);
SDValue Src2 = Op.getOperand(2);
@@ -20018,9 +20698,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue PassThru = Op.getOperand(4);
SDValue Mask = Op.getOperand(5);
- if (IntrData->Type == INTR_TYPE_3OP_IMM8_MASK)
- Src3 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src3);
-
// We specify 2 possible opcodes for intrinsics with rounding modes.
// First, we check if the intrinsic may have non-default rounding mode,
// (IntrData->Opc1 != 0), then we check the rounding mode operand.
@@ -20038,41 +20715,13 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Src1, Src2, Src3),
Mask, PassThru, Subtarget, DAG);
}
- case VPERM_2OP_MASK : {
+ case VPERM_2OP : {
SDValue Src1 = Op.getOperand(1);
SDValue Src2 = Op.getOperand(2);
- SDValue PassThru = Op.getOperand(3);
- SDValue Mask = Op.getOperand(4);
-
- // Swap Src1 and Src2 in the node creation
- return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1),
- Mask, PassThru, Subtarget, DAG);
- }
- case VPERM_3OP_MASKZ:
- case VPERM_3OP_MASK:{
- MVT VT = Op.getSimpleValueType();
- // Src2 is the PassThru
- SDValue Src1 = Op.getOperand(1);
- // PassThru needs to be the same type as the destination in order
- // to pattern match correctly.
- SDValue Src2 = DAG.getBitcast(VT, Op.getOperand(2));
- SDValue Src3 = Op.getOperand(3);
- SDValue Mask = Op.getOperand(4);
- SDValue PassThru = SDValue();
-
- // set PassThru element
- if (IntrData->Type == VPERM_3OP_MASKZ)
- PassThru = getZeroVector(VT, Subtarget, DAG, dl);
- else
- PassThru = Src2;
// Swap Src1 and Src2 in the node creation
- return getVectorMaskingNode(DAG.getNode(IntrData->Opc0,
- dl, Op.getValueType(),
- Src2, Src1, Src3),
- Mask, PassThru, Subtarget, DAG);
+ return DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1);
}
- case FMA_OP_MASK3:
case FMA_OP_MASKZ:
case FMA_OP_MASK: {
SDValue Src1 = Op.getOperand(1);
@@ -20085,8 +20734,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
// set PassThru element
if (IntrData->Type == FMA_OP_MASKZ)
PassThru = getZeroVector(VT, Subtarget, DAG, dl);
- else if (IntrData->Type == FMA_OP_MASK3)
- PassThru = Src3;
else
PassThru = Src1;
@@ -20107,76 +20754,11 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Src1, Src2, Src3),
Mask, PassThru, Subtarget, DAG);
}
- case FMA_OP_SCALAR_MASK:
- case FMA_OP_SCALAR_MASK3:
- case FMA_OP_SCALAR_MASKZ: {
- SDValue Src1 = Op.getOperand(1);
- SDValue Src2 = Op.getOperand(2);
- SDValue Src3 = Op.getOperand(3);
- SDValue Mask = Op.getOperand(4);
- MVT VT = Op.getSimpleValueType();
- SDValue PassThru = SDValue();
-
- // set PassThru element
- if (IntrData->Type == FMA_OP_SCALAR_MASKZ)
- PassThru = getZeroVector(VT, Subtarget, DAG, dl);
- else if (IntrData->Type == FMA_OP_SCALAR_MASK3)
- PassThru = Src3;
- else
- PassThru = Src1;
-
- unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
- if (IntrWithRoundingModeOpcode != 0) {
- SDValue Rnd = Op.getOperand(5);
- if (!isRoundModeCurDirection(Rnd))
- return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl,
- Op.getValueType(), Src1, Src2,
- Src3, Rnd),
- Mask, PassThru, Subtarget, DAG);
- }
-
- return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl,
- Op.getValueType(), Src1, Src2,
- Src3),
- Mask, PassThru, Subtarget, DAG);
- }
- case IFMA_OP_MASKZ:
- case IFMA_OP_MASK: {
- SDValue Src1 = Op.getOperand(1);
- SDValue Src2 = Op.getOperand(2);
- SDValue Src3 = Op.getOperand(3);
- SDValue Mask = Op.getOperand(4);
- MVT VT = Op.getSimpleValueType();
- SDValue PassThru = Src1;
-
- // set PassThru element
- if (IntrData->Type == IFMA_OP_MASKZ)
- PassThru = getZeroVector(VT, Subtarget, DAG, dl);
-
- // Node we need to swizzle the operands to pass the multiply operands
+ case IFMA_OP:
+ // NOTE: We need to swizzle the operands to pass the multiply operands
// first.
- return getVectorMaskingNode(DAG.getNode(IntrData->Opc0,
- dl, Op.getValueType(),
- Src2, Src3, Src1),
- Mask, PassThru, Subtarget, DAG);
- }
- case TERLOG_OP_MASK:
- case TERLOG_OP_MASKZ: {
- SDValue Src1 = Op.getOperand(1);
- SDValue Src2 = Op.getOperand(2);
- SDValue Src3 = Op.getOperand(3);
- SDValue Src4 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Op.getOperand(4));
- SDValue Mask = Op.getOperand(5);
- MVT VT = Op.getSimpleValueType();
- SDValue PassThru = Src1;
- // Set PassThru element.
- if (IntrData->Type == TERLOG_OP_MASKZ)
- PassThru = getZeroVector(VT, Subtarget, DAG, dl);
-
- return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT,
- Src1, Src2, Src3, Src4),
- Mask, PassThru, Subtarget, DAG);
- }
+ return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
case CVTPD2PS:
// ISD::FP_ROUND has a second argument that indicates if the truncation
// does not change the value. Set it to 0 since it can change.
@@ -20207,21 +20789,11 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Mask, PassThru, Subtarget, DAG);
}
case FPCLASS: {
- // FPclass intrinsics with mask
- SDValue Src1 = Op.getOperand(1);
- MVT VT = Src1.getSimpleValueType();
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
- SDValue Imm = Op.getOperand(2);
- SDValue Mask = Op.getOperand(3);
- MVT BitcastVT = MVT::getVectorVT(MVT::i1,
- Mask.getSimpleValueType().getSizeInBits());
- SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm);
- SDValue FPclassMask = getVectorMaskingNode(FPclass, Mask, SDValue(),
- Subtarget, DAG);
- SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT,
- DAG.getUNDEF(BitcastVT), FPclassMask,
- DAG.getIntPtrConstant(0, dl));
- return DAG.getBitcast(Op.getValueType(), Res);
+ // FPclass intrinsics
+ SDValue Src1 = Op.getOperand(1);
+ MVT MaskVT = Op.getSimpleValueType();
+ SDValue Imm = Op.getOperand(2);
+ return DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm);
}
case FPCLASSS: {
SDValue Src1 = Op.getOperand(1);
@@ -20230,17 +20802,20 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Imm);
SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, SDValue(),
Subtarget, DAG);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, FPclassMask,
- DAG.getIntPtrConstant(0, dl));
- }
- case CMP_MASK:
- case CMP_MASK_CC: {
+ // Need to fill with zeros to ensure the bitcast will produce zeroes
+ // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
+ SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
+ DAG.getConstant(0, dl, MVT::v8i1),
+ FPclassMask, DAG.getIntPtrConstant(0, dl));
+ return DAG.getBitcast(MVT::i8, Ins);
+ }
+ case CMP_MASK: {
// Comparison intrinsics with masks.
// Example of transformation:
// (i8 (int_x86_avx512_mask_pcmpeq_q_128
// (v2i64 %a), (v2i64 %b), (i8 %mask))) ->
// (i8 (bitcast
- // (v8i1 (insert_subvector undef,
+ // (v8i1 (insert_subvector zero,
// (v2i1 (and (PCMPEQM %a, %b),
// (extract_subvector
// (v8i1 (bitcast %mask)), 0))), 0))))
@@ -20249,36 +20824,39 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue Mask = Op.getOperand((IntrData->Type == CMP_MASK_CC) ? 4 : 3);
MVT BitcastVT = MVT::getVectorVT(MVT::i1,
Mask.getSimpleValueType().getSizeInBits());
- SDValue Cmp;
- if (IntrData->Type == CMP_MASK_CC) {
- SDValue CC = Op.getOperand(3);
- CC = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, CC);
- // We specify 2 possible opcodes for intrinsics with rounding modes.
- // First, we check if the intrinsic may have non-default rounding mode,
- // (IntrData->Opc1 != 0), then we check the rounding mode operand.
- if (IntrData->Opc1 != 0) {
- SDValue Rnd = Op.getOperand(5);
- if (!isRoundModeCurDirection(Rnd))
- Cmp = DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1),
- Op.getOperand(2), CC, Rnd);
- }
- //default rounding mode
- if(!Cmp.getNode())
- Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
- Op.getOperand(2), CC);
-
- } else {
- assert(IntrData->Type == CMP_MASK && "Unexpected intrinsic type!");
- Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
- Op.getOperand(2));
- }
+ SDValue Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
+ Op.getOperand(2));
SDValue CmpMask = getVectorMaskingNode(Cmp, Mask, SDValue(),
Subtarget, DAG);
+ // Need to fill with zeros to ensure the bitcast will produce zeroes
+ // for the upper bits in the v2i1/v4i1 case.
SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT,
- DAG.getUNDEF(BitcastVT), CmpMask,
- DAG.getIntPtrConstant(0, dl));
+ DAG.getConstant(0, dl, BitcastVT),
+ CmpMask, DAG.getIntPtrConstant(0, dl));
return DAG.getBitcast(Op.getValueType(), Res);
}
+
+ case CMP_MASK_CC: {
+ MVT MaskVT = Op.getSimpleValueType();
+ SDValue Cmp;
+ SDValue CC = Op.getOperand(3);
+ CC = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, CC);
+ // We specify 2 possible opcodes for intrinsics with rounding modes.
+ // First, we check if the intrinsic may have non-default rounding mode,
+ // (IntrData->Opc1 != 0), then we check the rounding mode operand.
+ if (IntrData->Opc1 != 0) {
+ SDValue Rnd = Op.getOperand(4);
+ if (!isRoundModeCurDirection(Rnd))
+ Cmp = DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1),
+ Op.getOperand(2), CC, Rnd);
+ }
+ //default rounding mode
+ if (!Cmp.getNode())
+ Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
+ Op.getOperand(2), CC);
+
+ return Cmp;
+ }
case CMP_MASK_SCALAR_CC: {
SDValue Src1 = Op.getOperand(1);
SDValue Src2 = Op.getOperand(2);
@@ -20297,8 +20875,12 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue CmpMask = getScalarMaskingNode(Cmp, Mask, SDValue(),
Subtarget, DAG);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, CmpMask,
- DAG.getIntPtrConstant(0, dl));
+ // Need to fill with zeros to ensure the bitcast will produce zeroes
+ // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
+ SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
+ DAG.getConstant(0, dl, MVT::v8i1),
+ CmpMask, DAG.getIntPtrConstant(0, dl));
+ return DAG.getBitcast(MVT::i8, Ins);
}
case COMI: { // Comparison intrinsics
ISD::CondCode CC = (ISD::CondCode)IntrData->Opc1;
@@ -20351,8 +20933,13 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
else
FCmp = DAG.getNode(X86ISD::FSETCCM_RND, dl, MVT::v1i1, LHS, RHS,
DAG.getConstant(CondVal, dl, MVT::i8), Sae);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, FCmp,
- DAG.getIntPtrConstant(0, dl));
+ // Need to fill with zeros to ensure the bitcast will produce zeroes
+ // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
+ SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
+ DAG.getConstant(0, dl, MVT::v16i1),
+ FCmp, DAG.getIntPtrConstant(0, dl));
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32,
+ DAG.getBitcast(MVT::i16, Ins));
}
case VSHIFT:
return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(),
@@ -20369,22 +20956,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
DataToCompress),
Mask, PassThru, Subtarget, DAG);
}
- case BROADCASTM: {
- SDValue Mask = Op.getOperand(1);
- MVT MaskVT = MVT::getVectorVT(MVT::i1,
- Mask.getSimpleValueType().getSizeInBits());
- Mask = DAG.getBitcast(MaskVT, Mask);
- return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Mask);
- }
- case MASK_BINOP: {
- MVT VT = Op.getSimpleValueType();
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getSizeInBits());
-
- SDValue Src1 = getMaskNode(Op.getOperand(1), MaskVT, Subtarget, DAG, dl);
- SDValue Src2 = getMaskNode(Op.getOperand(2), MaskVT, Subtarget, DAG, dl);
- SDValue Res = DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Src2);
- return DAG.getBitcast(VT, Res);
- }
case FIXUPIMMS:
case FIXUPIMMS_MASKZ:
case FIXUPIMM:
@@ -20414,18 +20985,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Src1, Src2, Src3, Imm, Rnd),
Mask, Passthru, Subtarget, DAG);
}
- case CONVERT_TO_MASK: {
- MVT SrcVT = Op.getOperand(1).getSimpleValueType();
- MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements());
- MVT BitcastVT = MVT::getVectorVT(MVT::i1, VT.getSizeInBits());
-
- SDValue CvtMask = DAG.getNode(IntrData->Opc0, dl, MaskVT,
- Op.getOperand(1));
- SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT,
- DAG.getUNDEF(BitcastVT), CvtMask,
- DAG.getIntPtrConstant(0, dl));
- return DAG.getBitcast(Op.getValueType(), Res);
- }
case ROUNDP: {
assert(IntrData->Opc0 == X86ISD::VRNDSCALE && "Unexpected opcode");
// Clear the upper bits of the rounding immediate so that the legacy
@@ -20454,13 +21013,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
switch (IntNo) {
default: return SDValue(); // Don't custom lower most intrinsics.
- case Intrinsic::x86_avx2_permd:
- case Intrinsic::x86_avx2_permps:
- // Operands intentionally swapped. Mask is last operand to intrinsic,
- // but second operand for node/instruction.
- return DAG.getNode(X86ISD::VPERMV, dl, Op.getValueType(),
- Op.getOperand(2), Op.getOperand(1));
-
// ptest and testp intrinsics. The intrinsic these come from are designed to
// return an integer value, not just an instruction so lower it to the ptest
// or testp pattern and a setcc for the result.
@@ -20528,43 +21080,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue SetCC = getSETCC(X86CC, Test, dl, DAG);
return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
}
- case Intrinsic::x86_avx512_kortestz_w:
- case Intrinsic::x86_avx512_kortestc_w: {
- X86::CondCode X86CC =
- (IntNo == Intrinsic::x86_avx512_kortestz_w) ? X86::COND_E : X86::COND_B;
- SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1));
- SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2));
- SDValue Test = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS);
- SDValue SetCC = getSETCC(X86CC, Test, dl, DAG);
- return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
- }
-
- case Intrinsic::x86_avx512_knot_w: {
- SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1));
- SDValue RHS = DAG.getConstant(1, dl, MVT::v16i1);
- SDValue Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS, RHS);
- return DAG.getBitcast(MVT::i16, Res);
- }
-
- case Intrinsic::x86_avx512_kandn_w: {
- SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1));
- // Invert LHS for the not.
- LHS = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS,
- DAG.getConstant(1, dl, MVT::v16i1));
- SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2));
- SDValue Res = DAG.getNode(ISD::AND, dl, MVT::v16i1, LHS, RHS);
- return DAG.getBitcast(MVT::i16, Res);
- }
-
- case Intrinsic::x86_avx512_kxnor_w: {
- SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1));
- SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2));
- SDValue Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, LHS, RHS);
- // Invert result for the not.
- Res = DAG.getNode(ISD::XOR, dl, MVT::v16i1, Res,
- DAG.getConstant(1, dl, MVT::v16i1));
- return DAG.getBitcast(MVT::i16, Res);
- }
case Intrinsic::x86_sse42_pcmpistria128:
case Intrinsic::x86_sse42_pcmpestria128:
@@ -20581,50 +21096,50 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
switch (IntNo) {
default: llvm_unreachable("Impossible intrinsic"); // Can't reach here.
case Intrinsic::x86_sse42_pcmpistria128:
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
X86CC = X86::COND_A;
break;
case Intrinsic::x86_sse42_pcmpestria128:
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
X86CC = X86::COND_A;
break;
case Intrinsic::x86_sse42_pcmpistric128:
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
X86CC = X86::COND_B;
break;
case Intrinsic::x86_sse42_pcmpestric128:
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
X86CC = X86::COND_B;
break;
case Intrinsic::x86_sse42_pcmpistrio128:
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
X86CC = X86::COND_O;
break;
case Intrinsic::x86_sse42_pcmpestrio128:
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
X86CC = X86::COND_O;
break;
case Intrinsic::x86_sse42_pcmpistris128:
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
X86CC = X86::COND_S;
break;
case Intrinsic::x86_sse42_pcmpestris128:
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
X86CC = X86::COND_S;
break;
case Intrinsic::x86_sse42_pcmpistriz128:
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
X86CC = X86::COND_E;
break;
case Intrinsic::x86_sse42_pcmpestriz128:
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
X86CC = X86::COND_E;
break;
}
SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end());
- SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
- SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps);
- SDValue SetCC = getSETCC(X86CC, SDValue(PCMP.getNode(), 1), dl, DAG);
+ SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
+ SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps).getValue(2);
+ SDValue SetCC = getSETCC(X86CC, PCMP, dl, DAG);
return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
}
@@ -20632,15 +21147,28 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::x86_sse42_pcmpestri128: {
unsigned Opcode;
if (IntNo == Intrinsic::x86_sse42_pcmpistri128)
- Opcode = X86ISD::PCMPISTRI;
+ Opcode = X86ISD::PCMPISTR;
else
- Opcode = X86ISD::PCMPESTRI;
+ Opcode = X86ISD::PCMPESTR;
SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end());
- SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
+ SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
return DAG.getNode(Opcode, dl, VTs, NewOps);
}
+ case Intrinsic::x86_sse42_pcmpistrm128:
+ case Intrinsic::x86_sse42_pcmpestrm128: {
+ unsigned Opcode;
+ if (IntNo == Intrinsic::x86_sse42_pcmpistrm128)
+ Opcode = X86ISD::PCMPISTR;
+ else
+ Opcode = X86ISD::PCMPESTR;
+
+ SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end());
+ SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
+ return DAG.getNode(Opcode, dl, VTs, NewOps).getValue(1);
+ }
+
case Intrinsic::eh_sjlj_lsda: {
MachineFunction &MF = DAG.getMachineFunction();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -20708,7 +21236,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Segment = DAG.getRegister(0, MVT::i32);
// If source is undef or we know it won't be used, use a zero vector
// to break register dependency.
- // TODO: use undef instead and let ExecutionDepsFix deal with it?
+ // TODO: use undef instead and let BreakFalseDeps deal with it?
if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode()))
Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
SDValue Ops[] = {Src, Base, Scale, Index, Disp, Segment, Mask, Chain};
@@ -20736,7 +21264,7 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Segment = DAG.getRegister(0, MVT::i32);
// If source is undef or we know it won't be used, use a zero vector
// to break register dependency.
- // TODO: use undef instead and let ExecutionDepsFix deal with it?
+ // TODO: use undef instead and let BreakFalseDeps deal with it?
if (Src.isUndef() || ISD::isBuildVectorAllOnes(VMask.getNode()))
Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain};
@@ -21029,17 +21557,35 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
return SDValue();
}
case Intrinsic::x86_lwpins32:
- case Intrinsic::x86_lwpins64: {
+ case Intrinsic::x86_lwpins64:
+ case Intrinsic::x86_umwait:
+ case Intrinsic::x86_tpause: {
SDLoc dl(Op);
SDValue Chain = Op->getOperand(0);
SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other);
- SDValue LwpIns =
- DAG.getNode(X86ISD::LWPINS, dl, VTs, Chain, Op->getOperand(2),
+ unsigned Opcode;
+
+ switch (IntNo) {
+ default: llvm_unreachable("Impossible intrinsic");
+ case Intrinsic::x86_umwait:
+ Opcode = X86ISD::UMWAIT;
+ break;
+ case Intrinsic::x86_tpause:
+ Opcode = X86ISD::TPAUSE;
+ break;
+ case Intrinsic::x86_lwpins32:
+ case Intrinsic::x86_lwpins64:
+ Opcode = X86ISD::LWPINS;
+ break;
+ }
+
+ SDValue Operation =
+ DAG.getNode(Opcode, dl, VTs, Chain, Op->getOperand(2),
Op->getOperand(3), Op->getOperand(4));
- SDValue SetCC = getSETCC(X86::COND_B, LwpIns.getValue(0), dl, DAG);
+ SDValue SetCC = getSETCC(X86::COND_B, Operation.getValue(0), dl, DAG);
SDValue Result = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, SetCC);
return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result,
- LwpIns.getValue(1));
+ Operation.getValue(1));
}
}
return SDValue();
@@ -21155,27 +21701,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
SDValue Results[] = { SetCC, Store };
return DAG.getMergeValues(Results, dl);
}
- case COMPRESS_TO_MEM: {
- SDValue Mask = Op.getOperand(4);
- SDValue DataToCompress = Op.getOperand(3);
- SDValue Addr = Op.getOperand(2);
- SDValue Chain = Op.getOperand(0);
- MVT VT = DataToCompress.getSimpleValueType();
-
- MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op);
- assert(MemIntr && "Expected MemIntrinsicSDNode!");
-
- if (isAllOnesConstant(Mask)) // return just a store
- return DAG.getStore(Chain, dl, DataToCompress, Addr,
- MemIntr->getMemOperand());
-
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
- SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
-
- return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT,
- MemIntr->getMemOperand(),
- false /* truncating */, true /* compressing */);
- }
case TRUNCATE_TO_MEM_VI8:
case TRUNCATE_TO_MEM_VI16:
case TRUNCATE_TO_MEM_VI32: {
@@ -21219,28 +21744,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
llvm_unreachable("Unsupported truncstore intrinsic");
}
}
-
- case EXPAND_FROM_MEM: {
- SDValue Mask = Op.getOperand(4);
- SDValue PassThru = Op.getOperand(3);
- SDValue Addr = Op.getOperand(2);
- SDValue Chain = Op.getOperand(0);
- MVT VT = Op.getSimpleValueType();
-
- MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op);
- assert(MemIntr && "Expected MemIntrinsicSDNode!");
-
- if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load.
- return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand());
- if (X86::isZeroNode(Mask))
- return DAG.getUNDEF(VT);
-
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
- SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
- return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT,
- MemIntr->getMemOperand(), ISD::NON_EXTLOAD,
- true /* expanding */);
- }
}
}
@@ -21657,14 +22160,16 @@ static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
unsigned NumElems = VT.getVectorNumElements();
unsigned SizeInBits = VT.getSizeInBits();
+ MVT EltVT = VT.getVectorElementType();
+ SDValue Src = Op.getOperand(0);
+ assert(EltVT == Src.getSimpleValueType().getVectorElementType() &&
+ "Src and Op should have the same element type!");
// Extract the Lo/Hi vectors
SDLoc dl(Op);
- SDValue Src = Op.getOperand(0);
SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2);
SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2);
- MVT EltVT = VT.getVectorElementType();
MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2);
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
DAG.getNode(Op.getOpcode(), dl, NewVT, Lo),
@@ -21687,13 +22192,14 @@ static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) {
return LowerVectorIntUnary(Op, DAG);
}
-/// \brief Lower a vector CTLZ using native supported vector CTLZ instruction.
+/// Lower a vector CTLZ using native supported vector CTLZ instruction.
//
// i8/i16 vector implemented using dword LZCNT vector instruction
// ( sub(trunc(lzcnt(zext32(x)))) ). In case zext32(x) is illegal,
// split the vector, perform operation on it's Lo a Hi part and
// concatenate the results.
-static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG) {
+static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
assert(Op.getOpcode() == ISD::CTLZ);
SDLoc dl(Op);
MVT VT = Op.getSimpleValueType();
@@ -21704,7 +22210,8 @@ static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG) {
"Unsupported element type");
// Split vector, it's Lo and Hi parts will be handled in next iteration.
- if (16 < NumElems)
+ if (NumElems > 16 ||
+ (NumElems == 16 && !Subtarget.canExtendTo512DQ()))
return LowerVectorIntUnary(Op, DAG);
MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
@@ -21809,8 +22316,10 @@ static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
- if (Subtarget.hasCDI())
- return LowerVectorCTLZ_AVX512CDI(Op, DAG);
+ if (Subtarget.hasCDI() &&
+ // vXi8 vectors need to be promoted to 512-bits for vXi32.
+ (Subtarget.canExtendTo512DQ() || VT.getVectorElementType() != MVT::i8))
+ return LowerVectorCTLZ_AVX512CDI(Op, DAG, Subtarget);
// Decompose 256-bit ops into smaller 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
@@ -21999,10 +22508,42 @@ static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) {
}
static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
- assert(Op.getSimpleValueType().is256BitVector() &&
- Op.getSimpleValueType().isInteger() &&
- "Only handle AVX 256-bit vector integer operation");
- return Lower256IntArith(Op, DAG);
+ MVT VT = Op.getSimpleValueType();
+
+ // For AVX1 cases, split to use legal ops (everything but v4i64).
+ if (VT.getScalarType() != MVT::i64 && VT.is256BitVector())
+ return Lower256IntArith(Op, DAG);
+
+ SDLoc DL(Op);
+ unsigned Opcode = Op.getOpcode();
+ SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
+
+ // For pre-SSE41, we can perform UMIN/UMAX v8i16 by flipping the signbit,
+ // using the SMIN/SMAX instructions and flipping the signbit back.
+ if (VT == MVT::v8i16) {
+ assert((Opcode == ISD::UMIN || Opcode == ISD::UMAX) &&
+ "Unexpected MIN/MAX opcode");
+ SDValue Sign = DAG.getConstant(APInt::getSignedMinValue(16), DL, VT);
+ N0 = DAG.getNode(ISD::XOR, DL, VT, N0, Sign);
+ N1 = DAG.getNode(ISD::XOR, DL, VT, N1, Sign);
+ Opcode = (Opcode == ISD::UMIN ? ISD::SMIN : ISD::SMAX);
+ SDValue Result = DAG.getNode(Opcode, DL, VT, N0, N1);
+ return DAG.getNode(ISD::XOR, DL, VT, Result, Sign);
+ }
+
+ // Else, expand to a compare/select.
+ ISD::CondCode CC;
+ switch (Opcode) {
+ case ISD::SMIN: CC = ISD::CondCode::SETLT; break;
+ case ISD::SMAX: CC = ISD::CondCode::SETGT; break;
+ case ISD::UMIN: CC = ISD::CondCode::SETULT; break;
+ case ISD::UMAX: CC = ISD::CondCode::SETUGT; break;
+ default: llvm_unreachable("Unknown MINMAX opcode");
+ }
+
+ SDValue Cond = DAG.getSetCC(DL, VT, N0, N1, CC);
+ return DAG.getSelect(DL, VT, Cond, N0, N1);
}
static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
@@ -22048,40 +22589,26 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
MVT ExVT = MVT::v8i16;
// Extract the lo parts and sign extend to i16
- SDValue ALo, BLo;
- if (Subtarget.hasSSE41()) {
- ALo = DAG.getSignExtendVectorInReg(A, dl, ExVT);
- BLo = DAG.getSignExtendVectorInReg(B, dl, ExVT);
- } else {
- const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3,
- -1, 4, -1, 5, -1, 6, -1, 7};
- ALo = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
- BLo = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
- ALo = DAG.getBitcast(ExVT, ALo);
- BLo = DAG.getBitcast(ExVT, BLo);
- ALo = DAG.getNode(ISD::SRA, dl, ExVT, ALo, DAG.getConstant(8, dl, ExVT));
- BLo = DAG.getNode(ISD::SRA, dl, ExVT, BLo, DAG.getConstant(8, dl, ExVT));
- }
+ // We're going to mask off the low byte of each result element of the
+ // pmullw, so it doesn't matter what's in the high byte of each 16-bit
+ // element.
+ const int LoShufMask[] = {0, -1, 1, -1, 2, -1, 3, -1,
+ 4, -1, 5, -1, 6, -1, 7, -1};
+ SDValue ALo = DAG.getVectorShuffle(VT, dl, A, A, LoShufMask);
+ SDValue BLo = DAG.getVectorShuffle(VT, dl, B, B, LoShufMask);
+ ALo = DAG.getBitcast(ExVT, ALo);
+ BLo = DAG.getBitcast(ExVT, BLo);
// Extract the hi parts and sign extend to i16
- SDValue AHi, BHi;
- if (Subtarget.hasSSE41()) {
- const int ShufMask[] = {8, 9, 10, 11, 12, 13, 14, 15,
- -1, -1, -1, -1, -1, -1, -1, -1};
- AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
- BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
- AHi = DAG.getSignExtendVectorInReg(AHi, dl, ExVT);
- BHi = DAG.getSignExtendVectorInReg(BHi, dl, ExVT);
- } else {
- const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11,
- -1, 12, -1, 13, -1, 14, -1, 15};
- AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
- BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
- AHi = DAG.getBitcast(ExVT, AHi);
- BHi = DAG.getBitcast(ExVT, BHi);
- AHi = DAG.getNode(ISD::SRA, dl, ExVT, AHi, DAG.getConstant(8, dl, ExVT));
- BHi = DAG.getNode(ISD::SRA, dl, ExVT, BHi, DAG.getConstant(8, dl, ExVT));
- }
+ // We're going to mask off the low byte of each result element of the
+ // pmullw, so it doesn't matter what's in the high byte of each 16-bit
+ // element.
+ const int HiShufMask[] = {8, -1, 9, -1, 10, -1, 11, -1,
+ 12, -1, 13, -1, 14, -1, 15, -1};
+ SDValue AHi = DAG.getVectorShuffle(VT, dl, A, A, HiShufMask);
+ SDValue BHi = DAG.getVectorShuffle(VT, dl, B, B, HiShufMask);
+ AHi = DAG.getBitcast(ExVT, AHi);
+ BHi = DAG.getBitcast(ExVT, BHi);
// Multiply, mask the lower 8bits of the lo/hi results and pack
SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo);
@@ -22096,22 +22623,19 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
assert(Subtarget.hasSSE2() && !Subtarget.hasSSE41() &&
"Should not custom lower when pmulld is available!");
- // If the upper 17 bits of each element are zero then we can use PMADD.
- APInt Mask17 = APInt::getHighBitsSet(32, 17);
- if (DAG.MaskedValueIsZero(A, Mask17) && DAG.MaskedValueIsZero(B, Mask17))
- return DAG.getNode(X86ISD::VPMADDWD, dl, VT,
- DAG.getBitcast(MVT::v8i16, A),
- DAG.getBitcast(MVT::v8i16, B));
-
// Extract the odd parts.
static const int UnpackMask[] = { 1, -1, 3, -1 };
SDValue Aodds = DAG.getVectorShuffle(VT, dl, A, A, UnpackMask);
SDValue Bodds = DAG.getVectorShuffle(VT, dl, B, B, UnpackMask);
// Multiply the even parts.
- SDValue Evens = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, A, B);
+ SDValue Evens = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64,
+ DAG.getBitcast(MVT::v2i64, A),
+ DAG.getBitcast(MVT::v2i64, B));
// Now multiply odd parts.
- SDValue Odds = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, Aodds, Bodds);
+ SDValue Odds = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64,
+ DAG.getBitcast(MVT::v2i64, Aodds),
+ DAG.getBitcast(MVT::v2i64, Bodds));
Evens = DAG.getBitcast(VT, Evens);
Odds = DAG.getBitcast(VT, Odds);
@@ -22124,17 +22648,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
assert((VT == MVT::v2i64 || VT == MVT::v4i64 || VT == MVT::v8i64) &&
"Only know how to lower V2I64/V4I64/V8I64 multiply");
-
- // 32-bit vector types used for MULDQ/MULUDQ.
- MVT MulVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
-
- // MULDQ returns the 64-bit result of the signed multiplication of the lower
- // 32-bits. We can lower with this if the sign bits stretch that far.
- if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(A) > 32 &&
- DAG.ComputeNumSignBits(B) > 32) {
- return DAG.getNode(X86ISD::PMULDQ, dl, VT, DAG.getBitcast(MulVT, A),
- DAG.getBitcast(MulVT, B));
- }
+ assert(!Subtarget.hasDQI() && "DQI should use MULLQ");
// Ahi = psrlqi(a, 32);
// Bhi = psrlqi(b, 32);
@@ -22145,42 +22659,35 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
//
// Hi = psllqi(AloBhi + AhiBlo, 32);
// return AloBlo + Hi;
+ KnownBits AKnown, BKnown;
+ DAG.computeKnownBits(A, AKnown);
+ DAG.computeKnownBits(B, BKnown);
+
APInt LowerBitsMask = APInt::getLowBitsSet(64, 32);
- bool ALoIsZero = DAG.MaskedValueIsZero(A, LowerBitsMask);
- bool BLoIsZero = DAG.MaskedValueIsZero(B, LowerBitsMask);
+ bool ALoIsZero = LowerBitsMask.isSubsetOf(AKnown.Zero);
+ bool BLoIsZero = LowerBitsMask.isSubsetOf(BKnown.Zero);
APInt UpperBitsMask = APInt::getHighBitsSet(64, 32);
- bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask);
- bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask);
-
- // If DQI is supported we can use MULLQ, but MULUDQ is still better if the
- // the high bits are known to be zero.
- if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero))
- return Op;
-
- // Bit cast to 32-bit vectors for MULUDQ.
- SDValue Alo = DAG.getBitcast(MulVT, A);
- SDValue Blo = DAG.getBitcast(MulVT, B);
+ bool AHiIsZero = UpperBitsMask.isSubsetOf(AKnown.Zero);
+ bool BHiIsZero = UpperBitsMask.isSubsetOf(BKnown.Zero);
SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl);
// Only multiply lo/hi halves that aren't known to be zero.
SDValue AloBlo = Zero;
if (!ALoIsZero && !BLoIsZero)
- AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Blo);
+ AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B);
SDValue AloBhi = Zero;
if (!ALoIsZero && !BHiIsZero) {
SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
- Bhi = DAG.getBitcast(MulVT, Bhi);
- AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Bhi);
+ AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
}
SDValue AhiBlo = Zero;
if (!AHiIsZero && !BLoIsZero) {
SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
- Ahi = DAG.getBitcast(MulVT, Ahi);
- AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, Blo);
+ AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
}
SDValue Hi = DAG.getNode(ISD::ADD, dl, VT, AloBhi, AhiBlo);
@@ -22226,7 +22733,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
SDValue Hi = DAG.getIntPtrConstant(NumElems / 2, dl);
if (VT == MVT::v32i8) {
- if (Subtarget.hasBWI()) {
+ if (Subtarget.canExtendTo512BW()) {
SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v32i16, A);
SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v32i16, B);
SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v32i16, ExA, ExB);
@@ -22277,13 +22784,14 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
assert(VT == MVT::v16i8 &&
"Pre-AVX2 support only supports v16i8 multiplication");
MVT ExVT = MVT::v8i16;
- unsigned ExSSE41 = (ISD::MULHU == Opcode ? X86ISD::VZEXT : X86ISD::VSEXT);
+ unsigned ExSSE41 = ISD::MULHU == Opcode ? ISD::ZERO_EXTEND_VECTOR_INREG
+ : ISD::SIGN_EXTEND_VECTOR_INREG;
// Extract the lo parts and zero/sign extend to i16.
SDValue ALo, BLo;
if (Subtarget.hasSSE41()) {
- ALo = getExtendInVec(ExSSE41, dl, ExVT, A, DAG);
- BLo = getExtendInVec(ExSSE41, dl, ExVT, B, DAG);
+ ALo = DAG.getNode(ExSSE41, dl, ExVT, A);
+ BLo = DAG.getNode(ExSSE41, dl, ExVT, B);
} else {
const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3,
-1, 4, -1, 5, -1, 6, -1, 7};
@@ -22302,8 +22810,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
-1, -1, -1, -1, -1, -1, -1, -1};
AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
- AHi = getExtendInVec(ExSSE41, dl, ExVT, AHi, DAG);
- BHi = getExtendInVec(ExSSE41, dl, ExVT, BHi, DAG);
+ AHi = DAG.getNode(ExSSE41, dl, ExVT, AHi);
+ BHi = DAG.getNode(ExSSE41, dl, ExVT, BHi);
} else {
const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11,
-1, 12, -1, 13, -1, 14, -1, 15};
@@ -22438,10 +22946,14 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget,
(!IsSigned || !Subtarget.hasSSE41()) ? X86ISD::PMULUDQ : X86ISD::PMULDQ;
// PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h>
// => <2 x i64> <ae|cg>
- SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, Op0, Op1));
+ SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT,
+ DAG.getBitcast(MulVT, Op0),
+ DAG.getBitcast(MulVT, Op1)));
// PMULUDQ <4 x i32> <b|undef|d|undef>, <4 x i32> <f|undef|h|undef>
// => <2 x i64> <bf|dh>
- SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, Odd0, Odd1));
+ SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT,
+ DAG.getBitcast(MulVT, Odd0),
+ DAG.getBitcast(MulVT, Odd1)));
// Shuffle it back into the right order.
SmallVector<int, 16> HighMask(NumElts);
@@ -22601,7 +23113,8 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
if (VT.is512BitVector()) {
assert(VT == MVT::v64i8 && "Unexpected element type!");
- SDValue CMP = DAG.getNode(X86ISD::PCMPGTM, dl, MVT::v64i1, Zeros, R);
+ SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R,
+ ISD::SETGT);
return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
}
return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
@@ -22711,57 +23224,81 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
return SDValue();
}
+// Determine if V is a splat value, and return the scalar.
+static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl,
+ SelectionDAG &DAG, const X86Subtarget &Subtarget,
+ unsigned Opcode) {
+ V = peekThroughEXTRACT_SUBVECTORs(V);
+
+ // Check if this is a splat build_vector node.
+ if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) {
+ SDValue SplatAmt = BV->getSplatValue();
+ if (SplatAmt && SplatAmt.isUndef())
+ return SDValue();
+ return SplatAmt;
+ }
+
+ // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns.
+ if (V.getOpcode() == ISD::SUB &&
+ !SupportedVectorVarShift(VT, Subtarget, Opcode)) {
+ SDValue LHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(0));
+ SDValue RHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(1));
+
+ // Ensure that the corresponding splat BV element is not UNDEF.
+ BitVector UndefElts;
+ BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS);
+ ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS);
+ if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) {
+ unsigned SplatIdx = (unsigned)SVN1->getSplatIndex();
+ if (!UndefElts[SplatIdx])
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+ VT.getVectorElementType(), V,
+ DAG.getIntPtrConstant(SplatIdx, dl));
+ }
+ }
+
+ // Check if this is a shuffle node doing a splat.
+ ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V);
+ if (!SVN || !SVN->isSplat())
+ return SDValue();
+
+ unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
+ SDValue InVec = V.getOperand(0);
+ if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
+ assert((SplatIdx < VT.getVectorNumElements()) &&
+ "Unexpected shuffle index found!");
+ return InVec.getOperand(SplatIdx);
+ } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2)))
+ if (C->getZExtValue() == SplatIdx)
+ return InVec.getOperand(1);
+ }
+
+ // Avoid introducing an extract element from a shuffle.
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+ VT.getVectorElementType(), InVec,
+ DAG.getIntPtrConstant(SplatIdx, dl));
+}
+
static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
SDLoc dl(Op);
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
+ unsigned Opcode = Op.getOpcode();
- unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
- (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
-
- unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
- (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
-
- if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
- SDValue BaseShAmt;
- MVT EltVT = VT.getVectorElementType();
-
- if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Amt)) {
- // Check if this build_vector node is doing a splat.
- // If so, then set BaseShAmt equal to the splat value.
- BaseShAmt = BV->getSplatValue();
- if (BaseShAmt && BaseShAmt.isUndef())
- BaseShAmt = SDValue();
- } else {
- if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
- Amt = Amt.getOperand(0);
+ unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI :
+ (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
- ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(Amt);
- if (SVN && SVN->isSplat()) {
- unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
- SDValue InVec = Amt.getOperand(0);
- if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
- assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
- "Unexpected shuffle index found!");
- BaseShAmt = InVec.getOperand(SplatIdx);
- } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
- if (ConstantSDNode *C =
- dyn_cast<ConstantSDNode>(InVec.getOperand(2))) {
- if (C->getZExtValue() == SplatIdx)
- BaseShAmt = InVec.getOperand(1);
- }
- }
+ unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL :
+ (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
- if (!BaseShAmt)
- // Avoid introducing an extract element from a shuffle.
- BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, InVec,
- DAG.getIntPtrConstant(SplatIdx, dl));
- }
- }
+ Amt = peekThroughEXTRACT_SUBVECTORs(Amt);
- if (BaseShAmt.getNode()) {
+ if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) {
+ if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG, Subtarget, Opcode)) {
+ MVT EltVT = VT.getVectorElementType();
assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, BaseShAmt);
@@ -22793,6 +23330,70 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
return SDValue();
}
+// Convert a shift/rotate left amount to a multiplication scale factor.
+static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ MVT VT = Amt.getSimpleValueType();
+ if (!(VT == MVT::v8i16 || VT == MVT::v4i32 ||
+ (Subtarget.hasInt256() && VT == MVT::v16i16) ||
+ (!Subtarget.hasAVX512() && VT == MVT::v16i8)))
+ return SDValue();
+
+ if (ISD::isBuildVectorOfConstantSDNodes(Amt.getNode())) {
+ SmallVector<SDValue, 8> Elts;
+ MVT SVT = VT.getVectorElementType();
+ unsigned SVTBits = SVT.getSizeInBits();
+ APInt One(SVTBits, 1);
+ unsigned NumElems = VT.getVectorNumElements();
+
+ for (unsigned i = 0; i != NumElems; ++i) {
+ SDValue Op = Amt->getOperand(i);
+ if (Op->isUndef()) {
+ Elts.push_back(Op);
+ continue;
+ }
+
+ ConstantSDNode *ND = cast<ConstantSDNode>(Op);
+ APInt C(SVTBits, ND->getAPIntValue().getZExtValue());
+ uint64_t ShAmt = C.getZExtValue();
+ if (ShAmt >= SVTBits) {
+ Elts.push_back(DAG.getUNDEF(SVT));
+ continue;
+ }
+ Elts.push_back(DAG.getConstant(One.shl(ShAmt), dl, SVT));
+ }
+ return DAG.getBuildVector(VT, dl, Elts);
+ }
+
+ // If the target doesn't support variable shifts, use either FP conversion
+ // or integer multiplication to avoid shifting each element individually.
+ if (VT == MVT::v4i32) {
+ Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT));
+ Amt = DAG.getNode(ISD::ADD, dl, VT, Amt,
+ DAG.getConstant(0x3f800000U, dl, VT));
+ Amt = DAG.getBitcast(MVT::v4f32, Amt);
+ return DAG.getNode(ISD::FP_TO_SINT, dl, VT, Amt);
+ }
+
+ // AVX2 can more effectively perform this as a zext/trunc to/from v8i32.
+ if (VT == MVT::v8i16 && !Subtarget.hasAVX2()) {
+ SDValue Z = getZeroVector(VT, Subtarget, DAG, dl);
+ SDValue Lo = DAG.getBitcast(MVT::v4i32, getUnpackl(DAG, dl, VT, Amt, Z));
+ SDValue Hi = DAG.getBitcast(MVT::v4i32, getUnpackh(DAG, dl, VT, Amt, Z));
+ Lo = convertShiftLeftToScale(Lo, dl, Subtarget, DAG);
+ Hi = convertShiftLeftToScale(Hi, dl, Subtarget, DAG);
+ if (Subtarget.hasSSE41())
+ return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi);
+
+ return DAG.getVectorShuffle(VT, dl, DAG.getBitcast(VT, Lo),
+ DAG.getBitcast(VT, Hi),
+ {0, 2, 4, 6, 8, 10, 12, 14});
+ }
+
+ return SDValue();
+}
+
static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
@@ -22815,11 +23416,10 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
// XOP has 128-bit variable logical/arithmetic shifts.
// +ve/-ve Amt = shift left/right.
- if (Subtarget.hasXOP() &&
- (VT == MVT::v2i64 || VT == MVT::v4i32 ||
- VT == MVT::v8i16 || VT == MVT::v16i8)) {
+ if (Subtarget.hasXOP() && (VT == MVT::v2i64 || VT == MVT::v4i32 ||
+ VT == MVT::v8i16 || VT == MVT::v16i8)) {
if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) {
- SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl);
+ SDValue Zero = DAG.getConstant(0, dl, VT);
Amt = DAG.getNode(ISD::SUB, dl, VT, Zero, Amt);
}
if (Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL)
@@ -22852,51 +23452,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
return R;
}
- // If possible, lower this packed shift into a vector multiply instead of
- // expanding it into a sequence of scalar shifts.
- // Do this only if the vector shift count is a constant build_vector.
- if (ConstantAmt && Op.getOpcode() == ISD::SHL &&
- (VT == MVT::v8i16 || VT == MVT::v4i32 ||
- (Subtarget.hasInt256() && VT == MVT::v16i16))) {
- SmallVector<SDValue, 8> Elts;
- MVT SVT = VT.getVectorElementType();
- unsigned SVTBits = SVT.getSizeInBits();
- APInt One(SVTBits, 1);
- unsigned NumElems = VT.getVectorNumElements();
-
- for (unsigned i=0; i !=NumElems; ++i) {
- SDValue Op = Amt->getOperand(i);
- if (Op->isUndef()) {
- Elts.push_back(Op);
- continue;
- }
-
- ConstantSDNode *ND = cast<ConstantSDNode>(Op);
- APInt C(SVTBits, ND->getAPIntValue().getZExtValue());
- uint64_t ShAmt = C.getZExtValue();
- if (ShAmt >= SVTBits) {
- Elts.push_back(DAG.getUNDEF(SVT));
- continue;
- }
- Elts.push_back(DAG.getConstant(One.shl(ShAmt), dl, SVT));
- }
- SDValue BV = DAG.getBuildVector(VT, dl, Elts);
- return DAG.getNode(ISD::MUL, dl, VT, R, BV);
- }
-
- // Lower SHL with variable shift amount.
- if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
- Op = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT));
-
- Op = DAG.getNode(ISD::ADD, dl, VT, Op,
- DAG.getConstant(0x3f800000U, dl, VT));
- Op = DAG.getBitcast(MVT::v4f32, Op);
- Op = DAG.getNode(ISD::FP_TO_SINT, dl, VT, Op);
- return DAG.getNode(ISD::MUL, dl, VT, Op, R);
- }
-
// If possible, lower this shift as a sequence of two shifts by
- // constant plus a MOVSS/MOVSD/PBLEND instead of scalarizing it.
+ // constant plus a BLENDing shuffle instead of scalarizing it.
// Example:
// (v4i32 (srl A, (build_vector < X, Y, Y, Y>)))
//
@@ -22904,67 +23461,54 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
// (v4i32 (MOVSS (srl A, <Y,Y,Y,Y>), (srl A, <X,X,X,X>)))
//
// The advantage is that the two shifts from the example would be
- // lowered as X86ISD::VSRLI nodes. This would be cheaper than scalarizing
- // the vector shift into four scalar shifts plus four pairs of vector
- // insert/extract.
- if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32)) {
- bool UseMOVSD = false;
- bool CanBeSimplified;
- // The splat value for the first packed shift (the 'X' from the example).
- SDValue Amt1 = Amt->getOperand(0);
- // The splat value for the second packed shift (the 'Y' from the example).
- SDValue Amt2 = (VT == MVT::v4i32) ? Amt->getOperand(1) : Amt->getOperand(2);
-
- // See if it is possible to replace this node with a sequence of
- // two shifts followed by a MOVSS/MOVSD/PBLEND.
- if (VT == MVT::v4i32) {
- // Check if it is legal to use a MOVSS.
- CanBeSimplified = Amt2 == Amt->getOperand(2) &&
- Amt2 == Amt->getOperand(3);
- if (!CanBeSimplified) {
- // Otherwise, check if we can still simplify this node using a MOVSD.
- CanBeSimplified = Amt1 == Amt->getOperand(1) &&
- Amt->getOperand(2) == Amt->getOperand(3);
- UseMOVSD = true;
- Amt2 = Amt->getOperand(2);
+ // lowered as X86ISD::VSRLI nodes in parallel before blending.
+ if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32 ||
+ (VT == MVT::v16i16 && Subtarget.hasInt256()))) {
+ SDValue Amt1, Amt2;
+ unsigned NumElts = VT.getVectorNumElements();
+ SmallVector<int, 8> ShuffleMask;
+ for (unsigned i = 0; i != NumElts; ++i) {
+ SDValue A = Amt->getOperand(i);
+ if (A.isUndef()) {
+ ShuffleMask.push_back(SM_SentinelUndef);
+ continue;
}
- } else {
- // Do similar checks for the case where the machine value type
- // is MVT::v8i16.
- CanBeSimplified = Amt1 == Amt->getOperand(1);
- for (unsigned i=3; i != 8 && CanBeSimplified; ++i)
- CanBeSimplified = Amt2 == Amt->getOperand(i);
-
- if (!CanBeSimplified) {
- UseMOVSD = true;
- CanBeSimplified = true;
- Amt2 = Amt->getOperand(4);
- for (unsigned i=0; i != 4 && CanBeSimplified; ++i)
- CanBeSimplified = Amt1 == Amt->getOperand(i);
- for (unsigned j=4; j != 8 && CanBeSimplified; ++j)
- CanBeSimplified = Amt2 == Amt->getOperand(j);
+ if (!Amt1 || Amt1 == A) {
+ ShuffleMask.push_back(i);
+ Amt1 = A;
+ continue;
}
+ if (!Amt2 || Amt2 == A) {
+ ShuffleMask.push_back(i + NumElts);
+ Amt2 = A;
+ continue;
+ }
+ break;
}
- if (CanBeSimplified && isa<ConstantSDNode>(Amt1) &&
- isa<ConstantSDNode>(Amt2)) {
- // Replace this node with two shifts followed by a MOVSS/MOVSD/PBLEND.
+ // Only perform this blend if we can perform it without loading a mask.
+ if (ShuffleMask.size() == NumElts && Amt1 && Amt2 &&
+ isa<ConstantSDNode>(Amt1) && isa<ConstantSDNode>(Amt2) &&
+ (VT != MVT::v16i16 ||
+ is128BitLaneRepeatedShuffleMask(VT, ShuffleMask)) &&
+ (VT == MVT::v4i32 || Subtarget.hasSSE41() ||
+ Op.getOpcode() != ISD::SHL || canWidenShuffleElements(ShuffleMask))) {
SDValue Splat1 =
DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT);
SDValue Shift1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat1);
SDValue Splat2 =
DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT);
SDValue Shift2 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat2);
- SDValue BitCast1 = DAG.getBitcast(MVT::v4i32, Shift1);
- SDValue BitCast2 = DAG.getBitcast(MVT::v4i32, Shift2);
- if (UseMOVSD)
- return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1,
- BitCast2, {0, 1, 6, 7}));
- return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1,
- BitCast2, {0, 5, 6, 7}));
+ return DAG.getVectorShuffle(VT, dl, Shift1, Shift2, ShuffleMask);
}
}
+ // If possible, lower this packed shift into a vector multiply instead of
+ // expanding it into a sequence of scalar shifts.
+ if (Op.getOpcode() == ISD::SHL)
+ if (SDValue Scale = convertShiftLeftToScale(Amt, dl, Subtarget, DAG))
+ return DAG.getNode(ISD::MUL, dl, VT, R, Scale);
+
// v4i32 Non Uniform Shifts.
// If the shift amount is constant we can shift each lane using the SSE2
// immediate shifts, else we need to zero-extend each lane to the lower i64
@@ -22994,31 +23538,56 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
break;
}
// The SSE2 shifts use the lower i64 as the same shift amount for
- // all lanes and the upper i64 is ignored. These shuffle masks
- // optimally zero-extend each lanes on SSE2/SSE41/AVX targets.
- SDValue Z = getZeroVector(VT, Subtarget, DAG, dl);
- Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1});
- Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1});
- Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1});
- Amt3 = DAG.getVectorShuffle(VT, dl, Amt, Z, {3, 7, -1, -1});
+ // all lanes and the upper i64 is ignored. On AVX we're better off
+ // just zero-extending, but for SSE just duplicating the top 16-bits is
+ // cheaper and has the same effect for out of range values.
+ if (Subtarget.hasAVX()) {
+ SDValue Z = getZeroVector(VT, Subtarget, DAG, dl);
+ Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1});
+ Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1});
+ Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1});
+ Amt3 = DAG.getVectorShuffle(VT, dl, Amt, Z, {3, 7, -1, -1});
+ } else {
+ SDValue Amt01 = DAG.getBitcast(MVT::v8i16, Amt);
+ SDValue Amt23 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01,
+ {4, 5, 6, 7, -1, -1, -1, -1});
+ Amt0 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01,
+ {0, 1, 1, 1, -1, -1, -1, -1});
+ Amt1 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01,
+ {2, 3, 3, 3, -1, -1, -1, -1});
+ Amt2 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt23, Amt23,
+ {0, 1, 1, 1, -1, -1, -1, -1});
+ Amt3 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt23, Amt23,
+ {2, 3, 3, 3, -1, -1, -1, -1});
+ }
}
- SDValue R0 = DAG.getNode(Opc, dl, VT, R, Amt0);
- SDValue R1 = DAG.getNode(Opc, dl, VT, R, Amt1);
- SDValue R2 = DAG.getNode(Opc, dl, VT, R, Amt2);
- SDValue R3 = DAG.getNode(Opc, dl, VT, R, Amt3);
- SDValue R02 = DAG.getVectorShuffle(VT, dl, R0, R2, {0, -1, 6, -1});
- SDValue R13 = DAG.getVectorShuffle(VT, dl, R1, R3, {-1, 1, -1, 7});
- return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7});
+ SDValue R0 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt0));
+ SDValue R1 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt1));
+ SDValue R2 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt2));
+ SDValue R3 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt3));
+
+ // Merge the shifted lane results optimally with/without PBLENDW.
+ // TODO - ideally shuffle combining would handle this.
+ if (Subtarget.hasSSE41()) {
+ SDValue R02 = DAG.getVectorShuffle(VT, dl, R0, R2, {0, -1, 6, -1});
+ SDValue R13 = DAG.getVectorShuffle(VT, dl, R1, R3, {-1, 1, -1, 7});
+ return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7});
+ }
+ SDValue R01 = DAG.getVectorShuffle(VT, dl, R0, R1, {0, -1, -1, 5});
+ SDValue R23 = DAG.getVectorShuffle(VT, dl, R2, R3, {2, -1, -1, 7});
+ return DAG.getVectorShuffle(VT, dl, R01, R23, {0, 3, 4, 7});
}
// It's worth extending once and using the vXi16/vXi32 shifts for smaller
// types, but without AVX512 the extra overheads to get from vXi8 to vXi32
// make the existing SSE solution better.
+ // NOTE: We honor prefered vector width before promoting to 512-bits.
if ((Subtarget.hasInt256() && VT == MVT::v8i16) ||
- (Subtarget.hasAVX512() && VT == MVT::v16i16) ||
- (Subtarget.hasAVX512() && VT == MVT::v16i8) ||
- (Subtarget.hasBWI() && VT == MVT::v32i8)) {
+ (Subtarget.canExtendTo512DQ() && VT == MVT::v16i16) ||
+ (Subtarget.canExtendTo512DQ() && VT == MVT::v16i8) ||
+ (Subtarget.canExtendTo512BW() && VT == MVT::v32i8) ||
+ (Subtarget.hasBWI() && Subtarget.hasVLX() && VT == MVT::v16i8)) {
assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
"Unexpected vector type");
MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
@@ -23046,7 +23615,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
V0 = DAG.getBitcast(VT, V0);
V1 = DAG.getBitcast(VT, V1);
Sel = DAG.getBitcast(VT, Sel);
- Sel = DAG.getNode(X86ISD::CVT2MASK, dl, MaskVT, Sel);
+ Sel = DAG.getSetCC(dl, MaskVT, DAG.getConstant(0, dl, VT), Sel,
+ ISD::SETGT);
return DAG.getBitcast(SelVT, DAG.getSelect(dl, VT, Sel, V0, V1));
} else if (Subtarget.hasSSE41()) {
// On SSE41 targets we make use of the fact that VSELECT lowers
@@ -23242,13 +23812,15 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
+ assert(VT.isVector() && "Custom lowering only for vector rotates!");
+
SDLoc DL(Op);
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
unsigned Opcode = Op.getOpcode();
unsigned EltSizeInBits = VT.getScalarSizeInBits();
- if (Subtarget.hasAVX512()) {
+ if (Subtarget.hasAVX512() && 32 <= EltSizeInBits) {
// Attempt to rotate by immediate.
APInt UndefElts;
SmallVector<APInt, 16> EltBits;
@@ -23267,31 +23839,178 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
return Op;
}
- assert(VT.isVector() && "Custom lowering only for vector rotates!");
- assert(Subtarget.hasXOP() && "XOP support required for vector rotates!");
assert((Opcode == ISD::ROTL) && "Only ROTL supported");
// XOP has 128-bit vector variable + immediate rotates.
// +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL.
+ if (Subtarget.hasXOP()) {
+ // Split 256-bit integers.
+ if (VT.is256BitVector())
+ return Lower256IntArith(Op, DAG);
+ assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
- // Split 256-bit integers.
- if (VT.is256BitVector())
+ // Attempt to rotate by immediate.
+ if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
+ if (auto *RotateConst = BVAmt->getConstantSplatNode()) {
+ uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue();
+ assert(RotateAmt < EltSizeInBits && "Rotation out of range");
+ return DAG.getNode(X86ISD::VROTLI, DL, VT, R,
+ DAG.getConstant(RotateAmt, DL, MVT::i8));
+ }
+ }
+
+ // Use general rotate by variable (per-element).
+ return Op;
+ }
+
+ // Split 256-bit integers on pre-AVX2 targets.
+ if (VT.is256BitVector() && !Subtarget.hasAVX2())
return Lower256IntArith(Op, DAG);
- assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
+ assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 ||
+ ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) &&
+ Subtarget.hasAVX2())) &&
+ "Only vXi32/vXi16/vXi8 vector rotates supported");
- // Attempt to rotate by immediate.
+ // Rotate by an uniform constant - expand back to shifts.
+ // TODO - legalizers should be able to handle this.
if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
if (auto *RotateConst = BVAmt->getConstantSplatNode()) {
uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue();
assert(RotateAmt < EltSizeInBits && "Rotation out of range");
- return DAG.getNode(X86ISD::VROTLI, DL, VT, R,
- DAG.getConstant(RotateAmt, DL, MVT::i8));
+ if (RotateAmt == 0)
+ return R;
+
+ SDValue AmtR = DAG.getConstant(EltSizeInBits - RotateAmt, DL, VT);
+ SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
+ SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
+ return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
}
}
- // Use general rotate by variable (per-element).
- return Op;
+ // Rotate by splat - expand back to shifts.
+ // TODO - legalizers should be able to handle this.
+ if ((EltSizeInBits >= 16 || Subtarget.hasBWI()) &&
+ IsSplatValue(VT, Amt, DL, DAG, Subtarget, Opcode)) {
+ SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
+ AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
+ SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
+ SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
+ return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
+ }
+
+ // v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by
+ // the amount bit.
+ if (EltSizeInBits == 8) {
+ if (Subtarget.hasBWI()) {
+ SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
+ AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
+ SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
+ SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
+ return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
+ }
+
+ MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
+
+ auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
+ if (Subtarget.hasSSE41()) {
+ // On SSE41 targets we make use of the fact that VSELECT lowers
+ // to PBLENDVB which selects bytes based just on the sign bit.
+ V0 = DAG.getBitcast(VT, V0);
+ V1 = DAG.getBitcast(VT, V1);
+ Sel = DAG.getBitcast(VT, Sel);
+ return DAG.getBitcast(SelVT, DAG.getSelect(DL, VT, Sel, V0, V1));
+ }
+ // On pre-SSE41 targets we test for the sign bit by comparing to
+ // zero - a negative value will set all bits of the lanes to true
+ // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering.
+ SDValue Z = getZeroVector(SelVT, Subtarget, DAG, DL);
+ SDValue C = DAG.getNode(X86ISD::PCMPGT, DL, SelVT, Z, Sel);
+ return DAG.getSelect(DL, SelVT, C, V0, V1);
+ };
+
+ // Turn 'a' into a mask suitable for VSELECT: a = a << 5;
+ // We can safely do this using i16 shifts as we're only interested in
+ // the 3 lower bits of each byte.
+ Amt = DAG.getBitcast(ExtVT, Amt);
+ Amt = DAG.getNode(ISD::SHL, DL, ExtVT, Amt, DAG.getConstant(5, DL, ExtVT));
+ Amt = DAG.getBitcast(VT, Amt);
+
+ // r = VSELECT(r, rot(r, 4), a);
+ SDValue M;
+ M = DAG.getNode(
+ ISD::OR, DL, VT,
+ DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(4, DL, VT)),
+ DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(4, DL, VT)));
+ R = SignBitSelect(VT, Amt, M, R);
+
+ // a += a
+ Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt);
+
+ // r = VSELECT(r, rot(r, 2), a);
+ M = DAG.getNode(
+ ISD::OR, DL, VT,
+ DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(2, DL, VT)),
+ DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(6, DL, VT)));
+ R = SignBitSelect(VT, Amt, M, R);
+
+ // a += a
+ Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt);
+
+ // return VSELECT(r, rot(r, 1), a);
+ M = DAG.getNode(
+ ISD::OR, DL, VT,
+ DAG.getNode(ISD::SHL, DL, VT, R, DAG.getConstant(1, DL, VT)),
+ DAG.getNode(ISD::SRL, DL, VT, R, DAG.getConstant(7, DL, VT)));
+ return SignBitSelect(VT, Amt, M, R);
+ }
+
+ bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
+ bool LegalVarShifts = SupportedVectorVarShift(VT, Subtarget, ISD::SHL) &&
+ SupportedVectorVarShift(VT, Subtarget, ISD::SRL);
+
+ // Best to fallback for all supported variable shifts.
+ // AVX2 - best to fallback for non-constants as well.
+ // TODO - legalizers should be able to handle this.
+ if (LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) {
+ SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
+ AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
+ SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
+ SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
+ return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
+ }
+
+ // As with shifts, convert the rotation amount to a multiplication factor.
+ SDValue Scale = convertShiftLeftToScale(Amt, DL, Subtarget, DAG);
+ assert(Scale && "Failed to convert ROTL amount to scale");
+
+ // v8i16/v16i16: perform unsigned multiply hi/lo and OR the results.
+ if (EltSizeInBits == 16) {
+ SDValue Lo = DAG.getNode(ISD::MUL, DL, VT, R, Scale);
+ SDValue Hi = DAG.getNode(ISD::MULHU, DL, VT, R, Scale);
+ return DAG.getNode(ISD::OR, DL, VT, Lo, Hi);
+ }
+
+ // v4i32: make use of the PMULUDQ instruction to multiply 2 lanes of v4i32
+ // to v2i64 results at a time. The upper 32-bits contain the wrapped bits
+ // that can then be OR'd with the lower 32-bits.
+ assert(VT == MVT::v4i32 && "Only v4i32 vector rotate expected");
+ static const int OddMask[] = {1, -1, 3, -1};
+ SDValue R13 = DAG.getVectorShuffle(VT, DL, R, R, OddMask);
+ SDValue Scale13 = DAG.getVectorShuffle(VT, DL, Scale, Scale, OddMask);
+
+ SDValue Res02 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64,
+ DAG.getBitcast(MVT::v2i64, R),
+ DAG.getBitcast(MVT::v2i64, Scale));
+ SDValue Res13 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64,
+ DAG.getBitcast(MVT::v2i64, R13),
+ DAG.getBitcast(MVT::v2i64, Scale13));
+ Res02 = DAG.getBitcast(VT, Res02);
+ Res13 = DAG.getBitcast(VT, Res13);
+
+ return DAG.getNode(ISD::OR, DL, VT,
+ DAG.getVectorShuffle(VT, DL, Res02, Res13, {0, 4, 2, 6}),
+ DAG.getVectorShuffle(VT, DL, Res02, Res13, {1, 5, 3, 7}));
}
static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
@@ -23353,9 +24072,6 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
SDValue SetCC = getSETCC(X86::COND_O, SDValue(Sum.getNode(), 2), DL, DAG);
- if (N->getValueType(1) == MVT::i1)
- SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC);
-
return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC);
}
}
@@ -23366,9 +24082,6 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
SDValue SetCC = getSETCC(Cond, SDValue(Sum.getNode(), 1), DL, DAG);
- if (N->getValueType(1) == MVT::i1)
- SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC);
-
return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC);
}
@@ -23572,11 +24285,68 @@ static SDValue LowerCMP_SWAP(SDValue Op, const X86Subtarget &Subtarget,
return SDValue();
}
+// Create MOVMSKB, taking into account whether we need to split for AVX1.
+static SDValue getPMOVMSKB(const SDLoc &DL, SDValue V, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ MVT InVT = V.getSimpleValueType();
+
+ if (InVT == MVT::v32i8 && !Subtarget.hasInt256()) {
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVector(V, DL);
+ Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo);
+ Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi);
+ Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi,
+ DAG.getConstant(16, DL, MVT::i8));
+ return DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi);
+ }
+
+ return DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
+}
+
static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
- MVT SrcVT = Op.getOperand(0).getSimpleValueType();
+ SDValue Src = Op.getOperand(0);
+ MVT SrcVT = Src.getSimpleValueType();
MVT DstVT = Op.getSimpleValueType();
+ // Legalize (v64i1 (bitcast i64 (X))) by splitting the i64, bitcasting each
+ // half to v32i1 and concatenating the result.
+ if (SrcVT == MVT::i64 && DstVT == MVT::v64i1) {
+ assert(!Subtarget.is64Bit() && "Expected 32-bit mode");
+ assert(Subtarget.hasBWI() && "Expected BWI target");
+ SDLoc dl(Op);
+ SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src,
+ DAG.getIntPtrConstant(0, dl));
+ Lo = DAG.getBitcast(MVT::v32i1, Lo);
+ SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src,
+ DAG.getIntPtrConstant(1, dl));
+ Hi = DAG.getBitcast(MVT::v32i1, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
+ }
+
+ // Custom splitting for BWI types when AVX512F is available but BWI isn't.
+ if ((SrcVT == MVT::v32i16 || SrcVT == MVT::v64i8) && DstVT.isVector() &&
+ DAG.getTargetLoweringInfo().isTypeLegal(DstVT)) {
+ SDLoc dl(Op);
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVector(Op.getOperand(0), dl);
+ EVT CastVT = MVT::getVectorVT(DstVT.getVectorElementType(),
+ DstVT.getVectorNumElements() / 2);
+ Lo = DAG.getBitcast(CastVT, Lo);
+ Hi = DAG.getBitcast(CastVT, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, DstVT, Lo, Hi);
+ }
+
+ // Use MOVMSK for vector to scalar conversion to prevent scalarization.
+ if ((SrcVT == MVT::v16i1 || SrcVT == MVT::v32i1) && DstVT.isScalarInteger()) {
+ assert(!Subtarget.hasAVX512() && "Should use K-registers with AVX512");
+ MVT SExtVT = SrcVT == MVT::v16i1 ? MVT::v16i8 : MVT::v32i8;
+ SDLoc DL(Op);
+ SDValue V = DAG.getSExtOrTrunc(Src, DL, SExtVT);
+ V = getPMOVMSKB(DL, V, DAG, Subtarget);
+ return DAG.getZExtOrTrunc(V, DL, DstVT);
+ }
+
if (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
SrcVT == MVT::i64) {
assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
@@ -23584,7 +24354,6 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
// This conversion needs to be expanded.
return SDValue();
- SDValue Op0 = Op->getOperand(0);
SmallVector<SDValue, 16> Elts;
SDLoc dl(Op);
unsigned NumElts;
@@ -23596,14 +24365,14 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
// Widen the vector in input in the case of MVT::v2i32.
// Example: from MVT::v2i32 to MVT::v4i32.
for (unsigned i = 0, e = NumElts; i != e; ++i)
- Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, Op0,
+ Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, Src,
DAG.getIntPtrConstant(i, dl)));
} else {
assert(SrcVT == MVT::i64 && !Subtarget.is64Bit() &&
"Unexpected source type in LowerBITCAST");
- Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Op0,
+ Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src,
DAG.getIntPtrConstant(0, dl)));
- Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Op0,
+ Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src,
DAG.getIntPtrConstant(1, dl)));
NumElts = 2;
SVT = MVT::i32;
@@ -23842,7 +24611,7 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
unsigned NumElems = VT.getVectorNumElements();
assert((VT.getVectorElementType() == MVT::i8 ||
VT.getVectorElementType() == MVT::i16) && "Unexpected type");
- if (NumElems <= 16) {
+ if (NumElems < 16 || (NumElems == 16 && Subtarget.canExtendTo512DQ())) {
MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
Op = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, Op0);
Op = DAG.getNode(ISD::CTPOP, DL, NewVT, Op);
@@ -24224,76 +24993,81 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
SDLoc dl(Op);
+ SDValue Scale = N->getScale();
SDValue Index = N->getIndex();
SDValue Mask = N->getMask();
SDValue Chain = N->getChain();
SDValue BasePtr = N->getBasePtr();
- MVT MemVT = N->getMemoryVT().getSimpleVT();
+
+ if (VT == MVT::v2f32) {
+ assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
+ // If the index is v2i64 and we have VLX we can use xmm for data and index.
+ if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) {
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
+ DAG.getUNDEF(MVT::v2f32));
+ SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+ VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
+ DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
+ return SDValue(NewScatter.getNode(), 1);
+ }
+ return SDValue();
+ }
+
+ if (VT == MVT::v2i32) {
+ assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
+ DAG.getUNDEF(MVT::v2i32));
+ // If the index is v2i64 and we have VLX we can use xmm for data and index.
+ if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) {
+ SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+ VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
+ DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
+ return SDValue(NewScatter.getNode(), 1);
+ }
+ // Custom widen all the operands to avoid promotion.
+ EVT NewIndexVT = EVT::getVectorVT(
+ *DAG.getContext(), Index.getValueType().getVectorElementType(), 4);
+ Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index,
+ DAG.getUNDEF(Index.getValueType()));
+ Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
+ DAG.getConstant(0, dl, MVT::v2i1));
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
+ return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl,
+ Ops, N->getMemOperand());
+ }
+
MVT IndexVT = Index.getSimpleValueType();
MVT MaskVT = Mask.getSimpleValueType();
- if (MemVT.getScalarSizeInBits() < VT.getScalarSizeInBits()) {
- // The v2i32 value was promoted to v2i64.
- // Now we "redo" the type legalizer's work and widen the original
- // v2i32 value to v4i32. The original v2i32 is retrieved from v2i64
- // with a shuffle.
- assert((MemVT == MVT::v2i32 && VT == MVT::v2i64) &&
- "Unexpected memory type");
- int ShuffleMask[] = {0, 2, -1, -1};
- Src = DAG.getVectorShuffle(MVT::v4i32, dl, DAG.getBitcast(MVT::v4i32, Src),
- DAG.getUNDEF(MVT::v4i32), ShuffleMask);
- // Now we have 4 elements instead of 2.
- // Expand the index.
- MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), 4);
- Index = ExtendToType(Index, NewIndexVT, DAG);
-
- // Expand the mask with zeroes
- // Mask may be <2 x i64> or <2 x i1> at this moment
- assert((MaskVT == MVT::v2i1 || MaskVT == MVT::v2i64) &&
- "Unexpected mask type");
- MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), 4);
- Mask = ExtendToType(Mask, ExtMaskVT, DAG, true);
- VT = MVT::v4i32;
- }
+ // If the index is v2i32, we're being called by type legalization and we
+ // should just let the default handling take care of it.
+ if (IndexVT == MVT::v2i32)
+ return SDValue();
- unsigned NumElts = VT.getVectorNumElements();
+ // If we don't have VLX and neither the passthru or index is 512-bits, we
+ // need to widen until one is.
if (!Subtarget.hasVLX() && !VT.is512BitVector() &&
!Index.getSimpleValueType().is512BitVector()) {
- // AVX512F supports only 512-bit vectors. Or data or index should
- // be 512 bit wide. If now the both index and data are 256-bit, but
- // the vector contains 8 elements, we just sign-extend the index
- if (IndexVT == MVT::v8i32)
- // Just extend index
- Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
- else {
- // The minimal number of elts in scatter is 8
- NumElts = 8;
- // Index
- MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts);
- // Use original index here, do not modify the index twice
- Index = ExtendToType(N->getIndex(), NewIndexVT, DAG);
- if (IndexVT.getScalarType() == MVT::i32)
- Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
-
- // Mask
- // At this point we have promoted mask operand
- assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type");
- MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts);
- // Use the original mask here, do not modify the mask twice
- Mask = ExtendToType(N->getMask(), ExtMaskVT, DAG, true);
-
- // The value that should be stored
- MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts);
- Src = ExtendToType(Src, NewVT, DAG);
- }
- }
- // If the mask is "wide" at this point - truncate it to i1 vector
- MVT BitMaskVT = MVT::getVectorVT(MVT::i1, NumElts);
- Mask = DAG.getNode(ISD::TRUNCATE, dl, BitMaskVT, Mask);
-
- // The mask is killed by scatter, add it to the values
- SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other);
- SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
+ // Determine how much we need to widen by to get a 512-bit type.
+ unsigned Factor = std::min(512/VT.getSizeInBits(),
+ 512/IndexVT.getSizeInBits());
+ unsigned NumElts = VT.getVectorNumElements() * Factor;
+
+ VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
+ IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts);
+ MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
+
+ Src = ExtendToType(Src, VT, DAG);
+ Index = ExtendToType(Index, IndexVT, DAG);
+ Mask = ExtendToType(Mask, MaskVT, DAG, true);
+ }
+
+ SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
+ SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
@@ -24315,11 +25089,6 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
assert((!N->isExpandingLoad() || ScalarVT.getSizeInBits() >= 32) &&
"Expanding masked load is supported for 32 and 64-bit types only!");
- // 4x32, 4x64 and 2x64 vectors of non-expanding loads are legal regardless of
- // VLX. These types for exp-loads are handled here.
- if (!N->isExpandingLoad() && VT.getVectorNumElements() <= 4)
- return Op;
-
assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
"Cannot lower masked load op.");
@@ -24336,16 +25105,12 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
Src0 = ExtendToType(Src0, WideDataVT, DAG);
// Mask element has to be i1.
- MVT MaskEltTy = Mask.getSimpleValueType().getScalarType();
- assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) &&
- "We handle 4x32, 4x64 and 2x64 vectors only in this case");
+ assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 &&
+ "Unexpected mask type");
- MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec);
+ MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
- if (MaskEltTy != MVT::i1)
- Mask = DAG.getNode(ISD::TRUNCATE, dl,
- MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask);
SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
N->getBasePtr(), Mask, Src0,
N->getMemoryVT(), N->getMemOperand(),
@@ -24374,10 +25139,6 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget,
assert((!N->isCompressingStore() || ScalarVT.getSizeInBits() >= 32) &&
"Expanding masked load is supported for 32 and 64-bit types only!");
- // 4x32 and 2x64 vectors of non-compressing stores are legal regardless to VLX.
- if (!N->isCompressingStore() && VT.getVectorNumElements() <= 4)
- return Op;
-
assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
"Cannot lower masked store op.");
@@ -24392,17 +25153,13 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget,
MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
// Mask element has to be i1.
- MVT MaskEltTy = Mask.getSimpleValueType().getScalarType();
- assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) &&
- "We handle 4x32, 4x64 and 2x64 vectors only in this case");
+ assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 &&
+ "Unexpected mask type");
- MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec);
+ MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
- if (MaskEltTy != MVT::i1)
- Mask = DAG.getNode(ISD::TRUNCATE, dl,
- MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask);
return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
Mask, N->getMemoryVT(), N->getMemOperand(),
N->isTruncatingStore(), N->isCompressingStore());
@@ -24422,63 +25179,40 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
MVT IndexVT = Index.getSimpleValueType();
MVT MaskVT = Mask.getSimpleValueType();
- unsigned NumElts = VT.getVectorNumElements();
assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op");
// If the index is v2i32, we're being called by type legalization.
if (IndexVT == MVT::v2i32)
return SDValue();
+ // If we don't have VLX and neither the passthru or index is 512-bits, we
+ // need to widen until one is.
+ MVT OrigVT = VT;
if (Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
- !Index.getSimpleValueType().is512BitVector()) {
- // AVX512F supports only 512-bit vectors. Or data or index should
- // be 512 bit wide. If now the both index and data are 256-bit, but
- // the vector contains 8 elements, we just sign-extend the index
- if (NumElts == 8) {
- Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
- SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
- SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
- DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
- N->getMemOperand());
- return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl);
- }
-
- // Minimal number of elements in Gather
- NumElts = 8;
- // Index
- MVT NewIndexVT = MVT::getVectorVT(IndexVT.getScalarType(), NumElts);
- Index = ExtendToType(Index, NewIndexVT, DAG);
- if (IndexVT.getScalarType() == MVT::i32)
- Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
-
- // Mask
- MVT MaskBitVT = MVT::getVectorVT(MVT::i1, NumElts);
- // At this point we have promoted mask operand
- assert(MaskVT.getScalarSizeInBits() >= 32 && "unexpected mask type");
- MVT ExtMaskVT = MVT::getVectorVT(MaskVT.getScalarType(), NumElts);
- Mask = ExtendToType(Mask, ExtMaskVT, DAG, true);
- Mask = DAG.getNode(ISD::TRUNCATE, dl, MaskBitVT, Mask);
-
- // The pass-through value
- MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts);
- Src0 = ExtendToType(Src0, NewVT, DAG);
-
- SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
- SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
- DAG.getVTList(NewVT, MaskBitVT, MVT::Other), Ops, dl, N->getMemoryVT(),
- N->getMemOperand());
- SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
- NewGather.getValue(0),
- DAG.getIntPtrConstant(0, dl));
- SDValue RetOps[] = {Extract, NewGather.getValue(2)};
- return DAG.getMergeValues(RetOps, dl);
+ !IndexVT.is512BitVector()) {
+ // Determine how much we need to widen by to get a 512-bit type.
+ unsigned Factor = std::min(512/VT.getSizeInBits(),
+ 512/IndexVT.getSizeInBits());
+
+ unsigned NumElts = VT.getVectorNumElements() * Factor;
+
+ VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
+ IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts);
+ MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
+
+ Src0 = ExtendToType(Src0, VT, DAG);
+ Index = ExtendToType(Index, IndexVT, DAG);
+ Mask = ExtendToType(Mask, MaskVT, DAG, true);
}
- SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+ SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index,
+ N->getScale() };
SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
N->getMemOperand());
- return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl);
+ SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT,
+ NewGather, DAG.getIntPtrConstant(0, dl));
+ return DAG.getMergeValues({Extract, NewGather.getValue(2)}, dl);
}
SDValue X86TargetLowering::LowerGC_TRANSITION_START(SDValue Op,
@@ -24545,6 +25279,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG);
case ISD::INSERT_SUBVECTOR: return LowerINSERT_SUBVECTOR(Op, Subtarget,DAG);
+ case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op,Subtarget,DAG);
case ISD::SCALAR_TO_VECTOR: return LowerSCALAR_TO_VECTOR(Op, Subtarget,DAG);
case ISD::ConstantPool: return LowerConstantPool(Op, DAG);
case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG);
@@ -24566,7 +25301,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT: return LowerFP_TO_INT(Op, DAG);
case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG);
- case ISD::LOAD: return LowerExtendedLoad(Op, Subtarget, DAG);
+ case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG);
+ case ISD::STORE: return LowerStore(Op, Subtarget, DAG);
case ISD::FABS:
case ISD::FNEG: return LowerFABSorFNEG(Op, DAG);
case ISD::FCOPYSIGN: return LowerFCOPYSIGN(Op, DAG);
@@ -24635,7 +25371,6 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::GC_TRANSITION_START:
return LowerGC_TRANSITION_START(Op, DAG);
case ISD::GC_TRANSITION_END: return LowerGC_TRANSITION_END(Op, DAG);
- case ISD::STORE: return LowerTruncatingStore(Op, Subtarget, DAG);
}
}
@@ -24676,19 +25411,13 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
auto InVT = N->getValueType(0);
- auto InVTSize = InVT.getSizeInBits();
- const unsigned RegSize =
- (InVTSize > 128) ? ((InVTSize > 256) ? 512 : 256) : 128;
- assert((Subtarget.hasBWI() || RegSize < 512) &&
- "512-bit vector requires AVX512BW");
- assert((Subtarget.hasAVX2() || RegSize < 256) &&
- "256-bit vector requires AVX2");
-
- auto ElemVT = InVT.getVectorElementType();
- auto RegVT = EVT::getVectorVT(*DAG.getContext(), ElemVT,
- RegSize / ElemVT.getSizeInBits());
- assert(RegSize % InVT.getSizeInBits() == 0);
- unsigned NumConcat = RegSize / InVT.getSizeInBits();
+ assert(InVT.getSizeInBits() < 128);
+ assert(128 % InVT.getSizeInBits() == 0);
+ unsigned NumConcat = 128 / InVT.getSizeInBits();
+
+ EVT RegVT = EVT::getVectorVT(*DAG.getContext(),
+ InVT.getVectorElementType(),
+ NumConcat * InVT.getVectorNumElements());
SmallVector<SDValue, 16> Ops(NumConcat, DAG.getUNDEF(InVT));
Ops[0] = N->getOperand(0);
@@ -24697,12 +25426,32 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Ops);
SDValue Res = DAG.getNode(X86ISD::AVG, dl, RegVT, InVec0, InVec1);
- if (!ExperimentalVectorWideningLegalization)
+ if (getTypeAction(*DAG.getContext(), InVT) != TypeWidenVector)
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, InVT, Res,
DAG.getIntPtrConstant(0, dl));
Results.push_back(Res);
return;
}
+ case ISD::SETCC: {
+ // Widen v2i32 (setcc v2f32). This is really needed for AVX512VL when
+ // setCC result type is v2i1 because type legalzation will end up with
+ // a v4i1 setcc plus an extend.
+ assert(N->getValueType(0) == MVT::v2i32 && "Unexpected type");
+ if (N->getOperand(0).getValueType() != MVT::v2f32)
+ return;
+ SDValue UNDEF = DAG.getUNDEF(MVT::v2f32);
+ SDValue LHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32,
+ N->getOperand(0), UNDEF);
+ SDValue RHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32,
+ N->getOperand(1), UNDEF);
+ SDValue Res = DAG.getNode(ISD::SETCC, dl, MVT::v4i32, LHS, RHS,
+ N->getOperand(2));
+ if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector)
+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res,
+ DAG.getIntPtrConstant(0, dl));
+ Results.push_back(Res);
+ return;
+ }
// We might have generated v2f32 FMIN/FMAX operations. Widen them to v4f32.
case X86ISD::FMINC:
case X86ISD::FMIN:
@@ -24731,12 +25480,14 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT: {
bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT;
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+ EVT SrcVT = Src.getValueType();
- if (N->getValueType(0) == MVT::v2i32) {
+ if (VT == MVT::v2i32) {
assert((IsSigned || Subtarget.hasAVX512()) &&
"Can only handle signed conversion without AVX512");
assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
- SDValue Src = N->getOperand(0);
if (Src.getValueType() == MVT::v2f64) {
MVT ResVT = MVT::v4i32;
unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
@@ -24749,20 +25500,21 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Src, DAG.getIntPtrConstant(0, dl));
}
SDValue Res = DAG.getNode(Opc, dl, ResVT, Src);
- ResVT = ExperimentalVectorWideningLegalization ? MVT::v4i32
- : MVT::v2i32;
+ bool WidenType = getTypeAction(*DAG.getContext(),
+ MVT::v2i32) == TypeWidenVector;
+ ResVT = WidenType ? MVT::v4i32 : MVT::v2i32;
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Res,
DAG.getIntPtrConstant(0, dl));
Results.push_back(Res);
return;
}
- if (Src.getValueType() == MVT::v2f32) {
+ if (SrcVT == MVT::v2f32) {
SDValue Idx = DAG.getIntPtrConstant(0, dl);
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
DAG.getUNDEF(MVT::v2f32));
Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT
: ISD::FP_TO_UINT, dl, MVT::v4i32, Res);
- if (!ExperimentalVectorWideningLegalization)
+ if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector)
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx);
Results.push_back(Res);
return;
@@ -24773,11 +25525,30 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
return;
}
+ if (Subtarget.hasDQI() && VT == MVT::i64 &&
+ (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
+ assert(!Subtarget.is64Bit() && "i64 should be legal");
+ unsigned NumElts = Subtarget.hasVLX() ? 4 : 8;
+ // Using a 256-bit input here to guarantee 128-bit input for f32 case.
+ // TODO: Use 128-bit vectors for f64 case?
+ // TODO: Use 128-bit vectors for f32 by using CVTTP2SI/CVTTP2UI.
+ MVT VecVT = MVT::getVectorVT(MVT::i64, NumElts);
+ MVT VecInVT = MVT::getVectorVT(SrcVT.getSimpleVT(), NumElts);
+
+ SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);
+ SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VecInVT,
+ DAG.getConstantFP(0.0, dl, VecInVT), Src,
+ ZeroIdx);
+ Res = DAG.getNode(N->getOpcode(), SDLoc(N), VecVT, Res);
+ Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Res, ZeroIdx);
+ Results.push_back(Res);
+ return;
+ }
+
std::pair<SDValue,SDValue> Vals =
FP_TO_INTHelper(SDValue(N, 0), DAG, IsSigned, /*IsReplace=*/ true);
SDValue FIST = Vals.first, StackSlot = Vals.second;
if (FIST.getNode()) {
- EVT VT = N->getValueType(0);
// Return a load from the stack slot.
if (StackSlot.getNode())
Results.push_back(
@@ -24963,6 +25734,32 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
EVT DstVT = N->getValueType(0);
EVT SrcVT = N->getOperand(0).getValueType();
+ // If this is a bitcast from a v64i1 k-register to a i64 on a 32-bit target
+ // we can split using the k-register rather than memory.
+ if (SrcVT == MVT::v64i1 && DstVT == MVT::i64 && Subtarget.hasBWI()) {
+ assert(!Subtarget.is64Bit() && "Expected 32-bit mode");
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+ Lo = DAG.getBitcast(MVT::i32, Lo);
+ Hi = DAG.getBitcast(MVT::i32, Hi);
+ SDValue Res = DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi);
+ Results.push_back(Res);
+ return;
+ }
+
+ // Custom splitting for BWI types when AVX512F is available but BWI isn't.
+ if ((DstVT == MVT::v32i16 || DstVT == MVT::v64i8) &&
+ SrcVT.isVector() && isTypeLegal(SrcVT)) {
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+ MVT CastVT = (DstVT == MVT::v32i16) ? MVT::v16i16 : MVT::v32i8;
+ Lo = DAG.getBitcast(CastVT, Lo);
+ Hi = DAG.getBitcast(CastVT, Hi);
+ SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, DstVT, Lo, Hi);
+ Results.push_back(Res);
+ return;
+ }
+
if (SrcVT != MVT::f64 ||
(DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8))
return;
@@ -24974,7 +25771,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
MVT::v2f64, N->getOperand(0));
SDValue ToVecInt = DAG.getBitcast(WiderVT, Expanded);
- if (ExperimentalVectorWideningLegalization) {
+ if (getTypeAction(*DAG.getContext(), DstVT) == TypeWidenVector) {
// If we are legalizing vectors by widening, we already have the desired
// legal vector type, just return it.
Results.push_back(ToVecInt);
@@ -25009,7 +25806,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
}
SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
- Index };
+ Index, Gather->getScale() };
SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl,
Gather->getMemoryVT(), Gather->getMemOperand());
@@ -25036,12 +25833,12 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
}
SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
- Index };
+ Index, Gather->getScale() };
SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl,
Gather->getMemoryVT(), Gather->getMemOperand());
SDValue Chain = Res.getValue(2);
- if (!ExperimentalVectorWideningLegalization)
+ if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector)
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res,
DAG.getIntPtrConstant(0, dl));
Results.push_back(Res);
@@ -25057,12 +25854,12 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
DAG.getConstant(0, dl, MVT::v2i1));
SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
- Index };
+ Index, Gather->getScale() };
SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other),
Gather->getMemoryVT(), dl, Ops,
Gather->getMemOperand());
SDValue Chain = Res.getValue(1);
- if (!ExperimentalVectorWideningLegalization)
+ if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector)
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res,
DAG.getIntPtrConstant(0, dl));
Results.push_back(Res);
@@ -25101,7 +25898,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::COMI: return "X86ISD::COMI";
case X86ISD::UCOMI: return "X86ISD::UCOMI";
case X86ISD::CMPM: return "X86ISD::CMPM";
- case X86ISD::CMPMU: return "X86ISD::CMPMU";
case X86ISD::CMPM_RND: return "X86ISD::CMPM_RND";
case X86ISD::SETCC: return "X86ISD::SETCC";
case X86ISD::SETCC_CARRY: return "X86ISD::SETCC_CARRY";
@@ -25192,7 +25988,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::VFPROUND: return "X86ISD::VFPROUND";
case X86ISD::VFPROUND_RND: return "X86ISD::VFPROUND_RND";
case X86ISD::VFPROUNDS_RND: return "X86ISD::VFPROUNDS_RND";
- case X86ISD::CVT2MASK: return "X86ISD::CVT2MASK";
case X86ISD::VSHLDQ: return "X86ISD::VSHLDQ";
case X86ISD::VSRLDQ: return "X86ISD::VSRLDQ";
case X86ISD::VSHL: return "X86ISD::VSHL";
@@ -25208,8 +26003,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::CMPP: return "X86ISD::CMPP";
case X86ISD::PCMPEQ: return "X86ISD::PCMPEQ";
case X86ISD::PCMPGT: return "X86ISD::PCMPGT";
- case X86ISD::PCMPEQM: return "X86ISD::PCMPEQM";
- case X86ISD::PCMPGTM: return "X86ISD::PCMPGTM";
case X86ISD::PHMINPOS: return "X86ISD::PHMINPOS";
case X86ISD::ADD: return "X86ISD::ADD";
case X86ISD::SUB: return "X86ISD::SUB";
@@ -25226,14 +26019,14 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::OR: return "X86ISD::OR";
case X86ISD::XOR: return "X86ISD::XOR";
case X86ISD::AND: return "X86ISD::AND";
+ case X86ISD::BEXTR: return "X86ISD::BEXTR";
case X86ISD::MUL_IMM: return "X86ISD::MUL_IMM";
case X86ISD::MOVMSK: return "X86ISD::MOVMSK";
case X86ISD::PTEST: return "X86ISD::PTEST";
case X86ISD::TESTP: return "X86ISD::TESTP";
- case X86ISD::TESTM: return "X86ISD::TESTM";
- case X86ISD::TESTNM: return "X86ISD::TESTNM";
case X86ISD::KORTEST: return "X86ISD::KORTEST";
case X86ISD::KTEST: return "X86ISD::KTEST";
+ case X86ISD::KADD: return "X86ISD::KADD";
case X86ISD::KSHIFTL: return "X86ISD::KSHIFTL";
case X86ISD::KSHIFTR: return "X86ISD::KSHIFTR";
case X86ISD::PACKSS: return "X86ISD::PACKSS";
@@ -25251,8 +26044,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::SHUF128: return "X86ISD::SHUF128";
case X86ISD::MOVLHPS: return "X86ISD::MOVLHPS";
case X86ISD::MOVHLPS: return "X86ISD::MOVHLPS";
- case X86ISD::MOVLPS: return "X86ISD::MOVLPS";
- case X86ISD::MOVLPD: return "X86ISD::MOVLPD";
case X86ISD::MOVDDUP: return "X86ISD::MOVDDUP";
case X86ISD::MOVSHDUP: return "X86ISD::MOVSHDUP";
case X86ISD::MOVSLDUP: return "X86ISD::MOVSLDUP";
@@ -25268,7 +26059,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::VPERM2X128: return "X86ISD::VPERM2X128";
case X86ISD::VPERMV: return "X86ISD::VPERMV";
case X86ISD::VPERMV3: return "X86ISD::VPERMV3";
- case X86ISD::VPERMIV3: return "X86ISD::VPERMIV3";
case X86ISD::VPERMI: return "X86ISD::VPERMI";
case X86ISD::VPTERNLOG: return "X86ISD::VPTERNLOG";
case X86ISD::VFIXUPIMM: return "X86ISD::VFIXUPIMM";
@@ -25308,26 +26098,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::FNMSUB_RND: return "X86ISD::FNMSUB_RND";
case X86ISD::FMADDSUB_RND: return "X86ISD::FMADDSUB_RND";
case X86ISD::FMSUBADD_RND: return "X86ISD::FMSUBADD_RND";
- case X86ISD::FMADDS1: return "X86ISD::FMADDS1";
- case X86ISD::FNMADDS1: return "X86ISD::FNMADDS1";
- case X86ISD::FMSUBS1: return "X86ISD::FMSUBS1";
- case X86ISD::FNMSUBS1: return "X86ISD::FNMSUBS1";
- case X86ISD::FMADDS1_RND: return "X86ISD::FMADDS1_RND";
- case X86ISD::FNMADDS1_RND: return "X86ISD::FNMADDS1_RND";
- case X86ISD::FMSUBS1_RND: return "X86ISD::FMSUBS1_RND";
- case X86ISD::FNMSUBS1_RND: return "X86ISD::FNMSUBS1_RND";
- case X86ISD::FMADDS3: return "X86ISD::FMADDS3";
- case X86ISD::FNMADDS3: return "X86ISD::FNMADDS3";
- case X86ISD::FMSUBS3: return "X86ISD::FMSUBS3";
- case X86ISD::FNMSUBS3: return "X86ISD::FNMSUBS3";
- case X86ISD::FMADDS3_RND: return "X86ISD::FMADDS3_RND";
- case X86ISD::FNMADDS3_RND: return "X86ISD::FNMADDS3_RND";
- case X86ISD::FMSUBS3_RND: return "X86ISD::FMSUBS3_RND";
- case X86ISD::FNMSUBS3_RND: return "X86ISD::FNMSUBS3_RND";
- case X86ISD::FMADD4S: return "X86ISD::FMADD4S";
- case X86ISD::FNMADD4S: return "X86ISD::FNMADD4S";
- case X86ISD::FMSUB4S: return "X86ISD::FMSUB4S";
- case X86ISD::FNMSUB4S: return "X86ISD::FNMSUB4S";
case X86ISD::VPMADD52H: return "X86ISD::VPMADD52H";
case X86ISD::VPMADD52L: return "X86ISD::VPMADD52L";
case X86ISD::VRNDSCALE: return "X86ISD::VRNDSCALE";
@@ -25342,8 +26112,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::VGETMANT_RND: return "X86ISD::VGETMANT_RND";
case X86ISD::VGETMANTS: return "X86ISD::VGETMANTS";
case X86ISD::VGETMANTS_RND: return "X86ISD::VGETMANTS_RND";
- case X86ISD::PCMPESTRI: return "X86ISD::PCMPESTRI";
- case X86ISD::PCMPISTRI: return "X86ISD::PCMPISTRI";
+ case X86ISD::PCMPESTR: return "X86ISD::PCMPESTR";
+ case X86ISD::PCMPISTR: return "X86ISD::PCMPISTR";
case X86ISD::XTEST: return "X86ISD::XTEST";
case X86ISD::COMPRESS: return "X86ISD::COMPRESS";
case X86ISD::EXPAND: return "X86ISD::EXPAND";
@@ -25412,6 +26182,10 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::GF2P8MULB: return "X86ISD::GF2P8MULB";
case X86ISD::GF2P8AFFINEQB: return "X86ISD::GF2P8AFFINEQB";
case X86ISD::GF2P8AFFINEINVQB: return "X86ISD::GF2P8AFFINEINVQB";
+ case X86ISD::NT_CALL: return "X86ISD::NT_CALL";
+ case X86ISD::NT_BRIND: return "X86ISD::NT_BRIND";
+ case X86ISD::UMWAIT: return "X86ISD::UMWAIT";
+ case X86ISD::TPAUSE: return "X86ISD::TPAUSE";
}
return nullptr;
}
@@ -25478,11 +26252,20 @@ bool X86TargetLowering::isVectorShiftByScalarCheap(Type *Ty) const {
if (Bits == 8)
return false;
+ // XOP has v16i8/v8i16/v4i32/v2i64 variable vector shifts.
+ if (Subtarget.hasXOP() && Ty->getPrimitiveSizeInBits() == 128 &&
+ (Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64))
+ return false;
+
// AVX2 has vpsllv[dq] instructions (and other shifts) that make variable
// shifts just as cheap as scalar ones.
if (Subtarget.hasAVX2() && (Bits == 32 || Bits == 64))
return false;
+ // AVX512BW has shifts such as vpsllvw.
+ if (Subtarget.hasBWI() && Bits == 16)
+ return false;
+
// Otherwise, it's significantly cheaper to shift by a scalar amount than by a
// fully general vector.
return true;
@@ -25561,7 +26344,15 @@ bool X86TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
return false;
}
-bool X86TargetLowering::isVectorLoadExtDesirable(SDValue) const { return true; }
+bool X86TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
+ EVT SrcVT = ExtVal.getOperand(0).getValueType();
+
+ // There is no extending load for vXi1.
+ if (SrcVT.getScalarType() == MVT::i1)
+ return false;
+
+ return true;
+}
bool
X86TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
@@ -25610,13 +26401,27 @@ bool X86TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
return isTypeLegal(VT.getSimpleVT());
}
-bool
-X86TargetLowering::isVectorClearMaskLegal(const SmallVectorImpl<int> &Mask,
- EVT VT) const {
+bool X86TargetLowering::isVectorClearMaskLegal(ArrayRef<int> Mask,
+ EVT VT) const {
+ // Don't convert an 'and' into a shuffle that we don't directly support.
+ // vpblendw and vpshufb for 256-bit vectors are not available on AVX1.
+ if (!Subtarget.hasAVX2())
+ if (VT == MVT::v32i8 || VT == MVT::v16i16)
+ return false;
+
// Just delegate to the generic legality, clear masks aren't special.
return isShuffleMaskLegal(Mask, VT);
}
+bool X86TargetLowering::areJTsAllowed(const Function *Fn) const {
+ // If the subtarget is using retpolines, we need to not generate jump tables.
+ if (Subtarget.useRetpoline())
+ return false;
+
+ // Otherwise, fallback on the generic logic.
+ return TargetLowering::areJTsAllowed(Fn);
+}
+
//===----------------------------------------------------------------------===//
// X86 Scheduler Hooks
//===----------------------------------------------------------------------===//
@@ -25697,79 +26502,6 @@ static MachineBasicBlock *emitXBegin(MachineInstr &MI, MachineBasicBlock *MBB,
return sinkMBB;
}
-// FIXME: When we get size specific XMM0 registers, i.e. XMM0_V16I8
-// or XMM0_V32I8 in AVX all of this code can be replaced with that
-// in the .td file.
-static MachineBasicBlock *emitPCMPSTRM(MachineInstr &MI, MachineBasicBlock *BB,
- const TargetInstrInfo *TII) {
- unsigned Opc;
- switch (MI.getOpcode()) {
- default: llvm_unreachable("illegal opcode!");
- case X86::PCMPISTRM128REG: Opc = X86::PCMPISTRM128rr; break;
- case X86::VPCMPISTRM128REG: Opc = X86::VPCMPISTRM128rr; break;
- case X86::PCMPISTRM128MEM: Opc = X86::PCMPISTRM128rm; break;
- case X86::VPCMPISTRM128MEM: Opc = X86::VPCMPISTRM128rm; break;
- case X86::PCMPESTRM128REG: Opc = X86::PCMPESTRM128rr; break;
- case X86::VPCMPESTRM128REG: Opc = X86::VPCMPESTRM128rr; break;
- case X86::PCMPESTRM128MEM: Opc = X86::PCMPESTRM128rm; break;
- case X86::VPCMPESTRM128MEM: Opc = X86::VPCMPESTRM128rm; break;
- }
-
- DebugLoc dl = MI.getDebugLoc();
- MachineInstrBuilder MIB = BuildMI(*BB, MI, dl, TII->get(Opc));
-
- unsigned NumArgs = MI.getNumOperands();
- for (unsigned i = 1; i < NumArgs; ++i) {
- MachineOperand &Op = MI.getOperand(i);
- if (!(Op.isReg() && Op.isImplicit()))
- MIB.add(Op);
- }
- if (MI.hasOneMemOperand())
- MIB->setMemRefs(MI.memoperands_begin(), MI.memoperands_end());
-
- BuildMI(*BB, MI, dl, TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg())
- .addReg(X86::XMM0);
-
- MI.eraseFromParent();
- return BB;
-}
-
-// FIXME: Custom handling because TableGen doesn't support multiple implicit
-// defs in an instruction pattern
-static MachineBasicBlock *emitPCMPSTRI(MachineInstr &MI, MachineBasicBlock *BB,
- const TargetInstrInfo *TII) {
- unsigned Opc;
- switch (MI.getOpcode()) {
- default: llvm_unreachable("illegal opcode!");
- case X86::PCMPISTRIREG: Opc = X86::PCMPISTRIrr; break;
- case X86::VPCMPISTRIREG: Opc = X86::VPCMPISTRIrr; break;
- case X86::PCMPISTRIMEM: Opc = X86::PCMPISTRIrm; break;
- case X86::VPCMPISTRIMEM: Opc = X86::VPCMPISTRIrm; break;
- case X86::PCMPESTRIREG: Opc = X86::PCMPESTRIrr; break;
- case X86::VPCMPESTRIREG: Opc = X86::VPCMPESTRIrr; break;
- case X86::PCMPESTRIMEM: Opc = X86::PCMPESTRIrm; break;
- case X86::VPCMPESTRIMEM: Opc = X86::VPCMPESTRIrm; break;
- }
-
- DebugLoc dl = MI.getDebugLoc();
- MachineInstrBuilder MIB = BuildMI(*BB, MI, dl, TII->get(Opc));
-
- unsigned NumArgs = MI.getNumOperands(); // remove the results
- for (unsigned i = 1; i < NumArgs; ++i) {
- MachineOperand &Op = MI.getOperand(i);
- if (!(Op.isReg() && Op.isImplicit()))
- MIB.add(Op);
- }
- if (MI.hasOneMemOperand())
- MIB->setMemRefs(MI.memoperands_begin(), MI.memoperands_end());
-
- BuildMI(*BB, MI, dl, TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg())
- .addReg(X86::ECX);
-
- MI.eraseFromParent();
- return BB;
-}
-
static MachineBasicBlock *emitWRPKRU(MachineInstr &MI, MachineBasicBlock *BB,
const X86Subtarget &Subtarget) {
DebugLoc dl = MI.getDebugLoc();
@@ -26158,7 +26890,7 @@ MachineBasicBlock *X86TargetLowering::EmitVAStartSaveXMMRegsWithCustomInserter(
!MI.getOperand(MI.getNumOperands() - 1).isReg() ||
MI.getOperand(MI.getNumOperands() - 1).getReg() == X86::EFLAGS) &&
"Expected last argument to be EFLAGS");
- unsigned MOVOpc = Subtarget.hasFp256() ? X86::VMOVAPSmr : X86::MOVAPSmr;
+ unsigned MOVOpc = Subtarget.hasAVX() ? X86::VMOVAPSmr : X86::MOVAPSmr;
// In the XMM save block, save all the XMM argument registers.
for (int i = 3, e = MI.getNumOperands() - 1; i != e; ++i) {
int64_t Offset = (i - 3) * 16 + VarArgsFPOffset;
@@ -26919,6 +27651,184 @@ X86TargetLowering::EmitLoweredTLSCall(MachineInstr &MI,
return BB;
}
+static unsigned getOpcodeForRetpoline(unsigned RPOpc) {
+ switch (RPOpc) {
+ case X86::RETPOLINE_CALL32:
+ return X86::CALLpcrel32;
+ case X86::RETPOLINE_CALL64:
+ return X86::CALL64pcrel32;
+ case X86::RETPOLINE_TCRETURN32:
+ return X86::TCRETURNdi;
+ case X86::RETPOLINE_TCRETURN64:
+ return X86::TCRETURNdi64;
+ }
+ llvm_unreachable("not retpoline opcode");
+}
+
+static const char *getRetpolineSymbol(const X86Subtarget &Subtarget,
+ unsigned Reg) {
+ if (Subtarget.useRetpolineExternalThunk()) {
+ // When using an external thunk for retpolines, we pick names that match the
+ // names GCC happens to use as well. This helps simplify the implementation
+ // of the thunks for kernels where they have no easy ability to create
+ // aliases and are doing non-trivial configuration of the thunk's body. For
+ // example, the Linux kernel will do boot-time hot patching of the thunk
+ // bodies and cannot easily export aliases of these to loaded modules.
+ //
+ // Note that at any point in the future, we may need to change the semantics
+ // of how we implement retpolines and at that time will likely change the
+ // name of the called thunk. Essentially, there is no hard guarantee that
+ // LLVM will generate calls to specific thunks, we merely make a best-effort
+ // attempt to help out kernels and other systems where duplicating the
+ // thunks is costly.
+ switch (Reg) {
+ case X86::EAX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__x86_indirect_thunk_eax";
+ case X86::ECX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__x86_indirect_thunk_ecx";
+ case X86::EDX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__x86_indirect_thunk_edx";
+ case X86::EDI:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__x86_indirect_thunk_edi";
+ case X86::R11:
+ assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!");
+ return "__x86_indirect_thunk_r11";
+ }
+ llvm_unreachable("unexpected reg for retpoline");
+ }
+
+ // When targeting an internal COMDAT thunk use an LLVM-specific name.
+ switch (Reg) {
+ case X86::EAX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__llvm_retpoline_eax";
+ case X86::ECX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__llvm_retpoline_ecx";
+ case X86::EDX:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__llvm_retpoline_edx";
+ case X86::EDI:
+ assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
+ return "__llvm_retpoline_edi";
+ case X86::R11:
+ assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!");
+ return "__llvm_retpoline_r11";
+ }
+ llvm_unreachable("unexpected reg for retpoline");
+}
+
+MachineBasicBlock *
+X86TargetLowering::EmitLoweredRetpoline(MachineInstr &MI,
+ MachineBasicBlock *BB) const {
+ // Copy the virtual register into the R11 physical register and
+ // call the retpoline thunk.
+ DebugLoc DL = MI.getDebugLoc();
+ const X86InstrInfo *TII = Subtarget.getInstrInfo();
+ unsigned CalleeVReg = MI.getOperand(0).getReg();
+ unsigned Opc = getOpcodeForRetpoline(MI.getOpcode());
+
+ // Find an available scratch register to hold the callee. On 64-bit, we can
+ // just use R11, but we scan for uses anyway to ensure we don't generate
+ // incorrect code. On 32-bit, we use one of EAX, ECX, or EDX that isn't
+ // already a register use operand to the call to hold the callee. If none
+ // are available, use EDI instead. EDI is chosen because EBX is the PIC base
+ // register and ESI is the base pointer to realigned stack frames with VLAs.
+ SmallVector<unsigned, 3> AvailableRegs;
+ if (Subtarget.is64Bit())
+ AvailableRegs.push_back(X86::R11);
+ else
+ AvailableRegs.append({X86::EAX, X86::ECX, X86::EDX, X86::EDI});
+
+ // Zero out any registers that are already used.
+ for (const auto &MO : MI.operands()) {
+ if (MO.isReg() && MO.isUse())
+ for (unsigned &Reg : AvailableRegs)
+ if (Reg == MO.getReg())
+ Reg = 0;
+ }
+
+ // Choose the first remaining non-zero available register.
+ unsigned AvailableReg = 0;
+ for (unsigned MaybeReg : AvailableRegs) {
+ if (MaybeReg) {
+ AvailableReg = MaybeReg;
+ break;
+ }
+ }
+ if (!AvailableReg)
+ report_fatal_error("calling convention incompatible with retpoline, no "
+ "available registers");
+
+ const char *Symbol = getRetpolineSymbol(Subtarget, AvailableReg);
+
+ BuildMI(*BB, MI, DL, TII->get(TargetOpcode::COPY), AvailableReg)
+ .addReg(CalleeVReg);
+ MI.getOperand(0).ChangeToES(Symbol);
+ MI.setDesc(TII->get(Opc));
+ MachineInstrBuilder(*BB->getParent(), &MI)
+ .addReg(AvailableReg, RegState::Implicit | RegState::Kill);
+ return BB;
+}
+
+/// SetJmp implies future control flow change upon calling the corresponding
+/// LongJmp.
+/// Instead of using the 'return' instruction, the long jump fixes the stack and
+/// performs an indirect branch. To do so it uses the registers that were stored
+/// in the jump buffer (when calling SetJmp).
+/// In case the shadow stack is enabled we need to fix it as well, because some
+/// return addresses will be skipped.
+/// The function will save the SSP for future fixing in the function
+/// emitLongJmpShadowStackFix.
+/// \sa emitLongJmpShadowStackFix
+/// \param [in] MI The temporary Machine Instruction for the builtin.
+/// \param [in] MBB The Machine Basic Block that will be modified.
+void X86TargetLowering::emitSetJmpShadowStackFix(MachineInstr &MI,
+ MachineBasicBlock *MBB) const {
+ DebugLoc DL = MI.getDebugLoc();
+ MachineFunction *MF = MBB->getParent();
+ const TargetInstrInfo *TII = Subtarget.getInstrInfo();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+ MachineInstrBuilder MIB;
+
+ // Memory Reference.
+ MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin();
+ MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end();
+
+ // Initialize a register with zero.
+ MVT PVT = getPointerTy(MF->getDataLayout());
+ const TargetRegisterClass *PtrRC = getRegClassFor(PVT);
+ unsigned ZReg = MRI.createVirtualRegister(PtrRC);
+ unsigned XorRROpc = (PVT == MVT::i64) ? X86::XOR64rr : X86::XOR32rr;
+ BuildMI(*MBB, MI, DL, TII->get(XorRROpc))
+ .addDef(ZReg)
+ .addReg(ZReg, RegState::Undef)
+ .addReg(ZReg, RegState::Undef);
+
+ // Read the current SSP Register value to the zeroed register.
+ unsigned SSPCopyReg = MRI.createVirtualRegister(PtrRC);
+ unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD;
+ BuildMI(*MBB, MI, DL, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg);
+
+ // Write the SSP register value to offset 3 in input memory buffer.
+ unsigned PtrStoreOpc = (PVT == MVT::i64) ? X86::MOV64mr : X86::MOV32mr;
+ MIB = BuildMI(*MBB, MI, DL, TII->get(PtrStoreOpc));
+ const int64_t SSPOffset = 3 * PVT.getStoreSize();
+ const unsigned MemOpndSlot = 1;
+ for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
+ if (i == X86::AddrDisp)
+ MIB.addDisp(MI.getOperand(MemOpndSlot + i), SSPOffset);
+ else
+ MIB.add(MI.getOperand(MemOpndSlot + i));
+ }
+ MIB.addReg(SSPCopyReg);
+ MIB.setMemRefs(MMOBegin, MMOEnd);
+}
+
MachineBasicBlock *
X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI,
MachineBasicBlock *MBB) const {
@@ -27028,6 +27938,11 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI,
else
MIB.addMBB(restoreMBB);
MIB.setMemRefs(MMOBegin, MMOEnd);
+
+ if (MF->getMMI().getModule()->getModuleFlag("cf-protection-return")) {
+ emitSetJmpShadowStackFix(MI, thisMBB);
+ }
+
// Setup
MIB = BuildMI(*thisMBB, MI, DL, TII->get(X86::EH_SjLj_Setup))
.addMBB(restoreMBB);
@@ -27069,6 +27984,183 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI,
return sinkMBB;
}
+/// Fix the shadow stack using the previously saved SSP pointer.
+/// \sa emitSetJmpShadowStackFix
+/// \param [in] MI The temporary Machine Instruction for the builtin.
+/// \param [in] MBB The Machine Basic Block that will be modified.
+/// \return The sink MBB that will perform the future indirect branch.
+MachineBasicBlock *
+X86TargetLowering::emitLongJmpShadowStackFix(MachineInstr &MI,
+ MachineBasicBlock *MBB) const {
+ DebugLoc DL = MI.getDebugLoc();
+ MachineFunction *MF = MBB->getParent();
+ const TargetInstrInfo *TII = Subtarget.getInstrInfo();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ // Memory Reference
+ MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin();
+ MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end();
+
+ MVT PVT = getPointerTy(MF->getDataLayout());
+ const TargetRegisterClass *PtrRC = getRegClassFor(PVT);
+
+ // checkSspMBB:
+ // xor vreg1, vreg1
+ // rdssp vreg1
+ // test vreg1, vreg1
+ // je sinkMBB # Jump if Shadow Stack is not supported
+ // fallMBB:
+ // mov buf+24/12(%rip), vreg2
+ // sub vreg1, vreg2
+ // jbe sinkMBB # No need to fix the Shadow Stack
+ // fixShadowMBB:
+ // shr 3/2, vreg2
+ // incssp vreg2 # fix the SSP according to the lower 8 bits
+ // shr 8, vreg2
+ // je sinkMBB
+ // fixShadowLoopPrepareMBB:
+ // shl vreg2
+ // mov 128, vreg3
+ // fixShadowLoopMBB:
+ // incssp vreg3
+ // dec vreg2
+ // jne fixShadowLoopMBB # Iterate until you finish fixing
+ // # the Shadow Stack
+ // sinkMBB:
+
+ MachineFunction::iterator I = ++MBB->getIterator();
+ const BasicBlock *BB = MBB->getBasicBlock();
+
+ MachineBasicBlock *checkSspMBB = MF->CreateMachineBasicBlock(BB);
+ MachineBasicBlock *fallMBB = MF->CreateMachineBasicBlock(BB);
+ MachineBasicBlock *fixShadowMBB = MF->CreateMachineBasicBlock(BB);
+ MachineBasicBlock *fixShadowLoopPrepareMBB = MF->CreateMachineBasicBlock(BB);
+ MachineBasicBlock *fixShadowLoopMBB = MF->CreateMachineBasicBlock(BB);
+ MachineBasicBlock *sinkMBB = MF->CreateMachineBasicBlock(BB);
+ MF->insert(I, checkSspMBB);
+ MF->insert(I, fallMBB);
+ MF->insert(I, fixShadowMBB);
+ MF->insert(I, fixShadowLoopPrepareMBB);
+ MF->insert(I, fixShadowLoopMBB);
+ MF->insert(I, sinkMBB);
+
+ // Transfer the remainder of BB and its successor edges to sinkMBB.
+ sinkMBB->splice(sinkMBB->begin(), MBB, MachineBasicBlock::iterator(MI),
+ MBB->end());
+ sinkMBB->transferSuccessorsAndUpdatePHIs(MBB);
+
+ MBB->addSuccessor(checkSspMBB);
+
+ // Initialize a register with zero.
+ unsigned ZReg = MRI.createVirtualRegister(PtrRC);
+ unsigned XorRROpc = (PVT == MVT::i64) ? X86::XOR64rr : X86::XOR32rr;
+ BuildMI(checkSspMBB, DL, TII->get(XorRROpc))
+ .addDef(ZReg)
+ .addReg(ZReg, RegState::Undef)
+ .addReg(ZReg, RegState::Undef);
+
+ // Read the current SSP Register value to the zeroed register.
+ unsigned SSPCopyReg = MRI.createVirtualRegister(PtrRC);
+ unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD;
+ BuildMI(checkSspMBB, DL, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg);
+
+ // Check whether the result of the SSP register is zero and jump directly
+ // to the sink.
+ unsigned TestRROpc = (PVT == MVT::i64) ? X86::TEST64rr : X86::TEST32rr;
+ BuildMI(checkSspMBB, DL, TII->get(TestRROpc))
+ .addReg(SSPCopyReg)
+ .addReg(SSPCopyReg);
+ BuildMI(checkSspMBB, DL, TII->get(X86::JE_1)).addMBB(sinkMBB);
+ checkSspMBB->addSuccessor(sinkMBB);
+ checkSspMBB->addSuccessor(fallMBB);
+
+ // Reload the previously saved SSP register value.
+ unsigned PrevSSPReg = MRI.createVirtualRegister(PtrRC);
+ unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm;
+ const int64_t SPPOffset = 3 * PVT.getStoreSize();
+ MachineInstrBuilder MIB =
+ BuildMI(fallMBB, DL, TII->get(PtrLoadOpc), PrevSSPReg);
+ for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
+ if (i == X86::AddrDisp)
+ MIB.addDisp(MI.getOperand(i), SPPOffset);
+ else
+ MIB.add(MI.getOperand(i));
+ }
+ MIB.setMemRefs(MMOBegin, MMOEnd);
+
+ // Subtract the current SSP from the previous SSP.
+ unsigned SspSubReg = MRI.createVirtualRegister(PtrRC);
+ unsigned SubRROpc = (PVT == MVT::i64) ? X86::SUB64rr : X86::SUB32rr;
+ BuildMI(fallMBB, DL, TII->get(SubRROpc), SspSubReg)
+ .addReg(PrevSSPReg)
+ .addReg(SSPCopyReg);
+
+ // Jump to sink in case PrevSSPReg <= SSPCopyReg.
+ BuildMI(fallMBB, DL, TII->get(X86::JBE_1)).addMBB(sinkMBB);
+ fallMBB->addSuccessor(sinkMBB);
+ fallMBB->addSuccessor(fixShadowMBB);
+
+ // Shift right by 2/3 for 32/64 because incssp multiplies the argument by 4/8.
+ unsigned ShrRIOpc = (PVT == MVT::i64) ? X86::SHR64ri : X86::SHR32ri;
+ unsigned Offset = (PVT == MVT::i64) ? 3 : 2;
+ unsigned SspFirstShrReg = MRI.createVirtualRegister(PtrRC);
+ BuildMI(fixShadowMBB, DL, TII->get(ShrRIOpc), SspFirstShrReg)
+ .addReg(SspSubReg)
+ .addImm(Offset);
+
+ // Increase SSP when looking only on the lower 8 bits of the delta.
+ unsigned IncsspOpc = (PVT == MVT::i64) ? X86::INCSSPQ : X86::INCSSPD;
+ BuildMI(fixShadowMBB, DL, TII->get(IncsspOpc)).addReg(SspFirstShrReg);
+
+ // Reset the lower 8 bits.
+ unsigned SspSecondShrReg = MRI.createVirtualRegister(PtrRC);
+ BuildMI(fixShadowMBB, DL, TII->get(ShrRIOpc), SspSecondShrReg)
+ .addReg(SspFirstShrReg)
+ .addImm(8);
+
+ // Jump if the result of the shift is zero.
+ BuildMI(fixShadowMBB, DL, TII->get(X86::JE_1)).addMBB(sinkMBB);
+ fixShadowMBB->addSuccessor(sinkMBB);
+ fixShadowMBB->addSuccessor(fixShadowLoopPrepareMBB);
+
+ // Do a single shift left.
+ unsigned ShlR1Opc = (PVT == MVT::i64) ? X86::SHL64r1 : X86::SHL32r1;
+ unsigned SspAfterShlReg = MRI.createVirtualRegister(PtrRC);
+ BuildMI(fixShadowLoopPrepareMBB, DL, TII->get(ShlR1Opc), SspAfterShlReg)
+ .addReg(SspSecondShrReg);
+
+ // Save the value 128 to a register (will be used next with incssp).
+ unsigned Value128InReg = MRI.createVirtualRegister(PtrRC);
+ unsigned MovRIOpc = (PVT == MVT::i64) ? X86::MOV64ri32 : X86::MOV32ri;
+ BuildMI(fixShadowLoopPrepareMBB, DL, TII->get(MovRIOpc), Value128InReg)
+ .addImm(128);
+ fixShadowLoopPrepareMBB->addSuccessor(fixShadowLoopMBB);
+
+ // Since incssp only looks at the lower 8 bits, we might need to do several
+ // iterations of incssp until we finish fixing the shadow stack.
+ unsigned DecReg = MRI.createVirtualRegister(PtrRC);
+ unsigned CounterReg = MRI.createVirtualRegister(PtrRC);
+ BuildMI(fixShadowLoopMBB, DL, TII->get(X86::PHI), CounterReg)
+ .addReg(SspAfterShlReg)
+ .addMBB(fixShadowLoopPrepareMBB)
+ .addReg(DecReg)
+ .addMBB(fixShadowLoopMBB);
+
+ // Every iteration we increase the SSP by 128.
+ BuildMI(fixShadowLoopMBB, DL, TII->get(IncsspOpc)).addReg(Value128InReg);
+
+ // Every iteration we decrement the counter by 1.
+ unsigned DecROpc = (PVT == MVT::i64) ? X86::DEC64r : X86::DEC32r;
+ BuildMI(fixShadowLoopMBB, DL, TII->get(DecROpc), DecReg).addReg(CounterReg);
+
+ // Jump if the counter is not zero yet.
+ BuildMI(fixShadowLoopMBB, DL, TII->get(X86::JNE_1)).addMBB(fixShadowLoopMBB);
+ fixShadowLoopMBB->addSuccessor(sinkMBB);
+ fixShadowLoopMBB->addSuccessor(fixShadowLoopMBB);
+
+ return sinkMBB;
+}
+
MachineBasicBlock *
X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI,
MachineBasicBlock *MBB) const {
@@ -27101,13 +28193,21 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI,
unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm;
unsigned IJmpOpc = (PVT == MVT::i64) ? X86::JMP64r : X86::JMP32r;
+ MachineBasicBlock *thisMBB = MBB;
+
+ // When CET and shadow stack is enabled, we need to fix the Shadow Stack.
+ if (MF->getMMI().getModule()->getModuleFlag("cf-protection-return")) {
+ thisMBB = emitLongJmpShadowStackFix(MI, thisMBB);
+ }
+
// Reload FP
- MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), FP);
+ MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), FP);
for (unsigned i = 0; i < X86::AddrNumOperands; ++i)
MIB.add(MI.getOperand(i));
MIB.setMemRefs(MMOBegin, MMOEnd);
+
// Reload IP
- MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), Tmp);
+ MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), Tmp);
for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
if (i == X86::AddrDisp)
MIB.addDisp(MI.getOperand(i), LabelOffset);
@@ -27115,8 +28215,9 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI,
MIB.add(MI.getOperand(i));
}
MIB.setMemRefs(MMOBegin, MMOEnd);
+
// Reload SP
- MIB = BuildMI(*MBB, MI, DL, TII->get(PtrLoadOpc), SP);
+ MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), SP);
for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
if (i == X86::AddrDisp)
MIB.addDisp(MI.getOperand(i), SPOffset);
@@ -27124,11 +28225,12 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI,
MIB.add(MI.getOperand(i));
}
MIB.setMemRefs(MMOBegin, MMOEnd);
+
// Jump
- BuildMI(*MBB, MI, DL, TII->get(IJmpOpc)).addReg(Tmp);
+ BuildMI(*thisMBB, MI, DL, TII->get(IJmpOpc)).addReg(Tmp);
MI.eraseFromParent();
- return MBB;
+ return thisMBB;
}
void X86TargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI,
@@ -27201,7 +28303,7 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI,
MCSymbol *Sym = nullptr;
for (const auto &MI : MBB) {
- if (MI.isDebugValue())
+ if (MI.isDebugInstr())
continue;
assert(MI.isEHLabel() && "expected EH_LABEL");
@@ -27419,21 +28521,16 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
switch (MI.getOpcode()) {
default: llvm_unreachable("Unexpected instr type to insert");
- case X86::TAILJMPd64:
- case X86::TAILJMPr64:
- case X86::TAILJMPm64:
- case X86::TAILJMPr64_REX:
- case X86::TAILJMPm64_REX:
- llvm_unreachable("TAILJMP64 would not be touched here.");
- case X86::TCRETURNdi64:
- case X86::TCRETURNri64:
- case X86::TCRETURNmi64:
- return BB;
case X86::TLS_addr32:
case X86::TLS_addr64:
case X86::TLS_base_addr32:
case X86::TLS_base_addr64:
return EmitLoweredTLSAddr(MI, BB);
+ case X86::RETPOLINE_CALL32:
+ case X86::RETPOLINE_CALL64:
+ case X86::RETPOLINE_TCRETURN32:
+ case X86::RETPOLINE_TCRETURN64:
+ return EmitLoweredRetpoline(MI, BB);
case X86::CATCHRET:
return EmitLoweredCatchRet(MI, BB);
case X86::CATCHPAD:
@@ -27446,7 +28543,7 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
return EmitLoweredTLSCall(MI, BB);
case X86::CMOV_FR32:
case X86::CMOV_FR64:
- case X86::CMOV_FR128:
+ case X86::CMOV_F128:
case X86::CMOV_GR8:
case X86::CMOV_GR16:
case X86::CMOV_GR32:
@@ -27474,11 +28571,16 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.getOpcode() == X86::RDFLAGS32 ? X86::PUSHF32 : X86::PUSHF64;
unsigned Pop = MI.getOpcode() == X86::RDFLAGS32 ? X86::POP32r : X86::POP64r;
MachineInstr *Push = BuildMI(*BB, MI, DL, TII->get(PushF));
- // Permit reads of the FLAGS register without it being defined.
+ // Permit reads of the EFLAGS and DF registers without them being defined.
// This intrinsic exists to read external processor state in flags, such as
// the trap flag, interrupt flag, and direction flag, none of which are
// modeled by the backend.
+ assert(Push->getOperand(2).getReg() == X86::EFLAGS &&
+ "Unexpected register in operand!");
Push->getOperand(2).setIsUndef();
+ assert(Push->getOperand(3).getReg() == X86::DF &&
+ "Unexpected register in operand!");
+ Push->getOperand(3).setIsUndef();
BuildMI(*BB, MI, DL, TII->get(Pop), MI.getOperand(0).getReg());
MI.eraseFromParent(); // The pseudo is gone now.
@@ -27561,32 +28663,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent(); // The pseudo instruction is gone now.
return BB;
}
- // String/text processing lowering.
- case X86::PCMPISTRM128REG:
- case X86::VPCMPISTRM128REG:
- case X86::PCMPISTRM128MEM:
- case X86::VPCMPISTRM128MEM:
- case X86::PCMPESTRM128REG:
- case X86::VPCMPESTRM128REG:
- case X86::PCMPESTRM128MEM:
- case X86::VPCMPESTRM128MEM:
- assert(Subtarget.hasSSE42() &&
- "Target must have SSE4.2 or AVX features enabled");
- return emitPCMPSTRM(MI, BB, Subtarget.getInstrInfo());
-
- // String/text processing lowering.
- case X86::PCMPISTRIREG:
- case X86::VPCMPISTRIREG:
- case X86::PCMPISTRIMEM:
- case X86::VPCMPISTRIMEM:
- case X86::PCMPESTRIREG:
- case X86::VPCMPESTRIREG:
- case X86::PCMPESTRIMEM:
- case X86::VPCMPESTRIMEM:
- assert(Subtarget.hasSSE42() &&
- "Target must have SSE4.2 or AVX features enabled");
- return emitPCMPSTRI(MI, BB, Subtarget.getInstrInfo());
-
// Thread synchronization.
case X86::MONITOR:
return emitMonitor(MI, BB, Subtarget, X86::MONITORrrr);
@@ -27633,8 +28709,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
return emitPatchPoint(MI, BB);
case TargetOpcode::PATCHABLE_EVENT_CALL:
- // Do nothing here, handle in xray instrumentation pass.
- return BB;
+ return emitXRayCustomEvent(MI, BB);
+
+ case TargetOpcode::PATCHABLE_TYPED_EVENT_CALL:
+ return emitXRayTypedEvent(MI, BB);
case X86::LCMPXCHG8B: {
const X86RegisterInfo *TRI = Subtarget.getRegisterInfo();
@@ -27702,6 +28780,65 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
// X86 Optimization Hooks
//===----------------------------------------------------------------------===//
+bool
+X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
+ const APInt &Demanded,
+ TargetLoweringOpt &TLO) const {
+ // Only optimize Ands to prevent shrinking a constant that could be
+ // matched by movzx.
+ if (Op.getOpcode() != ISD::AND)
+ return false;
+
+ EVT VT = Op.getValueType();
+
+ // Ignore vectors.
+ if (VT.isVector())
+ return false;
+
+ unsigned Size = VT.getSizeInBits();
+
+ // Make sure the RHS really is a constant.
+ ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
+ if (!C)
+ return false;
+
+ const APInt &Mask = C->getAPIntValue();
+
+ // Clear all non-demanded bits initially.
+ APInt ShrunkMask = Mask & Demanded;
+
+ // Find the width of the shrunk mask.
+ unsigned Width = ShrunkMask.getActiveBits();
+
+ // If the mask is all 0s there's nothing to do here.
+ if (Width == 0)
+ return false;
+
+ // Find the next power of 2 width, rounding up to a byte.
+ Width = PowerOf2Ceil(std::max(Width, 8U));
+ // Truncate the width to size to handle illegal types.
+ Width = std::min(Width, Size);
+
+ // Calculate a possible zero extend mask for this constant.
+ APInt ZeroExtendMask = APInt::getLowBitsSet(Size, Width);
+
+ // If we aren't changing the mask, just return true to keep it and prevent
+ // the caller from optimizing.
+ if (ZeroExtendMask == Mask)
+ return true;
+
+ // Make sure the new mask can be represented by a combination of mask bits
+ // and non-demanded bits.
+ if (!ZeroExtendMask.isSubsetOf(Mask | ~Demanded))
+ return false;
+
+ // Replace the constant with the zero extend mask.
+ SDLoc DL(Op);
+ SDValue NewC = TLO.DAG.getConstant(ZeroExtendMask, DL, VT);
+ SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC);
+ return TLO.CombineTo(Op, NewOp);
+}
+
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
@@ -27763,6 +28900,19 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
break;
}
+ case X86ISD::PACKUS: {
+ // PACKUS is just a truncation if the upper half is zero.
+ // TODO: Add DemandedElts support.
+ KnownBits Known2;
+ DAG.computeKnownBits(Op.getOperand(0), Known, Depth + 1);
+ DAG.computeKnownBits(Op.getOperand(1), Known2, Depth + 1);
+ Known.One &= Known2.One;
+ Known.Zero &= Known2.Zero;
+ if (Known.countMinLeadingZeros() < BitWidth)
+ Known.resetAll();
+ Known = Known.trunc(BitWidth);
+ break;
+ }
case X86ISD::VZEXT: {
// TODO: Add DemandedElts support.
SDValue N0 = Op.getOperand(0);
@@ -27801,6 +28951,57 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known.Zero.setBitsFrom(8);
break;
}
+
+ // Handle target shuffles.
+ // TODO - use resolveTargetShuffleInputs once we can limit recursive depth.
+ if (isTargetShuffle(Opc)) {
+ bool IsUnary;
+ SmallVector<int, 64> Mask;
+ SmallVector<SDValue, 2> Ops;
+ if (getTargetShuffleMask(Op.getNode(), VT.getSimpleVT(), true, Ops, Mask,
+ IsUnary)) {
+ unsigned NumOps = Ops.size();
+ unsigned NumElts = VT.getVectorNumElements();
+ if (Mask.size() == NumElts) {
+ SmallVector<APInt, 2> DemandedOps(NumOps, APInt(NumElts, 0));
+ Known.Zero.setAllBits(); Known.One.setAllBits();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ int M = Mask[i];
+ if (M == SM_SentinelUndef) {
+ // For UNDEF elements, we don't know anything about the common state
+ // of the shuffle result.
+ Known.resetAll();
+ break;
+ } else if (M == SM_SentinelZero) {
+ Known.One.clearAllBits();
+ continue;
+ }
+ assert(0 <= M && (unsigned)M < (NumOps * NumElts) &&
+ "Shuffle index out of range");
+
+ unsigned OpIdx = (unsigned)M / NumElts;
+ unsigned EltIdx = (unsigned)M % NumElts;
+ if (Ops[OpIdx].getValueType() != VT) {
+ // TODO - handle target shuffle ops with different value types.
+ Known.resetAll();
+ break;
+ }
+ DemandedOps[OpIdx].setBit(EltIdx);
+ }
+ // Known bits are the values that are shared by every demanded element.
+ for (unsigned i = 0; i != NumOps && !Known.isUnknown(); ++i) {
+ if (!DemandedOps[i])
+ continue;
+ KnownBits Known2;
+ DAG.computeKnownBits(Ops[i], Known2, DemandedOps[i], Depth + 1);
+ Known.One &= Known2.One;
+ Known.Zero &= Known2.Zero;
+ }
+ }
+ }
+ }
}
unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode(
@@ -27917,12 +29118,21 @@ bool X86TargetLowering::isGAPlusOffset(SDNode *N,
// TODO: Investigate sharing more of this with shuffle lowering.
static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
bool AllowFloatDomain, bool AllowIntDomain,
- SDValue &V1, SDLoc &DL, SelectionDAG &DAG,
+ SDValue &V1, const SDLoc &DL,
+ SelectionDAG &DAG,
const X86Subtarget &Subtarget,
unsigned &Shuffle, MVT &SrcVT, MVT &DstVT) {
unsigned NumMaskElts = Mask.size();
unsigned MaskEltSize = MaskVT.getScalarSizeInBits();
+ // Match against a VZEXT_MOVL vXi32 zero-extending instruction.
+ if (MaskEltSize == 32 && isUndefOrEqual(Mask[0], 0) &&
+ isUndefOrZero(Mask[1]) && isUndefInRange(Mask, 2, NumMaskElts - 2)) {
+ Shuffle = X86ISD::VZEXT_MOVL;
+ SrcVT = DstVT = !Subtarget.hasSSE2() ? MVT::v4f32 : MaskVT;
+ return true;
+ }
+
// Match against a ZERO_EXTEND_VECTOR_INREG/VZEXT instruction.
// TODO: Add 512-bit vector support (split AVX512F and AVX512BW).
if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSE41()) ||
@@ -28165,7 +29375,7 @@ static bool matchUnaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
// TODO: Investigate sharing more of this with shuffle lowering.
static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
bool AllowFloatDomain, bool AllowIntDomain,
- SDValue &V1, SDValue &V2, SDLoc &DL,
+ SDValue &V1, SDValue &V2, const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget,
unsigned &Shuffle, MVT &SrcVT, MVT &DstVT,
@@ -28175,27 +29385,28 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
if (MaskVT.is128BitVector()) {
if (isTargetShuffleEquivalent(Mask, {0, 0}) && AllowFloatDomain) {
V2 = V1;
- Shuffle = X86ISD::MOVLHPS;
- SrcVT = DstVT = MVT::v4f32;
+ V1 = (SM_SentinelUndef == Mask[0] ? DAG.getUNDEF(MVT::v4f32) : V1);
+ Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKL : X86ISD::MOVLHPS;
+ SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
return true;
}
if (isTargetShuffleEquivalent(Mask, {1, 1}) && AllowFloatDomain) {
V2 = V1;
- Shuffle = X86ISD::MOVHLPS;
- SrcVT = DstVT = MVT::v4f32;
+ Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKH : X86ISD::MOVHLPS;
+ SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
return true;
}
if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() &&
(AllowFloatDomain || !Subtarget.hasSSE41())) {
std::swap(V1, V2);
Shuffle = X86ISD::MOVSD;
- SrcVT = DstVT = MaskVT;
+ SrcVT = DstVT = MVT::v2f64;
return true;
}
if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) &&
(AllowFloatDomain || !Subtarget.hasSSE41())) {
Shuffle = X86ISD::MOVSS;
- SrcVT = DstVT = MaskVT;
+ SrcVT = DstVT = MVT::v4f32;
return true;
}
}
@@ -28228,15 +29439,11 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
return false;
}
-static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
- const APInt &Zeroable,
- bool AllowFloatDomain,
- bool AllowIntDomain,
- SDValue &V1, SDValue &V2, SDLoc &DL,
- SelectionDAG &DAG,
- const X86Subtarget &Subtarget,
- unsigned &Shuffle, MVT &ShuffleVT,
- unsigned &PermuteImm) {
+static bool matchBinaryPermuteVectorShuffle(
+ MVT MaskVT, ArrayRef<int> Mask, const APInt &Zeroable,
+ bool AllowFloatDomain, bool AllowIntDomain, SDValue &V1, SDValue &V2,
+ const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget,
+ unsigned &Shuffle, MVT &ShuffleVT, unsigned &PermuteImm) {
unsigned NumMaskElts = Mask.size();
unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
@@ -28385,7 +29592,7 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
return false;
}
-/// \brief Combine an arbitrary chain of shuffles into a single instruction if
+/// Combine an arbitrary chain of shuffles into a single instruction if
/// possible.
///
/// This is the leaf of the recursive combine below. When we have found some
@@ -28397,7 +29604,6 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
ArrayRef<int> BaseMask, int Depth,
bool HasVariableMask, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!");
assert((Inputs.size() == 1 || Inputs.size() == 2) &&
@@ -28430,6 +29636,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
unsigned NumRootElts = RootVT.getVectorNumElements();
unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts;
bool FloatDomain = VT1.isFloatingPoint() || VT2.isFloatingPoint() ||
+ (RootVT.isFloatingPoint() && Depth >= 2) ||
(RootVT.is256BitVector() && !Subtarget.hasAVX2());
// Don't combine if we are a AVX512/EVEX target and the mask element size
@@ -28458,11 +29665,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4);
Res = DAG.getBitcast(ShuffleVT, V1);
- DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res,
DAG.getUNDEF(ShuffleVT),
DAG.getConstant(PermMask, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28520,16 +29725,15 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
}
}
+ SDValue NewV1 = V1; // Save operand in case early exit happens.
if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain,
- V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT,
- ShuffleVT) &&
+ NewV1, DL, DAG, Subtarget, Shuffle,
+ ShuffleSrcVT, ShuffleVT) &&
(!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- Res = DAG.getBitcast(ShuffleSrcVT, V1);
- DCI.AddToWorklist(Res.getNode());
+ Res = DAG.getBitcast(ShuffleSrcVT, NewV1);
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28540,43 +29744,38 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
Res = DAG.getBitcast(ShuffleVT, V1);
- DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res,
DAG.getConstant(PermuteImm, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
}
+ SDValue NewV1 = V1; // Save operands in case early exit happens.
+ SDValue NewV2 = V2;
if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain,
- V1, V2, DL, DAG, Subtarget, Shuffle,
+ NewV1, NewV2, DL, DAG, Subtarget, Shuffle,
ShuffleSrcVT, ShuffleVT, UnaryShuffle) &&
(!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- V1 = DAG.getBitcast(ShuffleSrcVT, V1);
- DCI.AddToWorklist(V1.getNode());
- V2 = DAG.getBitcast(ShuffleSrcVT, V2);
- DCI.AddToWorklist(V2.getNode());
- Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2);
- DCI.AddToWorklist(Res.getNode());
+ NewV1 = DAG.getBitcast(ShuffleSrcVT, NewV1);
+ NewV2 = DAG.getBitcast(ShuffleSrcVT, NewV2);
+ Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2);
return DAG.getBitcast(RootVT, Res);
}
- if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,
- AllowIntDomain, V1, V2, DL, DAG,
- Subtarget, Shuffle, ShuffleVT,
- PermuteImm) &&
+ NewV1 = V1; // Save operands in case early exit happens.
+ NewV2 = V2;
+ if (matchBinaryPermuteVectorShuffle(
+ MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, NewV1,
+ NewV2, DL, DAG, Subtarget, Shuffle, ShuffleVT, PermuteImm) &&
(!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return SDValue(); // Nothing to do!
- V1 = DAG.getBitcast(ShuffleVT, V1);
- DCI.AddToWorklist(V1.getNode());
- V2 = DAG.getBitcast(ShuffleVT, V2);
- DCI.AddToWorklist(V2.getNode());
- Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2,
+ NewV1 = DAG.getBitcast(ShuffleVT, NewV1);
+ NewV2 = DAG.getBitcast(ShuffleVT, NewV2);
+ Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2,
DAG.getConstant(PermuteImm, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28592,11 +29791,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (Depth == 1 && Root.getOpcode() == X86ISD::EXTRQI)
return SDValue(); // Nothing to do!
V1 = DAG.getBitcast(IntMaskVT, V1);
- DCI.AddToWorklist(V1.getNode());
Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1,
DAG.getConstant(BitLen, DL, MVT::i8),
DAG.getConstant(BitIdx, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28604,13 +29801,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
if (Depth == 1 && Root.getOpcode() == X86ISD::INSERTQI)
return SDValue(); // Nothing to do!
V1 = DAG.getBitcast(IntMaskVT, V1);
- DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(IntMaskVT, V2);
- DCI.AddToWorklist(V2.getNode());
Res = DAG.getNode(X86ISD::INSERTQI, DL, IntMaskVT, V1, V2,
DAG.getConstant(BitLen, DL, MVT::i8),
DAG.getConstant(BitIdx, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
}
@@ -28640,11 +29834,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
(Subtarget.hasVBMI() && MaskVT == MVT::v64i8) ||
(Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) {
SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true);
- DCI.AddToWorklist(VPermMask.getNode());
Res = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28667,13 +29858,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
Mask[i] = NumMaskElts + i;
SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true);
- DCI.AddToWorklist(VPermMask.getNode());
Res = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(Res.getNode());
SDValue Zero = getZeroVector(MaskVT, Subtarget, DAG, DL);
- DCI.AddToWorklist(Zero.getNode());
Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, Res, VPermMask, Zero);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28690,13 +29877,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
(Subtarget.hasVBMI() && MaskVT == MVT::v64i8) ||
(Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) {
SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true);
- DCI.AddToWorklist(VPermMask.getNode());
V1 = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(MaskVT, V2);
- DCI.AddToWorklist(V2.getNode());
Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, V1, VPermMask, V2);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
return SDValue();
@@ -28722,13 +29905,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
EltBits[i] = AllOnes;
}
SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL);
- DCI.AddToWorklist(BitMask.getNode());
Res = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(Res.getNode());
unsigned AndOpcode =
FloatDomain ? unsigned(X86ISD::FAND) : unsigned(ISD::AND);
Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28745,11 +29925,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
VPermIdx.push_back(Idx);
}
SDValue VPermMask = DAG.getBuildVector(IntMaskVT, DL, VPermIdx);
- DCI.AddToWorklist(VPermMask.getNode());
Res = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28781,14 +29958,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
VPerm2Idx.push_back(Index);
}
V1 = DAG.getBitcast(MaskVT, V1);
- DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(MaskVT, V2);
- DCI.AddToWorklist(V2.getNode());
SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, IntMaskVT, DAG, DL, true);
- DCI.AddToWorklist(VPerm2MaskOp.getNode());
Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp,
DAG.getConstant(M2ZImm, DL, MVT::i8));
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28820,11 +29993,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
}
MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes);
Res = DAG.getBitcast(ByteVT, V1);
- DCI.AddToWorklist(Res.getNode());
SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask);
- DCI.AddToWorklist(PSHUFBMaskOp.getNode());
Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28853,13 +30023,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
}
MVT ByteVT = MVT::v16i8;
V1 = DAG.getBitcast(ByteVT, V1);
- DCI.AddToWorklist(V1.getNode());
V2 = DAG.getBitcast(ByteVT, V2);
- DCI.AddToWorklist(V2.getNode());
SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask);
- DCI.AddToWorklist(VPPERMMaskOp.getNode());
Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp);
- DCI.AddToWorklist(Res.getNode());
return DAG.getBitcast(RootVT, Res);
}
@@ -28870,11 +30036,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
// Attempt to constant fold all of the constant source ops.
// Returns true if the entire shuffle is folded to a constant.
// TODO: Extend this to merge multiple constant Ops and update the mask.
-static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops,
+static SDValue combineX86ShufflesConstants(ArrayRef<SDValue> Ops,
ArrayRef<int> Mask, SDValue Root,
bool HasVariableMask,
SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
MVT VT = Root.getSimpleValueType();
@@ -28950,11 +30115,10 @@ static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops,
SDLoc DL(Root);
SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL);
- DCI.AddToWorklist(CstOp.getNode());
return DAG.getBitcast(VT, CstOp);
}
-/// \brief Fully generic combining of x86 shuffle instructions.
+/// Fully generic combining of x86 shuffle instructions.
///
/// This should be the last combine run over the x86 shuffle instructions. Once
/// they have been fully optimized, this will recursively consider all chains
@@ -28985,12 +30149,12 @@ static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops,
/// combining in this recursive walk.
static SDValue combineX86ShufflesRecursively(
ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root,
- ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, int Depth,
- bool HasVariableMask, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) {
+ ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, unsigned Depth,
+ bool HasVariableMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) {
// Bound the depth of our recursive combine because this is ultimately
// quadratic in nature.
- if (Depth > 8)
+ const unsigned MaxRecursionDepth = 8;
+ if (Depth > MaxRecursionDepth)
return SDValue();
// Directly rip through bitcasts to find the underlying operand.
@@ -29143,17 +30307,21 @@ static SDValue combineX86ShufflesRecursively(
// See if we can recurse into each shuffle source op (if it's a target
// shuffle). The source op should only be combined if it either has a
// single use (i.e. current Op) or all its users have already been combined.
- for (int i = 0, e = Ops.size(); i < e; ++i)
- if (Ops[i].getNode()->hasOneUse() ||
- SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode()))
- if (SDValue Res = combineX86ShufflesRecursively(
- Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask,
- DAG, DCI, Subtarget))
- return Res;
+ // Don't recurse if we already have more source ops than we can combine in
+ // the remaining recursion depth.
+ if (Ops.size() < (MaxRecursionDepth - Depth)) {
+ for (int i = 0, e = Ops.size(); i < e; ++i)
+ if (Ops[i].getNode()->hasOneUse() ||
+ SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode()))
+ if (SDValue Res = combineX86ShufflesRecursively(
+ Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask,
+ DAG, Subtarget))
+ return Res;
+ }
// Attempt to constant fold all of the constant source ops.
if (SDValue Cst = combineX86ShufflesConstants(
- Ops, Mask, Root, HasVariableMask, DAG, DCI, Subtarget))
+ Ops, Mask, Root, HasVariableMask, DAG, Subtarget))
return Cst;
// We can only combine unary and binary shuffle mask cases.
@@ -29179,10 +30347,10 @@ static SDValue combineX86ShufflesRecursively(
// Finally, try to combine into a single shuffle instruction.
return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG,
- DCI, Subtarget);
+ Subtarget);
}
-/// \brief Get the PSHUF-style mask from PSHUF node.
+/// Get the PSHUF-style mask from PSHUF node.
///
/// This is a very minor wrapper around getTargetShuffleMask to easy forming v4
/// PSHUF-style masks that can be reused with such instructions.
@@ -29225,7 +30393,7 @@ static SmallVector<int, 4> getPSHUFShuffleMask(SDValue N) {
}
}
-/// \brief Search for a combinable shuffle across a chain ending in pshufd.
+/// Search for a combinable shuffle across a chain ending in pshufd.
///
/// We walk up the chain and look for a combinable shuffle, skipping over
/// shuffles that we could hoist this shuffle's transformation past without
@@ -29358,7 +30526,7 @@ combineRedundantDWordShuffle(SDValue N, MutableArrayRef<int> Mask,
return V;
}
-/// \brief Search for a combinable shuffle across a chain ending in pshuflw or
+/// Search for a combinable shuffle across a chain ending in pshuflw or
/// pshufhw.
///
/// We walk up the chain, skipping shuffles of the other half and looking
@@ -29426,7 +30594,7 @@ static bool combineRedundantHalfShuffle(SDValue N, MutableArrayRef<int> Mask,
return true;
}
-/// \brief Try to combine x86 target specific shuffles.
+/// Try to combine x86 target specific shuffles.
static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -29459,12 +30627,33 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
Hi = BC1.getOperand(Opcode == X86ISD::UNPCKH ? 1 : 0);
}
SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi);
- DCI.AddToWorklist(Horiz.getNode());
return DAG.getBitcast(VT, Horiz);
}
}
switch (Opcode) {
+ case X86ISD::VBROADCAST: {
+ // If broadcasting from another shuffle, attempt to simplify it.
+ // TODO - we really need a general SimplifyDemandedVectorElts mechanism.
+ SDValue Src = N.getOperand(0);
+ SDValue BC = peekThroughBitcasts(Src);
+ EVT SrcVT = Src.getValueType();
+ EVT BCVT = BC.getValueType();
+ if (isTargetShuffle(BC.getOpcode()) &&
+ VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits() == 0) {
+ unsigned Scale = VT.getScalarSizeInBits() / BCVT.getScalarSizeInBits();
+ SmallVector<int, 16> DemandedMask(BCVT.getVectorNumElements(),
+ SM_SentinelUndef);
+ for (unsigned i = 0; i != Scale; ++i)
+ DemandedMask[i] = i;
+ if (SDValue Res = combineX86ShufflesRecursively(
+ {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 1,
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
+ DAG.getBitcast(SrcVT, Res));
+ }
+ return SDValue();
+ }
case X86ISD::PSHUFD:
case X86ISD::PSHUFLW:
case X86ISD::PSHUFHW:
@@ -29505,53 +30694,31 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
}
return SDValue();
}
- case X86ISD::BLENDI: {
- SDValue V0 = N->getOperand(0);
- SDValue V1 = N->getOperand(1);
- assert(VT == V0.getSimpleValueType() && VT == V1.getSimpleValueType() &&
- "Unexpected input vector types");
-
- // Canonicalize a v2f64 blend with a mask of 2 by swapping the vector
- // operands and changing the mask to 1. This saves us a bunch of
- // pattern-matching possibilities related to scalar math ops in SSE/AVX.
- // x86InstrInfo knows how to commute this back after instruction selection
- // if it would help register allocation.
-
- // TODO: If optimizing for size or a processor that doesn't suffer from
- // partial register update stalls, this should be transformed into a MOVSD
- // instruction because a MOVSD is 1-2 bytes smaller than a BLENDPD.
-
- if (VT == MVT::v2f64)
- if (auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(2)))
- if (Mask->getZExtValue() == 2 && !isShuffleFoldableLoad(V0)) {
- SDValue NewMask = DAG.getConstant(1, DL, MVT::i8);
- return DAG.getNode(X86ISD::BLENDI, DL, VT, V1, V0, NewMask);
- }
-
- return SDValue();
- }
case X86ISD::MOVSD:
case X86ISD::MOVSS: {
- SDValue V0 = peekThroughBitcasts(N->getOperand(0));
- SDValue V1 = peekThroughBitcasts(N->getOperand(1));
- bool isZero0 = ISD::isBuildVectorAllZeros(V0.getNode());
- bool isZero1 = ISD::isBuildVectorAllZeros(V1.getNode());
- if (isZero0 && isZero1)
- return SDValue();
+ SDValue N0 = N.getOperand(0);
+ SDValue N1 = N.getOperand(1);
- // We often lower to MOVSD/MOVSS from integer as well as native float
- // types; remove unnecessary domain-crossing bitcasts if we can to make it
- // easier to combine shuffles later on. We've already accounted for the
- // domain switching cost when we decided to lower with it.
- bool isFloat = VT.isFloatingPoint();
- bool isFloat0 = V0.getSimpleValueType().isFloatingPoint();
- bool isFloat1 = V1.getSimpleValueType().isFloatingPoint();
- if ((isFloat != isFloat0 || isZero0) && (isFloat != isFloat1 || isZero1)) {
- MVT NewVT = isFloat ? (X86ISD::MOVSD == Opcode ? MVT::v2i64 : MVT::v4i32)
- : (X86ISD::MOVSD == Opcode ? MVT::v2f64 : MVT::v4f32);
- V0 = DAG.getBitcast(NewVT, V0);
- V1 = DAG.getBitcast(NewVT, V1);
- return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, NewVT, V0, V1));
+ // Canonicalize scalar FPOps:
+ // MOVS*(N0, OP(N0, N1)) --> MOVS*(N0, SCALAR_TO_VECTOR(OP(N0[0], N1[0])))
+ // If commutable, allow OP(N1[0], N0[0]).
+ unsigned Opcode1 = N1.getOpcode();
+ if (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL || Opcode1 == ISD::FSUB ||
+ Opcode1 == ISD::FDIV) {
+ SDValue N10 = N1.getOperand(0);
+ SDValue N11 = N1.getOperand(1);
+ if (N10 == N0 ||
+ (N11 == N0 && (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL))) {
+ if (N10 != N0)
+ std::swap(N10, N11);
+ MVT SVT = VT.getVectorElementType();
+ SDValue ZeroIdx = DAG.getIntPtrConstant(0, DL);
+ N10 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N10, ZeroIdx);
+ N11 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N11, ZeroIdx);
+ SDValue Scl = DAG.getNode(Opcode1, DL, SVT, N10, N11);
+ SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
+ return DAG.getNode(Opcode, DL, VT, N0, SclVec);
+ }
}
return SDValue();
@@ -29647,7 +30814,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
// Nuke no-op shuffles that show up after combining.
if (isNoopShuffleMask(Mask))
- return DCI.CombineTo(N.getNode(), N.getOperand(0), /*AddTo*/ true);
+ return N.getOperand(0);
// Look for simplifications involving one or two shuffle instructions.
SDValue V = N.getOperand(0);
@@ -29671,10 +30838,8 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
DMask[DOffset + 1] = DOffset + 0;
MVT DVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
V = DAG.getBitcast(DVT, V);
- DCI.AddToWorklist(V.getNode());
V = DAG.getNode(X86ISD::PSHUFD, DL, DVT, V,
getV4X86ShuffleImm8ForMask(DMask, DL, DAG));
- DCI.AddToWorklist(V.getNode());
return DAG.getBitcast(VT, V);
}
@@ -29705,7 +30870,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
makeArrayRef(MappedMask).equals({4, 4, 5, 5, 6, 6, 7, 7})) {
// We can replace all three shuffles with an unpack.
V = DAG.getBitcast(VT, D.getOperand(0));
- DCI.AddToWorklist(V.getNode());
return DAG.getNode(MappedMask[0] == 0 ? X86ISD::UNPCKL
: X86ISD::UNPCKH,
DL, VT, V, V);
@@ -29725,6 +30889,37 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
return SDValue();
}
+/// Checks if the shuffle mask takes subsequent elements
+/// alternately from two vectors.
+/// For example <0, 5, 2, 7> or <8, 1, 10, 3, 12, 5, 14, 7> are both correct.
+static bool isAddSubOrSubAddMask(ArrayRef<int> Mask, bool &Op0Even) {
+
+ int ParitySrc[2] = {-1, -1};
+ unsigned Size = Mask.size();
+ for (unsigned i = 0; i != Size; ++i) {
+ int M = Mask[i];
+ if (M < 0)
+ continue;
+
+ // Make sure we are using the matching element from the input.
+ if ((M % Size) != i)
+ return false;
+
+ // Make sure we use the same input for all elements of the same parity.
+ int Src = M / Size;
+ if (ParitySrc[i % 2] >= 0 && ParitySrc[i % 2] != Src)
+ return false;
+ ParitySrc[i % 2] = Src;
+ }
+
+ // Make sure each input is used.
+ if (ParitySrc[0] < 0 || ParitySrc[1] < 0 || ParitySrc[0] == ParitySrc[1])
+ return false;
+
+ Op0Even = ParitySrc[0] == 0;
+ return true;
+}
+
/// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD)
/// operation. If true is returned then the operands of ADDSUB(SUBADD) operation
/// are written to the parameters \p Opnd0 and \p Opnd1.
@@ -29735,13 +30930,13 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
/// by this operation to try to flow through the rest of the combiner
/// the fact that they're unused.
static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget,
- SDValue &Opnd0, SDValue &Opnd1,
- bool matchSubAdd = false) {
+ SelectionDAG &DAG, SDValue &Opnd0, SDValue &Opnd1,
+ bool &IsSubAdd) {
EVT VT = N->getValueType(0);
- if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) &&
- (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) &&
- (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64)))
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!Subtarget.hasSSE3() || !TLI.isTypeLegal(VT) ||
+ !VT.getSimpleVT().isFloatingPoint())
return false;
// We only handle target-independent shuffles.
@@ -29750,21 +30945,13 @@ static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget,
if (N->getOpcode() != ISD::VECTOR_SHUFFLE)
return false;
- ArrayRef<int> OrigMask = cast<ShuffleVectorSDNode>(N)->getMask();
- SmallVector<int, 16> Mask(OrigMask.begin(), OrigMask.end());
-
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
- unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB;
- unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD;
-
- // We require the first shuffle operand to be the ExpectedOpcode node,
- // and the second to be the NextExpectedOpcode node.
- if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) {
- ShuffleVectorSDNode::commuteMask(Mask);
- std::swap(V1, V2);
- } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode)
+ // Make sure we have an FADD and an FSUB.
+ if ((V1.getOpcode() != ISD::FADD && V1.getOpcode() != ISD::FSUB) ||
+ (V2.getOpcode() != ISD::FADD && V2.getOpcode() != ISD::FSUB) ||
+ V1.getOpcode() == V2.getOpcode())
return false;
// If there are other uses of these operations we can't fold them.
@@ -29773,41 +30960,101 @@ static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget,
// Ensure that both operations have the same operands. Note that we can
// commute the FADD operands.
- SDValue LHS = V1->getOperand(0), RHS = V1->getOperand(1);
- if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) &&
- (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS))
- return false;
+ SDValue LHS, RHS;
+ if (V1.getOpcode() == ISD::FSUB) {
+ LHS = V1->getOperand(0); RHS = V1->getOperand(1);
+ if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) &&
+ (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS))
+ return false;
+ } else {
+ assert(V2.getOpcode() == ISD::FSUB && "Unexpected opcode");
+ LHS = V2->getOperand(0); RHS = V2->getOperand(1);
+ if ((V1->getOperand(0) != LHS || V1->getOperand(1) != RHS) &&
+ (V1->getOperand(0) != RHS || V1->getOperand(1) != LHS))
+ return false;
+ }
- // We're looking for blends between FADD and FSUB nodes. We insist on these
- // nodes being lined up in a specific expected pattern.
- if (!(isShuffleEquivalent(V1, V2, Mask, {0, 3}) ||
- isShuffleEquivalent(V1, V2, Mask, {0, 5, 2, 7}) ||
- isShuffleEquivalent(V1, V2, Mask, {0, 9, 2, 11, 4, 13, 6, 15}) ||
- isShuffleEquivalent(V1, V2, Mask, {0, 17, 2, 19, 4, 21, 6, 23,
- 8, 25, 10, 27, 12, 29, 14, 31})))
+ ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
+ bool Op0Even;
+ if (!isAddSubOrSubAddMask(Mask, Op0Even))
return false;
+ // It's a subadd if the vector in the even parity is an FADD.
+ IsSubAdd = Op0Even ? V1->getOpcode() == ISD::FADD
+ : V2->getOpcode() == ISD::FADD;
+
Opnd0 = LHS;
Opnd1 = RHS;
return true;
}
-/// \brief Try to combine a shuffle into a target-specific add-sub or
+/// Combine shuffle of two fma nodes into FMAddSub or FMSubAdd.
+static SDValue combineShuffleToFMAddSub(SDNode *N,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ // We only handle target-independent shuffles.
+ // FIXME: It would be easy and harmless to use the target shuffle mask
+ // extraction tool to support more.
+ if (N->getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+
+ MVT VT = N->getSimpleValueType(0);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!Subtarget.hasAnyFMA() || !TLI.isTypeLegal(VT))
+ return SDValue();
+
+ // We're trying to match (shuffle fma(a, b, c), X86Fmsub(a, b, c).
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue FMAdd = Op0, FMSub = Op1;
+ if (FMSub.getOpcode() != X86ISD::FMSUB)
+ std::swap(FMAdd, FMSub);
+
+ if (FMAdd.getOpcode() != ISD::FMA || FMSub.getOpcode() != X86ISD::FMSUB ||
+ FMAdd.getOperand(0) != FMSub.getOperand(0) || !FMAdd.hasOneUse() ||
+ FMAdd.getOperand(1) != FMSub.getOperand(1) || !FMSub.hasOneUse() ||
+ FMAdd.getOperand(2) != FMSub.getOperand(2))
+ return SDValue();
+
+ // Check for correct shuffle mask.
+ ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
+ bool Op0Even;
+ if (!isAddSubOrSubAddMask(Mask, Op0Even))
+ return SDValue();
+
+ // FMAddSub takes zeroth operand from FMSub node.
+ SDLoc DL(N);
+ bool IsSubAdd = Op0Even ? Op0 == FMAdd : Op1 == FMAdd;
+ unsigned Opcode = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
+ return DAG.getNode(Opcode, DL, VT, FMAdd.getOperand(0), FMAdd.getOperand(1),
+ FMAdd.getOperand(2));
+}
+
+/// Try to combine a shuffle into a target-specific add-sub or
/// mul-add-sub node.
static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
+ if (SDValue V = combineShuffleToFMAddSub(N, Subtarget, DAG))
+ return V;
+
SDValue Opnd0, Opnd1;
- if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1))
+ bool IsSubAdd;
+ if (!isAddSubOrSubAdd(N, Subtarget, DAG, Opnd0, Opnd1, IsSubAdd))
return SDValue();
- EVT VT = N->getValueType(0);
+ MVT VT = N->getSimpleValueType(0);
SDLoc DL(N);
// Try to generate X86ISD::FMADDSUB node here.
SDValue Opnd2;
- if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))
- return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2);
+ if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) {
+ unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
+ return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2);
+ }
+
+ if (IsSubAdd)
+ return SDValue();
// Do not generate X86ISD::ADDSUB node for 512-bit types even though
// the ADDSUB idiom has been successfully recognized. There are no known
@@ -29818,26 +31065,6 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N,
return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);
}
-/// \brief Try to combine a shuffle into a target-specific
-/// mul-sub-add node.
-static SDValue combineShuffleToFMSubAdd(SDNode *N,
- const X86Subtarget &Subtarget,
- SelectionDAG &DAG) {
- SDValue Opnd0, Opnd1;
- if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true))
- return SDValue();
-
- EVT VT = N->getValueType(0);
- SDLoc DL(N);
-
- // Try to generate X86ISD::FMSUBADD node here.
- SDValue Opnd2;
- if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2))
- return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2);
-
- return SDValue();
-}
-
// We are looking for a shuffle where both sources are concatenated with undef
// and have a width that is half of the output's width. AVX2 has VPERMD/Q, so
// if we can express this as a single-source shuffle, that's preferable.
@@ -29897,8 +31124,8 @@ static SDValue foldShuffleOfHorizOp(SDNode *N) {
// lanes of each operand as:
// v4X32: A[0] + A[1] , A[2] + A[3] , B[0] + B[1] , B[2] + B[3]
// ...similarly for v2f64 and v8i16.
- // TODO: 256-bit is not the same because...x86.
- if (HOp.getOperand(0) != HOp.getOperand(1) || HOp.getValueSizeInBits() != 128)
+ // TODO: Handle UNDEF operands.
+ if (HOp.getOperand(0) != HOp.getOperand(1))
return SDValue();
// When the operands of a horizontal math op are identical, the low half of
@@ -29909,9 +31136,17 @@ static SDValue foldShuffleOfHorizOp(SDNode *N) {
// TODO: Other mask possibilities like {1,1} and {1,0} could be added here,
// but this should be tied to whatever horizontal op matching and shuffle
// canonicalization are producing.
- if (isTargetShuffleEquivalent(Mask, { 0, 0 }) ||
- isTargetShuffleEquivalent(Mask, { 0, 1, 0, 1 }) ||
- isTargetShuffleEquivalent(Mask, { 0, 1, 2, 3, 0, 1, 2, 3 }))
+ if (HOp.getValueSizeInBits() == 128 &&
+ (isTargetShuffleEquivalent(Mask, {0, 0}) ||
+ isTargetShuffleEquivalent(Mask, {0, 1, 0, 1}) ||
+ isTargetShuffleEquivalent(Mask, {0, 1, 2, 3, 0, 1, 2, 3})))
+ return HOp;
+
+ if (HOp.getValueSizeInBits() == 256 &&
+ (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}) ||
+ isTargetShuffleEquivalent(Mask, {0, 1, 0, 1, 4, 5, 4, 5}) ||
+ isTargetShuffleEquivalent(
+ Mask, {0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, 11, 8, 9, 10, 11})))
return HOp;
return SDValue();
@@ -29929,9 +31164,6 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG))
return AddSub;
- if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG))
- return FMSubAdd;
-
if (SDValue HAddSub = foldShuffleOfHorizOp(N))
return HAddSub;
}
@@ -30035,10 +31267,8 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
// a particular chain.
if (SDValue Res = combineX86ShufflesRecursively(
{Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
}
return SDValue();
@@ -30155,53 +31385,6 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,
SDValue N0 = BitCast.getOperand(0);
EVT VecVT = N0->getValueType(0);
- if (VT.isVector() && VecVT.isScalarInteger() && Subtarget.hasAVX512() &&
- N0->getOpcode() == ISD::OR) {
- SDValue Op0 = N0->getOperand(0);
- SDValue Op1 = N0->getOperand(1);
- MVT TrunckVT;
- MVT BitcastVT;
- switch (VT.getSimpleVT().SimpleTy) {
- default:
- return SDValue();
- case MVT::v16i1:
- TrunckVT = MVT::i8;
- BitcastVT = MVT::v8i1;
- break;
- case MVT::v32i1:
- TrunckVT = MVT::i16;
- BitcastVT = MVT::v16i1;
- break;
- case MVT::v64i1:
- TrunckVT = MVT::i32;
- BitcastVT = MVT::v32i1;
- break;
- }
- bool isArg0UndefRight = Op0->getOpcode() == ISD::SHL;
- bool isArg0UndefLeft =
- Op0->getOpcode() == ISD::ZERO_EXTEND || Op0->getOpcode() == ISD::AND;
- bool isArg1UndefRight = Op1->getOpcode() == ISD::SHL;
- bool isArg1UndefLeft =
- Op1->getOpcode() == ISD::ZERO_EXTEND || Op1->getOpcode() == ISD::AND;
- SDValue OpLeft;
- SDValue OpRight;
- if (isArg0UndefRight && isArg1UndefLeft) {
- OpLeft = Op0;
- OpRight = Op1;
- } else if (isArg1UndefRight && isArg0UndefLeft) {
- OpLeft = Op1;
- OpRight = Op0;
- } else
- return SDValue();
- SDLoc DL(BitCast);
- SDValue Shr = OpLeft->getOperand(0);
- SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, TrunckVT, Shr);
- SDValue Bitcast1 = DAG.getBitcast(BitcastVT, Trunc1);
- SDValue Trunc2 = DAG.getNode(ISD::TRUNCATE, DL, TrunckVT, OpRight);
- SDValue Bitcast2 = DAG.getBitcast(BitcastVT, Trunc2);
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Bitcast1, Bitcast2);
- }
-
if (!VT.isScalarInteger() || !VecVT.isSimple())
return SDValue();
@@ -30269,17 +31452,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,
SDLoc DL(BitCast);
SDValue V = DAG.getSExtOrTrunc(N0, DL, SExtVT);
- if (SExtVT == MVT::v32i8 && !Subtarget.hasInt256()) {
- // Handle pre-AVX2 cases by splitting to two v16i1's.
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- MVT ShiftTy = TLI.getScalarShiftAmountTy(DAG.getDataLayout(), MVT::i32);
- SDValue Lo = extract128BitVector(V, 0, DAG, DL);
- SDValue Hi = extract128BitVector(V, 16, DAG, DL);
- Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo);
- Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi);
- Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi,
- DAG.getConstant(16, DL, ShiftTy));
- V = DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi);
+ if (SExtVT == MVT::v16i8 || SExtVT == MVT::v32i8) {
+ V = getPMOVMSKB(DL, V, DAG, Subtarget);
return DAG.getZExtOrTrunc(V, DL, VT);
}
@@ -30296,6 +31470,153 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast,
return DAG.getZExtOrTrunc(V, DL, VT);
}
+// Convert a vXi1 constant build vector to the same width scalar integer.
+static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) {
+ EVT SrcVT = Op.getValueType();
+ assert(SrcVT.getVectorElementType() == MVT::i1 &&
+ "Expected a vXi1 vector");
+ assert(ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) &&
+ "Expected a constant build vector");
+
+ APInt Imm(SrcVT.getVectorNumElements(), 0);
+ for (unsigned Idx = 0, e = Op.getNumOperands(); Idx < e; ++Idx) {
+ SDValue In = Op.getOperand(Idx);
+ if (!In.isUndef() && (cast<ConstantSDNode>(In)->getZExtValue() & 0x1))
+ Imm.setBit(Idx);
+ }
+ EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), Imm.getBitWidth());
+ return DAG.getConstant(Imm, SDLoc(Op), IntVT);
+}
+
+static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ assert(N->getOpcode() == ISD::BITCAST && "Expected a bitcast");
+
+ if (!DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ // Only do this if we have k-registers.
+ if (!Subtarget.hasAVX512())
+ return SDValue();
+
+ EVT DstVT = N->getValueType(0);
+ SDValue Op = N->getOperand(0);
+ EVT SrcVT = Op.getValueType();
+
+ if (!Op.hasOneUse())
+ return SDValue();
+
+ // Look for logic ops.
+ if (Op.getOpcode() != ISD::AND &&
+ Op.getOpcode() != ISD::OR &&
+ Op.getOpcode() != ISD::XOR)
+ return SDValue();
+
+ // Make sure we have a bitcast between mask registers and a scalar type.
+ if (!(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
+ DstVT.isScalarInteger()) &&
+ !(DstVT.isVector() && DstVT.getVectorElementType() == MVT::i1 &&
+ SrcVT.isScalarInteger()))
+ return SDValue();
+
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ if (LHS.hasOneUse() && LHS.getOpcode() == ISD::BITCAST &&
+ LHS.getOperand(0).getValueType() == DstVT)
+ return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, LHS.getOperand(0),
+ DAG.getBitcast(DstVT, RHS));
+
+ if (RHS.hasOneUse() && RHS.getOpcode() == ISD::BITCAST &&
+ RHS.getOperand(0).getValueType() == DstVT)
+ return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
+ DAG.getBitcast(DstVT, LHS), RHS.getOperand(0));
+
+ // If the RHS is a vXi1 build vector, this is a good reason to flip too.
+ // Most of these have to move a constant from the scalar domain anyway.
+ if (ISD::isBuildVectorOfConstantSDNodes(RHS.getNode())) {
+ RHS = combinevXi1ConstantToInteger(RHS, DAG);
+ return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
+ DAG.getBitcast(DstVT, LHS), RHS);
+ }
+
+ return SDValue();
+}
+
+static SDValue createMMXBuildVector(SDValue N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ SDLoc DL(N);
+ unsigned NumElts = N.getNumOperands();
+
+ auto *BV = cast<BuildVectorSDNode>(N);
+ SDValue Splat = BV->getSplatValue();
+
+ // Build MMX element from integer GPR or SSE float values.
+ auto CreateMMXElement = [&](SDValue V) {
+ if (V.isUndef())
+ return DAG.getUNDEF(MVT::x86mmx);
+ if (V.getValueType().isFloatingPoint()) {
+ if (Subtarget.hasSSE1() && !isa<ConstantFPSDNode>(V)) {
+ V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, V);
+ V = DAG.getBitcast(MVT::v2i64, V);
+ return DAG.getNode(X86ISD::MOVDQ2Q, DL, MVT::x86mmx, V);
+ }
+ V = DAG.getBitcast(MVT::i32, V);
+ } else {
+ V = DAG.getAnyExtOrTrunc(V, DL, MVT::i32);
+ }
+ return DAG.getNode(X86ISD::MMX_MOVW2D, DL, MVT::x86mmx, V);
+ };
+
+ // Convert build vector ops to MMX data in the bottom elements.
+ SmallVector<SDValue, 8> Ops;
+
+ // Broadcast - use (PUNPCKL+)PSHUFW to broadcast single element.
+ if (Splat) {
+ if (Splat.isUndef())
+ return DAG.getUNDEF(MVT::x86mmx);
+
+ Splat = CreateMMXElement(Splat);
+
+ if (Subtarget.hasSSE1()) {
+ // Unpack v8i8 to splat i8 elements to lowest 16-bits.
+ if (NumElts == 8)
+ Splat = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx,
+ DAG.getConstant(Intrinsic::x86_mmx_punpcklbw, DL, MVT::i32), Splat,
+ Splat);
+
+ // Use PSHUFW to repeat 16-bit elements.
+ unsigned ShufMask = (NumElts > 2 ? 0 : 0x44);
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx,
+ DAG.getConstant(Intrinsic::x86_sse_pshuf_w, DL, MVT::i32), Splat,
+ DAG.getConstant(ShufMask, DL, MVT::i8));
+ }
+ Ops.append(NumElts, Splat);
+ } else {
+ for (unsigned i = 0; i != NumElts; ++i)
+ Ops.push_back(CreateMMXElement(N.getOperand(i)));
+ }
+
+ // Use tree of PUNPCKLs to build up general MMX vector.
+ while (Ops.size() > 1) {
+ unsigned NumOps = Ops.size();
+ unsigned IntrinOp =
+ (NumOps == 2 ? Intrinsic::x86_mmx_punpckldq
+ : (NumOps == 4 ? Intrinsic::x86_mmx_punpcklwd
+ : Intrinsic::x86_mmx_punpcklbw));
+ SDValue Intrin = DAG.getConstant(IntrinOp, DL, MVT::i32);
+ for (unsigned i = 0; i != NumOps; i += 2)
+ Ops[i / 2] = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx, Intrin,
+ Ops[i], Ops[i + 1]);
+ Ops.resize(NumOps / 2);
+ }
+
+ return Ops[0];
+}
+
static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -30309,42 +31630,124 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
// (i16 movmsk (16i8 sext (v16i1 x)))
// before the setcc result is scalarized on subtargets that don't have legal
// vxi1 types.
- if (DCI.isBeforeLegalize())
+ if (DCI.isBeforeLegalize()) {
if (SDValue V = combineBitcastvxi1(DAG, SDValue(N, 0), Subtarget))
return V;
+
+ // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer
+ // type, widen both sides to avoid a trip through memory.
+ if ((VT == MVT::v4i1 || VT == MVT::v2i1) && SrcVT.isScalarInteger() &&
+ Subtarget.hasAVX512()) {
+ SDLoc dl(N);
+ N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i8, N0);
+ N0 = DAG.getBitcast(MVT::v8i1, N0);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, N0,
+ DAG.getIntPtrConstant(0, dl));
+ }
+
+ // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer
+ // type, widen both sides to avoid a trip through memory.
+ if ((SrcVT == MVT::v4i1 || SrcVT == MVT::v2i1) && VT.isScalarInteger() &&
+ Subtarget.hasAVX512()) {
+ SDLoc dl(N);
+ unsigned NumConcats = 8 / SrcVT.getVectorNumElements();
+ SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT));
+ Ops[0] = N0;
+ N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops);
+ N0 = DAG.getBitcast(MVT::i8, N0);
+ return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
+ }
+ }
+
// Since MMX types are special and don't usually play with other vector types,
// it's better to handle them early to be sure we emit efficient code by
// avoiding store-load conversions.
+ if (VT == MVT::x86mmx) {
+ // Detect MMX constant vectors.
+ APInt UndefElts;
+ SmallVector<APInt, 1> EltBits;
+ if (getTargetConstantBitsFromNode(N0, 64, UndefElts, EltBits)) {
+ SDLoc DL(N0);
+ // Handle zero-extension of i32 with MOVD.
+ if (EltBits[0].countLeadingZeros() >= 32)
+ return DAG.getNode(X86ISD::MMX_MOVW2D, DL, VT,
+ DAG.getConstant(EltBits[0].trunc(32), DL, MVT::i32));
+ // Else, bitcast to a double.
+ // TODO - investigate supporting sext 32-bit immediates on x86_64.
+ APFloat F64(APFloat::IEEEdouble(), EltBits[0]);
+ return DAG.getBitcast(VT, DAG.getConstantFP(F64, DL, MVT::f64));
+ }
+
+ // Detect bitcasts to x86mmx low word.
+ if (N0.getOpcode() == ISD::BUILD_VECTOR &&
+ (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8) &&
+ N0.getOperand(0).getValueType() == SrcVT.getScalarType()) {
+ bool LowUndef = true, AllUndefOrZero = true;
+ for (unsigned i = 1, e = SrcVT.getVectorNumElements(); i != e; ++i) {
+ SDValue Op = N0.getOperand(i);
+ LowUndef &= Op.isUndef() || (i >= e/2);
+ AllUndefOrZero &= (Op.isUndef() || isNullConstant(Op));
+ }
+ if (AllUndefOrZero) {
+ SDValue N00 = N0.getOperand(0);
+ SDLoc dl(N00);
+ N00 = LowUndef ? DAG.getAnyExtOrTrunc(N00, dl, MVT::i32)
+ : DAG.getZExtOrTrunc(N00, dl, MVT::i32);
+ return DAG.getNode(X86ISD::MMX_MOVW2D, dl, VT, N00);
+ }
+ }
- // Detect bitcasts between i32 to x86mmx low word.
- if (VT == MVT::x86mmx && N0.getOpcode() == ISD::BUILD_VECTOR &&
- SrcVT == MVT::v2i32 && isNullConstant(N0.getOperand(1))) {
- SDValue N00 = N0->getOperand(0);
- if (N00.getValueType() == MVT::i32)
- return DAG.getNode(X86ISD::MMX_MOVW2D, SDLoc(N00), VT, N00);
+ // Detect bitcasts of 64-bit build vectors and convert to a
+ // MMX UNPCK/PSHUFW which takes MMX type inputs with the value in the
+ // lowest element.
+ if (N0.getOpcode() == ISD::BUILD_VECTOR &&
+ (SrcVT == MVT::v2f32 || SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 ||
+ SrcVT == MVT::v8i8))
+ return createMMXBuildVector(N0, DAG, Subtarget);
+
+ // Detect bitcasts between element or subvector extraction to x86mmx.
+ if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+ N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) &&
+ isNullConstant(N0.getOperand(1))) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getValueType().is128BitVector())
+ return DAG.getNode(X86ISD::MOVDQ2Q, SDLoc(N00), VT,
+ DAG.getBitcast(MVT::v2i64, N00));
+ }
+
+ // Detect bitcasts from FP_TO_SINT to x86mmx.
+ if (SrcVT == MVT::v2i32 && N0.getOpcode() == ISD::FP_TO_SINT) {
+ SDLoc DL(N0);
+ SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0,
+ DAG.getUNDEF(MVT::v2i32));
+ return DAG.getNode(X86ISD::MOVDQ2Q, DL, VT,
+ DAG.getBitcast(MVT::v2i64, Res));
+ }
}
- // Detect bitcasts between element or subvector extraction to x86mmx.
- if (VT == MVT::x86mmx &&
- (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
- N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) &&
- isNullConstant(N0.getOperand(1))) {
- SDValue N00 = N0->getOperand(0);
- if (N00.getValueType().is128BitVector())
- return DAG.getNode(X86ISD::MOVDQ2Q, SDLoc(N00), VT,
- DAG.getBitcast(MVT::v2i64, N00));
+ // Try to remove a bitcast of constant vXi1 vector. We have to legalize
+ // most of these to scalar anyway.
+ if (Subtarget.hasAVX512() && VT.isScalarInteger() &&
+ SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
+ ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
+ return combinevXi1ConstantToInteger(N0, DAG);
}
- // Detect bitcasts from FP_TO_SINT to x86mmx.
- if (VT == MVT::x86mmx && SrcVT == MVT::v2i32 &&
- N0.getOpcode() == ISD::FP_TO_SINT) {
- SDLoc DL(N0);
- SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0,
- DAG.getUNDEF(MVT::v2i32));
- return DAG.getNode(X86ISD::MOVDQ2Q, DL, VT,
- DAG.getBitcast(MVT::v2i64, Res));
+ if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() &&
+ VT.isVector() && VT.getVectorElementType() == MVT::i1 &&
+ isa<ConstantSDNode>(N0)) {
+ auto *C = cast<ConstantSDNode>(N0);
+ if (C->isAllOnesValue())
+ return DAG.getConstant(1, SDLoc(N0), VT);
+ if (C->isNullValue())
+ return DAG.getConstant(0, SDLoc(N0), VT);
}
+ // Try to remove bitcasts from input and output of mask arithmetic to
+ // remove GPR<->K-register crossings.
+ if (SDValue V = combineCastedMaskArithmetic(N, DAG, DCI, Subtarget))
+ return V;
+
// Convert a bitcasted integer logic operation that has one bitcasted
// floating-point operand into a floating-point logic operation. This may
// create a load of a constant, but that is cheaper than materializing the
@@ -30517,8 +31920,8 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
// to these zexts.
static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
- const SDValue &Zext1, const SDLoc &DL) {
-
+ const SDValue &Zext1, const SDLoc &DL,
+ const X86Subtarget &Subtarget) {
// Find the appropriate width for the PSADBW.
EVT InVT = Zext0.getOperand(0).getValueType();
unsigned RegSize = std::max(128u, InVT.getSizeInBits());
@@ -30533,9 +31936,15 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
Ops[0] = Zext1.getOperand(0);
SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
- // Actually build the SAD
+ // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
+ auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ MVT VT = MVT::getVectorVT(MVT::i64, Ops[0].getValueSizeInBits() / 64);
+ return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
+ };
MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
- return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1);
+ return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
+ PSADBWBuilder);
}
// Attempt to replace an min/max v8i16/v16i8 horizontal reduction with
@@ -30702,12 +32111,12 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
return SDValue();
unsigned RegSize = 128;
- if (Subtarget.hasBWI())
+ if (Subtarget.useBWIRegs())
RegSize = 512;
- else if (Subtarget.hasAVX2())
+ else if (Subtarget.hasAVX())
RegSize = 256;
- // We handle upto v16i* for SSE2 / v32i* for AVX2 / v64i* for AVX512.
+ // We handle upto v16i* for SSE2 / v32i* for AVX / v64i* for AVX512.
// TODO: We should be able to handle larger vectors by splitting them before
// feeding them into several SADs, and then reducing over those.
if (RegSize / VT.getVectorNumElements() < 8)
@@ -30742,7 +32151,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
// Create the SAD instruction.
SDLoc DL(Extract);
- SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL);
+ SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL, Subtarget);
// If the original vector was wider than 8 elements, sum over the results
// in the SAD vector.
@@ -30791,6 +32200,11 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG,
if (SrcSVT == MVT::i1 || !isa<ConstantSDNode>(Idx))
return SDValue();
+ // Handle extract(broadcast(scalar_value)), it doesn't matter what index is.
+ if (X86ISD::VBROADCAST == Src.getOpcode() &&
+ Src.getOperand(0).getValueType() == VT)
+ return Src.getOperand(0);
+
// Resolve the target shuffle inputs and mask.
SmallVector<int, 16> Mask;
SmallVector<SDValue, 2> Ops;
@@ -30908,8 +32322,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
isa<ConstantSDNode>(EltIdx) &&
isa<ConstantSDNode>(InputVector.getOperand(0))) {
uint64_t ExtractedElt = N->getConstantOperandVal(1);
- uint64_t InputValue = InputVector.getConstantOperandVal(0);
- uint64_t Res = (InputValue >> ExtractedElt) & 1;
+ auto *InputC = cast<ConstantSDNode>(InputVector.getOperand(0));
+ const APInt &InputValue = InputC->getAPIntValue();
+ uint64_t Res = InputValue[ExtractedElt];
return DAG.getConstant(Res, dl, MVT::i1);
}
@@ -30927,102 +32342,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget))
return MinMax;
- // Only operate on vectors of 4 elements, where the alternative shuffling
- // gets to be more expensive.
- if (SrcVT != MVT::v4i32)
- return SDValue();
-
- // Check whether every use of InputVector is an EXTRACT_VECTOR_ELT with a
- // single use which is a sign-extend or zero-extend, and all elements are
- // used.
- SmallVector<SDNode *, 4> Uses;
- unsigned ExtractedElements = 0;
- for (SDNode::use_iterator UI = InputVector.getNode()->use_begin(),
- UE = InputVector.getNode()->use_end(); UI != UE; ++UI) {
- if (UI.getUse().getResNo() != InputVector.getResNo())
- return SDValue();
-
- SDNode *Extract = *UI;
- if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
- return SDValue();
-
- if (Extract->getValueType(0) != MVT::i32)
- return SDValue();
- if (!Extract->hasOneUse())
- return SDValue();
- if (Extract->use_begin()->getOpcode() != ISD::SIGN_EXTEND &&
- Extract->use_begin()->getOpcode() != ISD::ZERO_EXTEND)
- return SDValue();
- if (!isa<ConstantSDNode>(Extract->getOperand(1)))
- return SDValue();
-
- // Record which element was extracted.
- ExtractedElements |= 1 << Extract->getConstantOperandVal(1);
- Uses.push_back(Extract);
- }
-
- // If not all the elements were used, this may not be worthwhile.
- if (ExtractedElements != 15)
- return SDValue();
-
- // Ok, we've now decided to do the transformation.
- // If 64-bit shifts are legal, use the extract-shift sequence,
- // otherwise bounce the vector off the cache.
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- SDValue Vals[4];
-
- if (TLI.isOperationLegal(ISD::SRA, MVT::i64)) {
- SDValue Cst = DAG.getBitcast(MVT::v2i64, InputVector);
- auto &DL = DAG.getDataLayout();
- EVT VecIdxTy = DAG.getTargetLoweringInfo().getVectorIdxTy(DL);
- SDValue BottomHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Cst,
- DAG.getConstant(0, dl, VecIdxTy));
- SDValue TopHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Cst,
- DAG.getConstant(1, dl, VecIdxTy));
-
- SDValue ShAmt = DAG.getConstant(
- 32, dl, DAG.getTargetLoweringInfo().getShiftAmountTy(MVT::i64, DL));
- Vals[0] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, BottomHalf);
- Vals[1] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32,
- DAG.getNode(ISD::SRA, dl, MVT::i64, BottomHalf, ShAmt));
- Vals[2] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, TopHalf);
- Vals[3] = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32,
- DAG.getNode(ISD::SRA, dl, MVT::i64, TopHalf, ShAmt));
- } else {
- // Store the value to a temporary stack slot.
- SDValue StackPtr = DAG.CreateStackTemporary(SrcVT);
- SDValue Ch = DAG.getStore(DAG.getEntryNode(), dl, InputVector, StackPtr,
- MachinePointerInfo());
-
- EVT ElementType = SrcVT.getVectorElementType();
- unsigned EltSize = ElementType.getSizeInBits() / 8;
-
- // Replace each use (extract) with a load of the appropriate element.
- for (unsigned i = 0; i < 4; ++i) {
- uint64_t Offset = EltSize * i;
- auto PtrVT = TLI.getPointerTy(DAG.getDataLayout());
- SDValue OffsetVal = DAG.getConstant(Offset, dl, PtrVT);
-
- SDValue ScalarAddr =
- DAG.getNode(ISD::ADD, dl, PtrVT, StackPtr, OffsetVal);
-
- // Load the scalar.
- Vals[i] =
- DAG.getLoad(ElementType, dl, Ch, ScalarAddr, MachinePointerInfo());
- }
- }
-
- // Replace the extracts
- for (SmallVectorImpl<SDNode *>::iterator UI = Uses.begin(),
- UE = Uses.end(); UI != UE; ++UI) {
- SDNode *Extract = *UI;
-
- uint64_t IdxVal = Extract->getConstantOperandVal(1);
- DAG.ReplaceAllUsesOfValueWith(SDValue(Extract, 0), Vals[IdxVal]);
- }
-
- // The replacement was made in place; return N so it won't be revisited.
- return SDValue(N, 0);
+ return SDValue();
}
/// If a vector select has an operand that is -1 or 0, try to simplify the
@@ -31051,8 +32371,7 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG,
if (TValIsAllZeros && Subtarget.hasAVX512() && Cond.hasOneUse() &&
CondVT.getVectorElementType() == MVT::i1) {
// Invert the cond to not(cond) : xor(op,allones)=not(op)
- SDValue CondNew = DAG.getNode(ISD::XOR, DL, CondVT, Cond,
- DAG.getAllOnesConstant(DL, CondVT));
+ SDValue CondNew = DAG.getNOT(DL, Cond, CondVT);
// Vselect cond, op1, op2 = Vselect not(cond), op2, op1
return DAG.getSelect(DL, VT, CondNew, RHS, LHS);
}
@@ -31191,68 +32510,77 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
-// If this is a bitcasted op that can be represented as another type, push the
-// the bitcast to the inputs. This allows more opportunities for pattern
-// matching masked instructions. This is called when we know that the operation
-// is used as one of the inputs of a vselect.
-static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI) {
- // Make sure we have a bitcast.
- if (OrigOp.getOpcode() != ISD::BITCAST)
- return false;
-
- SDValue Op = OrigOp.getOperand(0);
-
- // If the operation is used by anything other than the bitcast, we shouldn't
- // do this combine as that would replicate the operation.
- if (!Op.hasOneUse())
- return false;
+/// If this is a *dynamic* select (non-constant condition) and we can match
+/// this node with one of the variable blend instructions, restructure the
+/// condition so that blends can use the high (sign) bit of each element.
+static SDValue combineVSelectToShrunkBlend(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ SDValue Cond = N->getOperand(0);
+ if (N->getOpcode() != ISD::VSELECT ||
+ ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()))
+ return SDValue();
+
+ // Don't optimize before the condition has been transformed to a legal type
+ // and don't ever optimize vector selects that map to AVX512 mask-registers.
+ unsigned BitWidth = Cond.getScalarValueSizeInBits();
+ if (BitWidth < 8 || BitWidth > 64)
+ return SDValue();
+
+ // We can only handle the cases where VSELECT is directly legal on the
+ // subtarget. We custom lower VSELECT nodes with constant conditions and
+ // this makes it hard to see whether a dynamic VSELECT will correctly
+ // lower, so we both check the operation's status and explicitly handle the
+ // cases where a *dynamic* blend will fail even though a constant-condition
+ // blend could be custom lowered.
+ // FIXME: We should find a better way to handle this class of problems.
+ // Potentially, we should combine constant-condition vselect nodes
+ // pre-legalization into shuffles and not mark as many types as custom
+ // lowered.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ EVT VT = N->getValueType(0);
+ if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
+ return SDValue();
+ // FIXME: We don't support i16-element blends currently. We could and
+ // should support them by making *all* the bits in the condition be set
+ // rather than just the high bit and using an i8-element blend.
+ if (VT.getVectorElementType() == MVT::i16)
+ return SDValue();
+ // Dynamic blending was only available from SSE4.1 onward.
+ if (VT.is128BitVector() && !Subtarget.hasSSE41())
+ return SDValue();
+ // Byte blends are only available in AVX2
+ if (VT == MVT::v32i8 && !Subtarget.hasAVX2())
+ return SDValue();
+ // There are no 512-bit blend instructions that use sign bits.
+ if (VT.is512BitVector())
+ return SDValue();
- MVT VT = OrigOp.getSimpleValueType();
- MVT EltVT = VT.getVectorElementType();
- SDLoc DL(Op.getNode());
+ // TODO: Add other opcodes eventually lowered into BLEND.
+ for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end();
+ UI != UE; ++UI)
+ if (UI->getOpcode() != ISD::VSELECT || UI.getOperandNo() != 0)
+ return SDValue();
- auto BitcastAndCombineShuffle = [&](unsigned Opcode, SDValue Op0, SDValue Op1,
- SDValue Op2) {
- Op0 = DAG.getBitcast(VT, Op0);
- DCI.AddToWorklist(Op0.getNode());
- Op1 = DAG.getBitcast(VT, Op1);
- DCI.AddToWorklist(Op1.getNode());
- DCI.CombineTo(OrigOp.getNode(),
- DAG.getNode(Opcode, DL, VT, Op0, Op1, Op2));
- return true;
- };
+ APInt DemandedMask(APInt::getSignMask(BitWidth));
+ KnownBits Known;
+ TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+ !DCI.isBeforeLegalizeOps());
+ if (!TLI.SimplifyDemandedBits(Cond, DemandedMask, Known, TLO, 0, true))
+ return SDValue();
- unsigned Opcode = Op.getOpcode();
- switch (Opcode) {
- case X86ISD::SHUF128: {
- if (EltVT.getSizeInBits() != 32 && EltVT.getSizeInBits() != 64)
- return false;
- // Only change element size, not type.
- if (VT.isInteger() != Op.getSimpleValueType().isInteger())
- return false;
- return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1),
- Op.getOperand(2));
- }
- case X86ISD::SUBV_BROADCAST: {
- unsigned EltSize = EltVT.getSizeInBits();
- if (EltSize != 32 && EltSize != 64)
- return false;
- // Only change element size, not type.
- if (VT.isInteger() != Op.getSimpleValueType().isInteger())
- return false;
- SDValue Op0 = Op.getOperand(0);
- MVT Op0VT = MVT::getVectorVT(EltVT,
- Op0.getSimpleValueType().getSizeInBits() / EltSize);
- Op0 = DAG.getBitcast(Op0VT, Op.getOperand(0));
- DCI.AddToWorklist(Op0.getNode());
- DCI.CombineTo(OrigOp.getNode(),
- DAG.getNode(Opcode, DL, VT, Op0));
- return true;
+ // If we changed the computation somewhere in the DAG, this change will
+ // affect all users of Cond. Update all the nodes so that we do not use
+ // the generic VSELECT anymore. Otherwise, we may perform wrong
+ // optimizations as we messed with the actual expectation for the vector
+ // boolean values.
+ for (SDNode *U : Cond->uses()) {
+ SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, SDLoc(U), U->getValueType(0),
+ Cond, U->getOperand(1), U->getOperand(2));
+ DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB);
}
- }
-
- return false;
+ DCI.CommitTargetLoweringOpt(TLO);
+ return SDValue(N, 0);
}
/// Do target-specific dag combines on SELECT and VSELECT nodes.
@@ -31268,6 +32596,23 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
EVT CondVT = Cond.getValueType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ // Convert vselects with constant condition into shuffles.
+ if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()) &&
+ DCI.isBeforeLegalizeOps()) {
+ SmallVector<int, 64> Mask(VT.getVectorNumElements(), -1);
+ for (int i = 0, Size = Mask.size(); i != Size; ++i) {
+ SDValue CondElt = Cond->getOperand(i);
+ Mask[i] = i;
+ // Arbitrarily choose from the 2nd operand if the select condition element
+ // is undef.
+ // TODO: Can we do better by matching patterns such as even/odd?
+ if (CondElt.isUndef() || isNullConstant(CondElt))
+ Mask[i] += Size;
+ }
+
+ return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask);
+ }
+
// If we have SSE[12] support, try to form min/max nodes. SSE min/max
// instructions match the semantics of the common C idiom x<y?x:y but not
// x<=y?x:y, because of how they handle negative zero (which can be
@@ -31292,7 +32637,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// and negative zero incorrectly.
if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) {
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS)))
+ !(DAG.isKnownNeverZeroFloat(LHS) ||
+ DAG.isKnownNeverZeroFloat(RHS)))
break;
std::swap(LHS, RHS);
}
@@ -31302,7 +32648,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// Converting this to a min would handle comparisons between positive
// and negative zero incorrectly.
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS))
+ !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS))
break;
Opcode = X86ISD::FMIN;
break;
@@ -31321,7 +32667,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// Converting this to a max would handle comparisons between positive
// and negative zero incorrectly.
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS))
+ !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS))
break;
Opcode = X86ISD::FMAX;
break;
@@ -31331,7 +32677,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// and negative zero incorrectly.
if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) {
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS)))
+ !(DAG.isKnownNeverZeroFloat(LHS) ||
+ DAG.isKnownNeverZeroFloat(RHS)))
break;
std::swap(LHS, RHS);
}
@@ -31358,7 +32705,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// and negative zero incorrectly, and swapping the operands would
// cause it to handle NaNs incorrectly.
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !(DAG.isKnownNeverZero(LHS) || DAG.isKnownNeverZero(RHS))) {
+ !(DAG.isKnownNeverZeroFloat(LHS) ||
+ DAG.isKnownNeverZeroFloat(RHS))) {
if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
break;
std::swap(LHS, RHS);
@@ -31394,7 +32742,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
// and negative zero incorrectly, and swapping the operands would
// cause it to handle NaNs incorrectly.
if (!DAG.getTarget().Options.UnsafeFPMath &&
- !DAG.isKnownNeverZero(LHS) && !DAG.isKnownNeverZero(RHS)) {
+ !DAG.isKnownNeverZeroFloat(LHS) &&
+ !DAG.isKnownNeverZeroFloat(RHS)) {
if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
break;
std::swap(LHS, RHS);
@@ -31418,19 +32767,38 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS);
}
+ // Some mask scalar intrinsics rely on checking if only one bit is set
+ // and implement it in C code like this:
+ // A[0] = (U & 1) ? A[0] : W[0];
+ // This creates some redundant instructions that break pattern matching.
+ // fold (select (setcc (and (X, 1), 0, seteq), Y, Z)) -> select(and(X, 1),Z,Y)
+ if (Subtarget.hasAVX512() && N->getOpcode() == ISD::SELECT &&
+ Cond.getOpcode() == ISD::SETCC && (VT == MVT::f32 || VT == MVT::f64)) {
+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+ SDValue AndNode = Cond.getOperand(0);
+ if (AndNode.getOpcode() == ISD::AND && CC == ISD::SETEQ &&
+ isNullConstant(Cond.getOperand(1)) &&
+ isOneConstant(AndNode.getOperand(1))) {
+ // LHS and RHS swapped due to
+ // setcc outputting 1 when AND resulted in 0 and vice versa.
+ AndNode = DAG.getZExtOrTrunc(AndNode, DL, MVT::i8);
+ return DAG.getNode(ISD::SELECT, DL, VT, AndNode, RHS, LHS);
+ }
+ }
+
// v16i8 (select v16i1, v16i8, v16i8) does not have a proper
// lowering on KNL. In this case we convert it to
// v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction.
- // The same situation for all 128 and 256-bit vectors of i8 and i16.
+ // The same situation all vectors of i8 and i16 without BWI.
+ // Make sure we extend these even before type legalization gets a chance to
+ // split wide vectors.
// Since SKX these selects have a proper lowering.
- if (Subtarget.hasAVX512() && CondVT.isVector() &&
+ if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && CondVT.isVector() &&
CondVT.getVectorElementType() == MVT::i1 &&
- (VT.is128BitVector() || VT.is256BitVector()) &&
+ VT.getVectorNumElements() > 4 &&
(VT.getVectorElementType() == MVT::i8 ||
- VT.getVectorElementType() == MVT::i16) &&
- !(Subtarget.hasBWI() && Subtarget.hasVLX())) {
+ VT.getVectorElementType() == MVT::i16)) {
Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
- DCI.AddToWorklist(Cond.getNode());
return DAG.getNode(N->getOpcode(), DL, VT, Cond, LHS, RHS);
}
@@ -31476,7 +32844,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC &&
// psubus is available in SSE2 and AVX2 for i8 and i16 vectors.
((Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) ||
- (Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)))) {
+ (Subtarget.hasAVX() && (VT == MVT::v32i8 || VT == MVT::v16i16)))) {
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
// Check if one of the arms of the VSELECT is a zero vector. If it's on the
@@ -31494,40 +32862,50 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
SDValue OpLHS = Other->getOperand(0), OpRHS = Other->getOperand(1);
SDValue CondRHS = Cond->getOperand(1);
+ auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops);
+ };
+
// Look for a general sub with unsigned saturation first.
// x >= y ? x-y : 0 --> subus x, y
// x > y ? x-y : 0 --> subus x, y
if ((CC == ISD::SETUGE || CC == ISD::SETUGT) &&
Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS))
- return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS },
+ SUBUSBuilder);
if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS))
- if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
- if (auto *CondRHSBV = dyn_cast<BuildVectorSDNode>(CondRHS))
- if (auto *CondRHSConst = CondRHSBV->getConstantSplatNode())
- // If the RHS is a constant we have to reverse the const
- // canonicalization.
- // x > C-1 ? x+-C : 0 --> subus x, C
- if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
- CondRHSConst->getAPIntValue() ==
- (-OpRHSConst->getAPIntValue() - 1))
- return DAG.getNode(
- X86ISD::SUBUS, DL, VT, OpLHS,
- DAG.getConstant(-OpRHSConst->getAPIntValue(), DL, VT));
+ if (isa<BuildVectorSDNode>(CondRHS)) {
+ // If the RHS is a constant we have to reverse the const
+ // canonicalization.
+ // x > C-1 ? x+-C : 0 --> subus x, C
+ auto MatchSUBUS = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
+ return Cond->getAPIntValue() == (-Op->getAPIntValue() - 1);
+ };
+ if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
+ ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchSUBUS)) {
+ OpRHS = DAG.getNode(ISD::SUB, DL, VT,
+ DAG.getConstant(0, DL, VT), OpRHS);
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS },
+ SUBUSBuilder);
+ }
// Another special case: If C was a sign bit, the sub has been
// canonicalized into a xor.
// FIXME: Would it be better to use computeKnownBits to determine
// whether it's safe to decanonicalize the xor?
// x s< 0 ? x^C : 0 --> subus x, C
- if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR &&
- ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
- OpRHSConst->getAPIntValue().isSignMask())
- // Note that we have to rebuild the RHS constant here to ensure we
- // don't rely on particular values of undef lanes.
- return DAG.getNode(
- X86ISD::SUBUS, DL, VT, OpLHS,
- DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT));
+ if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode())
+ if (CC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
+ ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
+ OpRHSConst->getAPIntValue().isSignMask()) {
+ OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
+ // Note that we have to rebuild the RHS constant here to ensure we
+ // don't rely on particular values of undef lanes.
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS },
+ SUBUSBuilder);
+ }
}
}
}
@@ -31535,99 +32913,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget))
return V;
- // If this is a *dynamic* select (non-constant condition) and we can match
- // this node with one of the variable blend instructions, restructure the
- // condition so that blends can use the high (sign) bit of each element and
- // use SimplifyDemandedBits to simplify the condition operand.
- if (N->getOpcode() == ISD::VSELECT && DCI.isBeforeLegalizeOps() &&
- !DCI.isBeforeLegalize() &&
- !ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) {
- unsigned BitWidth = Cond.getScalarValueSizeInBits();
-
- // Don't optimize vector selects that map to mask-registers.
- if (BitWidth == 1)
- return SDValue();
-
- // We can only handle the cases where VSELECT is directly legal on the
- // subtarget. We custom lower VSELECT nodes with constant conditions and
- // this makes it hard to see whether a dynamic VSELECT will correctly
- // lower, so we both check the operation's status and explicitly handle the
- // cases where a *dynamic* blend will fail even though a constant-condition
- // blend could be custom lowered.
- // FIXME: We should find a better way to handle this class of problems.
- // Potentially, we should combine constant-condition vselect nodes
- // pre-legalization into shuffles and not mark as many types as custom
- // lowered.
- if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
- return SDValue();
- // FIXME: We don't support i16-element blends currently. We could and
- // should support them by making *all* the bits in the condition be set
- // rather than just the high bit and using an i8-element blend.
- if (VT.getVectorElementType() == MVT::i16)
- return SDValue();
- // Dynamic blending was only available from SSE4.1 onward.
- if (VT.is128BitVector() && !Subtarget.hasSSE41())
- return SDValue();
- // Byte blends are only available in AVX2
- if (VT == MVT::v32i8 && !Subtarget.hasAVX2())
- return SDValue();
- // There are no 512-bit blend instructions that use sign bits.
- if (VT.is512BitVector())
- return SDValue();
-
- assert(BitWidth >= 8 && BitWidth <= 64 && "Invalid mask size");
- APInt DemandedMask(APInt::getSignMask(BitWidth));
- KnownBits Known;
- TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
- !DCI.isBeforeLegalizeOps());
- if (TLI.ShrinkDemandedConstant(Cond, DemandedMask, TLO) ||
- TLI.SimplifyDemandedBits(Cond, DemandedMask, Known, TLO)) {
- // If we changed the computation somewhere in the DAG, this change will
- // affect all users of Cond. Make sure it is fine and update all the nodes
- // so that we do not use the generic VSELECT anymore. Otherwise, we may
- // perform wrong optimizations as we messed with the actual expectation
- // for the vector boolean values.
- if (Cond != TLO.Old) {
- // Check all uses of the condition operand to check whether it will be
- // consumed by non-BLEND instructions. Those may require that all bits
- // are set properly.
- for (SDNode *U : Cond->uses()) {
- // TODO: Add other opcodes eventually lowered into BLEND.
- if (U->getOpcode() != ISD::VSELECT)
- return SDValue();
- }
-
- // Update all users of the condition before committing the change, so
- // that the VSELECT optimizations that expect the correct vector boolean
- // value will not be triggered.
- for (SDNode *U : Cond->uses()) {
- SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, SDLoc(U),
- U->getValueType(0), Cond, U->getOperand(1),
- U->getOperand(2));
- DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB);
- }
- DCI.CommitTargetLoweringOpt(TLO);
- return SDValue();
- }
- // Only Cond (rather than other nodes in the computation chain) was
- // changed. Change the condition just for N to keep the opportunity to
- // optimize all other users their own way.
- SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, DL, VT, TLO.New, LHS, RHS);
- DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), SB);
- return SDValue();
- }
- }
-
- // Look for vselects with LHS/RHS being bitcasted from an operation that
- // can be executed on another type. Push the bitcast to the inputs of
- // the operation. This exposes opportunities for using masking instructions.
- if (N->getOpcode() == ISD::VSELECT && DCI.isAfterLegalizeVectorOps() &&
- CondVT.getVectorElementType() == MVT::i1) {
- if (combineBitcastForMaskedOp(LHS, DAG, DCI))
- return SDValue(N, 0);
- if (combineBitcastForMaskedOp(RHS, DAG, DCI))
- return SDValue(N, 0);
- }
+ if (SDValue V = combineVSelectToShrunkBlend(N, DAG, DCI, Subtarget))
+ return V;
// Custom action for SELECT MMX
if (VT == MVT::x86mmx) {
@@ -31969,17 +33256,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
X86::CondCode CC = (X86::CondCode)N->getConstantOperandVal(2);
SDValue Cond = N->getOperand(3);
- if (CC == X86::COND_E || CC == X86::COND_NE) {
- switch (Cond.getOpcode()) {
- default: break;
- case X86ISD::BSR:
- case X86ISD::BSF:
- // If operand of BSR / BSF are proven never zero, then ZF cannot be set.
- if (DAG.isKnownNeverZero(Cond.getOperand(0)))
- return (CC == X86::COND_E) ? FalseOp : TrueOp;
- }
- }
-
// Try to simplify the EFLAGS and condition code operands.
// We can't always do this as FCMOV only supports a subset of X86 cond.
if (SDValue Flags = combineSetCCEFLAGS(Cond, CC, DAG, Subtarget)) {
@@ -32149,6 +33425,36 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
}
}
+ // Handle (CMOV C-1, (ADD (CTTZ X), C), (X != 0)) ->
+ // (ADD (CMOV (CTTZ X), -1, (X != 0)), C) or
+ // (CMOV (ADD (CTTZ X), C), C-1, (X == 0)) ->
+ // (ADD (CMOV C-1, (CTTZ X), (X == 0)), C)
+ if (CC == X86::COND_NE || CC == X86::COND_E) {
+ auto *Cnst = CC == X86::COND_E ? dyn_cast<ConstantSDNode>(TrueOp)
+ : dyn_cast<ConstantSDNode>(FalseOp);
+ SDValue Add = CC == X86::COND_E ? FalseOp : TrueOp;
+
+ if (Cnst && Add.getOpcode() == ISD::ADD && Add.hasOneUse()) {
+ auto *AddOp1 = dyn_cast<ConstantSDNode>(Add.getOperand(1));
+ SDValue AddOp2 = Add.getOperand(0);
+ if (AddOp1 && (AddOp2.getOpcode() == ISD::CTTZ_ZERO_UNDEF ||
+ AddOp2.getOpcode() == ISD::CTTZ)) {
+ APInt Diff = Cnst->getAPIntValue() - AddOp1->getAPIntValue();
+ if (CC == X86::COND_E) {
+ Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(), AddOp2,
+ DAG.getConstant(Diff, DL, Add.getValueType()),
+ DAG.getConstant(CC, DL, MVT::i8), Cond);
+ } else {
+ Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(),
+ DAG.getConstant(Diff, DL, Add.getValueType()),
+ AddOp2, DAG.getConstant(CC, DL, MVT::i8), Cond);
+ }
+ return DAG.getNode(X86ISD::ADD, DL, Add.getValueType(), Add,
+ SDValue(AddOp1, 0));
+ }
+ }
+ }
+
return SDValue();
}
@@ -32276,13 +33582,6 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
if ((NumElts % 2) != 0)
return SDValue();
- // If the upper 17 bits of each element are zero then we can use PMADD.
- APInt Mask17 = APInt::getHighBitsSet(32, 17);
- if (VT == MVT::v4i32 && DAG.MaskedValueIsZero(N0, Mask17) &&
- DAG.MaskedValueIsZero(N1, Mask17))
- return DAG.getNode(X86ISD::VPMADDWD, DL, VT, DAG.getBitcast(MVT::v8i16, N0),
- DAG.getBitcast(MVT::v8i16, N1));
-
unsigned RegSize = 128;
MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16);
EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
@@ -32378,7 +33677,7 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
- EVT VT, SDLoc DL) {
+ EVT VT, const SDLoc &DL) {
auto combineMulShlAddOrSub = [&](int Mult, int Shift, bool isAdd) {
SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
@@ -32390,10 +33689,11 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
return Result;
};
- auto combineMulMulAddOrSub = [&](bool isAdd) {
+ auto combineMulMulAddOrSub = [&](int Mul1, int Mul2, bool isAdd) {
SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
- DAG.getConstant(9, DL, VT));
- Result = DAG.getNode(ISD::MUL, DL, VT, Result, DAG.getConstant(3, DL, VT));
+ DAG.getConstant(Mul1, DL, VT));
+ Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, Result,
+ DAG.getConstant(Mul2, DL, VT));
Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result,
N->getOperand(0));
return Result;
@@ -32408,43 +33708,137 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
case 21:
// mul x, 21 => add ((shl (mul x, 5), 2), x)
return combineMulShlAddOrSub(5, 2, /*isAdd*/ true);
+ case 41:
+ // mul x, 41 => add ((shl (mul x, 5), 3), x)
+ return combineMulShlAddOrSub(5, 3, /*isAdd*/ true);
case 22:
// mul x, 22 => add (add ((shl (mul x, 5), 2), x), x)
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0),
combineMulShlAddOrSub(5, 2, /*isAdd*/ true));
case 19:
- // mul x, 19 => sub ((shl (mul x, 5), 2), x)
- return combineMulShlAddOrSub(5, 2, /*isAdd*/ false);
+ // mul x, 19 => add ((shl (mul x, 9), 1), x)
+ return combineMulShlAddOrSub(9, 1, /*isAdd*/ true);
+ case 37:
+ // mul x, 37 => add ((shl (mul x, 9), 2), x)
+ return combineMulShlAddOrSub(9, 2, /*isAdd*/ true);
+ case 73:
+ // mul x, 73 => add ((shl (mul x, 9), 3), x)
+ return combineMulShlAddOrSub(9, 3, /*isAdd*/ true);
case 13:
// mul x, 13 => add ((shl (mul x, 3), 2), x)
return combineMulShlAddOrSub(3, 2, /*isAdd*/ true);
case 23:
- // mul x, 13 => sub ((shl (mul x, 3), 3), x)
+ // mul x, 23 => sub ((shl (mul x, 3), 3), x)
return combineMulShlAddOrSub(3, 3, /*isAdd*/ false);
- case 14:
- // mul x, 14 => add (add ((shl (mul x, 3), 2), x), x)
- return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0),
- combineMulShlAddOrSub(3, 2, /*isAdd*/ true));
case 26:
- // mul x, 26 => sub ((mul (mul x, 9), 3), x)
- return combineMulMulAddOrSub(/*isAdd*/ false);
+ // mul x, 26 => add ((mul (mul x, 5), 5), x)
+ return combineMulMulAddOrSub(5, 5, /*isAdd*/ true);
case 28:
// mul x, 28 => add ((mul (mul x, 9), 3), x)
- return combineMulMulAddOrSub(/*isAdd*/ true);
+ return combineMulMulAddOrSub(9, 3, /*isAdd*/ true);
case 29:
// mul x, 29 => add (add ((mul (mul x, 9), 3), x), x)
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0),
- combineMulMulAddOrSub(/*isAdd*/ true));
- case 30:
- // mul x, 30 => sub (sub ((shl x, 5), x), x)
- return DAG.getNode(
- ISD::SUB, DL, VT,
- DAG.getNode(ISD::SUB, DL, VT,
- DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
- DAG.getConstant(5, DL, MVT::i8)),
- N->getOperand(0)),
- N->getOperand(0));
+ combineMulMulAddOrSub(9, 3, /*isAdd*/ true));
+ }
+
+ // Another trick. If this is a power 2 + 2/4/8, we can use a shift followed
+ // by a single LEA.
+ // First check if this a sum of two power of 2s because that's easy. Then
+ // count how many zeros are up to the first bit.
+ // TODO: We can do this even without LEA at a cost of two shifts and an add.
+ if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
+ unsigned ScaleShift = countTrailingZeros(MulAmt);
+ if (ScaleShift >= 1 && ScaleShift < 4) {
+ unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
+ SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(ShiftAmt, DL, MVT::i8));
+ SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(ScaleShift, DL, MVT::i8));
+ return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2);
+ }
+ }
+
+ return SDValue();
+}
+
+// If the upper 17 bits of each element are zero then we can use PMADDWD,
+// which is always at least as quick as PMULLD, expect on KNL.
+static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ if (Subtarget.getProcFamily() == X86Subtarget::IntelKNL)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ // Only support vXi32 vectors.
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32)
+ return SDValue();
+
+ // Make sure the vXi16 type is legal. This covers the AVX512 without BWI case.
+ MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements());
+ if (!DAG.getTargetLoweringInfo().isTypeLegal(WVT))
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ APInt Mask17 = APInt::getHighBitsSet(32, 17);
+ if (!DAG.MaskedValueIsZero(N1, Mask17) ||
+ !DAG.MaskedValueIsZero(N0, Mask17))
+ return SDValue();
+
+ // Use SplitOpsAndApply to handle AVX splitting.
+ auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
+ return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
+ { DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1) },
+ PMADDWDBuilder);
+}
+
+static SDValue combineMulToPMULDQ(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ // Only support vXi64 vectors.
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 ||
+ !DAG.getTargetLoweringInfo().isTypeLegal(VT))
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // MULDQ returns the 64-bit result of the signed multiplication of the lower
+ // 32-bits. We can lower with this if the sign bits stretch that far.
+ if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(N0) > 32 &&
+ DAG.ComputeNumSignBits(N1) > 32) {
+ auto PMULDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::PMULDQ, DL, Ops[0].getValueType(), Ops);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, { N0, N1 },
+ PMULDQBuilder, /*CheckBWI*/false);
}
+
+ // If the upper bits are zero we can use a single pmuludq.
+ APInt Mask = APInt::getHighBitsSet(64, 32);
+ if (DAG.MaskedValueIsZero(N0, Mask) && DAG.MaskedValueIsZero(N1, Mask)) {
+ auto PMULUDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::PMULUDQ, DL, Ops[0].getValueType(), Ops);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, { N0, N1 },
+ PMULUDQBuilder, /*CheckBWI*/false);
+ }
+
return SDValue();
}
@@ -32454,6 +33848,13 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
+
+ if (SDValue V = combineMulToPMADDWD(N, DAG, Subtarget))
+ return V;
+
+ if (SDValue V = combineMulToPMULDQ(N, DAG, Subtarget))
+ return V;
+
if (DCI.isBeforeLegalize() && VT.isVector())
return reduceVMULWidth(N, DAG, Subtarget);
@@ -32473,9 +33874,14 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
if (!C)
return SDValue();
uint64_t MulAmt = C->getZExtValue();
- if (isPowerOf2_64(MulAmt) || MulAmt == 3 || MulAmt == 5 || MulAmt == 9)
+ if (isPowerOf2_64(MulAmt))
return SDValue();
+ SDLoc DL(N);
+ if (MulAmt == 3 || MulAmt == 5 || MulAmt == 9)
+ return DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
+ N->getOperand(1));
+
uint64_t MulAmt1 = 0;
uint64_t MulAmt2 = 0;
if ((MulAmt % 9) == 0) {
@@ -32489,7 +33895,6 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
MulAmt2 = MulAmt / 3;
}
- SDLoc DL(N);
SDValue NewMul;
if (MulAmt2 &&
(isPowerOf2_64(MulAmt2) || MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)){
@@ -32523,39 +33928,47 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
"Both cases that could cause potential overflows should have "
"already been handled.");
int64_t SignMulAmt = C->getSExtValue();
- if ((SignMulAmt != INT64_MIN) && (SignMulAmt != INT64_MAX) &&
- (SignMulAmt != -INT64_MAX)) {
- int NumSign = SignMulAmt > 0 ? 1 : -1;
- bool IsPowerOf2_64PlusOne = isPowerOf2_64(NumSign * SignMulAmt - 1);
- bool IsPowerOf2_64MinusOne = isPowerOf2_64(NumSign * SignMulAmt + 1);
- if (IsPowerOf2_64PlusOne) {
- // (mul x, 2^N + 1) => (add (shl x, N), x)
- NewMul = DAG.getNode(
- ISD::ADD, DL, VT, N->getOperand(0),
- DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
- DAG.getConstant(Log2_64(NumSign * SignMulAmt - 1), DL,
- MVT::i8)));
- } else if (IsPowerOf2_64MinusOne) {
- // (mul x, 2^N - 1) => (sub (shl x, N), x)
- NewMul = DAG.getNode(
- ISD::SUB, DL, VT,
- DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
- DAG.getConstant(Log2_64(NumSign * SignMulAmt + 1), DL,
- MVT::i8)),
- N->getOperand(0));
- }
+ assert(SignMulAmt != INT64_MIN && "Int min should have been handled!");
+ uint64_t AbsMulAmt = SignMulAmt < 0 ? -SignMulAmt : SignMulAmt;
+ if (isPowerOf2_64(AbsMulAmt - 1)) {
+ // (mul x, 2^N + 1) => (add (shl x, N), x)
+ NewMul = DAG.getNode(
+ ISD::ADD, DL, VT, N->getOperand(0),
+ DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(Log2_64(AbsMulAmt - 1), DL,
+ MVT::i8)));
// To negate, subtract the number from zero
- if ((IsPowerOf2_64PlusOne || IsPowerOf2_64MinusOne) && NumSign == -1)
- NewMul =
- DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), NewMul);
+ if (SignMulAmt < 0)
+ NewMul = DAG.getNode(ISD::SUB, DL, VT,
+ DAG.getConstant(0, DL, VT), NewMul);
+ } else if (isPowerOf2_64(AbsMulAmt + 1)) {
+ // (mul x, 2^N - 1) => (sub (shl x, N), x)
+ NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(Log2_64(AbsMulAmt + 1),
+ DL, MVT::i8));
+ // To negate, reverse the operands of the subtract.
+ if (SignMulAmt < 0)
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), NewMul);
+ else
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0));
+ } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt - 2)) {
+ // (mul x, 2^N + 2) => (add (add (shl x, N), x), x)
+ NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(Log2_64(AbsMulAmt - 2),
+ DL, MVT::i8));
+ NewMul = DAG.getNode(ISD::ADD, DL, VT, NewMul, N->getOperand(0));
+ NewMul = DAG.getNode(ISD::ADD, DL, VT, NewMul, N->getOperand(0));
+ } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt + 2)) {
+ // (mul x, 2^N - 2) => (sub (sub (shl x, N), x), x)
+ NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+ DAG.getConstant(Log2_64(AbsMulAmt + 2),
+ DL, MVT::i8));
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0));
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0));
}
}
- if (NewMul)
- // Do not add new nodes to DAG combiner worklist.
- DCI.CombineTo(N, NewMul, false);
-
- return SDValue();
+ return NewMul;
}
static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG) {
@@ -32670,11 +34083,17 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
-static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) {
+static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ // Only do this on the last DAG combine as it can interfere with other
+ // combines.
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
// Try to improve a sequence of srl (and X, C1), C2 by inverting the order.
// TODO: This is a generic DAG combine that became an x86-only combine to
// avoid shortcomings in other folds such as bswap, bit-test ('bt'), and
@@ -32691,6 +34110,14 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) {
// transform should reduce code size. It may also enable secondary transforms
// from improved known-bits analysis or instruction selection.
APInt MaskVal = AndC->getAPIntValue();
+
+ // If this can be matched by a zero extend, don't optimize.
+ if (MaskVal.isMask()) {
+ unsigned TO = MaskVal.countTrailingOnes();
+ if (TO >= 8 && isPowerOf2_32(TO))
+ return SDValue();
+ }
+
APInt NewMaskVal = MaskVal.lshr(ShiftC->getAPIntValue());
unsigned OldMaskSize = MaskVal.getMinSignedBits();
unsigned NewMaskSize = NewMaskVal.getMinSignedBits();
@@ -32717,7 +34144,7 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG,
return V;
if (N->getOpcode() == ISD::SRL)
- if (SDValue V = combineShiftRightLogical(N, DAG))
+ if (SDValue V = combineShiftRightLogical(N, DAG, DCI))
return V;
return SDValue();
@@ -32797,12 +34224,10 @@ static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG,
// Attempt to combine as shuffle.
SDValue Op(N, 0);
- if (SDValue Res = combineX86ShufflesRecursively(
- {Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ if (SDValue Res =
+ combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1,
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
return SDValue();
}
@@ -32861,10 +34286,8 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
SDValue Op(N, 0);
if (SDValue Res = combineX86ShufflesRecursively(
{Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
}
// Constant Folding.
@@ -32900,12 +34323,10 @@ static SDValue combineVectorInsert(SDNode *N, SelectionDAG &DAG,
// Attempt to combine PINSRB/PINSRW patterns to a shuffle.
SDValue Op(N, 0);
- if (SDValue Res = combineX86ShufflesRecursively(
- {Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ if (SDValue Res =
+ combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1,
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
return SDValue();
}
@@ -32973,9 +34394,13 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG,
SDValue FSetCC =
DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CMP00, CMP01,
DAG.getConstant(x86cc, DL, MVT::i8));
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
- N->getSimpleValueType(0), FSetCC,
- DAG.getIntPtrConstant(0, DL));
+ // Need to fill with zeros to ensure the bitcast will produce zeroes
+ // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
+ SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v16i1,
+ DAG.getConstant(0, DL, MVT::v16i1),
+ FSetCC, DAG.getIntPtrConstant(0, DL));
+ return DAG.getZExtOrTrunc(DAG.getBitcast(MVT::i16, Ins), DL,
+ N->getSimpleValueType(0));
}
SDValue OnesOrZeroesF = DAG.getNode(X86ISD::FSETCC, DL,
CMP00.getValueType(), CMP00, CMP01,
@@ -33012,25 +34437,40 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Try to match (and (xor X, -1), Y) logic pattern for (andnp X, Y) combines.
+static bool matchANDXORWithAllOnesAsANDNP(SDNode *N, SDValue &X, SDValue &Y) {
+ if (N->getOpcode() != ISD::AND)
+ return false;
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ if (N0.getOpcode() == ISD::XOR &&
+ ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) {
+ X = N0.getOperand(0);
+ Y = N1;
+ return true;
+ }
+ if (N1.getOpcode() == ISD::XOR &&
+ ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) {
+ X = N1.getOperand(0);
+ Y = N0;
+ return true;
+ }
+
+ return false;
+}
+
/// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y).
static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::AND);
EVT VT = N->getValueType(0);
- SDValue N0 = N->getOperand(0);
- SDValue N1 = N->getOperand(1);
- SDLoc DL(N);
-
if (VT != MVT::v2i64 && VT != MVT::v4i64 && VT != MVT::v8i64)
return SDValue();
- if (N0.getOpcode() == ISD::XOR &&
- ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode()))
- return DAG.getNode(X86ISD::ANDNP, DL, VT, N0.getOperand(0), N1);
-
- if (N1.getOpcode() == ISD::XOR &&
- ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode()))
- return DAG.getNode(X86ISD::ANDNP, DL, VT, N1.getOperand(0), N0);
+ SDValue X, Y;
+ if (matchANDXORWithAllOnesAsANDNP(N, X, Y))
+ return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
return SDValue();
}
@@ -33042,8 +34482,7 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
// Even with AVX-512 this is still useful for removing casts around logical
// operations on vXi1 mask types.
static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
assert(VT.isVector() && "Expected vector type");
@@ -33214,7 +34653,7 @@ static bool hasBZHI(const X86Subtarget &Subtarget, MVT VT) {
// It's equivalent to performing bzhi (zero high bits) on the input, with the
// same index of the load.
static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
+ const X86Subtarget &Subtarget) {
MVT VT = Node->getSimpleValueType(0);
SDLoc dl(Node);
@@ -33269,15 +34708,16 @@ static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG,
// <- (and (srl 0xFFFFFFFF, (sub 32, idx)))
// that will be replaced with one bzhi instruction.
SDValue Inp = (i == 0) ? Node->getOperand(1) : Node->getOperand(0);
- SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, VT);
+ SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, MVT::i32);
// Get the Node which indexes into the array.
SDValue Index = getIndexFromUnindexedLoad(Ld);
if (!Index)
return SDValue();
- Index = DAG.getZExtOrTrunc(Index, dl, VT);
+ Index = DAG.getZExtOrTrunc(Index, dl, MVT::i32);
- SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, SizeC, Index);
+ SDValue Sub = DAG.getNode(ISD::SUB, dl, MVT::i32, SizeC, Index);
+ Sub = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Sub);
SDValue AllOnes = DAG.getAllOnesConstant(dl, VT);
SDValue LShr = DAG.getNode(ISD::SRL, dl, VT, AllOnes, Sub);
@@ -33303,6 +34743,20 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
DAG.getBitcast(MVT::v4f32, N->getOperand(1))));
}
+ // Use a 32-bit and+zext if upper bits known zero.
+ if (VT == MVT::i64 && Subtarget.is64Bit() &&
+ !isa<ConstantSDNode>(N->getOperand(1))) {
+ APInt HiMask = APInt::getHighBitsSet(64, 32);
+ if (DAG.MaskedValueIsZero(N->getOperand(1), HiMask) ||
+ DAG.MaskedValueIsZero(N->getOperand(0), HiMask)) {
+ SDLoc dl(N);
+ SDValue LHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N->getOperand(0));
+ SDValue RHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N->getOperand(1));
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64,
+ DAG.getNode(ISD::AND, dl, MVT::i32, LHS, RHS));
+ }
+ }
+
if (DCI.isBeforeLegalizeOps())
return SDValue();
@@ -33326,10 +34780,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
SDValue Op(N, 0);
if (SDValue Res = combineX86ShufflesRecursively(
{Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
}
// Attempt to combine a scalar bitmask AND with an extracted shuffle.
@@ -33365,7 +34817,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
if (SDValue Shuffle = combineX86ShufflesRecursively(
{SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 2,
- /*HasVarMask*/ false, DAG, DCI, Subtarget))
+ /*HasVarMask*/ false, DAG, Subtarget))
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), VT, Shuffle,
N->getOperand(0).getOperand(1));
}
@@ -33374,6 +34826,38 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Try to match OR(AND(~MASK,X),AND(MASK,Y)) logic pattern.
+static bool matchLogicBlend(SDNode *N, SDValue &X, SDValue &Y, SDValue &Mask) {
+ if (N->getOpcode() != ISD::OR)
+ return false;
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // Canonicalize AND to LHS.
+ if (N1.getOpcode() == ISD::AND)
+ std::swap(N0, N1);
+
+ // Attempt to match OR(AND(M,Y),ANDNP(M,X)).
+ if (N0.getOpcode() != ISD::AND || N1.getOpcode() != X86ISD::ANDNP)
+ return false;
+
+ Mask = N1.getOperand(0);
+ X = N1.getOperand(1);
+
+ // Check to see if the mask appeared in both the AND and ANDNP.
+ if (N0.getOperand(0) == Mask)
+ Y = N0.getOperand(1);
+ else if (N0.getOperand(1) == Mask)
+ Y = N0.getOperand(0);
+ else
+ return false;
+
+ // TODO: Attempt to match against AND(XOR(-1,M),Y) as well, waiting for
+ // ANDNP combine allows other combines to happen that prevent matching.
+ return true;
+}
+
// Try to fold:
// (or (and (m, y), (pandn m, x)))
// into:
@@ -33386,33 +34870,13 @@ static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
assert(N->getOpcode() == ISD::OR && "Unexpected Opcode");
- SDValue N0 = N->getOperand(0);
- SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
-
if (!((VT.is128BitVector() && Subtarget.hasSSE2()) ||
(VT.is256BitVector() && Subtarget.hasInt256())))
return SDValue();
- // Canonicalize AND to LHS.
- if (N1.getOpcode() == ISD::AND)
- std::swap(N0, N1);
-
- // TODO: Attempt to match against AND(XOR(-1,X),Y) as well, waiting for
- // ANDNP combine allows other combines to happen that prevent matching.
- if (N0.getOpcode() != ISD::AND || N1.getOpcode() != X86ISD::ANDNP)
- return SDValue();
-
- SDValue Mask = N1.getOperand(0);
- SDValue X = N1.getOperand(1);
- SDValue Y;
- if (N0.getOperand(0) == Mask)
- Y = N0.getOperand(1);
- if (N0.getOperand(1) == Mask)
- Y = N0.getOperand(0);
-
- // Check to see if the mask appeared in both the AND and ANDNP.
- if (!Y.getNode())
+ SDValue X, Y, Mask;
+ if (!matchLogicBlend(N, X, Y, Mask))
return SDValue();
// Validate that X, Y, and Mask are bitcasts, and see through them.
@@ -33509,7 +34973,7 @@ static SDValue lowerX86CmpEqZeroToCtlzSrl(SDValue Op, EVT ExtTy,
// encoding of shr and lzcnt is more desirable.
SDValue Trunc = DAG.getZExtOrTrunc(Clz, dl, MVT::i32);
SDValue Scc = DAG.getNode(ISD::SRL, dl, MVT::i32, Trunc,
- DAG.getConstant(Log2b, dl, VT));
+ DAG.getConstant(Log2b, dl, MVT::i8));
return DAG.getZExtOrTrunc(Scc, dl, ExtTy);
}
@@ -33829,63 +35293,180 @@ static bool isSATValidOnAVX512Subtarget(EVT SrcVT, EVT DstVT,
return false;
// FIXME: Scalar type may be supported if we move it to vector register.
- if (!SrcVT.isVector() || !SrcVT.isSimple() || SrcVT.getSizeInBits() > 512)
+ if (!SrcVT.isVector())
return false;
EVT SrcElVT = SrcVT.getScalarType();
EVT DstElVT = DstVT.getScalarType();
- if (SrcElVT.getSizeInBits() < 16 || SrcElVT.getSizeInBits() > 64)
- return false;
- if (DstElVT.getSizeInBits() < 8 || DstElVT.getSizeInBits() > 32)
+ if (DstElVT != MVT::i8 && DstElVT != MVT::i16 && DstElVT != MVT::i32)
return false;
if (SrcVT.is512BitVector() || Subtarget.hasVLX())
return SrcElVT.getSizeInBits() >= 32 || Subtarget.hasBWI();
return false;
}
-/// Detect a pattern of truncation with saturation:
-/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+/// These two patterns are equivalent to:
+/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+/// So return the smax(x, C1) value to be truncated or SDValue() if the
+/// pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (C2.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+ return SMin;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+ C2.uge(C1)) {
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+ }
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
/// Return the source value to be truncated or SDValue() if the pattern was not
/// matched.
-static SDValue detectUSatPattern(SDValue In, EVT VT) {
- if (In.getOpcode() != ISD::UMIN)
+static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
+ const APInt &Limit) -> SDValue {
+ APInt C;
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+ return V.getOperand(0);
return SDValue();
+ };
- //Saturation with truncation. We truncate from InVT to VT.
- assert(In.getScalarValueSizeInBits() > VT.getScalarSizeInBits() &&
- "Unexpected types for truncate operation");
-
- APInt C;
- if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) {
- // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
- // the element size of the destination type.
- return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) :
- SDValue();
+ APInt SignedMax, SignedMin;
+ if (MatchPackUS) {
+ SignedMax = APInt::getAllOnesValue(NumDstBits).zext(NumSrcBits);
+ SignedMin = APInt(NumSrcBits, 0);
+ } else {
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
}
+
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax))
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin))
+ return SMax;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax))
+ return SMin;
+
return SDValue();
}
+/// Detect a pattern of truncation with signed saturation.
+/// The types should allow to use VPMOVSS* instruction on AVX512.
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectAVX512SSatPattern(SDValue In, EVT VT,
+ const X86Subtarget &Subtarget,
+ const TargetLowering &TLI) {
+ if (!TLI.isTypeLegal(In.getValueType()))
+ return SDValue();
+ if (!isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget))
+ return SDValue();
+ return detectSSatPattern(In, VT);
+}
+
/// Detect a pattern of truncation with saturation:
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
/// The types should allow to use VPMOVUS* instruction on AVX512.
/// Return the source value to be truncated or SDValue() if the pattern was not
/// matched.
-static SDValue detectAVX512USatPattern(SDValue In, EVT VT,
- const X86Subtarget &Subtarget) {
+static SDValue detectAVX512USatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL,
+ const X86Subtarget &Subtarget,
+ const TargetLowering &TLI) {
+ if (!TLI.isTypeLegal(In.getValueType()))
+ return SDValue();
if (!isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget))
return SDValue();
- return detectUSatPattern(In, VT);
+ return detectUSatPattern(In, VT, DAG, DL);
}
-static SDValue
-combineTruncateWithUSat(SDValue In, EVT VT, SDLoc &DL, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
+static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ EVT SVT = VT.getScalarType();
+ EVT InVT = In.getValueType();
+ EVT InSVT = InVT.getScalarType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- if (!TLI.isTypeLegal(In.getValueType()) || !TLI.isTypeLegal(VT))
- return SDValue();
- if (auto USatVal = detectUSatPattern(In, VT))
- if (isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget))
+ if (TLI.isTypeLegal(InVT) && TLI.isTypeLegal(VT) &&
+ isSATValidOnAVX512Subtarget(InVT, VT, Subtarget)) {
+ if (auto SSatVal = detectSSatPattern(In, VT))
+ return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal);
+ if (auto USatVal = detectUSatPattern(In, VT, DAG, DL))
return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal);
+ }
+ if (VT.isVector() && isPowerOf2_32(VT.getVectorNumElements()) &&
+ (SVT == MVT::i8 || SVT == MVT::i16) &&
+ (InSVT == MVT::i16 || InSVT == MVT::i32)) {
+ if (auto USatVal = detectSSatPattern(In, VT, true)) {
+ // vXi32 -> vXi8 must be performed as PACKUSWB(PACKSSDW,PACKSSDW).
+ if (SVT == MVT::i8 && InSVT == MVT::i32) {
+ EVT MidVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+ VT.getVectorNumElements());
+ SDValue Mid = truncateVectorWithPACK(X86ISD::PACKSS, MidVT, USatVal, DL,
+ DAG, Subtarget);
+ if (Mid)
+ return truncateVectorWithPACK(X86ISD::PACKUS, VT, Mid, DL, DAG,
+ Subtarget);
+ } else if (SVT == MVT::i8 || Subtarget.hasSSE41())
+ return truncateVectorWithPACK(X86ISD::PACKUS, VT, USatVal, DL, DAG,
+ Subtarget);
+ }
+ if (auto SSatVal = detectSSatPattern(In, VT))
+ return truncateVectorWithPACK(X86ISD::PACKSS, VT, SSatVal, DL, DAG,
+ Subtarget);
+ }
return SDValue();
}
@@ -33895,7 +35476,7 @@ combineTruncateWithUSat(SDValue In, EVT VT, SDLoc &DL, SelectionDAG &DAG,
static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
- if (!VT.isVector() || !VT.isSimple())
+ if (!VT.isVector())
return SDValue();
EVT InVT = In.getValueType();
unsigned NumElems = VT.getVectorNumElements();
@@ -33937,42 +35518,13 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op);
if (!C)
return false;
- uint64_t Val = C->getZExtValue();
- if (Val < Min || Val > Max)
+ const APInt &Val = C->getAPIntValue();
+ if (Val.ult(Min) || Val.ugt(Max))
return false;
}
return true;
};
- // Split vectors to legal target size and apply AVG.
- auto LowerToAVG = [&](SDValue Op0, SDValue Op1) {
- unsigned NumSubs = 1;
- if (Subtarget.hasBWI()) {
- if (VT.getSizeInBits() > 512)
- NumSubs = VT.getSizeInBits() / 512;
- } else if (Subtarget.hasAVX2()) {
- if (VT.getSizeInBits() > 256)
- NumSubs = VT.getSizeInBits() / 256;
- } else {
- if (VT.getSizeInBits() > 128)
- NumSubs = VT.getSizeInBits() / 128;
- }
-
- if (NumSubs == 1)
- return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1);
-
- SmallVector<SDValue, 4> Subs;
- EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
- VT.getVectorNumElements() / NumSubs);
- for (unsigned i = 0; i != NumSubs; ++i) {
- unsigned Idx = i * SubVT.getVectorNumElements();
- SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits());
- SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits());
- Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS));
- }
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
- };
-
// Check if each element of the vector is left-shifted by one.
auto LHS = In.getOperand(0);
auto RHS = In.getOperand(1);
@@ -33986,6 +35538,11 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
Operands[0] = LHS.getOperand(0);
Operands[1] = LHS.getOperand(1);
+ auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
+ };
+
// Take care of the case when one of the operands is a constant vector whose
// element is in the range [1, 256].
if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) &&
@@ -33996,7 +35553,9 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
SDValue VecOnes = DAG.getConstant(1, DL, InVT);
Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
- return LowerToAVG(Operands[0].getOperand(0), Operands[1]);
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT,
+ { Operands[0].getOperand(0), Operands[1] },
+ AVGBuilder);
}
if (Operands[0].getOpcode() == ISD::ADD)
@@ -34019,8 +35578,10 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
Operands[j].getOperand(0).getValueType() != VT)
return SDValue();
- // The pattern is detected, emit X86ISD::AVG instruction.
- return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0));
+ // The pattern is detected, emit X86ISD::AVG instruction(s).
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT,
+ { Operands[0].getOperand(0),
+ Operands[1].getOperand(0) }, AVGBuilder);
}
return SDValue();
@@ -34451,6 +36012,63 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
SDValue StoredVal = St->getOperand(1);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ // If this is a store of a scalar_to_vector to v1i1, just use a scalar store.
+ // This will avoid a copy to k-register.
+ if (VT == MVT::v1i1 && VT == StVT && Subtarget.hasAVX512() &&
+ StoredVal.getOpcode() == ISD::SCALAR_TO_VECTOR &&
+ StoredVal.getOperand(0).getValueType() == MVT::i8) {
+ return DAG.getStore(St->getChain(), dl, StoredVal.getOperand(0),
+ St->getBasePtr(), St->getPointerInfo(),
+ St->getAlignment(), St->getMemOperand()->getFlags());
+ }
+
+ // Widen v2i1/v4i1 stores to v8i1.
+ if ((VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT &&
+ Subtarget.hasAVX512()) {
+ unsigned NumConcats = 8 / VT.getVectorNumElements();
+ SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(VT));
+ Ops[0] = StoredVal;
+ StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops);
+ return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
+ St->getPointerInfo(), St->getAlignment(),
+ St->getMemOperand()->getFlags());
+ }
+
+ // Turn vXi1 stores of constants into a scalar store.
+ if ((VT == MVT::v8i1 || VT == MVT::v16i1 || VT == MVT::v32i1 ||
+ VT == MVT::v64i1) && VT == StVT && TLI.isTypeLegal(VT) &&
+ ISD::isBuildVectorOfConstantSDNodes(StoredVal.getNode())) {
+ // If its a v64i1 store without 64-bit support, we need two stores.
+ if (VT == MVT::v64i1 && !Subtarget.is64Bit()) {
+ SDValue Lo = DAG.getBuildVector(MVT::v32i1, dl,
+ StoredVal->ops().slice(0, 32));
+ Lo = combinevXi1ConstantToInteger(Lo, DAG);
+ SDValue Hi = DAG.getBuildVector(MVT::v32i1, dl,
+ StoredVal->ops().slice(32, 32));
+ Hi = combinevXi1ConstantToInteger(Hi, DAG);
+
+ unsigned Alignment = St->getAlignment();
+
+ SDValue Ptr0 = St->getBasePtr();
+ SDValue Ptr1 = DAG.getMemBasePlusOffset(Ptr0, 4, dl);
+
+ SDValue Ch0 =
+ DAG.getStore(St->getChain(), dl, Lo, Ptr0, St->getPointerInfo(),
+ Alignment, St->getMemOperand()->getFlags());
+ SDValue Ch1 =
+ DAG.getStore(St->getChain(), dl, Hi, Ptr1,
+ St->getPointerInfo().getWithOffset(4),
+ MinAlign(Alignment, 4U),
+ St->getMemOperand()->getFlags());
+ return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Ch0, Ch1);
+ }
+
+ StoredVal = combinevXi1ConstantToInteger(StoredVal, DAG);
+ return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
+ St->getPointerInfo(), St->getAlignment(),
+ St->getMemOperand()->getFlags());
+ }
+
// If we are saving a concatenation of two XMM registers and 32-byte stores
// are slow, such as on Sandy Bridge, perform two 16-byte stores.
bool Fast;
@@ -34493,13 +36111,19 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
St->getPointerInfo(), St->getAlignment(),
St->getMemOperand()->getFlags());
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (SDValue Val =
- detectAVX512USatPattern(St->getValue(), St->getMemoryVT(), Subtarget))
+ detectAVX512SSatPattern(St->getValue(), St->getMemoryVT(), Subtarget,
+ TLI))
+ return EmitTruncSStore(true /* Signed saturation */, St->getChain(),
+ dl, Val, St->getBasePtr(),
+ St->getMemoryVT(), St->getMemOperand(), DAG);
+ if (SDValue Val = detectAVX512USatPattern(St->getValue(), St->getMemoryVT(),
+ DAG, dl, Subtarget, TLI))
return EmitTruncSStore(false /* Unsigned saturation */, St->getChain(),
dl, Val, St->getBasePtr(),
St->getMemoryVT(), St->getMemOperand(), DAG);
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned NumElems = VT.getVectorNumElements();
assert(StVT != VT && "Cannot truncate to the same type");
unsigned FromSz = VT.getScalarSizeInBits();
@@ -34812,7 +36436,7 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
// Try to synthesize horizontal add/sub from adds/subs of shuffles.
if (((Subtarget.hasSSE3() && (VT == MVT::v4f32 || VT == MVT::v2f64)) ||
- (Subtarget.hasFp256() && (VT == MVT::v8f32 || VT == MVT::v4f64))) &&
+ (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))) &&
isHorizontalBinOp(LHS, RHS, IsFadd)) {
auto NewOpcode = IsFadd ? X86ISD::FHADD : X86ISD::FHSUB;
return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS);
@@ -34825,7 +36449,7 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
/// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) )
static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
- SDLoc &DL) {
+ const SDLoc &DL) {
assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode");
SDValue Src = N->getOperand(0);
unsigned Opcode = Src.getOpcode();
@@ -34898,7 +36522,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
// X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its
// better to truncate if we have the chance.
if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) &&
- !Subtarget.hasDQI())
+ !TLI.isOperationLegal(Opcode, SrcVT))
return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1));
LLVM_FALLTHROUGH;
case ISD::ADD: {
@@ -34915,88 +36539,50 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
-/// Truncate a group of v4i32 into v16i8/v8i16 using X86ISD::PACKUS.
-static SDValue
-combineVectorTruncationWithPACKUS(SDNode *N, SelectionDAG &DAG,
- SmallVector<SDValue, 8> &Regs) {
- assert(Regs.size() > 0 && (Regs[0].getValueType() == MVT::v4i32 ||
- Regs[0].getValueType() == MVT::v2i64));
+/// Truncate using ISD::AND mask and X86ISD::PACKUS.
+static SDValue combineVectorTruncationWithPACKUS(SDNode *N, const SDLoc &DL,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ SDValue In = N->getOperand(0);
+ EVT InVT = In.getValueType();
+ EVT InSVT = InVT.getVectorElementType();
EVT OutVT = N->getValueType(0);
EVT OutSVT = OutVT.getVectorElementType();
- EVT InVT = Regs[0].getValueType();
- EVT InSVT = InVT.getVectorElementType();
- SDLoc DL(N);
- // First, use mask to unset all bits that won't appear in the result.
- assert((OutSVT == MVT::i8 || OutSVT == MVT::i16) &&
- "OutSVT can only be either i8 or i16.");
+ // Split a long vector into vectors of legal type and mask to unset all bits
+ // that won't appear in the result to prevent saturation.
+ // TODO - we should be doing this at the maximum legal size but this is
+ // causing regressions where we're concatenating back to max width just to
+ // perform the AND and then extracting back again.....
+ unsigned NumSubRegs = InVT.getSizeInBits() / 128;
+ unsigned NumSubRegElts = 128 / InSVT.getSizeInBits();
+ EVT SubRegVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubRegElts);
+ SmallVector<SDValue, 8> SubVecs(NumSubRegs);
+
APInt Mask =
APInt::getLowBitsSet(InSVT.getSizeInBits(), OutSVT.getSizeInBits());
- SDValue MaskVal = DAG.getConstant(Mask, DL, InVT);
- for (auto &Reg : Regs)
- Reg = DAG.getNode(ISD::AND, DL, InVT, MaskVal, Reg);
-
- MVT UnpackedVT, PackedVT;
- if (OutSVT == MVT::i8) {
- UnpackedVT = MVT::v8i16;
- PackedVT = MVT::v16i8;
- } else {
- UnpackedVT = MVT::v4i32;
- PackedVT = MVT::v8i16;
- }
-
- // In each iteration, truncate the type by a half size.
- auto RegNum = Regs.size();
- for (unsigned j = 1, e = InSVT.getSizeInBits() / OutSVT.getSizeInBits();
- j < e; j *= 2, RegNum /= 2) {
- for (unsigned i = 0; i < RegNum; i++)
- Regs[i] = DAG.getBitcast(UnpackedVT, Regs[i]);
- for (unsigned i = 0; i < RegNum / 2; i++)
- Regs[i] = DAG.getNode(X86ISD::PACKUS, DL, PackedVT, Regs[i * 2],
- Regs[i * 2 + 1]);
- }
-
- // If the type of the result is v8i8, we need do one more X86ISD::PACKUS, and
- // then extract a subvector as the result since v8i8 is not a legal type.
- if (OutVT == MVT::v8i8) {
- Regs[0] = DAG.getNode(X86ISD::PACKUS, DL, PackedVT, Regs[0], Regs[0]);
- Regs[0] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, Regs[0],
- DAG.getIntPtrConstant(0, DL));
- return Regs[0];
- } else if (RegNum > 1) {
- Regs.resize(RegNum);
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Regs);
- } else
- return Regs[0];
-}
-
-/// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS.
-static SDValue
-combineVectorTruncationWithPACKSS(SDNode *N, const X86Subtarget &Subtarget,
- SelectionDAG &DAG,
- SmallVector<SDValue, 8> &Regs) {
- assert(Regs.size() > 0 && Regs[0].getValueType() == MVT::v4i32);
- EVT OutVT = N->getValueType(0);
- SDLoc DL(N);
+ SDValue MaskVal = DAG.getConstant(Mask, DL, SubRegVT);
- // Shift left by 16 bits, then arithmetic-shift right by 16 bits.
- SDValue ShAmt = DAG.getConstant(16, DL, MVT::i32);
- for (auto &Reg : Regs) {
- Reg = getTargetVShiftNode(X86ISD::VSHLI, DL, MVT::v4i32, Reg, ShAmt,
- Subtarget, DAG);
- Reg = getTargetVShiftNode(X86ISD::VSRAI, DL, MVT::v4i32, Reg, ShAmt,
- Subtarget, DAG);
+ for (unsigned i = 0; i < NumSubRegs; i++) {
+ SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubRegVT, In,
+ DAG.getIntPtrConstant(i * NumSubRegElts, DL));
+ SubVecs[i] = DAG.getNode(ISD::AND, DL, SubRegVT, Sub, MaskVal);
}
+ In = DAG.getNode(ISD::CONCAT_VECTORS, DL, InVT, SubVecs);
- for (unsigned i = 0, e = Regs.size() / 2; i < e; i++)
- Regs[i] = DAG.getNode(X86ISD::PACKSS, DL, MVT::v8i16, Regs[i * 2],
- Regs[i * 2 + 1]);
+ return truncateVectorWithPACK(X86ISD::PACKUS, OutVT, In, DL, DAG, Subtarget);
+}
- if (Regs.size() > 2) {
- Regs.resize(Regs.size() / 2);
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Regs);
- } else
- return Regs[0];
+/// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS.
+static SDValue combineVectorTruncationWithPACKSS(SDNode *N, const SDLoc &DL,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ SDValue In = N->getOperand(0);
+ EVT InVT = In.getValueType();
+ EVT OutVT = N->getValueType(0);
+ In = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, InVT, In,
+ DAG.getValueType(OutVT));
+ return truncateVectorWithPACK(X86ISD::PACKSS, OutVT, In, DL, DAG, Subtarget);
}
/// This function transforms truncation from vXi32/vXi64 to vXi8/vXi16 into
@@ -35037,32 +36623,21 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG,
return SDValue();
SDLoc DL(N);
-
- // Split a long vector into vectors of legal type.
- unsigned RegNum = InVT.getSizeInBits() / 128;
- SmallVector<SDValue, 8> SubVec(RegNum);
- unsigned NumSubRegElts = 128 / InSVT.getSizeInBits();
- EVT SubRegVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubRegElts);
-
- for (unsigned i = 0; i < RegNum; i++)
- SubVec[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubRegVT, In,
- DAG.getIntPtrConstant(i * NumSubRegElts, DL));
-
// SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS
// for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to
// truncate 2 x v4i32 to v8i16.
if (Subtarget.hasSSE41() || OutSVT == MVT::i8)
- return combineVectorTruncationWithPACKUS(N, DAG, SubVec);
- else if (InSVT == MVT::i32)
- return combineVectorTruncationWithPACKSS(N, Subtarget, DAG, SubVec);
- else
- return SDValue();
+ return combineVectorTruncationWithPACKUS(N, DL, Subtarget, DAG);
+ if (InSVT == MVT::i32)
+ return combineVectorTruncationWithPACKSS(N, DL, Subtarget, DAG);
+
+ return SDValue();
}
/// This function transforms vector truncation of 'extended sign-bits' or
/// 'extended zero-bits' values.
/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
-static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL,
+static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// Requires SSE2 but AVX512 has fast truncate.
@@ -35082,7 +36657,7 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL,
MVT InVT = In.getValueType().getSimpleVT();
MVT InSVT = InVT.getScalarType();
- // Check we have a truncation suited for PACKSS.
+ // Check we have a truncation suited for PACKSS/PACKUS.
if (!VT.is128BitVector() && !VT.is256BitVector())
return SDValue();
if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32)
@@ -35090,25 +36665,79 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL,
if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64)
return SDValue();
- // Use PACKSS if the input has sign-bits that extend all the way to the
- // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
- unsigned NumSignBits = DAG.ComputeNumSignBits(In);
- unsigned NumPackedBits = std::min<unsigned>(SVT.getSizeInBits(), 16);
- if (NumSignBits > (InSVT.getSizeInBits() - NumPackedBits))
- return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget);
+ unsigned NumPackedSignBits = std::min<unsigned>(SVT.getSizeInBits(), 16);
+ unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
// Use PACKUS if the input has zero-bits that extend all the way to the
// packed/truncated value. e.g. masks, zext_in_reg, etc.
KnownBits Known;
DAG.computeKnownBits(In, Known);
unsigned NumLeadingZeroBits = Known.countMinLeadingZeros();
- NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8;
- if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedBits))
+ if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits))
return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget);
+ // Use PACKSS if the input has sign-bits that extend all the way to the
+ // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
+ unsigned NumSignBits = DAG.ComputeNumSignBits(In);
+ if (NumSignBits > (InSVT.getSizeInBits() - NumPackedSignBits))
+ return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget);
+
return SDValue();
}
+// Try to form a MULHU or MULHS node by looking for
+// (trunc (srl (mul ext, ext), 16))
+// TODO: This is X86 specific because we want to be able to handle wide types
+// before type legalization. But we can only do it if the vector will be
+// legalized via widening/splitting. Type legalization can't handle promotion
+// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
+// combiner.
+static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
+ SelectionDAG &DAG, const X86Subtarget &Subtarget) {
+ // First instruction should be a right shift of a multiply.
+ if (Src.getOpcode() != ISD::SRL ||
+ Src.getOperand(0).getOpcode() != ISD::MUL)
+ return SDValue();
+
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ // Only handle vXi16 types that are at least 128-bits.
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 ||
+ VT.getVectorNumElements() < 8)
+ return SDValue();
+
+ // Input type should be vXi32.
+ EVT InVT = Src.getValueType();
+ if (InVT.getVectorElementType() != MVT::i32)
+ return SDValue();
+
+ // Need a shift by 16.
+ APInt ShiftAmt;
+ if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) ||
+ ShiftAmt != 16)
+ return SDValue();
+
+ SDValue LHS = Src.getOperand(0).getOperand(0);
+ SDValue RHS = Src.getOperand(0).getOperand(1);
+
+ unsigned ExtOpc = LHS.getOpcode();
+ if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
+ RHS.getOpcode() != ExtOpc)
+ return SDValue();
+
+ // Peek through the extends.
+ LHS = LHS.getOperand(0);
+ RHS = RHS.getOperand(0);
+
+ // Ensure the input types match.
+ if (LHS.getValueType() != VT || RHS.getValueType() != VT)
+ return SDValue();
+
+ unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
+ return DAG.getNode(Opc, DL, VT, LHS, RHS);
+}
+
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@@ -35123,10 +36752,14 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL))
return Avg;
- // Try to combine truncation with unsigned saturation.
- if (SDValue Val = combineTruncateWithUSat(Src, VT, DL, DAG, Subtarget))
+ // Try to combine truncation with signed/unsigned saturation.
+ if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
+ // Try to combine PMULHUW/PMULHW for vXi16.
+ if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
+ return V;
+
// The bitcast source is a direct mmx result.
// Detect bitcasts between i32 to x86mmx
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {
@@ -35224,7 +36857,7 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
// If we're negating an FMA node, then we can adjust the
// instruction to include the extra negation.
unsigned NewOpcode = 0;
- if (Arg.hasOneUse()) {
+ if (Arg.hasOneUse() && Subtarget.hasAnyFMA()) {
switch (Arg.getOpcode()) {
case ISD::FMA: NewOpcode = X86ISD::FNMSUB; break;
case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break;
@@ -35320,6 +36953,39 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ unsigned NumBits = VT.getSizeInBits();
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+ !DCI.isBeforeLegalizeOps());
+
+ // TODO - Constant Folding.
+ if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
+ // Reduce Cst1 to the bottom 16-bits.
+ // NOTE: SimplifyDemandedBits won't do this for constants.
+ const APInt &Val1 = Cst1->getAPIntValue();
+ APInt MaskedVal1 = Val1 & 0xFFFF;
+ if (MaskedVal1 != Val1)
+ return DAG.getNode(X86ISD::BEXTR, SDLoc(N), VT, Op0,
+ DAG.getConstant(MaskedVal1, SDLoc(N), VT));
+ }
+
+ // Only bottom 16-bits of the control bits are required.
+ KnownBits Known;
+ APInt DemandedMask(APInt::getLowBitsSet(NumBits, 16));
+ if (TLI.SimplifyDemandedBits(Op1, DemandedMask, Known, TLO)) {
+ DCI.CommitTargetLoweringOpt(TLO);
+ return SDValue(N, 0);
+ }
+
+ return SDValue();
+}
static bool isNullFPScalarOrVectorConst(SDValue V) {
return isNullFPConstant(V) || ISD::isBuildVectorAllZeros(V.getNode());
@@ -35450,8 +37116,6 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
if (Subtarget.useSoftFloat())
return SDValue();
- // TODO: Check for global or instruction-level "nnan". In that case, we
- // should be able to lower to FMAX/FMIN alone.
// TODO: If an operand is already known to be a NaN or not a NaN, this
// should be an optional swap and FMAX/FMIN.
@@ -35461,14 +37125,21 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
(Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))))
return SDValue();
- // This takes at least 3 instructions, so favor a library call when operating
- // on a scalar and minimizing code size.
- if (!VT.isVector() && DAG.getMachineFunction().getFunction().optForMinSize())
- return SDValue();
-
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDLoc DL(N);
+ auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN;
+
+ // If we don't have to respect NaN inputs, this is a direct translation to x86
+ // min/max instructions.
+ if (DAG.getTarget().Options.NoNaNsFPMath || N->getFlags().hasNoNaNs())
+ return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
+
+ // If we have to respect NaN inputs, this takes at least 3 instructions.
+ // Favor a library call when operating on a scalar and minimizing code size.
+ if (!VT.isVector() && DAG.getMachineFunction().getFunction().optForMinSize())
+ return SDValue();
+
EVT SetCCType = DAG.getTargetLoweringInfo().getSetCCResultType(
DAG.getDataLayout(), *DAG.getContext(), VT);
@@ -35491,9 +37162,8 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
// use those instructions for fmaxnum by selecting away a NaN input.
// If either operand is NaN, the 2nd source operand (Op0) is passed through.
- auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN;
SDValue MinOrMax = DAG.getNode(MinMaxOp, DL, VT, Op1, Op0);
- SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType , Op0, Op0, ISD::SETUO);
+ SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType, Op0, Op0, ISD::SETUO);
// If Op0 is a NaN, select Op1. Otherwise, select the max. If both operands
// are NaN, the NaN value of Op1 is the result.
@@ -35519,10 +37189,8 @@ static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG,
SDValue Op(N, 0);
if (SDValue Res = combineX86ShufflesRecursively(
{Op}, 0, Op, {0}, {}, /*Depth*/ 1,
- /*HasVarMask*/ false, DAG, DCI, Subtarget)) {
- DCI.CombineTo(N, Res);
- return SDValue();
- }
+ /*HasVarMask*/ false, DAG, Subtarget))
+ return Res;
}
return SDValue();
@@ -35542,12 +37210,54 @@ static SDValue combineBT(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
-static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
+// Try to combine sext_in_reg of a cmov of constants by extending the constants.
+static SDValue combineSextInRegCmov(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
- if (!VT.isVector())
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT ExtraVT = cast<VTSDNode>(N1)->getVT();
+
+ if (ExtraVT != MVT::i16)
return SDValue();
+ // Look through single use any_extends.
+ if (N0.getOpcode() == ISD::ANY_EXTEND && N0.hasOneUse())
+ N0 = N0.getOperand(0);
+
+ // See if we have a single use cmov.
+ if (N0.getOpcode() != X86ISD::CMOV || !N0.hasOneUse())
+ return SDValue();
+
+ SDValue CMovOp0 = N0.getOperand(0);
+ SDValue CMovOp1 = N0.getOperand(1);
+
+ // Make sure both operands are constants.
+ if (!isa<ConstantSDNode>(CMovOp0.getNode()) ||
+ !isa<ConstantSDNode>(CMovOp1.getNode()))
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // If we looked through an any_extend above, add one to the constants.
+ if (N0.getValueType() != VT) {
+ CMovOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, VT, CMovOp0);
+ CMovOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, VT, CMovOp1);
+ }
+
+ CMovOp0 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, CMovOp0, N1);
+ CMovOp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, CMovOp1, N1);
+
+ return DAG.getNode(X86ISD::CMOV, DL, VT, CMovOp0, CMovOp1,
+ N0.getOperand(2), N0.getOperand(3));
+}
+
+static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (SDValue V = combineSextInRegCmov(N, DAG))
+ return V;
+
+ EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT ExtraVT = cast<VTSDNode>(N1)->getVT();
@@ -35686,7 +37396,7 @@ static SDValue getDivRem8(SDNode *N, SelectionDAG &DAG) {
// promotion).
static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) {
SDValue CMovN = Extend->getOperand(0);
- if (CMovN.getOpcode() != X86ISD::CMOV)
+ if (CMovN.getOpcode() != X86ISD::CMOV || !CMovN.hasOneUse())
return SDValue();
EVT TargetVT = Extend->getValueType(0);
@@ -35697,20 +37407,36 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) {
SDValue CMovOp0 = CMovN.getOperand(0);
SDValue CMovOp1 = CMovN.getOperand(1);
- bool DoPromoteCMOV =
- (VT == MVT::i16 && (TargetVT == MVT::i32 || TargetVT == MVT::i64)) &&
- CMovN.hasOneUse() &&
- (isa<ConstantSDNode>(CMovOp0.getNode()) &&
- isa<ConstantSDNode>(CMovOp1.getNode()));
+ if (!isa<ConstantSDNode>(CMovOp0.getNode()) ||
+ !isa<ConstantSDNode>(CMovOp1.getNode()))
+ return SDValue();
- if (!DoPromoteCMOV)
+ // Only extend to i32 or i64.
+ if (TargetVT != MVT::i32 && TargetVT != MVT::i64)
return SDValue();
- CMovOp0 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp0);
- CMovOp1 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp1);
+ // Only extend from i16 unless its a sign_extend from i32. Zext/aext from i32
+ // are free.
+ if (VT != MVT::i16 && !(ExtendOpcode == ISD::SIGN_EXTEND && VT == MVT::i32))
+ return SDValue();
- return DAG.getNode(X86ISD::CMOV, DL, TargetVT, CMovOp0, CMovOp1,
- CMovN.getOperand(2), CMovN.getOperand(3));
+ // If this a zero extend to i64, we should only extend to i32 and use a free
+ // zero extend to finish.
+ EVT ExtendVT = TargetVT;
+ if (TargetVT == MVT::i64 && ExtendOpcode != ISD::SIGN_EXTEND)
+ ExtendVT = MVT::i32;
+
+ CMovOp0 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp0);
+ CMovOp1 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp1);
+
+ SDValue Res = DAG.getNode(X86ISD::CMOV, DL, ExtendVT, CMovOp0, CMovOp1,
+ CMovN.getOperand(2), CMovN.getOperand(3));
+
+ // Finish extending if needed.
+ if (ExtendVT != TargetVT)
+ Res = DAG.getNode(ExtendOpcode, DL, TargetVT, Res);
+
+ return Res;
}
// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
@@ -35866,7 +37592,7 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG,
// Also use this if we don't have SSE41 to allow the legalizer do its job.
if (!Subtarget.hasSSE41() || VT.is128BitVector() ||
(VT.is256BitVector() && Subtarget.hasInt256()) ||
- (VT.is512BitVector() && Subtarget.hasAVX512())) {
+ (VT.is512BitVector() && Subtarget.useAVX512Regs())) {
SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits());
return Opcode == ISD::SIGN_EXTEND
? DAG.getSignExtendVectorInReg(ExOp, DL, VT)
@@ -35899,12 +37625,55 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG,
// On pre-AVX512 targets, split into 256-bit nodes of
// ISD::*_EXTEND_VECTOR_INREG.
- if (!Subtarget.hasAVX512() && !(VT.getSizeInBits() % 256))
+ if (!Subtarget.useAVX512Regs() && !(VT.getSizeInBits() % 256))
return SplitAndExtendInReg(256);
return SDValue();
}
+// Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm
+// result type.
+static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+ SDLoc dl(N);
+
+ // Only do this combine with AVX512 for vector extends.
+ if (!Subtarget.hasAVX512() || !VT.isVector() || N0->getOpcode() != ISD::SETCC)
+ return SDValue();
+
+ // Only combine legal element types.
+ EVT SVT = VT.getVectorElementType();
+ if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32 &&
+ SVT != MVT::i64 && SVT != MVT::f32 && SVT != MVT::f64)
+ return SDValue();
+
+ // We can only do this if the vector size in 256 bits or less.
+ unsigned Size = VT.getSizeInBits();
+ if (Size > 256)
+ return SDValue();
+
+ // Don't fold if the condition code can't be handled by PCMPEQ/PCMPGT since
+ // that's the only integer compares with we have.
+ ISD::CondCode CC = cast<CondCodeSDNode>(N0->getOperand(2))->get();
+ if (ISD::isUnsignedIntSetCC(CC))
+ return SDValue();
+
+ // Only do this combine if the extension will be fully consumed by the setcc.
+ EVT N00VT = N0.getOperand(0).getValueType();
+ EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
+ if (Size != MatchingVecType.getSizeInBits())
+ return SDValue();
+
+ SDValue Res = DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC);
+
+ if (N->getOpcode() == ISD::ZERO_EXTEND)
+ Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType().getScalarType());
+
+ return Res;
+}
+
static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -35922,6 +37691,9 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
if (!DCI.isBeforeLegalizeOps())
return SDValue();
+ if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
+ return V;
+
if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR &&
isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) {
// Invert and sign-extend a boolean is the same as zero-extend and subtract
@@ -35939,7 +37711,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
return V;
if (VT.isVector())
- if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))
+ if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget))
return R;
if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget))
@@ -35948,9 +37720,40 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
+ if (NegMul) {
+ switch (Opcode) {
+ default: llvm_unreachable("Unexpected opcode");
+ case ISD::FMA: Opcode = X86ISD::FNMADD; break;
+ case X86ISD::FMADD_RND: Opcode = X86ISD::FNMADD_RND; break;
+ case X86ISD::FMSUB: Opcode = X86ISD::FNMSUB; break;
+ case X86ISD::FMSUB_RND: Opcode = X86ISD::FNMSUB_RND; break;
+ case X86ISD::FNMADD: Opcode = ISD::FMA; break;
+ case X86ISD::FNMADD_RND: Opcode = X86ISD::FMADD_RND; break;
+ case X86ISD::FNMSUB: Opcode = X86ISD::FMSUB; break;
+ case X86ISD::FNMSUB_RND: Opcode = X86ISD::FMSUB_RND; break;
+ }
+ }
+
+ if (NegAcc) {
+ switch (Opcode) {
+ default: llvm_unreachable("Unexpected opcode");
+ case ISD::FMA: Opcode = X86ISD::FMSUB; break;
+ case X86ISD::FMADD_RND: Opcode = X86ISD::FMSUB_RND; break;
+ case X86ISD::FMSUB: Opcode = ISD::FMA; break;
+ case X86ISD::FMSUB_RND: Opcode = X86ISD::FMADD_RND; break;
+ case X86ISD::FNMADD: Opcode = X86ISD::FNMSUB; break;
+ case X86ISD::FNMADD_RND: Opcode = X86ISD::FNMSUB_RND; break;
+ case X86ISD::FNMSUB: Opcode = X86ISD::FNMADD; break;
+ case X86ISD::FNMSUB_RND: Opcode = X86ISD::FNMADD_RND; break;
+ }
+ }
+
+ return Opcode;
+}
+
static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
- // TODO: Handle FMSUB/FNMADD/FNMSUB as the starting opcode.
SDLoc dl(N);
EVT VT = N->getValueType(0);
@@ -35966,96 +37769,41 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
SDValue B = N->getOperand(1);
SDValue C = N->getOperand(2);
- auto invertIfNegative = [](SDValue &V) {
+ auto invertIfNegative = [&DAG](SDValue &V) {
if (SDValue NegVal = isFNEG(V.getNode())) {
- V = NegVal;
+ V = DAG.getBitcast(V.getValueType(), NegVal);
return true;
}
+ // Look through extract_vector_elts. If it comes from an FNEG, create a
+ // new extract from the FNEG input.
+ if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ isa<ConstantSDNode>(V.getOperand(1)) &&
+ cast<ConstantSDNode>(V.getOperand(1))->getZExtValue() == 0) {
+ if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) {
+ NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal);
+ V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(),
+ NegVal, V.getOperand(1));
+ return true;
+ }
+ }
+
return false;
};
// Do not convert the passthru input of scalar intrinsics.
// FIXME: We could allow negations of the lower element only.
- bool NegA = N->getOpcode() != X86ISD::FMADDS1 &&
- N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A);
+ bool NegA = invertIfNegative(A);
bool NegB = invertIfNegative(B);
- bool NegC = N->getOpcode() != X86ISD::FMADDS3 &&
- N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C);
+ bool NegC = invertIfNegative(C);
- // Negative multiplication when NegA xor NegB
- bool NegMul = (NegA != NegB);
- bool HasNeg = NegA || NegB || NegC;
+ if (!NegA && !NegB && !NegC)
+ return SDValue();
- unsigned NewOpcode;
- if (!NegMul)
- NewOpcode = (!NegC) ? unsigned(ISD::FMA) : unsigned(X86ISD::FMSUB);
- else
- NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB;
-
- // For FMA, we risk reconstructing the node we started with.
- // In order to avoid this, we check for negation or opcode change. If
- // one of the two happened, then it is a new node and we return it.
- if (N->getOpcode() == ISD::FMA) {
- if (HasNeg || NewOpcode != N->getOpcode())
- return DAG.getNode(NewOpcode, dl, VT, A, B, C);
- return SDValue();
- }
-
- if (N->getOpcode() == X86ISD::FMADD_RND) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADD_RND; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break;
- }
- } else if (N->getOpcode() == X86ISD::FMADDS1) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADDS1; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1; break;
- }
- } else if (N->getOpcode() == X86ISD::FMADDS3) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADDS3; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3; break;
- }
- } else if (N->getOpcode() == X86ISD::FMADDS1_RND) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADDS1_RND; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1_RND; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1_RND; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1_RND; break;
- }
- } else if (N->getOpcode() == X86ISD::FMADDS3_RND) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADDS3_RND; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3_RND; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3_RND; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3_RND; break;
- }
- } else if (N->getOpcode() == X86ISD::FMADD4S) {
- switch (NewOpcode) {
- case ISD::FMA: NewOpcode = X86ISD::FMADD4S; break;
- case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB4S; break;
- case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD4S; break;
- case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB4S; break;
- }
- } else {
- llvm_unreachable("Unexpected opcode!");
- }
+ unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC);
- // Only return the node is the opcode was changed or one of the
- // operand was negated. If not, we'll just recreate the same node.
- if (HasNeg || NewOpcode != N->getOpcode()) {
- if (N->getNumOperands() == 4)
- return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3));
- return DAG.getNode(NewOpcode, dl, VT, A, B, C);
- }
-
- return SDValue();
+ if (N->getNumOperands() == 4)
+ return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3));
+ return DAG.getNode(NewOpcode, dl, VT, A, B, C);
}
// Combine FMADDSUB(A, B, FNEG(C)) -> FMSUBADD(A, B, C)
@@ -36124,6 +37872,10 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
if (SDValue NewCMov = combineToExtendCMOV(N, DAG))
return NewCMov;
+ if (DCI.isBeforeLegalizeOps())
+ if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
+ return V;
+
if (SDValue V = combineToExtendVectorInReg(N, DAG, DCI, Subtarget))
return V;
@@ -36131,7 +37883,7 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
return V;
if (VT.isVector())
- if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget))
+ if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget))
return R;
if (SDValue DivRem8 = getDivRem8(N, DAG))
@@ -36153,13 +37905,23 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG,
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate");
- // We're looking for an oversized integer equality comparison, but ignore a
- // comparison with zero because that gets special treatment in EmitTest().
+ // We're looking for an oversized integer equality comparison.
SDValue X = SetCC->getOperand(0);
SDValue Y = SetCC->getOperand(1);
EVT OpVT = X.getValueType();
unsigned OpSize = OpVT.getSizeInBits();
- if (!OpVT.isScalarInteger() || OpSize < 128 || isNullConstant(Y))
+ if (!OpVT.isScalarInteger() || OpSize < 128)
+ return SDValue();
+
+ // Ignore a comparison with zero because that gets special treatment in
+ // EmitTest(). But make an exception for the special case of a pair of
+ // logically-combined vector-sized operands compared to zero. This pattern may
+ // be generated by the memcmp expansion pass with oversized integer compares
+ // (see PR33325).
+ bool IsOrXorXorCCZero = isNullConstant(Y) && X.getOpcode() == ISD::OR &&
+ X.getOperand(0).getOpcode() == ISD::XOR &&
+ X.getOperand(1).getOpcode() == ISD::XOR;
+ if (isNullConstant(Y) && !IsOrXorXorCCZero)
return SDValue();
// Bail out if we know that this is not really just an oversized integer.
@@ -36174,15 +37936,29 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG,
if ((OpSize == 128 && Subtarget.hasSSE2()) ||
(OpSize == 256 && Subtarget.hasAVX2())) {
EVT VecVT = OpSize == 128 ? MVT::v16i8 : MVT::v32i8;
- SDValue VecX = DAG.getBitcast(VecVT, X);
- SDValue VecY = DAG.getBitcast(VecVT, Y);
-
+ SDValue Cmp;
+ if (IsOrXorXorCCZero) {
+ // This is a bitwise-combined equality comparison of 2 pairs of vectors:
+ // setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne
+ // Use 2 vector equality compares and 'and' the results before doing a
+ // MOVMSK.
+ SDValue A = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(0));
+ SDValue B = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(1));
+ SDValue C = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(0));
+ SDValue D = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(1));
+ SDValue Cmp1 = DAG.getSetCC(DL, VecVT, A, B, ISD::SETEQ);
+ SDValue Cmp2 = DAG.getSetCC(DL, VecVT, C, D, ISD::SETEQ);
+ Cmp = DAG.getNode(ISD::AND, DL, VecVT, Cmp1, Cmp2);
+ } else {
+ SDValue VecX = DAG.getBitcast(VecVT, X);
+ SDValue VecY = DAG.getBitcast(VecVT, Y);
+ Cmp = DAG.getSetCC(DL, VecVT, VecX, VecY, ISD::SETEQ);
+ }
// If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality.
// setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq
// setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne
// setcc i256 X, Y, eq --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, eq
// setcc i256 X, Y, ne --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, ne
- SDValue Cmp = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, VecX, VecY);
SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp);
SDValue FFFFs = DAG.getConstant(OpSize == 128 ? 0xFFFF : 0xFFFFFFFF, DL,
MVT::i32);
@@ -36198,10 +37974,10 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
EVT VT = N->getValueType(0);
+ EVT OpVT = LHS.getValueType();
SDLoc DL(N);
if (CC == ISD::SETNE || CC == ISD::SETEQ) {
- EVT OpVT = LHS.getValueType();
// 0-x == y --> x+y == 0
// 0-x != y --> x+y != 0
if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
@@ -36250,6 +38026,20 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
}
}
+ // If we have AVX512, but not BWI and this is a vXi16/vXi8 setcc, just
+ // pre-promote its result type since vXi1 vectors don't get promoted
+ // during type legalization.
+ // NOTE: The element count check is to ignore operand types that need to
+ // go through type promotion to a 128-bit vector.
+ if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.isVector() &&
+ VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() > 4 &&
+ (OpVT.getVectorElementType() == MVT::i8 ||
+ OpVT.getVectorElementType() == MVT::i16)) {
+ SDValue Setcc = DAG.getNode(ISD::SETCC, DL, OpVT, LHS, RHS,
+ N->getOperand(2));
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, Setcc);
+ }
+
// For an SSE1-only target, lower a comparison of v4f32 to X86ISD::CMPP early
// to avoid scalarization via legalization because v4i32 is not a legal type.
if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32 &&
@@ -36264,6 +38054,19 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
SDValue Src = N->getOperand(0);
MVT SrcVT = Src.getSimpleValueType();
+ // Perform constant folding.
+ if (ISD::isBuildVectorOfConstantSDNodes(Src.getNode())) {
+ assert(N->getValueType(0) == MVT::i32 && "Unexpected result type");
+ APInt Imm(32, 0);
+ for (unsigned Idx = 0, e = Src.getNumOperands(); Idx < e; ++Idx) {
+ SDValue In = Src.getOperand(Idx);
+ if (!In.isUndef() &&
+ cast<ConstantSDNode>(In)->getAPIntValue().isNegative())
+ Imm.setBit(Idx);
+ }
+ return DAG.getConstant(Imm, SDLoc(N), N->getValueType(0));
+ }
+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
!DCI.isBeforeLegalizeOps());
@@ -36295,12 +38098,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- // The original sign extend has less users, add back to worklist in case
- // it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N) {
+ // The original sign extend has less users, add back to worklist in
+ // case it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ }
+ return SDValue(Res, 0);
}
}
@@ -36313,9 +38118,10 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index;
- DAG.UpdateNodeOperands(N, NewOps);
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N)
+ DCI.AddToWorklist(N);
+ return SDValue(Res, 0);
}
// Try to remove zero extends from 32->64 if we know the sign bit of
@@ -36326,32 +38132,24 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
if (DAG.SignBitIsZero(Index.getOperand(0))) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- // The original zero extend has less users, add back to worklist in case
- // it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- return SDValue(N, 0);
+ SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+ if (Res == N) {
+ // The original sign extend has less users, add back to worklist in
+ // case it needs to be removed
+ DCI.AddToWorklist(Index.getNode());
+ DCI.AddToWorklist(N);
+ }
+ return SDValue(Res, 0);
}
}
}
- // Gather and Scatter instructions use k-registers for masks. The type of
- // the masks is v*i1. So the mask will be truncated anyway.
- // The SIGN_EXTEND_INREG my be dropped.
- SDValue Mask = N->getOperand(2);
- if (Subtarget.hasAVX512() && Mask.getOpcode() == ISD::SIGN_EXTEND_INREG) {
- SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[2] = Mask.getOperand(0);
- DAG.UpdateNodeOperands(N, NewOps);
- return SDValue(N, 0);
- }
-
// With AVX2 we only demand the upper bit of the mask.
if (!Subtarget.hasAVX512()) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
!DCI.isBeforeLegalizeOps());
+ SDValue Mask = N->getOperand(2);
KnownBits Known;
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) {
@@ -36448,11 +38246,11 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG,
SDValue Op0 = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT InVT = Op0.getValueType();
- EVT InSVT = InVT.getScalarType();
+ // UINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32))
// UINT_TO_FP(vXi8) -> SINT_TO_FP(ZEXT(vXi8 to vXi32))
// UINT_TO_FP(vXi16) -> SINT_TO_FP(ZEXT(vXi16 to vXi32))
- if (InVT.isVector() && (InSVT == MVT::i8 || InSVT == MVT::i16)) {
+ if (InVT.isVector() && InVT.getScalarSizeInBits() < 32) {
SDLoc dl(N);
EVT DstVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
InVT.getVectorNumElements());
@@ -36482,14 +38280,11 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG,
SDValue Op0 = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT InVT = Op0.getValueType();
- EVT InSVT = InVT.getScalarType();
// SINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32))
// SINT_TO_FP(vXi8) -> SINT_TO_FP(SEXT(vXi8 to vXi32))
// SINT_TO_FP(vXi16) -> SINT_TO_FP(SEXT(vXi16 to vXi32))
- if (InVT.isVector() &&
- (InSVT == MVT::i8 || InSVT == MVT::i16 ||
- (InSVT == MVT::i1 && !DAG.getTargetLoweringInfo().isTypeLegal(InVT)))) {
+ if (InVT.isVector() && InVT.getScalarSizeInBits() < 32) {
SDLoc dl(N);
EVT DstVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
InVT.getVectorNumElements());
@@ -36524,6 +38319,11 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG,
if (VT == MVT::f16 || VT == MVT::f128)
return SDValue();
+ // If we have AVX512DQ we can use packed conversion instructions unless
+ // the VT is f80.
+ if (Subtarget.hasDQI() && VT != MVT::f80)
+ return SDValue();
+
if (!Ld->isVolatile() && !VT.isVector() &&
ISD::isNON_EXTLoad(Op0.getNode()) && Op0.hasOneUse() &&
!Subtarget.is64Bit() && LdVT == MVT::i64) {
@@ -36778,15 +38578,9 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG,
EVT VT = N->getValueType(0);
- unsigned RegSize = 128;
- if (Subtarget.hasBWI())
- RegSize = 512;
- else if (Subtarget.hasAVX2())
- RegSize = 256;
- unsigned VectorSize = VT.getVectorNumElements() * 16;
// If the vector size is less than 128, or greater than the supported RegSize,
// do not use PMADD.
- if (VectorSize < 128 || VectorSize > RegSize)
+ if (VT.getVectorNumElements() < 8)
return SDValue();
SDLoc DL(N);
@@ -36800,7 +38594,13 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG,
SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1));
// Madd vector size is half of the original vector size
- SDValue Madd = DAG.getNode(X86ISD::VPMADDWD, DL, MAddVT, N0, N1);
+ auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
+ return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops);
+ };
+ SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 },
+ PMADDWDBuilder);
// Fill the rest of the output with 0
SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL);
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero);
@@ -36824,12 +38624,12 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
return SDValue();
unsigned RegSize = 128;
- if (Subtarget.hasBWI())
+ if (Subtarget.useBWIRegs())
RegSize = 512;
- else if (Subtarget.hasAVX2())
+ else if (Subtarget.hasAVX())
RegSize = 256;
- // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512.
+ // We only handle v16i32 for SSE2 / v32i32 for AVX / v64i32 for AVX512.
// TODO: We should be able to handle larger vectors by splitting them before
// feeding them into several SADs, and then reducing over those.
if (VT.getSizeInBits() / 4 > RegSize)
@@ -36855,7 +38655,7 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
// reduction. Note that the number of elements of the result of SAD is less
// than the number of elements of its input. Therefore, we could only update
// part of elements in the reduction vector.
- SDValue Sad = createPSADBW(DAG, Op0, Op1, DL);
+ SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
// The output of PSADBW is a vector of i64.
// We need to turn the vector of i64 into a vector of i32.
@@ -36905,6 +38705,236 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(NewOpcode, SDLoc(N), VT, N->getOperand(0), AllOnesVec);
}
+static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
+ const SDLoc &DL, EVT VT,
+ const X86Subtarget &Subtarget) {
+ // Example of pattern we try to detect:
+ // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
+ //(add (build_vector (extract_elt t, 0),
+ // (extract_elt t, 2),
+ // (extract_elt t, 4),
+ // (extract_elt t, 6)),
+ // (build_vector (extract_elt t, 1),
+ // (extract_elt t, 3),
+ // (extract_elt t, 5),
+ // (extract_elt t, 7)))
+
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
+ Op1.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
+ VT.getVectorNumElements() < 4 ||
+ !isPowerOf2_32(VT.getVectorNumElements()))
+ return SDValue();
+
+ // Check if one of Op0,Op1 is of the form:
+ // (build_vector (extract_elt Mul, 0),
+ // (extract_elt Mul, 2),
+ // (extract_elt Mul, 4),
+ // ...
+ // the other is of the form:
+ // (build_vector (extract_elt Mul, 1),
+ // (extract_elt Mul, 3),
+ // (extract_elt Mul, 5),
+ // ...
+ // and identify Mul.
+ SDValue Mul;
+ for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) {
+ SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
+ Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
+ // TODO: Be more tolerant to undefs.
+ if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
+ auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
+ auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
+ auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
+ if (!Const0L || !Const1L || !Const0H || !Const1H)
+ return SDValue();
+ unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
+ Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
+ // Commutativity of mul allows factors of a product to reorder.
+ if (Idx0L > Idx1L)
+ std::swap(Idx0L, Idx1L);
+ if (Idx0H > Idx1H)
+ std::swap(Idx0H, Idx1H);
+ // Commutativity of add allows pairs of factors to reorder.
+ if (Idx0L > Idx0H) {
+ std::swap(Idx0L, Idx0H);
+ std::swap(Idx1L, Idx1H);
+ }
+ if (Idx0L != 2 * i || Idx1L != 2 * i + 1 || Idx0H != 2 * i + 2 ||
+ Idx1H != 2 * i + 3)
+ return SDValue();
+ if (!Mul) {
+ // First time an extract_elt's source vector is visited. Must be a MUL
+ // with 2X number of vector elements than the BUILD_VECTOR.
+ // Both extracts must be from same MUL.
+ Mul = Op0L->getOperand(0);
+ if (Mul->getOpcode() != ISD::MUL ||
+ Mul.getValueType().getVectorNumElements() != 2 * e)
+ return SDValue();
+ }
+ // Check that the extract is from the same MUL previously seen.
+ if (Mul != Op0L->getOperand(0) || Mul != Op1L->getOperand(0) ||
+ Mul != Op0H->getOperand(0) || Mul != Op1H->getOperand(0))
+ return SDValue();
+ }
+
+ // Check if the Mul source can be safely shrunk.
+ ShrinkMode Mode;
+ if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) || Mode == MULU16)
+ return SDValue();
+
+ auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Shrink by adding truncate nodes and let DAGCombine fold with the
+ // sources.
+ EVT InVT = Ops[0].getValueType();
+ assert(InVT.getScalarType() == MVT::i32 &&
+ "Unexpected scalar element type");
+ assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
+ InVT.getVectorNumElements() / 2);
+ EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+ InVT.getVectorNumElements());
+ return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT,
+ DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[0]),
+ DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[1]));
+ };
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT,
+ { Mul.getOperand(0), Mul.getOperand(1) },
+ PMADDBuilder);
+}
+
+// Attempt to turn this pattern into PMADDWD.
+// (mul (add (zext (build_vector)), (zext (build_vector))),
+// (add (zext (build_vector)), (zext (build_vector)))
+static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
+ const SDLoc &DL, EVT VT,
+ const X86Subtarget &Subtarget) {
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
+ VT.getVectorNumElements() < 4 ||
+ !isPowerOf2_32(VT.getVectorNumElements()))
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ SDValue N10 = N1.getOperand(0);
+ SDValue N11 = N1.getOperand(1);
+
+ // All inputs need to be sign extends.
+ // TODO: Support ZERO_EXTEND from known positive?
+ if (N00.getOpcode() != ISD::SIGN_EXTEND ||
+ N01.getOpcode() != ISD::SIGN_EXTEND ||
+ N10.getOpcode() != ISD::SIGN_EXTEND ||
+ N11.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ // Peek through the extends.
+ N00 = N00.getOperand(0);
+ N01 = N01.getOperand(0);
+ N10 = N10.getOperand(0);
+ N11 = N11.getOperand(0);
+
+ // Must be extending from vXi16.
+ EVT InVT = N00.getValueType();
+ if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
+ N10.getValueType() != InVT || N11.getValueType() != InVT)
+ return SDValue();
+
+ // All inputs should be build_vectors.
+ if (N00.getOpcode() != ISD::BUILD_VECTOR ||
+ N01.getOpcode() != ISD::BUILD_VECTOR ||
+ N10.getOpcode() != ISD::BUILD_VECTOR ||
+ N11.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // For each element, we need to ensure we have an odd element from one vector
+ // multiplied by the odd element of another vector and the even element from
+ // one of the same vectors being multiplied by the even element from the
+ // other vector. So we need to make sure for each element i, this operator
+ // is being performed:
+ // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
+ SDValue In0, In1;
+ for (unsigned i = 0; i != N00.getNumOperands(); ++i) {
+ SDValue N00Elt = N00.getOperand(i);
+ SDValue N01Elt = N01.getOperand(i);
+ SDValue N10Elt = N10.getOperand(i);
+ SDValue N11Elt = N11.getOperand(i);
+ // TODO: Be more tolerant to undefs.
+ if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
+ auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
+ auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
+ auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
+ if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+ return SDValue();
+ unsigned IdxN00 = ConstN00Elt->getZExtValue();
+ unsigned IdxN01 = ConstN01Elt->getZExtValue();
+ unsigned IdxN10 = ConstN10Elt->getZExtValue();
+ unsigned IdxN11 = ConstN11Elt->getZExtValue();
+ // Add is commutative so indices can be reordered.
+ if (IdxN00 > IdxN10) {
+ std::swap(IdxN00, IdxN10);
+ std::swap(IdxN01, IdxN11);
+ }
+ // N0 indices be the even elemtn. N1 indices must be the next odd element.
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
+ IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+ return SDValue();
+ SDValue N00In = N00Elt.getOperand(0);
+ SDValue N01In = N01Elt.getOperand(0);
+ SDValue N10In = N10Elt.getOperand(0);
+ SDValue N11In = N11Elt.getOperand(0);
+ // First time we find an input capture it.
+ if (!In0) {
+ In0 = N00In;
+ In1 = N01In;
+ }
+ // Mul is commutative so the input vectors can be in any order.
+ // Canonicalize to make the compares easier.
+ if (In0 != N00In)
+ std::swap(N00In, N01In);
+ if (In0 != N10In)
+ std::swap(N10In, N11In);
+ if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In)
+ return SDValue();
+ }
+
+ auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Shrink by adding truncate nodes and let DAGCombine fold with the
+ // sources.
+ EVT InVT = Ops[0].getValueType();
+ assert(InVT.getScalarType() == MVT::i16 &&
+ "Unexpected scalar element type");
+ assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
+ InVT.getVectorNumElements() / 2);
+ return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
+ PMADDBuilder);
+}
+
static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
const SDNodeFlags Flags = N->getFlags();
@@ -36918,11 +38948,22 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
+ if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
+ return MAdd;
+ if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
+ return MAdd;
+
// Try to synthesize horizontal adds from adds of shuffles.
- if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) ||
- (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) &&
- isHorizontalBinOp(Op0, Op1, true))
- return DAG.getNode(X86ISD::HADD, SDLoc(N), VT, Op0, Op1);
+ if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 ||
+ VT == MVT::v8i32) &&
+ Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, true)) {
+ auto HADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::HADD, DL, Ops[0].getValueType(), Ops);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1},
+ HADDBuilder);
+ }
if (SDValue V = combineIncDecVector(N, DAG))
return V;
@@ -36936,20 +38977,19 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG,
SDValue Op1 = N->getOperand(1);
EVT VT = N->getValueType(0);
- // PSUBUS is supported, starting from SSE2, but special preprocessing
- // for v8i32 requires umin, which appears in SSE41.
+ // PSUBUS is supported, starting from SSE2, but truncation for v8i32
+ // is only worth it with SSSE3 (PSHUFB).
if (!(Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) &&
- !(Subtarget.hasSSE41() && (VT == MVT::v8i32)) &&
- !(Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)) &&
- !(Subtarget.hasAVX512() && Subtarget.hasBWI() &&
- (VT == MVT::v64i8 || VT == MVT::v32i16 || VT == MVT::v16i32 ||
- VT == MVT::v8i64)))
+ !(Subtarget.hasSSSE3() && (VT == MVT::v8i32 || VT == MVT::v8i64)) &&
+ !(Subtarget.hasAVX() && (VT == MVT::v32i8 || VT == MVT::v16i16)) &&
+ !(Subtarget.useBWIRegs() && (VT == MVT::v64i8 || VT == MVT::v32i16 ||
+ VT == MVT::v16i32 || VT == MVT::v8i64)))
return SDValue();
SDValue SubusLHS, SubusRHS;
// Try to find umax(a,b) - b or a - umin(a,b) patterns
// they may be converted to subus(a,b).
- // TODO: Need to add IR cannonicialization for this code.
+ // TODO: Need to add IR canonicalization for this code.
if (Op0.getOpcode() == ISD::UMAX) {
SubusRHS = Op1;
SDValue MaxLHS = Op0.getOperand(0);
@@ -36973,10 +39013,16 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG,
} else
return SDValue();
+ auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops);
+ };
+
// PSUBUS doesn't support v8i32/v8i64/v16i32, but it can be enabled with
// special preprocessing in some cases.
if (VT != MVT::v8i32 && VT != MVT::v16i32 && VT != MVT::v8i64)
- return DAG.getNode(X86ISD::SUBUS, SDLoc(N), VT, SubusLHS, SubusRHS);
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
+ { SubusLHS, SubusRHS }, SUBUSBuilder);
// Special preprocessing case can be only applied
// if the value was zero extended from 16 bit,
@@ -37006,8 +39052,9 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG,
SDValue NewSubusLHS =
DAG.getZExtOrTrunc(SubusLHS, SDLoc(SubusLHS), ShrinkedType);
SDValue NewSubusRHS = DAG.getZExtOrTrunc(UMin, SDLoc(SubusRHS), ShrinkedType);
- SDValue Psubus = DAG.getNode(X86ISD::SUBUS, SDLoc(N), ShrinkedType,
- NewSubusLHS, NewSubusRHS);
+ SDValue Psubus =
+ SplitOpsAndApply(DAG, Subtarget, SDLoc(N), ShrinkedType,
+ { NewSubusLHS, NewSubusRHS }, SUBUSBuilder);
// Zero extend the result, it may be used somewhere as 32 bit,
// if not zext and following trunc will shrink.
return DAG.getZExtOrTrunc(Psubus, SDLoc(N), ExtType);
@@ -37038,10 +39085,16 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
// Try to synthesize horizontal subs from subs of shuffles.
EVT VT = N->getValueType(0);
- if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) ||
- (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) &&
- isHorizontalBinOp(Op0, Op1, false))
- return DAG.getNode(X86ISD::HSUB, SDLoc(N), VT, Op0, Op1);
+ if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 ||
+ VT == MVT::v8i32) &&
+ Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, false)) {
+ auto HSUBBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ return DAG.getNode(X86ISD::HSUB, DL, Ops[0].getValueType(), Ops);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1},
+ HSUBBuilder);
+ }
if (SDValue V = combineIncDecVector(N, DAG))
return V;
@@ -37145,28 +39198,6 @@ static SDValue combineVSZext(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
-static SDValue combineTestM(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
-
- MVT VT = N->getSimpleValueType(0);
- SDLoc DL(N);
-
- // TEST (AND a, b) ,(AND a, b) -> TEST a, b
- if (Op0 == Op1 && Op1->getOpcode() == ISD::AND)
- return DAG.getNode(X86ISD::TESTM, DL, VT, Op0->getOperand(0),
- Op0->getOperand(1));
-
- // TEST op0, BUILD_VECTOR(all_zero) -> BUILD_VECTOR(all_zero)
- // TEST BUILD_VECTOR(all_zero), op1 -> BUILD_VECTOR(all_zero)
- if (ISD::isBuildVectorAllZeros(Op0.getNode()) ||
- ISD::isBuildVectorAllZeros(Op1.getNode()))
- return getZeroVector(VT, Subtarget, DAG, DL);
-
- return SDValue();
-}
-
static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
MVT VT = N->getSimpleValueType(0);
@@ -37190,9 +39221,7 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
MVT OpVT = N->getSimpleValueType(0);
- // Early out for mask vectors.
- if (OpVT.getVectorElementType() == MVT::i1)
- return SDValue();
+ bool IsI1Vector = OpVT.getVectorElementType() == MVT::i1;
SDLoc dl(N);
SDValue Vec = N->getOperand(0);
@@ -37204,23 +39233,40 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
if (ISD::isBuildVectorAllZeros(Vec.getNode())) {
// Inserting zeros into zeros is a nop.
if (ISD::isBuildVectorAllZeros(SubVec.getNode()))
- return Vec;
+ return getZeroVector(OpVT, Subtarget, DAG, dl);
// If we're inserting into a zero vector and then into a larger zero vector,
// just insert into the larger zero vector directly.
if (SubVec.getOpcode() == ISD::INSERT_SUBVECTOR &&
ISD::isBuildVectorAllZeros(SubVec.getOperand(0).getNode())) {
unsigned Idx2Val = SubVec.getConstantOperandVal(2);
- return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec,
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT,
+ getZeroVector(OpVT, Subtarget, DAG, dl),
SubVec.getOperand(1),
DAG.getIntPtrConstant(IdxVal + Idx2Val, dl));
}
+ // If we're inserting into a zero vector and our input was extracted from an
+ // insert into a zero vector of the same type and the extraction was at
+ // least as large as the original insertion. Just insert the original
+ // subvector into a zero vector.
+ if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && IdxVal == 0 &&
+ SubVec.getConstantOperandVal(1) == 0 &&
+ SubVec.getOperand(0).getOpcode() == ISD::INSERT_SUBVECTOR) {
+ SDValue Ins = SubVec.getOperand(0);
+ if (Ins.getConstantOperandVal(2) == 0 &&
+ ISD::isBuildVectorAllZeros(Ins.getOperand(0).getNode()) &&
+ Ins.getOperand(1).getValueSizeInBits() <= SubVecVT.getSizeInBits())
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT,
+ getZeroVector(OpVT, Subtarget, DAG, dl),
+ Ins.getOperand(1), N->getOperand(2));
+ }
+
// If we're inserting a bitcast into zeros, rewrite the insert and move the
// bitcast to the other side. This helps with detecting zero extending
// during isel.
// TODO: Is this useful for other indices than 0?
- if (SubVec.getOpcode() == ISD::BITCAST && IdxVal == 0) {
+ if (!IsI1Vector && SubVec.getOpcode() == ISD::BITCAST && IdxVal == 0) {
MVT CastVT = SubVec.getOperand(0).getSimpleValueType();
unsigned NumElems = OpVT.getSizeInBits() / CastVT.getScalarSizeInBits();
MVT NewVT = MVT::getVectorVT(CastVT.getVectorElementType(), NumElems);
@@ -37231,6 +39277,10 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
}
}
+ // Stop here if this is an i1 vector.
+ if (IsI1Vector)
+ return SDValue();
+
// If this is an insert of an extract, combine to a shuffle. Don't do this
// if the insert or extract can be represented with a subregister operation.
if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
@@ -37317,7 +39367,6 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
if (!Vec.getOperand(0).isUndef() && Vec.hasOneUse()) {
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, DAG.getUNDEF(OpVT),
SubVec2, Vec.getOperand(2));
- DCI.AddToWorklist(Vec.getNode());
return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec, SubVec,
N->getOperand(2));
@@ -37352,6 +39401,75 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG,
OpVT, SDLoc(N),
InVec.getNode()->ops().slice(IdxVal, OpVT.getVectorNumElements()));
+ // If we're extracting the lowest subvector and we're the only user,
+ // we may be able to perform this with a smaller vector width.
+ if (IdxVal == 0 && InVec.hasOneUse()) {
+ unsigned InOpcode = InVec.getOpcode();
+ if (OpVT == MVT::v2f64 && InVec.getValueType() == MVT::v4f64) {
+ // v2f64 CVTDQ2PD(v4i32).
+ if (InOpcode == ISD::SINT_TO_FP &&
+ InVec.getOperand(0).getValueType() == MVT::v4i32) {
+ return DAG.getNode(X86ISD::CVTSI2P, SDLoc(N), OpVT, InVec.getOperand(0));
+ }
+ // v2f64 CVTPS2PD(v4f32).
+ if (InOpcode == ISD::FP_EXTEND &&
+ InVec.getOperand(0).getValueType() == MVT::v4f32) {
+ return DAG.getNode(X86ISD::VFPEXT, SDLoc(N), OpVT, InVec.getOperand(0));
+ }
+ }
+ if ((InOpcode == X86ISD::VZEXT || InOpcode == X86ISD::VSEXT) &&
+ OpVT.is128BitVector() &&
+ InVec.getOperand(0).getSimpleValueType().is128BitVector()) {
+ unsigned ExtOp = InOpcode == X86ISD::VZEXT ? ISD::ZERO_EXTEND_VECTOR_INREG
+ : ISD::SIGN_EXTEND_VECTOR_INREG;
+ return DAG.getNode(ExtOp, SDLoc(N), OpVT, InVec.getOperand(0));
+ }
+ }
+
+ return SDValue();
+}
+
+static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+
+ // If this is a scalar to vector to v1i1 from an AND with 1, bypass the and.
+ // This occurs frequently in our masked scalar intrinsic code and our
+ // floating point select lowering with AVX512.
+ // TODO: SimplifyDemandedBits instead?
+ if (VT == MVT::v1i1 && Src.getOpcode() == ISD::AND && Src.hasOneUse())
+ if (auto *C = dyn_cast<ConstantSDNode>(Src.getOperand(1)))
+ if (C->getAPIntValue().isOneValue())
+ return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), MVT::v1i1,
+ Src.getOperand(0));
+
+ return SDValue();
+}
+
+// Simplify PMULDQ and PMULUDQ operations.
+static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+ !DCI.isBeforeLegalizeOps());
+ APInt DemandedMask(APInt::getLowBitsSet(64, 32));
+
+ // PMULQDQ/PMULUDQ only uses lower 32 bits from each vector element.
+ KnownBits LHSKnown;
+ if (TLI.SimplifyDemandedBits(LHS, DemandedMask, LHSKnown, TLO)) {
+ DCI.CommitTargetLoweringOpt(TLO);
+ return SDValue(N, 0);
+ }
+
+ KnownBits RHSKnown;
+ if (TLI.SimplifyDemandedBits(RHS, DemandedMask, RHSKnown, TLO)) {
+ DCI.CommitTargetLoweringOpt(TLO);
+ return SDValue(N, 0);
+ }
+
return SDValue();
}
@@ -37360,6 +39478,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
SelectionDAG &DAG = DCI.DAG;
switch (N->getOpcode()) {
default: break;
+ case ISD::SCALAR_TO_VECTOR:
+ return combineScalarToVector(N, DAG);
case ISD::EXTRACT_VECTOR_ELT:
case X86ISD::PEXTRW:
case X86ISD::PEXTRB:
@@ -37384,6 +39504,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
+ case X86ISD::BEXTR: return combineBEXTR(N, DAG, DCI, Subtarget);
case ISD::LOAD: return combineLoad(N, DAG, DCI, Subtarget);
case ISD::MLOAD: return combineMaskedLoad(N, DAG, DCI, Subtarget);
case ISD::STORE: return combineStore(N, DAG, Subtarget);
@@ -37449,20 +39570,21 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::VPERMI:
case X86ISD::VPERMV:
case X86ISD::VPERMV3:
- case X86ISD::VPERMIV3:
case X86ISD::VPERMIL2:
case X86ISD::VPERMILPI:
case X86ISD::VPERMILPV:
case X86ISD::VPERM2X128:
+ case X86ISD::SHUF128:
case X86ISD::VZEXT_MOVL:
case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget);
case X86ISD::FMADD_RND:
- case X86ISD::FMADDS1_RND:
- case X86ISD::FMADDS3_RND:
- case X86ISD::FMADDS1:
- case X86ISD::FMADDS3:
- case X86ISD::FMADD4S:
- case ISD::FMA: return combineFMA(N, DAG, Subtarget);
+ case X86ISD::FMSUB:
+ case X86ISD::FMSUB_RND:
+ case X86ISD::FNMADD:
+ case X86ISD::FNMADD_RND:
+ case X86ISD::FNMSUB:
+ case X86ISD::FNMSUB_RND:
+ case ISD::FMA: return combineFMA(N, DAG, Subtarget);
case X86ISD::FMADDSUB_RND:
case X86ISD::FMSUBADD_RND:
case X86ISD::FMADDSUB:
@@ -37472,9 +39594,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::MSCATTER:
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget);
- case X86ISD::TESTM: return combineTestM(N, DAG, Subtarget);
case X86ISD::PCMPEQ:
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
+ case X86ISD::PMULDQ:
+ case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI);
}
return SDValue();
@@ -37487,6 +39610,11 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const {
if (!isTypeLegal(VT))
return false;
+
+ // There are no vXi8 shifts.
+ if (Opc == ISD::SHL && VT.isVector() && VT.getVectorElementType() == MVT::i8)
+ return false;
+
if (VT != MVT::i16)
return true;
@@ -37509,23 +39637,20 @@ bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const {
}
}
-/// This function checks if any of the users of EFLAGS copies the EFLAGS. We
-/// know that the code that lowers COPY of EFLAGS has to use the stack, and if
-/// we don't adjust the stack we clobber the first frame index.
-/// See X86InstrInfo::copyPhysReg.
-static bool hasCopyImplyingStackAdjustment(const MachineFunction &MF) {
- const MachineRegisterInfo &MRI = MF.getRegInfo();
- return any_of(MRI.reg_instructions(X86::EFLAGS),
- [](const MachineInstr &RI) { return RI.isCopy(); });
-}
-
-void X86TargetLowering::finalizeLowering(MachineFunction &MF) const {
- if (hasCopyImplyingStackAdjustment(MF)) {
- MachineFrameInfo &MFI = MF.getFrameInfo();
- MFI.setHasCopyImplyingStackAdjustment(true);
+SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc& dl,
+ SDValue Value, SDValue Addr,
+ SelectionDAG &DAG) const {
+ const Module *M = DAG.getMachineFunction().getMMI().getModule();
+ Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch");
+ if (IsCFProtectionSupported) {
+ // In case control-flow branch protection is enabled, we need to add
+ // notrack prefix to the indirect branch.
+ // In order to do that we create NT_BRIND SDNode.
+ // Upon ISEL, the pattern will convert it to jmp with NoTrack prefix.
+ return DAG.getNode(X86ISD::NT_BRIND, dl, MVT::Other, Value, Addr);
}
- TargetLoweringBase::finalizeLowering(MF);
+ return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, DAG);
}
/// This method query the target whether it is beneficial for dag combiner to
@@ -37536,22 +39661,30 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const {
if (VT != MVT::i16)
return false;
- bool Promote = false;
+ auto IsFoldableRMW = [](SDValue Load, SDValue Op) {
+ if (!Op.hasOneUse())
+ return false;
+ SDNode *User = *Op->use_begin();
+ if (!ISD::isNormalStore(User))
+ return false;
+ auto *Ld = cast<LoadSDNode>(Load);
+ auto *St = cast<StoreSDNode>(User);
+ return Ld->getBasePtr() == St->getBasePtr();
+ };
+
bool Commute = false;
switch (Op.getOpcode()) {
- default: break;
+ default: return false;
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
case ISD::ANY_EXTEND:
- Promote = true;
break;
case ISD::SHL:
case ISD::SRL: {
SDValue N0 = Op.getOperand(0);
// Look out for (store (shl (load), x)).
- if (MayFoldLoad(N0) && MayFoldIntoStore(Op))
+ if (MayFoldLoad(N0) && IsFoldableRMW(N0, Op))
return false;
- Promote = true;
break;
}
case ISD::ADD:
@@ -37564,19 +39697,20 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const {
case ISD::SUB: {
SDValue N0 = Op.getOperand(0);
SDValue N1 = Op.getOperand(1);
- if (!Commute && MayFoldLoad(N1))
- return false;
// Avoid disabling potential load folding opportunities.
- if (MayFoldLoad(N0) && (!isa<ConstantSDNode>(N1) || MayFoldIntoStore(Op)))
+ if (MayFoldLoad(N1) &&
+ (!Commute || !isa<ConstantSDNode>(N0) ||
+ (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N1, Op))))
return false;
- if (MayFoldLoad(N1) && (!isa<ConstantSDNode>(N0) || MayFoldIntoStore(Op)))
+ if (MayFoldLoad(N0) &&
+ ((Commute && !isa<ConstantSDNode>(N1)) ||
+ (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N0, Op))))
return false;
- Promote = true;
}
}
PVT = MVT::i32;
- return Promote;
+ return true;
}
bool X86TargetLowering::
@@ -37862,7 +39996,7 @@ TargetLowering::ConstraintWeight
LLVM_FALLTHROUGH;
case 'x':
if (((type->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) ||
- ((type->getPrimitiveSizeInBits() == 256) && Subtarget.hasFp256()))
+ ((type->getPrimitiveSizeInBits() == 256) && Subtarget.hasAVX()))
weight = CW_Register;
break;
case 'k':
@@ -38353,6 +40487,25 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return Res;
}
+ // Make sure it isn't a register that requires 64-bit mode.
+ if (!Subtarget.is64Bit() &&
+ (isFRClass(*Res.second) || isGRClass(*Res.second)) &&
+ TRI->getEncodingValue(Res.first) >= 8) {
+ // Register requires REX prefix, but we're in 32-bit mode.
+ Res.first = 0;
+ Res.second = nullptr;
+ return Res;
+ }
+
+ // Make sure it isn't a register that requires AVX512.
+ if (!Subtarget.hasAVX512() && isFRClass(*Res.second) &&
+ TRI->getEncodingValue(Res.first) & 0x10) {
+ // Register requires EVEX prefix.
+ Res.first = 0;
+ Res.second = nullptr;
+ return Res;
+ }
+
// Otherwise, check to see if this is a register class of the wrong value
// type. For example, we want to map "{ax},i32" -> {eax}, we don't want it to
// turn into {ax},{dx}.
@@ -38421,7 +40574,7 @@ int X86TargetLowering::getScalingFactorCost(const DataLayout &DL,
// will take 2 allocations in the out of order engine instead of 1
// for plain addressing mode, i.e. inst (reg1).
// E.g.,
- // vaddps (%rsi,%drx), %ymm0, %ymm1
+ // vaddps (%rsi,%rdx), %ymm0, %ymm1
// Requires two allocations (one for the load, one for the computation)
// whereas:
// vaddps (%rsi), %ymm0, %ymm1
@@ -38516,7 +40669,8 @@ StringRef X86TargetLowering::getStackProbeSymbolName(MachineFunction &MF) const
// Generally, if we aren't on Windows, the platform ABI does not include
// support for stack probes, so don't emit them.
- if (!Subtarget.isOSWindows() || Subtarget.isTargetMachO())
+ if (!Subtarget.isOSWindows() || Subtarget.isTargetMachO() ||
+ MF.getFunction().hasFnAttribute("no-stack-arg-probe"))
return "";
// We need a stack probe to conform to the Windows ABI. Choose the right