aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp101
1 files changed, 80 insertions, 21 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
index 7014755b6706..2c2b34bb5b77 100644
--- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
@@ -12,16 +12,21 @@
// extended bits aren't consumed or because the input was already sign extended
// by an earlier instruction.
//
-// Then it removes the -w suffix from each addiw and slliw instructions
-// whenever all users are dependent only on the lower word of the result of the
-// instruction. We do this only for addiw, slliw, and mulw because the -w forms
-// are less compressible.
+// Then it removes the -w suffix from opw instructions whenever all users are
+// dependent only on the lower word of the result of the instruction.
+// The cases handled are:
+// * addw because c.add has a larger register encoding than c.addw.
+// * addiw because it helps reduce test differences between RV32 and RV64
+// w/o being a pessimization.
+// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
+// * slliw because c.slliw doesn't exist and c.slli does
//
//===---------------------------------------------------------------------===//
#include "RISCV.h"
#include "RISCVMachineFunctionInfo.h"
#include "RISCVSubtarget.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
@@ -48,9 +53,7 @@ class RISCVOptWInstrs : public MachineFunctionPass {
public:
static char ID;
- RISCVOptWInstrs() : MachineFunctionPass(ID) {
- initializeRISCVOptWInstrsPass(*PassRegistry::getPassRegistry());
- }
+ RISCVOptWInstrs() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
@@ -76,6 +79,29 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
return new RISCVOptWInstrs();
}
+static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
+ unsigned Bits) {
+ const MachineInstr &MI = *UserOp.getParent();
+ unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
+
+ if (!MCOpcode)
+ return false;
+
+ const MCInstrDesc &MCID = MI.getDesc();
+ const uint64_t TSFlags = MCID.TSFlags;
+ if (!RISCVII::hasSEWOp(TSFlags))
+ return false;
+ assert(RISCVII::hasVLOp(TSFlags));
+ const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
+
+ if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
+ return false;
+
+ auto NumDemandedBits =
+ RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
+ return NumDemandedBits && Bits >= *NumDemandedBits;
+}
+
// Checks if all users only demand the lower \p OrigBits of the original
// instruction's result.
// TODO: handle multiple interdependent transformations
@@ -100,12 +126,14 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
if (MI->getNumExplicitDefs() != 1)
return false;
- for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
+ for (auto &UserOp : MRI.use_nodbg_operands(MI->getOperand(0).getReg())) {
const MachineInstr *UserMI = UserOp.getParent();
unsigned OpIdx = UserOp.getOperandNo();
switch (UserMI->getOpcode()) {
default:
+ if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
+ break;
return false;
case RISCV::ADDIW:
@@ -283,6 +311,8 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
Worklist.push_back(std::make_pair(UserMI, Bits));
break;
+ case RISCV::CZERO_EQZ:
+ case RISCV::CZERO_NEZ:
case RISCV::VT_MASKC:
case RISCV::VT_MASKCN:
if (OpIdx != 1)
@@ -327,9 +357,27 @@ static bool isSignExtendingOpW(const MachineInstr &MI,
// An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
case RISCV::ORI:
return !isUInt<11>(MI.getOperand(2).getImm());
+ // A bseti with X0 is sign extended if the immediate is less than 31.
+ case RISCV::BSETI:
+ return MI.getOperand(2).getImm() < 31 &&
+ MI.getOperand(1).getReg() == RISCV::X0;
// Copying from X0 produces zero.
case RISCV::COPY:
return MI.getOperand(1).getReg() == RISCV::X0;
+ case RISCV::PseudoAtomicLoadNand32:
+ return true;
+ case RISCV::PseudoVMV_X_S_MF8:
+ case RISCV::PseudoVMV_X_S_MF4:
+ case RISCV::PseudoVMV_X_S_MF2:
+ case RISCV::PseudoVMV_X_S_M1:
+ case RISCV::PseudoVMV_X_S_M2:
+ case RISCV::PseudoVMV_X_S_M4:
+ case RISCV::PseudoVMV_X_S_M8: {
+ // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
+ int64_t Log2SEW = MI.getOperand(2).getImm();
+ assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
+ return Log2SEW <= 5;
+ }
}
return false;
@@ -348,6 +396,11 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (!SrcMI)
return false;
+ // Code assumes the register is operand 0.
+ // TODO: Maybe the worklist should store register?
+ if (!SrcMI->getOperand(0).isReg() ||
+ SrcMI->getOperand(0).getReg() != SrcReg)
+ return false;
// Add SrcMI to the worklist.
Worklist.push_back(SrcMI);
return true;
@@ -446,9 +499,16 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
break;
case RISCV::PseudoCCADDW:
+ case RISCV::PseudoCCADDIW:
case RISCV::PseudoCCSUBW:
- // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to
- // check if operand 4 is sign extended.
+ case RISCV::PseudoCCSLLW:
+ case RISCV::PseudoCCSRLW:
+ case RISCV::PseudoCCSRAW:
+ case RISCV::PseudoCCSLLIW:
+ case RISCV::PseudoCCSRLIW:
+ case RISCV::PseudoCCSRAIW:
+ // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
+ // need to check if operand 4 is sign extended.
if (!AddRegDefToWorkList(MI->getOperand(4).getReg()))
return false;
break;
@@ -504,6 +564,8 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
break;
}
+ case RISCV::CZERO_EQZ:
+ case RISCV::CZERO_NEZ:
case RISCV::VT_MASKC:
case RISCV::VT_MASKCN:
// Instructions return zero or operand 1. Result is sign extended if
@@ -567,25 +629,23 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
bool MadeChange = false;
for (MachineBasicBlock &MBB : MF) {
- for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
- MachineInstr *MI = &*I++;
-
+ for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
// We're looking for the sext.w pattern ADDIW rd, rs1, 0.
- if (!RISCV::isSEXT_W(*MI))
+ if (!RISCV::isSEXT_W(MI))
continue;
- Register SrcReg = MI->getOperand(1).getReg();
+ Register SrcReg = MI.getOperand(1).getReg();
SmallPtrSet<MachineInstr *, 4> FixableDefs;
// If all users only use the lower bits, this sext.w is redundant.
// Or if all definitions reaching MI sign-extend their output,
// then sext.w is redundant.
- if (!hasAllWUsers(*MI, ST, MRI) &&
+ if (!hasAllWUsers(MI, ST, MRI) &&
!isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
continue;
- Register DstReg = MI->getOperand(0).getReg();
+ Register DstReg = MI.getOperand(0).getReg();
if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
continue;
@@ -603,7 +663,7 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
MRI.replaceRegWith(DstReg, SrcReg);
MRI.clearKillFlags(SrcReg);
- MI->eraseFromParent();
+ MI.eraseFromParent();
++NumRemovedSExtW;
MadeChange = true;
}
@@ -621,14 +681,13 @@ bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
bool MadeChange = false;
for (MachineBasicBlock &MBB : MF) {
- for (auto I = MBB.begin(), IE = MBB.end(); I != IE; ++I) {
- MachineInstr &MI = *I;
-
+ for (MachineInstr &MI : MBB) {
unsigned Opc;
switch (MI.getOpcode()) {
default:
continue;
case RISCV::ADDW: Opc = RISCV::ADD; break;
+ case RISCV::ADDIW: Opc = RISCV::ADDI; break;
case RISCV::MULW: Opc = RISCV::MUL; break;
case RISCV::SLLIW: Opc = RISCV::SLLI; break;
}