diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index 4c35649cec6c..4f7bf3f7d35e 100644 --- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -3048,6 +3048,336 @@ bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In, return true; } +static MachineSDNode *buildRegSequence32(SmallVectorImpl<SDValue> &Elts, + llvm::SelectionDAG *CurDAG, + const SDLoc &DL) { + unsigned DstRegClass; + EVT DstTy; + switch (Elts.size()) { + case 8: + DstRegClass = AMDGPU::VReg_256RegClassID; + DstTy = MVT::v8i32; + break; + case 4: + DstRegClass = AMDGPU::VReg_128RegClassID; + DstTy = MVT::v4i32; + break; + case 2: + DstRegClass = AMDGPU::VReg_64RegClassID; + DstTy = MVT::v2i32; + break; + default: + llvm_unreachable("unhandled Reg sequence size"); + } + + SmallVector<SDValue, 17> Ops; + Ops.push_back(CurDAG->getTargetConstant(DstRegClass, DL, MVT::i32)); + for (unsigned i = 0; i < Elts.size(); ++i) { + Ops.push_back(Elts[i]); + Ops.push_back(CurDAG->getTargetConstant( + SIRegisterInfo::getSubRegFromChannel(i), DL, MVT::i32)); + } + return CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, DstTy, Ops); +} + +static MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts, + llvm::SelectionDAG *CurDAG, + const SDLoc &DL) { + SmallVector<SDValue, 8> PackedElts; + assert("unhandled Reg sequence size" && + (Elts.size() == 8 || Elts.size() == 16)); + + // Pack 16-bit elements in pairs into 32-bit register. If both elements are + // unpacked from 32-bit source use it, otherwise pack them using v_perm. + for (unsigned i = 0; i < Elts.size(); i += 2) { + SDValue LoSrc = stripExtractLoElt(stripBitcast(Elts[i])); + SDValue HiSrc; + if (isExtractHiElt(Elts[i + 1], HiSrc) && LoSrc == HiSrc) { + PackedElts.push_back(HiSrc); + } else { + SDValue PackLoLo = CurDAG->getTargetConstant(0x05040100, DL, MVT::i32); + MachineSDNode *Packed = + CurDAG->getMachineNode(AMDGPU::V_PERM_B32_e64, DL, MVT::i32, + {Elts[i + 1], Elts[i], PackLoLo}); + PackedElts.push_back(SDValue(Packed, 0)); + } + } + + return buildRegSequence32(PackedElts, CurDAG, DL); +} + +static MachineSDNode *buildRegSequence(SmallVectorImpl<SDValue> &Elts, + llvm::SelectionDAG *CurDAG, + const SDLoc &DL, unsigned ElementSize) { + if (ElementSize == 16) + return buildRegSequence16(Elts, CurDAG, DL); + if (ElementSize == 32) + return buildRegSequence32(Elts, CurDAG, DL); + llvm_unreachable("Unhandled element size"); +} + +static void selectWMMAModsNegAbs(unsigned ModOpcode, unsigned &Mods, + SmallVectorImpl<SDValue> &Elts, SDValue &Src, + llvm::SelectionDAG *CurDAG, const SDLoc &DL, + unsigned ElementSize) { + if (ModOpcode == ISD::FNEG) { + Mods |= SISrcMods::NEG; + // Check if all elements also have abs modifier + SmallVector<SDValue, 8> NegAbsElts; + for (auto El : Elts) { + if (El.getOpcode() != ISD::FABS) + break; + NegAbsElts.push_back(El->getOperand(0)); + } + if (Elts.size() != NegAbsElts.size()) { + // Neg + Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0); + } else { + // Neg and Abs + Mods |= SISrcMods::NEG_HI; + Src = SDValue(buildRegSequence(NegAbsElts, CurDAG, DL, ElementSize), 0); + } + } else { + assert(ModOpcode == ISD::FABS); + // Abs + Mods |= SISrcMods::NEG_HI; + Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0); + } +} + +// Check all f16 elements for modifiers while looking through b32 and v2b16 +// build vector, stop if element does not satisfy ModifierCheck. +static void +checkWMMAElementsModifiersF16(BuildVectorSDNode *BV, + std::function<bool(SDValue)> ModifierCheck) { + for (unsigned i = 0; i < BV->getNumOperands(); ++i) { + if (auto *F16Pair = + dyn_cast<BuildVectorSDNode>(stripBitcast(BV->getOperand(i)))) { + for (unsigned i = 0; i < F16Pair->getNumOperands(); ++i) { + SDValue ElF16 = stripBitcast(F16Pair->getOperand(i)); + if (!ModifierCheck(ElF16)) + break; + } + } + } +} + +bool AMDGPUDAGToDAGISel::SelectWMMAModsF16Neg(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + Src = In; + unsigned Mods = SISrcMods::OP_SEL_1; + + // mods are on f16 elements + if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) { + SmallVector<SDValue, 8> EltsF16; + + checkWMMAElementsModifiersF16(BV, [&](SDValue Element) -> bool { + if (Element.getOpcode() != ISD::FNEG) + return false; + EltsF16.push_back(Element.getOperand(0)); + return true; + }); + + // All elements have neg modifier + if (BV->getNumOperands() * 2 == EltsF16.size()) { + Src = SDValue(buildRegSequence16(EltsF16, CurDAG, SDLoc(In)), 0); + Mods |= SISrcMods::NEG; + Mods |= SISrcMods::NEG_HI; + } + } + + // mods are on v2f16 elements + if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) { + SmallVector<SDValue, 8> EltsV2F16; + for (unsigned i = 0; i < BV->getNumOperands(); ++i) { + SDValue ElV2f16 = stripBitcast(BV->getOperand(i)); + // Based on first element decide which mod we match, neg or abs + if (ElV2f16.getOpcode() != ISD::FNEG) + break; + EltsV2F16.push_back(ElV2f16.getOperand(0)); + } + + // All pairs of elements have neg modifier + if (BV->getNumOperands() == EltsV2F16.size()) { + Src = SDValue(buildRegSequence32(EltsV2F16, CurDAG, SDLoc(In)), 0); + Mods |= SISrcMods::NEG; + Mods |= SISrcMods::NEG_HI; + } + } + + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectWMMAModsF16NegAbs(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + Src = In; + unsigned Mods = SISrcMods::OP_SEL_1; + unsigned ModOpcode; + + // mods are on f16 elements + if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) { + SmallVector<SDValue, 8> EltsF16; + checkWMMAElementsModifiersF16(BV, [&](SDValue ElF16) -> bool { + // Based on first element decide which mod we match, neg or abs + if (EltsF16.empty()) + ModOpcode = (ElF16.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS; + if (ElF16.getOpcode() != ModOpcode) + return false; + EltsF16.push_back(ElF16.getOperand(0)); + return true; + }); + + // All elements have ModOpcode modifier + if (BV->getNumOperands() * 2 == EltsF16.size()) + selectWMMAModsNegAbs(ModOpcode, Mods, EltsF16, Src, CurDAG, SDLoc(In), + 16); + } + + // mods are on v2f16 elements + if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) { + SmallVector<SDValue, 8> EltsV2F16; + + for (unsigned i = 0; i < BV->getNumOperands(); ++i) { + SDValue ElV2f16 = stripBitcast(BV->getOperand(i)); + // Based on first element decide which mod we match, neg or abs + if (EltsV2F16.empty()) + ModOpcode = (ElV2f16.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS; + if (ElV2f16->getOpcode() != ModOpcode) + break; + EltsV2F16.push_back(ElV2f16->getOperand(0)); + } + + // All elements have ModOpcode modifier + if (BV->getNumOperands() == EltsV2F16.size()) + selectWMMAModsNegAbs(ModOpcode, Mods, EltsV2F16, Src, CurDAG, SDLoc(In), + 32); + } + + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + Src = In; + unsigned Mods = SISrcMods::OP_SEL_1; + unsigned ModOpcode; + SmallVector<SDValue, 8> EltsF32; + + if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) { + for (unsigned i = 0; i < BV->getNumOperands(); ++i) { + SDValue ElF32 = stripBitcast(BV->getOperand(i)); + // Based on first element decide which mod we match, neg or abs + if (EltsF32.empty()) + ModOpcode = (ElF32.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS; + if (ElF32.getOpcode() != ModOpcode) + break; + EltsF32.push_back(ElF32.getOperand(0)); + } + + // All elements had ModOpcode modifier + if (BV->getNumOperands() == EltsF32.size()) + selectWMMAModsNegAbs(ModOpcode, Mods, EltsF32, Src, CurDAG, SDLoc(In), + 32); + } + + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectWMMAVISrc(SDValue In, SDValue &Src) const { + if (auto *BV = dyn_cast<BuildVectorSDNode>(In)) { + BitVector UndefElements; + if (SDValue Splat = BV->getSplatValue(&UndefElements)) + if (isInlineImmediate(Splat.getNode())) { + if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) { + unsigned Imm = C->getAPIntValue().getSExtValue(); + Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32); + return true; + } + if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat)) { + unsigned Imm = C->getValueAPF().bitcastToAPInt().getSExtValue(); + Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32); + return true; + } + llvm_unreachable("unhandled Constant node"); + } + } + + // 16 bit splat + SDValue SplatSrc32 = stripBitcast(In); + if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32)) { + if (SDValue Splat32 = SplatSrc32BV->getSplatValue()) { + SDValue SplatSrc16 = stripBitcast(Splat32); + if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16)) { + if (SDValue Splat = SplatSrc16BV->getSplatValue()) { + + // f16 + if (isInlineImmediate(Splat.getNode())) { + const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat); + int64_t Imm = C->getValueAPF().bitcastToAPInt().getSExtValue(); + Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i16); + return true; + } + + // bf16 + if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) { + const SIInstrInfo *TII = Subtarget->getInstrInfo(); + APInt BF16Value = C->getAPIntValue(); + APInt F32Value = BF16Value.zext(32).shl(16); + if (TII->isInlineConstant(F32Value)) { + int64_t Imm = F32Value.getSExtValue(); + Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32); + return true; + } + } + } + } + } + } + + return false; +} + +bool AMDGPUDAGToDAGISel::SelectSWMMACIndex8(SDValue In, SDValue &Src, + SDValue &IndexKey) const { + unsigned Key = 0; + Src = In; + + if (In.getOpcode() == ISD::SRL) { + const llvm::SDValue &ShiftSrc = In.getOperand(0); + ConstantSDNode *ShiftAmt = dyn_cast<ConstantSDNode>(In.getOperand(1)); + if (ShiftSrc.getValueType().getSizeInBits() == 32 && ShiftAmt && + ShiftAmt->getZExtValue() % 8 == 0) { + Key = ShiftAmt->getZExtValue() / 8; + Src = ShiftSrc; + } + } + + IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src, + SDValue &IndexKey) const { + unsigned Key = 0; + Src = In; + + if (In.getOpcode() == ISD::SRL) { + const llvm::SDValue &ShiftSrc = In.getOperand(0); + ConstantSDNode *ShiftAmt = dyn_cast<ConstantSDNode>(In.getOperand(1)); + if (ShiftSrc.getValueType().getSizeInBits() == 32 && ShiftAmt && + ShiftAmt->getZExtValue() == 16) { + Key = 1; + Src = ShiftSrc; + } + } + + IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32); + return true; +} + bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const { Src = In; |
