aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp58
1 files changed, 44 insertions, 14 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 8fbc90a6db9f..0dbcaf5a1b13 100644
--- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -387,17 +387,20 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
MVT::v9i32, MVT::v9f32, MVT::v10i32, MVT::v10f32,
MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32},
Custom);
+
+ // FIXME: Why is v8f16/v8bf16 missing?
setOperationAction(
ISD::EXTRACT_SUBVECTOR,
- {MVT::v2f16, MVT::v2i16, MVT::v4f16, MVT::v4i16, MVT::v2f32,
- MVT::v2i32, MVT::v3f32, MVT::v3i32, MVT::v4f32, MVT::v4i32,
- MVT::v5f32, MVT::v5i32, MVT::v6f32, MVT::v6i32, MVT::v7f32,
- MVT::v7i32, MVT::v8f32, MVT::v8i32, MVT::v9f32, MVT::v9i32,
- MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
- MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
- MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64,
- MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64,
- MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
+ {MVT::v2f16, MVT::v2bf16, MVT::v2i16, MVT::v4f16, MVT::v4bf16,
+ MVT::v4i16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
+ MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32,
+ MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32,
+ MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32,
+ MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16bf16,
+ MVT::v16i16, MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32,
+ MVT::v2f64, MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64,
+ MVT::v4i64, MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64,
+ MVT::v32i16, MVT::v32f16, MVT::v32bf16},
Custom);
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
@@ -3281,7 +3284,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
SDLoc DL(Op);
@@ -3319,7 +3330,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
// TODO: Factor out code common with LowerUINT_TO_FP.
@@ -3517,7 +3536,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
}
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
unsigned OpOpcode = Op.getOpcode();
@@ -3528,6 +3547,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
if (SrcVT == MVT::f16 && DestVT == MVT::i16)
return Op;
+ if (SrcVT == MVT::bf16) {
+ SDLoc DL(Op);
+ SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+ return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+ }
+
// Promote i16 to i32
if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
SDLoc DL(Op);
@@ -3536,6 +3561,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
}
+ if (DestVT != MVT::i64)
+ return Op;
+
if (SrcVT == MVT::f16 ||
(SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
SDLoc DL(Op);
@@ -3546,7 +3574,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
}
- if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+ if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
return SDValue();
@@ -4947,7 +4975,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
// vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
if (DestVT.isVector()) {
SDValue Src = N->getOperand(0);
- if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+ if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+ (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+ isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
EVT SrcVT = Src.getValueType();
unsigned NElts = DestVT.getVectorNumElements();