diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /clang/lib/CodeGen/CGHLSLRuntime.cpp | |
parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) |
Diffstat (limited to 'clang/lib/CodeGen/CGHLSLRuntime.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGHLSLRuntime.cpp | 413 |
1 files changed, 410 insertions, 3 deletions
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 7dfcc65969a8..5882f491d597 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -13,16 +13,22 @@ //===----------------------------------------------------------------------===// #include "CGHLSLRuntime.h" +#include "CGDebugInfo.h" #include "CodeGenModule.h" +#include "clang/AST/Decl.h" #include "clang/Basic/TargetOptions.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/Support/FormatVariadic.h" using namespace clang; using namespace CodeGen; +using namespace clang::hlsl; using namespace llvm; namespace { + void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. // Assume ValVersionStr is legal here. @@ -39,14 +45,415 @@ void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { IRBuilder<> B(M.getContext()); MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), ConstantAsMetadata::get(B.getInt32(Minor))}); - StringRef DxilValKey = "dx.valver"; - M.addModuleFlag(llvm::Module::ModFlagBehavior::AppendUnique, DxilValKey, Val); + StringRef DXILValKey = "dx.valver"; + auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); + DXILValMD->addOperand(Val); +} +void addDisableOptimizations(llvm::Module &M) { + StringRef Key = "dx.disable_optimizations"; + M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); +} +// cbuffer will be translated into global variable in special address space. +// If translate into C, +// cbuffer A { +// float a; +// float b; +// } +// float foo() { return a + b; } +// +// will be translated into +// +// struct A { +// float a; +// float b; +// } cbuffer_A __attribute__((address_space(4))); +// float foo() { return cbuffer_A.a + cbuffer_A.b; } +// +// layoutBuffer will create the struct A type. +// replaceBuffer will replace use of global variable a and b with cbuffer_A.a +// and cbuffer_A.b. +// +void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { + if (Buf.Constants.empty()) + return; + + std::vector<llvm::Type *> EltTys; + for (auto &Const : Buf.Constants) { + GlobalVariable *GV = Const.first; + Const.second = EltTys.size(); + llvm::Type *Ty = GV->getValueType(); + EltTys.emplace_back(Ty); + } + Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); +} + +GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { + // Create global variable for CB. + GlobalVariable *CBGV = new GlobalVariable( + Buf.LayoutStruct, /*isConstant*/ true, + GlobalValue::LinkageTypes::ExternalLinkage, nullptr, + llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), + GlobalValue::NotThreadLocal); + + IRBuilder<> B(CBGV->getContext()); + Value *ZeroIdx = B.getInt32(0); + // Replace Const use with CB use. + for (auto &[GV, Offset] : Buf.Constants) { + Value *GEP = + B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); + + assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && + "constant type mismatch"); + + // Replace. + GV->replaceAllUsesWith(GEP); + // Erase GV. + GV->removeDeadConstantUsers(); + GV->eraseFromParent(); + } + return CBGV; } + } // namespace +void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { + if (D->getStorageClass() == SC_Static) { + // For static inside cbuffer, take as global static. + // Don't add to cbuffer. + CGM.EmitGlobal(D); + return; + } + + auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); + // Add debug info for constVal. + if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) + if (CGM.getCodeGenOpts().getDebugInfo() >= + codegenoptions::DebugInfoKind::LimitedDebugInfo) + DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); + + // FIXME: support packoffset. + // See https://github.com/llvm/llvm-project/issues/57914. + uint32_t Offset = 0; + bool HasUserOffset = false; + + unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; + CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); +} + +void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { + for (Decl *it : DC->decls()) { + if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { + addConstant(ConstDecl, CB); + } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { + // Nothing to do for this declaration. + } else if (isa<FunctionDecl>(it)) { + // A function within an cbuffer is effectively a top-level function, + // as it only refers to globally scoped declarations. + CGM.EmitTopLevelDecl(it); + } + } +} + +void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { + Buffers.emplace_back(Buffer(D)); + addBufferDecls(D, Buffers.back()); +} + void CGHLSLRuntime::finishCodeGen() { auto &TargetOpts = CGM.getTarget().getTargetOpts(); + llvm::Module &M = CGM.getModule(); + Triple T(M.getTargetTriple()); + if (T.getArch() == Triple::ArchType::dxil) + addDxilValVersion(TargetOpts.DxilValidatorVersion, M); + + generateGlobalCtorDtorCalls(); + if (CGM.getCodeGenOpts().OptimizationLevel == 0) + addDisableOptimizations(M); + + const DataLayout &DL = M.getDataLayout(); + + for (auto &Buf : Buffers) { + layoutBuffer(Buf, DL); + GlobalVariable *GV = replaceBuffer(Buf); + M.getGlobalList().push_back(GV); + llvm::hlsl::ResourceClass RC = Buf.IsCBuffer + ? llvm::hlsl::ResourceClass::CBuffer + : llvm::hlsl::ResourceClass::SRV; + llvm::hlsl::ResourceKind RK = Buf.IsCBuffer + ? llvm::hlsl::ResourceKind::CBuffer + : llvm::hlsl::ResourceKind::TBuffer; + std::string TyName = + Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty"; + addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding); + } +} + +CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) + : Name(D->getName()), IsCBuffer(D->isCBuffer()), + Binding(D->getAttr<HLSLResourceBindingAttr>()) {} +void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, + llvm::StringRef TyName, + llvm::hlsl::ResourceClass RC, + llvm::hlsl::ResourceKind RK, + BufferResBinding &Binding) { llvm::Module &M = CGM.getModule(); - addDxilValVersion(TargetOpts.DxilValidatorVersion, M); + + NamedMDNode *ResourceMD = nullptr; + switch (RC) { + case llvm::hlsl::ResourceClass::UAV: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); + break; + case llvm::hlsl::ResourceClass::SRV: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); + break; + case llvm::hlsl::ResourceClass::CBuffer: + ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); + break; + default: + assert(false && "Unsupported buffer type!"); + return; + } + + assert(ResourceMD != nullptr && + "ResourceMD must have been set by the switch above."); + + llvm::hlsl::FrontendResource Res( + GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space); + ResourceMD->addOperand(Res.getMetadata()); +} + +static llvm::hlsl::ResourceKind +castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) { + switch (RK) { + case HLSLResourceAttr::ResourceKind::Texture1D: + return llvm::hlsl::ResourceKind::Texture1D; + case HLSLResourceAttr::ResourceKind::Texture2D: + return llvm::hlsl::ResourceKind::Texture2D; + case HLSLResourceAttr::ResourceKind::Texture2DMS: + return llvm::hlsl::ResourceKind::Texture2DMS; + case HLSLResourceAttr::ResourceKind::Texture3D: + return llvm::hlsl::ResourceKind::Texture3D; + case HLSLResourceAttr::ResourceKind::TextureCube: + return llvm::hlsl::ResourceKind::TextureCube; + case HLSLResourceAttr::ResourceKind::Texture1DArray: + return llvm::hlsl::ResourceKind::Texture1DArray; + case HLSLResourceAttr::ResourceKind::Texture2DArray: + return llvm::hlsl::ResourceKind::Texture2DArray; + case HLSLResourceAttr::ResourceKind::Texture2DMSArray: + return llvm::hlsl::ResourceKind::Texture2DMSArray; + case HLSLResourceAttr::ResourceKind::TextureCubeArray: + return llvm::hlsl::ResourceKind::TextureCubeArray; + case HLSLResourceAttr::ResourceKind::TypedBuffer: + return llvm::hlsl::ResourceKind::TypedBuffer; + case HLSLResourceAttr::ResourceKind::RawBuffer: + return llvm::hlsl::ResourceKind::RawBuffer; + case HLSLResourceAttr::ResourceKind::StructuredBuffer: + return llvm::hlsl::ResourceKind::StructuredBuffer; + case HLSLResourceAttr::ResourceKind::CBufferKind: + return llvm::hlsl::ResourceKind::CBuffer; + case HLSLResourceAttr::ResourceKind::SamplerKind: + return llvm::hlsl::ResourceKind::Sampler; + case HLSLResourceAttr::ResourceKind::TBuffer: + return llvm::hlsl::ResourceKind::TBuffer; + case HLSLResourceAttr::ResourceKind::RTAccelerationStructure: + return llvm::hlsl::ResourceKind::RTAccelerationStructure; + case HLSLResourceAttr::ResourceKind::FeedbackTexture2D: + return llvm::hlsl::ResourceKind::FeedbackTexture2D; + case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray: + return llvm::hlsl::ResourceKind::FeedbackTexture2DArray; + } + // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to + // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for + // HLSLResourceAttr::ResourceKind. + static_assert( + static_cast<uint32_t>( + HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) == + (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2)); + llvm_unreachable("all switch cases should be covered"); +} + +void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { + const Type *Ty = D->getType()->getPointeeOrArrayElementType(); + if (!Ty) + return; + const auto *RD = Ty->getAsCXXRecordDecl(); + if (!RD) + return; + const auto *Attr = RD->getAttr<HLSLResourceAttr>(); + if (!Attr) + return; + + HLSLResourceAttr::ResourceClass RC = Attr->getResourceType(); + llvm::hlsl::ResourceKind RK = + castResourceShapeToResourceKind(Attr->getResourceShape()); + + QualType QT(Ty, 0); + BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); + addBufferResourceAnnotation(GV, QT.getAsString(), + static_cast<llvm::hlsl::ResourceClass>(RC), RK, + Binding); +} + +CGHLSLRuntime::BufferResBinding::BufferResBinding( + HLSLResourceBindingAttr *Binding) { + if (Binding) { + llvm::APInt RegInt(64, 0); + Binding->getSlot().substr(1).getAsInteger(10, RegInt); + Reg = RegInt.getLimitedValue(); + llvm::APInt SpaceInt(64, 0); + Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); + Space = SpaceInt.getLimitedValue(); + } else { + Space = 0; + } +} + +void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( + const FunctionDecl *FD, llvm::Function *Fn) { + const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); + assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); + const StringRef ShaderAttrKindStr = "hlsl.shader"; + Fn->addFnAttr(ShaderAttrKindStr, + ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); + if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { + const StringRef NumThreadsKindStr = "hlsl.numthreads"; + std::string NumThreadsStr = + formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), + NumThreadsAttr->getZ()); + Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); + } +} + +static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { + if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { + Value *Result = PoisonValue::get(Ty); + for (unsigned I = 0; I < VT->getNumElements(); ++I) { + Value *Elt = B.CreateCall(F, {B.getInt32(I)}); + Result = B.CreateInsertElement(Result, Elt, I); + } + return Result; + } + return B.CreateCall(F, {B.getInt32(0)}); +} + +llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, + const ParmVarDecl &D, + llvm::Type *Ty) { + assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); + if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { + llvm::Function *DxGroupIndex = + CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); + return B.CreateCall(FunctionCallee(DxGroupIndex)); + } + if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { + llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); + return buildVectorInput(B, DxThreadID, Ty); + } + assert(false && "Unhandled parameter attribute"); + return nullptr; +} + +void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, + llvm::Function *Fn) { + llvm::Module &M = CGM.getModule(); + llvm::LLVMContext &Ctx = M.getContext(); + auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); + Function *EntryFn = + Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); + + // Copy function attributes over, we have no argument or return attributes + // that can be valid on the real entry. + AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, + Fn->getAttributes().getFnAttrs()); + EntryFn->setAttributes(NewAttrs); + setHLSLEntryAttributes(FD, EntryFn); + + // Set the called function as internal linkage. + Fn->setLinkage(GlobalValue::InternalLinkage); + + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); + IRBuilder<> B(BB); + llvm::SmallVector<Value *> Args; + // FIXME: support struct parameters where semantics are on members. + // See: https://github.com/llvm/llvm-project/issues/57874 + unsigned SRetOffset = 0; + for (const auto &Param : Fn->args()) { + if (Param.hasStructRetAttr()) { + // FIXME: support output. + // See: https://github.com/llvm/llvm-project/issues/57874 + SRetOffset = 1; + Args.emplace_back(PoisonValue::get(Param.getType())); + continue; + } + const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); + Args.push_back(emitInputSemantic(B, *PD, Param.getType())); + } + + CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); + (void)CI; + // FIXME: Handle codegen for return type semantics. + // See: https://github.com/llvm/llvm-project/issues/57875 + B.CreateRetVoid(); +} + +static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, + bool CtorOrDtor) { + const auto *GV = + M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); + if (!GV) + return; + const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); + if (!CA) + return; + // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. + // HLSL neither supports priorities or COMDat values, so we will check those + // in an assert but not handle them. + + llvm::SmallVector<Function *> CtorFns; + for (const auto &Ctor : CA->operands()) { + if (isa<ConstantAggregateZero>(Ctor)) + continue; + ConstantStruct *CS = cast<ConstantStruct>(Ctor); + + assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && + "HLSL doesn't support setting priority for global ctors."); + assert(isa<ConstantPointerNull>(CS->getOperand(2)) && + "HLSL doesn't support COMDat for global ctors."); + Fns.push_back(cast<Function>(CS->getOperand(1))); + } +} + +void CGHLSLRuntime::generateGlobalCtorDtorCalls() { + llvm::Module &M = CGM.getModule(); + SmallVector<Function *> CtorFns; + SmallVector<Function *> DtorFns; + gatherFunctions(CtorFns, M, true); + gatherFunctions(DtorFns, M, false); + + // Insert a call to the global constructor at the beginning of the entry block + // to externally exported functions. This is a bit of a hack, but HLSL allows + // global constructors, but doesn't support driver initialization of globals. + for (auto &F : M.functions()) { + if (!F.hasFnAttribute("hlsl.shader")) + continue; + IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); + for (auto *Fn : CtorFns) + B.CreateCall(FunctionCallee(Fn)); + + // Insert global dtors before the terminator of the last instruction + B.SetInsertPoint(F.back().getTerminator()); + for (auto *Fn : DtorFns) + B.CreateCall(FunctionCallee(Fn)); + } + + // No need to keep global ctors/dtors for non-lib profile after call to + // ctors/dtors added for entry. + Triple T(M.getTargetTriple()); + if (T.getEnvironment() != Triple::EnvironmentType::Library) { + if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) + GV->eraseFromParent(); + if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) + GV->eraseFromParent(); + } } |