aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp330
1 files changed, 330 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 4c35649cec6c..4f7bf3f7d35e 100644
--- a/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3048,6 +3048,336 @@ bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
return true;
}
+static MachineSDNode *buildRegSequence32(SmallVectorImpl<SDValue> &Elts,
+ llvm::SelectionDAG *CurDAG,
+ const SDLoc &DL) {
+ unsigned DstRegClass;
+ EVT DstTy;
+ switch (Elts.size()) {
+ case 8:
+ DstRegClass = AMDGPU::VReg_256RegClassID;
+ DstTy = MVT::v8i32;
+ break;
+ case 4:
+ DstRegClass = AMDGPU::VReg_128RegClassID;
+ DstTy = MVT::v4i32;
+ break;
+ case 2:
+ DstRegClass = AMDGPU::VReg_64RegClassID;
+ DstTy = MVT::v2i32;
+ break;
+ default:
+ llvm_unreachable("unhandled Reg sequence size");
+ }
+
+ SmallVector<SDValue, 17> Ops;
+ Ops.push_back(CurDAG->getTargetConstant(DstRegClass, DL, MVT::i32));
+ for (unsigned i = 0; i < Elts.size(); ++i) {
+ Ops.push_back(Elts[i]);
+ Ops.push_back(CurDAG->getTargetConstant(
+ SIRegisterInfo::getSubRegFromChannel(i), DL, MVT::i32));
+ }
+ return CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, DstTy, Ops);
+}
+
+static MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
+ llvm::SelectionDAG *CurDAG,
+ const SDLoc &DL) {
+ SmallVector<SDValue, 8> PackedElts;
+ assert("unhandled Reg sequence size" &&
+ (Elts.size() == 8 || Elts.size() == 16));
+
+ // Pack 16-bit elements in pairs into 32-bit register. If both elements are
+ // unpacked from 32-bit source use it, otherwise pack them using v_perm.
+ for (unsigned i = 0; i < Elts.size(); i += 2) {
+ SDValue LoSrc = stripExtractLoElt(stripBitcast(Elts[i]));
+ SDValue HiSrc;
+ if (isExtractHiElt(Elts[i + 1], HiSrc) && LoSrc == HiSrc) {
+ PackedElts.push_back(HiSrc);
+ } else {
+ SDValue PackLoLo = CurDAG->getTargetConstant(0x05040100, DL, MVT::i32);
+ MachineSDNode *Packed =
+ CurDAG->getMachineNode(AMDGPU::V_PERM_B32_e64, DL, MVT::i32,
+ {Elts[i + 1], Elts[i], PackLoLo});
+ PackedElts.push_back(SDValue(Packed, 0));
+ }
+ }
+
+ return buildRegSequence32(PackedElts, CurDAG, DL);
+}
+
+static MachineSDNode *buildRegSequence(SmallVectorImpl<SDValue> &Elts,
+ llvm::SelectionDAG *CurDAG,
+ const SDLoc &DL, unsigned ElementSize) {
+ if (ElementSize == 16)
+ return buildRegSequence16(Elts, CurDAG, DL);
+ if (ElementSize == 32)
+ return buildRegSequence32(Elts, CurDAG, DL);
+ llvm_unreachable("Unhandled element size");
+}
+
+static void selectWMMAModsNegAbs(unsigned ModOpcode, unsigned &Mods,
+ SmallVectorImpl<SDValue> &Elts, SDValue &Src,
+ llvm::SelectionDAG *CurDAG, const SDLoc &DL,
+ unsigned ElementSize) {
+ if (ModOpcode == ISD::FNEG) {
+ Mods |= SISrcMods::NEG;
+ // Check if all elements also have abs modifier
+ SmallVector<SDValue, 8> NegAbsElts;
+ for (auto El : Elts) {
+ if (El.getOpcode() != ISD::FABS)
+ break;
+ NegAbsElts.push_back(El->getOperand(0));
+ }
+ if (Elts.size() != NegAbsElts.size()) {
+ // Neg
+ Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0);
+ } else {
+ // Neg and Abs
+ Mods |= SISrcMods::NEG_HI;
+ Src = SDValue(buildRegSequence(NegAbsElts, CurDAG, DL, ElementSize), 0);
+ }
+ } else {
+ assert(ModOpcode == ISD::FABS);
+ // Abs
+ Mods |= SISrcMods::NEG_HI;
+ Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0);
+ }
+}
+
+// Check all f16 elements for modifiers while looking through b32 and v2b16
+// build vector, stop if element does not satisfy ModifierCheck.
+static void
+checkWMMAElementsModifiersF16(BuildVectorSDNode *BV,
+ std::function<bool(SDValue)> ModifierCheck) {
+ for (unsigned i = 0; i < BV->getNumOperands(); ++i) {
+ if (auto *F16Pair =
+ dyn_cast<BuildVectorSDNode>(stripBitcast(BV->getOperand(i)))) {
+ for (unsigned i = 0; i < F16Pair->getNumOperands(); ++i) {
+ SDValue ElF16 = stripBitcast(F16Pair->getOperand(i));
+ if (!ModifierCheck(ElF16))
+ break;
+ }
+ }
+ }
+}
+
+bool AMDGPUDAGToDAGISel::SelectWMMAModsF16Neg(SDValue In, SDValue &Src,
+ SDValue &SrcMods) const {
+ Src = In;
+ unsigned Mods = SISrcMods::OP_SEL_1;
+
+ // mods are on f16 elements
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) {
+ SmallVector<SDValue, 8> EltsF16;
+
+ checkWMMAElementsModifiersF16(BV, [&](SDValue Element) -> bool {
+ if (Element.getOpcode() != ISD::FNEG)
+ return false;
+ EltsF16.push_back(Element.getOperand(0));
+ return true;
+ });
+
+ // All elements have neg modifier
+ if (BV->getNumOperands() * 2 == EltsF16.size()) {
+ Src = SDValue(buildRegSequence16(EltsF16, CurDAG, SDLoc(In)), 0);
+ Mods |= SISrcMods::NEG;
+ Mods |= SISrcMods::NEG_HI;
+ }
+ }
+
+ // mods are on v2f16 elements
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) {
+ SmallVector<SDValue, 8> EltsV2F16;
+ for (unsigned i = 0; i < BV->getNumOperands(); ++i) {
+ SDValue ElV2f16 = stripBitcast(BV->getOperand(i));
+ // Based on first element decide which mod we match, neg or abs
+ if (ElV2f16.getOpcode() != ISD::FNEG)
+ break;
+ EltsV2F16.push_back(ElV2f16.getOperand(0));
+ }
+
+ // All pairs of elements have neg modifier
+ if (BV->getNumOperands() == EltsV2F16.size()) {
+ Src = SDValue(buildRegSequence32(EltsV2F16, CurDAG, SDLoc(In)), 0);
+ Mods |= SISrcMods::NEG;
+ Mods |= SISrcMods::NEG_HI;
+ }
+ }
+
+ SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectWMMAModsF16NegAbs(SDValue In, SDValue &Src,
+ SDValue &SrcMods) const {
+ Src = In;
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned ModOpcode;
+
+ // mods are on f16 elements
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) {
+ SmallVector<SDValue, 8> EltsF16;
+ checkWMMAElementsModifiersF16(BV, [&](SDValue ElF16) -> bool {
+ // Based on first element decide which mod we match, neg or abs
+ if (EltsF16.empty())
+ ModOpcode = (ElF16.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS;
+ if (ElF16.getOpcode() != ModOpcode)
+ return false;
+ EltsF16.push_back(ElF16.getOperand(0));
+ return true;
+ });
+
+ // All elements have ModOpcode modifier
+ if (BV->getNumOperands() * 2 == EltsF16.size())
+ selectWMMAModsNegAbs(ModOpcode, Mods, EltsF16, Src, CurDAG, SDLoc(In),
+ 16);
+ }
+
+ // mods are on v2f16 elements
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) {
+ SmallVector<SDValue, 8> EltsV2F16;
+
+ for (unsigned i = 0; i < BV->getNumOperands(); ++i) {
+ SDValue ElV2f16 = stripBitcast(BV->getOperand(i));
+ // Based on first element decide which mod we match, neg or abs
+ if (EltsV2F16.empty())
+ ModOpcode = (ElV2f16.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS;
+ if (ElV2f16->getOpcode() != ModOpcode)
+ break;
+ EltsV2F16.push_back(ElV2f16->getOperand(0));
+ }
+
+ // All elements have ModOpcode modifier
+ if (BV->getNumOperands() == EltsV2F16.size())
+ selectWMMAModsNegAbs(ModOpcode, Mods, EltsV2F16, Src, CurDAG, SDLoc(In),
+ 32);
+ }
+
+ SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
+ SDValue &SrcMods) const {
+ Src = In;
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned ModOpcode;
+ SmallVector<SDValue, 8> EltsF32;
+
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(stripBitcast(In))) {
+ for (unsigned i = 0; i < BV->getNumOperands(); ++i) {
+ SDValue ElF32 = stripBitcast(BV->getOperand(i));
+ // Based on first element decide which mod we match, neg or abs
+ if (EltsF32.empty())
+ ModOpcode = (ElF32.getOpcode() == ISD::FNEG) ? ISD::FNEG : ISD::FABS;
+ if (ElF32.getOpcode() != ModOpcode)
+ break;
+ EltsF32.push_back(ElF32.getOperand(0));
+ }
+
+ // All elements had ModOpcode modifier
+ if (BV->getNumOperands() == EltsF32.size())
+ selectWMMAModsNegAbs(ModOpcode, Mods, EltsF32, Src, CurDAG, SDLoc(In),
+ 32);
+ }
+
+ SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectWMMAVISrc(SDValue In, SDValue &Src) const {
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(In)) {
+ BitVector UndefElements;
+ if (SDValue Splat = BV->getSplatValue(&UndefElements))
+ if (isInlineImmediate(Splat.getNode())) {
+ if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) {
+ unsigned Imm = C->getAPIntValue().getSExtValue();
+ Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
+ return true;
+ }
+ if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat)) {
+ unsigned Imm = C->getValueAPF().bitcastToAPInt().getSExtValue();
+ Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
+ return true;
+ }
+ llvm_unreachable("unhandled Constant node");
+ }
+ }
+
+ // 16 bit splat
+ SDValue SplatSrc32 = stripBitcast(In);
+ if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32)) {
+ if (SDValue Splat32 = SplatSrc32BV->getSplatValue()) {
+ SDValue SplatSrc16 = stripBitcast(Splat32);
+ if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16)) {
+ if (SDValue Splat = SplatSrc16BV->getSplatValue()) {
+
+ // f16
+ if (isInlineImmediate(Splat.getNode())) {
+ const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat);
+ int64_t Imm = C->getValueAPF().bitcastToAPInt().getSExtValue();
+ Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i16);
+ return true;
+ }
+
+ // bf16
+ if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) {
+ const SIInstrInfo *TII = Subtarget->getInstrInfo();
+ APInt BF16Value = C->getAPIntValue();
+ APInt F32Value = BF16Value.zext(32).shl(16);
+ if (TII->isInlineConstant(F32Value)) {
+ int64_t Imm = F32Value.getSExtValue();
+ Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
+ return true;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex8(SDValue In, SDValue &Src,
+ SDValue &IndexKey) const {
+ unsigned Key = 0;
+ Src = In;
+
+ if (In.getOpcode() == ISD::SRL) {
+ const llvm::SDValue &ShiftSrc = In.getOperand(0);
+ ConstantSDNode *ShiftAmt = dyn_cast<ConstantSDNode>(In.getOperand(1));
+ if (ShiftSrc.getValueType().getSizeInBits() == 32 && ShiftAmt &&
+ ShiftAmt->getZExtValue() % 8 == 0) {
+ Key = ShiftAmt->getZExtValue() / 8;
+ Src = ShiftSrc;
+ }
+ }
+
+ IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+ return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src,
+ SDValue &IndexKey) const {
+ unsigned Key = 0;
+ Src = In;
+
+ if (In.getOpcode() == ISD::SRL) {
+ const llvm::SDValue &ShiftSrc = In.getOperand(0);
+ ConstantSDNode *ShiftAmt = dyn_cast<ConstantSDNode>(In.getOperand(1));
+ if (ShiftSrc.getValueType().getSizeInBits() == 32 && ShiftAmt &&
+ ShiftAmt->getZExtValue() == 16) {
+ Key = 1;
+ Src = ShiftSrc;
+ }
+ }
+
+ IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+ return true;
+}
+
bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src,
SDValue &SrcMods) const {
Src = In;