diff options
Diffstat (limited to 'include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h')
-rw-r--r-- | include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h | 92 |
1 files changed, 55 insertions, 37 deletions
diff --git a/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h b/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h index a22778b8848c..7f960e727846 100644 --- a/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h +++ b/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h @@ -47,8 +47,7 @@ public: bool tryCombineAnyExt(MachineInstr &MI, SmallVectorImpl<MachineInstr *> &DeadInsts) { - if (MI.getOpcode() != TargetOpcode::G_ANYEXT) - return false; + assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); Builder.setInstr(MI); Register DstReg = MI.getOperand(0).getReg(); @@ -93,9 +92,7 @@ public: bool tryCombineZExt(MachineInstr &MI, SmallVectorImpl<MachineInstr *> &DeadInsts) { - - if (MI.getOpcode() != TargetOpcode::G_ZEXT) - return false; + assert(MI.getOpcode() == TargetOpcode::G_ZEXT); Builder.setInstr(MI); Register DstReg = MI.getOperand(0).getReg(); @@ -136,32 +133,24 @@ public: bool tryCombineSExt(MachineInstr &MI, SmallVectorImpl<MachineInstr *> &DeadInsts) { - - if (MI.getOpcode() != TargetOpcode::G_SEXT) - return false; + assert(MI.getOpcode() == TargetOpcode::G_SEXT); Builder.setInstr(MI); Register DstReg = MI.getOperand(0).getReg(); Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg()); - // sext(trunc x) - > ashr (shl (aext/copy/trunc x), c), c + // sext(trunc x) - > (sext_inreg (aext/copy/trunc x), c) Register TruncSrc; if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) { LLT DstTy = MRI.getType(DstReg); - // Guess on the RHS shift amount type, which should be re-legalized if - // applicable. - if (isInstUnsupported({TargetOpcode::G_SHL, {DstTy, DstTy}}) || - isInstUnsupported({TargetOpcode::G_ASHR, {DstTy, DstTy}}) || - isConstantUnsupported(DstTy)) + if (isInstUnsupported({TargetOpcode::G_SEXT_INREG, {DstTy}})) return false; LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;); LLT SrcTy = MRI.getType(SrcReg); - unsigned ShAmt = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits(); - auto MIBShAmt = Builder.buildConstant(DstTy, ShAmt); - auto MIBShl = Builder.buildInstr( - TargetOpcode::G_SHL, {DstTy}, - {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), MIBShAmt}); - Builder.buildInstr(TargetOpcode::G_ASHR, {DstReg}, {MIBShl, MIBShAmt}); + uint64_t SizeInBits = SrcTy.getScalarSizeInBits(); + Builder.buildInstr( + TargetOpcode::G_SEXT_INREG, {DstReg}, + {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), SizeInBits}); markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts); return true; } @@ -172,9 +161,8 @@ public: bool tryFoldImplicitDef(MachineInstr &MI, SmallVectorImpl<MachineInstr *> &DeadInsts) { unsigned Opcode = MI.getOpcode(); - if (Opcode != TargetOpcode::G_ANYEXT && Opcode != TargetOpcode::G_ZEXT && - Opcode != TargetOpcode::G_SEXT) - return false; + assert(Opcode == TargetOpcode::G_ANYEXT || Opcode == TargetOpcode::G_ZEXT || + Opcode == TargetOpcode::G_SEXT); if (MachineInstr *DefMI = getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(), MRI)) { @@ -203,21 +191,38 @@ public: return false; } - static unsigned getMergeOpcode(LLT OpTy, LLT DestTy) { + static unsigned canFoldMergeOpcode(unsigned MergeOp, unsigned ConvertOp, + LLT OpTy, LLT DestTy) { if (OpTy.isVector() && DestTy.isVector()) - return TargetOpcode::G_CONCAT_VECTORS; + return MergeOp == TargetOpcode::G_CONCAT_VECTORS; + + if (OpTy.isVector() && !DestTy.isVector()) { + if (MergeOp == TargetOpcode::G_BUILD_VECTOR) + return true; - if (OpTy.isVector() && !DestTy.isVector()) - return TargetOpcode::G_BUILD_VECTOR; + if (MergeOp == TargetOpcode::G_CONCAT_VECTORS) { + if (ConvertOp == 0) + return true; - return TargetOpcode::G_MERGE_VALUES; + const unsigned OpEltSize = OpTy.getElementType().getSizeInBits(); + + // Don't handle scalarization with a cast that isn't in the same + // direction as the vector cast. This could be handled, but it would + // require more intermediate unmerges. + if (ConvertOp == TargetOpcode::G_TRUNC) + return DestTy.getSizeInBits() <= OpEltSize; + return DestTy.getSizeInBits() >= OpEltSize; + } + + return false; + } + + return MergeOp == TargetOpcode::G_MERGE_VALUES; } bool tryCombineMerges(MachineInstr &MI, SmallVectorImpl<MachineInstr *> &DeadInsts) { - - if (MI.getOpcode() != TargetOpcode::G_UNMERGE_VALUES) - return false; + assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES); unsigned NumDefs = MI.getNumOperands() - 1; MachineInstr *SrcDef = @@ -237,16 +242,14 @@ public: MergeI = getDefIgnoringCopies(SrcDef->getOperand(1).getReg(), MRI); } - // FIXME: Handle scalarizing concat_vectors (scalar result type with vector - // source) - unsigned MergingOpcode = getMergeOpcode(OpTy, DestTy); - if (!MergeI || MergeI->getOpcode() != MergingOpcode) + if (!MergeI || !canFoldMergeOpcode(MergeI->getOpcode(), + ConvertOp, OpTy, DestTy)) return false; const unsigned NumMergeRegs = MergeI->getNumOperands() - 1; if (NumMergeRegs < NumDefs) { - if (ConvertOp != 0 || NumDefs % NumMergeRegs != 0) + if (NumDefs % NumMergeRegs != 0) return false; Builder.setInstr(MI); @@ -264,7 +267,22 @@ public: ++j, ++DefIdx) DstRegs.push_back(MI.getOperand(DefIdx).getReg()); - Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg()); + if (ConvertOp) { + SmallVector<Register, 2> TmpRegs; + // This is a vector that is being scalarized and casted. Extract to + // the element type, and do the conversion on the scalars. + LLT MergeEltTy + = MRI.getType(MergeI->getOperand(0).getReg()).getElementType(); + for (unsigned j = 0; j < NumMergeRegs; ++j) + TmpRegs.push_back(MRI.createGenericVirtualRegister(MergeEltTy)); + + Builder.buildUnmerge(TmpRegs, MergeI->getOperand(Idx + 1).getReg()); + + for (unsigned j = 0; j < NumMergeRegs; ++j) + Builder.buildInstr(ConvertOp, {DstRegs[j]}, {TmpRegs[j]}); + } else { + Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg()); + } } } else if (NumMergeRegs > NumDefs) { |