diff options
Diffstat (limited to 'contrib/llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | contrib/llvm/lib/Target/X86/X86ISelLowering.cpp | 8589 |
1 files changed, 5244 insertions, 3345 deletions
diff --git a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp index 67a127fe0a2b..b6a692ee187d 100644 --- a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -19,7 +19,6 @@ #include "X86InstrBuilder.h" #include "X86IntrinsicsInfo.h" #include "X86MachineFunctionInfo.h" -#include "X86ShuffleDecodeConstantPool.h" #include "X86TargetMachine.h" #include "X86TargetObjectFile.h" #include "llvm/ADT/SmallBitVector.h" @@ -196,6 +195,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ABS , MVT::i64 , Custom); } + // Funnel shifts. + for (auto ShiftOp : {ISD::FSHL, ISD::FSHR}) { + setOperationAction(ShiftOp , MVT::i16 , Custom); + setOperationAction(ShiftOp , MVT::i32 , Custom); + if (Subtarget.is64Bit()) + setOperationAction(ShiftOp , MVT::i64 , Custom); + } + // Promote all UINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have this // operation. setOperationAction(ISD::UINT_TO_FP , MVT::i1 , Promote); @@ -533,6 +540,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Use ANDPD and ORPD to simulate FCOPYSIGN. setOperationAction(ISD::FCOPYSIGN, VT, Custom); + // These might be better off as horizontal vector ops. + setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); + // We don't support sin/cos/fmod setOperationAction(ISD::FSIN , VT, Expand); setOperationAction(ISD::FCOS , VT, Expand); @@ -543,15 +554,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FGETSIGN, MVT::i64, Custom); setOperationAction(ISD::FGETSIGN, MVT::i32, Custom); - // Expand FP immediates into loads from the stack, except for the special - // cases we handle. - addLegalFPImmediate(APFloat(+0.0)); // xorpd - addLegalFPImmediate(APFloat(+0.0f)); // xorps - } else if (UseX87 && X86ScalarSSEf32) { + } else if (!useSoftFloat() && X86ScalarSSEf32 && (UseX87 || Is64Bit)) { // Use SSE for f32, x87 for f64. // Set up the FP register classes. addRegisterClass(MVT::f32, &X86::FR32RegClass); - addRegisterClass(MVT::f64, &X86::RFP64RegClass); + if (UseX87) + addRegisterClass(MVT::f64, &X86::RFP64RegClass); // Use ANDPS to simulate FABS. setOperationAction(ISD::FABS , MVT::f32, Custom); @@ -559,10 +567,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Use XORP to simulate FNEG. setOperationAction(ISD::FNEG , MVT::f32, Custom); - setOperationAction(ISD::UNDEF, MVT::f64, Expand); + if (UseX87) + setOperationAction(ISD::UNDEF, MVT::f64, Expand); // Use ANDPS and ORPS to simulate FCOPYSIGN. - setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); + if (UseX87) + setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom); // We don't support sin/cos/fmod @@ -570,17 +580,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FCOS , MVT::f32, Expand); setOperationAction(ISD::FSINCOS, MVT::f32, Expand); - // Special cases we handle for FP constants. - addLegalFPImmediate(APFloat(+0.0f)); // xorps - addLegalFPImmediate(APFloat(+0.0)); // FLD0 - addLegalFPImmediate(APFloat(+1.0)); // FLD1 - addLegalFPImmediate(APFloat(-0.0)); // FLD0/FCHS - addLegalFPImmediate(APFloat(-1.0)); // FLD1/FCHS - - // Always expand sin/cos functions even though x87 has an instruction. - setOperationAction(ISD::FSIN , MVT::f64, Expand); - setOperationAction(ISD::FCOS , MVT::f64, Expand); - setOperationAction(ISD::FSINCOS, MVT::f64, Expand); + if (UseX87) { + // Always expand sin/cos functions even though x87 has an instruction. + setOperationAction(ISD::FSIN, MVT::f64, Expand); + setOperationAction(ISD::FCOS, MVT::f64, Expand); + setOperationAction(ISD::FSINCOS, MVT::f64, Expand); + } } else if (UseX87) { // f32 and f64 in x87. // Set up the FP register classes. @@ -596,14 +601,27 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FCOS , VT, Expand); setOperationAction(ISD::FSINCOS, VT, Expand); } - addLegalFPImmediate(APFloat(+0.0)); // FLD0 - addLegalFPImmediate(APFloat(+1.0)); // FLD1 - addLegalFPImmediate(APFloat(-0.0)); // FLD0/FCHS - addLegalFPImmediate(APFloat(-1.0)); // FLD1/FCHS - addLegalFPImmediate(APFloat(+0.0f)); // FLD0 - addLegalFPImmediate(APFloat(+1.0f)); // FLD1 - addLegalFPImmediate(APFloat(-0.0f)); // FLD0/FCHS - addLegalFPImmediate(APFloat(-1.0f)); // FLD1/FCHS + } + + // Expand FP32 immediates into loads from the stack, save special cases. + if (isTypeLegal(MVT::f32)) { + if (UseX87 && (getRegClassFor(MVT::f32) == &X86::RFP32RegClass)) { + addLegalFPImmediate(APFloat(+0.0f)); // FLD0 + addLegalFPImmediate(APFloat(+1.0f)); // FLD1 + addLegalFPImmediate(APFloat(-0.0f)); // FLD0/FCHS + addLegalFPImmediate(APFloat(-1.0f)); // FLD1/FCHS + } else // SSE immediates. + addLegalFPImmediate(APFloat(+0.0f)); // xorps + } + // Expand FP64 immediates into loads from the stack, save special cases. + if (isTypeLegal(MVT::f64)) { + if (UseX87 && getRegClassFor(MVT::f64) == &X86::RFP64RegClass) { + addLegalFPImmediate(APFloat(+0.0)); // FLD0 + addLegalFPImmediate(APFloat(+1.0)); // FLD1 + addLegalFPImmediate(APFloat(-0.0)); // FLD0/FCHS + addLegalFPImmediate(APFloat(-1.0)); // FLD1/FCHS + } else // SSE immediates. + addLegalFPImmediate(APFloat(+0.0)); // xorpd } // We don't support FMA. @@ -613,7 +631,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Long double always uses X87, except f128 in MMX. if (UseX87) { if (Subtarget.is64Bit() && Subtarget.hasMMX()) { - addRegisterClass(MVT::f128, &X86::VR128RegClass); + addRegisterClass(MVT::f128, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); ValueTypeActions.setTypeAction(MVT::f128, TypeSoftenFloat); setOperationAction(ISD::FABS , MVT::f128, Custom); setOperationAction(ISD::FNEG , MVT::f128, Custom); @@ -778,11 +797,26 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v2i64, Subtarget.hasVLX() ? &X86::VR128XRegClass : &X86::VR128RegClass); + for (auto VT : { MVT::v2i8, MVT::v4i8, MVT::v8i8, + MVT::v2i16, MVT::v4i16, MVT::v2i32 }) { + setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::SREM, VT, Custom); + setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::UREM, VT, Custom); + } + + setOperationAction(ISD::MUL, MVT::v2i8, Custom); + setOperationAction(ISD::MUL, MVT::v2i16, Custom); + setOperationAction(ISD::MUL, MVT::v2i32, Custom); + setOperationAction(ISD::MUL, MVT::v4i8, Custom); + setOperationAction(ISD::MUL, MVT::v4i16, Custom); + setOperationAction(ISD::MUL, MVT::v8i8, Custom); + setOperationAction(ISD::MUL, MVT::v16i8, Custom); setOperationAction(ISD::MUL, MVT::v4i32, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); - setOperationAction(ISD::UMUL_LOHI, MVT::v4i32, Custom); - setOperationAction(ISD::SMUL_LOHI, MVT::v4i32, Custom); + setOperationAction(ISD::MULHU, MVT::v4i32, Custom); + setOperationAction(ISD::MULHS, MVT::v4i32, Custom); setOperationAction(ISD::MULHU, MVT::v16i8, Custom); setOperationAction(ISD::MULHS, MVT::v16i8, Custom); setOperationAction(ISD::MULHU, MVT::v8i16, Legal); @@ -799,6 +833,26 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UMIN, VT, VT == MVT::v16i8 ? Legal : Custom); } + setOperationAction(ISD::UADDSAT, MVT::v16i8, Legal); + setOperationAction(ISD::SADDSAT, MVT::v16i8, Legal); + setOperationAction(ISD::USUBSAT, MVT::v16i8, Legal); + setOperationAction(ISD::SSUBSAT, MVT::v16i8, Legal); + setOperationAction(ISD::UADDSAT, MVT::v8i16, Legal); + setOperationAction(ISD::SADDSAT, MVT::v8i16, Legal); + setOperationAction(ISD::USUBSAT, MVT::v8i16, Legal); + setOperationAction(ISD::SSUBSAT, MVT::v8i16, Legal); + + if (!ExperimentalVectorWideningLegalization) { + // Use widening instead of promotion. + for (auto VT : { MVT::v8i8, MVT::v4i8, MVT::v2i8, + MVT::v4i16, MVT::v2i16 }) { + setOperationAction(ISD::UADDSAT, VT, Custom); + setOperationAction(ISD::SADDSAT, VT, Custom); + setOperationAction(ISD::USUBSAT, VT, Custom); + setOperationAction(ISD::SSUBSAT, VT, Custom); + } + } + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v8i16, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i32, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f32, Custom); @@ -813,7 +867,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); - setOperationAction(ISD::CTTZ, VT, Custom); + setOperationAction(ISD::ABS, VT, Custom); // The condition codes aren't legal in SSE/AVX and under AVX512 we use // setcc all the way to isel and prefer SETGT in some isel patterns. @@ -834,9 +888,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // scalars) and extend in-register to a legal 128-bit vector type. For sext // loads these must work with a single scalar load. for (MVT VT : MVT::integer_vector_valuetypes()) { - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v4i8, Custom); - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v4i16, Custom); - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v8i8, Custom); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i8, Custom); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i16, Custom); setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i32, Custom); @@ -857,21 +908,36 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); } - // Promote v16i8, v8i16, v4i32 load, select, and, or, xor to v2i64. - for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { - setOperationPromotedToType(ISD::AND, VT, MVT::v2i64); - setOperationPromotedToType(ISD::OR, VT, MVT::v2i64); - setOperationPromotedToType(ISD::XOR, VT, MVT::v2i64); - setOperationPromotedToType(ISD::LOAD, VT, MVT::v2i64); - setOperationPromotedToType(ISD::SELECT, VT, MVT::v2i64); - } - // Custom lower v2i64 and v2f64 selects. setOperationAction(ISD::SELECT, MVT::v2f64, Custom); setOperationAction(ISD::SELECT, MVT::v2i64, Custom); + setOperationAction(ISD::SELECT, MVT::v4i32, Custom); + setOperationAction(ISD::SELECT, MVT::v8i16, Custom); + setOperationAction(ISD::SELECT, MVT::v16i8, Custom); setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v2i32, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i16, Custom); + + // Custom legalize these to avoid over promotion or custom promotion. + setOperationAction(ISD::FP_TO_SINT, MVT::v2i8, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v4i8, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v8i8, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i16, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v4i16, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i8, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v4i8, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v8i8, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i16, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v4i16, Custom); + + // By marking FP_TO_SINT v8i16 as Custom, will trick type legalization into + // promoting v8i8 FP_TO_UINT into FP_TO_SINT. When the v8i16 FP_TO_SINT is + // split again based on the input type, this will cause an AssertSExt i16 to + // be emitted instead of an AssertZExt. This will allow packssdw followed by + // packuswb to be used to truncate to v8i8. This is necessary since packusdw + // isn't available until sse4.1. + setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Custom); setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom); @@ -887,6 +953,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (MVT VT : MVT::fp_vector_valuetypes()) setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2f32, Legal); + // We want to legalize this to an f64 load rather than an i64 load on + // 64-bit targets and two 32-bit loads on a 32-bit target. Similar for + // store. + setOperationAction(ISD::LOAD, MVT::v2f32, Custom); + setOperationAction(ISD::LOAD, MVT::v2i32, Custom); + setOperationAction(ISD::LOAD, MVT::v4i16, Custom); + setOperationAction(ISD::LOAD, MVT::v8i8, Custom); + setOperationAction(ISD::STORE, MVT::v2f32, Custom); + setOperationAction(ISD::STORE, MVT::v2i32, Custom); + setOperationAction(ISD::STORE, MVT::v4i16, Custom); + setOperationAction(ISD::STORE, MVT::v8i8, Custom); + setOperationAction(ISD::BITCAST, MVT::v2i32, Custom); setOperationAction(ISD::BITCAST, MVT::v4i16, Custom); setOperationAction(ISD::BITCAST, MVT::v8i8, Custom); @@ -897,6 +975,19 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i32, Custom); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i16, Custom); + if (ExperimentalVectorWideningLegalization) { + setOperationAction(ISD::SIGN_EXTEND, MVT::v4i64, Custom); + + setOperationAction(ISD::TRUNCATE, MVT::v2i8, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v2i16, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v2i32, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v4i8, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v4i16, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v8i8, Custom); + } else { + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i64, Custom); + } + // In the customized shift lowering, the legal v4i32/v2i64 cases // in AVX2 will be recognized. for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { @@ -907,7 +998,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ROTL, MVT::v4i32, Custom); setOperationAction(ISD::ROTL, MVT::v8i16, Custom); - setOperationAction(ISD::ROTL, MVT::v16i8, Custom); + + // With AVX512, expanding (and promoting the shifts) is better. + if (!Subtarget.hasAVX512()) + setOperationAction(ISD::ROTL, MVT::v16i8, Custom); } if (!Subtarget.useSoftFloat() && Subtarget.hasSSSE3()) { @@ -919,6 +1013,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CTLZ, MVT::v8i16, Custom); setOperationAction(ISD::CTLZ, MVT::v4i32, Custom); setOperationAction(ISD::CTLZ, MVT::v2i64, Custom); + + // These might be better off as horizontal vector ops. + setOperationAction(ISD::ADD, MVT::i16, Custom); + setOperationAction(ISD::ADD, MVT::i32, Custom); + setOperationAction(ISD::SUB, MVT::i16, Custom); + setOperationAction(ISD::SUB, MVT::i32, Custom); } if (!Subtarget.useSoftFloat() && Subtarget.hasSSE41()) { @@ -953,17 +1053,22 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Legal); } - for (MVT VT : MVT::integer_vector_valuetypes()) { - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Custom); - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i16, Custom); - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i32, Custom); + if (!ExperimentalVectorWideningLegalization) { + // Avoid narrow result types when widening. The legal types are listed + // in the next loop. + for (MVT VT : MVT::integer_vector_valuetypes()) { + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i16, Custom); + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i32, Custom); + } } // SSE41 also has vector sign/zero extending loads, PMOV[SZ]X for (auto LoadExtOp : { ISD::SEXTLOAD, ISD::ZEXTLOAD }) { setLoadExtAction(LoadExtOp, MVT::v8i16, MVT::v8i8, Legal); setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i8, Legal); - setLoadExtAction(LoadExtOp, MVT::v2i32, MVT::v2i8, Legal); + if (!ExperimentalVectorWideningLegalization) + setLoadExtAction(LoadExtOp, MVT::v2i32, MVT::v2i8, Legal); setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i8, Legal); setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i16, Legal); setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i16, Legal); @@ -1039,12 +1144,26 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRA, VT, Custom); } + if (ExperimentalVectorWideningLegalization) { + // These types need custom splitting if their input is a 128-bit vector. + setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom); + } + setOperationAction(ISD::ROTL, MVT::v8i32, Custom); setOperationAction(ISD::ROTL, MVT::v16i16, Custom); - setOperationAction(ISD::ROTL, MVT::v32i8, Custom); + + // With BWI, expanding (and promoting the shifts) is the better. + if (!Subtarget.hasBWI()) + setOperationAction(ISD::ROTL, MVT::v32i8, Custom); setOperationAction(ISD::SELECT, MVT::v4f64, Custom); setOperationAction(ISD::SELECT, MVT::v4i64, Custom); + setOperationAction(ISD::SELECT, MVT::v8i32, Custom); + setOperationAction(ISD::SELECT, MVT::v16i16, Custom); + setOperationAction(ISD::SELECT, MVT::v32i8, Custom); setOperationAction(ISD::SELECT, MVT::v8f32, Custom); for (auto VT : { MVT::v16i16, MVT::v8i32, MVT::v4i64 }) { @@ -1061,9 +1180,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) { setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); - setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::CTLZ, VT, Custom); + // TODO - remove this once 256-bit X86ISD::ANDNP correctly split. + setOperationAction(ISD::CTTZ, VT, HasInt256 ? Expand : Custom); + // The condition codes aren't legal in SSE/AVX and under AVX512 we use // setcc all the way to isel and prefer SETGT in some isel patterns. setCondCodeAction(ISD::SETLT, VT, Custom); @@ -1086,19 +1207,28 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MUL, MVT::v16i16, HasInt256 ? Legal : Custom); setOperationAction(ISD::MUL, MVT::v32i8, Custom); - setOperationAction(ISD::UMUL_LOHI, MVT::v8i32, Custom); - setOperationAction(ISD::SMUL_LOHI, MVT::v8i32, Custom); - + setOperationAction(ISD::MULHU, MVT::v8i32, Custom); + setOperationAction(ISD::MULHS, MVT::v8i32, Custom); setOperationAction(ISD::MULHU, MVT::v16i16, HasInt256 ? Legal : Custom); setOperationAction(ISD::MULHS, MVT::v16i16, HasInt256 ? Legal : Custom); setOperationAction(ISD::MULHU, MVT::v32i8, Custom); setOperationAction(ISD::MULHS, MVT::v32i8, Custom); + setOperationAction(ISD::ABS, MVT::v4i64, Custom); setOperationAction(ISD::SMAX, MVT::v4i64, Custom); setOperationAction(ISD::UMAX, MVT::v4i64, Custom); setOperationAction(ISD::SMIN, MVT::v4i64, Custom); setOperationAction(ISD::UMIN, MVT::v4i64, Custom); + setOperationAction(ISD::UADDSAT, MVT::v32i8, HasInt256 ? Legal : Custom); + setOperationAction(ISD::SADDSAT, MVT::v32i8, HasInt256 ? Legal : Custom); + setOperationAction(ISD::USUBSAT, MVT::v32i8, HasInt256 ? Legal : Custom); + setOperationAction(ISD::SSUBSAT, MVT::v32i8, HasInt256 ? Legal : Custom); + setOperationAction(ISD::UADDSAT, MVT::v16i16, HasInt256 ? Legal : Custom); + setOperationAction(ISD::SADDSAT, MVT::v16i16, HasInt256 ? Legal : Custom); + setOperationAction(ISD::USUBSAT, MVT::v16i16, HasInt256 ? Legal : Custom); + setOperationAction(ISD::SSUBSAT, MVT::v16i16, HasInt256 ? Legal : Custom); + for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) { setOperationAction(ISD::ABS, VT, HasInt256 ? Legal : Custom); setOperationAction(ISD::SMAX, VT, HasInt256 ? Legal : Custom); @@ -1107,11 +1237,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UMIN, VT, HasInt256 ? Legal : Custom); } - if (HasInt256) { - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i64, Custom); - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i32, Custom); - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v16i16, Custom); + for (auto VT : {MVT::v16i16, MVT::v8i32, MVT::v4i64}) { + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom); + } + if (HasInt256) { // The custom lowering for UINT_TO_FP for v8i32 becomes interesting // when we have a 256bit-wide blend with immediate. setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Custom); @@ -1156,15 +1287,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, if (HasInt256) setOperationAction(ISD::VSELECT, MVT::v32i8, Legal); - // Promote v32i8, v16i16, v8i32 select, and, or, xor to v4i64. - for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) { - setOperationPromotedToType(ISD::AND, VT, MVT::v4i64); - setOperationPromotedToType(ISD::OR, VT, MVT::v4i64); - setOperationPromotedToType(ISD::XOR, VT, MVT::v4i64); - setOperationPromotedToType(ISD::LOAD, VT, MVT::v4i64); - setOperationPromotedToType(ISD::SELECT, VT, MVT::v4i64); - } - if (HasInt256) { // Custom legalize 2x32 to get a little better code. setOperationAction(ISD::MGATHER, MVT::v2f32, Custom); @@ -1224,6 +1346,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::UADDSAT, VT, Custom); + setOperationAction(ISD::SADDSAT, VT, Custom); + setOperationAction(ISD::USUBSAT, VT, Custom); + setOperationAction(ISD::SSUBSAT, VT, Custom); setOperationAction(ISD::BUILD_VECTOR, VT, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); @@ -1307,6 +1433,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom); + if (ExperimentalVectorWideningLegalization) { + // Need to custom widen this if we don't have AVX512BW. + setOperationAction(ISD::ANY_EXTEND, MVT::v8i8, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v8i8, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v8i8, Custom); + } + for (auto VT : { MVT::v16f32, MVT::v8f64 }) { setOperationAction(ISD::FFLOOR, VT, Legal); setOperationAction(ISD::FCEIL, VT, Legal); @@ -1315,12 +1448,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FNEARBYINT, VT, Legal); } - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i64, Custom); - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v16i32, Custom); - // Without BWI we need to use custom lowering to handle MVT::v64i8 input. - setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v64i8, Custom); - setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v64i8, Custom); + for (auto VT : {MVT::v16i32, MVT::v8i64, MVT::v64i8}) { + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom); + } setOperationAction(ISD::CONCAT_VECTORS, MVT::v8f64, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i64, Custom); @@ -1330,11 +1462,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MUL, MVT::v8i64, Custom); setOperationAction(ISD::MUL, MVT::v16i32, Legal); - setOperationAction(ISD::UMUL_LOHI, MVT::v16i32, Custom); - setOperationAction(ISD::SMUL_LOHI, MVT::v16i32, Custom); + setOperationAction(ISD::MULHU, MVT::v16i32, Custom); + setOperationAction(ISD::MULHS, MVT::v16i32, Custom); setOperationAction(ISD::SELECT, MVT::v8f64, Custom); setOperationAction(ISD::SELECT, MVT::v8i64, Custom); + setOperationAction(ISD::SELECT, MVT::v16i32, Custom); + setOperationAction(ISD::SELECT, MVT::v32i16, Custom); + setOperationAction(ISD::SELECT, MVT::v64i8, Custom); setOperationAction(ISD::SELECT, MVT::v16f32, Custom); for (auto VT : { MVT::v16i32, MVT::v8i64 }) { @@ -1347,7 +1482,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SRA, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); - setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::ROTL, VT, Custom); setOperationAction(ISD::ROTR, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -1358,13 +1492,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setCondCodeAction(ISD::SETLE, VT, Custom); } - // Need to promote to 64-bit even though we have 32-bit masked instructions - // because the IR optimizers rearrange bitcasts around logic ops leaving - // too many variations to handle if we don't promote them. - setOperationPromotedToType(ISD::AND, MVT::v16i32, MVT::v8i64); - setOperationPromotedToType(ISD::OR, MVT::v16i32, MVT::v8i64); - setOperationPromotedToType(ISD::XOR, MVT::v16i32, MVT::v8i64); - if (Subtarget.hasDQI()) { setOperationAction(ISD::SINT_TO_FP, MVT::v8i64, Legal); setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); @@ -1378,7 +1505,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // NonVLX sub-targets extend 128/256 vectors to use the 512 version. for (auto VT : { MVT::v16i32, MVT::v8i64} ) { setOperationAction(ISD::CTLZ, VT, Legal); - setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); } } // Subtarget.hasCDI() @@ -1407,16 +1533,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); } - for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32 }) { - setOperationPromotedToType(ISD::LOAD, VT, MVT::v8i64); - setOperationPromotedToType(ISD::SELECT, VT, MVT::v8i64); - } - // Need to custom split v32i16/v64i8 bitcasts. if (!Subtarget.hasBWI()) { setOperationAction(ISD::BITCAST, MVT::v32i16, Custom); setOperationAction(ISD::BITCAST, MVT::v64i8, Custom); } + + if (Subtarget.hasVBMI2()) { + for (auto VT : { MVT::v16i32, MVT::v8i64 }) { + setOperationAction(ISD::FSHL, VT, Custom); + setOperationAction(ISD::FSHR, VT, Custom); + } + } }// has AVX-512 // This block controls legalization for operations that don't have @@ -1468,7 +1596,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, if (Subtarget.hasCDI()) { for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) { setOperationAction(ISD::CTLZ, VT, Legal); - setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); } } // Subtarget.hasCDI() @@ -1490,6 +1617,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); setOperationAction(ISD::VSELECT, VT, Expand); + setOperationAction(ISD::UADDSAT, VT, Custom); + setOperationAction(ISD::SADDSAT, VT, Custom); + setOperationAction(ISD::USUBSAT, VT, Custom); + setOperationAction(ISD::SSUBSAT, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -1550,6 +1681,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::BITREVERSE, MVT::v64i8, Custom); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v32i16, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v32i16, Custom); setTruncStoreAction(MVT::v32i16, MVT::v32i8, Legal); @@ -1563,17 +1695,21 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MLOAD, VT, Legal); setOperationAction(ISD::MSTORE, VT, Legal); setOperationAction(ISD::CTPOP, VT, Custom); - setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::CTLZ, VT, Custom); setOperationAction(ISD::SMAX, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); setOperationAction(ISD::SMIN, VT, Legal); setOperationAction(ISD::UMIN, VT, Legal); setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::UADDSAT, VT, Legal); + setOperationAction(ISD::SADDSAT, VT, Legal); + setOperationAction(ISD::USUBSAT, VT, Legal); + setOperationAction(ISD::SSUBSAT, VT, Legal); - setOperationPromotedToType(ISD::AND, VT, MVT::v8i64); - setOperationPromotedToType(ISD::OR, VT, MVT::v8i64); - setOperationPromotedToType(ISD::XOR, VT, MVT::v8i64); + // The condition codes aren't legal in SSE/AVX and under AVX512 we use + // setcc all the way to isel and prefer SETGT in some isel patterns. + setCondCodeAction(ISD::SETLT, VT, Custom); + setCondCodeAction(ISD::SETLE, VT, Custom); } for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD}) { @@ -1584,6 +1720,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::v64i8, MVT::v32i16 }) setOperationAction(ISD::CTPOP, VT, Legal); } + + if (Subtarget.hasVBMI2()) { + setOperationAction(ISD::FSHL, MVT::v32i16, Custom); + setOperationAction(ISD::FSHR, MVT::v32i16, Custom); + } } if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) { @@ -1630,6 +1771,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(MVT::v16i16, MVT::v16i8, Legal); setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal); } + + if (Subtarget.hasVBMI2()) { + // TODO: Make these legal even without VLX? + for (auto VT : { MVT::v8i16, MVT::v4i32, MVT::v2i64, + MVT::v16i16, MVT::v8i32, MVT::v4i64 }) { + setOperationAction(ISD::FSHL, VT, Custom); + setOperationAction(ISD::FSHR, VT, Custom); + } + } } // We want to custom lower some of our intrinsics. @@ -1731,8 +1881,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTargetDAGCombine(ISD::ANY_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); - setTargetDAGCombine(ISD::SIGN_EXTEND_VECTOR_INREG); - setTargetDAGCombine(ISD::ZERO_EXTEND_VECTOR_INREG); setTargetDAGCombine(ISD::SINT_TO_FP); setTargetDAGCombine(ISD::UINT_TO_FP); setTargetDAGCombine(ISD::SETCC); @@ -1787,13 +1935,13 @@ SDValue X86TargetLowering::emitStackGuardXorFP(SelectionDAG &DAG, SDValue Val, } TargetLoweringBase::LegalizeTypeAction -X86TargetLowering::getPreferredVectorAction(EVT VT) const { +X86TargetLowering::getPreferredVectorAction(MVT VT) const { if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI()) return TypeSplitVector; if (ExperimentalVectorWideningLegalization && VT.getVectorNumElements() != 1 && - VT.getVectorElementType().getSimpleVT() != MVT::i1) + VT.getVectorElementType() != MVT::i1) return TypeWidenVector; return TargetLoweringBase::getPreferredVectorAction(VT); @@ -1926,7 +2074,8 @@ X86TargetLowering::getOptimalMemOpType(uint64_t Size, if (Subtarget.hasSSE2()) return MVT::v16i8; // TODO: Can SSE1 handle a byte vector? - if (Subtarget.hasSSE1()) + // If we have SSE1 registers we should be able to use them. + if (Subtarget.hasSSE1() && (Subtarget.is64Bit() || Subtarget.hasX87())) return MVT::v4f32; } else if ((!IsMemset || ZeroMemset) && !MemcpyStrSrc && Size >= 8 && !Subtarget.is64Bit() && Subtarget.hasSSE2()) { @@ -3138,7 +3287,7 @@ SDValue X86TargetLowering::LowerFormalArguments( } // If value is passed via pointer - do a load. - if (VA.getLocInfo() == CCValAssign::Indirect) + if (VA.getLocInfo() == CCValAssign::Indirect && !Ins[I].Flags.isByVal()) ArgValue = DAG.getLoad(VA.getValVT(), dl, Chain, ArgValue, MachinePointerInfo()); @@ -3621,13 +3770,29 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Arg = DAG.getBitcast(RegVT, Arg); break; case CCValAssign::Indirect: { - // Store the argument. - SDValue SpillSlot = DAG.CreateStackTemporary(VA.getValVT()); - int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex(); - Chain = DAG.getStore( - Chain, dl, Arg, SpillSlot, - MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI)); - Arg = SpillSlot; + if (isByVal) { + // Memcpy the argument to a temporary stack slot to prevent + // the caller from seeing any modifications the callee may make + // as guaranteed by the `byval` attribute. + int FrameIdx = MF.getFrameInfo().CreateStackObject( + Flags.getByValSize(), std::max(16, (int)Flags.getByValAlign()), + false); + SDValue StackSlot = + DAG.getFrameIndex(FrameIdx, getPointerTy(DAG.getDataLayout())); + Chain = + CreateCopyOfByValArgument(Arg, StackSlot, Chain, Flags, DAG, dl); + // From now on treat this as a regular pointer + Arg = StackSlot; + isByVal = false; + } else { + // Store the argument. + SDValue SpillSlot = DAG.CreateStackTemporary(VA.getValVT()); + int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex(); + Chain = DAG.getStore( + Chain, dl, Arg, SpillSlot, + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI)); + Arg = SpillSlot; + } break; } } @@ -4405,6 +4570,7 @@ static bool isTargetShuffleVariableMask(unsigned Opcode) { case X86ISD::VPERMV3: return true; // 'Faux' Target Shuffles. + case ISD::OR: case ISD::AND: case X86ISD::ANDNP: return true; @@ -4686,6 +4852,14 @@ bool X86TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, return true; } +bool X86TargetLowering::reduceSelectOfFPConstantLoads(bool IsFPSetCC) const { + // If we are using XMM registers in the ABI and the condition of the select is + // a floating-point compare and we have blendv or conditional move, then it is + // cheaper to select instead of doing a cross-register move and creating a + // load that depends on the compare result. + return !IsFPSetCC || !Subtarget.isTarget64BitLP64() || !Subtarget.hasAVX(); +} + bool X86TargetLowering::convertSelectOfConstantsToMath(EVT VT) const { // TODO: It might be a win to ease or lift this restriction, but the generic // folds in DAGCombiner conflict with vector folds for an AVX512 target. @@ -4695,6 +4869,31 @@ bool X86TargetLowering::convertSelectOfConstantsToMath(EVT VT) const { return true; } +bool X86TargetLowering::decomposeMulByConstant(EVT VT, SDValue C) const { + // TODO: We handle scalars using custom code, but generic combining could make + // that unnecessary. + APInt MulC; + if (!ISD::isConstantSplatVector(C.getNode(), MulC)) + return false; + + // If vector multiply is legal, assume that's faster than shl + add/sub. + // TODO: Multiply is a complex op with higher latency and lower througput in + // most implementations, so this check could be loosened based on type + // and/or a CPU attribute. + if (isOperationLegal(ISD::MUL, VT)) + return false; + + // shl+add, shl+sub, shl+add+neg + return (MulC + 1).isPowerOf2() || (MulC - 1).isPowerOf2() || + (1 - MulC).isPowerOf2() || (-(MulC + 1)).isPowerOf2(); +} + +bool X86TargetLowering::shouldUseStrictFP_TO_INT(EVT FpVT, EVT IntVT, + bool IsSigned) const { + // f80 UINT_TO_FP is more efficient using Strict code if FCMOV is available. + return !IsSigned && FpVT == MVT::f80 && Subtarget.hasCMov(); +} + bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) @@ -4709,6 +4908,18 @@ bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, return (Index % ResVT.getVectorNumElements()) == 0; } +bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const { + // If the vector op is not supported, try to convert to scalar. + EVT VecVT = VecOp.getValueType(); + if (!isOperationLegalOrCustomOrPromote(VecOp.getOpcode(), VecVT)) + return true; + + // If the vector op is supported, but the scalar op is not, the transform may + // not be worthwhile. + EVT ScalarVT = VecVT.getScalarType(); + return isOperationLegalOrCustomOrPromote(VecOp.getOpcode(), ScalarVT); +} + bool X86TargetLowering::isCheapToSpeculateCttz() const { // Speculate cttz only if we can directly use TZCNT. return Subtarget.hasBMI(); @@ -4721,7 +4932,11 @@ bool X86TargetLowering::isCheapToSpeculateCtlz() const { bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, EVT BitcastVT) const { - if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1) + if (!Subtarget.hasAVX512() && !LoadVT.isVector() && BitcastVT.isVector() && + BitcastVT.getVectorElementType() == MVT::i1) + return false; + + if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1 && LoadVT == MVT::i8) return false; return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT); @@ -4763,17 +4978,14 @@ bool X86TargetLowering::hasAndNotCompare(SDValue Y) const { if (VT != MVT::i32 && VT != MVT::i64) return false; - // A mask and compare against constant is ok for an 'andn' too - // even though the BMI instruction doesn't have an immediate form. - - return true; + return !isa<ConstantSDNode>(Y); } bool X86TargetLowering::hasAndNot(SDValue Y) const { EVT VT = Y.getValueType(); - if (!VT.isVector()) // x86 can't form 'andn' with an immediate. - return !isa<ConstantSDNode>(Y) && hasAndNotCompare(Y); + if (!VT.isVector()) + return hasAndNotCompare(Y); // Vector. @@ -4800,6 +5012,12 @@ bool X86TargetLowering::preferShiftsToClearExtremeBits(SDValue Y) const { return true; } +bool X86TargetLowering::shouldSplatInsEltVarIndex(EVT VT) const { + // Any legal vector type can be splatted more efficiently than + // loading/spilling from memory. + return isTypeLegal(VT); +} + MVT X86TargetLowering::hasFastEqualityCompare(unsigned NumBits) const { MVT VT = MVT::getIntegerVT(NumBits); if (isTypeLegal(VT)) @@ -5408,24 +5626,29 @@ static SDValue getOnesVector(EVT VT, SelectionDAG &DAG, const SDLoc &dl) { return DAG.getBitcast(VT, Vec); } -static SDValue getExtendInVec(unsigned Opc, const SDLoc &DL, EVT VT, SDValue In, +static SDValue getExtendInVec(bool Signed, const SDLoc &DL, EVT VT, SDValue In, SelectionDAG &DAG) { EVT InVT = In.getValueType(); - assert((X86ISD::VSEXT == Opc || X86ISD::VZEXT == Opc) && "Unexpected opcode"); - - if (VT.is128BitVector() && InVT.is128BitVector()) - return X86ISD::VSEXT == Opc ? DAG.getSignExtendVectorInReg(In, DL, VT) - : DAG.getZeroExtendVectorInReg(In, DL, VT); + assert(VT.isVector() && InVT.isVector() && "Expected vector VTs."); // For 256-bit vectors, we only need the lower (128-bit) input half. // For 512-bit vectors, we only need the lower input half or quarter. - if (VT.getSizeInBits() > 128 && InVT.getSizeInBits() > 128) { - int Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits(); + if (InVT.getSizeInBits() > 128) { + assert(VT.getSizeInBits() == InVT.getSizeInBits() && + "Expected VTs to be the same size!"); + unsigned Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits(); In = extractSubVector(In, 0, DAG, DL, - std::max(128, (int)VT.getSizeInBits() / Scale)); + std::max(128U, VT.getSizeInBits() / Scale)); + InVT = In.getValueType(); } - return DAG.getNode(Opc, DL, VT, In); + if (VT.getVectorNumElements() == InVT.getVectorNumElements()) + return DAG.getNode(Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, + DL, VT, In); + + return DAG.getNode(Signed ? ISD::SIGN_EXTEND_VECTOR_INREG + : ISD::ZERO_EXTEND_VECTOR_INREG, + DL, VT, In); } /// Returns a vector_shuffle node for an unpackl operation. @@ -5463,19 +5686,6 @@ static SDValue getShuffleVectorZeroOrUndef(SDValue V2, int Idx, return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, MaskVec); } -static SDValue peekThroughBitcasts(SDValue V) { - while (V.getNode() && V.getOpcode() == ISD::BITCAST) - V = V.getOperand(0); - return V; -} - -static SDValue peekThroughOneUseBitcasts(SDValue V) { - while (V.getNode() && V.getOpcode() == ISD::BITCAST && - V.getOperand(0).hasOneUse()) - V = V.getOperand(0); - return V; -} - // Peek through EXTRACT_SUBVECTORs - typically used for AVX1 256-bit intops. static SDValue peekThroughEXTRACT_SUBVECTORs(SDValue V) { while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) @@ -5496,10 +5706,10 @@ static const Constant *getTargetConstantFromNode(SDValue Op) { Ptr = Ptr->getOperand(0); auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr); - if (!CNode || CNode->isMachineConstantPoolEntry()) + if (!CNode || CNode->isMachineConstantPoolEntry() || CNode->getOffset() != 0) return nullptr; - return dyn_cast<Constant>(CNode->getConstVal()); + return CNode->getConstVal(); } // Extract raw constant bits from constant pools. @@ -5632,15 +5842,34 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, } return CastBitData(UndefSrcElts, SrcEltBits); } + if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) { + unsigned SrcEltSizeInBits = VT.getScalarSizeInBits(); + unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits; + + APInt UndefSrcElts(NumSrcElts, 0); + SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0)); + for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { + const SDValue &Src = Op.getOperand(i); + if (Src.isUndef()) { + UndefSrcElts.setBit(i); + continue; + } + auto *Cst = cast<ConstantFPSDNode>(Src); + APInt RawBits = Cst->getValueAPF().bitcastToAPInt(); + SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits); + } + return CastBitData(UndefSrcElts, SrcEltBits); + } // Extract constant bits from constant pool vector. if (auto *Cst = getTargetConstantFromNode(Op)) { Type *CstTy = Cst->getType(); - if (!CstTy->isVectorTy() || (SizeInBits != CstTy->getPrimitiveSizeInBits())) + unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits(); + if (!CstTy->isVectorTy() || (CstSizeInBits % SizeInBits) != 0) return false; unsigned SrcEltSizeInBits = CstTy->getScalarSizeInBits(); - unsigned NumSrcElts = CstTy->getVectorNumElements(); + unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits; APInt UndefSrcElts(NumSrcElts, 0); SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0)); @@ -5685,19 +5914,107 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, return CastBitData(UndefSrcElts, SrcEltBits); } + // Extract constant bits from a subvector's source. + if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR && + isa<ConstantSDNode>(Op.getOperand(1))) { + // TODO - support extract_subvector through bitcasts. + if (EltSizeInBits != VT.getScalarSizeInBits()) + return false; + + if (getTargetConstantBitsFromNode(Op.getOperand(0), EltSizeInBits, + UndefElts, EltBits, AllowWholeUndefs, + AllowPartialUndefs)) { + EVT SrcVT = Op.getOperand(0).getValueType(); + unsigned NumSrcElts = SrcVT.getVectorNumElements(); + unsigned NumSubElts = VT.getVectorNumElements(); + unsigned BaseIdx = Op.getConstantOperandVal(1); + UndefElts = UndefElts.extractBits(NumSubElts, BaseIdx); + if ((BaseIdx + NumSubElts) != NumSrcElts) + EltBits.erase(EltBits.begin() + BaseIdx + NumSubElts, EltBits.end()); + if (BaseIdx != 0) + EltBits.erase(EltBits.begin(), EltBits.begin() + BaseIdx); + return true; + } + } + + // Extract constant bits from shuffle node sources. + if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(Op)) { + // TODO - support shuffle through bitcasts. + if (EltSizeInBits != VT.getScalarSizeInBits()) + return false; + + ArrayRef<int> Mask = SVN->getMask(); + if ((!AllowWholeUndefs || !AllowPartialUndefs) && + llvm::any_of(Mask, [](int M) { return M < 0; })) + return false; + + APInt UndefElts0, UndefElts1; + SmallVector<APInt, 32> EltBits0, EltBits1; + if (isAnyInRange(Mask, 0, NumElts) && + !getTargetConstantBitsFromNode(Op.getOperand(0), EltSizeInBits, + UndefElts0, EltBits0, AllowWholeUndefs, + AllowPartialUndefs)) + return false; + if (isAnyInRange(Mask, NumElts, 2 * NumElts) && + !getTargetConstantBitsFromNode(Op.getOperand(1), EltSizeInBits, + UndefElts1, EltBits1, AllowWholeUndefs, + AllowPartialUndefs)) + return false; + + UndefElts = APInt::getNullValue(NumElts); + for (int i = 0; i != (int)NumElts; ++i) { + int M = Mask[i]; + if (M < 0) { + UndefElts.setBit(i); + EltBits.push_back(APInt::getNullValue(EltSizeInBits)); + } else if (M < (int)NumElts) { + if (UndefElts0[M]) + UndefElts.setBit(i); + EltBits.push_back(EltBits0[M]); + } else { + if (UndefElts1[M - NumElts]) + UndefElts.setBit(i); + EltBits.push_back(EltBits1[M - NumElts]); + } + } + return true; + } + return false; } -static bool getTargetShuffleMaskIndices(SDValue MaskNode, - unsigned MaskEltSizeInBits, - SmallVectorImpl<uint64_t> &RawMask) { +static bool isConstantSplat(SDValue Op, APInt &SplatVal) { APInt UndefElts; - SmallVector<APInt, 64> EltBits; + SmallVector<APInt, 16> EltBits; + if (getTargetConstantBitsFromNode(Op, Op.getScalarValueSizeInBits(), + UndefElts, EltBits, true, false)) { + int SplatIndex = -1; + for (int i = 0, e = EltBits.size(); i != e; ++i) { + if (UndefElts[i]) + continue; + if (0 <= SplatIndex && EltBits[i] != EltBits[SplatIndex]) { + SplatIndex = -1; + break; + } + SplatIndex = i; + } + if (0 <= SplatIndex) { + SplatVal = EltBits[SplatIndex]; + return true; + } + } + + return false; +} +static bool getTargetShuffleMaskIndices(SDValue MaskNode, + unsigned MaskEltSizeInBits, + SmallVectorImpl<uint64_t> &RawMask, + APInt &UndefElts) { // Extract the raw target constant bits. - // FIXME: We currently don't support UNDEF bits or mask entries. + SmallVector<APInt, 64> EltBits; if (!getTargetConstantBitsFromNode(MaskNode, MaskEltSizeInBits, UndefElts, - EltBits, /* AllowWholeUndefs */ false, + EltBits, /* AllowWholeUndefs */ true, /* AllowPartialUndefs */ false)) return false; @@ -5726,6 +6043,31 @@ static void createPackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, } } +// Split the demanded elts of a PACKSS/PACKUS node between its operands. +static void getPackDemandedElts(EVT VT, const APInt &DemandedElts, + APInt &DemandedLHS, APInt &DemandedRHS) { + int NumLanes = VT.getSizeInBits() / 128; + int NumElts = DemandedElts.getBitWidth(); + int NumInnerElts = NumElts / 2; + int NumEltsPerLane = NumElts / NumLanes; + int NumInnerEltsPerLane = NumInnerElts / NumLanes; + + DemandedLHS = APInt::getNullValue(NumInnerElts); + DemandedRHS = APInt::getNullValue(NumInnerElts); + + // Map DemandedElts to the packed operands. + for (int Lane = 0; Lane != NumLanes; ++Lane) { + for (int Elt = 0; Elt != NumInnerEltsPerLane; ++Elt) { + int OuterIdx = (Lane * NumEltsPerLane) + Elt; + int InnerIdx = (Lane * NumInnerEltsPerLane) + Elt; + if (DemandedElts[OuterIdx]) + DemandedLHS.setBit(InnerIdx); + if (DemandedElts[OuterIdx + NumInnerEltsPerLane]) + DemandedRHS.setBit(InnerIdx); + } + } +} + /// Calculates the shuffle mask corresponding to the target-specific opcode. /// If the mask could be calculated, returns it in \p Mask, returns the shuffle /// operands in \p Ops, and returns true. @@ -5737,6 +6079,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, SmallVectorImpl<SDValue> &Ops, SmallVectorImpl<int> &Mask, bool &IsUnary) { unsigned NumElems = VT.getVectorNumElements(); + unsigned MaskEltSize = VT.getScalarSizeInBits(); + SmallVector<uint64_t, 32> RawMask; + APInt RawUndefs; SDValue ImmN; assert(Mask.empty() && "getTargetShuffleMask expects an empty Mask vector"); @@ -5744,26 +6089,26 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, IsUnary = false; bool IsFakeUnary = false; - switch(N->getOpcode()) { + switch (N->getOpcode()) { case X86ISD::BLENDI: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodeBLENDMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::SHUFP: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); - DecodeSHUFPMask(NumElems, VT.getScalarSizeInBits(), + ImmN = N->getOperand(N->getNumOperands() - 1); + DecodeSHUFPMask(NumElems, MaskEltSize, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::INSERTPS: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodeINSERTPSMask(cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; @@ -5773,8 +6118,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, isa<ConstantSDNode>(N->getOperand(2))) { int BitLen = N->getConstantOperandVal(1); int BitIdx = N->getConstantOperandVal(2); - DecodeEXTRQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx, - Mask); + DecodeEXTRQIMask(NumElems, MaskEltSize, BitLen, BitIdx, Mask); IsUnary = true; } break; @@ -5785,21 +6129,20 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, isa<ConstantSDNode>(N->getOperand(3))) { int BitLen = N->getConstantOperandVal(2); int BitIdx = N->getConstantOperandVal(3); - DecodeINSERTQIMask(NumElems, VT.getScalarSizeInBits(), BitLen, BitIdx, - Mask); + DecodeINSERTQIMask(NumElems, MaskEltSize, BitLen, BitIdx, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); } break; case X86ISD::UNPCKH: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - DecodeUNPCKHMask(NumElems, VT.getScalarSizeInBits(), Mask); + DecodeUNPCKHMask(NumElems, MaskEltSize, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::UNPCKL: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - DecodeUNPCKLMask(NumElems, VT.getScalarSizeInBits(), Mask); + DecodeUNPCKLMask(NumElems, MaskEltSize, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVHLPS: @@ -5818,7 +6161,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodePALIGNRMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); @@ -5844,21 +6187,21 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, case X86ISD::PSHUFD: case X86ISD::VPERMILPI: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); - DecodePSHUFMask(NumElems, VT.getScalarSizeInBits(), + ImmN = N->getOperand(N->getNumOperands() - 1); + DecodePSHUFMask(NumElems, MaskEltSize, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFHW: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodePSHUFHWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFLW: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodePSHUFLWMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; @@ -5891,14 +6234,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); IsUnary = true; SDValue MaskNode = N->getOperand(1); - unsigned MaskEltSize = VT.getScalarSizeInBits(); - SmallVector<uint64_t, 32> RawMask; - if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) { - DecodeVPERMILPMask(NumElems, VT.getScalarSizeInBits(), RawMask, Mask); - break; - } - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPERMILPMask(C, MaskEltSize, Mask); + if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask, + RawUndefs)) { + DecodeVPERMILPMask(NumElems, MaskEltSize, RawMask, RawUndefs, Mask); break; } return false; @@ -5909,20 +6247,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = true; SDValue MaskNode = N->getOperand(1); - SmallVector<uint64_t, 32> RawMask; - if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask)) { - DecodePSHUFBMask(RawMask, Mask); - break; - } - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodePSHUFBMask(C, Mask); + if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask, RawUndefs)) { + DecodePSHUFBMask(RawMask, RawUndefs, Mask); break; } return false; } case X86ISD::VPERMI: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodeVPERMMask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; @@ -5935,7 +6268,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, case X86ISD::VPERM2X128: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); + ImmN = N->getOperand(N->getNumOperands() - 1); DecodeVPERM2X128Mask(NumElems, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); @@ -5943,10 +6276,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, case X86ISD::SHUF128: assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); - ImmN = N->getOperand(N->getNumOperands()-1); - decodeVSHUF64x2FamilyMask(NumElems, VT.getScalarSizeInBits(), - cast<ConstantSDNode>(ImmN)->getZExtValue(), - Mask); + ImmN = N->getOperand(N->getNumOperands() - 1); + decodeVSHUF64x2FamilyMask(NumElems, MaskEltSize, + cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVSLDUP: @@ -5968,19 +6300,14 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); - unsigned MaskEltSize = VT.getScalarSizeInBits(); SDValue MaskNode = N->getOperand(2); SDValue CtrlNode = N->getOperand(3); if (ConstantSDNode *CtrlOp = dyn_cast<ConstantSDNode>(CtrlNode)) { unsigned CtrlImm = CtrlOp->getZExtValue(); - SmallVector<uint64_t, 32> RawMask; - if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) { - DecodeVPERMIL2PMask(NumElems, VT.getScalarSizeInBits(), CtrlImm, - RawMask, Mask); - break; - } - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, Mask); + if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask, + RawUndefs)) { + DecodeVPERMIL2PMask(NumElems, MaskEltSize, CtrlImm, RawMask, RawUndefs, + Mask); break; } } @@ -5991,13 +6318,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); SDValue MaskNode = N->getOperand(2); - SmallVector<uint64_t, 32> RawMask; - if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask)) { - DecodeVPPERMMask(RawMask, Mask); - break; - } - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPPERMMask(C, Mask); + if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask, RawUndefs)) { + DecodeVPPERMMask(RawMask, RawUndefs, Mask); break; } return false; @@ -6008,14 +6330,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, // Unlike most shuffle nodes, VPERMV's mask operand is operand 0. Ops.push_back(N->getOperand(1)); SDValue MaskNode = N->getOperand(0); - SmallVector<uint64_t, 32> RawMask; - unsigned MaskEltSize = VT.getScalarSizeInBits(); - if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask)) { - DecodeVPERMVMask(RawMask, Mask); - break; - } - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPERMVMask(C, MaskEltSize, Mask); + if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask, + RawUndefs)) { + DecodeVPERMVMask(RawMask, RawUndefs, Mask); break; } return false; @@ -6028,9 +6345,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, Ops.push_back(N->getOperand(0)); Ops.push_back(N->getOperand(2)); SDValue MaskNode = N->getOperand(1); - unsigned MaskEltSize = VT.getScalarSizeInBits(); - if (auto *C = getTargetConstantFromNode(MaskNode)) { - DecodeVPERMV3Mask(C, MaskEltSize, Mask); + if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask, + RawUndefs)) { + DecodeVPERMV3Mask(RawMask, RawUndefs, Mask); break; } return false; @@ -6147,6 +6464,12 @@ static bool setTargetShuffleZeroElements(SDValue N, return true; } +// Forward declaration (for getFauxShuffleMask recursive check). +static bool resolveTargetShuffleInputs(SDValue Op, + SmallVectorImpl<SDValue> &Inputs, + SmallVectorImpl<int> &Mask, + const SelectionDAG &DAG); + // Attempt to decode ops that could be represented as a shuffle mask. // The decoded shuffle mask may contain a different number of elements to the // destination value type. @@ -6200,6 +6523,78 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, Ops.push_back(IsAndN ? N1 : N0); return true; } + case ISD::OR: { + // Handle OR(SHUFFLE,SHUFFLE) case where one source is zero and the other + // is a valid shuffle index. + SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); + SDValue N1 = peekThroughOneUseBitcasts(N.getOperand(1)); + if (!N0.getValueType().isVector() || !N1.getValueType().isVector()) + return false; + SmallVector<int, 64> SrcMask0, SrcMask1; + SmallVector<SDValue, 2> SrcInputs0, SrcInputs1; + if (!resolveTargetShuffleInputs(N0, SrcInputs0, SrcMask0, DAG) || + !resolveTargetShuffleInputs(N1, SrcInputs1, SrcMask1, DAG)) + return false; + int MaskSize = std::max(SrcMask0.size(), SrcMask1.size()); + SmallVector<int, 64> Mask0, Mask1; + scaleShuffleMask<int>(MaskSize / SrcMask0.size(), SrcMask0, Mask0); + scaleShuffleMask<int>(MaskSize / SrcMask1.size(), SrcMask1, Mask1); + for (int i = 0; i != MaskSize; ++i) { + if (Mask0[i] == SM_SentinelUndef && Mask1[i] == SM_SentinelUndef) + Mask.push_back(SM_SentinelUndef); + else if (Mask0[i] == SM_SentinelZero && Mask1[i] == SM_SentinelZero) + Mask.push_back(SM_SentinelZero); + else if (Mask1[i] == SM_SentinelZero) + Mask.push_back(Mask0[i]); + else if (Mask0[i] == SM_SentinelZero) + Mask.push_back(Mask1[i] + (MaskSize * SrcInputs0.size())); + else + return false; + } + for (SDValue &Op : SrcInputs0) + Ops.push_back(Op); + for (SDValue &Op : SrcInputs1) + Ops.push_back(Op); + return true; + } + case ISD::INSERT_SUBVECTOR: { + // Handle INSERT_SUBVECTOR(SRC0, SHUFFLE(EXTRACT_SUBVECTOR(SRC1)) where + // SRC0/SRC1 are both of the same valuetype VT. + // TODO - add peekThroughOneUseBitcasts support. + SDValue Src = N.getOperand(0); + SDValue Sub = N.getOperand(1); + EVT SubVT = Sub.getValueType(); + unsigned NumSubElts = SubVT.getVectorNumElements(); + if (!isa<ConstantSDNode>(N.getOperand(2)) || + !N->isOnlyUserOf(Sub.getNode())) + return false; + SmallVector<int, 64> SubMask; + SmallVector<SDValue, 2> SubInputs; + if (!resolveTargetShuffleInputs(Sub, SubInputs, SubMask, DAG) || + SubMask.size() != NumSubElts) + return false; + Ops.push_back(Src); + for (SDValue &SubInput : SubInputs) { + if (SubInput.getOpcode() != ISD::EXTRACT_SUBVECTOR || + SubInput.getOperand(0).getValueType() != VT || + !isa<ConstantSDNode>(SubInput.getOperand(1))) + return false; + Ops.push_back(SubInput.getOperand(0)); + } + int InsertIdx = N.getConstantOperandVal(2); + for (int i = 0; i != (int)NumElts; ++i) + Mask.push_back(i); + for (int i = 0; i != (int)NumSubElts; ++i) { + int M = SubMask[i]; + if (0 <= M) { + int InputIdx = M / NumSubElts; + int ExtractIdx = SubInputs[InputIdx].getConstantOperandVal(1); + M = (NumElts * (1 + InputIdx)) + ExtractIdx + (M % NumSubElts); + } + Mask[i + InsertIdx] = M; + } + return true; + } case ISD::SCALAR_TO_VECTOR: { // Match against a scalar_to_vector of an extract from a vector, // for PEXTRW/PEXTRB we must handle the implicit zext of the scalar. @@ -6334,14 +6729,14 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, return true; } case ISD::ZERO_EXTEND_VECTOR_INREG: - case X86ISD::VZEXT: { + case ISD::ZERO_EXTEND: { // TODO - add support for VPMOVZX with smaller input vector types. SDValue Src = N.getOperand(0); MVT SrcVT = Src.getSimpleValueType(); if (NumSizeInBits != SrcVT.getSizeInBits()) break; - DecodeZeroExtendMask(SrcVT.getScalarSizeInBits(), VT.getScalarSizeInBits(), - VT.getVectorNumElements(), Mask); + DecodeZeroExtendMask(SrcVT.getScalarSizeInBits(), NumBitsPerElt, NumElts, + Mask); Ops.push_back(Src); return true; } @@ -6586,6 +6981,26 @@ static SDValue LowerBuildVectorv8i16(SDValue Op, unsigned NonZeros, /// Custom lower build_vector of v4i32 or v4f32. static SDValue LowerBuildVectorv4x32(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + // If this is a splat of a pair of elements, use MOVDDUP (unless the target + // has XOP; in that case defer lowering to potentially use VPERMIL2PS). + // Because we're creating a less complicated build vector here, we may enable + // further folding of the MOVDDUP via shuffle transforms. + if (Subtarget.hasSSE3() && !Subtarget.hasXOP() && + Op.getOperand(0) == Op.getOperand(2) && + Op.getOperand(1) == Op.getOperand(3) && + Op.getOperand(0) != Op.getOperand(1)) { + SDLoc DL(Op); + MVT VT = Op.getSimpleValueType(); + MVT EltVT = VT.getVectorElementType(); + // Create a new build vector with the first 2 elements followed by undef + // padding, bitcast to v2f64, duplicate, and bitcast back. + SDValue Ops[4] = { Op.getOperand(0), Op.getOperand(1), + DAG.getUNDEF(EltVT), DAG.getUNDEF(EltVT) }; + SDValue NewBV = DAG.getBitcast(MVT::v2f64, DAG.getBuildVector(VT, DL, Ops)); + SDValue Dup = DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v2f64, NewBV); + return DAG.getBitcast(VT, Dup); + } + // Find all zeroable elements. std::bitset<4> Zeroable; for (int i=0; i < 4; ++i) { @@ -7059,9 +7474,9 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, } } - // We need a splat of a single value to use broadcast, and it doesn't - // make any sense if the value is only in one element of the vector. - if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) { + unsigned NumElts = VT.getVectorNumElements(); + unsigned NumUndefElts = UndefElements.count(); + if (!Ld || (NumElts - NumUndefElts) <= 1) { APInt SplatValue, Undef; unsigned SplatBitSize; bool HasUndef; @@ -7137,7 +7552,17 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, } } } - return SDValue(); + + // If we are moving a scalar into a vector (Ld must be set and all elements + // but 1 are undef) and that operation is not obviously supported by + // vmovd/vmovq/vmovss/vmovsd, then keep trying to form a broadcast. + // That's better than general shuffling and may eliminate a load to GPR and + // move from scalar to vector register. + if (!Ld || NumElts - NumUndefElts != 1) + return SDValue(); + unsigned ScalarSize = Ld.getValueSizeInBits(); + if (!(UndefElements[0] || (ScalarSize != 32 && ScalarSize != 64))) + return SDValue(); } bool ConstSplatVal = @@ -7434,13 +7859,14 @@ static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG, return DstVec; } -/// Return true if \p N implements a horizontal binop and return the -/// operands for the horizontal binop into V0 and V1. -/// /// This is a helper function of LowerToHorizontalOp(). /// This function checks that the build_vector \p N in input implements a -/// horizontal operation. Parameter \p Opcode defines the kind of horizontal -/// operation to match. +/// 128-bit partial horizontal operation on a 256-bit vector, but that operation +/// may not match the layout of an x86 256-bit horizontal instruction. +/// In other words, if this returns true, then some extraction/insertion will +/// be required to produce a valid horizontal instruction. +/// +/// Parameter \p Opcode defines the kind of horizontal operation to match. /// For example, if \p Opcode is equal to ISD::ADD, then this function /// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode /// is equal to ISD::SUB, then this function checks if this is a horizontal @@ -7448,12 +7874,17 @@ static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG, /// /// This function only analyzes elements of \p N whose indices are /// in range [BaseIdx, LastIdx). -static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, - SelectionDAG &DAG, - unsigned BaseIdx, unsigned LastIdx, - SDValue &V0, SDValue &V1) { +/// +/// TODO: This function was originally used to match both real and fake partial +/// horizontal operations, but the index-matching logic is incorrect for that. +/// See the corrected implementation in isHopBuildVector(). Can we reduce this +/// code because it is only used for partial h-op matching now? +static bool isHorizontalBinOpPart(const BuildVectorSDNode *N, unsigned Opcode, + SelectionDAG &DAG, + unsigned BaseIdx, unsigned LastIdx, + SDValue &V0, SDValue &V1) { EVT VT = N->getValueType(0); - + assert(VT.is256BitVector() && "Only use for matching partial 256-bit h-ops"); assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!"); assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx && "Invalid Vector in input!"); @@ -7623,7 +8054,7 @@ static bool isAddSubOrSubAdd(const BuildVectorSDNode *BV, // adding/subtracting two integer/float elements. // Even-numbered elements in the input build vector are obtained from // subtracting/adding two integer/float elements. - unsigned Opc[2] {0, 0}; + unsigned Opc[2] = {0, 0}; for (unsigned i = 0, e = NumElts; i != e; ++i) { SDValue Op = BV->getOperand(i); @@ -7794,17 +8225,158 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } +static bool isHopBuildVector(const BuildVectorSDNode *BV, SelectionDAG &DAG, + unsigned &HOpcode, SDValue &V0, SDValue &V1) { + // Initialize outputs to known values. + MVT VT = BV->getSimpleValueType(0); + HOpcode = ISD::DELETED_NODE; + V0 = DAG.getUNDEF(VT); + V1 = DAG.getUNDEF(VT); + + // x86 256-bit horizontal ops are defined in a non-obvious way. Each 128-bit + // half of the result is calculated independently from the 128-bit halves of + // the inputs, so that makes the index-checking logic below more complicated. + unsigned NumElts = VT.getVectorNumElements(); + unsigned GenericOpcode = ISD::DELETED_NODE; + unsigned Num128BitChunks = VT.is256BitVector() ? 2 : 1; + unsigned NumEltsIn128Bits = NumElts / Num128BitChunks; + unsigned NumEltsIn64Bits = NumEltsIn128Bits / 2; + for (unsigned i = 0; i != Num128BitChunks; ++i) { + for (unsigned j = 0; j != NumEltsIn128Bits; ++j) { + // Ignore undef elements. + SDValue Op = BV->getOperand(i * NumEltsIn128Bits + j); + if (Op.isUndef()) + continue; + + // If there's an opcode mismatch, we're done. + if (HOpcode != ISD::DELETED_NODE && Op.getOpcode() != GenericOpcode) + return false; + + // Initialize horizontal opcode. + if (HOpcode == ISD::DELETED_NODE) { + GenericOpcode = Op.getOpcode(); + switch (GenericOpcode) { + case ISD::ADD: HOpcode = X86ISD::HADD; break; + case ISD::SUB: HOpcode = X86ISD::HSUB; break; + case ISD::FADD: HOpcode = X86ISD::FHADD; break; + case ISD::FSUB: HOpcode = X86ISD::FHSUB; break; + default: return false; + } + } + + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + Op0.getOperand(0) != Op1.getOperand(0) || + !isa<ConstantSDNode>(Op0.getOperand(1)) || + !isa<ConstantSDNode>(Op1.getOperand(1)) || !Op.hasOneUse()) + return false; + + // The source vector is chosen based on which 64-bit half of the + // destination vector is being calculated. + if (j < NumEltsIn64Bits) { + if (V0.isUndef()) + V0 = Op0.getOperand(0); + } else { + if (V1.isUndef()) + V1 = Op0.getOperand(0); + } + + SDValue SourceVec = (j < NumEltsIn64Bits) ? V0 : V1; + if (SourceVec != Op0.getOperand(0)) + return false; + + // op (extract_vector_elt A, I), (extract_vector_elt A, I+1) + unsigned ExtIndex0 = Op0.getConstantOperandVal(1); + unsigned ExtIndex1 = Op1.getConstantOperandVal(1); + unsigned ExpectedIndex = i * NumEltsIn128Bits + + (j % NumEltsIn64Bits) * 2; + if (ExpectedIndex == ExtIndex0 && ExtIndex1 == ExtIndex0 + 1) + continue; + + // If this is not a commutative op, this does not match. + if (GenericOpcode != ISD::ADD && GenericOpcode != ISD::FADD) + return false; + + // Addition is commutative, so try swapping the extract indexes. + // op (extract_vector_elt A, I+1), (extract_vector_elt A, I) + if (ExpectedIndex == ExtIndex1 && ExtIndex0 == ExtIndex1 + 1) + continue; + + // Extract indexes do not match horizontal requirement. + return false; + } + } + // We matched. Opcode and operands are returned by reference as arguments. + return true; +} + +static SDValue getHopForBuildVector(const BuildVectorSDNode *BV, + SelectionDAG &DAG, unsigned HOpcode, + SDValue V0, SDValue V1) { + // If either input vector is not the same size as the build vector, + // extract/insert the low bits to the correct size. + // This is free (examples: zmm --> xmm, xmm --> ymm). + MVT VT = BV->getSimpleValueType(0); + unsigned Width = VT.getSizeInBits(); + if (V0.getValueSizeInBits() > Width) + V0 = extractSubVector(V0, 0, DAG, SDLoc(BV), Width); + else if (V0.getValueSizeInBits() < Width) + V0 = insertSubVector(DAG.getUNDEF(VT), V0, 0, DAG, SDLoc(BV), Width); + + if (V1.getValueSizeInBits() > Width) + V1 = extractSubVector(V1, 0, DAG, SDLoc(BV), Width); + else if (V1.getValueSizeInBits() < Width) + V1 = insertSubVector(DAG.getUNDEF(VT), V1, 0, DAG, SDLoc(BV), Width); + + return DAG.getNode(HOpcode, SDLoc(BV), VT, V0, V1); +} + /// Lower BUILD_VECTOR to a horizontal add/sub operation if possible. static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + // We need at least 2 non-undef elements to make this worthwhile by default. + unsigned NumNonUndefs = 0; + for (const SDValue &V : BV->op_values()) + if (!V.isUndef()) + ++NumNonUndefs; + + if (NumNonUndefs < 2) + return SDValue(); + + // There are 4 sets of horizontal math operations distinguished by type: + // int/FP at 128-bit/256-bit. Each type was introduced with a different + // subtarget feature. Try to match those "native" patterns first. MVT VT = BV->getSimpleValueType(0); + unsigned HOpcode; + SDValue V0, V1; + if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v8i16 || VT == MVT::v4i32) && Subtarget.hasSSSE3()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v8f32 || VT == MVT::v4f64) && Subtarget.hasAVX()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + if ((VT == MVT::v16i16 || VT == MVT::v8i32) && Subtarget.hasAVX2()) + if (isHopBuildVector(BV, DAG, HOpcode, V0, V1)) + return getHopForBuildVector(BV, DAG, HOpcode, V0, V1); + + // Try harder to match 256-bit ops by using extract/concat. + if (!Subtarget.hasAVX() || !VT.is256BitVector()) + return SDValue(); + + // Count the number of UNDEF operands in the build_vector in input. unsigned NumElts = VT.getVectorNumElements(); + unsigned Half = NumElts / 2; unsigned NumUndefsLO = 0; unsigned NumUndefsHI = 0; - unsigned Half = NumElts/2; - - // Count the number of UNDEF operands in the build_vector in input. for (unsigned i = 0, e = Half; i != e; ++i) if (BV->getOperand(i)->isUndef()) NumUndefsLO++; @@ -7813,96 +8385,61 @@ static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV, if (BV->getOperand(i)->isUndef()) NumUndefsHI++; - // Early exit if this is either a build_vector of all UNDEFs or all the - // operands but one are UNDEF. - if (NumUndefsLO + NumUndefsHI + 1 >= NumElts) - return SDValue(); - SDLoc DL(BV); SDValue InVec0, InVec1; - if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) { - // Try to match an SSE3 float HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); - } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget.hasSSSE3()) { - // Try to match an SSSE3 integer HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) - return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1); - } - - if (!Subtarget.hasAVX()) - return SDValue(); - - if ((VT == MVT::v8f32 || VT == MVT::v4f64)) { - // Try to match an AVX horizontal add/sub of packed single/double - // precision floating point values from 256-bit vectors. - SDValue InVec2, InVec3; - if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FADD, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) - return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - - if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FSUB, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) - return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); - } else if (VT == MVT::v8i32 || VT == MVT::v16i16) { - // Try to match an AVX2 horizontal add/sub of signed integers. + if (VT == MVT::v8i32 || VT == MVT::v16i16) { SDValue InVec2, InVec3; unsigned X86Opcode; bool CanFold = true; - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::ADD, DAG, Half, NumElts, InVec2, InVec3) && + if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOpPart(BV, ISD::ADD, DAG, Half, NumElts, InVec2, + InVec3) && ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, Half, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::SUB, DAG, Half, NumElts, InVec2, InVec3) && - ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && - ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) + else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, Half, InVec0, + InVec1) && + isHorizontalBinOpPart(BV, ISD::SUB, DAG, Half, NumElts, InVec2, + InVec3) && + ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) && + ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3)) X86Opcode = X86ISD::HSUB; else CanFold = false; if (CanFold) { - // Fold this build_vector into a single horizontal add/sub. - // Do this only if the target has AVX2. - if (Subtarget.hasAVX2()) - return DAG.getNode(X86Opcode, DL, VT, InVec0, InVec1); - // Do not try to expand this build_vector into a pair of horizontal // add/sub if we can emit a pair of scalar add/sub. if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half) return SDValue(); - // Convert this build_vector into a pair of horizontal binop followed by - // a concat vector. + // Convert this build_vector into a pair of horizontal binops followed by + // a concat vector. We must adjust the outputs from the partial horizontal + // matching calls above to account for undefined vector halves. + SDValue V0 = InVec0.isUndef() ? InVec2 : InVec0; + SDValue V1 = InVec1.isUndef() ? InVec3 : InVec1; + assert((!V0.isUndef() || !V1.isUndef()) && "Horizontal-op of undefs?"); bool isUndefLO = NumUndefsLO == Half; bool isUndefHI = NumUndefsHI == Half; - return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, false, - isUndefLO, isUndefHI); + return ExpandHorizontalBinOp(V0, V1, DL, DAG, X86Opcode, false, isUndefLO, + isUndefHI); } } - if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || - VT == MVT::v16i16) && Subtarget.hasAVX()) { + if (VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || + VT == MVT::v16i16) { unsigned X86Opcode; - if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::HSUB; - else if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::FADD, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::FHADD; - else if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOpPart(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, + InVec1)) X86Opcode = X86ISD::FHSUB; else return SDValue(); @@ -8370,9 +8907,9 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // If we are inserting one variable into a vector of non-zero constants, try // to avoid loading each constant element as a scalar. Load the constants as a // vector and then insert the variable scalar element. If insertion is not - // supported, we assume that we will fall back to a shuffle to get the scalar - // blended with the constants. Insertion into a zero vector is handled as a - // special-case somewhere below here. + // supported, fall back to a shuffle to get the scalar blended with the + // constants. Insertion into a zero vector is handled as a special-case + // somewhere below here. if (NumConstants == NumElems - 1 && NumNonZero != 1 && (isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT) || isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) { @@ -8410,7 +8947,21 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); MachinePointerInfo MPI = MachinePointerInfo::getConstantPool(MF); SDValue Ld = DAG.getLoad(VT, dl, DAG.getEntryNode(), LegalDAGConstVec, MPI); - return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Ld, VarElt, InsIndex); + unsigned InsertC = cast<ConstantSDNode>(InsIndex)->getZExtValue(); + unsigned NumEltsInLow128Bits = 128 / VT.getScalarSizeInBits(); + if (InsertC < NumEltsInLow128Bits) + return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Ld, VarElt, InsIndex); + + // There's no good way to insert into the high elements of a >128-bit + // vector, so use shuffles to avoid an extract/insert sequence. + assert(VT.getSizeInBits() > 128 && "Invalid insertion index?"); + assert(Subtarget.hasAVX() && "Must have AVX with >16-byte vector"); + SmallVector<int, 8> ShuffleMask; + unsigned NumElts = VT.getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) + ShuffleMask.push_back(i == InsertC ? NumElts : i); + SDValue S2V = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, VarElt); + return DAG.getVectorShuffle(VT, dl, Ld, S2V, ShuffleMask); } // Special case for single non-zero, non-undef, element. @@ -9097,6 +9648,28 @@ static SmallVector<int, 64> createTargetShuffleMask(ArrayRef<int> Mask, return TargetMask; } +// Attempt to create a shuffle mask from a VSELECT condition mask. +static bool createShuffleMaskFromVSELECT(SmallVectorImpl<int> &Mask, + SDValue Cond) { + if (!ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) + return false; + + unsigned Size = Cond.getValueType().getVectorNumElements(); + Mask.resize(Size, SM_SentinelUndef); + + for (int i = 0; i != (int)Size; ++i) { + SDValue CondElt = Cond.getOperand(i); + Mask[i] = i; + // Arbitrarily choose from the 2nd operand if the select condition element + // is undef. + // TODO: Can we do better by matching patterns such as even/odd? + if (CondElt.isUndef() || isNullConstant(CondElt)) + Mask[i] += Size; + } + + return true; +} + // Check if the shuffle mask is suitable for the AVX vpunpcklwd or vpunpckhwd // instructions. static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT) { @@ -9664,11 +10237,7 @@ static SDValue lowerVectorShuffleAsBitBlend(const SDLoc &DL, MVT VT, SDValue V1, SDValue V1Mask = DAG.getBuildVector(VT, DL, MaskOps); V1 = DAG.getNode(ISD::AND, DL, VT, V1, V1Mask); - // We have to cast V2 around. - MVT MaskVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64); - V2 = DAG.getBitcast(VT, DAG.getNode(X86ISD::ANDNP, DL, MaskVT, - DAG.getBitcast(MaskVT, V1Mask), - DAG.getBitcast(MaskVT, V2))); + V2 = DAG.getNode(X86ISD::ANDNP, DL, VT, V1Mask, V2); return DAG.getNode(ISD::OR, DL, VT, V1, V2); } @@ -9762,7 +10331,6 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, case MVT::v8f32: return DAG.getNode(X86ISD::BLENDI, DL, VT, V1, V2, DAG.getConstant(BlendMask, DL, MVT::i8)); - case MVT::v4i64: case MVT::v8i32: assert(Subtarget.hasAVX2() && "256-bit integer blends require AVX2!"); @@ -9794,7 +10362,6 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, DAG.getNode(X86ISD::BLENDI, DL, MVT::v8i16, V1, V2, DAG.getConstant(BlendMask, DL, MVT::i8))); } - case MVT::v16i16: { assert(Subtarget.hasAVX2() && "256-bit integer blends require AVX2!"); SmallVector<int, 8> RepeatedMask; @@ -9808,6 +10375,20 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, return DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2, DAG.getConstant(BlendMask, DL, MVT::i8)); } + // Use PBLENDW for lower/upper lanes and then blend lanes. + // TODO - we should allow 2 PBLENDW here and leave shuffle combine to + // merge to VSELECT where useful. + uint64_t LoMask = BlendMask & 0xFF; + uint64_t HiMask = (BlendMask >> 8) & 0xFF; + if (LoMask == 0 || LoMask == 255 || HiMask == 0 || HiMask == 255) { + SDValue Lo = DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2, + DAG.getConstant(LoMask, DL, MVT::i8)); + SDValue Hi = DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2, + DAG.getConstant(HiMask, DL, MVT::i8)); + return DAG.getVectorShuffle( + MVT::v16i16, DL, Lo, Hi, + {0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 27, 28, 29, 30, 31}); + } LLVM_FALLTHROUGH; } case MVT::v16i8: @@ -9815,6 +10396,11 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, assert((VT.is128BitVector() || Subtarget.hasAVX2()) && "256-bit byte-blends require AVX2 support!"); + // Attempt to lower to a bitmask if we can. VPAND is faster than VPBLENDVB. + if (SDValue Masked = + lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable, DAG)) + return Masked; + if (Subtarget.hasBWI() && Subtarget.hasVLX()) { MVT IntegerType = MVT::getIntegerVT(std::max((int)VT.getVectorNumElements(), 8)); @@ -9822,11 +10408,6 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, return getVectorMaskingNode(V2, MaskNode, V1, Subtarget, DAG); } - // Attempt to lower to a bitmask if we can. VPAND is faster than VPBLENDVB. - if (SDValue Masked = - lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable, DAG)) - return Masked; - // Scale the blend by the number of bytes per element. int Scale = VT.getScalarSizeInBits() / 8; @@ -9834,6 +10415,15 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, // type. MVT BlendVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8); + // x86 allows load folding with blendvb from the 2nd source operand. But + // we are still using LLVM select here (see comment below), so that's V1. + // If V2 can be load-folded and V1 cannot be load-folded, then commute to + // allow that load-folding possibility. + if (!ISD::isNormalLoad(V1.getNode()) && ISD::isNormalLoad(V2.getNode())) { + ShuffleVectorSDNode::commuteMask(Mask); + std::swap(V1, V2); + } + // Compute the VSELECT mask. Note that VSELECT is really confusing in the // mix of LLVM's code generator and the x86 backend. We tell the code // generator that boolean values in the elements of an x86 vector register @@ -9884,7 +10474,8 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, static SDValue lowerVectorShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - SelectionDAG &DAG) { + SelectionDAG &DAG, + bool ImmBlends = false) { // We build up the blend mask while checking whether a blend is a viable way // to reduce the shuffle. SmallVector<int, 32> BlendMask(Mask.size(), -1); @@ -9904,10 +10495,168 @@ static SDValue lowerVectorShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT, PermuteMask[i] = Mask[i] % Size; } + // If only immediate blends, then bail if the blend mask can't be widened to + // i16. + unsigned EltSize = VT.getScalarSizeInBits(); + if (ImmBlends && EltSize == 8 && !canWidenShuffleElements(BlendMask)) + return SDValue(); + SDValue V = DAG.getVectorShuffle(VT, DL, V1, V2, BlendMask); return DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), PermuteMask); } +/// Try to lower as an unpack of elements from two inputs followed by +/// a single-input permutation. +/// +/// This matches the pattern where we can unpack elements from two inputs and +/// then reduce the shuffle to a single-input (wider) permutation. +static SDValue lowerVectorShuffleAsUNPCKAndPermute(const SDLoc &DL, MVT VT, + SDValue V1, SDValue V2, + ArrayRef<int> Mask, + SelectionDAG &DAG) { + int NumElts = Mask.size(); + int NumLanes = VT.getSizeInBits() / 128; + int NumLaneElts = NumElts / NumLanes; + int NumHalfLaneElts = NumLaneElts / 2; + + bool MatchLo = true, MatchHi = true; + SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)}; + + // Determine UNPCKL/UNPCKH type and operand order. + for (int Lane = 0; Lane != NumElts; Lane += NumLaneElts) { + for (int Elt = 0; Elt != NumLaneElts; ++Elt) { + int M = Mask[Lane + Elt]; + if (M < 0) + continue; + + SDValue &Op = Ops[Elt & 1]; + if (M < NumElts && (Op.isUndef() || Op == V1)) + Op = V1; + else if (NumElts <= M && (Op.isUndef() || Op == V2)) + Op = V2; + else + return SDValue(); + + int Lo = Lane, Mid = Lane + NumHalfLaneElts, Hi = Lane + NumLaneElts; + MatchLo &= isUndefOrInRange(M, Lo, Mid) || + isUndefOrInRange(M, NumElts + Lo, NumElts + Mid); + MatchHi &= isUndefOrInRange(M, Mid, Hi) || + isUndefOrInRange(M, NumElts + Mid, NumElts + Hi); + if (!MatchLo && !MatchHi) + return SDValue(); + } + } + assert((MatchLo ^ MatchHi) && "Failed to match UNPCKLO/UNPCKHI"); + + // Now check that each pair of elts come from the same unpack pair + // and set the permute mask based on each pair. + // TODO - Investigate cases where we permute individual elements. + SmallVector<int, 32> PermuteMask(NumElts, -1); + for (int Lane = 0; Lane != NumElts; Lane += NumLaneElts) { + for (int Elt = 0; Elt != NumLaneElts; Elt += 2) { + int M0 = Mask[Lane + Elt + 0]; + int M1 = Mask[Lane + Elt + 1]; + if (0 <= M0 && 0 <= M1 && + (M0 % NumHalfLaneElts) != (M1 % NumHalfLaneElts)) + return SDValue(); + if (0 <= M0) + PermuteMask[Lane + Elt + 0] = Lane + (2 * (M0 % NumHalfLaneElts)); + if (0 <= M1) + PermuteMask[Lane + Elt + 1] = Lane + (2 * (M1 % NumHalfLaneElts)) + 1; + } + } + + unsigned UnpckOp = MatchLo ? X86ISD::UNPCKL : X86ISD::UNPCKH; + SDValue Unpck = DAG.getNode(UnpckOp, DL, VT, Ops); + return DAG.getVectorShuffle(VT, DL, Unpck, DAG.getUNDEF(VT), PermuteMask); +} + +/// Helper to form a PALIGNR-based rotate+permute, merging 2 inputs and then +/// permuting the elements of the result in place. +static SDValue lowerVectorShuffleAsByteRotateAndPermute( + const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { + if ((VT.is128BitVector() && !Subtarget.hasSSSE3()) || + (VT.is256BitVector() && !Subtarget.hasAVX2()) || + (VT.is512BitVector() && !Subtarget.hasBWI())) + return SDValue(); + + // We don't currently support lane crossing permutes. + if (is128BitLaneCrossingShuffleMask(VT, Mask)) + return SDValue(); + + int Scale = VT.getScalarSizeInBits() / 8; + int NumLanes = VT.getSizeInBits() / 128; + int NumElts = VT.getVectorNumElements(); + int NumEltsPerLane = NumElts / NumLanes; + + // Determine range of mask elts. + bool Blend1 = true; + bool Blend2 = true; + std::pair<int, int> Range1 = std::make_pair(INT_MAX, INT_MIN); + std::pair<int, int> Range2 = std::make_pair(INT_MAX, INT_MIN); + for (int Lane = 0; Lane != NumElts; Lane += NumEltsPerLane) { + for (int Elt = 0; Elt != NumEltsPerLane; ++Elt) { + int M = Mask[Lane + Elt]; + if (M < 0) + continue; + if (M < NumElts) { + Blend1 &= (M == (Lane + Elt)); + assert(Lane <= M && M < (Lane + NumEltsPerLane) && "Out of range mask"); + M = M % NumEltsPerLane; + Range1.first = std::min(Range1.first, M); + Range1.second = std::max(Range1.second, M); + } else { + M -= NumElts; + Blend2 &= (M == (Lane + Elt)); + assert(Lane <= M && M < (Lane + NumEltsPerLane) && "Out of range mask"); + M = M % NumEltsPerLane; + Range2.first = std::min(Range2.first, M); + Range2.second = std::max(Range2.second, M); + } + } + } + + // Bail if we don't need both elements. + // TODO - it might be worth doing this for unary shuffles if the permute + // can be widened. + if (!(0 <= Range1.first && Range1.second < NumEltsPerLane) || + !(0 <= Range2.first && Range2.second < NumEltsPerLane)) + return SDValue(); + + if (VT.getSizeInBits() > 128 && (Blend1 || Blend2)) + return SDValue(); + + // Rotate the 2 ops so we can access both ranges, then permute the result. + auto RotateAndPermute = [&](SDValue Lo, SDValue Hi, int RotAmt, int Ofs) { + MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8); + SDValue Rotate = DAG.getBitcast( + VT, DAG.getNode(X86ISD::PALIGNR, DL, ByteVT, DAG.getBitcast(ByteVT, Hi), + DAG.getBitcast(ByteVT, Lo), + DAG.getConstant(Scale * RotAmt, DL, MVT::i8))); + SmallVector<int, 64> PermMask(NumElts, SM_SentinelUndef); + for (int Lane = 0; Lane != NumElts; Lane += NumEltsPerLane) { + for (int Elt = 0; Elt != NumEltsPerLane; ++Elt) { + int M = Mask[Lane + Elt]; + if (M < 0) + continue; + if (M < NumElts) + PermMask[Lane + Elt] = Lane + ((M + Ofs - RotAmt) % NumEltsPerLane); + else + PermMask[Lane + Elt] = Lane + ((M - Ofs - RotAmt) % NumEltsPerLane); + } + } + return DAG.getVectorShuffle(VT, DL, Rotate, DAG.getUNDEF(VT), PermMask); + }; + + // Check if the ranges are small enough to rotate from either direction. + if (Range2.second < Range1.first) + return RotateAndPermute(V1, V2, Range1.first, 0); + if (Range1.second < Range2.first) + return RotateAndPermute(V2, V1, Range2.first, NumElts); + return SDValue(); +} + /// Generic routine to decompose a shuffle and blend into independent /// blends and permutes. /// @@ -9915,11 +10664,9 @@ static SDValue lowerVectorShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT, /// shuffle+blend operations on newer X86 ISAs where we have very fast blend /// operations. It will try to pick the best arrangement of shuffles and /// blends. -static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(const SDLoc &DL, - MVT VT, SDValue V1, - SDValue V2, - ArrayRef<int> Mask, - SelectionDAG &DAG) { +static SDValue lowerVectorShuffleAsDecomposedShuffleBlend( + const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { // Shuffle the input elements into the desired positions in V1 and V2 and // blend them together. SmallVector<int, 32> V1Mask(Mask.size(), -1); @@ -9934,15 +10681,27 @@ static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(const SDLoc &DL, BlendMask[i] = i + Size; } - // Try to lower with the simpler initial blend strategy unless one of the - // input shuffles would be a no-op. We prefer to shuffle inputs as the - // shuffle may be able to fold with a load or other benefit. However, when - // we'll have to do 2x as many shuffles in order to achieve this, blending - // first is a better strategy. - if (!isNoopShuffleMask(V1Mask) && !isNoopShuffleMask(V2Mask)) + // Try to lower with the simpler initial blend/unpack/rotate strategies unless + // one of the input shuffles would be a no-op. We prefer to shuffle inputs as + // the shuffle may be able to fold with a load or other benefit. However, when + // we'll have to do 2x as many shuffles in order to achieve this, a 2-input + // pre-shuffle first is a better strategy. + if (!isNoopShuffleMask(V1Mask) && !isNoopShuffleMask(V2Mask)) { + // Only prefer immediate blends to unpack/rotate. + if (SDValue BlendPerm = lowerVectorShuffleAsBlendAndPermute( + DL, VT, V1, V2, Mask, DAG, true)) + return BlendPerm; + if (SDValue UnpackPerm = + lowerVectorShuffleAsUNPCKAndPermute(DL, VT, V1, V2, Mask, DAG)) + return UnpackPerm; + if (SDValue RotatePerm = lowerVectorShuffleAsByteRotateAndPermute( + DL, VT, V1, V2, Mask, Subtarget, DAG)) + return RotatePerm; + // Unpack/rotate failed - try again with variable blends. if (SDValue BlendPerm = lowerVectorShuffleAsBlendAndPermute(DL, VT, V1, V2, Mask, DAG)) return BlendPerm; + } V1 = DAG.getVectorShuffle(VT, DL, V1, DAG.getUNDEF(VT), V1Mask); V2 = DAG.getVectorShuffle(VT, DL, V2, DAG.getUNDEF(VT), V2Mask); @@ -10452,7 +11211,7 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend( MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Scale), NumElements / Scale); InputV = ShuffleOffset(InputV); - InputV = getExtendInVec(X86ISD::VZEXT, DL, ExtVT, InputV, DAG); + InputV = getExtendInVec(/*Signed*/false, DL, ExtVT, InputV, DAG); return DAG.getBitcast(VT, InputV); } @@ -10930,7 +11689,8 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, continue; } case ISD::CONCAT_VECTORS: { - int OperandSize = Mask.size() / V.getNumOperands(); + int OperandSize = + V.getOperand(0).getSimpleValueType().getVectorNumElements(); V = V.getOperand(BroadcastIdx / OperandSize); BroadcastIdx %= OperandSize; continue; @@ -10989,7 +11749,7 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue BC = peekThroughBitcasts(V); // Also check the simpler case, where we can directly reuse the scalar. - if (V.getOpcode() == ISD::BUILD_VECTOR || + if ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) || (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) { V = V.getOperand(BroadcastIdx); @@ -11204,10 +11964,9 @@ static SDValue lowerVectorShuffleAsInsertPS(const SDLoc &DL, SDValue V1, /// because for floating point vectors we have a generalized SHUFPS lowering /// strategy that handles everything that doesn't *exactly* match an unpack, /// making this clever lowering unnecessary. -static SDValue lowerVectorShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT, - SDValue V1, SDValue V2, - ArrayRef<int> Mask, - SelectionDAG &DAG) { +static SDValue lowerVectorShuffleAsPermuteAndUnpack( + const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { assert(!VT.isFloatingPoint() && "This routine only supports integer vectors."); assert(VT.is128BitVector() && @@ -11276,6 +12035,12 @@ static SDValue lowerVectorShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT, if (SDValue Unpack = TryUnpack(ScalarSize, ScalarSize / OrigScalarSize)) return Unpack; + // If we're shuffling with a zero vector then we're better off not doing + // VECTOR_SHUFFLE(UNPCK()) as we lose track of those zero elements. + if (ISD::isBuildVectorAllZeros(V1.getNode()) || + ISD::isBuildVectorAllZeros(V2.getNode())) + return SDValue(); + // If none of the unpack-rooted lowerings worked (or were profitable) try an // initial unpack. if (NumLoInputs == 0 || NumHiInputs == 0) { @@ -11475,7 +12240,7 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // a permute. That will be faster than the domain cross. if (IsBlendSupported) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v2i64, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); // We implement this with SHUFPD which is pretty lame because it will likely // incur 2 cycles of stall for integer vectors on Nehalem and older chips. @@ -11785,11 +12550,11 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // a permute. That will be faster than the domain cross. if (IsBlendSupported) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v4i32, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); // Try to lower by permuting the inputs into an unpack instruction. if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( - DL, MVT::v4i32, V1, V2, Mask, DAG)) + DL, MVT::v4i32, V1, V2, Mask, Subtarget, DAG)) return Unpack; } @@ -12321,47 +13086,48 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( /// blend if only one input is used. static SDValue lowerVectorShuffleAsBlendOfPSHUFBs( const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - const APInt &Zeroable, SelectionDAG &DAG, bool &V1InUse, - bool &V2InUse) { - SDValue V1Mask[16]; - SDValue V2Mask[16]; + const APInt &Zeroable, SelectionDAG &DAG, bool &V1InUse, bool &V2InUse) { + assert(!is128BitLaneCrossingShuffleMask(VT, Mask) && + "Lane crossing shuffle masks not supported"); + + int NumBytes = VT.getSizeInBits() / 8; + int Size = Mask.size(); + int Scale = NumBytes / Size; + + SmallVector<SDValue, 64> V1Mask(NumBytes, DAG.getUNDEF(MVT::i8)); + SmallVector<SDValue, 64> V2Mask(NumBytes, DAG.getUNDEF(MVT::i8)); V1InUse = false; V2InUse = false; - int Size = Mask.size(); - int Scale = 16 / Size; - for (int i = 0; i < 16; ++i) { - if (Mask[i / Scale] < 0) { - V1Mask[i] = V2Mask[i] = DAG.getUNDEF(MVT::i8); - } else { - const int ZeroMask = 0x80; - int V1Idx = Mask[i / Scale] < Size ? Mask[i / Scale] * Scale + i % Scale - : ZeroMask; - int V2Idx = Mask[i / Scale] < Size - ? ZeroMask - : (Mask[i / Scale] - Size) * Scale + i % Scale; - if (Zeroable[i / Scale]) - V1Idx = V2Idx = ZeroMask; - V1Mask[i] = DAG.getConstant(V1Idx, DL, MVT::i8); - V2Mask[i] = DAG.getConstant(V2Idx, DL, MVT::i8); - V1InUse |= (ZeroMask != V1Idx); - V2InUse |= (ZeroMask != V2Idx); - } + for (int i = 0; i < NumBytes; ++i) { + int M = Mask[i / Scale]; + if (M < 0) + continue; + + const int ZeroMask = 0x80; + int V1Idx = M < Size ? M * Scale + i % Scale : ZeroMask; + int V2Idx = M < Size ? ZeroMask : (M - Size) * Scale + i % Scale; + if (Zeroable[i / Scale]) + V1Idx = V2Idx = ZeroMask; + + V1Mask[i] = DAG.getConstant(V1Idx, DL, MVT::i8); + V2Mask[i] = DAG.getConstant(V2Idx, DL, MVT::i8); + V1InUse |= (ZeroMask != V1Idx); + V2InUse |= (ZeroMask != V2Idx); } + MVT ShufVT = MVT::getVectorVT(MVT::i8, NumBytes); if (V1InUse) - V1 = DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, - DAG.getBitcast(MVT::v16i8, V1), - DAG.getBuildVector(MVT::v16i8, DL, V1Mask)); + V1 = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, DAG.getBitcast(ShufVT, V1), + DAG.getBuildVector(ShufVT, DL, V1Mask)); if (V2InUse) - V2 = DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, - DAG.getBitcast(MVT::v16i8, V2), - DAG.getBuildVector(MVT::v16i8, DL, V2Mask)); + V2 = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, DAG.getBitcast(ShufVT, V2), + DAG.getBuildVector(ShufVT, DL, V2Mask)); // If we need shuffled inputs from both, blend the two. SDValue V; if (V1InUse && V2InUse) - V = DAG.getNode(ISD::OR, DL, MVT::v16i8, V1, V2); + V = DAG.getNode(ISD::OR, DL, ShufVT, V1, V2); else V = V1InUse ? V1 : V2; @@ -12484,8 +13250,8 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return BitBlend; // Try to lower by permuting the inputs into an unpack instruction. - if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack(DL, MVT::v8i16, V1, - V2, Mask, DAG)) + if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( + DL, MVT::v8i16, V1, V2, Mask, Subtarget, DAG)) return Unpack; // If we can't directly blend but can use PSHUFB, that will be better as it @@ -12499,7 +13265,7 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // We can always bit-blend if we have to so the fallback strategy is to // decompose into single-input permutes and blends. return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v8i16, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); } /// Check whether a compaction lowering can be done by dropping even @@ -12632,6 +13398,10 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG)) return Broadcast; + if (SDValue V = + lowerVectorShuffleWithUNPCK(DL, MVT::v16i8, Mask, V1, V2, DAG)) + return V; + // Check whether we can widen this to an i16 shuffle by duplicating bytes. // Notably, this handles splat and partial-splat shuffles more efficiently. // However, it only makes sense if the pre-duplication shuffle simplifies @@ -12769,12 +13539,18 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // shuffles will both be pshufb, in which case we shouldn't bother with // this. if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( - DL, MVT::v16i8, V1, V2, Mask, DAG)) + DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG)) return Unpack; // If we have VBMI we can use one VPERM instead of multiple PSHUFBs. if (Subtarget.hasVBMI() && Subtarget.hasVLX()) return lowerVectorShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, DAG); + + // Use PALIGNR+Permute if possible - permute might become PSHUFB but the + // PALIGNR will be cheaper than the second PSHUFB+OR. + if (SDValue V = lowerVectorShuffleAsByteRotateAndPermute( + DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG)) + return V; } return PSHUFB; @@ -12830,7 +13606,7 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Handle multi-input cases by blending single-input shuffles. if (NumV2Elements > 0) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v16i8, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); // The fallback path for single-input shuffles widens this into two v8i16 // vectors with unpacks, shuffles those, and then pulls them back together @@ -13043,6 +13819,7 @@ static SDValue splitAndLowerVectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { assert(!V2.isUndef() && "This routine must not be used to lower single-input " "shuffles as it could then recurse on itself."); @@ -13069,7 +13846,7 @@ static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, }; if (DoBothBroadcast()) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, VT, V1, V2, Mask, - DAG); + Subtarget, DAG); // If the inputs all stem from a single 128-bit lane of each input, then we // split them rather than blending because the split will decompose to @@ -13087,7 +13864,62 @@ static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, // Otherwise, just fall back to decomposed shuffles and a blend. This requires // that the decomposed single-input shuffles don't end up here. - return lowerVectorShuffleAsDecomposedShuffleBlend(DL, VT, V1, V2, Mask, DAG); + return lowerVectorShuffleAsDecomposedShuffleBlend(DL, VT, V1, V2, Mask, + Subtarget, DAG); +} + +/// Lower a vector shuffle crossing multiple 128-bit lanes as +/// a lane permutation followed by a per-lane permutation. +/// +/// This is mainly for cases where we can have non-repeating permutes +/// in each lane. +/// +/// TODO: This is very similar to lowerVectorShuffleByMerging128BitLanes, +/// we should investigate merging them. +static SDValue lowerVectorShuffleAsLanePermuteAndPermute( + const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + SelectionDAG &DAG, const X86Subtarget &Subtarget) { + int NumElts = VT.getVectorNumElements(); + int NumLanes = VT.getSizeInBits() / 128; + int NumEltsPerLane = NumElts / NumLanes; + + SmallVector<int, 4> SrcLaneMask(NumLanes, SM_SentinelUndef); + SmallVector<int, 16> LaneMask(NumElts, SM_SentinelUndef); + SmallVector<int, 16> PermMask(NumElts, SM_SentinelUndef); + + for (int i = 0; i != NumElts; ++i) { + int M = Mask[i]; + if (M < 0) + continue; + + // Ensure that each lane comes from a single source lane. + int SrcLane = M / NumEltsPerLane; + int DstLane = i / NumEltsPerLane; + if (!isUndefOrEqual(SrcLaneMask[DstLane], SrcLane)) + return SDValue(); + SrcLaneMask[DstLane] = SrcLane; + + LaneMask[i] = (SrcLane * NumEltsPerLane) + (i % NumEltsPerLane); + PermMask[i] = (DstLane * NumEltsPerLane) + (M % NumEltsPerLane); + } + + // If we're only shuffling a single lowest lane and the rest are identity + // then don't bother. + // TODO - isShuffleMaskInputInPlace could be extended to something like this. + int NumIdentityLanes = 0; + bool OnlyShuffleLowestLane = true; + for (int i = 0; i != NumLanes; ++i) { + if (isSequentialOrUndefInRange(PermMask, i * NumEltsPerLane, NumEltsPerLane, + i * NumEltsPerLane)) + NumIdentityLanes++; + else if (SrcLaneMask[i] != 0 && SrcLaneMask[i] != NumLanes) + OnlyShuffleLowestLane = false; + } + if (OnlyShuffleLowestLane && NumIdentityLanes == (NumLanes - 1)) + return SDValue(); + + SDValue LanePermute = DAG.getVectorShuffle(VT, DL, V1, V2, LaneMask); + return DAG.getVectorShuffle(VT, DL, LanePermute, DAG.getUNDEF(VT), PermMask); } /// Lower a vector shuffle crossing multiple 128-bit lanes as @@ -13248,79 +14080,174 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, /// Lower a vector shuffle by first fixing the 128-bit lanes and then /// shuffling each lane. /// -/// This will only succeed when the result of fixing the 128-bit lanes results -/// in a single-input non-lane-crossing shuffle with a repeating shuffle mask in -/// each 128-bit lanes. This handles many cases where we can quickly blend away -/// the lane crosses early and then use simpler shuffles within each lane. +/// This attempts to create a repeated lane shuffle where each lane uses one +/// or two of the lanes of the inputs. The lanes of the input vectors are +/// shuffled in one or two independent shuffles to get the lanes into the +/// position needed by the final shuffle. /// -/// FIXME: It might be worthwhile at some point to support this without -/// requiring the 128-bit lane-relative shuffles to be repeating, but currently -/// in x86 only floating point has interesting non-repeating shuffles, and even -/// those are still *marginally* more expensive. +/// FIXME: This should be generalized to 512-bit shuffles. static SDValue lowerVectorShuffleByMerging128BitLanes( const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, const X86Subtarget &Subtarget, SelectionDAG &DAG) { assert(!V2.isUndef() && "This is only useful with multiple inputs."); + if (is128BitLaneRepeatedShuffleMask(VT, Mask)) + return SDValue(); + int Size = Mask.size(); int LaneSize = 128 / VT.getScalarSizeInBits(); int NumLanes = Size / LaneSize; - assert(NumLanes > 1 && "Only handles 256-bit and wider shuffles."); + assert(NumLanes == 2 && "Only handles 256-bit shuffles."); + + SmallVector<int, 16> RepeatMask(LaneSize, -1); + int LaneSrcs[2][2] = { { -1, -1 }, { -1 , -1 } }; + + // First pass will try to fill in the RepeatMask from lanes that need two + // sources. + for (int Lane = 0; Lane != NumLanes; ++Lane) { + int Srcs[2] = { -1, -1 }; + SmallVector<int, 16> InLaneMask(LaneSize, -1); + for (int i = 0; i != LaneSize; ++i) { + int M = Mask[(Lane * LaneSize) + i]; + if (M < 0) + continue; + // Determine which of the 4 possible input lanes (2 from each source) + // this element comes from. Assign that as one of the sources for this + // lane. We can assign up to 2 sources for this lane. If we run out + // sources we can't do anything. + int LaneSrc = M / LaneSize; + int Src; + if (Srcs[0] < 0 || Srcs[0] == LaneSrc) + Src = 0; + else if (Srcs[1] < 0 || Srcs[1] == LaneSrc) + Src = 1; + else + return SDValue(); - // See if we can build a hypothetical 128-bit lane-fixing shuffle mask. Also - // check whether the in-128-bit lane shuffles share a repeating pattern. - SmallVector<int, 4> Lanes((unsigned)NumLanes, -1); - SmallVector<int, 4> InLaneMask((unsigned)LaneSize, -1); - for (int i = 0; i < Size; ++i) { - if (Mask[i] < 0) + Srcs[Src] = LaneSrc; + InLaneMask[i] = (M % LaneSize) + Src * Size; + } + + // If this lane has two sources, see if it fits with the repeat mask so far. + if (Srcs[1] < 0) continue; - int j = i / LaneSize; + LaneSrcs[Lane][0] = Srcs[0]; + LaneSrcs[Lane][1] = Srcs[1]; - if (Lanes[j] < 0) { - // First entry we've seen for this lane. - Lanes[j] = Mask[i] / LaneSize; - } else if (Lanes[j] != Mask[i] / LaneSize) { - // This doesn't match the lane selected previously! - return SDValue(); + auto MatchMasks = [](ArrayRef<int> M1, ArrayRef<int> M2) { + assert(M1.size() == M2.size() && "Unexpected mask size"); + for (int i = 0, e = M1.size(); i != e; ++i) + if (M1[i] >= 0 && M2[i] >= 0 && M1[i] != M2[i]) + return false; + return true; + }; + + auto MergeMasks = [](ArrayRef<int> Mask, MutableArrayRef<int> MergedMask) { + assert(Mask.size() == MergedMask.size() && "Unexpected mask size"); + for (int i = 0, e = MergedMask.size(); i != e; ++i) { + int M = Mask[i]; + if (M < 0) + continue; + assert((MergedMask[i] < 0 || MergedMask[i] == M) && + "Unexpected mask element"); + MergedMask[i] = M; + } + }; + + if (MatchMasks(InLaneMask, RepeatMask)) { + // Merge this lane mask into the final repeat mask. + MergeMasks(InLaneMask, RepeatMask); + continue; } - // Check that within each lane we have a consistent shuffle mask. - int k = i % LaneSize; - if (InLaneMask[k] < 0) { - InLaneMask[k] = Mask[i] % LaneSize; - } else if (InLaneMask[k] != Mask[i] % LaneSize) { - // This doesn't fit a repeating in-lane mask. - return SDValue(); + // Didn't find a match. Swap the operands and try again. + std::swap(LaneSrcs[Lane][0], LaneSrcs[Lane][1]); + ShuffleVectorSDNode::commuteMask(InLaneMask); + + if (MatchMasks(InLaneMask, RepeatMask)) { + // Merge this lane mask into the final repeat mask. + MergeMasks(InLaneMask, RepeatMask); + continue; } + + // Couldn't find a match with the operands in either order. + return SDValue(); } - // First shuffle the lanes into place. - MVT LaneVT = MVT::getVectorVT(VT.isFloatingPoint() ? MVT::f64 : MVT::i64, - VT.getSizeInBits() / 64); - SmallVector<int, 8> LaneMask((unsigned)NumLanes * 2, -1); - for (int i = 0; i < NumLanes; ++i) - if (Lanes[i] >= 0) { - LaneMask[2 * i + 0] = 2*Lanes[i] + 0; - LaneMask[2 * i + 1] = 2*Lanes[i] + 1; + // Now handle any lanes with only one source. + for (int Lane = 0; Lane != NumLanes; ++Lane) { + // If this lane has already been processed, skip it. + if (LaneSrcs[Lane][0] >= 0) + continue; + + for (int i = 0; i != LaneSize; ++i) { + int M = Mask[(Lane * LaneSize) + i]; + if (M < 0) + continue; + + // If RepeatMask isn't defined yet we can define it ourself. + if (RepeatMask[i] < 0) + RepeatMask[i] = M % LaneSize; + + if (RepeatMask[i] < Size) { + if (RepeatMask[i] != M % LaneSize) + return SDValue(); + LaneSrcs[Lane][0] = M / LaneSize; + } else { + if (RepeatMask[i] != ((M % LaneSize) + Size)) + return SDValue(); + LaneSrcs[Lane][1] = M / LaneSize; + } } - V1 = DAG.getBitcast(LaneVT, V1); - V2 = DAG.getBitcast(LaneVT, V2); - SDValue LaneShuffle = DAG.getVectorShuffle(LaneVT, DL, V1, V2, LaneMask); + if (LaneSrcs[Lane][0] < 0 && LaneSrcs[Lane][1] < 0) + return SDValue(); + } + + SmallVector<int, 16> NewMask(Size, -1); + for (int Lane = 0; Lane != NumLanes; ++Lane) { + int Src = LaneSrcs[Lane][0]; + for (int i = 0; i != LaneSize; ++i) { + int M = -1; + if (Src >= 0) + M = Src * LaneSize + i; + NewMask[Lane * LaneSize + i] = M; + } + } + SDValue NewV1 = DAG.getVectorShuffle(VT, DL, V1, V2, NewMask); + // Ensure we didn't get back the shuffle we started with. + // FIXME: This is a hack to make up for some splat handling code in + // getVectorShuffle. + if (isa<ShuffleVectorSDNode>(NewV1) && + cast<ShuffleVectorSDNode>(NewV1)->getMask() == Mask) + return SDValue(); - // Cast it back to the type we actually want. - LaneShuffle = DAG.getBitcast(VT, LaneShuffle); + for (int Lane = 0; Lane != NumLanes; ++Lane) { + int Src = LaneSrcs[Lane][1]; + for (int i = 0; i != LaneSize; ++i) { + int M = -1; + if (Src >= 0) + M = Src * LaneSize + i; + NewMask[Lane * LaneSize + i] = M; + } + } + SDValue NewV2 = DAG.getVectorShuffle(VT, DL, V1, V2, NewMask); + // Ensure we didn't get back the shuffle we started with. + // FIXME: This is a hack to make up for some splat handling code in + // getVectorShuffle. + if (isa<ShuffleVectorSDNode>(NewV2) && + cast<ShuffleVectorSDNode>(NewV2)->getMask() == Mask) + return SDValue(); - // Now do a simple shuffle that isn't lane crossing. - SmallVector<int, 8> NewMask((unsigned)Size, -1); - for (int i = 0; i < Size; ++i) - if (Mask[i] >= 0) - NewMask[i] = (i / LaneSize) * LaneSize + Mask[i] % LaneSize; - assert(!is128BitLaneCrossingShuffleMask(VT, NewMask) && - "Must not introduce lane crosses at this point!"); + for (int i = 0; i != Size; ++i) { + NewMask[i] = RepeatMask[i % LaneSize]; + if (NewMask[i] < 0) + continue; - return DAG.getVectorShuffle(VT, DL, LaneShuffle, DAG.getUNDEF(VT), NewMask); + NewMask[i] += (i / LaneSize) * LaneSize; + } + return DAG.getVectorShuffle(VT, DL, NewV1, NewV2, NewMask); } /// Lower shuffles where an entire half of a 256 or 512-bit vector is UNDEF. @@ -13731,6 +14658,11 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) return V; + // Try to permute the lanes and then use a per-lane permute. + if (SDValue V = lowerVectorShuffleAsLanePermuteAndPermute( + DL, MVT::v4f64, V1, V2, Mask, DAG, Subtarget)) + return V; + // Otherwise, fall back. return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v4f64, V1, V2, Mask, DAG, Subtarget); @@ -13765,6 +14697,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) return Result; + // If we have VLX support, we can use VEXPAND. if (Subtarget.hasVLX()) if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v4f64, Zeroable, Mask, @@ -13775,10 +14708,11 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // can fully permute the elements. if (Subtarget.hasAVX2()) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v4f64, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); // Otherwise fall back on generic lowering. - return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v4f64, V1, V2, Mask, DAG); + return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v4f64, V1, V2, Mask, + Subtarget, DAG); } /// Handle lowering of 4-lane 64-bit integer shuffles. @@ -13872,7 +14806,7 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Otherwise fall back on generic blend lowering. return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v4i64, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); } /// Handle lowering of 8-lane 32-bit floating point shuffles. @@ -13961,17 +14895,18 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // vpunpckhwd instrs than vblend. if (!Subtarget.hasAVX512() && isUnpackWdShuffleMask(Mask, MVT::v8f32)) if (SDValue V = lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, - Mask, DAG)) + Mask, Subtarget, DAG)) return V; // If we have AVX2 then we always want to lower with a blend because at v8 we // can fully permute the elements. if (Subtarget.hasAVX2()) return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v8f32, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); // Otherwise fall back on generic lowering. - return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask, DAG); + return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask, + Subtarget, DAG); } /// Handle lowering of 8-lane 32-bit integer shuffles. @@ -14000,8 +14935,8 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // vpunpcklwd and vpunpckhwd instrs. if (isUnpackWdShuffleMask(Mask, MVT::v8i32) && !V2.isUndef() && !Subtarget.hasAVX512()) - if (SDValue V = - lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8i32, V1, V2, Mask, DAG)) + if (SDValue V = lowerVectorShuffleAsSplitOrBlend(DL, MVT::v8i32, V1, V2, + Mask, Subtarget, DAG)) return V; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v8i32, V1, V2, Mask, @@ -14084,7 +15019,7 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Otherwise fall back on generic blend lowering. return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v8i32, V1, V2, - Mask, DAG); + Mask, Subtarget, DAG); } /// Handle lowering of 16-lane 16-bit integer shuffles. @@ -14146,9 +15081,14 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (V2.isUndef()) { // There are no generalized cross-lane shuffle operations available on i16 // element types. - if (is128BitLaneCrossingShuffleMask(MVT::v16i16, Mask)) + if (is128BitLaneCrossingShuffleMask(MVT::v16i16, Mask)) { + if (SDValue V = lowerVectorShuffleAsLanePermuteAndPermute( + DL, MVT::v16i16, V1, V2, Mask, DAG, Subtarget)) + return V; + return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v16i16, V1, V2, Mask, DAG, Subtarget); + } SmallVector<int, 8> RepeatedMask; if (is128BitLaneRepeatedShuffleMask(MVT::v16i16, Mask, RepeatedMask)) { @@ -14174,8 +15114,14 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v16i16, V1, V2, Mask, Subtarget, DAG)) return Result; + // Try to permute the lanes and then use a per-lane permute. + if (SDValue V = lowerVectorShuffleAsLanePermuteAndPermute( + DL, MVT::v16i16, V1, V2, Mask, DAG, Subtarget)) + return V; + // Otherwise fall back on generic lowering. - return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v16i16, V1, V2, Mask, DAG); + return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v16i16, V1, V2, Mask, + Subtarget, DAG); } /// Handle lowering of 32-lane 8-bit integer shuffles. @@ -14236,9 +15182,14 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // There are no generalized cross-lane shuffle operations available on i8 // element types. - if (V2.isUndef() && is128BitLaneCrossingShuffleMask(MVT::v32i8, Mask)) + if (V2.isUndef() && is128BitLaneCrossingShuffleMask(MVT::v32i8, Mask)) { + if (SDValue V = lowerVectorShuffleAsLanePermuteAndPermute( + DL, MVT::v32i8, V1, V2, Mask, DAG, Subtarget)) + return V; + return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v32i8, V1, V2, Mask, DAG, Subtarget); + } if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG)) @@ -14254,8 +15205,14 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v32i8, V1, V2, Mask, Subtarget, DAG)) return Result; + // Try to permute the lanes and then use a per-lane permute. + if (SDValue V = lowerVectorShuffleAsLanePermuteAndPermute( + DL, MVT::v32i8, V1, V2, Mask, DAG, Subtarget)) + return V; + // Otherwise fall back on generic lowering. - return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v32i8, V1, V2, Mask, DAG); + return lowerVectorShuffleAsSplitOrBlend(DL, MVT::v32i8, V1, V2, Mask, + Subtarget, DAG); } /// High-level routine to lower various 256-bit x86 vector shuffles. @@ -14757,6 +15714,11 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v64i8, Mask, V1, V2, DAG)) return V; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v64i8, Mask, V1, V2, DAG, + Subtarget)) + return V; + // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v64i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -14845,6 +15807,39 @@ static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } } +// Determine if this shuffle can be implemented with a KSHIFT instruction. +// Returns the shift amount if possible or -1 if not. This is a simplified +// version of matchVectorShuffleAsShift. +static int match1BitShuffleAsKSHIFT(unsigned &Opcode, ArrayRef<int> Mask, + int MaskOffset, const APInt &Zeroable) { + int Size = Mask.size(); + + auto CheckZeros = [&](int Shift, bool Left) { + for (int j = 0; j < Shift; ++j) + if (!Zeroable[j + (Left ? 0 : (Size - Shift))]) + return false; + + return true; + }; + + auto MatchShift = [&](int Shift, bool Left) { + unsigned Pos = Left ? Shift : 0; + unsigned Low = Left ? 0 : Shift; + unsigned Len = Size - Shift; + return isSequentialOrUndefInRange(Mask, Pos, Len, Low + MaskOffset); + }; + + for (int Shift = 1; Shift != Size; ++Shift) + for (bool Left : {true, false}) + if (CheckZeros(Shift, Left) && MatchShift(Shift, Left)) { + Opcode = Left ? X86ISD::KSHIFTL : X86ISD::KSHIFTR; + return Shift; + } + + return -1; +} + + // Lower vXi1 vector shuffles. // There is no a dedicated instruction on AVX-512 that shuffles the masks. // The only way to shuffle bits is to sign-extend the mask vector to SIMD @@ -14854,6 +15849,9 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, const APInt &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + assert(Subtarget.hasAVX512() && + "Cannot lower 512-bit vectors w/o basic ISA!"); + unsigned NumElts = Mask.size(); // Try to recognize shuffles that are just padding a subvector with zeros. @@ -14880,9 +15878,21 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Extract, DAG.getIntPtrConstant(0, DL)); } + // Try to match KSHIFTs. + // TODO: Support narrower than legal shifts by widening and extracting. + if (NumElts >= 16 || (Subtarget.hasDQI() && NumElts == 8)) { + unsigned Offset = 0; + for (SDValue V : { V1, V2 }) { + unsigned Opcode; + int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable); + if (ShiftAmt >= 0) + return DAG.getNode(Opcode, DL, VT, V, + DAG.getConstant(ShiftAmt, DL, MVT::i8)); + Offset += NumElts; // Increment for next iteration. + } + } + - assert(Subtarget.hasAVX512() && - "Cannot lower 512-bit vectors w/o basic ISA!"); MVT ExtVT; switch (VT.SimpleTy) { default: @@ -15069,6 +16079,14 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, SmallVector<int, 16> WidenedMask; if (VT.getScalarSizeInBits() < 64 && !Is1BitVector && canWidenShuffleElements(ZeroableMask, WidenedMask)) { + // Shuffle mask widening should not interfere with a broadcast opportunity + // by obfuscating the operands with bitcasts. + // TODO: Avoid lowering directly from this top-level function: make this + // a query (canLowerAsBroadcast) and defer lowering to the type-based calls. + if (SDValue Broadcast = + lowerVectorShuffleAsBroadcast(DL, VT, V1, V2, Mask, Subtarget, DAG)) + return Broadcast; + MVT NewEltVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(VT.getScalarSizeInBits() * 2) : MVT::getIntegerVT(VT.getScalarSizeInBits() * 2); @@ -15135,34 +16153,27 @@ static SDValue lowerVSELECTtoVectorShuffle(SDValue Op, SDValue Cond = Op.getOperand(0); SDValue LHS = Op.getOperand(1); SDValue RHS = Op.getOperand(2); - SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); - if (!ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) - return SDValue(); - auto *CondBV = cast<BuildVectorSDNode>(Cond); - // Only non-legal VSELECTs reach this lowering, convert those into generic // shuffles and re-use the shuffle lowering path for blends. SmallVector<int, 32> Mask; - for (int i = 0, Size = VT.getVectorNumElements(); i < Size; ++i) { - SDValue CondElt = CondBV->getOperand(i); - int M = i; - // We can't map undef to undef here. They have different meanings. Treat - // as the same as zero. - if (CondElt.isUndef() || isNullConstant(CondElt)) - M += Size; - Mask.push_back(M); - } - return DAG.getVectorShuffle(VT, dl, LHS, RHS, Mask); + if (createShuffleMaskFromVSELECT(Mask, Cond)) + return DAG.getVectorShuffle(VT, SDLoc(Op), LHS, RHS, Mask); + + return SDValue(); } SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { + SDValue Cond = Op.getOperand(0); + SDValue LHS = Op.getOperand(1); + SDValue RHS = Op.getOperand(2); + // A vselect where all conditions and data are constants can be optimized into // a single vector load by SelectionDAGLegalize::ExpandBUILD_VECTOR(). - if (ISD::isBuildVectorOfConstantSDNodes(Op.getOperand(0).getNode()) && - ISD::isBuildVectorOfConstantSDNodes(Op.getOperand(1).getNode()) && - ISD::isBuildVectorOfConstantSDNodes(Op.getOperand(2).getNode())) + if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()) && + ISD::isBuildVectorOfConstantSDNodes(LHS.getNode()) && + ISD::isBuildVectorOfConstantSDNodes(RHS.getNode())) return SDValue(); // Try to lower this to a blend-style vector shuffle. This can handle all @@ -15172,7 +16183,9 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { // If this VSELECT has a vector if i1 as a mask, it will be directly matched // with patterns on the mask registers on AVX-512. - if (Op->getOperand(0).getValueType().getScalarSizeInBits() == 1) + MVT CondVT = Cond.getSimpleValueType(); + unsigned CondEltSize = Cond.getScalarValueSizeInBits(); + if (CondEltSize == 1) return Op; // Variable blends are only legal from SSE4.1 onward. @@ -15181,24 +16194,32 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); + unsigned EltSize = VT.getScalarSizeInBits(); + unsigned NumElts = VT.getVectorNumElements(); // If the VSELECT is on a 512-bit type, we have to convert a non-i1 condition // into an i1 condition so that we can use the mask-based 512-bit blend // instructions. if (VT.getSizeInBits() == 512) { - SDValue Cond = Op.getOperand(0); - // The vNi1 condition case should be handled above as it can be trivially - // lowered. - assert(Cond.getValueType().getScalarSizeInBits() == - VT.getScalarSizeInBits() && - "Should have a size-matched integer condition!"); // Build a mask by testing the condition against zero. - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts); SDValue Mask = DAG.getSetCC(dl, MaskVT, Cond, - getZeroVector(VT, Subtarget, DAG, dl), + DAG.getConstant(0, dl, CondVT), ISD::SETNE); // Now return a new VSELECT using the mask. - return DAG.getSelect(dl, VT, Mask, Op.getOperand(1), Op.getOperand(2)); + return DAG.getSelect(dl, VT, Mask, LHS, RHS); + } + + // SEXT/TRUNC cases where the mask doesn't match the destination size. + if (CondEltSize != EltSize) { + // If we don't have a sign splat, rely on the expansion. + if (CondEltSize != DAG.ComputeNumSignBits(Cond)) + return SDValue(); + + MVT NewCondSVT = MVT::getIntegerVT(EltSize); + MVT NewCondVT = MVT::getVectorVT(NewCondSVT, NumElts); + Cond = DAG.getSExtOrTrunc(Cond, dl, NewCondVT); + return DAG.getNode(ISD::VSELECT, dl, VT, Cond, LHS, RHS); } // Only some types will be legal on some subtargets. If we can emit a legal @@ -15219,10 +16240,10 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { case MVT::v8i16: case MVT::v16i16: { // Bitcast everything to the vXi8 type and use a vXi8 vselect. - MVT CastVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2); - SDValue Cond = DAG.getBitcast(CastVT, Op->getOperand(0)); - SDValue LHS = DAG.getBitcast(CastVT, Op->getOperand(1)); - SDValue RHS = DAG.getBitcast(CastVT, Op->getOperand(2)); + MVT CastVT = MVT::getVectorVT(MVT::i8, NumElts * 2); + Cond = DAG.getBitcast(CastVT, Cond); + LHS = DAG.getBitcast(CastVT, LHS); + RHS = DAG.getBitcast(CastVT, RHS); SDValue Select = DAG.getNode(ISD::VSELECT, dl, CastVT, Cond, LHS, RHS); return DAG.getBitcast(VT, Select); } @@ -15298,34 +16319,25 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, } unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + if (IdxVal == 0) // the operation is legal + return Op; - // If the kshift instructions of the correct width aren't natively supported - // then we need to promote the vector to the native size to get the correct - // zeroing behavior. - if (VecVT.getVectorNumElements() < 16) { - VecVT = MVT::v16i1; - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, - DAG.getUNDEF(VecVT), Vec, + // Extend to natively supported kshift. + unsigned NumElems = VecVT.getVectorNumElements(); + MVT WideVecVT = VecVT; + if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) { + WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT, + DAG.getUNDEF(WideVecVT), Vec, DAG.getIntPtrConstant(0, dl)); } - // Extracts from element 0 are always allowed. - if (IdxVal != 0) { - // Use kshiftr instruction to move to the lower element. - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - } - - // Shrink to v16i1 since that's always legal. - if (VecVT.getVectorNumElements() > 16) { - VecVT = MVT::v16i1; - Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Vec, - DAG.getIntPtrConstant(0, dl)); - } + // Use kshiftr instruction to move to the lower element. + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); - // Convert to a bitcast+aext/trunc. - MVT CastVT = MVT::getIntegerVT(VecVT.getVectorNumElements()); - return DAG.getAnyExtOrTrunc(DAG.getBitcast(CastVT, Vec), dl, EltVT); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec, + DAG.getIntPtrConstant(0, dl)); } SDValue @@ -15793,7 +16805,7 @@ X86TargetLowering::LowerExternalSymbol(SDValue Op, SelectionDAG &DAG) const { Result = DAG.getNode(getGlobalWrapperKind(), DL, PtrVT, Result); // With PIC, the address is actually $g + Offset. - if (isPositionIndependent() && !Subtarget.is64Bit()) { + if (OpFlag) { Result = DAG.getNode(ISD::ADD, DL, PtrVT, DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT), Result); @@ -16173,6 +17185,7 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { /// Lower SRA_PARTS and friends, which return two i32 values /// and take a 2 x i32 value to shift plus a shift amount. +/// TODO: Can this be moved to general expansion code? static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) { assert(Op.getNumOperands() == 3 && "Not a double-shift!"); MVT VT = Op.getSimpleValueType(); @@ -16182,8 +17195,8 @@ static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) { SDValue ShOpLo = Op.getOperand(0); SDValue ShOpHi = Op.getOperand(1); SDValue ShAmt = Op.getOperand(2); - // X86ISD::SHLD and X86ISD::SHRD have defined overflow behavior but the - // generic ISD nodes haven't. Insert an AND to be safe, it's optimized away + // ISD::FSHL and ISD::FSHR have defined overflow behavior but ISD::SHL and + // ISD::SRA/L nodes haven't. Insert an AND to be safe, it's optimized away // during isel. SDValue SafeShAmt = DAG.getNode(ISD::AND, dl, MVT::i8, ShAmt, DAG.getConstant(VTBits - 1, dl, MVT::i8)); @@ -16193,10 +17206,10 @@ static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) { SDValue Tmp2, Tmp3; if (Op.getOpcode() == ISD::SHL_PARTS) { - Tmp2 = DAG.getNode(X86ISD::SHLD, dl, VT, ShOpHi, ShOpLo, ShAmt); + Tmp2 = DAG.getNode(ISD::FSHL, dl, VT, ShOpHi, ShOpLo, ShAmt); Tmp3 = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, SafeShAmt); } else { - Tmp2 = DAG.getNode(X86ISD::SHRD, dl, VT, ShOpLo, ShOpHi, ShAmt); + Tmp2 = DAG.getNode(ISD::FSHR, dl, VT, ShOpHi, ShOpLo, ShAmt); Tmp3 = DAG.getNode(isSRA ? ISD::SRA : ISD::SRL, dl, VT, ShOpHi, SafeShAmt); } @@ -16220,6 +17233,56 @@ static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) { return DAG.getMergeValues({ Lo, Hi }, dl); } +static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + MVT VT = Op.getSimpleValueType(); + assert((Op.getOpcode() == ISD::FSHL || Op.getOpcode() == ISD::FSHR) && + "Unexpected funnel shift opcode!"); + + SDLoc DL(Op); + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + SDValue Amt = Op.getOperand(2); + + bool IsFSHR = Op.getOpcode() == ISD::FSHR; + + if (VT.isVector()) { + assert(Subtarget.hasVBMI2() && "Expected VBMI2"); + + if (IsFSHR) + std::swap(Op0, Op1); + + APInt APIntShiftAmt; + if (isConstantSplat(Amt, APIntShiftAmt)) { + uint64_t ShiftAmt = APIntShiftAmt.getZExtValue(); + return DAG.getNode(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT, + Op0, Op1, DAG.getConstant(ShiftAmt, DL, MVT::i8)); + } + + return DAG.getNode(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT, + Op0, Op1, Amt); + } + + assert((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) && + "Unexpected funnel shift type!"); + + // Expand slow SHLD/SHRD cases if we are not optimizing for size. + bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); + if (!OptForSize && Subtarget.isSHLDSlow()) + return SDValue(); + + if (IsFSHR) + std::swap(Op0, Op1); + + // i16 needs to modulo the shift amount, but i32/i64 have implicit modulo. + if (VT == MVT::i16) + Amt = DAG.getNode(ISD::AND, DL, Amt.getValueType(), Amt, + DAG.getConstant(15, DL, Amt.getValueType())); + + unsigned SHDOp = (IsFSHR ? X86ISD::SHRD : X86ISD::SHLD); + return DAG.getNode(SHDOp, DL, VT, Op0, Op1, Amt); +} + // Try to use a packed vector operation to handle i64 on 32-bit targets when // AVX512DQ is enabled. static SDValue LowerI64IntToFP_AVX512DQ(SDValue Op, SelectionDAG &DAG, @@ -16271,9 +17334,8 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, // Legal. if (SrcVT == MVT::i32 && isScalarFPTypeInSSEReg(VT)) return Op; - if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) && Subtarget.is64Bit()) { + if (SrcVT == MVT::i64 && isScalarFPTypeInSSEReg(VT) && Subtarget.is64Bit()) return Op; - } if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, DAG, Subtarget)) return V; @@ -16331,7 +17393,7 @@ SDValue X86TargetLowering::BuildFILD(SDValue Op, EVT SrcVT, SDValue Chain, Chain = Result.getValue(1); SDValue InFlag = Result.getValue(2); - // FIXME: Currently the FST is flagged to the FILD_FLAG. This + // FIXME: Currently the FST is glued to the FILD_FLAG. This // shouldn't be necessary except that RFP cannot be live across // multiple blocks. When stackifier is fixed, they can be uncoupled. MachineFunction &MF = DAG.getMachineFunction(); @@ -16412,13 +17474,11 @@ static SDValue LowerUINT_TO_FP_i64(SDValue Op, SelectionDAG &DAG, SDValue Result; if (Subtarget.hasSSE3()) { - // FIXME: The 'haddpd' instruction may be slower than 'movhlps + addsd'. + // FIXME: The 'haddpd' instruction may be slower than 'shuffle + addsd'. Result = DAG.getNode(X86ISD::FHADD, dl, MVT::v2f64, Sub, Sub); } else { - SDValue S2F = DAG.getBitcast(MVT::v4i32, Sub); - SDValue Shuffle = DAG.getVectorShuffle(MVT::v4i32, dl, S2F, S2F, {2,3,0,1}); - Result = DAG.getNode(ISD::FADD, dl, MVT::v2f64, - DAG.getBitcast(MVT::v2f64, Shuffle), Sub); + SDValue Shuffle = DAG.getVectorShuffle(MVT::v2f64, dl, Sub, Sub, {1,-1}); + Result = DAG.getNode(ISD::FADD, dl, MVT::v2f64, Shuffle, Sub); } return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Result, @@ -16910,33 +17970,43 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, InVT.getVectorElementType() == MVT::i32) && "Unexpected element type"); + // Custom legalize v8i8->v8i64 on CPUs without avx512bw. + if (InVT == MVT::v8i8) { + if (!ExperimentalVectorWideningLegalization || VT != MVT::v8i64) + return SDValue(); + + In = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), + MVT::v16i8, In, DAG.getUNDEF(MVT::v8i8)); + // FIXME: This should be ANY_EXTEND_VECTOR_INREG for ANY_EXTEND input. + return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, dl, VT, In); + } + if (Subtarget.hasInt256()) - return DAG.getNode(X86ISD::VZEXT, dl, VT, In); + return Op; // Optimize vectors in AVX mode: // // v8i16 -> v8i32 - // Use vpunpcklwd for 4 lower elements v8i16 -> v4i32. + // Use vpmovzwd for 4 lower elements v8i16 -> v4i32. // Use vpunpckhwd for 4 upper elements v8i16 -> v4i32. // Concat upper and lower parts. // // v4i32 -> v4i64 - // Use vpunpckldq for 4 lower elements v4i32 -> v2i64. + // Use vpmovzdq for 4 lower elements v4i32 -> v2i64. // Use vpunpckhdq for 4 upper elements v4i32 -> v2i64. // Concat upper and lower parts. // - SDValue ZeroVec = getZeroVector(InVT, Subtarget, DAG, dl); + MVT HalfVT = MVT::getVectorVT(VT.getVectorElementType(), + VT.getVectorNumElements() / 2); + + SDValue OpLo = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, dl, HalfVT, In); + + SDValue ZeroVec = DAG.getConstant(0, dl, InVT); SDValue Undef = DAG.getUNDEF(InVT); bool NeedZero = Op.getOpcode() == ISD::ZERO_EXTEND; - SDValue OpLo = getUnpackl(DAG, dl, InVT, In, NeedZero ? ZeroVec : Undef); SDValue OpHi = getUnpackh(DAG, dl, InVT, In, NeedZero ? ZeroVec : Undef); - - MVT HVT = MVT::getVectorVT(VT.getVectorElementType(), - VT.getVectorNumElements()/2); - - OpLo = DAG.getBitcast(HVT, OpLo); - OpHi = DAG.getBitcast(HVT, OpHi); + OpHi = DAG.getBitcast(HalfVT, OpHi); return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } @@ -16965,7 +18035,7 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op, SDLoc DL(Op); unsigned NumElts = VT.getVectorNumElements(); - // For all vectors, but vXi8 we can just emit a sign_extend a shift. This + // For all vectors, but vXi8 we can just emit a sign_extend and a shift. This // avoids a constant pool load. if (VT.getVectorElementType() != MVT::i8) { SDValue Extend = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, In); @@ -16995,7 +18065,7 @@ static SDValue LowerZERO_EXTEND_Mask(SDValue Op, } SDValue One = DAG.getConstant(1, DL, WideVT); - SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, DL); + SDValue Zero = DAG.getConstant(0, DL, WideVT); SDValue SelectedVal = DAG.getSelect(DL, WideVT, In, One, Zero); @@ -17035,9 +18105,10 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In, const X86Subtarget &Subtarget) { assert((Opcode == X86ISD::PACKSS || Opcode == X86ISD::PACKUS) && "Unexpected PACK opcode"); + assert(DstVT.isVector() && "VT not a vector?"); // Requires SSE2 but AVX512 has fast vector truncate. - if (!Subtarget.hasSSE2() || Subtarget.hasAVX512() || !DstVT.isVector()) + if (!Subtarget.hasSSE2()) return SDValue(); EVT SrcVT = In.getValueType(); @@ -17203,10 +18274,8 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, } // If we have DQI, emit a pattern that will be iseled as vpmovq2m/vpmovd2m. if (Subtarget.hasDQI()) - return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT), - In, ISD::SETGT); - return DAG.getSetCC(DL, VT, In, getZeroVector(InVT, Subtarget, DAG, DL), - ISD::SETNE); + return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT), In, ISD::SETGT); + return DAG.getSetCC(DL, VT, In, DAG.getConstant(0, DL, InVT), ISD::SETNE); } SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { @@ -17219,20 +18288,22 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { assert(VT.getVectorNumElements() == InVT.getVectorNumElements() && "Invalid TRUNCATE operation"); + // If called by the legalizer just return. + if (!DAG.getTargetLoweringInfo().isTypeLegal(InVT)) + return SDValue(); + if (VT.getVectorElementType() == MVT::i1) return LowerTruncateVecI1(Op, DAG, Subtarget); // vpmovqb/w/d, vpmovdb/w, vpmovwb if (Subtarget.hasAVX512()) { - // word to byte only under BWI - if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) { // v16i16 -> v16i8 - // Make sure we're allowed to promote 512-bits. - if (Subtarget.canExtendTo512DQ()) - return DAG.getNode(ISD::TRUNCATE, DL, VT, - DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In)); - } else { + // word to byte only under BWI. Otherwise we have to promoted to v16i32 + // and then truncate that. But we should only do that if we haven't been + // asked to avoid 512-bit vectors. The actual promotion to v16i32 will be + // handled by isel patterns. + if (InVT != MVT::v16i16 || Subtarget.hasBWI() || + Subtarget.canExtendTo512DQ()) return Op; - } } unsigned NumPackedSignBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16); @@ -17241,8 +18312,7 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { // Truncate with PACKUS if we are truncating a vector with leading zero bits // that extend all the way to the packed/truncated value. // Pre-SSE41 we can only use PACKUSWB. - KnownBits Known; - DAG.computeKnownBits(In, Known); + KnownBits Known = DAG.computeKnownBits(In); if ((InNumEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget)) @@ -17320,6 +18390,17 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { return DAG.getBitcast(MVT::v8i16, res); } + if (VT == MVT::v16i8 && InVT == MVT::v16i16) { + // Use an AND to zero uppper bits for PACKUS. + In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(255, DL, InVT)); + + SDValue InLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i16, In, + DAG.getIntPtrConstant(0, DL)); + SDValue InHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i16, In, + DAG.getIntPtrConstant(8, DL)); + return DAG.getNode(X86ISD::PACKUS, DL, VT, InLo, InHi); + } + // Handle truncation of V256 to V128 using shuffles. assert(VT.is128BitVector() && InVT.is256BitVector() && "Unexpected types!"); @@ -17405,6 +18486,98 @@ static SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) { In, DAG.getUNDEF(SVT))); } +/// Horizontal vector math instructions may be slower than normal math with +/// shuffles. Limit horizontal op codegen based on size/speed trade-offs, uarch +/// implementation, and likely shuffle complexity of the alternate sequence. +static bool shouldUseHorizontalOp(bool IsSingleSource, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + bool IsOptimizingSize = DAG.getMachineFunction().getFunction().optForSize(); + bool HasFastHOps = Subtarget.hasFastHorizontalOps(); + return !IsSingleSource || IsOptimizingSize || HasFastHOps; +} + +/// Depending on uarch and/or optimizing for size, we might prefer to use a +/// vector operation in place of the typical scalar operation. +static SDValue lowerAddSubToHorizontalOp(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // If both operands have other uses, this is probably not profitable. + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + if (!LHS.hasOneUse() && !RHS.hasOneUse()) + return Op; + + // FP horizontal add/sub were added with SSE3. Integer with SSSE3. + bool IsFP = Op.getSimpleValueType().isFloatingPoint(); + if (IsFP && !Subtarget.hasSSE3()) + return Op; + if (!IsFP && !Subtarget.hasSSSE3()) + return Op; + + // Defer forming the minimal horizontal op if the vector source has more than + // the 2 extract element uses that we're matching here. In that case, we might + // form a horizontal op that includes more than 1 add/sub op. + if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + LHS.getOperand(0) != RHS.getOperand(0) || + !LHS.getOperand(0)->hasNUsesOfValue(2, 0)) + return Op; + + if (!isa<ConstantSDNode>(LHS.getOperand(1)) || + !isa<ConstantSDNode>(RHS.getOperand(1)) || + !shouldUseHorizontalOp(true, DAG, Subtarget)) + return Op; + + // Allow commuted 'hadd' ops. + // TODO: Allow commuted (f)sub by negating the result of (F)HSUB? + unsigned HOpcode; + switch (Op.getOpcode()) { + case ISD::ADD: HOpcode = X86ISD::HADD; break; + case ISD::SUB: HOpcode = X86ISD::HSUB; break; + case ISD::FADD: HOpcode = X86ISD::FHADD; break; + case ISD::FSUB: HOpcode = X86ISD::FHSUB; break; + default: + llvm_unreachable("Trying to lower unsupported opcode to horizontal op"); + } + unsigned LExtIndex = LHS.getConstantOperandVal(1); + unsigned RExtIndex = RHS.getConstantOperandVal(1); + if (LExtIndex == 1 && RExtIndex == 0 && + (HOpcode == X86ISD::HADD || HOpcode == X86ISD::FHADD)) + std::swap(LExtIndex, RExtIndex); + + // TODO: This can be extended to handle other adjacent extract pairs. + if (LExtIndex != 0 || RExtIndex != 1) + return Op; + + SDValue X = LHS.getOperand(0); + EVT VecVT = X.getValueType(); + unsigned BitWidth = VecVT.getSizeInBits(); + assert((BitWidth == 128 || BitWidth == 256 || BitWidth == 512) && + "Not expecting illegal vector widths here"); + + // Creating a 256-bit horizontal op would be wasteful, and there is no 512-bit + // equivalent, so extract the 256/512-bit source op to 128-bit. + // This is free: ymm/zmm -> xmm. + SDLoc DL(Op); + if (BitWidth == 256 || BitWidth == 512) + X = extract128BitVector(X, 0, DAG, DL); + + // add (extractelt (X, 0), extractelt (X, 1)) --> extractelt (hadd X, X), 0 + // add (extractelt (X, 1), extractelt (X, 0)) --> extractelt (hadd X, X), 0 + // sub (extractelt (X, 0), extractelt (X, 1)) --> extractelt (hsub X, X), 0 + SDValue HOp = DAG.getNode(HOpcode, DL, X.getValueType(), X, X); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getSimpleValueType(), HOp, + DAG.getIntPtrConstant(0, DL)); +} + +/// Depending on uarch and/or optimizing for size, we might prefer to use a +/// vector operation in place of the typical scalar operation. +static SDValue lowerFaddFsub(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert((Op.getValueType() == MVT::f32 || Op.getValueType() == MVT::f64) && + "Only expecting float/double"); + return lowerAddSubToHorizontalOp(Op, DAG, Subtarget); +} + /// The only differences between FABS and FNEG are the mask and the logic op. /// FNEG also has a folding opportunity for FNEG(FABS(x)). static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) { @@ -17424,43 +18597,36 @@ static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); bool IsF128 = (VT == MVT::f128); + assert((VT == MVT::f64 || VT == MVT::f32 || VT == MVT::f128 || + VT == MVT::v2f64 || VT == MVT::v4f64 || VT == MVT::v4f32 || + VT == MVT::v8f32 || VT == MVT::v8f64 || VT == MVT::v16f32) && + "Unexpected type in LowerFABSorFNEG"); // FIXME: Use function attribute "OptimizeForSize" and/or CodeGenOpt::Level to // decide if we should generate a 16-byte constant mask when we only need 4 or // 8 bytes for the scalar case. - MVT LogicVT; - MVT EltVT; - - if (VT.isVector()) { - LogicVT = VT; - EltVT = VT.getVectorElementType(); - } else if (IsF128) { - // SSE instructions are used for optimized f128 logical operations. - LogicVT = MVT::f128; - EltVT = VT; - } else { - // There are no scalar bitwise logical SSE/AVX instructions, so we - // generate a 16-byte vector constant and logic op even for the scalar case. - // Using a 16-byte mask allows folding the load of the mask with - // the logic op, so it can save (~4 bytes) on code size. + // There are no scalar bitwise logical SSE/AVX instructions, so we + // generate a 16-byte vector constant and logic op even for the scalar case. + // Using a 16-byte mask allows folding the load of the mask with + // the logic op, so it can save (~4 bytes) on code size. + bool IsFakeVector = !VT.isVector() && !IsF128; + MVT LogicVT = VT; + if (IsFakeVector) LogicVT = (VT == MVT::f64) ? MVT::v2f64 : MVT::v4f32; - EltVT = VT; - } - unsigned EltBits = EltVT.getSizeInBits(); + unsigned EltBits = VT.getScalarSizeInBits(); // For FABS, mask is 0x7f...; for FNEG, mask is 0x80... - APInt MaskElt = - IsFABS ? APInt::getSignedMaxValue(EltBits) : APInt::getSignMask(EltBits); - const fltSemantics &Sem = - EltVT == MVT::f64 ? APFloat::IEEEdouble() : - (IsF128 ? APFloat::IEEEquad() : APFloat::IEEEsingle()); + APInt MaskElt = IsFABS ? APInt::getSignedMaxValue(EltBits) : + APInt::getSignMask(EltBits); + const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(VT); SDValue Mask = DAG.getConstantFP(APFloat(Sem, MaskElt), dl, LogicVT); SDValue Op0 = Op.getOperand(0); bool IsFNABS = !IsFABS && (Op0.getOpcode() == ISD::FABS); - unsigned LogicOp = - IsFABS ? X86ISD::FAND : IsFNABS ? X86ISD::FOR : X86ISD::FXOR; + unsigned LogicOp = IsFABS ? X86ISD::FAND : + IsFNABS ? X86ISD::FOR : + X86ISD::FXOR; SDValue Operand = IsFNABS ? Op0.getOperand(0) : Op0; if (VT.isVector() || IsF128) @@ -17496,10 +18662,7 @@ static SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) { VT == MVT::v8f32 || VT == MVT::v8f64 || VT == MVT::v16f32) && "Unexpected type in LowerFCOPYSIGN"); - MVT EltVT = VT.getScalarType(); - const fltSemantics &Sem = - EltVT == MVT::f64 ? APFloat::IEEEdouble() - : (IsF128 ? APFloat::IEEEquad() : APFloat::IEEEsingle()); + const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(VT); // Perform all scalar logic operations as 16-byte vectors because there are no // scalar FP logic instructions in SSE. @@ -17516,7 +18679,7 @@ static SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) { SDValue SignMask = DAG.getConstantFP( APFloat(Sem, APInt::getSignMask(EltSizeInBits)), dl, LogicVT); SDValue MagMask = DAG.getConstantFP( - APFloat(Sem, ~APInt::getSignMask(EltSizeInBits)), dl, LogicVT); + APFloat(Sem, APInt::getSignedMaxValue(EltSizeInBits)), dl, LogicVT); // First, clear all bits but the sign bit from the second operand (sign). if (IsFakeVector) @@ -17527,7 +18690,7 @@ static SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) { // TODO: If we had general constant folding for FP logic ops, this check // wouldn't be necessary. SDValue MagBits; - if (ConstantFPSDNode *Op0CN = dyn_cast<ConstantFPSDNode>(Mag)) { + if (ConstantFPSDNode *Op0CN = isConstOrConstSplatFP(Mag)) { APFloat APF = Op0CN->getValueAPF(); APF.clearSign(); MagBits = DAG.getConstantFP(APF, dl, LogicVT); @@ -17572,7 +18735,8 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl, // Check whether an OR'd tree is PTEST-able. static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { + SelectionDAG &DAG, + SDValue &X86CC) { assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree."); if (!Subtarget.hasSSE41()) @@ -17658,9 +18822,10 @@ static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, VecIns.push_back(DAG.getNode(ISD::OR, DL, TestVT, LHS, RHS)); } - SDValue Res = DAG.getNode(X86ISD::PTEST, DL, MVT::i32, - VecIns.back(), VecIns.back()); - return getSETCC(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, Res, DL, DAG); + X86CC = DAG.getConstant(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, + DL, MVT::i8); + return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, + VecIns.back(), VecIns.back()); } /// return true if \c Op has a use that doesn't just read flags. @@ -17684,8 +18849,8 @@ static bool hasNonFlagsUse(SDValue Op) { /// Emit nodes that will be selected as "test Op0,Op0", or something /// equivalent. -SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, - SelectionDAG &DAG) const { +static SDValue EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, + SelectionDAG &DAG, const X86Subtarget &Subtarget) { // CF and OF aren't always set the way we want. Determine which // of these we need. bool NeedCF = false; @@ -17728,159 +18893,26 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, unsigned Opcode = 0; unsigned NumOperands = 0; - // Truncate operations may prevent the merge of the SETCC instruction - // and the arithmetic instruction before it. Attempt to truncate the operands - // of the arithmetic instruction and use a reduced bit-width instruction. - bool NeedTruncation = false; SDValue ArithOp = Op; - if (Op->getOpcode() == ISD::TRUNCATE && Op->hasOneUse()) { - SDValue Arith = Op->getOperand(0); - // Both the trunc and the arithmetic op need to have one user each. - if (Arith->hasOneUse()) - switch (Arith.getOpcode()) { - default: break; - case ISD::ADD: - case ISD::SUB: - case ISD::AND: - case ISD::OR: - case ISD::XOR: { - NeedTruncation = true; - ArithOp = Arith; - } - } - } - - // Sometimes flags can be set either with an AND or with an SRL/SHL - // instruction. SRL/SHL variant should be preferred for masks longer than this - // number of bits. - const int ShiftToAndMaxMaskWidth = 32; - const bool ZeroCheck = (X86CC == X86::COND_E || X86CC == X86::COND_NE); // NOTICE: In the code below we use ArithOp to hold the arithmetic operation // which may be the result of a CAST. We use the variable 'Op', which is the // non-casted variable when we check for possible users. switch (ArithOp.getOpcode()) { - case ISD::ADD: - // We only want to rewrite this as a target-specific node with attached - // flags if there is a reasonable chance of either using that to do custom - // instructions selection that can fold some of the memory operands, or if - // only the flags are used. If there are other uses, leave the node alone - // and emit a test instruction. - for (SDNode::use_iterator UI = Op.getNode()->use_begin(), - UE = Op.getNode()->use_end(); UI != UE; ++UI) - if (UI->getOpcode() != ISD::CopyToReg && - UI->getOpcode() != ISD::SETCC && - UI->getOpcode() != ISD::STORE) - goto default_case; - - if (auto *C = dyn_cast<ConstantSDNode>(ArithOp.getOperand(1))) { - // An add of one will be selected as an INC. - if (C->isOne() && - (!Subtarget.slowIncDec() || - DAG.getMachineFunction().getFunction().optForSize())) { - Opcode = X86ISD::INC; - NumOperands = 1; - break; - } - - // An add of negative one (subtract of one) will be selected as a DEC. - if (C->isAllOnesValue() && - (!Subtarget.slowIncDec() || - DAG.getMachineFunction().getFunction().optForSize())) { - Opcode = X86ISD::DEC; - NumOperands = 1; - break; - } - } - - // Otherwise use a regular EFLAGS-setting add. - Opcode = X86ISD::ADD; - NumOperands = 2; - break; - case ISD::SHL: - case ISD::SRL: - // If we have a constant logical shift that's only used in a comparison - // against zero turn it into an equivalent AND. This allows turning it into - // a TEST instruction later. - if (ZeroCheck && Op->hasOneUse() && - isa<ConstantSDNode>(Op->getOperand(1)) && !hasNonFlagsUse(Op)) { - EVT VT = Op.getValueType(); - unsigned BitWidth = VT.getSizeInBits(); - unsigned ShAmt = Op->getConstantOperandVal(1); - if (ShAmt >= BitWidth) // Avoid undefined shifts. - break; - APInt Mask = ArithOp.getOpcode() == ISD::SRL - ? APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt) - : APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt); - if (!Mask.isSignedIntN(ShiftToAndMaxMaskWidth)) - break; - Op = DAG.getNode(ISD::AND, dl, VT, Op->getOperand(0), - DAG.getConstant(Mask, dl, VT)); - } - break; - case ISD::AND: // If the primary 'and' result isn't used, don't bother using X86ISD::AND, - // because a TEST instruction will be better. However, AND should be - // preferred if the instruction can be combined into ANDN. - if (!hasNonFlagsUse(Op)) { - SDValue Op0 = ArithOp->getOperand(0); - SDValue Op1 = ArithOp->getOperand(1); - EVT VT = ArithOp.getValueType(); - bool isAndn = isBitwiseNot(Op0) || isBitwiseNot(Op1); - bool isLegalAndnType = VT == MVT::i32 || VT == MVT::i64; - bool isProperAndn = isAndn && isLegalAndnType && Subtarget.hasBMI(); - - // If we cannot select an ANDN instruction, check if we can replace - // AND+IMM64 with a shift before giving up. This is possible for masks - // like 0xFF000000 or 0x00FFFFFF and if we care only about the zero flag. - if (!isProperAndn) { - if (!ZeroCheck) - break; - - assert(!isa<ConstantSDNode>(Op0) && "AND node isn't canonicalized"); - auto *CN = dyn_cast<ConstantSDNode>(Op1); - if (!CN) - break; - - const APInt &Mask = CN->getAPIntValue(); - if (Mask.isSignedIntN(ShiftToAndMaxMaskWidth)) - break; // Prefer TEST instruction. - - unsigned BitWidth = Mask.getBitWidth(); - unsigned LeadingOnes = Mask.countLeadingOnes(); - unsigned TrailingZeros = Mask.countTrailingZeros(); - - if (LeadingOnes + TrailingZeros == BitWidth) { - assert(TrailingZeros < VT.getSizeInBits() && - "Shift amount should be less than the type width"); - MVT ShTy = getScalarShiftAmountTy(DAG.getDataLayout(), VT); - SDValue ShAmt = DAG.getConstant(TrailingZeros, dl, ShTy); - Op = DAG.getNode(ISD::SRL, dl, VT, Op0, ShAmt); - break; - } - - unsigned LeadingZeros = Mask.countLeadingZeros(); - unsigned TrailingOnes = Mask.countTrailingOnes(); - - if (LeadingZeros + TrailingOnes == BitWidth) { - assert(LeadingZeros < VT.getSizeInBits() && - "Shift amount should be less than the type width"); - MVT ShTy = getScalarShiftAmountTy(DAG.getDataLayout(), VT); - SDValue ShAmt = DAG.getConstant(LeadingZeros, dl, ShTy); - Op = DAG.getNode(ISD::SHL, dl, VT, Op0, ShAmt); - break; - } + // because a TEST instruction will be better. + if (!hasNonFlagsUse(Op)) + break; - break; - } - } LLVM_FALLTHROUGH; + case ISD::ADD: case ISD::SUB: case ISD::OR: case ISD::XOR: - // Similar to ISD::ADD above, check if the uses will preclude useful - // lowering of the target-specific node. + // Transform to an x86-specific ALU node with flags if there is a chance of + // using an RMW op or only the flags are used. Otherwise, leave + // the node alone and emit a 'test' instruction. for (SDNode::use_iterator UI = Op.getNode()->use_begin(), UE = Op.getNode()->use_end(); UI != UE; ++UI) if (UI->getOpcode() != ISD::CopyToReg && @@ -17891,6 +18923,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, // Otherwise use a regular EFLAGS-setting instruction. switch (ArithOp.getOpcode()) { default: llvm_unreachable("unexpected operator!"); + case ISD::ADD: Opcode = X86ISD::ADD; break; case ISD::SUB: Opcode = X86ISD::SUB; break; case ISD::XOR: Opcode = X86ISD::XOR; break; case ISD::AND: Opcode = X86ISD::AND; break; @@ -17901,8 +18934,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, break; case X86ISD::ADD: case X86ISD::SUB: - case X86ISD::INC: - case X86ISD::DEC: case X86ISD::OR: case X86ISD::XOR: case X86ISD::AND: @@ -17912,36 +18943,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, break; } - // If we found that truncation is beneficial, perform the truncation and - // update 'Op'. - if (NeedTruncation) { - EVT VT = Op.getValueType(); - SDValue WideVal = Op->getOperand(0); - EVT WideVT = WideVal.getValueType(); - unsigned ConvertedOp = 0; - // Use a target machine opcode to prevent further DAGCombine - // optimizations that may separate the arithmetic operations - // from the setcc node. - switch (WideVal.getOpcode()) { - default: break; - case ISD::ADD: ConvertedOp = X86ISD::ADD; break; - case ISD::SUB: ConvertedOp = X86ISD::SUB; break; - case ISD::AND: ConvertedOp = X86ISD::AND; break; - case ISD::OR: ConvertedOp = X86ISD::OR; break; - case ISD::XOR: ConvertedOp = X86ISD::XOR; break; - } - - if (ConvertedOp) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (TLI.isOperationLegal(WideVal.getOpcode(), WideVT)) { - SDValue V0 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(0)); - SDValue V1 = DAG.getNode(ISD::TRUNCATE, dl, VT, WideVal.getOperand(1)); - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); - Op = DAG.getNode(ConvertedOp, dl, VTs, V0, V1); - } - } - } - if (Opcode == 0) { // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, @@ -17960,17 +18961,17 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, SDValue X86TargetLowering::EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, const SDLoc &dl, SelectionDAG &DAG) const { if (isNullConstant(Op1)) - return EmitTest(Op0, X86CC, dl, DAG); - - assert(!(isa<ConstantSDNode>(Op1) && Op0.getValueType() == MVT::i1) && - "Unexpected comparison operation for MVT::i1 operands"); + return EmitTest(Op0, X86CC, dl, DAG, Subtarget); if ((Op0.getValueType() == MVT::i8 || Op0.getValueType() == MVT::i16 || Op0.getValueType() == MVT::i32 || Op0.getValueType() == MVT::i64)) { // Only promote the compare up to I32 if it is a 16 bit operation // with an immediate. 16 bit immediates are to be avoided. - if ((Op0.getValueType() == MVT::i16 && - (isa<ConstantSDNode>(Op0) || isa<ConstantSDNode>(Op1))) && + if (Op0.getValueType() == MVT::i16 && + ((isa<ConstantSDNode>(Op0) && + !cast<ConstantSDNode>(Op0)->getAPIntValue().isSignedIntN(8)) || + (isa<ConstantSDNode>(Op1) && + !cast<ConstantSDNode>(Op1)->getAPIntValue().isSignedIntN(8))) && !DAG.getMachineFunction().getFunction().optForMinSize() && !Subtarget.isAtom()) { unsigned ExtendOp = @@ -17983,6 +18984,7 @@ SDValue X86TargetLowering::EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, SDValue Sub = DAG.getNode(X86ISD::SUB, dl, VTs, Op0, Op1); return SDValue(Sub.getNode(), 1); } + assert(Op0.getValueType().isFloatingPoint() && "Unexpected VT!"); return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op0, Op1); } @@ -18103,39 +19105,11 @@ unsigned X86TargetLowering::combineRepeatedFPDivisors() const { return 2; } -/// Create a BT (Bit Test) node - Test bit \p BitNo in \p Src and set condition -/// according to equal/not-equal condition code \p CC. -static SDValue getBitTestCondition(SDValue Src, SDValue BitNo, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG) { - // If Src is i8, promote it to i32 with any_extend. There is no i8 BT - // instruction. Since the shift amount is in-range-or-undefined, we know - // that doing a bittest on the i32 value is ok. We extend to i32 because - // the encoding for the i16 version is larger than the i32 version. - // Also promote i16 to i32 for performance / code size reason. - if (Src.getValueType() == MVT::i8 || Src.getValueType() == MVT::i16) - Src = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Src); - - // See if we can use the 32-bit instruction instead of the 64-bit one for a - // shorter encoding. Since the former takes the modulo 32 of BitNo and the - // latter takes the modulo 64, this is only valid if the 5th bit of BitNo is - // known to be zero. - if (Src.getValueType() == MVT::i64 && - DAG.MaskedValueIsZero(BitNo, APInt(BitNo.getValueSizeInBits(), 32))) - Src = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Src); - - // If the operand types disagree, extend the shift amount to match. Since - // BT ignores high bits (like shifts) we can use anyextend. - if (Src.getValueType() != BitNo.getValueType()) - BitNo = DAG.getNode(ISD::ANY_EXTEND, dl, Src.getValueType(), BitNo); - - SDValue BT = DAG.getNode(X86ISD::BT, dl, MVT::i32, Src, BitNo); - X86::CondCode Cond = CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B; - return getSETCC(Cond, BT, dl , DAG); -} - /// Result of 'and' is compared against zero. Change to a BT node if possible. +/// Returns the BT node and the condition code needed to use it. static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG) { + const SDLoc &dl, SelectionDAG &DAG, + SDValue &X86CC) { assert(And.getOpcode() == ISD::AND && "Expected AND node!"); SDValue Op0 = And.getOperand(0); SDValue Op1 = And.getOperand(1); @@ -18144,7 +19118,7 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, if (Op1.getOpcode() == ISD::TRUNCATE) Op1 = Op1.getOperand(0); - SDValue LHS, RHS; + SDValue Src, BitNo; if (Op1.getOpcode() == ISD::SHL) std::swap(Op0, Op1); if (Op0.getOpcode() == ISD::SHL) { @@ -18154,13 +19128,12 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, unsigned BitWidth = Op0.getValueSizeInBits(); unsigned AndBitWidth = And.getValueSizeInBits(); if (BitWidth > AndBitWidth) { - KnownBits Known; - DAG.computeKnownBits(Op0, Known); + KnownBits Known = DAG.computeKnownBits(Op0); if (Known.countMinLeadingZeros() < BitWidth - AndBitWidth) return SDValue(); } - LHS = Op1; - RHS = Op0.getOperand(1); + Src = Op1; + BitNo = Op0.getOperand(1); } } else if (Op1.getOpcode() == ISD::Constant) { ConstantSDNode *AndRHS = cast<ConstantSDNode>(Op1); @@ -18168,24 +19141,49 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, SDValue AndLHS = Op0; if (AndRHSVal == 1 && AndLHS.getOpcode() == ISD::SRL) { - LHS = AndLHS.getOperand(0); - RHS = AndLHS.getOperand(1); + Src = AndLHS.getOperand(0); + BitNo = AndLHS.getOperand(1); } else { // Use BT if the immediate can't be encoded in a TEST instruction or we // are optimizing for size and the immedaite won't fit in a byte. bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); if ((!isUInt<32>(AndRHSVal) || (OptForSize && !isUInt<8>(AndRHSVal))) && isPowerOf2_64(AndRHSVal)) { - LHS = AndLHS; - RHS = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, LHS.getValueType()); + Src = AndLHS; + BitNo = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl, + Src.getValueType()); } } } - if (LHS.getNode()) - return getBitTestCondition(LHS, RHS, CC, dl, DAG); + // No patterns found, give up. + if (!Src.getNode()) + return SDValue(); - return SDValue(); + // If Src is i8, promote it to i32 with any_extend. There is no i8 BT + // instruction. Since the shift amount is in-range-or-undefined, we know + // that doing a bittest on the i32 value is ok. We extend to i32 because + // the encoding for the i16 version is larger than the i32 version. + // Also promote i16 to i32 for performance / code size reason. + if (Src.getValueType() == MVT::i8 || Src.getValueType() == MVT::i16) + Src = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Src); + + // See if we can use the 32-bit instruction instead of the 64-bit one for a + // shorter encoding. Since the former takes the modulo 32 of BitNo and the + // latter takes the modulo 64, this is only valid if the 5th bit of BitNo is + // known to be zero. + if (Src.getValueType() == MVT::i64 && + DAG.MaskedValueIsZero(BitNo, APInt(BitNo.getValueSizeInBits(), 32))) + Src = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Src); + + // If the operand types disagree, extend the shift amount to match. Since + // BT ignores high bits (like shifts) we can use anyextend. + if (Src.getValueType() != BitNo.getValueType()) + BitNo = DAG.getNode(ISD::ANY_EXTEND, dl, Src.getValueType(), BitNo); + + X86CC = DAG.getConstant(CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B, + dl, MVT::i8); + return DAG.getNode(X86ISD::BT, dl, MVT::i32, Src, BitNo); } /// Turns an ISD::CondCode into a value suitable for SSE floating-point mask @@ -18292,34 +19290,32 @@ static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { return DAG.getSetCC(dl, VT, Op0, Op1, SetCCOpcode); } -/// Try to turn a VSETULT into a VSETULE by modifying its second -/// operand \p Op1. If non-trivial (for example because it's not constant) -/// return an empty value. -static SDValue ChangeVSETULTtoVSETULE(const SDLoc &dl, SDValue Op1, - SelectionDAG &DAG) { - BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1.getNode()); +/// Given a simple buildvector constant, return a new vector constant with each +/// element decremented. If decrementing would result in underflow or this +/// is not a simple vector constant, return an empty value. +static SDValue decrementVectorConstant(SDValue V, SelectionDAG &DAG) { + auto *BV = dyn_cast<BuildVectorSDNode>(V.getNode()); if (!BV) return SDValue(); - MVT VT = Op1.getSimpleValueType(); - MVT EVT = VT.getVectorElementType(); - unsigned n = VT.getVectorNumElements(); - SmallVector<SDValue, 8> ULTOp1; - - for (unsigned i = 0; i < n; ++i) { - ConstantSDNode *Elt = dyn_cast<ConstantSDNode>(BV->getOperand(i)); - if (!Elt || Elt->isOpaque() || Elt->getSimpleValueType(0) != EVT) + MVT VT = V.getSimpleValueType(); + MVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + SmallVector<SDValue, 8> NewVecC; + SDLoc DL(V); + for (unsigned i = 0; i < NumElts; ++i) { + auto *Elt = dyn_cast<ConstantSDNode>(BV->getOperand(i)); + if (!Elt || Elt->isOpaque() || Elt->getSimpleValueType(0) != EltVT) return SDValue(); // Avoid underflow. - APInt Val = Elt->getAPIntValue(); - if (Val == 0) + if (Elt->getAPIntValue().isNullValue()) return SDValue(); - ULTOp1.push_back(DAG.getConstant(Val - 1, dl, EVT)); + NewVecC.push_back(DAG.getConstant(Elt->getAPIntValue() - 1, DL, EltVT)); } - return DAG.getBuildVector(VT, dl, ULTOp1); + return DAG.getBuildVector(VT, DL, NewVecC); } /// As another special case, use PSUBUS[BW] when it's profitable. E.g. for @@ -18348,7 +19344,7 @@ static SDValue LowerVSETCCWithSUBUS(SDValue Op0, SDValue Op1, MVT VT, // Only do this pre-AVX since vpcmp* is no longer destructive. if (Subtarget.hasAVX()) return SDValue(); - SDValue ULEOp1 = ChangeVSETULTtoVSETULE(dl, Op1, DAG); + SDValue ULEOp1 = decrementVectorConstant(Op1, DAG); if (!ULEOp1) return SDValue(); Op1 = ULEOp1; @@ -18362,9 +19358,9 @@ static SDValue LowerVSETCCWithSUBUS(SDValue Op0, SDValue Op1, MVT VT, break; } - SDValue Result = DAG.getNode(X86ISD::SUBUS, dl, VT, Op0, Op1); + SDValue Result = DAG.getNode(ISD::USUBSAT, dl, VT, Op0, Op1); return DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result, - getZeroVector(VT, Subtarget, DAG, dl)); + DAG.getConstant(0, dl, VT)); } static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, @@ -18527,13 +19523,26 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, bool FlipSigns = ISD::isUnsignedIntSetCC(Cond) && !(DAG.SignBitIsZero(Op0) && DAG.SignBitIsZero(Op1)); - // Special case: Use min/max operations for unsigned compares. We only want - // to do this for unsigned compares if we need to flip signs or if it allows - // use to avoid an invert. + // Special case: Use min/max operations for unsigned compares. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (ISD::isUnsignedIntSetCC(Cond) && (FlipSigns || ISD::isTrueWhenEqual(Cond)) && TLI.isOperationLegal(ISD::UMIN, VT)) { + // If we have a constant operand, increment/decrement it and change the + // condition to avoid an invert. + // TODO: This could be extended to handle a non-splat constant by checking + // that each element of the constant is not the max/null value. + APInt C; + if (Cond == ISD::SETUGT && isConstantSplat(Op1, C) && !C.isMaxValue()) { + // X > C --> X >= (C+1) --> X == umax(X, C+1) + Op1 = DAG.getConstant(C + 1, dl, VT); + Cond = ISD::SETUGE; + } + if (Cond == ISD::SETULT && isConstantSplat(Op1, C) && !C.isNullValue()) { + // X < C --> X <= (C-1) --> X == umin(X, C-1) + Op1 = DAG.getConstant(C - 1, dl, VT); + Cond = ISD::SETULE; + } bool Invert = false; unsigned Opc; switch (Cond) { @@ -18577,23 +19586,21 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, if (Opc == X86ISD::PCMPGT && !Subtarget.hasSSE42()) { assert(Subtarget.hasSSE2() && "Don't know how to lower!"); - // First cast everything to the right type. - Op0 = DAG.getBitcast(MVT::v4i32, Op0); - Op1 = DAG.getBitcast(MVT::v4i32, Op1); - // Since SSE has no unsigned integer comparisons, we need to flip the sign // bits of the inputs before performing those operations. The lower // compare is always unsigned. SDValue SB; if (FlipSigns) { - SB = DAG.getConstant(0x80000000U, dl, MVT::v4i32); + SB = DAG.getConstant(0x8000000080000000ULL, dl, MVT::v2i64); } else { - SDValue Sign = DAG.getConstant(0x80000000U, dl, MVT::i32); - SDValue Zero = DAG.getConstant(0x00000000U, dl, MVT::i32); - SB = DAG.getBuildVector(MVT::v4i32, dl, {Sign, Zero, Sign, Zero}); + SB = DAG.getConstant(0x0000000080000000ULL, dl, MVT::v2i64); } - Op0 = DAG.getNode(ISD::XOR, dl, MVT::v4i32, Op0, SB); - Op1 = DAG.getNode(ISD::XOR, dl, MVT::v4i32, Op1, SB); + Op0 = DAG.getNode(ISD::XOR, dl, MVT::v2i64, Op0, SB); + Op1 = DAG.getNode(ISD::XOR, dl, MVT::v2i64, Op1, SB); + + // Cast everything to the right type. + Op0 = DAG.getBitcast(MVT::v4i32, Op0); + Op1 = DAG.getBitcast(MVT::v4i32, Op1); // Emulate PCMPGTQ with (hi1 > hi2) | ((hi1 == hi2) & (lo1 > lo2)) SDValue GT = DAG.getNode(X86ISD::PCMPGT, dl, MVT::v4i32, Op0, Op1); @@ -18658,10 +19665,11 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, return Result; } -// Try to select this as a KTEST+SETCC if possible. -static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +// Try to select this as a KORTEST+SETCC if possible. +static SDValue EmitKORTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + SDValue &X86CC) { // Only support equality comparisons. if (CC != ISD::SETEQ && CC != ISD::SETNE) return SDValue(); @@ -18677,12 +19685,12 @@ static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, !(Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1))) return SDValue(); - X86::CondCode X86CC; + X86::CondCode X86Cond; if (isNullConstant(Op1)) { - X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE; + X86Cond = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE; } else if (isAllOnesConstant(Op1)) { // C flag is set for all ones. - X86CC = CC == ISD::SETEQ ? X86::COND_B : X86::COND_AE; + X86Cond = CC == ISD::SETEQ ? X86::COND_B : X86::COND_AE; } else return SDValue(); @@ -18694,70 +19702,87 @@ static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, RHS = Op0.getOperand(1); } - SDValue KORTEST = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS); - return getSETCC(X86CC, KORTEST, dl, DAG); + X86CC = DAG.getConstant(X86Cond, dl, MVT::i8); + return DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS); } -SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { - - MVT VT = Op.getSimpleValueType(); - - if (VT.isVector()) return LowerVSETCC(Op, Subtarget, DAG); - - assert(VT == MVT::i8 && "SetCC type must be 8-bit integer"); - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - SDLoc dl(Op); - ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); - +/// Emit flags for the given setcc condition and operands. Also returns the +/// corresponding X86 condition code constant in X86CC. +SDValue X86TargetLowering::emitFlagsForSetcc(SDValue Op0, SDValue Op1, + ISD::CondCode CC, const SDLoc &dl, + SelectionDAG &DAG, + SDValue &X86CC) const { // Optimize to BT if possible. // Lower (X & (1 << N)) == 0 to BT(X, N). // Lower ((X >>u N) & 1) != 0 to BT(X, N). // Lower ((X >>s N) & 1) != 0 to BT(X, N). if (Op0.getOpcode() == ISD::AND && Op0.hasOneUse() && isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - if (SDValue NewSetCC = LowerAndToBT(Op0, CC, dl, DAG)) - return NewSetCC; + if (SDValue BT = LowerAndToBT(Op0, CC, dl, DAG, X86CC)) + return BT; } // Try to use PTEST for a tree ORs equality compared with 0. // TODO: We could do AND tree with all 1s as well by using the C flag. if (Op0.getOpcode() == ISD::OR && isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - if (SDValue NewSetCC = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG)) - return NewSetCC; + if (SDValue PTEST = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG, X86CC)) + return PTEST; } - // Try to lower using KTEST. - if (SDValue NewSetCC = EmitKTEST(Op0, Op1, CC, dl, DAG, Subtarget)) - return NewSetCC; + // Try to lower using KORTEST. + if (SDValue KORTEST = EmitKORTEST(Op0, Op1, CC, dl, DAG, Subtarget, X86CC)) + return KORTEST; // Look for X == 0, X == 1, X != 0, or X != 1. We can simplify some forms of // these. if ((isOneConstant(Op1) || isNullConstant(Op1)) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - // If the input is a setcc, then reuse the input setcc or use a new one with // the inverted condition. if (Op0.getOpcode() == X86ISD::SETCC) { - X86::CondCode CCode = (X86::CondCode)Op0.getConstantOperandVal(0); bool Invert = (CC == ISD::SETNE) ^ isNullConstant(Op1); - if (!Invert) - return Op0; - CCode = X86::GetOppositeBranchCondition(CCode); - return getSETCC(CCode, Op0.getOperand(1), dl, DAG); + X86CC = Op0.getOperand(0); + if (Invert) { + X86::CondCode CCode = (X86::CondCode)Op0.getConstantOperandVal(0); + CCode = X86::GetOppositeBranchCondition(CCode); + X86CC = DAG.getConstant(CCode, dl, MVT::i8); + } + + return Op0.getOperand(1); } } bool IsFP = Op1.getSimpleValueType().isFloatingPoint(); - X86::CondCode X86CC = TranslateX86CC(CC, dl, IsFP, Op0, Op1, DAG); - if (X86CC == X86::COND_INVALID) + X86::CondCode CondCode = TranslateX86CC(CC, dl, IsFP, Op0, Op1, DAG); + if (CondCode == X86::COND_INVALID) return SDValue(); - SDValue EFLAGS = EmitCmp(Op0, Op1, X86CC, dl, DAG); + SDValue EFLAGS = EmitCmp(Op0, Op1, CondCode, dl, DAG); EFLAGS = ConvertCmpIfNecessary(EFLAGS, DAG); - return getSETCC(X86CC, EFLAGS, dl, DAG); + X86CC = DAG.getConstant(CondCode, dl, MVT::i8); + return EFLAGS; +} + +SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { + + MVT VT = Op.getSimpleValueType(); + + if (VT.isVector()) return LowerVSETCC(Op, Subtarget, DAG); + + assert(VT == MVT::i8 && "SetCC type must be 8-bit integer"); + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + SDLoc dl(Op); + ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); + + SDValue X86CC; + SDValue EFLAGS = emitFlagsForSetcc(Op0, Op1, CC, dl, DAG, X86CC); + if (!EFLAGS) + return SDValue(); + + return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, X86CC, EFLAGS); } SDValue X86TargetLowering::LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) const { @@ -18781,6 +19806,70 @@ SDValue X86TargetLowering::LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) const return getSETCC(CC, Cmp.getValue(1), DL, DAG); } +// This function returns three things: the arithmetic computation itself +// (Value), an EFLAGS result (Overflow), and a condition code (Cond). The +// flag and the condition code define the case in which the arithmetic +// computation overflows. +static std::pair<SDValue, SDValue> +getX86XALUOOp(X86::CondCode &Cond, SDValue Op, SelectionDAG &DAG) { + assert(Op.getResNo() == 0 && "Unexpected result number!"); + SDValue Value, Overflow; + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + unsigned BaseOp = 0; + SDLoc DL(Op); + switch (Op.getOpcode()) { + default: llvm_unreachable("Unknown ovf instruction!"); + case ISD::SADDO: + BaseOp = X86ISD::ADD; + Cond = X86::COND_O; + break; + case ISD::UADDO: + BaseOp = X86ISD::ADD; + Cond = X86::COND_B; + break; + case ISD::SSUBO: + BaseOp = X86ISD::SUB; + Cond = X86::COND_O; + break; + case ISD::USUBO: + BaseOp = X86ISD::SUB; + Cond = X86::COND_B; + break; + case ISD::SMULO: + BaseOp = X86ISD::SMUL; + Cond = X86::COND_O; + break; + case ISD::UMULO: + BaseOp = X86ISD::UMUL; + Cond = X86::COND_O; + break; + } + + if (BaseOp) { + // Also sets EFLAGS. + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); + Value = DAG.getNode(BaseOp, DL, VTs, LHS, RHS); + Overflow = Value.getValue(1); + } + + return std::make_pair(Value, Overflow); +} + +static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { + // Lower the "add/sub/mul with overflow" instruction into a regular ins plus + // a "setcc" instruction that checks the overflow flag. The "brcond" lowering + // looks for this combo and may remove the "setcc" instruction if the "setcc" + // has only one use. + SDLoc DL(Op); + X86::CondCode Cond; + SDValue Value, Overflow; + std::tie(Value, Overflow) = getX86XALUOOp(Cond, Op, DAG); + + SDValue SetCC = getSETCC(Cond, Overflow, DL, DAG); + return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(), Value, SetCC); +} + /// Return true if opcode is a X86 logical comparison. static bool isX86LogicalCmp(SDValue Op) { unsigned Opc = Op.getOpcode(); @@ -18789,12 +19878,8 @@ static bool isX86LogicalCmp(SDValue Op) { return true; if (Op.getResNo() == 1 && (Opc == X86ISD::ADD || Opc == X86ISD::SUB || Opc == X86ISD::ADC || - Opc == X86ISD::SBB || Opc == X86ISD::SMUL || - Opc == X86ISD::INC || Opc == X86ISD::DEC || Opc == X86ISD::OR || - Opc == X86ISD::XOR || Opc == X86ISD::AND)) - return true; - - if (Op.getResNo() == 2 && Opc == X86ISD::UMUL) + Opc == X86ISD::SBB || Opc == X86ISD::SMUL || Opc == X86ISD::UMUL || + Opc == X86ISD::OR || Opc == X86ISD::XOR || Opc == X86ISD::AND)) return true; return false; @@ -18845,7 +19930,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // of 3 logic instructions for size savings and potentially speed. // Unfortunately, there is no scalar form of VBLENDV. - // If either operand is a constant, don't try this. We can expect to + // If either operand is a +0.0 constant, don't try this. We can expect to // optimize away at least one of the logic instructions later in that // case, so that sequence would be faster than a variable blend. @@ -18853,13 +19938,10 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // uses XMM0 as the selection register. That may need just as many // instructions as the AND/ANDN/OR sequence due to register moves, so // don't bother. - - if (Subtarget.hasAVX() && - !isa<ConstantFPSDNode>(Op1) && !isa<ConstantFPSDNode>(Op2)) { - + if (Subtarget.hasAVX() && !isNullFPConstant(Op1) && + !isNullFPConstant(Op2)) { // Convert to vectors, do a VSELECT, and convert back to scalar. // All of the conversions should be optimized away. - MVT VecVT = VT == MVT::f32 ? MVT::v4f32 : MVT::v2f64; SDValue VOp1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Op1); SDValue VOp2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Op2); @@ -18919,16 +20001,6 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } } - if (VT == MVT::v4i1 || VT == MVT::v2i1) { - SDValue zeroConst = DAG.getIntPtrConstant(0, DL); - Op1 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i1, - DAG.getUNDEF(MVT::v8i1), Op1, zeroConst); - Op2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i1, - DAG.getUNDEF(MVT::v8i1), Op2, zeroConst); - SDValue newSelect = DAG.getSelect(DL, MVT::v8i1, Cond, Op1, Op2); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, newSelect, zeroConst); - } - if (Cond.getOpcode() == ISD::SETCC) { if (SDValue NewCond = LowerSETCC(Cond, DAG)) { Cond = NewCond; @@ -18963,22 +20035,21 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // (select (x == 0), 0, -1) -> neg & sbb if (isNullConstant(Y) && (isAllOnesConstant(Op1) == (CondCode == X86::COND_NE))) { - SDVTList VTs = DAG.getVTList(CmpOp0.getValueType(), MVT::i32); SDValue Zero = DAG.getConstant(0, DL, CmpOp0.getValueType()); - SDValue Neg = DAG.getNode(X86ISD::SUB, DL, VTs, Zero, CmpOp0); - SDValue Res = DAG.getNode(X86ISD::SETCC_CARRY, DL, Op.getValueType(), - DAG.getConstant(X86::COND_B, DL, MVT::i8), - SDValue(Neg.getNode(), 1)); - return Res; + SDValue Cmp = DAG.getNode(X86ISD::CMP, DL, MVT::i32, Zero, CmpOp0); + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); + Zero = DAG.getConstant(0, DL, Op.getValueType()); + return DAG.getNode(X86ISD::SBB, DL, VTs, Zero, Zero, Cmp); } Cmp = DAG.getNode(X86ISD::CMP, DL, MVT::i32, CmpOp0, DAG.getConstant(1, DL, CmpOp0.getValueType())); Cmp = ConvertCmpIfNecessary(Cmp, DAG); + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); + SDValue Zero = DAG.getConstant(0, DL, Op.getValueType()); SDValue Res = // Res = 0 or -1. - DAG.getNode(X86ISD::SETCC_CARRY, DL, Op.getValueType(), - DAG.getConstant(X86::COND_B, DL, MVT::i8), Cmp); + DAG.getNode(X86ISD::SBB, DL, VTs, Zero, Zero, Cmp); if (isAllOnesConstant(Op1) != (CondCode == X86::COND_E)) Res = DAG.getNOT(DL, Res, Res.getValueType()); @@ -19055,34 +20126,10 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } } else if (CondOpcode == ISD::USUBO || CondOpcode == ISD::SSUBO || CondOpcode == ISD::UADDO || CondOpcode == ISD::SADDO || - ((CondOpcode == ISD::UMULO || CondOpcode == ISD::SMULO) && - Cond.getOperand(0).getValueType() != MVT::i8)) { - SDValue LHS = Cond.getOperand(0); - SDValue RHS = Cond.getOperand(1); - unsigned X86Opcode; - unsigned X86Cond; - SDVTList VTs; - switch (CondOpcode) { - case ISD::UADDO: X86Opcode = X86ISD::ADD; X86Cond = X86::COND_B; break; - case ISD::SADDO: X86Opcode = X86ISD::ADD; X86Cond = X86::COND_O; break; - case ISD::USUBO: X86Opcode = X86ISD::SUB; X86Cond = X86::COND_B; break; - case ISD::SSUBO: X86Opcode = X86ISD::SUB; X86Cond = X86::COND_O; break; - case ISD::UMULO: X86Opcode = X86ISD::UMUL; X86Cond = X86::COND_O; break; - case ISD::SMULO: X86Opcode = X86ISD::SMUL; X86Cond = X86::COND_O; break; - default: llvm_unreachable("unexpected overflowing operator"); - } - if (CondOpcode == ISD::UMULO) - VTs = DAG.getVTList(LHS.getValueType(), LHS.getValueType(), - MVT::i32); - else - VTs = DAG.getVTList(LHS.getValueType(), MVT::i32); - - SDValue X86Op = DAG.getNode(X86Opcode, DL, VTs, LHS, RHS); - - if (CondOpcode == ISD::UMULO) - Cond = X86Op.getValue(2); - else - Cond = X86Op.getValue(1); + CondOpcode == ISD::UMULO || CondOpcode == ISD::SMULO) { + SDValue Value; + X86::CondCode X86Cond; + std::tie(Value, Cond) = getX86XALUOOp(X86Cond, Cond.getValue(0), DAG); CC = DAG.getConstant(X86Cond, DL, MVT::i8); AddTest = false; @@ -19096,9 +20143,10 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // We know the result of AND is compared against zero. Try to match // it to BT. if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) { - if (SDValue NewSetCC = LowerAndToBT(Cond, ISD::SETNE, DL, DAG)) { - CC = NewSetCC.getOperand(0); - Cond = NewSetCC.getOperand(1); + SDValue BTCC; + if (SDValue BT = LowerAndToBT(Cond, ISD::SETNE, DL, DAG, BTCC)) { + CC = BTCC; + Cond = BT; AddTest = false; } } @@ -19106,7 +20154,8 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { if (AddTest) { CC = DAG.getConstant(X86::COND_NE, DL, MVT::i8); - Cond = EmitTest(Cond, X86::COND_NE, DL, DAG); + Cond = EmitCmp(Cond, DAG.getConstant(0, DL, Cond.getValueType()), + X86::COND_NE, DL, DAG); } // a < b ? -1 : 0 -> RES = ~setcc_carry @@ -19171,12 +20220,12 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, unsigned NumElts = VT.getVectorNumElements(); - // Extend VT if the scalar type is v8/v16 and BWI is not supported. + // Extend VT if the scalar type is i8/i16 and BWI is not supported. MVT ExtVT = VT; if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) { // If v16i32 is to be avoided, we'll need to split and concatenate. if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) - return SplitAndExtendv16i1(ISD::SIGN_EXTEND, VT, In, dl, DAG); + return SplitAndExtendv16i1(Op.getOpcode(), VT, In, dl, DAG); ExtVT = MVT::getVectorVT(MVT::i32, NumElts); } @@ -19195,10 +20244,10 @@ static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, MVT WideEltVT = WideVT.getVectorElementType(); if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) || (Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) { - V = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, In); + V = DAG.getNode(Op.getOpcode(), dl, WideVT, In); } else { - SDValue NegOne = getOnesVector(WideVT, DAG, dl); - SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, dl); + SDValue NegOne = DAG.getConstant(-1, dl, WideVT); + SDValue Zero = DAG.getConstant(0, dl, WideVT); V = DAG.getSelect(dl, WideVT, In, NegOne, Zero); } @@ -19238,7 +20287,6 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, SDValue In = Op->getOperand(0); MVT VT = Op->getSimpleValueType(0); MVT InVT = In.getSimpleValueType(); - assert(VT.getSizeInBits() == InVT.getSizeInBits()); MVT SVT = VT.getVectorElementType(); MVT InSVT = InVT.getVectorElementType(); @@ -19249,70 +20297,100 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, if (InSVT != MVT::i32 && InSVT != MVT::i16 && InSVT != MVT::i8) return SDValue(); if (!(VT.is128BitVector() && Subtarget.hasSSE2()) && - !(VT.is256BitVector() && Subtarget.hasInt256()) && + !(VT.is256BitVector() && Subtarget.hasAVX()) && !(VT.is512BitVector() && Subtarget.hasAVX512())) return SDValue(); SDLoc dl(Op); + unsigned Opc = Op.getOpcode(); + unsigned NumElts = VT.getVectorNumElements(); // For 256-bit vectors, we only need the lower (128-bit) half of the input. // For 512-bit vectors, we need 128-bits or 256-bits. - if (VT.getSizeInBits() > 128) { + if (InVT.getSizeInBits() > 128) { // Input needs to be at least the same number of elements as output, and // at least 128-bits. - int InSize = InSVT.getSizeInBits() * VT.getVectorNumElements(); + int InSize = InSVT.getSizeInBits() * NumElts; In = extractSubVector(In, 0, DAG, dl, std::max(InSize, 128)); + InVT = In.getSimpleValueType(); } - assert((Op.getOpcode() != ISD::ZERO_EXTEND_VECTOR_INREG || - InVT == MVT::v64i8) && "Zero extend only for v64i8 input!"); - - // SSE41 targets can use the pmovsx* instructions directly for 128-bit results, + // SSE41 targets can use the pmov[sz]x* instructions directly for 128-bit results, // so are legal and shouldn't occur here. AVX2/AVX512 pmovsx* instructions still // need to be handled here for 256/512-bit results. if (Subtarget.hasInt256()) { assert(VT.getSizeInBits() > 128 && "Unexpected 128-bit vector extension"); - unsigned ExtOpc = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ? - X86ISD::VSEXT : X86ISD::VZEXT; + + if (InVT.getVectorNumElements() != NumElts) + return DAG.getNode(Op.getOpcode(), dl, VT, In); + + // FIXME: Apparently we create inreg operations that could be regular + // extends. + unsigned ExtOpc = + Opc == ISD::SIGN_EXTEND_VECTOR_INREG ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND; return DAG.getNode(ExtOpc, dl, VT, In); } + // pre-AVX2 256-bit extensions need to be split into 128-bit instructions. + if (Subtarget.hasAVX()) { + assert(VT.is256BitVector() && "256-bit vector expected"); + int HalfNumElts = NumElts / 2; + MVT HalfVT = MVT::getVectorVT(SVT, HalfNumElts); + + unsigned NumSrcElts = InVT.getVectorNumElements(); + SmallVector<int, 16> HiMask(NumSrcElts, SM_SentinelUndef); + for (int i = 0; i != HalfNumElts; ++i) + HiMask[i] = HalfNumElts + i; + + SDValue Lo = DAG.getNode(Opc, dl, HalfVT, In); + SDValue Hi = DAG.getVectorShuffle(InVT, dl, In, DAG.getUNDEF(InVT), HiMask); + Hi = DAG.getNode(Opc, dl, HalfVT, Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi); + } + // We should only get here for sign extend. - assert(Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG && - "Unexpected opcode!"); + assert(Opc == ISD::SIGN_EXTEND_VECTOR_INREG && "Unexpected opcode!"); + assert(VT.is128BitVector() && InVT.is128BitVector() && "Unexpected VTs"); // pre-SSE41 targets unpack lower lanes and then sign-extend using SRAI. SDValue Curr = In; - MVT CurrVT = InVT; + SDValue SignExt = Curr; // As SRAI is only available on i16/i32 types, we expand only up to i32 // and handle i64 separately. - while (CurrVT != VT && CurrVT.getVectorElementType() != MVT::i32) { - Curr = DAG.getNode(X86ISD::UNPCKL, dl, CurrVT, DAG.getUNDEF(CurrVT), Curr); - MVT CurrSVT = MVT::getIntegerVT(CurrVT.getScalarSizeInBits() * 2); - CurrVT = MVT::getVectorVT(CurrSVT, CurrVT.getVectorNumElements() / 2); - Curr = DAG.getBitcast(CurrVT, Curr); - } + if (InVT != MVT::v4i32) { + MVT DestVT = VT == MVT::v2i64 ? MVT::v4i32 : VT; - SDValue SignExt = Curr; - if (CurrVT != InVT) { - unsigned SignExtShift = - CurrVT.getScalarSizeInBits() - InSVT.getSizeInBits(); - SignExt = DAG.getNode(X86ISD::VSRAI, dl, CurrVT, Curr, + unsigned DestWidth = DestVT.getScalarSizeInBits(); + unsigned Scale = DestWidth / InSVT.getSizeInBits(); + + unsigned InNumElts = InVT.getVectorNumElements(); + unsigned DestElts = DestVT.getVectorNumElements(); + + // Build a shuffle mask that takes each input element and places it in the + // MSBs of the new element size. + SmallVector<int, 16> Mask(InNumElts, SM_SentinelUndef); + for (unsigned i = 0; i != DestElts; ++i) + Mask[i * Scale + (Scale - 1)] = i; + + Curr = DAG.getVectorShuffle(InVT, dl, In, In, Mask); + Curr = DAG.getBitcast(DestVT, Curr); + + unsigned SignExtShift = DestWidth - InSVT.getSizeInBits(); + SignExt = DAG.getNode(X86ISD::VSRAI, dl, DestVT, Curr, DAG.getConstant(SignExtShift, dl, MVT::i8)); } - if (CurrVT == VT) - return SignExt; - - if (VT == MVT::v2i64 && CurrVT == MVT::v4i32) { - SDValue Sign = DAG.getNode(X86ISD::VSRAI, dl, CurrVT, Curr, - DAG.getConstant(31, dl, MVT::i8)); - SDValue Ext = DAG.getVectorShuffle(CurrVT, dl, SignExt, Sign, {0, 4, 1, 5}); - return DAG.getBitcast(VT, Ext); + if (VT == MVT::v2i64) { + assert(Curr.getValueType() == MVT::v4i32 && "Unexpected input VT"); + SDValue Zero = DAG.getConstant(0, dl, MVT::v4i32); + SDValue Sign = DAG.getSetCC(dl, MVT::v4i32, Zero, Curr, ISD::SETGT); + SignExt = DAG.getVectorShuffle(MVT::v4i32, dl, SignExt, Sign, {0, 4, 1, 5}); + SignExt = DAG.getBitcast(VT, SignExt); } - return SDValue(); + return SignExt; } static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, @@ -19337,38 +20415,40 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, InVT.getVectorElementType() == MVT::i32) && "Unexpected element type"); + // Custom legalize v8i8->v8i64 on CPUs without avx512bw. + if (InVT == MVT::v8i8) { + if (!ExperimentalVectorWideningLegalization || VT != MVT::v8i64) + return SDValue(); + + In = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), + MVT::v16i8, In, DAG.getUNDEF(MVT::v8i8)); + return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, VT, In); + } + if (Subtarget.hasInt256()) - return DAG.getNode(X86ISD::VSEXT, dl, VT, In); + return Op; // Optimize vectors in AVX mode // Sign extend v8i16 to v8i32 and // v4i32 to v4i64 // // Divide input vector into two parts - // for v4i32 the shuffle mask will be { 0, 1, -1, -1} {2, 3, -1, -1} + // for v4i32 the high shuffle mask will be {2, 3, -1, -1} // use vpmovsx instruction to extend v4i32 -> v2i64; v8i16 -> v4i32 // concat the vectors to original VT - unsigned NumElems = InVT.getVectorNumElements(); - SDValue Undef = DAG.getUNDEF(InVT); - - SmallVector<int,8> ShufMask1(NumElems, -1); - for (unsigned i = 0; i != NumElems/2; ++i) - ShufMask1[i] = i; + MVT HalfVT = MVT::getVectorVT(VT.getVectorElementType(), + VT.getVectorNumElements() / 2); - SDValue OpLo = DAG.getVectorShuffle(InVT, dl, In, Undef, ShufMask1); + SDValue OpLo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, HalfVT, In); - SmallVector<int,8> ShufMask2(NumElems, -1); + unsigned NumElems = InVT.getVectorNumElements(); + SmallVector<int,8> ShufMask(NumElems, -1); for (unsigned i = 0; i != NumElems/2; ++i) - ShufMask2[i] = i + NumElems/2; + ShufMask[i] = i + NumElems/2; - SDValue OpHi = DAG.getVectorShuffle(InVT, dl, In, Undef, ShufMask2); - - MVT HalfVT = MVT::getVectorVT(VT.getVectorElementType(), - VT.getVectorNumElements() / 2); - - OpLo = DAG.getSignExtendVectorInReg(OpLo, dl, HalfVT); - OpHi = DAG.getSignExtendVectorInReg(OpHi, dl, HalfVT); + SDValue OpHi = DAG.getVectorShuffle(InVT, dl, In, In, ShufMask); + OpHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, HalfVT, OpHi); return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } @@ -19379,19 +20459,47 @@ static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SDLoc dl(St); SDValue StoredVal = St->getValue(); - // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads. - assert(StoredVal.getValueType().isVector() && - StoredVal.getValueType().getVectorElementType() == MVT::i1 && - StoredVal.getValueType().getVectorNumElements() <= 8 && - "Unexpected VT"); - assert(!St->isTruncatingStore() && "Expected non-truncating store"); - assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && - "Expected AVX512F without AVX512DQI"); + // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 stores. + if (StoredVal.getValueType().isVector() && + StoredVal.getValueType().getVectorElementType() == MVT::i1) { + assert(StoredVal.getValueType().getVectorNumElements() <= 8 && + "Unexpected VT"); + assert(!St->isTruncatingStore() && "Expected non-truncating store"); + assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && + "Expected AVX512F without AVX512DQI"); + + StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, + DAG.getUNDEF(MVT::v16i1), StoredVal, + DAG.getIntPtrConstant(0, dl)); + StoredVal = DAG.getBitcast(MVT::i16, StoredVal); + StoredVal = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, StoredVal); + + return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + St->getPointerInfo(), St->getAlignment(), + St->getMemOperand()->getFlags()); + } + + if (St->isTruncatingStore()) + return SDValue(); - StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, - DAG.getUNDEF(MVT::v8i1), StoredVal, + MVT StoreVT = StoredVal.getSimpleValueType(); + assert(StoreVT.isVector() && StoreVT.getSizeInBits() == 64 && + "Unexpected VT"); + if (DAG.getTargetLoweringInfo().getTypeAction(*DAG.getContext(), StoreVT) != + TargetLowering::TypeWidenVector) + return SDValue(); + + // Widen the vector, cast to a v2x64 type, extract the single 64-bit element + // and store it. + MVT WideVT = MVT::getVectorVT(StoreVT.getVectorElementType(), + StoreVT.getVectorNumElements() * 2); + StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT, StoredVal, + DAG.getUNDEF(StoreVT)); + MVT StVT = Subtarget.is64Bit() && StoreVT.isInteger() ? MVT::i64 : MVT::f64; + MVT CastVT = MVT::getVectorVT(StVT, 2); + StoredVal = DAG.getBitcast(CastVT, StoredVal); + StoredVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, StVT, StoredVal, DAG.getIntPtrConstant(0, dl)); - StoredVal = DAG.getBitcast(MVT::i8, StoredVal); return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), St->getPointerInfo(), St->getAlignment(), @@ -19400,7 +20508,7 @@ static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, // Lower vector extended loads using a shuffle. If SSSE3 is not available we // may emit an illegal shuffle but the expansion is still better than scalar -// code. We generate X86ISD::VSEXT for SEXTLOADs if it's available, otherwise +// code. We generate sext/sext_invec for SEXTLOADs if it's available, otherwise // we'll emit a shuffle and a arithmetic shift. // FIXME: Is the expansion actually better than scalar code? It doesn't seem so. // TODO: It is possible to support ZExt by zeroing the undef values during @@ -19408,16 +20516,16 @@ static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT RegVT = Op.getSimpleValueType(); - assert(RegVT.isVector() && "We only custom lower vector sext loads."); + assert(RegVT.isVector() && "We only custom lower vector loads."); assert(RegVT.isInteger() && - "We only custom lower integer vector sext loads."); + "We only custom lower integer vector loads."); LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode()); SDLoc dl(Ld); EVT MemVT = Ld->getMemoryVT(); // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads. - if (RegVT.isVector() && RegVT.getVectorElementType() == MVT::i1) { + if (RegVT.getVectorElementType() == MVT::i1) { assert(EVT(RegVT) == MemVT && "Expected non-extending load"); assert(RegVT.getVectorNumElements() <= 8 && "Unexpected VT"); assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && @@ -19429,12 +20537,12 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, // Replace chain users with the new chain. assert(NewLd->getNumValues() == 2 && "Loads must carry a chain!"); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewLd.getValue(1)); - SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, RegVT, - DAG.getBitcast(MVT::v8i1, NewLd), - DAG.getIntPtrConstant(0, dl)); - return DAG.getMergeValues({Extract, NewLd.getValue(1)}, dl); + SDValue Val = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, NewLd); + Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, RegVT, + DAG.getBitcast(MVT::v16i1, Val), + DAG.getIntPtrConstant(0, dl)); + return DAG.getMergeValues({Val, NewLd.getValue(1)}, dl); } // Nothing useful we can do without SSE2 shuffles. @@ -19490,10 +20598,10 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, // Replace chain users with the new chain. assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); // Finally, do a normal sign-extend to the desired register. - return DAG.getSExtOrTrunc(Load, dl, RegVT); + SDValue SExt = DAG.getSExtOrTrunc(Load, dl, RegVT); + return DAG.getMergeValues({SExt, Load.getValue(1)}, dl); } // All sizes must be a power of two. @@ -19521,26 +20629,26 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, assert((Ext != ISD::SEXTLOAD || NumLoads == 1) && "Can only lower sext loads with a single scalar load!"); - unsigned loadRegZize = RegSz; + unsigned loadRegSize = RegSz; if (Ext == ISD::SEXTLOAD && RegSz >= 256) - loadRegZize = 128; + loadRegSize = 128; // If we don't have BWI we won't be able to create the shuffle needed for // v8i8->v8i64. if (Ext == ISD::EXTLOAD && !Subtarget.hasBWI() && RegVT == MVT::v8i64 && MemVT == MVT::v8i8) - loadRegZize = 128; + loadRegSize = 128; // Represent our vector as a sequence of elements which are the // largest scalar that we can load. EVT LoadUnitVecVT = EVT::getVectorVT( - *DAG.getContext(), SclrLoadTy, loadRegZize / SclrLoadTy.getSizeInBits()); + *DAG.getContext(), SclrLoadTy, loadRegSize / SclrLoadTy.getSizeInBits()); // Represent the data using the same element type that is stored in // memory. In practice, we ''widen'' MemVT. EVT WideVecVT = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), - loadRegZize / MemVT.getScalarSizeInBits()); + loadRegSize / MemVT.getScalarSizeInBits()); assert(WideVecVT.getSizeInBits() == LoadUnitVecVT.getSizeInBits() && "Invalid vector type"); @@ -19551,15 +20659,20 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SmallVector<SDValue, 8> Chains; SDValue Ptr = Ld->getBasePtr(); - SDValue Increment = DAG.getConstant(SclrLoadTy.getSizeInBits() / 8, dl, + unsigned OffsetInc = SclrLoadTy.getSizeInBits() / 8; + SDValue Increment = DAG.getConstant(OffsetInc, dl, TLI.getPointerTy(DAG.getDataLayout())); SDValue Res = DAG.getUNDEF(LoadUnitVecVT); + unsigned Offset = 0; for (unsigned i = 0; i < NumLoads; ++i) { + unsigned NewAlign = MinAlign(Ld->getAlignment(), Offset); + // Perform a single load. SDValue ScalarLoad = - DAG.getLoad(SclrLoadTy, dl, Ld->getChain(), Ptr, Ld->getPointerInfo(), - Ld->getAlignment(), Ld->getMemOperand()->getFlags()); + DAG.getLoad(SclrLoadTy, dl, Ld->getChain(), Ptr, + Ld->getPointerInfo().getWithOffset(Offset), + NewAlign, Ld->getMemOperand()->getFlags()); Chains.push_back(ScalarLoad.getValue(1)); // Create the first element type using SCALAR_TO_VECTOR in order to avoid // another round of DAGCombining. @@ -19570,6 +20683,7 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, ScalarLoad, DAG.getIntPtrConstant(i, dl)); Ptr = DAG.getNode(ISD::ADD, dl, Ptr.getValueType(), Ptr, Increment); + Offset += OffsetInc; } SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chains); @@ -19580,28 +20694,14 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, unsigned SizeRatio = RegSz / MemSz; if (Ext == ISD::SEXTLOAD) { - // If we have SSE4.1, we can directly emit a VSEXT node. - if (Subtarget.hasSSE41()) { - SDValue Sext = getExtendInVec(X86ISD::VSEXT, dl, RegVT, SlicedVec, DAG); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); - return Sext; - } - - // Otherwise we'll use SIGN_EXTEND_VECTOR_INREG to sign extend the lowest - // lanes. - assert(TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND_VECTOR_INREG, RegVT) && - "We can't implement a sext load without SIGN_EXTEND_VECTOR_INREG!"); - - SDValue Shuff = DAG.getSignExtendVectorInReg(SlicedVec, dl, RegVT); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); - return Shuff; + SDValue Sext = getExtendInVec(/*Signed*/true, dl, RegVT, SlicedVec, DAG); + return DAG.getMergeValues({Sext, TF}, dl); } if (Ext == ISD::EXTLOAD && !Subtarget.hasBWI() && RegVT == MVT::v8i64 && MemVT == MVT::v8i8) { - SDValue Sext = getExtendInVec(X86ISD::VZEXT, dl, RegVT, SlicedVec, DAG); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); - return Sext; + SDValue Sext = getExtendInVec(/*Signed*/false, dl, RegVT, SlicedVec, DAG); + return DAG.getMergeValues({Sext, TF}, dl); } // Redistribute the loaded elements into the different locations. @@ -19614,8 +20714,7 @@ static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, // Bitcast to the requested type. Shuff = DAG.getBitcast(RegVT, Shuff); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); - return Shuff; + return DAG.getMergeValues({Shuff, TF}, dl); } /// Return true if node is an ISD::AND or ISD::OR of two X86ISD::SETCC nodes @@ -19712,49 +20811,13 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { CondOpcode = Cond.getOpcode(); if (CondOpcode == ISD::UADDO || CondOpcode == ISD::SADDO || CondOpcode == ISD::USUBO || CondOpcode == ISD::SSUBO || - ((CondOpcode == ISD::UMULO || CondOpcode == ISD::SMULO) && - Cond.getOperand(0).getValueType() != MVT::i8)) { - SDValue LHS = Cond.getOperand(0); - SDValue RHS = Cond.getOperand(1); - unsigned X86Opcode; - unsigned X86Cond; - SDVTList VTs; - // Keep this in sync with LowerXALUO, otherwise we might create redundant - // instructions that can't be removed afterwards (i.e. X86ISD::ADD and - // X86ISD::INC). - switch (CondOpcode) { - case ISD::UADDO: X86Opcode = X86ISD::ADD; X86Cond = X86::COND_B; break; - case ISD::SADDO: - if (isOneConstant(RHS)) { - X86Opcode = X86ISD::INC; X86Cond = X86::COND_O; - break; - } - X86Opcode = X86ISD::ADD; X86Cond = X86::COND_O; break; - case ISD::USUBO: X86Opcode = X86ISD::SUB; X86Cond = X86::COND_B; break; - case ISD::SSUBO: - if (isOneConstant(RHS)) { - X86Opcode = X86ISD::DEC; X86Cond = X86::COND_O; - break; - } - X86Opcode = X86ISD::SUB; X86Cond = X86::COND_O; break; - case ISD::UMULO: X86Opcode = X86ISD::UMUL; X86Cond = X86::COND_O; break; - case ISD::SMULO: X86Opcode = X86ISD::SMUL; X86Cond = X86::COND_O; break; - default: llvm_unreachable("unexpected overflowing operator"); - } - if (Inverted) - X86Cond = X86::GetOppositeBranchCondition((X86::CondCode)X86Cond); - if (CondOpcode == ISD::UMULO) - VTs = DAG.getVTList(LHS.getValueType(), LHS.getValueType(), - MVT::i32); - else - VTs = DAG.getVTList(LHS.getValueType(), MVT::i32); - - SDValue X86Op = DAG.getNode(X86Opcode, dl, VTs, LHS, RHS); + CondOpcode == ISD::UMULO || CondOpcode == ISD::SMULO) { + SDValue Value; + X86::CondCode X86Cond; + std::tie(Value, Cond) = getX86XALUOOp(X86Cond, Cond.getValue(0), DAG); - if (CondOpcode == ISD::UMULO) - Cond = X86Op.getValue(2); - else - Cond = X86Op.getValue(1); + if (Inverted) + X86Cond = X86::GetOppositeBranchCondition(X86Cond); CC = DAG.getConstant(X86Cond, dl, MVT::i8); addTest = false; @@ -19855,34 +20918,17 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { } else if (Cond.getOpcode() == ISD::SETCC && cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUNE) { // For FCMP_UNE, we can emit - // two branches instead of an explicit AND instruction with a - // separate test. However, we only do this if this block doesn't - // have a fall-through edge, because this requires an explicit - // jmp when the condition is false. - if (Op.getNode()->hasOneUse()) { - SDNode *User = *Op.getNode()->use_begin(); - // Look for an unconditional branch following this conditional branch. - // We need this because we need to reverse the successors in order - // to implement FCMP_UNE. - if (User->getOpcode() == ISD::BR) { - SDValue FalseBB = User->getOperand(1); - SDNode *NewBR = - DAG.UpdateNodeOperands(User, User->getOperand(0), Dest); - assert(NewBR == User); - (void)NewBR; - - SDValue Cmp = DAG.getNode(X86ISD::CMP, dl, MVT::i32, - Cond.getOperand(0), Cond.getOperand(1)); - Cmp = ConvertCmpIfNecessary(Cmp, DAG); - CC = DAG.getConstant(X86::COND_NE, dl, MVT::i8); - Chain = DAG.getNode(X86ISD::BRCOND, dl, Op.getValueType(), - Chain, Dest, CC, Cmp); - CC = DAG.getConstant(X86::COND_NP, dl, MVT::i8); - Cond = Cmp; - addTest = false; - Dest = FalseBB; - } - } + // two branches instead of an explicit OR instruction with a + // separate test. + SDValue Cmp = DAG.getNode(X86ISD::CMP, dl, MVT::i32, + Cond.getOperand(0), Cond.getOperand(1)); + Cmp = ConvertCmpIfNecessary(Cmp, DAG); + CC = DAG.getConstant(X86::COND_NE, dl, MVT::i8); + Chain = DAG.getNode(X86ISD::BRCOND, dl, Op.getValueType(), + Chain, Dest, CC, Cmp); + CC = DAG.getConstant(X86::COND_P, dl, MVT::i8); + Cond = Cmp; + addTest = false; } } @@ -19894,9 +20940,10 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { // We know the result of AND is compared against zero. Try to match // it to BT. if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) { - if (SDValue NewSetCC = LowerAndToBT(Cond, ISD::SETNE, dl, DAG)) { - CC = NewSetCC.getOperand(0); - Cond = NewSetCC.getOperand(1); + SDValue BTCC; + if (SDValue BT = LowerAndToBT(Cond, ISD::SETNE, dl, DAG, BTCC)) { + CC = BTCC; + Cond = BT; addTest = false; } } @@ -19905,7 +20952,8 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { if (addTest) { X86::CondCode X86Cond = Inverted ? X86::COND_E : X86::COND_NE; CC = DAG.getConstant(X86Cond, dl, MVT::i8); - Cond = EmitTest(Cond, X86Cond, dl, DAG); + Cond = EmitCmp(Cond, DAG.getConstant(0, dl, Cond.getValueType()), + X86Cond, dl, DAG); } Cond = ConvertCmpIfNecessary(Cond, DAG); return DAG.getNode(X86ISD::BRCOND, dl, Op.getValueType(), @@ -20141,6 +21189,25 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget &Subtarget, MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV)); } +// Helper to get immediate/variable SSE shift opcode from other shift opcodes. +static unsigned getTargetVShiftUniformOpcode(unsigned Opc, bool IsVariable) { + switch (Opc) { + case ISD::SHL: + case X86ISD::VSHL: + case X86ISD::VSHLI: + return IsVariable ? X86ISD::VSHL : X86ISD::VSHLI; + case ISD::SRL: + case X86ISD::VSRL: + case X86ISD::VSRLI: + return IsVariable ? X86ISD::VSRL : X86ISD::VSRLI; + case ISD::SRA: + case X86ISD::VSRA: + case X86ISD::VSRAI: + return IsVariable ? X86ISD::VSRA : X86ISD::VSRAI; + } + llvm_unreachable("Unknown target vector shift node"); +} + /// Handle vector element shifts where the shift amount is a constant. /// Takes immediate version of shift as input. static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT, @@ -20236,46 +21303,57 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp, CShAmt->getZExtValue(), DAG); - // Change opcode to non-immediate version - switch (Opc) { - default: llvm_unreachable("Unknown target vector shift node"); - case X86ISD::VSHLI: Opc = X86ISD::VSHL; break; - case X86ISD::VSRLI: Opc = X86ISD::VSRL; break; - case X86ISD::VSRAI: Opc = X86ISD::VSRA; break; - } + // Change opcode to non-immediate version. + Opc = getTargetVShiftUniformOpcode(Opc, true); // Need to build a vector containing shift amount. // SSE/AVX packed shifts only use the lower 64-bit of the shift count. - // +=================+============+=======================================+ - // | ShAmt is | HasSSE4.1? | Construct ShAmt vector as | - // +=================+============+=======================================+ - // | i64 | Yes, No | Use ShAmt as lowest elt | - // | i32 | Yes | zero-extend in-reg | - // | (i32 zext(i16)) | Yes | zero-extend in-reg | - // | i16/i32 | No | v4i32 build_vector(ShAmt, 0, ud, ud)) | - // +=================+============+=======================================+ + // +====================+============+=======================================+ + // | ShAmt is | HasSSE4.1? | Construct ShAmt vector as | + // +====================+============+=======================================+ + // | i64 | Yes, No | Use ShAmt as lowest elt | + // | i32 | Yes | zero-extend in-reg | + // | (i32 zext(i16/i8)) | Yes | zero-extend in-reg | + // | (i32 zext(i16/i8)) | No | byte-shift-in-reg | + // | i16/i32 | No | v4i32 build_vector(ShAmt, 0, ud, ud)) | + // +====================+============+=======================================+ if (SVT == MVT::i64) ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v2i64, ShAmt); - else if (Subtarget.hasSSE41() && ShAmt.getOpcode() == ISD::ZERO_EXTEND && - ShAmt.getOperand(0).getSimpleValueType() == MVT::i16) { + else if (ShAmt.getOpcode() == ISD::ZERO_EXTEND && + ShAmt.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && + (ShAmt.getOperand(0).getSimpleValueType() == MVT::i16 || + ShAmt.getOperand(0).getSimpleValueType() == MVT::i8)) { ShAmt = ShAmt.getOperand(0); - ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v8i16, ShAmt); - ShAmt = DAG.getZeroExtendVectorInReg(ShAmt, SDLoc(ShAmt), MVT::v2i64); + MVT AmtTy = ShAmt.getSimpleValueType() == MVT::i8 ? MVT::v16i8 : MVT::v8i16; + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), AmtTy, ShAmt); + if (Subtarget.hasSSE41()) + ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt), + MVT::v2i64, ShAmt); + else { + SDValue ByteShift = DAG.getConstant( + (128 - AmtTy.getScalarSizeInBits()) / 8, SDLoc(ShAmt), MVT::i8); + ShAmt = DAG.getBitcast(MVT::v16i8, ShAmt); + ShAmt = DAG.getNode(X86ISD::VSHLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt, + ByteShift); + ShAmt = DAG.getNode(X86ISD::VSRLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt, + ByteShift); + } } else if (Subtarget.hasSSE41() && ShAmt.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v4i32, ShAmt); - ShAmt = DAG.getZeroExtendVectorInReg(ShAmt, SDLoc(ShAmt), MVT::v2i64); + ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt), + MVT::v2i64, ShAmt); } else { - SDValue ShOps[4] = {ShAmt, DAG.getConstant(0, dl, SVT), - DAG.getUNDEF(SVT), DAG.getUNDEF(SVT)}; + SDValue ShOps[4] = {ShAmt, DAG.getConstant(0, dl, SVT), DAG.getUNDEF(SVT), + DAG.getUNDEF(SVT)}; ShAmt = DAG.getBuildVector(MVT::v4i32, dl, ShOps); } // The return type has to be a 128-bit type with the same element // type as the input type. MVT EltVT = VT.getVectorElementType(); - MVT ShVT = MVT::getVectorVT(EltVT, 128/EltVT.getSizeInBits()); + MVT ShVT = MVT::getVectorVT(EltVT, 128 / EltVT.getSizeInBits()); ShAmt = DAG.getBitcast(ShVT, ShAmt); return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt); @@ -20292,11 +21370,7 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT, if (X86::isZeroNode(Mask)) return DAG.getConstant(0, dl, MaskVT); - if (MaskVT.bitsGT(Mask.getSimpleValueType())) { - // Mask should be extended - Mask = DAG.getNode(ISD::ANY_EXTEND, dl, - MVT::getIntegerVT(MaskVT.getSizeInBits()), Mask); - } + assert(MaskVT.bitsLE(Mask.getSimpleValueType()) && "Unexpected mask size!"); if (Mask.getSimpleValueType() == MVT::i64 && Subtarget.is32Bit()) { assert(MaskVT == MVT::v64i1 && "Expected v64i1 mask!"); @@ -20340,24 +21414,6 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - switch (Op.getOpcode()) { - default: break; - case X86ISD::CMPM: - case X86ISD::CMPM_RND: - case X86ISD::VPSHUFBITQMB: - case X86ISD::VFPCLASS: - return DAG.getNode(ISD::AND, dl, VT, Op, VMask); - case ISD::TRUNCATE: - case X86ISD::VTRUNC: - case X86ISD::VTRUNCS: - case X86ISD::VTRUNCUS: - case X86ISD::CVTPS2PH: - // We can't use ISD::VSELECT here because it is not always "Legal" - // for the destination type. For example vpmovqb require only AVX512 - // and vselect that can operate on byte element type require BWI - OpcodeSelect = X86ISD::SELECT; - break; - } if (PreservedSrc.isUndef()) PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl); return DAG.getNode(OpcodeSelect, dl, VT, VMask, Op, PreservedSrc); @@ -20383,7 +21439,9 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask, SDLoc dl(Op); assert(Mask.getValueType() == MVT::i8 && "Unexpect type"); - SDValue IMask = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Mask); + SDValue IMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i1, + DAG.getBitcast(MVT::v8i1, Mask), + DAG.getIntPtrConstant(0, dl)); if (Op.getOpcode() == X86ISD::FSETCCM || Op.getOpcode() == X86ISD::FSETCCM_RND || Op.getOpcode() == X86ISD::VFPCLASSS) @@ -20486,13 +21544,9 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, } return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1)); } - case INTR_TYPE_2OP: - case INTR_TYPE_2OP_IMM8: { + case INTR_TYPE_2OP: { SDValue Src2 = Op.getOperand(2); - if (IntrData->Type == INTR_TYPE_2OP_IMM8) - Src2 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src2); - // We specify 2 possible opcodes for intrinsics with rounding modes. // First, we check if the intrinsic may have non-default rounding mode, // (IntrData->Opc1 != 0), then we check the rounding mode operand. @@ -20724,38 +21778,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, // Swap Src1 and Src2 in the node creation return DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1); } - case FMA_OP_MASKZ: - case FMA_OP_MASK: { - SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); - SDValue Src3 = Op.getOperand(3); - SDValue Mask = Op.getOperand(4); - MVT VT = Op.getSimpleValueType(); - SDValue PassThru = SDValue(); - - // set PassThru element - if (IntrData->Type == FMA_OP_MASKZ) - PassThru = getZeroVector(VT, Subtarget, DAG, dl); - else - PassThru = Src1; - - // We specify 2 possible opcodes for intrinsics with rounding modes. - // First, we check if the intrinsic may have non-default rounding mode, - // (IntrData->Opc1 != 0), then we check the rounding mode operand. - unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; - if (IntrWithRoundingModeOpcode != 0) { - SDValue Rnd = Op.getOperand(5); - if (!isRoundModeCurDirection(Rnd)) - return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, - dl, Op.getValueType(), - Src1, Src2, Src3, Rnd), - Mask, PassThru, Subtarget, DAG); - } - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, - dl, Op.getValueType(), - Src1, Src2, Src3), - Mask, PassThru, Subtarget, DAG); - } case IFMA_OP: // NOTE: We need to swizzle the operands to pass the multiply operands // first. @@ -20766,7 +21788,7 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, // does not change the value. Set it to 0 since it can change. return DAG.getNode(IntrData->Opc0, dl, VT, Op.getOperand(1), DAG.getIntPtrConstant(0, dl)); - case CVTPD2PS_MASK: { + case CVTPD2PS_RND_MASK: { SDValue Src = Op.getOperand(1); SDValue PassThru = Op.getOperand(2); SDValue Mask = Op.getOperand(3); @@ -20790,13 +21812,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, DAG.getIntPtrConstant(0, dl)), Mask, PassThru, Subtarget, DAG); } - case FPCLASS: { - // FPclass intrinsics - SDValue Src1 = Op.getOperand(1); - MVT MaskVT = Op.getSimpleValueType(); - SDValue Imm = Op.getOperand(2); - return DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm); - } case FPCLASSS: { SDValue Src1 = Op.getOperand(1); SDValue Imm = Op.getOperand(2); @@ -20811,32 +21826,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, FPclassMask, DAG.getIntPtrConstant(0, dl)); return DAG.getBitcast(MVT::i8, Ins); } - case CMP_MASK: { - // Comparison intrinsics with masks. - // Example of transformation: - // (i8 (int_x86_avx512_mask_pcmpeq_q_128 - // (v2i64 %a), (v2i64 %b), (i8 %mask))) -> - // (i8 (bitcast - // (v8i1 (insert_subvector zero, - // (v2i1 (and (PCMPEQM %a, %b), - // (extract_subvector - // (v8i1 (bitcast %mask)), 0))), 0)))) - MVT VT = Op.getOperand(1).getSimpleValueType(); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue Mask = Op.getOperand((IntrData->Type == CMP_MASK_CC) ? 4 : 3); - MVT BitcastVT = MVT::getVectorVT(MVT::i1, - Mask.getSimpleValueType().getSizeInBits()); - SDValue Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), - Op.getOperand(2)); - SDValue CmpMask = getVectorMaskingNode(Cmp, Mask, SDValue(), - Subtarget, DAG); - // Need to fill with zeros to ensure the bitcast will produce zeroes - // for the upper bits in the v2i1/v4i1 case. - SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, - DAG.getConstant(0, dl, BitcastVT), - CmpMask, DAG.getIntPtrConstant(0, dl)); - return DAG.getBitcast(Op.getValueType(), Res); - } case CMP_MASK_CC: { MVT MaskVT = Op.getSimpleValueType(); @@ -21007,6 +21996,59 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), RoundingMode); } + // ADC/ADCX/SBB + case ADX: { + SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32); + SDVTList VTs = DAG.getVTList(Op.getOperand(2).getValueType(), MVT::i32); + + SDValue Res; + // If the carry in is zero, then we should just use ADD/SUB instead of + // ADC/SBB. + if (isNullConstant(Op.getOperand(1))) { + Res = DAG.getNode(IntrData->Opc1, dl, VTs, Op.getOperand(2), + Op.getOperand(3)); + } else { + SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(1), + DAG.getConstant(-1, dl, MVT::i8)); + Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(2), + Op.getOperand(3), GenCF.getValue(1)); + } + SDValue SetCC = getSETCC(X86::COND_B, Res.getValue(1), dl, DAG); + SDValue Results[] = { SetCC, Res }; + return DAG.getMergeValues(Results, dl); + } + case CVTPD2PS_MASK: + case CVTPD2I_MASK: + case TRUNCATE_TO_REG: { + SDValue Src = Op.getOperand(1); + SDValue PassThru = Op.getOperand(2); + SDValue Mask = Op.getOperand(3); + + if (isAllOnesConstant(Mask)) + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Src); + + MVT SrcVT = Src.getSimpleValueType(); + MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements()); + Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + return DAG.getNode(IntrData->Opc1, dl, Op.getValueType(), Src, PassThru, + Mask); + } + case CVTPS2PH_MASK: { + SDValue Src = Op.getOperand(1); + SDValue Rnd = Op.getOperand(2); + SDValue PassThru = Op.getOperand(3); + SDValue Mask = Op.getOperand(4); + + if (isAllOnesConstant(Mask)) + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Src, Rnd); + + MVT SrcVT = Src.getSimpleValueType(); + MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements()); + Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + return DAG.getNode(IntrData->Opc1, dl, Op.getValueType(), Src, Rnd, + PassThru, Mask); + + } default: break; } @@ -21018,6 +22060,14 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, // ptest and testp intrinsics. The intrinsic these come from are designed to // return an integer value, not just an instruction so lower it to the ptest // or testp pattern and a setcc for the result. + case Intrinsic::x86_avx512_ktestc_b: + case Intrinsic::x86_avx512_ktestc_w: + case Intrinsic::x86_avx512_ktestc_d: + case Intrinsic::x86_avx512_ktestc_q: + case Intrinsic::x86_avx512_ktestz_b: + case Intrinsic::x86_avx512_ktestz_w: + case Intrinsic::x86_avx512_ktestz_d: + case Intrinsic::x86_avx512_ktestz_q: case Intrinsic::x86_sse41_ptestz: case Intrinsic::x86_sse41_ptestc: case Intrinsic::x86_sse41_ptestnzc: @@ -21036,15 +22086,30 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::x86_avx_vtestz_pd_256: case Intrinsic::x86_avx_vtestc_pd_256: case Intrinsic::x86_avx_vtestnzc_pd_256: { - bool IsTestPacked = false; + unsigned TestOpc = X86ISD::PTEST; X86::CondCode X86CC; switch (IntNo) { default: llvm_unreachable("Bad fallthrough in Intrinsic lowering."); + case Intrinsic::x86_avx512_ktestc_b: + case Intrinsic::x86_avx512_ktestc_w: + case Intrinsic::x86_avx512_ktestc_d: + case Intrinsic::x86_avx512_ktestc_q: + // CF = 1 + TestOpc = X86ISD::KTEST; + X86CC = X86::COND_B; + break; + case Intrinsic::x86_avx512_ktestz_b: + case Intrinsic::x86_avx512_ktestz_w: + case Intrinsic::x86_avx512_ktestz_d: + case Intrinsic::x86_avx512_ktestz_q: + TestOpc = X86ISD::KTEST; + X86CC = X86::COND_E; + break; case Intrinsic::x86_avx_vtestz_ps: case Intrinsic::x86_avx_vtestz_pd: case Intrinsic::x86_avx_vtestz_ps_256: case Intrinsic::x86_avx_vtestz_pd_256: - IsTestPacked = true; + TestOpc = X86ISD::TESTP; LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestz: case Intrinsic::x86_avx_ptestz_256: @@ -21055,7 +22120,7 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::x86_avx_vtestc_pd: case Intrinsic::x86_avx_vtestc_ps_256: case Intrinsic::x86_avx_vtestc_pd_256: - IsTestPacked = true; + TestOpc = X86ISD::TESTP; LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestc: case Intrinsic::x86_avx_ptestc_256: @@ -21066,7 +22131,7 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::x86_avx_vtestnzc_pd: case Intrinsic::x86_avx_vtestnzc_ps_256: case Intrinsic::x86_avx_vtestnzc_pd_256: - IsTestPacked = true; + TestOpc = X86ISD::TESTP; LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestnzc: case Intrinsic::x86_avx_ptestnzc_256: @@ -21077,7 +22142,6 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue LHS = Op.getOperand(1); SDValue RHS = Op.getOperand(2); - unsigned TestOpc = IsTestPacked ? X86ISD::TESTP : X86ISD::PTEST; SDValue Test = DAG.getNode(TestOpc, dl, MVT::i32, LHS, RHS); SDValue SetCC = getSETCC(X86CC, Test, dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); @@ -21196,14 +22260,14 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(X86ISD::Wrapper, dl, VT, Result); } - case Intrinsic::x86_seh_recoverfp: { + case Intrinsic::eh_recoverfp: { SDValue FnOp = Op.getOperand(1); SDValue IncomingFPOp = Op.getOperand(2); GlobalAddressSDNode *GSD = dyn_cast<GlobalAddressSDNode>(FnOp); auto *Fn = dyn_cast_or_null<Function>(GSD ? GSD->getGlobal() : nullptr); if (!Fn) report_fatal_error( - "llvm.x86.seh.recoverfp must take a function as the first argument"); + "llvm.eh.recoverfp must take a function as the first argument"); return recoverFramePointer(DAG, Fn, IncomingFPOp); } @@ -21251,25 +22315,31 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, SDValue Src, SDValue Mask, SDValue Base, SDValue Index, SDValue ScaleOp, SDValue Chain, const X86Subtarget &Subtarget) { + MVT VT = Op.getSimpleValueType(); SDLoc dl(Op); auto *C = dyn_cast<ConstantSDNode>(ScaleOp); // Scale must be constant. if (!C) return SDValue(); SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8); - MVT MaskVT = MVT::getVectorVT(MVT::i1, - Index.getSimpleValueType().getVectorNumElements()); + unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(), + VT.getVectorNumElements()); + MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts); + + // We support two versions of the gather intrinsics. One with scalar mask and + // one with vXi1 mask. Convert scalar to vXi1 if necessary. + if (Mask.getValueType() != MaskVT) + Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other); SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32); SDValue Segment = DAG.getRegister(0, MVT::i32); // If source is undef or we know it won't be used, use a zero vector // to break register dependency. // TODO: use undef instead and let BreakFalseDeps deal with it? - if (Src.isUndef() || ISD::isBuildVectorAllOnes(VMask.getNode())) + if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode())) Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl); - SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain}; + SDValue Ops[] = {Src, Mask, Base, Scale, Index, Disp, Segment, Chain}; SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops); SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) }; return DAG.getMergeValues(RetOps, dl); @@ -21287,12 +22357,17 @@ static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8); SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32); SDValue Segment = DAG.getRegister(0, MVT::i32); - MVT MaskVT = MVT::getVectorVT(MVT::i1, - Index.getSimpleValueType().getVectorNumElements()); + unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(), + Src.getSimpleValueType().getVectorNumElements()); + MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts); + + // We support two versions of the scatter intrinsics. One with scalar mask and + // one with vXi1 mask. Convert scalar to vXi1 if necessary. + if (Mask.getValueType() != MaskVT) + Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other); - SDValue Ops[] = {Base, Scale, Index, Disp, Segment, VMask, Src, Chain}; + SDValue Ops[] = {Base, Scale, Index, Disp, Segment, Mask, Src, Chain}; SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops); return SDValue(Res, 1); } @@ -21433,39 +22508,39 @@ static void getReadTimeStampCounter(SDNode *N, const SDLoc &DL, unsigned Opcode, } SDValue Chain = HI.getValue(1); + SDValue TSC; + if (Subtarget.is64Bit()) { + // The EDX register is loaded with the high-order 32 bits of the MSR, and + // the EAX register is loaded with the low-order 32 bits. + TSC = DAG.getNode(ISD::SHL, DL, MVT::i64, HI, + DAG.getConstant(32, DL, MVT::i8)); + TSC = DAG.getNode(ISD::OR, DL, MVT::i64, LO, TSC); + } else { + // Use a buildpair to merge the two 32-bit values into a 64-bit one. + TSC = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, { LO, HI }); + } + if (Opcode == X86ISD::RDTSCP_DAG) { - assert(N->getNumOperands() == 3 && "Unexpected number of operands!"); + assert(N->getNumOperands() == 2 && "Unexpected number of operands!"); // Instruction RDTSCP loads the IA32:TSC_AUX_MSR (address C000_0103H) into // the ECX register. Add 'ecx' explicitly to the chain. SDValue ecx = DAG.getCopyFromReg(Chain, DL, X86::ECX, MVT::i32, HI.getValue(2)); - // Explicitly store the content of ECX at the location passed in input - // to the 'rdtscp' intrinsic. - Chain = DAG.getStore(ecx.getValue(1), DL, ecx, N->getOperand(2), - MachinePointerInfo()); - } - if (Subtarget.is64Bit()) { - // The EDX register is loaded with the high-order 32 bits of the MSR, and - // the EAX register is loaded with the low-order 32 bits. - SDValue Tmp = DAG.getNode(ISD::SHL, DL, MVT::i64, HI, - DAG.getConstant(32, DL, MVT::i8)); - Results.push_back(DAG.getNode(ISD::OR, DL, MVT::i64, LO, Tmp)); - Results.push_back(Chain); + Results.push_back(TSC); + Results.push_back(ecx); + Results.push_back(ecx.getValue(1)); return; } - // Use a buildpair to merge the two 32-bit values into a 64-bit one. - SDValue Ops[] = { LO, HI }; - SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, Ops); - Results.push_back(Pair); + Results.push_back(TSC); Results.push_back(Chain); } static SDValue LowerREADCYCLECOUNTER(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - SmallVector<SDValue, 2> Results; + SmallVector<SDValue, 3> Results; SDLoc DL(Op); getReadTimeStampCounter(Op.getNode(), DL, X86ISD::RDTSC_DAG, DAG, Subtarget, Results); @@ -21529,7 +22604,7 @@ EmitMaskedTruncSStore(bool SignedSat, SDValue Chain, const SDLoc &Dl, MachineMemOperand *MMO, SelectionDAG &DAG) { SDVTList VTs = DAG.getVTList(MVT::Other); - SDValue Ops[] = { Chain, Ptr, Mask, Val }; + SDValue Ops[] = { Chain, Val, Ptr, Mask }; return SignedSat ? DAG.getTargetMemSDNode<MaskedTruncSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO) : DAG.getTargetMemSDNode<MaskedTruncUSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO); @@ -21689,20 +22764,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Ret, SDValue(InTrans.getNode(), 1)); } - // ADC/ADCX/SBB - case ADX: { - SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32); - SDVTList VTs = DAG.getVTList(Op.getOperand(3).getValueType(), MVT::i32); - SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(2), - DAG.getConstant(-1, dl, MVT::i8)); - SDValue Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(3), - Op.getOperand(4), GenCF.getValue(1)); - SDValue Store = DAG.getStore(Op.getOperand(0), dl, Res.getValue(0), - Op.getOperand(5), MachinePointerInfo()); - SDValue SetCC = getSETCC(X86::COND_B, Res.getValue(1), dl, DAG); - SDValue Results[] = { SetCC, Store }; - return DAG.getMergeValues(Results, dl); - } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: case TRUNCATE_TO_MEM_VI32: { @@ -22255,11 +23316,10 @@ static SDValue LowerVectorCTLZInRegLUT(SDValue Op, const SDLoc &DL, // we just take the hi result (by masking the lo result to zero before the // add). SDValue Op0 = DAG.getBitcast(CurrVT, Op.getOperand(0)); - SDValue Zero = getZeroVector(CurrVT, Subtarget, DAG, DL); + SDValue Zero = DAG.getConstant(0, DL, CurrVT); - SDValue NibbleMask = DAG.getConstant(0xF, DL, CurrVT); SDValue NibbleShift = DAG.getConstant(0x4, DL, CurrVT); - SDValue Lo = DAG.getNode(ISD::AND, DL, CurrVT, Op0, NibbleMask); + SDValue Lo = Op0; SDValue Hi = DAG.getNode(ISD::SRL, DL, CurrVT, Op0, NibbleShift); SDValue HiZ; if (CurrVT.is512BitVector()) { @@ -22377,38 +23437,23 @@ static SDValue LowerCTLZ(SDValue Op, const X86Subtarget &Subtarget, return Op; } -static SDValue LowerCTTZ(SDValue Op, SelectionDAG &DAG) { +static SDValue LowerCTTZ(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); unsigned NumBits = VT.getScalarSizeInBits(); + SDValue N0 = Op.getOperand(0); SDLoc dl(Op); - if (VT.isVector()) { - SDValue N0 = Op.getOperand(0); - SDValue Zero = DAG.getConstant(0, dl, VT); - - // lsb(x) = (x & -x) - SDValue LSB = DAG.getNode(ISD::AND, dl, VT, N0, - DAG.getNode(ISD::SUB, dl, VT, Zero, N0)); - - // cttz_undef(x) = (width - 1) - ctlz(lsb) - if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF) { - SDValue WidthMinusOne = DAG.getConstant(NumBits - 1, dl, VT); - return DAG.getNode(ISD::SUB, dl, VT, WidthMinusOne, - DAG.getNode(ISD::CTLZ, dl, VT, LSB)); - } - - // cttz(x) = ctpop(lsb - 1) - SDValue One = DAG.getConstant(1, dl, VT); - return DAG.getNode(ISD::CTPOP, dl, VT, - DAG.getNode(ISD::SUB, dl, VT, LSB, One)); - } + // Decompose 256-bit ops into smaller 128-bit ops. + if (VT.is256BitVector() && !Subtarget.hasInt256()) + return Lower256IntUnary(Op, DAG); - assert(Op.getOpcode() == ISD::CTTZ && + assert(!VT.isVector() && Op.getOpcode() == ISD::CTTZ && "Only scalar CTTZ requires custom lowering"); // Issue a bsf (scan bits forward) which also sets EFLAGS. SDVTList VTs = DAG.getVTList(VT, MVT::i32); - Op = DAG.getNode(X86ISD::BSF, dl, VTs, Op.getOperand(0)); + Op = DAG.getNode(X86ISD::BSF, dl, VTs, N0); // If src is zero (i.e. bsf sets ZF), returns NumBits. SDValue Ops[] = { @@ -22422,7 +23467,7 @@ static SDValue LowerCTTZ(SDValue Op, SelectionDAG &DAG) { /// Break a 256-bit integer operation into two new 128-bit ones and then /// concatenate the result back. -static SDValue Lower256IntArith(SDValue Op, SelectionDAG &DAG) { +static SDValue split256IntArith(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); assert(VT.is256BitVector() && VT.isInteger() && @@ -22451,7 +23496,7 @@ static SDValue Lower256IntArith(SDValue Op, SelectionDAG &DAG) { /// Break a 512-bit integer operation into two new 256-bit ones and then /// concatenate the result back. -static SDValue Lower512IntArith(SDValue Op, SelectionDAG &DAG) { +static SDValue split512IntArith(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); assert(VT.is512BitVector() && VT.isInteger() && @@ -22478,18 +23523,46 @@ static SDValue Lower512IntArith(SDValue Op, SelectionDAG &DAG) { DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2)); } -static SDValue LowerADD_SUB(SDValue Op, SelectionDAG &DAG) { +static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); + if (VT == MVT::i16 || VT == MVT::i32) + return lowerAddSubToHorizontalOp(Op, DAG, Subtarget); + if (VT.getScalarType() == MVT::i1) return DAG.getNode(ISD::XOR, SDLoc(Op), VT, Op.getOperand(0), Op.getOperand(1)); + assert(Op.getSimpleValueType().is256BitVector() && Op.getSimpleValueType().isInteger() && "Only handle AVX 256-bit vector integer operation"); - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); } -static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) { +static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG) { + MVT VT = Op.getSimpleValueType(); + if (VT.getScalarType() == MVT::i1) { + SDLoc dl(Op); + switch (Op.getOpcode()) { + default: llvm_unreachable("Expected saturated arithmetic opcode"); + case ISD::UADDSAT: + case ISD::SADDSAT: + return DAG.getNode(ISD::OR, dl, VT, Op.getOperand(0), Op.getOperand(1)); + case ISD::USUBSAT: + case ISD::SSUBSAT: + return DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), + DAG.getNOT(dl, Op.getOperand(1), VT)); + } + } + + assert(Op.getSimpleValueType().is256BitVector() && + Op.getSimpleValueType().isInteger() && + "Only handle AVX 256-bit vector integer operation"); + return split256IntArith(Op, DAG); +} + +static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); if (VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) { // Since X86 does not have CMOV for 8-bit integer, we don't convert @@ -22503,10 +23576,23 @@ static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) { return DAG.getNode(X86ISD::CMOV, DL, VT, Ops); } - assert(Op.getSimpleValueType().is256BitVector() && - Op.getSimpleValueType().isInteger() && - "Only handle AVX 256-bit vector integer operation"); - return Lower256IntUnary(Op, DAG); + // ABS(vXi64 X) --> VPBLENDVPD(X, 0-X, X). + if ((VT == MVT::v2i64 || VT == MVT::v4i64) && Subtarget.hasSSE41()) { + SDLoc DL(Op); + SDValue Src = Op.getOperand(0); + SDValue Sub = + DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Src); + return DAG.getNode(X86ISD::BLENDV, DL, VT, Src, Sub, Src); + } + + if (VT.is256BitVector() && !Subtarget.hasInt256()) { + assert(VT.isInteger() && + "Only handle AVX 256-bit vector integer operation"); + return Lower256IntUnary(Op, DAG); + } + + // Default to expand. + return SDValue(); } static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) { @@ -22514,7 +23600,7 @@ static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) { // For AVX1 cases, split to use legal ops (everything but v4i64). if (VT.getScalarType() != MVT::i64 && VT.is256BitVector()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); @@ -22556,9 +23642,9 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, if (VT.getScalarType() == MVT::i1) return DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), Op.getOperand(1)); - // Decompose 256-bit ops into smaller 128-bit ops. + // Decompose 256-bit ops into 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); SDValue A = Op.getOperand(0); SDValue B = Op.getOperand(1); @@ -22566,53 +23652,49 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // Lower v16i8/v32i8/v64i8 mul as sign-extension to v8i16/v16i16/v32i16 // vector pairs, multiply and truncate. if (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) { - if (Subtarget.hasInt256()) { - // For 512-bit vectors, split into 256-bit vectors to allow the - // sign-extension to occur. - if (VT == MVT::v64i8) - return Lower512IntArith(Op, DAG); - - // For 256-bit vectors, split into 128-bit vectors to allow the - // sign-extension to occur. We don't need this on AVX512BW as we can - // safely sign-extend to v32i16. - if (VT == MVT::v32i8 && !Subtarget.hasBWI()) - return Lower256IntArith(Op, DAG); + unsigned NumElts = VT.getVectorNumElements(); + if ((VT == MVT::v16i8 && Subtarget.hasInt256()) || + (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) { MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements()); return DAG.getNode( ISD::TRUNCATE, dl, VT, DAG.getNode(ISD::MUL, dl, ExVT, - DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, A), - DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, B))); + DAG.getNode(ISD::ANY_EXTEND, dl, ExVT, A), + DAG.getNode(ISD::ANY_EXTEND, dl, ExVT, B))); } - assert(VT == MVT::v16i8 && - "Pre-AVX2 support only supports v16i8 multiplication"); - MVT ExVT = MVT::v8i16; + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); - // Extract the lo parts and sign extend to i16 + // Extract the lo/hi parts to any extend to i16. // We're going to mask off the low byte of each result element of the // pmullw, so it doesn't matter what's in the high byte of each 16-bit // element. - const int LoShufMask[] = {0, -1, 1, -1, 2, -1, 3, -1, - 4, -1, 5, -1, 6, -1, 7, -1}; - SDValue ALo = DAG.getVectorShuffle(VT, dl, A, A, LoShufMask); - SDValue BLo = DAG.getVectorShuffle(VT, dl, B, B, LoShufMask); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); - - // Extract the hi parts and sign extend to i16 - // We're going to mask off the low byte of each result element of the - // pmullw, so it doesn't matter what's in the high byte of each 16-bit - // element. - const int HiShufMask[] = {8, -1, 9, -1, 10, -1, 11, -1, - 12, -1, 13, -1, 14, -1, 15, -1}; - SDValue AHi = DAG.getVectorShuffle(VT, dl, A, A, HiShufMask); - SDValue BHi = DAG.getVectorShuffle(VT, dl, B, B, HiShufMask); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); - - // Multiply, mask the lower 8bits of the lo/hi results and pack + SDValue Undef = DAG.getUNDEF(VT); + SDValue ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Undef)); + SDValue AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Undef)); + + SDValue BLo, BHi; + if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { + // If the LHS is a constant, manually unpackl/unpackh. + SmallVector<SDValue, 16> LoOps, HiOps; + for (unsigned i = 0; i != NumElts; i += 16) { + for (unsigned j = 0; j != 8; ++j) { + LoOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j), dl, + MVT::i16)); + HiOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j + 8), dl, + MVT::i16)); + } + } + + BLo = DAG.getBuildVector(ExVT, dl, LoOps); + BHi = DAG.getBuildVector(ExVT, dl, HiOps); + } else { + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Undef)); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Undef)); + } + + // Multiply, mask the lower 8bits of the lo/hi results and pack. SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); RLo = DAG.getNode(ISD::AND, dl, ExVT, RLo, DAG.getConstant(255, dl, ExVT)); @@ -22661,9 +23743,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // // Hi = psllqi(AloBhi + AhiBlo, 32); // return AloBlo + Hi; - KnownBits AKnown, BKnown; - DAG.computeKnownBits(A, AKnown); - DAG.computeKnownBits(B, BKnown); + KnownBits AKnown = DAG.computeKnownBits(A); + KnownBits BKnown = DAG.computeKnownBits(B); APInt LowerBitsMask = APInt::getLowBitsSet(64, 32); bool ALoIsZero = LowerBitsMask.isSubsetOf(AKnown.Zero); @@ -22673,7 +23754,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, bool AHiIsZero = UpperBitsMask.isSubsetOf(AKnown.Zero); bool BHiIsZero = UpperBitsMask.isSubsetOf(BKnown.Zero); - SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl); + SDValue Zero = DAG.getConstant(0, dl, VT); // Only multiply lo/hi halves that aren't known to be zero. SDValue AloBlo = Zero; @@ -22702,10 +23783,79 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); + bool IsSigned = Op->getOpcode() == ISD::MULHS; + unsigned NumElts = VT.getVectorNumElements(); + SDValue A = Op.getOperand(0); + SDValue B = Op.getOperand(1); - // Decompose 256-bit ops into smaller 128-bit ops. + // Decompose 256-bit ops into 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); + + if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) { + assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) || + (VT == MVT::v8i32 && Subtarget.hasInt256()) || + (VT == MVT::v16i32 && Subtarget.hasAVX512())); + + // PMULxD operations multiply each even value (starting at 0) of LHS with + // the related value of RHS and produce a widen result. + // E.g., PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h> + // => <2 x i64> <ae|cg> + // + // In other word, to have all the results, we need to perform two PMULxD: + // 1. one with the even values. + // 2. one with the odd values. + // To achieve #2, with need to place the odd values at an even position. + // + // Place the odd value at an even position (basically, shift all values 1 + // step to the left): + const int Mask[] = {1, -1, 3, -1, 5, -1, 7, -1, + 9, -1, 11, -1, 13, -1, 15, -1}; + // <a|b|c|d> => <b|undef|d|undef> + SDValue Odd0 = DAG.getVectorShuffle(VT, dl, A, A, + makeArrayRef(&Mask[0], NumElts)); + // <e|f|g|h> => <f|undef|h|undef> + SDValue Odd1 = DAG.getVectorShuffle(VT, dl, B, B, + makeArrayRef(&Mask[0], NumElts)); + + // Emit two multiplies, one for the lower 2 ints and one for the higher 2 + // ints. + MVT MulVT = MVT::getVectorVT(MVT::i64, NumElts / 2); + unsigned Opcode = + (IsSigned && Subtarget.hasSSE41()) ? X86ISD::PMULDQ : X86ISD::PMULUDQ; + // PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h> + // => <2 x i64> <ae|cg> + SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, + DAG.getBitcast(MulVT, A), + DAG.getBitcast(MulVT, B))); + // PMULUDQ <4 x i32> <b|undef|d|undef>, <4 x i32> <f|undef|h|undef> + // => <2 x i64> <bf|dh> + SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, + DAG.getBitcast(MulVT, Odd0), + DAG.getBitcast(MulVT, Odd1))); + + // Shuffle it back into the right order. + SmallVector<int, 16> ShufMask(NumElts); + for (int i = 0; i != (int)NumElts; ++i) + ShufMask[i] = (i / 2) * 2 + ((i % 2) * NumElts) + 1; + + SDValue Res = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, ShufMask); + + // If we have a signed multiply but no PMULDQ fix up the result of an + // unsigned multiply. + if (IsSigned && !Subtarget.hasSSE41()) { + SDValue Zero = DAG.getConstant(0, dl, VT); + SDValue T1 = DAG.getNode(ISD::AND, dl, VT, + DAG.getSetCC(dl, VT, Zero, A, ISD::SETGT), B); + SDValue T2 = DAG.getNode(ISD::AND, dl, VT, + DAG.getSetCC(dl, VT, Zero, B, ISD::SETGT), A); + + SDValue Fixup = DAG.getNode(ISD::ADD, dl, VT, T1, T2); + Res = DAG.getNode(ISD::SUB, dl, VT, Res, Fixup); + } + + return Res; + } // Only i8 vectors should need custom lowering after this. assert((VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) || @@ -22714,123 +23864,141 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, // Lower v16i8/v32i8 as extension to v8i16/v16i16 vector pairs, multiply, // logical shift down the upper half and pack back to i8. - SDValue A = Op.getOperand(0); - SDValue B = Op.getOperand(1); // With SSE41 we can use sign/zero extend, but for pre-SSE41 we unpack // and then ashr/lshr the upper bits down to the lower bits before multiply. - unsigned Opcode = Op.getOpcode(); - unsigned ExShift = (ISD::MULHU == Opcode ? ISD::SRL : ISD::SRA); - unsigned ExAVX = (ISD::MULHU == Opcode ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND); + unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + + if ((VT == MVT::v16i8 && Subtarget.hasInt256()) || + (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) { + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts); + SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A); + SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B); + SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB); + Mul = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + } - // For 512-bit vectors, split into 256-bit vectors to allow the + // For signed 512-bit vectors, split into 256-bit vectors to allow the // sign-extension to occur. - if (VT == MVT::v64i8) - return Lower512IntArith(Op, DAG); + if (VT == MVT::v64i8 && IsSigned) + return split512IntArith(Op, DAG); - // AVX2 implementations - extend xmm subvectors to ymm. - if (Subtarget.hasInt256()) { - unsigned NumElems = VT.getVectorNumElements(); + // Signed AVX2 implementation - extend xmm subvectors to ymm. + if (VT == MVT::v32i8 && IsSigned) { SDValue Lo = DAG.getIntPtrConstant(0, dl); - SDValue Hi = DAG.getIntPtrConstant(NumElems / 2, dl); - - if (VT == MVT::v32i8) { - if (Subtarget.canExtendTo512BW()) { - SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v32i16, A); - SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v32i16, B); - SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v32i16, ExA, ExB); - Mul = DAG.getNode(ISD::SRL, dl, MVT::v32i16, Mul, - DAG.getConstant(8, dl, MVT::v32i16)); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); - } - SDValue ALo = extract128BitVector(A, 0, DAG, dl); - SDValue BLo = extract128BitVector(B, 0, DAG, dl); - SDValue AHi = extract128BitVector(A, NumElems / 2, DAG, dl); - SDValue BHi = extract128BitVector(B, NumElems / 2, DAG, dl); - ALo = DAG.getNode(ExAVX, dl, MVT::v16i16, ALo); - BLo = DAG.getNode(ExAVX, dl, MVT::v16i16, BLo); - AHi = DAG.getNode(ExAVX, dl, MVT::v16i16, AHi); - BHi = DAG.getNode(ExAVX, dl, MVT::v16i16, BHi); - Lo = DAG.getNode(ISD::SRL, dl, MVT::v16i16, - DAG.getNode(ISD::MUL, dl, MVT::v16i16, ALo, BLo), - DAG.getConstant(8, dl, MVT::v16i16)); - Hi = DAG.getNode(ISD::SRL, dl, MVT::v16i16, - DAG.getNode(ISD::MUL, dl, MVT::v16i16, AHi, BHi), - DAG.getConstant(8, dl, MVT::v16i16)); - // The ymm variant of PACKUS treats the 128-bit lanes separately, so before - // using PACKUS we need to permute the inputs to the correct lo/hi xmm lane. - const int LoMask[] = {0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; - const int HiMask[] = {8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; - return DAG.getNode(X86ISD::PACKUS, dl, VT, - DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, LoMask), - DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask)); - } - - assert(VT == MVT::v16i8 && "Unexpected VT"); - - SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A); - SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B); - SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB); - Mul = DAG.getNode(ISD::SRL, dl, MVT::v16i16, Mul, - DAG.getConstant(8, dl, MVT::v16i16)); - // If we have BWI we can use truncate instruction. - if (Subtarget.hasBWI()) - return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); - Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, Mul, Lo); - Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, Mul, Hi); - return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi); - } - - assert(VT == MVT::v16i8 && - "Pre-AVX2 support only supports v16i8 multiplication"); - MVT ExVT = MVT::v8i16; - unsigned ExSSE41 = ISD::MULHU == Opcode ? ISD::ZERO_EXTEND_VECTOR_INREG - : ISD::SIGN_EXTEND_VECTOR_INREG; + SDValue Hi = DAG.getIntPtrConstant(NumElts / 2, dl); + + MVT ExVT = MVT::v16i16; + SDValue ALo = extract128BitVector(A, 0, DAG, dl); + SDValue BLo = extract128BitVector(B, 0, DAG, dl); + SDValue AHi = extract128BitVector(A, NumElts / 2, DAG, dl); + SDValue BHi = extract128BitVector(B, NumElts / 2, DAG, dl); + ALo = DAG.getNode(ExAVX, dl, ExVT, ALo); + BLo = DAG.getNode(ExAVX, dl, ExVT, BLo); + AHi = DAG.getNode(ExAVX, dl, ExVT, AHi); + BHi = DAG.getNode(ExAVX, dl, ExVT, BHi); + Lo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); + Hi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); + Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Lo, 8, DAG); + Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Hi, 8, DAG); + + // Bitcast back to VT and then pack all the even elements from Lo and Hi. + // Shuffle lowering should turn this into PACKUS+PERMQ + Lo = DAG.getBitcast(VT, Lo); + Hi = DAG.getBitcast(VT, Hi); + return DAG.getVectorShuffle(VT, dl, Lo, Hi, + { 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, + 48, 50, 52, 54, 56, 58, 60, 62}); + } + + // For signed v16i8 and all unsigned vXi8 we will unpack the low and high + // half of each 128 bit lane to widen to a vXi16 type. Do the multiplies, + // shift the results and pack the half lane results back together. + + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); + + static const int PSHUFDMask[] = { 8, 9, 10, 11, 12, 13, 14, 15, + -1, -1, -1, -1, -1, -1, -1, -1}; // Extract the lo parts and zero/sign extend to i16. - SDValue ALo, BLo; - if (Subtarget.hasSSE41()) { - ALo = DAG.getNode(ExSSE41, dl, ExVT, A); - BLo = DAG.getNode(ExSSE41, dl, ExVT, B); + // Only use SSE4.1 instructions for signed v16i8 where using unpack requires + // shifts to sign extend. Using unpack for unsigned only requires an xor to + // create zeros and a copy due to tied registers contraints pre-avx. But using + // zero_extend_vector_inreg would require an additional pshufd for the high + // part. + + SDValue ALo, AHi; + if (IsSigned && VT == MVT::v16i8 && Subtarget.hasSSE41()) { + ALo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, A); + + AHi = DAG.getVectorShuffle(VT, dl, A, A, PSHUFDMask); + AHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, AHi); + } else if (IsSigned) { + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), A)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), A)); + + ALo = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, ALo, 8, DAG); + AHi = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, AHi, 8, DAG); } else { - const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3, - -1, 4, -1, 5, -1, 6, -1, 7}; - ALo = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BLo = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); - ALo = DAG.getNode(ExShift, dl, ExVT, ALo, DAG.getConstant(8, dl, ExVT)); - BLo = DAG.getNode(ExShift, dl, ExVT, BLo, DAG.getConstant(8, dl, ExVT)); - } - - // Extract the hi parts and zero/sign extend to i16. - SDValue AHi, BHi; - if (Subtarget.hasSSE41()) { - const int ShufMask[] = {8, 9, 10, 11, 12, 13, 14, 15, - -1, -1, -1, -1, -1, -1, -1, -1}; - AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = DAG.getNode(ExSSE41, dl, ExVT, AHi); - BHi = DAG.getNode(ExSSE41, dl, ExVT, BHi); + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, + DAG.getConstant(0, dl, VT))); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, + DAG.getConstant(0, dl, VT))); + } + + SDValue BLo, BHi; + if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { + // If the LHS is a constant, manually unpackl/unpackh and extend. + SmallVector<SDValue, 16> LoOps, HiOps; + for (unsigned i = 0; i != NumElts; i += 16) { + for (unsigned j = 0; j != 8; ++j) { + SDValue LoOp = B.getOperand(i + j); + SDValue HiOp = B.getOperand(i + j + 8); + + if (IsSigned) { + LoOp = DAG.getSExtOrTrunc(LoOp, dl, MVT::i16); + HiOp = DAG.getSExtOrTrunc(HiOp, dl, MVT::i16); + } else { + LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16); + HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16); + } + + LoOps.push_back(LoOp); + HiOps.push_back(HiOp); + } + } + + BLo = DAG.getBuildVector(ExVT, dl, LoOps); + BHi = DAG.getBuildVector(ExVT, dl, HiOps); + } else if (IsSigned && VT == MVT::v16i8 && Subtarget.hasSSE41()) { + BLo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, B); + + BHi = DAG.getVectorShuffle(VT, dl, B, B, PSHUFDMask); + BHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, BHi); + } else if (IsSigned) { + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), B)); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), B)); + + BLo = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, BLo, 8, DAG); + BHi = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, BHi, 8, DAG); } else { - const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11, - -1, 12, -1, 13, -1, 14, -1, 15}; - AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); - AHi = DAG.getNode(ExShift, dl, ExVT, AHi, DAG.getConstant(8, dl, ExVT)); - BHi = DAG.getNode(ExShift, dl, ExVT, BHi, DAG.getConstant(8, dl, ExVT)); + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, + DAG.getConstant(0, dl, VT))); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, + DAG.getConstant(0, dl, VT))); } // Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and - // pack back to v16i8. + // pack back to vXi8. SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); - RLo = DAG.getNode(ISD::SRL, dl, ExVT, RLo, DAG.getConstant(8, dl, ExVT)); - RHi = DAG.getNode(ISD::SRL, dl, ExVT, RHi, DAG.getConstant(8, dl, ExVT)); + RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RLo, 8, DAG); + RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RHi, 8, DAG); + + // Bitcast back to VT and then pack all the even elements from Lo and Hi. return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi); } @@ -22890,105 +24058,6 @@ SDValue X86TargetLowering::LowerWin64_i128OP(SDValue Op, SelectionDAG &DAG) cons return DAG.getBitcast(VT, CallInfo.first); } -static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1); - MVT VT = Op0.getSimpleValueType(); - SDLoc dl(Op); - - // Decompose 256-bit ops into smaller 128-bit ops. - if (VT.is256BitVector() && !Subtarget.hasInt256()) { - unsigned Opcode = Op.getOpcode(); - unsigned NumElems = VT.getVectorNumElements(); - MVT HalfVT = MVT::getVectorVT(VT.getScalarType(), NumElems / 2); - SDValue Lo0 = extract128BitVector(Op0, 0, DAG, dl); - SDValue Lo1 = extract128BitVector(Op1, 0, DAG, dl); - SDValue Hi0 = extract128BitVector(Op0, NumElems / 2, DAG, dl); - SDValue Hi1 = extract128BitVector(Op1, NumElems / 2, DAG, dl); - SDValue Lo = DAG.getNode(Opcode, dl, DAG.getVTList(HalfVT, HalfVT), Lo0, Lo1); - SDValue Hi = DAG.getNode(Opcode, dl, DAG.getVTList(HalfVT, HalfVT), Hi0, Hi1); - SDValue Ops[] = { - DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo.getValue(0), Hi.getValue(0)), - DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo.getValue(1), Hi.getValue(1)) - }; - return DAG.getMergeValues(Ops, dl); - } - - assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) || - (VT == MVT::v8i32 && Subtarget.hasInt256()) || - (VT == MVT::v16i32 && Subtarget.hasAVX512())); - - int NumElts = VT.getVectorNumElements(); - - // PMULxD operations multiply each even value (starting at 0) of LHS with - // the related value of RHS and produce a widen result. - // E.g., PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h> - // => <2 x i64> <ae|cg> - // - // In other word, to have all the results, we need to perform two PMULxD: - // 1. one with the even values. - // 2. one with the odd values. - // To achieve #2, with need to place the odd values at an even position. - // - // Place the odd value at an even position (basically, shift all values 1 - // step to the left): - const int Mask[] = {1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1}; - // <a|b|c|d> => <b|undef|d|undef> - SDValue Odd0 = DAG.getVectorShuffle(VT, dl, Op0, Op0, - makeArrayRef(&Mask[0], NumElts)); - // <e|f|g|h> => <f|undef|h|undef> - SDValue Odd1 = DAG.getVectorShuffle(VT, dl, Op1, Op1, - makeArrayRef(&Mask[0], NumElts)); - - // Emit two multiplies, one for the lower 2 ints and one for the higher 2 - // ints. - MVT MulVT = MVT::getVectorVT(MVT::i64, NumElts / 2); - bool IsSigned = Op->getOpcode() == ISD::SMUL_LOHI; - unsigned Opcode = - (!IsSigned || !Subtarget.hasSSE41()) ? X86ISD::PMULUDQ : X86ISD::PMULDQ; - // PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h> - // => <2 x i64> <ae|cg> - SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, - DAG.getBitcast(MulVT, Op0), - DAG.getBitcast(MulVT, Op1))); - // PMULUDQ <4 x i32> <b|undef|d|undef>, <4 x i32> <f|undef|h|undef> - // => <2 x i64> <bf|dh> - SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, - DAG.getBitcast(MulVT, Odd0), - DAG.getBitcast(MulVT, Odd1))); - - // Shuffle it back into the right order. - SmallVector<int, 16> HighMask(NumElts); - SmallVector<int, 16> LowMask(NumElts); - for (int i = 0; i != NumElts; ++i) { - HighMask[i] = (i / 2) * 2 + ((i % 2) * NumElts) + 1; - LowMask[i] = (i / 2) * 2 + ((i % 2) * NumElts); - } - - SDValue Highs = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, HighMask); - SDValue Lows = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, LowMask); - - // If we have a signed multiply but no PMULDQ fix up the high parts of a - // unsigned multiply. - if (IsSigned && !Subtarget.hasSSE41()) { - SDValue ShAmt = DAG.getConstant( - 31, dl, - DAG.getTargetLoweringInfo().getShiftAmountTy(VT, DAG.getDataLayout())); - SDValue T1 = DAG.getNode(ISD::AND, dl, VT, - DAG.getNode(ISD::SRA, dl, VT, Op0, ShAmt), Op1); - SDValue T2 = DAG.getNode(ISD::AND, dl, VT, - DAG.getNode(ISD::SRA, dl, VT, Op1, ShAmt), Op0); - - SDValue Fixup = DAG.getNode(ISD::ADD, dl, VT, T1, T2); - Highs = DAG.getNode(ISD::SUB, dl, VT, Highs, Fixup); - } - - // The first result of MUL_LOHI is actually the low value, followed by the - // high value. - SDValue Ops[] = {Lows, Highs}; - return DAG.getMergeValues(Ops, dl); -} - // Return true if the required (according to Opcode) shift-imm form is natively // supported by the Subtarget static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget &Subtarget, @@ -23042,9 +24111,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, SDLoc dl(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); - - unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; + unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false); auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) { assert((VT == MVT::v2i64 || VT == MVT::v4i64) && "Unexpected SRA type"); @@ -23055,8 +24122,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, if (ShiftAmt == 63 && Subtarget.hasSSE42()) { assert((VT != MVT::v4i64 || Subtarget.hasInt256()) && "Unsupported PCMPGT op"); - return DAG.getNode(X86ISD::PCMPGT, dl, VT, - getZeroVector(VT, Subtarget, DAG, dl), R); + return DAG.getNode(X86ISD::PCMPGT, dl, VT, DAG.getConstant(0, dl, VT), R); } if (ShiftAmt >= 32) { @@ -23071,7 +24137,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {9, 1, 11, 3, 13, 5, 15, 7}); } else { - // SRA upper i32, SHL whole i64 and select lower i32. + // SRA upper i32, SRL whole i64 and select lower i32. SDValue Upper = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex, ShiftAmt, DAG); SDValue Lower = @@ -23087,199 +24153,123 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, }; // Optimize shl/srl/sra with constant shift amount. - if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { - if (auto *ShiftConst = BVAmt->getConstantSplatNode()) { - uint64_t ShiftAmt = ShiftConst->getZExtValue(); - - if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) - return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); - - // i64 SRA needs to be performed as partial shifts. - if (((!Subtarget.hasXOP() && VT == MVT::v2i64) || - (Subtarget.hasInt256() && VT == MVT::v4i64)) && - Op.getOpcode() == ISD::SRA) - return ArithmeticShiftRight64(ShiftAmt); - - if (VT == MVT::v16i8 || - (Subtarget.hasInt256() && VT == MVT::v32i8) || - VT == MVT::v64i8) { - unsigned NumElts = VT.getVectorNumElements(); - MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2); - - // Simple i8 add case - if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1) - return DAG.getNode(ISD::ADD, dl, VT, R, R); - - // ashr(R, 7) === cmp_slt(R, 0) - if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) { - SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl); - if (VT.is512BitVector()) { - assert(VT == MVT::v64i8 && "Unexpected element type!"); - SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, - ISD::SETGT); - return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP); - } - return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R); - } + APInt APIntShiftAmt; + if (!isConstantSplat(Amt, APIntShiftAmt)) + return SDValue(); + uint64_t ShiftAmt = APIntShiftAmt.getZExtValue(); - // XOP can shift v16i8 directly instead of as shift v8i16 + mask. - if (VT == MVT::v16i8 && Subtarget.hasXOP()) - return SDValue(); + if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) + return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); - if (Op.getOpcode() == ISD::SHL) { - // Make a large shift. - SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT, - R, ShiftAmt, DAG); - SHL = DAG.getBitcast(VT, SHL); - // Zero out the rightmost bits. - return DAG.getNode(ISD::AND, dl, VT, SHL, - DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT)); - } - if (Op.getOpcode() == ISD::SRL) { - // Make a large shift. - SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT, - R, ShiftAmt, DAG); - SRL = DAG.getBitcast(VT, SRL); - // Zero out the leftmost bits. - return DAG.getNode(ISD::AND, dl, VT, SRL, - DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT)); - } - if (Op.getOpcode() == ISD::SRA) { - // ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask) - SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt); - - SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT); - Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask); - Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask); - return Res; - } - llvm_unreachable("Unknown shift opcode."); - } - } - } + // i64 SRA needs to be performed as partial shifts. + if (((!Subtarget.hasXOP() && VT == MVT::v2i64) || + (Subtarget.hasInt256() && VT == MVT::v4i64)) && + Op.getOpcode() == ISD::SRA) + return ArithmeticShiftRight64(ShiftAmt); - // Check cases (mainly 32-bit) where i64 is expanded into high and low parts. - // TODO: Replace constant extraction with getTargetConstantBitsFromNode. - if (!Subtarget.hasXOP() && - (VT == MVT::v2i64 || (Subtarget.hasInt256() && VT == MVT::v4i64) || - (Subtarget.hasAVX512() && VT == MVT::v8i64))) { + if (VT == MVT::v16i8 || (Subtarget.hasInt256() && VT == MVT::v32i8) || + VT == MVT::v64i8) { + unsigned NumElts = VT.getVectorNumElements(); + MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2); - // AVX1 targets maybe extracting a 128-bit vector from a 256-bit constant. - unsigned SubVectorScale = 1; - if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR) { - SubVectorScale = - Amt.getOperand(0).getValueSizeInBits() / Amt.getValueSizeInBits(); - Amt = Amt.getOperand(0); - } + // Simple i8 add case + if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1) + return DAG.getNode(ISD::ADD, dl, VT, R, R); - // Peek through any splat that was introduced for i64 shift vectorization. - int SplatIndex = -1; - if (ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(Amt.getNode())) - if (SVN->isSplat()) { - SplatIndex = SVN->getSplatIndex(); - Amt = Amt.getOperand(0); - assert(SplatIndex < (int)VT.getVectorNumElements() && - "Splat shuffle referencing second operand"); + // ashr(R, 7) === cmp_slt(R, 0) + if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) { + SDValue Zeros = DAG.getConstant(0, dl, VT); + if (VT.is512BitVector()) { + assert(VT == MVT::v64i8 && "Unexpected element type!"); + SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, ISD::SETGT); + return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP); } + return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R); + } - if (Amt.getOpcode() != ISD::BITCAST || - Amt.getOperand(0).getOpcode() != ISD::BUILD_VECTOR) + // XOP can shift v16i8 directly instead of as shift v8i16 + mask. + if (VT == MVT::v16i8 && Subtarget.hasXOP()) return SDValue(); - Amt = Amt.getOperand(0); - unsigned Ratio = Amt.getSimpleValueType().getVectorNumElements() / - (SubVectorScale * VT.getVectorNumElements()); - unsigned RatioInLog2 = Log2_32_Ceil(Ratio); - uint64_t ShiftAmt = 0; - unsigned BaseOp = (SplatIndex < 0 ? 0 : SplatIndex * Ratio); - for (unsigned i = 0; i != Ratio; ++i) { - ConstantSDNode *C = dyn_cast<ConstantSDNode>(Amt.getOperand(i + BaseOp)); - if (!C) - return SDValue(); - // 6 == Log2(64) - ShiftAmt |= C->getZExtValue() << (i * (1 << (6 - RatioInLog2))); - } - - // Check remaining shift amounts (if not a splat). - if (SplatIndex < 0) { - for (unsigned i = Ratio; i != Amt.getNumOperands(); i += Ratio) { - uint64_t ShAmt = 0; - for (unsigned j = 0; j != Ratio; ++j) { - ConstantSDNode *C = dyn_cast<ConstantSDNode>(Amt.getOperand(i + j)); - if (!C) - return SDValue(); - // 6 == Log2(64) - ShAmt |= C->getZExtValue() << (j * (1 << (6 - RatioInLog2))); - } - if (ShAmt != ShiftAmt) - return SDValue(); - } + if (Op.getOpcode() == ISD::SHL) { + // Make a large shift. + SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT, R, + ShiftAmt, DAG); + SHL = DAG.getBitcast(VT, SHL); + // Zero out the rightmost bits. + return DAG.getNode(ISD::AND, dl, VT, SHL, + DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT)); + } + if (Op.getOpcode() == ISD::SRL) { + // Make a large shift. + SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT, R, + ShiftAmt, DAG); + SRL = DAG.getBitcast(VT, SRL); + // Zero out the leftmost bits. + return DAG.getNode(ISD::AND, dl, VT, SRL, + DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT)); + } + if (Op.getOpcode() == ISD::SRA) { + // ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask) + SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt); + + SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT); + Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask); + Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask); + return Res; } - - if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) - return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); - - if (Op.getOpcode() == ISD::SRA) - return ArithmeticShiftRight64(ShiftAmt); + llvm_unreachable("Unknown shift opcode."); } return SDValue(); } -// Determine if V is a splat value, and return the scalar. -static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, - SelectionDAG &DAG, const X86Subtarget &Subtarget, - unsigned Opcode) { - V = peekThroughEXTRACT_SUBVECTORs(V); - - // Check if this is a splat build_vector node. - if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) { - SDValue SplatAmt = BV->getSplatValue(); - if (SplatAmt && SplatAmt.isUndef()) - return SDValue(); - return SplatAmt; - } - - // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns. - if (V.getOpcode() == ISD::SUB && - !SupportedVectorVarShift(VT, Subtarget, Opcode)) { - SDValue LHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(0)); - SDValue RHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(1)); +// If V is a splat value, return the source vector and splat index; +static SDValue IsSplatVector(SDValue V, int &SplatIdx, SelectionDAG &DAG) { + V = peekThroughEXTRACT_SUBVECTORs(V); - // Ensure that the corresponding splat BV element is not UNDEF. - BitVector UndefElts; - BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS); - ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS); - if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) { - unsigned SplatIdx = (unsigned)SVN1->getSplatIndex(); - if (!UndefElts[SplatIdx]) - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - VT.getVectorElementType(), V, - DAG.getIntPtrConstant(SplatIdx, dl)); + EVT VT = V.getValueType(); + unsigned Opcode = V.getOpcode(); + switch (Opcode) { + default: { + APInt UndefElts; + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (DAG.isSplatValue(V, DemandedElts, UndefElts)) { + // Handle case where all demanded elements are UNDEF. + if (DemandedElts.isSubsetOf(UndefElts)) { + SplatIdx = 0; + return DAG.getUNDEF(VT); + } + SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); + return V; } + break; } - - // Check if this is a shuffle node doing a splat. - ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V); - if (!SVN || !SVN->isSplat()) - return SDValue(); - - unsigned SplatIdx = (unsigned)SVN->getSplatIndex(); - SDValue InVec = V.getOperand(0); - if (InVec.getOpcode() == ISD::BUILD_VECTOR) { - assert((SplatIdx < VT.getVectorNumElements()) && - "Unexpected shuffle index found!"); - return InVec.getOperand(SplatIdx); - } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) { - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2))) - if (C->getZExtValue() == SplatIdx) - return InVec.getOperand(1); + case ISD::VECTOR_SHUFFLE: { + // Check if this is a shuffle node doing a splat. + // TODO - remove this and rely purely on SelectionDAG::isSplatValue, + // getTargetVShiftNode currently struggles without the splat source. + auto *SVN = cast<ShuffleVectorSDNode>(V); + if (!SVN->isSplat()) + break; + int Idx = SVN->getSplatIndex(); + int NumElts = V.getValueType().getVectorNumElements(); + SplatIdx = Idx % NumElts; + return V.getOperand(Idx / NumElts); } + } + + return SDValue(); +} - // Avoid introducing an extract element from a shuffle. - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - VT.getVectorElementType(), InVec, - DAG.getIntPtrConstant(SplatIdx, dl)); +static SDValue GetSplatValue(SDValue V, const SDLoc &dl, + SelectionDAG &DAG) { + int SplatIdx; + if (SDValue SrcVector = IsSplatVector(V, SplatIdx, DAG)) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + SrcVector.getValueType().getScalarType(), SrcVector, + DAG.getIntPtrConstant(SplatIdx, dl)); + return SDValue(); } static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, @@ -23289,17 +24279,11 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); unsigned Opcode = Op.getOpcode(); + unsigned X86OpcI = getTargetVShiftUniformOpcode(Opcode, false); + unsigned X86OpcV = getTargetVShiftUniformOpcode(Opcode, true); - unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI : - (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; - - unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL : - (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; - - Amt = peekThroughEXTRACT_SUBVECTORs(Amt); - - if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) { - if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG, Subtarget, Opcode)) { + if (SDValue BaseShAmt = GetSplatValue(Amt, dl, DAG)) { + if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) { MVT EltVT = VT.getVectorElementType(); assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!"); if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32)) @@ -23309,6 +24293,50 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, Subtarget, DAG); } + + // vXi8 shifts - shift as v8i16 + mask result. + if (((VT == MVT::v16i8 && !Subtarget.canExtendTo512DQ()) || + (VT == MVT::v32i8 && !Subtarget.canExtendTo512BW()) || + VT == MVT::v64i8) && + !Subtarget.hasXOP()) { + unsigned NumElts = VT.getVectorNumElements(); + MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2); + if (SupportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, Opcode)) { + unsigned LogicalOp = (Opcode == ISD::SHL ? ISD::SHL : ISD::SRL); + unsigned LogicalX86Op = getTargetVShiftUniformOpcode(LogicalOp, false); + BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt); + + // Create the mask using vXi16 shifts. For shift-rights we need to move + // the upper byte down before splatting the vXi8 mask. + SDValue BitMask = DAG.getConstant(-1, dl, ExtVT); + BitMask = getTargetVShiftNode(LogicalX86Op, dl, ExtVT, BitMask, + BaseShAmt, Subtarget, DAG); + if (Opcode != ISD::SHL) + BitMask = getTargetVShiftByConstNode(LogicalX86Op, dl, ExtVT, BitMask, + 8, DAG); + BitMask = DAG.getBitcast(VT, BitMask); + BitMask = DAG.getVectorShuffle(VT, dl, BitMask, BitMask, + SmallVector<int, 64>(NumElts, 0)); + + SDValue Res = getTargetVShiftNode(LogicalX86Op, dl, ExtVT, + DAG.getBitcast(ExtVT, R), BaseShAmt, + Subtarget, DAG); + Res = DAG.getBitcast(VT, Res); + Res = DAG.getNode(ISD::AND, dl, VT, Res, BitMask); + + if (Opcode == ISD::SRA) { + // ashr(R, Amt) === sub(xor(lshr(R, Amt), SignMask), SignMask) + // SignMask = lshr(SignBit, Amt) - safe to do this with PSRLW. + SDValue SignMask = DAG.getConstant(0x8080, dl, ExtVT); + SignMask = getTargetVShiftNode(LogicalX86Op, dl, ExtVT, SignMask, + BaseShAmt, Subtarget, DAG); + SignMask = DAG.getBitcast(VT, SignMask); + Res = DAG.getNode(ISD::XOR, dl, VT, Res, SignMask); + Res = DAG.getNode(ISD::SUB, dl, VT, Res, SignMask); + } + return Res; + } + } } // Check cases (mainly 32-bit) where i64 is expanded into high and low parts. @@ -23379,7 +24407,7 @@ static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl, // AVX2 can more effectively perform this as a zext/trunc to/from v8i32. if (VT == MVT::v8i16 && !Subtarget.hasAVX2()) { - SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); + SDValue Z = DAG.getConstant(0, dl, VT); SDValue Lo = DAG.getBitcast(MVT::v4i32, getUnpackl(DAG, dl, VT, Amt, Z)); SDValue Hi = DAG.getBitcast(MVT::v4i32, getUnpackh(DAG, dl, VT, Amt, Z)); Lo = convertShiftLeftToScale(Lo, dl, Subtarget, DAG); @@ -23401,8 +24429,13 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, SDLoc dl(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode()); + unsigned Opc = Op.getOpcode(); + unsigned X86OpcV = getTargetVShiftUniformOpcode(Opc, true); + unsigned X86OpcI = getTargetVShiftUniformOpcode(Opc, false); + assert(VT.isVector() && "Custom lowering only for vector shifts!"); assert(Subtarget.hasSSE2() && "Only custom lower when we have SSE2!"); @@ -23412,31 +24445,31 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget)) return V; - if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode())) + if (SupportedVectorVarShift(VT, Subtarget, Opc)) return Op; // XOP has 128-bit variable logical/arithmetic shifts. // +ve/-ve Amt = shift left/right. if (Subtarget.hasXOP() && (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8)) { - if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) { + if (Opc == ISD::SRL || Opc == ISD::SRA) { SDValue Zero = DAG.getConstant(0, dl, VT); Amt = DAG.getNode(ISD::SUB, dl, VT, Zero, Amt); } - if (Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL) + if (Opc == ISD::SHL || Opc == ISD::SRL) return DAG.getNode(X86ISD::VPSHL, dl, VT, R, Amt); - if (Op.getOpcode() == ISD::SRA) + if (Opc == ISD::SRA) return DAG.getNode(X86ISD::VPSHA, dl, VT, R, Amt); } // 2i64 vector logical shifts can efficiently avoid scalarization - do the // shifts per-lane and then shuffle the partial results back together. - if (VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) { + if (VT == MVT::v2i64 && Opc != ISD::SRA) { // Splat the shift amounts so the scalar shifts above will catch it. SDValue Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {0, 0}); SDValue Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {1, 1}); - SDValue R0 = DAG.getNode(Op->getOpcode(), dl, VT, R, Amt0); - SDValue R1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Amt1); + SDValue R0 = DAG.getNode(Opc, dl, VT, R, Amt0); + SDValue R1 = DAG.getNode(Opc, dl, VT, R, Amt1); return DAG.getVectorShuffle(VT, dl, R0, R1, {0, 3}); } @@ -23444,7 +24477,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // M = lshr(SIGN_MASK, Amt) // ashr(R, Amt) === sub(xor(lshr(R, Amt), M), M) if ((VT == MVT::v2i64 || (VT == MVT::v4i64 && Subtarget.hasInt256())) && - Op.getOpcode() == ISD::SRA) { + Opc == ISD::SRA) { SDValue S = DAG.getConstant(APInt::getSignMask(64), dl, VT); SDValue M = DAG.getNode(ISD::SRL, dl, VT, S, Amt); R = DAG.getNode(ISD::SRL, dl, VT, R, Amt); @@ -23489,36 +24522,34 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // Only perform this blend if we can perform it without loading a mask. if (ShuffleMask.size() == NumElts && Amt1 && Amt2 && - isa<ConstantSDNode>(Amt1) && isa<ConstantSDNode>(Amt2) && (VT != MVT::v16i16 || is128BitLaneRepeatedShuffleMask(VT, ShuffleMask)) && - (VT == MVT::v4i32 || Subtarget.hasSSE41() || - Op.getOpcode() != ISD::SHL || canWidenShuffleElements(ShuffleMask))) { - SDValue Splat1 = - DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT); - SDValue Shift1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat1); - SDValue Splat2 = - DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT); - SDValue Shift2 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat2); - return DAG.getVectorShuffle(VT, dl, Shift1, Shift2, ShuffleMask); + (VT == MVT::v4i32 || Subtarget.hasSSE41() || Opc != ISD::SHL || + canWidenShuffleElements(ShuffleMask))) { + auto *Cst1 = dyn_cast<ConstantSDNode>(Amt1); + auto *Cst2 = dyn_cast<ConstantSDNode>(Amt2); + if (Cst1 && Cst2 && Cst1->getAPIntValue().ult(EltSizeInBits) && + Cst2->getAPIntValue().ult(EltSizeInBits)) { + SDValue Shift1 = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, + Cst1->getZExtValue(), DAG); + SDValue Shift2 = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, + Cst2->getZExtValue(), DAG); + return DAG.getVectorShuffle(VT, dl, Shift1, Shift2, ShuffleMask); + } } } // If possible, lower this packed shift into a vector multiply instead of // expanding it into a sequence of scalar shifts. - if (Op.getOpcode() == ISD::SHL) + if (Opc == ISD::SHL) if (SDValue Scale = convertShiftLeftToScale(Amt, dl, Subtarget, DAG)) return DAG.getNode(ISD::MUL, dl, VT, R, Scale); - // Constant ISD::SRL can be performed efficiently on vXi8/vXi16 vectors as we + // Constant ISD::SRL can be performed efficiently on vXi16 vectors as we // can replace with ISD::MULHU, creating scale factor from (NumEltBits - Amt). - // TODO: Improve support for the shift by zero special case. - if (Op.getOpcode() == ISD::SRL && ConstantAmt && - ((Subtarget.hasSSE41() && VT == MVT::v8i16) || - DAG.isKnownNeverZero(Amt)) && - (VT == MVT::v16i8 || VT == MVT::v8i16 || - ((VT == MVT::v32i8 || VT == MVT::v16i16) && Subtarget.hasInt256()))) { - SDValue EltBits = DAG.getConstant(VT.getScalarSizeInBits(), dl, VT); + if (Opc == ISD::SRL && ConstantAmt && + (VT == MVT::v8i16 || (VT == MVT::v16i16 && Subtarget.hasInt256()))) { + SDValue EltBits = DAG.getConstant(EltSizeInBits, dl, VT); SDValue RAmt = DAG.getNode(ISD::SUB, dl, VT, EltBits, Amt); if (SDValue Scale = convertShiftLeftToScale(RAmt, dl, Subtarget, DAG)) { SDValue Zero = DAG.getConstant(0, dl, VT); @@ -23528,13 +24559,36 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, } } + // Constant ISD::SRA can be performed efficiently on vXi16 vectors as we + // can replace with ISD::MULHS, creating scale factor from (NumEltBits - Amt). + // TODO: Special case handling for shift by 0/1, really we can afford either + // of these cases in pre-SSE41/XOP/AVX512 but not both. + if (Opc == ISD::SRA && ConstantAmt && + (VT == MVT::v8i16 || (VT == MVT::v16i16 && Subtarget.hasInt256())) && + ((Subtarget.hasSSE41() && !Subtarget.hasXOP() && + !Subtarget.hasAVX512()) || + DAG.isKnownNeverZero(Amt))) { + SDValue EltBits = DAG.getConstant(EltSizeInBits, dl, VT); + SDValue RAmt = DAG.getNode(ISD::SUB, dl, VT, EltBits, Amt); + if (SDValue Scale = convertShiftLeftToScale(RAmt, dl, Subtarget, DAG)) { + SDValue Amt0 = + DAG.getSetCC(dl, VT, Amt, DAG.getConstant(0, dl, VT), ISD::SETEQ); + SDValue Amt1 = + DAG.getSetCC(dl, VT, Amt, DAG.getConstant(1, dl, VT), ISD::SETEQ); + SDValue Sra1 = + getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, 1, DAG); + SDValue Res = DAG.getNode(ISD::MULHS, dl, VT, R, Scale); + Res = DAG.getSelect(dl, VT, Amt0, R, Res); + return DAG.getSelect(dl, VT, Amt1, Sra1, Res); + } + } + // v4i32 Non Uniform Shifts. // If the shift amount is constant we can shift each lane using the SSE2 // immediate shifts, else we need to zero-extend each lane to the lower i64 // and shift using the SSE2 variable shifts. // The separate results can then be blended together. if (VT == MVT::v4i32) { - unsigned Opc = Op.getOpcode(); SDValue Amt0, Amt1, Amt2, Amt3; if (ConstantAmt) { Amt0 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {0, 0, 0, 0}); @@ -23542,26 +24596,12 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, Amt2 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {2, 2, 2, 2}); Amt3 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {3, 3, 3, 3}); } else { - // ISD::SHL is handled above but we include it here for completeness. - switch (Opc) { - default: - llvm_unreachable("Unknown target vector shift node"); - case ISD::SHL: - Opc = X86ISD::VSHL; - break; - case ISD::SRL: - Opc = X86ISD::VSRL; - break; - case ISD::SRA: - Opc = X86ISD::VSRA; - break; - } // The SSE2 shifts use the lower i64 as the same shift amount for // all lanes and the upper i64 is ignored. On AVX we're better off // just zero-extending, but for SSE just duplicating the top 16-bits is // cheaper and has the same effect for out of range values. if (Subtarget.hasAVX()) { - SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); + SDValue Z = DAG.getConstant(0, dl, VT); Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1}); Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1}); Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1}); @@ -23581,10 +24621,11 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, } } - SDValue R0 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt0)); - SDValue R1 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt1)); - SDValue R2 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt2)); - SDValue R3 = DAG.getNode(Opc, dl, VT, R, DAG.getBitcast(VT, Amt3)); + unsigned ShOpc = ConstantAmt ? Opc : X86OpcV; + SDValue R0 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt0)); + SDValue R1 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt1)); + SDValue R2 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt2)); + SDValue R3 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt3)); // Merge the shifted lane results optimally with/without PBLENDW. // TODO - ideally shuffle combining would handle this. @@ -23611,19 +24652,66 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, "Unexpected vector type"); MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32; MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements()); - unsigned ExtOpc = - Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + unsigned ExtOpc = Opc == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; R = DAG.getNode(ExtOpc, dl, ExtVT, R); Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt); return DAG.getNode(ISD::TRUNCATE, dl, VT, - DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt)); + DAG.getNode(Opc, dl, ExtVT, R, Amt)); + } + + // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we + // extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI. + if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) && + (VT == MVT::v16i8 || VT == MVT::v64i8 || + (VT == MVT::v32i8 && Subtarget.hasInt256())) && + !Subtarget.hasXOP()) { + int NumElts = VT.getVectorNumElements(); + SDValue Cst8 = DAG.getConstant(8, dl, MVT::i8); + + // Extend constant shift amount to vXi16 (it doesn't matter if the type + // isn't legal). + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts); + Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT); + Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt); + Amt = DAG.getNode(ISD::SHL, dl, ExVT, DAG.getConstant(1, dl, ExVT), Amt); + assert(ISD::isBuildVectorOfConstantSDNodes(Amt.getNode()) && + "Constant build vector expected"); + + if (VT == MVT::v16i8 && Subtarget.hasInt256()) { + R = Opc == ISD::SRA ? DAG.getSExtOrTrunc(R, dl, ExVT) + : DAG.getZExtOrTrunc(R, dl, ExVT); + R = DAG.getNode(ISD::MUL, dl, ExVT, R, Amt); + R = DAG.getNode(X86ISD::VSRLI, dl, ExVT, R, Cst8); + return DAG.getZExtOrTrunc(R, dl, VT); + } + + SmallVector<SDValue, 16> LoAmt, HiAmt; + for (int i = 0; i != NumElts; i += 16) { + for (int j = 0; j != 8; ++j) { + LoAmt.push_back(Amt.getOperand(i + j)); + HiAmt.push_back(Amt.getOperand(i + j + 8)); + } + } + + MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2); + SDValue LoA = DAG.getBuildVector(VT16, dl, LoAmt); + SDValue HiA = DAG.getBuildVector(VT16, dl, HiAmt); + + SDValue LoR = DAG.getBitcast(VT16, getUnpackl(DAG, dl, VT, R, R)); + SDValue HiR = DAG.getBitcast(VT16, getUnpackh(DAG, dl, VT, R, R)); + LoR = DAG.getNode(X86OpcI, dl, VT16, LoR, Cst8); + HiR = DAG.getNode(X86OpcI, dl, VT16, HiR, Cst8); + LoR = DAG.getNode(ISD::MUL, dl, VT16, LoR, LoA); + HiR = DAG.getNode(ISD::MUL, dl, VT16, HiR, HiA); + LoR = DAG.getNode(X86ISD::VSRLI, dl, VT16, LoR, Cst8); + HiR = DAG.getNode(X86ISD::VSRLI, dl, VT16, HiR, Cst8); + return DAG.getNode(X86ISD::PACKUS, dl, VT, LoR, HiR); } if (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) || (VT == MVT::v64i8 && Subtarget.hasBWI())) { MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); - unsigned ShiftOpcode = Op->getOpcode(); auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) { if (VT.is512BitVector()) { @@ -23648,7 +24736,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // On pre-SSE41 targets we test for the sign bit by comparing to // zero - a negative value will set all bits of the lanes to true // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering. - SDValue Z = getZeroVector(SelVT, Subtarget, DAG, dl); + SDValue Z = DAG.getConstant(0, dl, SelVT); SDValue C = DAG.getNode(X86ISD::PCMPGT, dl, SelVT, Z, Sel); return DAG.getSelect(dl, SelVT, C, V0, V1); }; @@ -23657,49 +24745,46 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // We can safely do this using i16 shifts as we're only interested in // the 3 lower bits of each byte. Amt = DAG.getBitcast(ExtVT, Amt); - Amt = DAG.getNode(ISD::SHL, dl, ExtVT, Amt, DAG.getConstant(5, dl, ExtVT)); + Amt = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ExtVT, Amt, 5, DAG); Amt = DAG.getBitcast(VT, Amt); - if (Op->getOpcode() == ISD::SHL || Op->getOpcode() == ISD::SRL) { + if (Opc == ISD::SHL || Opc == ISD::SRL) { // r = VSELECT(r, shift(r, 4), a); - SDValue M = - DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT)); + SDValue M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(4, dl, VT)); R = SignBitSelect(VT, Amt, M, R); // a += a Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt); // r = VSELECT(r, shift(r, 2), a); - M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(2, dl, VT)); + M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(2, dl, VT)); R = SignBitSelect(VT, Amt, M, R); // a += a Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt); // return VSELECT(r, shift(r, 1), a); - M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(1, dl, VT)); + M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(1, dl, VT)); R = SignBitSelect(VT, Amt, M, R); return R; } - if (Op->getOpcode() == ISD::SRA) { + if (Opc == ISD::SRA) { // For SRA we need to unpack each byte to the higher byte of a i16 vector // so we can correctly sign extend. We don't care what happens to the // lower byte. - SDValue ALo = DAG.getNode(X86ISD::UNPCKL, dl, VT, DAG.getUNDEF(VT), Amt); - SDValue AHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, DAG.getUNDEF(VT), Amt); - SDValue RLo = DAG.getNode(X86ISD::UNPCKL, dl, VT, DAG.getUNDEF(VT), R); - SDValue RHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, DAG.getUNDEF(VT), R); + SDValue ALo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), Amt); + SDValue AHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), Amt); + SDValue RLo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), R); + SDValue RHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), R); ALo = DAG.getBitcast(ExtVT, ALo); AHi = DAG.getBitcast(ExtVT, AHi); RLo = DAG.getBitcast(ExtVT, RLo); RHi = DAG.getBitcast(ExtVT, RHi); // r = VSELECT(r, shift(r, 4), a); - SDValue MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo, - DAG.getConstant(4, dl, ExtVT)); - SDValue MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi, - DAG.getConstant(4, dl, ExtVT)); + SDValue MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 4, DAG); + SDValue MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 4, DAG); RLo = SignBitSelect(ExtVT, ALo, MLo, RLo); RHi = SignBitSelect(ExtVT, AHi, MHi, RHi); @@ -23708,10 +24793,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi); // r = VSELECT(r, shift(r, 2), a); - MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo, - DAG.getConstant(2, dl, ExtVT)); - MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi, - DAG.getConstant(2, dl, ExtVT)); + MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 2, DAG); + MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 2, DAG); RLo = SignBitSelect(ExtVT, ALo, MLo, RLo); RHi = SignBitSelect(ExtVT, AHi, MHi, RHi); @@ -23720,45 +24803,38 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi); // r = VSELECT(r, shift(r, 1), a); - MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo, - DAG.getConstant(1, dl, ExtVT)); - MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi, - DAG.getConstant(1, dl, ExtVT)); + MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 1, DAG); + MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 1, DAG); RLo = SignBitSelect(ExtVT, ALo, MLo, RLo); RHi = SignBitSelect(ExtVT, AHi, MHi, RHi); // Logical shift the result back to the lower byte, leaving a zero upper - // byte - // meaning that we can safely pack with PACKUSWB. - RLo = - DAG.getNode(ISD::SRL, dl, ExtVT, RLo, DAG.getConstant(8, dl, ExtVT)); - RHi = - DAG.getNode(ISD::SRL, dl, ExtVT, RHi, DAG.getConstant(8, dl, ExtVT)); + // byte meaning that we can safely pack with PACKUSWB. + RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RLo, 8, DAG); + RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RHi, 8, DAG); return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi); } } if (Subtarget.hasInt256() && !Subtarget.hasXOP() && VT == MVT::v16i16) { MVT ExtVT = MVT::v8i32; - SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); - SDValue ALo = DAG.getNode(X86ISD::UNPCKL, dl, VT, Amt, Z); - SDValue AHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, Amt, Z); - SDValue RLo = DAG.getNode(X86ISD::UNPCKL, dl, VT, Z, R); - SDValue RHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, Z, R); + SDValue Z = DAG.getConstant(0, dl, VT); + SDValue ALo = getUnpackl(DAG, dl, VT, Amt, Z); + SDValue AHi = getUnpackh(DAG, dl, VT, Amt, Z); + SDValue RLo = getUnpackl(DAG, dl, VT, Z, R); + SDValue RHi = getUnpackh(DAG, dl, VT, Z, R); ALo = DAG.getBitcast(ExtVT, ALo); AHi = DAG.getBitcast(ExtVT, AHi); RLo = DAG.getBitcast(ExtVT, RLo); RHi = DAG.getBitcast(ExtVT, RHi); - SDValue Lo = DAG.getNode(Op.getOpcode(), dl, ExtVT, RLo, ALo); - SDValue Hi = DAG.getNode(Op.getOpcode(), dl, ExtVT, RHi, AHi); - Lo = DAG.getNode(ISD::SRL, dl, ExtVT, Lo, DAG.getConstant(16, dl, ExtVT)); - Hi = DAG.getNode(ISD::SRL, dl, ExtVT, Hi, DAG.getConstant(16, dl, ExtVT)); + SDValue Lo = DAG.getNode(Opc, dl, ExtVT, RLo, ALo); + SDValue Hi = DAG.getNode(Opc, dl, ExtVT, RHi, AHi); + Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Lo, 16, DAG); + Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Hi, 16, DAG); return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi); } if (VT == MVT::v8i16) { - unsigned ShiftOpcode = Op->getOpcode(); - // If we have a constant shift amount, the non-SSE41 path is best as // avoiding bitcasts make it easier to constant fold and reduce to PBLENDW. bool UseSSE41 = Subtarget.hasSSE41() && @@ -23778,7 +24854,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // set all bits of the lanes to true and VSELECT uses that in // its OR(AND(V0,C),AND(V1,~C)) lowering. SDValue C = - DAG.getNode(ISD::SRA, dl, VT, Sel, DAG.getConstant(15, dl, VT)); + getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Sel, 15, DAG); return DAG.getSelect(dl, VT, C, V0, V1); }; @@ -23788,42 +24864,42 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // bytes for PBLENDVB. Amt = DAG.getNode( ISD::OR, dl, VT, - DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(4, dl, VT)), - DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(12, dl, VT))); + getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 4, DAG), + getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 12, DAG)); } else { - Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(12, dl, VT)); + Amt = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 12, DAG); } // r = VSELECT(r, shift(r, 8), a); - SDValue M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(8, dl, VT)); + SDValue M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 8, DAG); R = SignBitSelect(Amt, M, R); // a += a Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt); // r = VSELECT(r, shift(r, 4), a); - M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT)); + M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 4, DAG); R = SignBitSelect(Amt, M, R); // a += a Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt); // r = VSELECT(r, shift(r, 2), a); - M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(2, dl, VT)); + M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 2, DAG); R = SignBitSelect(Amt, M, R); // a += a Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt); // return VSELECT(r, shift(r, 1), a); - M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(1, dl, VT)); + M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 1, DAG); R = SignBitSelect(Amt, M, R); return R; } - // Decompose 256-bit shifts into smaller 128-bit shifts. + // Decompose 256-bit shifts into 128-bit shifts. if (VT.is256BitVector()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); return SDValue(); } @@ -23838,20 +24914,31 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, SDValue Amt = Op.getOperand(1); unsigned Opcode = Op.getOpcode(); unsigned EltSizeInBits = VT.getScalarSizeInBits(); + int NumElts = VT.getVectorNumElements(); + + // Check for constant splat rotation amount. + APInt UndefElts; + SmallVector<APInt, 32> EltBits; + int CstSplatIndex = -1; + if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits)) + for (int i = 0; i != NumElts; ++i) + if (!UndefElts[i]) { + if (CstSplatIndex < 0 || EltBits[i] == EltBits[CstSplatIndex]) { + CstSplatIndex = i; + continue; + } + CstSplatIndex = -1; + break; + } + // AVX512 implicitly uses modulo rotation amounts. if (Subtarget.hasAVX512() && 32 <= EltSizeInBits) { // Attempt to rotate by immediate. - APInt UndefElts; - SmallVector<APInt, 16> EltBits; - if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits)) { - if (!UndefElts && llvm::all_of(EltBits, [EltBits](APInt &V) { - return EltBits[0] == V; - })) { - unsigned Op = (Opcode == ISD::ROTL ? X86ISD::VROTLI : X86ISD::VROTRI); - uint64_t RotateAmt = EltBits[0].urem(EltSizeInBits); - return DAG.getNode(Op, DL, VT, R, - DAG.getConstant(RotateAmt, DL, MVT::i8)); - } + if (0 <= CstSplatIndex) { + unsigned Op = (Opcode == ISD::ROTL ? X86ISD::VROTLI : X86ISD::VROTRI); + uint64_t RotateAmt = EltBits[CstSplatIndex].urem(EltSizeInBits); + return DAG.getNode(Op, DL, VT, R, + DAG.getConstant(RotateAmt, DL, MVT::i8)); } // Else, fall-back on VPROLV/VPRORV. @@ -23862,20 +24949,17 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, // XOP has 128-bit vector variable + immediate rotates. // +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL. + // XOP implicitly uses modulo rotation amounts. if (Subtarget.hasXOP()) { - // Split 256-bit integers. if (VT.is256BitVector()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); assert(VT.is128BitVector() && "Only rotate 128-bit vectors!"); // Attempt to rotate by immediate. - if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { - if (auto *RotateConst = BVAmt->getConstantSplatNode()) { - uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue(); - assert(RotateAmt < EltSizeInBits && "Rotation out of range"); - return DAG.getNode(X86ISD::VROTLI, DL, VT, R, - DAG.getConstant(RotateAmt, DL, MVT::i8)); - } + if (0 <= CstSplatIndex) { + uint64_t RotateAmt = EltBits[CstSplatIndex].urem(EltSizeInBits); + return DAG.getNode(X86ISD::VROTLI, DL, VT, R, + DAG.getConstant(RotateAmt, DL, MVT::i8)); } // Use general rotate by variable (per-element). @@ -23884,7 +24968,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, // Split 256-bit integers on pre-AVX2 targets. if (VT.is256BitVector() && !Subtarget.hasAVX2()) - return Lower256IntArith(Op, DAG); + return split256IntArith(Op, DAG); assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 || ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) && @@ -23892,44 +24976,19 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, "Only vXi32/vXi16/vXi8 vector rotates supported"); // Rotate by an uniform constant - expand back to shifts. - // TODO - legalizers should be able to handle this. - if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { - if (auto *RotateConst = BVAmt->getConstantSplatNode()) { - uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue(); - assert(RotateAmt < EltSizeInBits && "Rotation out of range"); - if (RotateAmt == 0) - return R; - - SDValue AmtR = DAG.getConstant(EltSizeInBits - RotateAmt, DL, VT); - SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); - SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); - return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); - } - } + if (0 <= CstSplatIndex) + return SDValue(); - // Rotate by splat - expand back to shifts. - // TODO - legalizers should be able to handle this. - if ((EltSizeInBits >= 16 || Subtarget.hasBWI()) && - IsSplatValue(VT, Amt, DL, DAG, Subtarget, Opcode)) { - SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); - AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); - SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); - SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); - return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); - } + bool IsSplatAmt = DAG.isSplatValue(Amt); // v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by // the amount bit. - if (EltSizeInBits == 8) { - if (Subtarget.hasBWI()) { - SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); - AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); - SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); - SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR); - return DAG.getNode(ISD::OR, DL, VT, SHL, SRL); - } + if (EltSizeInBits == 8 && !IsSplatAmt) { + if (ISD::isBuildVectorOfConstantSDNodes(Amt.getNode())) + return SDValue(); - MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); + // We don't need ModuloAmt here as we just peek at individual bits. + MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2); auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) { if (Subtarget.hasSSE41()) { @@ -23943,7 +25002,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, // On pre-SSE41 targets we test for the sign bit by comparing to // zero - a negative value will set all bits of the lanes to true // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering. - SDValue Z = getZeroVector(SelVT, Subtarget, DAG, DL); + SDValue Z = DAG.getConstant(0, DL, SelVT); SDValue C = DAG.getNode(X86ISD::PCMPGT, DL, SelVT, Z, Sel); return DAG.getSelect(DL, SelVT, C, V0, V1); }; @@ -23984,14 +25043,17 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, return SignBitSelect(VT, Amt, M, R); } + // ISD::ROT* uses modulo rotate amounts. + Amt = DAG.getNode(ISD::AND, DL, VT, Amt, + DAG.getConstant(EltSizeInBits - 1, DL, VT)); + bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode()); bool LegalVarShifts = SupportedVectorVarShift(VT, Subtarget, ISD::SHL) && SupportedVectorVarShift(VT, Subtarget, ISD::SRL); - // Best to fallback for all supported variable shifts. - // AVX2 - best to fallback for non-constants as well. - // TODO - legalizers should be able to handle this. - if (LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) { + // Fallback for splats + all supported variable shifts. + // Fallback for non-constants AVX2 vXi16 as well. + if (IsSplatAmt || LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) { SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt); @@ -24032,78 +25094,6 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, DAG.getVectorShuffle(VT, DL, Res02, Res13, {1, 5, 3, 7})); } -static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { - // Lower the "add/sub/mul with overflow" instruction into a regular ins plus - // a "setcc" instruction that checks the overflow flag. The "brcond" lowering - // looks for this combo and may remove the "setcc" instruction if the "setcc" - // has only one use. - SDNode *N = Op.getNode(); - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - unsigned BaseOp = 0; - X86::CondCode Cond; - SDLoc DL(Op); - switch (Op.getOpcode()) { - default: llvm_unreachable("Unknown ovf instruction!"); - case ISD::SADDO: - // A subtract of one will be selected as a INC. Note that INC doesn't - // set CF, so we can't do this for UADDO. - if (isOneConstant(RHS)) { - BaseOp = X86ISD::INC; - Cond = X86::COND_O; - break; - } - BaseOp = X86ISD::ADD; - Cond = X86::COND_O; - break; - case ISD::UADDO: - BaseOp = X86ISD::ADD; - Cond = X86::COND_B; - break; - case ISD::SSUBO: - // A subtract of one will be selected as a DEC. Note that DEC doesn't - // set CF, so we can't do this for USUBO. - if (isOneConstant(RHS)) { - BaseOp = X86ISD::DEC; - Cond = X86::COND_O; - break; - } - BaseOp = X86ISD::SUB; - Cond = X86::COND_O; - break; - case ISD::USUBO: - BaseOp = X86ISD::SUB; - Cond = X86::COND_B; - break; - case ISD::SMULO: - BaseOp = N->getValueType(0) == MVT::i8 ? X86ISD::SMUL8 : X86ISD::SMUL; - Cond = X86::COND_O; - break; - case ISD::UMULO: { // i64, i8 = umulo lhs, rhs --> i64, i64, i32 umul lhs,rhs - if (N->getValueType(0) == MVT::i8) { - BaseOp = X86ISD::UMUL8; - Cond = X86::COND_O; - break; - } - SDVTList VTs = DAG.getVTList(N->getValueType(0), N->getValueType(0), - MVT::i32); - SDValue Sum = DAG.getNode(X86ISD::UMUL, DL, VTs, LHS, RHS); - - SDValue SetCC = getSETCC(X86::COND_O, SDValue(Sum.getNode(), 2), DL, DAG); - - return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); - } - } - - // Also sets EFLAGS. - SDVTList VTs = DAG.getVTList(N->getValueType(0), MVT::i32); - SDValue Sum = DAG.getNode(BaseOp, DL, VTs, LHS, RHS); - - SDValue SetCC = getSETCC(Cond, SDValue(Sum.getNode(), 1), DL, DAG); - - return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); -} - /// Returns true if the operand type is exactly twice the native width, and /// the corresponding cmpxchg8b or cmpxchg16b instruction is available. /// Used to know whether to use cmpxchg8/16b when expanding atomic operations @@ -24246,7 +25236,7 @@ static SDValue LowerATOMIC_FENCE(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(X86ISD::MFENCE, dl, MVT::Other, Op.getOperand(0)); SDValue Chain = Op.getOperand(0); - SDValue Zero = DAG.getConstant(0, dl, MVT::i32); + SDValue Zero = DAG.getTargetConstant(0, dl, MVT::i32); SDValue Ops[] = { DAG.getRegister(X86::ESP, MVT::i32), // Base DAG.getTargetConstant(1, dl, MVT::i8), // Scale @@ -24256,7 +25246,7 @@ static SDValue LowerATOMIC_FENCE(SDValue Op, const X86Subtarget &Subtarget, Zero, Chain }; - SDNode *Res = DAG.getMachineNode(X86::OR32mrLocked, dl, MVT::Other, Ops); + SDNode *Res = DAG.getMachineNode(X86::OR32mi8Locked, dl, MVT::Other, Ops); return SDValue(Res, 0); } @@ -24369,40 +25359,32 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget, if (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 || SrcVT == MVT::i64) { assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); - if (DstVT != MVT::f64) + if (DstVT != MVT::f64 && DstVT != MVT::i64 && + !(DstVT == MVT::x86mmx && SrcVT.isVector())) // This conversion needs to be expanded. return SDValue(); - SmallVector<SDValue, 16> Elts; SDLoc dl(Op); - unsigned NumElts; - MVT SVT; if (SrcVT.isVector()) { - NumElts = SrcVT.getVectorNumElements(); - SVT = SrcVT.getVectorElementType(); - // Widen the vector in input in the case of MVT::v2i32. // Example: from MVT::v2i32 to MVT::v4i32. - for (unsigned i = 0, e = NumElts; i != e; ++i) - Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, Src, - DAG.getIntPtrConstant(i, dl))); + MVT NewVT = MVT::getVectorVT(SrcVT.getVectorElementType(), + SrcVT.getVectorNumElements() * 2); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewVT, Src, + DAG.getUNDEF(SrcVT)); } else { assert(SrcVT == MVT::i64 && !Subtarget.is64Bit() && "Unexpected source type in LowerBITCAST"); - Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, - DAG.getIntPtrConstant(0, dl))); - Elts.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Src, - DAG.getIntPtrConstant(1, dl))); - NumElts = 2; - SVT = MVT::i32; - } - // Explicitly mark the extra elements as Undef. - Elts.append(NumElts, DAG.getUNDEF(SVT)); - - EVT NewVT = EVT::getVectorVT(*DAG.getContext(), SVT, NumElts * 2); - SDValue BV = DAG.getBuildVector(NewVT, dl, Elts); - SDValue ToV2F64 = DAG.getBitcast(MVT::v2f64, BV); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, ToV2F64, + Src = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, Src); + } + + MVT V2X64VT = DstVT == MVT::f64 ? MVT::v2f64 : MVT::v2i64; + Src = DAG.getNode(ISD::BITCAST, dl, V2X64VT, Src); + + if (DstVT == MVT::x86mmx) + return DAG.getNode(X86ISD::MOVDQ2Q, dl, DstVT, Src); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, DstVT, Src, DAG.getIntPtrConstant(0, dl)); } @@ -24445,7 +25427,7 @@ static SDValue LowerHorizontalByteSum(SDValue V, MVT VT, // PSADBW instruction horizontally add all bytes and leave the result in i64 // chunks, thus directly computes the pop count for v2i64 and v4i64. if (EltVT == MVT::i64) { - SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL); + SDValue Zeros = DAG.getConstant(0, DL, ByteVecVT); MVT SadVecVT = MVT::getVectorVT(MVT::i64, VecSize / 64); V = DAG.getNode(X86ISD::PSADBW, DL, SadVecVT, V, Zeros); return DAG.getBitcast(VT, V); @@ -24457,13 +25439,13 @@ static SDValue LowerHorizontalByteSum(SDValue V, MVT VT, // this is that it lines up the results of two PSADBW instructions to be // two v2i64 vectors which concatenated are the 4 population counts. We can // then use PACKUSWB to shrink and concatenate them into a v4i32 again. - SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL); + SDValue Zeros = DAG.getConstant(0, DL, VT); SDValue V32 = DAG.getBitcast(VT, V); - SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V32, Zeros); - SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V32, Zeros); + SDValue Low = getUnpackl(DAG, DL, VT, V32, Zeros); + SDValue High = getUnpackh(DAG, DL, VT, V32, Zeros); // Do the horizontal sums into two v2i64s. - Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL); + Zeros = DAG.getConstant(0, DL, ByteVecVT); MVT SadVecVT = MVT::getVectorVT(MVT::i64, VecSize / 64); Low = DAG.getNode(X86ISD::PSADBW, DL, SadVecVT, DAG.getBitcast(ByteVecVT, Low), Zeros); @@ -24498,7 +25480,9 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); MVT EltVT = VT.getVectorElementType(); - unsigned VecSize = VT.getSizeInBits(); + int NumElts = VT.getVectorNumElements(); + (void)EltVT; + assert(EltVT == MVT::i8 && "Only vXi8 vector CTPOP lowering supported."); // Implement a lookup table in register by using an algorithm based on: // http://wm.ite.pl/articles/sse-popcount.html @@ -24510,109 +25494,30 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL, // masked out higher ones) for each byte. PSHUFB is used separately with both // to index the in-register table. Next, both are added and the result is a // i8 vector where each element contains the pop count for input byte. - // - // To obtain the pop count for elements != i8, we follow up with the same - // approach and use additional tricks as described below. - // const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4}; - int NumByteElts = VecSize / 8; - MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts); - SDValue In = DAG.getBitcast(ByteVecVT, Op); SmallVector<SDValue, 64> LUTVec; - for (int i = 0; i < NumByteElts; ++i) + for (int i = 0; i < NumElts; ++i) LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8)); - SDValue InRegLUT = DAG.getBuildVector(ByteVecVT, DL, LUTVec); - SDValue M0F = DAG.getConstant(0x0F, DL, ByteVecVT); + SDValue InRegLUT = DAG.getBuildVector(VT, DL, LUTVec); + SDValue M0F = DAG.getConstant(0x0F, DL, VT); // High nibbles - SDValue FourV = DAG.getConstant(4, DL, ByteVecVT); - SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV); + SDValue FourV = DAG.getConstant(4, DL, VT); + SDValue HiNibbles = DAG.getNode(ISD::SRL, DL, VT, Op, FourV); // Low nibbles - SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F); + SDValue LoNibbles = DAG.getNode(ISD::AND, DL, VT, Op, M0F); // The input vector is used as the shuffle mask that index elements into the // LUT. After counting low and high nibbles, add the vector to obtain the // final pop count per i8 element. - SDValue HighPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles); - SDValue LowPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles); - SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt); - - if (EltVT == MVT::i8) - return PopCnt; - - return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG); -} - -static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - MVT VT = Op.getSimpleValueType(); - assert(VT.is128BitVector() && - "Only 128-bit vector bitmath lowering supported."); - - int VecSize = VT.getSizeInBits(); - MVT EltVT = VT.getVectorElementType(); - int Len = EltVT.getSizeInBits(); - - // This is the vectorized version of the "best" algorithm from - // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel - // with a minor tweak to use a series of adds + shifts instead of vector - // multiplications. Implemented for all integer vector types. We only use - // this when we don't have SSSE3 which allows a LUT-based lowering that is - // much faster, even faster than using native popcnt instructions. - - auto GetShift = [&](unsigned OpCode, SDValue V, int Shifter) { - MVT VT = V.getSimpleValueType(); - SDValue ShifterV = DAG.getConstant(Shifter, DL, VT); - return DAG.getNode(OpCode, DL, VT, V, ShifterV); - }; - auto GetMask = [&](SDValue V, APInt Mask) { - MVT VT = V.getSimpleValueType(); - SDValue MaskV = DAG.getConstant(Mask, DL, VT); - return DAG.getNode(ISD::AND, DL, VT, V, MaskV); - }; - - // We don't want to incur the implicit masks required to SRL vNi8 vectors on - // x86, so set the SRL type to have elements at least i16 wide. This is - // correct because all of our SRLs are followed immediately by a mask anyways - // that handles any bits that sneak into the high bits of the byte elements. - MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16); - - SDValue V = Op; - - // v = v - ((v >> 1) & 0x55555555...) - SDValue Srl = - DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 1)); - SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55))); - V = DAG.getNode(ISD::SUB, DL, VT, V, And); - - // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...) - SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33))); - Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 2)); - SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33))); - V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS); - - // v = (v + (v >> 4)) & 0x0F0F0F0F... - Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 4)); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl); - V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F))); - - // At this point, V contains the byte-wise population count, and we are - // merely doing a horizontal sum if necessary to get the wider element - // counts. - if (EltVT == MVT::i8) - return V; - - return LowerHorizontalByteSum( - DAG.getBitcast(MVT::getVectorVT(MVT::i8, VecSize / 8), V), VT, Subtarget, - DAG); + SDValue HiPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, HiNibbles); + SDValue LoPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, LoNibbles); + return DAG.getNode(ISD::ADD, DL, VT, HiPopCnt, LoPopCnt); } // Please ensure that any codegen change from LowerVectorCTPOP is reflected in @@ -24638,12 +25543,6 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, } } - if (!Subtarget.hasSSSE3()) { - // We can't use the fast LUT approach, so fall back on vectorized bitmath. - assert(VT.is128BitVector() && "Only 128-bit vectors supported in SSE!"); - return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG); - } - // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) return Lower256IntUnary(Op, DAG); @@ -24652,6 +25551,18 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, if (VT.is512BitVector() && !Subtarget.hasBWI()) return Lower512IntUnary(Op, DAG); + // For element types greater than i8, do vXi8 pop counts and a bytesum. + if (VT.getScalarType() != MVT::i8) { + MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8); + SDValue ByteOp = DAG.getBitcast(ByteVT, Op0); + SDValue PopCnt8 = DAG.getNode(ISD::CTPOP, DL, ByteVT, ByteOp); + return LowerHorizontalByteSum(PopCnt8, VT, Subtarget, DAG); + } + + // We can't use the fast LUT approach, so fall back on LegalizeDAG. + if (!Subtarget.hasSSSE3()) + return SDValue(); + return LowerVectorCTPOPInRegLUT(Op0, DL, Subtarget, DAG); } @@ -24759,8 +25670,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget, } static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG, - const X86Subtarget &Subtarget, - bool AllowIncDec = true) { + const X86Subtarget &Subtarget) { unsigned NewOpc = 0; switch (N->getOpcode()) { case ISD::ATOMIC_LOAD_ADD: @@ -24784,25 +25694,6 @@ static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG, MachineMemOperand *MMO = cast<MemSDNode>(N)->getMemOperand(); - if (auto *C = dyn_cast<ConstantSDNode>(N->getOperand(2))) { - // Convert to inc/dec if they aren't slow or we are optimizing for size. - if (AllowIncDec && (!Subtarget.slowIncDec() || - DAG.getMachineFunction().getFunction().optForSize())) { - if ((NewOpc == X86ISD::LADD && C->isOne()) || - (NewOpc == X86ISD::LSUB && C->isAllOnesValue())) - return DAG.getMemIntrinsicNode(X86ISD::LINC, SDLoc(N), - DAG.getVTList(MVT::i32, MVT::Other), - {N->getOperand(0), N->getOperand(1)}, - /*MemVT=*/N->getSimpleValueType(0), MMO); - if ((NewOpc == X86ISD::LSUB && C->isOne()) || - (NewOpc == X86ISD::LADD && C->isAllOnesValue())) - return DAG.getMemIntrinsicNode(X86ISD::LDEC, SDLoc(N), - DAG.getVTList(MVT::i32, MVT::Other), - {N->getOperand(0), N->getOperand(1)}, - /*MemVT=*/N->getSimpleValueType(0), MMO); - } - } - return DAG.getMemIntrinsicNode( NewOpc, SDLoc(N), DAG.getVTList(MVT::i32, MVT::Other), {N->getOperand(0), N->getOperand(1), N->getOperand(2)}, @@ -25120,8 +26011,7 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, // VLX the vector should be widened to 512 bit unsigned NumEltsInWideVec = 512 / VT.getScalarSizeInBits(); MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); - SDValue Src0 = N->getSrc0(); - Src0 = ExtendToType(Src0, WideDataVT, DAG); + SDValue PassThru = ExtendToType(N->getPassThru(), WideDataVT, DAG); // Mask element has to be i1. assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && @@ -25131,7 +26021,7 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, Mask = ExtendToType(Mask, WideMaskVT, DAG, true); SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(), - N->getBasePtr(), Mask, Src0, + N->getBasePtr(), Mask, PassThru, N->getMemoryVT(), N->getMemOperand(), N->getExtensionType(), N->isExpandingLoad()); @@ -25194,7 +26084,7 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, MVT VT = Op.getSimpleValueType(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); - SDValue Src0 = N->getValue(); + SDValue PassThru = N->getPassThru(); MVT IndexVT = Index.getSimpleValueType(); MVT MaskVT = Mask.getSimpleValueType(); @@ -25219,12 +26109,12 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts); MaskVT = MVT::getVectorVT(MVT::i1, NumElts); - Src0 = ExtendToType(Src0, VT, DAG); + PassThru = ExtendToType(PassThru, VT, DAG); Index = ExtendToType(Index, IndexVT, DAG); Mask = ExtendToType(Mask, MaskVT, DAG, true); } - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, + SDValue Ops[] = { N->getChain(), PassThru, Mask, N->getBasePtr(), Index, N->getScale() }; SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), @@ -25308,6 +26198,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::SHL_PARTS: case ISD::SRA_PARTS: case ISD::SRL_PARTS: return LowerShiftParts(Op, DAG); + case ISD::FSHL: + case ISD::FSHR: return LowerFunnelShift(Op, Subtarget, DAG); case ISD::SINT_TO_FP: return LowerSINT_TO_FP(Op, DAG); case ISD::UINT_TO_FP: return LowerUINT_TO_FP(Op, DAG); case ISD::TRUNCATE: return LowerTRUNCATE(Op, DAG); @@ -25322,6 +26214,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG); case ISD::STORE: return LowerStore(Op, Subtarget, DAG); + case ISD::FADD: + case ISD::FSUB: return lowerFaddFsub(Op, DAG, Subtarget); case ISD::FABS: case ISD::FNEG: return LowerFABSorFNEG(Op, DAG); case ISD::FCOPYSIGN: return LowerFCOPYSIGN(Op, DAG); @@ -25354,12 +26248,10 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: return LowerCTLZ(Op, Subtarget, DAG); case ISD::CTTZ: - case ISD::CTTZ_ZERO_UNDEF: return LowerCTTZ(Op, DAG); + case ISD::CTTZ_ZERO_UNDEF: return LowerCTTZ(Op, Subtarget, DAG); case ISD::MUL: return LowerMUL(Op, Subtarget, DAG); case ISD::MULHS: case ISD::MULHU: return LowerMULH(Op, Subtarget, DAG); - case ISD::UMUL_LOHI: - case ISD::SMUL_LOHI: return LowerMUL_LOHI(Op, Subtarget, DAG); case ISD::ROTL: case ISD::ROTR: return LowerRotate(Op, Subtarget, DAG); case ISD::SRA: @@ -25376,12 +26268,16 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::ADDCARRY: case ISD::SUBCARRY: return LowerADDSUBCARRY(Op, DAG); case ISD::ADD: - case ISD::SUB: return LowerADD_SUB(Op, DAG); + case ISD::SUB: return lowerAddSub(Op, DAG, Subtarget); + case ISD::UADDSAT: + case ISD::SADDSAT: + case ISD::USUBSAT: + case ISD::SSUBSAT: return LowerADDSAT_SUBSAT(Op, DAG); case ISD::SMAX: case ISD::SMIN: case ISD::UMAX: case ISD::UMIN: return LowerMINMAX(Op, DAG); - case ISD::ABS: return LowerABS(Op, DAG); + case ISD::ABS: return LowerABS(Op, Subtarget, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG); case ISD::MLOAD: return LowerMLOAD(Op, Subtarget, DAG); case ISD::MSTORE: return LowerMSTORE(Op, Subtarget, DAG); @@ -25421,32 +26317,70 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue>&Results, SelectionDAG &DAG) const { SDLoc dl(N); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); switch (N->getOpcode()) { default: llvm_unreachable("Do not know how to custom type legalize this operation!"); + case ISD::MUL: { + EVT VT = N->getValueType(0); + assert(VT.isVector() && "Unexpected VT"); + if (getTypeAction(*DAG.getContext(), VT) == TypePromoteInteger && + VT.getVectorNumElements() == 2) { + // Promote to a pattern that will be turned into PMULUDQ. + SDValue N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::v2i64, + N->getOperand(0)); + SDValue N1 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::v2i64, + N->getOperand(1)); + SDValue Mul = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64, N0, N1); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, VT, Mul)); + } else if (getTypeAction(*DAG.getContext(), VT) == TypeWidenVector && + VT.getVectorElementType() == MVT::i8) { + // Pre-promote these to vXi16 to avoid op legalization thinking all 16 + // elements are needed. + MVT MulVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements()); + SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, dl, MulVT, N->getOperand(0)); + SDValue Op1 = DAG.getNode(ISD::ANY_EXTEND, dl, MulVT, N->getOperand(1)); + SDValue Res = DAG.getNode(ISD::MUL, dl, MulVT, Op0, Op1); + Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res); + unsigned NumConcats = 16 / VT.getVectorNumElements(); + SmallVector<SDValue, 8> ConcatOps(NumConcats, DAG.getUNDEF(VT)); + ConcatOps[0] = Res; + Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, ConcatOps); + Results.push_back(Res); + } + return; + } + case ISD::UADDSAT: + case ISD::SADDSAT: + case ISD::USUBSAT: + case ISD::SSUBSAT: + case X86ISD::VPMADDWD: case X86ISD::AVG: { - // Legalize types for X86ISD::AVG by expanding vectors. + // Legalize types for ISD::UADDSAT/SADDSAT/USUBSAT/SSUBSAT and + // X86ISD::AVG/VPMADDWD by widening. assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); - auto InVT = N->getValueType(0); - assert(InVT.getSizeInBits() < 128); - assert(128 % InVT.getSizeInBits() == 0); + EVT VT = N->getValueType(0); + EVT InVT = N->getOperand(0).getValueType(); + assert(VT.getSizeInBits() < 128 && 128 % VT.getSizeInBits() == 0 && + "Expected a VT that divides into 128 bits."); unsigned NumConcat = 128 / InVT.getSizeInBits(); - EVT RegVT = EVT::getVectorVT(*DAG.getContext(), - InVT.getVectorElementType(), - NumConcat * InVT.getVectorNumElements()); + EVT InWideVT = EVT::getVectorVT(*DAG.getContext(), + InVT.getVectorElementType(), + NumConcat * InVT.getVectorNumElements()); + EVT WideVT = EVT::getVectorVT(*DAG.getContext(), + VT.getVectorElementType(), + NumConcat * VT.getVectorNumElements()); SmallVector<SDValue, 16> Ops(NumConcat, DAG.getUNDEF(InVT)); Ops[0] = N->getOperand(0); - SDValue InVec0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Ops); + SDValue InVec0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, InWideVT, Ops); Ops[0] = N->getOperand(1); - SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Ops); + SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, InWideVT, Ops); - SDValue Res = DAG.getNode(X86ISD::AVG, dl, RegVT, InVec0, InVec1); - if (getTypeAction(*DAG.getContext(), InVT) != TypeWidenVector) - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, InVT, Res, + SDValue Res = DAG.getNode(N->getOpcode(), dl, WideVT, InVec0, InVec1); + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Res, DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; @@ -25456,7 +26390,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, // setCC result type is v2i1 because type legalzation will end up with // a v4i1 setcc plus an extend. assert(N->getValueType(0) == MVT::v2i32 && "Unexpected type"); - if (N->getOperand(0).getValueType() != MVT::v2f32) + if (N->getOperand(0).getValueType() != MVT::v2f32 || + getTypeAction(*DAG.getContext(), MVT::v2i32) == TypeWidenVector) return; SDValue UNDEF = DAG.getUNDEF(MVT::v2f32); SDValue LHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, @@ -25465,9 +26400,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, N->getOperand(1), UNDEF); SDValue Res = DAG.getNode(ISD::SETCC, dl, MVT::v4i32, LHS, RHS, N->getOperand(2)); - if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, - DAG.getIntPtrConstant(0, dl)); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; } @@ -25489,13 +26423,198 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, case ISD::SDIV: case ISD::UDIV: case ISD::SREM: - case ISD::UREM: + case ISD::UREM: { + EVT VT = N->getValueType(0); + if (getTypeAction(*DAG.getContext(), VT) == TypeWidenVector) { + // If this RHS is a constant splat vector we can widen this and let + // division/remainder by constant optimize it. + // TODO: Can we do something for non-splat? + APInt SplatVal; + if (ISD::isConstantSplatVector(N->getOperand(1).getNode(), SplatVal)) { + unsigned NumConcats = 128 / VT.getSizeInBits(); + SmallVector<SDValue, 8> Ops0(NumConcats, DAG.getUNDEF(VT)); + Ops0[0] = N->getOperand(0); + EVT ResVT = getTypeToTransformTo(*DAG.getContext(), VT); + SDValue N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Ops0); + SDValue N1 = DAG.getConstant(SplatVal, dl, ResVT); + SDValue Res = DAG.getNode(N->getOpcode(), dl, ResVT, N0, N1); + Results.push_back(Res); + } + return; + } + + if (VT == MVT::v2i32) { + // Legalize v2i32 div/rem by unrolling. Otherwise we promote to the + // v2i64 and unroll later. But then we create i64 scalar ops which + // might be slow in 64-bit mode or require a libcall in 32-bit mode. + Results.push_back(DAG.UnrollVectorOp(N)); + return; + } + + if (VT.isVector()) + return; + + LLVM_FALLTHROUGH; + } case ISD::SDIVREM: case ISD::UDIVREM: { SDValue V = LowerWin64_i128OP(SDValue(N,0), DAG); Results.push_back(V); return; } + case ISD::TRUNCATE: { + MVT VT = N->getSimpleValueType(0); + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) + return; + + // The generic legalizer will try to widen the input type to the same + // number of elements as the widened result type. But this isn't always + // the best thing so do some custom legalization to avoid some cases. + MVT WidenVT = getTypeToTransformTo(*DAG.getContext(), VT).getSimpleVT(); + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + + unsigned InBits = InVT.getSizeInBits(); + if (128 % InBits == 0) { + // 128 bit and smaller inputs should avoid truncate all together and + // just use a build_vector that will become a shuffle. + // TODO: Widen and use a shuffle directly? + MVT InEltVT = InVT.getSimpleVT().getVectorElementType(); + EVT EltVT = VT.getVectorElementType(); + unsigned WidenNumElts = WidenVT.getVectorNumElements(); + SmallVector<SDValue, 16> Ops(WidenNumElts, DAG.getUNDEF(EltVT)); + // Use the original element count so we don't do more scalar opts than + // necessary. + unsigned MinElts = VT.getVectorNumElements(); + for (unsigned i=0; i < MinElts; ++i) { + SDValue Val = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, InEltVT, In, + DAG.getIntPtrConstant(i, dl)); + Ops[i] = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Val); + } + Results.push_back(DAG.getBuildVector(WidenVT, dl, Ops)); + return; + } + // With AVX512 there are some cases that can use a target specific + // truncate node to go from 256/512 to less than 128 with zeros in the + // upper elements of the 128 bit result. + if (Subtarget.hasAVX512() && isTypeLegal(InVT)) { + // We can use VTRUNC directly if for 256 bits with VLX or for any 512. + if ((InBits == 256 && Subtarget.hasVLX()) || InBits == 512) { + Results.push_back(DAG.getNode(X86ISD::VTRUNC, dl, WidenVT, In)); + return; + } + // There's one case we can widen to 512 bits and use VTRUNC. + if (InVT == MVT::v4i64 && VT == MVT::v4i8 && isTypeLegal(MVT::v8i64)) { + In = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i64, In, + DAG.getUNDEF(MVT::v4i64)); + Results.push_back(DAG.getNode(X86ISD::VTRUNC, dl, WidenVT, In)); + return; + } + } + return; + } + case ISD::SIGN_EXTEND_VECTOR_INREG: { + if (ExperimentalVectorWideningLegalization) + return; + + EVT VT = N->getValueType(0); + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + if (!Subtarget.hasSSE41() && VT == MVT::v4i64 && + (InVT == MVT::v16i16 || InVT == MVT::v32i8)) { + // Custom split this so we can extend i8/i16->i32 invec. This is better + // since sign_extend_inreg i8/i16->i64 requires an extend to i32 using + // sra. Then extending from i32 to i64 using pcmpgt. By custom splitting + // we allow the sra from the extend to i32 to be shared by the split. + EVT ExtractVT = EVT::getVectorVT(*DAG.getContext(), + InVT.getVectorElementType(), + InVT.getVectorNumElements() / 2); + MVT ExtendVT = MVT::getVectorVT(MVT::i32, + VT.getVectorNumElements()); + In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ExtractVT, + In, DAG.getIntPtrConstant(0, dl)); + In = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, MVT::v4i32, In); + + // Fill a vector with sign bits for each element. + SDValue Zero = DAG.getConstant(0, dl, ExtendVT); + SDValue SignBits = DAG.getSetCC(dl, ExtendVT, Zero, In, ISD::SETGT); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); + + // Create an unpackl and unpackh to interleave the sign bits then bitcast + // to vXi64. + SDValue Lo = getUnpackl(DAG, dl, ExtendVT, In, SignBits); + Lo = DAG.getNode(ISD::BITCAST, dl, LoVT, Lo); + SDValue Hi = getUnpackh(DAG, dl, ExtendVT, In, SignBits); + Hi = DAG.getNode(ISD::BITCAST, dl, HiVT, Hi); + + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi); + Results.push_back(Res); + return; + } + return; + } + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: { + if (!ExperimentalVectorWideningLegalization) + return; + + EVT VT = N->getValueType(0); + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + if (!Subtarget.hasSSE41() && VT == MVT::v4i64 && + (InVT == MVT::v4i16 || InVT == MVT::v4i8)) { + // Custom split this so we can extend i8/i16->i32 invec. This is better + // since sign_extend_inreg i8/i16->i64 requires an extend to i32 using + // sra. Then extending from i32 to i64 using pcmpgt. By custom splitting + // we allow the sra from the extend to i32 to be shared by the split. + In = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, In); + + // Fill a vector with sign bits for each element. + SDValue Zero = DAG.getConstant(0, dl, MVT::v4i32); + SDValue SignBits = DAG.getSetCC(dl, MVT::v4i32, Zero, In, ISD::SETGT); + + // Create an unpackl and unpackh to interleave the sign bits then bitcast + // to v2i64. + SDValue Lo = DAG.getVectorShuffle(MVT::v4i32, dl, In, SignBits, + {0, 4, 1, 5}); + Lo = DAG.getNode(ISD::BITCAST, dl, MVT::v2i64, Lo); + SDValue Hi = DAG.getVectorShuffle(MVT::v4i32, dl, In, SignBits, + {2, 6, 3, 7}); + Hi = DAG.getNode(ISD::BITCAST, dl, MVT::v2i64, Hi); + + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi); + Results.push_back(Res); + return; + } + + if ((VT == MVT::v16i32 || VT == MVT::v8i64) && InVT.is128BitVector()) { + // Perform custom splitting instead of the two stage extend we would get + // by default. + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); + assert(isTypeLegal(LoVT) && "Split VT not legal?"); + + bool IsSigned = N->getOpcode() == ISD::SIGN_EXTEND; + + SDValue Lo = getExtendInVec(IsSigned, dl, LoVT, In, DAG); + + // We need to shift the input over by half the number of elements. + unsigned NumElts = InVT.getVectorNumElements(); + unsigned HalfNumElts = NumElts / 2; + SmallVector<int, 16> ShufMask(NumElts, SM_SentinelUndef); + for (unsigned i = 0; i != HalfNumElts; ++i) + ShufMask[i] = i + HalfNumElts; + + SDValue Hi = DAG.getVectorShuffle(InVT, dl, In, In, ShufMask); + Hi = getExtendInVec(IsSigned, dl, HiVT, Hi, DAG); + + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi); + Results.push_back(Res); + } + return; + } case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: { bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT; @@ -25503,38 +26622,90 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SDValue Src = N->getOperand(0); EVT SrcVT = Src.getValueType(); + // Promote these manually to avoid over promotion to v2i64. Type + // legalization will revisit the v2i32 operation for more cleanup. + if ((VT == MVT::v2i8 || VT == MVT::v2i16) && + getTypeAction(*DAG.getContext(), VT) == TypePromoteInteger) { + // AVX512DQ provides instructions that produce a v2i64 result. + if (Subtarget.hasDQI()) + return; + + SDValue Res = DAG.getNode(ISD::FP_TO_SINT, dl, MVT::v2i32, Src); + Res = DAG.getNode(N->getOpcode() == ISD::FP_TO_UINT ? ISD::AssertZext + : ISD::AssertSext, + dl, MVT::v2i32, Res, + DAG.getValueType(VT.getVectorElementType())); + Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res); + Results.push_back(Res); + return; + } + + if (VT.isVector() && VT.getScalarSizeInBits() < 32) { + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) + return; + + // Try to create a 128 bit vector, but don't exceed a 32 bit element. + unsigned NewEltWidth = std::min(128 / VT.getVectorNumElements(), 32U); + MVT PromoteVT = MVT::getVectorVT(MVT::getIntegerVT(NewEltWidth), + VT.getVectorNumElements()); + SDValue Res = DAG.getNode(ISD::FP_TO_SINT, dl, PromoteVT, Src); + + // Preserve what we know about the size of the original result. Except + // when the result is v2i32 since we can't widen the assert. + if (PromoteVT != MVT::v2i32) + Res = DAG.getNode(N->getOpcode() == ISD::FP_TO_UINT ? ISD::AssertZext + : ISD::AssertSext, + dl, PromoteVT, Res, + DAG.getValueType(VT.getVectorElementType())); + + // Truncate back to the original width. + Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res); + + // Now widen to 128 bits. + unsigned NumConcats = 128 / VT.getSizeInBits(); + MVT ConcatVT = MVT::getVectorVT(VT.getSimpleVT().getVectorElementType(), + VT.getVectorNumElements() * NumConcats); + SmallVector<SDValue, 8> ConcatOps(NumConcats, DAG.getUNDEF(VT)); + ConcatOps[0] = Res; + Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatVT, ConcatOps); + Results.push_back(Res); + return; + } + + if (VT == MVT::v2i32) { assert((IsSigned || Subtarget.hasAVX512()) && "Can only handle signed conversion without AVX512"); assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); + bool Widenv2i32 = + getTypeAction(*DAG.getContext(), MVT::v2i32) == TypeWidenVector; if (Src.getValueType() == MVT::v2f64) { - MVT ResVT = MVT::v4i32; unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI; if (!IsSigned && !Subtarget.hasVLX()) { - // Widen to 512-bits. - ResVT = MVT::v8i32; + // If v2i32 is widened, we can defer to the generic legalizer. + if (Widenv2i32) + return; + // Custom widen by doubling to a legal vector with. Isel will + // further widen to v8f64. Opc = ISD::FP_TO_UINT; - Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64, - DAG.getUNDEF(MVT::v8f64), - Src, DAG.getIntPtrConstant(0, dl)); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f64, + Src, DAG.getUNDEF(MVT::v2f64)); } - SDValue Res = DAG.getNode(Opc, dl, ResVT, Src); - bool WidenType = getTypeAction(*DAG.getContext(), - MVT::v2i32) == TypeWidenVector; - ResVT = WidenType ? MVT::v4i32 : MVT::v2i32; - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Res, - DAG.getIntPtrConstant(0, dl)); + SDValue Res = DAG.getNode(Opc, dl, MVT::v4i32, Src); + if (!Widenv2i32) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; } - if (SrcVT == MVT::v2f32) { + if (SrcVT == MVT::v2f32 && + getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) { SDValue Idx = DAG.getIntPtrConstant(0, dl); SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, DAG.getUNDEF(MVT::v2f32)); Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, MVT::v4i32, Res); - if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); Results.push_back(Res); return; } @@ -25610,7 +26781,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, return; } case ISD::FP_ROUND: { - if (!TLI.isTypeLegal(N->getOperand(0).getValueType())) + if (!isTypeLegal(N->getOperand(0).getValueType())) return; SDValue V = DAG.getNode(X86ISD::VFPROUND, dl, MVT::v4f32, N->getOperand(0)); Results.push_back(V); @@ -25780,29 +26951,19 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, } if (SrcVT != MVT::f64 || - (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8)) + (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8) || + getTypeAction(*DAG.getContext(), DstVT) == TypeWidenVector) return; unsigned NumElts = DstVT.getVectorNumElements(); EVT SVT = DstVT.getVectorElementType(); EVT WiderVT = EVT::getVectorVT(*DAG.getContext(), SVT, NumElts * 2); - SDValue Expanded = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, - MVT::v2f64, N->getOperand(0)); - SDValue ToVecInt = DAG.getBitcast(WiderVT, Expanded); - - if (getTypeAction(*DAG.getContext(), DstVT) == TypeWidenVector) { - // If we are legalizing vectors by widening, we already have the desired - // legal vector type, just return it. - Results.push_back(ToVecInt); - return; - } - - SmallVector<SDValue, 8> Elts; - for (unsigned i = 0, e = NumElts; i != e; ++i) - Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, - ToVecInt, DAG.getIntPtrConstant(i, dl))); - - Results.push_back(DAG.getBuildVector(DstVT, dl, Elts)); + SDValue Res; + Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2f64, N->getOperand(0)); + Res = DAG.getBitcast(WiderVT, Res); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DstVT, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); return; } case ISD::MGATHER: { @@ -25814,9 +26975,9 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, return; SDValue Mask = Gather->getMask(); assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); - SDValue Src0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, - Gather->getValue(), - DAG.getUNDEF(MVT::v2f32)); + SDValue PassThru = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, + Gather->getPassThru(), + DAG.getUNDEF(MVT::v2f32)); if (!Subtarget.hasVLX()) { // We need to widen the mask, but the instruction will only use 2 // of its elements. So we can use undef. @@ -25824,8 +26985,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, DAG.getUNDEF(MVT::v2i1)); Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } - SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index, Gather->getScale() }; + SDValue Ops[] = { Gather->getChain(), PassThru, Mask, + Gather->getBasePtr(), Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); @@ -25838,9 +26999,9 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SDValue Index = Gather->getIndex(); SDValue Mask = Gather->getMask(); assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); - SDValue Src0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, - Gather->getValue(), - DAG.getUNDEF(MVT::v2i32)); + SDValue PassThru = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, + Gather->getPassThru(), + DAG.getUNDEF(MVT::v2i32)); // If the index is v2i64 we can use it directly. if (Index.getValueType() == MVT::v2i64 && (Subtarget.hasVLX() || !Subtarget.hasAVX512())) { @@ -25851,8 +27012,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, DAG.getUNDEF(MVT::v2i1)); Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } - SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index, Gather->getScale() }; + SDValue Ops[] = { Gather->getChain(), PassThru, Mask, + Gather->getBasePtr(), Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); @@ -25864,28 +27025,56 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(Chain); return; } - EVT IndexVT = Index.getValueType(); - EVT NewIndexVT = EVT::getVectorVT(*DAG.getContext(), - IndexVT.getScalarType(), 4); - // Otherwise we need to custom widen everything to avoid promotion. - Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, - DAG.getUNDEF(IndexVT)); - Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, - DAG.getConstant(0, dl, MVT::v2i1)); - SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index, Gather->getScale() }; - SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), - Gather->getMemoryVT(), dl, Ops, - Gather->getMemOperand()); - SDValue Chain = Res.getValue(1); - if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, - DAG.getIntPtrConstant(0, dl)); - Results.push_back(Res); - Results.push_back(Chain); - return; + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) { + EVT IndexVT = Index.getValueType(); + EVT NewIndexVT = EVT::getVectorVT(*DAG.getContext(), + IndexVT.getScalarType(), 4); + // Otherwise we need to custom widen everything to avoid promotion. + Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, + DAG.getUNDEF(IndexVT)); + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getConstant(0, dl, MVT::v2i1)); + SDValue Ops[] = { Gather->getChain(), PassThru, Mask, + Gather->getBasePtr(), Index, Gather->getScale() }; + SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), + Gather->getMemoryVT(), dl, Ops, + Gather->getMemOperand()); + SDValue Chain = Res.getValue(1); + if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); + Results.push_back(Chain); + return; + } } - break; + return; + } + case ISD::LOAD: { + // Use an f64/i64 load and a scalar_to_vector for v2f32/v2i32 loads. This + // avoids scalarizing in 32-bit mode. In 64-bit mode this avoids a int->fp + // cast since type legalization will try to use an i64 load. + MVT VT = N->getSimpleValueType(0); + assert(VT.isVector() && VT.getSizeInBits() == 64 && "Unexpected VT"); + if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector) + return; + if (!ISD::isNON_EXTLoad(N)) + return; + auto *Ld = cast<LoadSDNode>(N); + MVT LdVT = Subtarget.is64Bit() && VT.isInteger() ? MVT::i64 : MVT::f64; + SDValue Res = DAG.getLoad(LdVT, dl, Ld->getChain(), Ld->getBasePtr(), + Ld->getPointerInfo(), + Ld->getAlignment(), + Ld->getMemOperand()->getFlags()); + SDValue Chain = Res.getValue(1); + MVT WideVT = MVT::getVectorVT(LdVT, 2); + Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, WideVT, Res); + MVT CastVT = MVT::getVectorVT(VT.getVectorElementType(), + VT.getVectorNumElements() * 2); + Res = DAG.getBitcast(CastVT, Res); + Results.push_back(Res); + Results.push_back(Chain); + return; } } } @@ -25943,9 +27132,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::PSHUFB: return "X86ISD::PSHUFB"; case X86ISD::ANDNP: return "X86ISD::ANDNP"; case X86ISD::BLENDI: return "X86ISD::BLENDI"; - case X86ISD::SHRUNKBLEND: return "X86ISD::SHRUNKBLEND"; - case X86ISD::ADDUS: return "X86ISD::ADDUS"; - case X86ISD::SUBUS: return "X86ISD::SUBUS"; + case X86ISD::BLENDV: return "X86ISD::BLENDV"; case X86ISD::HADD: return "X86ISD::HADD"; case X86ISD::HSUB: return "X86ISD::HSUB"; case X86ISD::FHADD: return "X86ISD::FHADD"; @@ -25988,15 +27175,14 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::LOR: return "X86ISD::LOR"; case X86ISD::LXOR: return "X86ISD::LXOR"; case X86ISD::LAND: return "X86ISD::LAND"; - case X86ISD::LINC: return "X86ISD::LINC"; - case X86ISD::LDEC: return "X86ISD::LDEC"; case X86ISD::VZEXT_MOVL: return "X86ISD::VZEXT_MOVL"; case X86ISD::VZEXT_LOAD: return "X86ISD::VZEXT_LOAD"; - case X86ISD::VZEXT: return "X86ISD::VZEXT"; - case X86ISD::VSEXT: return "X86ISD::VSEXT"; case X86ISD::VTRUNC: return "X86ISD::VTRUNC"; case X86ISD::VTRUNCS: return "X86ISD::VTRUNCS"; case X86ISD::VTRUNCUS: return "X86ISD::VTRUNCUS"; + case X86ISD::VMTRUNC: return "X86ISD::VMTRUNC"; + case X86ISD::VMTRUNCS: return "X86ISD::VMTRUNCS"; + case X86ISD::VMTRUNCUS: return "X86ISD::VMTRUNCUS"; case X86ISD::VTRUNCSTORES: return "X86ISD::VTRUNCSTORES"; case X86ISD::VTRUNCSTOREUS: return "X86ISD::VTRUNCSTOREUS"; case X86ISD::VMTRUNCSTORES: return "X86ISD::VMTRUNCSTORES"; @@ -26005,6 +27191,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VFPEXT_RND: return "X86ISD::VFPEXT_RND"; case X86ISD::VFPEXTS_RND: return "X86ISD::VFPEXTS_RND"; case X86ISD::VFPROUND: return "X86ISD::VFPROUND"; + case X86ISD::VMFPROUND: return "X86ISD::VMFPROUND"; case X86ISD::VFPROUND_RND: return "X86ISD::VFPROUND_RND"; case X86ISD::VFPROUNDS_RND: return "X86ISD::VFPROUNDS_RND"; case X86ISD::VSHLDQ: return "X86ISD::VSHLDQ"; @@ -26029,16 +27216,11 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SBB: return "X86ISD::SBB"; case X86ISD::SMUL: return "X86ISD::SMUL"; case X86ISD::UMUL: return "X86ISD::UMUL"; - case X86ISD::SMUL8: return "X86ISD::SMUL8"; - case X86ISD::UMUL8: return "X86ISD::UMUL8"; - case X86ISD::SDIVREM8_SEXT_HREG: return "X86ISD::SDIVREM8_SEXT_HREG"; - case X86ISD::UDIVREM8_ZEXT_HREG: return "X86ISD::UDIVREM8_ZEXT_HREG"; - case X86ISD::INC: return "X86ISD::INC"; - case X86ISD::DEC: return "X86ISD::DEC"; case X86ISD::OR: return "X86ISD::OR"; case X86ISD::XOR: return "X86ISD::XOR"; case X86ISD::AND: return "X86ISD::AND"; case X86ISD::BEXTR: return "X86ISD::BEXTR"; + case X86ISD::BZHI: return "X86ISD::BZHI"; case X86ISD::MUL_IMM: return "X86ISD::MUL_IMM"; case X86ISD::MOVMSK: return "X86ISD::MOVMSK"; case X86ISD::PTEST: return "X86ISD::PTEST"; @@ -26136,7 +27318,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::XTEST: return "X86ISD::XTEST"; case X86ISD::COMPRESS: return "X86ISD::COMPRESS"; case X86ISD::EXPAND: return "X86ISD::EXPAND"; - case X86ISD::SELECT: return "X86ISD::SELECT"; case X86ISD::SELECTS: return "X86ISD::SELECTS"; case X86ISD::ADDSUB: return "X86ISD::ADDSUB"; case X86ISD::RCP14: return "X86ISD::RCP14"; @@ -26162,16 +27343,18 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FGETEXPS_RND: return "X86ISD::FGETEXPS_RND"; case X86ISD::SCALEF: return "X86ISD::SCALEF"; case X86ISD::SCALEFS: return "X86ISD::SCALEFS"; - case X86ISD::ADDS: return "X86ISD::ADDS"; - case X86ISD::SUBS: return "X86ISD::SUBS"; case X86ISD::AVG: return "X86ISD::AVG"; case X86ISD::MULHRS: return "X86ISD::MULHRS"; case X86ISD::SINT_TO_FP_RND: return "X86ISD::SINT_TO_FP_RND"; case X86ISD::UINT_TO_FP_RND: return "X86ISD::UINT_TO_FP_RND"; case X86ISD::CVTTP2SI: return "X86ISD::CVTTP2SI"; case X86ISD::CVTTP2UI: return "X86ISD::CVTTP2UI"; + case X86ISD::MCVTTP2SI: return "X86ISD::MCVTTP2SI"; + case X86ISD::MCVTTP2UI: return "X86ISD::MCVTTP2UI"; case X86ISD::CVTTP2SI_RND: return "X86ISD::CVTTP2SI_RND"; case X86ISD::CVTTP2UI_RND: return "X86ISD::CVTTP2UI_RND"; + case X86ISD::CVTTS2SI: return "X86ISD::CVTTS2SI"; + case X86ISD::CVTTS2UI: return "X86ISD::CVTTS2UI"; case X86ISD::CVTTS2SI_RND: return "X86ISD::CVTTS2SI_RND"; case X86ISD::CVTTS2UI_RND: return "X86ISD::CVTTS2UI_RND"; case X86ISD::CVTSI2P: return "X86ISD::CVTSI2P"; @@ -26182,12 +27365,17 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SCALAR_SINT_TO_FP_RND: return "X86ISD::SCALAR_SINT_TO_FP_RND"; case X86ISD::SCALAR_UINT_TO_FP_RND: return "X86ISD::SCALAR_UINT_TO_FP_RND"; case X86ISD::CVTPS2PH: return "X86ISD::CVTPS2PH"; + case X86ISD::MCVTPS2PH: return "X86ISD::MCVTPS2PH"; case X86ISD::CVTPH2PS: return "X86ISD::CVTPH2PS"; case X86ISD::CVTPH2PS_RND: return "X86ISD::CVTPH2PS_RND"; case X86ISD::CVTP2SI: return "X86ISD::CVTP2SI"; case X86ISD::CVTP2UI: return "X86ISD::CVTP2UI"; + case X86ISD::MCVTP2SI: return "X86ISD::MCVTP2SI"; + case X86ISD::MCVTP2UI: return "X86ISD::MCVTP2UI"; case X86ISD::CVTP2SI_RND: return "X86ISD::CVTP2SI_RND"; case X86ISD::CVTP2UI_RND: return "X86ISD::CVTP2UI_RND"; + case X86ISD::CVTS2SI: return "X86ISD::CVTS2SI"; + case X86ISD::CVTS2UI: return "X86ISD::CVTS2UI"; case X86ISD::CVTS2SI_RND: return "X86ISD::CVTS2SI_RND"; case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND"; case X86ISD::LWPINS: return "X86ISD::LWPINS"; @@ -26321,6 +27509,10 @@ bool X86TargetLowering::isLegalAddImmediate(int64_t Imm) const { return isInt<32>(Imm); } +bool X86TargetLowering::isLegalStoreImmediate(int64_t Imm) const { + return isInt<32>(Imm); +} + bool X86TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const { if (!VT1.isInteger() || !VT2.isInteger()) return false; @@ -26434,7 +27626,7 @@ bool X86TargetLowering::isVectorClearMaskLegal(ArrayRef<int> Mask, bool X86TargetLowering::areJTsAllowed(const Function *Fn) const { // If the subtarget is using retpolines, we need to not generate jump tables. - if (Subtarget.useRetpoline()) + if (Subtarget.useRetpolineIndirectBranches()) return false; // Otherwise, fallback on the generic logic. @@ -26633,8 +27825,8 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, // Memory Reference assert(MI.hasOneMemOperand() && "Expected VAARG_64 to have one memoperand"); - MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); - MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + SmallVector<MachineMemOperand *, 1> MMOs(MI.memoperands_begin(), + MI.memoperands_end()); // Machine Information const TargetInstrInfo *TII = Subtarget.getInstrInfo(); @@ -26732,7 +27924,7 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, .add(Index) .addDisp(Disp, UseFPOffset ? 4 : 0) .add(Segment) - .setMemRefs(MMOBegin, MMOEnd); + .setMemRefs(MMOs); // Check if there is enough room left to pull this argument. BuildMI(thisMBB, DL, TII->get(X86::CMP32ri)) @@ -26757,7 +27949,7 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, .add(Index) .addDisp(Disp, 16) .add(Segment) - .setMemRefs(MMOBegin, MMOEnd); + .setMemRefs(MMOs); // Zero-extend the offset unsigned OffsetReg64 = MRI.createVirtualRegister(AddrRegClass); @@ -26785,7 +27977,7 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, .addDisp(Disp, UseFPOffset ? 4 : 0) .add(Segment) .addReg(NextOffsetReg) - .setMemRefs(MMOBegin, MMOEnd); + .setMemRefs(MMOs); // Jump to endMBB BuildMI(offsetMBB, DL, TII->get(X86::JMP_1)) @@ -26804,7 +27996,7 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, .add(Index) .addDisp(Disp, 8) .add(Segment) - .setMemRefs(MMOBegin, MMOEnd); + .setMemRefs(MMOs); // If we need to align it, do so. Otherwise, just copy the address // to OverflowDestReg. @@ -26841,7 +28033,7 @@ X86TargetLowering::EmitVAARG64WithCustomInserter(MachineInstr &MI, .addDisp(Disp, 8) .add(Segment) .addReg(NextAddrReg) - .setMemRefs(MMOBegin, MMOEnd); + .setMemRefs(MMOs); // If we branched, emit the PHI to the front of endMBB. if (offsetMBB) { @@ -26981,19 +28173,17 @@ static bool isCMOVPseudo(MachineInstr &MI) { case X86::CMOV_RFP32: case X86::CMOV_RFP64: case X86::CMOV_RFP80: - case X86::CMOV_V2F64: - case X86::CMOV_V2I64: - case X86::CMOV_V4F32: - case X86::CMOV_V4F64: - case X86::CMOV_V4I64: - case X86::CMOV_V16F32: - case X86::CMOV_V8F32: - case X86::CMOV_V8F64: - case X86::CMOV_V8I64: - case X86::CMOV_V8I1: - case X86::CMOV_V16I1: - case X86::CMOV_V32I1: - case X86::CMOV_V64I1: + case X86::CMOV_VR128: + case X86::CMOV_VR128X: + case X86::CMOV_VR256: + case X86::CMOV_VR256X: + case X86::CMOV_VR512: + case X86::CMOV_VK2: + case X86::CMOV_VK4: + case X86::CMOV_VK8: + case X86::CMOV_VK16: + case X86::CMOV_VK32: + case X86::CMOV_VK64: return true; default: @@ -27815,8 +29005,8 @@ void X86TargetLowering::emitSetJmpShadowStackFix(MachineInstr &MI, MachineInstrBuilder MIB; // Memory Reference. - MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); - MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(), + MI.memoperands_end()); // Initialize a register with zero. MVT PVT = getPointerTy(MF->getDataLayout()); @@ -27845,7 +29035,7 @@ void X86TargetLowering::emitSetJmpShadowStackFix(MachineInstr &MI, MIB.add(MI.getOperand(MemOpndSlot + i)); } MIB.addReg(SSPCopyReg); - MIB.setMemRefs(MMOBegin, MMOEnd); + MIB.setMemRefs(MMOs); } MachineBasicBlock * @@ -27861,8 +29051,8 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, MachineFunction::iterator I = ++MBB->getIterator(); // Memory Reference - MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); - MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(), + MI.memoperands_end()); unsigned DstReg; unsigned MemOpndSlot = 0; @@ -27956,7 +29146,7 @@ X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, MIB.addReg(LabelReg); else MIB.addMBB(restoreMBB); - MIB.setMemRefs(MMOBegin, MMOEnd); + MIB.setMemRefs(MMOs); if (MF->getMMI().getModule()->getModuleFlag("cf-protection-return")) { emitSetJmpShadowStackFix(MI, thisMBB); @@ -28017,8 +29207,8 @@ X86TargetLowering::emitLongJmpShadowStackFix(MachineInstr &MI, MachineRegisterInfo &MRI = MF->getRegInfo(); // Memory Reference - MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); - MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(), + MI.memoperands_end()); MVT PVT = getPointerTy(MF->getDataLayout()); const TargetRegisterClass *PtrRC = getRegClassFor(PVT); @@ -28100,12 +29290,16 @@ X86TargetLowering::emitLongJmpShadowStackFix(MachineInstr &MI, MachineInstrBuilder MIB = BuildMI(fallMBB, DL, TII->get(PtrLoadOpc), PrevSSPReg); for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { + const MachineOperand &MO = MI.getOperand(i); if (i == X86::AddrDisp) - MIB.addDisp(MI.getOperand(i), SPPOffset); + MIB.addDisp(MO, SPPOffset); + else if (MO.isReg()) // Don't add the whole operand, we don't want to + // preserve kill flags. + MIB.addReg(MO.getReg()); else - MIB.add(MI.getOperand(i)); + MIB.add(MO); } - MIB.setMemRefs(MMOBegin, MMOEnd); + MIB.setMemRefs(MMOs); // Subtract the current SSP from the previous SSP. unsigned SspSubReg = MRI.createVirtualRegister(PtrRC); @@ -28189,8 +29383,8 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, MachineRegisterInfo &MRI = MF->getRegInfo(); // Memory Reference - MachineInstr::mmo_iterator MMOBegin = MI.memoperands_begin(); - MachineInstr::mmo_iterator MMOEnd = MI.memoperands_end(); + SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(), + MI.memoperands_end()); MVT PVT = getPointerTy(MF->getDataLayout()); assert((PVT == MVT::i64 || PVT == MVT::i32) && @@ -28221,19 +29415,29 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, // Reload FP MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), FP); - for (unsigned i = 0; i < X86::AddrNumOperands; ++i) - MIB.add(MI.getOperand(i)); - MIB.setMemRefs(MMOBegin, MMOEnd); + for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { + const MachineOperand &MO = MI.getOperand(i); + if (MO.isReg()) // Don't add the whole operand, we don't want to + // preserve kill flags. + MIB.addReg(MO.getReg()); + else + MIB.add(MO); + } + MIB.setMemRefs(MMOs); // Reload IP MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), Tmp); for (unsigned i = 0; i < X86::AddrNumOperands; ++i) { + const MachineOperand &MO = MI.getOperand(i); if (i == X86::AddrDisp) - MIB.addDisp(MI.getOperand(i), LabelOffset); + MIB.addDisp(MO, LabelOffset); + else if (MO.isReg()) // Don't add the whole operand, we don't want to + // preserve kill flags. + MIB.addReg(MO.getReg()); else - MIB.add(MI.getOperand(i)); + MIB.add(MO); } - MIB.setMemRefs(MMOBegin, MMOEnd); + MIB.setMemRefs(MMOs); // Reload SP MIB = BuildMI(*thisMBB, MI, DL, TII->get(PtrLoadOpc), SP); @@ -28241,9 +29445,10 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI, if (i == X86::AddrDisp) MIB.addDisp(MI.getOperand(i), SPOffset); else - MIB.add(MI.getOperand(i)); + MIB.add(MI.getOperand(i)); // We can preserve the kill flags here, it's + // the last instruction of the expansion. } - MIB.setMemRefs(MMOBegin, MMOEnd); + MIB.setMemRefs(MMOs); // Jump BuildMI(*thisMBB, MI, DL, TII->get(IJmpOpc)).addReg(Tmp); @@ -28562,26 +29767,23 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, return EmitLoweredTLSCall(MI, BB); case X86::CMOV_FR32: case X86::CMOV_FR64: - case X86::CMOV_F128: case X86::CMOV_GR8: case X86::CMOV_GR16: case X86::CMOV_GR32: case X86::CMOV_RFP32: case X86::CMOV_RFP64: case X86::CMOV_RFP80: - case X86::CMOV_V2F64: - case X86::CMOV_V2I64: - case X86::CMOV_V4F32: - case X86::CMOV_V4F64: - case X86::CMOV_V4I64: - case X86::CMOV_V16F32: - case X86::CMOV_V8F32: - case X86::CMOV_V8F64: - case X86::CMOV_V8I64: - case X86::CMOV_V8I1: - case X86::CMOV_V16I1: - case X86::CMOV_V32I1: - case X86::CMOV_V64I1: + case X86::CMOV_VR128: + case X86::CMOV_VR128X: + case X86::CMOV_VR256: + case X86::CMOV_VR256X: + case X86::CMOV_VR512: + case X86::CMOV_VK2: + case X86::CMOV_VK4: + case X86::CMOV_VK8: + case X86::CMOV_VK16: + case X86::CMOV_VK32: + case X86::CMOV_VK64: return EmitLoweredSelect(MI, BB); case X86::RDFLAGS32: @@ -28890,11 +30092,12 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, EVT SrcVT = Src.getValueType(); APInt DemandedElt = APInt::getOneBitSet(SrcVT.getVectorNumElements(), Op.getConstantOperandVal(1)); - DAG.computeKnownBits(Src, Known, DemandedElt, Depth + 1); + Known = DAG.computeKnownBits(Src, DemandedElt, Depth + 1); Known = Known.zextOrTrunc(BitWidth); Known.Zero.setBitsFrom(SrcVT.getScalarSizeInBits()); break; } + case X86ISD::VSRAI: case X86ISD::VSHLI: case X86ISD::VSRLI: { if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { @@ -28903,72 +30106,62 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, break; } - DAG.computeKnownBits(Op.getOperand(0), Known, DemandedElts, Depth + 1); + Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); unsigned ShAmt = ShiftImm->getZExtValue(); if (Opc == X86ISD::VSHLI) { Known.Zero <<= ShAmt; Known.One <<= ShAmt; // Low bits are known zero. Known.Zero.setLowBits(ShAmt); - } else { + } else if (Opc == X86ISD::VSRLI) { Known.Zero.lshrInPlace(ShAmt); Known.One.lshrInPlace(ShAmt); // High bits are known zero. Known.Zero.setHighBits(ShAmt); + } else { + Known.Zero.ashrInPlace(ShAmt); + Known.One.ashrInPlace(ShAmt); } } break; } case X86ISD::PACKUS: { // PACKUS is just a truncation if the upper half is zero. - // TODO: Add DemandedElts support. + APInt DemandedLHS, DemandedRHS; + getPackDemandedElts(VT, DemandedElts, DemandedLHS, DemandedRHS); + + Known.One = APInt::getAllOnesValue(BitWidth * 2); + Known.Zero = APInt::getAllOnesValue(BitWidth * 2); + KnownBits Known2; - DAG.computeKnownBits(Op.getOperand(0), Known, Depth + 1); - DAG.computeKnownBits(Op.getOperand(1), Known2, Depth + 1); - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; + if (!!DemandedLHS) { + Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedLHS, Depth + 1); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + if (!!DemandedRHS) { + Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedRHS, Depth + 1); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + if (Known.countMinLeadingZeros() < BitWidth) Known.resetAll(); Known = Known.trunc(BitWidth); break; } - case X86ISD::VZEXT: { - // TODO: Add DemandedElts support. - SDValue N0 = Op.getOperand(0); - unsigned NumElts = VT.getVectorNumElements(); - - EVT SrcVT = N0.getValueType(); - unsigned InNumElts = SrcVT.getVectorNumElements(); - unsigned InBitWidth = SrcVT.getScalarSizeInBits(); - assert(InNumElts >= NumElts && "Illegal VZEXT input"); - - Known = KnownBits(InBitWidth); - APInt DemandedSrcElts = APInt::getLowBitsSet(InNumElts, NumElts); - DAG.computeKnownBits(N0, Known, DemandedSrcElts, Depth + 1); - Known = Known.zext(BitWidth); - Known.Zero.setBitsFrom(InBitWidth); - break; - } case X86ISD::CMOV: { - DAG.computeKnownBits(Op.getOperand(1), Known, Depth+1); + Known = DAG.computeKnownBits(Op.getOperand(1), Depth+1); // If we don't know any bits, early out. if (Known.isUnknown()) break; - KnownBits Known2; - DAG.computeKnownBits(Op.getOperand(0), Known2, Depth+1); + KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth+1); // Only known if known in both the LHS and RHS. Known.One &= Known2.One; Known.Zero &= Known2.Zero; break; } - case X86ISD::UDIVREM8_ZEXT_HREG: - // TODO: Support more than just the zero extended bits? - if (Op.getResNo() != 1) - break; - // The remainder is zero extended. - Known.Zero.setBitsFrom(8); - break; } // Handle target shuffles. @@ -29013,8 +30206,8 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, for (unsigned i = 0; i != NumOps && !Known.isUnknown(); ++i) { if (!DemandedOps[i]) continue; - KnownBits Known2; - DAG.computeKnownBits(Ops[i], Known2, DemandedOps[i], Depth + 1); + KnownBits Known2 = + DAG.computeKnownBits(Ops[i], DemandedOps[i], Depth + 1); Known.One &= Known2.One; Known.Zero &= Known2.Zero; } @@ -29033,14 +30226,6 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( // SETCC_CARRY sets the dest to ~0 for true or 0 for false. return VTBits; - case X86ISD::VSEXT: { - // TODO: Add DemandedElts support. - SDValue Src = Op.getOperand(0); - unsigned Tmp = DAG.ComputeNumSignBits(Src, Depth + 1); - Tmp += VTBits - Src.getScalarValueSizeInBits(); - return Tmp; - } - case X86ISD::VTRUNC: { // TODO: Add DemandedElts support. SDValue Src = Op.getOperand(0); @@ -29054,10 +30239,16 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( case X86ISD::PACKSS: { // PACKSS is just a truncation if the sign bits extend to the packed size. - // TODO: Add DemandedElts support. + APInt DemandedLHS, DemandedRHS; + getPackDemandedElts(Op.getValueType(), DemandedElts, DemandedLHS, + DemandedRHS); + unsigned SrcBits = Op.getOperand(0).getScalarValueSizeInBits(); - unsigned Tmp0 = DAG.ComputeNumSignBits(Op.getOperand(0), Depth + 1); - unsigned Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), Depth + 1); + unsigned Tmp0 = SrcBits, Tmp1 = SrcBits; + if (!!DemandedLHS) + Tmp0 = DAG.ComputeNumSignBits(Op.getOperand(0), DemandedLHS, Depth + 1); + if (!!DemandedRHS) + Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), DemandedRHS, Depth + 1); unsigned Tmp = std::min(Tmp0, Tmp1); if (Tmp > (SrcBits - VTBits)) return Tmp - (SrcBits - VTBits); @@ -29099,12 +30290,6 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( unsigned Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), Depth+1); return std::min(Tmp0, Tmp1); } - case X86ISD::SDIVREM8_SEXT_HREG: - // TODO: Support more than just the sign extended bits? - if (Op.getResNo() != 1) - break; - // The remainder is sign extended. - return VTBits - 7; } // Fallback case. @@ -29117,21 +30302,6 @@ SDValue X86TargetLowering::unwrapAddress(SDValue N) const { return N; } -/// Returns true (and the GlobalValue and the offset) if the node is a -/// GlobalAddress + offset. -bool X86TargetLowering::isGAPlusOffset(SDNode *N, - const GlobalValue* &GA, - int64_t &Offset) const { - if (N->getOpcode() == X86ISD::Wrapper) { - if (isa<GlobalAddressSDNode>(N->getOperand(0))) { - GA = cast<GlobalAddressSDNode>(N->getOperand(0))->getGlobal(); - Offset = cast<GlobalAddressSDNode>(N->getOperand(0))->getOffset(); - return true; - } - } - return TargetLowering::isGAPlusOffset(N, GA, Offset); -} - // Attempt to match a combined shuffle mask against supported unary shuffle // instructions. // TODO: Investigate sharing more of this with shuffle lowering. @@ -29170,10 +30340,12 @@ static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, MVT::getIntegerVT(MaskEltSize); SrcVT = MVT::getVectorVT(ScalarTy, SrcSize / MaskEltSize); - if (SrcVT.getSizeInBits() != MaskVT.getSizeInBits()) { + if (SrcVT.getSizeInBits() != MaskVT.getSizeInBits()) V1 = extractSubVector(V1, 0, DAG, DL, SrcSize); - Shuffle = unsigned(X86ISD::VZEXT); - } else + + if (SrcVT.getVectorNumElements() == NumDstElts) + Shuffle = unsigned(ISD::ZERO_EXTEND); + else Shuffle = unsigned(ISD::ZERO_EXTEND_VECTOR_INREG); DstVT = MVT::getIntegerVT(Scale * MaskEltSize); @@ -29430,9 +30602,10 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, } } - // Attempt to match against either a unary or binary PACKSS/PACKUS shuffle. - // TODO add support for 256/512-bit types. - if ((MaskVT == MVT::v8i16 || MaskVT == MVT::v16i8) && Subtarget.hasSSE2()) { + // Attempt to match against either an unary or binary PACKSS/PACKUS shuffle. + if (((MaskVT == MVT::v8i16 || MaskVT == MVT::v16i8) && Subtarget.hasSSE2()) || + ((MaskVT == MVT::v16i16 || MaskVT == MVT::v32i8) && Subtarget.hasInt256()) || + ((MaskVT == MVT::v32i16 || MaskVT == MVT::v64i8) && Subtarget.hasBWI())) { if (matchVectorShuffleWithPACK(MaskVT, SrcVT, V1, V2, Shuffle, Mask, DAG, Subtarget)) { DstVT = MaskVT; @@ -29622,7 +30795,8 @@ static bool matchBinaryPermuteVectorShuffle( /// instruction but should only be used to replace chains over a certain depth. static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth, - bool HasVariableMask, SelectionDAG &DAG, + bool HasVariableMask, + bool AllowVariableMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) { assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!"); assert((Inputs.size() == 1 || Inputs.size() == 2) && @@ -29835,7 +31009,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // Depth threshold above which we can efficiently use variable mask shuffles. int VariableShuffleDepth = Subtarget.hasFastVariableShuffle() ? 2 : 3; - bool AllowVariableMask = (Depth >= VariableShuffleDepth) || HasVariableMask; + AllowVariableMask &= (Depth >= VariableShuffleDepth) || HasVariableMask; bool MaskContainsZeros = any_of(Mask, [](int M) { return M == SM_SentinelZero; }); @@ -30169,7 +31343,8 @@ static SDValue combineX86ShufflesConstants(ArrayRef<SDValue> Ops, static SDValue combineX86ShufflesRecursively( ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root, ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, unsigned Depth, - bool HasVariableMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + bool HasVariableMask, bool AllowVariableMask, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { // Bound the depth of our recursive combine because this is ultimately // quadratic in nature. const unsigned MaxRecursionDepth = 8; @@ -30195,30 +31370,36 @@ static SDValue combineX86ShufflesRecursively( if (!resolveTargetShuffleInputs(Op, OpInputs, OpMask, DAG)) return SDValue(); - assert(OpInputs.size() <= 2 && "Too many shuffle inputs"); + // TODO - Add support for more than 2 inputs. + if (2 < OpInputs.size()) + return SDValue(); + SDValue Input0 = (OpInputs.size() > 0 ? OpInputs[0] : SDValue()); SDValue Input1 = (OpInputs.size() > 1 ? OpInputs[1] : SDValue()); // Add the inputs to the Ops list, avoiding duplicates. SmallVector<SDValue, 16> Ops(SrcOps.begin(), SrcOps.end()); - int InputIdx0 = -1, InputIdx1 = -1; - for (int i = 0, e = Ops.size(); i < e; ++i) { - SDValue BC = peekThroughBitcasts(Ops[i]); - if (Input0 && BC == peekThroughBitcasts(Input0)) - InputIdx0 = i; - if (Input1 && BC == peekThroughBitcasts(Input1)) - InputIdx1 = i; - } + auto AddOp = [&Ops](SDValue Input, int InsertionPoint) -> int { + if (!Input) + return -1; + // Attempt to find an existing match. + SDValue InputBC = peekThroughBitcasts(Input); + for (int i = 0, e = Ops.size(); i < e; ++i) + if (InputBC == peekThroughBitcasts(Ops[i])) + return i; + // Match failed - should we replace an existing Op? + if (InsertionPoint >= 0) { + Ops[InsertionPoint] = Input; + return InsertionPoint; + } + // Add to the end of the Ops list. + Ops.push_back(Input); + return Ops.size() - 1; + }; - if (Input0 && InputIdx0 < 0) { - InputIdx0 = SrcOpIndex; - Ops[SrcOpIndex] = Input0; - } - if (Input1 && InputIdx1 < 0) { - InputIdx1 = Ops.size(); - Ops.push_back(Input1); - } + int InputIdx0 = AddOp(Input0, SrcOpIndex); + int InputIdx1 = AddOp(Input1, -1); assert(((RootMask.size() > OpMask.size() && RootMask.size() % OpMask.size() == 0) || @@ -30324,18 +31505,23 @@ static SDValue combineX86ShufflesRecursively( CombinedNodes.push_back(Op.getNode()); // See if we can recurse into each shuffle source op (if it's a target - // shuffle). The source op should only be combined if it either has a - // single use (i.e. current Op) or all its users have already been combined. + // shuffle). The source op should only be generally combined if it either has + // a single use (i.e. current Op) or all its users have already been combined, + // if not then we can still combine but should prevent generation of variable + // shuffles to avoid constant pool bloat. // Don't recurse if we already have more source ops than we can combine in // the remaining recursion depth. if (Ops.size() < (MaxRecursionDepth - Depth)) { - for (int i = 0, e = Ops.size(); i < e; ++i) + for (int i = 0, e = Ops.size(); i < e; ++i) { + bool AllowVar = false; if (Ops[i].getNode()->hasOneUse() || SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode())) - if (SDValue Res = combineX86ShufflesRecursively( - Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask, - DAG, Subtarget)) - return Res; + AllowVar = AllowVariableMask; + if (SDValue Res = combineX86ShufflesRecursively( + Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask, + AllowVar, DAG, Subtarget)) + return Res; + } } // Attempt to constant fold all of the constant source ops. @@ -30365,8 +31551,8 @@ static SDValue combineX86ShufflesRecursively( } // Finally, try to combine into a single shuffle instruction. - return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG, - Subtarget); + return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, + AllowVariableMask, DAG, Subtarget); } /// Get the PSHUF-style mask from PSHUF node. @@ -30545,74 +31731,6 @@ combineRedundantDWordShuffle(SDValue N, MutableArrayRef<int> Mask, return V; } -/// Search for a combinable shuffle across a chain ending in pshuflw or -/// pshufhw. -/// -/// We walk up the chain, skipping shuffles of the other half and looking -/// through shuffles which switch halves trying to find a shuffle of the same -/// pair of dwords. -static bool combineRedundantHalfShuffle(SDValue N, MutableArrayRef<int> Mask, - SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { - assert( - (N.getOpcode() == X86ISD::PSHUFLW || N.getOpcode() == X86ISD::PSHUFHW) && - "Called with something other than an x86 128-bit half shuffle!"); - SDLoc DL(N); - unsigned CombineOpcode = N.getOpcode(); - - // Walk up a single-use chain looking for a combinable shuffle. - SDValue V = N.getOperand(0); - for (; V.hasOneUse(); V = V.getOperand(0)) { - switch (V.getOpcode()) { - default: - return false; // Nothing combined! - - case ISD::BITCAST: - // Skip bitcasts as we always know the type for the target specific - // instructions. - continue; - - case X86ISD::PSHUFLW: - case X86ISD::PSHUFHW: - if (V.getOpcode() == CombineOpcode) - break; - - // Other-half shuffles are no-ops. - continue; - } - // Break out of the loop if we break out of the switch. - break; - } - - if (!V.hasOneUse()) - // We fell out of the loop without finding a viable combining instruction. - return false; - - // Combine away the bottom node as its shuffle will be accumulated into - // a preceding shuffle. - DCI.CombineTo(N.getNode(), N.getOperand(0), /*AddTo*/ true); - - // Record the old value. - SDValue Old = V; - - // Merge this node's mask and our incoming mask (adjusted to account for all - // the pshufd instructions encountered). - SmallVector<int, 4> VMask = getPSHUFShuffleMask(V); - for (int &M : Mask) - M = VMask[M]; - V = DAG.getNode(V.getOpcode(), DL, MVT::v8i16, V.getOperand(0), - getV4X86ShuffleImm8ForMask(Mask, DL, DAG)); - - // Check that the shuffles didn't cancel each other out. If not, we need to - // combine to the new one. - if (Old != V) - // Replace the combinable shuffle with the combined one, updating all users - // so that we re-evaluate the chain here. - DCI.CombineTo(Old.getNode(), V, /*AddTo*/ true); - - return true; -} - /// Try to combine x86 target specific shuffles. static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -30667,7 +31785,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, DemandedMask[i] = i; if (SDValue Res = combineX86ShufflesRecursively( {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return DAG.getNode(X86ISD::VBROADCAST, DL, VT, DAG.getBitcast(SrcVT, Res)); } @@ -30679,40 +31797,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, Mask = getPSHUFShuffleMask(N); assert(Mask.size() == 4); break; - case X86ISD::UNPCKL: { - // Combine X86ISD::UNPCKL and ISD::VECTOR_SHUFFLE into X86ISD::UNPCKH, in - // which X86ISD::UNPCKL has a ISD::UNDEF operand, and ISD::VECTOR_SHUFFLE - // moves upper half elements into the lower half part. For example: - // - // t2: v16i8 = vector_shuffle<8,9,10,11,12,13,14,15,u,u,u,u,u,u,u,u> t1, - // undef:v16i8 - // t3: v16i8 = X86ISD::UNPCKL undef:v16i8, t2 - // - // will be combined to: - // - // t3: v16i8 = X86ISD::UNPCKH undef:v16i8, t1 - - // This is only for 128-bit vectors. From SSE4.1 onward this combine may not - // happen due to advanced instructions. - if (!VT.is128BitVector()) - return SDValue(); - - auto Op0 = N.getOperand(0); - auto Op1 = N.getOperand(1); - if (Op0.isUndef() && Op1.getOpcode() == ISD::VECTOR_SHUFFLE) { - ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op1.getNode())->getMask(); - - unsigned NumElts = VT.getVectorNumElements(); - SmallVector<int, 8> ExpectedMask(NumElts, -1); - std::iota(ExpectedMask.begin(), ExpectedMask.begin() + NumElts / 2, - NumElts / 2); - - auto ShufOp = Op1.getOperand(0); - if (isShuffleEquivalent(Op1, ShufOp, Mask, ExpectedMask)) - return DAG.getNode(X86ISD::UNPCKH, DL, VT, N.getOperand(0), ShufOp); - } - return SDValue(); - } case X86ISD::MOVSD: case X86ISD::MOVSS: { SDValue N0 = N.getOperand(0); @@ -30844,9 +31928,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, case X86ISD::PSHUFHW: assert(VT.getVectorElementType() == MVT::i16 && "Bad word shuffle type!"); - if (combineRedundantHalfShuffle(N, Mask, DAG, DCI)) - return SDValue(); // We combined away this shuffle, so we're done. - // See if this reduces to a PSHUFD which is no more expensive and can // combine with more operations. Note that it has to at least flip the // dwords as otherwise it would have been removed as a no-op. @@ -31286,13 +32367,404 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, // a particular chain. if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return Res; + + // Simplify source operands based on shuffle mask. + // TODO - merge this into combineX86ShufflesRecursively. + APInt KnownUndef, KnownZero; + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, DCI)) + return SDValue(N, 0); + } + + // Look for a truncating shuffle to v2i32 of a PMULUDQ where one of the + // operands is an extend from v2i32 to v2i64. Turn it into a pmulld. + // FIXME: This can probably go away once we default to widening legalization. + if (Subtarget.hasSSE41() && VT == MVT::v4i32 && + N->getOpcode() == ISD::VECTOR_SHUFFLE && + N->getOperand(0).getOpcode() == ISD::BITCAST && + N->getOperand(0).getOperand(0).getOpcode() == X86ISD::PMULUDQ) { + SDValue BC = N->getOperand(0); + SDValue MULUDQ = BC.getOperand(0); + ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(N); + ArrayRef<int> Mask = SVOp->getMask(); + if (BC.hasOneUse() && MULUDQ.hasOneUse() && + Mask[0] == 0 && Mask[1] == 2 && Mask[2] == -1 && Mask[3] == -1) { + SDValue Op0 = MULUDQ.getOperand(0); + SDValue Op1 = MULUDQ.getOperand(1); + if (Op0.getOpcode() == ISD::BITCAST && + Op0.getOperand(0).getOpcode() == ISD::VECTOR_SHUFFLE && + Op0.getOperand(0).getValueType() == MVT::v4i32) { + ShuffleVectorSDNode *SVOp0 = + cast<ShuffleVectorSDNode>(Op0.getOperand(0)); + ArrayRef<int> Mask2 = SVOp0->getMask(); + if (Mask2[0] == 0 && Mask2[1] == -1 && + Mask2[2] == 1 && Mask2[3] == -1) { + Op0 = SVOp0->getOperand(0); + Op1 = DAG.getBitcast(MVT::v4i32, Op1); + Op1 = DAG.getVectorShuffle(MVT::v4i32, dl, Op1, Op1, Mask); + return DAG.getNode(ISD::MUL, dl, MVT::v4i32, Op0, Op1); + } + } + if (Op1.getOpcode() == ISD::BITCAST && + Op1.getOperand(0).getOpcode() == ISD::VECTOR_SHUFFLE && + Op1.getOperand(0).getValueType() == MVT::v4i32) { + ShuffleVectorSDNode *SVOp1 = + cast<ShuffleVectorSDNode>(Op1.getOperand(0)); + ArrayRef<int> Mask2 = SVOp1->getMask(); + if (Mask2[0] == 0 && Mask2[1] == -1 && + Mask2[2] == 1 && Mask2[3] == -1) { + Op0 = DAG.getBitcast(MVT::v4i32, Op0); + Op0 = DAG.getVectorShuffle(MVT::v4i32, dl, Op0, Op0, Mask); + Op1 = SVOp1->getOperand(0); + return DAG.getNode(ISD::MUL, dl, MVT::v4i32, Op0, Op1); + } + } + } } return SDValue(); } +bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( + SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero, + TargetLoweringOpt &TLO, unsigned Depth) const { + int NumElts = DemandedElts.getBitWidth(); + unsigned Opc = Op.getOpcode(); + EVT VT = Op.getValueType(); + + // Handle special case opcodes. + switch (Opc) { + case X86ISD::VSHL: + case X86ISD::VSRL: + case X86ISD::VSRA: { + // We only need the bottom 64-bits of the (128-bit) shift amount. + SDValue Amt = Op.getOperand(1); + MVT AmtVT = Amt.getSimpleValueType(); + assert(AmtVT.is128BitVector() && "Unexpected value type"); + APInt AmtUndef, AmtZero; + unsigned NumAmtElts = AmtVT.getVectorNumElements(); + APInt AmtElts = APInt::getLowBitsSet(NumAmtElts, NumAmtElts / 2); + if (SimplifyDemandedVectorElts(Amt, AmtElts, AmtUndef, AmtZero, TLO, + Depth + 1)) + return true; + LLVM_FALLTHROUGH; + } + case X86ISD::VSHLI: + case X86ISD::VSRLI: + case X86ISD::VSRAI: { + SDValue Src = Op.getOperand(0); + APInt SrcUndef; + if (SimplifyDemandedVectorElts(Src, DemandedElts, SrcUndef, KnownZero, TLO, + Depth + 1)) + return true; + // TODO convert SrcUndef to KnownUndef. + break; + } + case X86ISD::CVTSI2P: + case X86ISD::CVTUI2P: { + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + APInt SrcUndef, SrcZero; + APInt SrcElts = DemandedElts.zextOrTrunc(SrcVT.getVectorNumElements()); + if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) + return true; + break; + } + case X86ISD::PACKSS: + case X86ISD::PACKUS: { + APInt DemandedLHS, DemandedRHS; + getPackDemandedElts(VT, DemandedElts, DemandedLHS, DemandedRHS); + + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedLHS, SrcUndef, + SrcZero, TLO, Depth + 1)) + return true; + if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedRHS, SrcUndef, + SrcZero, TLO, Depth + 1)) + return true; + break; + } + case X86ISD::VBROADCAST: { + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + if (!SrcVT.isVector()) + return false; + // Don't bother broadcasting if we just need the 0'th element. + if (DemandedElts == 1) { + if(Src.getValueType() != VT) + Src = widenSubVector(VT.getSimpleVT(), Src, false, Subtarget, TLO.DAG, + SDLoc(Op)); + return TLO.CombineTo(Op, Src); + } + APInt SrcUndef, SrcZero; + APInt SrcElts = APInt::getOneBitSet(SrcVT.getVectorNumElements(), 0); + if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) + return true; + break; + } + case X86ISD::PSHUFB: { + // TODO - simplify other variable shuffle masks. + SDValue Mask = Op.getOperand(1); + APInt MaskUndef, MaskZero; + if (SimplifyDemandedVectorElts(Mask, DemandedElts, MaskUndef, MaskZero, TLO, + Depth + 1)) + return true; + break; + } + } + + // Simplify target shuffles. + if (!isTargetShuffle(Opc) || !VT.isSimple()) + return false; + + // Get target shuffle mask. + bool IsUnary; + SmallVector<int, 64> OpMask; + SmallVector<SDValue, 2> OpInputs; + if (!getTargetShuffleMask(Op.getNode(), VT.getSimpleVT(), true, OpInputs, + OpMask, IsUnary)) + return false; + + // Shuffle inputs must be the same type as the result. + if (llvm::any_of(OpInputs, + [VT](SDValue V) { return VT != V.getValueType(); })) + return false; + + // Clear known elts that might have been set above. + KnownZero.clearAllBits(); + KnownUndef.clearAllBits(); + + // Check if shuffle mask can be simplified to undef/zero/identity. + int NumSrcs = OpInputs.size(); + for (int i = 0; i != NumElts; ++i) { + int &M = OpMask[i]; + if (!DemandedElts[i]) + M = SM_SentinelUndef; + else if (0 <= M && OpInputs[M / NumElts].isUndef()) + M = SM_SentinelUndef; + } + + if (isUndefInRange(OpMask, 0, NumElts)) { + KnownUndef.setAllBits(); + return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); + } + if (isUndefOrZeroInRange(OpMask, 0, NumElts)) { + KnownZero.setAllBits(); + return TLO.CombineTo( + Op, getZeroVector(VT.getSimpleVT(), Subtarget, TLO.DAG, SDLoc(Op))); + } + for (int Src = 0; Src != NumSrcs; ++Src) + if (isSequentialOrUndefInRange(OpMask, 0, NumElts, Src * NumElts)) + return TLO.CombineTo(Op, OpInputs[Src]); + + // Attempt to simplify inputs. + for (int Src = 0; Src != NumSrcs; ++Src) { + int Lo = Src * NumElts; + APInt SrcElts = APInt::getNullValue(NumElts); + for (int i = 0; i != NumElts; ++i) + if (DemandedElts[i]) { + int M = OpMask[i] - Lo; + if (0 <= M && M < NumElts) + SrcElts.setBit(M); + } + + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(OpInputs[Src], SrcElts, SrcUndef, SrcZero, + TLO, Depth + 1)) + return true; + } + + // Extract known zero/undef elements. + // TODO - Propagate input undef/zero elts. + for (int i = 0; i != NumElts; ++i) { + if (OpMask[i] == SM_SentinelUndef) + KnownUndef.setBit(i); + if (OpMask[i] == SM_SentinelZero) + KnownZero.setBit(i); + } + + return false; +} + +bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { + EVT VT = Op.getValueType(); + unsigned BitWidth = OriginalDemandedBits.getBitWidth(); + unsigned Opc = Op.getOpcode(); + switch(Opc) { + case X86ISD::PMULDQ: + case X86ISD::PMULUDQ: { + // PMULDQ/PMULUDQ only uses lower 32 bits from each vector element. + KnownBits KnownOp; + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + // FIXME: Can we bound this better? + APInt DemandedMask = APInt::getLowBitsSet(64, 32); + if (SimplifyDemandedBits(LHS, DemandedMask, KnownOp, TLO, Depth + 1)) + return true; + if (SimplifyDemandedBits(RHS, DemandedMask, KnownOp, TLO, Depth + 1)) + return true; + break; + } + case X86ISD::VSHLI: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) { + if (ShiftImm->getAPIntValue().uge(BitWidth)) + break; + + unsigned ShAmt = ShiftImm->getZExtValue(); + APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt); + + // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a + // single shift. We can do this if the bottom bits (which are shifted + // out) are never demanded. + if (Op0.getOpcode() == X86ISD::VSRLI && + OriginalDemandedBits.countTrailingZeros() >= ShAmt) { + if (auto *Shift2Imm = dyn_cast<ConstantSDNode>(Op0.getOperand(1))) { + if (Shift2Imm->getAPIntValue().ult(BitWidth)) { + int Diff = ShAmt - Shift2Imm->getZExtValue(); + if (Diff == 0) + return TLO.CombineTo(Op, Op0.getOperand(0)); + + unsigned NewOpc = Diff < 0 ? X86ISD::VSRLI : X86ISD::VSHLI; + SDValue NewShift = TLO.DAG.getNode( + NewOpc, SDLoc(Op), VT, Op0.getOperand(0), + TLO.DAG.getConstant(std::abs(Diff), SDLoc(Op), MVT::i8)); + return TLO.CombineTo(Op, NewShift); + } + } + } + + if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known, + TLO, Depth + 1)) + return true; + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero <<= ShAmt; + Known.One <<= ShAmt; + + // Low bits known zero. + Known.Zero.setLowBits(ShAmt); + } + break; + } + case X86ISD::VSRLI: { + if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { + if (ShiftImm->getAPIntValue().uge(BitWidth)) + break; + + unsigned ShAmt = ShiftImm->getZExtValue(); + APInt DemandedMask = OriginalDemandedBits << ShAmt; + + if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, + OriginalDemandedElts, Known, TLO, Depth + 1)) + return true; + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero.lshrInPlace(ShAmt); + Known.One.lshrInPlace(ShAmt); + + // High bits known zero. + Known.Zero.setHighBits(ShAmt); + } + break; + } + case X86ISD::VSRAI: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) { + if (ShiftImm->getAPIntValue().uge(BitWidth)) + break; + + unsigned ShAmt = ShiftImm->getZExtValue(); + APInt DemandedMask = OriginalDemandedBits << ShAmt; + + // If we just want the sign bit then we don't need to shift it. + if (OriginalDemandedBits.isSignMask()) + return TLO.CombineTo(Op, Op0); + + // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1 + if (Op0.getOpcode() == X86ISD::VSHLI && Op1 == Op0.getOperand(1)) { + SDValue Op00 = Op0.getOperand(0); + unsigned NumSignBits = + TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts); + if (ShAmt < NumSignBits) + return TLO.CombineTo(Op, Op00); + } + + // If any of the demanded bits are produced by the sign extension, we also + // demand the input sign bit. + if (OriginalDemandedBits.countLeadingZeros() < ShAmt) + DemandedMask.setSignBit(); + + if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known, + TLO, Depth + 1)) + return true; + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero.lshrInPlace(ShAmt); + Known.One.lshrInPlace(ShAmt); + + // If the input sign bit is known to be zero, or if none of the top bits + // are demanded, turn this into an unsigned shift right. + if (Known.Zero[BitWidth - ShAmt - 1] || + OriginalDemandedBits.countLeadingZeros() >= ShAmt) + return TLO.CombineTo( + Op, TLO.DAG.getNode(X86ISD::VSRLI, SDLoc(Op), VT, Op0, Op1)); + + // High bits are known one. + if (Known.One[BitWidth - ShAmt - 1]) + Known.One.setHighBits(ShAmt); + } + break; + } + case X86ISD::MOVMSK: { + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + unsigned SrcBits = SrcVT.getScalarSizeInBits(); + unsigned NumElts = SrcVT.getVectorNumElements(); + + // If we don't need the sign bits at all just return zero. + if (OriginalDemandedBits.countTrailingZeros() >= NumElts) + return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT)); + + // Only demand the vector elements of the sign bits we need. + APInt KnownUndef, KnownZero; + APInt DemandedElts = OriginalDemandedBits.zextOrTrunc(NumElts); + if (SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero, + TLO, Depth + 1)) + return true; + + Known.Zero = KnownZero.zextOrSelf(BitWidth); + Known.Zero.setHighBits(BitWidth - NumElts); + + // MOVMSK only uses the MSB from each vector element. + KnownBits KnownSrc; + if (SimplifyDemandedBits(Src, APInt::getSignMask(SrcBits), DemandedElts, + KnownSrc, TLO, Depth + 1)) + return true; + + if (KnownSrc.One[SrcBits - 1]) + Known.One.setLowBits(NumElts); + else if (KnownSrc.Zero[SrcBits - 1]) + Known.Zero.setLowBits(NumElts); + return false; + } + } + + return TargetLowering::SimplifyDemandedBitsForTargetNode( + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); +} + /// Check if a vector extract from a target-specific shuffle of a load can be /// folded into a single element load. /// Similar handling for VECTOR_SHUFFLE is performed by DAGCombiner, but @@ -31344,9 +32816,13 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, if (Idx == SM_SentinelUndef) return DAG.getUNDEF(EltVT); + // Bail if any mask element is SM_SentinelZero - getVectorShuffle below + // won't handle it. + if (llvm::any_of(ShuffleMask, [](int M) { return M == SM_SentinelZero; })) + return SDValue(); + assert(0 <= Idx && Idx < (int)(2 * NumElems) && "Shuffle index out of range"); - SDValue LdNode = (Idx < (int)NumElems) ? ShuffleOps[0] - : ShuffleOps[1]; + SDValue LdNode = (Idx < (int)NumElems) ? ShuffleOps[0] : ShuffleOps[1]; // If inputs to shuffle are the same for both ops, then allow 2 uses unsigned AllowedUses = @@ -31407,9 +32883,18 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, if (!VT.isScalarInteger() || !VecVT.isSimple()) return SDValue(); + // If the input is a truncate from v16i8 or v32i8 go ahead and use a + // movmskb even with avx512. This will be better than truncating to vXi1 and + // using a kmov. This can especially help KNL if the input is a v16i8/v32i8 + // vpcmpeqb/vpcmpgtb. + bool IsTruncated = N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && + (N0.getOperand(0).getValueType() == MVT::v16i8 || + N0.getOperand(0).getValueType() == MVT::v32i8 || + N0.getOperand(0).getValueType() == MVT::v64i8); + // With AVX512 vxi1 types are legal and we prefer using k-regs. // MOVMSK is supported in SSE2 or later. - if (Subtarget.hasAVX512() || !Subtarget.hasSSE2()) + if (!Subtarget.hasSSE2() || (Subtarget.hasAVX512() && !IsTruncated)) return SDValue(); // There are MOVMSK flavors for types v16i8, v32i8, v4f32, v8f32, v4f64 and @@ -31423,23 +32908,19 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // For example, t0 := (v8i16 sext(v8i1 x)) needs to be shuffled as: // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef) MVT SExtVT; - MVT FPCastVT = MVT::INVALID_SIMPLE_VALUE_TYPE; switch (VecVT.getSimpleVT().SimpleTy) { default: return SDValue(); case MVT::v2i1: SExtVT = MVT::v2i64; - FPCastVT = MVT::v2f64; break; case MVT::v4i1: SExtVT = MVT::v4i32; - FPCastVT = MVT::v4f32; // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2)) // sign-extend to a 256-bit operation to avoid truncation. if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && N0->getOperand(0).getValueType().is256BitVector()) { SExtVT = MVT::v4i64; - FPCastVT = MVT::v4f64; } break; case MVT::v8i1: @@ -31453,7 +32934,6 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, (N0->getOperand(0).getValueType().is256BitVector() || N0->getOperand(0).getValueType().is512BitVector())) { SExtVT = MVT::v8i32; - FPCastVT = MVT::v8f32; } break; case MVT::v16i1: @@ -31466,26 +32946,37 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, case MVT::v32i1: SExtVT = MVT::v32i8; break; + case MVT::v64i1: + // If we have AVX512F, but not AVX512BW and the input is truncated from + // v64i8 checked earlier. Then split the input and make two pmovmskbs. + if (Subtarget.hasAVX512() && !Subtarget.hasBWI()) { + SExtVT = MVT::v64i8; + break; + } + return SDValue(); }; SDLoc DL(BitCast); - SDValue V = DAG.getSExtOrTrunc(N0, DL, SExtVT); + SDValue V = DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, N0); - if (SExtVT == MVT::v16i8 || SExtVT == MVT::v32i8) { + if (SExtVT == MVT::v64i8) { + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(V, DL); + Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo); + Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Lo); + Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi); + Hi = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Hi); + Hi = DAG.getNode(ISD::SHL, DL, MVT::i64, Hi, + DAG.getConstant(32, DL, MVT::i8)); + V = DAG.getNode(ISD::OR, DL, MVT::i64, Lo, Hi); + } else if (SExtVT == MVT::v16i8 || SExtVT == MVT::v32i8) { V = getPMOVMSKB(DL, V, DAG, Subtarget); - return DAG.getZExtOrTrunc(V, DL, VT); + } else { + if (SExtVT == MVT::v8i16) + V = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, V, + DAG.getUNDEF(MVT::v8i16)); + V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V); } - - if (SExtVT == MVT::v8i16) { - assert(16 == DAG.ComputeNumSignBits(V) && "Expected all/none bit vector"); - V = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, V, - DAG.getUNDEF(MVT::v8i16)); - } else - assert(SExtVT.getScalarType() != MVT::i16 && - "Vectors of i16 must be packed"); - if (FPCastVT != MVT::INVALID_SIMPLE_VALUE_TYPE) - V = DAG.getBitcast(FPCastVT, V); - V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V); return DAG.getZExtOrTrunc(V, DL, VT); } @@ -31806,65 +33297,6 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, return SDValue(); } -// Match a binop + shuffle pyramid that represents a horizontal reduction over -// the elements of a vector. -// Returns the vector that is being reduced on, or SDValue() if a reduction -// was not matched. -static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp, - ArrayRef<ISD::NodeType> CandidateBinOps) { - // The pattern must end in an extract from index 0. - if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || - !isNullConstant(Extract->getOperand(1))) - return SDValue(); - - SDValue Op = Extract->getOperand(0); - unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); - - // Match against one of the candidate binary ops. - if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { - return Op.getOpcode() == unsigned(BinOp); - })) - return SDValue(); - - // At each stage, we're looking for something that looks like: - // %s = shufflevector <8 x i32> %op, <8 x i32> undef, - // <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, - // i32 undef, i32 undef, i32 undef, i32 undef> - // %a = binop <8 x i32> %op, %s - // Where the mask changes according to the stage. E.g. for a 3-stage pyramid, - // we expect something like: - // <4,5,6,7,u,u,u,u> - // <2,3,u,u,u,u,u,u> - // <1,u,u,u,u,u,u,u> - unsigned CandidateBinOp = Op.getOpcode(); - for (unsigned i = 0; i < Stages; ++i) { - if (Op.getOpcode() != CandidateBinOp) - return SDValue(); - - ShuffleVectorSDNode *Shuffle = - dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0).getNode()); - if (Shuffle) { - Op = Op.getOperand(1); - } else { - Shuffle = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1).getNode()); - Op = Op.getOperand(0); - } - - // The first operand of the shuffle should be the same as the other operand - // of the binop. - if (!Shuffle || Shuffle->getOperand(0) != Op) - return SDValue(); - - // Verify the shuffle has the expected (at this stage of the pyramid) mask. - for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) - if (Shuffle->getMaskElt(Index) != MaskEnd + Index) - return SDValue(); - } - - BinOp = CandidateBinOp; - return Op; -} - // Given a select, detect the following pattern: // 1: %2 = zext <N x i8> %0 to <N x i32> // 2: %3 = zext <N x i8> %1 to <N x i32> @@ -31979,8 +33411,8 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, return SDValue(); // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. - unsigned BinOp; - SDValue Src = matchBinOpReduction( + ISD::NodeType BinOp; + SDValue Src = DAG.matchBinOpReduction( Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN}); if (!Src) return SDValue(); @@ -32027,7 +33459,7 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, // ready for the PHMINPOS. if (ExtractVT == MVT::i8) { SDValue Upper = DAG.getVectorShuffle( - SrcVT, DL, MinPos, getZeroVector(MVT::v16i8, Subtarget, DAG, DL), + SrcVT, DL, MinPos, DAG.getConstant(0, DL, MVT::v16i8), {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper); } @@ -32059,8 +33491,8 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. - unsigned BinOp = 0; - SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); + ISD::NodeType BinOp; + SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); if (!Match) return SDValue(); @@ -32142,8 +33574,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, return SDValue(); // Match shuffle + add pyramid. - unsigned BinOp = 0; - SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); + ISD::NodeType BinOp; + SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD}); // The operand is expected to be zero extended from i8 // (verified in detectZextAbsDiff). @@ -32238,6 +33670,15 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, scaleShuffleMask<int>(Scale, Mask, ScaledMask); Mask = std::move(ScaledMask); } else if ((Mask.size() % NumSrcElts) == 0) { + // Simplify Mask based on demanded element. + int ExtractIdx = (int)N->getConstantOperandVal(1); + int Scale = Mask.size() / NumSrcElts; + int Lo = Scale * ExtractIdx; + int Hi = Scale * (ExtractIdx + 1); + for (int i = 0, e = (int)Mask.size(); i != e; ++i) + if (i < Lo || Hi <= i) + Mask[i] = SM_SentinelUndef; + SmallVector<int, 16> WidenedMask; while (Mask.size() > NumSrcElts && canWidenShuffleElements(Mask, WidenedMask)) @@ -32532,11 +33973,14 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) { /// If this is a *dynamic* select (non-constant condition) and we can match /// this node with one of the variable blend instructions, restructure the /// condition so that blends can use the high (sign) bit of each element. -static SDValue combineVSelectToShrunkBlend(SDNode *N, SelectionDAG &DAG, +/// This function will also call SimplfiyDemandedBits on already created +/// BLENDV to perform additional simplifications. +static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { SDValue Cond = N->getOperand(0); - if (N->getOpcode() != ISD::VSELECT || + if ((N->getOpcode() != ISD::VSELECT && + N->getOpcode() != X86ISD::BLENDV) || ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) return SDValue(); @@ -32578,7 +34022,9 @@ static SDValue combineVSelectToShrunkBlend(SDNode *N, SelectionDAG &DAG, // TODO: Add other opcodes eventually lowered into BLEND. for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end(); UI != UE; ++UI) - if (UI->getOpcode() != ISD::VSELECT || UI.getOperandNo() != 0) + if ((UI->getOpcode() != ISD::VSELECT && + UI->getOpcode() != X86ISD::BLENDV) || + UI.getOperandNo() != 0) return SDValue(); APInt DemandedMask(APInt::getSignMask(BitWidth)); @@ -32594,9 +34040,13 @@ static SDValue combineVSelectToShrunkBlend(SDNode *N, SelectionDAG &DAG, // optimizations as we messed with the actual expectation for the vector // boolean values. for (SDNode *U : Cond->uses()) { - SDValue SB = DAG.getNode(X86ISD::SHRUNKBLEND, SDLoc(U), U->getValueType(0), + if (U->getOpcode() == X86ISD::BLENDV) + continue; + + SDValue SB = DAG.getNode(X86ISD::BLENDV, SDLoc(U), U->getValueType(0), Cond, U->getOperand(1), U->getOperand(2)); DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB); + DCI.AddToWorklist(U); } DCI.CommitTargetLoweringOpt(TLO); return SDValue(N, 0); @@ -32608,9 +34058,14 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); SDValue Cond = N->getOperand(0); - // Get the LHS/RHS of the select. SDValue LHS = N->getOperand(1); SDValue RHS = N->getOperand(2); + + // Try simplification again because we use this function to optimize + // BLENDV nodes that are not handled by the generic combiner. + if (SDValue V = DAG.simplifySelect(Cond, LHS, RHS)) + return V; + EVT VT = LHS.getValueType(); EVT CondVT = Cond.getValueType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -32618,18 +34073,9 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // Convert vselects with constant condition into shuffles. if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()) && DCI.isBeforeLegalizeOps()) { - SmallVector<int, 64> Mask(VT.getVectorNumElements(), -1); - for (int i = 0, Size = Mask.size(); i != Size; ++i) { - SDValue CondElt = Cond->getOperand(i); - Mask[i] = i; - // Arbitrarily choose from the 2nd operand if the select condition element - // is undef. - // TODO: Can we do better by matching patterns such as even/odd? - if (CondElt.isUndef() || isNullConstant(CondElt)) - Mask[i] += Size; - } - - return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask); + SmallVector<int, 64> Mask; + if (createShuffleMaskFromVSELECT(Mask, Cond)) + return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask); } // If we have SSE[12] support, try to form min/max nodes. SSE min/max @@ -32814,7 +34260,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // Since SKX these selects have a proper lowering. if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && CondVT.isVector() && CondVT.getVectorElementType() == MVT::i1 && - VT.getVectorNumElements() > 4 && + (ExperimentalVectorWideningLegalization || + VT.getVectorNumElements() > 4) && (VT.getVectorElementType() == MVT::i8 || VT.getVectorElementType() == MVT::i16)) { Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond); @@ -32855,15 +34302,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } - // Early exit check - if (!TLI.isTypeLegal(VT)) - return SDValue(); - // Match VSELECTs into subs with unsigned saturation. if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC && - // psubus is available in SSE2 and AVX2 for i8 and i16 vectors. - ((Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) || - (Subtarget.hasAVX() && (VT == MVT::v32i8 || VT == MVT::v16i16)))) { + // psubus is available in SSE2 for i8 and i16 vectors. + Subtarget.hasSSE2() && VT.getVectorNumElements() >= 2 && + isPowerOf2_32(VT.getVectorNumElements()) && + (VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16)) { ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); // Check if one of the arms of the VSELECT is a zero vector. If it's on the @@ -32877,37 +34322,31 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } if (Other.getNode() && Other->getNumOperands() == 2 && - DAG.isEqualTo(Other->getOperand(0), Cond.getOperand(0))) { + Other->getOperand(0) == Cond.getOperand(0)) { SDValue OpLHS = Other->getOperand(0), OpRHS = Other->getOperand(1); SDValue CondRHS = Cond->getOperand(1); - auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, - ArrayRef<SDValue> Ops) { - return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops); - }; - // Look for a general sub with unsigned saturation first. // x >= y ? x-y : 0 --> subus x, y // x > y ? x-y : 0 --> subus x, y if ((CC == ISD::SETUGE || CC == ISD::SETUGT) && - Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS)) - return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, - SUBUSBuilder); + Other->getOpcode() == ISD::SUB && OpRHS == CondRHS) + return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); - if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) + if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) { if (isa<BuildVectorSDNode>(CondRHS)) { // If the RHS is a constant we have to reverse the const // canonicalization. // x > C-1 ? x+-C : 0 --> subus x, C - auto MatchSUBUS = [](ConstantSDNode *Op, ConstantSDNode *Cond) { + // TODO: Handle build_vectors with undef elements. + auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) { return Cond->getAPIntValue() == (-Op->getAPIntValue() - 1); }; if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD && - ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchSUBUS)) { + ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT)) { OpRHS = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), OpRHS); - return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, - SUBUSBuilder); + return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); } // Another special case: If C was a sign bit, the sub has been @@ -32915,24 +34354,82 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // FIXME: Would it be better to use computeKnownBits to determine // whether it's safe to decanonicalize the xor? // x s< 0 ? x^C : 0 --> subus x, C - if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) + if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) { if (CC == ISD::SETLT && Other.getOpcode() == ISD::XOR && ISD::isBuildVectorAllZeros(CondRHS.getNode()) && OpRHSConst->getAPIntValue().isSignMask()) { - OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT); // Note that we have to rebuild the RHS constant here to ensure we // don't rely on particular values of undef lanes. - return SplitOpsAndApply(DAG, Subtarget, DL, VT, { OpLHS, OpRHS }, - SUBUSBuilder); + OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT); + return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS); } + } } + } + } + } + + // Match VSELECTs into add with unsigned saturation. + if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC && + // paddus is available in SSE2 for i8 and i16 vectors. + Subtarget.hasSSE2() && VT.getVectorNumElements() >= 2 && + isPowerOf2_32(VT.getVectorNumElements()) && + (VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16)) { + ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); + + SDValue CondLHS = Cond->getOperand(0); + SDValue CondRHS = Cond->getOperand(1); + + // Check if one of the arms of the VSELECT is vector with all bits set. + // If it's on the left side invert the predicate to simplify logic below. + SDValue Other; + if (ISD::isBuildVectorAllOnes(LHS.getNode())) { + Other = RHS; + CC = ISD::getSetCCInverse(CC, true); + } else if (ISD::isBuildVectorAllOnes(RHS.getNode())) { + Other = LHS; + } + + if (Other.getNode() && Other.getOpcode() == ISD::ADD) { + SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1); + + // Canonicalize condition operands. + if (CC == ISD::SETUGE) { + std::swap(CondLHS, CondRHS); + CC = ISD::SETULE; + } + + // We can test against either of the addition operands. + // x <= x+y ? x+y : ~0 --> addus x, y + // x+y >= x ? x+y : ~0 --> addus x, y + if (CC == ISD::SETULE && Other == CondRHS && + (OpLHS == CondLHS || OpRHS == CondLHS)) + return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS); + + if (isa<BuildVectorSDNode>(OpRHS) && isa<BuildVectorSDNode>(CondRHS) && + CondLHS == OpLHS) { + // If the RHS is a constant we have to reverse the const + // canonicalization. + // x > ~C ? x+C : ~0 --> addus x, C + auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) { + return Cond->getAPIntValue() == ~Op->getAPIntValue(); + }; + if (CC == ISD::SETULE && + ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT)) + return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS); + } } } + // Early exit check + if (!TLI.isTypeLegal(VT)) + return SDValue(); + if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget)) return V; - if (SDValue V = combineVSelectToShrunkBlend(N, DAG, DCI, Subtarget)) + if (SDValue V = combineVSelectToBLENDV(N, DAG, DCI, Subtarget)) return V; // Custom action for SELECT MMX @@ -33014,16 +34511,7 @@ static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, /*Chain*/ CmpLHS.getOperand(0), /*LHS*/ CmpLHS.getOperand(1), /*RHS*/ DAG.getConstant(-Addend, SDLoc(CmpRHS), CmpRHS.getValueType()), AN->getMemOperand()); - // If the comparision uses the CF flag we can't use INC/DEC instructions. - bool NeedCF = false; - switch (CC) { - default: break; - case X86::COND_A: case X86::COND_AE: - case X86::COND_B: case X86::COND_BE: - NeedCF = true; - break; - } - auto LockOp = lowerAtomicArithWithLOCK(AtomicSub, DAG, Subtarget, !NeedCF); + auto LockOp = lowerAtomicArithWithLOCK(AtomicSub, DAG, Subtarget); DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(0), DAG.getUNDEF(CmpLHS.getValueType())); DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(1), LockOp.getValue(1)); @@ -33453,10 +34941,13 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, SDValue Add = TrueOp; SDValue Const = FalseOp; // Canonicalize the condition code for easier matching and output. - if (CC == X86::COND_E) { + if (CC == X86::COND_E) std::swap(Add, Const); - CC = X86::COND_NE; - } + + // We might have replaced the constant in the cmov with the LHS of the + // compare. If so change it to the RHS of the compare. + if (Const == Cond.getOperand(0)) + Const = Cond.getOperand(1); // Ok, now make sure that Add is (add (cttz X), C2) and Const is a constant. if (isa<ConstantSDNode>(Const) && Add.getOpcode() == ISD::ADD && @@ -33468,7 +34959,8 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, // This should constant fold. SDValue Diff = DAG.getNode(ISD::SUB, DL, VT, Const, Add.getOperand(1)); SDValue CMov = DAG.getNode(X86ISD::CMOV, DL, VT, Diff, Add.getOperand(0), - DAG.getConstant(CC, DL, MVT::i8), Cond); + DAG.getConstant(X86::COND_NE, DL, MVT::i8), + Cond); return DAG.getNode(ISD::ADD, DL, VT, CMov, Add.getOperand(1)); } } @@ -33490,40 +34982,8 @@ static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) { for (unsigned i = 0; i < 2; i++) { SDValue Opd = N->getOperand(i); - // DAG.ComputeNumSignBits return 1 for ISD::ANY_EXTEND, so we need to - // compute signbits for it separately. - if (Opd.getOpcode() == ISD::ANY_EXTEND) { - // For anyextend, it is safe to assume an appropriate number of leading - // sign/zero bits. - if (Opd.getOperand(0).getValueType().getVectorElementType() == MVT::i8) - SignBits[i] = 25; - else if (Opd.getOperand(0).getValueType().getVectorElementType() == - MVT::i16) - SignBits[i] = 17; - else - return false; - IsPositive[i] = true; - } else if (Opd.getOpcode() == ISD::BUILD_VECTOR) { - // All the operands of BUILD_VECTOR need to be int constant. - // Find the smallest value range which all the operands belong to. - SignBits[i] = 32; - IsPositive[i] = true; - for (const SDValue &SubOp : Opd.getNode()->op_values()) { - if (SubOp.isUndef()) - continue; - auto *CN = dyn_cast<ConstantSDNode>(SubOp); - if (!CN) - return false; - APInt IntVal = CN->getAPIntValue(); - if (IntVal.isNegative()) - IsPositive[i] = false; - SignBits[i] = std::min(SignBits[i], IntVal.getNumSignBits()); - } - } else { - SignBits[i] = DAG.ComputeNumSignBits(Opd); - if (Opd.getOpcode() == ISD::ZERO_EXTEND) - IsPositive[i] = true; - } + SignBits[i] = DAG.ComputeNumSignBits(Opd); + IsPositive[i] = DAG.SignBitIsZero(Opd); } bool AllPositive = IsPositive[0] && IsPositive[1]; @@ -33608,90 +35068,90 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0); SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1); - if (NumElts >= OpsVT.getVectorNumElements()) { + if (ExperimentalVectorWideningLegalization || + NumElts >= OpsVT.getVectorNumElements()) { // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the // lower part is needed. SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1); - if (Mode == MULU8 || Mode == MULS8) { + if (Mode == MULU8 || Mode == MULS8) return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, VT, MulLo); - } else { - MVT ResVT = MVT::getVectorVT(MVT::i32, NumElts / 2); - // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16, - // the higher part is also needed. - SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, - ReducedVT, NewN0, NewN1); - - // Repack the lower part and higher part result of mul into a wider - // result. - // Generate shuffle functioning as punpcklwd. - SmallVector<int, 16> ShuffleMask(NumElts); - for (unsigned i = 0, e = NumElts / 2; i < e; i++) { - ShuffleMask[2 * i] = i; - ShuffleMask[2 * i + 1] = i + NumElts; - } - SDValue ResLo = - DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); - ResLo = DAG.getBitcast(ResVT, ResLo); - // Generate shuffle functioning as punpckhwd. - for (unsigned i = 0, e = NumElts / 2; i < e; i++) { - ShuffleMask[2 * i] = i + NumElts / 2; - ShuffleMask[2 * i + 1] = i + NumElts * 3 / 2; - } - SDValue ResHi = - DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); - ResHi = DAG.getBitcast(ResVT, ResHi); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi); - } - } else { - // When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want - // to legalize the mul explicitly because implicit legalization for type - // <4 x i16> to <4 x i32> sometimes involves unnecessary unpack - // instructions which will not exist when we explicitly legalize it by - // extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with - // <4 x i16> undef). - // - // Legalize the operands of mul. - // FIXME: We may be able to handle non-concatenated vectors by insertion. - unsigned ReducedSizeInBits = ReducedVT.getSizeInBits(); - if ((RegSize % ReducedSizeInBits) != 0) - return SDValue(); - SmallVector<SDValue, 16> Ops(RegSize / ReducedSizeInBits, - DAG.getUNDEF(ReducedVT)); - Ops[0] = NewN0; - NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); - Ops[0] = NewN1; - NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); - - if (Mode == MULU8 || Mode == MULS8) { - // Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower - // part is needed. - SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); - - // convert the type of mul result to VT. - MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); - SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG - : ISD::SIGN_EXTEND_VECTOR_INREG, - DL, ResVT, Mul); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, - DAG.getIntPtrConstant(0, DL)); - } else { - // Generate the lower and higher part of mul: pmulhw/pmulhuw. For - // MULU16/MULS16, both parts are needed. - SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); - SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, - OpsVT, NewN0, NewN1); - - // Repack the lower part and higher part result of mul into a wider - // result. Make sure the type of mul result is VT. - MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); - SDValue Res = getUnpackl(DAG, DL, OpsVT, MulLo, MulHi); - Res = DAG.getBitcast(ResVT, Res); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, - DAG.getIntPtrConstant(0, DL)); - } + MVT ResVT = MVT::getVectorVT(MVT::i32, NumElts / 2); + // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16, + // the higher part is also needed. + SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, + ReducedVT, NewN0, NewN1); + + // Repack the lower part and higher part result of mul into a wider + // result. + // Generate shuffle functioning as punpcklwd. + SmallVector<int, 16> ShuffleMask(NumElts); + for (unsigned i = 0, e = NumElts / 2; i < e; i++) { + ShuffleMask[2 * i] = i; + ShuffleMask[2 * i + 1] = i + NumElts; + } + SDValue ResLo = + DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); + ResLo = DAG.getBitcast(ResVT, ResLo); + // Generate shuffle functioning as punpckhwd. + for (unsigned i = 0, e = NumElts / 2; i < e; i++) { + ShuffleMask[2 * i] = i + NumElts / 2; + ShuffleMask[2 * i + 1] = i + NumElts * 3 / 2; + } + SDValue ResHi = + DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); + ResHi = DAG.getBitcast(ResVT, ResHi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi); + } + + // When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want + // to legalize the mul explicitly because implicit legalization for type + // <4 x i16> to <4 x i32> sometimes involves unnecessary unpack + // instructions which will not exist when we explicitly legalize it by + // extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with + // <4 x i16> undef). + // + // Legalize the operands of mul. + // FIXME: We may be able to handle non-concatenated vectors by insertion. + unsigned ReducedSizeInBits = ReducedVT.getSizeInBits(); + if ((RegSize % ReducedSizeInBits) != 0) + return SDValue(); + + SmallVector<SDValue, 16> Ops(RegSize / ReducedSizeInBits, + DAG.getUNDEF(ReducedVT)); + Ops[0] = NewN0; + NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); + Ops[0] = NewN1; + NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); + + if (Mode == MULU8 || Mode == MULS8) { + // Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower + // part is needed. + SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); + + // convert the type of mul result to VT. + MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG + : ISD::SIGN_EXTEND_VECTOR_INREG, + DL, ResVT, Mul); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, + DAG.getIntPtrConstant(0, DL)); } + + // Generate the lower and higher part of mul: pmulhw/pmulhuw. For + // MULU16/MULS16, both parts are needed. + SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); + SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, + OpsVT, NewN0, NewN1); + + // Repack the lower part and higher part result of mul into a wider + // result. Make sure the type of mul result is VT. + MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + SDValue Res = getUnpackl(DAG, DL, OpsVT, MulLo, MulHi); + Res = DAG.getBitcast(ResVT, Res); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, + DAG.getIntPtrConstant(0, DL)); } static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, @@ -33781,13 +35241,13 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, } // If the upper 17 bits of each element are zero then we can use PMADDWD, -// which is always at least as quick as PMULLD, expect on KNL. +// which is always at least as quick as PMULLD, except on KNL. static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { if (!Subtarget.hasSSE2()) return SDValue(); - if (Subtarget.getProcFamily() == X86Subtarget::IntelKNL) + if (Subtarget.isPMADDWDSlow()) return SDValue(); EVT VT = N->getValueType(0); @@ -33797,12 +35257,24 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, return SDValue(); // Make sure the vXi16 type is legal. This covers the AVX512 without BWI case. + // Also allow v2i32 if it will be widened. MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements()); - if (!DAG.getTargetLoweringInfo().isTypeLegal(WVT)) + if (!((ExperimentalVectorWideningLegalization && VT == MVT::v2i32) || + DAG.getTargetLoweringInfo().isTypeLegal(WVT))) return SDValue(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + + // If we are zero extending two steps without SSE4.1, its better to reduce + // the vmul width instead. + if (!Subtarget.hasSSE41() && + (N0.getOpcode() == ISD::ZERO_EXTEND && + N0.getOperand(0).getScalarValueSizeInBits() <= 8) && + (N1.getOpcode() == ISD::ZERO_EXTEND && + N1.getOperand(0).getScalarValueSizeInBits() <= 8)) + return SDValue(); + APInt Mask17 = APInt::getHighBitsSet(32, 17); if (!DAG.MaskedValueIsZero(N1, Mask17) || !DAG.MaskedValueIsZero(N0, Mask17)) @@ -33828,7 +35300,8 @@ static SDValue combineMulToPMULDQ(SDNode *N, SelectionDAG &DAG, // Only support vXi64 vectors. if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 || - !DAG.getTargetLoweringInfo().isTypeLegal(VT)) + VT.getVectorNumElements() < 2 || + !isPowerOf2_32(VT.getVectorNumElements())) return SDValue(); SDValue N0 = N->getOperand(0); @@ -33929,10 +35402,12 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, (SignMulAmt >= 0 && (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)))) { if (isPowerOf2_64(MulAmt2) && - !(N->hasOneUse() && N->use_begin()->getOpcode() == ISD::ADD)) + !(SignMulAmt >= 0 && N->hasOneUse() && + N->use_begin()->getOpcode() == ISD::ADD)) // If second multiplifer is pow2, issue it first. We want the multiply by // 3, 5, or 9 to be folded into the addressing mode unless the lone use - // is an add. + // is an add. Only do this for positive multiply amounts since the + // negate would prevent it from being used as an address mode anyway. std::swap(MulAmt1, MulAmt2); if (isPowerOf2_64(MulAmt1)) @@ -34197,6 +35672,8 @@ static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG, N1.getScalarValueSizeInBits() == SrcBitsPerElt && "Unexpected PACKSS/PACKUS input type"); + bool IsSigned = (X86ISD::PACKSS == Opcode); + // Constant Folding. APInt UndefElts0, UndefElts1; SmallVector<APInt, 32> EltBits0, EltBits1; @@ -34209,7 +35686,6 @@ static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG, unsigned NumSrcElts = NumDstElts / 2; unsigned NumDstEltsPerLane = NumDstElts / NumLanes; unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; - bool IsSigned = (X86ISD::PACKSS == Opcode); APInt Undefs(NumDstElts, 0); SmallVector<APInt, 32> Bits(NumDstElts, APInt::getNullValue(DstBitsPerElt)); @@ -34253,16 +35729,58 @@ static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG, return getConstVector(Bits, Undefs, VT.getSimpleVT(), DAG, SDLoc(N)); } + // Try to combine a PACKUSWB/PACKSSWB implemented truncate with a regular + // truncate to create a larger truncate. + if (Subtarget.hasAVX512() && + N0.getOpcode() == ISD::TRUNCATE && N1.isUndef() && VT == MVT::v16i8 && + N0.getOperand(0).getValueType() == MVT::v8i32) { + if ((IsSigned && DAG.ComputeNumSignBits(N0) > 8) || + (!IsSigned && + DAG.MaskedValueIsZero(N0, APInt::getHighBitsSet(16, 8)))) { + if (Subtarget.hasVLX()) + return DAG.getNode(X86ISD::VTRUNC, SDLoc(N), VT, N0.getOperand(0)); + + // Widen input to v16i32 so we can truncate that. + SDLoc dl(N); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i32, + N0.getOperand(0), DAG.getUNDEF(MVT::v8i32)); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Concat); + } + } + // Attempt to combine as shuffle. SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, + /*AllowVarMask*/ true, DAG, Subtarget)) return Res; return SDValue(); } +static SDValue combineVectorShiftVar(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + assert((X86ISD::VSHL == N->getOpcode() || X86ISD::VSRA == N->getOpcode() || + X86ISD::VSRL == N->getOpcode()) && + "Unexpected shift opcode"); + EVT VT = N->getValueType(0); + + // Shift zero -> zero. + if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode())) + return DAG.getConstant(0, SDLoc(N), VT); + + APInt KnownUndef, KnownZero; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, KnownUndef, + KnownZero, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -34277,13 +35795,14 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, unsigned NumBitsPerElt = VT.getScalarSizeInBits(); assert(VT == N0.getValueType() && (NumBitsPerElt % 8) == 0 && "Unexpected value type"); + assert(N1.getValueType() == MVT::i8 && "Unexpected shift amount type"); // Out of range logical bit shifts are guaranteed to be zero. // Out of range arithmetic bit shifts splat the sign bit. - APInt ShiftVal = cast<ConstantSDNode>(N1)->getAPIntValue(); - if (ShiftVal.zextOrTrunc(8).uge(NumBitsPerElt)) { + unsigned ShiftVal = cast<ConstantSDNode>(N1)->getZExtValue(); + if (ShiftVal >= NumBitsPerElt) { if (LogicalShift) - return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); + return DAG.getConstant(0, SDLoc(N), VT); else ShiftVal = NumBitsPerElt - 1; } @@ -34294,30 +35813,25 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, // Shift zero -> zero. if (ISD::isBuildVectorAllZeros(N0.getNode())) - return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); - - // fold (VSRLI (VSRAI X, Y), 31) -> (VSRLI X, 31). - // This VSRLI only looks at the sign bit, which is unmodified by VSRAI. - // TODO - support other sra opcodes as needed. - if (Opcode == X86ISD::VSRLI && (ShiftVal + 1) == NumBitsPerElt && - N0.getOpcode() == X86ISD::VSRAI) - return DAG.getNode(X86ISD::VSRLI, SDLoc(N), VT, N0.getOperand(0), N1); - - // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1 - if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSHLI && - N1 == N0.getOperand(1)) { - SDValue N00 = N0.getOperand(0); - unsigned NumSignBits = DAG.ComputeNumSignBits(N00); - if (ShiftVal.ult(NumSignBits)) - return N00; + return DAG.getConstant(0, SDLoc(N), VT); + + // Fold (VSRAI (VSRAI X, C1), C2) --> (VSRAI X, (C1 + C2)) with (C1 + C2) + // clamped to (NumBitsPerElt - 1). + if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSRAI) { + unsigned ShiftVal2 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); + unsigned NewShiftVal = ShiftVal + ShiftVal2; + if (NewShiftVal >= NumBitsPerElt) + NewShiftVal = NumBitsPerElt - 1; + return DAG.getNode(X86ISD::VSRAI, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(NewShiftVal, SDLoc(N), MVT::i8)); } // We can decode 'whole byte' logical bit shifts as shuffles. - if (LogicalShift && (ShiftVal.getZExtValue() % 8) == 0) { + if (LogicalShift && (ShiftVal % 8) == 0) { SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return Res; } @@ -34328,18 +35842,22 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, getTargetConstantBitsFromNode(N0, NumBitsPerElt, UndefElts, EltBits)) { assert(EltBits.size() == VT.getVectorNumElements() && "Unexpected shift value type"); - unsigned ShiftImm = ShiftVal.getZExtValue(); for (APInt &Elt : EltBits) { if (X86ISD::VSHLI == Opcode) - Elt <<= ShiftImm; + Elt <<= ShiftVal; else if (X86ISD::VSRAI == Opcode) - Elt.ashrInPlace(ShiftImm); + Elt.ashrInPlace(ShiftVal); else - Elt.lshrInPlace(ShiftImm); + Elt.lshrInPlace(ShiftVal); } return getConstVector(EltBits, UndefElts, VT.getSimpleVT(), DAG, SDLoc(N)); } + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedBits(SDValue(N, 0), + APInt::getAllOnesValue(NumBitsPerElt), DCI)) + return SDValue(N, 0); + return SDValue(); } @@ -34356,7 +35874,8 @@ static SDValue combineVectorInsert(SDNode *N, SelectionDAG &DAG, SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively({Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, + /*AllowVarMask*/ true, DAG, Subtarget)) return Res; return SDValue(); @@ -34468,42 +35987,31 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, return SDValue(); } -// Try to match (and (xor X, -1), Y) logic pattern for (andnp X, Y) combines. -static bool matchANDXORWithAllOnesAsANDNP(SDNode *N, SDValue &X, SDValue &Y) { - if (N->getOpcode() != ISD::AND) - return false; +/// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y). +static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::AND); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); + MVT VT = N->getSimpleValueType(0); + if (!VT.is128BitVector() && !VT.is256BitVector() && !VT.is512BitVector()) + return SDValue(); + + SDValue X, Y; + SDValue N0 = peekThroughBitcasts(N->getOperand(0)); + SDValue N1 = peekThroughBitcasts(N->getOperand(1)); if (N0.getOpcode() == ISD::XOR && ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) { X = N0.getOperand(0); Y = N1; - return true; - } - if (N1.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) { + } else if (N1.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) { X = N1.getOperand(0); Y = N0; - return true; - } - - return false; -} - -/// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y). -static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { - assert(N->getOpcode() == ISD::AND); - - EVT VT = N->getValueType(0); - if (VT != MVT::v2i64 && VT != MVT::v4i64 && VT != MVT::v8i64) + } else return SDValue(); - SDValue X, Y; - if (matchANDXORWithAllOnesAsANDNP(N, X, Y)) - return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y); - - return SDValue(); + X = DAG.getBitcast(VT, X); + Y = DAG.getBitcast(VT, Y); + return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y); } // On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized @@ -34512,8 +36020,8 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { // some of the transition sequences. // Even with AVX-512 this is still useful for removing casts around logical // operations on vXi1 mask types. -static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); assert(VT.isVector() && "Expected vector type"); @@ -34628,6 +36136,10 @@ static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, !SplatVal.isMask()) return SDValue(); + // Don't prevent creation of ANDN. + if (isBitwiseNot(Op0)) + return SDValue(); + if (!SupportedVectorShiftWithImm(VT0.getSimpleVT(), Subtarget, ISD::SRL)) return SDValue(); @@ -34761,6 +36273,73 @@ static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG, return SDValue(); } +// Look for (and (ctpop X), 1) which is the IR form of __builtin_parity. +// Turn it into series of XORs and a setnp. +static SDValue combineParity(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + EVT VT = N->getValueType(0); + + // We only support 64-bit and 32-bit. 64-bit requires special handling + // unless the 64-bit popcnt instruction is legal. + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.isTypeLegal(VT) && TLI.isOperationLegal(ISD::CTPOP, VT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // LHS needs to be a single use CTPOP. + if (N0.getOpcode() != ISD::CTPOP || !N0.hasOneUse()) + return SDValue(); + + // RHS needs to be 1. + if (!isOneConstant(N1)) + return SDValue(); + + SDLoc DL(N); + SDValue X = N0.getOperand(0); + + // If this is 64-bit, its always best to xor the two 32-bit pieces together + // even if we have popcnt. + if (VT == MVT::i64) { + SDValue Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, + DAG.getNode(ISD::SRL, DL, VT, X, + DAG.getConstant(32, DL, MVT::i8))); + SDValue Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); + X = DAG.getNode(ISD::XOR, DL, MVT::i32, Lo, Hi); + // Generate a 32-bit parity idiom. This will bring us back here if we need + // to expand it too. + SDValue Parity = DAG.getNode(ISD::AND, DL, MVT::i32, + DAG.getNode(ISD::CTPOP, DL, MVT::i32, X), + DAG.getConstant(1, DL, MVT::i32)); + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Parity); + } + assert(VT == MVT::i32 && "Unexpected VT!"); + + // Xor the high and low 16-bits together using a 32-bit operation. + SDValue Hi16 = DAG.getNode(ISD::SRL, DL, VT, X, + DAG.getConstant(16, DL, MVT::i8)); + X = DAG.getNode(ISD::XOR, DL, VT, X, Hi16); + + // Finally xor the low 2 bytes together and use a 8-bit flag setting xor. + // This should allow an h-reg to be used to save a shift. + // FIXME: We only get an h-reg in 32-bit mode. + SDValue Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, + DAG.getNode(ISD::SRL, DL, VT, X, + DAG.getConstant(8, DL, MVT::i8))); + SDValue Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, X); + SDVTList VTs = DAG.getVTList(MVT::i8, MVT::i32); + SDValue Flags = DAG.getNode(X86ISD::XOR, DL, VTs, Lo, Hi).getValue(1); + + // Copy the inverse of the parity flag into a register with setcc. + SDValue Setnp = getSETCC(X86::COND_NP, Flags, DL, DAG); + // Zero extend to original type. + return DAG.getNode(ISD::ZERO_EXTEND, DL, N->getValueType(0), Setnp); +} + static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -34788,6 +36367,10 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, } } + // This must be done before legalization has expanded the ctpop. + if (SDValue V = combineParity(N, DAG, Subtarget)) + return V; + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -34811,7 +36394,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return Res; } @@ -34848,7 +36431,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (SDValue Shuffle = combineX86ShufflesRecursively( {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 2, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), VT, Shuffle, N->getOperand(0).getOperand(1)); } @@ -34978,7 +36561,7 @@ static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE41()) return SDValue(); - MVT BlendVT = (VT == MVT::v4i64) ? MVT::v32i8 : MVT::v16i8; + MVT BlendVT = VT.is256BitVector() ? MVT::v32i8 : MVT::v16i8; X = DAG.getBitcast(BlendVT, X); Y = DAG.getBitcast(BlendVT, Y); @@ -35122,11 +36705,21 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, if (SDValue R = combineLogicBlendIntoPBLENDV(N, DAG, Subtarget)) return R; + // Attempt to recursively combine an OR of shuffles. + if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { + SDValue Op(N, 0); + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) + return Res; + } + if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) return SDValue(); // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c) bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); + unsigned Bits = VT.getScalarSizeInBits(); // SHLD/SHRD instructions have lower register pressure, but on some // platforms they have higher latency than the equivalent @@ -35149,6 +36742,23 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, SDValue ShAmt1 = N1.getOperand(1); if (ShAmt1.getValueType() != MVT::i8) return SDValue(); + + // Peek through any modulo shift masks. + SDValue ShMsk0; + if (ShAmt0.getOpcode() == ISD::AND && + isa<ConstantSDNode>(ShAmt0.getOperand(1)) && + ShAmt0.getConstantOperandVal(1) == (Bits - 1)) { + ShMsk0 = ShAmt0; + ShAmt0 = ShAmt0.getOperand(0); + } + SDValue ShMsk1; + if (ShAmt1.getOpcode() == ISD::AND && + isa<ConstantSDNode>(ShAmt1.getOperand(1)) && + ShAmt1.getConstantOperandVal(1) == (Bits - 1)) { + ShMsk1 = ShAmt1; + ShAmt1 = ShAmt1.getOperand(0); + } + if (ShAmt0.getOpcode() == ISD::TRUNCATE) ShAmt0 = ShAmt0.getOperand(0); if (ShAmt1.getOpcode() == ISD::TRUNCATE) @@ -35163,27 +36773,29 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, Opc = X86ISD::SHRD; std::swap(Op0, Op1); std::swap(ShAmt0, ShAmt1); + std::swap(ShMsk0, ShMsk1); } // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( Y, 32 - C ) ) -> SHRD( X, Y, C ) // OR( SHL( X, C ), SRL( SRL( Y, 1 ), XOR( C, 31 ) ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( SHL( Y, 1 ), XOR( C, 31 ) ) ) -> SHRD( X, Y, C ) - unsigned Bits = VT.getSizeInBits(); + // OR( SHL( X, AND( C, 31 ) ), SRL( Y, AND( 0 - C, 31 ) ) ) -> SHLD( X, Y, C ) + // OR( SRL( X, AND( C, 31 ) ), SHL( Y, AND( 0 - C, 31 ) ) ) -> SHRD( X, Y, C ) if (ShAmt1.getOpcode() == ISD::SUB) { SDValue Sum = ShAmt1.getOperand(0); - if (ConstantSDNode *SumC = dyn_cast<ConstantSDNode>(Sum)) { + if (auto *SumC = dyn_cast<ConstantSDNode>(Sum)) { SDValue ShAmt1Op1 = ShAmt1.getOperand(1); if (ShAmt1Op1.getOpcode() == ISD::TRUNCATE) ShAmt1Op1 = ShAmt1Op1.getOperand(0); - if (SumC->getSExtValue() == Bits && ShAmt1Op1 == ShAmt0) - return DAG.getNode(Opc, DL, VT, - Op0, Op1, - DAG.getNode(ISD::TRUNCATE, DL, - MVT::i8, ShAmt0)); - } - } else if (ConstantSDNode *ShAmt1C = dyn_cast<ConstantSDNode>(ShAmt1)) { - ConstantSDNode *ShAmt0C = dyn_cast<ConstantSDNode>(ShAmt0); + if ((SumC->getAPIntValue() == Bits || + (SumC->getAPIntValue() == 0 && ShMsk1)) && + ShAmt1Op1 == ShAmt0) + return DAG.getNode(Opc, DL, VT, Op0, Op1, + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + } + } else if (auto *ShAmt1C = dyn_cast<ConstantSDNode>(ShAmt1)) { + auto *ShAmt0C = dyn_cast<ConstantSDNode>(ShAmt0); if (ShAmt0C && (ShAmt0C->getSExtValue() + ShAmt1C->getSExtValue()) == Bits) return DAG.getNode(Opc, DL, VT, N0.getOperand(0), N1.getOperand(0), @@ -35191,12 +36803,13 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, MVT::i8, ShAmt0)); } else if (ShAmt1.getOpcode() == ISD::XOR) { SDValue Mask = ShAmt1.getOperand(1); - if (ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask)) { + if (auto *MaskC = dyn_cast<ConstantSDNode>(Mask)) { unsigned InnerShift = (X86ISD::SHLD == Opc ? ISD::SRL : ISD::SHL); SDValue ShAmt1Op0 = ShAmt1.getOperand(0); if (ShAmt1Op0.getOpcode() == ISD::TRUNCATE) ShAmt1Op0 = ShAmt1Op0.getOperand(0); - if (MaskC->getSExtValue() == (Bits - 1) && ShAmt1Op0 == ShAmt0) { + if (MaskC->getSExtValue() == (Bits - 1) && + (ShAmt1Op0 == ShAmt0 || ShAmt1Op0 == ShMsk0)) { if (Op1.getOpcode() == InnerShift && isa<ConstantSDNode>(Op1.getOperand(1)) && Op1.getConstantOperandVal(1) == 1) { @@ -35207,7 +36820,7 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, if (InnerShift == ISD::SHL && Op1.getOpcode() == ISD::ADD && Op1.getOperand(0) == Op1.getOperand(1)) { return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0), - DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); } } } @@ -35478,6 +37091,7 @@ static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL, return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal); } if (VT.isVector() && isPowerOf2_32(VT.getVectorNumElements()) && + !Subtarget.hasAVX512() && (SVT == MVT::i8 || SVT == MVT::i16) && (InSVT == MVT::i16 || InSVT == MVT::i32)) { if (auto USatVal = detectSSatPattern(In, VT, true)) { @@ -35514,7 +37128,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, EVT ScalarVT = VT.getVectorElementType(); if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && - isPowerOf2_32(NumElems))) + NumElems >= 2 && isPowerOf2_32(NumElems))) return SDValue(); // InScalarVT is the intermediate type in AVG pattern and it should be greater @@ -35752,8 +37366,8 @@ reduceMaskedLoadToScalarLoad(MaskedLoadSDNode *ML, SelectionDAG &DAG, Alignment, ML->getMemOperand()->getFlags()); // Insert the loaded element into the appropriate place in the vector. - SDValue Insert = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, ML->getSrc0(), - Load, VecIndex); + SDValue Insert = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, + ML->getPassThru(), Load, VecIndex); return DCI.CombineTo(ML, Insert, Load.getValue(1), true); } @@ -35776,7 +37390,8 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG, if (LoadFirstElt && LoadLastElt) { SDValue VecLd = DAG.getLoad(VT, DL, ML->getChain(), ML->getBasePtr(), ML->getMemOperand()); - SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), VecLd, ML->getSrc0()); + SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), VecLd, + ML->getPassThru()); return DCI.CombineTo(ML, Blend, VecLd.getValue(1), true); } @@ -35786,7 +37401,7 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG, // Don't try this if the pass-through operand is already undefined. That would // cause an infinite loop because that's what we're about to create. - if (ML->getSrc0().isUndef()) + if (ML->getPassThru().isUndef()) return SDValue(); // The new masked load has an undef pass-through operand. The select uses the @@ -35795,7 +37410,8 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG, ML->getMask(), DAG.getUNDEF(VT), ML->getMemoryVT(), ML->getMemOperand(), ML->getExtensionType()); - SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), NewML, ML->getSrc0()); + SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), NewML, + ML->getPassThru()); return DCI.CombineTo(ML, Blend, NewML.getValue(1), true); } @@ -35842,9 +37458,9 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, LdVT.getScalarType(), NumElems*SizeRatio); assert(WideVecVT.getSizeInBits() == VT.getSizeInBits()); - // Convert Src0 value. - SDValue WideSrc0 = DAG.getBitcast(WideVecVT, Mld->getSrc0()); - if (!Mld->getSrc0().isUndef()) { + // Convert PassThru value. + SDValue WidePassThru = DAG.getBitcast(WideVecVT, Mld->getPassThru()); + if (!Mld->getPassThru().isUndef()) { SmallVector<int, 16> ShuffleVec(NumElems * SizeRatio, -1); for (unsigned i = 0; i != NumElems; ++i) ShuffleVec[i] = i * SizeRatio; @@ -35852,7 +37468,7 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, // Can't shuffle using an illegal type. assert(DAG.getTargetLoweringInfo().isTypeLegal(WideVecVT) && "WideVecVT should be legal"); - WideSrc0 = DAG.getVectorShuffle(WideVecVT, dl, WideSrc0, + WidePassThru = DAG.getVectorShuffle(WideVecVT, dl, WidePassThru, DAG.getUNDEF(WideVecVT), ShuffleVec); } @@ -35885,10 +37501,10 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, } SDValue WideLd = DAG.getMaskedLoad(WideVecVT, dl, Mld->getChain(), - Mld->getBasePtr(), NewMask, WideSrc0, + Mld->getBasePtr(), NewMask, WidePassThru, Mld->getMemoryVT(), Mld->getMemOperand(), ISD::NON_EXTLOAD); - SDValue NewVec = getExtendInVec(X86ISD::VSEXT, dl, VT, WideLd, DAG); + SDValue NewVec = getExtendInVec(/*Signed*/true, dl, VT, WideLd, DAG); return DCI.CombineTo(N, NewVec, WideLd.getValue(1), true); } @@ -35920,31 +37536,25 @@ static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS, } static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MaskedStoreSDNode *Mst = cast<MaskedStoreSDNode>(N); - if (Mst->isCompressingStore()) return SDValue(); + EVT VT = Mst->getValue().getValueType(); if (!Mst->isTruncatingStore()) { if (SDValue ScalarStore = reduceMaskedStoreToScalarStore(Mst, DAG)) return ScalarStore; - // If the mask is checking (0 > X), we're creating a vector with all-zeros - // or all-ones elements based on the sign bits of X. AVX1 masked store only - // cares about the sign bit of each mask element, so eliminate the compare: - // mstore val, ptr, (pcmpgt 0, X) --> mstore val, ptr, X - // Note that by waiting to match an x86-specific PCMPGT node, we're - // eliminating potentially more complex matching of a setcc node which has - // a full range of predicates. + // If the mask value has been legalized to a non-boolean vector, try to + // simplify ops leading up to it. We only demand the MSB of each lane. SDValue Mask = Mst->getMask(); - if (Mask.getOpcode() == X86ISD::PCMPGT && - ISD::isBuildVectorAllZeros(Mask.getOperand(0).getNode())) { - assert(Mask.getValueType() == Mask.getOperand(1).getValueType() && - "Unexpected type for PCMPGT"); - return DAG.getMaskedStore( - Mst->getChain(), SDLoc(N), Mst->getValue(), Mst->getBasePtr(), - Mask.getOperand(1), Mst->getMemoryVT(), Mst->getMemOperand()); + if (Mask.getScalarValueSizeInBits() != 1) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + APInt DemandedMask(APInt::getSignMask(VT.getScalarSizeInBits())); + if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) + return SDValue(N, 0); } // TODO: AVX512 targets should also be able to simplify something like the @@ -35955,7 +37565,6 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, } // Resolve truncating stores. - EVT VT = Mst->getValue().getValueType(); unsigned NumElems = VT.getVectorNumElements(); EVT StVT = Mst->getMemoryVT(); SDLoc dl(Mst); @@ -36043,6 +37652,18 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, SDValue StoredVal = St->getOperand(1); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // Convert a store of vXi1 into a store of iX and a bitcast. + if (!Subtarget.hasAVX512() && VT == StVT && VT.isVector() && + VT.getVectorElementType() == MVT::i1) { + + EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), VT.getVectorNumElements()); + StoredVal = DAG.getBitcast(NewVT, StoredVal); + + return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + St->getPointerInfo(), St->getAlignment(), + St->getMemOperand()->getFlags()); + } + // If this is a store of a scalar_to_vector to v1i1, just use a scalar store. // This will avoid a copy to k-register. if (VT == MVT::v1i1 && VT == StVT && Subtarget.hasAVX512() && @@ -36269,7 +37890,8 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, // Otherwise, if it's legal to use f64 SSE instructions, use f64 load/store // pair instead. if (Subtarget.is64Bit() || F64IsLegal) { - MVT LdVT = Subtarget.is64Bit() ? MVT::i64 : MVT::f64; + MVT LdVT = (Subtarget.is64Bit() && + (!VT.isFloatingPoint() || !F64IsLegal)) ? MVT::i64 : MVT::f64; SDValue NewLd = DAG.getLoad(LdVT, LdDL, Ld->getChain(), Ld->getBasePtr(), Ld->getMemOperand()); @@ -36343,10 +37965,12 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, /// In short, LHS and RHS are inspected to see if LHS op RHS is of the form /// A horizontal-op B, for some already available A and B, and if so then LHS is /// set to A, RHS to B, and the routine returns 'true'. -/// Note that the binary operation should have the property that if one of the -/// operands is UNDEF then the result is UNDEF. static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative) { - // Look for the following pattern: if + // If either operand is undef, bail out. The binop should be simplified. + if (LHS.isUndef() || RHS.isUndef()) + return false; + + // Look for the following pattern: // A = < float a0, float a1, float a2, float a3 > // B = < float b0, float b1, float b2, float b3 > // and @@ -36361,25 +37985,15 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative) { return false; MVT VT = LHS.getSimpleValueType(); - assert((VT.is128BitVector() || VT.is256BitVector()) && "Unsupported vector type for horizontal add/sub"); - // Handle 128 and 256-bit vector lengths. AVX defines horizontal add/sub to - // operate independently on 128-bit lanes. - unsigned NumElts = VT.getVectorNumElements(); - unsigned NumLanes = VT.getSizeInBits()/128; - unsigned NumLaneElts = NumElts / NumLanes; - assert((NumLaneElts % 2 == 0) && - "Vector type should have an even number of elements in each lane"); - unsigned HalfLaneElts = NumLaneElts/2; - // View LHS in the form // LHS = VECTOR_SHUFFLE A, B, LMask - // If LHS is not a shuffle then pretend it is the shuffle + // If LHS is not a shuffle, then pretend it is the identity shuffle: // LHS = VECTOR_SHUFFLE LHS, undef, <0, 1, ..., N-1> - // NOTE: in what follows a default initialized SDValue represents an UNDEF of - // type VT. + // NOTE: A default initialized SDValue represents an UNDEF of type VT. + unsigned NumElts = VT.getVectorNumElements(); SDValue A, B; SmallVector<int, 16> LMask(NumElts); if (LHS.getOpcode() == ISD::VECTOR_SHUFFLE) { @@ -36388,10 +38002,9 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative) { if (!LHS.getOperand(1).isUndef()) B = LHS.getOperand(1); ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(LHS.getNode())->getMask(); - std::copy(Mask.begin(), Mask.end(), LMask.begin()); + llvm::copy(Mask, LMask.begin()); } else { - if (!LHS.isUndef()) - A = LHS; + A = LHS; for (unsigned i = 0; i != NumElts; ++i) LMask[i] = i; } @@ -36406,45 +38019,51 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, bool IsCommutative) { if (!RHS.getOperand(1).isUndef()) D = RHS.getOperand(1); ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(RHS.getNode())->getMask(); - std::copy(Mask.begin(), Mask.end(), RMask.begin()); + llvm::copy(Mask, RMask.begin()); } else { - if (!RHS.isUndef()) - C = RHS; + C = RHS; for (unsigned i = 0; i != NumElts; ++i) RMask[i] = i; } + // If A and B occur in reverse order in RHS, then canonicalize by commuting + // RHS operands and shuffle mask. + if (A != C) { + std::swap(C, D); + ShuffleVectorSDNode::commuteMask(RMask); + } // Check that the shuffles are both shuffling the same vectors. - if (!(A == C && B == D) && !(A == D && B == C)) - return false; - - // If everything is UNDEF then bail out: it would be better to fold to UNDEF. - if (!A.getNode() && !B.getNode()) + if (!(A == C && B == D)) return false; - // If A and B occur in reverse order in RHS, then "swap" them (which means - // rewriting the mask). - if (A != C) - ShuffleVectorSDNode::commuteMask(RMask); - - // At this point LHS and RHS are equivalent to - // LHS = VECTOR_SHUFFLE A, B, LMask - // RHS = VECTOR_SHUFFLE A, B, RMask + // LHS and RHS are now: + // LHS = shuffle A, B, LMask + // RHS = shuffle A, B, RMask // Check that the masks correspond to performing a horizontal operation. - for (unsigned l = 0; l != NumElts; l += NumLaneElts) { - for (unsigned i = 0; i != NumLaneElts; ++i) { - int LIdx = LMask[i+l], RIdx = RMask[i+l]; - - // Ignore any UNDEF components. + // AVX defines horizontal add/sub to operate independently on 128-bit lanes, + // so we just repeat the inner loop if this is a 256-bit op. + unsigned Num128BitChunks = VT.getSizeInBits() / 128; + unsigned NumEltsPer128BitChunk = NumElts / Num128BitChunks; + assert((NumEltsPer128BitChunk % 2 == 0) && + "Vector type should have an even number of elements in each lane"); + for (unsigned j = 0; j != NumElts; j += NumEltsPer128BitChunk) { + for (unsigned i = 0; i != NumEltsPer128BitChunk; ++i) { + // Ignore undefined components. + int LIdx = LMask[i + j], RIdx = RMask[i + j]; if (LIdx < 0 || RIdx < 0 || (!A.getNode() && (LIdx < (int)NumElts || RIdx < (int)NumElts)) || (!B.getNode() && (LIdx >= (int)NumElts || RIdx >= (int)NumElts))) continue; - // Check that successive elements are being operated on. If not, this is + // The low half of the 128-bit result must choose from A. + // The high half of the 128-bit result must choose from B, + // unless B is undef. In that case, we are always choosing from A. + unsigned NumEltsPer64BitChunk = NumEltsPer128BitChunk / 2; + unsigned Src = B.getNode() ? i >= NumEltsPer64BitChunk : 0; + + // Check that successive elements are being operated on. If not, this is // not a horizontal operation. - unsigned Src = (i/HalfLaneElts); // each lane is split between srcs - int Index = 2*(i%HalfLaneElts) + NumElts*Src + l; + int Index = 2 * (i % NumEltsPer64BitChunk) + NumElts * Src + j; if (!(LIdx == Index && RIdx == Index + 1) && !(IsCommutative && LIdx == Index + 1 && RIdx == Index)) return false; @@ -36463,21 +38082,24 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); bool IsFadd = N->getOpcode() == ISD::FADD; + auto HorizOpcode = IsFadd ? X86ISD::FHADD : X86ISD::FHSUB; assert((IsFadd || N->getOpcode() == ISD::FSUB) && "Wrong opcode"); // Try to synthesize horizontal add/sub from adds/subs of shuffles. if (((Subtarget.hasSSE3() && (VT == MVT::v4f32 || VT == MVT::v2f64)) || (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))) && - isHorizontalBinOp(LHS, RHS, IsFadd)) { - auto NewOpcode = IsFadd ? X86ISD::FHADD : X86ISD::FHSUB; - return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); - } + isHorizontalBinOp(LHS, RHS, IsFadd) && + shouldUseHorizontalOp(LHS == RHS, DAG, Subtarget)) + return DAG.getNode(HorizOpcode, SDLoc(N), VT, LHS, RHS); + return SDValue(); } /// Attempt to pre-truncate inputs to arithmetic ops if it will simplify /// the codegen. /// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) ) +/// TODO: This overlaps with the generic combiner's visitTRUNCATE. Remove +/// anything that is guaranteed to be transformed by DAGCombiner. static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL) { @@ -36489,34 +38111,20 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); EVT SrcVT = Src.getValueType(); - auto IsRepeatedOpOrFreeTruncation = [VT](SDValue Op0, SDValue Op1) { + auto IsFreeTruncation = [VT](SDValue Op) { unsigned TruncSizeInBits = VT.getScalarSizeInBits(); - // Repeated operand, so we are only trading one output truncation for - // one input truncation. - if (Op0 == Op1) - return true; - - // See if either operand has been extended from a smaller/equal size to + // See if this has been extended from a smaller/equal size to // the truncation size, allowing a truncation to combine with the extend. - unsigned Opcode0 = Op0.getOpcode(); - if ((Opcode0 == ISD::ANY_EXTEND || Opcode0 == ISD::SIGN_EXTEND || - Opcode0 == ISD::ZERO_EXTEND) && - Op0.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits) - return true; - - unsigned Opcode1 = Op1.getOpcode(); - if ((Opcode1 == ISD::ANY_EXTEND || Opcode1 == ISD::SIGN_EXTEND || - Opcode1 == ISD::ZERO_EXTEND) && - Op1.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits) + unsigned Opcode = Op.getOpcode(); + if ((Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND || + Opcode == ISD::ZERO_EXTEND) && + Op.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits) return true; - // See if either operand is a single use constant which can be constant - // folded. - SDValue BC0 = peekThroughOneUseBitcasts(Op0); - SDValue BC1 = peekThroughOneUseBitcasts(Op1); - return ISD::isBuildVectorOfConstantSDNodes(BC0.getNode()) || - ISD::isBuildVectorOfConstantSDNodes(BC1.getNode()); + // See if this is a single use constant which can be constant folded. + SDValue BC = peekThroughOneUseBitcasts(Op); + return ISD::isBuildVectorOfConstantSDNodes(BC.getNode()); }; auto TruncateArithmetic = [&](SDValue N0, SDValue N1) { @@ -36526,7 +38134,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, }; // Don't combine if the operation has other uses. - if (!N->isOnlyUserOf(Src.getNode())) + if (!Src.hasOneUse()) return SDValue(); // Only support vector truncation for now. @@ -36544,7 +38152,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, SDValue Op0 = Src.getOperand(0); SDValue Op1 = Src.getOperand(1); if (TLI.isOperationLegalOrPromote(Opcode, VT) && - IsRepeatedOpOrFreeTruncation(Op0, Op1)) + (Op0 == Op1 || IsFreeTruncation(Op0) || IsFreeTruncation(Op1))) return TruncateArithmetic(Op0, Op1); break; } @@ -36557,11 +38165,20 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { - // TODO: ISD::SUB should be here but interferes with combineSubToSubus. SDValue Op0 = Src.getOperand(0); SDValue Op1 = Src.getOperand(1); if (TLI.isOperationLegal(Opcode, VT) && - IsRepeatedOpOrFreeTruncation(Op0, Op1)) + (Op0 == Op1 || IsFreeTruncation(Op0) || IsFreeTruncation(Op1))) + return TruncateArithmetic(Op0, Op1); + break; + } + case ISD::SUB: { + // TODO: ISD::SUB We are conservative and require both sides to be freely + // truncatable to avoid interfering with combineSubToSubus. + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + if (TLI.isOperationLegal(Opcode, VT) && + (Op0 == Op1 || (IsFreeTruncation(Op0) && IsFreeTruncation(Op1)))) return TruncateArithmetic(Op0, Op1); break; } @@ -36701,8 +38318,7 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL, // Use PACKUS if the input has zero-bits that extend all the way to the // packed/truncated value. e.g. masks, zext_in_reg, etc. - KnownBits Known; - DAG.computeKnownBits(In, Known); + KnownBits Known = DAG.computeKnownBits(In); unsigned NumLeadingZeroBits = Known.countMinLeadingZeros(); if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits)) return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget); @@ -36733,9 +38349,11 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, if (!Subtarget.hasSSE2()) return SDValue(); - // Only handle vXi16 types that are at least 128-bits. + // Only handle vXi16 types that are at least 128-bits unless they will be + // widened. if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 || - VT.getVectorNumElements() < 8) + (!ExperimentalVectorWideningLegalization && + VT.getVectorNumElements() < 8)) return SDValue(); // Input type should be vXi32. @@ -36951,29 +38569,72 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, /// Returns the negated value if the node \p N flips sign of FP value. /// -/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000). +/// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000) +/// or FSUB(0, x) /// AVX512F does not have FXOR, so FNEG is lowered as /// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). /// In this case we go though all bitcasts. -static SDValue isFNEG(SDNode *N) { +/// This also recognizes splat of a negated value and returns the splat of that +/// value. +static SDValue isFNEG(SelectionDAG &DAG, SDNode *N) { if (N->getOpcode() == ISD::FNEG) return N->getOperand(0); SDValue Op = peekThroughBitcasts(SDValue(N, 0)); - if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) + auto VT = Op->getValueType(0); + if (auto SVOp = dyn_cast<ShuffleVectorSDNode>(Op.getNode())) { + // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate + // of this is VECTOR_SHUFFLE(-VEC1, UNDEF). The mask can be anything here. + if (!SVOp->getOperand(1).isUndef()) + return SDValue(); + if (SDValue NegOp0 = isFNEG(DAG, SVOp->getOperand(0).getNode())) + return DAG.getVectorShuffle(VT, SDLoc(SVOp), NegOp0, DAG.getUNDEF(VT), + SVOp->getMask()); + return SDValue(); + } + unsigned Opc = Op.getOpcode(); + if (Opc == ISD::INSERT_VECTOR_ELT) { + // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF, + // -V, INDEX). + SDValue InsVector = Op.getOperand(0); + SDValue InsVal = Op.getOperand(1); + if (!InsVector.isUndef()) + return SDValue(); + if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode())) + return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector, + NegInsVal, Op.getOperand(2)); + return SDValue(); + } + + if (Opc != X86ISD::FXOR && Opc != ISD::XOR && Opc != ISD::FSUB) return SDValue(); SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); if (!Op1.getValueType().isFloatingPoint()) return SDValue(); - // Extract constant bits and see if they are all sign bit masks. + SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); + + // For XOR and FXOR, we want to check if constant bits of Op1 are sign bit + // masks. For FSUB, we have to check if constant bits of Op0 are sign bit + // masks and hence we swap the operands. + if (Opc == ISD::FSUB) + std::swap(Op0, Op1); + APInt UndefElts; SmallVector<APInt, 16> EltBits; + // Extract constant bits and see if they are all sign bit masks. Ignore the + // undef elements. if (getTargetConstantBitsFromNode(Op1, Op1.getScalarValueSizeInBits(), - UndefElts, EltBits, false, false)) - if (llvm::all_of(EltBits, [](APInt &I) { return I.isSignMask(); })) - return peekThroughBitcasts(Op.getOperand(0)); + UndefElts, EltBits, + /* AllowWholeUndefs */ true, + /* AllowPartialUndefs */ false)) { + for (unsigned I = 0, E = EltBits.size(); I < E; I++) + if (!UndefElts[I] && !EltBits[I].isSignMask()) + return SDValue(); + + return peekThroughBitcasts(Op0); + } return SDValue(); } @@ -36982,8 +38643,9 @@ static SDValue isFNEG(SDNode *N) { static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT OrigVT = N->getValueType(0); - SDValue Arg = isFNEG(N); - assert(Arg.getNode() && "N is expected to be an FNEG node"); + SDValue Arg = isFNEG(DAG, N); + if (!Arg) + return SDValue(); EVT VT = Arg.getValueType(); EVT SVT = VT.getScalarType(); @@ -37033,25 +38695,27 @@ static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MVT VT = N->getSimpleValueType(0); // If we have integer vector types available, use the integer opcodes. - if (VT.isVector() && Subtarget.hasSSE2()) { - SDLoc dl(N); + if (!VT.isVector() || !Subtarget.hasSSE2()) + return SDValue(); - MVT IntVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64); + SDLoc dl(N); - SDValue Op0 = DAG.getBitcast(IntVT, N->getOperand(0)); - SDValue Op1 = DAG.getBitcast(IntVT, N->getOperand(1)); - unsigned IntOpcode; - switch (N->getOpcode()) { - default: llvm_unreachable("Unexpected FP logic op"); - case X86ISD::FOR: IntOpcode = ISD::OR; break; - case X86ISD::FXOR: IntOpcode = ISD::XOR; break; - case X86ISD::FAND: IntOpcode = ISD::AND; break; - case X86ISD::FANDN: IntOpcode = X86ISD::ANDNP; break; - } - SDValue IntOp = DAG.getNode(IntOpcode, dl, IntVT, Op0, Op1); - return DAG.getBitcast(VT, IntOp); + unsigned IntBits = VT.getScalarSizeInBits(); + MVT IntSVT = MVT::getIntegerVT(IntBits); + MVT IntVT = MVT::getVectorVT(IntSVT, VT.getSizeInBits() / IntBits); + + SDValue Op0 = DAG.getBitcast(IntVT, N->getOperand(0)); + SDValue Op1 = DAG.getBitcast(IntVT, N->getOperand(1)); + unsigned IntOpcode; + switch (N->getOpcode()) { + default: llvm_unreachable("Unexpected FP logic op"); + case X86ISD::FOR: IntOpcode = ISD::OR; break; + case X86ISD::FXOR: IntOpcode = ISD::XOR; break; + case X86ISD::FAND: IntOpcode = ISD::AND; break; + case X86ISD::FANDN: IntOpcode = X86ISD::ANDNP; break; } - return SDValue(); + SDValue IntOp = DAG.getNode(IntOpcode, dl, IntVT, Op0, Op1); + return DAG.getBitcast(VT, IntOp); } @@ -37098,9 +38762,7 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG, if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) return FPLogic; - if (isFNEG(N)) - return combineFneg(N, DAG, Subtarget); - return SDValue(); + return combineFneg(N, DAG, Subtarget); } static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG, @@ -37112,8 +38774,6 @@ static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG, unsigned NumBits = VT.getSizeInBits(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); // TODO - Constant Folding. if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) { @@ -37127,12 +38787,9 @@ static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG, } // Only bottom 16-bits of the control bits are required. - KnownBits Known; APInt DemandedMask(APInt::getLowBitsSet(NumBits, 16)); - if (TLI.SimplifyDemandedBits(Op1, DemandedMask, Known, TLO)) { - DCI.CommitTargetLoweringOpt(TLO); + if (TLI.SimplifyDemandedBits(Op1, DemandedMask, DCI)) return SDValue(N, 0); - } return SDValue(); } @@ -37233,9 +38890,8 @@ static SDValue combineFOr(SDNode *N, SelectionDAG &DAG, if (isNullFPScalarOrVectorConst(N->getOperand(1))) return N->getOperand(0); - if (isFNEG(N)) - if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) - return NewVal; + if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) + return NewVal; return lowerX86FPLogicOp(N, DAG, Subtarget); } @@ -37320,26 +38976,47 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, return DAG.getSelect(DL, VT, IsOp0Nan, Op1, MinOrMax); } +static SDValue combineX86INT_TO_FP(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + EVT VT = N->getValueType(0); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + APInt KnownUndef, KnownZero; + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, KnownUndef, + KnownZero, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + /// Do target-specific dag combines on X86ISD::ANDNP nodes. static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + MVT VT = N->getSimpleValueType(0); + // ANDNP(0, x) -> x if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode())) return N->getOperand(1); // ANDNP(x, 0) -> 0 if (ISD::isBuildVectorAllZeros(N->getOperand(1).getNode())) - return getZeroVector(N->getSimpleValueType(0), Subtarget, DAG, SDLoc(N)); + return DAG.getConstant(0, SDLoc(N), VT); - EVT VT = N->getValueType(0); + // Turn ANDNP back to AND if input is inverted. + if (VT.isVector() && N->getOperand(0).getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N->getOperand(0).getOperand(1).getNode())) { + return DAG.getNode(ISD::AND, SDLoc(N), VT, + N->getOperand(0).getOperand(0), N->getOperand(1)); + } // Attempt to recursively combine a bitmask ANDNP with shuffles. if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 1, - /*HasVarMask*/ false, DAG, Subtarget)) + /*HasVarMask*/ false, /*AllowVarMask*/ true, DAG, Subtarget)) return Res; } @@ -37502,36 +39179,6 @@ static SDValue promoteExtBeforeAdd(SDNode *Ext, SelectionDAG &DAG, return DAG.getNode(ISD::ADD, SDLoc(Add), VT, NewExt, NewConstant, Flags); } -/// (i8,i32 {s/z}ext ({s/u}divrem (i8 x, i8 y)) -> -/// (i8,i32 ({s/u}divrem_sext_hreg (i8 x, i8 y) -/// This exposes the {s/z}ext to the sdivrem lowering, so that it directly -/// extends from AH (which we otherwise need to do contortions to access). -static SDValue getDivRem8(SDNode *N, SelectionDAG &DAG) { - SDValue N0 = N->getOperand(0); - auto OpcodeN = N->getOpcode(); - auto OpcodeN0 = N0.getOpcode(); - if (!((OpcodeN == ISD::SIGN_EXTEND && OpcodeN0 == ISD::SDIVREM) || - (OpcodeN == ISD::ZERO_EXTEND && OpcodeN0 == ISD::UDIVREM))) - return SDValue(); - - EVT VT = N->getValueType(0); - EVT InVT = N0.getValueType(); - if (N0.getResNo() != 1 || InVT != MVT::i8 || - !(VT == MVT::i32 || VT == MVT::i64)) - return SDValue(); - - SDVTList NodeTys = DAG.getVTList(MVT::i8, MVT::i32); - auto DivRemOpcode = OpcodeN0 == ISD::SDIVREM ? X86ISD::SDIVREM8_SEXT_HREG - : X86ISD::UDIVREM8_ZEXT_HREG; - SDValue R = DAG.getNode(DivRemOpcode, SDLoc(N), NodeTys, N0.getOperand(0), - N0.getOperand(1)); - DAG.ReplaceAllUsesOfValueWith(N0.getValue(0), R.getValue(0)); - // If this was a 64-bit extend, complete it. - if (VT == MVT::i64) - return DAG.getNode(OpcodeN, SDLoc(N), VT, R.getValue(1)); - return R.getValue(1); -} - // If we face {ANY,SIGN,ZERO}_EXTEND that is applied to a CMOV with constant // operands and the result of CMOV is not used anywhere else - promote CMOV // itself instead of promoting its result. This could be beneficial, because: @@ -37685,6 +39332,9 @@ combineToExtendBoolVectorInReg(SDNode *N, SelectionDAG &DAG, static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + if (ExperimentalVectorWideningLegalization) + return SDValue(); + unsigned Opcode = N->getOpcode(); if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND) return SDValue(); @@ -37699,17 +39349,33 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, EVT InVT = N0.getValueType(); EVT InSVT = InVT.getScalarType(); + // FIXME: Generic DAGCombiner previously had a bug that would cause a + // sign_extend of setcc to sometimes return the original node and tricked it + // into thinking CombineTo was used which prevented the target combines from + // running. + // Earlying out here to avoid regressions like this + // (v4i32 (sext (v4i1 (setcc (v4i16))))) + // Becomes + // (v4i32 (sext_invec (v8i16 (concat (v4i16 (setcc (v4i16))), undef)))) + // Type legalized to + // (v4i32 (sext_invec (v8i16 (trunc_invec (v4i32 (setcc (v4i32))))))) + // Leading to a packssdw+pmovsxwd + // We could write a DAG combine to fix this, but really we shouldn't be + // creating sext_invec that's forcing v8i16 into the DAG. + if (N0.getOpcode() == ISD::SETCC) + return SDValue(); + // Input type must be a vector and we must be extending legal integer types. - if (!VT.isVector()) + if (!VT.isVector() || VT.getVectorNumElements() < 2) return SDValue(); if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16) return SDValue(); if (InSVT != MVT::i32 && InSVT != MVT::i16 && InSVT != MVT::i8) return SDValue(); - // On AVX2+ targets, if the input/output types are both legal then we will be - // able to use SIGN_EXTEND/ZERO_EXTEND directly. - if (Subtarget.hasInt256() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && + // If the input/output types are both legal then we have at least AVX1 and + // we will be able to use SIGN_EXTEND/ZERO_EXTEND directly. + if (DAG.getTargetLoweringInfo().isTypeLegal(VT) && DAG.getTargetLoweringInfo().isTypeLegal(InVT)) return SDValue(); @@ -37737,16 +39403,16 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, DAG.getIntPtrConstant(0, DL)); } - // If target-size is 128-bits (or 256-bits on AVX2 target), then convert to + // If target-size is 128-bits (or 256-bits on AVX target), then convert to // ISD::*_EXTEND_VECTOR_INREG which ensures lowering to X86ISD::V*EXT. // Also use this if we don't have SSE41 to allow the legalizer do its job. if (!Subtarget.hasSSE41() || VT.is128BitVector() || - (VT.is256BitVector() && Subtarget.hasInt256()) || + (VT.is256BitVector() && Subtarget.hasAVX()) || (VT.is512BitVector() && Subtarget.useAVX512Regs())) { SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits()); - return Opcode == ISD::SIGN_EXTEND - ? DAG.getSignExtendVectorInReg(ExOp, DL, VT) - : DAG.getZeroExtendVectorInReg(ExOp, DL, VT); + Opcode = Opcode == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND_VECTOR_INREG + : ISD::ZERO_EXTEND_VECTOR_INREG; + return DAG.getNode(Opcode, DL, VT, ExOp); } auto SplitAndExtendInReg = [&](unsigned SplitSize) { @@ -37755,22 +39421,23 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, EVT SubVT = EVT::getVectorVT(*DAG.getContext(), SVT, NumSubElts); EVT InSubVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubElts); + unsigned IROpc = Opcode == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND_VECTOR_INREG + : ISD::ZERO_EXTEND_VECTOR_INREG; + SmallVector<SDValue, 8> Opnds; for (unsigned i = 0, Offset = 0; i != NumVecs; ++i, Offset += NumSubElts) { SDValue SrcVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InSubVT, N0, DAG.getIntPtrConstant(Offset, DL)); SrcVec = ExtendVecSize(DL, SrcVec, SplitSize); - SrcVec = Opcode == ISD::SIGN_EXTEND - ? DAG.getSignExtendVectorInReg(SrcVec, DL, SubVT) - : DAG.getZeroExtendVectorInReg(SrcVec, DL, SubVT); + SrcVec = DAG.getNode(IROpc, DL, SubVT, SrcVec); Opnds.push_back(SrcVec); } return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds); }; - // On pre-AVX2 targets, split into 128-bit nodes of + // On pre-AVX targets, split into 128-bit nodes of // ISD::*_EXTEND_VECTOR_INREG. - if (!Subtarget.hasInt256() && !(VT.getSizeInBits() % 128)) + if (!Subtarget.hasAVX() && !(VT.getSizeInBits() % 128)) return SplitAndExtendInReg(128); // On pre-AVX512 targets, split into 256-bit nodes of @@ -37832,9 +39499,6 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, EVT InVT = N0.getValueType(); SDLoc DL(N); - if (SDValue DivRem8 = getDivRem8(N, DAG)) - return DivRem8; - if (SDValue NewCMov = combineToExtendCMOV(N, DAG)) return NewCMov; @@ -37861,7 +39525,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, return V; if (VT.isVector()) - if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget)) + if (SDValue R = PromoteMaskArithmetic(N, DAG, Subtarget)) return R; if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget)) @@ -37920,7 +39584,7 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue C = N->getOperand(2); auto invertIfNegative = [&DAG](SDValue &V) { - if (SDValue NegVal = isFNEG(V.getNode())) { + if (SDValue NegVal = isFNEG(DAG, V.getNode())) { V = DAG.getBitcast(V.getValueType(), NegVal); return true; } @@ -37928,7 +39592,7 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, // new extract from the FNEG input. if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT && isNullConstant(V.getOperand(1))) { - if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) { + if (SDValue NegVal = isFNEG(DAG, V.getOperand(0).getNode())) { NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal); V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(), NegVal, V.getOperand(1)); @@ -37961,7 +39625,7 @@ static SDValue combineFMADDSUB(SDNode *N, SelectionDAG &DAG, SDLoc dl(N); EVT VT = N->getValueType(0); - SDValue NegVal = isFNEG(N->getOperand(2).getNode()); + SDValue NegVal = isFNEG(DAG, N->getOperand(2).getNode()); if (!NegVal) return SDValue(); @@ -38032,12 +39696,9 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, return V; if (VT.isVector()) - if (SDValue R = WidenMaskArithmetic(N, DAG, Subtarget)) + if (SDValue R = PromoteMaskArithmetic(N, DAG, Subtarget)) return R; - if (SDValue DivRem8 = getDivRem8(N, DAG)) - return DivRem8; - if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget)) return NewAdd; @@ -38079,12 +39740,15 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, return SDValue(); // TODO: Use PXOR + PTEST for SSE4.1 or later? - // TODO: Add support for AVX-512. EVT VT = SetCC->getValueType(0); SDLoc DL(SetCC); if ((OpSize == 128 && Subtarget.hasSSE2()) || - (OpSize == 256 && Subtarget.hasAVX2())) { - EVT VecVT = OpSize == 128 ? MVT::v16i8 : MVT::v32i8; + (OpSize == 256 && Subtarget.hasAVX2()) || + (OpSize == 512 && Subtarget.useAVX512Regs())) { + EVT VecVT = OpSize == 512 ? MVT::v16i32 : + OpSize == 256 ? MVT::v32i8 : + MVT::v16i8; + EVT CmpVT = OpSize == 512 ? MVT::v16i1 : VecVT; SDValue Cmp; if (IsOrXorXorCCZero) { // This is a bitwise-combined equality comparison of 2 pairs of vectors: @@ -38095,14 +39759,18 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, SDValue B = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(1)); SDValue C = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(0)); SDValue D = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(1)); - SDValue Cmp1 = DAG.getSetCC(DL, VecVT, A, B, ISD::SETEQ); - SDValue Cmp2 = DAG.getSetCC(DL, VecVT, C, D, ISD::SETEQ); - Cmp = DAG.getNode(ISD::AND, DL, VecVT, Cmp1, Cmp2); + SDValue Cmp1 = DAG.getSetCC(DL, CmpVT, A, B, ISD::SETEQ); + SDValue Cmp2 = DAG.getSetCC(DL, CmpVT, C, D, ISD::SETEQ); + Cmp = DAG.getNode(ISD::AND, DL, CmpVT, Cmp1, Cmp2); } else { SDValue VecX = DAG.getBitcast(VecVT, X); SDValue VecY = DAG.getBitcast(VecVT, Y); - Cmp = DAG.getSetCC(DL, VecVT, VecX, VecY, ISD::SETEQ); + Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETEQ); } + // For 512-bits we want to emit a setcc that will lower to kortest. + if (OpSize == 512) + return DAG.getSetCC(DL, VT, DAG.getBitcast(MVT::i16, Cmp), + DAG.getConstant(0xFFFF, DL, MVT::i16), CC); // If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality. // setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq // setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne @@ -38181,7 +39849,9 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, // NOTE: The element count check is to ignore operand types that need to // go through type promotion to a 128-bit vector. if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.isVector() && - VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() > 4 && + VT.getVectorElementType() == MVT::i1 && + (ExperimentalVectorWideningLegalization || + VT.getVectorNumElements() > 4) && (OpVT.getVectorElementType() == MVT::i8 || OpVT.getVectorElementType() == MVT::i16)) { SDValue Setcc = DAG.getNode(ISD::SETCC, DL, OpVT, LHS, RHS, @@ -38202,10 +39872,11 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDValue Src = N->getOperand(0); MVT SrcVT = Src.getSimpleValueType(); + MVT VT = N->getSimpleValueType(0); // Perform constant folding. if (ISD::isBuildVectorOfConstantSDNodes(Src.getNode())) { - assert(N->getValueType(0) == MVT::i32 && "Unexpected result type"); + assert(VT== MVT::i32 && "Unexpected result type"); APInt Imm(32, 0); for (unsigned Idx = 0, e = Src.getNumOperands(); Idx < e; ++Idx) { SDValue In = Src.getOperand(Idx); @@ -38213,20 +39884,53 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, cast<ConstantSDNode>(In)->getAPIntValue().isNegative()) Imm.setBit(Idx); } - return DAG.getConstant(Imm, SDLoc(N), N->getValueType(0)); + return DAG.getConstant(Imm, SDLoc(N), VT); } - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); + // Look through int->fp bitcasts that don't change the element width. + if (Src.getOpcode() == ISD::BITCAST && Src.hasOneUse() && + SrcVT.isFloatingPoint() && + Src.getOperand(0).getValueType() == + EVT(SrcVT).changeVectorElementTypeToInteger()) + Src = Src.getOperand(0); - // MOVMSK only uses the MSB from each vector element. - KnownBits Known; - APInt DemandedMask(APInt::getSignMask(SrcVT.getScalarSizeInBits())); - if (TLI.SimplifyDemandedBits(Src, DemandedMask, Known, TLO)) { - DCI.AddToWorklist(Src.getNode()); - DCI.CommitTargetLoweringOpt(TLO); + // Simplify the inputs. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + APInt DemandedMask(APInt::getAllOnesValue(VT.getScalarSizeInBits())); + if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI)) return SDValue(N, 0); + + // Combine (movmsk (setne (and X, (1 << C)), 0)) -> (movmsk (X << C)). + // Only do this when the setcc input and output types are the same and the + // setcc and the 'and' node have a single use. + // FIXME: Support 256-bits with AVX1. The movmsk is split, but the and isn't. + APInt SplatVal; + if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse() && + Src.getOperand(0).getValueType() == Src.getValueType() && + cast<CondCodeSDNode>(Src.getOperand(2))->get() == ISD::SETNE && + ISD::isBuildVectorAllZeros(Src.getOperand(1).getNode()) && + Src.getOperand(0).getOpcode() == ISD::AND) { + SDValue And = Src.getOperand(0); + if (And.hasOneUse() && + ISD::isConstantSplatVector(And.getOperand(1).getNode(), SplatVal) && + SplatVal.isPowerOf2()) { + MVT VT = Src.getSimpleValueType(); + unsigned BitWidth = VT.getScalarSizeInBits(); + unsigned ShAmt = BitWidth - SplatVal.logBase2() - 1; + SDLoc DL(And); + SDValue X = And.getOperand(0); + // If the element type is i8, we need to bitcast to i16 to use a legal + // shift. If we wait until lowering we end up with an extra and to bits + // from crossing the 8-bit elements, but we don't care about that here. + if (VT.getVectorElementType() == MVT::i8) { + VT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); + X = DAG.getBitcast(VT, X); + } + SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, + DAG.getConstant(ShAmt, DL, VT)); + SDValue Cast = DAG.getBitcast(SrcVT, Shl); + return DAG.getNode(X86ISD::MOVMSK, SDLoc(N), N->getValueType(0), Cast); + } } return SDValue(); @@ -38296,16 +40000,10 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, // With AVX2 we only demand the upper bit of the mask. if (!Subtarget.hasAVX512()) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); SDValue Mask = N->getOperand(2); - KnownBits Known; APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); - if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) { - DCI.AddToWorklist(Mask.getNode()); - DCI.CommitTargetLoweringOpt(TLO); + if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) return SDValue(N, 0); - } } return SDValue(); @@ -38396,7 +40094,7 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); EVT InVT = Op0.getValueType(); - // UINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32)) + // UINT_TO_FP(vXi1) -> SINT_TO_FP(ZEXT(vXi1 to vXi32)) // UINT_TO_FP(vXi8) -> SINT_TO_FP(ZEXT(vXi8 to vXi32)) // UINT_TO_FP(vXi16) -> SINT_TO_FP(ZEXT(vXi16 to vXi32)) if (InVT.isVector() && InVT.getScalarSizeInBits() < 32) { @@ -38460,7 +40158,8 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, // Transform (SINT_TO_FP (i64 ...)) into an x87 operation if we have // a 32-bit target where SSE doesn't support i64->FP operations. - if (!Subtarget.useSoftFloat() && Op0.getOpcode() == ISD::LOAD) { + if (!Subtarget.useSoftFloat() && Subtarget.hasX87() && + Op0.getOpcode() == ISD::LOAD) { LoadSDNode *Ld = cast<LoadSDNode>(Op0.getNode()); EVT LdVT = Ld->getValueType(0); @@ -38485,6 +40184,159 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static bool needCarryOrOverflowFlag(SDValue Flags) { + assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!"); + + for (SDNode::use_iterator UI = Flags->use_begin(), UE = Flags->use_end(); + UI != UE; ++UI) { + SDNode *User = *UI; + + X86::CondCode CC; + switch (User->getOpcode()) { + default: + // Be conservative. + return true; + case X86ISD::SETCC: + case X86ISD::SETCC_CARRY: + CC = (X86::CondCode)User->getConstantOperandVal(0); + break; + case X86ISD::BRCOND: + CC = (X86::CondCode)User->getConstantOperandVal(2); + break; + case X86ISD::CMOV: + CC = (X86::CondCode)User->getConstantOperandVal(2); + break; + } + + switch (CC) { + default: break; + case X86::COND_A: case X86::COND_AE: + case X86::COND_B: case X86::COND_BE: + case X86::COND_O: case X86::COND_NO: + case X86::COND_G: case X86::COND_GE: + case X86::COND_L: case X86::COND_LE: + return true; + } + } + + return false; +} + +static bool onlyZeroFlagUsed(SDValue Flags) { + assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!"); + + for (SDNode::use_iterator UI = Flags->use_begin(), UE = Flags->use_end(); + UI != UE; ++UI) { + SDNode *User = *UI; + + unsigned CCOpNo; + switch (User->getOpcode()) { + default: + // Be conservative. + return false; + case X86ISD::SETCC: CCOpNo = 0; break; + case X86ISD::SETCC_CARRY: CCOpNo = 0; break; + case X86ISD::BRCOND: CCOpNo = 2; break; + case X86ISD::CMOV: CCOpNo = 2; break; + } + + X86::CondCode CC = (X86::CondCode)User->getConstantOperandVal(CCOpNo); + if (CC != X86::COND_E && CC != X86::COND_NE) + return false; + } + + return true; +} + +static SDValue combineCMP(SDNode *N, SelectionDAG &DAG) { + // Only handle test patterns. + if (!isNullConstant(N->getOperand(1))) + return SDValue(); + + // If we have a CMP of a truncated binop, see if we can make a smaller binop + // and use its flags directly. + // TODO: Maybe we should try promoting compares that only use the zero flag + // first if we can prove the upper bits with computeKnownBits? + SDLoc dl(N); + SDValue Op = N->getOperand(0); + EVT VT = Op.getValueType(); + + // If we have a constant logical shift that's only used in a comparison + // against zero turn it into an equivalent AND. This allows turning it into + // a TEST instruction later. + if ((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) && + Op.hasOneUse() && isa<ConstantSDNode>(Op.getOperand(1)) && + onlyZeroFlagUsed(SDValue(N, 0))) { + EVT VT = Op.getValueType(); + unsigned BitWidth = VT.getSizeInBits(); + unsigned ShAmt = Op.getConstantOperandVal(1); + if (ShAmt < BitWidth) { // Avoid undefined shifts. + APInt Mask = Op.getOpcode() == ISD::SRL + ? APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt) + : APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt); + if (Mask.isSignedIntN(32)) { + Op = DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), + DAG.getConstant(Mask, dl, VT)); + return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, + DAG.getConstant(0, dl, VT)); + } + } + } + + + // Look for a truncate with a single use. + if (Op.getOpcode() != ISD::TRUNCATE || !Op.hasOneUse()) + return SDValue(); + + Op = Op.getOperand(0); + + // Arithmetic op can only have one use. + if (!Op.hasOneUse()) + return SDValue(); + + unsigned NewOpc; + switch (Op.getOpcode()) { + default: return SDValue(); + case ISD::AND: + // Skip and with constant. We have special handling for and with immediate + // during isel to generate test instructions. + if (isa<ConstantSDNode>(Op.getOperand(1))) + return SDValue(); + NewOpc = X86ISD::AND; + break; + case ISD::OR: NewOpc = X86ISD::OR; break; + case ISD::XOR: NewOpc = X86ISD::XOR; break; + case ISD::ADD: + // If the carry or overflow flag is used, we can't truncate. + if (needCarryOrOverflowFlag(SDValue(N, 0))) + return SDValue(); + NewOpc = X86ISD::ADD; + break; + case ISD::SUB: + // If the carry or overflow flag is used, we can't truncate. + if (needCarryOrOverflowFlag(SDValue(N, 0))) + return SDValue(); + NewOpc = X86ISD::SUB; + break; + } + + // We found an op we can narrow. Truncate its inputs. + SDValue Op0 = DAG.getNode(ISD::TRUNCATE, dl, VT, Op.getOperand(0)); + SDValue Op1 = DAG.getNode(ISD::TRUNCATE, dl, VT, Op.getOperand(1)); + + // Use a X86 specific opcode to avoid DAG combine messing with it. + SDVTList VTs = DAG.getVTList(VT, MVT::i32); + Op = DAG.getNode(NewOpc, dl, VTs, Op0, Op1); + + // For AND, keep a CMP so that we can match the test pattern. + if (NewOpc == X86ISD::AND) + return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, + DAG.getConstant(0, dl, VT)); + + // Return the flags. + return Op.getValue(1); +} + static SDValue combineSBB(SDNode *N, SelectionDAG &DAG) { if (SDValue Flags = combineCarryThroughADD(N->getOperand(2))) { MVT VT = N->getSimpleValueType(0); @@ -38531,21 +40383,6 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// Materialize "setb reg" as "sbb reg,reg", since it produces an all-ones bit -/// which is more useful than 0/1 in some cases. -static SDValue materializeSBB(SDNode *N, SDValue EFLAGS, SelectionDAG &DAG) { - SDLoc DL(N); - // "Condition code B" is also known as "the carry flag" (CF). - SDValue CF = DAG.getConstant(X86::COND_B, DL, MVT::i8); - SDValue SBB = DAG.getNode(X86ISD::SETCC_CARRY, DL, MVT::i8, CF, EFLAGS); - MVT VT = N->getSimpleValueType(0); - if (VT == MVT::i8) - return DAG.getNode(ISD::AND, DL, VT, SBB, DAG.getConstant(1, DL, VT)); - - assert(VT == MVT::i1 && "Unexpected type for SETCC node"); - return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SBB); -} - /// If this is an add or subtract where one operand is produced by a cmp+setcc, /// then try to convert it to an ADC or SBB. This replaces TEST+SET+{ADD/SUB} /// with CMP+{ADC, SBB}. @@ -38616,13 +40453,11 @@ static SDValue combineAddOrSubToADCOrSBB(SDNode *N, SelectionDAG &DAG) { } if (CC == X86::COND_B) { - // X + SETB Z --> X + (mask SBB Z, Z) - // X - SETB Z --> X - (mask SBB Z, Z) - // TODO: Produce ADC/SBB here directly and avoid SETCC_CARRY? - SDValue SBB = materializeSBB(Y.getNode(), Y.getOperand(1), DAG); - if (SBB.getValueSizeInBits() != VT.getSizeInBits()) - SBB = DAG.getZExtOrTrunc(SBB, DL, VT); - return DAG.getNode(IsSub ? ISD::SUB : ISD::ADD, DL, VT, X, SBB); + // X + SETB Z --> adc X, 0 + // X - SETB Z --> sbb X, 0 + return DAG.getNode(IsSub ? X86ISD::SBB : X86ISD::ADC, DL, + DAG.getVTList(VT, MVT::i32), X, + DAG.getConstant(0, DL, VT), Y.getOperand(1)); } if (CC == X86::COND_A) { @@ -38640,10 +40475,9 @@ static SDValue combineAddOrSubToADCOrSBB(SDNode *N, SelectionDAG &DAG) { EFLAGS.getNode()->getVTList(), EFLAGS.getOperand(1), EFLAGS.getOperand(0)); SDValue NewEFLAGS = SDValue(NewSub.getNode(), EFLAGS.getResNo()); - SDValue SBB = materializeSBB(Y.getNode(), NewEFLAGS, DAG); - if (SBB.getValueSizeInBits() != VT.getSizeInBits()) - SBB = DAG.getZExtOrTrunc(SBB, DL, VT); - return DAG.getNode(IsSub ? ISD::SUB : ISD::ADD, DL, VT, X, SBB); + return DAG.getNode(IsSub ? X86ISD::SBB : X86ISD::ADC, DL, + DAG.getVTList(VT, MVT::i32), X, + DAG.getConstant(0, DL, VT), NewEFLAGS); } } @@ -38713,23 +40547,23 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - SDValue MulOp = N->getOperand(0); - SDValue Phi = N->getOperand(1); - - if (MulOp.getOpcode() != ISD::MUL) - std::swap(MulOp, Phi); - if (MulOp.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); EVT VT = N->getValueType(0); // If the vector size is less than 128, or greater than the supported RegSize, // do not use PMADD. - if (VT.getVectorNumElements() < 8) + if (!VT.isVector() || VT.getVectorNumElements() < 8) + return SDValue(); + + if (Op0.getOpcode() != ISD::MUL) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16) return SDValue(); SDLoc DL(N); @@ -38738,22 +40572,34 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, VT.getVectorNumElements() / 2); - // Shrink the operands of mul. - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); - // Madd vector size is half of the original vector size auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops); }; - SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, - PMADDWDBuilder); - // Fill the rest of the output with 0 - SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); - SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); - return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); + + auto BuildPMADDWD = [&](SDValue Mul) { + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1)); + + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); + // Fill the rest of the output with 0 + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, + DAG.getConstant(0, DL, MAddVT)); + }; + + Op0 = BuildPMADDWD(Op0); + + // It's possible that Op1 is also a mul we can reduce. + if (Op1.getOpcode() == ISD::MUL && + canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) { + Op1 = BuildPMADDWD(Op1); + } + + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); } static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, @@ -38786,45 +40632,53 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, // We know N is a reduction add, which means one of its operands is a phi. // To match SAD, we need the other operand to be a vector select. - SDValue SelectOp, Phi; - if (Op0.getOpcode() == ISD::VSELECT) { - SelectOp = Op0; - Phi = Op1; - } else if (Op1.getOpcode() == ISD::VSELECT) { - SelectOp = Op1; - Phi = Op0; - } else - return SDValue(); + if (Op0.getOpcode() != ISD::VSELECT) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::VSELECT) + return SDValue(); + + auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) { + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. Note that the number of elements of the result of SAD is less + // than the number of elements of its input. Therefore, we could only update + // part of elements in the reduction vector. + SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); + + // The output of PSADBW is a vector of i64. + // We need to turn the vector of i64 into a vector of i32. + // If the reduction vector is at least as wide as the psadbw result, just + // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero + // anyway. + MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); + if (VT.getSizeInBits() >= ResVT.getSizeInBits()) + Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + else + Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); + + if (VT.getSizeInBits() > ResVT.getSizeInBits()) { + // Fill the upper elements with zero to match the add width. + SDValue Zero = DAG.getConstant(0, DL, VT); + Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, + DAG.getIntPtrConstant(0, DL)); + } + + return Sad; + }; // Check whether we have an abs-diff pattern feeding into the select. - if(!detectZextAbsDiff(SelectOp, Op0, Op1)) - return SDValue(); - - // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elements of the result of SAD is less - // than the number of elements of its input. Therefore, we could only update - // part of elements in the reduction vector. - SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); - - // The output of PSADBW is a vector of i64. - // We need to turn the vector of i64 into a vector of i32. - // If the reduction vector is at least as wide as the psadbw result, just - // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero - // anyway. - MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); - if (VT.getSizeInBits() >= ResVT.getSizeInBits()) - Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); - else - Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); + SDValue SadOp0, SadOp1; + if (!detectZextAbsDiff(Op0, SadOp0, SadOp1)) + return SDValue(); - if (VT.getSizeInBits() > ResVT.getSizeInBits()) { - // Fill the upper elements with zero to match the add width. - SDValue Zero = DAG.getConstant(0, DL, VT); - Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, - DAG.getIntPtrConstant(0, DL)); + Op0 = BuildPSADBW(SadOp0, SadOp1); + + // It's possible we have a sad on the other side too. + if (Op1.getOpcode() == ISD::VSELECT && + detectZextAbsDiff(Op1, SadOp0, SadOp1)) { + Op1 = BuildPSADBW(SadOp0, SadOp1); } - return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); } /// Convert vector increment or decrement to sub/add with an all-ones constant: @@ -38843,10 +40697,8 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { if (!VT.is128BitVector() && !VT.is256BitVector() && !VT.is512BitVector()) return SDValue(); - SDNode *N1 = N->getOperand(1).getNode(); APInt SplatVal; - if (!ISD::isConstantSplatVector(N1, SplatVal) || - !SplatVal.isOneValue()) + if (!isConstantSplat(N->getOperand(1), SplatVal) || !SplatVal.isOneValue()) return SDValue(); SDValue AllOnesVec = getOnesVector(VT, DAG, SDLoc(N)); @@ -38963,6 +40815,39 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1, PMADDBuilder); } +// Try to turn (add (umax X, C), -C) into (psubus X, C) +static SDValue combineAddToSUBUS(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + EVT VT = N->getValueType(0); + + // psubus is available in SSE2 for i8 and i16 vectors. + if (!VT.isVector() || VT.getVectorNumElements() < 2 || + !isPowerOf2_32(VT.getVectorNumElements()) || + !(VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16)) + return SDValue(); + + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + if (Op0.getOpcode() != ISD::UMAX) + return SDValue(); + + // The add should have a constant that is the negative of the max. + // TODO: Handle build_vectors with undef elements. + auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) { + return Max->getAPIntValue() == (-Op->getAPIntValue()); + }; + if (!ISD::matchBinaryPredicate(Op0.getOperand(1), Op1, MatchUSUBSAT)) + return SDValue(); + + SDLoc DL(N); + return DAG.getNode(ISD::USUBSAT, DL, VT, Op0.getOperand(0), + Op0.getOperand(1)); +} + // Attempt to turn this pattern into PMADDWD. // (mul (add (zext (build_vector)), (zext (build_vector))), // (add (zext (build_vector)), (zext (build_vector))) @@ -39105,7 +40990,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, // Try to synthesize horizontal adds from adds of shuffles. if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || VT == MVT::v8i32) && - Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, true)) { + Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, true) && + shouldUseHorizontalOp(Op0 == Op1, DAG, Subtarget)) { auto HADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { return DAG.getNode(X86ISD::HADD, DL, Ops[0].getValueType(), Ops); @@ -39117,6 +41003,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineIncDecVector(N, DAG)) return V; + if (SDValue V = combineAddToSUBUS(N, DAG, Subtarget)) + return V; + return combineAddOrSubToADCOrSBB(N, DAG); } @@ -39162,23 +41051,22 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, } else return SDValue(); - auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, - ArrayRef<SDValue> Ops) { - return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops); + auto USUBSATBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + return DAG.getNode(ISD::USUBSAT, DL, Ops[0].getValueType(), Ops); }; // PSUBUS doesn't support v8i32/v8i64/v16i32, but it can be enabled with // special preprocessing in some cases. if (VT != MVT::v8i32 && VT != MVT::v16i32 && VT != MVT::v8i64) return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, - { SubusLHS, SubusRHS }, SUBUSBuilder); + { SubusLHS, SubusRHS }, USUBSATBuilder); // Special preprocessing case can be only applied // if the value was zero extended from 16 bit, // so we require first 16 bits to be zeros for 32 bit // values, or first 48 bits for 64 bit values. - KnownBits Known; - DAG.computeKnownBits(SubusLHS, Known); + KnownBits Known = DAG.computeKnownBits(SubusLHS); unsigned NumZeros = Known.countMinLeadingZeros(); if ((VT == MVT::v8i64 && NumZeros < 48) || NumZeros < 16) return SDValue(); @@ -39203,7 +41091,7 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, SDValue NewSubusRHS = DAG.getZExtOrTrunc(UMin, SDLoc(SubusRHS), ShrinkedType); SDValue Psubus = SplitOpsAndApply(DAG, Subtarget, SDLoc(N), ShrinkedType, - { NewSubusLHS, NewSubusRHS }, SUBUSBuilder); + { NewSubusLHS, NewSubusRHS }, USUBSATBuilder); // Zero extend the result, it may be used somewhere as 32 bit, // if not zext and following trunc will shrink. return DAG.getZExtOrTrunc(Psubus, SDLoc(N), ExtType); @@ -39236,7 +41124,8 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || VT == MVT::v8i32) && - Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, false)) { + Subtarget.hasSSSE3() && isHorizontalBinOp(Op0, Op1, false) && + shouldUseHorizontalOp(Op0 == Op1, DAG, Subtarget)) { auto HSUBBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { return DAG.getNode(X86ISD::HSUB, DL, Ops[0].getValueType(), Ops); @@ -39255,98 +41144,6 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, return combineAddOrSubToADCOrSBB(N, DAG); } -static SDValue combineVSZext(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - if (DCI.isBeforeLegalize()) - return SDValue(); - - SDLoc DL(N); - unsigned Opcode = N->getOpcode(); - MVT VT = N->getSimpleValueType(0); - MVT SVT = VT.getVectorElementType(); - unsigned NumElts = VT.getVectorNumElements(); - unsigned EltSizeInBits = SVT.getSizeInBits(); - - SDValue Op = N->getOperand(0); - MVT OpVT = Op.getSimpleValueType(); - MVT OpEltVT = OpVT.getVectorElementType(); - unsigned OpEltSizeInBits = OpEltVT.getSizeInBits(); - unsigned InputBits = OpEltSizeInBits * NumElts; - - // Perform any constant folding. - // FIXME: Reduce constant pool usage and don't fold when OptSize is enabled. - APInt UndefElts; - SmallVector<APInt, 64> EltBits; - if (getTargetConstantBitsFromNode(Op, OpEltSizeInBits, UndefElts, EltBits)) { - APInt Undefs(NumElts, 0); - SmallVector<APInt, 4> Vals(NumElts, APInt(EltSizeInBits, 0)); - bool IsZEXT = - (Opcode == X86ISD::VZEXT) || (Opcode == ISD::ZERO_EXTEND_VECTOR_INREG); - for (unsigned i = 0; i != NumElts; ++i) { - if (UndefElts[i]) { - Undefs.setBit(i); - continue; - } - Vals[i] = IsZEXT ? EltBits[i].zextOrTrunc(EltSizeInBits) - : EltBits[i].sextOrTrunc(EltSizeInBits); - } - return getConstVector(Vals, Undefs, VT, DAG, DL); - } - - // (vzext (bitcast (vzext (x)) -> (vzext x) - // TODO: (vsext (bitcast (vsext (x)) -> (vsext x) - SDValue V = peekThroughBitcasts(Op); - if (Opcode == X86ISD::VZEXT && V != Op && V.getOpcode() == X86ISD::VZEXT) { - MVT InnerVT = V.getSimpleValueType(); - MVT InnerEltVT = InnerVT.getVectorElementType(); - - // If the element sizes match exactly, we can just do one larger vzext. This - // is always an exact type match as vzext operates on integer types. - if (OpEltVT == InnerEltVT) { - assert(OpVT == InnerVT && "Types must match for vzext!"); - return DAG.getNode(X86ISD::VZEXT, DL, VT, V.getOperand(0)); - } - - // The only other way we can combine them is if only a single element of the - // inner vzext is used in the input to the outer vzext. - if (InnerEltVT.getSizeInBits() < InputBits) - return SDValue(); - - // In this case, the inner vzext is completely dead because we're going to - // only look at bits inside of the low element. Just do the outer vzext on - // a bitcast of the input to the inner. - return DAG.getNode(X86ISD::VZEXT, DL, VT, DAG.getBitcast(OpVT, V)); - } - - // Check if we can bypass extracting and re-inserting an element of an input - // vector. Essentially: - // (bitcast (sclr2vec (ext_vec_elt x))) -> (bitcast x) - // TODO: Add X86ISD::VSEXT support - if (Opcode == X86ISD::VZEXT && - V.getOpcode() == ISD::SCALAR_TO_VECTOR && - V.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && - V.getOperand(0).getSimpleValueType().getSizeInBits() == InputBits) { - SDValue ExtractedV = V.getOperand(0); - SDValue OrigV = ExtractedV.getOperand(0); - if (isNullConstant(ExtractedV.getOperand(1))) { - MVT OrigVT = OrigV.getSimpleValueType(); - // Extract a subvector if necessary... - if (OrigVT.getSizeInBits() > OpVT.getSizeInBits()) { - int Ratio = OrigVT.getSizeInBits() / OpVT.getSizeInBits(); - OrigVT = MVT::getVectorVT(OrigVT.getVectorElementType(), - OrigVT.getVectorNumElements() / Ratio); - OrigV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OrigVT, OrigV, - DAG.getIntPtrConstant(0, DL)); - } - Op = DAG.getBitcast(OpVT, OrigV); - return DAG.getNode(X86ISD::VZEXT, DL, VT, Op); - } - } - - return SDValue(); -} - static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MVT VT = N->getSimpleValueType(0); @@ -39354,9 +41151,9 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG, if (N->getOperand(0) == N->getOperand(1)) { if (N->getOpcode() == X86ISD::PCMPEQ) - return getOnesVector(VT, DAG, DL); + return DAG.getConstant(-1, DL, VT); if (N->getOpcode() == X86ISD::PCMPGT) - return getZeroVector(VT, Subtarget, DAG, DL); + return DAG.getConstant(0, DL, VT); } return SDValue(); @@ -39487,11 +41284,10 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, return Ld; } } - // If lower/upper loads are the same and the only users of the load, then - // lower to a VBROADCASTF128/VBROADCASTI128/etc. + // If lower/upper loads are the same and there's no other use of the lower + // load, then splat the loaded value with a broadcast. if (auto *Ld = dyn_cast<LoadSDNode>(peekThroughOneUseBitcasts(SubVec2))) - if (SubVec2 == SubVec && ISD::isNormalLoad(Ld) && - SDNode::areOnlyUsersOf({N, Vec.getNode()}, SubVec2.getNode())) + if (SubVec2 == SubVec && ISD::isNormalLoad(Ld) && Vec.hasOneUse()) return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, SubVec); // If this is subv_broadcast insert into both halves, use a larger @@ -39528,6 +41324,39 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + // For AVX1 only, if we are extracting from a 256-bit and+not (which will + // eventually get combined/lowered into ANDNP) with a concatenated operand, + // split the 'and' into 128-bit ops to avoid the concatenate and extract. + // We let generic combining take over from there to simplify the + // insert/extract and 'not'. + // This pattern emerges during AVX1 legalization. We handle it before lowering + // to avoid complications like splitting constant vector loads. + + // Capture the original wide type in the likely case that we need to bitcast + // back to this type. + EVT VT = N->getValueType(0); + EVT WideVecVT = N->getOperand(0).getValueType(); + SDValue WideVec = peekThroughBitcasts(N->getOperand(0)); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (Subtarget.hasAVX() && !Subtarget.hasAVX2() && + TLI.isTypeLegal(WideVecVT) && + WideVecVT.getSizeInBits() == 256 && WideVec.getOpcode() == ISD::AND) { + auto isConcatenatedNot = [] (SDValue V) { + V = peekThroughBitcasts(V); + if (!isBitwiseNot(V)) + return false; + SDValue NotOp = V->getOperand(0); + return peekThroughBitcasts(NotOp).getOpcode() == ISD::CONCAT_VECTORS; + }; + if (isConcatenatedNot(WideVec.getOperand(0)) || + isConcatenatedNot(WideVec.getOperand(1))) { + // extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1 + SDValue Concat = split256IntArith(WideVec, DAG); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, + DAG.getBitcast(WideVecVT, Concat), N->getOperand(1)); + } + } + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -39565,13 +41394,32 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, return DAG.getNode(X86ISD::VFPEXT, SDLoc(N), OpVT, InVec.getOperand(0)); } } - if ((InOpcode == X86ISD::VZEXT || InOpcode == X86ISD::VSEXT) && + if ((InOpcode == ISD::ZERO_EXTEND || InOpcode == ISD::SIGN_EXTEND) && OpVT.is128BitVector() && InVec.getOperand(0).getSimpleValueType().is128BitVector()) { - unsigned ExtOp = InOpcode == X86ISD::VZEXT ? ISD::ZERO_EXTEND_VECTOR_INREG - : ISD::SIGN_EXTEND_VECTOR_INREG; + unsigned ExtOp = + InOpcode == ISD::ZERO_EXTEND ? ISD::ZERO_EXTEND_VECTOR_INREG + : ISD::SIGN_EXTEND_VECTOR_INREG; return DAG.getNode(ExtOp, SDLoc(N), OpVT, InVec.getOperand(0)); } + if ((InOpcode == ISD::ZERO_EXTEND_VECTOR_INREG || + InOpcode == ISD::SIGN_EXTEND_VECTOR_INREG) && + OpVT.is128BitVector() && + InVec.getOperand(0).getSimpleValueType().is128BitVector()) { + return DAG.getNode(InOpcode, SDLoc(N), OpVT, InVec.getOperand(0)); + } + if (InOpcode == ISD::BITCAST) { + // TODO - do this for target shuffles in general. + SDValue InVecBC = peekThroughOneUseBitcasts(InVec); + if (InVecBC.getOpcode() == X86ISD::PSHUFB && OpVT.is128BitVector()) { + SDLoc DL(N); + SDValue SubPSHUFB = + DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, + extract128BitVector(InVecBC.getOperand(0), 0, DAG, DL), + extract128BitVector(InVecBC.getOperand(1), 0, DAG, DL)); + return DAG.getBitcast(OpVT, SubPSHUFB); + } + } } return SDValue(); @@ -39591,6 +41439,15 @@ static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), MVT::v1i1, Src.getOperand(0)); + // Combine scalar_to_vector of an extract_vector_elt into an extract_subvec. + if (VT == MVT::v1i1 && Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Src.hasOneUse() && Src.getOperand(0).getValueType().isVector() && + Src.getOperand(0).getValueType().getVectorElementType() == MVT::i1) + if (auto *C = dyn_cast<ConstantSDNode>(Src.getOperand(1))) + if (C->isNullValue()) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, + Src.getOperand(0), Src.getOperand(1)); + return SDValue(); } @@ -39600,23 +41457,28 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG, SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); + // Canonicalize constant to RHS. + if (DAG.isConstantIntBuildVectorOrConstantInt(LHS) && + !DAG.isConstantIntBuildVectorOrConstantInt(RHS)) + return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), RHS, LHS); + + // Multiply by zero. + if (ISD::isBuildVectorAllZeros(RHS.getNode())) + return RHS; + + // Aggressively peek through ops to get at the demanded low bits. + APInt DemandedMask = APInt::getLowBitsSet(64, 32); + SDValue DemandedLHS = DAG.GetDemandedBits(LHS, DemandedMask); + SDValue DemandedRHS = DAG.GetDemandedBits(RHS, DemandedMask); + if (DemandedLHS || DemandedRHS) + return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), + DemandedLHS ? DemandedLHS : LHS, + DemandedRHS ? DemandedRHS : RHS); + + // PMULDQ/PMULUDQ only uses lower 32 bits from each vector element. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); - APInt DemandedMask(APInt::getLowBitsSet(64, 32)); - - // PMULQDQ/PMULUDQ only uses lower 32 bits from each vector element. - KnownBits LHSKnown; - if (TLI.SimplifyDemandedBits(LHS, DemandedMask, LHSKnown, TLO)) { - DCI.CommitTargetLoweringOpt(TLO); + if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnesValue(64), DCI)) return SDValue(N, 0); - } - - KnownBits RHSKnown; - if (TLI.SimplifyDemandedBits(RHS, DemandedMask, RHSKnown, TLO)) { - DCI.CommitTargetLoweringOpt(TLO); - return SDValue(N, 0); - } return SDValue(); } @@ -39638,9 +41500,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, return combineExtractSubvector(N, DAG, DCI, Subtarget); case ISD::VSELECT: case ISD::SELECT: - case X86ISD::SHRUNKBLEND: return combineSelect(N, DAG, DCI, Subtarget); + case X86ISD::BLENDV: return combineSelect(N, DAG, DCI, Subtarget); case ISD::BITCAST: return combineBitcast(N, DAG, DCI, Subtarget); case X86ISD::CMOV: return combineCMov(N, DAG, DCI, Subtarget); + case X86ISD::CMP: return combineCMP(N, DAG); case ISD::ADD: return combineAdd(N, DAG, Subtarget); case ISD::SUB: return combineSub(N, DAG, Subtarget); case X86ISD::SBB: return combineSBB(N, DAG); @@ -39656,7 +41519,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::LOAD: return combineLoad(N, DAG, DCI, Subtarget); case ISD::MLOAD: return combineMaskedLoad(N, DAG, DCI, Subtarget); case ISD::STORE: return combineStore(N, DAG, Subtarget); - case ISD::MSTORE: return combineMaskedStore(N, DAG, Subtarget); + case ISD::MSTORE: return combineMaskedStore(N, DAG, DCI, Subtarget); case ISD::SINT_TO_FP: return combineSIntToFP(N, DAG, Subtarget); case ISD::UINT_TO_FP: return combineUIntToFP(N, DAG, Subtarget); case ISD::FADD: @@ -39672,6 +41535,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::FMAX: return combineFMinFMax(N, DAG); case ISD::FMINNUM: case ISD::FMAXNUM: return combineFMinNumFMaxNum(N, DAG, Subtarget); + case X86ISD::CVTSI2P: + case X86ISD::CVTUI2P: return combineX86INT_TO_FP(N, DAG, DCI); case X86ISD::BT: return combineBT(N, DAG, DCI); case ISD::ANY_EXTEND: case ISD::ZERO_EXTEND: return combineZext(N, DAG, DCI, Subtarget); @@ -39682,14 +41547,14 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::BRCOND: return combineBrCond(N, DAG, Subtarget); case X86ISD::PACKSS: case X86ISD::PACKUS: return combineVectorPack(N, DAG, DCI, Subtarget); + case X86ISD::VSHL: + case X86ISD::VSRA: + case X86ISD::VSRL: + return combineVectorShiftVar(N, DAG, DCI, Subtarget); case X86ISD::VSHLI: case X86ISD::VSRAI: case X86ISD::VSRLI: return combineVectorShiftImm(N, DAG, DCI, Subtarget); - case ISD::SIGN_EXTEND_VECTOR_INREG: - case ISD::ZERO_EXTEND_VECTOR_INREG: - case X86ISD::VSEXT: - case X86ISD::VZEXT: return combineVSZext(N, DAG, DCI, Subtarget); case X86ISD::PINSRB: case X86ISD::PINSRW: return combineVectorInsert(N, DAG, DCI, Subtarget); case X86ISD::SHUFP: // Handle all target specific shuffles @@ -39751,10 +41616,6 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, return SDValue(); } -/// Return true if the target has native support for the specified value type -/// and it is 'desirable' to use the type for the given node type. e.g. On x86 -/// i16 is legal, but undesirable since i16 instruction encodings are longer and -/// some i16 instructions are slow. bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const { if (!isTypeLegal(VT)) return false; @@ -39763,26 +41624,37 @@ bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const { if (Opc == ISD::SHL && VT.isVector() && VT.getVectorElementType() == MVT::i8) return false; - if (VT != MVT::i16) - return true; - - switch (Opc) { - default: - return true; - case ISD::LOAD: - case ISD::SIGN_EXTEND: - case ISD::ZERO_EXTEND: - case ISD::ANY_EXTEND: - case ISD::SHL: - case ISD::SRL: - case ISD::SUB: - case ISD::ADD: - case ISD::MUL: - case ISD::AND: - case ISD::OR: - case ISD::XOR: + // 8-bit multiply is probably not much cheaper than 32-bit multiply, and + // we have specializations to turn 32-bit multiply into LEA or other ops. + // Also, see the comment in "IsDesirableToPromoteOp" - where we additionally + // check for a constant operand to the multiply. + if (Opc == ISD::MUL && VT == MVT::i8) return false; + + // i16 instruction encodings are longer and some i16 instructions are slow, + // so those are not desirable. + if (VT == MVT::i16) { + switch (Opc) { + default: + break; + case ISD::LOAD: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + case ISD::ANY_EXTEND: + case ISD::SHL: + case ISD::SRL: + case ISD::SUB: + case ISD::ADD: + case ISD::MUL: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + return false; + } } + + // Any legal type not explicitly accounted for above here is desirable. + return true; } SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc& dl, @@ -39801,12 +41673,16 @@ SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc& dl, return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, DAG); } -/// This method query the target whether it is beneficial for dag combiner to -/// promote the specified node. If true, it should return the desired promotion -/// type by reference. bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { EVT VT = Op.getValueType(); - if (VT != MVT::i16) + bool Is8BitMulByConstant = VT == MVT::i8 && Op.getOpcode() == ISD::MUL && + isa<ConstantSDNode>(Op.getOperand(1)); + + // i16 is legal, but undesirable since i16 instruction encodings are longer + // and some i16 instructions are slow. + // 8-bit multiply-by-constant can usually be expanded to something cheaper + // using LEA and/or other ALU ops. + if (VT != MVT::i16 && !Is8BitMulByConstant) return false; auto IsFoldableRMW = [](SDValue Load, SDValue Op) { @@ -39820,6 +41696,19 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { return Ld->getBasePtr() == St->getBasePtr(); }; + auto IsFoldableAtomicRMW = [](SDValue Load, SDValue Op) { + if (!Load.hasOneUse() || Load.getOpcode() != ISD::ATOMIC_LOAD) + return false; + if (!Op.hasOneUse()) + return false; + SDNode *User = *Op->use_begin(); + if (User->getOpcode() != ISD::ATOMIC_STORE) + return false; + auto *Ld = cast<AtomicSDNode>(Load); + auto *St = cast<AtomicSDNode>(User); + return Ld->getBasePtr() == St->getBasePtr(); + }; + bool Commute = false; switch (Op.getOpcode()) { default: return false; @@ -39854,6 +41743,9 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { ((Commute && !isa<ConstantSDNode>(N1)) || (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N0, Op)))) return false; + if (IsFoldableAtomicRMW(N0, Op) || + (Commute && IsFoldableAtomicRMW(N1, Op))) + return false; } } @@ -40593,44 +42485,33 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Res.second) { // Map st(0) -> st(7) -> ST0 if (Constraint.size() == 7 && Constraint[0] == '{' && - tolower(Constraint[1]) == 's' && - tolower(Constraint[2]) == 't' && + tolower(Constraint[1]) == 's' && tolower(Constraint[2]) == 't' && Constraint[3] == '(' && (Constraint[4] >= '0' && Constraint[4] <= '7') && - Constraint[5] == ')' && - Constraint[6] == '}') { - - Res.first = X86::FP0+Constraint[4]-'0'; - Res.second = &X86::RFP80RegClass; - return Res; + Constraint[5] == ')' && Constraint[6] == '}') { + // st(7) is not allocatable and thus not a member of RFP80. Return + // singleton class in cases where we have a reference to it. + if (Constraint[4] == '7') + return std::make_pair(X86::FP7, &X86::RFP80_7RegClass); + return std::make_pair(X86::FP0 + Constraint[4] - '0', + &X86::RFP80RegClass); } // GCC allows "st(0)" to be called just plain "st". - if (StringRef("{st}").equals_lower(Constraint)) { - Res.first = X86::FP0; - Res.second = &X86::RFP80RegClass; - return Res; - } + if (StringRef("{st}").equals_lower(Constraint)) + return std::make_pair(X86::FP0, &X86::RFP80RegClass); // flags -> EFLAGS - if (StringRef("{flags}").equals_lower(Constraint)) { - Res.first = X86::EFLAGS; - Res.second = &X86::CCRRegClass; - return Res; - } + if (StringRef("{flags}").equals_lower(Constraint)) + return std::make_pair(X86::EFLAGS, &X86::CCRRegClass); // 'A' means [ER]AX + [ER]DX. if (Constraint == "A") { - if (Subtarget.is64Bit()) { - Res.first = X86::RAX; - Res.second = &X86::GR64_ADRegClass; - } else { - assert((Subtarget.is32Bit() || Subtarget.is16Bit()) && - "Expecting 64, 32 or 16 bit subtarget"); - Res.first = X86::EAX; - Res.second = &X86::GR32_ADRegClass; - } - return Res; + if (Subtarget.is64Bit()) + return std::make_pair(X86::RAX, &X86::GR64_ADRegClass); + assert((Subtarget.is32Bit() || Subtarget.is16Bit()) && + "Expecting 64, 32 or 16 bit subtarget"); + return std::make_pair(X86::EAX, &X86::GR32_ADRegClass); } return Res; } @@ -40640,18 +42521,14 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, (isFRClass(*Res.second) || isGRClass(*Res.second)) && TRI->getEncodingValue(Res.first) >= 8) { // Register requires REX prefix, but we're in 32-bit mode. - Res.first = 0; - Res.second = nullptr; - return Res; + return std::make_pair(0, nullptr); } // Make sure it isn't a register that requires AVX512. if (!Subtarget.hasAVX512() && isFRClass(*Res.second) && TRI->getEncodingValue(Res.first) & 0x10) { // Register requires EVEX prefix. - Res.first = 0; - Res.second = nullptr; - return Res; + return std::make_pair(0, nullptr); } // Otherwise, check to see if this is a register class of the wrong value @@ -40679,14 +42556,36 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, Size == 8 ? (is64Bit ? &X86::GR8RegClass : &X86::GR8_NOREXRegClass) : Size == 16 ? (is64Bit ? &X86::GR16RegClass : &X86::GR16_NOREXRegClass) : Size == 32 ? (is64Bit ? &X86::GR32RegClass : &X86::GR32_NOREXRegClass) - : &X86::GR64RegClass; - if (RC->contains(DestReg)) - Res = std::make_pair(DestReg, RC); - } else { - // No register found/type mismatch. - Res.first = 0; - Res.second = nullptr; + : Size == 64 ? (is64Bit ? &X86::GR64RegClass : nullptr) + : nullptr; + if (Size == 64 && !is64Bit) { + // Model GCC's behavior here and select a fixed pair of 32-bit + // registers. + switch (Res.first) { + case X86::EAX: + return std::make_pair(X86::EAX, &X86::GR32_ADRegClass); + case X86::EDX: + return std::make_pair(X86::EDX, &X86::GR32_DCRegClass); + case X86::ECX: + return std::make_pair(X86::ECX, &X86::GR32_CBRegClass); + case X86::EBX: + return std::make_pair(X86::EBX, &X86::GR32_BSIRegClass); + case X86::ESI: + return std::make_pair(X86::ESI, &X86::GR32_SIDIRegClass); + case X86::EDI: + return std::make_pair(X86::EDI, &X86::GR32_DIBPRegClass); + case X86::EBP: + return std::make_pair(X86::EBP, &X86::GR32_BPSPRegClass); + default: + return std::make_pair(0, nullptr); + } + } + if (RC && RC->contains(DestReg)) + return std::make_pair(DestReg, RC); + return Res; } + // No register found/type mismatch. + return std::make_pair(0, nullptr); } else if (isFRClass(*Class)) { // Handle references to XMM physical registers that got mapped into the // wrong class. This can happen with constraints like {xmm0} where the |
