diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 523fa2d3724b..54177564afbc 100644 --- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -594,6 +594,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::SRL); setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::MUL); + setTargetDAGCombine(ISD::SMUL_LOHI); + setTargetDAGCombine(ISD::UMUL_LOHI); setTargetDAGCombine(ISD::MULHU); setTargetDAGCombine(ISD::MULHS); setTargetDAGCombine(ISD::SELECT); @@ -3462,6 +3464,50 @@ SDValue AMDGPUTargetLowering::performMulCombine(SDNode *N, return DAG.getSExtOrTrunc(Mul, DL, VT); } +SDValue +AMDGPUTargetLowering::performMulLoHiCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + if (N->getValueType(0) != MVT::i32) + return SDValue(); + + SelectionDAG &DAG = DCI.DAG; + SDLoc DL(N); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // SimplifyDemandedBits has the annoying habit of turning useful zero_extends + // in the source into any_extends if the result of the mul is truncated. Since + // we can assume the high bits are whatever we want, use the underlying value + // to avoid the unknown high bits from interfering. + if (N0.getOpcode() == ISD::ANY_EXTEND) + N0 = N0.getOperand(0); + if (N1.getOpcode() == ISD::ANY_EXTEND) + N1 = N1.getOperand(0); + + // Try to use two fast 24-bit multiplies (one for each half of the result) + // instead of one slow extending multiply. + unsigned LoOpcode, HiOpcode; + if (Subtarget->hasMulU24() && isU24(N0, DAG) && isU24(N1, DAG)) { + N0 = DAG.getZExtOrTrunc(N0, DL, MVT::i32); + N1 = DAG.getZExtOrTrunc(N1, DL, MVT::i32); + LoOpcode = AMDGPUISD::MUL_U24; + HiOpcode = AMDGPUISD::MULHI_U24; + } else if (Subtarget->hasMulI24() && isI24(N0, DAG) && isI24(N1, DAG)) { + N0 = DAG.getSExtOrTrunc(N0, DL, MVT::i32); + N1 = DAG.getSExtOrTrunc(N1, DL, MVT::i32); + LoOpcode = AMDGPUISD::MUL_I24; + HiOpcode = AMDGPUISD::MULHI_I24; + } else { + return SDValue(); + } + + SDValue Lo = DAG.getNode(LoOpcode, DL, MVT::i32, N0, N1); + SDValue Hi = DAG.getNode(HiOpcode, DL, MVT::i32, N0, N1); + DCI.CombineTo(N, Lo, Hi); + return SDValue(N, 0); +} + SDValue AMDGPUTargetLowering::performMulhsCombine(SDNode *N, DAGCombinerInfo &DCI) const { EVT VT = N->getValueType(0); @@ -4103,6 +4149,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N, return performTruncateCombine(N, DCI); case ISD::MUL: return performMulCombine(N, DCI); + case ISD::SMUL_LOHI: + case ISD::UMUL_LOHI: + return performMulLoHiCombine(N, DCI); case ISD::MULHS: return performMulhsCombine(N, DCI); case ISD::MULHU: |
