aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX/DXILResource.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILResource.cpp')
-rw-r--r--llvm/lib/Target/DirectX/DXILResource.cpp420
1 files changed, 420 insertions, 0 deletions
diff --git a/llvm/lib/Target/DirectX/DXILResource.cpp b/llvm/lib/Target/DirectX/DXILResource.cpp
new file mode 100644
index 000000000000..763432911dbf
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResource.cpp
@@ -0,0 +1,420 @@
+//===- DXILResource.cpp - DXIL Resource helper objects --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with DXIL Resources.
+///
+//===----------------------------------------------------------------------===//
+
+#include "DXILResource.h"
+#include "CBufferDataLayout.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Format.h"
+
+using namespace llvm;
+using namespace llvm::dxil;
+using namespace llvm::hlsl;
+
+template <typename T> void ResourceTable<T>::collect(Module &M) {
+ NamedMDNode *Entry = M.getNamedMetadata(MDName);
+ if (!Entry || Entry->getNumOperands() == 0)
+ return;
+
+ uint32_t Counter = 0;
+ for (auto *Res : Entry->operands()) {
+ Data.push_back(T(Counter++, FrontendResource(cast<MDNode>(Res))));
+ }
+}
+
+template <> void ResourceTable<ConstantBuffer>::collect(Module &M) {
+ NamedMDNode *Entry = M.getNamedMetadata(MDName);
+ if (!Entry || Entry->getNumOperands() == 0)
+ return;
+
+ uint32_t Counter = 0;
+ for (auto *Res : Entry->operands()) {
+ Data.push_back(
+ ConstantBuffer(Counter++, FrontendResource(cast<MDNode>(Res))));
+ }
+ // FIXME: share CBufferDataLayout with CBuffer load lowering.
+ // See https://github.com/llvm/llvm-project/issues/58381
+ CBufferDataLayout CBDL(M.getDataLayout(), /*IsLegacy*/ true);
+ for (auto &CB : Data)
+ CB.setSize(CBDL);
+}
+
+void Resources::collect(Module &M) {
+ UAVs.collect(M);
+ CBuffers.collect(M);
+}
+
+ResourceBase::ResourceBase(uint32_t I, FrontendResource R)
+ : ID(I), GV(R.getGlobalVariable()), Name(""), Space(R.getSpace()),
+ LowerBound(R.getResourceIndex()), RangeSize(1) {
+ if (auto *ArrTy = dyn_cast<ArrayType>(GV->getValueType()))
+ RangeSize = ArrTy->getNumElements();
+}
+
+StringRef ResourceBase::getComponentTypeName(ComponentType CompType) {
+ switch (CompType) {
+ case ComponentType::LastEntry:
+ case ComponentType::Invalid:
+ return "invalid";
+ case ComponentType::I1:
+ return "i1";
+ case ComponentType::I16:
+ return "i16";
+ case ComponentType::U16:
+ return "u16";
+ case ComponentType::I32:
+ return "i32";
+ case ComponentType::U32:
+ return "u32";
+ case ComponentType::I64:
+ return "i64";
+ case ComponentType::U64:
+ return "u64";
+ case ComponentType::F16:
+ return "f16";
+ case ComponentType::F32:
+ return "f32";
+ case ComponentType::F64:
+ return "f64";
+ case ComponentType::SNormF16:
+ return "snorm_f16";
+ case ComponentType::UNormF16:
+ return "unorm_f16";
+ case ComponentType::SNormF32:
+ return "snorm_f32";
+ case ComponentType::UNormF32:
+ return "unorm_f32";
+ case ComponentType::SNormF64:
+ return "snorm_f64";
+ case ComponentType::UNormF64:
+ return "unorm_f64";
+ case ComponentType::PackedS8x32:
+ return "p32i8";
+ case ComponentType::PackedU8x32:
+ return "p32u8";
+ }
+}
+
+void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType,
+ unsigned Alignment, raw_ostream &OS) {
+ switch (Kind) {
+ default:
+ // TODO: add vector size.
+ OS << right_justify(getComponentTypeName(CompType), Alignment);
+ break;
+ case Kinds::RawBuffer:
+ OS << right_justify("byte", Alignment);
+ break;
+ case Kinds::StructuredBuffer:
+ OS << right_justify("struct", Alignment);
+ break;
+ case Kinds::CBuffer:
+ case Kinds::Sampler:
+ OS << right_justify("NA", Alignment);
+ break;
+ case Kinds::Invalid:
+ case Kinds::NumEntries:
+ break;
+ }
+}
+
+StringRef ResourceBase::getKindName(Kinds Kind) {
+ switch (Kind) {
+ case Kinds::NumEntries:
+ case Kinds::Invalid:
+ return "invalid";
+ case Kinds::Texture1D:
+ return "1d";
+ case Kinds::Texture2D:
+ return "2d";
+ case Kinds::Texture2DMS:
+ return "2dMS";
+ case Kinds::Texture3D:
+ return "3d";
+ case Kinds::TextureCube:
+ return "cube";
+ case Kinds::Texture1DArray:
+ return "1darray";
+ case Kinds::Texture2DArray:
+ return "2darray";
+ case Kinds::Texture2DMSArray:
+ return "2darrayMS";
+ case Kinds::TextureCubeArray:
+ return "cubearray";
+ case Kinds::TypedBuffer:
+ return "buf";
+ case Kinds::RawBuffer:
+ return "rawbuf";
+ case Kinds::StructuredBuffer:
+ return "structbuf";
+ case Kinds::CBuffer:
+ return "cbuffer";
+ case Kinds::Sampler:
+ return "sampler";
+ case Kinds::TBuffer:
+ return "tbuffer";
+ case Kinds::RTAccelerationStructure:
+ return "ras";
+ case Kinds::FeedbackTexture2D:
+ return "fbtex2d";
+ case Kinds::FeedbackTexture2DArray:
+ return "fbtex2darray";
+ }
+}
+
+void ResourceBase::printKind(Kinds Kind, unsigned Alignment, raw_ostream &OS,
+ bool SRV, bool HasCounter, uint32_t SampleCount) {
+ switch (Kind) {
+ default:
+ OS << right_justify(getKindName(Kind), Alignment);
+ break;
+
+ case Kinds::RawBuffer:
+ case Kinds::StructuredBuffer:
+ if (SRV)
+ OS << right_justify("r/o", Alignment);
+ else {
+ if (!HasCounter)
+ OS << right_justify("r/w", Alignment);
+ else
+ OS << right_justify("r/w+cnt", Alignment);
+ }
+ break;
+ case Kinds::TypedBuffer:
+ OS << right_justify("buf", Alignment);
+ break;
+ case Kinds::Texture2DMS:
+ case Kinds::Texture2DMSArray: {
+ std::string DimName = getKindName(Kind).str();
+ if (SampleCount)
+ DimName += std::to_string(SampleCount);
+ OS << right_justify(DimName, Alignment);
+ } break;
+ case Kinds::CBuffer:
+ case Kinds::Sampler:
+ OS << right_justify("NA", Alignment);
+ break;
+ case Kinds::Invalid:
+ case Kinds::NumEntries:
+ break;
+ }
+}
+
+void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
+ StringRef BindingPrefix) const {
+ std::string ResID = IDPrefix.str();
+ ResID += std::to_string(ID);
+ OS << right_justify(ResID, 8);
+
+ std::string Bind = BindingPrefix.str();
+ Bind += std::to_string(LowerBound);
+ if (Space)
+ Bind += ",space" + std::to_string(Space);
+
+ OS << right_justify(Bind, 15);
+ if (RangeSize != UINT_MAX)
+ OS << right_justify(std::to_string(RangeSize), 6) << "\n";
+ else
+ OS << right_justify("unbounded", 6) << "\n";
+}
+
+UAVResource::UAVResource(uint32_t I, FrontendResource R)
+ : ResourceBase(I, R),
+ Shape(static_cast<ResourceBase::Kinds>(R.getResourceKind())),
+ GloballyCoherent(false), HasCounter(false), IsROV(false), ExtProps() {
+ parseSourceType(R.getSourceType());
+}
+
+void UAVResource::print(raw_ostream &OS) const {
+ OS << "; " << left_justify(Name, 31);
+
+ OS << right_justify("UAV", 10);
+
+ printComponentType(
+ Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS);
+
+ // FIXME: support SampleCount.
+ // See https://github.com/llvm/llvm-project/issues/58175
+ printKind(Shape, 12, OS, /*SRV*/ false, HasCounter);
+ // Print the binding part.
+ ResourceBase::print(OS, "U", "u");
+}
+
+// FIXME: Capture this in HLSL source. I would go do this right now, but I want
+// to get this in first so that I can make sure to capture all the extra
+// information we need to remove the source type string from here (See issue:
+// https://github.com/llvm/llvm-project/issues/57991).
+void UAVResource::parseSourceType(StringRef S) {
+ IsROV = S.startswith("RasterizerOrdered");
+ if (IsROV)
+ S = S.substr(strlen("RasterizerOrdered"));
+ if (S.startswith("RW"))
+ S = S.substr(strlen("RW"));
+
+ // Note: I'm deliberately not handling any of the Texture buffer types at the
+ // moment. I want to resolve the issue above before adding Texture or Sampler
+ // support.
+ Shape = StringSwitch<ResourceBase::Kinds>(S)
+ .StartsWith("Buffer<", Kinds::TypedBuffer)
+ .StartsWith("ByteAddressBuffer<", Kinds::RawBuffer)
+ .StartsWith("StructuredBuffer<", Kinds::StructuredBuffer)
+ .Default(Kinds::Invalid);
+ assert(Shape != Kinds::Invalid && "Unsupported buffer type");
+
+ S = S.substr(S.find("<") + 1);
+
+ constexpr size_t PrefixLen = StringRef("vector<").size();
+ if (S.startswith("vector<"))
+ S = S.substr(PrefixLen, S.find(",") - PrefixLen);
+ else
+ S = S.substr(0, S.find(">"));
+
+ ComponentType ElTy = StringSwitch<ResourceBase::ComponentType>(S)
+ .Case("bool", ComponentType::I1)
+ .Case("int16_t", ComponentType::I16)
+ .Case("uint16_t", ComponentType::U16)
+ .Case("int32_t", ComponentType::I32)
+ .Case("uint32_t", ComponentType::U32)
+ .Case("int64_t", ComponentType::I64)
+ .Case("uint64_t", ComponentType::U64)
+ .Case("half", ComponentType::F16)
+ .Case("float", ComponentType::F32)
+ .Case("double", ComponentType::F64)
+ .Default(ComponentType::Invalid);
+ if (ElTy != ComponentType::Invalid)
+ ExtProps.ElementType = ElTy;
+}
+
+ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
+ : ResourceBase(I, R) {}
+
+void ConstantBuffer::setSize(CBufferDataLayout &DL) {
+ CBufferSizeInBytes = DL.getTypeAllocSizeInBytes(GV->getValueType());
+}
+
+void ConstantBuffer::print(raw_ostream &OS) const {
+ OS << "; " << left_justify(Name, 31);
+
+ OS << right_justify("cbuffer", 10);
+
+ printComponentType(Kinds::CBuffer, ComponentType::Invalid, 8, OS);
+
+ printKind(Kinds::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
+ // Print the binding part.
+ ResourceBase::print(OS, "CB", "cb");
+}
+
+template <typename T> void ResourceTable<T>::print(raw_ostream &OS) const {
+ for (auto &Res : Data)
+ Res.print(OS);
+}
+
+MDNode *ResourceBase::ExtendedProperties::write(LLVMContext &Ctx) const {
+ IRBuilder<> B(Ctx);
+ SmallVector<Metadata *> Entries;
+ if (ElementType) {
+ Entries.emplace_back(
+ ConstantAsMetadata::get(B.getInt32(TypedBufferElementType)));
+ Entries.emplace_back(ConstantAsMetadata::get(
+ B.getInt32(static_cast<uint32_t>(*ElementType))));
+ }
+ if (Entries.empty())
+ return nullptr;
+ return MDNode::get(Ctx, Entries);
+}
+
+void ResourceBase::write(LLVMContext &Ctx,
+ MutableArrayRef<Metadata *> Entries) const {
+ IRBuilder<> B(Ctx);
+ Entries[0] = ConstantAsMetadata::get(B.getInt32(ID));
+ Entries[1] = ConstantAsMetadata::get(GV);
+ Entries[2] = MDString::get(Ctx, Name);
+ Entries[3] = ConstantAsMetadata::get(B.getInt32(Space));
+ Entries[4] = ConstantAsMetadata::get(B.getInt32(LowerBound));
+ Entries[5] = ConstantAsMetadata::get(B.getInt32(RangeSize));
+}
+
+MDNode *UAVResource::write() const {
+ auto &Ctx = GV->getContext();
+ IRBuilder<> B(Ctx);
+ Metadata *Entries[11];
+ ResourceBase::write(Ctx, Entries);
+ Entries[6] =
+ ConstantAsMetadata::get(B.getInt32(static_cast<uint32_t>(Shape)));
+ Entries[7] = ConstantAsMetadata::get(B.getInt1(GloballyCoherent));
+ Entries[8] = ConstantAsMetadata::get(B.getInt1(HasCounter));
+ Entries[9] = ConstantAsMetadata::get(B.getInt1(IsROV));
+ Entries[10] = ExtProps.write(Ctx);
+ return MDNode::get(Ctx, Entries);
+}
+
+MDNode *ConstantBuffer::write() const {
+ auto &Ctx = GV->getContext();
+ IRBuilder<> B(Ctx);
+ Metadata *Entries[7];
+ ResourceBase::write(Ctx, Entries);
+
+ Entries[6] = ConstantAsMetadata::get(B.getInt32(CBufferSizeInBytes));
+ return MDNode::get(Ctx, Entries);
+}
+
+template <typename T> MDNode *ResourceTable<T>::write(Module &M) const {
+ if (Data.empty())
+ return nullptr;
+ SmallVector<Metadata *> MDs;
+ for (auto &Res : Data)
+ MDs.emplace_back(Res.write());
+
+ NamedMDNode *Entry = M.getNamedMetadata(MDName);
+ if (Entry)
+ Entry->eraseFromParent();
+
+ return MDNode::get(M.getContext(), MDs);
+}
+
+void Resources::write(Module &M) const {
+ Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr};
+
+ ResourceMDs[1] = UAVs.write(M);
+
+ ResourceMDs[2] = CBuffers.write(M);
+
+ bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr ||
+ ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr;
+
+ if (HasResource) {
+ NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources");
+ DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs));
+ }
+
+ NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs");
+ if (Entry)
+ Entry->eraseFromParent();
+}
+
+void Resources::print(raw_ostream &O) const {
+ O << ";\n"
+ << "; Resource Bindings:\n"
+ << ";\n"
+ << "; Name Type Format Dim "
+ "ID HLSL Bind Count\n"
+ << "; ------------------------------ ---------- ------- ----------- "
+ "------- -------------- ------\n";
+
+ CBuffers.print(O);
+ UAVs.print(O);
+}
+
+void Resources::dump() const { print(dbgs()); }