aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp49
1 files changed, 49 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 523fa2d3724b..54177564afbc 100644
--- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -594,6 +594,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::MUL);
+ setTargetDAGCombine(ISD::SMUL_LOHI);
+ setTargetDAGCombine(ISD::UMUL_LOHI);
setTargetDAGCombine(ISD::MULHU);
setTargetDAGCombine(ISD::MULHS);
setTargetDAGCombine(ISD::SELECT);
@@ -3462,6 +3464,50 @@ SDValue AMDGPUTargetLowering::performMulCombine(SDNode *N,
return DAG.getSExtOrTrunc(Mul, DL, VT);
}
+SDValue
+AMDGPUTargetLowering::performMulLoHiCombine(SDNode *N,
+ DAGCombinerInfo &DCI) const {
+ if (N->getValueType(0) != MVT::i32)
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // SimplifyDemandedBits has the annoying habit of turning useful zero_extends
+ // in the source into any_extends if the result of the mul is truncated. Since
+ // we can assume the high bits are whatever we want, use the underlying value
+ // to avoid the unknown high bits from interfering.
+ if (N0.getOpcode() == ISD::ANY_EXTEND)
+ N0 = N0.getOperand(0);
+ if (N1.getOpcode() == ISD::ANY_EXTEND)
+ N1 = N1.getOperand(0);
+
+ // Try to use two fast 24-bit multiplies (one for each half of the result)
+ // instead of one slow extending multiply.
+ unsigned LoOpcode, HiOpcode;
+ if (Subtarget->hasMulU24() && isU24(N0, DAG) && isU24(N1, DAG)) {
+ N0 = DAG.getZExtOrTrunc(N0, DL, MVT::i32);
+ N1 = DAG.getZExtOrTrunc(N1, DL, MVT::i32);
+ LoOpcode = AMDGPUISD::MUL_U24;
+ HiOpcode = AMDGPUISD::MULHI_U24;
+ } else if (Subtarget->hasMulI24() && isI24(N0, DAG) && isI24(N1, DAG)) {
+ N0 = DAG.getSExtOrTrunc(N0, DL, MVT::i32);
+ N1 = DAG.getSExtOrTrunc(N1, DL, MVT::i32);
+ LoOpcode = AMDGPUISD::MUL_I24;
+ HiOpcode = AMDGPUISD::MULHI_I24;
+ } else {
+ return SDValue();
+ }
+
+ SDValue Lo = DAG.getNode(LoOpcode, DL, MVT::i32, N0, N1);
+ SDValue Hi = DAG.getNode(HiOpcode, DL, MVT::i32, N0, N1);
+ DCI.CombineTo(N, Lo, Hi);
+ return SDValue(N, 0);
+}
+
SDValue AMDGPUTargetLowering::performMulhsCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
EVT VT = N->getValueType(0);
@@ -4103,6 +4149,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::MUL:
return performMulCombine(N, DCI);
+ case ISD::SMUL_LOHI:
+ case ISD::UMUL_LOHI:
+ return performMulLoHiCombine(N, DCI);
case ISD::MULHS:
return performMulhsCombine(N, DCI);
case ISD::MULHU: