diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 164 |
1 files changed, 161 insertions, 3 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 605bf949187f..6d60bd5e3c97 100644 --- a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -21,6 +21,7 @@ #include "SPIRVUtils.h" #include "TargetInfo/SPIRVTargetInfo.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -58,9 +59,14 @@ public: void outputModuleSection(SPIRV::ModuleSectionType MSType); void outputEntryPoints(); void outputDebugSourceAndStrings(const Module &M); + void outputOpExtInstImports(const Module &M); void outputOpMemoryModel(); void outputOpFunctionEnd(); void outputExtFuncDecls(); + void outputExecutionModeFromMDNode(Register Reg, MDNode *Node, + SPIRV::ExecutionMode EM); + void outputExecutionMode(const Module &M); + void outputAnnotations(const Module &M); void outputModuleSections(); void emitInstruction(const MachineInstr *MI) override; @@ -127,6 +133,8 @@ void SPIRVAsmPrinter::emitFunctionBodyEnd() { } void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { + if (MAI->MBBsToSkip.contains(&MBB)) + return; MCInst LabelInst; LabelInst.setOpcode(SPIRV::OpLabel); LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); @@ -237,6 +245,13 @@ void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { } void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) { + // Output OpSourceExtensions. + for (auto &Str : MAI->SrcExt) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpSourceExtension); + addStringImm(Str.first(), Inst); + outputMCInst(Inst); + } // Output OpSource. MCInst Inst; Inst.setOpcode(SPIRV::OpSource); @@ -246,6 +261,19 @@ void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) { outputMCInst(Inst); } +void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) { + for (auto &CU : MAI->ExtInstSetMap) { + unsigned Set = CU.first; + Register Reg = CU.second; + MCInst Inst; + Inst.setOpcode(SPIRV::OpExtInstImport); + Inst.addOperand(MCOperand::createReg(Reg)); + addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)), + Inst); + outputMCInst(Inst); + } +} + void SPIRVAsmPrinter::outputOpMemoryModel() { MCInst Inst; Inst.setOpcode(SPIRV::OpMemoryModel); @@ -301,6 +329,135 @@ void SPIRVAsmPrinter::outputExtFuncDecls() { } } +// Encode LLVM type by SPIR-V execution mode VecTypeHint. +static unsigned encodeVecTypeHint(Type *Ty) { + if (Ty->isHalfTy()) + return 4; + if (Ty->isFloatTy()) + return 5; + if (Ty->isDoubleTy()) + return 6; + if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) { + switch (IntTy->getIntegerBitWidth()) { + case 8: + return 0; + case 16: + return 1; + case 32: + return 2; + case 64: + return 3; + default: + llvm_unreachable("invalid integer type"); + } + } + if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) { + Type *EleTy = VecTy->getElementType(); + unsigned Size = VecTy->getNumElements(); + return Size << 16 | encodeVecTypeHint(EleTy); + } + llvm_unreachable("invalid type"); +} + +static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst, + SPIRV::ModuleAnalysisInfo *MAI) { + for (const MDOperand &MDOp : MDN->operands()) { + if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { + Constant *C = CMeta->getValue(); + if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { + Inst.addOperand(MCOperand::createImm(Const->getZExtValue())); + } else if (auto *CE = dyn_cast<Function>(C)) { + Register FuncReg = MAI->getFuncReg(CE->getName().str()); + assert(FuncReg.isValid()); + Inst.addOperand(MCOperand::createReg(FuncReg)); + } + } + } +} + +void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node, + SPIRV::ExecutionMode EM) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(Reg)); + Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM))); + addOpsFromMDNode(Node, Inst, MAI); + outputMCInst(Inst); +} + +void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { + NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (Node) { + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); + outputMCInst(Inst); + } + } + for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { + const Function &F = *FI; + if (F.isDeclaration()) + continue; + Register FReg = MAI->getFuncReg(F.getGlobalIdentifier()); + assert(FReg.isValid()); + if (MDNode *Node = F.getMetadata("reqd_work_group_size")) + outputExecutionModeFromMDNode(FReg, Node, + SPIRV::ExecutionMode::LocalSize); + if (MDNode *Node = F.getMetadata("work_group_size_hint")) + outputExecutionModeFromMDNode(FReg, Node, + SPIRV::ExecutionMode::LocalSizeHint); + if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size")) + outputExecutionModeFromMDNode(FReg, Node, + SPIRV::ExecutionMode::SubgroupSize); + if (MDNode *Node = F.getMetadata("vec_type_hint")) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint); + Inst.addOperand(MCOperand::createImm(EM)); + unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0)); + Inst.addOperand(MCOperand::createImm(TypeCode)); + outputMCInst(Inst); + } + } +} + +void SPIRVAsmPrinter::outputAnnotations(const Module &M) { + outputModuleSection(SPIRV::MB_Annotations); + // Process llvm.global.annotations special global variable. + for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) { + if ((*F).getName() != "llvm.global.annotations") + continue; + const GlobalVariable *V = &(*F); + const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0)); + for (Value *Op : CA->operands()) { + ConstantStruct *CS = cast<ConstantStruct>(Op); + // The first field of the struct contains a pointer to + // the annotated variable. + Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts(); + if (!isa<Function>(AnnotatedVar)) + llvm_unreachable("Unsupported value in llvm.global.annotations"); + Function *Func = cast<Function>(AnnotatedVar); + Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier()); + + // The second field contains a pointer to a global annotation string. + GlobalVariable *GV = + cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts()); + + StringRef AnnotationString; + getConstantStringInfo(GV, AnnotationString); + MCInst Inst; + Inst.setOpcode(SPIRV::OpDecorate); + Inst.addOperand(MCOperand::createReg(Reg)); + unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic); + Inst.addOperand(MCOperand::createImm(Dec)); + addStringImm(AnnotationString, Inst); + outputMCInst(Inst); + } + } +} + void SPIRVAsmPrinter::outputModuleSections() { const Module *M = MMI->getModule(); // Get the global subtarget to output module-level info. @@ -311,13 +468,14 @@ void SPIRVAsmPrinter::outputModuleSections() { // Output instructions according to the Logical Layout of a Module: // TODO: 1,2. All OpCapability instructions, then optional OpExtension // instructions. - // TODO: 3. Optional OpExtInstImport instructions. + // 3. Optional OpExtInstImport instructions. + outputOpExtInstImports(*M); // 4. The single required OpMemoryModel instruction. outputOpMemoryModel(); // 5. All entry point declarations, using OpEntryPoint. outputEntryPoints(); // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. - // TODO: + outputExecutionMode(*M); // 7a. Debug: all OpString, OpSourceExtension, OpSource, and // OpSourceContinued, without forward references. outputDebugSourceAndStrings(*M); @@ -326,7 +484,7 @@ void SPIRVAsmPrinter::outputModuleSections() { // 7c. Debug: all OpModuleProcessed instructions. outputModuleSection(SPIRV::MB_DebugModuleProcessed); // 8. All annotation instructions (all decorations). - outputModuleSection(SPIRV::MB_Annotations); + outputAnnotations(*M); // 9. All type declarations (OpTypeXXX instructions), all constant // instructions, and all global variable declarations. This section is // the first section to allow use of: OpLine and OpNoLine debug information; |