summaryrefslogtreecommitdiff
path: root/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h')
-rw-r--r--include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h92
1 files changed, 55 insertions, 37 deletions
diff --git a/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h b/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
index a22778b8848c..7f960e727846 100644
--- a/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
+++ b/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
@@ -47,8 +47,7 @@ public:
bool tryCombineAnyExt(MachineInstr &MI,
SmallVectorImpl<MachineInstr *> &DeadInsts) {
- if (MI.getOpcode() != TargetOpcode::G_ANYEXT)
- return false;
+ assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
Builder.setInstr(MI);
Register DstReg = MI.getOperand(0).getReg();
@@ -93,9 +92,7 @@ public:
bool tryCombineZExt(MachineInstr &MI,
SmallVectorImpl<MachineInstr *> &DeadInsts) {
-
- if (MI.getOpcode() != TargetOpcode::G_ZEXT)
- return false;
+ assert(MI.getOpcode() == TargetOpcode::G_ZEXT);
Builder.setInstr(MI);
Register DstReg = MI.getOperand(0).getReg();
@@ -136,32 +133,24 @@ public:
bool tryCombineSExt(MachineInstr &MI,
SmallVectorImpl<MachineInstr *> &DeadInsts) {
-
- if (MI.getOpcode() != TargetOpcode::G_SEXT)
- return false;
+ assert(MI.getOpcode() == TargetOpcode::G_SEXT);
Builder.setInstr(MI);
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
- // sext(trunc x) - > ashr (shl (aext/copy/trunc x), c), c
+ // sext(trunc x) - > (sext_inreg (aext/copy/trunc x), c)
Register TruncSrc;
if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
LLT DstTy = MRI.getType(DstReg);
- // Guess on the RHS shift amount type, which should be re-legalized if
- // applicable.
- if (isInstUnsupported({TargetOpcode::G_SHL, {DstTy, DstTy}}) ||
- isInstUnsupported({TargetOpcode::G_ASHR, {DstTy, DstTy}}) ||
- isConstantUnsupported(DstTy))
+ if (isInstUnsupported({TargetOpcode::G_SEXT_INREG, {DstTy}}))
return false;
LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
LLT SrcTy = MRI.getType(SrcReg);
- unsigned ShAmt = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
- auto MIBShAmt = Builder.buildConstant(DstTy, ShAmt);
- auto MIBShl = Builder.buildInstr(
- TargetOpcode::G_SHL, {DstTy},
- {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), MIBShAmt});
- Builder.buildInstr(TargetOpcode::G_ASHR, {DstReg}, {MIBShl, MIBShAmt});
+ uint64_t SizeInBits = SrcTy.getScalarSizeInBits();
+ Builder.buildInstr(
+ TargetOpcode::G_SEXT_INREG, {DstReg},
+ {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), SizeInBits});
markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
return true;
}
@@ -172,9 +161,8 @@ public:
bool tryFoldImplicitDef(MachineInstr &MI,
SmallVectorImpl<MachineInstr *> &DeadInsts) {
unsigned Opcode = MI.getOpcode();
- if (Opcode != TargetOpcode::G_ANYEXT && Opcode != TargetOpcode::G_ZEXT &&
- Opcode != TargetOpcode::G_SEXT)
- return false;
+ assert(Opcode == TargetOpcode::G_ANYEXT || Opcode == TargetOpcode::G_ZEXT ||
+ Opcode == TargetOpcode::G_SEXT);
if (MachineInstr *DefMI = getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF,
MI.getOperand(1).getReg(), MRI)) {
@@ -203,21 +191,38 @@ public:
return false;
}
- static unsigned getMergeOpcode(LLT OpTy, LLT DestTy) {
+ static unsigned canFoldMergeOpcode(unsigned MergeOp, unsigned ConvertOp,
+ LLT OpTy, LLT DestTy) {
if (OpTy.isVector() && DestTy.isVector())
- return TargetOpcode::G_CONCAT_VECTORS;
+ return MergeOp == TargetOpcode::G_CONCAT_VECTORS;
+
+ if (OpTy.isVector() && !DestTy.isVector()) {
+ if (MergeOp == TargetOpcode::G_BUILD_VECTOR)
+ return true;
- if (OpTy.isVector() && !DestTy.isVector())
- return TargetOpcode::G_BUILD_VECTOR;
+ if (MergeOp == TargetOpcode::G_CONCAT_VECTORS) {
+ if (ConvertOp == 0)
+ return true;
- return TargetOpcode::G_MERGE_VALUES;
+ const unsigned OpEltSize = OpTy.getElementType().getSizeInBits();
+
+ // Don't handle scalarization with a cast that isn't in the same
+ // direction as the vector cast. This could be handled, but it would
+ // require more intermediate unmerges.
+ if (ConvertOp == TargetOpcode::G_TRUNC)
+ return DestTy.getSizeInBits() <= OpEltSize;
+ return DestTy.getSizeInBits() >= OpEltSize;
+ }
+
+ return false;
+ }
+
+ return MergeOp == TargetOpcode::G_MERGE_VALUES;
}
bool tryCombineMerges(MachineInstr &MI,
SmallVectorImpl<MachineInstr *> &DeadInsts) {
-
- if (MI.getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
- return false;
+ assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
unsigned NumDefs = MI.getNumOperands() - 1;
MachineInstr *SrcDef =
@@ -237,16 +242,14 @@ public:
MergeI = getDefIgnoringCopies(SrcDef->getOperand(1).getReg(), MRI);
}
- // FIXME: Handle scalarizing concat_vectors (scalar result type with vector
- // source)
- unsigned MergingOpcode = getMergeOpcode(OpTy, DestTy);
- if (!MergeI || MergeI->getOpcode() != MergingOpcode)
+ if (!MergeI || !canFoldMergeOpcode(MergeI->getOpcode(),
+ ConvertOp, OpTy, DestTy))
return false;
const unsigned NumMergeRegs = MergeI->getNumOperands() - 1;
if (NumMergeRegs < NumDefs) {
- if (ConvertOp != 0 || NumDefs % NumMergeRegs != 0)
+ if (NumDefs % NumMergeRegs != 0)
return false;
Builder.setInstr(MI);
@@ -264,7 +267,22 @@ public:
++j, ++DefIdx)
DstRegs.push_back(MI.getOperand(DefIdx).getReg());
- Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
+ if (ConvertOp) {
+ SmallVector<Register, 2> TmpRegs;
+ // This is a vector that is being scalarized and casted. Extract to
+ // the element type, and do the conversion on the scalars.
+ LLT MergeEltTy
+ = MRI.getType(MergeI->getOperand(0).getReg()).getElementType();
+ for (unsigned j = 0; j < NumMergeRegs; ++j)
+ TmpRegs.push_back(MRI.createGenericVirtualRegister(MergeEltTy));
+
+ Builder.buildUnmerge(TmpRegs, MergeI->getOperand(Idx + 1).getReg());
+
+ for (unsigned j = 0; j < NumMergeRegs; ++j)
+ Builder.buildInstr(ConvertOp, {DstRegs[j]}, {TmpRegs[j]});
+ } else {
+ Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
+ }
}
} else if (NumMergeRegs > NumDefs) {