diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp | 101 |
1 files changed, 80 insertions, 21 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp index 7014755b6706..2c2b34bb5b77 100644 --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -12,16 +12,21 @@ // extended bits aren't consumed or because the input was already sign extended // by an earlier instruction. // -// Then it removes the -w suffix from each addiw and slliw instructions -// whenever all users are dependent only on the lower word of the result of the -// instruction. We do this only for addiw, slliw, and mulw because the -w forms -// are less compressible. +// Then it removes the -w suffix from opw instructions whenever all users are +// dependent only on the lower word of the result of the instruction. +// The cases handled are: +// * addw because c.add has a larger register encoding than c.addw. +// * addiw because it helps reduce test differences between RV32 and RV64 +// w/o being a pessimization. +// * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) +// * slliw because c.slliw doesn't exist and c.slli does // //===---------------------------------------------------------------------===// #include "RISCV.h" #include "RISCVMachineFunctionInfo.h" #include "RISCVSubtarget.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/TargetInstrInfo.h" @@ -48,9 +53,7 @@ class RISCVOptWInstrs : public MachineFunctionPass { public: static char ID; - RISCVOptWInstrs() : MachineFunctionPass(ID) { - initializeRISCVOptWInstrsPass(*PassRegistry::getPassRegistry()); - } + RISCVOptWInstrs() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, @@ -76,6 +79,29 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() { return new RISCVOptWInstrs(); } +static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, + unsigned Bits) { + const MachineInstr &MI = *UserOp.getParent(); + unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); + + if (!MCOpcode) + return false; + + const MCInstrDesc &MCID = MI.getDesc(); + const uint64_t TSFlags = MCID.TSFlags; + if (!RISCVII::hasSEWOp(TSFlags)) + return false; + assert(RISCVII::hasVLOp(TSFlags)); + const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm(); + + if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID)) + return false; + + auto NumDemandedBits = + RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); + return NumDemandedBits && Bits >= *NumDemandedBits; +} + // Checks if all users only demand the lower \p OrigBits of the original // instruction's result. // TODO: handle multiple interdependent transformations @@ -100,12 +126,14 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI, if (MI->getNumExplicitDefs() != 1) return false; - for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { + for (auto &UserOp : MRI.use_nodbg_operands(MI->getOperand(0).getReg())) { const MachineInstr *UserMI = UserOp.getParent(); unsigned OpIdx = UserOp.getOperandNo(); switch (UserMI->getOpcode()) { default: + if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) + break; return false; case RISCV::ADDIW: @@ -283,6 +311,8 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI, Worklist.push_back(std::make_pair(UserMI, Bits)); break; + case RISCV::CZERO_EQZ: + case RISCV::CZERO_NEZ: case RISCV::VT_MASKC: case RISCV::VT_MASKCN: if (OpIdx != 1) @@ -327,9 +357,27 @@ static bool isSignExtendingOpW(const MachineInstr &MI, // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. case RISCV::ORI: return !isUInt<11>(MI.getOperand(2).getImm()); + // A bseti with X0 is sign extended if the immediate is less than 31. + case RISCV::BSETI: + return MI.getOperand(2).getImm() < 31 && + MI.getOperand(1).getReg() == RISCV::X0; // Copying from X0 produces zero. case RISCV::COPY: return MI.getOperand(1).getReg() == RISCV::X0; + case RISCV::PseudoAtomicLoadNand32: + return true; + case RISCV::PseudoVMV_X_S_MF8: + case RISCV::PseudoVMV_X_S_MF4: + case RISCV::PseudoVMV_X_S_MF2: + case RISCV::PseudoVMV_X_S_M1: + case RISCV::PseudoVMV_X_S_M2: + case RISCV::PseudoVMV_X_S_M4: + case RISCV::PseudoVMV_X_S_M8: { + // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. + int64_t Log2SEW = MI.getOperand(2).getImm(); + assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW"); + return Log2SEW <= 5; + } } return false; @@ -348,6 +396,11 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); if (!SrcMI) return false; + // Code assumes the register is operand 0. + // TODO: Maybe the worklist should store register? + if (!SrcMI->getOperand(0).isReg() || + SrcMI->getOperand(0).getReg() != SrcReg) + return false; // Add SrcMI to the worklist. Worklist.push_back(SrcMI); return true; @@ -446,9 +499,16 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, break; case RISCV::PseudoCCADDW: + case RISCV::PseudoCCADDIW: case RISCV::PseudoCCSUBW: - // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to - // check if operand 4 is sign extended. + case RISCV::PseudoCCSLLW: + case RISCV::PseudoCCSRLW: + case RISCV::PseudoCCSRAW: + case RISCV::PseudoCCSLLIW: + case RISCV::PseudoCCSRLIW: + case RISCV::PseudoCCSRAIW: + // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only + // need to check if operand 4 is sign extended. if (!AddRegDefToWorkList(MI->getOperand(4).getReg())) return false; break; @@ -504,6 +564,8 @@ static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, break; } + case RISCV::CZERO_EQZ: + case RISCV::CZERO_NEZ: case RISCV::VT_MASKC: case RISCV::VT_MASKCN: // Instructions return zero or operand 1. Result is sign extended if @@ -567,25 +629,23 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { - for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { - MachineInstr *MI = &*I++; - + for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { // We're looking for the sext.w pattern ADDIW rd, rs1, 0. - if (!RISCV::isSEXT_W(*MI)) + if (!RISCV::isSEXT_W(MI)) continue; - Register SrcReg = MI->getOperand(1).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); SmallPtrSet<MachineInstr *, 4> FixableDefs; // If all users only use the lower bits, this sext.w is redundant. // Or if all definitions reaching MI sign-extend their output, // then sext.w is redundant. - if (!hasAllWUsers(*MI, ST, MRI) && + if (!hasAllWUsers(MI, ST, MRI) && !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) continue; - Register DstReg = MI->getOperand(0).getReg(); + Register DstReg = MI.getOperand(0).getReg(); if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) continue; @@ -603,7 +663,7 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); MRI.replaceRegWith(DstReg, SrcReg); MRI.clearKillFlags(SrcReg); - MI->eraseFromParent(); + MI.eraseFromParent(); ++NumRemovedSExtW; MadeChange = true; } @@ -621,14 +681,13 @@ bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { - for (auto I = MBB.begin(), IE = MBB.end(); I != IE; ++I) { - MachineInstr &MI = *I; - + for (MachineInstr &MI : MBB) { unsigned Opc; switch (MI.getOpcode()) { default: continue; case RISCV::ADDW: Opc = RISCV::ADD; break; + case RISCV::ADDIW: Opc = RISCV::ADDI; break; case RISCV::MULW: Opc = RISCV::MUL; break; case RISCV::SLLIW: Opc = RISCV::SLLI; break; } |
