summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp32
1 files changed, 31 insertions, 1 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index d07c0bcdf9af..27da0f21f157 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -66,6 +66,9 @@ public:
void outputExtFuncDecls();
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
SPIRV::ExecutionMode::ExecutionMode EM);
+ void outputExecutionModeFromNumthreadsAttribute(
+ const Register &Reg, const Attribute &Attr,
+ SPIRV::ExecutionMode::ExecutionMode EM);
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
@@ -412,6 +415,29 @@ void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
outputMCInst(Inst);
}
+void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
+ const Register &Reg, const Attribute &Attr,
+ SPIRV::ExecutionMode::ExecutionMode EM) {
+ assert(Attr.isValid() && "Function called with an invalid attribute.");
+
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(Reg));
+ Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
+
+ SmallVector<StringRef> NumThreads;
+ Attr.getValueAsString().split(NumThreads, ',');
+ assert(NumThreads.size() == 3 && "invalid numthreads");
+ for (uint32_t i = 0; i < 3; ++i) {
+ uint32_t V;
+ [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(10, V);
+ assert(!Result && "Failed to parse numthreads");
+ Inst.addOperand(MCOperand::createImm(V));
+ }
+
+ outputMCInst(Inst);
+}
+
void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
@@ -431,6 +457,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSize);
+ if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid())
+ outputExecutionModeFromNumthreadsAttribute(
+ FReg, Attr, SPIRV::ExecutionMode::LocalSize);
if (MDNode *Node = F.getMetadata("work_group_size_hint"))
outputExecutionModeFromMDNode(FReg, Node,
SPIRV::ExecutionMode::LocalSizeHint);
@@ -447,7 +476,7 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
Inst.addOperand(MCOperand::createImm(TypeCode));
outputMCInst(Inst);
}
- if (!M.getNamedMetadata("spirv.ExecutionMode") &&
+ if (ST->isOpenCLEnv() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
@@ -542,4 +571,5 @@ bool SPIRVAsmPrinter::doInitialization(Module &M) {
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() {
RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
+ RegisterAsmPrinter<SPIRVAsmPrinter> Z(getTheSPIRVLogicalTarget());
}