aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp176
1 files changed, 145 insertions, 31 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e8fedfeffde7..8b618686ee7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -14,12 +14,14 @@
#include "SPIRVCallLowering.h"
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
+#include "SPIRVBuiltins.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVISelLowering.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
+#include "llvm/Support/ModRef.h"
using namespace llvm;
@@ -48,19 +50,20 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
static uint32_t getFunctionControl(const Function &F) {
+ MemoryEffects MemEffects = F.getMemoryEffects();
+
uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
- if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
+
+ if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
+ FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
+ else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
- }
- if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
+
+ if (MemEffects.doesNotAccessMemory())
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
- }
- if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
+ else if (MemEffects.onlyReadsMemory())
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
- }
- if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
- FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
- }
+
return FuncControl;
}
@@ -114,6 +117,102 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}
+static MDString *getKernelArgAttribute(const Function &KernelFunction,
+ unsigned ArgIdx,
+ const StringRef AttributeName) {
+ assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
+ "Kernel attributes are attached/belong only to kernel functions");
+
+ // Lookup the argument attribute in metadata attached to the kernel function.
+ MDNode *Node = KernelFunction.getMetadata(AttributeName);
+ if (Node && ArgIdx < Node->getNumOperands())
+ return cast<MDString>(Node->getOperand(ArgIdx));
+
+ // Sometimes metadata containing kernel attributes is not attached to the
+ // function, but can be found in the named module-level metadata instead.
+ // For example:
+ // !opencl.kernels = !{!0}
+ // !0 = !{void ()* @someKernelFunction, !1, ...}
+ // !1 = !{!"kernel_arg_addr_space", ...}
+ // In this case the actual index of searched argument attribute is ArgIdx + 1,
+ // since the first metadata node operand is occupied by attribute name
+ // ("kernel_arg_addr_space" in the example above).
+ unsigned MDArgIdx = ArgIdx + 1;
+ NamedMDNode *OpenCLKernelsMD =
+ KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
+ if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
+ return nullptr;
+
+ // KernelToMDNodeList contains kernel function declarations followed by
+ // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
+ // to the currently lowered kernel function.
+ MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
+ bool FoundLoweredKernelFunction = false;
+ for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
+ ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
+ if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
+ KernelFunction.getName()) {
+ FoundLoweredKernelFunction = true;
+ continue;
+ }
+ if (MaybeValue && FoundLoweredKernelFunction)
+ return nullptr;
+
+ MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
+ if (FoundLoweredKernelFunction && MaybeNode &&
+ cast<MDString>(MaybeNode->getOperand(0))->getString() ==
+ AttributeName &&
+ MDArgIdx < MaybeNode->getNumOperands())
+ return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
+ }
+ return nullptr;
+}
+
+static SPIRV::AccessQualifier::AccessQualifier
+getArgAccessQual(const Function &F, unsigned ArgIdx) {
+ if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
+ return SPIRV::AccessQualifier::ReadWrite;
+
+ MDString *ArgAttribute =
+ getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+ if (!ArgAttribute)
+ return SPIRV::AccessQualifier::ReadWrite;
+
+ if (ArgAttribute->getString().compare("read_only") == 0)
+ return SPIRV::AccessQualifier::ReadOnly;
+ if (ArgAttribute->getString().compare("write_only") == 0)
+ return SPIRV::AccessQualifier::WriteOnly;
+ return SPIRV::AccessQualifier::ReadWrite;
+}
+
+static std::vector<SPIRV::Decoration::Decoration>
+getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
+ MDString *ArgAttribute =
+ getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
+ if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
+ return {SPIRV::Decoration::Volatile};
+ return {};
+}
+
+static Type *getArgType(const Function &F, unsigned ArgIdx) {
+ Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+ if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
+ isSpecialOpaqueType(OriginalArgType))
+ return OriginalArgType;
+
+ MDString *MDKernelArgType =
+ getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+ if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
+ return OriginalArgType;
+
+ std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
+ Type *ExistingOpaqueType =
+ StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
+ return ExistingOpaqueType
+ ? ExistingOpaqueType
+ : StructType::create(F.getContext(), KernelArgTypeStr);
+}
+
bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const Function &F,
ArrayRef<ArrayRef<Register>> VRegs,
@@ -131,17 +230,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
- Type *ArgTy = FTy->getParamType(i);
- SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite;
- MDNode *Node = F.getMetadata("kernel_arg_access_qual");
- if (Node && i < Node->getNumOperands()) {
- StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
- if (AQString.compare("read_only") == 0)
- AQ = SPIRV::AccessQualifier::ReadOnly;
- else if (AQString.compare("write_only") == 0)
- AQ = SPIRV::AccessQualifier::WriteOnly;
- }
- auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
+ SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
+ getArgAccessQual(F, i);
+ auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
+ MIRBuilder, ArgAccessQual);
ArgTypeVRegs.push_back(SpirvTy);
if (Arg.hasName())
@@ -176,14 +268,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::FuncParamAttr, {Attr});
}
- Node = F.getMetadata("kernel_arg_type_qual");
- if (Node && i < Node->getNumOperands()) {
- StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
- if (TypeQual.compare("volatile") == 0)
- buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
- {});
+
+ if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
+ std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
+ getKernelArgTypeQual(F, i);
+ for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
+ buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
}
- Node = F.getMetadata("spirv.ParameterDecorations");
+
+ MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
if (Node && i < Node->getNumOperands() &&
isa<MDNode>(Node->getOperand(i))) {
MDNode *MD = cast<MDNode>(Node->getOperand(i));
@@ -192,7 +285,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
assert(MD2 && "Metadata operand is expected");
ConstantInt *Const = getConstInt(MD2, 0);
assert(Const && "MDOperand should be ConstantInt");
- auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue());
+ auto Dec =
+ static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
std::vector<uint32_t> DecVec;
for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
ConstantInt *Const = getConstInt(MD2, j);
@@ -282,6 +376,27 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
+ std::string FuncName = Info.Callee.getGlobal()->getName().str();
+ std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
+ const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
+ // TODO: check that it's OCL builtin, then apply OpenCL_std.
+ if (!DemangledName.empty() && CF && CF->isDeclaration() &&
+ ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
+ const Type *OrigRetTy = Info.OrigRet.Ty;
+ if (FTy)
+ OrigRetTy = FTy->getReturnType();
+ SmallVector<Register, 8> ArgVRegs;
+ for (auto Arg : Info.OrigArgs) {
+ assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
+ ArgVRegs.push_back(Arg.Regs[0]);
+ SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
+ GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
+ }
+ if (auto Res = SPIRV::lowerBuiltin(
+ DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
+ ResVReg, OrigRetTy, ArgVRegs, GR))
+ return *Res;
+ }
if (CF && CF->isDeclaration() &&
!GR->find(CF, &MIRBuilder.getMF()).isValid()) {
// Emit the type info and forward function declaration to the first MBB
@@ -322,7 +437,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return false;
MIB.addUse(Arg.Regs[0]);
}
- const auto &STI = MF.getSubtarget();
- return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
- *STI.getRegBankInfo());
+ return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
+ *ST->getRegBankInfo());
}