diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Object/DXContainer.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Object/DXContainer.cpp | 195 |
1 files changed, 193 insertions, 2 deletions
diff --git a/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp b/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp index 48932afea84b..4aabe9cea3e5 100644 --- a/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp +++ b/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp @@ -9,6 +9,7 @@ #include "llvm/Object/DXContainer.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Object/Error.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/FormatVariadic.h" using namespace llvm; @@ -100,6 +101,31 @@ Error DXContainer::parsePSVInfo(StringRef Part) { return Error::success(); } +Error DirectX::Signature::initialize(StringRef Part) { + dxbc::ProgramSignatureHeader SigHeader; + if (Error Err = readStruct(Part, Part.begin(), SigHeader)) + return Err; + size_t Size = sizeof(dxbc::ProgramSignatureElement) * SigHeader.ParamCount; + + if (Part.size() < Size + SigHeader.FirstParamOffset) + return parseFailed("Signature parameters extend beyond the part boundary"); + + Parameters.Data = Part.substr(SigHeader.FirstParamOffset, Size); + + StringTableOffset = SigHeader.FirstParamOffset + static_cast<uint32_t>(Size); + StringTable = Part.substr(SigHeader.FirstParamOffset + Size); + + for (const auto &Param : Parameters) { + if (Param.NameOffset < StringTableOffset) + return parseFailed("Invalid parameter name offset: name starts before " + "the first name offset"); + if (Param.NameOffset - StringTableOffset > StringTable.size()) + return parseFailed("Invalid parameter name offset: name starts after the " + "end of the part data"); + } + return Error::success(); +} + Error DXContainer::parsePartOffsets() { uint32_t LastOffset = sizeof(dxbc::Header) + (Header.PartCount * sizeof(uint32_t)); @@ -153,6 +179,18 @@ Error DXContainer::parsePartOffsets() { if (Error Err = parsePSVInfo(PartData)) return Err; break; + case dxbc::PartType::ISG1: + if (Error Err = InputSignature.initialize(PartData)) + return Err; + break; + case dxbc::PartType::OSG1: + if (Error Err = OutputSignature.initialize(PartData)) + return Err; + break; + case dxbc::PartType::PSG1: + if (Error Err = PatchConstantSignature.initialize(PartData)) + return Err; + break; case dxbc::PartType::Unknown: break; } @@ -223,14 +261,17 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) { if (sys::IsBigEndianHost) Info.swapBytes(ShaderStage); BasicInfo = Info; - } else { + } else if (PSVVersion == 0) { v0::RuntimeInfo Info; if (Error Err = readStruct(PSVInfoData, Current, Info)) return Err; if (sys::IsBigEndianHost) Info.swapBytes(ShaderStage); BasicInfo = Info; - } + } else + return parseFailed( + "Cannot read PSV Runtime Info, unsupported PSV version."); + Current += Size; uint32_t ResourceCount = 0; @@ -251,7 +292,157 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) { "Resource binding data extends beyond the bounds of the part"); Current += BindingDataSize; + } else + Resources.Stride = sizeof(v2::ResourceBindInfo); + + // PSV version 0 ends after the resource bindings. + if (PSVVersion == 0) + return Error::success(); + + // String table starts at a 4-byte offset. + Current = reinterpret_cast<const char *>( + alignTo<4>(reinterpret_cast<uintptr_t>(Current))); + + uint32_t StringTableSize = 0; + if (Error Err = readInteger(Data, Current, StringTableSize)) + return Err; + if (StringTableSize % 4 != 0) + return parseFailed("String table misaligned"); + Current += sizeof(uint32_t); + StringTable = StringRef(Current, StringTableSize); + + Current += StringTableSize; + + uint32_t SemanticIndexTableSize = 0; + if (Error Err = readInteger(Data, Current, SemanticIndexTableSize)) + return Err; + Current += sizeof(uint32_t); + + SemanticIndexTable.reserve(SemanticIndexTableSize); + for (uint32_t I = 0; I < SemanticIndexTableSize; ++I) { + uint32_t Index = 0; + if (Error Err = readInteger(Data, Current, Index)) + return Err; + Current += sizeof(uint32_t); + SemanticIndexTable.push_back(Index); + } + + uint8_t InputCount = getSigInputCount(); + uint8_t OutputCount = getSigOutputCount(); + uint8_t PatchOrPrimCount = getSigPatchOrPrimCount(); + + uint32_t ElementCount = InputCount + OutputCount + PatchOrPrimCount; + + if (ElementCount > 0) { + if (Error Err = readInteger(Data, Current, SigInputElements.Stride)) + return Err; + Current += sizeof(uint32_t); + // Assign the stride to all the arrays. + SigOutputElements.Stride = SigPatchOrPrimElements.Stride = + SigInputElements.Stride; + + if (Data.end() - Current < ElementCount * SigInputElements.Stride) + return parseFailed( + "Signature elements extend beyond the size of the part"); + + size_t InputSize = SigInputElements.Stride * InputCount; + SigInputElements.Data = Data.substr(Current - Data.begin(), InputSize); + Current += InputSize; + + size_t OutputSize = SigOutputElements.Stride * OutputCount; + SigOutputElements.Data = Data.substr(Current - Data.begin(), OutputSize); + Current += OutputSize; + + size_t PSize = SigPatchOrPrimElements.Stride * PatchOrPrimCount; + SigPatchOrPrimElements.Data = Data.substr(Current - Data.begin(), PSize); + Current += PSize; + } + + ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts(); + uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount(); + uint8_t InputVectorCount = getInputVectorCount(); + + auto maskDwordSize = [](uint8_t Vector) { + return (static_cast<uint32_t>(Vector) + 7) >> 3; + }; + + auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) { + return maskDwordSize(Y) * X * 4; + }; + + if (usesViewID()) { + for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) { + // The vector mask is one bit per component and 4 components per vector. + // We can compute the number of dwords required by rounding up to the next + // multiple of 8. + uint32_t NumDwords = + maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I])); + size_t NumBytes = NumDwords * sizeof(uint32_t); + OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes); + Current += NumBytes; + } + + if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) { + uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount); + size_t NumBytes = NumDwords * sizeof(uint32_t); + PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes); + Current += NumBytes; + } + } + + // Input/Output mapping table + for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) { + if (InputVectorCount == 0 || OutputVectorCounts[I] == 0) + continue; + uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]); + size_t NumBytes = NumDwords * sizeof(uint32_t); + InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes); + Current += NumBytes; + } + + // Hull shader: Input/Patch mapping table + if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 && + InputVectorCount > 0) { + uint32_t NumDwords = + mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount); + size_t NumBytes = NumDwords * sizeof(uint32_t); + InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes); + Current += NumBytes; + } + + // Domain Shader: Patch/Output mapping table + if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 && + OutputVectorCounts[0] > 0) { + uint32_t NumDwords = + mapTableSize(PatchConstOrPrimVectorCount, OutputVectorCounts[0]); + size_t NumBytes = NumDwords * sizeof(uint32_t); + PatchOutputMap.Data = Data.substr(Current - Data.begin(), NumBytes); + Current += NumBytes; } return Error::success(); } + +uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const { + if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo)) + return P->SigInputElements; + if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo)) + return P->SigInputElements; + return 0; +} + +uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const { + if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo)) + return P->SigOutputElements; + if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo)) + return P->SigOutputElements; + return 0; +} + +uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const { + if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo)) + return P->SigPatchOrPrimElements; + if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo)) + return P->SigPatchOrPrimElements; + return 0; +} |
