diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 176 |
1 files changed, 145 insertions, 31 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index e8fedfeffde7..8b618686ee7d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -14,12 +14,14 @@ #include "SPIRVCallLowering.h" #include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" +#include "SPIRVBuiltins.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVISelLowering.h" #include "SPIRVRegisterInfo.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" +#include "llvm/Support/ModRef.h" using namespace llvm; @@ -48,19 +50,20 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, // Based on the LLVM function attributes, get a SPIR-V FunctionControl. static uint32_t getFunctionControl(const Function &F) { + MemoryEffects MemEffects = F.getMemoryEffects(); + uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); - if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { + + if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) + FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); + else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); - } - if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { + + if (MemEffects.doesNotAccessMemory()) FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); - } - if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { + else if (MemEffects.onlyReadsMemory()) FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); - } - if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { - FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); - } + return FuncControl; } @@ -114,6 +117,102 @@ static FunctionType *getOriginalFunctionType(const Function &F) { return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); } +static MDString *getKernelArgAttribute(const Function &KernelFunction, + unsigned ArgIdx, + const StringRef AttributeName) { + assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && + "Kernel attributes are attached/belong only to kernel functions"); + + // Lookup the argument attribute in metadata attached to the kernel function. + MDNode *Node = KernelFunction.getMetadata(AttributeName); + if (Node && ArgIdx < Node->getNumOperands()) + return cast<MDString>(Node->getOperand(ArgIdx)); + + // Sometimes metadata containing kernel attributes is not attached to the + // function, but can be found in the named module-level metadata instead. + // For example: + // !opencl.kernels = !{!0} + // !0 = !{void ()* @someKernelFunction, !1, ...} + // !1 = !{!"kernel_arg_addr_space", ...} + // In this case the actual index of searched argument attribute is ArgIdx + 1, + // since the first metadata node operand is occupied by attribute name + // ("kernel_arg_addr_space" in the example above). + unsigned MDArgIdx = ArgIdx + 1; + NamedMDNode *OpenCLKernelsMD = + KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); + if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) + return nullptr; + + // KernelToMDNodeList contains kernel function declarations followed by + // corresponding MDNodes for each attribute. Search only MDNodes "belonging" + // to the currently lowered kernel function. + MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); + bool FoundLoweredKernelFunction = false; + for (const MDOperand &Operand : KernelToMDNodeList->operands()) { + ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand); + if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() == + KernelFunction.getName()) { + FoundLoweredKernelFunction = true; + continue; + } + if (MaybeValue && FoundLoweredKernelFunction) + return nullptr; + + MDNode *MaybeNode = dyn_cast<MDNode>(Operand); + if (FoundLoweredKernelFunction && MaybeNode && + cast<MDString>(MaybeNode->getOperand(0))->getString() == + AttributeName && + MDArgIdx < MaybeNode->getNumOperands()) + return cast<MDString>(MaybeNode->getOperand(MDArgIdx)); + } + return nullptr; +} + +static SPIRV::AccessQualifier::AccessQualifier +getArgAccessQual(const Function &F, unsigned ArgIdx) { + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + return SPIRV::AccessQualifier::ReadWrite; + + MDString *ArgAttribute = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); + if (!ArgAttribute) + return SPIRV::AccessQualifier::ReadWrite; + + if (ArgAttribute->getString().compare("read_only") == 0) + return SPIRV::AccessQualifier::ReadOnly; + if (ArgAttribute->getString().compare("write_only") == 0) + return SPIRV::AccessQualifier::WriteOnly; + return SPIRV::AccessQualifier::ReadWrite; +} + +static std::vector<SPIRV::Decoration::Decoration> +getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { + MDString *ArgAttribute = + getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); + if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) + return {SPIRV::Decoration::Volatile}; + return {}; +} + +static Type *getArgType(const Function &F, unsigned ArgIdx) { + Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); + if (F.getCallingConv() != CallingConv::SPIR_KERNEL || + isSpecialOpaqueType(OriginalArgType)) + return OriginalArgType; + + MDString *MDKernelArgType = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); + if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t")) + return OriginalArgType; + + std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str(); + Type *ExistingOpaqueType = + StructType::getTypeByName(F.getContext(), KernelArgTypeStr); + return ExistingOpaqueType + ? ExistingOpaqueType + : StructType::create(F.getContext(), KernelArgTypeStr); +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef<ArrayRef<Register>> VRegs, @@ -131,17 +230,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - Type *ArgTy = FTy->getParamType(i); - SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite; - MDNode *Node = F.getMetadata("kernel_arg_access_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef AQString = cast<MDString>(Node->getOperand(i))->getString(); - if (AQString.compare("read_only") == 0) - AQ = SPIRV::AccessQualifier::ReadOnly; - else if (AQString.compare("write_only") == 0) - AQ = SPIRV::AccessQualifier::WriteOnly; - } - auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); + SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = + getArgAccessQual(F, i); + auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], + MIRBuilder, ArgAccessQual); ArgTypeVRegs.push_back(SpirvTy); if (Arg.hasName()) @@ -176,14 +268,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::FuncParamAttr, {Attr}); } - Node = F.getMetadata("kernel_arg_type_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString(); - if (TypeQual.compare("volatile") == 0) - buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, - {}); + + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = + getKernelArgTypeQual(F, i); + for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); } - Node = F.getMetadata("spirv.ParameterDecorations"); + + MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); if (Node && i < Node->getNumOperands() && isa<MDNode>(Node->getOperand(i))) { MDNode *MD = cast<MDNode>(Node->getOperand(i)); @@ -192,7 +285,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, assert(MD2 && "Metadata operand is expected"); ConstantInt *Const = getConstInt(MD2, 0); assert(Const && "MDOperand should be ConstantInt"); - auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue()); + auto Dec = + static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); std::vector<uint32_t> DecVec; for (unsigned j = 1; j < MD2->getNumOperands(); j++) { ConstantInt *Const = getConstInt(MD2, j); @@ -282,6 +376,27 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + std::string FuncName = Info.Callee.getGlobal()->getName().str(); + std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); + const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); + // TODO: check that it's OCL builtin, then apply OpenCL_std. + if (!DemangledName.empty() && CF && CF->isDeclaration() && + ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { + const Type *OrigRetTy = Info.OrigRet.Ty; + if (FTy) + OrigRetTy = FTy->getReturnType(); + SmallVector<Register, 8> ArgVRegs; + for (auto Arg : Info.OrigArgs) { + assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); + ArgVRegs.push_back(Arg.Regs[0]); + SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); + GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); + } + if (auto Res = SPIRV::lowerBuiltin( + DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder, + ResVReg, OrigRetTy, ArgVRegs, GR)) + return *Res; + } if (CF && CF->isDeclaration() && !GR->find(CF, &MIRBuilder.getMF()).isValid()) { // Emit the type info and forward function declaration to the first MBB @@ -322,7 +437,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, return false; MIB.addUse(Arg.Regs[0]); } - const auto &STI = MF.getSubtarget(); - return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), - *STI.getRegBankInfo()); + return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), + *ST->getRegBankInfo()); } |