diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 58 |
1 files changed, 44 insertions, 14 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 8fbc90a6db9f..0dbcaf5a1b13 100644 --- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -387,17 +387,20 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, MVT::v9i32, MVT::v9f32, MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32}, Custom); + + // FIXME: Why is v8f16/v8bf16 missing? setOperationAction( ISD::EXTRACT_SUBVECTOR, - {MVT::v2f16, MVT::v2i16, MVT::v4f16, MVT::v4i16, MVT::v2f32, - MVT::v2i32, MVT::v3f32, MVT::v3i32, MVT::v4f32, MVT::v4i32, - MVT::v5f32, MVT::v5i32, MVT::v6f32, MVT::v6i32, MVT::v7f32, - MVT::v7i32, MVT::v8f32, MVT::v8i32, MVT::v9f32, MVT::v9i32, - MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, - MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32, - MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64, - MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64, - MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16}, + {MVT::v2f16, MVT::v2bf16, MVT::v2i16, MVT::v4f16, MVT::v4bf16, + MVT::v4i16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32, + MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32, + MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32, + MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32, + MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16bf16, + MVT::v16i16, MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, + MVT::v2f64, MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64, + MVT::v4i64, MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64, + MVT::v32i16, MVT::v32f16, MVT::v32bf16}, Custom); setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand); @@ -3281,7 +3284,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op, return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext); } - assert(SrcVT == MVT::i64 && "operation should be legal"); + if (DestVT == MVT::bf16) { + SDLoc SL(Op); + SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src); + SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true); + return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag); + } + + if (SrcVT != MVT::i64) + return Op; if (Subtarget->has16BitInsts() && DestVT == MVT::f16) { SDLoc DL(Op); @@ -3319,7 +3330,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op, return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext); } - assert(SrcVT == MVT::i64 && "operation should be legal"); + if (DestVT == MVT::bf16) { + SDLoc SL(Op); + SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src); + SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true); + return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag); + } + + if (SrcVT != MVT::i64) + return Op; // TODO: Factor out code common with LowerUINT_TO_FP. @@ -3517,7 +3536,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con return DAG.getZExtOrTrunc(V, DL, Op.getValueType()); } -SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op, +SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op, SelectionDAG &DAG) const { SDValue Src = Op.getOperand(0); unsigned OpOpcode = Op.getOpcode(); @@ -3528,6 +3547,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op, if (SrcVT == MVT::f16 && DestVT == MVT::i16) return Op; + if (SrcVT == MVT::bf16) { + SDLoc DL(Op); + SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src); + return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc); + } + // Promote i16 to i32 if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) { SDLoc DL(Op); @@ -3536,6 +3561,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op, return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32); } + if (DestVT != MVT::i64) + return Op; + if (SrcVT == MVT::f16 || (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) { SDLoc DL(Op); @@ -3546,7 +3574,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op, return DAG.getNode(Ext, DL, MVT::i64, FpToInt32); } - if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) + if (SrcVT == MVT::f32 || SrcVT == MVT::f64) return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT); return SDValue(); @@ -4947,7 +4975,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N, // vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y)) if (DestVT.isVector()) { SDValue Src = N->getOperand(0); - if (Src.getOpcode() == ISD::BUILD_VECTOR) { + if (Src.getOpcode() == ISD::BUILD_VECTOR && + (DCI.getDAGCombineLevel() < AfterLegalizeDAG || + isOperationLegal(ISD::BUILD_VECTOR, DestVT))) { EVT SrcVT = Src.getValueType(); unsigned NElts = DestVT.getVectorNumElements(); |