diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 459 |
1 files changed, 459 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp new file mode 100644 index 000000000000..02a6905a1abc --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -0,0 +1,459 @@ +//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the SPIRVGlobalRegistry class, +// which is used to maintain rich type information required for SPIR-V even +// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into +// an OpTypeXXX instruction, and map it to a virtual register. Also it builds +// and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVGlobalRegistry.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" + +using namespace llvm; +SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) + : PointerSize(PointerSize) {} + +SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + + SPIRVType *SpirvType = + getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); + return SpirvType; +} + +void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, + Register VReg, + MachineFunction &MF) { + VRegToTypeMap[&MF][VReg] = SpirvType; +} + +static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { + auto &MRI = MIRBuilder.getMF().getRegInfo(); + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +static Register createTypeVReg(MachineRegisterInfo &MRI) { + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeBool) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, + MachineIRBuilder &MIRBuilder, + bool IsSigned) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(IsSigned ? 1 : 0); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto EleOpc = ElemType->getOpcode(); + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addImm(NumElems); + return MIB; +} + +Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, + bool EmitIR) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const IntegerType *LLVMIntTy; + if (SpvType) + LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); + else + LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); + // Find a constant in DT or build a new one. + const auto ConstInt = + ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + return Res; +} + +Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const Type *LLVMFPTy; + if (SpvType) { + LLVMFPTy = getTypeForSPIRVType(SpvType); + assert(LLVMFPTy->isFloatingPointTy()); + } else { + LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + return Res; +} + +Register SPIRVGlobalRegistry::buildGlobalVariable( + Register ResVReg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, + bool IsInstSelector) { + const GlobalVariable *GVar = nullptr; + if (GV) + GVar = cast<const GlobalVariable>(GV); + else { + // If GV is not passed explicitly, use the name to find or construct + // the global variable. + Module *M = MIRBuilder.getMF().getFunction().getParent(); + GVar = M->getGlobalVariable(Name); + if (GVar == nullptr) { + const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. + GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, + GlobalValue::ExternalLinkage, nullptr, + Twine(Name)); + } + GV = GVar; + } + Register Reg; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(static_cast<uint32_t>(Storage)); + + if (Init != 0) { + MIB.addUse(Init->getOperand(0).getReg()); + } + + // ISel may introduce a new register on this step, so we need to add it to + // DT and correct its type avoiding fails on the next stage. + if (IsInstSelector) { + const auto &Subtarget = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), + *Subtarget.getRegisterInfo(), + *Subtarget.getRegBankInfo()); + } + Reg = MIB->getOperand(0).getReg(); + + // Set to Reg the same type as ResVReg has. + auto MRI = MIRBuilder.getMRI(); + assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); + if (Reg != ResVReg) { + LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); + MRI->setType(Reg, RegLLTy); + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); + } + + // If it's a global variable with name, output OpName for it. + if (GVar && GVar->hasName()) + buildOpName(Reg, GVar->getName(), MIRBuilder); + + // Output decorations for the GV. + // TODO: maybe move to GenerateDecorations pass. + if (IsConst) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); + + if (GVar && GVar->getAlign().valueOrOne().value() != 1) + buildOpDecorate( + Reg, MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())}); + + if (HasLinkageTy) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast<uint32_t>(LinkageType)}, Name); + return Reg; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, + bool EmitIR) { + assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && + "Invalid array element type"); + Register NumElementsVReg = + buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(NumElementsVReg); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(static_cast<uint32_t>(SC)) + .addUse(getSPIRVTypeID(ElemType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( + SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(RetType)); + for (const SPIRVType *ArgType : ArgTypes) + MIB.addUse(getSPIRVTypeID(ArgType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, + MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccQual, + bool EmitIR) { + if (auto IType = dyn_cast<IntegerType>(Ty)) { + const unsigned Width = IType->getBitWidth(); + return Width == 1 ? getOpTypeBool(MIRBuilder) + : getOpTypeInt(Width, MIRBuilder, false); + } + if (Ty->isFloatingPointTy()) + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + if (Ty->isVoidTy()) + return getOpTypeVoid(MIRBuilder); + if (Ty->isVectorTy()) { + auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), + MIRBuilder); + return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, + MIRBuilder); + } + if (Ty->isArrayTy()) { + auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + } + assert(!isa<StructType>(Ty) && "Unsupported StructType"); + if (auto FType = dyn_cast<FunctionType>(Ty)) { + SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); + SmallVector<SPIRVType *, 4> ParamTypes; + for (const auto &t : FType->params()) { + ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); + } + return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); + } + if (auto PType = dyn_cast<PointerType>(Ty)) { + SPIRVType *SpvElementType; + // At the moment, all opaque pointers correspond to i8 element type. + // TODO: change the implementation once opaque pointers are supported + // in the SPIR-V specification. + if (PType->isOpaque()) { + SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); + } else { + Type *ElemType = PType->getNonOpaquePointerElementType(); + // TODO: support OpenCL and SPIRV builtins like image2d_t that are passed + // as pointers, but should be treated as custom types like OpTypeImage. + assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer"); + + // Otherwise, treat it as a regular pointer type. + SpvElementType = getOrCreateSPIRVType( + ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + } + auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + return getOpTypePointer(SC, SpvElementType, MIRBuilder); + } + llvm_unreachable("Unable to convert LLVM type to SPIRVType"); +} + +SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { + auto t = VRegToTypeMap.find(CurMF); + if (t != VRegToTypeMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return nullptr; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = Type; + return SpirvType; +} + +bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOfType VReg has no type assigned"); + return Type->getOpcode() == TypeOpcode; +} + +bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); + if (Type->getOpcode() == TypeOpcode) + return true; + if (Type->getOpcode() == SPIRV::OpTypeVector) { + Register ScalarTypeVReg = Type->getOperand(1).getReg(); + SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); + return ScalarType->getOpcode() == TypeOpcode; + } + return false; +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt || + Type->getOpcode() == SPIRV::OpTypeFloat) + return Type->getOperand(1).getImm(); + if (Type->getOpcode() == SPIRV::OpTypeBool) + return 1; + llvm_unreachable("Attempting to get bit width of non-integer/float type."); +} + +bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt) + return Type->getOperand(2).getImm() != 0; + llvm_unreachable("Attempting to get sign of non-integer type."); +} + +SPIRV::StorageClass +SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && + Type->getOperand(1).isImm() && "Pointer type is expected"); + return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm()); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, + MachineInstrBuilder MIB) { + SPIRVType *SpirvType = MIB; + VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = LLVMTy; + return SpirvType; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( + unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { + Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(BitWidth) + .addImm(0); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + NumElements), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII) { + Type *LLVMTy = FixedVectorType::get( + const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(NumElements); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, + MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass) { + return getOrCreateSPIRVType( + PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SClass)), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SC) { + Type *LLVMTy = + PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SC)); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(static_cast<uint32_t>(SC)) + .addUse(getSPIRVTypeID(BaseType)); + return restOfCreateSPIRVType(LLVMTy, MIB); +} |