diff options
Diffstat (limited to 'llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp')
-rw-r--r-- | llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 50 |
1 files changed, 36 insertions, 14 deletions
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 2c94f87804ac..ad0c0c8315dc 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -697,14 +697,16 @@ bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, return false; Register SrcReg = MI.getOperand(1).getReg(); - GAnyLoad *LoadMI = getOpcodeDef<GAnyLoad>(SrcReg, MRI); - if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg()) || - !LoadMI->isSimple()) + // Don't use getOpcodeDef() here since intermediate instructions may have + // multiple users. + GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg)); + if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg())) return false; Register LoadReg = LoadMI->getDstReg(); - LLT LoadTy = MRI.getType(LoadReg); + LLT RegTy = MRI.getType(LoadReg); Register PtrReg = LoadMI->getPointerReg(); + unsigned RegSize = RegTy.getSizeInBits(); uint64_t LoadSizeBits = LoadMI->getMemSizeInBits(); unsigned MaskSizeBits = MaskVal.countTrailingOnes(); @@ -715,7 +717,7 @@ bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, // If the mask covers the whole destination register, there's nothing to // extend - if (MaskSizeBits >= LoadTy.getSizeInBits()) + if (MaskSizeBits >= RegSize) return false; // Most targets cannot deal with loads of size < 8 and need to re-legalize to @@ -725,17 +727,26 @@ bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, const MachineMemOperand &MMO = LoadMI->getMMO(); LegalityQuery::MemDesc MemDesc(MMO); - MemDesc.MemoryTy = LLT::scalar(MaskSizeBits); + + // Don't modify the memory access size if this is atomic/volatile, but we can + // still adjust the opcode to indicate the high bit behavior. + if (LoadMI->isSimple()) + MemDesc.MemoryTy = LLT::scalar(MaskSizeBits); + else if (LoadSizeBits > MaskSizeBits || LoadSizeBits == RegSize) + return false; + + // TODO: Could check if it's legal with the reduced or original memory size. if (!isLegalOrBeforeLegalizer( - {TargetOpcode::G_ZEXTLOAD, {LoadTy, MRI.getType(PtrReg)}, {MemDesc}})) + {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}})) return false; MatchInfo = [=](MachineIRBuilder &B) { B.setInstrAndDebugLoc(*LoadMI); auto &MF = B.getMF(); auto PtrInfo = MMO.getPointerInfo(); - auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MaskSizeBits / 8); + auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy); B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO); + LoadMI->eraseFromParent(); }; return true; } @@ -805,21 +816,24 @@ bool CombinerHelper::matchSextInRegOfLoad( MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); + Register DstReg = MI.getOperand(0).getReg(); + LLT RegTy = MRI.getType(DstReg); + // Only supports scalars for now. - if (MRI.getType(MI.getOperand(0).getReg()).isVector()) + if (RegTy.isVector()) return false; Register SrcReg = MI.getOperand(1).getReg(); auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI); - if (!LoadDef || !MRI.hasOneNonDBGUse(LoadDef->getOperand(0).getReg()) || - !LoadDef->isSimple()) + if (!LoadDef || !MRI.hasOneNonDBGUse(DstReg)) return false; + uint64_t MemBits = LoadDef->getMemSizeInBits(); + // If the sign extend extends from a narrower width than the load's width, // then we can narrow the load width when we combine to a G_SEXTLOAD. // Avoid widening the load at all. - unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), - LoadDef->getMemSizeInBits()); + unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits); // Don't generate G_SEXTLOADs with a < 1 byte width. if (NewSizeBits < 8) @@ -831,7 +845,15 @@ bool CombinerHelper::matchSextInRegOfLoad( const MachineMemOperand &MMO = LoadDef->getMMO(); LegalityQuery::MemDesc MMDesc(MMO); - MMDesc.MemoryTy = LLT::scalar(NewSizeBits); + + // Don't modify the memory access size if this is atomic/volatile, but we can + // still adjust the opcode to indicate the high bit behavior. + if (LoadDef->isSimple()) + MMDesc.MemoryTy = LLT::scalar(NewSizeBits); + else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) + return false; + + // TODO: Could check if it's legal with the reduced or original memory size. if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD, {MRI.getType(LoadDef->getDstReg()), MRI.getType(LoadDef->getPointerReg())}, |