diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 96 |
1 files changed, 56 insertions, 40 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b080ab7e138c..7d0fc4e8a8c6 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -419,7 +419,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(VT, MVT::bf16, Expand); setOperationAction(ISD::BF16_TO_FP, VT, Expand); - setOperationAction(ISD::FP_TO_BF16, VT, Expand); + setOperationAction(ISD::FP_TO_BF16, VT, Custom); } setOperationAction(ISD::PARITY, MVT::i8, Custom); @@ -2494,6 +2494,10 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, !Subtarget.hasX87()) return MVT::i32; + if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) + return getRegisterTypeForCallingConv(Context, CC, + VT.changeVectorElementTypeToInteger()); + return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); } @@ -2525,6 +2529,10 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context, return 3; } + if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) + return getNumRegistersForCallingConv(Context, CC, + VT.changeVectorElementTypeToInteger()); + return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT); } @@ -2733,6 +2741,40 @@ unsigned X86TargetLowering::getJumpTableEncoding() const { return TargetLowering::getJumpTableEncoding(); } +bool X86TargetLowering::splitValueIntoRegisterParts( + SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, + unsigned NumParts, MVT PartVT, Optional<CallingConv::ID> CC) const { + bool IsABIRegCopy = CC.has_value(); + EVT ValueVT = Val.getValueType(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val); + Parts[0] = Val; + return true; + } + return false; +} + +SDValue X86TargetLowering::joinRegisterPartsIntoValue( + SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, Optional<CallingConv::ID> CC) const { + bool IsABIRegCopy = CC.has_value(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + SDValue Val = Parts[0]; + + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + return Val; + } + return SDValue(); +} + bool X86TargetLowering::useSoftFloat() const { return Subtarget.useSoftFloat(); } @@ -19304,44 +19346,6 @@ static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) { return false; } -static bool canCombineAsMaskOperation(SDValue V1, SDValue V2, - const X86Subtarget &Subtarget) { - if (!Subtarget.hasAVX512()) - return false; - - MVT VT = V1.getSimpleValueType().getScalarType(); - if ((VT == MVT::i16 || VT == MVT::i8) && !Subtarget.hasBWI()) - return false; - - // i8 is better to be widen to i16, because there is PBLENDW for vXi16 - // when the vector bit size is 128 or 256. - if (VT == MVT::i8 && V1.getSimpleValueType().getSizeInBits() < 512) - return false; - - auto HasMaskOperation = [&](SDValue V) { - // TODO: Currently we only check limited opcode. We probably extend - // it to all binary operation by checking TLI.isBinOp(). - switch (V->getOpcode()) { - default: - return false; - case ISD::ADD: - case ISD::SUB: - case ISD::AND: - case ISD::XOR: - break; - } - if (!V->hasOneUse()) - return false; - - return true; - }; - - if (HasMaskOperation(V1) || HasMaskOperation(V2)) - return true; - - return false; -} - // Forward declaration. static SDValue canonicalizeShuffleMaskWithHorizOp( MutableArrayRef<SDValue> Ops, MutableArrayRef<int> Mask, @@ -19417,7 +19421,6 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget, // integers to handle flipping the low and high halves of AVX 256-bit vectors. SmallVector<int, 16> WidenedMask; if (VT.getScalarSizeInBits() < 64 && !Is1BitVector && - !canCombineAsMaskOperation(V1, V2, Subtarget) && canWidenShuffleElements(OrigMask, Zeroable, V2IsZero, WidenedMask)) { // Shuffle mask widening should not interfere with a broadcast opportunity // by obfuscating the operands with bitcasts. @@ -23058,6 +23061,18 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) { return Res; } +SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + RTLIB::Libcall LC = + RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16); + SDValue Res = + makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, + DAG.getBitcast(MVT::i32, Res)); +} + /// Depending on uarch and/or optimizing for size, we might prefer to use a /// vector operation in place of the typical scalar operation. static SDValue lowerAddSubToHorizontalOp(SDValue Op, SelectionDAG &DAG, @@ -32250,6 +32265,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::STRICT_FP16_TO_FP: return LowerFP16_TO_FP(Op, DAG); case ISD::FP_TO_FP16: case ISD::STRICT_FP_TO_FP16: return LowerFP_TO_FP16(Op, DAG); + case ISD::FP_TO_BF16: return LowerFP_TO_BF16(Op, DAG); case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG); case ISD::STORE: return LowerStore(Op, Subtarget, DAG); case ISD::FADD: |
