diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4658 |
1 files changed, 3533 insertions, 1125 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f49c5011607f..f2ec422b54a9 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -27,6 +27,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -38,6 +39,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/InstructionCost.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -73,6 +75,10 @@ static cl::opt<int> "use for creating a floating-point immediate value"), cl::init(2)); +static cl::opt<bool> + RV64LegalI32("riscv-experimental-rv64-legal-i32", cl::ReallyHidden, + cl::desc("Make i32 a legal type for SelectionDAG on RV64.")); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -113,6 +119,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Set up the register classes. addRegisterClass(XLenVT, &RISCV::GPRRegClass); + if (Subtarget.is64Bit() && RV64LegalI32) + addRegisterClass(MVT::i32, &RISCV::GPRRegClass); if (Subtarget.hasStdExtZfhOrZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); @@ -145,6 +153,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, static const MVT::SimpleValueType F16VecVTs[] = { MVT::nxv1f16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16}; + static const MVT::SimpleValueType BF16VecVTs[] = { + MVT::nxv1bf16, MVT::nxv2bf16, MVT::nxv4bf16, + MVT::nxv8bf16, MVT::nxv16bf16, MVT::nxv32bf16}; static const MVT::SimpleValueType F32VecVTs[] = { MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32}; static const MVT::SimpleValueType F64VecVTs[] = { @@ -154,7 +165,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, auto addRegClassForRVV = [this](MVT VT) { // Disable the smallest fractional LMUL types if ELEN is less than // RVVBitsPerBlock. - unsigned MinElts = RISCV::RVVBitsPerBlock / Subtarget.getELEN(); + unsigned MinElts = RISCV::RVVBitsPerBlock / Subtarget.getELen(); if (VT.getVectorMinNumElements() < MinElts) return; @@ -183,10 +194,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, addRegClassForRVV(VT); } - if (Subtarget.hasVInstructionsF16()) + if (Subtarget.hasVInstructionsF16Minimal()) for (MVT VT : F16VecVTs) addRegClassForRVV(VT); + if (Subtarget.hasVInstructionsBF16()) + for (MVT VT : BF16VecVTs) + addRegClassForRVV(VT); + if (Subtarget.hasVInstructionsF32()) for (MVT VT : F32VecVTs) addRegClassForRVV(VT); @@ -228,8 +243,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BR_JT, MVT::Other, Expand); setOperationAction(ISD::BR_CC, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::BR_CC, MVT::i32, Expand); setOperationAction(ISD::BRCOND, MVT::Other, Custom); setOperationAction(ISD::SELECT_CC, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::SELECT_CC, MVT::i32, Expand); setCondCodeAction(ISD::SETLE, XLenVT, Expand); setCondCodeAction(ISD::SETGT, XLenVT, Custom); @@ -238,6 +257,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setCondCodeAction(ISD::SETUGT, XLenVT, Custom); setCondCodeAction(ISD::SETUGE, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::SETCC, MVT::i32, Promote); + setOperationAction({ISD::STACKSAVE, ISD::STACKRESTORE}, MVT::Other, Expand); setOperationAction(ISD::VASTART, MVT::Other, Custom); @@ -253,14 +275,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.is64Bit()) { setOperationAction(ISD::EH_DWARF_CFA, MVT::i64, Custom); - setOperationAction(ISD::LOAD, MVT::i32, Custom); - - setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL}, - MVT::i32, Custom); - - setOperationAction(ISD::SADDO, MVT::i32, Custom); - setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT}, - MVT::i32, Custom); + if (!RV64LegalI32) { + setOperationAction(ISD::LOAD, MVT::i32, Custom); + setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL}, + MVT::i32, Custom); + setOperationAction(ISD::SADDO, MVT::i32, Custom); + setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT}, + MVT::i32, Custom); + } } else { setLibcallName( {RTLIB::SHL_I128, RTLIB::SRL_I128, RTLIB::SRA_I128, RTLIB::MUL_I128}, @@ -268,19 +290,36 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setLibcallName(RTLIB::MULO_I64, nullptr); } - if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul()) + if (!Subtarget.hasStdExtM() && !Subtarget.hasStdExtZmmul()) { setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand); - else if (Subtarget.is64Bit()) - setOperationAction(ISD::MUL, {MVT::i32, MVT::i128}, Custom); - else + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::MUL, MVT::i32, Promote); + } else if (Subtarget.is64Bit()) { + setOperationAction(ISD::MUL, MVT::i128, Custom); + if (!RV64LegalI32) + setOperationAction(ISD::MUL, MVT::i32, Custom); + } else { setOperationAction(ISD::MUL, MVT::i64, Custom); + } - if (!Subtarget.hasStdExtM()) + if (!Subtarget.hasStdExtM()) { setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM}, XLenVT, Expand); - else if (Subtarget.is64Bit()) - setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM}, - {MVT::i8, MVT::i16, MVT::i32}, Custom); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM}, MVT::i32, + Promote); + } else if (Subtarget.is64Bit()) { + if (!RV64LegalI32) + setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM}, + {MVT::i8, MVT::i16, MVT::i32}, Custom); + } + + if (RV64LegalI32 && Subtarget.is64Bit()) { + setOperationAction({ISD::MULHS, ISD::MULHU}, MVT::i32, Expand); + setOperationAction( + {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, MVT::i32, + Expand); + } setOperationAction( {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, XLenVT, @@ -290,7 +329,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, Custom); if (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) { - if (Subtarget.is64Bit()) + if (!RV64LegalI32 && Subtarget.is64Bit()) setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom); } else if (Subtarget.hasVendorXTHeadBb()) { if (Subtarget.is64Bit()) @@ -298,6 +337,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Custom); } else { setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Expand); } // With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll @@ -307,6 +348,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, Subtarget.hasVendorXTHeadBb()) ? Legal : Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::BSWAP, MVT::i32, + (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb() || + Subtarget.hasVendorXTHeadBb()) + ? Promote + : Expand); + // Zbkb can use rev8+brev8 to implement bitreverse. setOperationAction(ISD::BITREVERSE, XLenVT, Subtarget.hasStdExtZbkb() ? Custom : Expand); @@ -314,30 +362,54 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtZbb()) { setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, XLenVT, Legal); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, MVT::i32, + Promote); - if (Subtarget.is64Bit()) - setOperationAction( - {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, - MVT::i32, Custom); + if (Subtarget.is64Bit()) { + if (RV64LegalI32) + setOperationAction(ISD::CTTZ, MVT::i32, Legal); + else + setOperationAction({ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF}, MVT::i32, Custom); + } } else { - setOperationAction({ISD::CTTZ, ISD::CTLZ, ISD::CTPOP}, XLenVT, Expand); + setOperationAction({ISD::CTTZ, ISD::CTPOP}, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction({ISD::CTTZ, ISD::CTPOP}, MVT::i32, Expand); } - if (Subtarget.hasVendorXTHeadBb()) { - setOperationAction(ISD::CTLZ, XLenVT, Legal); - + if (Subtarget.hasStdExtZbb() || Subtarget.hasVendorXTHeadBb()) { // We need the custom lowering to make sure that the resulting sequence // for the 32bit case is efficient on 64bit targets. - if (Subtarget.is64Bit()) - setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i32, Custom); + if (Subtarget.is64Bit()) { + if (RV64LegalI32) { + setOperationAction(ISD::CTLZ, MVT::i32, + Subtarget.hasStdExtZbb() ? Legal : Promote); + if (!Subtarget.hasStdExtZbb()) + setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Promote); + } else + setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i32, Custom); + } + } else { + setOperationAction(ISD::CTLZ, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::CTLZ, MVT::i32, Expand); } - if (Subtarget.is64Bit()) + if (!RV64LegalI32 && Subtarget.is64Bit() && + !Subtarget.hasShortForwardBranchOpt()) setOperationAction(ISD::ABS, MVT::i32, Custom); + // We can use PseudoCCSUB to implement ABS. + if (Subtarget.hasShortForwardBranchOpt()) + setOperationAction(ISD::ABS, XLenVT, Legal); + if (!Subtarget.hasVendorXTHeadCondMov()) setOperationAction(ISD::SELECT, XLenVT, Custom); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::SELECT, MVT::i32, Promote); + static const unsigned FPLegalNodeTypes[] = { ISD::FMINNUM, ISD::FMAXNUM, ISD::LRINT, ISD::LLRINT, ISD::LROUND, ISD::LLROUND, @@ -361,7 +433,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); - + + static const unsigned ZfhminZfbfminPromoteOps[] = { + ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, + ISD::FSUB, ISD::FMUL, ISD::FMA, + ISD::FDIV, ISD::FSQRT, ISD::FABS, + ISD::FNEG, ISD::STRICT_FMA, ISD::STRICT_FADD, + ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, + ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS, + ISD::SETCC, ISD::FCEIL, ISD::FFLOOR, + ISD::FTRUNC, ISD::FRINT, ISD::FROUND, + ISD::FROUNDEVEN, ISD::SELECT}; + if (Subtarget.hasStdExtZfbfmin()) { setOperationAction(ISD::BITCAST, MVT::i16, Custom); setOperationAction(ISD::BITCAST, MVT::bf16, Custom); @@ -369,6 +452,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); + setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand); + setOperationAction(ISD::BR_CC, MVT::bf16, Expand); + setOperationAction(ZfhminZfbfminPromoteOps, MVT::bf16, Promote); + setOperationAction(ISD::FREM, MVT::bf16, Promote); + // FIXME: Need to promote bf16 FCOPYSIGN to f32, but the + // DAGCombiner::visitFP_ROUND probably needs improvements first. + setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand); } if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { @@ -379,18 +469,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SELECT, MVT::f16, Custom); setOperationAction(ISD::IS_FPCLASS, MVT::f16, Custom); } else { - static const unsigned ZfhminPromoteOps[] = { - ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, - ISD::FSUB, ISD::FMUL, ISD::FMA, - ISD::FDIV, ISD::FSQRT, ISD::FABS, - ISD::FNEG, ISD::STRICT_FMA, ISD::STRICT_FADD, - ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, - ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS, - ISD::SETCC, ISD::FCEIL, ISD::FFLOOR, - ISD::FTRUNC, ISD::FRINT, ISD::FROUND, - ISD::FROUNDEVEN, ISD::SELECT}; - - setOperationAction(ZfhminPromoteOps, MVT::f16, Promote); + setOperationAction(ZfhminZfbfminPromoteOps, MVT::f16, Promote); setOperationAction({ISD::STRICT_LRINT, ISD::STRICT_LLRINT, ISD::STRICT_LROUND, ISD::STRICT_LLROUND}, MVT::f16, Legal); @@ -409,7 +488,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, Subtarget.hasStdExtZfa() ? Legal : Promote); setOperationAction({ISD::FREM, ISD::FPOW, ISD::FPOWI, ISD::FCOS, ISD::FSIN, ISD::FSINCOS, ISD::FEXP, - ISD::FEXP2, ISD::FLOG, ISD::FLOG2, ISD::FLOG10}, + ISD::FEXP2, ISD::FEXP10, ISD::FLOG, ISD::FLOG2, + ISD::FLOG10}, MVT::f16, Promote); // FIXME: Need to promote f16 STRICT_* to f32 libcalls, but we don't have @@ -439,6 +519,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(FPOpToExpand, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); setTruncStoreAction(MVT::f32, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom); setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom); setOperationAction(ISD::FP_TO_BF16, MVT::f32, @@ -481,6 +563,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(FPOpToExpand, MVT::f64, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom); setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom); setOperationAction(ISD::FP_TO_BF16, MVT::f64, @@ -504,6 +588,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, XLenVT, Legal); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT, + ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, + MVT::i32, Legal); + setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom); setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); } @@ -548,6 +637,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setBooleanVectorContents(ZeroOrOneBooleanContent); setOperationAction(ISD::VSCALE, XLenVT, Custom); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::VSCALE, MVT::i32, Custom); // RVV intrinsics may have illegal operands. // We also need to custom legalize vmv.x.s. @@ -576,7 +667,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::VP_FP_TO_UINT, ISD::VP_SETCC, ISD::VP_SIGN_EXTEND, ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE, ISD::VP_SMIN, ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX, - ISD::VP_ABS}; + ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE}; static const unsigned FloatingPointVPOps[] = { ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL, @@ -588,7 +679,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::VP_SQRT, ISD::VP_FMINNUM, ISD::VP_FMAXNUM, ISD::VP_FCEIL, ISD::VP_FFLOOR, ISD::VP_FROUND, ISD::VP_FROUNDEVEN, ISD::VP_FCOPYSIGN, ISD::VP_FROUNDTOZERO, - ISD::VP_FRINT, ISD::VP_FNEARBYINT}; + ISD::VP_FRINT, ISD::VP_FNEARBYINT, ISD::VP_IS_FPCLASS, + ISD::EXPERIMENTAL_VP_REVERSE}; static const unsigned IntegerVecReduceOps[] = { ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, @@ -659,9 +751,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Expand all extending loads to types larger than this, and truncating // stores from types larger than this. for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) { - setTruncStoreAction(OtherVT, VT, Expand); - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT, - VT, Expand); + setTruncStoreAction(VT, OtherVT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT, + OtherVT, Expand); } setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT, @@ -673,6 +765,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECTOR_REVERSE, VT, Custom); + setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom); + setOperationPromotedToType( ISD::VECTOR_SPLICE, VT, MVT::getVectorVT(MVT::i8, VT.getVectorElementCount())); @@ -695,8 +789,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, VT, Legal); - setOperationAction({ISD::VP_FSHL, ISD::VP_FSHR}, VT, Expand); - // Custom-lower extensions and truncations from/to mask types. setOperationAction({ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND}, VT, Custom); @@ -712,7 +804,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, VT, Custom); setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT, Custom); - + setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom); setOperationAction( {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal); @@ -751,8 +843,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) { setTruncStoreAction(VT, OtherVT, Expand); - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT, - VT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT, + OtherVT, Expand); } setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom); @@ -761,15 +853,22 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Splice setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); + if (Subtarget.hasStdExtZvkb()) { + setOperationAction(ISD::BSWAP, VT, Legal); + setOperationAction(ISD::VP_BSWAP, VT, Custom); + } else { + setOperationAction({ISD::BSWAP, ISD::VP_BSWAP}, VT, Expand); + setOperationAction({ISD::ROTL, ISD::ROTR}, VT, Expand); + } + if (Subtarget.hasStdExtZvbb()) { - setOperationAction({ISD::BITREVERSE, ISD::BSWAP}, VT, Legal); - setOperationAction({ISD::VP_BITREVERSE, ISD::VP_BSWAP}, VT, Custom); + setOperationAction(ISD::BITREVERSE, VT, Legal); + setOperationAction(ISD::VP_BITREVERSE, VT, Custom); setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ, ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP}, VT, Custom); } else { - setOperationAction({ISD::BITREVERSE, ISD::BSWAP}, VT, Expand); - setOperationAction({ISD::VP_BITREVERSE, ISD::VP_BSWAP}, VT, Expand); + setOperationAction({ISD::BITREVERSE, ISD::VP_BITREVERSE}, VT, Expand); setOperationAction({ISD::CTLZ, ISD::CTTZ, ISD::CTPOP}, VT, Expand); setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ, ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP}, @@ -784,8 +883,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ_ZERO_UNDEF}, VT, Custom); } - - setOperationAction({ISD::ROTL, ISD::ROTR}, VT, Expand); } } @@ -802,6 +899,27 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::SETGT, ISD::SETOGT, ISD::SETGE, ISD::SETOGE, }; + // TODO: support more ops. + static const unsigned ZvfhminPromoteOps[] = { + ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB, + ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT, + ISD::FABS, ISD::FNEG, ISD::FCOPYSIGN, ISD::FCEIL, + ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN, ISD::FRINT, + ISD::FNEARBYINT, ISD::IS_FPCLASS, ISD::SETCC, ISD::FMAXIMUM, + ISD::FMINIMUM, ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL, + ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA}; + + // TODO: support more vp ops. + static const unsigned ZvfhminPromoteVPOps[] = { + ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL, + ISD::VP_FDIV, ISD::VP_FNEG, ISD::VP_FABS, + ISD::VP_FMA, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD, + ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_SQRT, + ISD::VP_FMINNUM, ISD::VP_FMAXNUM, ISD::VP_FCEIL, + ISD::VP_FFLOOR, ISD::VP_FROUND, ISD::VP_FROUNDEVEN, + ISD::VP_FCOPYSIGN, ISD::VP_FROUNDTOZERO, ISD::VP_FRINT, + ISD::VP_FNEARBYINT, ISD::VP_SETCC}; + // Sets common operation actions on RVV floating-point vector types. const auto SetCommonVFPActions = [&](MVT VT) { setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); @@ -817,6 +935,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setCondCodeAction(VFPCCToExpand, VT, Expand); setOperationAction({ISD::FMINNUM, ISD::FMAXNUM}, VT, Legal); + setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, VT, Custom); setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT, @@ -833,6 +952,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FSINCOS, VT, Expand); setOperationAction(ISD::FEXP, VT, Expand); setOperationAction(ISD::FEXP2, VT, Expand); + setOperationAction(ISD::FEXP10, VT, Expand); setOperationAction(ISD::FLOG, VT, Expand); setOperationAction(ISD::FLOG2, VT, Expand); setOperationAction(ISD::FLOG10, VT, Expand); @@ -891,6 +1011,38 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, continue; SetCommonVFPActions(VT); } + } else if (Subtarget.hasVInstructionsF16Minimal()) { + for (MVT VT : F16VecVTs) { + if (!isTypeLegal(VT)) + continue; + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); + setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, + Custom); + setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); + setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT, + Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, + ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, + VT, Custom); + setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, + ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR}, + VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + // load/store + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); + + // Custom split nxv32f16 since nxv32f32 if not legal. + if (VT == MVT::nxv32f16) { + setOperationAction(ZvfhminPromoteOps, VT, Custom); + setOperationAction(ZvfhminPromoteVPOps, VT, Custom); + continue; + } + // Add more promote ops. + MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); + setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT); + setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT); + } } if (Subtarget.hasVInstructionsF32()) { @@ -922,8 +1074,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(Op, VT, Expand); for (MVT OtherVT : MVT::integer_fixedlen_vector_valuetypes()) { setTruncStoreAction(VT, OtherVT, Expand); - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, - OtherVT, VT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, VT, + OtherVT, Expand); } // Custom lower fixed vector undefs to scalable vector undefs to avoid @@ -986,6 +1138,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT, ISD::VP_SETCC, ISD::VP_TRUNCATE}, VT, Custom); + + setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom); continue; } @@ -1039,13 +1193,22 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(IntegerVPOps, VT, Custom); - // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the - // range of f32. - EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); - if (isTypeLegal(FloatVT)) - setOperationAction( - {ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT, - Custom); + if (Subtarget.hasStdExtZvkb()) + setOperationAction({ISD::BSWAP, ISD::ROTL, ISD::ROTR}, VT, Custom); + + if (Subtarget.hasStdExtZvbb()) { + setOperationAction({ISD::BITREVERSE, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, + ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTPOP}, + VT, Custom); + } else { + // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the + // range of f32. + EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); + if (isTypeLegal(FloatVT)) + setOperationAction( + {ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT, + Custom); + } } for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) { @@ -1066,6 +1229,34 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // expansion to a build_vector of 0s. setOperationAction(ISD::UNDEF, VT, Custom); + if (VT.getVectorElementType() == MVT::f16 && + !Subtarget.hasVInstructionsF16()) { + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); + setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, + Custom); + setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); + setOperationAction( + {ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT, + Custom); + setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, + ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, + VT, Custom); + setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, + ISD::EXTRACT_SUBVECTOR, ISD::SCALAR_TO_VECTOR}, + VT, Custom); + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); + // Don't promote f16 vector operations to f32 if f32 vector type is + // not legal. + // TODO: could split the f16 vector into two vectors and do promotion. + if (!isTypeLegal(F32VecVT)) + continue; + setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT); + setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT); + continue; + } + // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed. setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT, Custom); @@ -1088,7 +1279,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction({ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV, ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::FSQRT, ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM, - ISD::IS_FPCLASS}, + ISD::IS_FPCLASS, ISD::FMAXIMUM, ISD::FMINIMUM}, VT, Custom); setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); @@ -1132,14 +1323,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } + if (Subtarget.hasStdExtA()) { + setOperationAction(ISD::ATOMIC_LOAD_SUB, XLenVT, Expand); + if (RV64LegalI32 && Subtarget.is64Bit()) + setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Expand); + } + if (Subtarget.hasForcedAtomics()) { - // Set atomic rmw/cas operations to expand to force __sync libcalls. + // Force __sync libcalls to be emitted for atomic rmw/cas operations. setOperationAction( {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP, ISD::ATOMIC_LOAD_ADD, ISD::ATOMIC_LOAD_SUB, ISD::ATOMIC_LOAD_AND, ISD::ATOMIC_LOAD_OR, ISD::ATOMIC_LOAD_XOR, ISD::ATOMIC_LOAD_NAND, ISD::ATOMIC_LOAD_MIN, ISD::ATOMIC_LOAD_MAX, ISD::ATOMIC_LOAD_UMIN, ISD::ATOMIC_LOAD_UMAX}, - XLenVT, Expand); + XLenVT, LibCall); } if (Subtarget.hasVendorXTHeadMemIdx()) { @@ -1166,8 +1363,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setPrefFunctionAlignment(Subtarget.getPrefFunctionAlignment()); setPrefLoopAlignment(Subtarget.getPrefLoopAlignment()); - setMinimumJumpTableEntries(5); - // Jumps are expensive, compared to logic setJumpIsExpensive(); @@ -1197,7 +1392,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL, ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR, - ISD::CONCAT_VECTORS}); + ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, + ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL, + ISD::INSERT_VECTOR_ELT}); if (Subtarget.hasVendorXTHeadMemPair()) setTargetDAGCombine({ISD::LOAD, ISD::STORE}); if (Subtarget.useRVVForFixedLengthVectors()) @@ -1239,7 +1436,7 @@ bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT, return true; // Don't allow VF=1 if those types are't legal. - if (VF < RISCV::RVVBitsPerBlock / Subtarget.getELEN()) + if (VF < RISCV::RVVBitsPerBlock / Subtarget.getELen()) return true; // VLEN=32 support is incomplete. @@ -1677,7 +1874,7 @@ bool RISCVTargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, // replace. If we don't support unaligned scalar mem, prefer the constant // pool. // TODO: Can the caller pass down the alignment? - if (!Subtarget.enableUnalignedScalarMem()) + if (!Subtarget.hasFastUnalignedAccess()) return true; // Prefer to keep the load if it would require many instructions. @@ -1686,8 +1883,7 @@ bool RISCVTargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, // TODO: Should we keep the load only when we're definitely going to emit a // constant pool? - RISCVMatInt::InstSeq Seq = - RISCVMatInt::generateInstSeq(Val, Subtarget.getFeatureBits()); + RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Val, Subtarget); return Seq.size() <= Subtarget.getMaxBuildIntsCost(); } @@ -1844,8 +2040,11 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const { // If the vector op is supported, but the scalar op is not, the transform may // not be worthwhile. + // Permit a vector binary operation can be converted to scalar binary + // operation which is custom lowered with illegal type. EVT ScalarVT = VecVT.getScalarType(); - return isOperationLegalOrCustomOrPromote(Opc, ScalarVT); + return isOperationLegalOrCustomOrPromote(Opc, ScalarVT) || + isOperationCustom(Opc, ScalarVT); } bool RISCVTargetLowering::isOffsetFoldingLegal( @@ -1857,11 +2056,17 @@ bool RISCVTargetLowering::isOffsetFoldingLegal( return false; } -// Returns 0-31 if the fli instruction is available for the type and this is -// legal FP immediate for the type. Returns -1 otherwise. -int RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm, EVT VT) const { +// Return one of the followings: +// (1) `{0-31 value, false}` if FLI is available for Imm's type and FP value. +// (2) `{0-31 value, true}` if Imm is negative and FLI is available for its +// positive counterpart, which will be materialized from the first returned +// element. The second returned element indicated that there should be a FNEG +// followed. +// (3) `{-1, _}` if there is no way FLI can be used to materialize Imm. +std::pair<int, bool> RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm, + EVT VT) const { if (!Subtarget.hasStdExtZfa()) - return -1; + return std::make_pair(-1, false); bool IsSupportedVT = false; if (VT == MVT::f16) { @@ -1874,9 +2079,14 @@ int RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm, EVT VT) const { } if (!IsSupportedVT) - return -1; + return std::make_pair(-1, false); - return RISCVLoadFPImm::getLoadFPImm(Imm); + int Index = RISCVLoadFPImm::getLoadFPImm(Imm); + if (Index < 0 && Imm.isNegative()) + // Try the combination of its positive counterpart + FNEG. + return std::make_pair(RISCVLoadFPImm::getLoadFPImm(-Imm), true); + else + return std::make_pair(Index, false); } bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, @@ -1888,11 +2098,13 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, IsLegalVT = Subtarget.hasStdExtFOrZfinx(); else if (VT == MVT::f64) IsLegalVT = Subtarget.hasStdExtDOrZdinx(); + else if (VT == MVT::bf16) + IsLegalVT = Subtarget.hasStdExtZfbfmin(); if (!IsLegalVT) return false; - if (getLegalZfaFPImm(Imm, VT) >= 0) + if (getLegalZfaFPImm(Imm, VT).first >= 0) return true; // Cannot create a 64 bit floating-point immediate value for rv32. @@ -1901,14 +2113,17 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, // -0.0 can be created by fmv + fneg. return Imm.isZero(); } - // Special case: the cost for -0.0 is 1. - int Cost = Imm.isNegZero() - ? 1 - : RISCVMatInt::getIntMatCost(Imm.bitcastToAPInt(), - Subtarget.getXLen(), - Subtarget.getFeatureBits()); - // If the constantpool data is already in cache, only Cost 1 is cheaper. - return Cost < FPImmCost; + + // Special case: fmv + fneg + if (Imm.isNegZero()) + return true; + + // Building an integer and then converting requires a fmv at the end of + // the integer sequence. + const int Cost = + 1 + RISCVMatInt::getIntMatCost(Imm.bitcastToAPInt(), Subtarget.getXLen(), + Subtarget); + return Cost <= FPImmCost; } // TODO: This is very conservative. @@ -1953,7 +2168,12 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) return MVT::f32; - return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); + MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); + + if (RV64LegalI32 && Subtarget.is64Bit() && PartVT == MVT::i32) + return MVT::i64; + + return PartVT; } unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context, @@ -1968,6 +2188,21 @@ unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT); } +unsigned RISCVTargetLowering::getVectorTypeBreakdownForCallingConv( + LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT, + unsigned &NumIntermediates, MVT &RegisterVT) const { + unsigned NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv( + Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT); + + if (RV64LegalI32 && Subtarget.is64Bit() && IntermediateVT == MVT::i32) + IntermediateVT = MVT::i64; + + if (RV64LegalI32 && Subtarget.is64Bit() && RegisterVT == MVT::i32) + RegisterVT = MVT::i64; + + return NumRegs; +} + // Changes the condition code and swaps operands if necessary, so the SetCC // operation matches one of the comparisons supported directly by branches // in the RISC-V ISA. May adjust compares to favor compare with 0 over compare @@ -2010,7 +2245,7 @@ static void translateSetCCForBranch(const SDLoc &DL, SDValue &LHS, SDValue &RHS, } break; case ISD::SETLT: - // Convert X < 1 to 0 <= X. + // Convert X < 1 to 0 >= X. if (C == 1) { RHS = LHS; LHS = DAG.getConstant(0, DL, RHS.getValueType()); @@ -2228,7 +2463,7 @@ static bool useRVVForFixedLengthVectorVT(MVT VT, return false; break; case MVT::f16: - if (!Subtarget.hasVInstructionsF16()) + if (!Subtarget.hasVInstructionsF16Minimal()) return false; break; case MVT::f32: @@ -2242,7 +2477,7 @@ static bool useRVVForFixedLengthVectorVT(MVT VT, } // Reject elements larger than ELEN. - if (EltVT.getSizeInBits() > Subtarget.getELEN()) + if (EltVT.getSizeInBits() > Subtarget.getELen()) return false; unsigned LMul = divideCeil(VT.getSizeInBits(), MinVLen); @@ -2271,7 +2506,7 @@ static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT, "Expected legal fixed length vector!"); unsigned MinVLen = Subtarget.getRealMinVLen(); - unsigned MaxELen = Subtarget.getELEN(); + unsigned MaxELen = Subtarget.getELen(); MVT EltVT = VT.getVectorElementType(); switch (EltVT.SimpleTy) { @@ -2354,6 +2589,15 @@ static SDValue getVLOp(uint64_t NumElts, const SDLoc &DL, SelectionDAG &DAG, } static std::pair<SDValue, SDValue> +getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(VecVT.isScalableVector() && "Expecting a scalable vector"); + SDValue VL = DAG.getRegister(RISCV::X0, Subtarget.getXLenVT()); + SDValue Mask = getAllOnesMask(VecVT, VL, DL, DAG); + return {Mask, VL}; +} + +static std::pair<SDValue, SDValue> getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(ContainerVT.isScalableVector() && "Expecting scalable container type"); @@ -2373,18 +2617,7 @@ getDefaultVLOps(MVT VecVT, MVT ContainerVT, const SDLoc &DL, SelectionDAG &DAG, return getDefaultVLOps(VecVT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget); assert(ContainerVT.isScalableVector() && "Expecting scalable container type"); - MVT XLenVT = Subtarget.getXLenVT(); - SDValue VL = DAG.getRegister(RISCV::X0, XLenVT); - SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); - return {Mask, VL}; -} - -// As above but assuming the given type is a scalable vector type. -static std::pair<SDValue, SDValue> -getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - assert(VecVT.isScalableVector() && "Expecting a scalable vector"); - return getDefaultVLOps(VecVT, VecVT, DL, DAG, Subtarget); + return getDefaultScalableVLOps(ContainerVT, DL, DAG, Subtarget); } SDValue RISCVTargetLowering::computeVLMax(MVT VecVT, const SDLoc &DL, @@ -2407,6 +2640,51 @@ bool RISCVTargetLowering::shouldExpandBuildVectorWithShuffles( return false; } +InstructionCost RISCVTargetLowering::getLMULCost(MVT VT) const { + // TODO: Here assume reciprocal throughput is 1 for LMUL_1, it is + // implementation-defined. + if (!VT.isVector()) + return InstructionCost::getInvalid(); + unsigned DLenFactor = Subtarget.getDLenFactor(); + unsigned Cost; + if (VT.isScalableVector()) { + unsigned LMul; + bool Fractional; + std::tie(LMul, Fractional) = + RISCVVType::decodeVLMUL(RISCVTargetLowering::getLMUL(VT)); + if (Fractional) + Cost = LMul <= DLenFactor ? (DLenFactor / LMul) : 1; + else + Cost = (LMul * DLenFactor); + } else { + Cost = divideCeil(VT.getSizeInBits(), Subtarget.getRealMinVLen() / DLenFactor); + } + return Cost; +} + + +/// Return the cost of a vrgather.vv instruction for the type VT. vrgather.vv +/// is generally quadratic in the number of vreg implied by LMUL. Note that +/// operand (index and possibly mask) are handled separately. +InstructionCost RISCVTargetLowering::getVRGatherVVCost(MVT VT) const { + return getLMULCost(VT) * getLMULCost(VT); +} + +/// Return the cost of a vrgather.vi (or vx) instruction for the type VT. +/// vrgather.vi/vx may be linear in the number of vregs implied by LMUL, +/// or may track the vrgather.vv cost. It is implementation-dependent. +InstructionCost RISCVTargetLowering::getVRGatherVICost(MVT VT) const { + return getLMULCost(VT); +} + +/// Return the cost of a vslidedown.vi/vx or vslideup.vi/vx instruction +/// for the type VT. (This does not cover the vslide1up or vslide1down +/// variants.) Slides may be linear in the number of vregs implied by LMUL, +/// or may track the vrgather.vv cost. It is implementation-dependent. +InstructionCost RISCVTargetLowering::getVSlideCost(MVT VT) const { + return getLMULCost(VT); +} + static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { // RISC-V FP-to-int conversions saturate to the destination register size, but @@ -2420,9 +2698,10 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG, bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT; if (!DstVT.isVector()) { - // In absense of Zfh, promote f16 to f32, then saturate the result. - if (Src.getSimpleValueType() == MVT::f16 && - !Subtarget.hasStdExtZfhOrZhinx()) { + // For bf16 or for f16 in absense of Zfh, promote to f32, then saturate + // the result. + if ((Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) || + Src.getValueType() == MVT::bf16) { Src = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Src); } @@ -2778,6 +3057,31 @@ lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG, DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT())); } +// Expand vector LRINT and LLRINT by converting to the integer domain. +static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + MVT VT = Op.getSimpleValueType(); + assert(VT.isVector() && "Unexpected type"); + + SDLoc DL(Op); + SDValue Src = Op.getOperand(0); + MVT ContainerVT = VT; + + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); + } + + auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + SDValue Truncated = + DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, ContainerVT, Src, Mask, VL); + + if (!VT.isFixedLengthVector()) + return Truncated; + + return convertFromScalableVector(VT, Truncated, DAG, Subtarget); +} + static SDValue getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const SDLoc &DL, EVT VT, SDValue Merge, SDValue Op, @@ -2802,6 +3106,14 @@ getVSlideup(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const SDLoc &DL, return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VT, Ops); } +static MVT getLMUL1VT(MVT VT) { + assert(VT.getVectorElementType().getSizeInBits() <= 64 && + "Unexpected vector MVT"); + return MVT::getScalableVectorVT( + VT.getVectorElementType(), + RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits()); +} + struct VIDSequence { int64_t StepNumerator; unsigned StepDenominator; @@ -2975,8 +3287,124 @@ static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL, return convertFromScalableVector(VT, Gather, DAG, Subtarget); } -static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { + +/// Try and optimize BUILD_VECTORs with "dominant values" - these are values +/// which constitute a large proportion of the elements. In such cases we can +/// splat a vector with the dominant element and make up the shortfall with +/// INSERT_VECTOR_ELTs. Returns SDValue if not profitable. +/// Note that this includes vectors of 2 elements by association. The +/// upper-most element is the "dominant" one, allowing us to use a splat to +/// "insert" the upper element, and an insert of the lower element at position +/// 0, which improves codegen. +static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + MVT VT = Op.getSimpleValueType(); + assert(VT.isFixedLengthVector() && "Unexpected vector!"); + + MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + + SDLoc DL(Op); + auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + MVT XLenVT = Subtarget.getXLenVT(); + unsigned NumElts = Op.getNumOperands(); + + SDValue DominantValue; + unsigned MostCommonCount = 0; + DenseMap<SDValue, unsigned> ValueCounts; + unsigned NumUndefElts = + count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); }); + + // Track the number of scalar loads we know we'd be inserting, estimated as + // any non-zero floating-point constant. Other kinds of element are either + // already in registers or are materialized on demand. The threshold at which + // a vector load is more desirable than several scalar materializion and + // vector-insertion instructions is not known. + unsigned NumScalarLoads = 0; + + for (SDValue V : Op->op_values()) { + if (V.isUndef()) + continue; + + ValueCounts.insert(std::make_pair(V, 0)); + unsigned &Count = ValueCounts[V]; + if (0 == Count) + if (auto *CFP = dyn_cast<ConstantFPSDNode>(V)) + NumScalarLoads += !CFP->isExactlyValue(+0.0); + + // Is this value dominant? In case of a tie, prefer the highest element as + // it's cheaper to insert near the beginning of a vector than it is at the + // end. + if (++Count >= MostCommonCount) { + DominantValue = V; + MostCommonCount = Count; + } + } + + assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR"); + unsigned NumDefElts = NumElts - NumUndefElts; + unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2; + + // Don't perform this optimization when optimizing for size, since + // materializing elements and inserting them tends to cause code bloat. + if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts && + (NumElts != 2 || ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) && + ((MostCommonCount > DominantValueCountThreshold) || + (ValueCounts.size() <= Log2_32(NumDefElts)))) { + // Start by splatting the most common element. + SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue); + + DenseSet<SDValue> Processed{DominantValue}; + + // We can handle an insert into the last element (of a splat) via + // v(f)slide1down. This is slightly better than the vslideup insert + // lowering as it avoids the need for a vector group temporary. It + // is also better than using vmerge.vx as it avoids the need to + // materialize the mask in a vector register. + if (SDValue LastOp = Op->getOperand(Op->getNumOperands() - 1); + !LastOp.isUndef() && ValueCounts[LastOp] == 1 && + LastOp != DominantValue) { + Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); + auto OpCode = + VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL; + if (!VT.isFloatingPoint()) + LastOp = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, LastOp); + Vec = DAG.getNode(OpCode, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Vec, + LastOp, Mask, VL); + Vec = convertFromScalableVector(VT, Vec, DAG, Subtarget); + Processed.insert(LastOp); + } + + MVT SelMaskTy = VT.changeVectorElementType(MVT::i1); + for (const auto &OpIdx : enumerate(Op->ops())) { + const SDValue &V = OpIdx.value(); + if (V.isUndef() || !Processed.insert(V).second) + continue; + if (ValueCounts[V] == 1) { + Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V, + DAG.getConstant(OpIdx.index(), DL, XLenVT)); + } else { + // Blend in all instances of this value using a VSELECT, using a + // mask where each bit signals whether that element is the one + // we're after. + SmallVector<SDValue> Ops; + transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) { + return DAG.getConstant(V == V1, DL, XLenVT); + }); + Vec = DAG.getNode(ISD::VSELECT, DL, VT, + DAG.getBuildVector(SelMaskTy, DL, Ops), + DAG.getSplatBuildVector(VT, DL, V), Vec); + } + } + + return Vec; + } + + return SDValue(); +} + +static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); assert(VT.isFixedLengthVector() && "Unexpected vector!"); @@ -3008,94 +3436,68 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // XLenVT if we're producing a v8i1. This results in more consistent // codegen across RV32 and RV64. unsigned NumViaIntegerBits = std::clamp(NumElts, 8u, Subtarget.getXLen()); - NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELEN()); - if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { - // If we have to use more than one INSERT_VECTOR_ELT then this - // optimization is likely to increase code size; avoid peforming it in - // such a case. We can use a load from a constant pool in this case. - if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits) - return SDValue(); - // Now we can create our integer vector type. Note that it may be larger - // than the resulting mask type: v4i1 would use v1i8 as its integer type. - unsigned IntegerViaVecElts = divideCeil(NumElts, NumViaIntegerBits); - MVT IntegerViaVecVT = - MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits), - IntegerViaVecElts); - - uint64_t Bits = 0; - unsigned BitPos = 0, IntegerEltIdx = 0; - SmallVector<SDValue, 8> Elts(IntegerViaVecElts); - - for (unsigned I = 0; I < NumElts;) { - SDValue V = Op.getOperand(I); - bool BitValue = !V.isUndef() && cast<ConstantSDNode>(V)->getZExtValue(); - Bits |= ((uint64_t)BitValue << BitPos); - ++BitPos; - ++I; + NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELen()); + // If we have to use more than one INSERT_VECTOR_ELT then this + // optimization is likely to increase code size; avoid peforming it in + // such a case. We can use a load from a constant pool in this case. + if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits) + return SDValue(); + // Now we can create our integer vector type. Note that it may be larger + // than the resulting mask type: v4i1 would use v1i8 as its integer type. + unsigned IntegerViaVecElts = divideCeil(NumElts, NumViaIntegerBits); + MVT IntegerViaVecVT = + MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits), + IntegerViaVecElts); - // Once we accumulate enough bits to fill our scalar type or process the - // last element, insert into our vector and clear our accumulated data. - if (I % NumViaIntegerBits == 0 || I == NumElts) { - if (NumViaIntegerBits <= 32) - Bits = SignExtend64<32>(Bits); - SDValue Elt = DAG.getConstant(Bits, DL, XLenVT); - Elts[IntegerEltIdx] = Elt; - Bits = 0; - BitPos = 0; - IntegerEltIdx++; - } - } + uint64_t Bits = 0; + unsigned BitPos = 0, IntegerEltIdx = 0; + SmallVector<SDValue, 8> Elts(IntegerViaVecElts); - SDValue Vec = DAG.getBuildVector(IntegerViaVecVT, DL, Elts); + for (unsigned I = 0; I < NumElts;) { + SDValue V = Op.getOperand(I); + bool BitValue = !V.isUndef() && cast<ConstantSDNode>(V)->getZExtValue(); + Bits |= ((uint64_t)BitValue << BitPos); + ++BitPos; + ++I; - if (NumElts < NumViaIntegerBits) { - // If we're producing a smaller vector than our minimum legal integer - // type, bitcast to the equivalent (known-legal) mask type, and extract - // our final mask. - assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type"); - Vec = DAG.getBitcast(MVT::v8i1, Vec); - Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec, - DAG.getConstant(0, DL, XLenVT)); - } else { - // Else we must have produced an integer type with the same size as the - // mask type; bitcast for the final result. - assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits()); - Vec = DAG.getBitcast(VT, Vec); + // Once we accumulate enough bits to fill our scalar type or process the + // last element, insert into our vector and clear our accumulated data. + if (I % NumViaIntegerBits == 0 || I == NumElts) { + if (NumViaIntegerBits <= 32) + Bits = SignExtend64<32>(Bits); + SDValue Elt = DAG.getConstant(Bits, DL, XLenVT); + Elts[IntegerEltIdx] = Elt; + Bits = 0; + BitPos = 0; + IntegerEltIdx++; } - - return Vec; } - // A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask - // vector type, we have a legal equivalently-sized i8 type, so we can use - // that. - MVT WideVecVT = VT.changeVectorElementType(MVT::i8); - SDValue VecZero = DAG.getConstant(0, DL, WideVecVT); + SDValue Vec = DAG.getBuildVector(IntegerViaVecVT, DL, Elts); - SDValue WideVec; - if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) { - // For a splat, perform a scalar truncate before creating the wider - // vector. - assert(Splat.getValueType() == XLenVT && - "Unexpected type for i1 splat value"); - Splat = DAG.getNode(ISD::AND, DL, XLenVT, Splat, - DAG.getConstant(1, DL, XLenVT)); - WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat); + if (NumElts < NumViaIntegerBits) { + // If we're producing a smaller vector than our minimum legal integer + // type, bitcast to the equivalent (known-legal) mask type, and extract + // our final mask. + assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type"); + Vec = DAG.getBitcast(MVT::v8i1, Vec); + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec, + DAG.getConstant(0, DL, XLenVT)); } else { - SmallVector<SDValue, 8> Ops(Op->op_values()); - WideVec = DAG.getBuildVector(WideVecVT, DL, Ops); - SDValue VecOne = DAG.getConstant(1, DL, WideVecVT); - WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne); + // Else we must have produced an integer type with the same size as the + // mask type; bitcast for the final result. + assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits()); + Vec = DAG.getBitcast(VT, Vec); } - return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE); + return Vec; } if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) { - if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget)) - return Gather; unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL; + if (!VT.isFloatingPoint()) + Splat = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Splat); Splat = DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL); return convertFromScalableVector(VT, Splat, DAG, Subtarget); @@ -3113,12 +3515,13 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, bool Negate = false; int64_t SplatStepVal = StepNumerator; unsigned StepOpcode = ISD::MUL; - if (StepNumerator != 1) { - if (isPowerOf2_64(std::abs(StepNumerator))) { - Negate = StepNumerator < 0; - StepOpcode = ISD::SHL; - SplatStepVal = Log2_64(std::abs(StepNumerator)); - } + // Exclude INT64_MIN to avoid passing it to std::abs. We won't optimize it + // anyway as the shift of 63 won't fit in uimm5. + if (StepNumerator != 1 && StepNumerator != INT64_MIN && + isPowerOf2_64(std::abs(StepNumerator))) { + Negate = StepNumerator < 0; + StepOpcode = ISD::SHL; + SplatStepVal = Log2_64(std::abs(StepNumerator)); } // Only emit VIDs with suitably-small steps/addends. We use imm5 is a @@ -3141,18 +3544,16 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, VID = convertFromScalableVector(VIDVT, VID, DAG, Subtarget); if ((StepOpcode == ISD::MUL && SplatStepVal != 1) || (StepOpcode == ISD::SHL && SplatStepVal != 0)) { - SDValue SplatStep = DAG.getSplatBuildVector( - VIDVT, DL, DAG.getConstant(SplatStepVal, DL, XLenVT)); + SDValue SplatStep = DAG.getConstant(SplatStepVal, DL, VIDVT); VID = DAG.getNode(StepOpcode, DL, VIDVT, VID, SplatStep); } if (StepDenominator != 1) { - SDValue SplatStep = DAG.getSplatBuildVector( - VIDVT, DL, DAG.getConstant(Log2_64(StepDenominator), DL, XLenVT)); + SDValue SplatStep = + DAG.getConstant(Log2_64(StepDenominator), DL, VIDVT); VID = DAG.getNode(ISD::SRL, DL, VIDVT, VID, SplatStep); } if (Addend != 0 || Negate) { - SDValue SplatAddend = DAG.getSplatBuildVector( - VIDVT, DL, DAG.getConstant(Addend, DL, XLenVT)); + SDValue SplatAddend = DAG.getConstant(Addend, DL, VIDVT); VID = DAG.getNode(Negate ? ISD::SUB : ISD::ADD, DL, VIDVT, SplatAddend, VID); } @@ -3164,6 +3565,48 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, } } + // For very small build_vectors, use a single scalar insert of a constant. + // TODO: Base this on constant rematerialization cost, not size. + const unsigned EltBitSize = VT.getScalarSizeInBits(); + if (VT.getSizeInBits() <= 32 && + ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { + MVT ViaIntVT = MVT::getIntegerVT(VT.getSizeInBits()); + assert((ViaIntVT == MVT::i16 || ViaIntVT == MVT::i32) && + "Unexpected sequence type"); + // If we can use the original VL with the modified element type, this + // means we only have a VTYPE toggle, not a VL toggle. TODO: Should this + // be moved into InsertVSETVLI? + unsigned ViaVecLen = + (Subtarget.getRealMinVLen() >= VT.getSizeInBits() * NumElts) ? NumElts : 1; + MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, ViaVecLen); + + uint64_t EltMask = maskTrailingOnes<uint64_t>(EltBitSize); + uint64_t SplatValue = 0; + // Construct the amalgamated value at this larger vector type. + for (const auto &OpIdx : enumerate(Op->op_values())) { + const auto &SeqV = OpIdx.value(); + if (!SeqV.isUndef()) + SplatValue |= ((cast<ConstantSDNode>(SeqV)->getZExtValue() & EltMask) + << (OpIdx.index() * EltBitSize)); + } + + // On RV64, sign-extend from 32 to 64 bits where possible in order to + // achieve better constant materializion. + if (Subtarget.is64Bit() && ViaIntVT == MVT::i32) + SplatValue = SignExtend64<32>(SplatValue); + + SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT, + DAG.getUNDEF(ViaVecVT), + DAG.getConstant(SplatValue, DL, XLenVT), + DAG.getConstant(0, DL, XLenVT)); + if (ViaVecLen != 1) + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, + MVT::getVectorVT(ViaIntVT, 1), Vec, + DAG.getConstant(0, DL, XLenVT)); + return DAG.getBitcast(VT, Vec); + } + + // Attempt to detect "hidden" splats, which only reveal themselves as splats // when re-interpreted as a vector with a larger element type. For example, // v4i16 = build_vector i16 0, i16 1, i16 0, i16 1 @@ -3172,7 +3615,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // TODO: This optimization could also work on non-constant splats, but it // would require bit-manipulation instructions to construct the splat value. SmallVector<SDValue> Sequence; - unsigned EltBitSize = VT.getScalarSizeInBits(); const auto *BV = cast<BuildVectorSDNode>(Op); if (VT.isInteger() && EltBitSize < 64 && ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) && @@ -3180,11 +3622,19 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, (Sequence.size() * EltBitSize) <= 64) { unsigned SeqLen = Sequence.size(); MVT ViaIntVT = MVT::getIntegerVT(EltBitSize * SeqLen); - MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, NumElts / SeqLen); assert((ViaIntVT == MVT::i16 || ViaIntVT == MVT::i32 || ViaIntVT == MVT::i64) && "Unexpected sequence type"); + // If we can use the original VL with the modified element type, this + // means we only have a VTYPE toggle, not a VL toggle. TODO: Should this + // be moved into InsertVSETVLI? + const unsigned RequiredVL = NumElts / SeqLen; + const unsigned ViaVecLen = + (Subtarget.getRealMinVLen() >= ViaIntVT.getSizeInBits() * NumElts) ? + NumElts : RequiredVL; + MVT ViaVecVT = MVT::getVectorVT(ViaIntVT, ViaVecLen); + unsigned EltIdx = 0; uint64_t EltMask = maskTrailingOnes<uint64_t>(EltBitSize); uint64_t SplatValue = 0; @@ -3218,94 +3668,171 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, DAG.getUNDEF(ViaContainerVT), DAG.getConstant(SplatValue, DL, XLenVT), ViaVL); Splat = convertFromScalableVector(ViaVecVT, Splat, DAG, Subtarget); + if (ViaVecLen != RequiredVL) + Splat = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, + MVT::getVectorVT(ViaIntVT, RequiredVL), Splat, + DAG.getConstant(0, DL, XLenVT)); return DAG.getBitcast(VT, Splat); } } - // Try and optimize BUILD_VECTORs with "dominant values" - these are values - // which constitute a large proportion of the elements. In such cases we can - // splat a vector with the dominant element and make up the shortfall with - // INSERT_VECTOR_ELTs. - // Note that this includes vectors of 2 elements by association. The - // upper-most element is the "dominant" one, allowing us to use a splat to - // "insert" the upper element, and an insert of the lower element at position - // 0, which improves codegen. - SDValue DominantValue; - unsigned MostCommonCount = 0; - DenseMap<SDValue, unsigned> ValueCounts; - unsigned NumUndefElts = - count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); }); + // If the number of signbits allows, see if we can lower as a <N x i8>. + // Our main goal here is to reduce LMUL (and thus work) required to + // build the constant, but we will also narrow if the resulting + // narrow vector is known to materialize cheaply. + // TODO: We really should be costing the smaller vector. There are + // profitable cases this misses. + if (EltBitSize > 8 && VT.isInteger() && + (NumElts <= 4 || VT.getSizeInBits() > Subtarget.getRealMinVLen())) { + unsigned SignBits = DAG.ComputeNumSignBits(Op); + if (EltBitSize - SignBits < 8) { + SDValue Source = DAG.getBuildVector(VT.changeVectorElementType(MVT::i8), + DL, Op->ops()); + Source = convertToScalableVector(ContainerVT.changeVectorElementType(MVT::i8), + Source, DAG, Subtarget); + SDValue Res = DAG.getNode(RISCVISD::VSEXT_VL, DL, ContainerVT, Source, Mask, VL); + return convertFromScalableVector(VT, Res, DAG, Subtarget); + } + } - // Track the number of scalar loads we know we'd be inserting, estimated as - // any non-zero floating-point constant. Other kinds of element are either - // already in registers or are materialized on demand. The threshold at which - // a vector load is more desirable than several scalar materializion and - // vector-insertion instructions is not known. - unsigned NumScalarLoads = 0; + if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget)) + return Res; - for (SDValue V : Op->op_values()) { - if (V.isUndef()) - continue; + // For constant vectors, use generic constant pool lowering. Otherwise, + // we'd have to materialize constants in GPRs just to move them into the + // vector. + return SDValue(); +} - ValueCounts.insert(std::make_pair(V, 0)); - unsigned &Count = ValueCounts[V]; - if (0 == Count) - if (auto *CFP = dyn_cast<ConstantFPSDNode>(V)) - NumScalarLoads += !CFP->isExactlyValue(+0.0); +static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + MVT VT = Op.getSimpleValueType(); + assert(VT.isFixedLengthVector() && "Unexpected vector!"); - // Is this value dominant? In case of a tie, prefer the highest element as - // it's cheaper to insert near the beginning of a vector than it is at the - // end. - if (++Count >= MostCommonCount) { - DominantValue = V; - MostCommonCount = Count; + if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || + ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) + return lowerBuildVectorOfConstants(Op, DAG, Subtarget); + + MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + + SDLoc DL(Op); + auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + MVT XLenVT = Subtarget.getXLenVT(); + + if (VT.getVectorElementType() == MVT::i1) { + // A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask + // vector type, we have a legal equivalently-sized i8 type, so we can use + // that. + MVT WideVecVT = VT.changeVectorElementType(MVT::i8); + SDValue VecZero = DAG.getConstant(0, DL, WideVecVT); + + SDValue WideVec; + if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) { + // For a splat, perform a scalar truncate before creating the wider + // vector. + Splat = DAG.getNode(ISD::AND, DL, Splat.getValueType(), Splat, + DAG.getConstant(1, DL, Splat.getValueType())); + WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat); + } else { + SmallVector<SDValue, 8> Ops(Op->op_values()); + WideVec = DAG.getBuildVector(WideVecVT, DL, Ops); + SDValue VecOne = DAG.getConstant(1, DL, WideVecVT); + WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne); } + + return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE); } - assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR"); - unsigned NumDefElts = NumElts - NumUndefElts; - unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2; + if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) { + if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget)) + return Gather; + unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL + : RISCVISD::VMV_V_X_VL; + if (!VT.isFloatingPoint()) + Splat = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Splat); + Splat = + DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL); + return convertFromScalableVector(VT, Splat, DAG, Subtarget); + } - // Don't perform this optimization when optimizing for size, since - // materializing elements and inserting them tends to cause code bloat. - if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts && - (NumElts != 2 || ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) && - ((MostCommonCount > DominantValueCountThreshold) || - (ValueCounts.size() <= Log2_32(NumDefElts)))) { - // Start by splatting the most common element. - SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue); + if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget)) + return Res; - DenseSet<SDValue> Processed{DominantValue}; - MVT SelMaskTy = VT.changeVectorElementType(MVT::i1); - for (const auto &OpIdx : enumerate(Op->ops())) { - const SDValue &V = OpIdx.value(); - if (V.isUndef() || !Processed.insert(V).second) - continue; - if (ValueCounts[V] == 1) { - Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V, - DAG.getConstant(OpIdx.index(), DL, XLenVT)); - } else { - // Blend in all instances of this value using a VSELECT, using a - // mask where each bit signals whether that element is the one - // we're after. - SmallVector<SDValue> Ops; - transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) { - return DAG.getConstant(V == V1, DL, XLenVT); - }); - Vec = DAG.getNode(ISD::VSELECT, DL, VT, - DAG.getBuildVector(SelMaskTy, DL, Ops), - DAG.getSplatBuildVector(VT, DL, V), Vec); - } + // If we're compiling for an exact VLEN value, we can split our work per + // register in the register group. + const unsigned MinVLen = Subtarget.getRealMinVLen(); + const unsigned MaxVLen = Subtarget.getRealMaxVLen(); + if (MinVLen == MaxVLen && VT.getSizeInBits().getKnownMinValue() > MinVLen) { + MVT ElemVT = VT.getVectorElementType(); + unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg); + MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget); + assert(M1VT == getLMUL1VT(M1VT)); + + // The following semantically builds up a fixed length concat_vector + // of the component build_vectors. We eagerly lower to scalable and + // insert_subvector here to avoid DAG combining it back to a large + // build_vector. + SmallVector<SDValue> BuildVectorOps(Op->op_begin(), Op->op_end()); + unsigned NumOpElts = M1VT.getVectorMinNumElements(); + SDValue Vec = DAG.getUNDEF(ContainerVT); + for (unsigned i = 0; i < VT.getVectorNumElements(); i += ElemsPerVReg) { + auto OneVRegOfOps = ArrayRef(BuildVectorOps).slice(i, ElemsPerVReg); + SDValue SubBV = + DAG.getNode(ISD::BUILD_VECTOR, DL, OneRegVT, OneVRegOfOps); + SubBV = convertToScalableVector(M1VT, SubBV, DAG, Subtarget); + unsigned InsertIdx = (i / ElemsPerVReg) * NumOpElts; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, SubBV, + DAG.getVectorIdxConstant(InsertIdx, DL)); } + return convertFromScalableVector(VT, Vec, DAG, Subtarget); + } - return Vec; + // Cap the cost at a value linear to the number of elements in the vector. + // The default lowering is to use the stack. The vector store + scalar loads + // is linear in VL. However, at high lmuls vslide1down and vslidedown end up + // being (at least) linear in LMUL. As a result, using the vslidedown + // lowering for every element ends up being VL*LMUL.. + // TODO: Should we be directly costing the stack alternative? Doing so might + // give us a more accurate upper bound. + InstructionCost LinearBudget = VT.getVectorNumElements() * 2; + + // TODO: unify with TTI getSlideCost. + InstructionCost PerSlideCost = 1; + switch (RISCVTargetLowering::getLMUL(ContainerVT)) { + default: break; + case RISCVII::VLMUL::LMUL_2: + PerSlideCost = 2; + break; + case RISCVII::VLMUL::LMUL_4: + PerSlideCost = 4; + break; + case RISCVII::VLMUL::LMUL_8: + PerSlideCost = 8; + break; } - // For constant vectors, use generic constant pool lowering. Otherwise, - // we'd have to materialize constants in GPRs just to move them into the - // vector. - if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || - ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) + // TODO: Should we be using the build instseq then cost + evaluate scheme + // we use for integer constants here? + unsigned UndefCount = 0; + for (const SDValue &V : Op->ops()) { + if (V.isUndef()) { + UndefCount++; + continue; + } + if (UndefCount) { + LinearBudget -= PerSlideCost; + UndefCount = 0; + } + LinearBudget -= PerSlideCost; + } + if (UndefCount) { + LinearBudget -= PerSlideCost; + } + + if (LinearBudget < 0) return SDValue(); assert((!VT.isFloatingPoint() || @@ -3314,13 +3841,24 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, const unsigned Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC; - SDValue Vec = DAG.getUNDEF(ContainerVT); - unsigned UndefCount = 0; - for (const SDValue &V : Op->ops()) { + SDValue Vec; + UndefCount = 0; + for (SDValue V : Op->ops()) { if (V.isUndef()) { UndefCount++; continue; } + + // Start our sequence with a TA splat in the hopes that hardware is able to + // recognize there's no dependency on the prior value of our temporary + // register. + if (!Vec) { + Vec = DAG.getSplatVector(VT, DL, V); + Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); + UndefCount = 0; + continue; + } + if (UndefCount) { const SDValue Offset = DAG.getConstant(UndefCount, DL, Subtarget.getXLenVT()); Vec = getVSlidedown(DAG, Subtarget, DL, ContainerVT, DAG.getUNDEF(ContainerVT), @@ -3329,6 +3867,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, } auto OpCode = VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL; + if (!VT.isFloatingPoint()) + V = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), V); Vec = DAG.getNode(OpCode, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Vec, V, Mask, VL); } @@ -3353,19 +3893,43 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru, if ((LoC >> 31) == HiC) return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL); - // If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use - // vmv.v.x whose EEW = 32 to lower it. - if (LoC == HiC && isAllOnesConstant(VL)) { - MVT InterVT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2); - // TODO: if vl <= min(VLMAX), we can also do this. But we could not - // access the subtarget here now. - auto InterVec = DAG.getNode( - RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo, - DAG.getRegister(RISCV::X0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, DL, VT, InterVec); + // If vl is equal to VLMAX or fits in 4 bits and Hi constant is equal to Lo, + // we could use vmv.v.x whose EEW = 32 to lower it. This allows us to use + // vlmax vsetvli or vsetivli to change the VL. + // FIXME: Support larger constants? + // FIXME: Support non-constant VLs by saturating? + if (LoC == HiC) { + SDValue NewVL; + if (isAllOnesConstant(VL) || + (isa<RegisterSDNode>(VL) && + cast<RegisterSDNode>(VL)->getReg() == RISCV::X0)) + NewVL = DAG.getRegister(RISCV::X0, MVT::i32); + else if (isa<ConstantSDNode>(VL) && + isUInt<4>(cast<ConstantSDNode>(VL)->getZExtValue())) + NewVL = DAG.getNode(ISD::ADD, DL, VL.getValueType(), VL, VL); + + if (NewVL) { + MVT InterVT = + MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2); + auto InterVec = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterVT, + DAG.getUNDEF(InterVT), Lo, + DAG.getRegister(RISCV::X0, MVT::i32)); + return DAG.getNode(ISD::BITCAST, DL, VT, InterVec); + } } } + // Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended. + if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo && + isa<ConstantSDNode>(Hi.getOperand(1)) && + Hi.getConstantOperandVal(1) == 31) + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL); + + // If the hi bits of the splat are undefined, then it's fine to just splat Lo + // even if it might be sign extended. + if (Hi.isUndef()) + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL); + // Fall back to a stack store and stride x0 vector load. return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VT, Passthru, Lo, Hi, VL); @@ -3392,12 +3956,8 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL, bool HasPassthru = Passthru && !Passthru.isUndef(); if (!HasPassthru && !Passthru) Passthru = DAG.getUNDEF(VT); - if (VT.isFloatingPoint()) { - // If VL is 1, we could use vfmv.s.f. - if (isOneConstant(VL)) - return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, Passthru, Scalar, VL); + if (VT.isFloatingPoint()) return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL); - } MVT XLenVT = Subtarget.getXLenVT(); @@ -3410,12 +3970,6 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL, unsigned ExtOpc = isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND; Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar); - ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar); - // If VL is 1 and the scalar value won't benefit from immediate, we could - // use vmv.s.x. - if (isOneConstant(VL) && - (!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue()))) - return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL); return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL); } @@ -3430,14 +3984,6 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL, return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG); } -static MVT getLMUL1VT(MVT VT) { - assert(VT.getVectorElementType().getSizeInBits() <= 64 && - "Unexpected vector MVT"); - return MVT::getScalableVectorVT( - VT.getVectorElementType(), - RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits()); -} - // This function lowers an insert of a scalar operand Scalar into lane // 0 of the vector regardless of the value of VL. The contents of the // remaining lanes of the result vector are unspecified. VL is assumed @@ -3445,24 +3991,34 @@ static MVT getLMUL1VT(MVT VT) { static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT, const SDLoc &DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - const MVT XLenVT = Subtarget.getXLenVT(); + assert(VT.isScalableVector() && "Expect VT is scalable vector type."); + const MVT XLenVT = Subtarget.getXLenVT(); SDValue Passthru = DAG.getUNDEF(VT); - if (VT.isFloatingPoint()) { - // TODO: Use vmv.v.i for appropriate constants - // Use M1 or smaller to avoid over constraining register allocation - const MVT M1VT = getLMUL1VT(VT); - auto InnerVT = VT.bitsLE(M1VT) ? VT : M1VT; - SDValue Result = DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, InnerVT, - DAG.getUNDEF(InnerVT), Scalar, VL); - if (VT != InnerVT) - Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, - DAG.getUNDEF(VT), - Result, DAG.getConstant(0, DL, XLenVT)); - return Result; + + if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isNullConstant(Scalar.getOperand(1))) { + SDValue ExtractedVal = Scalar.getOperand(0); + MVT ExtractedVT = ExtractedVal.getSimpleValueType(); + MVT ExtractedContainerVT = ExtractedVT; + if (ExtractedContainerVT.isFixedLengthVector()) { + ExtractedContainerVT = getContainerForFixedLengthVector( + DAG, ExtractedContainerVT, Subtarget); + ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal, + DAG, Subtarget); + } + if (ExtractedContainerVT.bitsLE(VT)) + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal, + DAG.getConstant(0, DL, XLenVT)); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal, + DAG.getConstant(0, DL, XLenVT)); } + if (VT.isFloatingPoint()) + return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, + DAG.getUNDEF(VT), Scalar, VL); + // Avoid the tricky legalization cases by falling back to using the // splat code which already handles it gracefully. if (!Scalar.getValueType().bitsLE(XLenVT)) @@ -3477,24 +4033,8 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT, unsigned ExtOpc = isa<ConstantSDNode>(Scalar) ? ISD::SIGN_EXTEND : ISD::ANY_EXTEND; Scalar = DAG.getNode(ExtOpc, DL, XLenVT, Scalar); - // We use a vmv.v.i if possible. We limit this to LMUL1. LMUL2 or - // higher would involve overly constraining the register allocator for - // no purpose. - if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar)) { - if (!isNullConstant(Scalar) && isInt<5>(Const->getSExtValue()) && - VT.bitsLE(getLMUL1VT(VT))) - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL); - } - // Use M1 or smaller to avoid over constraining register allocation - const MVT M1VT = getLMUL1VT(VT); - auto InnerVT = VT.bitsLE(M1VT) ? VT : M1VT; - SDValue Result = DAG.getNode(RISCVISD::VMV_S_X_VL, DL, InnerVT, - DAG.getUNDEF(InnerVT), Scalar, VL); - if (VT != InnerVT) - Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, - DAG.getUNDEF(VT), - Result, DAG.getConstant(0, DL, XLenVT)); - return Result; + return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, + DAG.getUNDEF(VT), Scalar, VL); } // Is this a shuffle extracts either the even or odd elements of a vector? @@ -3508,7 +4048,7 @@ static bool isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1, SDValue V2, ArrayRef<int> Mask, const RISCVSubtarget &Subtarget) { // Need to be able to widen the vector. - if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) + if (VT.getScalarSizeInBits() >= Subtarget.getELen()) return false; // Both input must be extracts. @@ -3552,7 +4092,7 @@ static bool isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1, static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, int &EvenSrc, int &OddSrc, const RISCVSubtarget &Subtarget) { // We need to be able to widen elements to the next larger integer type. - if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) + if (VT.getScalarSizeInBits() >= Subtarget.getELen()) return false; int Size = Mask.size(); @@ -3881,6 +4421,8 @@ static SDValue lowerVECTOR_SHUFFLEAsVSlide1(const SDLoc &DL, MVT VT, auto OpCode = IsVSlidedown ? (VT.isFloatingPoint() ? RISCVISD::VFSLIDE1DOWN_VL : RISCVISD::VSLIDE1DOWN_VL) : (VT.isFloatingPoint() ? RISCVISD::VFSLIDE1UP_VL : RISCVISD::VSLIDE1UP_VL); + if (!VT.isFloatingPoint()) + Splat = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), Splat); auto Vec = DAG.getNode(OpCode, DL, ContainerVT, DAG.getUNDEF(ContainerVT), convertToScalableVector(ContainerVT, V2, DAG, Subtarget), @@ -3903,7 +4445,7 @@ static SDValue getWideningInterleave(SDValue EvenV, SDValue OddV, OddV = convertToScalableVector(VecContainerVT, OddV, DAG, Subtarget); } - assert(VecVT.getScalarSizeInBits() < Subtarget.getELEN()); + assert(VecVT.getScalarSizeInBits() < Subtarget.getELen()); // We're working with a vector of the same size as the resulting // interleaved vector, but with half the number of elements and @@ -3924,24 +4466,37 @@ static SDValue getWideningInterleave(SDValue EvenV, SDValue OddV, auto [Mask, VL] = getDefaultVLOps(VecVT, VecContainerVT, DL, DAG, Subtarget); SDValue Passthru = DAG.getUNDEF(WideContainerVT); - // Widen EvenV and OddV with 0s and add one copy of OddV to EvenV with - // vwaddu.vv - SDValue Interleaved = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideContainerVT, - EvenV, OddV, Passthru, Mask, VL); + SDValue Interleaved; + if (Subtarget.hasStdExtZvbb()) { + // Interleaved = (OddV << VecVT.getScalarSizeInBits()) + EvenV. + SDValue OffsetVec = + DAG.getSplatVector(VecContainerVT, DL, + DAG.getConstant(VecVT.getScalarSizeInBits(), DL, + Subtarget.getXLenVT())); + Interleaved = DAG.getNode(RISCVISD::VWSLL_VL, DL, WideContainerVT, OddV, + OffsetVec, Passthru, Mask, VL); + Interleaved = DAG.getNode(RISCVISD::VWADDU_W_VL, DL, WideContainerVT, + Interleaved, EvenV, Passthru, Mask, VL); + } else { + // Widen EvenV and OddV with 0s and add one copy of OddV to EvenV with + // vwaddu.vv + Interleaved = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideContainerVT, EvenV, + OddV, Passthru, Mask, VL); - // Then get OddV * by 2^(VecVT.getScalarSizeInBits() - 1) - SDValue AllOnesVec = DAG.getSplatVector( - VecContainerVT, DL, DAG.getAllOnesConstant(DL, Subtarget.getXLenVT())); - SDValue OddsMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideContainerVT, OddV, - AllOnesVec, Passthru, Mask, VL); + // Then get OddV * by 2^(VecVT.getScalarSizeInBits() - 1) + SDValue AllOnesVec = DAG.getSplatVector( + VecContainerVT, DL, DAG.getAllOnesConstant(DL, Subtarget.getXLenVT())); + SDValue OddsMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideContainerVT, + OddV, AllOnesVec, Passthru, Mask, VL); - // Add the two together so we get - // (OddV * 0xff...ff) + (OddV + EvenV) - // = (OddV * 0x100...00) + EvenV - // = (OddV << VecVT.getScalarSizeInBits()) + EvenV - // Note the ADD_VL and VLMULU_VL should get selected as vwmaccu.vx - Interleaved = DAG.getNode(RISCVISD::ADD_VL, DL, WideContainerVT, Interleaved, - OddsMul, Passthru, Mask, VL); + // Add the two together so we get + // (OddV * 0xff...ff) + (OddV + EvenV) + // = (OddV * 0x100...00) + EvenV + // = (OddV << VecVT.getScalarSizeInBits()) + EvenV + // Note the ADD_VL and VLMULU_VL should get selected as vwmaccu.vx + Interleaved = DAG.getNode(RISCVISD::ADD_VL, DL, WideContainerVT, + Interleaved, OddsMul, Passthru, Mask, VL); + } // Bitcast from <vscale x n * ty*2> to <vscale x 2*n x ty> MVT ResultContainerVT = MVT::getVectorVT( @@ -3960,6 +4515,96 @@ static SDValue getWideningInterleave(SDValue EvenV, SDValue OddV, return Interleaved; } +// If we have a vector of bits that we want to reverse, we can use a vbrev on a +// larger element type, e.g. v32i1 can be reversed with a v1i32 bitreverse. +static SDValue lowerBitreverseShuffle(ShuffleVectorSDNode *SVN, + SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDLoc DL(SVN); + MVT VT = SVN->getSimpleValueType(0); + SDValue V = SVN->getOperand(0); + unsigned NumElts = VT.getVectorNumElements(); + + assert(VT.getVectorElementType() == MVT::i1); + + if (!ShuffleVectorInst::isReverseMask(SVN->getMask(), + SVN->getMask().size()) || + !SVN->getOperand(1).isUndef()) + return SDValue(); + + unsigned ViaEltSize = std::max((uint64_t)8, PowerOf2Ceil(NumElts)); + EVT ViaVT = EVT::getVectorVT( + *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), ViaEltSize), 1); + EVT ViaBitVT = + EVT::getVectorVT(*DAG.getContext(), MVT::i1, ViaVT.getScalarSizeInBits()); + + // If we don't have zvbb or the larger element type > ELEN, the operation will + // be illegal. + if (!Subtarget.getTargetLowering()->isOperationLegalOrCustom(ISD::BITREVERSE, + ViaVT) || + !Subtarget.getTargetLowering()->isTypeLegal(ViaBitVT)) + return SDValue(); + + // If the bit vector doesn't fit exactly into the larger element type, we need + // to insert it into the larger vector and then shift up the reversed bits + // afterwards to get rid of the gap introduced. + if (ViaEltSize > NumElts) + V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ViaBitVT, DAG.getUNDEF(ViaBitVT), + V, DAG.getVectorIdxConstant(0, DL)); + + SDValue Res = + DAG.getNode(ISD::BITREVERSE, DL, ViaVT, DAG.getBitcast(ViaVT, V)); + + // Shift up the reversed bits if the vector didn't exactly fit into the larger + // element type. + if (ViaEltSize > NumElts) + Res = DAG.getNode(ISD::SRL, DL, ViaVT, Res, + DAG.getConstant(ViaEltSize - NumElts, DL, ViaVT)); + + Res = DAG.getBitcast(ViaBitVT, Res); + + if (ViaEltSize > NumElts) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, + DAG.getVectorIdxConstant(0, DL)); + return Res; +} + +// Given a shuffle mask like <3, 0, 1, 2, 7, 4, 5, 6> for v8i8, we can +// reinterpret it as a v2i32 and rotate it right by 8 instead. We can lower this +// as a vror.vi if we have Zvkb, or otherwise as a vsll, vsrl and vor. +static SDValue lowerVECTOR_SHUFFLEAsRotate(ShuffleVectorSDNode *SVN, + SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDLoc DL(SVN); + + EVT VT = SVN->getValueType(0); + unsigned NumElts = VT.getVectorNumElements(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + unsigned NumSubElts, RotateAmt; + if (!ShuffleVectorInst::isBitRotateMask(SVN->getMask(), EltSizeInBits, 2, + NumElts, NumSubElts, RotateAmt)) + return SDValue(); + MVT RotateVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits * NumSubElts), + NumElts / NumSubElts); + + // We might have a RotateVT that isn't legal, e.g. v4i64 on zve32x. + if (!Subtarget.getTargetLowering()->isTypeLegal(RotateVT)) + return SDValue(); + + SDValue Op = DAG.getBitcast(RotateVT, SVN->getOperand(0)); + + SDValue Rotate; + // A rotate of an i16 by 8 bits either direction is equivalent to a byteswap, + // so canonicalize to vrev8. + if (RotateVT.getScalarType() == MVT::i16 && RotateAmt == 8) + Rotate = DAG.getNode(ISD::BSWAP, DL, RotateVT, Op); + else + Rotate = DAG.getNode(ISD::ROTL, DL, RotateVT, Op, + DAG.getConstant(RotateAmt, DL, RotateVT)); + + return DAG.getBitcast(VT, Rotate); +} + static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue V1 = Op.getOperand(0); @@ -3970,8 +4615,15 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, unsigned NumElts = VT.getVectorNumElements(); ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode()); - // Promote i1 shuffle to i8 shuffle. if (VT.getVectorElementType() == MVT::i1) { + // Lower to a vror.vi of a larger element type if possible before we promote + // i1s to i8s. + if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget)) + return V; + if (SDValue V = lowerBitreverseShuffle(SVN, DAG, Subtarget)) + return V; + + // Promote i1 shuffle to i8 shuffle. MVT WidenVT = MVT::getVectorVT(MVT::i8, VT.getVectorElementCount()); V1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, V1); V2 = V2.isUndef() ? DAG.getUNDEF(WidenVT) @@ -4007,8 +4659,8 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, if (ISD::isNormalLoad(V.getNode()) && cast<LoadSDNode>(V)->isSimple()) { auto *Ld = cast<LoadSDNode>(V); Offset *= SVT.getStoreSize(); - SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), - TypeSize::Fixed(Offset), DL); + SDValue NewAddr = DAG.getMemBasePlusOffset( + Ld->getBasePtr(), TypeSize::getFixed(Offset), DL); // If this is SEW=64 on RV32, use a strided load with a stride of x0. if (SVT.isInteger() && SVT.bitsGT(XLenVT)) { @@ -4070,6 +4722,12 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, lowerVECTOR_SHUFFLEAsVSlidedown(DL, VT, V1, V2, Mask, Subtarget, DAG)) return V; + // A bitrotate will be one instruction on Zvkb, so try to lower to it first if + // available. + if (Subtarget.hasStdExtZvkb()) + if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget)) + return V; + // Lower rotations to a SLIDEDOWN and a SLIDEUP. One of the source vectors may // be undef which can be handled with a single SLIDEDOWN/UP. int LoSrc, HiSrc; @@ -4196,6 +4854,12 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, if (IsSelect) return DAG.getNode(ISD::VSELECT, DL, VT, SelectMask, V1, V2); + // We might be able to express the shuffle as a bitrotate. But even if we + // don't have Zvkb and have to expand, the expanded sequence of approx. 2 + // shifts and a vor will have a higher throughput than a vrgather. + if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget)) + return V; + if (VT.getScalarSizeInBits() == 8 && VT.getVectorNumElements() > 256) { // On such a large vector we're unable to use i8 as the index type. // FIXME: We could promote the index to i16 and use vrgatherei16, but that @@ -4215,6 +4879,15 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, IndexVT = IndexVT.changeVectorElementType(MVT::i16); } + // If the mask allows, we can do all the index computation in 16 bits. This + // requires less work and less register pressure at high LMUL, and creates + // smaller constants which may be cheaper to materialize. + if (IndexVT.getScalarType().bitsGT(MVT::i16) && isUInt<16>(NumElts - 1) && + (IndexVT.getSizeInBits() / Subtarget.getRealMinVLen()) > 1) { + GatherVVOpc = RISCVISD::VRGATHEREI16_VV_VL; + IndexVT = IndexVT.changeVectorElementType(MVT::i16); + } + MVT IndexContainerVT = ContainerVT.changeVectorElementType(IndexVT.getScalarType()); @@ -4489,26 +5162,26 @@ static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG, if (!Subtarget.useConstantPoolForLargeInts()) return Op; - RISCVMatInt::InstSeq Seq = - RISCVMatInt::generateInstSeq(Imm, Subtarget.getFeatureBits()); + RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Imm, Subtarget); if (Seq.size() <= Subtarget.getMaxBuildIntsCost()) return Op; - // Special case. See if we can build the constant as (ADD (SLLI X, 32), X) do + // Optimizations below are disabled for opt size. If we're optimizing for + // size, use a constant pool. + if (DAG.shouldOptForSize()) + return SDValue(); + + // Special case. See if we can build the constant as (ADD (SLLI X, C), X) do // that if it will avoid a constant pool. // It will require an extra temporary register though. - if (!DAG.shouldOptForSize()) { - int64_t LoVal = SignExtend64<32>(Imm); - int64_t HiVal = SignExtend64<32>(((uint64_t)Imm - (uint64_t)LoVal) >> 32); - if (LoVal == HiVal) { - RISCVMatInt::InstSeq SeqLo = - RISCVMatInt::generateInstSeq(LoVal, Subtarget.getFeatureBits()); - if ((SeqLo.size() + 2) <= Subtarget.getMaxBuildIntsCost()) - return Op; - } - } + // If we have Zba we can use (ADD_UW X, (SLLI X, 32)) to handle cases where + // low and high 32 bits are the same and bit 31 and 63 are set. + unsigned ShiftAmt, AddOpc; + RISCVMatInt::InstSeq SeqLo = + RISCVMatInt::generateTwoRegInstSeq(Imm, Subtarget, ShiftAmt, AddOpc); + if (!SeqLo.empty() && (SeqLo.size() + 2) <= Subtarget.getMaxBuildIntsCost()) + return Op; - // Expand to a constant pool using the default expansion code. return SDValue(); } @@ -4546,8 +5219,7 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); MVT XLenVT = Subtarget.getXLenVT(); - auto CNode = cast<ConstantSDNode>(Op.getOperand(1)); - unsigned Check = CNode->getZExtValue(); + unsigned Check = Op.getConstantOperandVal(1); unsigned TDCMask = 0; if (Check & fcSNan) TDCMask |= RISCV::FPMASK_Signaling_NaN; @@ -4581,6 +5253,10 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, if (VT.isScalableVector()) { MVT DstVT = VT0.changeVectorElementTypeToInteger(); auto [Mask, VL] = getDefaultScalableVLOps(VT0, DL, DAG, Subtarget); + if (Op.getOpcode() == ISD::VP_IS_FPCLASS) { + Mask = Op.getOperand(2); + VL = Op.getOperand(3); + } SDValue FPCLASS = DAG.getNode(RISCVISD::FCLASS_VL, DL, DstVT, Op0, Mask, VL, Op->getFlags()); if (IsOneBitMask) @@ -4597,7 +5273,13 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, MVT ContainerVT = getContainerForFixedLengthVector(VT); MVT ContainerDstVT = ContainerVT0.changeVectorElementTypeToInteger(); auto [Mask, VL] = getDefaultVLOps(VT0, ContainerVT0, DL, DAG, Subtarget); - + if (Op.getOpcode() == ISD::VP_IS_FPCLASS) { + Mask = Op.getOperand(2); + MVT MaskContainerVT = + getContainerForFixedLengthVector(Mask.getSimpleValueType()); + Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); + VL = Op.getOperand(3); + } Op0 = convertToScalableVector(ContainerVT0, Op0, DAG, Subtarget); SDValue FPCLASS = DAG.getNode(RISCVISD::FCLASS_VL, DL, ContainerDstVT, Op0, @@ -4615,7 +5297,7 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, SDValue AND = DAG.getNode(RISCVISD::AND_VL, DL, ContainerDstVT, FPCLASS, TDCMaskV, DAG.getUNDEF(ContainerDstVT), Mask, VL); - SDValue SplatZero = DAG.getConstant(0, DL, Subtarget.getXLenVT()); + SDValue SplatZero = DAG.getConstant(0, DL, XLenVT); SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerDstVT, DAG.getUNDEF(ContainerDstVT), SplatZero, VL); @@ -4625,10 +5307,11 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, return convertFromScalableVector(VT, VMSNE, DAG, Subtarget); } - SDValue FPCLASS = DAG.getNode(RISCVISD::FPCLASS, DL, VT, Op.getOperand(0)); - SDValue AND = DAG.getNode(ISD::AND, DL, VT, FPCLASS, TDCMaskV); - return DAG.getSetCC(DL, VT, AND, DAG.getConstant(0, DL, XLenVT), - ISD::CondCode::SETNE); + SDValue FCLASS = DAG.getNode(RISCVISD::FCLASS, DL, XLenVT, Op.getOperand(0)); + SDValue AND = DAG.getNode(ISD::AND, DL, XLenVT, FCLASS, TDCMaskV); + SDValue Res = DAG.getSetCC(DL, XLenVT, AND, DAG.getConstant(0, DL, XLenVT), + ISD::CondCode::SETNE); + return DAG.getNode(ISD::TRUNCATE, DL, VT, Res); } // Lower fmaximum and fminimum. Unlike our fmax and fmin instructions, these @@ -4636,38 +5319,88 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op, static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDLoc DL(Op); - EVT VT = Op.getValueType(); + MVT VT = Op.getSimpleValueType(); SDValue X = Op.getOperand(0); SDValue Y = Op.getOperand(1); - MVT XLenVT = Subtarget.getXLenVT(); + if (!VT.isVector()) { + MVT XLenVT = Subtarget.getXLenVT(); - // If X is a nan, replace Y with X. If Y is a nan, replace X with Y. This - // ensures that when one input is a nan, the other will also be a nan allowing - // the nan to propagate. If both inputs are nan, this will swap the inputs - // which is harmless. - // FIXME: Handle nonans FMF and use isKnownNeverNaN. - SDValue XIsNonNan = DAG.getSetCC(DL, XLenVT, X, X, ISD::SETOEQ); - SDValue NewY = DAG.getSelect(DL, VT, XIsNonNan, Y, X); + // If X is a nan, replace Y with X. If Y is a nan, replace X with Y. This + // ensures that when one input is a nan, the other will also be a nan + // allowing the nan to propagate. If both inputs are nan, this will swap the + // inputs which is harmless. - SDValue YIsNonNan = DAG.getSetCC(DL, XLenVT, Y, Y, ISD::SETOEQ); - SDValue NewX = DAG.getSelect(DL, VT, YIsNonNan, X, Y); + SDValue NewY = Y; + if (!Op->getFlags().hasNoNaNs() && !DAG.isKnownNeverNaN(X)) { + SDValue XIsNonNan = DAG.getSetCC(DL, XLenVT, X, X, ISD::SETOEQ); + NewY = DAG.getSelect(DL, VT, XIsNonNan, Y, X); + } + + SDValue NewX = X; + if (!Op->getFlags().hasNoNaNs() && !DAG.isKnownNeverNaN(Y)) { + SDValue YIsNonNan = DAG.getSetCC(DL, XLenVT, Y, Y, ISD::SETOEQ); + NewX = DAG.getSelect(DL, VT, YIsNonNan, X, Y); + } + + unsigned Opc = + Op.getOpcode() == ISD::FMAXIMUM ? RISCVISD::FMAX : RISCVISD::FMIN; + return DAG.getNode(Opc, DL, VT, NewX, NewY); + } + + // Check no NaNs before converting to fixed vector scalable. + bool XIsNeverNan = Op->getFlags().hasNoNaNs() || DAG.isKnownNeverNaN(X); + bool YIsNeverNan = Op->getFlags().hasNoNaNs() || DAG.isKnownNeverNaN(Y); + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + X = convertToScalableVector(ContainerVT, X, DAG, Subtarget); + Y = convertToScalableVector(ContainerVT, Y, DAG, Subtarget); + } + + auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + SDValue NewY = Y; + if (!XIsNeverNan) { + SDValue XIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(), + {X, X, DAG.getCondCode(ISD::SETOEQ), + DAG.getUNDEF(ContainerVT), Mask, VL}); + NewY = + DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, XIsNonNan, Y, X, VL); + } + + SDValue NewX = X; + if (!YIsNeverNan) { + SDValue YIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(), + {Y, Y, DAG.getCondCode(ISD::SETOEQ), + DAG.getUNDEF(ContainerVT), Mask, VL}); + NewX = + DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, YIsNonNan, X, Y, VL); + } unsigned Opc = - Op.getOpcode() == ISD::FMAXIMUM ? RISCVISD::FMAX : RISCVISD::FMIN; - return DAG.getNode(Opc, DL, VT, NewX, NewY); + Op.getOpcode() == ISD::FMAXIMUM ? RISCVISD::VFMAX_VL : RISCVISD::VFMIN_VL; + SDValue Res = DAG.getNode(Opc, DL, ContainerVT, NewX, NewY, + DAG.getUNDEF(ContainerVT), Mask, VL); + if (VT.isFixedLengthVector()) + Res = convertFromScalableVector(VT, Res, DAG, Subtarget); + return Res; } -/// Get a RISCV target specified VL op for a given SDNode. +/// Get a RISC-V target specified VL op for a given SDNode. static unsigned getRISCVVLOp(SDValue Op) { #define OP_CASE(NODE) \ case ISD::NODE: \ return RISCVISD::NODE##_VL; +#define VP_CASE(NODE) \ + case ISD::VP_##NODE: \ + return RISCVISD::NODE##_VL; + // clang-format off switch (Op.getOpcode()) { default: llvm_unreachable("don't have RISC-V specified VL op for this SDNode"); - // clang-format off OP_CASE(ADD) OP_CASE(SUB) OP_CASE(MUL) @@ -4680,6 +5413,13 @@ static unsigned getRISCVVLOp(SDValue Op) { OP_CASE(SHL) OP_CASE(SRA) OP_CASE(SRL) + OP_CASE(ROTL) + OP_CASE(ROTR) + OP_CASE(BSWAP) + OP_CASE(CTTZ) + OP_CASE(CTLZ) + OP_CASE(CTPOP) + OP_CASE(BITREVERSE) OP_CASE(SADDSAT) OP_CASE(UADDSAT) OP_CASE(SSUBSAT) @@ -4695,47 +5435,113 @@ static unsigned getRISCVVLOp(SDValue Op) { OP_CASE(SMAX) OP_CASE(UMIN) OP_CASE(UMAX) - OP_CASE(FMINNUM) - OP_CASE(FMAXNUM) OP_CASE(STRICT_FADD) OP_CASE(STRICT_FSUB) OP_CASE(STRICT_FMUL) OP_CASE(STRICT_FDIV) OP_CASE(STRICT_FSQRT) - // clang-format on -#undef OP_CASE + VP_CASE(ADD) // VP_ADD + VP_CASE(SUB) // VP_SUB + VP_CASE(MUL) // VP_MUL + VP_CASE(SDIV) // VP_SDIV + VP_CASE(SREM) // VP_SREM + VP_CASE(UDIV) // VP_UDIV + VP_CASE(UREM) // VP_UREM + VP_CASE(SHL) // VP_SHL + VP_CASE(FADD) // VP_FADD + VP_CASE(FSUB) // VP_FSUB + VP_CASE(FMUL) // VP_FMUL + VP_CASE(FDIV) // VP_FDIV + VP_CASE(FNEG) // VP_FNEG + VP_CASE(FABS) // VP_FABS + VP_CASE(SMIN) // VP_SMIN + VP_CASE(SMAX) // VP_SMAX + VP_CASE(UMIN) // VP_UMIN + VP_CASE(UMAX) // VP_UMAX + VP_CASE(FCOPYSIGN) // VP_FCOPYSIGN + VP_CASE(SETCC) // VP_SETCC + VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP + VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP + VP_CASE(BITREVERSE) // VP_BITREVERSE + VP_CASE(BSWAP) // VP_BSWAP + VP_CASE(CTLZ) // VP_CTLZ + VP_CASE(CTTZ) // VP_CTTZ + VP_CASE(CTPOP) // VP_CTPOP + case ISD::CTLZ_ZERO_UNDEF: + case ISD::VP_CTLZ_ZERO_UNDEF: + return RISCVISD::CTLZ_VL; + case ISD::CTTZ_ZERO_UNDEF: + case ISD::VP_CTTZ_ZERO_UNDEF: + return RISCVISD::CTTZ_VL; case ISD::FMA: + case ISD::VP_FMA: return RISCVISD::VFMADD_VL; case ISD::STRICT_FMA: return RISCVISD::STRICT_VFMADD_VL; case ISD::AND: + case ISD::VP_AND: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMAND_VL; return RISCVISD::AND_VL; case ISD::OR: + case ISD::VP_OR: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMOR_VL; return RISCVISD::OR_VL; case ISD::XOR: + case ISD::VP_XOR: if (Op.getSimpleValueType().getVectorElementType() == MVT::i1) return RISCVISD::VMXOR_VL; return RISCVISD::XOR_VL; + case ISD::VP_SELECT: + return RISCVISD::VSELECT_VL; + case ISD::VP_MERGE: + return RISCVISD::VP_MERGE_VL; + case ISD::VP_ASHR: + return RISCVISD::SRA_VL; + case ISD::VP_LSHR: + return RISCVISD::SRL_VL; + case ISD::VP_SQRT: + return RISCVISD::FSQRT_VL; + case ISD::VP_SIGN_EXTEND: + return RISCVISD::VSEXT_VL; + case ISD::VP_ZERO_EXTEND: + return RISCVISD::VZEXT_VL; + case ISD::VP_FP_TO_SINT: + return RISCVISD::VFCVT_RTZ_X_F_VL; + case ISD::VP_FP_TO_UINT: + return RISCVISD::VFCVT_RTZ_XU_F_VL; + case ISD::FMINNUM: + case ISD::VP_FMINNUM: + return RISCVISD::VFMIN_VL; + case ISD::FMAXNUM: + case ISD::VP_FMAXNUM: + return RISCVISD::VFMAX_VL; } + // clang-format on +#undef OP_CASE +#undef VP_CASE } /// Return true if a RISC-V target specified op has a merge operand. static bool hasMergeOp(unsigned Opcode) { assert(Opcode > RISCVISD::FIRST_NUMBER && - Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL && + Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE && "not a RISC-V target specific op"); - assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 && - "adding target specific op should update this function"); - if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::FMAXNUM_VL) + static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == + 125 && + RISCVISD::LAST_RISCV_STRICTFP_OPCODE - + ISD::FIRST_TARGET_STRICTFP_OPCODE == + 21 && + "adding target specific op should update this function"); + if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL) return true; if (Opcode == RISCVISD::FCOPYSIGN_VL) return true; if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL) return true; + if (Opcode == RISCVISD::SETCC_VL) + return true; if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL) return true; return false; @@ -4744,10 +5550,14 @@ static bool hasMergeOp(unsigned Opcode) { /// Return true if a RISC-V target specified op has a mask operand. static bool hasMaskOp(unsigned Opcode) { assert(Opcode > RISCVISD::FIRST_NUMBER && - Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL && + Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE && "not a RISC-V target specific op"); - assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 && - "adding target specific op should update this function"); + static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == + 125 && + RISCVISD::LAST_RISCV_STRICTFP_OPCODE - + ISD::FIRST_TARGET_STRICTFP_OPCODE == + 21 && + "adding target specific op should update this function"); if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL) return true; if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL) @@ -4758,6 +5568,112 @@ static bool hasMaskOp(unsigned Opcode) { return false; } +static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) { + auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType()); + SDLoc DL(Op); + + SmallVector<SDValue, 4> LoOperands(Op.getNumOperands()); + SmallVector<SDValue, 4> HiOperands(Op.getNumOperands()); + + for (unsigned j = 0; j != Op.getNumOperands(); ++j) { + if (!Op.getOperand(j).getValueType().isVector()) { + LoOperands[j] = Op.getOperand(j); + HiOperands[j] = Op.getOperand(j); + continue; + } + std::tie(LoOperands[j], HiOperands[j]) = + DAG.SplitVector(Op.getOperand(j), DL); + } + + SDValue LoRes = + DAG.getNode(Op.getOpcode(), DL, LoVT, LoOperands, Op->getFlags()); + SDValue HiRes = + DAG.getNode(Op.getOpcode(), DL, HiVT, HiOperands, Op->getFlags()); + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, Op.getValueType(), LoRes, HiRes); +} + +static SDValue SplitVPOp(SDValue Op, SelectionDAG &DAG) { + assert(ISD::isVPOpcode(Op.getOpcode()) && "Not a VP op"); + auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType()); + SDLoc DL(Op); + + SmallVector<SDValue, 4> LoOperands(Op.getNumOperands()); + SmallVector<SDValue, 4> HiOperands(Op.getNumOperands()); + + for (unsigned j = 0; j != Op.getNumOperands(); ++j) { + if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == j) { + std::tie(LoOperands[j], HiOperands[j]) = + DAG.SplitEVL(Op.getOperand(j), Op.getValueType(), DL); + continue; + } + if (!Op.getOperand(j).getValueType().isVector()) { + LoOperands[j] = Op.getOperand(j); + HiOperands[j] = Op.getOperand(j); + continue; + } + std::tie(LoOperands[j], HiOperands[j]) = + DAG.SplitVector(Op.getOperand(j), DL); + } + + SDValue LoRes = + DAG.getNode(Op.getOpcode(), DL, LoVT, LoOperands, Op->getFlags()); + SDValue HiRes = + DAG.getNode(Op.getOpcode(), DL, HiVT, HiOperands, Op->getFlags()); + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, Op.getValueType(), LoRes, HiRes); +} + +static SDValue SplitVectorReductionOp(SDValue Op, SelectionDAG &DAG) { + SDLoc DL(Op); + + auto [Lo, Hi] = DAG.SplitVector(Op.getOperand(1), DL); + auto [MaskLo, MaskHi] = DAG.SplitVector(Op.getOperand(2), DL); + auto [EVLLo, EVLHi] = + DAG.SplitEVL(Op.getOperand(3), Op.getOperand(1).getValueType(), DL); + + SDValue ResLo = + DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), + {Op.getOperand(0), Lo, MaskLo, EVLLo}, Op->getFlags()); + return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), + {ResLo, Hi, MaskHi, EVLHi}, Op->getFlags()); +} + +static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) { + + assert(Op->isStrictFPOpcode()); + + auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op->getValueType(0)); + + SDVTList LoVTs = DAG.getVTList(LoVT, Op->getValueType(1)); + SDVTList HiVTs = DAG.getVTList(HiVT, Op->getValueType(1)); + + SDLoc DL(Op); + + SmallVector<SDValue, 4> LoOperands(Op.getNumOperands()); + SmallVector<SDValue, 4> HiOperands(Op.getNumOperands()); + + for (unsigned j = 0; j != Op.getNumOperands(); ++j) { + if (!Op.getOperand(j).getValueType().isVector()) { + LoOperands[j] = Op.getOperand(j); + HiOperands[j] = Op.getOperand(j); + continue; + } + std::tie(LoOperands[j], HiOperands[j]) = + DAG.SplitVector(Op.getOperand(j), DL); + } + + SDValue LoRes = + DAG.getNode(Op.getOpcode(), DL, LoVTs, LoOperands, Op->getFlags()); + HiOperands[0] = LoRes.getValue(1); + SDValue HiRes = + DAG.getNode(Op.getOpcode(), DL, HiVTs, HiOperands, Op->getFlags()); + + SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, Op->getValueType(0), + LoRes.getValue(0), HiRes.getValue(0)); + return DAG.getMergeValues({V, HiRes.getValue(1)}, DL); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -4795,6 +5711,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerShiftRightParts(Op, DAG, false); case ISD::ROTL: case ISD::ROTR: + if (Op.getValueType().isFixedLengthVector()) { + assert(Subtarget.hasStdExtZvkb()); + return lowerToScalableOp(Op, DAG); + } assert(Subtarget.hasVendorXTHeadBb() && !(Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbkb()) && "Unexpected custom legalization"); @@ -4888,6 +5808,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return LowerIS_FPCLASS(Op, DAG); case ISD::BITREVERSE: { MVT VT = Op.getSimpleValueType(); + if (VT.isFixedLengthVector()) { + assert(Subtarget.hasStdExtZvbb()); + return lowerToScalableOp(Op, DAG); + } SDLoc DL(Op); assert(Subtarget.hasStdExtZbkb() && "Unexpected custom legalization"); assert(Op.getOpcode() == ISD::BITREVERSE && "Unexpected opcode"); @@ -4930,6 +5854,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, if (VT.isFixedLengthVector()) ContainerVT = getContainerForFixedLengthVector(VT); SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second; + Scalar = DAG.getNode(ISD::ANY_EXTEND, DL, Subtarget.getXLenVT(), Scalar); SDValue V = DAG.getNode(RISCVISD::VMV_S_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Scalar, VL); if (VT.isFixedLengthVector()) @@ -4937,9 +5862,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return V; } case ISD::VSCALE: { + MVT XLenVT = Subtarget.getXLenVT(); MVT VT = Op.getSimpleValueType(); SDLoc DL(Op); - SDValue VLENB = DAG.getNode(RISCVISD::READ_VLENB, DL, VT); + SDValue Res = DAG.getNode(RISCVISD::READ_VLENB, DL, XLenVT); // We define our scalable vector types for lmul=1 to use a 64 bit known // minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate // vscale as VLENB / 8. @@ -4952,22 +5878,23 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, if (isPowerOf2_64(Val)) { uint64_t Log2 = Log2_64(Val); if (Log2 < 3) - return DAG.getNode(ISD::SRL, DL, VT, VLENB, - DAG.getConstant(3 - Log2, DL, VT)); - if (Log2 > 3) - return DAG.getNode(ISD::SHL, DL, VT, VLENB, - DAG.getConstant(Log2 - 3, DL, VT)); - return VLENB; + Res = DAG.getNode(ISD::SRL, DL, XLenVT, Res, + DAG.getConstant(3 - Log2, DL, VT)); + else if (Log2 > 3) + Res = DAG.getNode(ISD::SHL, DL, XLenVT, Res, + DAG.getConstant(Log2 - 3, DL, XLenVT)); + } else if ((Val % 8) == 0) { + // If the multiplier is a multiple of 8, scale it down to avoid needing + // to shift the VLENB value. + Res = DAG.getNode(ISD::MUL, DL, XLenVT, Res, + DAG.getConstant(Val / 8, DL, XLenVT)); + } else { + SDValue VScale = DAG.getNode(ISD::SRL, DL, XLenVT, Res, + DAG.getConstant(3, DL, XLenVT)); + Res = DAG.getNode(ISD::MUL, DL, XLenVT, VScale, + DAG.getConstant(Val, DL, XLenVT)); } - // If the multiplier is a multiple of 8, scale it down to avoid needing - // to shift the VLENB value. - if ((Val % 8) == 0) - return DAG.getNode(ISD::MUL, DL, VT, VLENB, - DAG.getConstant(Val / 8, DL, VT)); - - SDValue VScale = DAG.getNode(ISD::SRL, DL, VT, VLENB, - DAG.getConstant(3, DL, VT)); - return DAG.getNode(ISD::MUL, DL, VT, VScale, Op.getOperand(0)); + return DAG.getNode(ISD::TRUNCATE, DL, VT, Res); } case ISD::FPOWI: { // Custom promote f16 powi with illegal i32 integer type on RV64. Once @@ -4985,6 +5912,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } case ISD::FMAXIMUM: case ISD::FMINIMUM: + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVectorOp(Op, DAG); return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget); case ISD::FP_EXTEND: { SDLoc DL(Op); @@ -5025,10 +5956,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FP_ROUND: case ISD::STRICT_FP_EXTEND: return lowerStrictFPExtendOrRoundLike(Op, DAG); - case ISD::FP_TO_SINT: - case ISD::FP_TO_UINT: case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: + if (Op.getValueType().isVector() && + Op.getValueType().getScalarType() == MVT::f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + if (Op.getValueType() == MVT::nxv32f16) + return SplitVectorOp(Op, DAG); + // int -> f32 + SDLoc DL(Op); + MVT NVT = + MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()); + SDValue NC = DAG.getNode(Op.getOpcode(), DL, NVT, Op->ops()); + // f32 -> f16 + return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NC, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + } + [[fallthrough]]; + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + if (SDValue Op1 = Op.getOperand(0); + Op1.getValueType().isVector() && + Op1.getValueType().getScalarType() == MVT::f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + if (Op1.getValueType() == MVT::nxv32f16) + return SplitVectorOp(Op, DAG); + // f16 -> f32 + SDLoc DL(Op); + MVT NVT = MVT::getVectorVT(MVT::f32, + Op1.getValueType().getVectorElementCount()); + SDValue WidenVec = DAG.getNode(ISD::FP_EXTEND, DL, NVT, Op1); + // f32 -> int + return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), WidenVec); + } + [[fallthrough]]; case ISD::STRICT_FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: case ISD::STRICT_SINT_TO_FP: @@ -5179,7 +6142,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16); SDValue Res = makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; - if (Subtarget.is64Bit()) + if (Subtarget.is64Bit() && !RV64LegalI32) return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res); return DAG.getBitcast(MVT::i32, Res); } @@ -5208,7 +6171,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::f16); SDValue Res = makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; - if (Subtarget.is64Bit()) + if (Subtarget.is64Bit() && !RV64LegalI32) return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res); return DAG.getBitcast(MVT::i32, Res); } @@ -5235,6 +6198,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::FROUND: case ISD::FROUNDEVEN: return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget); + case ISD::LRINT: + case ISD::LLRINT: + return lowerVectorXRINT(Op, DAG, Subtarget); case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_SMAX: @@ -5261,6 +6227,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::VP_REDUCE_SEQ_FADD: case ISD::VP_REDUCE_FMIN: case ISD::VP_REDUCE_FMAX: + if (Op.getOperand(1).getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVectorReductionOp(Op, DAG); return lowerVPREDUCE(Op, DAG); case ISD::VP_REDUCE_AND: case ISD::VP_REDUCE_OR: @@ -5290,6 +6260,21 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::BUILD_VECTOR: return lowerBUILD_VECTOR(Op, DAG, Subtarget); case ISD::SPLAT_VECTOR: + if (Op.getValueType().getScalarType() == MVT::f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + if (Op.getValueType() == MVT::nxv32f16) + return SplitVectorOp(Op, DAG); + SDLoc DL(Op); + SDValue NewScalar = + DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0)); + SDValue NewSplat = DAG.getNode( + ISD::SPLAT_VECTOR, DL, + MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()), + NewScalar); + return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NewSplat, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + } if (Op.getValueType().getVectorElementType() == MVT::i1) return lowerVectorMaskSplat(Op, DAG); return SDValue(); @@ -5368,9 +6353,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, if (isa<ConstantSDNode>(RHS)) { int64_t Imm = cast<ConstantSDNode>(RHS)->getSExtValue(); if (Imm != 0 && isInt<12>((uint64_t)Imm + 1)) { - // X > -1 should have been replaced with false. - assert((CCVal != ISD::SETUGT || Imm != -1) && - "Missing canonicalization"); + // If this is an unsigned compare and the constant is -1, incrementing + // the constant would change behavior. The result should be false. + if (CCVal == ISD::SETUGT && Imm == -1) + return DAG.getConstant(0, DL, VT); // Using getSetCCSwappedOperands will convert SET(U)GT->SET(U)LT. CCVal = ISD::getSetCCSwappedOperands(CCVal); SDValue SetCC = DAG.getSetCC( @@ -5385,6 +6371,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getSetCC(DL, VT, RHS, LHS, CCVal); } + if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVectorOp(Op, DAG); + return lowerFixedLengthVectorSetccToRVV(Op, DAG); } case ISD::ADD: @@ -5399,6 +6390,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::SREM: case ISD::UDIV: case ISD::UREM: + case ISD::BSWAP: + case ISD::CTPOP: return lowerToScalableOp(Op, DAG); case ISD::SHL: case ISD::SRA: @@ -5409,10 +6402,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); return SDValue(); - case ISD::SADDSAT: - case ISD::UADDSAT: - case ISD::SSUBSAT: - case ISD::USUBSAT: case ISD::FADD: case ISD::FSUB: case ISD::FMUL: @@ -5421,23 +6410,40 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::FABS: case ISD::FSQRT: case ISD::FMA: + case ISD::FMINNUM: + case ISD::FMAXNUM: + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVectorOp(Op, DAG); + [[fallthrough]]; + case ISD::SADDSAT: + case ISD::UADDSAT: + case ISD::SSUBSAT: + case ISD::USUBSAT: case ISD::SMIN: case ISD::SMAX: case ISD::UMIN: case ISD::UMAX: - case ISD::FMINNUM: - case ISD::FMAXNUM: return lowerToScalableOp(Op, DAG); case ISD::ABS: case ISD::VP_ABS: return lowerABS(Op, DAG); case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: + case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: + if (Subtarget.hasStdExtZvbb()) + return lowerToScalableOp(Op, DAG); + assert(Op.getOpcode() != ISD::CTTZ); return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG); case ISD::VSELECT: return lowerFixedLengthVectorSelectToRVV(Op, DAG); case ISD::FCOPYSIGN: + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVectorOp(Op, DAG); return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG); case ISD::STRICT_FADD: case ISD::STRICT_FSUB: @@ -5445,6 +6451,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FDIV: case ISD::STRICT_FSQRT: case ISD::STRICT_FMA: + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitStrictFPVectorOp(Op, DAG); return lowerToScalableOp(Op, DAG); case ISD::STRICT_FSETCC: case ISD::STRICT_FSETCCS: @@ -5470,106 +6480,115 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::EH_DWARF_CFA: return lowerEH_DWARF_CFA(Op, DAG); case ISD::VP_SELECT: - return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL); case ISD::VP_MERGE: - return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL); case ISD::VP_ADD: - return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true); case ISD::VP_SUB: - return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true); case ISD::VP_MUL: - return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true); case ISD::VP_SDIV: - return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true); case ISD::VP_UDIV: - return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true); case ISD::VP_SREM: - return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true); case ISD::VP_UREM: - return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::VP_AND: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL); case ISD::VP_OR: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMOR_VL, RISCVISD::OR_VL); case ISD::VP_XOR: - return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL); - case ISD::VP_ASHR: - return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, /*HasMergeOp*/ true); - case ISD::VP_LSHR: - return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, /*HasMergeOp*/ true); - case ISD::VP_SHL: - return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, /*HasMergeOp*/ true); + return lowerLogicVPOp(Op, DAG); case ISD::VP_FADD: - return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true); case ISD::VP_FSUB: - return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true); case ISD::VP_FMUL: - return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true); case ISD::VP_FDIV: - return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true); case ISD::VP_FNEG: - return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL); case ISD::VP_FABS: - return lowerVPOp(Op, DAG, RISCVISD::FABS_VL); case ISD::VP_SQRT: - return lowerVPOp(Op, DAG, RISCVISD::FSQRT_VL); case ISD::VP_FMA: - return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL); case ISD::VP_FMINNUM: - return lowerVPOp(Op, DAG, RISCVISD::FMINNUM_VL, /*HasMergeOp*/ true); case ISD::VP_FMAXNUM: - return lowerVPOp(Op, DAG, RISCVISD::FMAXNUM_VL, /*HasMergeOp*/ true); case ISD::VP_FCOPYSIGN: - return lowerVPOp(Op, DAG, RISCVISD::FCOPYSIGN_VL, /*HasMergeOp*/ true); + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVPOp(Op, DAG); + [[fallthrough]]; + case ISD::VP_ASHR: + case ISD::VP_LSHR: + case ISD::VP_SHL: + return lowerVPOp(Op, DAG); + case ISD::VP_IS_FPCLASS: + return LowerIS_FPCLASS(Op, DAG); case ISD::VP_SIGN_EXTEND: case ISD::VP_ZERO_EXTEND: if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) return lowerVPExtMaskOp(Op, DAG); - return lowerVPOp(Op, DAG, - Op.getOpcode() == ISD::VP_SIGN_EXTEND - ? RISCVISD::VSEXT_VL - : RISCVISD::VZEXT_VL); + return lowerVPOp(Op, DAG); case ISD::VP_TRUNCATE: return lowerVectorTruncLike(Op, DAG); case ISD::VP_FP_EXTEND: case ISD::VP_FP_ROUND: return lowerVectorFPExtendOrRoundLike(Op, DAG); - case ISD::VP_FP_TO_SINT: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_X_F_VL); - case ISD::VP_FP_TO_UINT: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_XU_F_VL); case ISD::VP_SINT_TO_FP: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL); case ISD::VP_UINT_TO_FP: - return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL); + if (Op.getValueType().isVector() && + Op.getValueType().getScalarType() == MVT::f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + if (Op.getValueType() == MVT::nxv32f16) + return SplitVPOp(Op, DAG); + // int -> f32 + SDLoc DL(Op); + MVT NVT = + MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()); + auto NC = DAG.getNode(Op.getOpcode(), DL, NVT, Op->ops()); + // f32 -> f16 + return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NC, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + } + [[fallthrough]]; + case ISD::VP_FP_TO_SINT: + case ISD::VP_FP_TO_UINT: + if (SDValue Op1 = Op.getOperand(0); + Op1.getValueType().isVector() && + Op1.getValueType().getScalarType() == MVT::f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + if (Op1.getValueType() == MVT::nxv32f16) + return SplitVPOp(Op, DAG); + // f16 -> f32 + SDLoc DL(Op); + MVT NVT = MVT::getVectorVT(MVT::f32, + Op1.getValueType().getVectorElementCount()); + SDValue WidenVec = DAG.getNode(ISD::FP_EXTEND, DL, NVT, Op1); + // f32 -> int + return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), + {WidenVec, Op.getOperand(1), Op.getOperand(2)}); + } + return lowerVPFPIntConvOp(Op, DAG); case ISD::VP_SETCC: + if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVPOp(Op, DAG); if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) return lowerVPSetCCMaskOp(Op, DAG); - return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL, /*HasMergeOp*/ true); + [[fallthrough]]; case ISD::VP_SMIN: - return lowerVPOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true); case ISD::VP_SMAX: - return lowerVPOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true); case ISD::VP_UMIN: - return lowerVPOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true); case ISD::VP_UMAX: - return lowerVPOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true); case ISD::VP_BITREVERSE: - return lowerVPOp(Op, DAG, RISCVISD::BITREVERSE_VL, /*HasMergeOp*/ true); case ISD::VP_BSWAP: - return lowerVPOp(Op, DAG, RISCVISD::BSWAP_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::VP_CTLZ: case ISD::VP_CTLZ_ZERO_UNDEF: if (Subtarget.hasStdExtZvbb()) - return lowerVPOp(Op, DAG, RISCVISD::CTLZ_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG); case ISD::VP_CTTZ: case ISD::VP_CTTZ_ZERO_UNDEF: if (Subtarget.hasStdExtZvbb()) - return lowerVPOp(Op, DAG, RISCVISD::CTTZ_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG); case ISD::VP_CTPOP: - return lowerVPOp(Op, DAG, RISCVISD::CTPOP_VL, /*HasMergeOp*/ true); + return lowerVPOp(Op, DAG); case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: return lowerVPStridedLoad(Op, DAG); case ISD::EXPERIMENTAL_VP_STRIDED_STORE: @@ -5581,7 +6600,13 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::VP_FROUND: case ISD::VP_FROUNDEVEN: case ISD::VP_FROUNDTOZERO: + if (Op.getValueType() == MVT::nxv32f16 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SplitVPOp(Op, DAG); return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget); + case ISD::EXPERIMENTAL_VP_REVERSE: + return lowerVPReverseExperimental(Op, DAG); } } @@ -5628,15 +6653,15 @@ SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, // Use PC-relative addressing to access the GOT for this symbol, then load // the address from the GOT. This generates the pattern (PseudoLGA sym), // which expands to (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))). + SDValue Load = + SDValue(DAG.getMachineNode(RISCV::PseudoLGA, DL, Ty, Addr), 0); MachineFunction &MF = DAG.getMachineFunction(); MachineMemOperand *MemOp = MF.getMachineMemOperand( MachinePointerInfo::getGOT(MF), MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant, LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8)); - SDValue Load = - DAG.getMemIntrinsicNode(RISCVISD::LGA, DL, DAG.getVTList(Ty, MVT::Other), - {DAG.getEntryNode(), Addr}, Ty, MemOp); + DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp}); return Load; } @@ -5658,16 +6683,15 @@ SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, // not be within 2GiB of PC, so use GOT-indirect addressing to access the // symbol. This generates the pattern (PseudoLGA sym), which expands to // (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))). + SDValue Load = + SDValue(DAG.getMachineNode(RISCV::PseudoLGA, DL, Ty, Addr), 0); MachineFunction &MF = DAG.getMachineFunction(); MachineMemOperand *MemOp = MF.getMachineMemOperand( MachinePointerInfo::getGOT(MF), MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant, LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8)); - SDValue Load = - DAG.getMemIntrinsicNode(RISCVISD::LGA, DL, - DAG.getVTList(Ty, MVT::Other), - {DAG.getEntryNode(), Addr}, Ty, MemOp); + DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp}); return Load; } @@ -5722,15 +6746,15 @@ SDValue RISCVTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N, // the pattern (PseudoLA_TLS_IE sym), which expands to // (ld (auipc %tls_ie_pcrel_hi(sym)) %pcrel_lo(auipc)). SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0); + SDValue Load = + SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_IE, DL, Ty, Addr), 0); MachineFunction &MF = DAG.getMachineFunction(); MachineMemOperand *MemOp = MF.getMachineMemOperand( MachinePointerInfo::getGOT(MF), MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant, LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8)); - SDValue Load = DAG.getMemIntrinsicNode( - RISCVISD::LA_TLS_IE, DL, DAG.getVTList(Ty, MVT::Other), - {DAG.getEntryNode(), Addr}, Ty, MemOp); + DAG.setNodeMemRefs(cast<MachineSDNode>(Load.getNode()), {MemOp}); // Add the thread pointer. SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT); @@ -5766,7 +6790,8 @@ SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N, // This generates the pattern (PseudoLA_TLS_GD sym), which expands to // (addi (auipc %tls_gd_pcrel_hi(sym)) %pcrel_lo(auipc)). SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0); - SDValue Load = DAG.getNode(RISCVISD::LA_TLS_GD, DL, Ty, Addr); + SDValue Load = + SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_GD, DL, Ty, Addr), 0); // Prepare argument list to generate call. ArgListTy Args; @@ -5902,56 +6927,6 @@ static SDValue combineSelectToBinOp(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// RISC-V doesn't have general instructions for integer setne/seteq, but we can -/// check for equality with 0. This function emits nodes that convert the -/// seteq/setne into something that can be compared with 0. -/// Based on RISCVDAGToDAGISel::selectSETCC but modified to produce -/// target-independent SelectionDAG nodes rather than machine nodes. -static SDValue selectSETCC(SDValue N, ISD::CondCode ExpectedCCVal, - SelectionDAG &DAG) { - assert(ISD::isIntEqualitySetCC(ExpectedCCVal) && - "Unexpected condition code!"); - - // We're looking for a setcc. - if (N->getOpcode() != ISD::SETCC) - return SDValue(); - - // Must be an equality comparison. - ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(2))->get(); - if (CCVal != ExpectedCCVal) - return SDValue(); - - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - - if (!LHS.getValueType().isScalarInteger()) - return SDValue(); - - // If the RHS side is 0, we don't need any extra instructions, return the LHS. - if (isNullConstant(RHS)) - return LHS; - - SDLoc DL(N); - - if (auto *C = dyn_cast<ConstantSDNode>(RHS)) { - int64_t CVal = C->getSExtValue(); - // If the RHS is -2048, we can use xori to produce 0 if the LHS is -2048 and - // non-zero otherwise. - if (CVal == -2048) - return DAG.getNode(ISD::XOR, DL, N->getValueType(0), LHS, - DAG.getConstant(CVal, DL, N->getValueType(0))); - // If the RHS is [-2047,2048], we can use addi with -RHS to produce 0 if the - // LHS is equal to the RHS and non-zero otherwise. - if (isInt<12>(CVal) || CVal == 2048) - return DAG.getNode(ISD::ADD, DL, N->getValueType(0), LHS, - DAG.getConstant(-CVal, DL, N->getValueType(0))); - } - - // If nothing else we can XOR the LHS and RHS to produce zero if they are - // equal and a non-zero value if they aren't. - return DAG.getNode(ISD::XOR, DL, N->getValueType(0), LHS, RHS); -} - // Transform `binOp (select cond, x, c0), c1` where `c0` and `c1` are constants // into `select cond, binOp(x, c1), binOp(c0, c1)` if profitable. // For now we only consider transformation profitable if `binOp(c0, c1)` ends up @@ -6039,35 +7014,6 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const { // sequence or RISCVISD::SELECT_CC node (branch-based select). if ((Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps()) && VT.isScalarInteger()) { - if (SDValue NewCondV = selectSETCC(CondV, ISD::SETNE, DAG)) { - // (select (riscv_setne c), t, 0) -> (czero_eqz t, c) - if (isNullConstant(FalseV)) - return DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, NewCondV); - // (select (riscv_setne c), 0, f) -> (czero_nez f, c) - if (isNullConstant(TrueV)) - return DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, NewCondV); - // (select (riscv_setne c), t, f) -> (or (czero_eqz t, c), (czero_nez f, - // c) - return DAG.getNode( - ISD::OR, DL, VT, - DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, NewCondV), - DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, NewCondV)); - } - if (SDValue NewCondV = selectSETCC(CondV, ISD::SETEQ, DAG)) { - // (select (riscv_seteq c), t, 0) -> (czero_nez t, c) - if (isNullConstant(FalseV)) - return DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, TrueV, NewCondV); - // (select (riscv_seteq c), 0, f) -> (czero_eqz f, c) - if (isNullConstant(TrueV)) - return DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, FalseV, NewCondV); - // (select (riscv_seteq c), t, f) -> (or (czero_eqz f, c), (czero_nez t, - // c) - return DAG.getNode( - ISD::OR, DL, VT, - DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, FalseV, NewCondV), - DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, TrueV, NewCondV)); - } - // (select c, t, 0) -> (czero_eqz t, c) if (isNullConstant(FalseV)) return DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV); @@ -6088,10 +7034,17 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const { ISD::OR, DL, VT, FalseV, DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV)); + // Try some other optimizations before falling back to generic lowering. + if (SDValue V = combineSelectToBinOp(Op.getNode(), DAG, Subtarget)) + return V; + // (select c, t, f) -> (or (czero_eqz t, c), (czero_nez f, c)) - return DAG.getNode(ISD::OR, DL, VT, - DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV), - DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, CondV)); + // Unless we have the short forward branch optimization. + if (!Subtarget.hasShortForwardBranchOpt()) + return DAG.getNode( + ISD::OR, DL, VT, + DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV, CondV), + DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV, CondV)); } if (SDValue V = combineSelectToBinOp(Op.getNode(), DAG, Subtarget)) @@ -6295,7 +7248,7 @@ SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op, // if Shamt-XLEN < 0: // Shamt < XLEN // Lo = Lo << Shamt - // Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 ^ Shamt)) + // Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 - Shamt)) // else: // Lo = 0 // Hi = Lo << (Shamt-XLEN) @@ -6334,7 +7287,7 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, // SRA expansion: // if Shamt-XLEN < 0: // Shamt < XLEN - // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1)) + // Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt)) // Hi = Hi >>s Shamt // else: // Lo = Hi >>s (Shamt-XLEN); @@ -6342,7 +7295,7 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, // // SRL expansion: // if Shamt-XLEN < 0: // Shamt < XLEN - // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1)) + // Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt)) // Hi = Hi >>u Shamt // else: // Lo = Hi >>u (Shamt-XLEN); @@ -6392,12 +7345,9 @@ SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op, SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second; return DAG.getNode(RISCVISD::VMCLR_VL, DL, VT, VL); } - MVT XLenVT = Subtarget.getXLenVT(); - assert(SplatVal.getValueType() == XLenVT && - "Unexpected type for i1 splat value"); MVT InterVT = VT.changeVectorElementType(MVT::i8); - SplatVal = DAG.getNode(ISD::AND, DL, XLenVT, SplatVal, - DAG.getConstant(1, DL, XLenVT)); + SplatVal = DAG.getNode(ISD::AND, DL, SplatVal.getValueType(), SplatVal, + DAG.getConstant(1, DL, SplatVal.getValueType())); SDValue LHS = DAG.getSplatVector(InterVT, DL, SplatVal); SDValue Zero = DAG.getConstant(0, DL, InterVT); return DAG.getSetCC(DL, VT, LHS, Zero, ISD::SETNE); @@ -6418,37 +7368,19 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op, SDValue Lo = Op.getOperand(0); SDValue Hi = Op.getOperand(1); - if (VecVT.isFixedLengthVector()) { - MVT ContainerVT = getContainerForFixedLengthVector(VecVT); - SDLoc DL(Op); - auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second; + MVT ContainerVT = VecVT; + if (VecVT.isFixedLengthVector()) + ContainerVT = getContainerForFixedLengthVector(VecVT); - SDValue Res = - splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG); - return convertFromScalableVector(VecVT, Res, DAG, Subtarget); - } + auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second; - if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) { - int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue(); - int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue(); - // If Hi constant is all the same sign bit as Lo, lower this as a custom - // node in order to try and match RVV vector/scalar instructions. - if ((LoC >> 31) == HiC) - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), - Lo, DAG.getRegister(RISCV::X0, MVT::i32)); - } + SDValue Res = + splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG); - // Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended. - if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo && - isa<ConstantSDNode>(Hi.getOperand(1)) && - Hi.getConstantOperandVal(1) == 31) - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), Lo, - DAG.getRegister(RISCV::X0, MVT::i32)); + if (VecVT.isFixedLengthVector()) + Res = convertFromScalableVector(VecVT, Res, DAG, Subtarget); - // Fall back to use a stack store and stride x0 vector load. Use X0 as VL. - return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VecVT, - DAG.getUNDEF(VecVT), Lo, Hi, - DAG.getRegister(RISCV::X0, MVT::i32)); + return Res; } // Custom-lower extensions from mask vectors by using a vselect either with 1 @@ -6752,6 +7684,32 @@ RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op, return Result; } +// Given a scalable vector type and an index into it, returns the type for the +// smallest subvector that the index fits in. This can be used to reduce LMUL +// for operations like vslidedown. +// +// E.g. With Zvl128b, index 3 in a nxv4i32 fits within the first nxv2i32. +static std::optional<MVT> +getSmallestVTForIndex(MVT VecVT, unsigned MaxIdx, SDLoc DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(VecVT.isScalableVector()); + const unsigned EltSize = VecVT.getScalarSizeInBits(); + const unsigned VectorBitsMin = Subtarget.getRealMinVLen(); + const unsigned MinVLMAX = VectorBitsMin / EltSize; + MVT SmallerVT; + if (MaxIdx < MinVLMAX) + SmallerVT = getLMUL1VT(VecVT); + else if (MaxIdx < MinVLMAX * 2) + SmallerVT = getLMUL1VT(VecVT).getDoubleNumVectorElementsVT(); + else if (MaxIdx < MinVLMAX * 4) + SmallerVT = getLMUL1VT(VecVT) + .getDoubleNumVectorElementsVT() + .getDoubleNumVectorElementsVT(); + if (!SmallerVT.isValid() || !VecVT.bitsGT(SmallerVT)) + return std::nullopt; + return SmallerVT; +} + // Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the // first position of a vector, and that vector is slid up to the insert index. // By limiting the active vector length to index+1 and merging with the @@ -6782,6 +7740,43 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } + // If we know the index we're going to insert at, we can shrink Vec so that + // we're performing the scalar inserts and slideup on a smaller LMUL. + MVT OrigContainerVT = ContainerVT; + SDValue OrigVec = Vec; + SDValue AlignedIdx; + if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx)) { + const unsigned OrigIdx = IdxC->getZExtValue(); + // Do we know an upper bound on LMUL? + if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx, + DL, DAG, Subtarget)) { + ContainerVT = *ShrunkVT; + AlignedIdx = DAG.getVectorIdxConstant(0, DL); + } + + // If we're compiling for an exact VLEN value, we can always perform + // the insert in m1 as we can determine the register corresponding to + // the index in the register group. + const unsigned MinVLen = Subtarget.getRealMinVLen(); + const unsigned MaxVLen = Subtarget.getRealMaxVLen(); + const MVT M1VT = getLMUL1VT(ContainerVT); + if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) { + EVT ElemVT = VecVT.getVectorElementType(); + unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits(); + unsigned RemIdx = OrigIdx % ElemsPerVReg; + unsigned SubRegIdx = OrigIdx / ElemsPerVReg; + unsigned ExtractIdx = + SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue(); + AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL); + Idx = DAG.getVectorIdxConstant(RemIdx, DL); + ContainerVT = M1VT; + } + + if (AlignedIdx) + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec, + AlignedIdx); + } + MVT XLenVT = Subtarget.getXLenVT(); bool IsLegalInsert = Subtarget.is64Bit() || Val.getValueType() != MVT::i64; @@ -6805,7 +7800,13 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, unsigned Opc = VecVT.isFloatingPoint() ? RISCVISD::VFMV_S_F_VL : RISCVISD::VMV_S_X_VL; if (isNullConstant(Idx)) { + if (!VecVT.isFloatingPoint()) + Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val); Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL); + + if (AlignedIdx) + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, + Vec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Vec; return convertFromScalableVector(VecVT, Vec, DAG, Subtarget); @@ -6838,6 +7839,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, // Bitcast back to the right container type. ValInVec = DAG.getBitcast(ContainerVT, ValInVec); + if (AlignedIdx) + ValInVec = + DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, + ValInVec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return ValInVec; return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget); @@ -6868,6 +7873,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Policy = RISCVII::TAIL_AGNOSTIC; SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec, Idx, Mask, InsertVL, Policy); + + if (AlignedIdx) + Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, + Slideup, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Slideup; return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget); @@ -6897,8 +7906,9 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); SDValue Vfirst = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Vec, Mask, VL); - return DAG.getSetCC(DL, XLenVT, Vfirst, DAG.getConstant(0, DL, XLenVT), - ISD::SETEQ); + SDValue Res = DAG.getSetCC(DL, XLenVT, Vfirst, + DAG.getConstant(0, DL, XLenVT), ISD::SETEQ); + return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res); } if (VecVT.isFixedLengthVector()) { unsigned NumElts = VecVT.getVectorNumElements(); @@ -6907,7 +7917,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, unsigned WidenVecLen; SDValue ExtractElementIdx; SDValue ExtractBitIdx; - unsigned MaxEEW = Subtarget.getELEN(); + unsigned MaxEEW = Subtarget.getELen(); MVT LargestEltVT = MVT::getIntegerVT( std::min(MaxEEW, unsigned(XLenVT.getSizeInBits()))); if (NumElts <= LargestEltVT.getSizeInBits()) { @@ -6936,8 +7946,9 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, // Extract the bit from GPR. SDValue ShiftRight = DAG.getNode(ISD::SRL, DL, XLenVT, ExtractElt, ExtractBitIdx); - return DAG.getNode(ISD::AND, DL, XLenVT, ShiftRight, - DAG.getConstant(1, DL, XLenVT)); + SDValue Res = DAG.getNode(ISD::AND, DL, XLenVT, ShiftRight, + DAG.getConstant(1, DL, XLenVT)); + return DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res); } } // Otherwise, promote to an i8 vector and extract from that. @@ -6953,6 +7964,61 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } + // If we're compiling for an exact VLEN value and we have a known + // constant index, we can always perform the extract in m1 (or + // smaller) as we can determine the register corresponding to + // the index in the register group. + const unsigned MinVLen = Subtarget.getRealMinVLen(); + const unsigned MaxVLen = Subtarget.getRealMaxVLen(); + if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx); + IdxC && MinVLen == MaxVLen && + VecVT.getSizeInBits().getKnownMinValue() > MinVLen) { + MVT M1VT = getLMUL1VT(ContainerVT); + unsigned OrigIdx = IdxC->getZExtValue(); + EVT ElemVT = VecVT.getVectorElementType(); + unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits(); + unsigned RemIdx = OrigIdx % ElemsPerVReg; + unsigned SubRegIdx = OrigIdx / ElemsPerVReg; + unsigned ExtractIdx = + SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue(); + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, Vec, + DAG.getVectorIdxConstant(ExtractIdx, DL)); + Idx = DAG.getVectorIdxConstant(RemIdx, DL); + ContainerVT = M1VT; + } + + // Reduce the LMUL of our slidedown and vmv.x.s to the smallest LMUL which + // contains our index. + std::optional<uint64_t> MaxIdx; + if (VecVT.isFixedLengthVector()) + MaxIdx = VecVT.getVectorNumElements() - 1; + if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx)) + MaxIdx = IdxC->getZExtValue(); + if (MaxIdx) { + if (auto SmallerVT = + getSmallestVTForIndex(ContainerVT, *MaxIdx, DL, DAG, Subtarget)) { + ContainerVT = *SmallerVT; + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec, + DAG.getConstant(0, DL, XLenVT)); + } + } + + // If after narrowing, the required slide is still greater than LMUL2, + // fallback to generic expansion and go through the stack. This is done + // for a subtle reason: extracting *all* elements out of a vector is + // widely expected to be linear in vector size, but because vslidedown + // is linear in LMUL, performing N extracts using vslidedown becomes + // O(n^2) / (VLEN/ETYPE) work. On the surface, going through the stack + // seems to have the same problem (the store is linear in LMUL), but the + // generic expansion *memoizes* the store, and thus for many extracts of + // the same vector we end up with one store and a bunch of loads. + // TODO: We don't have the same code for insert_vector_elt because we + // have BUILD_VECTOR and handle the degenerate case there. Should we + // consider adding an inverse BUILD_VECTOR node? + MVT LMUL2VT = getLMUL1VT(ContainerVT).getDoubleNumVectorElementsVT(); + if (ContainerVT.bitsGT(LMUL2VT) && VecVT.isFixedLengthVector()) + return SDValue(); + // If the index is 0, the vector is already in the right position. if (!isNullConstant(Idx)) { // Use a VL of 1 to avoid processing more elements than we need. @@ -7180,7 +8246,7 @@ static SDValue lowerGetVectorLength(SDNode *N, SelectionDAG &DAG, // Determine the VF that corresponds to LMUL 1 for ElementWidth. unsigned LMul1VF = RISCV::RVVBitsPerBlock / ElementWidth; // We don't support VF==1 with ELEN==32. - unsigned MinVF = RISCV::RVVBitsPerBlock / Subtarget.getELEN(); + unsigned MinVF = RISCV::RVVBitsPerBlock / Subtarget.getELen(); unsigned VF = N->getConstantOperandVal(2); assert(VF >= MinVF && VF <= (LMul1VF * 8) && isPowerOf2_32(VF) && @@ -7200,7 +8266,39 @@ static SDValue lowerGetVectorLength(SDNode *N, SelectionDAG &DAG, SDValue AVL = DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, N->getOperand(1)); SDValue ID = DAG.getTargetConstant(Intrinsic::riscv_vsetvli, DL, XLenVT); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, ID, AVL, Sew, LMul); + SDValue Res = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, ID, AVL, Sew, LMul); + return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), Res); +} + +static void getVCIXOperands(SDValue &Op, SelectionDAG &DAG, + SmallVector<SDValue> &Ops) { + SDLoc DL(Op); + + const RISCVSubtarget &Subtarget = + DAG.getMachineFunction().getSubtarget<RISCVSubtarget>(); + for (const SDValue &V : Op->op_values()) { + EVT ValType = V.getValueType(); + if (ValType.isScalableVector() && ValType.isFloatingPoint()) { + MVT InterimIVT = + MVT::getVectorVT(MVT::getIntegerVT(ValType.getScalarSizeInBits()), + ValType.getVectorElementCount()); + Ops.push_back(DAG.getBitcast(InterimIVT, V)); + } else if (ValType.isFixedLengthVector()) { + MVT OpContainerVT = getContainerForFixedLengthVector( + DAG, V.getSimpleValueType(), Subtarget); + Ops.push_back(convertToScalableVector(OpContainerVT, V, DAG, Subtarget)); + } else + Ops.push_back(V); + } +} + +// LMUL * VLEN should be greater than or equal to EGS * SEW +static inline bool isValidEGW(int EGS, EVT VT, + const RISCVSubtarget &Subtarget) { + return (Subtarget.getRealMinVLen() * + VT.getSizeInBits().getKnownMinValue()) / RISCV::RVVBitsPerBlock >= + EGS * VT.getScalarSizeInBits(); } SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, @@ -7236,12 +8334,30 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::riscv_sm3p1: Opc = RISCVISD::SM3P1; break; } + if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) { + SDValue NewOp = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1)); + SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res); + } + return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1)); } case Intrinsic::riscv_sm4ks: case Intrinsic::riscv_sm4ed: { unsigned Opc = IntNo == Intrinsic::riscv_sm4ks ? RISCVISD::SM4KS : RISCVISD::SM4ED; + + if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) { + SDValue NewOp0 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1)); + SDValue NewOp1 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2)); + SDValue Res = + DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1, Op.getOperand(3)); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res); + } + return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } @@ -7252,20 +8368,43 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1)); } case Intrinsic::riscv_clmul: + if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) { + SDValue NewOp0 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1)); + SDValue NewOp1 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2)); + SDValue Res = DAG.getNode(RISCVISD::CLMUL, DL, MVT::i64, NewOp0, NewOp1); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res); + } return DAG.getNode(RISCVISD::CLMUL, DL, XLenVT, Op.getOperand(1), Op.getOperand(2)); case Intrinsic::riscv_clmulh: - return DAG.getNode(RISCVISD::CLMULH, DL, XLenVT, Op.getOperand(1), - Op.getOperand(2)); - case Intrinsic::riscv_clmulr: - return DAG.getNode(RISCVISD::CLMULR, DL, XLenVT, Op.getOperand(1), - Op.getOperand(2)); + case Intrinsic::riscv_clmulr: { + unsigned Opc = + IntNo == Intrinsic::riscv_clmulh ? RISCVISD::CLMULH : RISCVISD::CLMULR; + if (RV64LegalI32 && Subtarget.is64Bit() && Op.getValueType() == MVT::i32) { + SDValue NewOp0 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(1)); + SDValue NewOp1 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op.getOperand(2)); + NewOp0 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp0, + DAG.getConstant(32, DL, MVT::i64)); + NewOp1 = DAG.getNode(ISD::SHL, DL, MVT::i64, NewOp1, + DAG.getConstant(32, DL, MVT::i64)); + SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0, NewOp1); + Res = DAG.getNode(ISD::SRL, DL, MVT::i64, Res, + DAG.getConstant(32, DL, MVT::i64)); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res); + } + + return DAG.getNode(Opc, DL, XLenVT, Op.getOperand(1), Op.getOperand(2)); + } case Intrinsic::experimental_get_vector_length: return lowerGetVectorLength(Op.getNode(), DAG, Subtarget); - case Intrinsic::riscv_vmv_x_s: - assert(Op.getValueType() == XLenVT && "Unexpected VT!"); - return DAG.getNode(RISCVISD::VMV_X_S, DL, Op.getValueType(), - Op.getOperand(1)); + case Intrinsic::riscv_vmv_x_s: { + SDValue Res = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Op.getOperand(1)); + return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res); + } case Intrinsic::riscv_vfmv_f_s: return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(), Op.getOperand(1), DAG.getConstant(0, DL, XLenVT)); @@ -7323,6 +8462,86 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, SelectCond, SplattedVal, Vec, VL); } + // EGS * EEW >= 128 bits + case Intrinsic::riscv_vaesdf_vv: + case Intrinsic::riscv_vaesdf_vs: + case Intrinsic::riscv_vaesdm_vv: + case Intrinsic::riscv_vaesdm_vs: + case Intrinsic::riscv_vaesef_vv: + case Intrinsic::riscv_vaesef_vs: + case Intrinsic::riscv_vaesem_vv: + case Intrinsic::riscv_vaesem_vs: + case Intrinsic::riscv_vaeskf1: + case Intrinsic::riscv_vaeskf2: + case Intrinsic::riscv_vaesz_vs: + case Intrinsic::riscv_vsm4k: + case Intrinsic::riscv_vsm4r_vv: + case Intrinsic::riscv_vsm4r_vs: { + if (!isValidEGW(4, Op.getSimpleValueType(), Subtarget) || + !isValidEGW(4, Op->getOperand(1).getSimpleValueType(), Subtarget) || + !isValidEGW(4, Op->getOperand(2).getSimpleValueType(), Subtarget)) + report_fatal_error("EGW should be greater than or equal to 4 * SEW."); + return Op; + } + // EGS * EEW >= 256 bits + case Intrinsic::riscv_vsm3c: + case Intrinsic::riscv_vsm3me: { + if (!isValidEGW(8, Op.getSimpleValueType(), Subtarget) || + !isValidEGW(8, Op->getOperand(1).getSimpleValueType(), Subtarget)) + report_fatal_error("EGW should be greater than or equal to 8 * SEW."); + return Op; + } + // zvknha(SEW=32)/zvknhb(SEW=[32|64]) + case Intrinsic::riscv_vsha2ch: + case Intrinsic::riscv_vsha2cl: + case Intrinsic::riscv_vsha2ms: { + if (Op->getSimpleValueType(0).getScalarSizeInBits() == 64 && + !Subtarget.hasStdExtZvknhb()) + report_fatal_error("SEW=64 needs Zvknhb to be enabled."); + if (!isValidEGW(4, Op.getSimpleValueType(), Subtarget) || + !isValidEGW(4, Op->getOperand(1).getSimpleValueType(), Subtarget) || + !isValidEGW(4, Op->getOperand(2).getSimpleValueType(), Subtarget)) + report_fatal_error("EGW should be greater than or equal to 4 * SEW."); + return Op; + } + case Intrinsic::riscv_sf_vc_v_x: + case Intrinsic::riscv_sf_vc_v_i: + case Intrinsic::riscv_sf_vc_v_xv: + case Intrinsic::riscv_sf_vc_v_iv: + case Intrinsic::riscv_sf_vc_v_vv: + case Intrinsic::riscv_sf_vc_v_fv: + case Intrinsic::riscv_sf_vc_v_xvv: + case Intrinsic::riscv_sf_vc_v_ivv: + case Intrinsic::riscv_sf_vc_v_vvv: + case Intrinsic::riscv_sf_vc_v_fvv: + case Intrinsic::riscv_sf_vc_v_xvw: + case Intrinsic::riscv_sf_vc_v_ivw: + case Intrinsic::riscv_sf_vc_v_vvw: + case Intrinsic::riscv_sf_vc_v_fvw: { + MVT VT = Op.getSimpleValueType(); + + SmallVector<SDValue> Ops; + getVCIXOperands(Op, DAG, Ops); + + MVT RetVT = VT; + if (VT.isFixedLengthVector()) + RetVT = getContainerForFixedLengthVector(VT); + else if (VT.isFloatingPoint()) + RetVT = MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits()), + VT.getVectorElementCount()); + + SDValue NewNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, RetVT, Ops); + + if (VT.isFixedLengthVector()) + NewNode = convertFromScalableVector(VT, NewNode, DAG, Subtarget); + else if (VT.isFloatingPoint()) + NewNode = DAG.getBitcast(VT, NewNode); + + if (Op == NewNode) + break; + + return NewNode; + } } return lowerVectorIntrinsicScalars(Op, DAG, Subtarget); @@ -7443,6 +8662,49 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, Results.push_back(Result.getValue(NF)); return DAG.getMergeValues(Results, DL); } + case Intrinsic::riscv_sf_vc_v_x_se: + case Intrinsic::riscv_sf_vc_v_i_se: + case Intrinsic::riscv_sf_vc_v_xv_se: + case Intrinsic::riscv_sf_vc_v_iv_se: + case Intrinsic::riscv_sf_vc_v_vv_se: + case Intrinsic::riscv_sf_vc_v_fv_se: + case Intrinsic::riscv_sf_vc_v_xvv_se: + case Intrinsic::riscv_sf_vc_v_ivv_se: + case Intrinsic::riscv_sf_vc_v_vvv_se: + case Intrinsic::riscv_sf_vc_v_fvv_se: + case Intrinsic::riscv_sf_vc_v_xvw_se: + case Intrinsic::riscv_sf_vc_v_ivw_se: + case Intrinsic::riscv_sf_vc_v_vvw_se: + case Intrinsic::riscv_sf_vc_v_fvw_se: { + MVT VT = Op.getSimpleValueType(); + SDLoc DL(Op); + SmallVector<SDValue> Ops; + getVCIXOperands(Op, DAG, Ops); + + MVT RetVT = VT; + if (VT.isFixedLengthVector()) + RetVT = getContainerForFixedLengthVector(VT); + else if (VT.isFloatingPoint()) + RetVT = MVT::getVectorVT(MVT::getIntegerVT(RetVT.getScalarSizeInBits()), + RetVT.getVectorElementCount()); + + SDVTList VTs = DAG.getVTList({RetVT, MVT::Other}); + SDValue NewNode = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops); + + if (VT.isFixedLengthVector()) { + SDValue FixedVector = + convertFromScalableVector(VT, NewNode, DAG, Subtarget); + NewNode = DAG.getMergeValues({FixedVector, NewNode.getValue(1)}, DL); + } else if (VT.isFloatingPoint()) { + SDValue BitCast = DAG.getBitcast(VT, NewNode.getValue(0)); + NewNode = DAG.getMergeValues({BitCast, NewNode.getValue(1)}, DL); + } + + if (Op == NewNode) + break; + + return NewNode; + } } return lowerVectorIntrinsicScalars(Op, DAG, Subtarget); @@ -7530,6 +8792,73 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op, ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops, FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand()); } + case Intrinsic::riscv_sf_vc_x_se_e8mf8: + case Intrinsic::riscv_sf_vc_x_se_e8mf4: + case Intrinsic::riscv_sf_vc_x_se_e8mf2: + case Intrinsic::riscv_sf_vc_x_se_e8m1: + case Intrinsic::riscv_sf_vc_x_se_e8m2: + case Intrinsic::riscv_sf_vc_x_se_e8m4: + case Intrinsic::riscv_sf_vc_x_se_e8m8: + case Intrinsic::riscv_sf_vc_x_se_e16mf4: + case Intrinsic::riscv_sf_vc_x_se_e16mf2: + case Intrinsic::riscv_sf_vc_x_se_e16m1: + case Intrinsic::riscv_sf_vc_x_se_e16m2: + case Intrinsic::riscv_sf_vc_x_se_e16m4: + case Intrinsic::riscv_sf_vc_x_se_e16m8: + case Intrinsic::riscv_sf_vc_x_se_e32mf2: + case Intrinsic::riscv_sf_vc_x_se_e32m1: + case Intrinsic::riscv_sf_vc_x_se_e32m2: + case Intrinsic::riscv_sf_vc_x_se_e32m4: + case Intrinsic::riscv_sf_vc_x_se_e32m8: + case Intrinsic::riscv_sf_vc_x_se_e64m1: + case Intrinsic::riscv_sf_vc_x_se_e64m2: + case Intrinsic::riscv_sf_vc_x_se_e64m4: + case Intrinsic::riscv_sf_vc_x_se_e64m8: + case Intrinsic::riscv_sf_vc_i_se_e8mf8: + case Intrinsic::riscv_sf_vc_i_se_e8mf4: + case Intrinsic::riscv_sf_vc_i_se_e8mf2: + case Intrinsic::riscv_sf_vc_i_se_e8m1: + case Intrinsic::riscv_sf_vc_i_se_e8m2: + case Intrinsic::riscv_sf_vc_i_se_e8m4: + case Intrinsic::riscv_sf_vc_i_se_e8m8: + case Intrinsic::riscv_sf_vc_i_se_e16mf4: + case Intrinsic::riscv_sf_vc_i_se_e16mf2: + case Intrinsic::riscv_sf_vc_i_se_e16m1: + case Intrinsic::riscv_sf_vc_i_se_e16m2: + case Intrinsic::riscv_sf_vc_i_se_e16m4: + case Intrinsic::riscv_sf_vc_i_se_e16m8: + case Intrinsic::riscv_sf_vc_i_se_e32mf2: + case Intrinsic::riscv_sf_vc_i_se_e32m1: + case Intrinsic::riscv_sf_vc_i_se_e32m2: + case Intrinsic::riscv_sf_vc_i_se_e32m4: + case Intrinsic::riscv_sf_vc_i_se_e32m8: + case Intrinsic::riscv_sf_vc_i_se_e64m1: + case Intrinsic::riscv_sf_vc_i_se_e64m2: + case Intrinsic::riscv_sf_vc_i_se_e64m4: + case Intrinsic::riscv_sf_vc_i_se_e64m8: + case Intrinsic::riscv_sf_vc_xv_se: + case Intrinsic::riscv_sf_vc_iv_se: + case Intrinsic::riscv_sf_vc_vv_se: + case Intrinsic::riscv_sf_vc_fv_se: + case Intrinsic::riscv_sf_vc_xvv_se: + case Intrinsic::riscv_sf_vc_ivv_se: + case Intrinsic::riscv_sf_vc_vvv_se: + case Intrinsic::riscv_sf_vc_fvv_se: + case Intrinsic::riscv_sf_vc_xvw_se: + case Intrinsic::riscv_sf_vc_ivw_se: + case Intrinsic::riscv_sf_vc_vvw_se: + case Intrinsic::riscv_sf_vc_fvw_se: { + SmallVector<SDValue> Ops; + getVCIXOperands(Op, DAG, Ops); + + SDValue NewNode = + DAG.getNode(ISD::INTRINSIC_VOID, SDLoc(Op), Op->getVTList(), Ops); + + if (Op == NewNode) + break; + + return NewNode; + } } return lowerVectorIntrinsicScalars(Op, DAG, Subtarget); @@ -7539,23 +8868,40 @@ static unsigned getRVVReductionOp(unsigned ISDOpcode) { switch (ISDOpcode) { default: llvm_unreachable("Unhandled reduction"); + case ISD::VP_REDUCE_ADD: case ISD::VECREDUCE_ADD: return RISCVISD::VECREDUCE_ADD_VL; + case ISD::VP_REDUCE_UMAX: case ISD::VECREDUCE_UMAX: return RISCVISD::VECREDUCE_UMAX_VL; + case ISD::VP_REDUCE_SMAX: case ISD::VECREDUCE_SMAX: return RISCVISD::VECREDUCE_SMAX_VL; + case ISD::VP_REDUCE_UMIN: case ISD::VECREDUCE_UMIN: return RISCVISD::VECREDUCE_UMIN_VL; + case ISD::VP_REDUCE_SMIN: case ISD::VECREDUCE_SMIN: return RISCVISD::VECREDUCE_SMIN_VL; + case ISD::VP_REDUCE_AND: case ISD::VECREDUCE_AND: return RISCVISD::VECREDUCE_AND_VL; + case ISD::VP_REDUCE_OR: case ISD::VECREDUCE_OR: return RISCVISD::VECREDUCE_OR_VL; + case ISD::VP_REDUCE_XOR: case ISD::VECREDUCE_XOR: return RISCVISD::VECREDUCE_XOR_VL; + case ISD::VP_REDUCE_FADD: + return RISCVISD::VECREDUCE_FADD_VL; + case ISD::VP_REDUCE_SEQ_FADD: + return RISCVISD::VECREDUCE_SEQ_FADD_VL; + case ISD::VP_REDUCE_FMAX: + return RISCVISD::VECREDUCE_FMAX_VL; + case ISD::VP_REDUCE_FMIN: + return RISCVISD::VECREDUCE_FMIN_VL; } + } SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op, @@ -7573,8 +8919,6 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op, "Unexpected reduction lowering"); MVT XLenVT = Subtarget.getXLenVT(); - assert(Op.getValueType() == XLenVT && - "Expected reduction output to be legalized to XLenVT"); MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { @@ -7628,6 +8972,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op, } SDValue SetCC = DAG.getSetCC(DL, XLenVT, Vec, Zero, CC); + SetCC = DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), SetCC); if (!IsVP) return SetCC; @@ -7638,7 +8983,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op, // 0 for an inactive vector, and so we've already received the neutral value: // AND gives us (0 == 0) -> 1 and OR/XOR give us (0 != 0) -> 0. Therefore we // can simply include the start value. - return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0)); + return DAG.getNode(BaseOpc, DL, Op.getValueType(), SetCC, Op.getOperand(0)); } static bool isNonZeroAVL(SDValue AVL) { @@ -7714,9 +9059,19 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); - SDValue NeutralElem = - DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags()); - return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), NeutralElem, Vec, + SDValue StartV = DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags()); + switch (BaseOpc) { + case ISD::AND: + case ISD::OR: + case ISD::UMAX: + case ISD::UMIN: + case ISD::SMAX: + case ISD::SMIN: + MVT XLenVT = Subtarget.getXLenVT(); + StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec, + DAG.getConstant(0, DL, XLenVT)); + } + return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec, Mask, VL, DL, DAG, Subtarget); } @@ -7724,11 +9079,11 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, // the vector SDValue and the scalar SDValue required to lower this to a // RISCVISD node. static std::tuple<unsigned, SDValue, SDValue> -getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT) { +getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT, + const RISCVSubtarget &Subtarget) { SDLoc DL(Op); auto Flags = Op->getFlags(); unsigned Opcode = Op.getOpcode(); - unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Opcode); switch (Opcode) { default: llvm_unreachable("Unhandled reduction"); @@ -7742,11 +9097,16 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT) { return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1), Op.getOperand(0)); case ISD::VECREDUCE_FMIN: - return std::make_tuple(RISCVISD::VECREDUCE_FMIN_VL, Op.getOperand(0), - DAG.getNeutralElement(BaseOpcode, DL, EltVT, Flags)); - case ISD::VECREDUCE_FMAX: - return std::make_tuple(RISCVISD::VECREDUCE_FMAX_VL, Op.getOperand(0), - DAG.getNeutralElement(BaseOpcode, DL, EltVT, Flags)); + case ISD::VECREDUCE_FMAX: { + MVT XLenVT = Subtarget.getXLenVT(); + SDValue Front = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0), + DAG.getConstant(0, DL, XLenVT)); + unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN) + ? RISCVISD::VECREDUCE_FMIN_VL + : RISCVISD::VECREDUCE_FMAX_VL; + return std::make_tuple(RVVOpc, Op.getOperand(0), Front); + } } } @@ -7758,7 +9118,7 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op, unsigned RVVOpcode; SDValue VectorVal, ScalarVal; std::tie(RVVOpcode, VectorVal, ScalarVal) = - getRVVFPReductionOpAndOperands(Op, DAG, VecEltVT); + getRVVFPReductionOpAndOperands(Op, DAG, VecEltVT, Subtarget); MVT VecVT = VectorVal.getSimpleValueType(); MVT ContainerVT = VecVT; @@ -7772,37 +9132,6 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op, VectorVal, Mask, VL, DL, DAG, Subtarget); } -static unsigned getRVVVPReductionOp(unsigned ISDOpcode) { - switch (ISDOpcode) { - default: - llvm_unreachable("Unhandled reduction"); - case ISD::VP_REDUCE_ADD: - return RISCVISD::VECREDUCE_ADD_VL; - case ISD::VP_REDUCE_UMAX: - return RISCVISD::VECREDUCE_UMAX_VL; - case ISD::VP_REDUCE_SMAX: - return RISCVISD::VECREDUCE_SMAX_VL; - case ISD::VP_REDUCE_UMIN: - return RISCVISD::VECREDUCE_UMIN_VL; - case ISD::VP_REDUCE_SMIN: - return RISCVISD::VECREDUCE_SMIN_VL; - case ISD::VP_REDUCE_AND: - return RISCVISD::VECREDUCE_AND_VL; - case ISD::VP_REDUCE_OR: - return RISCVISD::VECREDUCE_OR_VL; - case ISD::VP_REDUCE_XOR: - return RISCVISD::VECREDUCE_XOR_VL; - case ISD::VP_REDUCE_FADD: - return RISCVISD::VECREDUCE_FADD_VL; - case ISD::VP_REDUCE_SEQ_FADD: - return RISCVISD::VECREDUCE_SEQ_FADD_VL; - case ISD::VP_REDUCE_FMAX: - return RISCVISD::VECREDUCE_FMAX_VL; - case ISD::VP_REDUCE_FMIN: - return RISCVISD::VECREDUCE_FMIN_VL; - } -} - SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -7815,7 +9144,7 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op, return SDValue(); MVT VecVT = VecEVT.getSimpleVT(); - unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode()); + unsigned RVVOpcode = getRVVReductionOp(Op.getOpcode()); if (VecVT.isFixedLengthVector()) { auto ContainerVT = getContainerForFixedLengthVector(VecVT); @@ -7890,13 +9219,18 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op, ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } - SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, - DAG.getUNDEF(ContainerVT), SubVec, - DAG.getConstant(0, DL, XLenVT)); + if (OrigIdx == 0 && Vec.isUndef() && VecVT.isFixedLengthVector()) { + SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SubVec, + DAG.getConstant(0, DL, XLenVT)); SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget); return DAG.getBitcast(Op.getValueType(), SubVec); } + + SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SubVec, + DAG.getConstant(0, DL, XLenVT)); SDValue Mask = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first; // Set the vector length to only the number of elements we care about. Note @@ -8049,21 +9383,32 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, } } + // With an index of 0 this is a cast-like subvector, which can be performed + // with subregister operations. + if (OrigIdx == 0) + return Op; + // If the subvector vector is a fixed-length type, we cannot use subregister // manipulation to simplify the codegen; we don't know which register of a // LMUL group contains the specific subvector as we only know the minimum // register size. Therefore we must slide the vector group down the full // amount. if (SubVecVT.isFixedLengthVector()) { - // With an index of 0 this is a cast-like subvector, which can be performed - // with subregister operations. - if (OrigIdx == 0) - return Op; MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } + + // Shrink down Vec so we're performing the slidedown on a smaller LMUL. + unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1; + if (auto ShrunkVT = + getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) { + ContainerVT = *ShrunkVT; + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec, + DAG.getVectorIdxConstant(0, DL)); + } + SDValue Mask = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first; // Set the vector length to only the number of elements we care about. This @@ -8090,17 +9435,18 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, if (RemIdx == 0) return Op; - // Else we must shift our vector register directly to extract the subvector. - // Do this using VSLIDEDOWN. + // Else SubVecVT is a fractional LMUL and may need to be slid down. + assert(RISCVVType::decodeVLMUL(getLMUL(SubVecVT)).second); // If the vector type is an LMUL-group type, extract a subvector equal to the - // nearest full vector register type. This should resolve to a EXTRACT_SUBREG - // instruction. + // nearest full vector register type. MVT InterSubVT = VecVT; if (VecVT.bitsGT(getLMUL1VT(VecVT))) { + // If VecVT has an LMUL > 1, then SubVecVT should have a smaller LMUL, and + // we should have successfully decomposed the extract into a subregister. + assert(SubRegIdx != RISCV::NoSubRegister); InterSubVT = getLMUL1VT(VecVT); - Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec, - DAG.getConstant(OrigIdx - RemIdx, DL, XLenVT)); + Vec = DAG.getTargetExtractSubreg(SubRegIdx, DL, InterSubVT, Vec); } // Slide this vector register down by the desired number of elements in order @@ -8198,7 +9544,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op, // We can deinterleave through vnsrl.wi if the element type is smaller than // ELEN - if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) { + if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) { SDValue Even = getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG); SDValue Odd = @@ -8267,7 +9613,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op, // If the element type is smaller than ELEN, then we can interleave with // vwaddu.vv and vwmaccu.vx - if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) { + if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) { Interleaved = getWideningInterleave(Op.getOperand(0), Op.getOperand(1), DL, DAG, Subtarget); } else { @@ -8900,9 +10246,10 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, // * The EVL operand is promoted from i32 to i64 on RV64. // * Fixed-length vectors are converted to their scalable-vector container // types. -SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc, - bool HasMergeOp) const { +SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const { + unsigned RISCVISDOpc = getRISCVVLOp(Op); + bool HasMergeOp = hasMergeOp(RISCVISDOpc); + SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); SmallVector<SDValue, 4> Ops; @@ -9051,13 +10398,14 @@ SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op, } // Lower Floating-Point/Integer Type-Convert VP SDNodes -SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc) const { +SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, + SelectionDAG &DAG) const { SDLoc DL(Op); SDValue Src = Op.getOperand(0); SDValue Mask = Op.getOperand(1); SDValue VL = Op.getOperand(2); + unsigned RISCVISDOpc = getRISCVVLOp(Op); MVT DstVT = Op.getSimpleValueType(); MVT SrcVT = Src.getSimpleValueType(); @@ -9183,12 +10531,132 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG, return convertFromScalableVector(VT, Result, DAG, Subtarget); } -SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, - unsigned MaskOpc, - unsigned VecOpc) const { +SDValue +RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MVT VT = Op.getSimpleValueType(); + MVT XLenVT = Subtarget.getXLenVT(); + + SDValue Op1 = Op.getOperand(0); + SDValue Mask = Op.getOperand(1); + SDValue EVL = Op.getOperand(2); + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(VT); + Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget); + MVT MaskVT = getMaskTypeFor(ContainerVT); + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + } + + MVT GatherVT = ContainerVT; + MVT IndicesVT = ContainerVT.changeVectorElementTypeToInteger(); + // Check if we are working with mask vectors + bool IsMaskVector = ContainerVT.getVectorElementType() == MVT::i1; + if (IsMaskVector) { + GatherVT = IndicesVT = ContainerVT.changeVectorElementType(MVT::i8); + + // Expand input operand + SDValue SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT, + DAG.getUNDEF(IndicesVT), + DAG.getConstant(1, DL, XLenVT), EVL); + SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT, + DAG.getUNDEF(IndicesVT), + DAG.getConstant(0, DL, XLenVT), EVL); + Op1 = DAG.getNode(RISCVISD::VSELECT_VL, DL, IndicesVT, Op1, SplatOne, + SplatZero, EVL); + } + + unsigned EltSize = GatherVT.getScalarSizeInBits(); + unsigned MinSize = GatherVT.getSizeInBits().getKnownMinValue(); + unsigned VectorBitsMax = Subtarget.getRealMaxVLen(); + unsigned MaxVLMAX = + RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize); + + unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL; + // If this is SEW=8 and VLMAX is unknown or more than 256, we need + // to use vrgatherei16.vv. + // TODO: It's also possible to use vrgatherei16.vv for other types to + // decrease register width for the index calculation. + // NOTE: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16. + if (MaxVLMAX > 256 && EltSize == 8) { + // If this is LMUL=8, we have to split before using vrgatherei16.vv. + // Split the vector in half and reverse each half using a full register + // reverse. + // Swap the halves and concatenate them. + // Slide the concatenated result by (VLMax - VL). + if (MinSize == (8 * RISCV::RVVBitsPerBlock)) { + auto [LoVT, HiVT] = DAG.GetSplitDestVTs(GatherVT); + auto [Lo, Hi] = DAG.SplitVector(Op1, DL); + + SDValue LoRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo); + SDValue HiRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi); + + // Reassemble the low and high pieces reversed. + // NOTE: this Result is unmasked (because we do not need masks for + // shuffles). If in the future this has to change, we can use a SELECT_VL + // between Result and UNDEF using the mask originally passed to VP_REVERSE + SDValue Result = + DAG.getNode(ISD::CONCAT_VECTORS, DL, GatherVT, HiRev, LoRev); + + // Slide off any elements from past EVL that were reversed into the low + // elements. + unsigned MinElts = GatherVT.getVectorMinNumElements(); + SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT, + DAG.getConstant(MinElts, DL, XLenVT)); + SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL); + + Result = getVSlidedown(DAG, Subtarget, DL, GatherVT, + DAG.getUNDEF(GatherVT), Result, Diff, Mask, EVL); + + if (IsMaskVector) { + // Truncate Result back to a mask vector + Result = + DAG.getNode(RISCVISD::SETCC_VL, DL, ContainerVT, + {Result, DAG.getConstant(0, DL, GatherVT), + DAG.getCondCode(ISD::SETNE), + DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL}); + } + + if (!VT.isFixedLengthVector()) + return Result; + return convertFromScalableVector(VT, Result, DAG, Subtarget); + } + + // Just promote the int type to i16 which will double the LMUL. + IndicesVT = MVT::getVectorVT(MVT::i16, IndicesVT.getVectorElementCount()); + GatherOpc = RISCVISD::VRGATHEREI16_VV_VL; + } + + SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IndicesVT, Mask, EVL); + SDValue VecLen = + DAG.getNode(ISD::SUB, DL, XLenVT, EVL, DAG.getConstant(1, DL, XLenVT)); + SDValue VecLenSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT, + DAG.getUNDEF(IndicesVT), VecLen, EVL); + SDValue VRSUB = DAG.getNode(RISCVISD::SUB_VL, DL, IndicesVT, VecLenSplat, VID, + DAG.getUNDEF(IndicesVT), Mask, EVL); + SDValue Result = DAG.getNode(GatherOpc, DL, GatherVT, Op1, VRSUB, + DAG.getUNDEF(GatherVT), Mask, EVL); + + if (IsMaskVector) { + // Truncate Result back to a mask vector + Result = DAG.getNode( + RISCVISD::SETCC_VL, DL, ContainerVT, + {Result, DAG.getConstant(0, DL, GatherVT), DAG.getCondCode(ISD::SETNE), + DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL}); + } + + if (!VT.isFixedLengthVector()) + return Result; + return convertFromScalableVector(VT, Result, DAG, Subtarget); +} + +SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, + SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); if (VT.getVectorElementType() != MVT::i1) - return lowerVPOp(Op, DAG, VecOpc, true); + return lowerVPOp(Op, DAG); // It is safe to drop mask parameter as masked-off elements are undef. SDValue Op1 = Op->getOperand(0); @@ -9204,7 +10672,7 @@ SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, } SDLoc DL(Op); - SDValue Val = DAG.getNode(MaskOpc, DL, ContainerVT, Op1, Op2, VL); + SDValue Val = DAG.getNode(getRISCVVLOp(Op), DL, ContainerVT, Op1, Op2, VL); if (!IsFixed) return Val; return convertFromScalableVector(VT, Val, DAG, Subtarget); @@ -9364,10 +10832,7 @@ SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op, if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) { IndexVT = IndexVT.changeVectorElementType(XLenVT); - SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, Mask.getValueType(), - VL); - Index = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, IndexVT, Index, - TrueMask, VL); + Index = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index); } unsigned IntID = @@ -9466,10 +10931,7 @@ SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op, if (XLenVT == MVT::i32 && IndexVT.getVectorElementType().bitsGT(XLenVT)) { IndexVT = IndexVT.changeVectorElementType(XLenVT); - SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, Mask.getValueType(), - VL); - Index = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, IndexVT, Index, - TrueMask, VL); + Index = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index); } unsigned IntID = @@ -9537,6 +10999,8 @@ SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op, (RISCVFPRndMode::RUP << 4 * int(RoundingMode::TowardPositive)) | (RISCVFPRndMode::RMM << 4 * int(RoundingMode::NearestTiesToAway)); + RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, RMValue); + SDValue Shift = DAG.getNode(ISD::SHL, DL, XLenVT, RMValue, DAG.getConstant(2, DL, XLenVT)); SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT, @@ -9651,8 +11115,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(Res.getValue(1)); return; } - // In absense of Zfh, promote f16 to f32, then convert. - if (Op0.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) + // For bf16, or f16 in absense of Zfh, promote [b]f16 to f32 and then + // convert. + if ((Op0.getValueType() == MVT::f16 && + !Subtarget.hasStdExtZfhOrZhinx()) || + Op0.getValueType() == MVT::bf16) Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0); unsigned Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64; @@ -10279,6 +11746,136 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, } } +/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP +/// which corresponds to it. +static unsigned getVecReduceOpcode(unsigned Opc) { + switch (Opc) { + default: + llvm_unreachable("Unhandled binary to transfrom reduction"); + case ISD::ADD: + return ISD::VECREDUCE_ADD; + case ISD::UMAX: + return ISD::VECREDUCE_UMAX; + case ISD::SMAX: + return ISD::VECREDUCE_SMAX; + case ISD::UMIN: + return ISD::VECREDUCE_UMIN; + case ISD::SMIN: + return ISD::VECREDUCE_SMIN; + case ISD::AND: + return ISD::VECREDUCE_AND; + case ISD::OR: + return ISD::VECREDUCE_OR; + case ISD::XOR: + return ISD::VECREDUCE_XOR; + case ISD::FADD: + // Note: This is the associative form of the generic reduction opcode. + return ISD::VECREDUCE_FADD; + } +} + +/// Perform two related transforms whose purpose is to incrementally recognize +/// an explode_vector followed by scalar reduction as a vector reduction node. +/// This exists to recover from a deficiency in SLP which can't handle +/// forests with multiple roots sharing common nodes. In some cases, one +/// of the trees will be vectorized, and the other will remain (unprofitably) +/// scalarized. +static SDValue +combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + + // This transforms need to run before all integer types have been legalized + // to i64 (so that the vector element type matches the add type), and while + // it's safe to introduce odd sized vector types. + if (DAG.NewNodesMustHaveLegalTypes) + return SDValue(); + + // Without V, this transform isn't useful. We could form the (illegal) + // operations and let them be scalarized again, but there's really no point. + if (!Subtarget.hasVInstructions()) + return SDValue(); + + const SDLoc DL(N); + const EVT VT = N->getValueType(0); + const unsigned Opc = N->getOpcode(); + + // For FADD, we only handle the case with reassociation allowed. We + // could handle strict reduction order, but at the moment, there's no + // known reason to, and the complexity isn't worth it. + // TODO: Handle fminnum and fmaxnum here + if (!VT.isInteger() && + (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation())) + return SDValue(); + + const unsigned ReduceOpc = getVecReduceOpcode(Opc); + assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) && + "Inconsistent mappings"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + if (!LHS.hasOneUse() || !RHS.hasOneUse()) + return SDValue(); + + if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + std::swap(LHS, RHS); + + if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + !isa<ConstantSDNode>(RHS.getOperand(1))) + return SDValue(); + + uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue(); + SDValue SrcVec = RHS.getOperand(0); + EVT SrcVecVT = SrcVec.getValueType(); + assert(SrcVecVT.getVectorElementType() == VT); + if (SrcVecVT.isScalableVector()) + return SDValue(); + + if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen()) + return SDValue(); + + // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to + // reduce_op (extract_subvector [2 x VT] from V). This will form the + // root of our reduction tree. TODO: We could extend this to any two + // adjacent aligned constant indices if desired. + if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) { + uint64_t LHSIdx = + cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue(); + if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) { + EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2); + SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, + DAG.getVectorIdxConstant(0, DL)); + return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags()); + } + } + + // Match (binop (reduce (extract_subvector V, 0), + // (extract_vector_elt V, sizeof(SubVec)))) + // into a reduction of one more element from the original vector V. + if (LHS.getOpcode() != ReduceOpc) + return SDValue(); + + SDValue ReduceVec = LHS.getOperand(0); + if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && + ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) && + isNullConstant(ReduceVec.getOperand(1)) && + ReduceVec.getValueType().getVectorNumElements() == RHSIdx) { + // For illegal types (e.g. 3xi32), most will be combined again into a + // wider (hopefully legal) type. If this is a terminal state, we are + // relying on type legalization here to produce something reasonable + // and this lowering quality could probably be improved. (TODO) + EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1); + SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, + DAG.getVectorIdxConstant(0, DL)); + auto Flags = ReduceVec->getFlags(); + Flags.intersectWith(N->getFlags()); + return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags); + } + + return SDValue(); +} + + // Try to fold (<bop> x, (reduction.<bop> vec, start)) static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { @@ -10451,8 +12048,23 @@ static SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp, if (VT.isVector()) return SDValue(); - if (!Subtarget.hasShortForwardBranchOpt() || - (Slct.getOpcode() != ISD::SELECT && + if (!Subtarget.hasShortForwardBranchOpt()) { + // (select cond, x, (and x, c)) has custom lowering with Zicond. + if ((!Subtarget.hasStdExtZicond() && + !Subtarget.hasVendorXVentanaCondOps()) || + N->getOpcode() != ISD::AND) + return SDValue(); + + // Maybe harmful when condition code has multiple use. + if (Slct.getOpcode() == ISD::SELECT && !Slct.getOperand(0).hasOneUse()) + return SDValue(); + + // Maybe harmful when VT is wider than XLen. + if (VT.getSizeInBits() > Subtarget.getXLen()) + return SDValue(); + } + + if ((Slct.getOpcode() != ISD::SELECT && Slct.getOpcode() != RISCVISD::SELECT_CC) || !Slct.hasOneUse()) return SDValue(); @@ -10571,7 +12183,7 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT)); } -// Try to turn (add (xor (setcc X, Y), 1) -1) into (neg (setcc X, Y)). +// Try to turn (add (xor bool, 1) -1) into (neg bool). static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -10582,9 +12194,13 @@ static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) { if (!isAllOnesConstant(N1)) return SDValue(); - // Look for an (xor (setcc X, Y), 1). - if (N0.getOpcode() != ISD::XOR || !isOneConstant(N0.getOperand(1)) || - N0.getOperand(0).getOpcode() != ISD::SETCC) + // Look for (xor X, 1). + if (N0.getOpcode() != ISD::XOR || !isOneConstant(N0.getOperand(1))) + return SDValue(); + + // First xor input should be 0 or 1. + APInt Mask = APInt::getBitsSetFrom(VT.getSizeInBits(), 1); + if (!DAG.MaskedValueIsZero(N0.getOperand(0), Mask)) return SDValue(); // Emit a negate of the setcc. @@ -10602,6 +12218,9 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG, return V; if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; + if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) + return V; + // fold (add (select lhs, rhs, cc, 0, y), x) -> // (select lhs, rhs, cc, x, (add x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget); @@ -10730,7 +12349,7 @@ static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG, // shift amounts larger than 31 would produce poison. If we wait until // type legalization, we'll create RISCVISD::SRLW and we can't recover it // to use a BEXT instruction. - if (Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && VT == MVT::i1 && + if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && VT == MVT::i1 && N0.getValueType() == MVT::i32 && N0.getOpcode() == ISD::SRL && !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) { SDLoc DL(N0); @@ -10757,7 +12376,7 @@ static SDValue performANDCombine(SDNode *N, // shift amounts larger than 31 would produce poison. If we wait until // type legalization, we'll create RISCVISD::SRLW and we can't recover it // to use a BEXT instruction. - if (Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && + if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && N->getValueType(0) == MVT::i32 && isOneConstant(N->getOperand(1)) && N0.getOpcode() == ISD::SRL && !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) { @@ -10772,6 +12391,8 @@ static SDValue performANDCombine(SDNode *N, if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; + if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) + return V; if (DCI.isAfterLegalizeDAG()) if (SDValue V = combineDeMorganOfBoolean(N, DAG)) @@ -10782,17 +12403,64 @@ static SDValue performANDCombine(SDNode *N, return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ true, Subtarget); } +// Try to pull an xor with 1 through a select idiom that uses czero_eqz/nez. +// FIXME: Generalize to other binary operators with same operand. +static SDValue combineOrOfCZERO(SDNode *N, SDValue N0, SDValue N1, + SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::OR && "Unexpected opcode"); + + if (N0.getOpcode() != RISCVISD::CZERO_EQZ || + N1.getOpcode() != RISCVISD::CZERO_NEZ || + !N0.hasOneUse() || !N1.hasOneUse()) + return SDValue(); + + // Should have the same condition. + SDValue Cond = N0.getOperand(1); + if (Cond != N1.getOperand(1)) + return SDValue(); + + SDValue TrueV = N0.getOperand(0); + SDValue FalseV = N1.getOperand(0); + + if (TrueV.getOpcode() != ISD::XOR || FalseV.getOpcode() != ISD::XOR || + TrueV.getOperand(1) != FalseV.getOperand(1) || + !isOneConstant(TrueV.getOperand(1)) || + !TrueV.hasOneUse() || !FalseV.hasOneUse()) + return SDValue(); + + EVT VT = N->getValueType(0); + SDLoc DL(N); + + SDValue NewN0 = DAG.getNode(RISCVISD::CZERO_EQZ, DL, VT, TrueV.getOperand(0), + Cond); + SDValue NewN1 = DAG.getNode(RISCVISD::CZERO_NEZ, DL, VT, FalseV.getOperand(0), + Cond); + SDValue NewOr = DAG.getNode(ISD::OR, DL, VT, NewN0, NewN1); + return DAG.getNode(ISD::XOR, DL, VT, NewOr, TrueV.getOperand(1)); +} + static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; + if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) + return V; if (DCI.isAfterLegalizeDAG()) if (SDValue V = combineDeMorganOfBoolean(N, DAG)) return V; + // Look for Or of CZERO_EQZ/NEZ with same condition which is the select idiom. + // We may be able to pull a common operation out of the true and false value. + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (SDValue V = combineOrOfCZERO(N, N0, N1, DAG)) + return V; + if (SDValue V = combineOrOfCZERO(N, N1, N0, DAG)) + return V; + // fold (or (select cond, 0, y), x) -> // (select cond, x, (or x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget); @@ -10803,6 +12471,21 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + // Pre-promote (i32 (xor (shl -1, X), ~0)) on RV64 with Zbs so we can use + // (ADDI (BSET X0, X), -1). If we wait until/ type legalization, we'll create + // RISCVISD:::SLLW and we can't recover it to use a BSET instruction. + if (!RV64LegalI32 && Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && + N->getValueType(0) == MVT::i32 && isAllOnesConstant(N1) && + N0.getOpcode() == ISD::SHL && isAllOnesConstant(N0.getOperand(0)) && + !isa<ConstantSDNode>(N0.getOperand(1)) && N0.hasOneUse()) { + SDLoc DL(N); + SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0)); + SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1)); + SDValue Shl = DAG.getNode(ISD::SHL, DL, MVT::i64, Op0, Op1); + SDValue And = DAG.getNOT(DL, Shl, MVT::i64); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); + } + // fold (xor (sllw 1, x), -1) -> (rolw ~1, x) // NOTE: Assumes ROL being legal means ROLW is legal. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -10815,7 +12498,7 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, } // Fold (xor (setcc constant, y, setlt), 1) -> (setcc y, constant + 1, setlt) - if (N0.hasOneUse() && N0.getOpcode() == ISD::SETCC && isOneConstant(N1)) { + if (N0.getOpcode() == ISD::SETCC && isOneConstant(N1) && N0.hasOneUse()) { auto *ConstN00 = dyn_cast<ConstantSDNode>(N0.getOperand(0)); ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); if (ConstN00 && CC == ISD::SETLT) { @@ -10830,32 +12513,102 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; + if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) + return V; + // fold (xor (select cond, 0, y), x) -> // (select cond, x, (xor x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget); } -// According to the property that indexed load/store instructions -// zero-extended their indices, \p narrowIndex tries to narrow the type of index -// operand if it is matched to pattern (shl (zext x to ty), C) and bits(x) + C < -// bits(ty). -static SDValue narrowIndex(SDValue N, SelectionDAG &DAG) { - if (N.getOpcode() != ISD::SHL || !N->hasOneUse()) +static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (!VT.isVector()) return SDValue(); + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue MulOper; + unsigned AddSubOpc; + + // vmadd: (mul (add x, 1), y) -> (add (mul x, y), y) + // (mul x, add (y, 1)) -> (add x, (mul x, y)) + // vnmsub: (mul (sub 1, x), y) -> (sub y, (mul x, y)) + // (mul x, (sub 1, y)) -> (sub x, (mul x, y)) + auto IsAddSubWith1 = [&](SDValue V) -> bool { + AddSubOpc = V->getOpcode(); + if ((AddSubOpc == ISD::ADD || AddSubOpc == ISD::SUB) && V->hasOneUse()) { + SDValue Opnd = V->getOperand(1); + MulOper = V->getOperand(0); + if (AddSubOpc == ISD::SUB) + std::swap(Opnd, MulOper); + if (isOneOrOneSplat(Opnd)) + return true; + } + return false; + }; + + if (IsAddSubWith1(N0)) { + SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N1, MulOper); + return DAG.getNode(AddSubOpc, DL, VT, N1, MulVal); + } + + if (IsAddSubWith1(N1)) { + SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N0, MulOper); + return DAG.getNode(AddSubOpc, DL, VT, N0, MulVal); + } + + return SDValue(); +} + +/// According to the property that indexed load/store instructions zero-extend +/// their indices, try to narrow the type of index operand. +static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &DAG) { + if (isIndexTypeSigned(IndexType)) + return false; + + if (!N->hasOneUse()) + return false; + + EVT VT = N.getValueType(); + SDLoc DL(N); + + // In general, what we're doing here is seeing if we can sink a truncate to + // a smaller element type into the expression tree building our index. + // TODO: We can generalize this and handle a bunch more cases if useful. + + // Narrow a buildvector to the narrowest element type. This requires less + // work and less register pressure at high LMUL, and creates smaller constants + // which may be cheaper to materialize. + if (ISD::isBuildVectorOfConstantSDNodes(N.getNode())) { + KnownBits Known = DAG.computeKnownBits(N); + unsigned ActiveBits = std::max(8u, Known.countMaxActiveBits()); + LLVMContext &C = *DAG.getContext(); + EVT ResultVT = EVT::getIntegerVT(C, ActiveBits).getRoundIntegerType(C); + if (ResultVT.bitsLT(VT.getVectorElementType())) { + N = DAG.getNode(ISD::TRUNCATE, DL, + VT.changeVectorElementType(ResultVT), N); + return true; + } + } + + // Handle the pattern (shl (zext x to ty), C) and bits(x) + C < bits(ty). + if (N.getOpcode() != ISD::SHL) + return false; + SDValue N0 = N.getOperand(0); if (N0.getOpcode() != ISD::ZERO_EXTEND && N0.getOpcode() != RISCVISD::VZEXT_VL) - return SDValue(); + return false;; if (!N0->hasOneUse()) - return SDValue(); + return false;; APInt ShAmt; SDValue N1 = N.getOperand(1); if (!ISD::isConstantSplatVector(N1.getNode(), ShAmt)) - return SDValue(); + return false;; - SDLoc DL(N); SDValue Src = N0.getOperand(0); EVT SrcVT = Src.getValueType(); unsigned SrcElen = SrcVT.getScalarSizeInBits(); @@ -10865,14 +12618,15 @@ static SDValue narrowIndex(SDValue N, SelectionDAG &DAG) { // Skip if NewElen is not narrower than the original extended type. if (NewElen >= N0.getValueType().getScalarSizeInBits()) - return SDValue(); + return false; EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), NewElen); EVT NewVT = SrcVT.changeVectorElementType(NewEltVT); SDValue NewExt = DAG.getNode(N0->getOpcode(), DL, NewVT, N0->ops()); SDValue NewShAmtVec = DAG.getConstant(ShAmtV, DL, NewVT); - return DAG.getNode(ISD::SHL, DL, NewVT, NewExt, NewShAmtVec); + N = DAG.getNode(ISD::SHL, DL, NewVT, NewExt, NewShAmtVec); + return true; } // Replace (seteq (i64 (and X, 0xffffffff)), C1) with @@ -11710,7 +13464,11 @@ static SDValue performFP_TO_INTCombine(SDNode *N, return SDValue(); RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode()); - if (FRM == RISCVFPRndMode::Invalid) + // If the result is invalid, we didn't find a foldable instruction. + // If the result is dynamic, then we found an frint which we don't yet + // support. It will cause 7 to be written to the FRM CSR for vector. + // FIXME: We could support this by using VFCVT_X_F_VL/VFCVT_XU_F_VL below. + if (FRM == RISCVFPRndMode::Invalid || FRM == RISCVFPRndMode::DYN) return SDValue(); SDLoc DL(N); @@ -11943,10 +13701,18 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) { VL); } -static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG)) return V; + if (N->getValueType(0).isScalableVector() && + N->getValueType(0).getVectorElementType() == MVT::f32 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + return SDValue(); + } + // FIXME: Ignore strict opcodes for now. if (N->isTargetStrictFPOpcode()) return SDValue(); @@ -11997,7 +13763,15 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) { N->getOperand(2), Mask, VL); } -static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (N->getValueType(0).isScalableVector() && + N->getValueType(0).getVectorElementType() == MVT::f32 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + return SDValue(); + } + // FIXME: Ignore strict opcodes for now. assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode"); @@ -12030,7 +13804,15 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) { Op1, Merge, Mask, VL); } -static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (N->getValueType(0).isScalableVector() && + N->getValueType(0).getVectorElementType() == MVT::f32 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) { + return SDValue(); + } + SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Merge = N->getOperand(2); @@ -12261,12 +14043,10 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL, // shift can be omitted. // Fold setlt (sra X, N), 0 -> setlt X, 0 and // setge (sra X, N), 0 -> setge X, 0 - if (auto *RHSConst = dyn_cast<ConstantSDNode>(RHS.getNode())) { - if ((CCVal == ISD::SETGE || CCVal == ISD::SETLT) && - LHS.getOpcode() == ISD::SRA && RHSConst->isZero()) { - LHS = LHS.getOperand(0); - return true; - } + if (isNullConstant(RHS) && (CCVal == ISD::SETGE || CCVal == ISD::SETLT) && + LHS.getOpcode() == ISD::SRA) { + LHS = LHS.getOperand(0); + return true; } if (!ISD::isIntEqualitySetCC(CCVal)) @@ -12352,9 +14132,13 @@ static SDValue tryFoldSelectIntoOp(SDNode *N, SelectionDAG &DAG, SDValue TrueVal, SDValue FalseVal, bool Swapped) { bool Commutative = true; - switch (TrueVal.getOpcode()) { + unsigned Opc = TrueVal.getOpcode(); + switch (Opc) { default: return SDValue(); + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: case ISD::SUB: Commutative = false; break; @@ -12377,12 +14161,18 @@ static SDValue tryFoldSelectIntoOp(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); SDLoc DL(N); - SDValue Zero = DAG.getConstant(0, DL, VT); SDValue OtherOp = TrueVal.getOperand(1 - OpToFold); + EVT OtherOpVT = OtherOp->getValueType(0); + SDValue IdentityOperand = + DAG.getNeutralElement(Opc, DL, OtherOpVT, N->getFlags()); + if (!Commutative) + IdentityOperand = DAG.getConstant(0, DL, OtherOpVT); + assert(IdentityOperand && "No identity operand!"); if (Swapped) - std::swap(OtherOp, Zero); - SDValue NewSel = DAG.getSelect(DL, VT, N->getOperand(0), OtherOp, Zero); + std::swap(OtherOp, IdentityOperand); + SDValue NewSel = + DAG.getSelect(DL, OtherOpVT, N->getOperand(0), OtherOp, IdentityOperand); return DAG.getNode(TrueVal.getOpcode(), DL, VT, FalseVal, NewSel); } @@ -12447,11 +14237,45 @@ static SDValue foldSelectOfCTTZOrCTLZ(SDNode *N, SelectionDAG &DAG) { return DAG.getZExtOrTrunc(AndNode, SDLoc(N), N->getValueType(0)); } +static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDValue Cond = N->getOperand(0); + SDValue True = N->getOperand(1); + SDValue False = N->getOperand(2); + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT CondVT = Cond.getValueType(); + + if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse()) + return SDValue(); + + // Replace (setcc eq (and x, C)) with (setcc ne (and x, C))) to generate + // BEXTI, where C is power of 2. + if (Subtarget.hasStdExtZbs() && VT.isScalarInteger() && + (Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps())) { + SDValue LHS = Cond.getOperand(0); + SDValue RHS = Cond.getOperand(1); + ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); + if (CC == ISD::SETEQ && LHS.getOpcode() == ISD::AND && + isa<ConstantSDNode>(LHS.getOperand(1)) && isNullConstant(RHS)) { + uint64_t MaskVal = LHS.getConstantOperandVal(1); + if (isPowerOf2_64(MaskVal) && !isInt<12>(MaskVal)) + return DAG.getSelect(DL, VT, + DAG.getSetCC(DL, CondVT, LHS, RHS, ISD::SETNE), + False, True); + } + } + return SDValue(); +} + static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG)) return Folded; + if (SDValue V = useInversedSetcc(N, DAG, Subtarget)) + return V; + if (Subtarget.hasShortForwardBranchOpt()) return SDValue(); @@ -12462,6 +14286,132 @@ static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG, return tryFoldSelectIntoOp(N, DAG, FalseVal, TrueVal, /*Swapped*/true); } +/// If we have a build_vector where each lane is binop X, C, where C +/// is a constant (but not necessarily the same constant on all lanes), +/// form binop (build_vector x1, x2, ...), (build_vector c1, c2, c3, ..). +/// We assume that materializing a constant build vector will be no more +/// expensive that performing O(n) binops. +static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + const RISCVTargetLowering &TLI) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + assert(!VT.isScalableVector() && "unexpected build vector"); + + if (VT.getVectorNumElements() == 1) + return SDValue(); + + const unsigned Opcode = N->op_begin()->getNode()->getOpcode(); + if (!TLI.isBinOp(Opcode)) + return SDValue(); + + if (!TLI.isOperationLegalOrCustom(Opcode, VT) || !TLI.isTypeLegal(VT)) + return SDValue(); + + SmallVector<SDValue> LHSOps; + SmallVector<SDValue> RHSOps; + for (SDValue Op : N->ops()) { + if (Op.isUndef()) { + // We can't form a divide or remainder from undef. + if (!DAG.isSafeToSpeculativelyExecute(Opcode)) + return SDValue(); + + LHSOps.push_back(Op); + RHSOps.push_back(Op); + continue; + } + + // TODO: We can handle operations which have an neutral rhs value + // (e.g. x + 0, a * 1 or a << 0), but we then have to keep track + // of profit in a more explicit manner. + if (Op.getOpcode() != Opcode || !Op.hasOneUse()) + return SDValue(); + + LHSOps.push_back(Op.getOperand(0)); + if (!isa<ConstantSDNode>(Op.getOperand(1)) && + !isa<ConstantFPSDNode>(Op.getOperand(1))) + return SDValue(); + // FIXME: Return failure if the RHS type doesn't match the LHS. Shifts may + // have different LHS and RHS types. + if (Op.getOperand(0).getValueType() != Op.getOperand(1).getValueType()) + return SDValue(); + RHSOps.push_back(Op.getOperand(1)); + } + + return DAG.getNode(Opcode, DL, VT, DAG.getBuildVector(VT, DL, LHSOps), + DAG.getBuildVector(VT, DL, RHSOps)); +} + +static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + const RISCVTargetLowering &TLI) { + SDValue InVec = N->getOperand(0); + SDValue InVal = N->getOperand(1); + SDValue EltNo = N->getOperand(2); + SDLoc DL(N); + + EVT VT = InVec.getValueType(); + if (VT.isScalableVector()) + return SDValue(); + + if (!InVec.hasOneUse()) + return SDValue(); + + // Given insert_vector_elt (binop a, VecC), (same_binop b, C2), Elt + // move the insert_vector_elts into the arms of the binop. Note that + // the new RHS must be a constant. + const unsigned InVecOpcode = InVec->getOpcode(); + if (InVecOpcode == InVal->getOpcode() && TLI.isBinOp(InVecOpcode) && + InVal.hasOneUse()) { + SDValue InVecLHS = InVec->getOperand(0); + SDValue InVecRHS = InVec->getOperand(1); + SDValue InValLHS = InVal->getOperand(0); + SDValue InValRHS = InVal->getOperand(1); + + if (!ISD::isBuildVectorOfConstantSDNodes(InVecRHS.getNode())) + return SDValue(); + if (!isa<ConstantSDNode>(InValRHS) && !isa<ConstantFPSDNode>(InValRHS)) + return SDValue(); + // FIXME: Return failure if the RHS type doesn't match the LHS. Shifts may + // have different LHS and RHS types. + if (InVec.getOperand(0).getValueType() != InVec.getOperand(1).getValueType()) + return SDValue(); + SDValue LHS = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, + InVecLHS, InValLHS, EltNo); + SDValue RHS = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, + InVecRHS, InValRHS, EltNo); + return DAG.getNode(InVecOpcode, DL, VT, LHS, RHS); + } + + // Given insert_vector_elt (concat_vectors ...), InVal, Elt + // move the insert_vector_elt to the source operand of the concat_vector. + if (InVec.getOpcode() != ISD::CONCAT_VECTORS) + return SDValue(); + + auto *IndexC = dyn_cast<ConstantSDNode>(EltNo); + if (!IndexC) + return SDValue(); + unsigned Elt = IndexC->getZExtValue(); + + EVT ConcatVT = InVec.getOperand(0).getValueType(); + if (ConcatVT.getVectorElementType() != InVal.getValueType()) + return SDValue(); + unsigned ConcatNumElts = ConcatVT.getVectorNumElements(); + SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, DL, + EltNo.getValueType()); + + unsigned ConcatOpIdx = Elt / ConcatNumElts; + SDValue ConcatOp = InVec.getOperand(ConcatOpIdx); + ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT, + ConcatOp, InVal, NewIdx); + + SmallVector<SDValue> ConcatOps; + ConcatOps.append(InVec->op_begin(), InVec->op_end()); + ConcatOps[ConcatOpIdx] = ConcatOp; + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps); +} + // If we're concatenating a series of vector loads like // concat_vectors (load v4i8, p+0), (load v4i8, p+n), (load v4i8, p+n*2) ... // Then we can turn this into a strided load by widening the vector elements @@ -12486,13 +14436,11 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); EVT BaseLdVT = BaseLd->getValueType(0); - SDValue BasePtr = BaseLd->getBasePtr(); // Go through the loads and check that they're strided - SDValue CurPtr = BasePtr; - SDValue Stride; + SmallVector<LoadSDNode *> Lds; + Lds.push_back(BaseLd); Align Align = BaseLd->getAlign(); - for (SDValue Op : N->ops().drop_front()) { auto *Ld = dyn_cast<LoadSDNode>(Op); if (!Ld || !Ld->isSimple() || !Op.hasOneUse() || @@ -12500,42 +14448,46 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, Ld->getValueType(0) != BaseLdVT) return SDValue(); - SDValue Ptr = Ld->getBasePtr(); - // Check that each load's pointer is (add CurPtr, Stride) - if (Ptr.getOpcode() != ISD::ADD || Ptr.getOperand(0) != CurPtr) - return SDValue(); - SDValue Offset = Ptr.getOperand(1); - if (!Stride) - Stride = Offset; - else if (Offset != Stride) - return SDValue(); + Lds.push_back(Ld); // The common alignment is the most restrictive (smallest) of all the loads Align = std::min(Align, Ld->getAlign()); - - CurPtr = Ptr; } - // A special case is if the stride is exactly the width of one of the loads, - // in which case it's contiguous and can be combined into a regular vle - // without changing the element size - if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride); - ConstStride && - ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) { - MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( - BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), - VT.getStoreSize(), Align); - // Can't do the combine if the load isn't naturally aligned with the element - // type - if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(), - DAG.getDataLayout(), VT, *MMO)) + using PtrDiff = std::pair<std::variant<int64_t, SDValue>, bool>; + auto GetPtrDiff = [&DAG](LoadSDNode *Ld1, + LoadSDNode *Ld2) -> std::optional<PtrDiff> { + // If the load ptrs can be decomposed into a common (Base + Index) with a + // common constant stride, then return the constant stride. + BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG); + BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG); + if (BIO1.equalBaseIndex(BIO2, DAG)) + return {{BIO2.getOffset() - BIO1.getOffset(), false}}; + + // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride) + SDValue P1 = Ld1->getBasePtr(); + SDValue P2 = Ld2->getBasePtr(); + if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1) + return {{P2.getOperand(1), false}}; + if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2) + return {{P1.getOperand(1), true}}; + + return std::nullopt; + }; + + // Get the distance between the first and second loads + auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]); + if (!BaseDiff) + return SDValue(); + + // Check all the loads are the same distance apart + for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++) + if (GetPtrDiff(*It, *std::next(It)) != BaseDiff) return SDValue(); - SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO); - for (SDValue Ld : N->ops()) - DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), WideLoad); - return WideLoad; - } + // TODO: At this point, we've successfully matched a generalized gather + // load. Maybe we should emit that, and then move the specialized + // matchers above and below into a DAG combine? // Get the widened scalar type, e.g. v4i8 -> i64 unsigned WideScalarBitWidth = @@ -12551,21 +14503,29 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) return SDValue(); - MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT); - SDValue VL = - getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second; - SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other}); + auto [StrideVariant, MustNegateStride] = *BaseDiff; + SDValue Stride = std::holds_alternative<SDValue>(StrideVariant) + ? std::get<SDValue>(StrideVariant) + : DAG.getConstant(std::get<int64_t>(StrideVariant), DL, + Lds[0]->getOffset().getValueType()); + if (MustNegateStride) + Stride = DAG.getNegative(Stride, DL, Stride.getValueType()); + + SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other}); SDValue IntID = - DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT()); - SDValue Ops[] = {BaseLd->getChain(), - IntID, - DAG.getUNDEF(ContainerVT), - BasePtr, - Stride, - VL}; + DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL, + Subtarget.getXLenVT()); + + SDValue AllOneMask = + DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL, + DAG.getConstant(1, DL, MVT::i1)); + + SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT), + BaseLd->getBasePtr(), Stride, AllOneMask}; uint64_t MemSize; - if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride)) + if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride); + ConstStride && ConstStride->getSExtValue() >= 0) // total size = (elsize * n) + (stride - elsize) * (n-1) // = elsize + stride * (n-1) MemSize = WideScalarVT.getSizeInBits() + @@ -12583,11 +14543,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, for (SDValue Ld : N->ops()) DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad); - // Note: Perform the bitcast before the convertFromScalableVector so we have - // balanced pairs of convertFromScalable/convertToScalable - SDValue Res = DAG.getBitcast( - TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad); - return convertFromScalableVector(VT, Res, DAG, Subtarget); + return DAG.getBitcast(VT.getSimpleVT(), StridedLoad); } static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, @@ -12647,9 +14603,121 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Opc, DL, VT, Ops); } +static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index, + ISD::MemIndexType &IndexType, + RISCVTargetLowering::DAGCombinerInfo &DCI) { + if (!DCI.isBeforeLegalize()) + return false; + + SelectionDAG &DAG = DCI.DAG; + const MVT XLenVT = + DAG.getMachineFunction().getSubtarget<RISCVSubtarget>().getXLenVT(); + + const EVT IndexVT = Index.getValueType(); + + // RISC-V indexed loads only support the "unsigned unscaled" addressing + // mode, so anything else must be manually legalized. + if (!isIndexTypeSigned(IndexType)) + return false; + + if (IndexVT.getVectorElementType().bitsLT(XLenVT)) { + // Any index legalization should first promote to XLenVT, so we don't lose + // bits when scaling. This may create an illegal index type so we let + // LLVM's legalization take care of the splitting. + // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet. + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, + IndexVT.changeVectorElementType(XLenVT), Index); + } + IndexType = ISD::UNSIGNED_SCALED; + return true; +} + +/// Match the index vector of a scatter or gather node as the shuffle mask +/// which performs the rearrangement if possible. Will only match if +/// all lanes are touched, and thus replacing the scatter or gather with +/// a unit strided access and shuffle is legal. +static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask, + SmallVector<int> &ShuffleMask) { + if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode())) + return false; + if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode())) + return false; + + const unsigned ElementSize = VT.getScalarStoreSize(); + const unsigned NumElems = VT.getVectorNumElements(); + + // Create the shuffle mask and check all bits active + assert(ShuffleMask.empty()); + BitVector ActiveLanes(NumElems); + for (unsigned i = 0; i < Index->getNumOperands(); i++) { + // TODO: We've found an active bit of UB, and could be + // more aggressive here if desired. + if (Index->getOperand(i)->isUndef()) + return false; + uint64_t C = Index->getConstantOperandVal(i); + if (C % ElementSize != 0) + return false; + C = C / ElementSize; + if (C >= NumElems) + return false; + ShuffleMask.push_back(C); + ActiveLanes.set(C); + } + return ActiveLanes.all(); +} + +/// Match the index of a gather or scatter operation as an operation +/// with twice the element width and half the number of elements. This is +/// generally profitable (if legal) because these operations are linear +/// in VL, so even if we cause some extract VTYPE/VL toggles, we still +/// come out ahead. +static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, + Align BaseAlign, const RISCVSubtarget &ST) { + if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode())) + return false; + if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode())) + return false; + + // Attempt a doubling. If we can use a element type 4x or 8x in + // size, this will happen via multiply iterations of the transform. + const unsigned NumElems = VT.getVectorNumElements(); + if (NumElems % 2 != 0) + return false; + + const unsigned ElementSize = VT.getScalarStoreSize(); + const unsigned WiderElementSize = ElementSize * 2; + if (WiderElementSize > ST.getELen()/8) + return false; + + if (!ST.hasFastUnalignedAccess() && BaseAlign < WiderElementSize) + return false; + + for (unsigned i = 0; i < Index->getNumOperands(); i++) { + // TODO: We've found an active bit of UB, and could be + // more aggressive here if desired. + if (Index->getOperand(i)->isUndef()) + return false; + // TODO: This offset check is too strict if we support fully + // misaligned memory operations. + uint64_t C = Index->getConstantOperandVal(i); + if (i % 2 == 0) { + if (C % WiderElementSize != 0) + return false; + continue; + } + uint64_t Last = Index->getConstantOperandVal(i-1); + if (C != Last + ElementSize) + return false; + } + return true; +} + + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; + const MVT XLenVT = Subtarget.getXLenVT(); + SDLoc DL(N); // Helper to call SimplifyDemandedBits on an operand of N where only some low // bits are demanded. N will be added to the Worklist if it was not deleted. @@ -12681,8 +14749,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return DCI.CombineTo(N, Lo, Hi); } - SDLoc DL(N); - // It's cheaper to materialise two 32-bit integers than to load a double // from the constant pool and transfer it to integer registers through the // stack. @@ -12789,14 +14855,21 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return performORCombine(N, DCI, Subtarget); case ISD::XOR: return performXORCombine(N, DAG, Subtarget); + case ISD::MUL: + return performMULCombine(N, DAG); case ISD::FADD: case ISD::UMAX: case ISD::UMIN: case ISD::SMAX: case ISD::SMIN: case ISD::FMAXNUM: - case ISD::FMINNUM: - return combineBinOpToReduce(N, DAG, Subtarget); + case ISD::FMINNUM: { + if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) + return V; + return SDValue(); + } case ISD::SETCC: return performSETCCCombine(N, DAG, Subtarget); case ISD::SIGN_EXTEND_INREG: @@ -12823,6 +14896,56 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } } return SDValue(); + case RISCVISD::TRUNCATE_VECTOR_VL: { + // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) + // This would be benefit for the cases where X and Y are both the same value + // type of low precision vectors. Since the truncate would be lowered into + // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate + // restriction, such pattern would be expanded into a series of "vsetvli" + // and "vnsrl" instructions later to reach this point. + auto IsTruncNode = [](SDValue V) { + if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + return false; + SDValue VL = V.getOperand(2); + auto *C = dyn_cast<ConstantSDNode>(VL); + // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand + bool IsVLMAXForVMSET = (C && C->isAllOnes()) || + (isa<RegisterSDNode>(VL) && + cast<RegisterSDNode>(VL)->getReg() == RISCV::X0); + return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && + IsVLMAXForVMSET; + }; + + SDValue Op = N->getOperand(0); + + // We need to first find the inner level of TRUNCATE_VECTOR_VL node + // to distinguish such pattern. + while (IsTruncNode(Op)) { + if (!Op.hasOneUse()) + return SDValue(); + Op = Op.getOperand(0); + } + + if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) { + SDValue N0 = Op.getOperand(0); + SDValue N1 = Op.getOperand(1); + if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && + N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) { + SDValue N00 = N0.getOperand(0); + SDValue N10 = N1.getOperand(0); + if (N00.getValueType().isVector() && + N00.getValueType() == N10.getValueType() && + N->getValueType(0) == N10.getValueType()) { + unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1; + SDValue SMin = DAG.getNode( + ISD::SMIN, SDLoc(N1), N->getValueType(0), N10, + DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0))); + return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin); + } + } + } + break; + } case ISD::TRUNCATE: return performTRUNCATECombine(N, DAG, Subtarget); case ISD::SELECT: @@ -12933,6 +15056,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } } + // If both true/false are an xor with 1, pull through the select. + // This can occur after op legalization if both operands are setccs that + // require an xor to invert. + // FIXME: Generalize to other binary ops with identical operand? + if (TrueV.getOpcode() == ISD::XOR && FalseV.getOpcode() == ISD::XOR && + TrueV.getOperand(1) == FalseV.getOperand(1) && + isOneConstant(TrueV.getOperand(1)) && + TrueV.hasOneUse() && FalseV.hasOneUse()) { + SDValue NewSel = DAG.getNode(RISCVISD::SELECT_CC, DL, VT, LHS, RHS, CC, + TrueV.getOperand(0), FalseV.getOperand(0)); + return DAG.getNode(ISD::XOR, DL, VT, NewSel, TrueV.getOperand(1)); + } + return SDValue(); } case RISCVISD::BR_CC: { @@ -12979,75 +15115,187 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0), DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound)); } - case ISD::MGATHER: - case ISD::MSCATTER: - case ISD::VP_GATHER: - case ISD::VP_SCATTER: { - if (!DCI.isBeforeLegalize()) - break; - SDValue Index, ScaleOp; - bool IsIndexSigned = false; - if (const auto *VPGSN = dyn_cast<VPGatherScatterSDNode>(N)) { - Index = VPGSN->getIndex(); - ScaleOp = VPGSN->getScale(); - IsIndexSigned = VPGSN->isIndexSigned(); - assert(!VPGSN->isIndexScaled() && - "Scaled gather/scatter should not be formed"); - } else { - const auto *MGSN = cast<MaskedGatherScatterSDNode>(N); - Index = MGSN->getIndex(); - ScaleOp = MGSN->getScale(); - IsIndexSigned = MGSN->isIndexSigned(); - assert(!MGSN->isIndexScaled() && - "Scaled gather/scatter should not be formed"); + case ISD::MGATHER: { + const auto *MGN = dyn_cast<MaskedGatherSDNode>(N); + const EVT VT = N->getValueType(0); + SDValue Index = MGN->getIndex(); + SDValue ScaleOp = MGN->getScale(); + ISD::MemIndexType IndexType = MGN->getIndexType(); + assert(!MGN->isIndexScaled() && + "Scaled gather/scatter should not be formed"); + + SDLoc DL(N); + if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI)) + return DAG.getMaskedGather( + N->getVTList(), MGN->getMemoryVT(), DL, + {MGN->getChain(), MGN->getPassThru(), MGN->getMask(), + MGN->getBasePtr(), Index, ScaleOp}, + MGN->getMemOperand(), IndexType, MGN->getExtensionType()); + + if (narrowIndex(Index, IndexType, DAG)) + return DAG.getMaskedGather( + N->getVTList(), MGN->getMemoryVT(), DL, + {MGN->getChain(), MGN->getPassThru(), MGN->getMask(), + MGN->getBasePtr(), Index, ScaleOp}, + MGN->getMemOperand(), IndexType, MGN->getExtensionType()); + + if (Index.getOpcode() == ISD::BUILD_VECTOR && + MGN->getExtensionType() == ISD::NON_EXTLOAD) { + if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index); + SimpleVID && SimpleVID->StepDenominator == 1) { + const int64_t StepNumerator = SimpleVID->StepNumerator; + const int64_t Addend = SimpleVID->Addend; + + // Note: We don't need to check alignment here since (by assumption + // from the existance of the gather), our offsets must be sufficiently + // aligned. + const EVT PtrVT = getPointerTy(DAG.getDataLayout()); + assert(MGN->getBasePtr()->getValueType(0) == PtrVT); + assert(IndexType == ISD::UNSIGNED_SCALED); + SDValue BasePtr = DAG.getNode(ISD::ADD, DL, PtrVT, MGN->getBasePtr(), + DAG.getConstant(Addend, DL, PtrVT)); + + SDVTList VTs = DAG.getVTList({VT, MVT::Other}); + SDValue IntID = + DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL, + XLenVT); + SDValue Ops[] = + {MGN->getChain(), IntID, MGN->getPassThru(), BasePtr, + DAG.getConstant(StepNumerator, DL, XLenVT), MGN->getMask()}; + return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, + Ops, VT, MGN->getMemOperand()); + } } - EVT IndexVT = Index.getValueType(); - MVT XLenVT = Subtarget.getXLenVT(); - // RISC-V indexed loads only support the "unsigned unscaled" addressing - // mode, so anything else must be manually legalized. - bool NeedsIdxLegalization = - (IsIndexSigned && IndexVT.getVectorElementType().bitsLT(XLenVT)); - if (!NeedsIdxLegalization) - break; + + SmallVector<int> ShuffleMask; + if (MGN->getExtensionType() == ISD::NON_EXTLOAD && + matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) { + SDValue Load = DAG.getMaskedLoad(VT, DL, MGN->getChain(), + MGN->getBasePtr(), DAG.getUNDEF(XLenVT), + MGN->getMask(), DAG.getUNDEF(VT), + MGN->getMemoryVT(), MGN->getMemOperand(), + ISD::UNINDEXED, ISD::NON_EXTLOAD); + SDValue Shuffle = + DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask); + return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL); + } + + if (MGN->getExtensionType() == ISD::NON_EXTLOAD && + matchIndexAsWiderOp(VT, Index, MGN->getMask(), + MGN->getMemOperand()->getBaseAlign(), Subtarget)) { + SmallVector<SDValue> NewIndices; + for (unsigned i = 0; i < Index->getNumOperands(); i += 2) + NewIndices.push_back(Index.getOperand(i)); + EVT IndexVT = Index.getValueType() + .getHalfNumVectorElementsVT(*DAG.getContext()); + Index = DAG.getBuildVector(IndexVT, DL, NewIndices); + + unsigned ElementSize = VT.getScalarStoreSize(); + EVT WideScalarVT = MVT::getIntegerVT(ElementSize * 8 * 2); + auto EltCnt = VT.getVectorElementCount(); + assert(EltCnt.isKnownEven() && "Splitting vector, but not in half!"); + EVT WideVT = EVT::getVectorVT(*DAG.getContext(), WideScalarVT, + EltCnt.divideCoefficientBy(2)); + SDValue Passthru = DAG.getBitcast(WideVT, MGN->getPassThru()); + EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + EltCnt.divideCoefficientBy(2)); + SDValue Mask = DAG.getSplat(MaskVT, DL, DAG.getConstant(1, DL, MVT::i1)); + + SDValue Gather = + DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), WideVT, DL, + {MGN->getChain(), Passthru, Mask, MGN->getBasePtr(), + Index, ScaleOp}, + MGN->getMemOperand(), IndexType, ISD::NON_EXTLOAD); + SDValue Result = DAG.getBitcast(VT, Gather.getValue(0)); + return DAG.getMergeValues({Result, Gather.getValue(1)}, DL); + } + break; + } + case ISD::MSCATTER:{ + const auto *MSN = dyn_cast<MaskedScatterSDNode>(N); + SDValue Index = MSN->getIndex(); + SDValue ScaleOp = MSN->getScale(); + ISD::MemIndexType IndexType = MSN->getIndexType(); + assert(!MSN->isIndexScaled() && + "Scaled gather/scatter should not be formed"); SDLoc DL(N); + if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI)) + return DAG.getMaskedScatter( + N->getVTList(), MSN->getMemoryVT(), DL, + {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(), + Index, ScaleOp}, + MSN->getMemOperand(), IndexType, MSN->isTruncatingStore()); - // Any index legalization should first promote to XLenVT, so we don't lose - // bits when scaling. This may create an illegal index type so we let - // LLVM's legalization take care of the splitting. - // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet. - if (IndexVT.getVectorElementType().bitsLT(XLenVT)) { - IndexVT = IndexVT.changeVectorElementType(XLenVT); - Index = DAG.getNode(IsIndexSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, - DL, IndexVT, Index); + if (narrowIndex(Index, IndexType, DAG)) + return DAG.getMaskedScatter( + N->getVTList(), MSN->getMemoryVT(), DL, + {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(), + Index, ScaleOp}, + MSN->getMemOperand(), IndexType, MSN->isTruncatingStore()); + + EVT VT = MSN->getValue()->getValueType(0); + SmallVector<int> ShuffleMask; + if (!MSN->isTruncatingStore() && + matchIndexAsShuffle(VT, Index, MSN->getMask(), ShuffleMask)) { + SDValue Shuffle = DAG.getVectorShuffle(VT, DL, MSN->getValue(), + DAG.getUNDEF(VT), ShuffleMask); + return DAG.getMaskedStore(MSN->getChain(), DL, Shuffle, MSN->getBasePtr(), + DAG.getUNDEF(XLenVT), MSN->getMask(), + MSN->getMemoryVT(), MSN->getMemOperand(), + ISD::UNINDEXED, false); } + break; + } + case ISD::VP_GATHER: { + const auto *VPGN = dyn_cast<VPGatherSDNode>(N); + SDValue Index = VPGN->getIndex(); + SDValue ScaleOp = VPGN->getScale(); + ISD::MemIndexType IndexType = VPGN->getIndexType(); + assert(!VPGN->isIndexScaled() && + "Scaled gather/scatter should not be formed"); + + SDLoc DL(N); + if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI)) + return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL, + {VPGN->getChain(), VPGN->getBasePtr(), Index, + ScaleOp, VPGN->getMask(), + VPGN->getVectorLength()}, + VPGN->getMemOperand(), IndexType); - ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED; - if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N)) + if (narrowIndex(Index, IndexType, DAG)) return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL, {VPGN->getChain(), VPGN->getBasePtr(), Index, ScaleOp, VPGN->getMask(), VPGN->getVectorLength()}, - VPGN->getMemOperand(), NewIndexTy); - if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N)) + VPGN->getMemOperand(), IndexType); + + break; + } + case ISD::VP_SCATTER: { + const auto *VPSN = dyn_cast<VPScatterSDNode>(N); + SDValue Index = VPSN->getIndex(); + SDValue ScaleOp = VPSN->getScale(); + ISD::MemIndexType IndexType = VPSN->getIndexType(); + assert(!VPSN->isIndexScaled() && + "Scaled gather/scatter should not be formed"); + + SDLoc DL(N); + if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI)) return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL, {VPSN->getChain(), VPSN->getValue(), VPSN->getBasePtr(), Index, ScaleOp, VPSN->getMask(), VPSN->getVectorLength()}, - VPSN->getMemOperand(), NewIndexTy); - if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N)) - return DAG.getMaskedGather( - N->getVTList(), MGN->getMemoryVT(), DL, - {MGN->getChain(), MGN->getPassThru(), MGN->getMask(), - MGN->getBasePtr(), Index, ScaleOp}, - MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType()); - const auto *MSN = cast<MaskedScatterSDNode>(N); - return DAG.getMaskedScatter( - N->getVTList(), MSN->getMemoryVT(), DL, - {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(), - Index, ScaleOp}, - MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore()); + VPSN->getMemOperand(), IndexType); + + if (narrowIndex(Index, IndexType, DAG)) + return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL, + {VPSN->getChain(), VPSN->getValue(), + VPSN->getBasePtr(), Index, ScaleOp, + VPSN->getMask(), VPSN->getVectorLength()}, + VPSN->getMemOperand(), IndexType); + break; } case RISCVISD::SRA_VL: case RISCVISD::SRL_VL: @@ -13056,7 +15304,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (ShAmt.getOpcode() == RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL) { // We don't need the upper 32 bits of a 64-bit element for a shift amount. SDLoc DL(N); - SDValue VL = N->getOperand(3); + SDValue VL = N->getOperand(4); EVT VT = N->getValueType(0); ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), ShAmt.getOperand(1), VL); @@ -13102,12 +15350,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::STRICT_VFNMADD_VL: case RISCVISD::STRICT_VFMSUB_VL: case RISCVISD::STRICT_VFNMSUB_VL: - return performVFMADD_VLCombine(N, DAG); + return performVFMADD_VLCombine(N, DAG, Subtarget); case RISCVISD::FMUL_VL: - return performVFMUL_VLCombine(N, DAG); + return performVFMUL_VLCombine(N, DAG, Subtarget); case RISCVISD::FADD_VL: case RISCVISD::FSUB_VL: - return performFADDSUB_VLCombine(N, DAG); + return performFADDSUB_VLCombine(N, DAG, Subtarget); case ISD::LOAD: case ISD::STORE: { if (DCI.isAfterLegalizeDAG()) @@ -13143,16 +15391,17 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, ISD::isBuildVectorOfConstantSDNodes(Val.getNode())) { // Get the constant vector bits APInt NewC(Val.getValueSizeInBits(), 0); + uint64_t EltSize = Val.getScalarValueSizeInBits(); for (unsigned i = 0; i < Val.getNumOperands(); i++) { if (Val.getOperand(i).isUndef()) continue; - NewC.insertBits(Val.getConstantOperandAPInt(i), - i * Val.getScalarValueSizeInBits()); + NewC.insertBits(Val.getConstantOperandAPInt(i).trunc(EltSize), + i * EltSize); } MVT NewVT = MVT::getIntegerVT(MemVT.getSizeInBits()); - if (RISCVMatInt::getIntMatCost(NewC, Subtarget.getXLen(), - Subtarget.getFeatureBits(), true) <= 2 && + if (RISCVMatInt::getIntMatCost(NewC, Subtarget.getXLen(), Subtarget, + true) <= 2 && allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), NewVT, *Store->getMemOperand())) { SDValue NewV = DAG.getConstant(NewC, DL, NewVT); @@ -13195,7 +15444,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, SDValue Src = Val.getOperand(0); MVT VecVT = Src.getSimpleValueType(); // VecVT should be scalable and memory VT should match the element type. - if (VecVT.isScalableVector() && + if (!Store->isIndexed() && VecVT.isScalableVector() && MemVT == VecVT.getVectorElementType()) { SDLoc DL(N); MVT MaskVT = getMaskTypeFor(VecVT); @@ -13220,19 +15469,51 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return Gather; break; } + case ISD::BUILD_VECTOR: + if (SDValue V = performBUILD_VECTORCombine(N, DAG, Subtarget, *this)) + return V; + break; case ISD::CONCAT_VECTORS: if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this)) return V; break; + case ISD::INSERT_VECTOR_ELT: + if (SDValue V = performINSERT_VECTOR_ELTCombine(N, DAG, Subtarget, *this)) + return V; + break; + case RISCVISD::VFMV_V_F_VL: { + const MVT VT = N->getSimpleValueType(0); + SDValue Passthru = N->getOperand(0); + SDValue Scalar = N->getOperand(1); + SDValue VL = N->getOperand(2); + + // If VL is 1, we can use vfmv.s.f. + if (isOneConstant(VL)) + return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, Passthru, Scalar, VL); + break; + } case RISCVISD::VMV_V_X_VL: { + const MVT VT = N->getSimpleValueType(0); + SDValue Passthru = N->getOperand(0); + SDValue Scalar = N->getOperand(1); + SDValue VL = N->getOperand(2); + // Tail agnostic VMV.V.X only demands the vector element bitwidth from the // scalar input. - unsigned ScalarSize = N->getOperand(1).getValueSizeInBits(); - unsigned EltWidth = N->getValueType(0).getScalarSizeInBits(); - if (ScalarSize > EltWidth && N->getOperand(0).isUndef()) + unsigned ScalarSize = Scalar.getValueSizeInBits(); + unsigned EltWidth = VT.getScalarSizeInBits(); + if (ScalarSize > EltWidth && Passthru.isUndef()) if (SimplifyDemandedLowBitsHelper(1, EltWidth)) return SDValue(N, 0); + // If VL is 1 and the scalar value won't benefit from immediate, we can + // use vmv.s.x. + ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar); + if (isOneConstant(VL) && + (!Const || Const->isZero() || + !Const->getAPIntValue().sextOrTrunc(EltWidth).isSignedIntN(5))) + return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL); + break; } case RISCVISD::VFMV_S_F_VL: { @@ -13252,6 +15533,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return Src.getOperand(0); // TODO: Use insert_subvector/extract_subvector to change widen/narrow? } + [[fallthrough]]; + } + case RISCVISD::VMV_S_X_VL: { + const MVT VT = N->getSimpleValueType(0); + SDValue Passthru = N->getOperand(0); + SDValue Scalar = N->getOperand(1); + SDValue VL = N->getOperand(2); + + // Use M1 or smaller to avoid over constraining register allocation + const MVT M1VT = getLMUL1VT(VT); + if (M1VT.bitsLT(VT)) { + SDValue M1Passthru = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, Passthru, + DAG.getVectorIdxConstant(0, DL)); + SDValue Result = + DAG.getNode(N->getOpcode(), DL, M1VT, M1Passthru, Scalar, VL); + Result = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, Result, + DAG.getConstant(0, DL, XLenVT)); + return Result; + } + + // We use a vmv.v.i if possible. We limit this to LMUL1. LMUL2 or + // higher would involve overly constraining the register allocator for + // no purpose. + if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar); + Const && !Const->isZero() && isInt<5>(Const->getSExtValue()) && + VT.bitsLE(getLMUL1VT(VT)) && Passthru.isUndef()) + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL); + break; } case ISD::INTRINSIC_VOID: @@ -13263,6 +15573,43 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, // By default we do not combine any intrinsic. default: return SDValue(); + case Intrinsic::riscv_masked_strided_load: { + MVT VT = N->getSimpleValueType(0); + auto *Load = cast<MemIntrinsicSDNode>(N); + SDValue PassThru = N->getOperand(2); + SDValue Base = N->getOperand(3); + SDValue Stride = N->getOperand(4); + SDValue Mask = N->getOperand(5); + + // If the stride is equal to the element size in bytes, we can use + // a masked.load. + const unsigned ElementSize = VT.getScalarStoreSize(); + if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride); + StrideC && StrideC->getZExtValue() == ElementSize) + return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base, + DAG.getUNDEF(XLenVT), Mask, PassThru, + Load->getMemoryVT(), Load->getMemOperand(), + ISD::UNINDEXED, ISD::NON_EXTLOAD); + return SDValue(); + } + case Intrinsic::riscv_masked_strided_store: { + auto *Store = cast<MemIntrinsicSDNode>(N); + SDValue Value = N->getOperand(2); + SDValue Base = N->getOperand(3); + SDValue Stride = N->getOperand(4); + SDValue Mask = N->getOperand(5); + + // If the stride is equal to the element size in bytes, we can use + // a masked.store. + const unsigned ElementSize = Value.getValueType().getScalarStoreSize(); + if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride); + StrideC && StrideC->getZExtValue() == ElementSize) + return DAG.getMaskedStore(Store->getChain(), DL, Value, Base, + DAG.getUNDEF(XLenVT), Mask, + Store->getMemoryVT(), Store->getMemOperand(), + ISD::UNINDEXED, false); + return SDValue(); + } case Intrinsic::riscv_vcpop: case Intrinsic::riscv_vcpop_mask: case Intrinsic::riscv_vfirst: @@ -13281,23 +15628,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return DAG.getConstant(-1, DL, VT); return DAG.getConstant(0, DL, VT); } - case Intrinsic::riscv_vloxei: - case Intrinsic::riscv_vloxei_mask: - case Intrinsic::riscv_vluxei: - case Intrinsic::riscv_vluxei_mask: - case Intrinsic::riscv_vsoxei: - case Intrinsic::riscv_vsoxei_mask: - case Intrinsic::riscv_vsuxei: - case Intrinsic::riscv_vsuxei_mask: - if (SDValue V = narrowIndex(N->getOperand(4), DAG)) { - SmallVector<SDValue, 8> Ops(N->ops()); - Ops[4] = V; - const auto *MemSD = cast<MemIntrinsicSDNode>(N); - return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), - Ops, MemSD->getMemoryVT(), - MemSD->getMemOperand()); - } - return SDValue(); } } case ISD::BITCAST: { @@ -13380,12 +15710,12 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift( // Neither constant will fit into an immediate, so find materialisation // costs. - int C1Cost = RISCVMatInt::getIntMatCost(C1Int, Ty.getSizeInBits(), - Subtarget.getFeatureBits(), - /*CompressionCost*/true); + int C1Cost = + RISCVMatInt::getIntMatCost(C1Int, Ty.getSizeInBits(), Subtarget, + /*CompressionCost*/ true); int ShiftedC1Cost = RISCVMatInt::getIntMatCost( - ShiftedC1Int, Ty.getSizeInBits(), Subtarget.getFeatureBits(), - /*CompressionCost*/true); + ShiftedC1Int, Ty.getSizeInBits(), Subtarget, + /*CompressionCost*/ true); // Materialising `c1` is cheaper than materialising `c1 << c2`, so the // combine should be prevented. @@ -13556,6 +15886,15 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known = Known.sext(BitWidth); break; } + case RISCVISD::SLLW: { + KnownBits Known2; + Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = KnownBits::shl(Known.trunc(32), Known2.trunc(5).zext(32)); + // Restore the original width by sign extending. + Known = Known.sext(BitWidth); + break; + } case RISCVISD::CTZW: { KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1); unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros(); @@ -13594,7 +15933,7 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.One.setBit(Log2_32(MinVLenB)); break; } - case RISCVISD::FPCLASS: { + case RISCVISD::FCLASS: { // fclass will only set one of the low 10 bits. Known.Zero.setBitsFrom(10); break; @@ -13609,7 +15948,7 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, break; case Intrinsic::riscv_vsetvli: case Intrinsic::riscv_vsetvlimax: - // Assume that VL output is >= 65536. + // Assume that VL output is <= 65536. // TODO: Take SEW and LMUL into account. if (BitWidth > 17) Known.Zero.setBitsFrom(17); @@ -14181,47 +16520,6 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, return TailMBB; } -static MachineBasicBlock *emitVFCVT_RM(MachineInstr &MI, MachineBasicBlock *BB, - unsigned Opcode) { - DebugLoc DL = MI.getDebugLoc(); - - const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo(); - - MachineRegisterInfo &MRI = BB->getParent()->getRegInfo(); - Register SavedFRM = MRI.createVirtualRegister(&RISCV::GPRRegClass); - - assert(MI.getNumOperands() == 8 || MI.getNumOperands() == 7); - unsigned FRMIdx = MI.getNumOperands() == 8 ? 4 : 3; - - // Update FRM and save the old value. - BuildMI(*BB, MI, DL, TII.get(RISCV::SwapFRMImm), SavedFRM) - .addImm(MI.getOperand(FRMIdx).getImm()); - - // Emit an VFCVT with the FRM == DYN - auto MIB = BuildMI(*BB, MI, DL, TII.get(Opcode)); - - for (unsigned I = 0; I < MI.getNumOperands(); I++) - if (I != FRMIdx) - MIB = MIB.add(MI.getOperand(I)); - else - MIB = MIB.add(MachineOperand::CreateImm(7)); // frm = DYN - - MIB.add(MachineOperand::CreateReg(RISCV::FRM, - /*IsDef*/ false, - /*IsImp*/ true)); - - if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept)) - MIB->setFlag(MachineInstr::MIFlag::NoFPExcept); - - // Restore FRM. - BuildMI(*BB, MI, DL, TII.get(RISCV::WriteFRM)) - .addReg(SavedFRM, RegState::Kill); - - // Erase the pseudoinstruction. - MI.eraseFromParent(); - return BB; -} - static MachineBasicBlock *emitVFROUND_NOEXCEPT_MASK(MachineInstr &MI, MachineBasicBlock *BB, unsigned CVTXOpc, @@ -14466,43 +16764,6 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, return emitQuietFCMP(MI, BB, RISCV::FLT_D_IN32X, RISCV::FEQ_D_IN32X, Subtarget); -#define PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, LMUL) \ - case RISCV::RMOpc##_##LMUL: \ - return emitVFCVT_RM(MI, BB, RISCV::Opc##_##LMUL); \ - case RISCV::RMOpc##_##LMUL##_MASK: \ - return emitVFCVT_RM(MI, BB, RISCV::Opc##_##LMUL##_MASK); - -#define PseudoVFCVT_RM_CASE(RMOpc, Opc) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, M1) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, M2) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, M4) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, MF2) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, MF4) - -#define PseudoVFCVT_RM_CASE_M8(RMOpc, Opc) \ - PseudoVFCVT_RM_CASE(RMOpc, Opc) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, M8) - -#define PseudoVFCVT_RM_CASE_MF8(RMOpc, Opc) \ - PseudoVFCVT_RM_CASE(RMOpc, Opc) \ - PseudoVFCVT_RM_LMUL_CASE(RMOpc, Opc, MF8) - - // VFCVT - PseudoVFCVT_RM_CASE_M8(PseudoVFCVT_RM_X_F_V, PseudoVFCVT_X_F_V) - PseudoVFCVT_RM_CASE_M8(PseudoVFCVT_RM_XU_F_V, PseudoVFCVT_XU_F_V) - PseudoVFCVT_RM_CASE_M8(PseudoVFCVT_RM_F_XU_V, PseudoVFCVT_F_XU_V) - PseudoVFCVT_RM_CASE_M8(PseudoVFCVT_RM_F_X_V, PseudoVFCVT_F_X_V) - - // VFWCVT - PseudoVFCVT_RM_CASE(PseudoVFWCVT_RM_XU_F_V, PseudoVFWCVT_XU_F_V); - PseudoVFCVT_RM_CASE(PseudoVFWCVT_RM_X_F_V, PseudoVFWCVT_X_F_V); - - // VFNCVT - PseudoVFCVT_RM_CASE_MF8(PseudoVFNCVT_RM_XU_F_W, PseudoVFNCVT_XU_F_W); - PseudoVFCVT_RM_CASE_MF8(PseudoVFNCVT_RM_X_F_W, PseudoVFNCVT_X_F_W); - PseudoVFCVT_RM_CASE(PseudoVFNCVT_RM_F_XU_W, PseudoVFNCVT_F_XU_W); - PseudoVFCVT_RM_CASE(PseudoVFNCVT_RM_F_X_W, PseudoVFNCVT_F_X_W); - case RISCV::PseudoVFROUND_NOEXCEPT_V_M1_MASK: return emitVFROUND_NOEXCEPT_MASK(MI, BB, RISCV::PseudoVFCVT_X_F_V_M1_MASK, RISCV::PseudoVFCVT_F_X_V_M1_MASK); @@ -14529,41 +16790,26 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case RISCV::PseudoFROUND_D_INX: case RISCV::PseudoFROUND_D_IN32X: return emitFROUND(MI, BB, Subtarget); + case TargetOpcode::STATEPOINT: + case TargetOpcode::STACKMAP: + case TargetOpcode::PATCHPOINT: + if (!Subtarget.is64Bit()) + report_fatal_error("STACKMAP, PATCHPOINT and STATEPOINT are only " + "supported on 64-bit targets"); + return emitPatchPoint(MI, BB); } } -// Returns the index to the rounding mode immediate value if any, otherwise the -// function will return None. -static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) { - uint64_t TSFlags = MI.getDesc().TSFlags; - if (!RISCVII::hasRoundModeOp(TSFlags)) - return std::nullopt; - - // The operand order - // ------------------------------------- - // | n-1 (if any) | n-2 | n-3 | n-4 | - // | policy | sew | vl | rm | - // ------------------------------------- - return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3; -} - void RISCVTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, SDNode *Node) const { - // Add FRM dependency to vector floating-point instructions with dynamic - // rounding mode. - if (auto RoundModeIdx = getRoundModeIdx(MI)) { - unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm(); - if (FRMImm == RISCVFPRndMode::DYN && !MI.readsRegister(RISCV::FRM)) { - MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*isDef*/ false, - /*isImp*/ true)); - } - } - // Add FRM dependency to any instructions with dynamic rounding mode. - unsigned Opc = MI.getOpcode(); - auto Idx = RISCV::getNamedOperandIdx(Opc, RISCV::OpName::frm); - if (Idx < 0) - return; + int Idx = RISCV::getNamedOperandIdx(MI.getOpcode(), RISCV::OpName::frm); + if (Idx < 0) { + // Vector pseudos have FRM index indicated by TSFlags. + Idx = RISCVII::getFRMOpNum(MI.getDesc()); + if (Idx < 0) + return; + } if (MI.getOperand(Idx).getImm() != RISCVFPRndMode::DYN) return; // If the instruction already reads FRM, don't add another read. @@ -14598,10 +16844,6 @@ void RISCVTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, // register-size fields in the same situations they would be for fixed // arguments. -static const MCPhysReg ArgGPRs[] = { - RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13, - RISCV::X14, RISCV::X15, RISCV::X16, RISCV::X17 -}; static const MCPhysReg ArgFPR16s[] = { RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H, RISCV::F14_H, RISCV::F15_H, RISCV::F16_H, RISCV::F17_H @@ -14626,6 +16868,14 @@ static const MCPhysReg ArgVRM4s[] = {RISCV::V8M4, RISCV::V12M4, RISCV::V16M4, RISCV::V20M4}; static const MCPhysReg ArgVRM8s[] = {RISCV::V8M8, RISCV::V16M8}; +ArrayRef<MCPhysReg> RISCV::getArgGPRs() { + static const MCPhysReg ArgGPRs[] = {RISCV::X10, RISCV::X11, RISCV::X12, + RISCV::X13, RISCV::X14, RISCV::X15, + RISCV::X16, RISCV::X17}; + + return ArrayRef(ArgGPRs); +} + // Pass a 2*XLEN argument that has been split into two XLEN values through // registers or the stack as necessary. static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1, @@ -14633,6 +16883,7 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1, MVT ValVT2, MVT LocVT2, ISD::ArgFlagsTy ArgFlags2) { unsigned XLenInBytes = XLen / 8; + ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(); if (Register Reg = State.AllocateReg(ArgGPRs)) { // At least one half can be passed via register. State.addLoc(CCValAssign::getReg(VA1.getValNo(), VA1.getValVT(), Reg, @@ -14753,6 +17004,8 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, LocInfo = CCValAssign::BCvt; } + ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(); + // If this is a variadic argument, the RISC-V calling convention requires // that it is assigned an 'even' or 'aligned' register if it has 8-byte // alignment (RV32) or 16-byte alignment (RV64). An aligned register should @@ -14779,23 +17032,29 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, // Handle passing f64 on RV32D with a soft float ABI or when floating point // registers are exhausted. if (UseGPRForF64 && XLen == 32 && ValVT == MVT::f64) { - assert(!ArgFlags.isSplit() && PendingLocs.empty() && - "Can't lower f64 if it is split"); + assert(PendingLocs.empty() && "Can't lower f64 if it is split"); // Depending on available argument GPRS, f64 may be passed in a pair of // GPRs, split between a GPR and the stack, or passed completely on the // stack. LowerCall/LowerFormalArguments/LowerReturn must recognise these // cases. Register Reg = State.AllocateReg(ArgGPRs); - LocVT = MVT::i32; if (!Reg) { unsigned StackOffset = State.AllocateStack(8, Align(8)); State.addLoc( CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo)); return false; } - if (!State.AllocateReg(ArgGPRs)) - State.AllocateStack(4, Align(4)); - State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo)); + LocVT = MVT::i32; + State.addLoc(CCValAssign::getCustomReg(ValNo, ValVT, Reg, LocVT, LocInfo)); + Register HiReg = State.AllocateReg(ArgGPRs); + if (HiReg) { + State.addLoc( + CCValAssign::getCustomReg(ValNo, ValVT, HiReg, LocVT, LocInfo)); + } else { + unsigned StackOffset = State.AllocateStack(4, Align(4)); + State.addLoc( + CCValAssign::getCustomMem(ValNo, ValVT, StackOffset, LocVT, LocInfo)); + } return false; } @@ -14996,12 +17255,18 @@ static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val, break; case CCValAssign::BCvt: if (VA.getLocVT().isInteger() && - (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val); - else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) - Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); - else + } else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) { + if (RV64LegalI32) { + Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Val); + Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val); + } else { + Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); + } + } else { Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val); + } break; } return Val; @@ -15055,13 +17320,19 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val, Val = convertToScalableVector(LocVT, Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && - (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) - Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val); - else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) - Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); - else + if (LocVT.isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { + Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, LocVT, Val); + } else if (LocVT == MVT::i64 && VA.getValVT() == MVT::f32) { + if (RV64LegalI32) { + Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Val); + } else { + Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); + } + } else { Val = DAG.getNode(ISD::BITCAST, DL, LocVT, Val); + } break; } return Val; @@ -15104,38 +17375,32 @@ static SDValue unpackFromMemLoc(SelectionDAG &DAG, SDValue Chain, } static SDValue unpackF64OnRV32DSoftABI(SelectionDAG &DAG, SDValue Chain, - const CCValAssign &VA, const SDLoc &DL) { + const CCValAssign &VA, + const CCValAssign &HiVA, + const SDLoc &DL) { assert(VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64 && "Unexpected VA"); MachineFunction &MF = DAG.getMachineFunction(); MachineFrameInfo &MFI = MF.getFrameInfo(); MachineRegisterInfo &RegInfo = MF.getRegInfo(); - if (VA.isMemLoc()) { - // f64 is passed on the stack. - int FI = - MFI.CreateFixedObject(8, VA.getLocMemOffset(), /*IsImmutable=*/true); - SDValue FIN = DAG.getFrameIndex(FI, MVT::i32); - return DAG.getLoad(MVT::f64, DL, Chain, FIN, - MachinePointerInfo::getFixedStack(MF, FI)); - } - assert(VA.isRegLoc() && "Expected register VA assignment"); Register LoVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass); RegInfo.addLiveIn(VA.getLocReg(), LoVReg); SDValue Lo = DAG.getCopyFromReg(Chain, DL, LoVReg, MVT::i32); SDValue Hi; - if (VA.getLocReg() == RISCV::X17) { + if (HiVA.isMemLoc()) { // Second half of f64 is passed on the stack. - int FI = MFI.CreateFixedObject(4, 0, /*IsImmutable=*/true); + int FI = MFI.CreateFixedObject(4, HiVA.getLocMemOffset(), + /*IsImmutable=*/true); SDValue FIN = DAG.getFrameIndex(FI, MVT::i32); Hi = DAG.getLoad(MVT::i32, DL, Chain, FIN, MachinePointerInfo::getFixedStack(MF, FI)); } else { // Second half of f64 is passed in another GPR. Register HiVReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass); - RegInfo.addLiveIn(VA.getLocReg() + 1, HiVReg); + RegInfo.addLiveIn(HiVA.getLocReg(), HiVReg); Hi = DAG.getCopyFromReg(Chain, DL, HiVReg, MVT::i32); } return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi); @@ -15340,6 +17605,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments( report_fatal_error("Unsupported calling convention"); case CallingConv::C: case CallingConv::Fast: + case CallingConv::SPIR_KERNEL: + case CallingConv::GRAAL: break; case CallingConv::GHC: if (!Subtarget.hasStdExtFOrZfinx() || !Subtarget.hasStdExtDOrZdinx()) @@ -15378,15 +17645,16 @@ SDValue RISCVTargetLowering::LowerFormalArguments( CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV); - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { + for (unsigned i = 0, e = ArgLocs.size(), InsIdx = 0; i != e; ++i, ++InsIdx) { CCValAssign &VA = ArgLocs[i]; SDValue ArgValue; // Passing f64 on RV32D with a soft float ABI must be handled as a special // case. - if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) - ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, DL); - else if (VA.isRegLoc()) - ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, Ins[i], *this); + if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) { + assert(VA.needsCustom()); + ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, ArgLocs[++i], DL); + } else if (VA.isRegLoc()) + ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, Ins[InsIdx], *this); else ArgValue = unpackFromMemLoc(DAG, Chain, VA, DL); @@ -15398,12 +17666,12 @@ SDValue RISCVTargetLowering::LowerFormalArguments( // stores are relative to that. InVals.push_back(DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue, MachinePointerInfo())); - unsigned ArgIndex = Ins[i].OrigArgIndex; - unsigned ArgPartOffset = Ins[i].PartOffset; + unsigned ArgIndex = Ins[InsIdx].OrigArgIndex; + unsigned ArgPartOffset = Ins[InsIdx].PartOffset; assert(VA.getValVT().isVector() || ArgPartOffset == 0); - while (i + 1 != e && Ins[i + 1].OrigArgIndex == ArgIndex) { + while (i + 1 != e && Ins[InsIdx + 1].OrigArgIndex == ArgIndex) { CCValAssign &PartVA = ArgLocs[i + 1]; - unsigned PartOffset = Ins[i + 1].PartOffset - ArgPartOffset; + unsigned PartOffset = Ins[InsIdx + 1].PartOffset - ArgPartOffset; SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL); if (PartVA.getValVT().isScalableVector()) Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset); @@ -15411,6 +17679,7 @@ SDValue RISCVTargetLowering::LowerFormalArguments( InVals.push_back(DAG.getLoad(PartVA.getValVT(), DL, Chain, Address, MachinePointerInfo())); ++i; + ++InsIdx; } continue; } @@ -15422,57 +17691,56 @@ SDValue RISCVTargetLowering::LowerFormalArguments( MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall(); if (IsVarArg) { - ArrayRef<MCPhysReg> ArgRegs = ArrayRef(ArgGPRs); + ArrayRef<MCPhysReg> ArgRegs = RISCV::getArgGPRs(); unsigned Idx = CCInfo.getFirstUnallocated(ArgRegs); const TargetRegisterClass *RC = &RISCV::GPRRegClass; MachineFrameInfo &MFI = MF.getFrameInfo(); MachineRegisterInfo &RegInfo = MF.getRegInfo(); RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>(); - // Offset of the first variable argument from stack pointer, and size of - // the vararg save area. For now, the varargs save area is either zero or - // large enough to hold a0-a7. - int VaArgOffset, VarArgsSaveSize; + // Size of the vararg save area. For now, the varargs save area is either + // zero or large enough to hold a0-a7. + int VarArgsSaveSize = XLenInBytes * (ArgRegs.size() - Idx); + int FI; // If all registers are allocated, then all varargs must be passed on the // stack and we don't need to save any argregs. - if (ArgRegs.size() == Idx) { - VaArgOffset = CCInfo.getStackSize(); - VarArgsSaveSize = 0; + if (VarArgsSaveSize == 0) { + int VaArgOffset = CCInfo.getStackSize(); + FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true); } else { - VarArgsSaveSize = XLenInBytes * (ArgRegs.size() - Idx); - VaArgOffset = -VarArgsSaveSize; + int VaArgOffset = -VarArgsSaveSize; + FI = MFI.CreateFixedObject(VarArgsSaveSize, VaArgOffset, true); + + // If saving an odd number of registers then create an extra stack slot to + // ensure that the frame pointer is 2*XLEN-aligned, which in turn ensures + // offsets to even-numbered registered remain 2*XLEN-aligned. + if (Idx % 2) { + MFI.CreateFixedObject( + XLenInBytes, VaArgOffset - static_cast<int>(XLenInBytes), true); + VarArgsSaveSize += XLenInBytes; + } + + SDValue FIN = DAG.getFrameIndex(FI, PtrVT); + + // Copy the integer registers that may have been used for passing varargs + // to the vararg save area. + for (unsigned I = Idx; I < ArgRegs.size(); ++I) { + const Register Reg = RegInfo.createVirtualRegister(RC); + RegInfo.addLiveIn(ArgRegs[I], Reg); + SDValue ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, XLenVT); + SDValue Store = DAG.getStore( + Chain, DL, ArgValue, FIN, + MachinePointerInfo::getFixedStack(MF, FI, (I - Idx) * XLenInBytes)); + OutChains.push_back(Store); + FIN = + DAG.getMemBasePlusOffset(FIN, TypeSize::getFixed(XLenInBytes), DL); + } } // Record the frame index of the first variable argument // which is a value necessary to VASTART. - int FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true); RVFI->setVarArgsFrameIndex(FI); - - // If saving an odd number of registers then create an extra stack slot to - // ensure that the frame pointer is 2*XLEN-aligned, which in turn ensures - // offsets to even-numbered registered remain 2*XLEN-aligned. - if (Idx % 2) { - MFI.CreateFixedObject(XLenInBytes, VaArgOffset - (int)XLenInBytes, true); - VarArgsSaveSize += XLenInBytes; - } - - // Copy the integer registers that may have been used for passing varargs - // to the vararg save area. - for (unsigned I = Idx; I < ArgRegs.size(); - ++I, VaArgOffset += XLenInBytes) { - const Register Reg = RegInfo.createVirtualRegister(RC); - RegInfo.addLiveIn(ArgRegs[I], Reg); - SDValue ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, XLenVT); - FI = MFI.CreateFixedObject(XLenInBytes, VaArgOffset, true); - SDValue PtrOff = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); - SDValue Store = DAG.getStore(Chain, DL, ArgValue, PtrOff, - MachinePointerInfo::getFixedStack(MF, FI)); - cast<StoreSDNode>(Store.getNode()) - ->getMemOperand() - ->setValue((Value *)nullptr); - OutChains.push_back(Store); - } RVFI->setVarArgsSaveSize(VarArgsSaveSize); } @@ -15626,15 +17894,16 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, SmallVector<std::pair<Register, SDValue>, 8> RegsToPass; SmallVector<SDValue, 8> MemOpChains; SDValue StackPtr; - for (unsigned i = 0, j = 0, e = ArgLocs.size(); i != e; ++i) { + for (unsigned i = 0, j = 0, e = ArgLocs.size(), OutIdx = 0; i != e; + ++i, ++OutIdx) { CCValAssign &VA = ArgLocs[i]; - SDValue ArgValue = OutVals[i]; - ISD::ArgFlagsTy Flags = Outs[i].Flags; + SDValue ArgValue = OutVals[OutIdx]; + ISD::ArgFlagsTy Flags = Outs[OutIdx].Flags; // Handle passing f64 on RV32D with a soft float ABI as a special case. - bool IsF64OnRV32DSoftABI = - VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64; - if (IsF64OnRV32DSoftABI && VA.isRegLoc()) { + if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) { + assert(VA.isRegLoc() && "Expected register VA assignment"); + assert(VA.needsCustom()); SDValue SplitF64 = DAG.getNode( RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32), ArgValue); SDValue Lo = SplitF64.getValue(0); @@ -15643,32 +17912,33 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, Register RegLo = VA.getLocReg(); RegsToPass.push_back(std::make_pair(RegLo, Lo)); - if (RegLo == RISCV::X17) { + // Get the CCValAssign for the Hi part. + CCValAssign &HiVA = ArgLocs[++i]; + + if (HiVA.isMemLoc()) { // Second half of f64 is passed on the stack. - // Work out the address of the stack slot. if (!StackPtr.getNode()) StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT); + SDValue Address = + DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, + DAG.getIntPtrConstant(HiVA.getLocMemOffset(), DL)); // Emit the store. MemOpChains.push_back( - DAG.getStore(Chain, DL, Hi, StackPtr, MachinePointerInfo())); + DAG.getStore(Chain, DL, Hi, Address, MachinePointerInfo())); } else { // Second half of f64 is passed in another GPR. - assert(RegLo < RISCV::X31 && "Invalid register pair"); - Register RegHigh = RegLo + 1; + Register RegHigh = HiVA.getLocReg(); RegsToPass.push_back(std::make_pair(RegHigh, Hi)); } continue; } - // IsF64OnRV32DSoftABI && VA.isMemLoc() is handled below in the same way - // as any other MemLoc. - // Promote the value if needed. // For now, only handle fully promoted and indirect arguments. if (VA.getLocInfo() == CCValAssign::Indirect) { // Store the argument in a stack slot and pass its address. Align StackAlign = - std::max(getPrefTypeAlign(Outs[i].ArgVT, DAG), + std::max(getPrefTypeAlign(Outs[OutIdx].ArgVT, DAG), getPrefTypeAlign(ArgValue.getValueType(), DAG)); TypeSize StoredSize = ArgValue.getValueType().getStoreSize(); // If the original argument was split (e.g. i128), we need @@ -15676,16 +17946,16 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, // Vectors may be partly split to registers and partly to the stack, in // which case the base address is partly offset and subsequent stores are // relative to that. - unsigned ArgIndex = Outs[i].OrigArgIndex; - unsigned ArgPartOffset = Outs[i].PartOffset; + unsigned ArgIndex = Outs[OutIdx].OrigArgIndex; + unsigned ArgPartOffset = Outs[OutIdx].PartOffset; assert(VA.getValVT().isVector() || ArgPartOffset == 0); // Calculate the total size to store. We don't have access to what we're // actually storing other than performing the loop and collecting the // info. SmallVector<std::pair<SDValue, SDValue>> Parts; - while (i + 1 != e && Outs[i + 1].OrigArgIndex == ArgIndex) { - SDValue PartValue = OutVals[i + 1]; - unsigned PartOffset = Outs[i + 1].PartOffset - ArgPartOffset; + while (i + 1 != e && Outs[OutIdx + 1].OrigArgIndex == ArgIndex) { + SDValue PartValue = OutVals[OutIdx + 1]; + unsigned PartOffset = Outs[OutIdx + 1].PartOffset - ArgPartOffset; SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL); EVT PartVT = PartValue.getValueType(); if (PartVT.isScalableVector()) @@ -15694,6 +17964,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, StackAlign = std::max(StackAlign, getPrefTypeAlign(PartVT, DAG)); Parts.push_back(std::make_pair(PartValue, Offset)); ++i; + ++OutIdx; } SDValue SpillSlot = DAG.CreateStackTemporary(StoredSize, StackAlign); int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex(); @@ -15835,7 +18106,8 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV); // Copy all of the result registers out of their specified physreg. - for (auto &VA : RVLocs) { + for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { + auto &VA = RVLocs[i]; // Copy the value out SDValue RetValue = DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), Glue); @@ -15844,9 +18116,9 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, Glue = RetValue.getValue(2); if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) { - assert(VA.getLocReg() == ArgGPRs[0] && "Unexpected reg assignment"); - SDValue RetValue2 = - DAG.getCopyFromReg(Chain, DL, ArgGPRs[1], MVT::i32, Glue); + assert(VA.needsCustom()); + SDValue RetValue2 = DAG.getCopyFromReg(Chain, DL, RVLocs[++i].getLocReg(), + MVT::i32, Glue); Chain = RetValue2.getValue(1); Glue = RetValue2.getValue(2); RetValue = DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, RetValue, @@ -15909,21 +18181,21 @@ RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, SmallVector<SDValue, 4> RetOps(1, Chain); // Copy the result values into the output registers. - for (unsigned i = 0, e = RVLocs.size(); i < e; ++i) { - SDValue Val = OutVals[i]; + for (unsigned i = 0, e = RVLocs.size(), OutIdx = 0; i < e; ++i, ++OutIdx) { + SDValue Val = OutVals[OutIdx]; CCValAssign &VA = RVLocs[i]; assert(VA.isRegLoc() && "Can only return in registers!"); if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) { // Handle returning f64 on RV32D with a soft float ABI. assert(VA.isRegLoc() && "Expected return via registers"); + assert(VA.needsCustom()); SDValue SplitF64 = DAG.getNode(RISCVISD::SplitF64, DL, DAG.getVTList(MVT::i32, MVT::i32), Val); SDValue Lo = SplitF64.getValue(0); SDValue Hi = SplitF64.getValue(1); Register RegLo = VA.getLocReg(); - assert(RegLo < RISCV::X31 && "Invalid register pair"); - Register RegHi = RegLo + 1; + Register RegHi = RVLocs[++i].getLocReg(); if (STI.isRegisterReservedByUser(RegLo) || STI.isRegisterReservedByUser(RegHi)) @@ -16061,10 +18333,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(ADD_LO) NODE_NAME_CASE(HI) NODE_NAME_CASE(LLA) - NODE_NAME_CASE(LGA) NODE_NAME_CASE(ADD_TPREL) - NODE_NAME_CASE(LA_TLS_IE) - NODE_NAME_CASE(LA_TLS_GD) NODE_NAME_CASE(MULHSU) NODE_NAME_CASE(SLLW) NODE_NAME_CASE(SRAW) @@ -16091,7 +18360,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(FP_ROUND_BF16) NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) - NODE_NAME_CASE(FPCLASS) + NODE_NAME_CASE(FCLASS) NODE_NAME_CASE(FMAX) NODE_NAME_CASE(FMIN) NODE_NAME_CASE(READ_CYCLE_WIDE) @@ -16153,6 +18422,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(SREM_VL) NODE_NAME_CASE(SRA_VL) NODE_NAME_CASE(SRL_VL) + NODE_NAME_CASE(ROTL_VL) + NODE_NAME_CASE(ROTR_VL) NODE_NAME_CASE(SUB_VL) NODE_NAME_CASE(UDIV_VL) NODE_NAME_CASE(UREM_VL) @@ -16187,8 +18458,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(CTLZ_VL) NODE_NAME_CASE(CTTZ_VL) NODE_NAME_CASE(CTPOP_VL) - NODE_NAME_CASE(FMINNUM_VL) - NODE_NAME_CASE(FMAXNUM_VL) + NODE_NAME_CASE(VFMIN_VL) + NODE_NAME_CASE(VFMAX_VL) NODE_NAME_CASE(MULHS_VL) NODE_NAME_CASE(MULHU_VL) NODE_NAME_CASE(VFCVT_RTZ_X_F_VL) @@ -16235,6 +18506,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VWADDU_W_VL) NODE_NAME_CASE(VWSUB_W_VL) NODE_NAME_CASE(VWSUBU_W_VL) + NODE_NAME_CASE(VWSLL_VL) NODE_NAME_CASE(VFWMUL_VL) NODE_NAME_CASE(VFWADD_VL) NODE_NAME_CASE(VFWSUB_VL) @@ -16308,6 +18580,12 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, // TODO: Support fixed vectors up to XLen for P extension? if (VT.isVector()) break; + if (VT == MVT::f16 && Subtarget.hasStdExtZhinxOrZhinxmin()) + return std::make_pair(0U, &RISCV::GPRF16RegClass); + if (VT == MVT::f32 && Subtarget.hasStdExtZfinx()) + return std::make_pair(0U, &RISCV::GPRF32RegClass); + if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit()) + return std::make_pair(0U, &RISCV::GPRPF64RegClass); return std::make_pair(0U, &RISCV::GPRNoX0RegClass); case 'f': if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) @@ -16495,13 +18773,13 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return Res; } -unsigned +InlineAsm::ConstraintCode RISCVTargetLowering::getInlineAsmMemConstraint(StringRef ConstraintCode) const { // Currently only support length 1 constraints. if (ConstraintCode.size() == 1) { switch (ConstraintCode[0]) { case 'A': - return InlineAsm::Constraint_A; + return InlineAsm::ConstraintCode::A; default: break; } @@ -16511,10 +18789,10 @@ RISCVTargetLowering::getInlineAsmMemConstraint(StringRef ConstraintCode) const { } void RISCVTargetLowering::LowerAsmOperandForConstraint( - SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops, + SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops, SelectionDAG &DAG) const { // Currently only support length 1 constraints. - if (Constraint.length() == 1) { + if (Constraint.size() == 1) { switch (Constraint[0]) { case 'I': // Validate & create a 12-bit signed immediate operand. @@ -16575,8 +18853,11 @@ Instruction *RISCVTargetLowering::emitLeadingFence(IRBuilderBase &Builder, Instruction *RISCVTargetLowering::emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const { - if (Subtarget.hasStdExtZtso()) + if (Subtarget.hasStdExtZtso()) { + if (isa<StoreInst>(Inst) && Ord == AtomicOrdering::SequentiallyConsistent) + return Builder.CreateFence(Ord); return nullptr; + } if (isa<LoadInst>(Inst) && isAcquireOrStronger(Ord)) return Builder.CreateFence(AtomicOrdering::Acquire); @@ -16660,6 +18941,22 @@ getIntrinsicForMaskedAtomicRMWBinOp(unsigned XLen, AtomicRMWInst::BinOp BinOp) { Value *RISCVTargetLowering::emitMaskedAtomicRMWIntrinsic( IRBuilderBase &Builder, AtomicRMWInst *AI, Value *AlignedAddr, Value *Incr, Value *Mask, Value *ShiftAmt, AtomicOrdering Ord) const { + // In the case of an atomicrmw xchg with a constant 0/-1 operand, replace + // the atomic instruction with an AtomicRMWInst::And/Or with appropriate + // mask, as this produces better code than the LR/SC loop emitted by + // int_riscv_masked_atomicrmw_xchg. + if (AI->getOperation() == AtomicRMWInst::Xchg && + isa<ConstantInt>(AI->getValOperand())) { + ConstantInt *CVal = cast<ConstantInt>(AI->getValOperand()); + if (CVal->isZero()) + return Builder.CreateAtomicRMW(AtomicRMWInst::And, AlignedAddr, + Builder.CreateNot(Mask, "Inv_Mask"), + AI->getAlign(), Ord); + if (CVal->isMinusOne()) + return Builder.CreateAtomicRMW(AtomicRMWInst::Or, AlignedAddr, Mask, + AI->getAlign(), Ord); + } + unsigned XLen = Subtarget.getXLen(); Value *Ordering = Builder.getIntN(XLen, static_cast<uint64_t>(AI->getOrdering())); @@ -16735,9 +19032,13 @@ Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic( return Result; } -bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT IndexVT, +bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend, EVT DataVT) const { - return false; + // We have indexed loads for all legal index types. Indices are always + // zero extended + return Extend.getOpcode() == ISD::ZERO_EXTEND && + isTypeLegal(Extend.getValueType()) && + isTypeLegal(Extend.getOperand(0).getValueType()); } bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT, @@ -16993,8 +19294,8 @@ bool RISCVTargetLowering::allowsMisalignedMemoryAccesses( unsigned *Fast) const { if (!VT.isVector()) { if (Fast) - *Fast = Subtarget.enableUnalignedScalarMem(); - return Subtarget.enableUnalignedScalarMem(); + *Fast = Subtarget.hasFastUnalignedAccess(); + return Subtarget.hasFastUnalignedAccess(); } // All vector implementations must support element alignment @@ -17010,8 +19311,51 @@ bool RISCVTargetLowering::allowsMisalignedMemoryAccesses( // misaligned accesses. TODO: Work through the codegen implications of // allowing such accesses to be formed, and considered fast. if (Fast) - *Fast = Subtarget.enableUnalignedVectorMem(); - return Subtarget.enableUnalignedVectorMem(); + *Fast = Subtarget.hasFastUnalignedAccess(); + return Subtarget.hasFastUnalignedAccess(); +} + + +EVT RISCVTargetLowering::getOptimalMemOpType(const MemOp &Op, + const AttributeList &FuncAttributes) const { + if (!Subtarget.hasVInstructions()) + return MVT::Other; + + if (FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat)) + return MVT::Other; + + // We use LMUL1 memory operations here for a non-obvious reason. Our caller + // has an expansion threshold, and we want the number of hardware memory + // operations to correspond roughly to that threshold. LMUL>1 operations + // are typically expanded linearly internally, and thus correspond to more + // than one actual memory operation. Note that store merging and load + // combining will typically form larger LMUL operations from the LMUL1 + // operations emitted here, and that's okay because combining isn't + // introducing new memory operations; it's just merging existing ones. + const unsigned MinVLenInBytes = Subtarget.getRealMinVLen()/8; + if (Op.size() < MinVLenInBytes) + // TODO: Figure out short memops. For the moment, do the default thing + // which ends up using scalar sequences. + return MVT::Other; + + // Prefer i8 for non-zero memset as it allows us to avoid materializing + // a large scalar constant and instead use vmv.v.x/i to do the + // broadcast. For everything else, prefer ELenVT to minimize VL and thus + // maximize the chance we can encode the size in the vsetvli. + MVT ELenVT = MVT::getIntegerVT(Subtarget.getELen()); + MVT PreferredVT = (Op.isMemset() && !Op.isZeroMemset()) ? MVT::i8 : ELenVT; + + // Do we have sufficient alignment for our preferred VT? If not, revert + // to largest size allowed by our alignment criteria. + if (PreferredVT != MVT::i8 && !Subtarget.hasFastUnalignedAccess()) { + Align RequiredAlign(PreferredVT.getStoreSize()); + if (Op.isFixedDstAlign()) + RequiredAlign = std::min(RequiredAlign, Op.getDstAlign()); + if (Op.isMemcpy()) + RequiredAlign = std::min(RequiredAlign, Op.getSrcAlign()); + PreferredVT = MVT::getIntegerVT(RequiredAlign.value() * 8); + } + return MVT::getVectorVT(PreferredVT, MinVLenInBytes/PreferredVT.getStoreSize()); } bool RISCVTargetLowering::splitValueIntoRegisterParts( @@ -17136,10 +19480,8 @@ static Value *useTpOffset(IRBuilderBase &IRB, unsigned Offset) { Module *M = IRB.GetInsertBlock()->getParent()->getParent(); Function *ThreadPointerFunc = Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); - return IRB.CreatePointerCast( - IRB.CreateConstGEP1_32(IRB.getInt8Ty(), - IRB.CreateCall(ThreadPointerFunc), Offset), - IRB.getInt8PtrTy()->getPointerTo(0)); + return IRB.CreateConstGEP1_32(IRB.getInt8Ty(), + IRB.CreateCall(ThreadPointerFunc), Offset); } Value *RISCVTargetLowering::getIRStackGuard(IRBuilderBase &IRB) const { @@ -17197,7 +19539,7 @@ bool RISCVTargetLowering::isLegalStridedLoadStore(EVT DataType, if (!isLegalElementTypeForRVV(ScalarType)) return false; - if (!Subtarget.enableUnalignedVectorMem() && + if (!Subtarget.hasFastUnalignedAccess() && Alignment < ScalarType.getStoreSize()) return false; @@ -17497,6 +19839,72 @@ bool RISCVTargetLowering::areTwoSDNodeTargetMMOFlagsMergeable( return getTargetMMOFlags(NodeX) == getTargetMMOFlags(NodeY); } +bool RISCVTargetLowering::isCtpopFast(EVT VT) const { + if (VT.isScalableVector()) + return isTypeLegal(VT) && Subtarget.hasStdExtZvbb(); + if (VT.isFixedLengthVector() && Subtarget.hasStdExtZvbb()) + return true; + return Subtarget.hasStdExtZbb() && + (VT == MVT::i32 || VT == MVT::i64 || VT.isFixedLengthVector()); +} + +unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT, + ISD::CondCode Cond) const { + return isCtpopFast(VT) ? 0 : 1; +} + +bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const { + // At the moment, the only scalable instruction GISel knows how to lower is + // ret with scalable argument. + + if (Inst.getType()->isScalableTy()) + return true; + + for (unsigned i = 0; i < Inst.getNumOperands(); ++i) + if (Inst.getOperand(i)->getType()->isScalableTy() && + !isa<ReturnInst>(&Inst)) + return true; + + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { + if (AI->getAllocatedType()->isScalableTy()) + return true; + } + + return false; +} + +SDValue +RISCVTargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl<SDNode *> &Created) const { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + if (isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(N, 0); // Lower SDIV as SDIV + + // Only perform this transform if short forward branch opt is supported. + if (!Subtarget.hasShortForwardBranchOpt()) + return SDValue(); + EVT VT = N->getValueType(0); + if (!(VT == MVT::i32 || (VT == MVT::i64 && Subtarget.is64Bit()))) + return SDValue(); + + // Ensure 2**k-1 < 2048 so that we can just emit a single addi/addiw. + if (Divisor.sgt(2048) || Divisor.slt(-2048)) + return SDValue(); + return TargetLowering::buildSDIVPow2WithCMov(N, Divisor, DAG, Created); +} + +bool RISCVTargetLowering::shouldFoldSelectWithSingleBitTest( + EVT VT, const APInt &AndMask) const { + if (Subtarget.hasStdExtZicond() || Subtarget.hasVendorXVentanaCondOps()) + return !Subtarget.hasStdExtZbs() && AndMask.ugt(1024); + return TargetLowering::shouldFoldSelectWithSingleBitTest(VT, AndMask); +} + +unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const { + return Subtarget.getMinimumJumpTableEntries(); +} + namespace llvm::RISCVVIntrinsicsTable { #define GET_RISCVVIntrinsicsTable_IMPL |
