diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp | 79 |
1 files changed, 27 insertions, 52 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp index 4b26c27bb4f8..b807abcc5681 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp @@ -8,8 +8,9 @@ // This file implements the machine function pass to insert read/write of CSR-s // of the RISC-V instructions. // -// Currently the pass implements naive insertion of a write to vxrm before an -// RVV fixed-point instruction. +// Currently the pass implements: +// -Writing and saving frm before an RVV floating-point instruction with a +// static rounding mode and restores the value after. // //===----------------------------------------------------------------------===// @@ -30,9 +31,7 @@ class RISCVInsertReadWriteCSR : public MachineFunctionPass { public: static char ID; - RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) { - initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry()); - } + RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; @@ -56,60 +55,36 @@ char RISCVInsertReadWriteCSR::ID = 0; INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) -// Returns the index to the rounding mode immediate value if any, otherwise the -// function will return None. -static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) { - uint64_t TSFlags = MI.getDesc().TSFlags; - if (!RISCVII::hasRoundModeOp(TSFlags)) - return std::nullopt; - - // The operand order - // ------------------------------------- - // | n-1 (if any) | n-2 | n-3 | n-4 | - // | policy | sew | vl | rm | - // ------------------------------------- - return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3; -} - -// This function inserts a write to vxrm when encountering an RVV fixed-point -// instruction. +// This function also swaps frm and restores it when encountering an RVV +// floating point instruction with a static rounding mode. bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { bool Changed = false; for (MachineInstr &MI : MBB) { - if (auto RoundModeIdx = getRoundModeIdx(MI)) { - if (RISCVII::usesVXRM(MI.getDesc().TSFlags)) { - unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); - - Changed = true; + int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); + if (FRMIdx < 0) + continue; - BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) - .addImm(VXRMImm); - MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, - /*IsImp*/ true)); - } else { // FRM - unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm(); + unsigned FRMImm = MI.getOperand(FRMIdx).getImm(); - // The value is a hint to this pass to not alter the frm value. - if (FRMImm == RISCVFPRndMode::DYN) - continue; + // The value is a hint to this pass to not alter the frm value. + if (FRMImm == RISCVFPRndMode::DYN) + continue; - Changed = true; + Changed = true; - // Save - MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); - Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); - BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), - SavedFRM) - .addImm(FRMImm); - MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, - /*IsImp*/ true)); - // Restore - MachineInstrBuilder MIB = - BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) - .addReg(SavedFRM); - MBB.insertAfter(MI, MIB); - } - } + // Save + MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); + Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), + SavedFRM) + .addImm(FRMImm); + MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, + /*IsImp*/ true)); + // Restore + MachineInstrBuilder MIB = + BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) + .addReg(SavedFRM); + MBB.insertAfter(MI, MIB); } return Changed; } |
