diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index d07c0bcdf9af..27da0f21f157 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -66,6 +66,9 @@ public: void outputExtFuncDecls(); void outputExecutionModeFromMDNode(Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM); + void outputExecutionModeFromNumthreadsAttribute( + const Register &Reg, const Attribute &Attr, + SPIRV::ExecutionMode::ExecutionMode EM); void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); @@ -412,6 +415,29 @@ void SPIRVAsmPrinter::outputExecutionModeFromMDNode( outputMCInst(Inst); } +void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute( + const Register &Reg, const Attribute &Attr, + SPIRV::ExecutionMode::ExecutionMode EM) { + assert(Attr.isValid() && "Function called with an invalid attribute."); + + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(Reg)); + Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM))); + + SmallVector<StringRef> NumThreads; + Attr.getValueAsString().split(NumThreads, ','); + assert(NumThreads.size() == 3 && "invalid numthreads"); + for (uint32_t i = 0; i < 3; ++i) { + uint32_t V; + [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(10, V); + assert(!Result && "Failed to parse numthreads"); + Inst.addOperand(MCOperand::createImm(V)); + } + + outputMCInst(Inst); +} + void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { @@ -431,6 +457,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { if (MDNode *Node = F.getMetadata("reqd_work_group_size")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize); + if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid()) + outputExecutionModeFromNumthreadsAttribute( + FReg, Attr, SPIRV::ExecutionMode::LocalSize); if (MDNode *Node = F.getMetadata("work_group_size_hint")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSizeHint); @@ -447,7 +476,7 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { Inst.addOperand(MCOperand::createImm(TypeCode)); outputMCInst(Inst); } - if (!M.getNamedMetadata("spirv.ExecutionMode") && + if (ST->isOpenCLEnv() && !M.getNamedMetadata("spirv.ExecutionMode") && !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); @@ -542,4 +571,5 @@ bool SPIRVAsmPrinter::doInitialization(Module &M) { extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() { RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target()); RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target()); + RegisterAsmPrinter<SPIRVAsmPrinter> Z(getTheSPIRVLogicalTarget()); } |
