diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVFoldMasks.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVFoldMasks.cpp | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp new file mode 100644 index 000000000000..6ee006525df5 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp @@ -0,0 +1,216 @@ +//===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// +// +// This pass performs various peephole optimisations that fold masks into vector +// pseudo instructions after instruction selection. +// +// Currently it converts +// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew +// -> +// PseudoVMV_V_V %false, %true, %vl, %sew +// +//===---------------------------------------------------------------------===// + +#include "RISCV.h" +#include "RISCVISelDAGToDAG.h" +#include "RISCVSubtarget.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "riscv-fold-masks" + +namespace { + +class RISCVFoldMasks : public MachineFunctionPass { +public: + static char ID; + const TargetInstrInfo *TII; + MachineRegisterInfo *MRI; + const TargetRegisterInfo *TRI; + RISCVFoldMasks() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + MachineFunctionProperties getRequiredProperties() const override { + return MachineFunctionProperties().set( + MachineFunctionProperties::Property::IsSSA); + } + + StringRef getPassName() const override { return "RISC-V Fold Masks"; } + +private: + bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef); + bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef); + + bool isAllOnesMask(MachineInstr *MaskDef); +}; + +} // namespace + +char RISCVFoldMasks::ID = 0; + +INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false) + +bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) { + if (!MaskDef) + return false; + assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0); + Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI); + if (!SrcReg.isVirtual()) + return false; + MaskDef = MRI->getVRegDef(SrcReg); + if (!MaskDef) + return false; + + // TODO: Check that the VMSET is the expected bitwidth? The pseudo has + // undefined behaviour if it's the wrong bitwidth, so we could choose to + // assume that it's all-ones? Same applies to its VL. + switch (MaskDef->getOpcode()) { + case RISCV::PseudoVMSET_M_B1: + case RISCV::PseudoVMSET_M_B2: + case RISCV::PseudoVMSET_M_B4: + case RISCV::PseudoVMSET_M_B8: + case RISCV::PseudoVMSET_M_B16: + case RISCV::PseudoVMSET_M_B32: + case RISCV::PseudoVMSET_M_B64: + return true; + default: + return false; + } +} + +// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to +// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET. +bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) { +#define CASE_VMERGE_TO_VMV(lmul) \ + case RISCV::PseudoVMERGE_VVM_##lmul: \ + NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ + break; + unsigned NewOpc; + switch (MI.getOpcode()) { + default: + return false; + CASE_VMERGE_TO_VMV(MF8) + CASE_VMERGE_TO_VMV(MF4) + CASE_VMERGE_TO_VMV(MF2) + CASE_VMERGE_TO_VMV(M1) + CASE_VMERGE_TO_VMV(M2) + CASE_VMERGE_TO_VMV(M4) + CASE_VMERGE_TO_VMV(M8) + } + + Register MergeReg = MI.getOperand(1).getReg(); + Register FalseReg = MI.getOperand(2).getReg(); + // Check merge == false (or merge == undef) + if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) != + TRI->lookThruCopyLike(FalseReg, MRI)) + return false; + + assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); + if (!isAllOnesMask(V0Def)) + return false; + + MI.setDesc(TII->get(NewOpc)); + MI.removeOperand(1); // Merge operand + MI.tieOperands(0, 1); // Tie false to dest + MI.removeOperand(3); // Mask operand + MI.addOperand( + MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)); + + // vmv.v.v doesn't have a mask operand, so we may be able to inflate the + // register class for the destination and merge operands e.g. VRNoV0 -> VR + MRI->recomputeRegClass(MI.getOperand(0).getReg()); + MRI->recomputeRegClass(MI.getOperand(1).getReg()); + return true; +} + +bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI, + MachineInstr *MaskDef) { + const RISCV::RISCVMaskedPseudoInfo *I = + RISCV::getMaskedPseudoInfo(MI.getOpcode()); + if (!I) + return false; + + if (!isAllOnesMask(MaskDef)) + return false; + + // There are two classes of pseudos in the table - compares and + // everything else. See the comment on RISCVMaskedPseudo for details. + const unsigned Opc = I->UnmaskedPseudo; + const MCInstrDesc &MCID = TII->get(Opc); + const bool HasPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags); + const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID); +#ifndef NDEBUG + const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode()); + assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == + RISCVII::hasVecPolicyOp(MCID.TSFlags) && + "Masked and unmasked pseudos are inconsistent"); + assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure"); +#endif + (void)HasPolicyOp; + + MI.setDesc(MCID); + + // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? + unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); + MI.removeOperand(MaskOpIdx); + + // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, + // so try and relax it to vr. + MRI->recomputeRegClass(MI.getOperand(0).getReg()); + unsigned PassthruOpIdx = MI.getNumExplicitDefs(); + if (HasPassthru) { + if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) + MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg()); + } else + MI.removeOperand(PassthruOpIdx); + + return true; +} + +bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) { + if (skipFunction(MF.getFunction())) + return false; + + // Skip if the vector extension is not enabled. + const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); + if (!ST.hasVInstructions()) + return false; + + TII = ST.getInstrInfo(); + MRI = &MF.getRegInfo(); + TRI = MRI->getTargetRegisterInfo(); + + bool Changed = false; + + // Masked pseudos coming out of isel will have their mask operand in the form: + // + // $v0:vr = COPY %mask:vr + // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr + // + // Because $v0 isn't in SSA, keep track of it so we can check the mask operand + // on each pseudo. + MachineInstr *CurrentV0Def; + for (MachineBasicBlock &MBB : MF) { + CurrentV0Def = nullptr; + for (MachineInstr &MI : MBB) { + Changed |= convertToUnmasked(MI, CurrentV0Def); + Changed |= convertVMergeToVMv(MI, CurrentV0Def); + + if (MI.definesRegister(RISCV::V0, TRI)) + CurrentV0Def = &MI; + } + } + + return Changed; +} + +FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); } |
