aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp275
1 files changed, 249 insertions, 26 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
index 715d92b036e3..dadf8f81a2c0 100644
--- a/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
@@ -21,6 +21,8 @@ using namespace llvm;
#define DEBUG_TYPE "riscv-sextw-removal"
STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
+STATISTIC(NumTransformedToWInstrs,
+ "Number of instructions transformed to W-ops");
static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
cl::desc("Disable removal of sext.w"),
@@ -55,11 +57,143 @@ FunctionPass *llvm::createRISCVSExtWRemovalPass() {
return new RISCVSExtWRemoval();
}
+// add uses of MI to the Worklist
+static void addUses(const MachineInstr &MI,
+ SmallVectorImpl<const MachineInstr *> &Worklist,
+ MachineRegisterInfo &MRI) {
+ for (auto &UserOp : MRI.reg_operands(MI.getOperand(0).getReg())) {
+ const auto *User = UserOp.getParent();
+ if (User == &MI) // ignore the def, current MI
+ continue;
+ Worklist.push_back(User);
+ }
+}
+
+// returns true if all uses of OrigMI only depend on the lower word of its
+// output, so we can transform OrigMI to the corresponding W-version.
+// TODO: handle multiple interdependent transformations
+static bool isAllUsesReadW(const MachineInstr &OrigMI,
+ MachineRegisterInfo &MRI) {
+
+ SmallPtrSet<const MachineInstr *, 4> Visited;
+ SmallVector<const MachineInstr *, 4> Worklist;
+
+ Visited.insert(&OrigMI);
+ addUses(OrigMI, Worklist, MRI);
+
+ while (!Worklist.empty()) {
+ const MachineInstr *MI = Worklist.pop_back_val();
+
+ if (!Visited.insert(MI).second) {
+ // If we've looped back to OrigMI through a PHI cycle, we can't transform
+ // LD or LWU, because these operations use all 64 bits of input.
+ if (MI == &OrigMI) {
+ unsigned opcode = MI->getOpcode();
+ if (opcode == RISCV::LD || opcode == RISCV::LWU)
+ return false;
+ }
+ continue;
+ }
+
+ switch (MI->getOpcode()) {
+ case RISCV::ADDIW:
+ case RISCV::ADDW:
+ case RISCV::DIVUW:
+ case RISCV::DIVW:
+ case RISCV::MULW:
+ case RISCV::REMUW:
+ case RISCV::REMW:
+ case RISCV::SLLIW:
+ case RISCV::SLLW:
+ case RISCV::SRAIW:
+ case RISCV::SRAW:
+ case RISCV::SRLIW:
+ case RISCV::SRLW:
+ case RISCV::SUBW:
+ case RISCV::ROLW:
+ case RISCV::RORW:
+ case RISCV::RORIW:
+ case RISCV::CLZW:
+ case RISCV::CTZW:
+ case RISCV::CPOPW:
+ case RISCV::SLLI_UW:
+ case RISCV::FCVT_S_W:
+ case RISCV::FCVT_S_WU:
+ case RISCV::FCVT_D_W:
+ case RISCV::FCVT_D_WU:
+ continue;
+
+ // these overwrite higher input bits, otherwise the lower word of output
+ // depends only on the lower word of input. So check their uses read W.
+ case RISCV::SLLI:
+ if (MI->getOperand(2).getImm() >= 32)
+ continue;
+ addUses(*MI, Worklist, MRI);
+ continue;
+ case RISCV::ANDI:
+ if (isUInt<11>(MI->getOperand(2).getImm()))
+ continue;
+ addUses(*MI, Worklist, MRI);
+ continue;
+ case RISCV::ORI:
+ if (!isUInt<11>(MI->getOperand(2).getImm()))
+ continue;
+ addUses(*MI, Worklist, MRI);
+ continue;
+
+ case RISCV::BEXTI:
+ if (MI->getOperand(2).getImm() >= 32)
+ return false;
+ continue;
+
+ // For these, lower word of output in these operations, depends only on
+ // the lower word of input. So, we check all uses only read lower word.
+ case RISCV::COPY:
+ case RISCV::PHI:
+
+ case RISCV::ADD:
+ case RISCV::ADDI:
+ case RISCV::AND:
+ case RISCV::MUL:
+ case RISCV::OR:
+ case RISCV::SLL:
+ case RISCV::SUB:
+ case RISCV::XOR:
+ case RISCV::XORI:
+
+ case RISCV::ADD_UW:
+ case RISCV::ANDN:
+ case RISCV::CLMUL:
+ case RISCV::ORC_B:
+ case RISCV::ORN:
+ case RISCV::SEXT_B:
+ case RISCV::SEXT_H:
+ case RISCV::SH1ADD:
+ case RISCV::SH1ADD_UW:
+ case RISCV::SH2ADD:
+ case RISCV::SH2ADD_UW:
+ case RISCV::SH3ADD:
+ case RISCV::SH3ADD_UW:
+ case RISCV::XNOR:
+ case RISCV::ZEXT_H_RV64:
+ addUses(*MI, Worklist, MRI);
+ continue;
+ default:
+ return false;
+ }
+ }
+ return true;
+}
+
// This function returns true if the machine instruction always outputs a value
// where bits 63:32 match bit 31.
+// Alternatively, if the instruction can be converted to W variant
+// (e.g. ADD->ADDW) and all of its uses only use the lower word of its output,
+// then return true and add the instr to FixableDef to be convereted later
// TODO: Allocate a bit in TSFlags for the W instructions?
// TODO: Add other W instructions.
-static bool isSignExtendingOpW(const MachineInstr &MI) {
+static bool isSignExtendingOpW(MachineInstr &MI, MachineRegisterInfo &MRI,
+ SmallPtrSetImpl<MachineInstr *> &FixableDef) {
switch (MI.getOpcode()) {
case RISCV::LUI:
case RISCV::LW:
@@ -89,8 +223,9 @@ static bool isSignExtendingOpW(const MachineInstr &MI) {
case RISCV::FCVT_WU_S:
case RISCV::FCVT_W_D:
case RISCV::FCVT_WU_D:
+ case RISCV::FMV_X_W:
// The following aren't W instructions, but are either sign extended from a
- // smaller size or put zeros in bits 63:31.
+ // smaller size, always outputs a small integer, or put zeros in bits 63:31.
case RISCV::LBU:
case RISCV::LHU:
case RISCV::LB:
@@ -102,6 +237,12 @@ static bool isSignExtendingOpW(const MachineInstr &MI) {
case RISCV::SEXT_B:
case RISCV::SEXT_H:
case RISCV::ZEXT_H_RV64:
+ case RISCV::FMV_X_H:
+ case RISCV::BEXT:
+ case RISCV::BEXTI:
+ case RISCV::CLZ:
+ case RISCV::CPOP:
+ case RISCV::CTZ:
return true;
// shifting right sufficiently makes the value 32-bit sign-extended
case RISCV::SRAI:
@@ -110,7 +251,14 @@ static bool isSignExtendingOpW(const MachineInstr &MI) {
return MI.getOperand(2).getImm() > 32;
// The LI pattern ADDI rd, X0, imm is sign extended.
case RISCV::ADDI:
- return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
+ if (MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0)
+ return true;
+ if (isAllUsesReadW(MI, MRI)) {
+ // transform to ADDIW
+ FixableDef.insert(&MI);
+ return true;
+ }
+ return false;
// An ANDI with an 11 bit immediate will zero bits 63:11.
case RISCV::ANDI:
return isUInt<11>(MI.getOperand(2).getImm());
@@ -120,28 +268,45 @@ static bool isSignExtendingOpW(const MachineInstr &MI) {
// Copying from X0 produces zero.
case RISCV::COPY:
return MI.getOperand(1).getReg() == RISCV::X0;
+
+ // With these opcode, we can "fix" them with the W-version
+ // if we know all users of the result only rely on bits 31:0
+ case RISCV::SLLI:
+ // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
+ if (MI.getOperand(2).getImm() >= 32)
+ return false;
+ LLVM_FALLTHROUGH;
+ case RISCV::ADD:
+ case RISCV::LD:
+ case RISCV::LWU:
+ case RISCV::MUL:
+ case RISCV::SUB:
+ if (isAllUsesReadW(MI, MRI)) {
+ FixableDef.insert(&MI);
+ return true;
+ }
}
return false;
}
-static bool isSignExtendedW(const MachineInstr &OrigMI,
- MachineRegisterInfo &MRI) {
+static bool isSignExtendedW(MachineInstr &OrigMI, MachineRegisterInfo &MRI,
+ SmallPtrSetImpl<MachineInstr *> &FixableDef) {
SmallPtrSet<const MachineInstr *, 4> Visited;
- SmallVector<const MachineInstr *, 4> Worklist;
+ SmallVector<MachineInstr *, 4> Worklist;
Worklist.push_back(&OrigMI);
while (!Worklist.empty()) {
- const MachineInstr *MI = Worklist.pop_back_val();
+ MachineInstr *MI = Worklist.pop_back_val();
// If we already visited this instruction, we don't need to check it again.
if (!Visited.insert(MI).second)
continue;
// If this is a sign extending operation we don't need to look any further.
- if (isSignExtendingOpW(*MI))
+ if (isSignExtendingOpW(*MI, MRI, FixableDef))
continue;
// Is this an instruction that propagates sign extend.
@@ -157,7 +322,7 @@ static bool isSignExtendedW(const MachineInstr &OrigMI,
// If this is a copy from another register, check its source instruction.
if (!SrcReg.isVirtual())
return false;
- const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
+ MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (!SrcMI)
return false;
@@ -165,18 +330,25 @@ static bool isSignExtendedW(const MachineInstr &OrigMI,
Worklist.push_back(SrcMI);
break;
}
+
+ // For these, we just need to check if the 1st operand is sign extended.
+ case RISCV::BCLRI:
+ case RISCV::BINVI:
+ case RISCV::BSETI:
+ if (MI->getOperand(2).getImm() >= 31)
+ return false;
+ LLVM_FALLTHROUGH;
case RISCV::REM:
case RISCV::ANDI:
case RISCV::ORI:
case RISCV::XORI: {
// |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
// DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
- // Logical operations use a sign extended 12-bit immediate. We just need
- // to check if the other operand is sign extended.
+ // Logical operations use a sign extended 12-bit immediate.
Register SrcReg = MI->getOperand(1).getReg();
if (!SrcReg.isVirtual())
return false;
- const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
+ MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (!SrcMI)
return false;
@@ -214,7 +386,7 @@ static bool isSignExtendedW(const MachineInstr &OrigMI,
Register SrcReg = MI->getOperand(I).getReg();
if (!SrcReg.isVirtual())
return false;
- const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
+ MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (!SrcMI)
return false;
@@ -232,6 +404,26 @@ static bool isSignExtendedW(const MachineInstr &OrigMI,
return true;
}
+static unsigned getWOp(unsigned Opcode) {
+ switch (Opcode) {
+ case RISCV::ADDI:
+ return RISCV::ADDIW;
+ case RISCV::ADD:
+ return RISCV::ADDW;
+ case RISCV::LD:
+ case RISCV::LWU:
+ return RISCV::LW;
+ case RISCV::MUL:
+ return RISCV::MULW;
+ case RISCV::SLLI:
+ return RISCV::SLLIW;
+ case RISCV::SUB:
+ return RISCV::SUBW;
+ default:
+ llvm_unreachable("Unexpected opcode for replacement with W variant");
+ }
+}
+
bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()) || DisableSExtWRemoval)
return false;
@@ -242,7 +434,10 @@ bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
if (!ST.is64Bit())
return false;
- bool MadeChange = false;
+ SmallPtrSet<MachineInstr *, 4> SExtWRemovalCands;
+
+ // Replacing instructions invalidates the MI iterator
+ // we collect the candidates, then iterate over them separately.
for (MachineBasicBlock &MBB : MF) {
for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
MachineInstr *MI = &*I++;
@@ -257,21 +452,49 @@ bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
if (!SrcReg.isVirtual())
continue;
- const MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg);
- if (!isSignExtendedW(SrcMI, MRI))
- continue;
+ SExtWRemovalCands.insert(MI);
+ }
+ }
- Register DstReg = MI->getOperand(0).getReg();
- if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
- continue;
+ bool MadeChange = false;
+ for (auto MI : SExtWRemovalCands) {
+ SmallPtrSet<MachineInstr *, 4> FixableDef;
+ Register SrcReg = MI->getOperand(1).getReg();
+ MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg);
+
+ // If all definitions reaching MI sign-extend their output,
+ // then sext.w is redundant
+ if (!isSignExtendedW(SrcMI, MRI, FixableDef))
+ continue;
- LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
- MRI.replaceRegWith(DstReg, SrcReg);
- MRI.clearKillFlags(SrcReg);
- MI->eraseFromParent();
- ++NumRemovedSExtW;
- MadeChange = true;
+ Register DstReg = MI->getOperand(0).getReg();
+ if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
+ continue;
+ // Replace Fixable instructions with their W versions.
+ for (MachineInstr *Fixable : FixableDef) {
+ MachineBasicBlock &MBB = *Fixable->getParent();
+ const DebugLoc &DL = Fixable->getDebugLoc();
+ unsigned Code = getWOp(Fixable->getOpcode());
+ MachineInstrBuilder Replacement =
+ BuildMI(MBB, Fixable, DL, ST.getInstrInfo()->get(Code));
+ for (auto Op : Fixable->operands())
+ Replacement.add(Op);
+ for (auto Op : Fixable->memoperands())
+ Replacement.addMemOperand(Op);
+
+ LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
+ LLVM_DEBUG(dbgs() << " with " << *Replacement);
+
+ Fixable->eraseFromParent();
+ ++NumTransformedToWInstrs;
}
+
+ LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
+ MRI.replaceRegWith(DstReg, SrcReg);
+ MRI.clearKillFlags(SrcReg);
+ MI->eraseFromParent();
+ ++NumRemovedSExtW;
+ MadeChange = true;
}
return MadeChange;