diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
commit | 145449b1e420787bb99721a429341fa6be3adfb6 (patch) | |
tree | 1d56ae694a6de602e348dd80165cf881a36600ed /llvm/lib/Target/RISCV/RISCVISelLowering.cpp | |
parent | ecbca9f5fb7d7613d2b94982c4825eb0d33d6842 (diff) | |
download | src-145449b1e420787bb99721a429341fa6be3adfb6.tar.gz src-145449b1e420787bb99721a429341fa6be3adfb6.zip |
Vendor import of llvm-project main llvmorg-15-init-15358-g53dc0f107877.vendor/llvm-project/llvmorg-15-init-15358-g53dc0f107877
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4004 |
1 files changed, 2692 insertions, 1312 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 97d24c8e9c0b..ff645dea4e7a 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -112,17 +112,24 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasVInstructions()) { auto addRegClassForRVV = [this](MVT VT) { + // Disable the smallest fractional LMUL types if ELEN is less than + // RVVBitsPerBlock. + unsigned MinElts = RISCV::RVVBitsPerBlock / Subtarget.getELEN(); + if (VT.getVectorMinNumElements() < MinElts) + return; + unsigned Size = VT.getSizeInBits().getKnownMinValue(); - assert(Size <= 512 && isPowerOf2_32(Size)); const TargetRegisterClass *RC; - if (Size <= 64) + if (Size <= RISCV::RVVBitsPerBlock) RC = &RISCV::VRRegClass; - else if (Size == 128) + else if (Size == 2 * RISCV::RVVBitsPerBlock) RC = &RISCV::VRM2RegClass; - else if (Size == 256) + else if (Size == 4 * RISCV::RVVBitsPerBlock) RC = &RISCV::VRM4RegClass; - else + else if (Size == 8 * RISCV::RVVBitsPerBlock) RC = &RISCV::VRM8RegClass; + else + llvm_unreachable("Unexpected size"); addRegisterClass(VT, RC); }; @@ -170,8 +177,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setStackPointerRegisterToSaveRestore(RISCV::X2); - for (auto N : {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}) - setLoadExtAction(N, XLenVT, MVT::i1, Promote); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, XLenVT, + MVT::i1, Promote); // TODO: add all necessary setOperationAction calls. setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand); @@ -181,100 +188,75 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BRCOND, MVT::Other, Custom); setOperationAction(ISD::SELECT_CC, XLenVT, Expand); - setOperationAction(ISD::STACKSAVE, MVT::Other, Expand); - setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand); + setOperationAction({ISD::STACKSAVE, ISD::STACKRESTORE}, MVT::Other, Expand); setOperationAction(ISD::VASTART, MVT::Other, Custom); - setOperationAction(ISD::VAARG, MVT::Other, Expand); - setOperationAction(ISD::VACOPY, MVT::Other, Expand); - setOperationAction(ISD::VAEND, MVT::Other, Expand); + setOperationAction({ISD::VAARG, ISD::VACOPY, ISD::VAEND}, MVT::Other, Expand); setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); - if (!Subtarget.hasStdExtZbb()) { - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand); - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand); - } + + setOperationAction(ISD::EH_DWARF_CFA, MVT::i32, Custom); + + if (!Subtarget.hasStdExtZbb()) + setOperationAction(ISD::SIGN_EXTEND_INREG, {MVT::i8, MVT::i16}, Expand); if (Subtarget.is64Bit()) { - setOperationAction(ISD::ADD, MVT::i32, Custom); - setOperationAction(ISD::SUB, MVT::i32, Custom); - setOperationAction(ISD::SHL, MVT::i32, Custom); - setOperationAction(ISD::SRA, MVT::i32, Custom); - setOperationAction(ISD::SRL, MVT::i32, Custom); - - setOperationAction(ISD::UADDO, MVT::i32, Custom); - setOperationAction(ISD::USUBO, MVT::i32, Custom); - setOperationAction(ISD::UADDSAT, MVT::i32, Custom); - setOperationAction(ISD::USUBSAT, MVT::i32, Custom); + setOperationAction(ISD::EH_DWARF_CFA, MVT::i64, Custom); + + setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL}, + MVT::i32, Custom); + + setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT, ISD::USUBSAT}, + MVT::i32, Custom); } else { - setLibcallName(RTLIB::SHL_I128, nullptr); - setLibcallName(RTLIB::SRL_I128, nullptr); - setLibcallName(RTLIB::SRA_I128, nullptr); - setLibcallName(RTLIB::MUL_I128, nullptr); + setLibcallName( + {RTLIB::SHL_I128, RTLIB::SRL_I128, RTLIB::SRA_I128, RTLIB::MUL_I128}, + nullptr); setLibcallName(RTLIB::MULO_I64, nullptr); } if (!Subtarget.hasStdExtM()) { - setOperationAction(ISD::MUL, XLenVT, Expand); - setOperationAction(ISD::MULHS, XLenVT, Expand); - setOperationAction(ISD::MULHU, XLenVT, Expand); - setOperationAction(ISD::SDIV, XLenVT, Expand); - setOperationAction(ISD::UDIV, XLenVT, Expand); - setOperationAction(ISD::SREM, XLenVT, Expand); - setOperationAction(ISD::UREM, XLenVT, Expand); + setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::SDIV, ISD::UDIV, + ISD::SREM, ISD::UREM}, + XLenVT, Expand); } else { if (Subtarget.is64Bit()) { - setOperationAction(ISD::MUL, MVT::i32, Custom); - setOperationAction(ISD::MUL, MVT::i128, Custom); - - setOperationAction(ISD::SDIV, MVT::i8, Custom); - setOperationAction(ISD::UDIV, MVT::i8, Custom); - setOperationAction(ISD::UREM, MVT::i8, Custom); - setOperationAction(ISD::SDIV, MVT::i16, Custom); - setOperationAction(ISD::UDIV, MVT::i16, Custom); - setOperationAction(ISD::UREM, MVT::i16, Custom); - setOperationAction(ISD::SDIV, MVT::i32, Custom); - setOperationAction(ISD::UDIV, MVT::i32, Custom); - setOperationAction(ISD::UREM, MVT::i32, Custom); + setOperationAction(ISD::MUL, {MVT::i32, MVT::i128}, Custom); + + setOperationAction({ISD::SDIV, ISD::UDIV, ISD::UREM}, + {MVT::i8, MVT::i16, MVT::i32}, Custom); } else { setOperationAction(ISD::MUL, MVT::i64, Custom); } } - setOperationAction(ISD::SDIVREM, XLenVT, Expand); - setOperationAction(ISD::UDIVREM, XLenVT, Expand); - setOperationAction(ISD::SMUL_LOHI, XLenVT, Expand); - setOperationAction(ISD::UMUL_LOHI, XLenVT, Expand); + setOperationAction( + {ISD::SDIVREM, ISD::UDIVREM, ISD::SMUL_LOHI, ISD::UMUL_LOHI}, XLenVT, + Expand); - setOperationAction(ISD::SHL_PARTS, XLenVT, Custom); - setOperationAction(ISD::SRL_PARTS, XLenVT, Custom); - setOperationAction(ISD::SRA_PARTS, XLenVT, Custom); + setOperationAction({ISD::SHL_PARTS, ISD::SRL_PARTS, ISD::SRA_PARTS}, XLenVT, + Custom); if (Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbp() || Subtarget.hasStdExtZbkb()) { - if (Subtarget.is64Bit()) { - setOperationAction(ISD::ROTL, MVT::i32, Custom); - setOperationAction(ISD::ROTR, MVT::i32, Custom); - } + if (Subtarget.is64Bit()) + setOperationAction({ISD::ROTL, ISD::ROTR}, MVT::i32, Custom); } else { - setOperationAction(ISD::ROTL, XLenVT, Expand); - setOperationAction(ISD::ROTR, XLenVT, Expand); + setOperationAction({ISD::ROTL, ISD::ROTR}, XLenVT, Expand); } if (Subtarget.hasStdExtZbp()) { // Custom lower bswap/bitreverse so we can convert them to GREVI to enable // more combining. - setOperationAction(ISD::BITREVERSE, XLenVT, Custom); - setOperationAction(ISD::BSWAP, XLenVT, Custom); - setOperationAction(ISD::BITREVERSE, MVT::i8, Custom); + setOperationAction({ISD::BITREVERSE, ISD::BSWAP}, XLenVT, Custom); + // BSWAP i8 doesn't exist. - setOperationAction(ISD::BITREVERSE, MVT::i16, Custom); - setOperationAction(ISD::BSWAP, MVT::i16, Custom); + setOperationAction(ISD::BITREVERSE, MVT::i8, Custom); - if (Subtarget.is64Bit()) { - setOperationAction(ISD::BITREVERSE, MVT::i32, Custom); - setOperationAction(ISD::BSWAP, MVT::i32, Custom); - } + setOperationAction({ISD::BITREVERSE, ISD::BSWAP}, MVT::i16, Custom); + + if (Subtarget.is64Bit()) + setOperationAction({ISD::BITREVERSE, ISD::BSWAP}, MVT::i32, Custom); } else { // With Zbb we have an XLen rev8 instruction, but not GREVI. So we'll // pattern match it directly in isel. @@ -288,36 +270,38 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } if (Subtarget.hasStdExtZbb()) { - setOperationAction(ISD::SMIN, XLenVT, Legal); - setOperationAction(ISD::SMAX, XLenVT, Legal); - setOperationAction(ISD::UMIN, XLenVT, Legal); - setOperationAction(ISD::UMAX, XLenVT, Legal); + setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, XLenVT, + Legal); - if (Subtarget.is64Bit()) { - setOperationAction(ISD::CTTZ, MVT::i32, Custom); - setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Custom); - setOperationAction(ISD::CTLZ, MVT::i32, Custom); - setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Custom); - } + if (Subtarget.is64Bit()) + setOperationAction( + {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, + MVT::i32, Custom); } else { - setOperationAction(ISD::CTTZ, XLenVT, Expand); - setOperationAction(ISD::CTLZ, XLenVT, Expand); - setOperationAction(ISD::CTPOP, XLenVT, Expand); + setOperationAction({ISD::CTTZ, ISD::CTLZ, ISD::CTPOP}, XLenVT, Expand); + + if (Subtarget.is64Bit()) + setOperationAction(ISD::ABS, MVT::i32, Custom); } if (Subtarget.hasStdExtZbt()) { - setOperationAction(ISD::FSHL, XLenVT, Custom); - setOperationAction(ISD::FSHR, XLenVT, Custom); + setOperationAction({ISD::FSHL, ISD::FSHR}, XLenVT, Custom); setOperationAction(ISD::SELECT, XLenVT, Legal); - if (Subtarget.is64Bit()) { - setOperationAction(ISD::FSHL, MVT::i32, Custom); - setOperationAction(ISD::FSHR, MVT::i32, Custom); - } + if (Subtarget.is64Bit()) + setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Custom); } else { setOperationAction(ISD::SELECT, XLenVT, Custom); } + static constexpr ISD::NodeType FPLegalNodeTypes[] = { + ISD::FMINNUM, ISD::FMAXNUM, ISD::LRINT, + ISD::LLRINT, ISD::LROUND, ISD::LLROUND, + ISD::STRICT_LRINT, ISD::STRICT_LLRINT, ISD::STRICT_LROUND, + ISD::STRICT_LLROUND, ISD::STRICT_FMA, ISD::STRICT_FADD, + ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, + ISD::STRICT_FSQRT, ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS}; + static const ISD::CondCode FPCCToExpand[] = { ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT, ISD::SETUGE, ISD::SETULT, ISD::SETULE, ISD::SETUNE, ISD::SETGT, @@ -331,50 +315,21 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::i16, Custom); if (Subtarget.hasStdExtZfh()) { - setOperationAction(ISD::FMINNUM, MVT::f16, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); - setOperationAction(ISD::LRINT, MVT::f16, Legal); - setOperationAction(ISD::LLRINT, MVT::f16, Legal); - setOperationAction(ISD::LROUND, MVT::f16, Legal); - setOperationAction(ISD::LLROUND, MVT::f16, Legal); - setOperationAction(ISD::STRICT_LRINT, MVT::f16, Legal); - setOperationAction(ISD::STRICT_LLRINT, MVT::f16, Legal); - setOperationAction(ISD::STRICT_LROUND, MVT::f16, Legal); - setOperationAction(ISD::STRICT_LLROUND, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FADD, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FMA, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FSUB, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FMUL, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FDIV, MVT::f16, Legal); + for (auto NT : FPLegalNodeTypes) + setOperationAction(NT, MVT::f16, Legal); setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal); setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FSQRT, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Legal); - setOperationAction(ISD::STRICT_FSETCCS, MVT::f16, Legal); - for (auto CC : FPCCToExpand) - setCondCodeAction(CC, MVT::f16, Expand); + setCondCodeAction(FPCCToExpand, MVT::f16, Expand); setOperationAction(ISD::SELECT_CC, MVT::f16, Expand); setOperationAction(ISD::SELECT, MVT::f16, Custom); setOperationAction(ISD::BR_CC, MVT::f16, Expand); - setOperationAction(ISD::FREM, MVT::f16, Promote); - setOperationAction(ISD::FCEIL, MVT::f16, Promote); - setOperationAction(ISD::FFLOOR, MVT::f16, Promote); - setOperationAction(ISD::FNEARBYINT, MVT::f16, Promote); - setOperationAction(ISD::FRINT, MVT::f16, Promote); - setOperationAction(ISD::FROUND, MVT::f16, Promote); - setOperationAction(ISD::FROUNDEVEN, MVT::f16, Promote); - setOperationAction(ISD::FTRUNC, MVT::f16, Promote); - setOperationAction(ISD::FPOW, MVT::f16, Promote); - setOperationAction(ISD::FPOWI, MVT::f16, Promote); - setOperationAction(ISD::FCOS, MVT::f16, Promote); - setOperationAction(ISD::FSIN, MVT::f16, Promote); - setOperationAction(ISD::FSINCOS, MVT::f16, Promote); - setOperationAction(ISD::FEXP, MVT::f16, Promote); - setOperationAction(ISD::FEXP2, MVT::f16, Promote); - setOperationAction(ISD::FLOG, MVT::f16, Promote); - setOperationAction(ISD::FLOG2, MVT::f16, Promote); - setOperationAction(ISD::FLOG10, MVT::f16, Promote); + setOperationAction({ISD::FREM, ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, + ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN, ISD::FTRUNC, + ISD::FPOW, ISD::FPOWI, ISD::FCOS, ISD::FSIN, + ISD::FSINCOS, ISD::FEXP, ISD::FEXP2, ISD::FLOG, + ISD::FLOG2, ISD::FLOG10}, + MVT::f16, Promote); // FIXME: Need to promote f16 STRICT_* to f32 libcalls, but we don't have // complete support for all operations in LegalizeDAG. @@ -385,26 +340,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } if (Subtarget.hasStdExtF()) { - setOperationAction(ISD::FMINNUM, MVT::f32, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); - setOperationAction(ISD::LRINT, MVT::f32, Legal); - setOperationAction(ISD::LLRINT, MVT::f32, Legal); - setOperationAction(ISD::LROUND, MVT::f32, Legal); - setOperationAction(ISD::LLROUND, MVT::f32, Legal); - setOperationAction(ISD::STRICT_LRINT, MVT::f32, Legal); - setOperationAction(ISD::STRICT_LLRINT, MVT::f32, Legal); - setOperationAction(ISD::STRICT_LROUND, MVT::f32, Legal); - setOperationAction(ISD::STRICT_LLROUND, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FADD, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FMA, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FSUB, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FMUL, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FDIV, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FSQRT, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Legal); - setOperationAction(ISD::STRICT_FSETCCS, MVT::f32, Legal); - for (auto CC : FPCCToExpand) - setCondCodeAction(CC, MVT::f32, Expand); + for (auto NT : FPLegalNodeTypes) + setOperationAction(NT, MVT::f32, Legal); + setCondCodeAction(FPCCToExpand, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT, MVT::f32, Custom); setOperationAction(ISD::BR_CC, MVT::f32, Expand); @@ -418,28 +356,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::i32, Custom); if (Subtarget.hasStdExtD()) { - setOperationAction(ISD::FMINNUM, MVT::f64, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); - setOperationAction(ISD::LRINT, MVT::f64, Legal); - setOperationAction(ISD::LLRINT, MVT::f64, Legal); - setOperationAction(ISD::LROUND, MVT::f64, Legal); - setOperationAction(ISD::LLROUND, MVT::f64, Legal); - setOperationAction(ISD::STRICT_LRINT, MVT::f64, Legal); - setOperationAction(ISD::STRICT_LLRINT, MVT::f64, Legal); - setOperationAction(ISD::STRICT_LROUND, MVT::f64, Legal); - setOperationAction(ISD::STRICT_LLROUND, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FMA, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FADD, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FSUB, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FMUL, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FDIV, MVT::f64, Legal); + for (auto NT : FPLegalNodeTypes) + setOperationAction(NT, MVT::f64, Legal); setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Legal); setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FSQRT, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Legal); - setOperationAction(ISD::STRICT_FSETCCS, MVT::f64, Legal); - for (auto CC : FPCCToExpand) - setCondCodeAction(CC, MVT::f64, Expand); + setCondCodeAction(FPCCToExpand, MVT::f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); setOperationAction(ISD::SELECT, MVT::f64, Custom); setOperationAction(ISD::BR_CC, MVT::f64, Expand); @@ -451,40 +372,38 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTruncStoreAction(MVT::f64, MVT::f16, Expand); } - if (Subtarget.is64Bit()) { - setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom); - setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom); - setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom); - setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom); - } + if (Subtarget.is64Bit()) + setOperationAction({ISD::FP_TO_UINT, ISD::FP_TO_SINT, + ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT}, + MVT::i32, Custom); if (Subtarget.hasStdExtF()) { - setOperationAction(ISD::FP_TO_UINT_SAT, XLenVT, Custom); - setOperationAction(ISD::FP_TO_SINT_SAT, XLenVT, Custom); + setOperationAction({ISD::FP_TO_UINT_SAT, ISD::FP_TO_SINT_SAT}, XLenVT, + Custom); - setOperationAction(ISD::STRICT_FP_TO_UINT, XLenVT, Legal); - setOperationAction(ISD::STRICT_FP_TO_SINT, XLenVT, Legal); - setOperationAction(ISD::STRICT_UINT_TO_FP, XLenVT, Legal); - setOperationAction(ISD::STRICT_SINT_TO_FP, XLenVT, Legal); + setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT, + ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, + XLenVT, Legal); setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom); setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); } - setOperationAction(ISD::GlobalAddress, XLenVT, Custom); - setOperationAction(ISD::BlockAddress, XLenVT, Custom); - setOperationAction(ISD::ConstantPool, XLenVT, Custom); - setOperationAction(ISD::JumpTable, XLenVT, Custom); + setOperationAction({ISD::GlobalAddress, ISD::BlockAddress, ISD::ConstantPool, + ISD::JumpTable}, + XLenVT, Custom); setOperationAction(ISD::GlobalTLSAddress, XLenVT, Custom); + if (Subtarget.is64Bit()) + setOperationAction(ISD::Constant, MVT::i64, Custom); + // TODO: On M-mode only targets, the cycle[h] CSR may not be present. // Unfortunately this can't be determined just from the ISA naming string. setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Subtarget.is64Bit() ? Legal : Custom); - setOperationAction(ISD::TRAP, MVT::Other, Legal); - setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); + setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); if (Subtarget.is64Bit()) setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i32, Custom); @@ -505,19 +424,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // RVV intrinsics may have illegal operands. // We also need to custom legalize vmv.x.s. - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); - setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); - setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i16, Custom); - if (Subtarget.is64Bit()) { + setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN}, + {MVT::i8, MVT::i16}, Custom); + if (Subtarget.is64Bit()) setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i32, Custom); - } else { - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom); - setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom); - } + else + setOperationAction({ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN}, + MVT::i64, Custom); - setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); - setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom); + setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID}, + MVT::Other, Custom); static const unsigned IntegerVPOps[] = { ISD::VP_ADD, ISD::VP_SUB, ISD::VP_MUL, @@ -527,191 +443,175 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::VP_SHL, ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR, ISD::VP_REDUCE_SMAX, ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN, - ISD::VP_MERGE, ISD::VP_SELECT}; + ISD::VP_MERGE, ISD::VP_SELECT, ISD::VP_FPTOSI, + ISD::VP_FPTOUI, ISD::VP_SETCC, ISD::VP_SIGN_EXTEND, + ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE}; static const unsigned FloatingPointVPOps[] = { - ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL, - ISD::VP_FDIV, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD, - ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_MERGE, - ISD::VP_SELECT}; + ISD::VP_FADD, ISD::VP_FSUB, + ISD::VP_FMUL, ISD::VP_FDIV, + ISD::VP_FNEG, ISD::VP_FMA, + ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD, + ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, + ISD::VP_MERGE, ISD::VP_SELECT, + ISD::VP_SITOFP, ISD::VP_UITOFP, + ISD::VP_SETCC, ISD::VP_FP_ROUND, + ISD::VP_FP_EXTEND}; if (!Subtarget.is64Bit()) { // We must custom-lower certain vXi64 operations on RV32 due to the vector // element type being illegal. - setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom); - - setOperationAction(ISD::VECREDUCE_ADD, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_AND, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_OR, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_XOR, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_SMAX, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_SMIN, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_UMAX, MVT::i64, Custom); - setOperationAction(ISD::VECREDUCE_UMIN, MVT::i64, Custom); - - setOperationAction(ISD::VP_REDUCE_ADD, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_AND, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_OR, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_XOR, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_SMAX, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_SMIN, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_UMAX, MVT::i64, Custom); - setOperationAction(ISD::VP_REDUCE_UMIN, MVT::i64, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, + MVT::i64, Custom); + + setOperationAction({ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, + ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, + ISD::VECREDUCE_SMAX, ISD::VECREDUCE_SMIN, + ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN}, + MVT::i64, Custom); + + setOperationAction({ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND, + ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR, + ISD::VP_REDUCE_SMAX, ISD::VP_REDUCE_SMIN, + ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN}, + MVT::i64, Custom); } for (MVT VT : BoolVecVTs) { + if (!isTypeLegal(VT)) + continue; + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); // Mask VTs are custom-expanded into a series of standard nodes - setOperationAction(ISD::TRUNCATE, VT, Custom); - setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction({ISD::TRUNCATE, ISD::CONCAT_VECTORS, + ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, + VT, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT, + Custom); setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction(ISD::VSELECT, VT, Expand); - setOperationAction(ISD::VP_MERGE, VT, Expand); - setOperationAction(ISD::VP_SELECT, VT, Expand); + setOperationAction( + {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT, + Expand); - setOperationAction(ISD::VP_AND, VT, Custom); - setOperationAction(ISD::VP_OR, VT, Custom); - setOperationAction(ISD::VP_XOR, VT, Custom); + setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom); - setOperationAction(ISD::VECREDUCE_AND, VT, Custom); - setOperationAction(ISD::VECREDUCE_OR, VT, Custom); - setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction( + {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT, + Custom); - setOperationAction(ISD::VP_REDUCE_AND, VT, Custom); - setOperationAction(ISD::VP_REDUCE_OR, VT, Custom); - setOperationAction(ISD::VP_REDUCE_XOR, VT, Custom); + setOperationAction( + {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT, + Custom); // RVV has native int->float & float->int conversions where the // element type sizes are within one power-of-two of each other. Any // wider distances between type sizes have to be lowered as sequences // which progressively narrow the gap in stages. - setOperationAction(ISD::SINT_TO_FP, VT, Custom); - setOperationAction(ISD::UINT_TO_FP, VT, Custom); - setOperationAction(ISD::FP_TO_SINT, VT, Custom); - setOperationAction(ISD::FP_TO_UINT, VT, Custom); + setOperationAction( + {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT}, + VT, Custom); // Expand all extending loads to types larger than this, and truncating // stores from types larger than this. for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) { setTruncStoreAction(OtherVT, VT, Expand); - setLoadExtAction(ISD::EXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::SEXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::ZEXTLOAD, OtherVT, VT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT, + VT, Expand); } + + setOperationAction( + {ISD::VP_FPTOSI, ISD::VP_FPTOUI, ISD::VP_TRUNCATE, ISD::VP_SETCC}, VT, + Custom); + setOperationAction(ISD::VECTOR_REVERSE, VT, Custom); } for (MVT VT : IntVecVTs) { - if (VT.getVectorElementType() == MVT::i64 && - !Subtarget.hasVInstructionsI64()) + if (!isTypeLegal(VT)) continue; setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom); // Vectors implement MULHS/MULHU. - setOperationAction(ISD::SMUL_LOHI, VT, Expand); - setOperationAction(ISD::UMUL_LOHI, VT, Expand); + setOperationAction({ISD::SMUL_LOHI, ISD::UMUL_LOHI}, VT, Expand); // nxvXi64 MULHS/MULHU requires the V extension instead of Zve64*. - if (VT.getVectorElementType() == MVT::i64 && !Subtarget.hasStdExtV()) { - setOperationAction(ISD::MULHU, VT, Expand); - setOperationAction(ISD::MULHS, VT, Expand); - } + if (VT.getVectorElementType() == MVT::i64 && !Subtarget.hasStdExtV()) + setOperationAction({ISD::MULHU, ISD::MULHS}, VT, Expand); - setOperationAction(ISD::SMIN, VT, Legal); - setOperationAction(ISD::SMAX, VT, Legal); - setOperationAction(ISD::UMIN, VT, Legal); - setOperationAction(ISD::UMAX, VT, Legal); + setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, VT, + Legal); - setOperationAction(ISD::ROTL, VT, Expand); - setOperationAction(ISD::ROTR, VT, Expand); + setOperationAction({ISD::ROTL, ISD::ROTR}, VT, Expand); - setOperationAction(ISD::CTTZ, VT, Expand); - setOperationAction(ISD::CTLZ, VT, Expand); - setOperationAction(ISD::CTPOP, VT, Expand); + setOperationAction({ISD::CTTZ, ISD::CTLZ, ISD::CTPOP, ISD::BSWAP}, VT, + Expand); setOperationAction(ISD::BSWAP, VT, Expand); // Custom-lower extensions and truncations from/to mask types. - setOperationAction(ISD::ANY_EXTEND, VT, Custom); - setOperationAction(ISD::SIGN_EXTEND, VT, Custom); - setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction({ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND}, + VT, Custom); // RVV has native int->float & float->int conversions where the // element type sizes are within one power-of-two of each other. Any // wider distances between type sizes have to be lowered as sequences // which progressively narrow the gap in stages. - setOperationAction(ISD::SINT_TO_FP, VT, Custom); - setOperationAction(ISD::UINT_TO_FP, VT, Custom); - setOperationAction(ISD::FP_TO_SINT, VT, Custom); - setOperationAction(ISD::FP_TO_UINT, VT, Custom); + setOperationAction( + {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT}, + VT, Custom); - setOperationAction(ISD::SADDSAT, VT, Legal); - setOperationAction(ISD::UADDSAT, VT, Legal); - setOperationAction(ISD::SSUBSAT, VT, Legal); - setOperationAction(ISD::USUBSAT, VT, Legal); + setOperationAction( + {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal); // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL" // nodes which truncate by one power of two at a time. setOperationAction(ISD::TRUNCATE, VT, Custom); // Custom-lower insert/extract operations to simplify patterns. - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT, + Custom); // Custom-lower reduction operations to set up the corresponding custom // nodes' operands. - setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); - setOperationAction(ISD::VECREDUCE_AND, VT, Custom); - setOperationAction(ISD::VECREDUCE_OR, VT, Custom); - setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); - - for (unsigned VPOpc : IntegerVPOps) - setOperationAction(VPOpc, VT, Custom); - - setOperationAction(ISD::LOAD, VT, Custom); - setOperationAction(ISD::STORE, VT, Custom); - - setOperationAction(ISD::MLOAD, VT, Custom); - setOperationAction(ISD::MSTORE, VT, Custom); - setOperationAction(ISD::MGATHER, VT, Custom); - setOperationAction(ISD::MSCATTER, VT, Custom); - - setOperationAction(ISD::VP_LOAD, VT, Custom); - setOperationAction(ISD::VP_STORE, VT, Custom); - setOperationAction(ISD::VP_GATHER, VT, Custom); - setOperationAction(ISD::VP_SCATTER, VT, Custom); - - setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction({ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, + ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, + ISD::VECREDUCE_SMAX, ISD::VECREDUCE_SMIN, + ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN}, + VT, Custom); + + setOperationAction(IntegerVPOps, VT, Custom); + + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); + + setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER}, + VT, Custom); + + setOperationAction( + {ISD::VP_LOAD, ISD::VP_STORE, ISD::VP_GATHER, ISD::VP_SCATTER}, VT, + Custom); + + setOperationAction( + {ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, + VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction(ISD::STEP_VECTOR, VT, Custom); - setOperationAction(ISD::VECTOR_REVERSE, VT, Custom); + setOperationAction({ISD::STEP_VECTOR, ISD::VECTOR_REVERSE}, VT, Custom); for (MVT OtherVT : MVT::integer_scalable_vector_valuetypes()) { setTruncStoreAction(VT, OtherVT, Expand); - setLoadExtAction(ISD::EXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::SEXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::ZEXTLOAD, OtherVT, VT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, OtherVT, + VT, Expand); } + // Splice + setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); + // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if we have a floating point // type that can represent the value exactly. if (VT.getVectorElementType() != MVT::i64) { @@ -719,8 +619,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, VT.getVectorElementType() == MVT::i32 ? MVT::f64 : MVT::f32; EVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount()); if (isTypeLegal(FloatVT)) { - setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Custom); - setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); + setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT, + Custom); } } } @@ -745,21 +645,35 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // sizes are within one power-of-two of each other. Therefore conversions // between vXf16 and vXf64 must be lowered as sequences which convert via // vXf32. - setOperationAction(ISD::FP_ROUND, VT, Custom); - setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); // Custom-lower insert/extract operations to simplify patterns. - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT, + Custom); // Expand various condition codes (explained above). - for (auto CC : VFPCCToExpand) - setCondCodeAction(CC, VT, Expand); - - setOperationAction(ISD::FMINNUM, VT, Legal); - setOperationAction(ISD::FMAXNUM, VT, Legal); - - setOperationAction(ISD::FTRUNC, VT, Custom); - setOperationAction(ISD::FCEIL, VT, Custom); - setOperationAction(ISD::FFLOOR, VT, Custom); + setCondCodeAction(VFPCCToExpand, VT, Expand); + + setOperationAction({ISD::FMINNUM, ISD::FMAXNUM}, VT, Legal); + + setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND}, + VT, Custom); + + setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_SEQ_FADD, + ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMAX}, + VT, Custom); + + // Expand FP operations that need libcalls. + setOperationAction(ISD::FREM, VT, Expand); + setOperationAction(ISD::FPOW, VT, Expand); + setOperationAction(ISD::FCOS, VT, Expand); + setOperationAction(ISD::FSIN, VT, Expand); + setOperationAction(ISD::FSINCOS, VT, Expand); + setOperationAction(ISD::FEXP, VT, Expand); + setOperationAction(ISD::FEXP2, VT, Expand); + setOperationAction(ISD::FLOG, VT, Expand); + setOperationAction(ISD::FLOG2, VT, Expand); + setOperationAction(ISD::FLOG10, VT, Expand); + setOperationAction(ISD::FRINT, VT, Expand); + setOperationAction(ISD::FNEARBYINT, VT, Expand); setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); @@ -768,30 +682,25 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FCOPYSIGN, VT, Legal); - setOperationAction(ISD::LOAD, VT, Custom); - setOperationAction(ISD::STORE, VT, Custom); + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); - setOperationAction(ISD::MLOAD, VT, Custom); - setOperationAction(ISD::MSTORE, VT, Custom); - setOperationAction(ISD::MGATHER, VT, Custom); - setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction({ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER}, + VT, Custom); - setOperationAction(ISD::VP_LOAD, VT, Custom); - setOperationAction(ISD::VP_STORE, VT, Custom); - setOperationAction(ISD::VP_GATHER, VT, Custom); - setOperationAction(ISD::VP_SCATTER, VT, Custom); + setOperationAction( + {ISD::VP_LOAD, ISD::VP_STORE, ISD::VP_GATHER, ISD::VP_SCATTER}, VT, + Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction( + {ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, + VT, Custom); - setOperationAction(ISD::VECTOR_REVERSE, VT, Custom); + setOperationAction({ISD::VECTOR_REVERSE, ISD::VECTOR_SPLICE}, VT, Custom); - for (unsigned VPOpc : FloatingPointVPOps) - setOperationAction(VPOpc, VT, Custom); + setOperationAction(FloatingPointVPOps, VT, Custom); }; // Sets common extload/truncstore actions on RVV floating-point vector @@ -804,21 +713,31 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } }; - if (Subtarget.hasVInstructionsF16()) - for (MVT VT : F16VecVTs) + if (Subtarget.hasVInstructionsF16()) { + for (MVT VT : F16VecVTs) { + if (!isTypeLegal(VT)) + continue; SetCommonVFPActions(VT); + } + } - for (MVT VT : F32VecVTs) { - if (Subtarget.hasVInstructionsF32()) + if (Subtarget.hasVInstructionsF32()) { + for (MVT VT : F32VecVTs) { + if (!isTypeLegal(VT)) + continue; SetCommonVFPActions(VT); - SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs); + SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs); + } } - for (MVT VT : F64VecVTs) { - if (Subtarget.hasVInstructionsF64()) + if (Subtarget.hasVInstructionsF64()) { + for (MVT VT : F64VecVTs) { + if (!isTypeLegal(VT)) + continue; SetCommonVFPActions(VT); - SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs); - SetCommonVFPExtLoadTruncStoreActions(VT, F32VecVTs); + SetCommonVFPExtLoadTruncStoreActions(VT, F16VecVTs); + SetCommonVFPExtLoadTruncStoreActions(VT, F32VecVTs); + } } if (Subtarget.useRVVForFixedLengthVectors()) { @@ -831,23 +750,21 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(Op, VT, Expand); for (MVT OtherVT : MVT::integer_fixedlen_vector_valuetypes()) { setTruncStoreAction(VT, OtherVT, Expand); - setLoadExtAction(ISD::EXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::SEXTLOAD, OtherVT, VT, Expand); - setLoadExtAction(ISD::ZEXTLOAD, OtherVT, VT, Expand); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, + OtherVT, VT, Expand); } // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed. - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT, + Custom); - setOperationAction(ISD::BUILD_VECTOR, VT, Custom); - setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS}, VT, + Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, + VT, Custom); - setOperationAction(ISD::LOAD, VT, Custom); - setOperationAction(ISD::STORE, VT, Custom); + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -857,100 +774,80 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, VT, Custom); - setOperationAction(ISD::VECREDUCE_AND, VT, Custom); - setOperationAction(ISD::VECREDUCE_OR, VT, Custom); - setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction( + {ISD::VECREDUCE_AND, ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR}, VT, + Custom); - setOperationAction(ISD::VP_REDUCE_AND, VT, Custom); - setOperationAction(ISD::VP_REDUCE_OR, VT, Custom); - setOperationAction(ISD::VP_REDUCE_XOR, VT, Custom); + setOperationAction( + {ISD::VP_REDUCE_AND, ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR}, VT, + Custom); - setOperationAction(ISD::SINT_TO_FP, VT, Custom); - setOperationAction(ISD::UINT_TO_FP, VT, Custom); - setOperationAction(ISD::FP_TO_SINT, VT, Custom); - setOperationAction(ISD::FP_TO_UINT, VT, Custom); + setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, + ISD::FP_TO_UINT}, + VT, Custom); // Operations below are different for between masks and other vectors. if (VT.getVectorElementType() == MVT::i1) { - setOperationAction(ISD::VP_AND, VT, Custom); - setOperationAction(ISD::VP_OR, VT, Custom); - setOperationAction(ISD::VP_XOR, VT, Custom); - setOperationAction(ISD::AND, VT, Custom); - setOperationAction(ISD::OR, VT, Custom); - setOperationAction(ISD::XOR, VT, Custom); + setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR, ISD::AND, + ISD::OR, ISD::XOR}, + VT, Custom); + + setOperationAction( + {ISD::VP_FPTOSI, ISD::VP_FPTOUI, ISD::VP_SETCC, ISD::VP_TRUNCATE}, + VT, Custom); continue; } - // Use SPLAT_VECTOR to prevent type legalization from destroying the - // splats when type legalizing i64 scalar on RV32. + // Make SPLAT_VECTOR Legal so DAGCombine will convert splat vectors to + // it before type legalization for i64 vectors on RV32. It will then be + // type legalized to SPLAT_VECTOR_PARTS which we need to Custom handle. // FIXME: Use SPLAT_VECTOR for all types? DAGCombine probably needs // improvements first. if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) { - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom); } setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::MLOAD, VT, Custom); - setOperationAction(ISD::MSTORE, VT, Custom); - setOperationAction(ISD::MGATHER, VT, Custom); - setOperationAction(ISD::MSCATTER, VT, Custom); - - setOperationAction(ISD::VP_LOAD, VT, Custom); - setOperationAction(ISD::VP_STORE, VT, Custom); - setOperationAction(ISD::VP_GATHER, VT, Custom); - setOperationAction(ISD::VP_SCATTER, VT, Custom); - - setOperationAction(ISD::ADD, VT, Custom); - setOperationAction(ISD::MUL, VT, Custom); - setOperationAction(ISD::SUB, VT, Custom); - setOperationAction(ISD::AND, VT, Custom); - setOperationAction(ISD::OR, VT, Custom); - setOperationAction(ISD::XOR, VT, Custom); - setOperationAction(ISD::SDIV, VT, Custom); - setOperationAction(ISD::SREM, VT, Custom); - setOperationAction(ISD::UDIV, VT, Custom); - setOperationAction(ISD::UREM, VT, Custom); - setOperationAction(ISD::SHL, VT, Custom); - setOperationAction(ISD::SRA, VT, Custom); - setOperationAction(ISD::SRL, VT, Custom); - - setOperationAction(ISD::SMIN, VT, Custom); - setOperationAction(ISD::SMAX, VT, Custom); - setOperationAction(ISD::UMIN, VT, Custom); - setOperationAction(ISD::UMAX, VT, Custom); - setOperationAction(ISD::ABS, VT, Custom); + setOperationAction( + {ISD::MLOAD, ISD::MSTORE, ISD::MGATHER, ISD::MSCATTER}, VT, Custom); + + setOperationAction( + {ISD::VP_LOAD, ISD::VP_STORE, ISD::VP_GATHER, ISD::VP_SCATTER}, VT, + Custom); + + setOperationAction({ISD::ADD, ISD::MUL, ISD::SUB, ISD::AND, ISD::OR, + ISD::XOR, ISD::SDIV, ISD::SREM, ISD::UDIV, + ISD::UREM, ISD::SHL, ISD::SRA, ISD::SRL}, + VT, Custom); + + setOperationAction( + {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::ABS}, VT, Custom); // vXi64 MULHS/MULHU requires the V extension instead of Zve64*. - if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV()) { - setOperationAction(ISD::MULHS, VT, Custom); - setOperationAction(ISD::MULHU, VT, Custom); - } + if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV()) + setOperationAction({ISD::MULHS, ISD::MULHU}, VT, Custom); - setOperationAction(ISD::SADDSAT, VT, Custom); - setOperationAction(ISD::UADDSAT, VT, Custom); - setOperationAction(ISD::SSUBSAT, VT, Custom); - setOperationAction(ISD::USUBSAT, VT, Custom); + setOperationAction( + {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, + Custom); setOperationAction(ISD::VSELECT, VT, Custom); setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction(ISD::ANY_EXTEND, VT, Custom); - setOperationAction(ISD::SIGN_EXTEND, VT, Custom); - setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction( + {ISD::ANY_EXTEND, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND}, VT, Custom); // Custom-lower reduction operations to set up the corresponding custom // nodes' operands. - setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + setOperationAction({ISD::VECREDUCE_ADD, ISD::VECREDUCE_SMAX, + ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, + ISD::VECREDUCE_UMIN}, + VT, Custom); - for (unsigned VPOpc : IntegerVPOps) - setOperationAction(VPOpc, VT, Custom); + setOperationAction(IntegerVPOps, VT, Custom); // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if we have a floating point // type that can represent the value exactly. @@ -959,10 +856,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, VT.getVectorElementType() == MVT::i32 ? MVT::f64 : MVT::f32; EVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount()); - if (isTypeLegal(FloatVT)) { - setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Custom); - setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); - } + if (isTypeLegal(FloatVT)) + setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT, + Custom); } } @@ -979,69 +875,50 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed. - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT, + Custom); - setOperationAction(ISD::BUILD_VECTOR, VT, Custom); - setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); - - setOperationAction(ISD::LOAD, VT, Custom); - setOperationAction(ISD::STORE, VT, Custom); - setOperationAction(ISD::MLOAD, VT, Custom); - setOperationAction(ISD::MSTORE, VT, Custom); - setOperationAction(ISD::MGATHER, VT, Custom); - setOperationAction(ISD::MSCATTER, VT, Custom); - - setOperationAction(ISD::VP_LOAD, VT, Custom); - setOperationAction(ISD::VP_STORE, VT, Custom); - setOperationAction(ISD::VP_GATHER, VT, Custom); - setOperationAction(ISD::VP_SCATTER, VT, Custom); - - setOperationAction(ISD::FADD, VT, Custom); - setOperationAction(ISD::FSUB, VT, Custom); - setOperationAction(ISD::FMUL, VT, Custom); - setOperationAction(ISD::FDIV, VT, Custom); - setOperationAction(ISD::FNEG, VT, Custom); - setOperationAction(ISD::FABS, VT, Custom); - setOperationAction(ISD::FCOPYSIGN, VT, Custom); - setOperationAction(ISD::FSQRT, VT, Custom); - setOperationAction(ISD::FMA, VT, Custom); - setOperationAction(ISD::FMINNUM, VT, Custom); - setOperationAction(ISD::FMAXNUM, VT, Custom); - - setOperationAction(ISD::FP_ROUND, VT, Custom); - setOperationAction(ISD::FP_EXTEND, VT, Custom); - - setOperationAction(ISD::FTRUNC, VT, Custom); - setOperationAction(ISD::FCEIL, VT, Custom); - setOperationAction(ISD::FFLOOR, VT, Custom); + setOperationAction({ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, + ISD::VECTOR_SHUFFLE, ISD::INSERT_VECTOR_ELT, + ISD::EXTRACT_VECTOR_ELT}, + VT, Custom); + + setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE, + ISD::MGATHER, ISD::MSCATTER}, + VT, Custom); + + setOperationAction( + {ISD::VP_LOAD, ISD::VP_STORE, ISD::VP_GATHER, ISD::VP_SCATTER}, VT, + Custom); + + setOperationAction({ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV, + ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::FSQRT, + ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM}, + VT, Custom); + + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); + + setOperationAction({ISD::FTRUNC, ISD::FCEIL, ISD::FFLOOR, ISD::FROUND}, + VT, Custom); for (auto CC : VFPCCToExpand) setCondCodeAction(CC, VT, Expand); - setOperationAction(ISD::VSELECT, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction({ISD::VSELECT, ISD::SELECT}, VT, Custom); setOperationAction(ISD::SELECT_CC, VT, Expand); setOperationAction(ISD::BITCAST, VT, Custom); - setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); - setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); - setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); - setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_SEQ_FADD, + ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMAX}, + VT, Custom); - for (unsigned VPOpc : FloatingPointVPOps) - setOperationAction(VPOpc, VT, Custom); + setOperationAction(FloatingPointVPOps, VT, Custom); } // Custom-legalize bitcasts from fixed-length vectors to scalar types. - setOperationAction(ISD::BITCAST, MVT::i8, Custom); - setOperationAction(ISD::BITCAST, MVT::i16, Custom); - setOperationAction(ISD::BITCAST, MVT::i32, Custom); - setOperationAction(ISD::BITCAST, MVT::i64, Custom); + setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64}, + Custom); if (Subtarget.hasStdExtZfh()) setOperationAction(ISD::BITCAST, MVT::f16, Custom); if (Subtarget.hasStdExtF()) @@ -1061,30 +938,33 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Jumps are expensive, compared to logic setJumpIsExpensive(); - setTargetDAGCombine(ISD::ADD); - setTargetDAGCombine(ISD::SUB); - setTargetDAGCombine(ISD::AND); - setTargetDAGCombine(ISD::OR); - setTargetDAGCombine(ISD::XOR); - setTargetDAGCombine(ISD::ANY_EXTEND); - if (Subtarget.hasStdExtF()) { - setTargetDAGCombine(ISD::ZERO_EXTEND); - setTargetDAGCombine(ISD::FP_TO_SINT); - setTargetDAGCombine(ISD::FP_TO_UINT); - setTargetDAGCombine(ISD::FP_TO_SINT_SAT); - setTargetDAGCombine(ISD::FP_TO_UINT_SAT); - } - if (Subtarget.hasVInstructions()) { - setTargetDAGCombine(ISD::FCOPYSIGN); - setTargetDAGCombine(ISD::MGATHER); - setTargetDAGCombine(ISD::MSCATTER); - setTargetDAGCombine(ISD::VP_GATHER); - setTargetDAGCombine(ISD::VP_SCATTER); + setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, + ISD::OR, ISD::XOR}); + if (Subtarget.is64Bit()) setTargetDAGCombine(ISD::SRA); - setTargetDAGCombine(ISD::SRL); - setTargetDAGCombine(ISD::SHL); - setTargetDAGCombine(ISD::STORE); - } + + if (Subtarget.hasStdExtF()) + setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM}); + + if (Subtarget.hasStdExtZbp()) + setTargetDAGCombine({ISD::ROTL, ISD::ROTR}); + + if (Subtarget.hasStdExtZbb()) + setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN}); + + if (Subtarget.hasStdExtZbkb()) + setTargetDAGCombine(ISD::BITREVERSE); + if (Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZbb()) + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); + if (Subtarget.hasStdExtF()) + setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT, + ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}); + if (Subtarget.hasVInstructions()) + setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, + ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL, + ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR}); + if (Subtarget.useRVVForFixedLengthVectors()) + setTargetDAGCombine(ISD::BITCAST); setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2"); setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2"); @@ -1149,6 +1029,24 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.size = MemoryLocation::UnknownSize; Info.flags |= MachineMemOperand::MOStore; return true; + case Intrinsic::riscv_seg2_load: + case Intrinsic::riscv_seg3_load: + case Intrinsic::riscv_seg4_load: + case Intrinsic::riscv_seg5_load: + case Intrinsic::riscv_seg6_load: + case Intrinsic::riscv_seg7_load: + case Intrinsic::riscv_seg8_load: + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.ptrVal = I.getArgOperand(0); + Info.memVT = + getValueType(DL, I.getType()->getStructElementType(0)->getScalarType()); + Info.align = + Align(DL.getTypeSizeInBits( + I.getType()->getStructElementType(0)->getScalarType()) / + 8); + Info.size = MemoryLocation::UnknownSize; + Info.flags |= MachineMemOperand::MOLoad; + return true; } } @@ -1160,6 +1058,10 @@ bool RISCVTargetLowering::isLegalAddressingMode(const DataLayout &DL, if (AM.BaseGV) return false; + // RVV instructions only support register addressing. + if (Subtarget.hasVInstructions() && isa<VectorType>(Ty)) + return AM.HasBaseReg && AM.Scale == 0 && !AM.BaseOffs; + // Require a 12-bit signed offset. if (!isInt<12>(AM.BaseOffs)) return false; @@ -1225,6 +1127,10 @@ bool RISCVTargetLowering::isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const { return Subtarget.is64Bit() && SrcVT == MVT::i32 && DstVT == MVT::i64; } +bool RISCVTargetLowering::signExtendConstant(const ConstantInt *CI) const { + return Subtarget.is64Bit() && CI->getType()->isIntegerTy(32); +} + bool RISCVTargetLowering::isCheapToSpeculateCttz() const { return Subtarget.hasStdExtZbb(); } @@ -1245,6 +1151,36 @@ bool RISCVTargetLowering::hasAndNotCompare(SDValue Y) const { !isa<ConstantSDNode>(Y); } +bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const { + // We can use ANDI+SEQZ/SNEZ as a bit test. Y contains the bit position. + auto *C = dyn_cast<ConstantSDNode>(Y); + return C && C->getAPIntValue().ule(10); +} + +bool RISCVTargetLowering:: + shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd( + SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y, + unsigned OldShiftOpcode, unsigned NewShiftOpcode, + SelectionDAG &DAG) const { + // One interesting pattern that we'd want to form is 'bit extract': + // ((1 >> Y) & 1) ==/!= 0 + // But we also need to be careful not to try to reverse that fold. + + // Is this '((1 >> Y) & 1)'? + if (XC && OldShiftOpcode == ISD::SRL && XC->isOne()) + return false; // Keep the 'bit extract' pattern. + + // Will this be '((1 >> Y) & 1)' after the transform? + if (NewShiftOpcode == ISD::SRL && CC->isOne()) + return true; // Do form the 'bit extract' pattern. + + // If 'X' is a constant, and we transform, then we will immediately + // try to undo the fold, thus causing endless combine loop. + // So only do the transform if X is not a constant. This matches the default + // implementation of this function. + return !XC; +} + /// Check if sinking \p I's operands to I's basic block is profitable, because /// the operands can be folded into a target instruction, e.g. /// splats of scalars can fold into vector instructions. @@ -1282,6 +1218,7 @@ bool RISCVTargetLowering::shouldSinkOperands( if (auto *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { case Intrinsic::fma: + case Intrinsic::vp_fma: return Operand == 0 || Operand == 1; // FIXME: Our patterns can only match vx/vf instructions when the splat // it on the RHS, because TableGen doesn't recognize our VP operations @@ -1345,6 +1282,15 @@ bool RISCVTargetLowering::shouldSinkOperands( return true; } +bool RISCVTargetLowering::isOffsetFoldingLegal( + const GlobalAddressSDNode *GA) const { + // In order to maximise the opportunity for common subexpression elimination, + // keep a separate ADD node for the global address offset instead of folding + // it in the global address node. Later peephole optimisations may choose to + // fold it back in when profitable. + return false; +} + bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, bool ForCodeSize) const { // FIXME: Change to Zfhmin once f16 becomes a legal type with Zfhmin. @@ -1583,7 +1529,7 @@ static bool useRVVForFixedLengthVectorVT(MVT VT, if (VT.getFixedSizeInBits() > 1024 * 8) return false; - unsigned MinVLen = Subtarget.getMinRVVVectorSizeInBits(); + unsigned MinVLen = Subtarget.getRealMinVLen(); MVT EltVT = VT.getVectorElementType(); @@ -1621,7 +1567,7 @@ static bool useRVVForFixedLengthVectorVT(MVT VT, } // Reject elements larger than ELEN. - if (EltVT.getSizeInBits() > Subtarget.getMaxELENForFixedLengthVectors()) + if (EltVT.getSizeInBits() > Subtarget.getELEN()) return false; unsigned LMul = divideCeil(VT.getSizeInBits(), MinVLen); @@ -1649,8 +1595,8 @@ static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT, useRVVForFixedLengthVectorVT(VT, Subtarget)) && "Expected legal fixed length vector!"); - unsigned MinVLen = Subtarget.getMinRVVVectorSizeInBits(); - unsigned MaxELen = Subtarget.getMaxELENForFixedLengthVectors(); + unsigned MinVLen = Subtarget.getRealMinVLen(); + unsigned MaxELen = Subtarget.getELEN(); MVT EltVT = VT.getVectorElementType(); switch (EltVT.SimpleTy) { @@ -1710,6 +1656,23 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG, return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero); } +/// Return the type of the mask type suitable for masking the provided +/// vector type. This is simply an i1 element type vector of the same +/// (possibly scalable) length. +static MVT getMaskTypeFor(EVT VecVT) { + assert(VecVT.isVector()); + ElementCount EC = VecVT.getVectorElementCount(); + return MVT::getVectorVT(MVT::i1, EC); +} + +/// Creates an all ones mask suitable for masking a vector of type VecTy with +/// vector length VL. . +static SDValue getAllOnesMask(MVT VecVT, SDValue VL, SDLoc DL, + SelectionDAG &DAG) { + MVT MaskVT = getMaskTypeFor(VecVT); + return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); +} + // Gets the two common "VL" operands: an all-ones mask and the vector length. // VecVT is a vector type, either fixed-length or scalable, and ContainerVT is // the vector type that it is contained in. @@ -1720,9 +1683,8 @@ getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG, MVT XLenVT = Subtarget.getXLenVT(); SDValue VL = VecVT.isFixedLengthVector() ? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT) - : DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT); - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + : DAG.getRegister(RISCV::X0, XLenVT); + SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); return {Mask, VL}; } @@ -1747,14 +1709,6 @@ bool RISCVTargetLowering::shouldExpandBuildVectorWithShuffles( return false; } -bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { - // Only splats are currently supported. - if (ShuffleVectorSDNode::isSplatMask(M.data(), VT)) - return true; - - return false; -} - static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { // RISCV FP-to-int conversions saturate to the destination register size, but @@ -1796,7 +1750,7 @@ static SDValue lowerFTRUNC_FCEIL_FFLOOR(SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); // Freeze the source since we are increasing the number of uses. - SDValue Src = DAG.getNode(ISD::FREEZE, DL, VT, Op.getOperand(0)); + SDValue Src = DAG.getFreeze(Op.getOperand(0)); // Truncate to integer and convert back to FP. MVT IntVT = VT.changeVectorElementTypeToInteger(); @@ -1844,21 +1798,56 @@ static SDValue lowerFTRUNC_FCEIL_FFLOOR(SDValue Op, SelectionDAG &DAG) { return DAG.getSelect(DL, VT, Setcc, Truncated, Src); } -static SDValue lowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { +// ISD::FROUND is defined to round to nearest with ties rounding away from 0. +// This mode isn't supported in vector hardware on RISCV. But as long as we +// aren't compiling with trapping math, we can emulate this with +// floor(X + copysign(nextafter(0.5, 0.0), X)). +// FIXME: Could be shorter by changing rounding mode, but we don't have FRM +// dependencies modeled yet. +// FIXME: Use masked operations to avoid final merge. +static SDValue lowerFROUND(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); - assert(VT.isFixedLengthVector() && "Unexpected vector!"); - - MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + assert(VT.isVector() && "Unexpected type"); SDLoc DL(Op); - SDValue Mask, VL; - std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); - unsigned Opc = - VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL; - SDValue Splat = DAG.getNode(Opc, DL, ContainerVT, Op.getOperand(0), VL); - return convertFromScalableVector(VT, Splat, DAG, Subtarget); + // Freeze the source since we are increasing the number of uses. + SDValue Src = DAG.getFreeze(Op.getOperand(0)); + + // We do the conversion on the absolute value and fix the sign at the end. + SDValue Abs = DAG.getNode(ISD::FABS, DL, VT, Src); + + const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); + bool Ignored; + APFloat Point5Pred = APFloat(0.5f); + Point5Pred.convert(FltSem, APFloat::rmNearestTiesToEven, &Ignored); + Point5Pred.next(/*nextDown*/ true); + + // Add the adjustment. + SDValue Adjust = DAG.getNode(ISD::FADD, DL, VT, Abs, + DAG.getConstantFP(Point5Pred, DL, VT)); + + // Truncate to integer and convert back to fp. + MVT IntVT = VT.changeVectorElementTypeToInteger(); + SDValue Truncated = DAG.getNode(ISD::FP_TO_SINT, DL, IntVT, Adjust); + Truncated = DAG.getNode(ISD::SINT_TO_FP, DL, VT, Truncated); + + // Restore the original sign. + Truncated = DAG.getNode(ISD::FCOPYSIGN, DL, VT, Truncated, Src); + + // Determine the largest integer that can be represented exactly. This and + // values larger than it don't have any fractional bits so don't need to + // be converted. + unsigned Precision = APFloat::semanticsPrecision(FltSem); + APFloat MaxVal = APFloat(FltSem); + MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1), + /*IsSigned*/ false, APFloat::rmNearestTiesToEven); + SDValue MaxValNode = DAG.getConstantFP(MaxVal, DL, VT); + + // If abs(Src) was larger than MaxVal or nan, keep it. + MVT SetccVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); + SDValue Setcc = DAG.getSetCC(DL, SetccVT, Abs, MaxValNode, ISD::SETOLT); + return DAG.getSelect(DL, VT, Setcc, Truncated, Src); } struct VIDSequence { @@ -1908,37 +1897,27 @@ static Optional<VIDSequence> isSimpleVIDSequence(SDValue Op) { // A zero-value value difference means that we're somewhere in the middle // of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a // step change before evaluating the sequence. - if (ValDiff != 0) { - int64_t Remainder = ValDiff % IdxDiff; - // Normalize the step if it's greater than 1. - if (Remainder != ValDiff) { - // The difference must cleanly divide the element span. - if (Remainder != 0) - return None; - ValDiff /= IdxDiff; - IdxDiff = 1; - } - - if (!SeqStepNum) - SeqStepNum = ValDiff; - else if (ValDiff != SeqStepNum) - return None; + if (ValDiff == 0) + continue; - if (!SeqStepDenom) - SeqStepDenom = IdxDiff; - else if (IdxDiff != *SeqStepDenom) + int64_t Remainder = ValDiff % IdxDiff; + // Normalize the step if it's greater than 1. + if (Remainder != ValDiff) { + // The difference must cleanly divide the element span. + if (Remainder != 0) return None; + ValDiff /= IdxDiff; + IdxDiff = 1; } - } - // Record and/or check any addend. - if (SeqStepNum && SeqStepDenom) { - uint64_t ExpectedVal = - (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom; - int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits); - if (!SeqAddend) - SeqAddend = Addend; - else if (SeqAddend != Addend) + if (!SeqStepNum) + SeqStepNum = ValDiff; + else if (ValDiff != SeqStepNum) + return None; + + if (!SeqStepDenom) + SeqStepDenom = IdxDiff; + else if (IdxDiff != *SeqStepDenom) return None; } @@ -1946,14 +1925,68 @@ static Optional<VIDSequence> isSimpleVIDSequence(SDValue Op) { if (!PrevElt || PrevElt->first != Val) PrevElt = std::make_pair(Val, Idx); } - // We need to have logged both a step and an addend for this to count as - // a legal index sequence. - if (!SeqStepNum || !SeqStepDenom || !SeqAddend) + + // We need to have logged a step for this to count as a legal index sequence. + if (!SeqStepNum || !SeqStepDenom) return None; + // Loop back through the sequence and validate elements we might have skipped + // while waiting for a valid step. While doing this, log any sequence addend. + for (unsigned Idx = 0; Idx < NumElts; Idx++) { + if (Op.getOperand(Idx).isUndef()) + continue; + uint64_t Val = Op.getConstantOperandVal(Idx) & + maskTrailingOnes<uint64_t>(EltSizeInBits); + uint64_t ExpectedVal = + (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom; + int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits); + if (!SeqAddend) + SeqAddend = Addend; + else if (Addend != SeqAddend) + return None; + } + + assert(SeqAddend && "Must have an addend if we have a step"); + return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend}; } +// Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT +// and lower it as a VRGATHER_VX_VL from the source vector. +static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL, + SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + SDValue Vec = SplatVal.getOperand(0); + // Only perform this optimization on vectors of the same size for simplicity. + // Don't perform this optimization for i1 vectors. + // FIXME: Support i1 vectors, maybe by promoting to i8? + if (Vec.getValueType() != VT || VT.getVectorElementType() == MVT::i1) + return SDValue(); + SDValue Idx = SplatVal.getOperand(1); + // The index must be a legal type. + if (Idx.getValueType() != Subtarget.getXLenVT()) + return SDValue(); + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); + } + + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec, + Idx, Mask, DAG.getUNDEF(ContainerVT), VL); + + if (!VT.isFixedLengthVector()) + return Gather; + + return convertFromScalableVector(VT, Gather, DAG, Subtarget); +} + static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); @@ -1989,8 +2022,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // codegen across RV32 and RV64. unsigned NumViaIntegerBits = std::min(std::max(NumElts, 8u), Subtarget.getXLen()); - NumViaIntegerBits = std::min(NumViaIntegerBits, - Subtarget.getMaxELENForFixedLengthVectors()); + NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELEN()); if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { // If we have to use more than one INSERT_VECTOR_ELT then this // optimization is likely to increase code size; avoid peforming it in @@ -2012,7 +2044,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // our vector and clear our accumulated data. if (I != 0 && I % NumViaIntegerBits == 0) { if (NumViaIntegerBits <= 32) - Bits = SignExtend64(Bits, 32); + Bits = SignExtend64<32>(Bits); SDValue Elt = DAG.getConstant(Bits, DL, XLenVT); Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, IntegerViaVecVT, Vec, Elt, DAG.getConstant(IntegerEltIdx, DL, XLenVT)); @@ -2028,7 +2060,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // Insert the (remaining) scalar value into position in our integer // vector type. if (NumViaIntegerBits <= 32) - Bits = SignExtend64(Bits, 32); + Bits = SignExtend64<32>(Bits); SDValue Elt = DAG.getConstant(Bits, DL, XLenVT); Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, IntegerViaVecVT, Vec, Elt, DAG.getConstant(IntegerEltIdx, DL, XLenVT)); @@ -2077,9 +2109,12 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, } if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) { + if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget)) + return Gather; unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL; - Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL); + Splat = + DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL); return convertFromScalableVector(VT, Splat, DAG, Subtarget); } @@ -2109,7 +2144,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // a single addi instruction. if (((StepOpcode == ISD::MUL && isInt<12>(SplatStepVal)) || (StepOpcode == ISD::SHL && isUInt<5>(SplatStepVal))) && - isPowerOf2_32(StepDenominator) && isInt<5>(Addend)) { + isPowerOf2_32(StepDenominator) && + (SplatStepVal >= 0 || StepDenominator == 1) && isInt<5>(Addend)) { SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, ContainerVT, Mask, VL); // Convert right out of the scalable type so we can use standard ISD // nodes for the rest of the computation. If we used scalable types with @@ -2118,18 +2154,18 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, VID = convertFromScalableVector(VT, VID, DAG, Subtarget); if ((StepOpcode == ISD::MUL && SplatStepVal != 1) || (StepOpcode == ISD::SHL && SplatStepVal != 0)) { - SDValue SplatStep = DAG.getSplatVector( + SDValue SplatStep = DAG.getSplatBuildVector( VT, DL, DAG.getConstant(SplatStepVal, DL, XLenVT)); VID = DAG.getNode(StepOpcode, DL, VT, VID, SplatStep); } if (StepDenominator != 1) { - SDValue SplatStep = DAG.getSplatVector( + SDValue SplatStep = DAG.getSplatBuildVector( VT, DL, DAG.getConstant(Log2_64(StepDenominator), DL, XLenVT)); VID = DAG.getNode(ISD::SRL, DL, VT, VID, SplatStep); } if (Addend != 0 || Negate) { - SDValue SplatAddend = - DAG.getSplatVector(VT, DL, DAG.getConstant(Addend, DL, XLenVT)); + SDValue SplatAddend = DAG.getSplatBuildVector( + VT, DL, DAG.getConstant(Addend, DL, XLenVT)); VID = DAG.getNode(Negate ? ISD::SUB : ISD::ADD, DL, VT, SplatAddend, VID); } return VID; @@ -2172,7 +2208,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // On RV64, sign-extend from 32 to 64 bits where possible in order to // achieve better constant materializion. if (Subtarget.is64Bit() && ViaIntVT == MVT::i32) - SplatValue = SignExtend64(SplatValue, 32); + SplatValue = SignExtend64<32>(SplatValue); // Since we can't introduce illegal i64 types at this stage, we can only // perform an i64 splat on RV32 if it is its own sign-extended value. That @@ -2187,6 +2223,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, getContainerForFixedLengthVector(DAG, ViaVecVT, Subtarget); SDValue Splat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ViaContainerVT, + DAG.getUNDEF(ViaContainerVT), DAG.getConstant(SplatValue, DL, XLenVT), ViaVL); Splat = convertFromScalableVector(ViaVecVT, Splat, DAG, Subtarget); return DAG.getBitcast(VT, Splat); @@ -2274,57 +2311,66 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, return SDValue(); } -static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Lo, - SDValue Hi, SDValue VL, SelectionDAG &DAG) { +static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru, + SDValue Lo, SDValue Hi, SDValue VL, + SelectionDAG &DAG) { + if (!Passthru) + Passthru = DAG.getUNDEF(VT); if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) { int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue(); int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue(); // If Hi constant is all the same sign bit as Lo, lower this as a custom // node in order to try and match RVV vector/scalar instructions. if ((LoC >> 31) == HiC) - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Lo, VL); + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL); - // If vl is equal to VLMax and Hi constant is equal to Lo, we could use + // If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use // vmv.v.x whose EEW = 32 to lower it. auto *Const = dyn_cast<ConstantSDNode>(VL); - if (LoC == HiC && Const && Const->getSExtValue() == RISCV::VLMaxSentinel) { + if (LoC == HiC && Const && Const->isAllOnesValue()) { MVT InterVT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2); // TODO: if vl <= min(VLMAX), we can also do this. But we could not // access the subtarget here now. - auto InterVec = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterVT, Lo, VL); + auto InterVec = DAG.getNode( + RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo, + DAG.getRegister(RISCV::X0, MVT::i32)); return DAG.getNode(ISD::BITCAST, DL, VT, InterVec); } } // Fall back to a stack store and stride x0 vector load. - return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VT, Lo, Hi, VL); + return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VT, Passthru, Lo, + Hi, VL); } // Called by type legalization to handle splat of i64 on RV32. // FIXME: We can optimize this when the type has sign or zero bits in one // of the halves. -static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Scalar, - SDValue VL, SelectionDAG &DAG) { +static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru, + SDValue Scalar, SDValue VL, + SelectionDAG &DAG) { assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!"); SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar, DAG.getConstant(0, DL, MVT::i32)); SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar, DAG.getConstant(1, DL, MVT::i32)); - return splatPartsI64WithVL(DL, VT, Lo, Hi, VL, DAG); + return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG); } // This function lowers a splat of a scalar operand Splat with the vector // length VL. It ensures the final sequence is type legal, which is useful when // lowering a splat after type legalization. -static SDValue lowerScalarSplat(SDValue Scalar, SDValue VL, MVT VT, SDLoc DL, - SelectionDAG &DAG, +static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL, + MVT VT, SDLoc DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { + bool HasPassthru = Passthru && !Passthru.isUndef(); + if (!HasPassthru && !Passthru) + Passthru = DAG.getUNDEF(VT); if (VT.isFloatingPoint()) { // If VL is 1, we could use vfmv.s.f. if (isOneConstant(VL)) - return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, DAG.getUNDEF(VT), - Scalar, VL); - return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Scalar, VL); + return DAG.getNode(RISCVISD::VFMV_S_F_VL, DL, VT, Passthru, Scalar, VL); + return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, VT, Passthru, Scalar, VL); } MVT XLenVT = Subtarget.getXLenVT(); @@ -2343,55 +2389,25 @@ static SDValue lowerScalarSplat(SDValue Scalar, SDValue VL, MVT VT, SDLoc DL, // use vmv.s.x. if (isOneConstant(VL) && (!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue()))) - return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, DAG.getUNDEF(VT), Scalar, - VL); - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Scalar, VL); + return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL); + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL); } assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 && "Unexpected scalar for splat lowering!"); if (isOneConstant(VL) && isNullConstant(Scalar)) - return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, DAG.getUNDEF(VT), + return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, DAG.getConstant(0, DL, XLenVT), VL); // Otherwise use the more complicated splatting algorithm. - return splatSplitI64WithVL(DL, VT, Scalar, VL, DAG); -} - -// Is the mask a slidedown that shifts in undefs. -static int matchShuffleAsSlideDown(ArrayRef<int> Mask) { - int Size = Mask.size(); - - // Elements shifted in should be undef. - auto CheckUndefs = [&](int Shift) { - for (int i = Size - Shift; i != Size; ++i) - if (Mask[i] >= 0) - return false; - return true; - }; - - // Elements should be shifted or undef. - auto MatchShift = [&](int Shift) { - for (int i = 0; i != Size - Shift; ++i) - if (Mask[i] >= 0 && Mask[i] != Shift + i) - return false; - return true; - }; - - // Try all possible shifts. - for (int Shift = 1; Shift != Size; ++Shift) - if (CheckUndefs(Shift) && MatchShift(Shift)) - return Shift; - - // No match. - return -1; + return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG); } static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, bool &SwapSources, const RISCVSubtarget &Subtarget) { // We need to be able to widen elements to the next larger integer type. - if (VT.getScalarSizeInBits() >= Subtarget.getMaxELENForFixedLengthVectors()) + if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) return false; int Size = Mask.size(); @@ -2430,6 +2446,79 @@ static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, bool &SwapSources, return true; } +/// Match shuffles that concatenate two vectors, rotate the concatenation, +/// and then extract the original number of elements from the rotated result. +/// This is equivalent to vector.splice or X86's PALIGNR instruction. The +/// returned rotation amount is for a rotate right, where elements move from +/// higher elements to lower elements. \p LoSrc indicates the first source +/// vector of the rotate or -1 for undef. \p HiSrc indicates the second vector +/// of the rotate or -1 for undef. At least one of \p LoSrc and \p HiSrc will be +/// 0 or 1 if a rotation is found. +/// +/// NOTE: We talk about rotate to the right which matches how bit shift and +/// rotate instructions are described where LSBs are on the right, but LLVM IR +/// and the table below write vectors with the lowest elements on the left. +static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) { + int Size = Mask.size(); + + // We need to detect various ways of spelling a rotation: + // [11, 12, 13, 14, 15, 0, 1, 2] + // [-1, 12, 13, 14, -1, -1, 1, -1] + // [-1, -1, -1, -1, -1, -1, 1, 2] + // [ 3, 4, 5, 6, 7, 8, 9, 10] + // [-1, 4, 5, 6, -1, -1, 9, -1] + // [-1, 4, 5, 6, -1, -1, -1, -1] + int Rotation = 0; + LoSrc = -1; + HiSrc = -1; + for (int i = 0; i != Size; ++i) { + int M = Mask[i]; + if (M < 0) + continue; + + // Determine where a rotate vector would have started. + int StartIdx = i - (M % Size); + // The identity rotation isn't interesting, stop. + if (StartIdx == 0) + return -1; + + // If we found the tail of a vector the rotation must be the missing + // front. If we found the head of a vector, it must be how much of the + // head. + int CandidateRotation = StartIdx < 0 ? -StartIdx : Size - StartIdx; + + if (Rotation == 0) + Rotation = CandidateRotation; + else if (Rotation != CandidateRotation) + // The rotations don't match, so we can't match this mask. + return -1; + + // Compute which value this mask is pointing at. + int MaskSrc = M < Size ? 0 : 1; + + // Compute which of the two target values this index should be assigned to. + // This reflects whether the high elements are remaining or the low elemnts + // are remaining. + int &TargetSrc = StartIdx < 0 ? HiSrc : LoSrc; + + // Either set up this value if we've not encountered it before, or check + // that it remains consistent. + if (TargetSrc < 0) + TargetSrc = MaskSrc; + else if (TargetSrc != MaskSrc) + // This may be a rotation, but it pulls from the inputs in some + // unsupported interleaving. + return -1; + } + + // Check that we successfully analyzed the mask, and normalize the results. + assert(Rotation != 0 && "Failed to locate a viable rotation!"); + assert((LoSrc >= 0 || HiSrc >= 0) && + "Failed to find a rotated input vector!"); + + return Rotation; +} + static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue V1 = Op.getOperand(0); @@ -2506,33 +2595,59 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL : RISCVISD::VMV_V_X_VL; - SDValue Splat = DAG.getNode(Opc, DL, ContainerVT, V, VL); + SDValue Splat = + DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V, VL); return convertFromScalableVector(VT, Splat, DAG, Subtarget); } V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget); assert(Lane < (int)NumElts && "Unexpected lane!"); - SDValue Gather = - DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, V1, - DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL); + SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, + V1, DAG.getConstant(Lane, DL, XLenVT), + TrueMask, DAG.getUNDEF(ContainerVT), VL); return convertFromScalableVector(VT, Gather, DAG, Subtarget); } } ArrayRef<int> Mask = SVN->getMask(); - // Try to match as a slidedown. - int SlideAmt = matchShuffleAsSlideDown(Mask); - if (SlideAmt >= 0) { - // TODO: Should we reduce the VL to account for the upper undef elements? - // Requires additional vsetvlis, but might be faster to execute. - V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget); - SDValue SlideDown = - DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT, - DAG.getUNDEF(ContainerVT), V1, - DAG.getConstant(SlideAmt, DL, XLenVT), - TrueMask, VL); - return convertFromScalableVector(VT, SlideDown, DAG, Subtarget); + // Lower rotations to a SLIDEDOWN and a SLIDEUP. One of the source vectors may + // be undef which can be handled with a single SLIDEDOWN/UP. + int LoSrc, HiSrc; + int Rotation = isElementRotate(LoSrc, HiSrc, Mask); + if (Rotation > 0) { + SDValue LoV, HiV; + if (LoSrc >= 0) { + LoV = LoSrc == 0 ? V1 : V2; + LoV = convertToScalableVector(ContainerVT, LoV, DAG, Subtarget); + } + if (HiSrc >= 0) { + HiV = HiSrc == 0 ? V1 : V2; + HiV = convertToScalableVector(ContainerVT, HiV, DAG, Subtarget); + } + + // We found a rotation. We need to slide HiV down by Rotation. Then we need + // to slide LoV up by (NumElts - Rotation). + unsigned InvRotate = NumElts - Rotation; + + SDValue Res = DAG.getUNDEF(ContainerVT); + if (HiV) { + // If we are doing a SLIDEDOWN+SLIDEUP, reduce the VL for the SLIDEDOWN. + // FIXME: If we are only doing a SLIDEDOWN, don't reduce the VL as it + // causes multiple vsetvlis in some test cases such as lowering + // reduce.mul + SDValue DownVL = VL; + if (LoV) + DownVL = DAG.getConstant(InvRotate, DL, XLenVT); + Res = + DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT, Res, HiV, + DAG.getConstant(Rotation, DL, XLenVT), TrueMask, DownVL); + } + if (LoV) + Res = DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, ContainerVT, Res, LoV, + DAG.getConstant(InvRotate, DL, XLenVT), TrueMask, VL); + + return convertFromScalableVector(VT, Res, DAG, Subtarget); } // Detect an interleave shuffle and lower to @@ -2576,18 +2691,17 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, // Freeze V2 since we use it twice and we need to be sure that the add and // multiply see the same value. - V2 = DAG.getNode(ISD::FREEZE, DL, IntHalfVT, V2); + V2 = DAG.getFreeze(V2); // Recreate TrueMask using the widened type's element count. - MVT MaskVT = - MVT::getVectorVT(MVT::i1, HalfContainerVT.getVectorElementCount()); - TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + TrueMask = getAllOnesMask(HalfContainerVT, VL, DL, DAG); // Widen V1 and V2 with 0s and add one copy of V2 to V1. SDValue Add = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideIntContainerVT, V1, V2, TrueMask, VL); // Create 2^eltbits - 1 copies of V2 by multiplying by the largest integer. SDValue Multiplier = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntHalfVT, + DAG.getUNDEF(IntHalfVT), DAG.getAllOnesConstant(DL, XLenVT)); SDValue WidenMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideIntContainerVT, V2, Multiplier, TrueMask, VL); @@ -2691,7 +2805,8 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, // TODO: This doesn't trigger for i64 vectors on RV32, since there we // encounter a bitcasted BUILD_VECTOR with low/high i32 values. if (SDValue SplatValue = DAG.getSplatValue(V1, /*LegalTypes*/ true)) { - Gather = lowerScalarSplat(SplatValue, VL, ContainerVT, DL, DAG, Subtarget); + Gather = lowerScalarSplat(SDValue(), SplatValue, VL, ContainerVT, DL, DAG, + Subtarget); } else { V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget); // If only one index is used, we can use a "splat" vrgather. @@ -2699,16 +2814,16 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, // that's beneficial. if (LHSIndexCounts.size() == 1) { int SplatIndex = LHSIndexCounts.begin()->getFirst(); - Gather = - DAG.getNode(GatherVXOpc, DL, ContainerVT, V1, - DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL); + Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V1, + DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, + DAG.getUNDEF(ContainerVT), VL); } else { SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS); LHSIndices = convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget); Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices, - TrueMask, VL); + TrueMask, DAG.getUNDEF(ContainerVT), VL); } } @@ -2716,45 +2831,46 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, // additional vrgather. if (!V2.isUndef()) { V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget); + + MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); + SelectMask = + convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget); + // If only one index is used, we can use a "splat" vrgather. // TODO: We can splat the most-common index and fix-up any stragglers, if // that's beneficial. if (RHSIndexCounts.size() == 1) { int SplatIndex = RHSIndexCounts.begin()->getFirst(); - V2 = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2, - DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL); + Gather = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2, + DAG.getConstant(SplatIndex, DL, XLenVT), SelectMask, + Gather, VL); } else { SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS); RHSIndices = convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget); - V2 = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, TrueMask, - VL); + Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, + SelectMask, Gather, VL); } - - MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); - SelectMask = - convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget); - - Gather = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, SelectMask, V2, - Gather, VL); } return convertFromScalableVector(VT, Gather, DAG, Subtarget); } -static SDValue getRVVFPExtendOrRound(SDValue Op, MVT VT, MVT ContainerVT, - SDLoc DL, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - if (VT.isScalableVector()) - return DAG.getFPExtendOrRound(Op, DL, VT); - assert(VT.isFixedLengthVector() && - "Unexpected value type for RVV FP extend/round lowering"); - SDValue Mask, VL; - std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); - unsigned RVVOpc = ContainerVT.bitsGT(Op.getSimpleValueType()) - ? RISCVISD::FP_EXTEND_VL - : RISCVISD::FP_ROUND_VL; - return DAG.getNode(RVVOpc, DL, ContainerVT, Op, Mask, VL); +bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { + // Support splats for any type. These should type legalize well. + if (ShuffleVectorSDNode::isSplatMask(M.data(), VT)) + return true; + + // Only support legal VTs for other shuffles for now. + if (!isTypeLegal(VT)) + return false; + + MVT SVT = VT.getSimpleVT(); + + bool SwapSources; + int LoSrc, HiSrc; + return (isElementRotate(LoSrc, HiSrc, M) > 0) || + isInterleaveShuffle(M, SVT, SwapSources, Subtarget); } // Lower CTLZ_ZERO_UNDEF or CTTZ_ZERO_UNDEF by converting to FP and extracting @@ -2868,6 +2984,32 @@ SDValue RISCVTargetLowering::expandUnalignedRVVStore(SDValue Op, Store->getMemOperand()->getFlags()); } +static SDValue lowerConstant(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(Op.getValueType() == MVT::i64 && "Unexpected VT"); + + int64_t Imm = cast<ConstantSDNode>(Op)->getSExtValue(); + + // All simm32 constants should be handled by isel. + // NOTE: The getMaxBuildIntsCost call below should return a value >= 2 making + // this check redundant, but small immediates are common so this check + // should have better compile time. + if (isInt<32>(Imm)) + return Op; + + // We only need to cost the immediate, if constant pool lowering is enabled. + if (!Subtarget.useConstantPoolForLargeInts()) + return Op; + + RISCVMatInt::InstSeq Seq = + RISCVMatInt::generateInstSeq(Imm, Subtarget.getFeatureBits()); + if (Seq.size() <= Subtarget.getMaxBuildIntsCost()) + return Op; + + // Expand to a constant pool using the default expansion code. + return SDValue(); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -2883,6 +3025,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerJumpTable(Op, DAG); case ISD::GlobalTLSAddress: return lowerGlobalTLSAddress(Op, DAG); + case ISD::Constant: + return lowerConstant(Op, DAG, Subtarget); case ISD::SELECT: return lowerSELECT(Op, DAG); case ISD::BRCOND: @@ -2905,6 +3049,30 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SDValue Op0 = Op.getOperand(0); EVT Op0VT = Op0.getValueType(); MVT XLenVT = Subtarget.getXLenVT(); + if (VT == MVT::f16 && Op0VT == MVT::i16 && Subtarget.hasStdExtZfh()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); + SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); + return FPConv; + } + if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && + Subtarget.hasStdExtF()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); + SDValue FPConv = + DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, NewOp0); + return FPConv; + } + + // Consider other scalar<->scalar casts as legal if the types are legal. + // Otherwise expand them. + if (!VT.isVector() && !Op0VT.isVector()) { + if (isTypeLegal(VT) && isTypeLegal(Op0VT)) + return Op; + return SDValue(); + } + + assert(!VT.isScalableVector() && !Op0VT.isScalableVector() && + "Unexpected types"); + if (VT.isFixedLengthVector()) { // We can handle fixed length vector bitcasts with a simple replacement // in isel. @@ -2934,18 +3102,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec, DAG.getConstant(0, DL, XLenVT)); } - if (VT == MVT::f16 && Op0VT == MVT::i16 && Subtarget.hasStdExtZfh()) { - SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); - SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); - return FPConv; - } - if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && - Subtarget.hasStdExtF()) { - SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); - SDValue FPConv = - DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, NewOp0); - return FPConv; - } return SDValue(); } case ISD::INTRINSIC_WO_CHAIN: @@ -3002,55 +3158,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } return DAG.getNode(Opc, DL, VT, Op0, Op1, ShAmt); } - case ISD::TRUNCATE: { - SDLoc DL(Op); - MVT VT = Op.getSimpleValueType(); + case ISD::TRUNCATE: // Only custom-lower vector truncates - if (!VT.isVector()) + if (!Op.getSimpleValueType().isVector()) return Op; - - // Truncates to mask types are handled differently - if (VT.getVectorElementType() == MVT::i1) - return lowerVectorMaskTrunc(Op, DAG); - - // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary - // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which - // truncate by one power of two at a time. - MVT DstEltVT = VT.getVectorElementType(); - - SDValue Src = Op.getOperand(0); - MVT SrcVT = Src.getSimpleValueType(); - MVT SrcEltVT = SrcVT.getVectorElementType(); - - assert(DstEltVT.bitsLT(SrcEltVT) && - isPowerOf2_64(DstEltVT.getSizeInBits()) && - isPowerOf2_64(SrcEltVT.getSizeInBits()) && - "Unexpected vector truncate lowering"); - - MVT ContainerVT = SrcVT; - if (SrcVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(SrcVT); - Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); - } - - SDValue Result = Src; - SDValue Mask, VL; - std::tie(Mask, VL) = - getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); - LLVMContext &Context = *DAG.getContext(); - const ElementCount Count = ContainerVT.getVectorElementCount(); - do { - SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); - EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count); - Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, - Mask, VL); - } while (SrcEltVT != DstEltVT); - - if (SrcVT.isFixedLengthVector()) - Result = convertFromScalableVector(VT, Result, DAG, Subtarget); - - return Result; - } + return lowerVectorTruncLike(Op, DAG); case ISD::ANY_EXTEND: case ISD::ZERO_EXTEND: if (Op.getOperand(0).getValueType().isVector() && @@ -3076,28 +3188,26 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, // minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate // vscale as VLENB / 8. static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!"); - if (Subtarget.getMinVLen() < RISCV::RVVBitsPerBlock) + if (Subtarget.getRealMinVLen() < RISCV::RVVBitsPerBlock) report_fatal_error("Support for VLEN==32 is incomplete."); - if (isa<ConstantSDNode>(Op.getOperand(0))) { - // We assume VLENB is a multiple of 8. We manually choose the best shift - // here because SimplifyDemandedBits isn't always able to simplify it. - uint64_t Val = Op.getConstantOperandVal(0); - if (isPowerOf2_64(Val)) { - uint64_t Log2 = Log2_64(Val); - if (Log2 < 3) - return DAG.getNode(ISD::SRL, DL, VT, VLENB, - DAG.getConstant(3 - Log2, DL, VT)); - if (Log2 > 3) - return DAG.getNode(ISD::SHL, DL, VT, VLENB, - DAG.getConstant(Log2 - 3, DL, VT)); - return VLENB; - } - // If the multiplier is a multiple of 8, scale it down to avoid needing - // to shift the VLENB value. - if ((Val % 8) == 0) - return DAG.getNode(ISD::MUL, DL, VT, VLENB, - DAG.getConstant(Val / 8, DL, VT)); - } + // We assume VLENB is a multiple of 8. We manually choose the best shift + // here because SimplifyDemandedBits isn't always able to simplify it. + uint64_t Val = Op.getConstantOperandVal(0); + if (isPowerOf2_64(Val)) { + uint64_t Log2 = Log2_64(Val); + if (Log2 < 3) + return DAG.getNode(ISD::SRL, DL, VT, VLENB, + DAG.getConstant(3 - Log2, DL, VT)); + if (Log2 > 3) + return DAG.getNode(ISD::SHL, DL, VT, VLENB, + DAG.getConstant(Log2 - 3, DL, VT)); + return VLENB; + } + // If the multiplier is a multiple of 8, scale it down to avoid needing + // to shift the VLENB value. + if ((Val % 8) == 0) + return DAG.getNode(ISD::MUL, DL, VT, VLENB, + DAG.getConstant(Val / 8, DL, VT)); SDValue VScale = DAG.getNode(ISD::SRL, DL, VT, VLENB, DAG.getConstant(3, DL, VT)); @@ -3117,88 +3227,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } return SDValue(); } - case ISD::FP_EXTEND: { - // RVV can only do fp_extend to types double the size as the source. We - // custom-lower f16->f64 extensions to two hops of ISD::FP_EXTEND, going - // via f32. - SDLoc DL(Op); - MVT VT = Op.getSimpleValueType(); - SDValue Src = Op.getOperand(0); - MVT SrcVT = Src.getSimpleValueType(); - - // Prepare any fixed-length vector operands. - MVT ContainerVT = VT; - if (SrcVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(VT); - MVT SrcContainerVT = - ContainerVT.changeVectorElementType(SrcVT.getVectorElementType()); - Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); - } - - if (!VT.isVector() || VT.getVectorElementType() != MVT::f64 || - SrcVT.getVectorElementType() != MVT::f16) { - // For scalable vectors, we only need to close the gap between - // vXf16->vXf64. - if (!VT.isFixedLengthVector()) - return Op; - // For fixed-length vectors, lower the FP_EXTEND to a custom "VL" version. - Src = getRVVFPExtendOrRound(Src, VT, ContainerVT, DL, DAG, Subtarget); - return convertFromScalableVector(VT, Src, DAG, Subtarget); - } - - MVT InterVT = VT.changeVectorElementType(MVT::f32); - MVT InterContainerVT = ContainerVT.changeVectorElementType(MVT::f32); - SDValue IntermediateExtend = getRVVFPExtendOrRound( - Src, InterVT, InterContainerVT, DL, DAG, Subtarget); - - SDValue Extend = getRVVFPExtendOrRound(IntermediateExtend, VT, ContainerVT, - DL, DAG, Subtarget); - if (VT.isFixedLengthVector()) - return convertFromScalableVector(VT, Extend, DAG, Subtarget); - return Extend; - } - case ISD::FP_ROUND: { - // RVV can only do fp_round to types half the size as the source. We - // custom-lower f64->f16 rounds via RVV's round-to-odd float - // conversion instruction. - SDLoc DL(Op); - MVT VT = Op.getSimpleValueType(); - SDValue Src = Op.getOperand(0); - MVT SrcVT = Src.getSimpleValueType(); - - // Prepare any fixed-length vector operands. - MVT ContainerVT = VT; - if (VT.isFixedLengthVector()) { - MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT); - ContainerVT = - SrcContainerVT.changeVectorElementType(VT.getVectorElementType()); - Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); - } - - if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 || - SrcVT.getVectorElementType() != MVT::f64) { - // For scalable vectors, we only need to close the gap between - // vXf64<->vXf16. - if (!VT.isFixedLengthVector()) - return Op; - // For fixed-length vectors, lower the FP_ROUND to a custom "VL" version. - Src = getRVVFPExtendOrRound(Src, VT, ContainerVT, DL, DAG, Subtarget); - return convertFromScalableVector(VT, Src, DAG, Subtarget); - } - - SDValue Mask, VL; - std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); - - MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32); - SDValue IntermediateRound = - DAG.getNode(RISCVISD::VFNCVT_ROD_VL, DL, InterVT, Src, Mask, VL); - SDValue Round = getRVVFPExtendOrRound(IntermediateRound, VT, ContainerVT, - DL, DAG, Subtarget); - - if (VT.isFixedLengthVector()) - return convertFromScalableVector(VT, Round, DAG, Subtarget); - return Round; - } + case ISD::FP_EXTEND: + case ISD::FP_ROUND: + if (!Op.getValueType().isVector()) + return Op; + return lowerVectorFPExtendOrRoundLike(Op, DAG); case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: case ISD::SINT_TO_FP: @@ -3221,10 +3254,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, bool IsInt2FP = SrcEltVT.isInteger(); // Widening conversions - if (EltSize > SrcEltSize && (EltSize / SrcEltSize >= 4)) { + if (EltSize > (2 * SrcEltSize)) { if (IsInt2FP) { // Do a regular integer sign/zero extension then convert to float. - MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(EltVT.getSizeInBits()), + MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), VT.getVectorElementCount()); unsigned ExtOpcode = Op.getOpcode() == ISD::UINT_TO_FP ? ISD::ZERO_EXTEND @@ -3242,7 +3275,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } // Narrowing conversions - if (SrcEltSize > EltSize && (SrcEltSize / EltSize >= 4)) { + if (SrcEltSize > (2 * EltSize)) { if (IsInt2FP) { // One narrowing int_to_fp, then an fp_round. assert(EltVT == MVT::f16 && "Unexpected [US]_TO_FP lowering"); @@ -3253,9 +3286,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, // FP2Int // One narrowing fp_to_int, then truncate the integer. If the float isn't // representable by the integer, the result is poison. - MVT IVecVT = - MVT::getVectorVT(MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2), - VT.getVectorElementCount()); + MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2), + VT.getVectorElementCount()); SDValue FP2Int = DAG.getNode(Op.getOpcode(), DL, IVecVT, Src); return DAG.getNode(ISD::TRUNCATE, DL, VT, FP2Int); } @@ -3309,6 +3341,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::FCEIL: case ISD::FFLOOR: return lowerFTRUNC_FCEIL_FFLOOR(Op, DAG); + case ISD::FROUND: + return lowerFROUND(Op, DAG); case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_SMAX: @@ -3350,12 +3384,14 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerSTEP_VECTOR(Op, DAG); case ISD::VECTOR_REVERSE: return lowerVECTOR_REVERSE(Op, DAG); + case ISD::VECTOR_SPLICE: + return lowerVECTOR_SPLICE(Op, DAG); case ISD::BUILD_VECTOR: return lowerBUILD_VECTOR(Op, DAG, Subtarget); case ISD::SPLAT_VECTOR: if (Op.getValueType().getVectorElementType() == MVT::i1) return lowerVectorMaskSplat(Op, DAG); - return lowerSPLAT_VECTOR(Op, DAG, Subtarget); + return SDValue(); case ISD::VECTOR_SHUFFLE: return lowerVECTOR_SHUFFLE(Op, DAG, Subtarget); case ISD::CONCAT_VECTORS: { @@ -3455,7 +3491,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::FSQRT: return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL); case ISD::FMA: - return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL); case ISD::SMIN: return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL); case ISD::SMAX: @@ -3487,6 +3523,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerGET_ROUNDING(Op, DAG); case ISD::SET_ROUNDING: return lowerSET_ROUNDING(Op, DAG); + case ISD::EH_DWARF_CFA: + return lowerEH_DWARF_CFA(Op, DAG); case ISD::VP_SELECT: return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL); case ISD::VP_MERGE: @@ -3525,6 +3563,35 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL); case ISD::VP_FDIV: return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL); + case ISD::VP_FNEG: + return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL); + case ISD::VP_FMA: + return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL); + case ISD::VP_SIGN_EXTEND: + case ISD::VP_ZERO_EXTEND: + if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) + return lowerVPExtMaskOp(Op, DAG); + return lowerVPOp(Op, DAG, + Op.getOpcode() == ISD::VP_SIGN_EXTEND + ? RISCVISD::VSEXT_VL + : RISCVISD::VZEXT_VL); + case ISD::VP_TRUNCATE: + return lowerVectorTruncLike(Op, DAG); + case ISD::VP_FP_EXTEND: + case ISD::VP_FP_ROUND: + return lowerVectorFPExtendOrRoundLike(Op, DAG); + case ISD::VP_FPTOSI: + return lowerVPFPIntConvOp(Op, DAG, RISCVISD::FP_TO_SINT_VL); + case ISD::VP_FPTOUI: + return lowerVPFPIntConvOp(Op, DAG, RISCVISD::FP_TO_UINT_VL); + case ISD::VP_SITOFP: + return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL); + case ISD::VP_UITOFP: + return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL); + case ISD::VP_SETCC: + if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1) + return lowerVPSetCCMaskOp(Op, DAG); + return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL); } } @@ -3562,12 +3629,21 @@ SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, // Use PC-relative addressing to access the symbol. This generates the // pattern (PseudoLLA sym), which expands to (addi (auipc %pcrel_hi(sym)) // %pcrel_lo(auipc)). - return SDValue(DAG.getMachineNode(RISCV::PseudoLLA, DL, Ty, Addr), 0); + return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr); // Use PC-relative addressing to access the GOT for this symbol, then load // the address from the GOT. This generates the pattern (PseudoLA sym), // which expands to (ld (addi (auipc %got_pcrel_hi(sym)) %pcrel_lo(auipc))). - return SDValue(DAG.getMachineNode(RISCV::PseudoLA, DL, Ty, Addr), 0); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MemOp = MF.getMachineMemOperand( + MachinePointerInfo::getGOT(MF), + MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable | + MachineMemOperand::MOInvariant, + LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8)); + SDValue Load = + DAG.getMemIntrinsicNode(RISCVISD::LA, DL, DAG.getVTList(Ty, MVT::Other), + {DAG.getEntryNode(), Addr}, Ty, MemOp); + return Load; } switch (getTargetMachine().getCodeModel()) { @@ -3578,15 +3654,15 @@ SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, // address space. This generates the pattern (addi (lui %hi(sym)) %lo(sym)). SDValue AddrHi = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_HI); SDValue AddrLo = getTargetNode(N, DL, Ty, DAG, RISCVII::MO_LO); - SDValue MNHi = SDValue(DAG.getMachineNode(RISCV::LUI, DL, Ty, AddrHi), 0); - return SDValue(DAG.getMachineNode(RISCV::ADDI, DL, Ty, MNHi, AddrLo), 0); + SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi); + return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNHi, AddrLo); } case CodeModel::Medium: { // Generate a sequence for accessing addresses within any 2GiB range within // the address space. This generates the pattern (PseudoLLA sym), which // expands to (addi (auipc %pcrel_hi(sym)) %pcrel_lo(auipc)). SDValue Addr = getTargetNode(N, DL, Ty, DAG, 0); - return SDValue(DAG.getMachineNode(RISCV::PseudoLLA, DL, Ty, Addr), 0); + return DAG.getNode(RISCVISD::LLA, DL, Ty, Addr); } } } @@ -3594,23 +3670,12 @@ SDValue RISCVTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, SDValue RISCVTargetLowering::lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); - EVT Ty = Op.getValueType(); GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op); - int64_t Offset = N->getOffset(); - MVT XLenVT = Subtarget.getXLenVT(); + assert(N->getOffset() == 0 && "unexpected offset in global node"); const GlobalValue *GV = N->getGlobal(); bool IsLocal = getTargetMachine().shouldAssumeDSOLocal(*GV->getParent(), GV); - SDValue Addr = getAddr(N, DAG, IsLocal); - - // In order to maximise the opportunity for common subexpression elimination, - // emit a separate ADD node for the global address offset instead of folding - // it in the global address node. Later peephole optimisations may choose to - // fold it back in when profitable. - if (Offset != 0) - return DAG.getNode(ISD::ADD, DL, Ty, Addr, - DAG.getConstant(Offset, DL, XLenVT)); - return Addr; + return getAddr(N, DAG, IsLocal); } SDValue RISCVTargetLowering::lowerBlockAddress(SDValue Op, @@ -3648,8 +3713,15 @@ SDValue RISCVTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N, // the pattern (PseudoLA_TLS_IE sym), which expands to // (ld (auipc %tls_ie_pcrel_hi(sym)) %pcrel_lo(auipc)). SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0); - SDValue Load = - SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_IE, DL, Ty, Addr), 0); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MemOp = MF.getMachineMemOperand( + MachinePointerInfo::getGOT(MF), + MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable | + MachineMemOperand::MOInvariant, + LLT(Ty.getSimpleVT()), Align(Ty.getFixedSizeInBits() / 8)); + SDValue Load = DAG.getMemIntrinsicNode( + RISCVISD::LA_TLS_IE, DL, DAG.getVTList(Ty, MVT::Other), + {DAG.getEntryNode(), Addr}, Ty, MemOp); // Add the thread pointer. SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT); @@ -3667,12 +3739,11 @@ SDValue RISCVTargetLowering::getStaticTLSAddr(GlobalAddressSDNode *N, SDValue AddrLo = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, RISCVII::MO_TPREL_LO); - SDValue MNHi = SDValue(DAG.getMachineNode(RISCV::LUI, DL, Ty, AddrHi), 0); + SDValue MNHi = DAG.getNode(RISCVISD::HI, DL, Ty, AddrHi); SDValue TPReg = DAG.getRegister(RISCV::X4, XLenVT); - SDValue MNAdd = SDValue( - DAG.getMachineNode(RISCV::PseudoAddTPRel, DL, Ty, MNHi, TPReg, AddrAdd), - 0); - return SDValue(DAG.getMachineNode(RISCV::ADDI, DL, Ty, MNAdd, AddrLo), 0); + SDValue MNAdd = + DAG.getNode(RISCVISD::ADD_TPREL, DL, Ty, MNHi, TPReg, AddrAdd); + return DAG.getNode(RISCVISD::ADD_LO, DL, Ty, MNAdd, AddrLo); } SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N, @@ -3686,8 +3757,7 @@ SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N, // This generates the pattern (PseudoLA_TLS_GD sym), which expands to // (addi (auipc %tls_gd_pcrel_hi(sym)) %pcrel_lo(auipc)). SDValue Addr = DAG.getTargetGlobalAddress(GV, DL, Ty, 0, 0); - SDValue Load = - SDValue(DAG.getMachineNode(RISCV::PseudoLA_TLS_GD, DL, Ty, Addr), 0); + SDValue Load = DAG.getNode(RISCVISD::LA_TLS_GD, DL, Ty, Addr); // Prepare argument list to generate call. ArgListTy Args; @@ -3710,10 +3780,8 @@ SDValue RISCVTargetLowering::getDynamicTLSAddr(GlobalAddressSDNode *N, SDValue RISCVTargetLowering::lowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); - EVT Ty = Op.getValueType(); GlobalAddressSDNode *N = cast<GlobalAddressSDNode>(Op); - int64_t Offset = N->getOffset(); - MVT XLenVT = Subtarget.getXLenVT(); + assert(N->getOffset() == 0 && "unexpected offset in global node"); TLSModel::Model Model = getTargetMachine().getTLSModel(N->getGlobal()); @@ -3735,13 +3803,6 @@ SDValue RISCVTargetLowering::lowerGlobalTLSAddress(SDValue Op, break; } - // In order to maximise the opportunity for common subexpression elimination, - // emit a separate ADD node for the global address offset instead of folding - // it in the global address node. Later peephole optimisations may choose to - // fold it back in when profitable. - if (Offset != 0) - return DAG.getNode(ISD::ADD, DL, Ty, Addr, - DAG.getConstant(Offset, DL, XLenVT)); return Addr; } @@ -3911,7 +3972,7 @@ SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op, // if Shamt-XLEN < 0: // Shamt < XLEN // Lo = Lo << Shamt - // Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 - Shamt)) + // Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 ^ Shamt)) // else: // Lo = 0 // Hi = Lo << (Shamt-XLEN) @@ -3921,7 +3982,7 @@ SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op, SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT); SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT); SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen); - SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt); + SDValue XLenMinus1Shamt = DAG.getNode(ISD::XOR, DL, VT, Shamt, XLenMinus1); SDValue LoTrue = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt); SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo, One); @@ -3950,7 +4011,7 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, // SRA expansion: // if Shamt-XLEN < 0: // Shamt < XLEN - // Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - Shamt)) + // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1)) // Hi = Hi >>s Shamt // else: // Lo = Hi >>s (Shamt-XLEN); @@ -3958,7 +4019,7 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, // // SRL expansion: // if Shamt-XLEN < 0: // Shamt < XLEN - // Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - Shamt)) + // Lo = (Lo >>u Shamt) | ((Hi << 1) << (ShAmt ^ XLEN-1)) // Hi = Hi >>u Shamt // else: // Lo = Hi >>u (Shamt-XLEN); @@ -3971,7 +4032,7 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT); SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT); SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen); - SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt); + SDValue XLenMinus1Shamt = DAG.getNode(ISD::XOR, DL, VT, Shamt, XLenMinus1); SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt); SDValue ShiftLeftHi1 = DAG.getNode(ISD::SHL, DL, VT, Hi, One); @@ -4022,7 +4083,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op, // Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is // illegal (currently only vXi64 RV32). // FIXME: We could also catch non-constant sign-extended i32 values and lower -// them to SPLAT_VECTOR_I64 +// them to VMV_V_X_VL. SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -4041,7 +4102,8 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op, std::tie(Mask, VL) = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); - SDValue Res = splatPartsI64WithVL(DL, ContainerVT, Lo, Hi, VL, DAG); + SDValue Res = + splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG); return convertFromScalableVector(VecVT, Res, DAG, Subtarget); } @@ -4051,18 +4113,21 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op, // If Hi constant is all the same sign bit as Lo, lower this as a custom // node in order to try and match RVV vector/scalar instructions. if ((LoC >> 31) == HiC) - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), + Lo, DAG.getRegister(RISCV::X0, MVT::i32)); } // Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended. if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo && isa<ConstantSDNode>(Hi.getOperand(1)) && Hi.getConstantOperandVal(1) == 31) - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), Lo, + DAG.getRegister(RISCV::X0, MVT::i32)); // Fall back to use a stack store and stride x0 vector load. Use X0 as VL. - return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VecVT, Lo, Hi, - DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, MVT::i64)); + return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VecVT, + DAG.getUNDEF(VecVT), Lo, Hi, + DAG.getRegister(RISCV::X0, MVT::i32)); } // Custom-lower extensions from mask vectors by using a vselect either with 1 @@ -4078,27 +4143,9 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, assert(Src.getValueType().isVector() && Src.getValueType().getVectorElementType() == MVT::i1); - MVT XLenVT = Subtarget.getXLenVT(); - SDValue SplatZero = DAG.getConstant(0, DL, XLenVT); - SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, XLenVT); - if (VecVT.isScalableVector()) { - // Be careful not to introduce illegal scalar types at this stage, and be - // careful also about splatting constants as on RV32, vXi64 SPLAT_VECTOR is - // illegal and must be expanded. Since we know that the constants are - // sign-extended 32-bit values, we use SPLAT_VECTOR_I64 directly. - bool IsRV32E64 = - !Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64; - - if (!IsRV32E64) { - SplatZero = DAG.getSplatVector(VecVT, DL, SplatZero); - SplatTrueVal = DAG.getSplatVector(VecVT, DL, SplatTrueVal); - } else { - SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatZero); - SplatTrueVal = - DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatTrueVal); - } - + SDValue SplatZero = DAG.getConstant(0, DL, VecVT); + SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, VecVT); return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero); } @@ -4111,9 +4158,14 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, SDValue Mask, VL; std::tie(Mask, VL) = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); - SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatZero, VL); - SplatTrueVal = - DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatTrueVal, VL); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue SplatZero = DAG.getConstant(0, DL, XLenVT); + SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, XLenVT); + + SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SplatZero, VL); + SplatTrueVal = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SplatTrueVal, VL); SDValue Select = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC, SplatTrueVal, SplatZero, VL); @@ -4151,8 +4203,9 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV( // Custom-lower truncations from vectors to mask vectors by using a mask and a // setcc operation: // (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne) -SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op, - SelectionDAG &DAG) const { +SDValue RISCVTargetLowering::lowerVectorMaskTruncLike(SDValue Op, + SelectionDAG &DAG) const { + bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE; SDLoc DL(Op); EVT MaskVT = Op.getValueType(); // Only expect to custom-lower truncations to mask types @@ -4160,34 +4213,176 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op, "Unexpected type for vector mask lowering"); SDValue Src = Op.getOperand(0); MVT VecVT = Src.getSimpleValueType(); - + SDValue Mask, VL; + if (IsVPTrunc) { + Mask = Op.getOperand(1); + VL = Op.getOperand(2); + } // If this is a fixed vector, we need to convert it to a scalable vector. MVT ContainerVT = VecVT; + if (VecVT.isFixedLengthVector()) { ContainerVT = getContainerForFixedLengthVector(VecVT); Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); + if (IsVPTrunc) { + MVT MaskContainerVT = + getContainerForFixedLengthVector(Mask.getSimpleValueType()); + Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); + } + } + + if (!IsVPTrunc) { + std::tie(Mask, VL) = + getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); } SDValue SplatOne = DAG.getConstant(1, DL, Subtarget.getXLenVT()); SDValue SplatZero = DAG.getConstant(0, DL, Subtarget.getXLenVT()); - SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatOne); - SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatZero); - - if (VecVT.isScalableVector()) { - SDValue Trunc = DAG.getNode(ISD::AND, DL, VecVT, Src, SplatOne); - return DAG.getSetCC(DL, MaskVT, Trunc, SplatZero, ISD::SETNE); - } - - SDValue Mask, VL; - std::tie(Mask, VL) = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); + SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SplatOne, VL); + SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SplatZero, VL); MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); SDValue Trunc = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Src, SplatOne, Mask, VL); Trunc = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskContainerVT, Trunc, SplatZero, DAG.getCondCode(ISD::SETNE), Mask, VL); - return convertFromScalableVector(MaskVT, Trunc, DAG, Subtarget); + if (MaskVT.isFixedLengthVector()) + Trunc = convertFromScalableVector(MaskVT, Trunc, DAG, Subtarget); + return Trunc; +} + +SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op, + SelectionDAG &DAG) const { + bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE; + SDLoc DL(Op); + + MVT VT = Op.getSimpleValueType(); + // Only custom-lower vector truncates + assert(VT.isVector() && "Unexpected type for vector truncate lowering"); + + // Truncates to mask types are handled differently + if (VT.getVectorElementType() == MVT::i1) + return lowerVectorMaskTruncLike(Op, DAG); + + // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary + // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which + // truncate by one power of two at a time. + MVT DstEltVT = VT.getVectorElementType(); + + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + MVT SrcEltVT = SrcVT.getVectorElementType(); + + assert(DstEltVT.bitsLT(SrcEltVT) && isPowerOf2_64(DstEltVT.getSizeInBits()) && + isPowerOf2_64(SrcEltVT.getSizeInBits()) && + "Unexpected vector truncate lowering"); + + MVT ContainerVT = SrcVT; + SDValue Mask, VL; + if (IsVPTrunc) { + Mask = Op.getOperand(1); + VL = Op.getOperand(2); + } + if (SrcVT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(SrcVT); + Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); + if (IsVPTrunc) { + MVT MaskVT = getMaskTypeFor(ContainerVT); + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + } + } + + SDValue Result = Src; + if (!IsVPTrunc) { + std::tie(Mask, VL) = + getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); + } + + LLVMContext &Context = *DAG.getContext(); + const ElementCount Count = ContainerVT.getVectorElementCount(); + do { + SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); + EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count); + Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, + Mask, VL); + } while (SrcEltVT != DstEltVT); + + if (SrcVT.isFixedLengthVector()) + Result = convertFromScalableVector(VT, Result, DAG, Subtarget); + + return Result; +} + +SDValue +RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op, + SelectionDAG &DAG) const { + bool IsVP = + Op.getOpcode() == ISD::VP_FP_ROUND || Op.getOpcode() == ISD::VP_FP_EXTEND; + bool IsExtend = + Op.getOpcode() == ISD::VP_FP_EXTEND || Op.getOpcode() == ISD::FP_EXTEND; + // RVV can only do truncate fp to types half the size as the source. We + // custom-lower f64->f16 rounds via RVV's round-to-odd float + // conversion instruction. + SDLoc DL(Op); + MVT VT = Op.getSimpleValueType(); + + assert(VT.isVector() && "Unexpected type for vector truncate lowering"); + + SDValue Src = Op.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + + bool IsDirectExtend = IsExtend && (VT.getVectorElementType() != MVT::f64 || + SrcVT.getVectorElementType() != MVT::f16); + bool IsDirectTrunc = !IsExtend && (VT.getVectorElementType() != MVT::f16 || + SrcVT.getVectorElementType() != MVT::f64); + + bool IsDirectConv = IsDirectExtend || IsDirectTrunc; + + // Prepare any fixed-length vector operands. + MVT ContainerVT = VT; + SDValue Mask, VL; + if (IsVP) { + Mask = Op.getOperand(1); + VL = Op.getOperand(2); + } + if (VT.isFixedLengthVector()) { + MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT); + ContainerVT = + SrcContainerVT.changeVectorElementType(VT.getVectorElementType()); + Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); + if (IsVP) { + MVT MaskVT = getMaskTypeFor(ContainerVT); + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + } + } + + if (!IsVP) + std::tie(Mask, VL) = + getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); + + unsigned ConvOpc = IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::FP_ROUND_VL; + + if (IsDirectConv) { + Src = DAG.getNode(ConvOpc, DL, ContainerVT, Src, Mask, VL); + if (VT.isFixedLengthVector()) + Src = convertFromScalableVector(VT, Src, DAG, Subtarget); + return Src; + } + + unsigned InterConvOpc = + IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::VFNCVT_ROD_VL; + + MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32); + SDValue IntermediateConv = + DAG.getNode(InterConvOpc, DL, InterVT, Src, Mask, VL); + SDValue Result = + DAG.getNode(ConvOpc, DL, ContainerVT, IntermediateConv, Mask, VL); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(VT, Result, DAG, Subtarget); + return Result; } // Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the @@ -4268,13 +4463,15 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, SDValue InsertI64VL = DAG.getConstant(2, DL, XLenVT); // Note: We can't pass a UNDEF to the first VSLIDE1UP_VL since an untied // undef doesn't obey the earlyclobber constraint. Just splat a zero value. - ValInVec = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, I32ContainerVT, Zero, - InsertI64VL); + ValInVec = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, I32ContainerVT, + DAG.getUNDEF(I32ContainerVT), Zero, InsertI64VL); // First slide in the hi value, then the lo in underneath it. - ValInVec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32ContainerVT, ValInVec, - ValHi, I32Mask, InsertI64VL); - ValInVec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32ContainerVT, ValInVec, - ValLo, I32Mask, InsertI64VL); + ValInVec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32ContainerVT, + DAG.getUNDEF(I32ContainerVT), ValInVec, ValHi, + I32Mask, InsertI64VL); + ValInVec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32ContainerVT, + DAG.getUNDEF(I32ContainerVT), ValInVec, ValLo, + I32Mask, InsertI64VL); // Bitcast back to the right container type. ValInVec = DAG.getBitcast(ContainerVT, ValInVec); } @@ -4310,7 +4507,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, unsigned WidenVecLen; SDValue ExtractElementIdx; SDValue ExtractBitIdx; - unsigned MaxEEW = Subtarget.getMaxELENForFixedLengthVectors(); + unsigned MaxEEW = Subtarget.getELEN(); MVT LargestEltVT = MVT::getIntegerVT( std::min(MaxEEW, unsigned(XLenVT.getSizeInBits()))); if (NumElts <= LargestEltVT.getSizeInBits()) { @@ -4360,8 +4557,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, if (!isNullConstant(Idx)) { // Use a VL of 1 to avoid processing more elements than we need. SDValue VL = DAG.getConstant(1, DL, XLenVT); - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); Vec = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL); } @@ -4378,8 +4574,8 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, // Some RVV intrinsics may claim that they want an integer operand to be // promoted or expanded. -static SDValue lowerVectorIntrinsicSplats(SDValue Op, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { +static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { assert((Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || Op.getOpcode() == ISD::INTRINSIC_W_CHAIN) && "Unexpected opcode"); @@ -4393,10 +4589,10 @@ static SDValue lowerVectorIntrinsicSplats(SDValue Op, SelectionDAG &DAG, const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IntNo); - if (!II || !II->hasSplatOperand()) + if (!II || !II->hasScalarOperand()) return SDValue(); - unsigned SplatOp = II->SplatOperand + 1 + HasChain; + unsigned SplatOp = II->ScalarOperand + 1 + HasChain; assert(SplatOp < Op.getNumOperands()); SmallVector<SDValue, 8> Operands(Op->op_begin(), Op->op_end()); @@ -4426,28 +4622,141 @@ static SDValue lowerVectorIntrinsicSplats(SDValue Op, SelectionDAG &DAG, // that a widening operation never uses SEW=64. // NOTE: If this fails the below assert, we can probably just find the // element count from any operand or result and use it to construct the VT. - assert(II->SplatOperand > 0 && "Unexpected splat operand!"); + assert(II->ScalarOperand > 0 && "Unexpected splat operand!"); MVT VT = Op.getOperand(SplatOp - 1).getSimpleValueType(); // The more complex case is when the scalar is larger than XLenVT. assert(XLenVT == MVT::i32 && OpVT == MVT::i64 && VT.getVectorElementType() == MVT::i64 && "Unexpected VTs!"); - // If this is a sign-extended 32-bit constant, we can truncate it and rely - // on the instruction to sign-extend since SEW>XLEN. - if (auto *CVal = dyn_cast<ConstantSDNode>(ScalarOp)) { - if (isInt<32>(CVal->getSExtValue())) { - ScalarOp = DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32); - return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands); + // If this is a sign-extended 32-bit value, we can truncate it and rely on the + // instruction to sign-extend since SEW>XLEN. + if (DAG.ComputeNumSignBits(ScalarOp) > 32) { + ScalarOp = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ScalarOp); + return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands); + } + + switch (IntNo) { + case Intrinsic::riscv_vslide1up: + case Intrinsic::riscv_vslide1down: + case Intrinsic::riscv_vslide1up_mask: + case Intrinsic::riscv_vslide1down_mask: { + // We need to special case these when the scalar is larger than XLen. + unsigned NumOps = Op.getNumOperands(); + bool IsMasked = NumOps == 7; + + // Convert the vector source to the equivalent nxvXi32 vector. + MVT I32VT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2); + SDValue Vec = DAG.getBitcast(I32VT, Operands[2]); + + SDValue ScalarLo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, ScalarOp, + DAG.getConstant(0, DL, XLenVT)); + SDValue ScalarHi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, ScalarOp, + DAG.getConstant(1, DL, XLenVT)); + + // Double the VL since we halved SEW. + SDValue AVL = getVLOperand(Op); + SDValue I32VL; + + // Optimize for constant AVL + if (isa<ConstantSDNode>(AVL)) { + unsigned EltSize = VT.getScalarSizeInBits(); + unsigned MinSize = VT.getSizeInBits().getKnownMinValue(); + + unsigned VectorBitsMax = Subtarget.getRealMaxVLen(); + unsigned MaxVLMAX = + RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize); + + unsigned VectorBitsMin = Subtarget.getRealMinVLen(); + unsigned MinVLMAX = + RISCVTargetLowering::computeVLMAX(VectorBitsMin, EltSize, MinSize); + + uint64_t AVLInt = cast<ConstantSDNode>(AVL)->getZExtValue(); + if (AVLInt <= MinVLMAX) { + I32VL = DAG.getConstant(2 * AVLInt, DL, XLenVT); + } else if (AVLInt >= 2 * MaxVLMAX) { + // Just set vl to VLMAX in this situation + RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(I32VT); + SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT); + unsigned Sew = RISCVVType::encodeSEW(I32VT.getScalarSizeInBits()); + SDValue SEW = DAG.getConstant(Sew, DL, XLenVT); + SDValue SETVLMAX = DAG.getTargetConstant( + Intrinsic::riscv_vsetvlimax_opt, DL, MVT::i32); + I32VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVLMAX, SEW, + LMUL); + } else { + // For AVL between (MinVLMAX, 2 * MaxVLMAX), the actual working vl + // is related to the hardware implementation. + // So let the following code handle + } } + if (!I32VL) { + RISCVII::VLMUL Lmul = RISCVTargetLowering::getLMUL(VT); + SDValue LMUL = DAG.getConstant(Lmul, DL, XLenVT); + unsigned Sew = RISCVVType::encodeSEW(VT.getScalarSizeInBits()); + SDValue SEW = DAG.getConstant(Sew, DL, XLenVT); + SDValue SETVL = + DAG.getTargetConstant(Intrinsic::riscv_vsetvli_opt, DL, MVT::i32); + // Using vsetvli instruction to get actually used length which related to + // the hardware implementation + SDValue VL = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XLenVT, SETVL, AVL, + SEW, LMUL); + I32VL = + DAG.getNode(ISD::SHL, DL, XLenVT, VL, DAG.getConstant(1, DL, XLenVT)); + } + + SDValue I32Mask = getAllOnesMask(I32VT, I32VL, DL, DAG); + + // Shift the two scalar parts in using SEW=32 slide1up/slide1down + // instructions. + SDValue Passthru; + if (IsMasked) + Passthru = DAG.getUNDEF(I32VT); + else + Passthru = DAG.getBitcast(I32VT, Operands[1]); + + if (IntNo == Intrinsic::riscv_vslide1up || + IntNo == Intrinsic::riscv_vslide1up_mask) { + Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec, + ScalarHi, I32Mask, I32VL); + Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Passthru, Vec, + ScalarLo, I32Mask, I32VL); + } else { + Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec, + ScalarLo, I32Mask, I32VL); + Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Passthru, Vec, + ScalarHi, I32Mask, I32VL); + } + + // Convert back to nxvXi64. + Vec = DAG.getBitcast(VT, Vec); + + if (!IsMasked) + return Vec; + // Apply mask after the operation. + SDValue Mask = Operands[NumOps - 3]; + SDValue MaskedOff = Operands[1]; + // Assume Policy operand is the last operand. + uint64_t Policy = + cast<ConstantSDNode>(Operands[NumOps - 1])->getZExtValue(); + // We don't need to select maskedoff if it's undef. + if (MaskedOff.isUndef()) + return Vec; + // TAMU + if (Policy == RISCVII::TAIL_AGNOSTIC) + return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, Mask, Vec, MaskedOff, + AVL); + // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma. + // It's fine because vmerge does not care mask policy. + return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff, + AVL); + } } // We need to convert the scalar to a splat vector. - // FIXME: Can we implicitly truncate the scalar if it is known to - // be sign extended? SDValue VL = getVLOperand(Op); assert(VL.getValueType() == XLenVT); - ScalarOp = splatSplitI64WithVL(DL, VT, ScalarOp, VL, DAG); + ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG); return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands); } @@ -4481,7 +4790,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::riscv_zip: case Intrinsic::riscv_unzip: { // Lower to the SHFLI encoding for zip or the UNSHFLI encoding for unzip. - // For i32 the immdiate is 15. For i64 the immediate is 31. + // For i32 the immediate is 15. For i64 the immediate is 31. unsigned Opc = IntNo == Intrinsic::riscv_zip ? RISCVISD::SHFL : RISCVISD::UNSHFL; unsigned BitWidth = Op.getValueSizeInBits(); @@ -4516,10 +4825,11 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(1)); case Intrinsic::riscv_vmv_v_x: return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2), - Op.getSimpleValueType(), DL, DAG, Subtarget); + Op.getOperand(3), Op.getSimpleValueType(), DL, DAG, + Subtarget); case Intrinsic::riscv_vfmv_v_f: return DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, Op.getValueType(), - Op.getOperand(1), Op.getOperand(2)); + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); case Intrinsic::riscv_vmv_s_x: { SDValue Scalar = Op.getOperand(2); @@ -4533,7 +4843,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, // This is an i64 value that lives in two scalar registers. We have to // insert this in a convoluted way. First we build vXi64 splat containing - // the/ two values that we assemble using some bit math. Next we'll use + // the two values that we assemble using some bit math. Next we'll use // vid.v and vmseq to build a mask with bit 0 set. Then we'll use that mask // to merge element 0 from our splat into the source vector. // FIXME: This is probably not the best way to do this, but it is @@ -4550,12 +4860,15 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SDValue Vec = Op.getOperand(1); SDValue VL = getVLOperand(Op); - SDValue SplattedVal = splatSplitI64WithVL(DL, VT, Scalar, VL, DAG); - SDValue SplattedIdx = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, - DAG.getConstant(0, DL, MVT::i32), VL); + SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG); + if (Op.getOperand(1).isUndef()) + return SplattedVal; + SDValue SplattedIdx = + DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), + DAG.getConstant(0, DL, MVT::i32), VL); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + MVT MaskVT = getMaskTypeFor(VT); + SDValue Mask = getAllOnesMask(VT, VL, DL, DAG); SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL); SDValue SelectCond = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, VID, SplattedIdx, @@ -4563,73 +4876,9 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, SelectCond, SplattedVal, Vec, VL); } - case Intrinsic::riscv_vslide1up: - case Intrinsic::riscv_vslide1down: - case Intrinsic::riscv_vslide1up_mask: - case Intrinsic::riscv_vslide1down_mask: { - // We need to special case these when the scalar is larger than XLen. - unsigned NumOps = Op.getNumOperands(); - bool IsMasked = NumOps == 7; - unsigned OpOffset = IsMasked ? 1 : 0; - SDValue Scalar = Op.getOperand(2 + OpOffset); - if (Scalar.getValueType().bitsLE(XLenVT)) - break; - - // Splatting a sign extended constant is fine. - if (auto *CVal = dyn_cast<ConstantSDNode>(Scalar)) - if (isInt<32>(CVal->getSExtValue())) - break; - - MVT VT = Op.getSimpleValueType(); - assert(VT.getVectorElementType() == MVT::i64 && - Scalar.getValueType() == MVT::i64 && "Unexpected VTs"); - - // Convert the vector source to the equivalent nxvXi32 vector. - MVT I32VT = MVT::getVectorVT(MVT::i32, VT.getVectorElementCount() * 2); - SDValue Vec = DAG.getBitcast(I32VT, Op.getOperand(1 + OpOffset)); - - SDValue ScalarLo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar, - DAG.getConstant(0, DL, XLenVT)); - SDValue ScalarHi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, Scalar, - DAG.getConstant(1, DL, XLenVT)); - - // Double the VL since we halved SEW. - SDValue VL = getVLOperand(Op); - SDValue I32VL = - DAG.getNode(ISD::SHL, DL, XLenVT, VL, DAG.getConstant(1, DL, XLenVT)); - - MVT I32MaskVT = MVT::getVectorVT(MVT::i1, I32VT.getVectorElementCount()); - SDValue I32Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, I32MaskVT, VL); - - // Shift the two scalar parts in using SEW=32 slide1up/slide1down - // instructions. - if (IntNo == Intrinsic::riscv_vslide1up || - IntNo == Intrinsic::riscv_vslide1up_mask) { - Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Vec, ScalarHi, - I32Mask, I32VL); - Vec = DAG.getNode(RISCVISD::VSLIDE1UP_VL, DL, I32VT, Vec, ScalarLo, - I32Mask, I32VL); - } else { - Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Vec, ScalarLo, - I32Mask, I32VL); - Vec = DAG.getNode(RISCVISD::VSLIDE1DOWN_VL, DL, I32VT, Vec, ScalarHi, - I32Mask, I32VL); - } - - // Convert back to nxvXi64. - Vec = DAG.getBitcast(VT, Vec); - - if (!IsMasked) - return Vec; - - // Apply mask after the operation. - SDValue Mask = Op.getOperand(NumOps - 3); - SDValue MaskedOff = Op.getOperand(1); - return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, Mask, Vec, MaskedOff, VL); - } } - return lowerVectorIntrinsicSplats(Op, DAG, Subtarget); + return lowerVectorIntrinsicScalars(Op, DAG, Subtarget); } SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, @@ -4652,8 +4901,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SDValue PassThru = Op.getOperand(2); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget); } @@ -4688,9 +4936,48 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, Result = convertFromScalableVector(VT, Result, DAG, Subtarget); return DAG.getMergeValues({Result, Chain}, DL); } + case Intrinsic::riscv_seg2_load: + case Intrinsic::riscv_seg3_load: + case Intrinsic::riscv_seg4_load: + case Intrinsic::riscv_seg5_load: + case Intrinsic::riscv_seg6_load: + case Intrinsic::riscv_seg7_load: + case Intrinsic::riscv_seg8_load: { + SDLoc DL(Op); + static const Intrinsic::ID VlsegInts[7] = { + Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3, + Intrinsic::riscv_vlseg4, Intrinsic::riscv_vlseg5, + Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7, + Intrinsic::riscv_vlseg8}; + unsigned NF = Op->getNumValues() - 1; + assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); + MVT XLenVT = Subtarget.getXLenVT(); + MVT VT = Op->getSimpleValueType(0); + MVT ContainerVT = getContainerForFixedLengthVector(VT); + + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT); + auto *Load = cast<MemIntrinsicSDNode>(Op); + SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT); + ContainerVTs.push_back(MVT::Other); + SDVTList VTs = DAG.getVTList(ContainerVTs); + SmallVector<SDValue, 12> Ops = {Load->getChain(), IntID}; + Ops.insert(Ops.end(), NF, DAG.getUNDEF(ContainerVT)); + Ops.push_back(Op.getOperand(2)); + Ops.push_back(VL); + SDValue Result = + DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, + Load->getMemoryVT(), Load->getMemOperand()); + SmallVector<SDValue, 9> Results; + for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) + Results.push_back(convertFromScalableVector(VT, Result.getValue(RetIdx), + DAG, Subtarget)); + Results.push_back(Result.getValue(NF)); + return DAG.getMergeValues(Results, DL); + } } - return lowerVectorIntrinsicSplats(Op, DAG, Subtarget); + return lowerVectorIntrinsicScalars(Op, DAG, Subtarget); } SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op, @@ -4714,8 +5001,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op, Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); } @@ -4898,8 +5184,9 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags()); - SDValue IdentitySplat = lowerScalarSplat( - NeutralElem, DAG.getConstant(1, DL, XLenVT), M1VT, DL, DAG, Subtarget); + SDValue IdentitySplat = + lowerScalarSplat(SDValue(), NeutralElem, DAG.getConstant(1, DL, XLenVT), + M1VT, DL, DAG, Subtarget); SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec, IdentitySplat, Mask, VL); SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction, @@ -4960,8 +5247,9 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op, SDValue Mask, VL; std::tie(Mask, VL) = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); - SDValue ScalarSplat = lowerScalarSplat( - ScalarVal, DAG.getConstant(1, DL, XLenVT), M1VT, DL, DAG, Subtarget); + SDValue ScalarSplat = + lowerScalarSplat(SDValue(), ScalarVal, DAG.getConstant(1, DL, XLenVT), + M1VT, DL, DAG, Subtarget); SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), VectorVal, ScalarSplat, Mask, VL); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction, @@ -5027,9 +5315,9 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op, MVT XLenVT = Subtarget.getXLenVT(); MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT; - SDValue StartSplat = - lowerScalarSplat(Op.getOperand(0), DAG.getConstant(1, DL, XLenVT), M1VT, - DL, DAG, Subtarget); + SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0), + DAG.getConstant(1, DL, XLenVT), M1VT, + DL, DAG, Subtarget); SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL); SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction, @@ -5331,13 +5619,13 @@ SDValue RISCVTargetLowering::lowerSTEP_VECTOR(SDValue Op, if (StepValImm != 1) { if (isPowerOf2_64(StepValImm)) { SDValue StepVal = - DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, + DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), DAG.getConstant(Log2_64(StepValImm), DL, XLenVT)); StepVec = DAG.getNode(ISD::SHL, DL, VT, StepVec, StepVal); } else { SDValue StepVal = lowerScalarSplat( - DAG.getConstant(StepValImm, DL, VT.getVectorElementType()), VL, VT, - DL, DAG, Subtarget); + SDValue(), DAG.getConstant(StepValImm, DL, VT.getVectorElementType()), + VL, VT, DL, DAG, Subtarget); StepVec = DAG.getNode(ISD::MUL, DL, VT, StepVec, StepVal); } } @@ -5353,22 +5641,26 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); MVT VecVT = Op.getSimpleValueType(); + if (VecVT.getVectorElementType() == MVT::i1) { + MVT WidenVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount()); + SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WidenVT, Op.getOperand(0)); + SDValue Op2 = DAG.getNode(ISD::VECTOR_REVERSE, DL, WidenVT, Op1); + return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Op2); + } unsigned EltSize = VecVT.getScalarSizeInBits(); unsigned MinSize = VecVT.getSizeInBits().getKnownMinValue(); - - unsigned MaxVLMAX = 0; - unsigned VectorBitsMax = Subtarget.getMaxRVVVectorSizeInBits(); - if (VectorBitsMax != 0) - MaxVLMAX = ((VectorBitsMax / EltSize) * MinSize) / RISCV::RVVBitsPerBlock; + unsigned VectorBitsMax = Subtarget.getRealMaxVLen(); + unsigned MaxVLMAX = + RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize); unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL; MVT IntVT = VecVT.changeVectorElementTypeToInteger(); - // If this is SEW=8 and VLMAX is unknown or more than 256, we need + // If this is SEW=8 and VLMAX is potentially more than 256, we need // to use vrgatherei16.vv. // TODO: It's also possible to use vrgatherei16.vv for other types to // decrease register width for the index calculation. - if ((MaxVLMAX == 0 || MaxVLMAX > 256) && EltSize == 8) { + if (MaxVLMAX > 256 && EltSize == 8) { // If this is LMUL=8, we have to split before can use vrgatherei16.vv. // Reverse each half, then reassemble them in reverse order. // NOTE: It's also possible that after splitting that VLMAX no longer @@ -5413,13 +5705,51 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op, if (!IsRV32E64) SplatVL = DAG.getSplatVector(IntVT, DL, VLMinus1); else - SplatVL = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, IntVT, VLMinus1); + SplatVL = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, DAG.getUNDEF(IntVT), + VLMinus1, DAG.getRegister(RISCV::X0, XLenVT)); SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL); SDValue Indices = DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL); - return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, VL); + return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, + DAG.getUNDEF(VecVT), VL); +} + +SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue V1 = Op.getOperand(0); + SDValue V2 = Op.getOperand(1); + MVT XLenVT = Subtarget.getXLenVT(); + MVT VecVT = Op.getSimpleValueType(); + + unsigned MinElts = VecVT.getVectorMinNumElements(); + SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT, + DAG.getConstant(MinElts, DL, XLenVT)); + + int64_t ImmValue = cast<ConstantSDNode>(Op.getOperand(2))->getSExtValue(); + SDValue DownOffset, UpOffset; + if (ImmValue >= 0) { + // The operand is a TargetConstant, we need to rebuild it as a regular + // constant. + DownOffset = DAG.getConstant(ImmValue, DL, XLenVT); + UpOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DownOffset); + } else { + // The operand is a TargetConstant, we need to rebuild it as a regular + // constant rather than negating the original operand. + UpOffset = DAG.getConstant(-ImmValue, DL, XLenVT); + DownOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, UpOffset); + } + + SDValue TrueMask = getAllOnesMask(VecVT, VLMax, DL, DAG); + + SDValue SlideDown = + DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, DAG.getUNDEF(VecVT), V1, + DownOffset, TrueMask, UpOffset); + return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VecVT, SlideDown, V2, UpOffset, + TrueMask, + DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT)); } SDValue @@ -5434,18 +5764,26 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op, "Expecting a correctly-aligned load"); MVT VT = Op.getSimpleValueType(); + MVT XLenVT = Subtarget.getXLenVT(); MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VL = - DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + bool IsMaskOp = VT.getVectorElementType() == MVT::i1; + SDValue IntID = DAG.getTargetConstant( + IsMaskOp ? Intrinsic::riscv_vlm : Intrinsic::riscv_vle, DL, XLenVT); + SmallVector<SDValue, 4> Ops{Load->getChain(), IntID}; + if (!IsMaskOp) + Ops.push_back(DAG.getUNDEF(ContainerVT)); + Ops.push_back(Load->getBasePtr()); + Ops.push_back(VL); SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other}); - SDValue NewLoad = DAG.getMemIntrinsicNode( - RISCVISD::VLE_VL, DL, VTs, {Load->getChain(), Load->getBasePtr(), VL}, - Load->getMemoryVT(), Load->getMemOperand()); + SDValue NewLoad = + DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, + Load->getMemoryVT(), Load->getMemOperand()); SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget); - return DAG.getMergeValues({Result, Load->getChain()}, DL); + return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL); } SDValue @@ -5461,6 +5799,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op, SDValue StoreVal = Store->getValue(); MVT VT = StoreVal.getSimpleValueType(); + MVT XLenVT = Subtarget.getXLenVT(); // If the size less than a byte, we need to pad with zeros to make a byte. if (VT.getVectorElementType() == MVT::i1 && VT.getVectorNumElements() < 8) { @@ -5472,14 +5811,17 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op, MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VL = - DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); SDValue NewValue = convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget); + + bool IsMaskOp = VT.getVectorElementType() == MVT::i1; + SDValue IntID = DAG.getTargetConstant( + IsMaskOp ? Intrinsic::riscv_vsm : Intrinsic::riscv_vse, DL, XLenVT); return DAG.getMemIntrinsicNode( - RISCVISD::VSE_VL, DL, DAG.getVTList(MVT::Other), - {Store->getChain(), NewValue, Store->getBasePtr(), VL}, + ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), + {Store->getChain(), IntID, NewValue, Store->getBasePtr(), VL}, Store->getMemoryVT(), Store->getMemOperand()); } @@ -5514,8 +5856,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op, ContainerVT = getContainerForFixedLengthVector(VT); PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); } } @@ -5581,8 +5922,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op, Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); } } @@ -5620,8 +5960,8 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op, SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + MVT MaskVT = getMaskTypeFor(ContainerVT); + SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); SDValue Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2, Op.getOperand(2), Mask, VL); @@ -5667,9 +6007,9 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const { SDValue Mask, VL; std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); - SDValue SplatZero = - DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, - DAG.getConstant(0, DL, Subtarget.getXLenVT())); + SDValue SplatZero = DAG.getNode( + RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), + DAG.getConstant(0, DL, Subtarget.getXLenVT())); SDValue NegX = DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, SplatZero, X, Mask, VL); SDValue Max = @@ -5787,15 +6127,260 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG, } if (!VT.isFixedLengthVector()) - return DAG.getNode(RISCVISDOpc, DL, VT, Ops); + return DAG.getNode(RISCVISDOpc, DL, VT, Ops, Op->getFlags()); MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops); + SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops, Op->getFlags()); return convertFromScalableVector(VT, VPOp, DAG, Subtarget); } +SDValue RISCVTargetLowering::lowerVPExtMaskOp(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MVT VT = Op.getSimpleValueType(); + + SDValue Src = Op.getOperand(0); + // NOTE: Mask is dropped. + SDValue VL = Op.getOperand(2); + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(VT); + MVT SrcVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget); + } + + MVT XLenVT = Subtarget.getXLenVT(); + SDValue Zero = DAG.getConstant(0, DL, XLenVT); + SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), Zero, VL); + + SDValue SplatValue = DAG.getConstant( + Op.getOpcode() == ISD::VP_ZERO_EXTEND ? 1 : -1, DL, XLenVT); + SDValue Splat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), SplatValue, VL); + + SDValue Result = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, Src, + Splat, ZeroSplat, VL); + if (!VT.isFixedLengthVector()) + return Result; + return convertFromScalableVector(VT, Result, DAG, Subtarget); +} + +SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MVT VT = Op.getSimpleValueType(); + + SDValue Op1 = Op.getOperand(0); + SDValue Op2 = Op.getOperand(1); + ISD::CondCode Condition = cast<CondCodeSDNode>(Op.getOperand(2))->get(); + // NOTE: Mask is dropped. + SDValue VL = Op.getOperand(4); + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(VT); + Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget); + Op2 = convertToScalableVector(ContainerVT, Op2, DAG, Subtarget); + } + + SDValue Result; + SDValue AllOneMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL); + + switch (Condition) { + default: + break; + // X != Y --> (X^Y) + case ISD::SETNE: + Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL); + break; + // X == Y --> ~(X^Y) + case ISD::SETEQ: { + SDValue Temp = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, Op2, VL); + Result = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, AllOneMask, VL); + break; + } + // X >s Y --> X == 0 & Y == 1 --> ~X & Y + // X <u Y --> X == 0 & Y == 1 --> ~X & Y + case ISD::SETGT: + case ISD::SETULT: { + SDValue Temp = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL); + Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Temp, Op2, VL); + break; + } + // X <s Y --> X == 1 & Y == 0 --> ~Y & X + // X >u Y --> X == 1 & Y == 0 --> ~Y & X + case ISD::SETLT: + case ISD::SETUGT: { + SDValue Temp = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL); + Result = DAG.getNode(RISCVISD::VMAND_VL, DL, ContainerVT, Op1, Temp, VL); + break; + } + // X >=s Y --> X == 0 | Y == 1 --> ~X | Y + // X <=u Y --> X == 0 | Y == 1 --> ~X | Y + case ISD::SETGE: + case ISD::SETULE: { + SDValue Temp = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op1, AllOneMask, VL); + Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op2, VL); + break; + } + // X <=s Y --> X == 1 | Y == 0 --> ~Y | X + // X >=u Y --> X == 1 | Y == 0 --> ~Y | X + case ISD::SETLE: + case ISD::SETUGE: { + SDValue Temp = + DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Op2, AllOneMask, VL); + Result = DAG.getNode(RISCVISD::VMXOR_VL, DL, ContainerVT, Temp, Op1, VL); + break; + } + } + + if (!VT.isFixedLengthVector()) + return Result; + return convertFromScalableVector(VT, Result, DAG, Subtarget); +} + +// Lower Floating-Point/Integer Type-Convert VP SDNodes +SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG, + unsigned RISCVISDOpc) const { + SDLoc DL(Op); + + SDValue Src = Op.getOperand(0); + SDValue Mask = Op.getOperand(1); + SDValue VL = Op.getOperand(2); + + MVT DstVT = Op.getSimpleValueType(); + MVT SrcVT = Src.getSimpleValueType(); + if (DstVT.isFixedLengthVector()) { + DstVT = getContainerForFixedLengthVector(DstVT); + SrcVT = getContainerForFixedLengthVector(SrcVT); + Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget); + MVT MaskVT = getMaskTypeFor(DstVT); + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + } + + unsigned RISCVISDExtOpc = (RISCVISDOpc == RISCVISD::SINT_TO_FP_VL || + RISCVISDOpc == RISCVISD::FP_TO_SINT_VL) + ? RISCVISD::VSEXT_VL + : RISCVISD::VZEXT_VL; + + unsigned DstEltSize = DstVT.getScalarSizeInBits(); + unsigned SrcEltSize = SrcVT.getScalarSizeInBits(); + + SDValue Result; + if (DstEltSize >= SrcEltSize) { // Single-width and widening conversion. + if (SrcVT.isInteger()) { + assert(DstVT.isFloatingPoint() && "Wrong input/output vector types"); + + // Do we need to do any pre-widening before converting? + if (SrcEltSize == 1) { + MVT IntVT = DstVT.changeVectorElementTypeToInteger(); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue Zero = DAG.getConstant(0, DL, XLenVT); + SDValue ZeroSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, + DAG.getUNDEF(IntVT), Zero, VL); + SDValue One = DAG.getConstant( + RISCVISDExtOpc == RISCVISD::VZEXT_VL ? 1 : -1, DL, XLenVT); + SDValue OneSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, + DAG.getUNDEF(IntVT), One, VL); + Src = DAG.getNode(RISCVISD::VSELECT_VL, DL, IntVT, Src, OneSplat, + ZeroSplat, VL); + } else if (DstEltSize > (2 * SrcEltSize)) { + // Widen before converting. + MVT IntVT = MVT::getVectorVT(MVT::getIntegerVT(DstEltSize / 2), + DstVT.getVectorElementCount()); + Src = DAG.getNode(RISCVISDExtOpc, DL, IntVT, Src, Mask, VL); + } + + Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL); + } else { + assert(SrcVT.isFloatingPoint() && DstVT.isInteger() && + "Wrong input/output vector types"); + + // Convert f16 to f32 then convert f32 to i64. + if (DstEltSize > (2 * SrcEltSize)) { + assert(SrcVT.getVectorElementType() == MVT::f16 && "Unexpected type!"); + MVT InterimFVT = + MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount()); + Src = + DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, InterimFVT, Src, Mask, VL); + } + + Result = DAG.getNode(RISCVISDOpc, DL, DstVT, Src, Mask, VL); + } + } else { // Narrowing + Conversion + if (SrcVT.isInteger()) { + assert(DstVT.isFloatingPoint() && "Wrong input/output vector types"); + // First do a narrowing convert to an FP type half the size, then round + // the FP type to a small FP type if needed. + + MVT InterimFVT = DstVT; + if (SrcEltSize > (2 * DstEltSize)) { + assert(SrcEltSize == (4 * DstEltSize) && "Unexpected types!"); + assert(DstVT.getVectorElementType() == MVT::f16 && "Unexpected type!"); + InterimFVT = MVT::getVectorVT(MVT::f32, DstVT.getVectorElementCount()); + } + + Result = DAG.getNode(RISCVISDOpc, DL, InterimFVT, Src, Mask, VL); + + if (InterimFVT != DstVT) { + Src = Result; + Result = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, DstVT, Src, Mask, VL); + } + } else { + assert(SrcVT.isFloatingPoint() && DstVT.isInteger() && + "Wrong input/output vector types"); + // First do a narrowing conversion to an integer half the size, then + // truncate if needed. + + if (DstEltSize == 1) { + // First convert to the same size integer, then convert to mask using + // setcc. + assert(SrcEltSize >= 16 && "Unexpected FP type!"); + MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize), + DstVT.getVectorElementCount()); + Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL); + + // Compare the integer result to 0. The integer should be 0 or 1/-1, + // otherwise the conversion was undefined. + MVT XLenVT = Subtarget.getXLenVT(); + SDValue SplatZero = DAG.getConstant(0, DL, XLenVT); + SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, InterimIVT, + DAG.getUNDEF(InterimIVT), SplatZero); + Result = DAG.getNode(RISCVISD::SETCC_VL, DL, DstVT, Result, SplatZero, + DAG.getCondCode(ISD::SETNE), Mask, VL); + } else { + MVT InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2), + DstVT.getVectorElementCount()); + + Result = DAG.getNode(RISCVISDOpc, DL, InterimIVT, Src, Mask, VL); + + while (InterimIVT != DstVT) { + SrcEltSize /= 2; + Src = Result; + InterimIVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize / 2), + DstVT.getVectorElementCount()); + Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, InterimIVT, + Src, Mask, VL); + } + } + } + } + + MVT VT = Op.getSimpleValueType(); + if (!VT.isFixedLengthVector()) + return Result; + return convertFromScalableVector(VT, Result, DAG, Subtarget); +} + SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const { @@ -5876,23 +6461,14 @@ SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op, MVT ContainerVT = VT; if (VT.isFixedLengthVector()) { - // We need to use the larger of the result and index type to determine the - // scalable type to use so we don't increase LMUL for any operand/result. - if (VT.bitsGE(IndexVT)) { - ContainerVT = getContainerForFixedLengthVector(VT); - IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), - ContainerVT.getVectorElementCount()); - } else { - IndexVT = getContainerForFixedLengthVector(IndexVT); - ContainerVT = MVT::getVectorVT(ContainerVT.getVectorElementType(), - IndexVT.getVectorElementCount()); - } + ContainerVT = getContainerForFixedLengthVector(VT); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), + ContainerVT.getVectorElementCount()); Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget); } @@ -5987,24 +6563,15 @@ SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op, MVT ContainerVT = VT; if (VT.isFixedLengthVector()) { - // We need to use the larger of the value and index type to determine the - // scalable type to use so we don't increase LMUL for any operand/result. - if (VT.bitsGE(IndexVT)) { - ContainerVT = getContainerForFixedLengthVector(VT); - IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), - ContainerVT.getVectorElementCount()); - } else { - IndexVT = getContainerForFixedLengthVector(IndexVT); - ContainerVT = MVT::getVectorVT(VT.getVectorElementType(), - IndexVT.getVectorElementCount()); - } + ContainerVT = getContainerForFixedLengthVector(VT); + IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), + ContainerVT.getVectorElementCount()); Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget); Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget); if (!IsUnmasked) { - MVT MaskVT = - MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(ContainerVT); Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); } } @@ -6095,14 +6662,21 @@ SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op, RMValue); } +SDValue RISCVTargetLowering::lowerEH_DWARF_CFA(SDValue Op, + SelectionDAG &DAG) const { + MachineFunction &MF = DAG.getMachineFunction(); + + bool isRISCV64 = Subtarget.is64Bit(); + EVT PtrVT = getPointerTy(DAG.getDataLayout()); + + int FI = MF.getFrameInfo().CreateFixedObject(isRISCV64 ? 8 : 4, 0, false); + return DAG.getFrameIndex(FI, PtrVT); +} + static RISCVISD::NodeType getRISCVWOpcodeByIntr(unsigned IntNo) { switch (IntNo) { default: llvm_unreachable("Unexpected Intrinsic"); - case Intrinsic::riscv_grev: - return RISCVISD::GREVW; - case Intrinsic::riscv_gorc: - return RISCVISD::GORCW; case Intrinsic::riscv_bcompress: return RISCVISD::BCOMPRESSW; case Intrinsic::riscv_bdecompress: @@ -6121,9 +6695,12 @@ static SDValue customLegalizeToWOpByIntr(SDNode *N, SelectionDAG &DAG, unsigned IntNo) { SDLoc DL(N); RISCVISD::NodeType WOpcode = getRISCVWOpcodeByIntr(IntNo); - SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); - SDValue NewOp2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2)); - SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp1, NewOp2); + // Deal with the Instruction Operands + SmallVector<SDValue, 3> NewOps; + for (SDValue Op : drop_begin(N->ops())) + // Promote the operand to i64 type + NewOps.push_back(DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op)); + SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOps); // ReplaceNodeResults requires we maintain the same type for the return value. return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewRes); } @@ -6150,10 +6727,6 @@ static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) { return RISCVISD::ROLW; case ISD::ROTR: return RISCVISD::RORW; - case RISCVISD::GREV: - return RISCVISD::GREVW; - case RISCVISD::GORC: - return RISCVISD::GORCW; } } @@ -6309,6 +6882,10 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); if (N->getOperand(1).getOpcode() != ISD::Constant) { + // If we can use a BSET instruction, allow default promotion to apply. + if (N->getOpcode() == ISD::SHL && Subtarget.hasStdExtZbs() && + isOneConstant(N->getOperand(0))) + break; Results.push_back(customLegalizeToWOp(N, DAG)); break; } @@ -6388,12 +6965,23 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Res = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Res, DAG.getValueType(MVT::i32)); - // Sign extend the LHS and perform an unsigned compare with the ADDW result. - // Since the inputs are sign extended from i32, this is equivalent to - // comparing the lower 32 bits. - LHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0)); - SDValue Overflow = DAG.getSetCC(DL, N->getValueType(1), Res, LHS, - IsAdd ? ISD::SETULT : ISD::SETUGT); + SDValue Overflow; + if (IsAdd && isOneConstant(RHS)) { + // Special case uaddo X, 1 overflowed if the addition result is 0. + // The general case (X + C) < C is not necessarily beneficial. Although we + // reduce the live range of X, we may introduce the materialization of + // constant C, especially when the setcc result is used by branch. We have + // no compare with constant and branch instructions. + Overflow = DAG.getSetCC(DL, N->getValueType(1), Res, + DAG.getConstant(0, DL, MVT::i64), ISD::SETEQ); + } else { + // Sign extend the LHS and perform an unsigned compare with the ADDW + // result. Since the inputs are sign extended from i32, this is equivalent + // to comparing the lower 32 bits. + LHS = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0)); + Overflow = DAG.getSetCC(DL, N->getValueType(1), Res, LHS, + IsAdd ? ISD::SETULT : ISD::SETUGT); + } Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); Results.push_back(Overflow); @@ -6421,6 +7009,33 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(expandAddSubSat(N, DAG)); return; } + case ISD::ABS: { + assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && + "Unexpected custom legalisation"); + + // Expand abs to Y = (sraiw X, 31); subw(xor(X, Y), Y) + + SDValue Src = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); + + // Freeze the source so we can increase it's use count. + Src = DAG.getFreeze(Src); + + // Copy sign bit to all bits using the sraiw pattern. + SDValue SignFill = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Src, + DAG.getValueType(MVT::i32)); + SignFill = DAG.getNode(ISD::SRA, DL, MVT::i64, SignFill, + DAG.getConstant(31, DL, MVT::i64)); + + SDValue NewRes = DAG.getNode(ISD::XOR, DL, MVT::i64, Src, SignFill); + NewRes = DAG.getNode(ISD::SUB, DL, MVT::i64, NewRes, SignFill); + + // NOTE: The result is only required to be anyextended, but sext is + // consistent with type legalization of sub. + NewRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, NewRes, + DAG.getValueType(MVT::i32)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes)); + return; + } case ISD::BITCAST: { EVT VT = N->getValueType(0); assert(VT.isInteger() && !VT.isVector() && "Unexpected VT!"); @@ -6451,37 +7066,24 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, break; } case RISCVISD::GREV: - case RISCVISD::GORC: { - assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && - "Unexpected custom legalisation"); - assert(isa<ConstantSDNode>(N->getOperand(1)) && "Expected constant"); - // This is similar to customLegalizeToWOp, except that we pass the second - // operand (a TargetConstant) straight through: it is already of type - // XLenVT. - RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode()); - SDValue NewOp0 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); - SDValue NewOp1 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); - SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1); - // ReplaceNodeResults requires we maintain the same type for the return - // value. - Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes)); - break; - } + case RISCVISD::GORC: case RISCVISD::SHFL: { - // There is no SHFLIW instruction, but we can just promote the operation. - assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && + MVT VT = N->getSimpleValueType(0); + MVT XLenVT = Subtarget.getXLenVT(); + assert((VT == MVT::i16 || (VT == MVT::i32 && Subtarget.is64Bit())) && "Unexpected custom legalisation"); assert(isa<ConstantSDNode>(N->getOperand(1)) && "Expected constant"); - SDValue NewOp0 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); + assert((Subtarget.hasStdExtZbp() || + (Subtarget.hasStdExtZbkb() && N->getOpcode() == RISCVISD::GREV && + N->getConstantOperandVal(1) == 7)) && + "Unexpected extension"); + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, N->getOperand(0)); SDValue NewOp1 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); - SDValue NewRes = DAG.getNode(RISCVISD::SHFL, DL, MVT::i64, NewOp0, NewOp1); + DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, N->getOperand(1)); + SDValue NewRes = DAG.getNode(N->getOpcode(), DL, XLenVT, NewOp0, NewOp1); // ReplaceNodeResults requires we maintain the same type for the return // value. - Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes)); break; } case ISD::BSWAP: @@ -6496,9 +7098,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // If this is BSWAP rather than BITREVERSE, clear the lower 3 bits. if (N->getOpcode() == ISD::BSWAP) Imm &= ~0x7U; - unsigned Opc = Subtarget.is64Bit() ? RISCVISD::GREVW : RISCVISD::GREV; - SDValue GREVI = - DAG.getNode(Opc, DL, XLenVT, NewOp0, DAG.getConstant(Imm, DL, XLenVT)); + SDValue GREVI = DAG.getNode(RISCVISD::GREV, DL, XLenVT, NewOp0, + DAG.getConstant(Imm, DL, XLenVT)); // ReplaceNodeResults requires we maintain the same type for the return // value. Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, GREVI)); @@ -6564,9 +7165,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, MVT XLenVT = Subtarget.getXLenVT(); // Use a VL of 1 to avoid processing more elements than we need. - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); SDValue VL = DAG.getConstant(1, DL, XLenVT); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); // Unless the index is known to be 0, we must slide the vector down to get // the desired element into index 0. @@ -6581,6 +7181,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // To extract the upper XLEN bits of the vector element, shift the first // element right by 32 bits and re-extract the lower XLEN bits. SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), DAG.getConstant(32, DL, XLenVT), VL); SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Vec, ThirtyTwoV, Mask, VL); @@ -6597,38 +7198,42 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, llvm_unreachable( "Don't know how to custom type legalize this intrinsic!"); case Intrinsic::riscv_grev: - case Intrinsic::riscv_gorc: - case Intrinsic::riscv_bcompress: - case Intrinsic::riscv_bdecompress: - case Intrinsic::riscv_bfp: { + case Intrinsic::riscv_gorc: { assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); - Results.push_back(customLegalizeToWOpByIntr(N, DAG, IntNo)); + SDValue NewOp1 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); + SDValue NewOp2 = + DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2)); + unsigned Opc = + IntNo == Intrinsic::riscv_grev ? RISCVISD::GREVW : RISCVISD::GORCW; + // If the control is a constant, promote the node by clearing any extra + // bits bits in the control. isel will form greviw/gorciw if the result is + // sign extended. + if (isa<ConstantSDNode>(NewOp2)) { + NewOp2 = DAG.getNode(ISD::AND, DL, MVT::i64, NewOp2, + DAG.getConstant(0x1f, DL, MVT::i64)); + Opc = IntNo == Intrinsic::riscv_grev ? RISCVISD::GREV : RISCVISD::GORC; + } + SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp1, NewOp2); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); break; } + case Intrinsic::riscv_bcompress: + case Intrinsic::riscv_bdecompress: + case Intrinsic::riscv_bfp: case Intrinsic::riscv_fsl: case Intrinsic::riscv_fsr: { assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); - SDValue NewOp1 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); - SDValue NewOp2 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(2)); - SDValue NewOp3 = - DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(3)); - unsigned Opc = getRISCVWOpcodeByIntr(IntNo); - SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp1, NewOp2, NewOp3); - Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); + Results.push_back(customLegalizeToWOpByIntr(N, DAG, IntNo)); break; } case Intrinsic::riscv_orc_b: { // Lower to the GORCI encoding for orc.b with the operand extended. SDValue NewOp = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); - // If Zbp is enabled, use GORCIW which will sign extend the result. - unsigned Opc = - Subtarget.hasStdExtZbp() ? RISCVISD::GORCW : RISCVISD::GORC; - SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp, + SDValue Res = DAG.getNode(RISCVISD::GORC, DL, MVT::i64, NewOp, DAG.getConstant(7, DL, MVT::i64)); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); return; @@ -6681,10 +7286,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // To extract the upper XLEN bits of the vector element, shift the first // element right by 32 bits and re-extract the lower XLEN bits. SDValue VL = DAG.getConstant(1, DL, XLenVT); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); - SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, - DAG.getConstant(32, DL, XLenVT), VL); + SDValue Mask = getAllOnesMask(VecVT, VL, DL, DAG); + + SDValue ThirtyTwoV = + DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), + DAG.getConstant(32, DL, XLenVT), VL); SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, VecVT, Vec, ThirtyTwoV, Mask, VL); SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32); @@ -6840,6 +7446,110 @@ static Optional<RISCVBitmanipPat> matchGREVIPat(SDValue Op) { return matchRISCVBitmanipPat(Op, BitmanipMasks); } +// Try to fold (<bop> x, (reduction.<bop> vec, start)) +static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) { + auto BinOpToRVVReduce = [](unsigned Opc) { + switch (Opc) { + default: + llvm_unreachable("Unhandled binary to transfrom reduction"); + case ISD::ADD: + return RISCVISD::VECREDUCE_ADD_VL; + case ISD::UMAX: + return RISCVISD::VECREDUCE_UMAX_VL; + case ISD::SMAX: + return RISCVISD::VECREDUCE_SMAX_VL; + case ISD::UMIN: + return RISCVISD::VECREDUCE_UMIN_VL; + case ISD::SMIN: + return RISCVISD::VECREDUCE_SMIN_VL; + case ISD::AND: + return RISCVISD::VECREDUCE_AND_VL; + case ISD::OR: + return RISCVISD::VECREDUCE_OR_VL; + case ISD::XOR: + return RISCVISD::VECREDUCE_XOR_VL; + case ISD::FADD: + return RISCVISD::VECREDUCE_FADD_VL; + case ISD::FMAXNUM: + return RISCVISD::VECREDUCE_FMAX_VL; + case ISD::FMINNUM: + return RISCVISD::VECREDUCE_FMIN_VL; + } + }; + + auto IsReduction = [&BinOpToRVVReduce](SDValue V, unsigned Opc) { + return V.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isNullConstant(V.getOperand(1)) && + V.getOperand(0).getOpcode() == BinOpToRVVReduce(Opc); + }; + + unsigned Opc = N->getOpcode(); + unsigned ReduceIdx; + if (IsReduction(N->getOperand(0), Opc)) + ReduceIdx = 0; + else if (IsReduction(N->getOperand(1), Opc)) + ReduceIdx = 1; + else + return SDValue(); + + // Skip if FADD disallows reassociation but the combiner needs. + if (Opc == ISD::FADD && !N->getFlags().hasAllowReassociation()) + return SDValue(); + + SDValue Extract = N->getOperand(ReduceIdx); + SDValue Reduce = Extract.getOperand(0); + if (!Reduce.hasOneUse()) + return SDValue(); + + SDValue ScalarV = Reduce.getOperand(2); + + // Make sure that ScalarV is a splat with VL=1. + if (ScalarV.getOpcode() != RISCVISD::VFMV_S_F_VL && + ScalarV.getOpcode() != RISCVISD::VMV_S_X_VL && + ScalarV.getOpcode() != RISCVISD::VMV_V_X_VL) + return SDValue(); + + if (!isOneConstant(ScalarV.getOperand(2))) + return SDValue(); + + // TODO: Deal with value other than neutral element. + auto IsRVVNeutralElement = [Opc, &DAG](SDNode *N, SDValue V) { + if (Opc == ISD::FADD && N->getFlags().hasNoSignedZeros() && + isNullFPConstant(V)) + return true; + return DAG.getNeutralElement(Opc, SDLoc(V), V.getSimpleValueType(), + N->getFlags()) == V; + }; + + // Check the scalar of ScalarV is neutral element + if (!IsRVVNeutralElement(N, ScalarV.getOperand(1))) + return SDValue(); + + if (!ScalarV.hasOneUse()) + return SDValue(); + + EVT SplatVT = ScalarV.getValueType(); + SDValue NewStart = N->getOperand(1 - ReduceIdx); + unsigned SplatOpc = RISCVISD::VFMV_S_F_VL; + if (SplatVT.isInteger()) { + auto *C = dyn_cast<ConstantSDNode>(NewStart.getNode()); + if (!C || C->isZero() || !isInt<5>(C->getSExtValue())) + SplatOpc = RISCVISD::VMV_S_X_VL; + else + SplatOpc = RISCVISD::VMV_V_X_VL; + } + + SDValue NewScalarV = + DAG.getNode(SplatOpc, SDLoc(N), SplatVT, ScalarV.getOperand(0), NewStart, + ScalarV.getOperand(2)); + SDValue NewReduce = + DAG.getNode(Reduce.getOpcode(), SDLoc(Reduce), Reduce.getValueType(), + Reduce.getOperand(0), Reduce.getOperand(1), NewScalarV, + Reduce.getOperand(3), Reduce.getOperand(4)); + return DAG.getNode(Extract.getOpcode(), SDLoc(Extract), + Extract.getValueType(), NewReduce, Extract.getOperand(1)); +} + // Match the following pattern as a GREVI(W) operation // (or (BITMANIP_SHL x), (BITMANIP_SRL x)) static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG, @@ -7066,11 +7776,70 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT)); } +// Combine +// ROTR ((GREVI x, 24), 16) -> (GREVI x, 8) for RV32 +// ROTL ((GREVI x, 24), 16) -> (GREVI x, 8) for RV32 +// ROTR ((GREVI x, 56), 32) -> (GREVI x, 24) for RV64 +// ROTL ((GREVI x, 56), 32) -> (GREVI x, 24) for RV64 +// RORW ((GREVI x, 24), 16) -> (GREVIW x, 8) for RV64 +// ROLW ((GREVI x, 24), 16) -> (GREVIW x, 8) for RV64 +// The grev patterns represents BSWAP. +// FIXME: This can be generalized to any GREV. We just need to toggle the MSB +// off the grev. +static SDValue combineROTR_ROTL_RORW_ROLW(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + bool IsWInstruction = + N->getOpcode() == RISCVISD::RORW || N->getOpcode() == RISCVISD::ROLW; + assert((N->getOpcode() == ISD::ROTR || N->getOpcode() == ISD::ROTL || + IsWInstruction) && + "Unexpected opcode!"); + SDValue Src = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + if (!Subtarget.hasStdExtZbp() || Src.getOpcode() != RISCVISD::GREV) + return SDValue(); + + if (!isa<ConstantSDNode>(N->getOperand(1)) || + !isa<ConstantSDNode>(Src.getOperand(1))) + return SDValue(); + + unsigned BitWidth = IsWInstruction ? 32 : VT.getSizeInBits(); + assert(isPowerOf2_32(BitWidth) && "Expected a power of 2"); + + // Needs to be a rotate by half the bitwidth for ROTR/ROTL or by 16 for + // RORW/ROLW. And the grev should be the encoding for bswap for this width. + unsigned ShAmt1 = N->getConstantOperandVal(1); + unsigned ShAmt2 = Src.getConstantOperandVal(1); + if (BitWidth < 32 || ShAmt1 != (BitWidth / 2) || ShAmt2 != (BitWidth - 8)) + return SDValue(); + + Src = Src.getOperand(0); + + // Toggle bit the MSB of the shift. + unsigned CombinedShAmt = ShAmt1 ^ ShAmt2; + if (CombinedShAmt == 0) + return Src; + + SDValue Res = DAG.getNode( + RISCVISD::GREV, DL, VT, Src, + DAG.getConstant(CombinedShAmt, DL, N->getOperand(1).getValueType())); + if (!IsWInstruction) + return Res; + + // Sign extend the result to match the behavior of the rotate. This will be + // selected to GREVIW in isel. + return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Res, + DAG.getValueType(MVT::i32)); +} + // Combine (GREVI (GREVI x, C2), C1) -> (GREVI x, C1^C2) when C1^C2 is // non-zero, and to x when it is. Any repeated GREVI stage undoes itself. // Combine (GORCI (GORCI x, C2), C1) -> (GORCI x, C1|C2). Repeated stage does // not undo itself, but they are redundant. static SDValue combineGREVI_GORCI(SDNode *N, SelectionDAG &DAG) { + bool IsGORC = N->getOpcode() == RISCVISD::GORC; + assert((IsGORC || N->getOpcode() == RISCVISD::GREV) && "Unexpected opcode"); SDValue Src = N->getOperand(0); if (Src.getOpcode() != N->getOpcode()) @@ -7085,7 +7854,7 @@ static SDValue combineGREVI_GORCI(SDNode *N, SelectionDAG &DAG) { Src = Src.getOperand(0); unsigned CombinedShAmt; - if (N->getOpcode() == RISCVISD::GORC || N->getOpcode() == RISCVISD::GORCW) + if (IsGORC) CombinedShAmt = ShAmt1 | ShAmt2; else CombinedShAmt = ShAmt1 ^ ShAmt2; @@ -7203,6 +7972,11 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG, auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1)); if (!N0C || !N1C) return SDValue(); + // If N0C has multiple uses it's possible one of the cases in + // DAGCombiner::isMulAddWithConstProfitable will be true, which would result + // in an infinite loop. + if (!N0C->hasOneUse()) + return SDValue(); int64_t C0 = N0C->getSExtValue(); int64_t C1 = N1C->getSExtValue(); int64_t CA, CB; @@ -7238,6 +8012,8 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG, return V; if (SDValue V = transformAddShlImm(N, DAG, Subtarget)) return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) + return V; // fold (add (select lhs, rhs, cc, 0, y), x) -> // (select lhs, rhs, cc, x, (add x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false); @@ -7251,7 +8027,30 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG) { return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false); } -static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + // Pre-promote (i32 (and (srl X, Y), 1)) on RV64 with Zbs without zero + // extending X. This is safe since we only need the LSB after the shift and + // shift amounts larger than 31 would produce poison. If we wait until + // type legalization, we'll create RISCVISD::SRLW and we can't recover it + // to use a BEXT instruction. + if (Subtarget.is64Bit() && Subtarget.hasStdExtZbs() && + N->getValueType(0) == MVT::i32 && isOneConstant(N->getOperand(1)) && + N0.getOpcode() == ISD::SRL && !isa<ConstantSDNode>(N0.getOperand(1)) && + N0.hasOneUse()) { + SDLoc DL(N); + SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N0.getOperand(0)); + SDValue Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, N0.getOperand(1)); + SDValue Srl = DAG.getNode(ISD::SRL, DL, MVT::i64, Op0, Op1); + SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, Srl, + DAG.getConstant(1, DL, MVT::i64)); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); + } + + if (SDValue V = combineBinOpToReduce(N, DAG)) + return V; + // fold (and (select lhs, rhs, cc, -1, y), x) -> // (select lhs, rhs, cc, x, (and x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ true); @@ -7268,99 +8067,197 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG, return SHFL; } + if (SDValue V = combineBinOpToReduce(N, DAG)) + return V; // fold (or (select cond, 0, y), x) -> // (select cond, x, (or x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false); } static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // fold (xor (sllw 1, x), -1) -> (rolw ~1, x) + // NOTE: Assumes ROL being legal means ROLW is legal. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (N0.getOpcode() == RISCVISD::SLLW && + isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0)) && + TLI.isOperationLegal(ISD::ROTL, MVT::i64)) { + SDLoc DL(N); + return DAG.getNode(RISCVISD::ROLW, DL, MVT::i64, + DAG.getConstant(~1, DL, MVT::i64), N0.getOperand(1)); + } + + if (SDValue V = combineBinOpToReduce(N, DAG)) + return V; // fold (xor (select cond, 0, y), x) -> // (select cond, x, (xor x, y)) return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false); } -// Attempt to turn ANY_EXTEND into SIGN_EXTEND if the input to the ANY_EXTEND -// has users that require SIGN_EXTEND and the SIGN_EXTEND can be done for free -// by an instruction like ADDW/SUBW/MULW. Without this the ANY_EXTEND would be -// removed during type legalization leaving an ADD/SUB/MUL use that won't use -// ADDW/SUBW/MULW. -static SDValue performANY_EXTENDCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - const RISCVSubtarget &Subtarget) { - if (!Subtarget.is64Bit()) - return SDValue(); - - SelectionDAG &DAG = DCI.DAG; - +static SDValue +performSIGN_EXTEND_INREGCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { SDValue Src = N->getOperand(0); EVT VT = N->getValueType(0); - if (VT != MVT::i64 || Src.getValueType() != MVT::i32) - return SDValue(); - // The opcode must be one that can implicitly sign_extend. - // FIXME: Additional opcodes. - switch (Src.getOpcode()) { - default: - return SDValue(); - case ISD::MUL: - if (!Subtarget.hasStdExtM()) - return SDValue(); - LLVM_FALLTHROUGH; - case ISD::ADD: - case ISD::SUB: - break; + // Fold (sext_inreg (fmv_x_anyexth X), i16) -> (fmv_x_signexth X) + if (Src.getOpcode() == RISCVISD::FMV_X_ANYEXTH && + cast<VTSDNode>(N->getOperand(1))->getVT().bitsGE(MVT::i16)) + return DAG.getNode(RISCVISD::FMV_X_SIGNEXTH, SDLoc(N), VT, + Src.getOperand(0)); + + // Fold (i64 (sext_inreg (abs X), i32)) -> + // (i64 (smax (sext_inreg (neg X), i32), X)) if X has more than 32 sign bits. + // The (sext_inreg (neg X), i32) will be selected to negw by isel. This + // pattern occurs after type legalization of (i32 (abs X)) on RV64 if the user + // of the (i32 (abs X)) is a sext or setcc or something else that causes type + // legalization to add a sext_inreg after the abs. The (i32 (abs X)) will have + // been type legalized to (i64 (abs (sext_inreg X, i32))), but the sext_inreg + // may get combined into an earlier operation so we need to use + // ComputeNumSignBits. + // NOTE: (i64 (sext_inreg (abs X), i32)) can also be created for + // (i64 (ashr (shl (abs X), 32), 32)) without any type legalization so + // we can't assume that X has 33 sign bits. We must check. + if (Subtarget.hasStdExtZbb() && Subtarget.is64Bit() && + Src.getOpcode() == ISD::ABS && Src.hasOneUse() && VT == MVT::i64 && + cast<VTSDNode>(N->getOperand(1))->getVT() == MVT::i32 && + DAG.ComputeNumSignBits(Src.getOperand(0)) > 32) { + SDLoc DL(N); + SDValue Freeze = DAG.getFreeze(Src.getOperand(0)); + SDValue Neg = + DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, MVT::i64), Freeze); + Neg = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, Neg, + DAG.getValueType(MVT::i32)); + return DAG.getNode(ISD::SMAX, DL, MVT::i64, Freeze, Neg); } - // Only handle cases where the result is used by a CopyToReg. That likely - // means the value is a liveout of the basic block. This helps prevent - // infinite combine loops like PR51206. - if (none_of(N->uses(), - [](SDNode *User) { return User->getOpcode() == ISD::CopyToReg; })) - return SDValue(); + return SDValue(); +} - SmallVector<SDNode *, 4> SetCCs; - for (SDNode::use_iterator UI = Src.getNode()->use_begin(), - UE = Src.getNode()->use_end(); - UI != UE; ++UI) { - SDNode *User = *UI; - if (User == N) - continue; - if (UI.getUse().getResNo() != Src.getResNo()) - continue; - // All i32 setccs are legalized by sign extending operands. - if (User->getOpcode() == ISD::SETCC) { - SetCCs.push_back(User); - continue; - } - // We don't know if we can extend this user. - break; +// Try to form vwadd(u).wv/wx or vwsub(u).wv/wx. It might later be optimized to +// vwadd(u).vv/vx or vwsub(u).vv/vx. +static SDValue combineADDSUB_VLToVWADDSUB_VL(SDNode *N, SelectionDAG &DAG, + bool Commute = false) { + assert((N->getOpcode() == RISCVISD::ADD_VL || + N->getOpcode() == RISCVISD::SUB_VL) && + "Unexpected opcode"); + bool IsAdd = N->getOpcode() == RISCVISD::ADD_VL; + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + if (Commute) + std::swap(Op0, Op1); + + MVT VT = N->getSimpleValueType(0); + + // Determine the narrow size for a widening add/sub. + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), + VT.getVectorElementCount()); + + SDValue Mask = N->getOperand(2); + SDValue VL = N->getOperand(3); + + SDLoc DL(N); + + // If the RHS is a sext or zext, we can form a widening op. + if ((Op1.getOpcode() == RISCVISD::VZEXT_VL || + Op1.getOpcode() == RISCVISD::VSEXT_VL) && + Op1.hasOneUse() && Op1.getOperand(1) == Mask && Op1.getOperand(2) == VL) { + unsigned ExtOpc = Op1.getOpcode(); + Op1 = Op1.getOperand(0); + // Re-introduce narrower extends if needed. + if (Op1.getValueType() != NarrowVT) + Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); + + unsigned WOpc; + if (ExtOpc == RISCVISD::VSEXT_VL) + WOpc = IsAdd ? RISCVISD::VWADD_W_VL : RISCVISD::VWSUB_W_VL; + else + WOpc = IsAdd ? RISCVISD::VWADDU_W_VL : RISCVISD::VWSUBU_W_VL; + + return DAG.getNode(WOpc, DL, VT, Op0, Op1, Mask, VL); } - // If we don't have any SetCCs, this isn't worthwhile. - if (SetCCs.empty()) - return SDValue(); + // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar + // sext/zext? + + return SDValue(); +} + +// Try to convert vwadd(u).wv/wx or vwsub(u).wv/wx to vwadd(u).vv/vx or +// vwsub(u).vv/vx. +static SDValue combineVWADD_W_VL_VWSUB_W_VL(SDNode *N, SelectionDAG &DAG) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Mask = N->getOperand(2); + SDValue VL = N->getOperand(3); + + MVT VT = N->getSimpleValueType(0); + MVT NarrowVT = Op1.getSimpleValueType(); + unsigned NarrowSize = NarrowVT.getScalarSizeInBits(); + + unsigned VOpc; + switch (N->getOpcode()) { + default: llvm_unreachable("Unexpected opcode"); + case RISCVISD::VWADD_W_VL: VOpc = RISCVISD::VWADD_VL; break; + case RISCVISD::VWSUB_W_VL: VOpc = RISCVISD::VWSUB_VL; break; + case RISCVISD::VWADDU_W_VL: VOpc = RISCVISD::VWADDU_VL; break; + case RISCVISD::VWSUBU_W_VL: VOpc = RISCVISD::VWSUBU_VL; break; + } + + bool IsSigned = N->getOpcode() == RISCVISD::VWADD_W_VL || + N->getOpcode() == RISCVISD::VWSUB_W_VL; SDLoc DL(N); - SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Src); - DCI.CombineTo(N, SExt); - // Promote all the setccs. - for (SDNode *SetCC : SetCCs) { - SmallVector<SDValue, 4> Ops; + // If the LHS is a sext or zext, we can narrow this op to the same size as + // the RHS. + if (((Op0.getOpcode() == RISCVISD::VZEXT_VL && !IsSigned) || + (Op0.getOpcode() == RISCVISD::VSEXT_VL && IsSigned)) && + Op0.hasOneUse() && Op0.getOperand(1) == Mask && Op0.getOperand(2) == VL) { + unsigned ExtOpc = Op0.getOpcode(); + Op0 = Op0.getOperand(0); + // Re-introduce narrower extends if needed. + if (Op0.getValueType() != NarrowVT) + Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); + return DAG.getNode(VOpc, DL, VT, Op0, Op1, Mask, VL); + } - for (unsigned j = 0; j != 2; ++j) { - SDValue SOp = SetCC->getOperand(j); - if (SOp == Src) - Ops.push_back(SExt); - else - Ops.push_back(DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, SOp)); + bool IsAdd = N->getOpcode() == RISCVISD::VWADD_W_VL || + N->getOpcode() == RISCVISD::VWADDU_W_VL; + + // Look for splats on the left hand side of a vwadd(u).wv. We might be able + // to commute and use a vwadd(u).vx instead. + if (IsAdd && Op0.getOpcode() == RISCVISD::VMV_V_X_VL && + Op0.getOperand(0).isUndef() && Op0.getOperand(2) == VL) { + Op0 = Op0.getOperand(1); + + // See if have enough sign bits or zero bits in the scalar to use a + // widening add/sub by splatting to smaller element size. + unsigned EltBits = VT.getScalarSizeInBits(); + unsigned ScalarBits = Op0.getValueSizeInBits(); + // Make sure we're getting all element bits from the scalar register. + // FIXME: Support implicit sign extension of vmv.v.x? + if (ScalarBits < EltBits) + return SDValue(); + + if (IsSigned) { + if (DAG.ComputeNumSignBits(Op0) <= (ScalarBits - NarrowSize)) + return SDValue(); + } else { + APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize); + if (!DAG.MaskedValueIsZero(Op0, Mask)) + return SDValue(); } - Ops.push_back(SetCC->getOperand(2)); - DCI.CombineTo(SetCC, - DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops)); + Op0 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, + DAG.getUNDEF(NarrowVT), Op0, VL); + return DAG.getNode(VOpc, DL, VT, Op1, Op0, Mask, VL); } - return SDValue(N, 0); + + return SDValue(); } // Try to form VWMUL, VWMULU or VWMULSU. @@ -7408,12 +8305,15 @@ static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, } else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) { // The operand is a splat of a scalar. + // The pasthru must be undef for tail agnostic + if (!Op1.getOperand(0).isUndef()) + return SDValue(); // The VL must be the same. - if (Op1.getOperand(1) != VL) + if (Op1.getOperand(2) != VL) return SDValue(); // Get the scalar value. - Op1 = Op1.getOperand(0); + Op1 = Op1.getOperand(1); // See if have enough sign bits or zero bits in the scalar to use a // widening multiply by splatting to smaller element size. @@ -7424,16 +8324,20 @@ static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, if (ScalarBits < EltBits) return SDValue(); - if (IsSignExt) { - if (DAG.ComputeNumSignBits(Op1) <= (ScalarBits - NarrowSize)) - return SDValue(); + // If the LHS is a sign extend, try to use vwmul. + if (IsSignExt && DAG.ComputeNumSignBits(Op1) > (ScalarBits - NarrowSize)) { + // Can use vwmul. } else { + // Otherwise try to use vwmulu or vwmulsu. APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize); - if (!DAG.MaskedValueIsZero(Op1, Mask)) + if (DAG.MaskedValueIsZero(Op1, Mask)) + IsVWMULSU = IsSignExt; + else return SDValue(); } - Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, Op1, VL); + Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, + DAG.getUNDEF(NarrowVT), Op1, VL); } else return SDValue(); @@ -7443,6 +8347,8 @@ static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; if (Op0.getValueType() != NarrowVT) Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); + // vwmulsu requires second operand to be zero extended. + ExtOpc = IsVWMULSU ? RISCVISD::VZEXT_VL : ExtOpc; if (Op1.getValueType() != NarrowVT) Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); @@ -7569,6 +8475,133 @@ static SDValue performFP_TO_INT_SATCombine(SDNode *N, return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO); } +// Combine (bitreverse (bswap X)) to the BREV8 GREVI encoding if the type is +// smaller than XLenVT. +static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(Subtarget.hasStdExtZbkb() && "Unexpected extension"); + + SDValue Src = N->getOperand(0); + if (Src.getOpcode() != ISD::BSWAP) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalarInteger() || VT.getSizeInBits() >= Subtarget.getXLen() || + !isPowerOf2_32(VT.getSizeInBits())) + return SDValue(); + + SDLoc DL(N); + return DAG.getNode(RISCVISD::GREV, DL, VT, Src.getOperand(0), + DAG.getConstant(7, DL, VT)); +} + +// Convert from one FMA opcode to another based on whether we are negating the +// multiply result and/or the accumulator. +// NOTE: Only supports RVV operations with VL. +static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) { + assert((NegMul || NegAcc) && "Not negating anything?"); + + // Negating the multiply result changes ADD<->SUB and toggles 'N'. + if (NegMul) { + // clang-format off + switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); + case RISCVISD::VFMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break; + case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFMADD_VL; break; + case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFMSUB_VL; break; + case RISCVISD::VFMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break; + } + // clang-format on + } + + // Negating the accumulator changes ADD<->SUB. + if (NegAcc) { + // clang-format off + switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); + case RISCVISD::VFMADD_VL: Opcode = RISCVISD::VFMSUB_VL; break; + case RISCVISD::VFMSUB_VL: Opcode = RISCVISD::VFMADD_VL; break; + case RISCVISD::VFNMADD_VL: Opcode = RISCVISD::VFNMSUB_VL; break; + case RISCVISD::VFNMSUB_VL: Opcode = RISCVISD::VFNMADD_VL; break; + } + // clang-format on + } + + return Opcode; +} + +// Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C) +// FIXME: Should this be a generic combine? There's a similar combine on X86. +// +// Also try these folds where an add or sub is in the middle. +// (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C) +// (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C) +static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); + + if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit()) + return SDValue(); + + auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!ShAmtC || ShAmtC->getZExtValue() > 32) + return SDValue(); + + SDValue N0 = N->getOperand(0); + + SDValue Shl; + ConstantSDNode *AddC = nullptr; + + // We might have an ADD or SUB between the SRA and SHL. + bool IsAdd = N0.getOpcode() == ISD::ADD; + if ((IsAdd || N0.getOpcode() == ISD::SUB)) { + if (!N0.hasOneUse()) + return SDValue(); + // Other operand needs to be a constant we can modify. + AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0)); + if (!AddC) + return SDValue(); + + // AddC needs to have at least 32 trailing zeros. + if (AddC->getAPIntValue().countTrailingZeros() < 32) + return SDValue(); + + Shl = N0.getOperand(IsAdd ? 0 : 1); + } else { + // Not an ADD or SUB. + Shl = N0; + } + + // Look for a shift left by 32. + if (Shl.getOpcode() != ISD::SHL || !Shl.hasOneUse() || + !isa<ConstantSDNode>(Shl.getOperand(1)) || + Shl.getConstantOperandVal(1) != 32) + return SDValue(); + + SDLoc DL(N); + SDValue In = Shl.getOperand(0); + + // If we looked through an ADD or SUB, we need to rebuild it with the shifted + // constant. + if (AddC) { + SDValue ShiftedAddC = + DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64); + if (IsAdd) + In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC); + else + In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In); + } + + SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In, + DAG.getValueType(MVT::i32)); + if (ShAmtC->getZExtValue() == 32) + return SExt; + + return DAG.getNode( + ISD::SHL, DL, MVT::i64, SExt, + DAG.getConstant(32 - ShAmtC->getZExtValue(), DL, MVT::i64)); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -7597,6 +8630,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (Op0->getOpcode() == RISCVISD::BuildPairF64) return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1)); + if (Op0->isUndef()) { + SDValue Lo = DAG.getUNDEF(MVT::i32); + SDValue Hi = DAG.getUNDEF(MVT::i32); + return DCI.CombineTo(N, Lo, Hi); + } + SDLoc DL(N); // It's cheaper to materialise two 32-bit integers than to load a double @@ -7634,15 +8673,27 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } case RISCVISD::SLLW: case RISCVISD::SRAW: - case RISCVISD::SRLW: - case RISCVISD::ROLW: - case RISCVISD::RORW: { + case RISCVISD::SRLW: { // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. if (SimplifyDemandedLowBitsHelper(0, 32) || SimplifyDemandedLowBitsHelper(1, 5)) return SDValue(N, 0); + break; } + case ISD::ROTR: + case ISD::ROTL: + case RISCVISD::RORW: + case RISCVISD::ROLW: { + if (N->getOpcode() == RISCVISD::RORW || N->getOpcode() == RISCVISD::ROLW) { + // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 5)) + return SDValue(N, 0); + } + + return combineROTR_ROTL_RORW_ROLW(N, DAG, Subtarget); + } case RISCVISD::CLZW: case RISCVISD::CTZW: { // Only the lower 32 bits of the first operand are read @@ -7667,7 +8718,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, SimplifyDemandedLowBitsHelper(1, 5)) return SDValue(N, 0); - return combineGREVI_GORCI(N, DAG); + break; } case RISCVISD::SHFL: case RISCVISD::UNSHFL: { @@ -7682,10 +8733,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::SHFLW: case RISCVISD::UNSHFLW: { // Only the lower 32 bits of LHS and lower 4 bits of RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4); if (SimplifyDemandedLowBitsHelper(0, 32) || SimplifyDemandedLowBitsHelper(1, 4)) return SDValue(N, 0); @@ -7701,6 +8748,21 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } + case RISCVISD::FSR: + case RISCVISD::FSL: + case RISCVISD::FSRW: + case RISCVISD::FSLW: { + bool IsWInstruction = + N->getOpcode() == RISCVISD::FSRW || N->getOpcode() == RISCVISD::FSLW; + unsigned BitWidth = + IsWInstruction ? 32 : N->getSimpleValueType(0).getSizeInBits(); + assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); + // Only the lower log2(Bitwidth)+1 bits of the the shift amount are read. + if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth) + 1)) + return SDValue(N, 0); + + break; + } case RISCVISD::FMV_X_ANYEXTH: case RISCVISD::FMV_X_ANYEXTW_RV64: { SDLoc DL(N); @@ -7727,7 +8789,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; SDValue NewFMV = DAG.getNode(N->getOpcode(), DL, VT, Op0.getOperand(0)); unsigned FPBits = N->getOpcode() == RISCVISD::FMV_X_ANYEXTW_RV64 ? 32 : 16; - APInt SignBit = APInt::getSignMask(FPBits).sextOrSelf(VT.getSizeInBits()); + APInt SignBit = APInt::getSignMask(FPBits).sext(VT.getSizeInBits()); if (Op0.getOpcode() == ISD::FNEG) return DAG.getNode(ISD::XOR, DL, VT, NewFMV, DAG.getConstant(SignBit, DL, VT)); @@ -7741,13 +8803,21 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case ISD::SUB: return performSUBCombine(N, DAG); case ISD::AND: - return performANDCombine(N, DAG); + return performANDCombine(N, DAG, Subtarget); case ISD::OR: return performORCombine(N, DAG, Subtarget); case ISD::XOR: return performXORCombine(N, DAG); - case ISD::ANY_EXTEND: - return performANY_EXTENDCombine(N, DCI, Subtarget); + case ISD::FADD: + case ISD::UMAX: + case ISD::UMIN: + case ISD::SMAX: + case ISD::SMIN: + case ISD::FMAXNUM: + case ISD::FMINNUM: + return combineBinOpToReduce(N, DAG); + case ISD::SIGN_EXTEND_INREG: + return performSIGN_EXTEND_INREGCombine(N, DAG, Subtarget); case ISD::ZERO_EXTEND: // Fold (zero_extend (fp_to_uint X)) to prevent forming fcvt+zexti32 during // type legalization. This is safe because fp_to_uint produces poison if @@ -7879,6 +8949,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } break; } + case ISD::BITREVERSE: + return performBITREVERSECombine(N, DAG, Subtarget); case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: return performFP_TO_INTCombine(N, DCI, Subtarget); @@ -7952,40 +9024,41 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DL, IndexVT, Index); } - unsigned Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue(); - if (IsIndexScaled && Scale != 1) { - // Manually scale the indices by the element size. + if (IsIndexScaled) { + // Manually scale the indices. // TODO: Sanitize the scale operand here? // TODO: For VP nodes, should we use VP_SHL here? + unsigned Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue(); assert(isPowerOf2_32(Scale) && "Expecting power-of-two types"); SDValue SplatScale = DAG.getConstant(Log2_32(Scale), DL, IndexVT); Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index, SplatScale); + ScaleOp = DAG.getTargetConstant(1, DL, ScaleOp.getValueType()); } - ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_UNSCALED; + ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED; if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N)) return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL, {VPGN->getChain(), VPGN->getBasePtr(), Index, - VPGN->getScale(), VPGN->getMask(), + ScaleOp, VPGN->getMask(), VPGN->getVectorLength()}, VPGN->getMemOperand(), NewIndexTy); if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N)) return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL, {VPSN->getChain(), VPSN->getValue(), - VPSN->getBasePtr(), Index, VPSN->getScale(), + VPSN->getBasePtr(), Index, ScaleOp, VPSN->getMask(), VPSN->getVectorLength()}, VPSN->getMemOperand(), NewIndexTy); if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N)) return DAG.getMaskedGather( N->getVTList(), MGN->getMemoryVT(), DL, {MGN->getChain(), MGN->getPassThru(), MGN->getMask(), - MGN->getBasePtr(), Index, MGN->getScale()}, + MGN->getBasePtr(), Index, ScaleOp}, MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType()); const auto *MSN = cast<MaskedScatterSDNode>(N); return DAG.getMaskedScatter( N->getVTList(), MSN->getMemoryVT(), DL, {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(), - Index, MSN->getScale()}, + Index, ScaleOp}, MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore()); } case RISCVISD::SRA_VL: @@ -7997,14 +9070,17 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, SDLoc DL(N); SDValue VL = N->getOperand(3); EVT VT = N->getValueType(0); - ShAmt = - DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, ShAmt.getOperand(0), VL); + ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), + ShAmt.getOperand(1), VL); return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt, N->getOperand(2), N->getOperand(3)); } break; } case ISD::SRA: + if (SDValue V = performSRACombine(N, DAG, Subtarget)) + return V; + LLVM_FALLTHROUGH; case ISD::SRL: case ISD::SHL: { SDValue ShAmt = N->getOperand(1); @@ -8012,17 +9088,63 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, // We don't need the upper 32 bits of a 64-bit element for a shift amount. SDLoc DL(N); EVT VT = N->getValueType(0); - ShAmt = - DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT, ShAmt.getOperand(0)); + ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), + ShAmt.getOperand(1), + DAG.getRegister(RISCV::X0, Subtarget.getXLenVT())); return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt); } break; } + case RISCVISD::ADD_VL: + if (SDValue V = combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ false)) + return V; + return combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ true); + case RISCVISD::SUB_VL: + return combineADDSUB_VLToVWADDSUB_VL(N, DAG); + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return combineVWADD_W_VL_VWSUB_W_VL(N, DAG); case RISCVISD::MUL_VL: if (SDValue V = combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ false)) return V; // Mul is commutative. return combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ true); + case RISCVISD::VFMADD_VL: + case RISCVISD::VFNMADD_VL: + case RISCVISD::VFMSUB_VL: + case RISCVISD::VFNMSUB_VL: { + // Fold FNEG_VL into FMA opcodes. + SDValue A = N->getOperand(0); + SDValue B = N->getOperand(1); + SDValue C = N->getOperand(2); + SDValue Mask = N->getOperand(3); + SDValue VL = N->getOperand(4); + + auto invertIfNegative = [&Mask, &VL](SDValue &V) { + if (V.getOpcode() == RISCVISD::FNEG_VL && V.getOperand(1) == Mask && + V.getOperand(2) == VL) { + // Return the negated input. + V = V.getOperand(0); + return true; + } + + return false; + }; + + bool NegA = invertIfNegative(A); + bool NegB = invertIfNegative(B); + bool NegC = invertIfNegative(C); + + // If no operands are negated, we're done. + if (!NegA && !NegB && !NegC) + return SDValue(); + + unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC); + return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask, + VL); + } case ISD::STORE: { auto *Store = cast<StoreSDNode>(N); SDValue Val = Store->getValue(); @@ -8035,7 +9157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, // The memory VT and the element type must match. if (VecVT.getVectorElementType() == MemVT) { SDLoc DL(N); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount()); + MVT MaskVT = getMaskTypeFor(VecVT); return DAG.getStoreVP( Store->getChain(), DL, Src, Store->getBasePtr(), Store->getOffset(), DAG.getConstant(1, DL, MaskVT), @@ -8047,6 +9169,73 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } + case ISD::SPLAT_VECTOR: { + EVT VT = N->getValueType(0); + // Only perform this combine on legal MVT types. + if (!isTypeLegal(VT)) + break; + if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N, + DAG, Subtarget)) + return Gather; + break; + } + case RISCVISD::VMV_V_X_VL: { + // Tail agnostic VMV.V.X only demands the vector element bitwidth from the + // scalar input. + unsigned ScalarSize = N->getOperand(1).getValueSizeInBits(); + unsigned EltWidth = N->getValueType(0).getScalarSizeInBits(); + if (ScalarSize > EltWidth && N->getOperand(0).isUndef()) + if (SimplifyDemandedLowBitsHelper(1, EltWidth)) + return SDValue(N, 0); + + break; + } + case ISD::INTRINSIC_WO_CHAIN: { + unsigned IntNo = N->getConstantOperandVal(0); + switch (IntNo) { + // By default we do not combine any intrinsic. + default: + return SDValue(); + case Intrinsic::riscv_vcpop: + case Intrinsic::riscv_vcpop_mask: + case Intrinsic::riscv_vfirst: + case Intrinsic::riscv_vfirst_mask: { + SDValue VL = N->getOperand(2); + if (IntNo == Intrinsic::riscv_vcpop_mask || + IntNo == Intrinsic::riscv_vfirst_mask) + VL = N->getOperand(3); + if (!isNullConstant(VL)) + return SDValue(); + // If VL is 0, vcpop -> li 0, vfirst -> li -1. + SDLoc DL(N); + EVT VT = N->getValueType(0); + if (IntNo == Intrinsic::riscv_vfirst || + IntNo == Intrinsic::riscv_vfirst_mask) + return DAG.getConstant(-1, DL, VT); + return DAG.getConstant(0, DL, VT); + } + } + } + case ISD::BITCAST: { + assert(Subtarget.useRVVForFixedLengthVectors()); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + EVT SrcVT = N0.getValueType(); + // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer + // type, widen both sides to avoid a trip through memory. + if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) && + VT.isScalarInteger()) { + unsigned NumConcats = 8 / SrcVT.getVectorNumElements(); + SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT)); + Ops[0] = N0; + SDLoc DL(N); + N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i1, Ops); + N0 = DAG.getBitcast(MVT::i8, N0); + return DAG.getNode(ISD::TRUNCATE, DL, VT, N0); + } + + return SDValue(); + } } return SDValue(); @@ -8182,22 +9371,23 @@ bool RISCVTargetLowering::targetShrinkDemandedConstant( return UseMask(NewMask); } -static void computeGREV(APInt &Src, unsigned ShAmt) { - ShAmt &= Src.getBitWidth() - 1; - uint64_t x = Src.getZExtValue(); - if (ShAmt & 1) - x = ((x & 0x5555555555555555LL) << 1) | ((x & 0xAAAAAAAAAAAAAAAALL) >> 1); - if (ShAmt & 2) - x = ((x & 0x3333333333333333LL) << 2) | ((x & 0xCCCCCCCCCCCCCCCCLL) >> 2); - if (ShAmt & 4) - x = ((x & 0x0F0F0F0F0F0F0F0FLL) << 4) | ((x & 0xF0F0F0F0F0F0F0F0LL) >> 4); - if (ShAmt & 8) - x = ((x & 0x00FF00FF00FF00FFLL) << 8) | ((x & 0xFF00FF00FF00FF00LL) >> 8); - if (ShAmt & 16) - x = ((x & 0x0000FFFF0000FFFFLL) << 16) | ((x & 0xFFFF0000FFFF0000LL) >> 16); - if (ShAmt & 32) - x = ((x & 0x00000000FFFFFFFFLL) << 32) | ((x & 0xFFFFFFFF00000000LL) >> 32); - Src = x; +static uint64_t computeGREVOrGORC(uint64_t x, unsigned ShAmt, bool IsGORC) { + static const uint64_t GREVMasks[] = { + 0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL, + 0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL}; + + for (unsigned Stage = 0; Stage != 6; ++Stage) { + unsigned Shift = 1 << Stage; + if (ShAmt & Shift) { + uint64_t Mask = GREVMasks[Stage]; + uint64_t Res = ((x & Mask) << Shift) | ((x >> Shift) & Mask); + if (IsGORC) + Res |= x; + x = Res; + } + } + + return x; } void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, @@ -8263,28 +9453,28 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, break; } case RISCVISD::GREV: - case RISCVISD::GREVW: { + case RISCVISD::GORC: { if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { Known = DAG.computeKnownBits(Op.getOperand(0), Depth + 1); - if (Opc == RISCVISD::GREVW) - Known = Known.trunc(32); - unsigned ShAmt = C->getZExtValue(); - computeGREV(Known.Zero, ShAmt); - computeGREV(Known.One, ShAmt); - if (Opc == RISCVISD::GREVW) - Known = Known.sext(BitWidth); + unsigned ShAmt = C->getZExtValue() & (Known.getBitWidth() - 1); + bool IsGORC = Op.getOpcode() == RISCVISD::GORC; + // To compute zeros, we need to invert the value and invert it back after. + Known.Zero = + ~computeGREVOrGORC(~Known.Zero.getZExtValue(), ShAmt, IsGORC); + Known.One = computeGREVOrGORC(Known.One.getZExtValue(), ShAmt, IsGORC); } break; } case RISCVISD::READ_VLENB: { - // If we know the minimum VLen from Zvl extensions, we can use that to - // determine the trailing zeros of VLENB. - // FIXME: Limit to 128 bit vectors until we have more testing. - unsigned MinVLenB = std::min(128U, Subtarget.getMinVLen()) / 8; - if (MinVLenB > 0) - Known.Zero.setLowBits(Log2_32(MinVLenB)); - // We assume VLENB is no more than 65536 / 8 bytes. - Known.Zero.setBitsFrom(14); + // We can use the minimum and maximum VLEN values to bound VLENB. We + // know VLEN must be a power of two. + const unsigned MinVLenB = Subtarget.getRealMinVLen() / 8; + const unsigned MaxVLenB = Subtarget.getRealMaxVLen() / 8; + assert(MinVLenB > 0 && "READ_VLENB without vector extension enabled?"); + Known.Zero.setLowBits(Log2_32(MinVLenB)); + Known.Zero.setBitsFrom(Log2_32(MaxVLenB)+1); + if (MaxVLenB == MinVLenB) + Known.One.setBit(Log2_32(MinVLenB)); break; } case ISD::INTRINSIC_W_CHAIN: @@ -8381,6 +9571,51 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode( return 1; } +const Constant * +RISCVTargetLowering::getTargetConstantFromLoad(LoadSDNode *Ld) const { + assert(Ld && "Unexpected null LoadSDNode"); + if (!ISD::isNormalLoad(Ld)) + return nullptr; + + SDValue Ptr = Ld->getBasePtr(); + + // Only constant pools with no offset are supported. + auto GetSupportedConstantPool = [](SDValue Ptr) -> ConstantPoolSDNode * { + auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr); + if (!CNode || CNode->isMachineConstantPoolEntry() || + CNode->getOffset() != 0) + return nullptr; + + return CNode; + }; + + // Simple case, LLA. + if (Ptr.getOpcode() == RISCVISD::LLA) { + auto *CNode = GetSupportedConstantPool(Ptr); + if (!CNode || CNode->getTargetFlags() != 0) + return nullptr; + + return CNode->getConstVal(); + } + + // Look for a HI and ADD_LO pair. + if (Ptr.getOpcode() != RISCVISD::ADD_LO || + Ptr.getOperand(0).getOpcode() != RISCVISD::HI) + return nullptr; + + auto *CNodeLo = GetSupportedConstantPool(Ptr.getOperand(1)); + auto *CNodeHi = GetSupportedConstantPool(Ptr.getOperand(0).getOperand(0)); + + if (!CNodeLo || CNodeLo->getTargetFlags() != RISCVII::MO_LO || + !CNodeHi || CNodeHi->getTargetFlags() != RISCVII::MO_HI) + return nullptr; + + if (CNodeLo->getConstVal() != CNodeHi->getConstVal()) + return nullptr; + + return CNodeLo->getConstVal(); +} + static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI, MachineBasicBlock *BB) { assert(MI.getOpcode() == RISCV::ReadCycleWide && "Unexpected instruction"); @@ -8559,6 +9794,109 @@ static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB, return BB; } +static MachineBasicBlock * +EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second, + MachineBasicBlock *ThisMBB, + const RISCVSubtarget &Subtarget) { + // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5) + // Without this, custom-inserter would have generated: + // + // A + // | \ + // | B + // | / + // C + // | \ + // | D + // | / + // E + // + // A: X = ...; Y = ... + // B: empty + // C: Z = PHI [X, A], [Y, B] + // D: empty + // E: PHI [X, C], [Z, D] + // + // If we lower both Select_FPRX_ in a single step, we can instead generate: + // + // A + // | \ + // | C + // | /| + // |/ | + // | | + // | D + // | / + // E + // + // A: X = ...; Y = ... + // D: empty + // E: PHI [X, A], [X, C], [Y, D] + + const RISCVInstrInfo &TII = *Subtarget.getInstrInfo(); + const DebugLoc &DL = First.getDebugLoc(); + const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock(); + MachineFunction *F = ThisMBB->getParent(); + MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineFunction::iterator It = ++ThisMBB->getIterator(); + F->insert(It, FirstMBB); + F->insert(It, SecondMBB); + F->insert(It, SinkMBB); + + // Transfer the remainder of ThisMBB and its successor edges to SinkMBB. + SinkMBB->splice(SinkMBB->begin(), ThisMBB, + std::next(MachineBasicBlock::iterator(First)), + ThisMBB->end()); + SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB); + + // Fallthrough block for ThisMBB. + ThisMBB->addSuccessor(FirstMBB); + // Fallthrough block for FirstMBB. + FirstMBB->addSuccessor(SecondMBB); + ThisMBB->addSuccessor(SinkMBB); + FirstMBB->addSuccessor(SinkMBB); + // This is fallthrough. + SecondMBB->addSuccessor(SinkMBB); + + auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm()); + Register FLHS = First.getOperand(1).getReg(); + Register FRHS = First.getOperand(2).getReg(); + // Insert appropriate branch. + BuildMI(ThisMBB, DL, TII.getBrCond(FirstCC)) + .addReg(FLHS) + .addReg(FRHS) + .addMBB(SinkMBB); + + Register SLHS = Second.getOperand(1).getReg(); + Register SRHS = Second.getOperand(2).getReg(); + Register Op1Reg4 = First.getOperand(4).getReg(); + Register Op1Reg5 = First.getOperand(5).getReg(); + + auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm()); + // Insert appropriate branch. + BuildMI(FirstMBB, DL, TII.getBrCond(SecondCC)) + .addReg(SLHS) + .addReg(SRHS) + .addMBB(SinkMBB); + + Register DestReg = Second.getOperand(0).getReg(); + Register Op2Reg4 = Second.getOperand(4).getReg(); + BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg) + .addReg(Op1Reg4) + .addMBB(ThisMBB) + .addReg(Op2Reg4) + .addMBB(FirstMBB) + .addReg(Op1Reg5) + .addMBB(SecondMBB); + + // Now remove the Select_FPRX_s. + First.eraseFromParent(); + Second.eraseFromParent(); + return SinkMBB; +} + static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, MachineBasicBlock *BB, const RISCVSubtarget &Subtarget) { @@ -8586,6 +9924,10 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, // previous selects in the sequence. // These conditions could be further relaxed. See the X86 target for a // related approach and more information. + // + // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)) + // is checked here and handled by a separate function - + // EmitLoweredCascadedSelect. Register LHS = MI.getOperand(1).getReg(); Register RHS = MI.getOperand(2).getReg(); auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm()); @@ -8595,12 +9937,19 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, SelectDests.insert(MI.getOperand(0).getReg()); MachineInstr *LastSelectPseudo = &MI; + auto Next = next_nodbg(MI.getIterator(), BB->instr_end()); + if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() && + Next->getOpcode() == MI.getOpcode() && + Next->getOperand(5).getReg() == MI.getOperand(0).getReg() && + Next->getOperand(5).isKill()) { + return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget); + } for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI); SequenceMBBI != E; ++SequenceMBBI) { if (SequenceMBBI->isDebugInstr()) continue; - else if (isSelectPseudo(*SequenceMBBI)) { + if (isSelectPseudo(*SequenceMBBI)) { if (SequenceMBBI->getOperand(1).getReg() != LHS || SequenceMBBI->getOperand(2).getReg() != RHS || SequenceMBBI->getOperand(3).getImm() != CC || @@ -8831,7 +10180,7 @@ static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo, // Assign the first mask argument to V0. // This is an interim calling convention and it may be changed in the // future. - if (FirstMaskArgument.hasValue() && ValNo == FirstMaskArgument.getValue()) + if (FirstMaskArgument && ValNo == *FirstMaskArgument) return State.AllocateReg(RISCV::V0); return State.AllocateReg(ArgVRs); } @@ -10112,6 +11461,13 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(BuildPairF64) NODE_NAME_CASE(SplitF64) NODE_NAME_CASE(TAIL) + NODE_NAME_CASE(ADD_LO) + NODE_NAME_CASE(HI) + NODE_NAME_CASE(LLA) + NODE_NAME_CASE(ADD_TPREL) + NODE_NAME_CASE(LA) + NODE_NAME_CASE(LA_TLS_IE) + NODE_NAME_CASE(LA_TLS_GD) NODE_NAME_CASE(MULHSU) NODE_NAME_CASE(SLLW) NODE_NAME_CASE(SRAW) @@ -10129,6 +11485,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(FSR) NODE_NAME_CASE(FMV_H_X) NODE_NAME_CASE(FMV_X_ANYEXTH) + NODE_NAME_CASE(FMV_X_SIGNEXTH) NODE_NAME_CASE(FMV_W_X_RV64) NODE_NAME_CASE(FMV_X_ANYEXTW_RV64) NODE_NAME_CASE(FCVT_X) @@ -10157,7 +11514,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VMV_X_S) NODE_NAME_CASE(VMV_S_X_VL) NODE_NAME_CASE(VFMV_S_F_VL) - NODE_NAME_CASE(SPLAT_VECTOR_I64) NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL) NODE_NAME_CASE(READ_VLENB) NODE_NAME_CASE(TRUNCATE_VECTOR_VL) @@ -10203,7 +11559,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(FNEG_VL) NODE_NAME_CASE(FABS_VL) NODE_NAME_CASE(FSQRT_VL) - NODE_NAME_CASE(FMA_VL) + NODE_NAME_CASE(VFMADD_VL) + NODE_NAME_CASE(VFNMADD_VL) + NODE_NAME_CASE(VFMSUB_VL) + NODE_NAME_CASE(VFNMSUB_VL) NODE_NAME_CASE(FCOPYSIGN_VL) NODE_NAME_CASE(SMIN_VL) NODE_NAME_CASE(SMAX_VL) @@ -10222,7 +11581,14 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VWMUL_VL) NODE_NAME_CASE(VWMULU_VL) NODE_NAME_CASE(VWMULSU_VL) + NODE_NAME_CASE(VWADD_VL) NODE_NAME_CASE(VWADDU_VL) + NODE_NAME_CASE(VWSUB_VL) + NODE_NAME_CASE(VWSUBU_VL) + NODE_NAME_CASE(VWADD_W_VL) + NODE_NAME_CASE(VWADDU_W_VL) + NODE_NAME_CASE(VWSUB_W_VL) + NODE_NAME_CASE(VWSUBU_W_VL) NODE_NAME_CASE(SETCC_VL) NODE_NAME_CASE(VSELECT_VL) NODE_NAME_CASE(VP_MERGE_VL) @@ -10237,8 +11603,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VSEXT_VL) NODE_NAME_CASE(VZEXT_VL) NODE_NAME_CASE(VCPOP_VL) - NODE_NAME_CASE(VLE_VL) - NODE_NAME_CASE(VSE_VL) NODE_NAME_CASE(READ_CSR) NODE_NAME_CASE(WRITE_CSR) NODE_NAME_CASE(SWAP_CSR) @@ -10459,7 +11823,18 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, } } - return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); + std::pair<Register, const TargetRegisterClass *> Res = + TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); + + // If we picked one of the Zfinx register classes, remap it to the GPR class. + // FIXME: When Zfinx is supported in CodeGen this will need to take the + // Subtarget into account. + if (Res.second == &RISCV::GPRF16RegClass || + Res.second == &RISCV::GPRF32RegClass || + Res.second == &RISCV::GPRF64RegClass) + return std::make_pair(Res.first, &RISCV::GPRRegClass); + + return Res; } unsigned @@ -10681,7 +12056,8 @@ Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic( return Result; } -bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { +bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT IndexVT, + EVT DataVT) const { return false; } @@ -10797,7 +12173,7 @@ bool RISCVTargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT, APInt ImmS = Imm.ashr(Imm.countTrailingZeros()); if ((ImmS + 1).isPowerOf2() || (ImmS - 1).isPowerOf2() || (1 - ImmS).isPowerOf2()) - return true; + return true; } } } @@ -10805,8 +12181,8 @@ bool RISCVTargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT, return false; } -bool RISCVTargetLowering::isMulAddWithConstProfitable( - const SDValue &AddNode, const SDValue &ConstNode) const { +bool RISCVTargetLowering::isMulAddWithConstProfitable(SDValue AddNode, + SDValue ConstNode) const { // Let the DAGCombiner decide for vectors. EVT VT = AddNode.getValueType(); if (VT.isVector()) @@ -10831,9 +12207,13 @@ bool RISCVTargetLowering::isMulAddWithConstProfitable( bool RISCVTargetLowering::allowsMisalignedMemoryAccesses( EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags, bool *Fast) const { - if (!VT.isVector()) - return false; + if (!VT.isVector()) { + if (Fast) + *Fast = false; + return Subtarget.enableUnalignedScalarMem(); + } + // All vector implementations must support element alignment EVT ElemVT = VT.getVectorElementType(); if (Alignment >= ElemVT.getStoreSize()) { if (Fast) @@ -10847,7 +12227,7 @@ bool RISCVTargetLowering::allowsMisalignedMemoryAccesses( bool RISCVTargetLowering::splitValueIntoRegisterParts( SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, unsigned NumParts, MVT PartVT, Optional<CallingConv::ID> CC) const { - bool IsABIRegCopy = CC.hasValue(); + bool IsABIRegCopy = CC.has_value(); EVT ValueVT = Val.getValueType(); if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { // Cast the f16 to i16, extend to i32, pad with ones to make a float nan, @@ -10901,7 +12281,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts( SDValue RISCVTargetLowering::joinRegisterPartsIntoValue( SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, MVT PartVT, EVT ValueVT, Optional<CallingConv::ID> CC) const { - bool IsABIRegCopy = CC.hasValue(); + bool IsABIRegCopy = CC.has_value(); if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { SDValue Val = Parts[0]; |