aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp79
1 files changed, 27 insertions, 52 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
index 4b26c27bb4f8..b807abcc5681 100644
--- a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
@@ -8,8 +8,9 @@
// This file implements the machine function pass to insert read/write of CSR-s
// of the RISC-V instructions.
//
-// Currently the pass implements naive insertion of a write to vxrm before an
-// RVV fixed-point instruction.
+// Currently the pass implements:
+// -Writing and saving frm before an RVV floating-point instruction with a
+// static rounding mode and restores the value after.
//
//===----------------------------------------------------------------------===//
@@ -30,9 +31,7 @@ class RISCVInsertReadWriteCSR : public MachineFunctionPass {
public:
static char ID;
- RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {
- initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry());
- }
+ RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
@@ -56,60 +55,36 @@ char RISCVInsertReadWriteCSR::ID = 0;
INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
-// Returns the index to the rounding mode immediate value if any, otherwise the
-// function will return None.
-static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) {
- uint64_t TSFlags = MI.getDesc().TSFlags;
- if (!RISCVII::hasRoundModeOp(TSFlags))
- return std::nullopt;
-
- // The operand order
- // -------------------------------------
- // | n-1 (if any) | n-2 | n-3 | n-4 |
- // | policy | sew | vl | rm |
- // -------------------------------------
- return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3;
-}
-
-// This function inserts a write to vxrm when encountering an RVV fixed-point
-// instruction.
+// This function also swaps frm and restores it when encountering an RVV
+// floating point instruction with a static rounding mode.
bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
bool Changed = false;
for (MachineInstr &MI : MBB) {
- if (auto RoundModeIdx = getRoundModeIdx(MI)) {
- if (RISCVII::usesVXRM(MI.getDesc().TSFlags)) {
- unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm();
-
- Changed = true;
+ int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
+ if (FRMIdx < 0)
+ continue;
- BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
- .addImm(VXRMImm);
- MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
- /*IsImp*/ true));
- } else { // FRM
- unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm();
+ unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
- // The value is a hint to this pass to not alter the frm value.
- if (FRMImm == RISCVFPRndMode::DYN)
- continue;
+ // The value is a hint to this pass to not alter the frm value.
+ if (FRMImm == RISCVFPRndMode::DYN)
+ continue;
- Changed = true;
+ Changed = true;
- // Save
- MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
- Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
- BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
- SavedFRM)
- .addImm(FRMImm);
- MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
- /*IsImp*/ true));
- // Restore
- MachineInstrBuilder MIB =
- BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
- .addReg(SavedFRM);
- MBB.insertAfter(MI, MIB);
- }
- }
+ // Save
+ MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
+ Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
+ BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
+ SavedFRM)
+ .addImm(FRMImm);
+ MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
+ /*IsImp*/ true));
+ // Restore
+ MachineInstrBuilder MIB =
+ BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
+ .addReg(SavedFRM);
+ MBB.insertAfter(MI, MIB);
}
return Changed;
}