aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp96
1 files changed, 56 insertions, 40 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b080ab7e138c..7d0fc4e8a8c6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -419,7 +419,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setTruncStoreAction(VT, MVT::bf16, Expand);
setOperationAction(ISD::BF16_TO_FP, VT, Expand);
- setOperationAction(ISD::FP_TO_BF16, VT, Expand);
+ setOperationAction(ISD::FP_TO_BF16, VT, Custom);
}
setOperationAction(ISD::PARITY, MVT::i8, Custom);
@@ -2494,6 +2494,10 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
!Subtarget.hasX87())
return MVT::i32;
+ if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
+ return getRegisterTypeForCallingConv(Context, CC,
+ VT.changeVectorElementTypeToInteger());
+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}
@@ -2525,6 +2529,10 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
return 3;
}
+ if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
+ return getNumRegistersForCallingConv(Context, CC,
+ VT.changeVectorElementTypeToInteger());
+
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
@@ -2733,6 +2741,40 @@ unsigned X86TargetLowering::getJumpTableEncoding() const {
return TargetLowering::getJumpTableEncoding();
}
+bool X86TargetLowering::splitValueIntoRegisterParts(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
+ unsigned NumParts, MVT PartVT, Optional<CallingConv::ID> CC) const {
+ bool IsABIRegCopy = CC.has_value();
+ EVT ValueVT = Val.getValueType();
+ if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
+ unsigned ValueBits = ValueVT.getSizeInBits();
+ unsigned PartBits = PartVT.getSizeInBits();
+ Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val);
+ Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val);
+ Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
+ Parts[0] = Val;
+ return true;
+ }
+ return false;
+}
+
+SDValue X86TargetLowering::joinRegisterPartsIntoValue(
+ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
+ MVT PartVT, EVT ValueVT, Optional<CallingConv::ID> CC) const {
+ bool IsABIRegCopy = CC.has_value();
+ if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
+ unsigned ValueBits = ValueVT.getSizeInBits();
+ unsigned PartBits = PartVT.getSizeInBits();
+ SDValue Val = Parts[0];
+
+ Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val);
+ Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val);
+ Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+ return Val;
+ }
+ return SDValue();
+}
+
bool X86TargetLowering::useSoftFloat() const {
return Subtarget.useSoftFloat();
}
@@ -19304,44 +19346,6 @@ static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) {
return false;
}
-static bool canCombineAsMaskOperation(SDValue V1, SDValue V2,
- const X86Subtarget &Subtarget) {
- if (!Subtarget.hasAVX512())
- return false;
-
- MVT VT = V1.getSimpleValueType().getScalarType();
- if ((VT == MVT::i16 || VT == MVT::i8) && !Subtarget.hasBWI())
- return false;
-
- // i8 is better to be widen to i16, because there is PBLENDW for vXi16
- // when the vector bit size is 128 or 256.
- if (VT == MVT::i8 && V1.getSimpleValueType().getSizeInBits() < 512)
- return false;
-
- auto HasMaskOperation = [&](SDValue V) {
- // TODO: Currently we only check limited opcode. We probably extend
- // it to all binary operation by checking TLI.isBinOp().
- switch (V->getOpcode()) {
- default:
- return false;
- case ISD::ADD:
- case ISD::SUB:
- case ISD::AND:
- case ISD::XOR:
- break;
- }
- if (!V->hasOneUse())
- return false;
-
- return true;
- };
-
- if (HasMaskOperation(V1) || HasMaskOperation(V2))
- return true;
-
- return false;
-}
-
// Forward declaration.
static SDValue canonicalizeShuffleMaskWithHorizOp(
MutableArrayRef<SDValue> Ops, MutableArrayRef<int> Mask,
@@ -19417,7 +19421,6 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget,
// integers to handle flipping the low and high halves of AVX 256-bit vectors.
SmallVector<int, 16> WidenedMask;
if (VT.getScalarSizeInBits() < 64 && !Is1BitVector &&
- !canCombineAsMaskOperation(V1, V2, Subtarget) &&
canWidenShuffleElements(OrigMask, Zeroable, V2IsZero, WidenedMask)) {
// Shuffle mask widening should not interfere with a broadcast opportunity
// by obfuscating the operands with bitcasts.
@@ -23058,6 +23061,18 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
return Res;
}
+SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ MakeLibCallOptions CallOptions;
+ RTLIB::Libcall LC =
+ RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
+ SDValue Res =
+ makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
+ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16,
+ DAG.getBitcast(MVT::i32, Res));
+}
+
/// Depending on uarch and/or optimizing for size, we might prefer to use a
/// vector operation in place of the typical scalar operation.
static SDValue lowerAddSubToHorizontalOp(SDValue Op, SelectionDAG &DAG,
@@ -32250,6 +32265,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::STRICT_FP16_TO_FP: return LowerFP16_TO_FP(Op, DAG);
case ISD::FP_TO_FP16:
case ISD::STRICT_FP_TO_FP16: return LowerFP_TO_FP16(Op, DAG);
+ case ISD::FP_TO_BF16: return LowerFP_TO_BF16(Op, DAG);
case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG);
case ISD::STORE: return LowerStore(Op, Subtarget, DAG);
case ISD::FADD: