diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
commit | 145449b1e420787bb99721a429341fa6be3adfb6 (patch) | |
tree | 1d56ae694a6de602e348dd80165cf881a36600ed /llvm/lib/Target/SPIRV | |
parent | ecbca9f5fb7d7613d2b94982c4825eb0d33d6842 (diff) | |
download | src-145449b1e420787bb99721a429341fa6be3adfb6.tar.gz src-145449b1e420787bb99721a429341fa6be3adfb6.zip |
Diffstat (limited to 'llvm/lib/Target/SPIRV')
53 files changed, 9450 insertions, 0 deletions
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp new file mode 100644 index 000000000000..4156a0026411 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVAsmBackend.cpp @@ -0,0 +1,63 @@ +//===-- SPIRVAsmBackend.cpp - SPIR-V Assembler Backend ---------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MCTargetDesc/SPIRVMCTargetDesc.h" +#include "llvm/MC/MCAsmBackend.h" +#include "llvm/MC/MCAssembler.h" +#include "llvm/MC/MCObjectWriter.h" +#include "llvm/Support/EndianStream.h" + +using namespace llvm; + +namespace { + +class SPIRVAsmBackend : public MCAsmBackend { +public: + SPIRVAsmBackend(support::endianness Endian) : MCAsmBackend(Endian) {} + + void applyFixup(const MCAssembler &Asm, const MCFixup &Fixup, + const MCValue &Target, MutableArrayRef<char> Data, + uint64_t Value, bool IsResolved, + const MCSubtargetInfo *STI) const override {} + + std::unique_ptr<MCObjectTargetWriter> + createObjectTargetWriter() const override { + return createSPIRVObjectTargetWriter(); + } + + // No instruction requires relaxation. + bool fixupNeedsRelaxation(const MCFixup &Fixup, uint64_t Value, + const MCRelaxableFragment *DF, + const MCAsmLayout &Layout) const override { + return false; + } + + unsigned getNumFixupKinds() const override { return 1; } + + bool mayNeedRelaxation(const MCInst &Inst, + const MCSubtargetInfo &STI) const override { + return false; + } + + void relaxInstruction(MCInst &Inst, + const MCSubtargetInfo &STI) const override {} + + bool writeNopData(raw_ostream &OS, uint64_t Count, + const MCSubtargetInfo *STI) const override { + return false; + } +}; + +} // end anonymous namespace + +MCAsmBackend *llvm::createSPIRVAsmBackend(const Target &T, + const MCSubtargetInfo &STI, + const MCRegisterInfo &MRI, + const MCTargetOptions &) { + return new SPIRVAsmBackend(support::little); +} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp new file mode 100644 index 000000000000..1a3e35a5f901 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -0,0 +1,1072 @@ +//===-- SPIRVBaseInfo.cpp - Top level definitions for SPIRV ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBaseInfo.h" +#include "llvm/Support/ErrorHandling.h" + +namespace llvm { +namespace SPIRV { + +#define CASE(CLASS, ATTR) \ + case CLASS::ATTR: \ + return #ATTR; +#define CASE_SUF(CLASS, SF, ATTR) \ + case CLASS::SF##_##ATTR: \ + return #ATTR; + +// Implement getEnumName(Enum e) helper functions. +// TODO: re-implement all the functions using TableGen. +StringRef getCapabilityName(Capability e) { + switch (e) { + CASE(Capability, Matrix) + CASE(Capability, Shader) + CASE(Capability, Geometry) + CASE(Capability, Tessellation) + CASE(Capability, Addresses) + CASE(Capability, Linkage) + CASE(Capability, Kernel) + CASE(Capability, Vector16) + CASE(Capability, Float16Buffer) + CASE(Capability, Float16) + CASE(Capability, Float64) + CASE(Capability, Int64) + CASE(Capability, Int64Atomics) + CASE(Capability, ImageBasic) + CASE(Capability, ImageReadWrite) + CASE(Capability, ImageMipmap) + CASE(Capability, Pipes) + CASE(Capability, Groups) + CASE(Capability, DeviceEnqueue) + CASE(Capability, LiteralSampler) + CASE(Capability, AtomicStorage) + CASE(Capability, Int16) + CASE(Capability, TessellationPointSize) + CASE(Capability, GeometryPointSize) + CASE(Capability, ImageGatherExtended) + CASE(Capability, StorageImageMultisample) + CASE(Capability, UniformBufferArrayDynamicIndexing) + CASE(Capability, SampledImageArrayDymnamicIndexing) + CASE(Capability, ClipDistance) + CASE(Capability, CullDistance) + CASE(Capability, ImageCubeArray) + CASE(Capability, SampleRateShading) + CASE(Capability, ImageRect) + CASE(Capability, SampledRect) + CASE(Capability, GenericPointer) + CASE(Capability, Int8) + CASE(Capability, InputAttachment) + CASE(Capability, SparseResidency) + CASE(Capability, MinLod) + CASE(Capability, Sampled1D) + CASE(Capability, Image1D) + CASE(Capability, SampledCubeArray) + CASE(Capability, SampledBuffer) + CASE(Capability, ImageBuffer) + CASE(Capability, ImageMSArray) + CASE(Capability, StorageImageExtendedFormats) + CASE(Capability, ImageQuery) + CASE(Capability, DerivativeControl) + CASE(Capability, InterpolationFunction) + CASE(Capability, TransformFeedback) + CASE(Capability, GeometryStreams) + CASE(Capability, StorageImageReadWithoutFormat) + CASE(Capability, StorageImageWriteWithoutFormat) + CASE(Capability, MultiViewport) + CASE(Capability, SubgroupDispatch) + CASE(Capability, NamedBarrier) + CASE(Capability, PipeStorage) + CASE(Capability, GroupNonUniform) + CASE(Capability, GroupNonUniformVote) + CASE(Capability, GroupNonUniformArithmetic) + CASE(Capability, GroupNonUniformBallot) + CASE(Capability, GroupNonUniformShuffle) + CASE(Capability, GroupNonUniformShuffleRelative) + CASE(Capability, GroupNonUniformClustered) + CASE(Capability, GroupNonUniformQuad) + CASE(Capability, SubgroupBallotKHR) + CASE(Capability, DrawParameters) + CASE(Capability, SubgroupVoteKHR) + CASE(Capability, StorageBuffer16BitAccess) + CASE(Capability, StorageUniform16) + CASE(Capability, StoragePushConstant16) + CASE(Capability, StorageInputOutput16) + CASE(Capability, DeviceGroup) + CASE(Capability, MultiView) + CASE(Capability, VariablePointersStorageBuffer) + CASE(Capability, VariablePointers) + CASE(Capability, AtomicStorageOps) + CASE(Capability, SampleMaskPostDepthCoverage) + CASE(Capability, StorageBuffer8BitAccess) + CASE(Capability, UniformAndStorageBuffer8BitAccess) + CASE(Capability, StoragePushConstant8) + CASE(Capability, DenormPreserve) + CASE(Capability, DenormFlushToZero) + CASE(Capability, SignedZeroInfNanPreserve) + CASE(Capability, RoundingModeRTE) + CASE(Capability, RoundingModeRTZ) + CASE(Capability, Float16ImageAMD) + CASE(Capability, ImageGatherBiasLodAMD) + CASE(Capability, FragmentMaskAMD) + CASE(Capability, StencilExportEXT) + CASE(Capability, ImageReadWriteLodAMD) + CASE(Capability, SampleMaskOverrideCoverageNV) + CASE(Capability, GeometryShaderPassthroughNV) + CASE(Capability, ShaderViewportIndexLayerEXT) + CASE(Capability, ShaderViewportMaskNV) + CASE(Capability, ShaderStereoViewNV) + CASE(Capability, PerViewAttributesNV) + CASE(Capability, FragmentFullyCoveredEXT) + CASE(Capability, MeshShadingNV) + CASE(Capability, ShaderNonUniformEXT) + CASE(Capability, RuntimeDescriptorArrayEXT) + CASE(Capability, InputAttachmentArrayDynamicIndexingEXT) + CASE(Capability, UniformTexelBufferArrayDynamicIndexingEXT) + CASE(Capability, StorageTexelBufferArrayDynamicIndexingEXT) + CASE(Capability, UniformBufferArrayNonUniformIndexingEXT) + CASE(Capability, SampledImageArrayNonUniformIndexingEXT) + CASE(Capability, StorageBufferArrayNonUniformIndexingEXT) + CASE(Capability, StorageImageArrayNonUniformIndexingEXT) + CASE(Capability, InputAttachmentArrayNonUniformIndexingEXT) + CASE(Capability, UniformTexelBufferArrayNonUniformIndexingEXT) + CASE(Capability, StorageTexelBufferArrayNonUniformIndexingEXT) + CASE(Capability, RayTracingNV) + CASE(Capability, SubgroupShuffleINTEL) + CASE(Capability, SubgroupBufferBlockIOINTEL) + CASE(Capability, SubgroupImageBlockIOINTEL) + CASE(Capability, SubgroupImageMediaBlockIOINTEL) + CASE(Capability, SubgroupAvcMotionEstimationINTEL) + CASE(Capability, SubgroupAvcMotionEstimationIntraINTEL) + CASE(Capability, SubgroupAvcMotionEstimationChromaINTEL) + CASE(Capability, GroupNonUniformPartitionedNV) + CASE(Capability, VulkanMemoryModelKHR) + CASE(Capability, VulkanMemoryModelDeviceScopeKHR) + CASE(Capability, ImageFootprintNV) + CASE(Capability, FragmentBarycentricNV) + CASE(Capability, ComputeDerivativeGroupQuadsNV) + CASE(Capability, ComputeDerivativeGroupLinearNV) + CASE(Capability, FragmentDensityEXT) + CASE(Capability, PhysicalStorageBufferAddressesEXT) + CASE(Capability, CooperativeMatrixNV) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSourceLanguageName(SourceLanguage e) { + switch (e) { + CASE(SourceLanguage, Unknown) + CASE(SourceLanguage, ESSL) + CASE(SourceLanguage, GLSL) + CASE(SourceLanguage, OpenCL_C) + CASE(SourceLanguage, OpenCL_CPP) + CASE(SourceLanguage, HLSL) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getExecutionModelName(ExecutionModel e) { + switch (e) { + CASE(ExecutionModel, Vertex) + CASE(ExecutionModel, TessellationControl) + CASE(ExecutionModel, TessellationEvaluation) + CASE(ExecutionModel, Geometry) + CASE(ExecutionModel, Fragment) + CASE(ExecutionModel, GLCompute) + CASE(ExecutionModel, Kernel) + CASE(ExecutionModel, TaskNV) + CASE(ExecutionModel, MeshNV) + CASE(ExecutionModel, RayGenerationNV) + CASE(ExecutionModel, IntersectionNV) + CASE(ExecutionModel, AnyHitNV) + CASE(ExecutionModel, ClosestHitNV) + CASE(ExecutionModel, MissNV) + CASE(ExecutionModel, CallableNV) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getAddressingModelName(AddressingModel e) { + switch (e) { + CASE(AddressingModel, Logical) + CASE(AddressingModel, Physical32) + CASE(AddressingModel, Physical64) + CASE(AddressingModel, PhysicalStorageBuffer64EXT) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getMemoryModelName(MemoryModel e) { + switch (e) { + CASE(MemoryModel, Simple) + CASE(MemoryModel, GLSL450) + CASE(MemoryModel, OpenCL) + CASE(MemoryModel, VulkanKHR) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getExecutionModeName(ExecutionMode e) { + switch (e) { + CASE(ExecutionMode, Invocations) + CASE(ExecutionMode, SpacingEqual) + CASE(ExecutionMode, SpacingFractionalEven) + CASE(ExecutionMode, SpacingFractionalOdd) + CASE(ExecutionMode, VertexOrderCw) + CASE(ExecutionMode, VertexOrderCcw) + CASE(ExecutionMode, PixelCenterInteger) + CASE(ExecutionMode, OriginUpperLeft) + CASE(ExecutionMode, OriginLowerLeft) + CASE(ExecutionMode, EarlyFragmentTests) + CASE(ExecutionMode, PointMode) + CASE(ExecutionMode, Xfb) + CASE(ExecutionMode, DepthReplacing) + CASE(ExecutionMode, DepthGreater) + CASE(ExecutionMode, DepthLess) + CASE(ExecutionMode, DepthUnchanged) + CASE(ExecutionMode, LocalSize) + CASE(ExecutionMode, LocalSizeHint) + CASE(ExecutionMode, InputPoints) + CASE(ExecutionMode, InputLines) + CASE(ExecutionMode, InputLinesAdjacency) + CASE(ExecutionMode, Triangles) + CASE(ExecutionMode, InputTrianglesAdjacency) + CASE(ExecutionMode, Quads) + CASE(ExecutionMode, Isolines) + CASE(ExecutionMode, OutputVertices) + CASE(ExecutionMode, OutputPoints) + CASE(ExecutionMode, OutputLineStrip) + CASE(ExecutionMode, OutputTriangleStrip) + CASE(ExecutionMode, VecTypeHint) + CASE(ExecutionMode, ContractionOff) + CASE(ExecutionMode, Initializer) + CASE(ExecutionMode, Finalizer) + CASE(ExecutionMode, SubgroupSize) + CASE(ExecutionMode, SubgroupsPerWorkgroup) + CASE(ExecutionMode, SubgroupsPerWorkgroupId) + CASE(ExecutionMode, LocalSizeId) + CASE(ExecutionMode, LocalSizeHintId) + CASE(ExecutionMode, PostDepthCoverage) + CASE(ExecutionMode, DenormPreserve) + CASE(ExecutionMode, DenormFlushToZero) + CASE(ExecutionMode, SignedZeroInfNanPreserve) + CASE(ExecutionMode, RoundingModeRTE) + CASE(ExecutionMode, RoundingModeRTZ) + CASE(ExecutionMode, StencilRefReplacingEXT) + CASE(ExecutionMode, OutputLinesNV) + CASE(ExecutionMode, DerivativeGroupQuadsNV) + CASE(ExecutionMode, DerivativeGroupLinearNV) + CASE(ExecutionMode, OutputTrianglesNV) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getStorageClassName(StorageClass e) { + switch (e) { + CASE(StorageClass, UniformConstant) + CASE(StorageClass, Input) + CASE(StorageClass, Uniform) + CASE(StorageClass, Output) + CASE(StorageClass, Workgroup) + CASE(StorageClass, CrossWorkgroup) + CASE(StorageClass, Private) + CASE(StorageClass, Function) + CASE(StorageClass, Generic) + CASE(StorageClass, PushConstant) + CASE(StorageClass, AtomicCounter) + CASE(StorageClass, Image) + CASE(StorageClass, StorageBuffer) + CASE(StorageClass, CallableDataNV) + CASE(StorageClass, IncomingCallableDataNV) + CASE(StorageClass, RayPayloadNV) + CASE(StorageClass, HitAttributeNV) + CASE(StorageClass, IncomingRayPayloadNV) + CASE(StorageClass, ShaderRecordBufferNV) + CASE(StorageClass, PhysicalStorageBufferEXT) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getDimName(Dim dim) { + switch (dim) { + CASE_SUF(Dim, DIM, 1D) + CASE_SUF(Dim, DIM, 2D) + CASE_SUF(Dim, DIM, 3D) + CASE_SUF(Dim, DIM, Cube) + CASE_SUF(Dim, DIM, Rect) + CASE_SUF(Dim, DIM, Buffer) + CASE_SUF(Dim, DIM, SubpassData) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSamplerAddressingModeName(SamplerAddressingMode e) { + switch (e) { + CASE(SamplerAddressingMode, None) + CASE(SamplerAddressingMode, ClampToEdge) + CASE(SamplerAddressingMode, Clamp) + CASE(SamplerAddressingMode, Repeat) + CASE(SamplerAddressingMode, RepeatMirrored) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSamplerFilterModeName(SamplerFilterMode e) { + switch (e) { + CASE(SamplerFilterMode, Nearest) + CASE(SamplerFilterMode, Linear) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageFormatName(ImageFormat e) { + switch (e) { + CASE(ImageFormat, Unknown) + CASE(ImageFormat, Rgba32f) + CASE(ImageFormat, Rgba16f) + CASE(ImageFormat, R32f) + CASE(ImageFormat, Rgba8) + CASE(ImageFormat, Rgba8Snorm) + CASE(ImageFormat, Rg32f) + CASE(ImageFormat, Rg16f) + CASE(ImageFormat, R11fG11fB10f) + CASE(ImageFormat, R16f) + CASE(ImageFormat, Rgba16) + CASE(ImageFormat, Rgb10A2) + CASE(ImageFormat, Rg16) + CASE(ImageFormat, Rg8) + CASE(ImageFormat, R16) + CASE(ImageFormat, R8) + CASE(ImageFormat, Rgba16Snorm) + CASE(ImageFormat, Rg16Snorm) + CASE(ImageFormat, Rg8Snorm) + CASE(ImageFormat, R16Snorm) + CASE(ImageFormat, R8Snorm) + CASE(ImageFormat, Rgba32i) + CASE(ImageFormat, Rgba16i) + CASE(ImageFormat, Rgba8i) + CASE(ImageFormat, R32i) + CASE(ImageFormat, Rg32i) + CASE(ImageFormat, Rg16i) + CASE(ImageFormat, Rg8i) + CASE(ImageFormat, R16i) + CASE(ImageFormat, R8i) + CASE(ImageFormat, Rgba32ui) + CASE(ImageFormat, Rgba16ui) + CASE(ImageFormat, Rgba8ui) + CASE(ImageFormat, R32ui) + CASE(ImageFormat, Rgb10a2ui) + CASE(ImageFormat, Rg32ui) + CASE(ImageFormat, Rg16ui) + CASE(ImageFormat, Rg8ui) + CASE(ImageFormat, R16ui) + CASE(ImageFormat, R8ui) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageChannelOrderName(ImageChannelOrder e) { + switch (e) { + CASE(ImageChannelOrder, R) + CASE(ImageChannelOrder, A) + CASE(ImageChannelOrder, RG) + CASE(ImageChannelOrder, RA) + CASE(ImageChannelOrder, RGB) + CASE(ImageChannelOrder, RGBA) + CASE(ImageChannelOrder, BGRA) + CASE(ImageChannelOrder, ARGB) + CASE(ImageChannelOrder, Intensity) + CASE(ImageChannelOrder, Luminance) + CASE(ImageChannelOrder, Rx) + CASE(ImageChannelOrder, RGx) + CASE(ImageChannelOrder, RGBx) + CASE(ImageChannelOrder, Depth) + CASE(ImageChannelOrder, DepthStencil) + CASE(ImageChannelOrder, sRGB) + CASE(ImageChannelOrder, sRGBx) + CASE(ImageChannelOrder, sRGBA) + CASE(ImageChannelOrder, sBGRA) + CASE(ImageChannelOrder, ABGR) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageChannelDataTypeName(ImageChannelDataType e) { + switch (e) { + CASE(ImageChannelDataType, SnormInt8) + CASE(ImageChannelDataType, SnormInt16) + CASE(ImageChannelDataType, UnormInt8) + CASE(ImageChannelDataType, UnormInt16) + CASE(ImageChannelDataType, UnormShort565) + CASE(ImageChannelDataType, UnormShort555) + CASE(ImageChannelDataType, UnormInt101010) + CASE(ImageChannelDataType, SignedInt8) + CASE(ImageChannelDataType, SignedInt16) + CASE(ImageChannelDataType, SignedInt32) + CASE(ImageChannelDataType, UnsignedInt8) + CASE(ImageChannelDataType, UnsignedInt16) + CASE(ImageChannelDataType, UnsigendInt32) + CASE(ImageChannelDataType, HalfFloat) + CASE(ImageChannelDataType, Float) + CASE(ImageChannelDataType, UnormInt24) + CASE(ImageChannelDataType, UnormInt101010_2) + break; + } + llvm_unreachable("Unexpected operand"); +} + +std::string getImageOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(ImageOperand::None)) + return "None"; + if (e == static_cast<uint32_t>(ImageOperand::Bias)) + return "Bias"; + if (e & static_cast<uint32_t>(ImageOperand::Bias)) { + nameString += sep + "Bias"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::Lod)) + return "Lod"; + if (e & static_cast<uint32_t>(ImageOperand::Lod)) { + nameString += sep + "Lod"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::Grad)) + return "Grad"; + if (e & static_cast<uint32_t>(ImageOperand::Grad)) { + nameString += sep + "Grad"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::ConstOffset)) + return "ConstOffset"; + if (e & static_cast<uint32_t>(ImageOperand::ConstOffset)) { + nameString += sep + "ConstOffset"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::Offset)) + return "Offset"; + if (e & static_cast<uint32_t>(ImageOperand::Offset)) { + nameString += sep + "Offset"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::ConstOffsets)) + return "ConstOffsets"; + if (e & static_cast<uint32_t>(ImageOperand::ConstOffsets)) { + nameString += sep + "ConstOffsets"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::Sample)) + return "Sample"; + if (e & static_cast<uint32_t>(ImageOperand::Sample)) { + nameString += sep + "Sample"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::MinLod)) + return "MinLod"; + if (e & static_cast<uint32_t>(ImageOperand::MinLod)) { + nameString += sep + "MinLod"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::MakeTexelAvailableKHR)) + return "MakeTexelAvailableKHR"; + if (e & static_cast<uint32_t>(ImageOperand::MakeTexelAvailableKHR)) { + nameString += sep + "MakeTexelAvailableKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::MakeTexelVisibleKHR)) + return "MakeTexelVisibleKHR"; + if (e & static_cast<uint32_t>(ImageOperand::MakeTexelVisibleKHR)) { + nameString += sep + "MakeTexelVisibleKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::NonPrivateTexelKHR)) + return "NonPrivateTexelKHR"; + if (e & static_cast<uint32_t>(ImageOperand::NonPrivateTexelKHR)) { + nameString += sep + "NonPrivateTexelKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::VolatileTexelKHR)) + return "VolatileTexelKHR"; + if (e & static_cast<uint32_t>(ImageOperand::VolatileTexelKHR)) { + nameString += sep + "VolatileTexelKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::SignExtend)) + return "SignExtend"; + if (e & static_cast<uint32_t>(ImageOperand::SignExtend)) { + nameString += sep + "SignExtend"; + sep = "|"; + } + if (e == static_cast<uint32_t>(ImageOperand::ZeroExtend)) + return "ZeroExtend"; + if (e & static_cast<uint32_t>(ImageOperand::ZeroExtend)) { + nameString += sep + "ZeroExtend"; + sep = "|"; + }; + return nameString; +} + +std::string getFPFastMathModeName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(FPFastMathMode::None)) + return "None"; + if (e == static_cast<uint32_t>(FPFastMathMode::NotNaN)) + return "NotNaN"; + if (e & static_cast<uint32_t>(FPFastMathMode::NotNaN)) { + nameString += sep + "NotNaN"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FPFastMathMode::NotInf)) + return "NotInf"; + if (e & static_cast<uint32_t>(FPFastMathMode::NotInf)) { + nameString += sep + "NotInf"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FPFastMathMode::NSZ)) + return "NSZ"; + if (e & static_cast<uint32_t>(FPFastMathMode::NSZ)) { + nameString += sep + "NSZ"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FPFastMathMode::AllowRecip)) + return "AllowRecip"; + if (e & static_cast<uint32_t>(FPFastMathMode::AllowRecip)) { + nameString += sep + "AllowRecip"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FPFastMathMode::Fast)) + return "Fast"; + if (e & static_cast<uint32_t>(FPFastMathMode::Fast)) { + nameString += sep + "Fast"; + sep = "|"; + }; + return nameString; +} + +StringRef getFPRoundingModeName(FPRoundingMode e) { + switch (e) { + CASE(FPRoundingMode, RTE) + CASE(FPRoundingMode, RTZ) + CASE(FPRoundingMode, RTP) + CASE(FPRoundingMode, RTN) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getLinkageTypeName(LinkageType e) { + switch (e) { + CASE(LinkageType, Export) + CASE(LinkageType, Import) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getAccessQualifierName(AccessQualifier e) { + switch (e) { + CASE(AccessQualifier, ReadOnly) + CASE(AccessQualifier, WriteOnly) + CASE(AccessQualifier, ReadWrite) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e) { + switch (e) { + CASE(FunctionParameterAttribute, Zext) + CASE(FunctionParameterAttribute, Sext) + CASE(FunctionParameterAttribute, ByVal) + CASE(FunctionParameterAttribute, Sret) + CASE(FunctionParameterAttribute, NoAlias) + CASE(FunctionParameterAttribute, NoCapture) + CASE(FunctionParameterAttribute, NoWrite) + CASE(FunctionParameterAttribute, NoReadWrite) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getDecorationName(Decoration e) { + switch (e) { + CASE(Decoration, RelaxedPrecision) + CASE(Decoration, SpecId) + CASE(Decoration, Block) + CASE(Decoration, BufferBlock) + CASE(Decoration, RowMajor) + CASE(Decoration, ColMajor) + CASE(Decoration, ArrayStride) + CASE(Decoration, MatrixStride) + CASE(Decoration, GLSLShared) + CASE(Decoration, GLSLPacked) + CASE(Decoration, CPacked) + CASE(Decoration, BuiltIn) + CASE(Decoration, NoPerspective) + CASE(Decoration, Flat) + CASE(Decoration, Patch) + CASE(Decoration, Centroid) + CASE(Decoration, Sample) + CASE(Decoration, Invariant) + CASE(Decoration, Restrict) + CASE(Decoration, Aliased) + CASE(Decoration, Volatile) + CASE(Decoration, Constant) + CASE(Decoration, Coherent) + CASE(Decoration, NonWritable) + CASE(Decoration, NonReadable) + CASE(Decoration, Uniform) + CASE(Decoration, UniformId) + CASE(Decoration, SaturatedConversion) + CASE(Decoration, Stream) + CASE(Decoration, Location) + CASE(Decoration, Component) + CASE(Decoration, Index) + CASE(Decoration, Binding) + CASE(Decoration, DescriptorSet) + CASE(Decoration, Offset) + CASE(Decoration, XfbBuffer) + CASE(Decoration, XfbStride) + CASE(Decoration, FuncParamAttr) + CASE(Decoration, FPRoundingMode) + CASE(Decoration, FPFastMathMode) + CASE(Decoration, LinkageAttributes) + CASE(Decoration, NoContraction) + CASE(Decoration, InputAttachmentIndex) + CASE(Decoration, Alignment) + CASE(Decoration, MaxByteOffset) + CASE(Decoration, AlignmentId) + CASE(Decoration, MaxByteOffsetId) + CASE(Decoration, NoSignedWrap) + CASE(Decoration, NoUnsignedWrap) + CASE(Decoration, ExplicitInterpAMD) + CASE(Decoration, OverrideCoverageNV) + CASE(Decoration, PassthroughNV) + CASE(Decoration, ViewportRelativeNV) + CASE(Decoration, SecondaryViewportRelativeNV) + CASE(Decoration, PerPrimitiveNV) + CASE(Decoration, PerViewNV) + CASE(Decoration, PerVertexNV) + CASE(Decoration, NonUniformEXT) + CASE(Decoration, CountBuffer) + CASE(Decoration, UserSemantic) + CASE(Decoration, RestrictPointerEXT) + CASE(Decoration, AliasedPointerEXT) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getBuiltInName(BuiltIn e) { + switch (e) { + CASE(BuiltIn, Position) + CASE(BuiltIn, PointSize) + CASE(BuiltIn, ClipDistance) + CASE(BuiltIn, CullDistance) + CASE(BuiltIn, VertexId) + CASE(BuiltIn, InstanceId) + CASE(BuiltIn, PrimitiveId) + CASE(BuiltIn, InvocationId) + CASE(BuiltIn, Layer) + CASE(BuiltIn, ViewportIndex) + CASE(BuiltIn, TessLevelOuter) + CASE(BuiltIn, TessLevelInner) + CASE(BuiltIn, TessCoord) + CASE(BuiltIn, PatchVertices) + CASE(BuiltIn, FragCoord) + CASE(BuiltIn, PointCoord) + CASE(BuiltIn, FrontFacing) + CASE(BuiltIn, SampleId) + CASE(BuiltIn, SamplePosition) + CASE(BuiltIn, SampleMask) + CASE(BuiltIn, FragDepth) + CASE(BuiltIn, HelperInvocation) + CASE(BuiltIn, NumWorkgroups) + CASE(BuiltIn, WorkgroupSize) + CASE(BuiltIn, WorkgroupId) + CASE(BuiltIn, LocalInvocationId) + CASE(BuiltIn, GlobalInvocationId) + CASE(BuiltIn, LocalInvocationIndex) + CASE(BuiltIn, WorkDim) + CASE(BuiltIn, GlobalSize) + CASE(BuiltIn, EnqueuedWorkgroupSize) + CASE(BuiltIn, GlobalOffset) + CASE(BuiltIn, GlobalLinearId) + CASE(BuiltIn, SubgroupSize) + CASE(BuiltIn, SubgroupMaxSize) + CASE(BuiltIn, NumSubgroups) + CASE(BuiltIn, NumEnqueuedSubgroups) + CASE(BuiltIn, SubgroupId) + CASE(BuiltIn, SubgroupLocalInvocationId) + CASE(BuiltIn, VertexIndex) + CASE(BuiltIn, InstanceIndex) + CASE(BuiltIn, SubgroupEqMask) + CASE(BuiltIn, SubgroupGeMask) + CASE(BuiltIn, SubgroupGtMask) + CASE(BuiltIn, SubgroupLeMask) + CASE(BuiltIn, SubgroupLtMask) + CASE(BuiltIn, BaseVertex) + CASE(BuiltIn, BaseInstance) + CASE(BuiltIn, DrawIndex) + CASE(BuiltIn, DeviceIndex) + CASE(BuiltIn, ViewIndex) + CASE(BuiltIn, BaryCoordNoPerspAMD) + CASE(BuiltIn, BaryCoordNoPerspCentroidAMD) + CASE(BuiltIn, BaryCoordNoPerspSampleAMD) + CASE(BuiltIn, BaryCoordSmoothAMD) + CASE(BuiltIn, BaryCoordSmoothCentroid) + CASE(BuiltIn, BaryCoordSmoothSample) + CASE(BuiltIn, BaryCoordPullModel) + CASE(BuiltIn, FragStencilRefEXT) + CASE(BuiltIn, ViewportMaskNV) + CASE(BuiltIn, SecondaryPositionNV) + CASE(BuiltIn, SecondaryViewportMaskNV) + CASE(BuiltIn, PositionPerViewNV) + CASE(BuiltIn, ViewportMaskPerViewNV) + CASE(BuiltIn, FullyCoveredEXT) + CASE(BuiltIn, TaskCountNV) + CASE(BuiltIn, PrimitiveCountNV) + CASE(BuiltIn, PrimitiveIndicesNV) + CASE(BuiltIn, ClipDistancePerViewNV) + CASE(BuiltIn, CullDistancePerViewNV) + CASE(BuiltIn, LayerPerViewNV) + CASE(BuiltIn, MeshViewCountNV) + CASE(BuiltIn, MeshViewIndices) + CASE(BuiltIn, BaryCoordNV) + CASE(BuiltIn, BaryCoordNoPerspNV) + CASE(BuiltIn, FragSizeEXT) + CASE(BuiltIn, FragInvocationCountEXT) + CASE(BuiltIn, LaunchIdNV) + CASE(BuiltIn, LaunchSizeNV) + CASE(BuiltIn, WorldRayOriginNV) + CASE(BuiltIn, WorldRayDirectionNV) + CASE(BuiltIn, ObjectRayOriginNV) + CASE(BuiltIn, ObjectRayDirectionNV) + CASE(BuiltIn, RayTminNV) + CASE(BuiltIn, RayTmaxNV) + CASE(BuiltIn, InstanceCustomIndexNV) + CASE(BuiltIn, ObjectToWorldNV) + CASE(BuiltIn, WorldToObjectNV) + CASE(BuiltIn, HitTNV) + CASE(BuiltIn, HitKindNV) + CASE(BuiltIn, IncomingRayFlagsNV) + break; + } + llvm_unreachable("Unexpected operand"); +} + +std::string getSelectionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(SelectionControl::None)) + return "None"; + if (e == static_cast<uint32_t>(SelectionControl::Flatten)) + return "Flatten"; + if (e & static_cast<uint32_t>(SelectionControl::Flatten)) { + nameString += sep + "Flatten"; + sep = "|"; + } + if (e == static_cast<uint32_t>(SelectionControl::DontFlatten)) + return "DontFlatten"; + if (e & static_cast<uint32_t>(SelectionControl::DontFlatten)) { + nameString += sep + "DontFlatten"; + sep = "|"; + }; + return nameString; +} + +std::string getLoopControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(LoopControl::None)) + return "None"; + if (e == static_cast<uint32_t>(LoopControl::Unroll)) + return "Unroll"; + if (e & static_cast<uint32_t>(LoopControl::Unroll)) { + nameString += sep + "Unroll"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::DontUnroll)) + return "DontUnroll"; + if (e & static_cast<uint32_t>(LoopControl::DontUnroll)) { + nameString += sep + "DontUnroll"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::DependencyInfinite)) + return "DependencyInfinite"; + if (e & static_cast<uint32_t>(LoopControl::DependencyInfinite)) { + nameString += sep + "DependencyInfinite"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::DependencyLength)) + return "DependencyLength"; + if (e & static_cast<uint32_t>(LoopControl::DependencyLength)) { + nameString += sep + "DependencyLength"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::MinIterations)) + return "MinIterations"; + if (e & static_cast<uint32_t>(LoopControl::MinIterations)) { + nameString += sep + "MinIterations"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::MaxIterations)) + return "MaxIterations"; + if (e & static_cast<uint32_t>(LoopControl::MaxIterations)) { + nameString += sep + "MaxIterations"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::IterationMultiple)) + return "IterationMultiple"; + if (e & static_cast<uint32_t>(LoopControl::IterationMultiple)) { + nameString += sep + "IterationMultiple"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::PeelCount)) + return "PeelCount"; + if (e & static_cast<uint32_t>(LoopControl::PeelCount)) { + nameString += sep + "PeelCount"; + sep = "|"; + } + if (e == static_cast<uint32_t>(LoopControl::PartialCount)) + return "PartialCount"; + if (e & static_cast<uint32_t>(LoopControl::PartialCount)) { + nameString += sep + "PartialCount"; + sep = "|"; + }; + return nameString; +} + +std::string getFunctionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(FunctionControl::None)) + return "None"; + if (e == static_cast<uint32_t>(FunctionControl::Inline)) + return "Inline"; + if (e & static_cast<uint32_t>(FunctionControl::Inline)) { + nameString += sep + "Inline"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FunctionControl::DontInline)) + return "DontInline"; + if (e & static_cast<uint32_t>(FunctionControl::DontInline)) { + nameString += sep + "DontInline"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FunctionControl::Pure)) + return "Pure"; + if (e & static_cast<uint32_t>(FunctionControl::Pure)) { + nameString += sep + "Pure"; + sep = "|"; + } + if (e == static_cast<uint32_t>(FunctionControl::Const)) + return "Const"; + if (e & static_cast<uint32_t>(FunctionControl::Const)) { + nameString += sep + "Const"; + sep = "|"; + }; + return nameString; +} + +std::string getMemorySemanticsName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(MemorySemantics::None)) + return "None"; + if (e == static_cast<uint32_t>(MemorySemantics::Acquire)) + return "Acquire"; + if (e & static_cast<uint32_t>(MemorySemantics::Acquire)) { + nameString += sep + "Acquire"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::Release)) + return "Release"; + if (e & static_cast<uint32_t>(MemorySemantics::Release)) { + nameString += sep + "Release"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::AcquireRelease)) + return "AcquireRelease"; + if (e & static_cast<uint32_t>(MemorySemantics::AcquireRelease)) { + nameString += sep + "AcquireRelease"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::SequentiallyConsistent)) + return "SequentiallyConsistent"; + if (e & static_cast<uint32_t>(MemorySemantics::SequentiallyConsistent)) { + nameString += sep + "SequentiallyConsistent"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::UniformMemory)) + return "UniformMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::UniformMemory)) { + nameString += sep + "UniformMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::SubgroupMemory)) + return "SubgroupMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::SubgroupMemory)) { + nameString += sep + "SubgroupMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::WorkgroupMemory)) + return "WorkgroupMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::WorkgroupMemory)) { + nameString += sep + "WorkgroupMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::CrossWorkgroupMemory)) + return "CrossWorkgroupMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::CrossWorkgroupMemory)) { + nameString += sep + "CrossWorkgroupMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::AtomicCounterMemory)) + return "AtomicCounterMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::AtomicCounterMemory)) { + nameString += sep + "AtomicCounterMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::ImageMemory)) + return "ImageMemory"; + if (e & static_cast<uint32_t>(MemorySemantics::ImageMemory)) { + nameString += sep + "ImageMemory"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::OutputMemoryKHR)) + return "OutputMemoryKHR"; + if (e & static_cast<uint32_t>(MemorySemantics::OutputMemoryKHR)) { + nameString += sep + "OutputMemoryKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::MakeAvailableKHR)) + return "MakeAvailableKHR"; + if (e & static_cast<uint32_t>(MemorySemantics::MakeAvailableKHR)) { + nameString += sep + "MakeAvailableKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemorySemantics::MakeVisibleKHR)) + return "MakeVisibleKHR"; + if (e & static_cast<uint32_t>(MemorySemantics::MakeVisibleKHR)) { + nameString += sep + "MakeVisibleKHR"; + sep = "|"; + }; + return nameString; +} + +std::string getMemoryOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast<uint32_t>(MemoryOperand::None)) + return "None"; + if (e == static_cast<uint32_t>(MemoryOperand::Volatile)) + return "Volatile"; + if (e & static_cast<uint32_t>(MemoryOperand::Volatile)) { + nameString += sep + "Volatile"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemoryOperand::Aligned)) + return "Aligned"; + if (e & static_cast<uint32_t>(MemoryOperand::Aligned)) { + nameString += sep + "Aligned"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemoryOperand::Nontemporal)) + return "Nontemporal"; + if (e & static_cast<uint32_t>(MemoryOperand::Nontemporal)) { + nameString += sep + "Nontemporal"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemoryOperand::MakePointerAvailableKHR)) + return "MakePointerAvailableKHR"; + if (e & static_cast<uint32_t>(MemoryOperand::MakePointerAvailableKHR)) { + nameString += sep + "MakePointerAvailableKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemoryOperand::MakePointerVisibleKHR)) + return "MakePointerVisibleKHR"; + if (e & static_cast<uint32_t>(MemoryOperand::MakePointerVisibleKHR)) { + nameString += sep + "MakePointerVisibleKHR"; + sep = "|"; + } + if (e == static_cast<uint32_t>(MemoryOperand::NonPrivatePointerKHR)) + return "NonPrivatePointerKHR"; + if (e & static_cast<uint32_t>(MemoryOperand::NonPrivatePointerKHR)) { + nameString += sep + "NonPrivatePointerKHR"; + sep = "|"; + }; + return nameString; +} + +StringRef getScopeName(Scope e) { + switch (e) { + CASE(Scope, CrossDevice) + CASE(Scope, Device) + CASE(Scope, Workgroup) + CASE(Scope, Subgroup) + CASE(Scope, Invocation) + CASE(Scope, QueueFamilyKHR) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getGroupOperationName(GroupOperation e) { + switch (e) { + CASE(GroupOperation, Reduce) + CASE(GroupOperation, InclusiveScan) + CASE(GroupOperation, ExclusiveScan) + CASE(GroupOperation, ClusteredReduce) + CASE(GroupOperation, PartitionedReduceNV) + CASE(GroupOperation, PartitionedInclusiveScanNV) + CASE(GroupOperation, PartitionedExclusiveScanNV) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e) { + switch (e) { + CASE(KernelEnqueueFlags, NoWait) + CASE(KernelEnqueueFlags, WaitKernel) + CASE(KernelEnqueueFlags, WaitWorkGroup) + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getKernelProfilingInfoName(KernelProfilingInfo e) { + switch (e) { + CASE(KernelProfilingInfo, None) + CASE(KernelProfilingInfo, CmdExecTime) + break; + } + llvm_unreachable("Unexpected operand"); +} +} // namespace SPIRV +} // namespace llvm diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h new file mode 100644 index 000000000000..2aa9f076c78e --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -0,0 +1,739 @@ +//===-- SPIRVBaseInfo.h - Top level definitions for SPIRV ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H + +#include "llvm/ADT/StringRef.h" +#include <string> + +namespace llvm { +namespace SPIRV { +enum class Capability : uint32_t { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDymnamicIndexing = 29, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + StorageUniform16 = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ShaderNonUniformEXT = 5301, + RuntimeDescriptorArrayEXT = 5302, + InputAttachmentArrayDynamicIndexingEXT = 5303, + UniformTexelBufferArrayDynamicIndexingEXT = 5304, + StorageTexelBufferArrayDynamicIndexingEXT = 5305, + UniformBufferArrayNonUniformIndexingEXT = 5306, + SampledImageArrayNonUniformIndexingEXT = 5307, + StorageBufferArrayNonUniformIndexingEXT = 5308, + StorageImageArrayNonUniformIndexingEXT = 5309, + InputAttachmentArrayNonUniformIndexingEXT = 5310, + UniformTexelBufferArrayNonUniformIndexingEXT = 5311, + StorageTexelBufferArrayNonUniformIndexingEXT = 5312, + RayTracingNV = 5340, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + GroupNonUniformPartitionedNV = 5297, + VulkanMemoryModelKHR = 5345, + VulkanMemoryModelDeviceScopeKHR = 5346, + ImageFootprintNV = 5282, + FragmentBarycentricNV = 5284, + ComputeDerivativeGroupQuadsNV = 5288, + ComputeDerivativeGroupLinearNV = 5350, + FragmentDensityEXT = 5291, + PhysicalStorageBufferAddressesEXT = 5347, + CooperativeMatrixNV = 5357, +}; +StringRef getCapabilityName(Capability e); + +enum class SourceLanguage : uint32_t { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, +}; +StringRef getSourceLanguageName(SourceLanguage e); + +enum class AddressingModel : uint32_t { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64EXT = 5348, +}; +StringRef getAddressingModelName(AddressingModel e); + +enum class ExecutionModel : uint32_t { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationNV = 5313, + IntersectionNV = 5314, + AnyHitNV = 5315, + ClosestHitNV = 5316, + MissNV = 5317, + CallableNV = 5318, +}; +StringRef getExecutionModelName(ExecutionModel e); + +enum class MemoryModel : uint32_t { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + VulkanKHR = 3, +}; +StringRef getMemoryModelName(MemoryModel e); + +enum class ExecutionMode : uint32_t { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + StencilRefReplacingEXT = 5027, + OutputLinesNV = 5269, + DerivativeGroupQuadsNV = 5289, + DerivativeGroupLinearNV = 5290, + OutputTrianglesNV = 5298, +}; +StringRef getExecutionModeName(ExecutionMode e); + +enum class StorageClass : uint32_t { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + CallableDataNV = 5328, + IncomingCallableDataNV = 5329, + RayPayloadNV = 5338, + HitAttributeNV = 5339, + IncomingRayPayloadNV = 5342, + ShaderRecordBufferNV = 5343, + PhysicalStorageBufferEXT = 5349, +}; +StringRef getStorageClassName(StorageClass e); + +enum class Dim : uint32_t { + DIM_1D = 0, + DIM_2D = 1, + DIM_3D = 2, + DIM_Cube = 3, + DIM_Rect = 4, + DIM_Buffer = 5, + DIM_SubpassData = 6, +}; +StringRef getDimName(Dim e); + +enum class SamplerAddressingMode : uint32_t { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +StringRef getSamplerAddressingModeName(SamplerAddressingMode e); + +enum class SamplerFilterMode : uint32_t { + Nearest = 0, + Linear = 1, +}; +StringRef getSamplerFilterModeName(SamplerFilterMode e); + +enum class ImageFormat : uint32_t { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, +}; +StringRef getImageFormatName(ImageFormat e); + +enum class ImageChannelOrder : uint32_t { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +StringRef getImageChannelOrderName(ImageChannelOrder e); + +enum class ImageChannelDataType : uint32_t { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsigendInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, +}; +StringRef getImageChannelDataTypeName(ImageChannelDataType e); + +enum class ImageOperand : uint32_t { + None = 0x0, + Bias = 0x1, + Lod = 0x2, + Grad = 0x4, + ConstOffset = 0x8, + Offset = 0x10, + ConstOffsets = 0x20, + Sample = 0x40, + MinLod = 0x80, + MakeTexelAvailableKHR = 0x100, + MakeTexelVisibleKHR = 0x200, + NonPrivateTexelKHR = 0x400, + VolatileTexelKHR = 0x800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, +}; +std::string getImageOperandName(uint32_t e); + +enum class FPFastMathMode : uint32_t { + None = 0x0, + NotNaN = 0x1, + NotInf = 0x2, + NSZ = 0x4, + AllowRecip = 0x8, + Fast = 0x10, +}; +std::string getFPFastMathModeName(uint32_t e); + +enum class FPRoundingMode : uint32_t { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +StringRef getFPRoundingModeName(FPRoundingMode e); + +enum class LinkageType : uint32_t { + Export = 0, + Import = 1, +}; +StringRef getLinkageTypeName(LinkageType e); + +enum class AccessQualifier : uint32_t { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +StringRef getAccessQualifierName(AccessQualifier e); + +enum class FunctionParameterAttribute : uint32_t { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, +}; +StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e); + +enum class Decoration : uint32_t { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + ExplicitInterpAMD = 4999, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveNV = 5271, + PerViewNV = 5272, + PerVertexNV = 5273, + NonUniformEXT = 5300, + CountBuffer = 5634, + UserSemantic = 5635, + RestrictPointerEXT = 5355, + AliasedPointerEXT = 5356, +}; +StringRef getDecorationName(Decoration e); + +enum class BuiltIn : uint32_t { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + DeviceIndex = 4438, + ViewIndex = 4440, + BaryCoordNoPerspAMD = 4492, + BaryCoordNoPerspCentroidAMD = 4493, + BaryCoordNoPerspSampleAMD = 4494, + BaryCoordSmoothAMD = 4495, + BaryCoordSmoothCentroid = 4496, + BaryCoordSmoothSample = 4497, + BaryCoordPullModel = 4498, + FragStencilRefEXT = 5014, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndices = 5281, + BaryCoordNV = 5286, + BaryCoordNoPerspNV = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + LaunchIdNV = 5319, + LaunchSizeNV = 5320, + WorldRayOriginNV = 5321, + WorldRayDirectionNV = 5322, + ObjectRayOriginNV = 5323, + ObjectRayDirectionNV = 5324, + RayTminNV = 5325, + RayTmaxNV = 5326, + InstanceCustomIndexNV = 5327, + ObjectToWorldNV = 5330, + WorldToObjectNV = 5331, + HitTNV = 5332, + HitKindNV = 5333, + IncomingRayFlagsNV = 5351, +}; +StringRef getBuiltInName(BuiltIn e); + +enum class SelectionControl : uint32_t { + None = 0x0, + Flatten = 0x1, + DontFlatten = 0x2, +}; +std::string getSelectionControlName(uint32_t e); + +enum class LoopControl : uint32_t { + None = 0x0, + Unroll = 0x1, + DontUnroll = 0x2, + DependencyInfinite = 0x4, + DependencyLength = 0x8, + MinIterations = 0x10, + MaxIterations = 0x20, + IterationMultiple = 0x40, + PeelCount = 0x80, + PartialCount = 0x100, +}; +std::string getLoopControlName(uint32_t e); + +enum class FunctionControl : uint32_t { + None = 0x0, + Inline = 0x1, + DontInline = 0x2, + Pure = 0x4, + Const = 0x8, +}; +std::string getFunctionControlName(uint32_t e); + +enum class MemorySemantics : uint32_t { + None = 0x0, + Acquire = 0x2, + Release = 0x4, + AcquireRelease = 0x8, + SequentiallyConsistent = 0x10, + UniformMemory = 0x40, + SubgroupMemory = 0x80, + WorkgroupMemory = 0x100, + CrossWorkgroupMemory = 0x200, + AtomicCounterMemory = 0x400, + ImageMemory = 0x800, + OutputMemoryKHR = 0x1000, + MakeAvailableKHR = 0x2000, + MakeVisibleKHR = 0x4000, +}; +std::string getMemorySemanticsName(uint32_t e); + +enum class MemoryOperand : uint32_t { + None = 0x0, + Volatile = 0x1, + Aligned = 0x2, + Nontemporal = 0x4, + MakePointerAvailableKHR = 0x8, + MakePointerVisibleKHR = 0x10, + NonPrivatePointerKHR = 0x20, +}; +std::string getMemoryOperandName(uint32_t e); + +enum class Scope : uint32_t { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamilyKHR = 5, +}; +StringRef getScopeName(Scope e); + +enum class GroupOperation : uint32_t { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +StringRef getGroupOperationName(GroupOperation e); + +enum class KernelEnqueueFlags : uint32_t { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e); + +enum class KernelProfilingInfo : uint32_t { + None = 0x0, + CmdExecTime = 0x1, +}; +StringRef getKernelProfilingInfoName(KernelProfilingInfo e); +} // namespace SPIRV +} // namespace llvm + +// Return a string representation of the operands from startIndex onwards. +// Templated to allow both MachineInstr and MCInst to use the same logic. +template <class InstType> +std::string getSPIRVStringOperand(const InstType &MI, unsigned StartIndex) { + std::string s; // Iteratively append to this string. + + const unsigned NumOps = MI.getNumOperands(); + bool IsFinished = false; + for (unsigned i = StartIndex; i < NumOps && !IsFinished; ++i) { + const auto &Op = MI.getOperand(i); + if (!Op.isImm()) // Stop if we hit a register operand. + break; + assert((Op.getImm() >> 32) == 0 && "Imm operand should be i32 word"); + const uint32_t Imm = Op.getImm(); // Each i32 word is up to 4 characters. + for (unsigned ShiftAmount = 0; ShiftAmount < 32; ShiftAmount += 8) { + char c = (Imm >> ShiftAmount) & 0xff; + if (c == 0) { // Stop if we hit a null-terminator character. + IsFinished = true; + break; + } else { + s += c; // Otherwise, append the character to the result string. + } + } + } + return s; +} + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp new file mode 100644 index 000000000000..3105baa02c90 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -0,0 +1,556 @@ +//===-- SPIRVInstPrinter.cpp - Output SPIR-V MCInsts as ASM -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This class prints a SPIR-V MCInst to a .s file. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVInstPrinter.h" +#include "SPIRV.h" +#include "SPIRVBaseInfo.h" +#include "llvm/CodeGen/Register.h" +#include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCExpr.h" +#include "llvm/MC/MCInst.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCSymbol.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" + +using namespace llvm; + +#define DEBUG_TYPE "asm-printer" + +// Include the auto-generated portion of the assembly writer. +#include "SPIRVGenAsmWriter.inc" + +void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI, + unsigned StartIndex, + raw_ostream &O, + bool SkipFirstSpace, + bool SkipImmediates) { + const unsigned NumOps = MI->getNumOperands(); + for (unsigned i = StartIndex; i < NumOps; ++i) { + if (!SkipImmediates || !MI->getOperand(i).isImm()) { + if (!SkipFirstSpace || i != StartIndex) + O << ' '; + printOperand(MI, i, O); + } + } +} + +void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, + unsigned StartIndex, + raw_ostream &O) { + O << ' '; + if (MI->getNumOperands() - StartIndex == 2) { // Handle 64 bit literals. + uint64_t Imm = MI->getOperand(StartIndex).getImm(); + Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); + O << Imm; + } else { + printRemainingVariableOps(MI, StartIndex, O, true, false); + } +} + +void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) { + llvm_unreachable("Unimplemented recordOpExtInstImport"); +} + +void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address, + StringRef Annot, const MCSubtargetInfo &STI, + raw_ostream &OS) { + const unsigned OpCode = MI->getOpcode(); + printInstruction(MI, Address, OS); + + if (OpCode == SPIRV::OpDecorate) { + printOpDecorate(MI, OS); + } else if (OpCode == SPIRV::OpExtInstImport) { + recordOpExtInstImport(MI); + } else if (OpCode == SPIRV::OpExtInst) { + printOpExtInst(MI, OS); + } else { + // Print any extra operands for variadic instructions. + MCInstrDesc MCDesc = MII.get(OpCode); + if (MCDesc.isVariadic()) { + const unsigned NumFixedOps = MCDesc.getNumOperands(); + const unsigned LastFixedIndex = NumFixedOps - 1; + const int FirstVariableIndex = NumFixedOps; + if (NumFixedOps > 0 && + MCDesc.OpInfo[LastFixedIndex].OperandType == MCOI::OPERAND_UNKNOWN) { + // For instructions where a custom type (not reg or immediate) comes as + // the last operand before the variable_ops. This is usually a StringImm + // operand, but there are a few other cases. + switch (OpCode) { + case SPIRV::OpTypeImage: + OS << ' '; + printAccessQualifier(MI, FirstVariableIndex, OS); + break; + case SPIRV::OpVariable: + OS << ' '; + printOperand(MI, FirstVariableIndex, OS); + break; + case SPIRV::OpEntryPoint: { + // Print the interface ID operands, skipping the name's string + // literal. + printRemainingVariableOps(MI, NumFixedOps, OS, false, true); + break; + } + case SPIRV::OpExecutionMode: + case SPIRV::OpExecutionModeId: + case SPIRV::OpLoopMerge: { + // Print any literals after the OPERAND_UNKNOWN argument normally. + printRemainingVariableOps(MI, NumFixedOps, OS); + break; + } + default: + break; // printStringImm has already been handled + } + } else { + // For instructions with no fixed ops or a reg/immediate as the final + // fixed operand, we can usually print the rest with "printOperand", but + // check for a few cases with custom types first. + switch (OpCode) { + case SPIRV::OpLoad: + case SPIRV::OpStore: + OS << ' '; + printMemoryOperand(MI, FirstVariableIndex, OS); + printRemainingVariableOps(MI, FirstVariableIndex + 1, OS); + break; + case SPIRV::OpImageSampleImplicitLod: + case SPIRV::OpImageSampleDrefImplicitLod: + case SPIRV::OpImageSampleProjImplicitLod: + case SPIRV::OpImageSampleProjDrefImplicitLod: + case SPIRV::OpImageFetch: + case SPIRV::OpImageGather: + case SPIRV::OpImageDrefGather: + case SPIRV::OpImageRead: + case SPIRV::OpImageWrite: + case SPIRV::OpImageSparseSampleImplicitLod: + case SPIRV::OpImageSparseSampleDrefImplicitLod: + case SPIRV::OpImageSparseSampleProjImplicitLod: + case SPIRV::OpImageSparseSampleProjDrefImplicitLod: + case SPIRV::OpImageSparseFetch: + case SPIRV::OpImageSparseGather: + case SPIRV::OpImageSparseDrefGather: + case SPIRV::OpImageSparseRead: + case SPIRV::OpImageSampleFootprintNV: + OS << ' '; + printImageOperand(MI, FirstVariableIndex, OS); + printRemainingVariableOps(MI, NumFixedOps + 1, OS); + break; + case SPIRV::OpCopyMemory: + case SPIRV::OpCopyMemorySized: { + const unsigned NumOps = MI->getNumOperands(); + for (unsigned i = NumFixedOps; i < NumOps; ++i) { + OS << ' '; + printMemoryOperand(MI, i, OS); + if (MI->getOperand(i).getImm() & + static_cast<unsigned>(SPIRV::MemoryOperand::Aligned)) { + assert(i + 1 < NumOps && "Missing alignment operand"); + OS << ' '; + printOperand(MI, i + 1, OS); + i += 1; + } + } + break; + } + case SPIRV::OpConstantI: + case SPIRV::OpConstantF: + printOpConstantVarOps(MI, NumFixedOps, OS); + break; + default: + printRemainingVariableOps(MI, NumFixedOps, OS); + break; + } + } + } + } + + printAnnotation(OS, Annot); +} + +void SPIRVInstPrinter::printOpExtInst(const MCInst *MI, raw_ostream &O) { + llvm_unreachable("Unimplemented printOpExtInst"); +} + +void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) { + // The fixed operands have already been printed, so just need to decide what + // type of decoration operands to print based on the Decoration type. + MCInstrDesc MCDesc = MII.get(MI->getOpcode()); + unsigned NumFixedOps = MCDesc.getNumOperands(); + + if (NumFixedOps != MI->getNumOperands()) { + auto DecOp = MI->getOperand(NumFixedOps - 1); + auto Dec = static_cast<SPIRV::Decoration>(DecOp.getImm()); + + O << ' '; + + switch (Dec) { + case SPIRV::Decoration::BuiltIn: + printBuiltIn(MI, NumFixedOps, O); + break; + case SPIRV::Decoration::UniformId: + printScope(MI, NumFixedOps, O); + break; + case SPIRV::Decoration::FuncParamAttr: + printFunctionParameterAttribute(MI, NumFixedOps, O); + break; + case SPIRV::Decoration::FPRoundingMode: + printFPRoundingMode(MI, NumFixedOps, O); + break; + case SPIRV::Decoration::FPFastMathMode: + printFPFastMathMode(MI, NumFixedOps, O); + break; + case SPIRV::Decoration::LinkageAttributes: + case SPIRV::Decoration::UserSemantic: + printStringImm(MI, NumFixedOps, O); + break; + default: + printRemainingVariableOps(MI, NumFixedOps, O, true); + break; + } + } +} + +static void printExpr(const MCExpr *Expr, raw_ostream &O) { +#ifndef NDEBUG + const MCSymbolRefExpr *SRE; + + if (const MCBinaryExpr *BE = dyn_cast<MCBinaryExpr>(Expr)) + SRE = cast<MCSymbolRefExpr>(BE->getLHS()); + else + SRE = cast<MCSymbolRefExpr>(Expr); + + MCSymbolRefExpr::VariantKind Kind = SRE->getKind(); + + assert(Kind == MCSymbolRefExpr::VK_None); +#endif + O << *Expr; +} + +void SPIRVInstPrinter::printOperand(const MCInst *MI, unsigned OpNo, + raw_ostream &O, const char *Modifier) { + assert((Modifier == 0 || Modifier[0] == 0) && "No modifiers supported"); + if (OpNo < MI->getNumOperands()) { + const MCOperand &Op = MI->getOperand(OpNo); + if (Op.isReg()) + O << '%' << (Register::virtReg2Index(Op.getReg()) + 1); + else if (Op.isImm()) + O << formatImm((int64_t)Op.getImm()); + else if (Op.isDFPImm()) + O << formatImm((double)Op.getDFPImm()); + else if (Op.isExpr()) + printExpr(Op.getExpr(), O); + else + llvm_unreachable("Unexpected operand type"); + } +} + +void SPIRVInstPrinter::printStringImm(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + const unsigned NumOps = MI->getNumOperands(); + unsigned StrStartIndex = OpNo; + while (StrStartIndex < NumOps) { + if (MI->getOperand(StrStartIndex).isReg()) + break; + + std::string Str = getSPIRVStringOperand(*MI, OpNo); + if (StrStartIndex != OpNo) + O << ' '; // Add a space if we're starting a new string/argument. + O << '"'; + for (char c : Str) { + if (c == '"') + O.write('\\'); // Escape " characters (might break for complex UTF-8). + O.write(c); + } + O << '"'; + + unsigned numOpsInString = (Str.size() / 4) + 1; + StrStartIndex += numOpsInString; + + // Check for final Op of "OpDecorate %x %stringImm %linkageAttribute". + if (MI->getOpcode() == SPIRV::OpDecorate && + MI->getOperand(1).getImm() == + static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { + O << ' '; + printLinkageType(MI, StrStartIndex, O); + break; + } + } +} + +void SPIRVInstPrinter::printExtInst(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + llvm_unreachable("Unimplemented printExtInst"); +} + +void SPIRVInstPrinter::printCapability(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::Capability e = + static_cast<SPIRV::Capability>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getCapabilityName(e); + } +} + +void SPIRVInstPrinter::printSourceLanguage(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::SourceLanguage e = + static_cast<SPIRV::SourceLanguage>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getSourceLanguageName(e); + } +} + +void SPIRVInstPrinter::printExecutionModel(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::ExecutionModel e = + static_cast<SPIRV::ExecutionModel>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getExecutionModelName(e); + } +} + +void SPIRVInstPrinter::printAddressingModel(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::AddressingModel e = + static_cast<SPIRV::AddressingModel>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getAddressingModelName(e); + } +} + +void SPIRVInstPrinter::printMemoryModel(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::MemoryModel e = + static_cast<SPIRV::MemoryModel>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getMemoryModelName(e); + } +} + +void SPIRVInstPrinter::printExecutionMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::ExecutionMode e = + static_cast<SPIRV::ExecutionMode>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getExecutionModeName(e); + } +} + +void SPIRVInstPrinter::printStorageClass(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::StorageClass e = + static_cast<SPIRV::StorageClass>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getStorageClassName(e); + } +} + +void SPIRVInstPrinter::printDim(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::Dim e = static_cast<SPIRV::Dim>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getDimName(e); + } +} + +void SPIRVInstPrinter::printSamplerAddressingMode(const MCInst *MI, + unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::SamplerAddressingMode e = static_cast<SPIRV::SamplerAddressingMode>( + MI->getOperand(OpNo).getImm()); + O << SPIRV::getSamplerAddressingModeName(e); + } +} + +void SPIRVInstPrinter::printSamplerFilterMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::SamplerFilterMode e = + static_cast<SPIRV::SamplerFilterMode>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getSamplerFilterModeName(e); + } +} + +void SPIRVInstPrinter::printImageFormat(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::ImageFormat e = + static_cast<SPIRV::ImageFormat>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getImageFormatName(e); + } +} + +void SPIRVInstPrinter::printImageChannelOrder(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::ImageChannelOrder e = + static_cast<SPIRV::ImageChannelOrder>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getImageChannelOrderName(e); + } +} + +void SPIRVInstPrinter::printImageChannelDataType(const MCInst *MI, + unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::ImageChannelDataType e = + static_cast<SPIRV::ImageChannelDataType>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getImageChannelDataTypeName(e); + } +} + +void SPIRVInstPrinter::printImageOperand(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getImageOperandName(e); + } +} + +void SPIRVInstPrinter::printFPFastMathMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getFPFastMathModeName(e); + } +} + +void SPIRVInstPrinter::printFPRoundingMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::FPRoundingMode e = + static_cast<SPIRV::FPRoundingMode>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getFPRoundingModeName(e); + } +} + +void SPIRVInstPrinter::printLinkageType(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::LinkageType e = + static_cast<SPIRV::LinkageType>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getLinkageTypeName(e); + } +} + +void SPIRVInstPrinter::printAccessQualifier(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::AccessQualifier e = + static_cast<SPIRV::AccessQualifier>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getAccessQualifierName(e); + } +} + +void SPIRVInstPrinter::printFunctionParameterAttribute(const MCInst *MI, + unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::FunctionParameterAttribute e = + static_cast<SPIRV::FunctionParameterAttribute>( + MI->getOperand(OpNo).getImm()); + O << SPIRV::getFunctionParameterAttributeName(e); + } +} + +void SPIRVInstPrinter::printDecoration(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::Decoration e = + static_cast<SPIRV::Decoration>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getDecorationName(e); + } +} + +void SPIRVInstPrinter::printBuiltIn(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::BuiltIn e = + static_cast<SPIRV::BuiltIn>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getBuiltInName(e); + } +} + +void SPIRVInstPrinter::printSelectionControl(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getSelectionControlName(e); + } +} + +void SPIRVInstPrinter::printLoopControl(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getLoopControlName(e); + } +} + +void SPIRVInstPrinter::printFunctionControl(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getFunctionControlName(e); + } +} + +void SPIRVInstPrinter::printMemorySemantics(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getMemorySemanticsName(e); + } +} + +void SPIRVInstPrinter::printMemoryOperand(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + unsigned e = static_cast<unsigned>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getMemoryOperandName(e); + } +} + +void SPIRVInstPrinter::printScope(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::Scope e = static_cast<SPIRV::Scope>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getScopeName(e); + } +} + +void SPIRVInstPrinter::printGroupOperation(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::GroupOperation e = + static_cast<SPIRV::GroupOperation>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getGroupOperationName(e); + } +} + +void SPIRVInstPrinter::printKernelEnqueueFlags(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::KernelEnqueueFlags e = + static_cast<SPIRV::KernelEnqueueFlags>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getKernelEnqueueFlagsName(e); + } +} + +void SPIRVInstPrinter::printKernelProfilingInfo(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + if (OpNo < MI->getNumOperands()) { + SPIRV::KernelProfilingInfo e = + static_cast<SPIRV::KernelProfilingInfo>(MI->getOperand(OpNo).getImm()); + O << SPIRV::getKernelProfilingInfoName(e); + } +} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h new file mode 100644 index 000000000000..cd3b6f1e6d66 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h @@ -0,0 +1,94 @@ +//===-- SPIRVInstPrinter.h - Output SPIR-V MCInsts as ASM -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This class prints a SPIR-V MCInst to a .s file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_INSTPRINTER_SPIRVINSTPRINTER_H +#define LLVM_LIB_TARGET_SPIRV_INSTPRINTER_SPIRVINSTPRINTER_H + +#include "llvm/MC/MCInstPrinter.h" + +namespace llvm { +class SPIRVInstPrinter : public MCInstPrinter { +private: + void recordOpExtInstImport(const MCInst *MI); + +public: + using MCInstPrinter::MCInstPrinter; + + void printInst(const MCInst *MI, uint64_t Address, StringRef Annot, + const MCSubtargetInfo &STI, raw_ostream &OS) override; + void printOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O, + const char *Modifier = nullptr); + + void printStringImm(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printOpDecorate(const MCInst *MI, raw_ostream &O); + void printOpExtInst(const MCInst *MI, raw_ostream &O); + void printRemainingVariableOps(const MCInst *MI, unsigned StartIndex, + raw_ostream &O, bool SkipFirstSpace = false, + bool SkipImmediates = false); + void printOpConstantVarOps(const MCInst *MI, unsigned StartIndex, + raw_ostream &O); + + void printExtInst(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + // SPIR-V enumerations printing. + void printCapability(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printSourceLanguage(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printExecutionModel(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printAddressingModel(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printMemoryModel(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printExecutionMode(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printStorageClass(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printDim(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printSamplerAddressingMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O); + void printSamplerFilterMode(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printImageFormat(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printImageChannelOrder(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printImageChannelDataType(const MCInst *MI, unsigned OpNo, + raw_ostream &O); + void printImageOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printFPFastMathMode(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printFPRoundingMode(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printLinkageType(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printAccessQualifier(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printFunctionParameterAttribute(const MCInst *MI, unsigned OpNo, + raw_ostream &O); + + void printDecoration(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printBuiltIn(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printSelectionControl(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printLoopControl(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printFunctionControl(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printMemorySemantics(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printMemoryOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printScope(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printGroupOperation(const MCInst *MI, unsigned OpNo, raw_ostream &O); + + void printKernelEnqueueFlags(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printKernelProfilingInfo(const MCInst *MI, unsigned OpNo, + raw_ostream &O); + // Autogenerated by tblgen. + std::pair<const char *, uint64_t> getMnemonic(const MCInst *MI) override; + void printInstruction(const MCInst *MI, uint64_t Address, raw_ostream &O); + static const char *getRegisterName(unsigned RegNo); +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_INSTPRINTER_SPIRVINSTPRINTER_H diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.cpp new file mode 100644 index 000000000000..2f3462f419e5 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.cpp @@ -0,0 +1,34 @@ +//===-- SPIRVMCAsmInfo.h - SPIR-V asm properties --------------*- C++ -*--====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the SPIRVMCAsmInfo properties. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVMCAsmInfo.h" +#include "llvm/ADT/Triple.h" + +using namespace llvm; + +SPIRVMCAsmInfo::SPIRVMCAsmInfo(const Triple &TT, + const MCTargetOptions &Options) { + IsLittleEndian = true; + + HasSingleParameterDotFile = false; + HasDotTypeDotSizeDirective = false; + + MinInstAlignment = 4; + + CodePointerSize = 4; + CommentString = ";"; + HasFunctionAlignment = false; +} + +bool SPIRVMCAsmInfo::shouldOmitSectionDirective(StringRef SectionName) const { + return true; +} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.h new file mode 100644 index 000000000000..08e579e1c32c --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCAsmInfo.h @@ -0,0 +1,29 @@ +//===-- SPIRVMCAsmInfo.h - SPIR-V asm properties --------------*- C++ -*--====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declaration of the SPIRVMCAsmInfo class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCASMINFO_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCASMINFO_H + +#include "llvm/MC/MCAsmInfo.h" + +namespace llvm { + +class Triple; + +class SPIRVMCAsmInfo : public MCAsmInfo { +public: + explicit SPIRVMCAsmInfo(const Triple &TT, const MCTargetOptions &Options); + bool shouldOmitSectionDirective(StringRef SectionName) const override; +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCASMINFO_H diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp new file mode 100644 index 000000000000..d953bc590473 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp @@ -0,0 +1,132 @@ +//===-- SPIRVMCCodeEmitter.cpp - Emit SPIR-V machine code -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SPIRVMCCodeEmitter class. +// +//===----------------------------------------------------------------------===// + +#include "MCTargetDesc/SPIRVMCTargetDesc.h" +#include "llvm/CodeGen/Register.h" +#include "llvm/MC/MCCodeEmitter.h" +#include "llvm/MC/MCFixup.h" +#include "llvm/MC/MCInst.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCRegisterInfo.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/EndianStream.h" + +using namespace llvm; + +#define DEBUG_TYPE "spirv-mccodeemitter" + +namespace { + +class SPIRVMCCodeEmitter : public MCCodeEmitter { + const MCInstrInfo &MCII; + +public: + SPIRVMCCodeEmitter(const MCInstrInfo &mcii) : MCII(mcii) {} + SPIRVMCCodeEmitter(const SPIRVMCCodeEmitter &) = delete; + void operator=(const SPIRVMCCodeEmitter &) = delete; + ~SPIRVMCCodeEmitter() override = default; + + // getBinaryCodeForInstr - TableGen'erated function for getting the + // binary encoding for an instruction. + uint64_t getBinaryCodeForInstr(const MCInst &MI, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const; + + void encodeInstruction(const MCInst &MI, raw_ostream &OS, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const override; + +private: + FeatureBitset computeAvailableFeatures(const FeatureBitset &FB) const; + void + verifyInstructionPredicates(const MCInst &MI, + const FeatureBitset &AvailableFeatures) const; +}; + +} // end anonymous namespace + +MCCodeEmitter *llvm::createSPIRVMCCodeEmitter(const MCInstrInfo &MCII, + MCContext &Ctx) { + return new SPIRVMCCodeEmitter(MCII); +} + +using EndianWriter = support::endian::Writer; + +// Check if the instruction has a type argument for operand 1, and defines an ID +// output register in operand 0. If so, we need to swap operands 0 and 1 so the +// type comes first in the output, despide coming second in the MCInst. +static bool hasType(const MCInst &MI, const MCInstrInfo &MII) { + MCInstrDesc MCDesc = MII.get(MI.getOpcode()); + // If we define an output, and have at least one other argument. + if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) { + // Check if we define an ID, and take a type as operand 1. + auto DefOpInfo = MCDesc.opInfo_begin(); + auto FirstArgOpInfo = MCDesc.opInfo_begin() + 1; + return (DefOpInfo->RegClass == SPIRV::IDRegClassID || + DefOpInfo->RegClass == SPIRV::ANYIDRegClassID) && + FirstArgOpInfo->RegClass == SPIRV::TYPERegClassID; + } + return false; +} + +static void emitOperand(const MCOperand &Op, EndianWriter &OSE) { + if (Op.isReg()) { + // Emit the id index starting at 1 (0 is an invalid index). + OSE.write<uint32_t>(Register::virtReg2Index(Op.getReg()) + 1); + } else if (Op.isImm()) { + OSE.write<uint32_t>(Op.getImm()); + } else { + llvm_unreachable("Unexpected operand type in VReg"); + } +} + +// Emit the type in operand 1 before the ID in operand 0 it defines, and all +// remaining operands in the order they come naturally. +static void emitTypedInstrOperands(const MCInst &MI, EndianWriter &OSE) { + unsigned NumOps = MI.getNumOperands(); + emitOperand(MI.getOperand(1), OSE); + emitOperand(MI.getOperand(0), OSE); + for (unsigned i = 2; i < NumOps; ++i) + emitOperand(MI.getOperand(i), OSE); +} + +// Emit operands in the order they come naturally. +static void emitUntypedInstrOperands(const MCInst &MI, EndianWriter &OSE) { + for (const auto &Op : MI) + emitOperand(Op, OSE); +} + +void SPIRVMCCodeEmitter::encodeInstruction(const MCInst &MI, raw_ostream &OS, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const { + auto Features = computeAvailableFeatures(STI.getFeatureBits()); + verifyInstructionPredicates(MI, Features); + + EndianWriter OSE(OS, support::little); + + // Encode the first 32 SPIR-V bytes with the number of args and the opcode. + const uint64_t OpCode = getBinaryCodeForInstr(MI, Fixups, STI); + const uint32_t NumWords = MI.getNumOperands() + 1; + const uint32_t FirstWord = (NumWords << 16) | OpCode; + OSE.write<uint32_t>(FirstWord); + + // Emit the instruction arguments (emitting the output type first if present). + if (hasType(MI, MCII)) + emitTypedInstrOperands(MI, OSE); + else + emitUntypedInstrOperands(MI, OSE); +} + +#define ENABLE_INSTR_PREDICATE_VERIFIER +#include "SPIRVGenMCCodeEmitter.inc" diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.cpp new file mode 100644 index 000000000000..6b8b4a73af92 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.cpp @@ -0,0 +1,102 @@ +//===-- SPIRVMCTargetDesc.cpp - SPIR-V Target Descriptions ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides SPIR-V specific target descriptions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVMCTargetDesc.h" +#include "SPIRVInstPrinter.h" +#include "SPIRVMCAsmInfo.h" +#include "SPIRVTargetStreamer.h" +#include "TargetInfo/SPIRVTargetInfo.h" +#include "llvm/MC/MCInstrAnalysis.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCRegisterInfo.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" + +#define GET_INSTRINFO_MC_DESC +#include "SPIRVGenInstrInfo.inc" + +#define GET_SUBTARGETINFO_MC_DESC +#include "SPIRVGenSubtargetInfo.inc" + +#define GET_REGINFO_MC_DESC +#include "SPIRVGenRegisterInfo.inc" + +using namespace llvm; + +static MCInstrInfo *createSPIRVMCInstrInfo() { + MCInstrInfo *X = new MCInstrInfo(); + InitSPIRVMCInstrInfo(X); + return X; +} + +static MCRegisterInfo *createSPIRVMCRegisterInfo(const Triple &TT) { + MCRegisterInfo *X = new MCRegisterInfo(); + return X; +} + +static MCSubtargetInfo * +createSPIRVMCSubtargetInfo(const Triple &TT, StringRef CPU, StringRef FS) { + return createSPIRVMCSubtargetInfoImpl(TT, CPU, /*TuneCPU*/ CPU, FS); +} + +static MCStreamer * +createSPIRVMCStreamer(const Triple &T, MCContext &Ctx, + std::unique_ptr<MCAsmBackend> &&MAB, + std::unique_ptr<MCObjectWriter> &&OW, + std::unique_ptr<MCCodeEmitter> &&Emitter, bool RelaxAll) { + return createSPIRVStreamer(Ctx, std::move(MAB), std::move(OW), + std::move(Emitter), RelaxAll); +} + +static MCTargetStreamer *createTargetAsmStreamer(MCStreamer &S, + formatted_raw_ostream &, + MCInstPrinter *, bool) { + return new SPIRVTargetStreamer(S); +} + +static MCInstPrinter *createSPIRVMCInstPrinter(const Triple &T, + unsigned SyntaxVariant, + const MCAsmInfo &MAI, + const MCInstrInfo &MII, + const MCRegisterInfo &MRI) { + assert(SyntaxVariant == 0); + return new SPIRVInstPrinter(MAI, MII, MRI); +} + +namespace { + +class SPIRVMCInstrAnalysis : public MCInstrAnalysis { +public: + explicit SPIRVMCInstrAnalysis(const MCInstrInfo *Info) + : MCInstrAnalysis(Info) {} +}; + +} // end anonymous namespace + +static MCInstrAnalysis *createSPIRVInstrAnalysis(const MCInstrInfo *Info) { + return new SPIRVMCInstrAnalysis(Info); +} + +extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTargetMC() { + for (Target *T : {&getTheSPIRV32Target(), &getTheSPIRV64Target()}) { + RegisterMCAsmInfo<SPIRVMCAsmInfo> X(*T); + TargetRegistry::RegisterMCInstrInfo(*T, createSPIRVMCInstrInfo); + TargetRegistry::RegisterMCRegInfo(*T, createSPIRVMCRegisterInfo); + TargetRegistry::RegisterMCSubtargetInfo(*T, createSPIRVMCSubtargetInfo); + TargetRegistry::RegisterSPIRVStreamer(*T, createSPIRVMCStreamer); + TargetRegistry::RegisterMCInstPrinter(*T, createSPIRVMCInstPrinter); + TargetRegistry::RegisterMCInstrAnalysis(*T, createSPIRVInstrAnalysis); + TargetRegistry::RegisterMCCodeEmitter(*T, createSPIRVMCCodeEmitter); + TargetRegistry::RegisterMCAsmBackend(*T, createSPIRVAsmBackend); + TargetRegistry::RegisterAsmTargetStreamer(*T, createTargetAsmStreamer); + } +} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.h new file mode 100644 index 000000000000..4009fa96aa68 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.h @@ -0,0 +1,52 @@ +//===-- SPIRVMCTargetDesc.h - SPIR-V Target Descriptions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides SPIR-V specific target descriptions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H + +#include "llvm/Support/DataTypes.h" +#include <memory> + +namespace llvm { +class MCAsmBackend; +class MCCodeEmitter; +class MCContext; +class MCInstrInfo; +class MCObjectTargetWriter; +class MCRegisterInfo; +class MCSubtargetInfo; +class MCTargetOptions; +class Target; + +MCCodeEmitter *createSPIRVMCCodeEmitter(const MCInstrInfo &MCII, + MCContext &Ctx); + +MCAsmBackend *createSPIRVAsmBackend(const Target &T, const MCSubtargetInfo &STI, + const MCRegisterInfo &MRI, + const MCTargetOptions &Options); + +std::unique_ptr<MCObjectTargetWriter> createSPIRVObjectTargetWriter(); +} // namespace llvm + +// Defines symbolic names for SPIR-V registers. This defines a mapping from +// register name to register number. +#define GET_REGINFO_ENUM +#include "SPIRVGenRegisterInfo.inc" + +// Defines symbolic names for the SPIR-V instructions. +#define GET_INSTRINFO_ENUM +#include "SPIRVGenInstrInfo.inc" + +#define GET_SUBTARGETINFO_ENUM +#include "SPIRVGenSubtargetInfo.inc" + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVObjectTargetWriter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVObjectTargetWriter.cpp new file mode 100644 index 000000000000..685168b4073d --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVObjectTargetWriter.cpp @@ -0,0 +1,25 @@ +//===- SPIRVObjectTargetWriter.cpp - SPIR-V Object Target Writer *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SPIRVMCTargetDesc.h" +#include "llvm/MC/MCSPIRVObjectWriter.h" + +using namespace llvm; + +namespace { + +class SPIRVObjectTargetWriter : public MCSPIRVObjectTargetWriter { +public: + SPIRVObjectTargetWriter() = default; +}; + +} // namespace + +std::unique_ptr<MCObjectTargetWriter> llvm::createSPIRVObjectTargetWriter() { + return std::make_unique<SPIRVObjectTargetWriter>(); +} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.cpp new file mode 100644 index 000000000000..0a318e0e01e5 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.cpp @@ -0,0 +1,18 @@ +//=====- SPIRVTargetStreamer.cpp - SPIRVTargetStreamer class ------------=====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SPIRVTargetStreamer class. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVTargetStreamer.h" + +using namespace llvm; + +SPIRVTargetStreamer::SPIRVTargetStreamer(MCStreamer &S) : MCTargetStreamer(S) {} +SPIRVTargetStreamer::~SPIRVTargetStreamer() {} diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h new file mode 100644 index 000000000000..2cc8f50aba67 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h @@ -0,0 +1,28 @@ +//===-- SPIRVTargetStreamer.h - SPIRV Target Streamer ----------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVTARGETSTREAMER_H +#define LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVTARGETSTREAMER_H + +#include "llvm/MC/MCStreamer.h" + +namespace llvm { + +class MCSection; + +class SPIRVTargetStreamer : public MCTargetStreamer { +public: + SPIRVTargetStreamer(MCStreamer &S); + ~SPIRVTargetStreamer() override; + + void changeSection(const MCSection *CurSection, MCSection *Section, + const MCExpr *SubSection, raw_ostream &OS) override{}; +}; +} // namespace llvm + +#endif // LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVTARGETSTREAMER_H_ diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h new file mode 100644 index 000000000000..8da54a5d6e61 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -0,0 +1,34 @@ +//===-- SPIRV.h - Top-level interface for SPIR-V representation -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRV_H +#define LLVM_LIB_TARGET_SPIRV_SPIRV_H + +#include "MCTargetDesc/SPIRVMCTargetDesc.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/Target/TargetMachine.h" + +namespace llvm { +class SPIRVTargetMachine; +class SPIRVSubtarget; +class InstructionSelector; +class RegisterBankInfo; + +FunctionPass *createSPIRVPreLegalizerPass(); +FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const RegisterBankInfo &RBI); + +void initializeSPIRVModuleAnalysisPass(PassRegistry &); +void initializeSPIRVPreLegalizerPass(PassRegistry &); +void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H diff --git a/llvm/lib/Target/SPIRV/SPIRV.td b/llvm/lib/Target/SPIRV/SPIRV.td new file mode 100644 index 000000000000..27374acb8882 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRV.td @@ -0,0 +1,43 @@ +//===-- SPIRV.td - Describe the SPIR-V Target Machine ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "llvm/Target/Target.td" + +include "SPIRVRegisterInfo.td" +include "SPIRVRegisterBanks.td" +include "SPIRVInstrInfo.td" + +def SPIRVInstrInfo : InstrInfo; + +class Proc<string Name, list<SubtargetFeature> Features> + : Processor<Name, NoItineraries, Features>; + +def : Proc<"generic", []>; + +def SPIRV10 : SubtargetFeature<"spirv1.0", "SPIRVVersion", "10", + "Use SPIR-V version 1.0">; +def SPIRV11 : SubtargetFeature<"spirv1.1", "SPIRVVersion", "11", + "Use SPIR-V version 1.1">; +def SPIRV12 : SubtargetFeature<"spirv1.2", "SPIRVVersion", "12", + "Use SPIR-V version 1.2">; +def SPIRV13 : SubtargetFeature<"spirv1.3", "SPIRVVersion", "13", + "Use SPIR-V version 1.3">; +def SPIRV14 : SubtargetFeature<"spirv1.4", "SPIRVVersion", "14", + "Use SPIR-V version 1.4">; +def SPIRV15 : SubtargetFeature<"spirv1.5", "SPIRVVersion", "15", + "Use SPIR-V version 1.5">; + +def SPIRVInstPrinter : AsmWriter { + string AsmWriterClassName = "InstPrinter"; + bit isMCAsmWriter = 1; +} + +def SPIRV : Target { + let InstructionSet = SPIRVInstrInfo; + let AssemblyWriters = [SPIRVInstPrinter]; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp new file mode 100644 index 000000000000..0de232651377 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -0,0 +1,348 @@ +//===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a printer that converts from our internal representation +// of machine-dependent LLVM code to the SPIR-V assembly language. +// +//===----------------------------------------------------------------------===// + +#include "MCTargetDesc/SPIRVInstPrinter.h" +#include "SPIRV.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVMCInstLower.h" +#include "SPIRVModuleAnalysis.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "TargetInfo/SPIRVTargetInfo.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/CodeGen/AsmPrinter.h" +#include "llvm/CodeGen/MachineConstantPool.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" +#include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCInst.h" +#include "llvm/MC/MCStreamer.h" +#include "llvm/MC/MCSymbol.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "asm-printer" + +namespace { +class SPIRVAsmPrinter : public AsmPrinter { +public: + explicit SPIRVAsmPrinter(TargetMachine &TM, + std::unique_ptr<MCStreamer> Streamer) + : AsmPrinter(TM, std::move(Streamer)), ST(nullptr), TII(nullptr) {} + bool ModuleSectionsEmitted; + const SPIRVSubtarget *ST; + const SPIRVInstrInfo *TII; + + StringRef getPassName() const override { return "SPIRV Assembly Printer"; } + void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O); + bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, + const char *ExtraCode, raw_ostream &O) override; + + void outputMCInst(MCInst &Inst); + void outputInstruction(const MachineInstr *MI); + void outputModuleSection(SPIRV::ModuleSectionType MSType); + void outputEntryPoints(); + void outputDebugSourceAndStrings(const Module &M); + void outputOpMemoryModel(); + void outputOpFunctionEnd(); + void outputExtFuncDecls(); + void outputModuleSections(); + + void emitInstruction(const MachineInstr *MI) override; + void emitFunctionEntryLabel() override {} + void emitFunctionHeader() override; + void emitFunctionBodyStart() override {} + void emitFunctionBodyEnd() override; + void emitBasicBlockStart(const MachineBasicBlock &MBB) override; + void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {} + void emitGlobalVariable(const GlobalVariable *GV) override {} + void emitOpLabel(const MachineBasicBlock &MBB); + void emitEndOfAsmFile(Module &M) override; + bool doInitialization(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + SPIRV::ModuleAnalysisInfo *MAI; +}; +} // namespace + +void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<SPIRVModuleAnalysis>(); + AU.addPreserved<SPIRVModuleAnalysis>(); + AsmPrinter::getAnalysisUsage(AU); +} + +// If the module has no functions, we need output global info anyway. +void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) { + if (ModuleSectionsEmitted == false) { + outputModuleSections(); + ModuleSectionsEmitted = true; + } +} + +void SPIRVAsmPrinter::emitFunctionHeader() { + if (ModuleSectionsEmitted == false) { + outputModuleSections(); + ModuleSectionsEmitted = true; + } + // Get the subtarget from the current MachineFunction. + ST = &MF->getSubtarget<SPIRVSubtarget>(); + TII = ST->getInstrInfo(); + const Function &F = MF->getFunction(); + + if (isVerbose()) { + OutStreamer->getCommentOS() + << "-- Begin function " + << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n'; + } + + auto Section = getObjFileLowering().SectionForGlobal(&F, TM); + MF->setSection(Section); +} + +void SPIRVAsmPrinter::outputOpFunctionEnd() { + MCInst FunctionEndInst; + FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd); + outputMCInst(FunctionEndInst); +} + +// Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap. +void SPIRVAsmPrinter::emitFunctionBodyEnd() { + outputOpFunctionEnd(); + MAI->BBNumToRegMap.clear(); +} + +void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { + MCInst LabelInst; + LabelInst.setOpcode(SPIRV::OpLabel); + LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); + outputMCInst(LabelInst); +} + +void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { + // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so + // OpLabel should be output after them. + if (MBB.getNumber() == MF->front().getNumber()) { + for (const MachineInstr &MI : MBB) + if (MI.getOpcode() == SPIRV::OpFunction) + return; + // TODO: this case should be checked by the verifier. + report_fatal_error("OpFunction is expected in the front MBB of MF"); + } + emitOpLabel(MBB); +} + +void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum, + raw_ostream &O) { + const MachineOperand &MO = MI->getOperand(OpNum); + + switch (MO.getType()) { + case MachineOperand::MO_Register: + O << SPIRVInstPrinter::getRegisterName(MO.getReg()); + break; + + case MachineOperand::MO_Immediate: + O << MO.getImm(); + break; + + case MachineOperand::MO_FPImmediate: + O << MO.getFPImm(); + break; + + case MachineOperand::MO_MachineBasicBlock: + O << *MO.getMBB()->getSymbol(); + break; + + case MachineOperand::MO_GlobalAddress: + O << *getSymbol(MO.getGlobal()); + break; + + case MachineOperand::MO_BlockAddress: { + MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress()); + O << BA->getName(); + break; + } + + case MachineOperand::MO_ExternalSymbol: + O << *GetExternalSymbolSymbol(MO.getSymbolName()); + break; + + case MachineOperand::MO_JumpTableIndex: + case MachineOperand::MO_ConstantPoolIndex: + default: + llvm_unreachable("<unknown operand type>"); + } +} + +bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, + const char *ExtraCode, raw_ostream &O) { + if (ExtraCode && ExtraCode[0]) + return true; // Invalid instruction - SPIR-V does not have special modifiers + + printOperand(MI, OpNo, O); + return false; +} + +static bool isFuncOrHeaderInstr(const MachineInstr *MI, + const SPIRVInstrInfo *TII) { + return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction || + MI->getOpcode() == SPIRV::OpFunctionParameter; +} + +void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) { + OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo()); +} + +void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) { + SPIRVMCInstLower MCInstLowering; + MCInst TmpInst; + MCInstLowering.lower(MI, TmpInst, MAI); + outputMCInst(TmpInst); +} + +void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) { + if (!MAI->getSkipEmission(MI)) + outputInstruction(MI); + + // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB. + const MachineInstr *NextMI = MI->getNextNode(); + if (!MAI->hasMBBRegister(*MI->getParent()) && isFuncOrHeaderInstr(MI, TII) && + (!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) { + assert(MI->getParent()->getNumber() == MF->front().getNumber() && + "OpFunction is not in the front MBB of MF"); + emitOpLabel(*MI->getParent()); + } +} + +void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { + for (MachineInstr *MI : MAI->getMSInstrs(MSType)) + outputInstruction(MI); +} + +void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) { + // Output OpSource. + MCInst Inst; + Inst.setOpcode(SPIRV::OpSource); + Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->SrcLang))); + Inst.addOperand( + MCOperand::createImm(static_cast<unsigned>(MAI->SrcLangVersion))); + outputMCInst(Inst); +} + +void SPIRVAsmPrinter::outputOpMemoryModel() { + MCInst Inst; + Inst.setOpcode(SPIRV::OpMemoryModel); + Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Addr))); + Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Mem))); + outputMCInst(Inst); +} + +// Before the OpEntryPoints' output, we need to add the entry point's +// interfaces. The interface is a list of IDs of global OpVariable instructions. +// These declare the set of global variables from a module that form +// the interface of this entry point. +void SPIRVAsmPrinter::outputEntryPoints() { + // Find all OpVariable IDs with required StorageClass. + DenseSet<Register> InterfaceIDs; + for (MachineInstr *MI : MAI->GlobalVarList) { + assert(MI->getOpcode() == SPIRV::OpVariable); + auto SC = static_cast<SPIRV::StorageClass>(MI->getOperand(2).getImm()); + // Before version 1.4, the interface's storage classes are limited to + // the Input and Output storage classes. Starting with version 1.4, + // the interface's storage classes are all storage classes used in + // declaring all global variables referenced by the entry point call tree. + if (ST->getSPIRVVersion() >= 14 || SC == SPIRV::StorageClass::Input || + SC == SPIRV::StorageClass::Output) { + MachineFunction *MF = MI->getMF(); + Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); + InterfaceIDs.insert(Reg); + } + } + + // Output OpEntryPoints adding interface args to all of them. + for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { + SPIRVMCInstLower MCInstLowering; + MCInst TmpInst; + MCInstLowering.lower(MI, TmpInst, MAI); + for (Register Reg : InterfaceIDs) { + assert(Reg.isValid()); + TmpInst.addOperand(MCOperand::createReg(Reg)); + } + outputMCInst(TmpInst); + } +} + +void SPIRVAsmPrinter::outputExtFuncDecls() { + // Insert OpFunctionEnd after each declaration. + SmallVectorImpl<MachineInstr *>::iterator + I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), + E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); + for (; I != E; ++I) { + outputInstruction(*I); + if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction) + outputOpFunctionEnd(); + } +} + +void SPIRVAsmPrinter::outputModuleSections() { + const Module *M = MMI->getModule(); + // Get the global subtarget to output module-level info. + ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl(); + TII = ST->getInstrInfo(); + MAI = &SPIRVModuleAnalysis::MAI; + assert(ST && TII && MAI && M && "Module analysis is required"); + // Output instructions according to the Logical Layout of a Module: + // TODO: 1,2. All OpCapability instructions, then optional OpExtension + // instructions. + // TODO: 3. Optional OpExtInstImport instructions. + // 4. The single required OpMemoryModel instruction. + outputOpMemoryModel(); + // 5. All entry point declarations, using OpEntryPoint. + outputEntryPoints(); + // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. + // TODO: + // 7a. Debug: all OpString, OpSourceExtension, OpSource, and + // OpSourceContinued, without forward references. + outputDebugSourceAndStrings(*M); + // 7b. Debug: all OpName and all OpMemberName. + outputModuleSection(SPIRV::MB_DebugNames); + // 7c. Debug: all OpModuleProcessed instructions. + outputModuleSection(SPIRV::MB_DebugModuleProcessed); + // 8. All annotation instructions (all decorations). + outputModuleSection(SPIRV::MB_Annotations); + // 9. All type declarations (OpTypeXXX instructions), all constant + // instructions, and all global variable declarations. This section is + // the first section to allow use of: OpLine and OpNoLine debug information; + // non-semantic instructions with OpExtInst. + outputModuleSection(SPIRV::MB_TypeConstVars); + // 10. All function declarations (functions without a body). + outputExtFuncDecls(); + // 11. All function definitions (functions with a body). + // This is done in regular function output. +} + +bool SPIRVAsmPrinter::doInitialization(Module &M) { + ModuleSectionsEmitted = false; + // We need to call the parent's one explicitly. + return AsmPrinter::doInitialization(M); +} + +// Force static initialization. +extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() { + RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target()); + RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target()); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp new file mode 100644 index 000000000000..df07a126eeea --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -0,0 +1,223 @@ +//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering of LLVM calls to machine code calls for +// GlobalISel. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVCallLowering.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVISelLowering.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVSubtarget.h" +#include "SPIRVUtils.h" +#include "llvm/CodeGen/FunctionLoweringInfo.h" + +using namespace llvm; + +SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, + const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR) + : CallLowering(&TLI), ST(ST), GR(GR) {} + +bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, + const Value *Val, ArrayRef<Register> VRegs, + FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const { + // Currently all return types should use a single register. + // TODO: handle the case of multiple registers. + if (VRegs.size() > 1) + return false; + if (Val) + return MIRBuilder.buildInstr(SPIRV::OpReturnValue) + .addUse(VRegs[0]) + .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); + MIRBuilder.buildInstr(SPIRV::OpReturn); + return true; +} + +// Based on the LLVM function attributes, get a SPIR-V FunctionControl. +static uint32_t getFunctionControl(const Function &F) { + uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); + if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { + FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { + FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { + FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); + } + if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { + FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); + } + return FuncControl; +} + +bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, + const Function &F, + ArrayRef<ArrayRef<Register>> VRegs, + FunctionLoweringInfo &FLI) const { + assert(GR && "Must initialize the SPIRV type registry before lowering args."); + + // Assign types and names to all args, and store their types for later. + SmallVector<Register, 4> ArgTypeVRegs; + if (VRegs.size() > 0) { + unsigned i = 0; + for (const auto &Arg : F.args()) { + // Currently formal args should use single registers. + // TODO: handle the case of multiple registers. + if (VRegs[i].size() > 1) + return false; + auto *SpirvTy = + GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); + ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); + + if (Arg.hasName()) + buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); + if (Arg.getType()->isPointerTy()) { + auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); + if (DerefBytes != 0) + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::MaxByteOffset, {DerefBytes}); + } + if (Arg.hasAttribute(Attribute::Alignment)) { + buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast<unsigned>(Arg.getParamAlignment())}); + } + if (Arg.hasAttribute(Attribute::ReadOnly)) { + auto Attr = + static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } + if (Arg.hasAttribute(Attribute::ZExt)) { + auto Attr = + static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } + ++i; + } + } + + // Generate a SPIR-V type for the function. + auto MRI = MIRBuilder.getMRI(); + Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + + auto *FTy = F.getFunctionType(); + auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); + + // Build the OpTypeFunction declaring it. + Register ReturnTypeID = FuncTy->getOperand(1).getReg(); + uint32_t FuncControl = getFunctionControl(F); + + MIRBuilder.buildInstr(SPIRV::OpFunction) + .addDef(FuncVReg) + .addUse(ReturnTypeID) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); + + // Add OpFunctionParameters. + const unsigned NumArgs = ArgTypeVRegs.size(); + for (unsigned i = 0; i < NumArgs; ++i) { + assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); + MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); + MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) + .addDef(VRegs[i][0]) + .addUse(ArgTypeVRegs[i]); + } + // Name the function. + if (F.hasName()) + buildOpName(FuncVReg, F.getName(), MIRBuilder); + + // Handle entry points and function linkage. + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) + .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel)) + .addUse(FuncVReg); + addStringImm(F.getName(), MIB); + } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || + F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { + auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); + } + + return true; +} + +bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const { + // Currently call returns should have single vregs. + // TODO: handle the case of multiple registers. + if (Info.OrigRet.Regs.size() > 1) + return false; + + Register ResVReg = + Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + // Emit a regular OpFunctionCall. If it's an externally declared function, + // be sure to emit its type and function declaration here. It will be + // hoisted globally later. + if (Info.Callee.isGlobal()) { + auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); + // TODO: support constexpr casts and indirect calls. + if (CF == nullptr) + return false; + if (CF->isDeclaration()) { + // Emit the type info and forward function declaration to the first MBB + // to ensure VReg definition dependencies are valid across all MBBs. + MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt(); + MachineBasicBlock &OldBB = MIRBuilder.getMBB(); + MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0); + MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end()); + + SmallVector<ArrayRef<Register>, 8> VRegArgs; + SmallVector<SmallVector<Register, 1>, 8> ToInsert; + for (const Argument &Arg : CF->args()) { + if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) + continue; // Don't handle zero sized types. + ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister( + LLT::scalar(32))}); + VRegArgs.push_back(ToInsert.back()); + } + // TODO: Reuse FunctionLoweringInfo. + FunctionLoweringInfo FuncInfo; + lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo); + MIRBuilder.setInsertPt(OldBB, OldII); + } + } + + // Make sure there's a valid return reg, even for functions returning void. + if (!ResVReg.isValid()) { + ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + } + SPIRVType *RetType = + GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); + + // Emit the OpFunctionCall and its args. + auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) + .addDef(ResVReg) + .addUse(GR->getSPIRVTypeID(RetType)) + .add(Info.Callee); + + for (const auto &Arg : Info.OrigArgs) { + // Currently call args should have single vregs. + if (Arg.Regs.size() > 1) + return false; + MIB.addUse(Arg.Regs[0]); + } + return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h new file mode 100644 index 000000000000..c179bb35154b --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h @@ -0,0 +1,50 @@ +//===--- SPIRVCallLowering.h - Call lowering --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file describes how to lower LLVM calls to machine code calls. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCALLLOWERING_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVCALLLOWERING_H + +#include "llvm/CodeGen/GlobalISel/CallLowering.h" + +namespace llvm { + +class SPIRVGlobalRegistry; +class SPIRVSubtarget; +class SPIRVTargetLowering; + +class SPIRVCallLowering : public CallLowering { +private: + const SPIRVSubtarget &ST; + // Used to create and assign function, argument, and return type information. + SPIRVGlobalRegistry *GR; + +public: + SPIRVCallLowering(const SPIRVTargetLowering &TLI, const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR); + + // Built OpReturn or OpReturnValue. + bool lowerReturn(MachineIRBuilder &MIRBuiler, const Value *Val, + ArrayRef<Register> VRegs, FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const override; + + // Build OpFunction, OpFunctionParameter, and any EntryPoint or Linkage data. + bool lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, + ArrayRef<ArrayRef<Register>> VRegs, + FunctionLoweringInfo &FLI) const override; + + // Build OpCall, or replace with a builtin function. + bool lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const override; +}; +} // end namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVCALLLOWERING_H diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp new file mode 100644 index 000000000000..9624482e3622 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -0,0 +1,433 @@ +//===-- SPIRVEmitIntrinsics.cpp - emit SPIRV intrinsics ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The pass emits SPIRV intrinsics keeping essential high-level information for +// the translation of LLVM IR to SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/IntrinsicsSPIRV.h" + +#include <queue> + +// This pass performs the following transformation on LLVM IR level required +// for the following translation to SPIR-V: +// - replaces direct usages of aggregate constants with target-specific +// intrinsics; +// - replaces aggregates-related instructions (extract/insert, ld/st, etc) +// with a target-specific intrinsics; +// - emits intrinsics for the global variable initializers since IRTranslator +// doesn't handle them and it's not very convenient to translate them +// ourselves; +// - emits intrinsics to keep track of the string names assigned to the values; +// - emits intrinsics to keep track of constants (this is necessary to have an +// LLVM IR constant after the IRTranslation is completed) for their further +// deduplication; +// - emits intrinsics to keep track of original LLVM types of the values +// to be able to emit proper SPIR-V types eventually. +// +// TODO: consider removing spv.track.constant in favor of spv.assign.type. + +using namespace llvm; + +namespace llvm { +void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); +} // namespace llvm + +namespace { +class SPIRVEmitIntrinsics + : public FunctionPass, + public InstVisitor<SPIRVEmitIntrinsics, Instruction *> { + SPIRVTargetMachine *TM = nullptr; + IRBuilder<> *IRB = nullptr; + Function *F = nullptr; + bool TrackConstants = true; + DenseMap<Instruction *, Constant *> AggrConsts; + DenseSet<Instruction *> AggrStores; + void preprocessCompositeConstants(); + CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types, + Value *Arg, Value *Arg2) { + ConstantAsMetadata *CM = ValueAsMetadata::getConstant(Arg); + MDTuple *TyMD = MDNode::get(F->getContext(), CM); + MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD); + return IRB->CreateIntrinsic(IntrID, {Types}, {Arg2, VMD}); + } + void replaceMemInstrUses(Instruction *Old, Instruction *New); + void processInstrAfterVisit(Instruction *I); + void insertAssignTypeIntrs(Instruction *I); + void processGlobalValue(GlobalVariable &GV); + +public: + static char ID; + SPIRVEmitIntrinsics() : FunctionPass(ID) { + initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry()); + } + SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) { + initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry()); + } + Instruction *visitInstruction(Instruction &I) { return &I; } + Instruction *visitSwitchInst(SwitchInst &I); + Instruction *visitGetElementPtrInst(GetElementPtrInst &I); + Instruction *visitBitCastInst(BitCastInst &I); + Instruction *visitInsertElementInst(InsertElementInst &I); + Instruction *visitExtractElementInst(ExtractElementInst &I); + Instruction *visitInsertValueInst(InsertValueInst &I); + Instruction *visitExtractValueInst(ExtractValueInst &I); + Instruction *visitLoadInst(LoadInst &I); + Instruction *visitStoreInst(StoreInst &I); + Instruction *visitAllocaInst(AllocaInst &I); + bool runOnFunction(Function &F) override; +}; +} // namespace + +char SPIRVEmitIntrinsics::ID = 0; + +INITIALIZE_PASS(SPIRVEmitIntrinsics, "emit-intrinsics", "SPIRV emit intrinsics", + false, false) + +static inline bool isAssignTypeInstr(const Instruction *I) { + return isa<IntrinsicInst>(I) && + cast<IntrinsicInst>(I)->getIntrinsicID() == Intrinsic::spv_assign_type; +} + +static bool isMemInstrToReplace(Instruction *I) { + return isa<StoreInst>(I) || isa<LoadInst>(I) || isa<InsertValueInst>(I) || + isa<ExtractValueInst>(I); +} + +static bool isAggrToReplace(const Value *V) { + return isa<ConstantAggregate>(V) || isa<ConstantDataArray>(V) || + (isa<ConstantAggregateZero>(V) && !V->getType()->isVectorTy()); +} + +static void setInsertPointSkippingPhis(IRBuilder<> &B, Instruction *I) { + if (isa<PHINode>(I)) + B.SetInsertPoint(I->getParent(), I->getParent()->getFirstInsertionPt()); + else + B.SetInsertPoint(I); +} + +static bool requireAssignType(Instruction *I) { + IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(I); + if (Intr) { + switch (Intr->getIntrinsicID()) { + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + return false; + } + } + return true; +} + +void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old, + Instruction *New) { + while (!Old->user_empty()) { + auto *U = Old->user_back(); + if (isMemInstrToReplace(U) || isa<ReturnInst>(U)) { + U->replaceUsesOfWith(Old, New); + } else if (isAssignTypeInstr(U)) { + IRB->SetInsertPoint(U); + SmallVector<Value *, 2> Args = {New, U->getOperand(1)}; + IRB->CreateIntrinsic(Intrinsic::spv_assign_type, {New->getType()}, Args); + U->eraseFromParent(); + } else { + llvm_unreachable("illegal aggregate intrinsic user"); + } + } + Old->eraseFromParent(); +} + +void SPIRVEmitIntrinsics::preprocessCompositeConstants() { + std::queue<Instruction *> Worklist; + for (auto &I : instructions(F)) + Worklist.push(&I); + + while (!Worklist.empty()) { + auto *I = Worklist.front(); + assert(I); + bool KeepInst = false; + for (const auto &Op : I->operands()) { + auto BuildCompositeIntrinsic = [&KeepInst, &Worklist, &I, &Op, + this](Constant *AggrC, + ArrayRef<Value *> Args) { + IRB->SetInsertPoint(I); + auto *CCI = + IRB->CreateIntrinsic(Intrinsic::spv_const_composite, {}, {Args}); + Worklist.push(CCI); + I->replaceUsesOfWith(Op, CCI); + KeepInst = true; + AggrConsts[CCI] = AggrC; + }; + + if (auto *AggrC = dyn_cast<ConstantAggregate>(Op)) { + SmallVector<Value *> Args(AggrC->op_begin(), AggrC->op_end()); + BuildCompositeIntrinsic(AggrC, Args); + } else if (auto *AggrC = dyn_cast<ConstantDataArray>(Op)) { + SmallVector<Value *> Args; + for (unsigned i = 0; i < AggrC->getNumElements(); ++i) + Args.push_back(AggrC->getElementAsConstant(i)); + BuildCompositeIntrinsic(AggrC, Args); + } else if (isa<ConstantAggregateZero>(Op) && + !Op->getType()->isVectorTy()) { + auto *AggrC = cast<ConstantAggregateZero>(Op); + SmallVector<Value *> Args(AggrC->op_begin(), AggrC->op_end()); + BuildCompositeIntrinsic(AggrC, Args); + } + } + if (!KeepInst) + Worklist.pop(); + } +} + +Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) { + SmallVector<Value *, 4> Args; + for (auto &Op : I.operands()) + if (Op.get()->getType()->isSized()) + Args.push_back(Op); + IRB->CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()}, + {Args}); + return &I; +} + +Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) { + SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()}; + SmallVector<Value *, 4> Args; + Args.push_back(IRB->getInt1(I.isInBounds())); + for (auto &Op : I.operands()) + Args.push_back(Op); + auto *NewI = IRB->CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) { + SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()}; + SmallVector<Value *> Args(I.op_begin(), I.op_end()); + auto *NewI = IRB->CreateIntrinsic(Intrinsic::spv_bitcast, {Types}, {Args}); + std::string InstName = I.hasName() ? I.getName().str() : ""; + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + NewI->setName(InstName); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) { + SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(), + I.getOperand(1)->getType(), + I.getOperand(2)->getType()}; + SmallVector<Value *> Args(I.op_begin(), I.op_end()); + auto *NewI = IRB->CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args}); + std::string InstName = I.hasName() ? I.getName().str() : ""; + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + NewI->setName(InstName); + return NewI; +} + +Instruction * +SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) { + SmallVector<Type *, 3> Types = {I.getType(), I.getVectorOperandType(), + I.getIndexOperand()->getType()}; + SmallVector<Value *, 2> Args = {I.getVectorOperand(), I.getIndexOperand()}; + auto *NewI = IRB->CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args}); + std::string InstName = I.hasName() ? I.getName().str() : ""; + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + NewI->setName(InstName); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitInsertValueInst(InsertValueInst &I) { + SmallVector<Type *, 1> Types = {I.getInsertedValueOperand()->getType()}; + SmallVector<Value *> Args; + for (auto &Op : I.operands()) + if (isa<UndefValue>(Op)) + Args.push_back(UndefValue::get(IRB->getInt32Ty())); + else + Args.push_back(Op); + for (auto &Op : I.indices()) + Args.push_back(IRB->getInt32(Op)); + Instruction *NewI = + IRB->CreateIntrinsic(Intrinsic::spv_insertv, {Types}, {Args}); + replaceMemInstrUses(&I, NewI); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitExtractValueInst(ExtractValueInst &I) { + SmallVector<Value *> Args; + for (auto &Op : I.operands()) + Args.push_back(Op); + for (auto &Op : I.indices()) + Args.push_back(IRB->getInt32(Op)); + auto *NewI = + IRB->CreateIntrinsic(Intrinsic::spv_extractv, {I.getType()}, {Args}); + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitLoadInst(LoadInst &I) { + if (!I.getType()->isAggregateType()) + return &I; + TrackConstants = false; + const auto *TLI = TM->getSubtargetImpl()->getTargetLowering(); + MachineMemOperand::Flags Flags = + TLI->getLoadMemOperandFlags(I, F->getParent()->getDataLayout()); + auto *NewI = + IRB->CreateIntrinsic(Intrinsic::spv_load, {I.getOperand(0)->getType()}, + {I.getPointerOperand(), IRB->getInt16(Flags), + IRB->getInt8(I.getAlign().value())}); + replaceMemInstrUses(&I, NewI); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitStoreInst(StoreInst &I) { + if (!AggrStores.contains(&I)) + return &I; + TrackConstants = false; + const auto *TLI = TM->getSubtargetImpl()->getTargetLowering(); + MachineMemOperand::Flags Flags = + TLI->getStoreMemOperandFlags(I, F->getParent()->getDataLayout()); + auto *PtrOp = I.getPointerOperand(); + auto *NewI = + IRB->CreateIntrinsic(Intrinsic::spv_store, {PtrOp->getType()}, + {I.getValueOperand(), PtrOp, IRB->getInt16(Flags), + IRB->getInt8(I.getAlign().value())}); + I.eraseFromParent(); + return NewI; +} + +Instruction *SPIRVEmitIntrinsics::visitAllocaInst(AllocaInst &I) { + TrackConstants = false; + return &I; +} + +void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV) { + // Skip special artifical variable llvm.global.annotations. + if (GV.getName() == "llvm.global.annotations") + return; + if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) { + Constant *Init = GV.getInitializer(); + Type *Ty = isAggrToReplace(Init) ? IRB->getInt32Ty() : Init->getType(); + Constant *Const = isAggrToReplace(Init) ? IRB->getInt32(1) : Init; + auto *InitInst = IRB->CreateIntrinsic(Intrinsic::spv_init_global, + {GV.getType(), Ty}, {&GV, Const}); + InitInst->setArgOperand(1, Init); + } + if ((!GV.hasInitializer() || isa<UndefValue>(GV.getInitializer())) && + GV.getNumUses() == 0) + IRB->CreateIntrinsic(Intrinsic::spv_unref_global, GV.getType(), &GV); +} + +void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I) { + Type *Ty = I->getType(); + if (!Ty->isVoidTy() && requireAssignType(I)) { + setInsertPointSkippingPhis(*IRB, I->getNextNode()); + Type *TypeToAssign = Ty; + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::spv_const_composite) { + auto t = AggrConsts.find(II); + assert(t != AggrConsts.end()); + TypeToAssign = t->second->getType(); + } + } + Constant *Const = Constant::getNullValue(TypeToAssign); + buildIntrWithMD(Intrinsic::spv_assign_type, {Ty}, Const, I); + } + for (const auto &Op : I->operands()) { + if (isa<ConstantPointerNull>(Op) || isa<UndefValue>(Op) || + // Check GetElementPtrConstantExpr case. + (isa<ConstantExpr>(Op) && isa<GEPOperator>(Op))) { + IRB->SetInsertPoint(I); + buildIntrWithMD(Intrinsic::spv_assign_type, {Op->getType()}, Op, Op); + } + } + // StoreInst's operand type can be changed in the next stage so we need to + // store it in the set. + if (isa<StoreInst>(I) && + cast<StoreInst>(I)->getValueOperand()->getType()->isAggregateType()) + AggrStores.insert(I); +} + +void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I) { + auto *II = dyn_cast<IntrinsicInst>(I); + if (II && II->getIntrinsicID() == Intrinsic::spv_const_composite && + TrackConstants) { + IRB->SetInsertPoint(I->getNextNode()); + Type *Ty = IRB->getInt32Ty(); + auto t = AggrConsts.find(I); + assert(t != AggrConsts.end()); + auto *NewOp = + buildIntrWithMD(Intrinsic::spv_track_constant, {Ty, Ty}, t->second, I); + I->replaceAllUsesWith(NewOp); + NewOp->setArgOperand(0, I); + } + for (const auto &Op : I->operands()) { + if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) || + isa<PHINode>(I) || isa<SwitchInst>(I)) + TrackConstants = false; + if (isa<ConstantData>(Op) && TrackConstants) { + unsigned OpNo = Op.getOperandNo(); + if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) || + (II->paramHasAttr(OpNo, Attribute::ImmArg)))) + continue; + IRB->SetInsertPoint(I); + auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant, + {Op->getType(), Op->getType()}, Op, Op); + I->setOperand(OpNo, NewOp); + } + } + if (I->hasName()) { + setInsertPointSkippingPhis(*IRB, I->getNextNode()); + std::vector<Value *> Args = {I}; + addStringImm(I->getName(), *IRB, Args); + IRB->CreateIntrinsic(Intrinsic::spv_assign_name, {I->getType()}, Args); + } +} + +bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { + if (Func.isDeclaration()) + return false; + F = &Func; + IRB = new IRBuilder<>(Func.getContext()); + AggrConsts.clear(); + AggrStores.clear(); + + IRB->SetInsertPoint(&Func.getEntryBlock().front()); + + for (auto &GV : Func.getParent()->globals()) + processGlobalValue(GV); + + preprocessCompositeConstants(); + SmallVector<Instruction *> Worklist; + for (auto &I : instructions(Func)) + Worklist.push_back(&I); + + for (auto &I : Worklist) + insertAssignTypeIntrs(I); + + for (auto *I : Worklist) { + TrackConstants = true; + if (!I->getType()->isVoidTy() || isa<StoreInst>(I)) + IRB->SetInsertPoint(I->getNextNode()); + I = visit(*I); + processInstrAfterVisit(I); + } + return true; +} + +FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) { + return new SPIRVEmitIntrinsics(TM); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVEnums.td b/llvm/lib/Target/SPIRV/SPIRVEnums.td new file mode 100644 index 000000000000..1d0c6ffd6e37 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVEnums.td @@ -0,0 +1,51 @@ +//===-- SPIRVEnums.td - Describe SPIRV Enum Operands -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// All SPIRV enums defined in SPIRVBaseInfo.h should have a corresponding enum +// operand here. This enables the correct PrintMethod to be defined so +// its name or mask bits can be automatically printed in SPIRVInstPrinter +// when referred to in SPIRVInstrInfo.td. +// +//===----------------------------------------------------------------------===// + +class EnumOperand<string Name> : Operand<i32>{ + let PrintMethod = "print"#Name; +} + +def ExtInst : EnumOperand<"ExtInst">; + +def Capability : EnumOperand<"Capability">; +def SourceLanguage : EnumOperand<"SourceLanguage">; +def ExecutionModel : EnumOperand<"ExecutionModel">; +def AddressingModel : EnumOperand<"AddressingModel">; +def MemoryModel : EnumOperand<"MemoryModel">; +def ExecutionMode : EnumOperand<"ExecutionMode">; +def StorageClass : EnumOperand<"StorageClass">; +def Dim : EnumOperand<"Dim">; +def SamplerAddressingMode : EnumOperand<"SamplerAddressingMode">; +def SamplerFilterMode : EnumOperand<"SamplerFilterMode">; +def ImageFormat : EnumOperand<"ImageFormat">; +def ImageChannelOrder : EnumOperand<"ImageChannelOrder">; +def ImageChannelDataType : EnumOperand<"ImageChannelDataType">; +def ImageOperand : EnumOperand<"ImageOperand">; +def FPFastMathMode : EnumOperand<"FPFastMathMode">; +def FProundingMode : EnumOperand<"FPRoundingMode">; +def LinkageType : EnumOperand<"LinkageType">; +def AccessQualifier : EnumOperand<"AccessQualifier">; +def FunctionParameterAttribute : EnumOperand<"FunctionParameterAttribute">; +def Decoration : EnumOperand<"Decoration">; +def Builtin : EnumOperand<"Builtin">; +def SelectionControl: EnumOperand<"SelectionControl">; +def LoopControl: EnumOperand<"LoopControl">; +def FunctionControl : EnumOperand<"FunctionControl">; +def MemorySemantics : EnumOperand<"MemorySemantics">; +def MemoryOperand : EnumOperand<"MemoryOperand">; +def Scope : EnumOperand<"Scope">; +def GroupOperation : EnumOperand<"GroupOperation">; +def KernelEnqueueFlags : EnumOperand<"KernelEnqueueFlags">; +def KernelProfilingInfo : EnumOperand<"KernelProfilingInfo">; diff --git a/llvm/lib/Target/SPIRV/SPIRVFrameLowering.h b/llvm/lib/Target/SPIRV/SPIRVFrameLowering.h new file mode 100644 index 000000000000..b98f8d0928e5 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVFrameLowering.h @@ -0,0 +1,39 @@ +//===-- SPIRVFrameLowering.h - Define frame lowering for SPIR-V -*- C++-*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This class implements SPIRV-specific bits of TargetFrameLowering class. +// The target uses only virtual registers. It does not operate with stack frame +// explicitly and does not generate prologues/epilogues of functions. +// As a result, we are not required to implemented the frame lowering +// functionality substantially. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVFRAMELOWERING_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVFRAMELOWERING_H + +#include "llvm/CodeGen/TargetFrameLowering.h" +#include "llvm/Support/Alignment.h" + +namespace llvm { +class SPIRVSubtarget; + +class SPIRVFrameLowering : public TargetFrameLowering { +public: + explicit SPIRVFrameLowering(const SPIRVSubtarget &sti) + : TargetFrameLowering(TargetFrameLowering::StackGrowsDown, Align(8), 0) {} + + void emitPrologue(MachineFunction &MF, + MachineBasicBlock &MBB) const override {} + void emitEpilogue(MachineFunction &MF, + MachineBasicBlock &MBB) const override {} + + bool hasFP(const MachineFunction &MF) const override { return false; } +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVFRAMELOWERING_H diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp new file mode 100644 index 000000000000..02a6905a1abc --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -0,0 +1,459 @@ +//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the SPIRVGlobalRegistry class, +// which is used to maintain rich type information required for SPIR-V even +// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into +// an OpTypeXXX instruction, and map it to a virtual register. Also it builds +// and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVGlobalRegistry.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" + +using namespace llvm; +SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) + : PointerSize(PointerSize) {} + +SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + + SPIRVType *SpirvType = + getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); + return SpirvType; +} + +void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, + Register VReg, + MachineFunction &MF) { + VRegToTypeMap[&MF][VReg] = SpirvType; +} + +static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { + auto &MRI = MIRBuilder.getMF().getRegInfo(); + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +static Register createTypeVReg(MachineRegisterInfo &MRI) { + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeBool) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, + MachineIRBuilder &MIRBuilder, + bool IsSigned) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(IsSigned ? 1 : 0); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto EleOpc = ElemType->getOpcode(); + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addImm(NumElems); + return MIB; +} + +Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, + bool EmitIR) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const IntegerType *LLVMIntTy; + if (SpvType) + LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); + else + LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); + // Find a constant in DT or build a new one. + const auto ConstInt = + ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + return Res; +} + +Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const Type *LLVMFPTy; + if (SpvType) { + LLVMFPTy = getTypeForSPIRVType(SpvType); + assert(LLVMFPTy->isFloatingPointTy()); + } else { + LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + return Res; +} + +Register SPIRVGlobalRegistry::buildGlobalVariable( + Register ResVReg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, + bool IsInstSelector) { + const GlobalVariable *GVar = nullptr; + if (GV) + GVar = cast<const GlobalVariable>(GV); + else { + // If GV is not passed explicitly, use the name to find or construct + // the global variable. + Module *M = MIRBuilder.getMF().getFunction().getParent(); + GVar = M->getGlobalVariable(Name); + if (GVar == nullptr) { + const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. + GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, + GlobalValue::ExternalLinkage, nullptr, + Twine(Name)); + } + GV = GVar; + } + Register Reg; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(static_cast<uint32_t>(Storage)); + + if (Init != 0) { + MIB.addUse(Init->getOperand(0).getReg()); + } + + // ISel may introduce a new register on this step, so we need to add it to + // DT and correct its type avoiding fails on the next stage. + if (IsInstSelector) { + const auto &Subtarget = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), + *Subtarget.getRegisterInfo(), + *Subtarget.getRegBankInfo()); + } + Reg = MIB->getOperand(0).getReg(); + + // Set to Reg the same type as ResVReg has. + auto MRI = MIRBuilder.getMRI(); + assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); + if (Reg != ResVReg) { + LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); + MRI->setType(Reg, RegLLTy); + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); + } + + // If it's a global variable with name, output OpName for it. + if (GVar && GVar->hasName()) + buildOpName(Reg, GVar->getName(), MIRBuilder); + + // Output decorations for the GV. + // TODO: maybe move to GenerateDecorations pass. + if (IsConst) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); + + if (GVar && GVar->getAlign().valueOrOne().value() != 1) + buildOpDecorate( + Reg, MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())}); + + if (HasLinkageTy) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast<uint32_t>(LinkageType)}, Name); + return Reg; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, + bool EmitIR) { + assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && + "Invalid array element type"); + Register NumElementsVReg = + buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(NumElementsVReg); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(static_cast<uint32_t>(SC)) + .addUse(getSPIRVTypeID(ElemType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( + SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(RetType)); + for (const SPIRVType *ArgType : ArgTypes) + MIB.addUse(getSPIRVTypeID(ArgType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, + MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccQual, + bool EmitIR) { + if (auto IType = dyn_cast<IntegerType>(Ty)) { + const unsigned Width = IType->getBitWidth(); + return Width == 1 ? getOpTypeBool(MIRBuilder) + : getOpTypeInt(Width, MIRBuilder, false); + } + if (Ty->isFloatingPointTy()) + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + if (Ty->isVoidTy()) + return getOpTypeVoid(MIRBuilder); + if (Ty->isVectorTy()) { + auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), + MIRBuilder); + return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, + MIRBuilder); + } + if (Ty->isArrayTy()) { + auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + } + assert(!isa<StructType>(Ty) && "Unsupported StructType"); + if (auto FType = dyn_cast<FunctionType>(Ty)) { + SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); + SmallVector<SPIRVType *, 4> ParamTypes; + for (const auto &t : FType->params()) { + ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); + } + return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); + } + if (auto PType = dyn_cast<PointerType>(Ty)) { + SPIRVType *SpvElementType; + // At the moment, all opaque pointers correspond to i8 element type. + // TODO: change the implementation once opaque pointers are supported + // in the SPIR-V specification. + if (PType->isOpaque()) { + SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); + } else { + Type *ElemType = PType->getNonOpaquePointerElementType(); + // TODO: support OpenCL and SPIRV builtins like image2d_t that are passed + // as pointers, but should be treated as custom types like OpTypeImage. + assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer"); + + // Otherwise, treat it as a regular pointer type. + SpvElementType = getOrCreateSPIRVType( + ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + } + auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + return getOpTypePointer(SC, SpvElementType, MIRBuilder); + } + llvm_unreachable("Unable to convert LLVM type to SPIRVType"); +} + +SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { + auto t = VRegToTypeMap.find(CurMF); + if (t != VRegToTypeMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return nullptr; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = Type; + return SpirvType; +} + +bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOfType VReg has no type assigned"); + return Type->getOpcode() == TypeOpcode; +} + +bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); + if (Type->getOpcode() == TypeOpcode) + return true; + if (Type->getOpcode() == SPIRV::OpTypeVector) { + Register ScalarTypeVReg = Type->getOperand(1).getReg(); + SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); + return ScalarType->getOpcode() == TypeOpcode; + } + return false; +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt || + Type->getOpcode() == SPIRV::OpTypeFloat) + return Type->getOperand(1).getImm(); + if (Type->getOpcode() == SPIRV::OpTypeBool) + return 1; + llvm_unreachable("Attempting to get bit width of non-integer/float type."); +} + +bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt) + return Type->getOperand(2).getImm() != 0; + llvm_unreachable("Attempting to get sign of non-integer type."); +} + +SPIRV::StorageClass +SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && + Type->getOperand(1).isImm() && "Pointer type is expected"); + return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm()); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, + MachineInstrBuilder MIB) { + SPIRVType *SpirvType = MIB; + VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = LLVMTy; + return SpirvType; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( + unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { + Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(BitWidth) + .addImm(0); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + NumElements), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII) { + Type *LLVMTy = FixedVectorType::get( + const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(NumElements); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, + MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass) { + return getOrCreateSPIRVType( + PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SClass)), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SC) { + Type *LLVMTy = + PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SC)); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(static_cast<uint32_t>(SC)) + .addUse(getSPIRVTypeID(BaseType)); + return restOfCreateSPIRVType(LLVMTy, MIB); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h new file mode 100644 index 000000000000..952ab4c13e29 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -0,0 +1,174 @@ +//===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SPIRVGlobalRegistry is used to maintain rich type information required for +// SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type +// into an OpTypeXXX instruction, and map it to a virtual register. Also it +// builds and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVInstrInfo.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" + +namespace llvm { +using SPIRVType = const MachineInstr; + +class SPIRVGlobalRegistry { + // Registers holding values which have types associated with them. + // Initialized upon VReg definition in IRTranslator. + // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg> + // where Reg = OpType... + // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not + // type-declaring ones) + DenseMap<MachineFunction *, DenseMap<Register, SPIRVType *>> VRegToTypeMap; + + DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType; + + // Number of bits pointers and size_t integers require. + const unsigned PointerSize; + + // Add a new OpTypeXXX instruction without checking for duplicates. + SPIRVType * + createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + +public: + SPIRVGlobalRegistry(unsigned PointerSize); + + MachineFunction *CurMF; + + // Get or create a SPIR-V type corresponding the given LLVM IR type, + // and map it to the given VReg by creating an ASSIGN_TYPE instruction. + SPIRVType *assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + // In cases where the SPIR-V type is already known, this function can be + // used to map it to the given VReg via an ASSIGN_TYPE instruction. + void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, + MachineFunction &MF); + + // Either generate a new OpTypeXXX instruction or return an existing one + // corresponding to the given LLVM IR type. + // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes) + // because this method may be called from InstructionSelector and we don't + // want to emit extra IR instructions there. + SPIRVType *getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + const Type *getTypeForSPIRVType(const SPIRVType *Ty) const { + auto Res = SPIRVToLLVMType.find(Ty); + assert(Res != SPIRVToLLVMType.end()); + return Res->second; + } + + // Return the SPIR-V type instruction corresponding to the given VReg, or + // nullptr if no such type instruction exists. + SPIRVType *getSPIRVTypeForVReg(Register VReg) const; + + // Whether the given VReg has a SPIR-V type mapped to it yet. + bool hasSPIRVTypeForVReg(Register VReg) const { + return getSPIRVTypeForVReg(VReg) != nullptr; + } + + // Return the VReg holding the result of the given OpTypeXXX instruction. + Register getSPIRVTypeID(const SPIRVType *SpirvType) const { + assert(SpirvType && "Attempting to get type id for nullptr type."); + return SpirvType->defs().begin()->getReg(); + } + + void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; } + + // Whether the given VReg has an OpTypeXXX instruction mapped to it with the + // given opcode (e.g. OpTypeFloat). + bool isScalarOfType(Register VReg, unsigned TypeOpcode) const; + + // Return true if the given VReg's assigned SPIR-V type is either a scalar + // matching the given opcode, or a vector with an element type matching that + // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). + bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const; + + // For vectors or scalars of ints/floats, return the scalar type's bitwidth. + unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const; + + // For integer vectors or scalars, return whether the integers are signed. + bool isScalarOrVectorSigned(const SPIRVType *Type) const; + + // Gets the storage class of the pointer type assigned to this vreg. + SPIRV::StorageClass getPointerStorageClass(Register VReg) const; + + // Return the number of bits SPIR-V pointers and size_t variables require. + unsigned getPointerSize() const { return PointerSize; } + +private: + SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder, + bool IsSigned = false); + + SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, bool EmitIR = true); + + SPIRVType *getOpTypePointer(SPIRV::StorageClass SC, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeFunction(SPIRVType *RetType, + const SmallVectorImpl<SPIRVType *> &ArgTypes, + MachineIRBuilder &MIRBuilder); + SPIRVType *restOfCreateSPIRVType(Type *LLVMTy, MachineInstrBuilder MIB); + +public: + Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr, bool EmitIR = true); + Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr); + Register + buildGlobalVariable(Register Reg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); + + // Convenient helpers for getting types with check for duplicates. + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I, + const SPIRVInstrInfo &TII); + SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII); + + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); +}; +} // end namespace llvm +#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp new file mode 100644 index 000000000000..66ff51c912b0 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -0,0 +1,45 @@ +//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SPIRVTargetLowering class. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVISelLowering.h" +#include "SPIRV.h" + +#define DEBUG_TYPE "spirv-lower" + +using namespace llvm; + +unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( + LLVMContext &Context, CallingConv::ID CC, EVT VT) const { + // This code avoids CallLowering fail inside getVectorTypeBreakdown + // on v3i1 arguments. Maybe we need to return 1 for all types. + // TODO: remove it once this case is supported by the default implementation. + if (VT.isVector() && VT.getVectorNumElements() == 3 && + (VT.getVectorElementType() == MVT::i1 || + VT.getVectorElementType() == MVT::i8)) + return 1; + return getNumRegisters(Context, VT); +} + +MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, + CallingConv::ID CC, + EVT VT) const { + // This code avoids CallLowering fail inside getVectorTypeBreakdown + // on v3i1 arguments. Maybe we need to return i32 for all types. + // TODO: remove it once this case is supported by the default implementation. + if (VT.isVector() && VT.getVectorNumElements() == 3) { + if (VT.getVectorElementType() == MVT::i1) + return MVT::v4i1; + else if (VT.getVectorElementType() == MVT::i8) + return MVT::v4i8; + } + return getRegisterType(Context, VT); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h new file mode 100644 index 000000000000..bee9220f5248 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -0,0 +1,47 @@ +//===-- SPIRVISelLowering.h - SPIR-V DAG Lowering Interface -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the interfaces that SPIR-V uses to lower LLVM code into a +// selection DAG. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H + +#include "llvm/CodeGen/TargetLowering.h" + +namespace llvm { +class SPIRVSubtarget; + +class SPIRVTargetLowering : public TargetLowering { +public: + explicit SPIRVTargetLowering(const TargetMachine &TM, + const SPIRVSubtarget &STI) + : TargetLowering(TM) {} + + // Stop IRTranslator breaking up FMA instrs to preserve types information. + bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, + EVT) const override { + return true; + } + + // This is to prevent sexts of non-i64 vector indices which are generated + // within general IRTranslator hence type generation for it is omitted. + MVT getVectorIdxTy(const DataLayout &DL) const override { + return MVT::getIntegerVT(32); + } + unsigned getNumRegistersForCallingConv(LLVMContext &Context, + CallingConv::ID CC, + EVT VT) const override; + MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, + EVT VT) const override; +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td new file mode 100644 index 000000000000..c78c8ee11590 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td @@ -0,0 +1,31 @@ +//===-- SPIRVInstrFormats.td - SPIR-V Instruction Formats --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +def StringImm: Operand<i32>{ + let PrintMethod="printStringImm"; +} + +class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern = []> + : Instruction { + field bits<16> Inst; + + let Inst = Opcode; + + let Namespace = "SPIRV"; + let DecoderNamespace = "SPIRV"; + + dag OutOperandList = outs; + dag InOperandList = ins; + let AsmString = asmstr; + let Pattern = pattern; +} + +// Pseudo instructions +class Pseudo<dag outs, dag ins> : Op<0, outs, ins, ""> { + let isPseudo = 1; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp new file mode 100644 index 000000000000..754906308114 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -0,0 +1,195 @@ +//===-- SPIRVInstrInfo.cpp - SPIR-V Instruction Information ------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the SPIR-V implementation of the TargetInstrInfo class. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVInstrInfo.h" +#include "SPIRV.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/Support/ErrorHandling.h" + +#define GET_INSTRINFO_CTOR_DTOR +#include "SPIRVGenInstrInfo.inc" + +using namespace llvm; + +SPIRVInstrInfo::SPIRVInstrInfo() : SPIRVGenInstrInfo() {} + +bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const { + switch (MI.getOpcode()) { + case SPIRV::OpConstantTrue: + case SPIRV::OpConstantFalse: + case SPIRV::OpConstantI: + case SPIRV::OpConstantF: + case SPIRV::OpConstantComposite: + case SPIRV::OpConstantSampler: + case SPIRV::OpConstantNull: + case SPIRV::OpSpecConstantTrue: + case SPIRV::OpSpecConstantFalse: + case SPIRV::OpSpecConstant: + case SPIRV::OpSpecConstantComposite: + case SPIRV::OpSpecConstantOp: + case SPIRV::OpUndef: + return true; + default: + return false; + } +} + +bool SPIRVInstrInfo::isTypeDeclInstr(const MachineInstr &MI) const { + auto &MRI = MI.getMF()->getRegInfo(); + if (MI.getNumDefs() >= 1 && MI.getOperand(0).isReg()) { + auto DefRegClass = MRI.getRegClassOrNull(MI.getOperand(0).getReg()); + return DefRegClass && DefRegClass->getID() == SPIRV::TYPERegClass.getID(); + } else { + return false; + } +} + +bool SPIRVInstrInfo::isDecorationInstr(const MachineInstr &MI) const { + switch (MI.getOpcode()) { + case SPIRV::OpDecorate: + case SPIRV::OpDecorateId: + case SPIRV::OpDecorateString: + case SPIRV::OpMemberDecorate: + case SPIRV::OpMemberDecorateString: + return true; + default: + return false; + } +} + +bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const { + switch (MI.getOpcode()) { + case SPIRV::OpCapability: + case SPIRV::OpExtension: + case SPIRV::OpExtInstImport: + case SPIRV::OpMemoryModel: + case SPIRV::OpEntryPoint: + case SPIRV::OpExecutionMode: + case SPIRV::OpExecutionModeId: + case SPIRV::OpString: + case SPIRV::OpSourceExtension: + case SPIRV::OpSource: + case SPIRV::OpSourceContinued: + case SPIRV::OpName: + case SPIRV::OpMemberName: + case SPIRV::OpModuleProcessed: + return true; + default: + return isTypeDeclInstr(MI) || isConstantInstr(MI) || isDecorationInstr(MI); + } +} + +// Analyze the branching code at the end of MBB, returning +// true if it cannot be understood (e.g. it's a switch dispatch or isn't +// implemented for a target). Upon success, this returns false and returns +// with the following information in various cases: +// +// 1. If this block ends with no branches (it just falls through to its succ) +// just return false, leaving TBB/FBB null. +// 2. If this block ends with only an unconditional branch, it sets TBB to be +// the destination block. +// 3. If this block ends with a conditional branch and it falls through to a +// successor block, it sets TBB to be the branch destination block and a +// list of operands that evaluate the condition. These operands can be +// passed to other TargetInstrInfo methods to create new branches. +// 4. If this block ends with a conditional branch followed by an +// unconditional branch, it returns the 'true' destination in TBB, the +// 'false' destination in FBB, and a list of operands that evaluate the +// condition. These operands can be passed to other TargetInstrInfo +// methods to create new branches. +// +// Note that removeBranch and insertBranch must be implemented to support +// cases where this method returns success. +// +// If AllowModify is true, then this routine is allowed to modify the basic +// block (e.g. delete instructions after the unconditional branch). +// +// The CFG information in MBB.Predecessors and MBB.Successors must be valid +// before calling this function. +bool SPIRVInstrInfo::analyzeBranch(MachineBasicBlock &MBB, + MachineBasicBlock *&TBB, + MachineBasicBlock *&FBB, + SmallVectorImpl<MachineOperand> &Cond, + bool AllowModify) const { + TBB = nullptr; + FBB = nullptr; + if (MBB.empty()) + return false; + auto MI = MBB.getLastNonDebugInstr(); + if (!MI.isValid()) + return false; + if (MI->getOpcode() == SPIRV::OpBranch) { + TBB = MI->getOperand(0).getMBB(); + return false; + } else if (MI->getOpcode() == SPIRV::OpBranchConditional) { + Cond.push_back(MI->getOperand(0)); + TBB = MI->getOperand(1).getMBB(); + if (MI->getNumOperands() == 3) { + FBB = MI->getOperand(2).getMBB(); + } + return false; + } else { + return true; + } +} + +// Remove the branching code at the end of the specific MBB. +// This is only invoked in cases where analyzeBranch returns success. It +// returns the number of instructions that were removed. +// If \p BytesRemoved is non-null, report the change in code size from the +// removed instructions. +unsigned SPIRVInstrInfo::removeBranch(MachineBasicBlock &MBB, + int *BytesRemoved) const { + report_fatal_error("Branch removal not supported, as MBB info not propagated" + " to OpPhi instructions. Try using -O0 instead."); +} + +// Insert branch code into the end of the specified MachineBasicBlock. The +// operands to this method are the same as those returned by analyzeBranch. +// This is only invoked in cases where analyzeBranch returns success. It +// returns the number of instructions inserted. If \p BytesAdded is non-null, +// report the change in code size from the added instructions. +// +// It is also invoked by tail merging to add unconditional branches in +// cases where analyzeBranch doesn't apply because there was no original +// branch to analyze. At least this much must be implemented, else tail +// merging needs to be disabled. +// +// The CFG information in MBB.Predecessors and MBB.Successors must be valid +// before calling this function. +unsigned SPIRVInstrInfo::insertBranch( + MachineBasicBlock &MBB, MachineBasicBlock *TBB, MachineBasicBlock *FBB, + ArrayRef<MachineOperand> Cond, const DebugLoc &DL, int *BytesAdded) const { + report_fatal_error("Branch insertion not supported, as MBB info not " + "propagated to OpPhi instructions. Try using " + "-O0 instead."); +} + +void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB, + MachineBasicBlock::iterator I, + const DebugLoc &DL, MCRegister DestReg, + MCRegister SrcReg, bool KillSrc) const { + // Actually we don't need this COPY instruction. However if we do nothing with + // it, post RA pseudo instrs expansion just removes it and we get the code + // with undef registers. Therefore, we need to replace all uses of dst with + // the src register. COPY instr itself will be safely removed later. + assert(I->isCopy() && "Copy instruction is expected"); + auto DstOp = I->getOperand(0); + auto SrcOp = I->getOperand(1); + assert(DstOp.isReg() && SrcOp.isReg() && + "Register operands are expected in COPY"); + auto &MRI = I->getMF()->getRegInfo(); + MRI.replaceRegWith(DstOp.getReg(), SrcOp.getReg()); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h new file mode 100644 index 000000000000..2600d9cfca2e --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h @@ -0,0 +1,54 @@ +//===-- SPIRVInstrInfo.h - SPIR-V Instruction Information -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the SPIR-V implementation of the TargetInstrInfo class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H + +#include "SPIRVRegisterInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" + +#define GET_INSTRINFO_HEADER +#include "SPIRVGenInstrInfo.inc" + +namespace llvm { + +class SPIRVInstrInfo : public SPIRVGenInstrInfo { + const SPIRVRegisterInfo RI; + +public: + SPIRVInstrInfo(); + + const SPIRVRegisterInfo &getRegisterInfo() const { return RI; } + bool isHeaderInstr(const MachineInstr &MI) const; + bool isConstantInstr(const MachineInstr &MI) const; + bool isTypeDeclInstr(const MachineInstr &MI) const; + bool isDecorationInstr(const MachineInstr &MI) const; + + bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB, + MachineBasicBlock *&FBB, + SmallVectorImpl<MachineOperand> &Cond, + bool AllowModify = false) const override; + + unsigned removeBranch(MachineBasicBlock &MBB, + int *BytesRemoved = nullptr) const override; + + unsigned insertBranch(MachineBasicBlock &MBB, MachineBasicBlock *TBB, + MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond, + const DebugLoc &DL, + int *BytesAdded = nullptr) const override; + void copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, + const DebugLoc &DL, MCRegister DestReg, MCRegister SrcReg, + bool KillSrc) const override; +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td new file mode 100644 index 000000000000..d6fec5fd0785 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -0,0 +1,732 @@ +//===-- SPIRVInstrInfo.td - Target Description for SPIR-V Target ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file describes the SPIR-V instructions in TableGen format. +// +//===----------------------------------------------------------------------===// + +include "SPIRVInstrFormats.td" +include "SPIRVEnums.td" + +// Codegen only metadata instructions +let isCodeGenOnly=1 in { + def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; + def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; + def GET_ID: Pseudo<(outs ID:$dst_id), (ins ANYID:$src)>; + def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>; + def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>; + def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>; + def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>; +} + +def SPVTypeBin : SDTypeProfile<1, 2, []>; + +def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>; + +def : GINodeEquiv<ASSIGN_TYPE, assigntype>; + +class BinOp<string name, bits<16> opCode, list<dag> pattern=[]> + : Op<opCode, (outs ANYID:$dst), (ins TYPE:$src_ty, ANYID:$src, ANYID:$src2), + "$dst = "#name#" $src_ty $src $src2", pattern>; + +class BinOpTyped<string name, bits<16> opCode, RegisterClass CID, SDNode node> + : Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, CID:$src, CID:$src2), + "$dst = "#name#" $src_ty $src $src2", [(set ID:$dst, (assigntype (node CID:$src, CID:$src2), TYPE:$src_ty))]>; + +class TernOpTyped<string name, bits<16> opCode, RegisterClass CCond, RegisterClass CID, SDNode node> + : Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, CCond:$cond, CID:$src1, CID:$src2), + "$dst = "#name#" $src_ty $cond $src1 $src2", [(set ID:$dst, (assigntype (node CCond:$cond, CID:$src1, CID:$src2), TYPE:$src_ty))]>; + +multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genF = 0, bit genV = 0> { + if genF then + def S: BinOpTyped<name, opCode, fID, node>; + else + def S: BinOpTyped<name, opCode, ID, node>; + if genV then { + if genF then + def V: BinOpTyped<name, opCode, vfID, node>; + else + def V: BinOpTyped<name, opCode, vID, node>; + } +} + +multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> { + if genF then { + def SFSCond: TernOpTyped<name, opCode, ID, fID, node>; + def SFVCond: TernOpTyped<name, opCode, vID, fID, node>; + } + if genI then { + def SISCond: TernOpTyped<name, opCode, ID, ID, node>; + def SIVCond: TernOpTyped<name, opCode, vID, ID, node>; + } + if genV then { + if genF then { + def VFSCond: TernOpTyped<name, opCode, ID, vfID, node>; + def VFVCond: TernOpTyped<name, opCode, vID, vfID, node>; + } + if genI then { + def VISCond: TernOpTyped<name, opCode, ID, vID, node>; + def VIVCond: TernOpTyped<name, opCode, vID, vID, node>; + } + } +} + +class UnOp<string name, bits<16> opCode, list<dag> pattern=[]> + : Op<opCode, (outs ANYID:$dst), (ins TYPE:$type, ANYID:$src), + "$dst = "#name#" $type $src", pattern>; +class UnOpTyped<string name, bits<16> opCode, RegisterClass CID, SDNode node> + : Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, CID:$src), + "$dst = "#name#" $src_ty $src", [(set ID:$dst, (assigntype (node CID:$src), TYPE:$src_ty))]>; + +class SimpleOp<string name, bits<16> opCode>: Op<opCode, (outs), (ins), name>; + +// 3.42.1 Miscellaneous Instructions + +def OpNop: SimpleOp<"OpNop", 0>; +def OpUndef: Op<1, (outs ID:$res), (ins TYPE:$type), "$res = OpUndef $type">; +def OpSizeOf: Op<321, (outs ID:$res), (ins TYPE:$ty, ID:$ptr), "$res = OpSizeOf $ty $ptr">; + +// 3.42.2 Debug Instructions + +def OpSourceContinued: Op<2, (outs), (ins StringImm:$str, variable_ops), + "OpSourceContinued $str">; +def OpSource: Op<3, (outs), (ins SourceLanguage:$lang, i32imm:$version, variable_ops), + "OpSource $lang $version">; +def OpSourceExtension: Op<4, (outs), (ins StringImm:$extension, variable_ops), + "OpSourceExtension $extension">; +def OpName: Op<5, (outs), (ins ANY:$tar, StringImm:$name, variable_ops), "OpName $tar $name">; +def OpMemberName: Op<6, (outs), (ins TYPE:$ty, i32imm:$mem, StringImm:$name, variable_ops), + "OpMemberName $ty $mem $name">; +def OpString: Op<7, (outs ID:$r), (ins StringImm:$s, variable_ops), "$r = OpString $s">; +def OpLine: Op<8, (outs), (ins ID:$file, i32imm:$ln, i32imm:$col), "OpLine $file $ln $col">; +def OpNoLine: Op<317, (outs), (ins), "OpNoLine">; +def OpModuleProcessed: Op<330, (outs), (ins StringImm:$process, variable_ops), + "OpModuleProcessed $process">; + +// 3.42.3 Annotation Instructions + +def OpDecorate: Op<71, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), + "OpDecorate $target $dec">; +def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, variable_ops), + "OpMemberDecorate $t $m $d">; + +// TODO Currently some deprecated opcodes are missing: OpDecorationGroup, +// OpGroupDecorate and OpGroupMemberDecorate + +def OpDecorateId: Op<332, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), + "OpDecorateId $target $dec">; +def OpDecorateString: Op<5632, (outs), (ins ANY:$t, Decoration:$d, StringImm:$s, variable_ops), + "OpDecorateString $t $d $s">; +def OpMemberDecorateString: Op<5633, (outs), + (ins TYPE:$ty, i32imm:$mem, Decoration:$dec, StringImm:$str, variable_ops), + "OpMemberDecorateString $ty $mem $dec $str">; + +// 3.42.4 Extension Instructions + +def OpExtension: Op<10, (outs), (ins StringImm:$name, variable_ops), "OpExtension $name">; +def OpExtInstImport: Op<11, (outs ID:$res), (ins StringImm:$extInstsName, variable_ops), + "$res = OpExtInstImport $extInstsName">; +def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, ID:$set, ExtInst:$inst, variable_ops), + "$res = OpExtInst $ty $set $inst">; + +// 3.42.5 Mode-Setting Instructions + +def OpMemoryModel: Op<14, (outs), (ins AddressingModel:$addr, MemoryModel:$mem), + "OpMemoryModel $addr $mem">; +def OpEntryPoint: Op<15, (outs), + (ins ExecutionModel:$model, ID:$entry, StringImm:$name, variable_ops), + "OpEntryPoint $model $entry $name">; +def OpExecutionMode: Op<16, (outs), (ins ID:$entry, ExecutionMode:$mode, variable_ops), + "OpExecutionMode $entry $mode">; +def OpCapability: Op<17, (outs), (ins Capability:$cap), "OpCapability $cap">; +def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, variable_ops), + "OpExecutionModeId $entry $mode">; + +// 3.42.6 Type-Declaration Instructions + +def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; +def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; +def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), + "$type = OpTypeInt $width $signedness">; +def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width), + "$type = OpTypeFloat $width">; +def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), + "$type = OpTypeVector $compType $compCount">; +def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), + "$type = OpTypeMatrix $colType $colCount">; +def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, + i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops), + "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">; +def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; +def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType), + "$res = OpTypeSampledImage $imageType">; +def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), + "$type = OpTypeArray $elementType $length">; +def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType), + "$type = OpTypeRuntimeArray $elementType">; +def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; +def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), + "$res = OpTypeOpaque $name">; +def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), + "$res = OpTypePointer $storage $type">; +def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), + "$funcType = OpTypeFunction $returnType">; +def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; +def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; +def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; +def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; +def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">; +def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), + "OpTypeForwardPointer $ptrType $storageClass">; +def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; +def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; +def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), + "$res = OpTypeAccelerationStructureNV">; +def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), + "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; + +// 3.42.7 Constant-Creation Instructions + +def imm_to_i32 : SDNodeXForm<imm, [{ +return CurDAG->getTargetConstant( + N->getValueAP().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); +}]>; + +def fimm_to_i32 : SDNodeXForm<imm, [{ +return CurDAG->getTargetConstant( + N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); +}]>; + +def gi_bitcast_fimm_to_i32 : GICustomOperandRenderer<"renderFImm32">, + GISDNodeXFormEquiv<fimm_to_i32>; + +def gi_bitcast_imm_to_i32 : GICustomOperandRenderer<"renderImm32">, + GISDNodeXFormEquiv<imm_to_i32>; + +def PseudoConstI: IntImmLeaf<i32, [{ return Imm.getBitWidth() <= 32; }], imm_to_i32>; +def PseudoConstF: FPImmLeaf<f32, [{ return true; }], fimm_to_i32>; +def ConstPseudoTrue: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>; +def ConstPseudoFalse: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>; +def ConstPseudoNull: IntImmLeaf<i64, [{ return Imm.isNullValue(); }]>; + +multiclass IntFPImm<bits<16> opCode, string name> { + def I: Op<opCode, (outs ID:$dst), (ins TYPE:$type, ID:$src, variable_ops), + "$dst = "#name#" $type $src", [(set ID:$dst, (assigntype PseudoConstI:$src, TYPE:$type))]>; + def F: Op<opCode, (outs ID:$dst), (ins TYPE:$type, fID:$src, variable_ops), + "$dst = "#name#" $type $src", [(set ID:$dst, (assigntype PseudoConstF:$src, TYPE:$type))]>; +} + +def OpConstantTrue: Op<41, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty", + [(set ID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; +def OpConstantFalse: Op<42, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty", + [(set ID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; + +defm OpConstant: IntFPImm<43, "OpConstant">; + +def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpConstantComposite $type">; +def OpConstantSampler: Op<45, (outs ID:$res), + (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f), + "$res = OpConstantSampler $t $s $p $f">; +def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty", + [(set ID:$dst, (assigntype ConstPseudoNull, TYPE:$src_ty))]>; + +def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; +def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; +def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), + "$res = OpSpecConstant $type $imm">; +def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpSpecConstantComposite $type">; +def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, i32imm:$c, ID:$o, variable_ops), + "$res = OpSpecConstantOp $t $c $o">; + +// 3.42.8 Memory Instructions + +def OpVariable: Op<59, (outs ID:$res), (ins TYPE:$type, StorageClass:$sc, variable_ops), + "$res = OpVariable $type $sc">; +def OpImageTexelPointer: Op<60, (outs ID:$res), + (ins TYPE:$resType, ID:$image, ID:$coord, ID:$sample), + "$res = OpImageTexelPointer $resType $image $coord $sample">; +def OpLoad: Op<61, (outs ID:$res), (ins TYPE:$resType, ID:$pointer, variable_ops), + "$res = OpLoad $resType $pointer">; +def OpStore: Op<62, (outs), (ins ID:$pointer, ID:$objectToStore, variable_ops), + "OpStore $pointer $objectToStore">; +def OpCopyMemory: Op<63, (outs), (ins ID:$dest, ID:$src, variable_ops), + "OpCopyMemory $dest $src">; +def OpCopyMemorySized: Op<64, (outs), (ins ID:$dest, ID:$src, ID:$size, variable_ops), + "OpCopyMemorySized $dest $src $size">; +def OpAccessChain: Op<65, (outs ID:$res), (ins TYPE:$type, ID:$base, variable_ops), + "$res = OpAccessChain $type $base">; +def OpInBoundsAccessChain: Op<66, (outs ID:$res), + (ins TYPE:$type, ID:$base, variable_ops), + "$res = OpInBoundsAccessChain $type $base">; +def OpPtrAccessChain: Op<67, (outs ID:$res), + (ins TYPE:$type, ID:$base, ID:$element, variable_ops), + "$res = OpPtrAccessChain $type $base $element">; +def OpArrayLength: Op<68, (outs ID:$res), (ins TYPE:$resTy, ID:$struct, i32imm:$arrayMember), + "$res = OpArrayLength $resTy $struct $arrayMember">; +def OpGenericPtrMemSemantics: Op<69, (outs ID:$res), (ins TYPE:$resType, ID:$pointer), + "$res = OpGenericPtrMemSemantics $resType $pointer">; +def OpInBoundsPtrAccessChain: Op<70, (outs ID:$res), + (ins TYPE:$type, ID:$base, ID:$element, variable_ops), + "$res = OpInBoundsPtrAccessChain $type $base $element">; +def OpPtrEqual: Op<401, (outs ID:$res), (ins TYPE:$resType, ID:$a, ID:$b), + "$res = OpPtrEqual $resType $a $b">; +def OpPtrNotEqual: Op<402, (outs ID:$res), (ins TYPE:$resType, ID:$a, ID:$b), + "$res = OpPtrNotEqual $resType $a $b">; +def OpPtrDiff: Op<403, (outs ID:$res), (ins TYPE:$resType, ID:$a, ID:$b), + "$res = OpPtrDiff $resType $a $b">; + +// 3.42.9 Function Instructions + +def OpFunction: Op<54, (outs ID:$func), + (ins TYPE:$resType, FunctionControl:$funcControl, TYPE:$funcType), + "$func = OpFunction $resType $funcControl $funcType">; +def OpFunctionParameter: Op<55, (outs ID:$arg), (ins TYPE:$type), + "$arg = OpFunctionParameter $type">; +def OpFunctionEnd: Op<56, (outs), (ins), "OpFunctionEnd"> { + let isTerminator=1; +} +def OpFunctionCall: Op<57, (outs ID:$res), (ins TYPE:$resType, ID:$function, variable_ops), + "$res = OpFunctionCall $resType $function">; + +// 3.42.10 Image Instructions + +def OpSampledImage: BinOp<"OpSampledImage", 86>; + +def OpImageSampleImplicitLod: Op<87, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, variable_ops), + "$res = OpImageSampleImplicitLod $type $sampledImage $coord">; +def OpImageSampleExplicitLod: Op<88, (outs ID:$res), + (ins TYPE:$ty, ID:$sImage, ID:$uv, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSampleExplicitLod $ty $sImage $uv $op $i">; + +def OpImageSampleDrefImplicitLod: Op<89, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageSampleDrefImplicitLod $type $sampledImage $dref $coord">; +def OpImageSampleDrefExplicitLod: Op<90, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSampleDrefExplicitLod $ty $im $uv $d $op $i">; + +def OpImageSampleProjImplicitLod: Op<91, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, variable_ops), + "$res = OpImageSampleProjImplicitLod $type $sampledImage $coord">; +def OpImageSampleProjExplicitLod: Op<92, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSampleProjExplicitLod $ty $im $uv $op $i">; + +def OpImageSampleProjDrefImplicitLod: Op<93, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageSampleProjDrefImplicitLod $type $sampledImage $dref $coord">; +def OpImageSampleProjDrefExplicitLod: Op<94, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSampleProjDrefExplicitLod $ty $im $uv $d $op $i">; + +def OpImageFetch: Op<95, (outs ID:$res), + (ins TYPE:$type, ID:$image, ID:$coord, variable_ops), + "$res = OpImageFetch $type $image $coord">; +def OpImageGather: Op<96, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$component, variable_ops), + "$res = OpImageGather $type $sampledImage $coord $component">; +def OpImageDrefGather: Op<97, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageDrefGather $type $sampledImage $coord $dref">; + +def OpImageRead: Op<98, (outs ID:$res), + (ins TYPE:$type, ID:$image, ID:$coord, variable_ops), + "$res = OpImageRead $type $image $coord">; +def OpImageWrite: Op<99, (outs), (ins ID:$image, ID:$coord, ID:$texel, variable_ops), + "OpImageWrite $image $coord $texel">; + +def OpImage: UnOp<"OpImage", 100>; +def OpImageQueryFormat: UnOp<"OpImageQueryFormat", 101>; +def OpImageQueryOrder: UnOp<"OpImageQueryOrder", 102>; +def OpImageQuerySizeLod: BinOp<"OpImageQuerySizeLod", 103>; +def OpImageQuerySize: UnOp<"OpImageQuerySize", 104>; +def OpImageQueryLod: BinOp<"OpImageQueryLod", 105>; +def OpImageQueryLevels: UnOp<"OpImageQueryLevels", 106>; +def OpImageQuerySamples: UnOp<"OpImageQuerySamples", 107>; + +def OpImageSparseSampleImplicitLod: Op<305, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, variable_ops), + "$res = OpImageSparseSampleImplicitLod $type $sampledImage $coord">; +def OpImageSparseSampleExplicitLod: Op<306, (outs ID:$res), + (ins TYPE:$ty, ID:$sImage, ID:$uv, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSparseSampleExplicitLod $ty $sImage $uv $op $i">; + +def OpImageSparseSampleDrefImplicitLod: Op<307, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImg, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageSparseSampleDrefImplicitLod $type $sampledImg $dref $coord">; +def OpImageSparseSampleDrefExplicitLod: Op<308, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSparseSampleDrefExplicitLod $ty $im $uv $d $op $i">; + +def OpImageSparseSampleProjImplicitLod: Op<309, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, variable_ops), + "$res = OpImageSparseSampleProjImplicitLod $type $sampledImage $coord">; +def OpImageSparseSampleProjExplicitLod: Op<310, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSparseSampleProjExplicitLod $ty $im $uv $op $i">; + +def OpImageSparseSampleProjDrefImplicitLod: Op<311, (outs ID:$res), + (ins TYPE:$type, ID:$sImage, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageSparseSampleProjDrefImplicitLod $type $sImage $dref $coord">; +def OpImageSparseSampleProjDrefExplicitLod: Op<312, (outs ID:$res), + (ins TYPE:$ty, ID:$im, ID:$uv, ID:$d, ImageOperand:$op, ID:$i, variable_ops), + "$res = OpImageSparseSampleProjDrefExplicitLod $ty $im $uv $d $op $i">; + +def OpImageSparseFetch: Op<313, (outs ID:$res), + (ins TYPE:$type, ID:$image, ID:$coord, variable_ops), + "$res = OpImageSparseFetch $type $image $coord">; +def OpImageSparseGather: Op<314, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$component, variable_ops), + "$res = OpImageSparseGather $type $sampledImage $coord $component">; +def OpImageSparseDrefGather: Op<315, (outs ID:$res), + (ins TYPE:$type, ID:$sampledImage, ID:$coord, ID:$dref, variable_ops), + "$res = OpImageSparseDrefGather $type $sampledImage $coord $dref">; + +def OpImageSparseTexelsResident: UnOp<"OpImageSparseTexelsResident", 316>; + +def OpImageSparseRead: Op<320, (outs ID:$res), + (ins TYPE:$type, ID:$image, ID:$coord, variable_ops), + "$res = OpImageSparseRead $type $image $coord">; + +def OpImageSampleFootprintNV: Op<5283, (outs ID:$res), + (ins TYPE:$ty, ID:$sImg, ID:$uv, ID:$granularity, ID:$coarse, variable_ops), + "$res = OpImageSampleFootprintNV $ty $sImg $uv $granularity $coarse">; + +// 3.42.11 Conversion instructions + +def OpConvertFToU : UnOp<"OpConvertFToU", 109>; +def OpConvertFToS : UnOp<"OpConvertFToS", 110>; +def OpConvertSToF : UnOp<"OpConvertSToF", 111>; +def OpConvertUToF : UnOp<"OpConvertUToF", 112>; + +def OpUConvert : UnOp<"OpUConvert", 113>; +def OpSConvert : UnOp<"OpSConvert", 114>; +def OpFConvert : UnOp<"OpFConvert", 115>; + +def OpQuantizeToF16 : UnOp<"OpQuantizeToF16", 116>; + +def OpConvertPtrToU : UnOp<"OpConvertPtrToU", 117>; + +def OpSatConvertSToU : UnOp<"OpSatConvertSToU", 118>; +def OpSatConvertUToS : UnOp<"OpSatConvertUToS", 119>; + +def OpConvertUToPtr : UnOp<"OpConvertUToPtr", 120>; +def OpPtrCastToGeneric : UnOp<"OpPtrCastToGeneric", 121>; +def OpGenericCastToPtr : UnOp<"OpGenericCastToPtr", 122>; +def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, StorageClass:$s), + "$r = OpGenericCastToPtrExplicit $t $p $s">; +def OpBitcast : UnOp<"OpBitcast", 124>; + +// 3.42.12 Composite Instructions + +def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), + "$res = OpVectorExtractDynamic $type $vec $idx", [(set ID:$res, (assigntype (extractelt vID:$vec, ID:$idx), TYPE:$type))]>; + +def OpVectorInsertDynamic: Op<78, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$comp, ID:$idx), + "$res = OpVectorInsertDynamic $ty $vec $comp $idx">; +def OpVectorShuffle: Op<79, (outs ID:$res), (ins TYPE:$ty, ID:$v1, ID:$v2, variable_ops), + "$res = OpVectorShuffle $ty $v1 $v2">; +def OpCompositeConstruct: Op<80, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpCompositeConstruct $type">; +def OpCompositeExtract: Op<81, (outs ID:$res), (ins TYPE:$type, ID:$base, variable_ops), + "$res = OpCompositeExtract $type $base">; +def OpCompositeInsert: Op<82, (outs ID:$r), (ins TYPE:$ty, ID:$obj, ID:$base, variable_ops), + "$r = OpCompositeInsert $ty $obj $base">; +def OpCopyObject: UnOp<"OpCopyObject", 83>; +def OpTranspose: UnOp<"OpTranspose", 84>; +def OpCopyLogical: UnOp<"OpCopyLogical", 400>; + +// 3.42.13 Arithmetic Instructions + +def OpSNegate: UnOp<"OpSNegate", 126>; +def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>; +defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>; +defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>; + +defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>; +defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>; + +defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>; +defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>; + +defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>; +defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>; +defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>; + +defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>; +defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>; + +def OpSMod: BinOp<"OpSMod", 139>; + +defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>; +def OpFMod: BinOp<"OpFMod", 141>; + +def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>; +def OpMatrixTimesScalar: BinOp<"OpMatrixTimesScalar", 143>; +def OpVectorTimesMatrix: BinOp<"OpVectorTimesMatrix", 144>; +def OpMatrixTimesVector: BinOp<"OpMatrixTimesVector", 145>; +def OpMatrixTimesMatrix: BinOp<"OpMatrixTimesMatrix", 146>; + +def OpOuterProduct: BinOp<"OpOuterProduct", 147>; +def OpDot: BinOp<"OpDot", 148>; + +def OpIAddCarry: BinOpTyped<"OpIAddCarry", 149, ID, addc>; +def OpISubBorrow: BinOpTyped<"OpISubBorrow", 150, ID, subc>; +def OpUMulExtended: BinOp<"OpUMulExtended", 151>; +def OpSMulExtended: BinOp<"OpSMulExtended", 152>; + +// 3.42.14 Bit Instructions + +defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>; +defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 0, 1>; +defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 0, 1>; + +defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 0, 1>; +defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 0, 1>; +defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 0, 1>; +def OpNot: UnOp<"OpNot", 200>; + +def OpBitFieldInsert: Op<201, (outs ID:$res), + (ins TYPE:$ty, ID:$base, ID:$insert, ID:$offset, ID:$count), + "$res = OpBitFieldInsert $ty $base $insert $offset $count">; +def OpBitFieldSExtract: Op<202, (outs ID:$res), + (ins TYPE:$ty, ID:$base, ID:$offset, ID:$count), + "$res = OpBitFieldSExtract $ty $base $offset $count">; +def OpBitFieldUExtract: Op<203, (outs ID:$res), + (ins TYPE:$ty, ID:$base, ID:$offset, ID:$count), + "$res = OpBitFieldUExtract $ty $base $offset $count">; +def OpBitReverse: Op<204, (outs ID:$r), (ins TYPE:$ty, ID:$b), "$r = OpBitReverse $ty $b">; +def OpBitCount: Op<205, (outs ID:$r), (ins TYPE:$ty, ID:$b), "$r = OpBitCount $ty $b">; + +// 3.42.15 Relational and Logical Instructions + +def OpAny: Op<154, (outs ID:$res), (ins TYPE:$ty, ID:$vec), + "$res = OpAny $ty $vec">; +def OpAll: Op<155, (outs ID:$res), (ins TYPE:$ty, ID:$vec), + "$res = OpAll $ty $vec">; + +def OpIsNan: UnOp<"OpIsNan", 156>; +def OpIsInf: UnOp<"OpIsInf", 157>; +def OpIsFinite: UnOp<"OpIsFinite", 158>; +def OpIsNormal: UnOp<"OpIsNormal", 159>; +def OpSignBitSet: UnOp<"OpSignBitSet", 160>; + +def OpLessOrGreater: BinOp<"OpLessOrGreater", 161>; +def OpOrdered: BinOp<"OpOrdered", 162>; +def OpUnordered: BinOp<"OpUnordered", 163>; + +def OpLogicalEqual: BinOp<"OpLogicalEqual", 164>; +def OpLogicalNotEqual: BinOp<"OpLogicalNotEqual", 165>; +def OpLogicalOr: BinOp<"OpLogicalOr", 166>; +def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>; +def OpLogicalNot: UnOp<"OpLogicalNot", 168>; + +defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>; + +def OpIEqual: BinOp<"OpIEqual", 170>; +def OpINotEqual: BinOp<"OpINotEqual", 171>; + +def OpUGreaterThan: BinOp<"OpUGreaterThan", 172>; +def OpSGreaterThan: BinOp<"OpSGreaterThan", 173>; +def OpUGreaterThanEqual: BinOp<"OpUGreaterThanEqual", 174>; +def OpSGreaterThanEqual: BinOp<"OpSGreaterThanEqual", 175>; +def OpULessThan: BinOp<"OpULessThan", 176>; +def OpSLessThan: BinOp<"OpSLessThan", 177>; +def OpULessThanEqual: BinOp<"OpULessThanEqual", 178>; +def OpSLessThanEqual: BinOp<"OpSLessThanEqual", 179>; + +def OpFOrdEqual: BinOp<"OpFOrdEqual", 180>; +def OpFUnordEqual: BinOp<"OpFUnordEqual", 181>; +def OpFOrdNotEqual: BinOp<"OpFOrdNotEqual", 182>; +def OpFUnordNotEqual: BinOp<"OpFUnordNotEqual", 183>; + +def OpFOrdLessThan: BinOp<"OpFOrdLessThan", 184>; +def OpFUnordLessThan: BinOp<"OpFUnordLessThan", 185>; +def OpFOrdGreaterThan: BinOp<"OpFOrdGreaterThan", 186>; +def OpFUnordGreaterThan: BinOp<"OpFUnordGreaterThan", 187>; + +def OpFOrdLessThanEqual: BinOp<"OpFOrdLessThanEqual", 188>; +def OpFUnordLessThanEqual: BinOp<"OpFUnordLessThanEqual", 189>; +def OpFOrdGreaterThanEqual: BinOp<"OpFOrdGreaterThanEqual", 190>; +def OpFUnordGreaterThanEqual: BinOp<"OpFUnordGreaterThanEqual", 191>; + +// 3.42.16 Derivative Instructions + +def OpDPdx: UnOp<"OpDPdx", 207>; +def OpDPdy: UnOp<"OpDPdy", 208>; +def OpFwidth: UnOp<"OpFwidth", 209>; + +def OpDPdxFine: UnOp<"OpDPdxFine", 210>; +def OpDPdyFine: UnOp<"OpDPdyFine", 211>; +def OpFwidthFine: UnOp<"OpFwidthFine", 212>; + +def OpDPdxCoarse: UnOp<"OpDPdxCoarse", 213>; +def OpDPdyCoarse: UnOp<"OpDPdyCoarse", 214>; +def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>; + +// 3.42.17 Control-Flow Instructions + +def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops), + "$res = OpPhi $type $var0 $block0">; +def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops), + "OpLoopMerge $merge $merge $continue $lc">; +def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc), + "OpSelectionMerge $merge $sc">; +def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">; +let isTerminator=1 in { + def OpBranch: Op<249, (outs), (ins ID:$label), "OpBranch $label">; + def OpBranchConditional: Op<250, (outs), (ins ID:$cond, ID:$true, ID:$false, variable_ops), + "OpBranchConditional $cond $true $false">; + def OpSwitch: Op<251, (outs), (ins ID:$sel, ID:$dflt, variable_ops), "OpSwitch $sel $dflt">; +} +let isReturn = 1, hasDelaySlot=0, isBarrier = 0, isTerminator=1, isNotDuplicable = 1 in { + def OpKill: SimpleOp<"OpKill", 252>; + def OpReturn: SimpleOp<"OpReturn", 253>; + def OpReturnValue: Op<254, (outs), (ins ANYID:$ret), "OpReturnValue $ret">; + def OpUnreachable: SimpleOp<"OpUnreachable", 255>; +} +def OpLifetimeStart: Op<256, (outs), (ins ID:$ptr, i32imm:$sz), "OpLifetimeStart $ptr, $sz">; +def OpLifetimeStop: Op<257, (outs), (ins ID:$ptr, i32imm:$sz), "OpLifetimeStop $ptr, $sz">; + +// 3.42.18 Atomic Instructions + +class AtomicOp<string name, bits<16> opCode>: Op<opCode, (outs ID:$res), + (ins TYPE:$ty, ID:$ptr, ID:$sc, ID:$sem), + "$res = "#name#" $ty $ptr $sc $sem">; + +class AtomicOpVal<string name, bits<16> opCode>: Op<opCode, (outs ID:$res), + (ins TYPE:$ty, ID:$ptr, ID:$sc, ID:$sem, ID:$val), + "$res = "#name#" $ty $ptr $sc $sem $val">; + +def OpAtomicLoad: AtomicOp<"OpAtomicLoad", 227>; + +def OpAtomicStore: Op<228, (outs), (ins ID:$ptr, ID:$sc, ID:$sem, ID:$val), + "OpAtomicStore $ptr $sc $sem $val">; +def OpAtomicExchange: Op<229, (outs ID:$res), + (ins TYPE:$ty, ID:$ptr, ID:$sc, ID:$sem, ID:$val), + "$res = OpAtomicExchange $ty $ptr $sc $sem $val">; +def OpAtomicCompareExchange: Op<230, (outs ID:$res), + (ins TYPE:$ty, ID:$ptr, ID:$sc, ID:$eq, + ID:$neq, ID:$val, ID:$cmp), + "$res = OpAtomicCompareExchange $ty $ptr $sc $eq $neq $val $cmp">; +// TODO Currently the following deprecated opcode is missing: +// OpAtomicCompareExchangeWeak + +def OpAtomicIIncrement: AtomicOp<"OpAtomicIIncrement", 232>; +def OpAtomicIDecrement: AtomicOp<"OpAtomicIDecrement", 233>; + +def OpAtomicIAdd: AtomicOpVal<"OpAtomicIAdd", 234>; +def OpAtomicISub: AtomicOpVal<"OpAtomicISub", 235>; + +def OpAtomicSMin: AtomicOpVal<"OpAtomicSMin", 236>; +def OpAtomicUMin: AtomicOpVal<"OpAtomicUMin", 237>; +def OpAtomicSMax: AtomicOpVal<"OpAtomicSMax", 238>; +def OpAtomicUMax: AtomicOpVal<"OpAtomicUMax", 239>; + +def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>; +def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>; +def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>; + + +def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>; +def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem), + "OpAtomicFlagClear $ptr $sc $sem">; + +// 3.42.19 Primitive Instructions + +def OpEmitVertex: SimpleOp<"OpEmitVertex", 218>; +def OpEndPrimitive: SimpleOp<"OpEndPrimitive", 219>; +def OpEmitStreamVertex: Op<220, (outs), (ins ID:$stream), "OpEmitStreamVertex $stream">; +def OpEndStreamPrimitive: Op<221, (outs), (ins ID:$stream), "OpEndStreamPrimitive $stream">; + +// 3.42.20 Barrier Instructions + +def OpControlBarrier: Op<224, (outs), (ins ID:$exec, ID:$mem, ID:$sem), + "OpControlBarrier $exec $mem $sem">; +def OpMemoryBarrier: Op<225, (outs), (ins ID:$mem, ID:$sem), + "OpMemoryBarrier $mem $sem">; +def OpNamedBarrierInitialize: UnOp<"OpNamedBarrierInitialize", 328>; +def OpMemoryNamedBarrier: Op<329, (outs), (ins ID:$barr, ID:$mem, ID:$sem), + "OpMemoryNamedBarrier $barr $mem $sem">; + +// 3.42.21. Group and Subgroup Instructions + +def OpGroupAll: Op<261, (outs ID:$res), (ins TYPE:$ty, ID:$scope, ID:$pr), + "$res = OpGroupAll $ty $scope $pr">; +def OpGroupAny: Op<262, (outs ID:$res), (ins TYPE:$ty, ID:$scope, ID:$pr), + "$res = OpGroupAny $ty $scope $pr">; +def OpGroupBroadcast: Op<263, (outs ID:$res), (ins TYPE:$ty, ID:$scope, + ID:$val, ID:$id), + "$res = OpGroupBroadcast $ty $scope $val $id">; +class OpGroup<string name, bits<16> opCode>: Op<opCode, (outs ID:$res), + (ins TYPE:$ty, ID:$scope, GroupOperation:$groupOp, ID:$x), + "$res = OpGroup"#name#" $ty $scope $groupOp $x">; +def OpGroupIAdd: OpGroup<"IAdd", 264>; +def OpGroupFAdd: OpGroup<"FAdd", 265>; +def OpGroupFMin: OpGroup<"FMin", 266>; +def OpGroupUMin: OpGroup<"UMin", 267>; +def OpGroupSMin: OpGroup<"SMin", 268>; +def OpGroupFMax: OpGroup<"FMax", 269>; +def OpGroupUMax: OpGroup<"UMax", 270>; +def OpGroupSMax: OpGroup<"SMax", 271>; + +// TODO: 3.42.22. Device-Side Enqueue Instructions +// TODO: 3.42.23. Pipe Instructions + +// 3.42.24. Non-Uniform Instructions + +def OpGroupNonUniformElect: Op<333, (outs ID:$res), (ins TYPE:$ty, ID:$scope), + "$res = OpGroupNonUniformElect $ty $scope">; +class OpGroupNU3<string name, bits<16> opCode>: Op<opCode, + (outs ID:$res), (ins TYPE:$ty, ID:$scope, ID:$pred), + "$res = OpGroupNonUniform"#name#" $ty $scope $pred">; +class OpGroupNU4<string name, bits<16> opCode>: Op<opCode, + (outs ID:$res), (ins TYPE:$ty, ID:$scope, ID:$val, ID:$id), + "$res = OpGroupNonUniform"#name#" $ty $scope $val $id">; +def OpGroupNonUniformAll: OpGroupNU3<"All", 334>; +def OpGroupNonUniformAny: OpGroupNU3<"Any", 335>; +def OpGroupNonUniformAllEqual: OpGroupNU3<"AllEqual", 336>; +def OpGroupNonUniformBroadcast: OpGroupNU4<"Broadcast", 337>; +def OpGroupNonUniformBroadcastFirst: OpGroupNU3<"BroadcastFirst", 338>; +def OpGroupNonUniformBallot: OpGroupNU3<"Ballot", 339>; +def OpGroupNonUniformInverseBallot: OpGroupNU3<"InverseBallot", 340>; +def OpGroupNonUniformBallotBitExtract: OpGroupNU4<"BallotBitExtract", 341>; +def OpGroupNonUniformBallotBitCount: Op<342, (outs ID:$res), + (ins TYPE:$ty, ID:$scope, GroupOperation:$groupOp, ID:$val), + "$res = OpGroupNonUniformBallotBitCount " + "$ty $scope $groupOp $val">; +def OpGroupNonUniformBallotFindLSB: OpGroupNU3<"BallotFindLSB", 343>; +def OpGroupNonUniformBallotFindMSB: OpGroupNU3<"BallotFindMSB", 344>; +def OpGroupNonUniformShuffle: OpGroupNU4<"Shuffle", 345>; +def OpGroupNonUniformShuffleXor: OpGroupNU4<"ShuffleXor", 346>; +def OpGroupNonUniformShuffleUp: OpGroupNU4<"ShuffleUp", 347>; +def OpGroupNonUniformShuffleDown: OpGroupNU4<"ShuffleDown", 348>; +class OpGroupNUGroup<string name, bits<16> opCode>: Op<opCode, (outs ID:$res), + (ins TYPE:$ty, ID:$scope, GroupOperation:$groupOp, + ID:$val, variable_ops), + "$res = OpGroupNonUniform"#name#" $ty $scope $groupOp $val">; +def OpGroupNonUniformIAdd: OpGroupNUGroup<"IAdd", 349>; +def OpGroupNonUniformFAdd: OpGroupNUGroup<"FAdd", 350>; +def OpGroupNonUniformIMul: OpGroupNUGroup<"IMul", 351>; +def OpGroupNonUniformFMul: OpGroupNUGroup<"FMul", 352>; +def OpGroupNonUniformSMin: OpGroupNUGroup<"SMin", 353>; +def OpGroupNonUniformUMin: OpGroupNUGroup<"UMin", 354>; +def OpGroupNonUniformFMin: OpGroupNUGroup<"FMin", 355>; +def OpGroupNonUniformSMax: OpGroupNUGroup<"SMax", 356>; +def OpGroupNonUniformUMax: OpGroupNUGroup<"UMax", 357>; +def OpGroupNonUniformFMax: OpGroupNUGroup<"FMax", 358>; +def OpGroupNonUniformBitwiseAnd: OpGroupNUGroup<"BitwiseAnd", 359>; +def OpGroupNonUniformBitwiseOr: OpGroupNUGroup<"BitwiseOr", 360>; +def OpGroupNonUniformBitwiseXor: OpGroupNUGroup<"BitwiseXor", 361>; +def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>; +def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>; +def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp new file mode 100644 index 000000000000..9294a60506a8 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -0,0 +1,1268 @@ +//===- SPIRVInstructionSelector.cpp ------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the InstructionSelector class for +// SPIRV. +// TODO: This should be generated by TableGen. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class SPIRVInstructionSelector : public InstructionSelector { + const SPIRVSubtarget &STI; + const SPIRVInstrInfo &TII; + const SPIRVRegisterInfo &TRI; + const RegisterBankInfo &RBI; + SPIRVGlobalRegistry &GR; + MachineRegisterInfo *MRI; + +public: + SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const RegisterBankInfo &RBI); + void setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) override; + // Common selection code. Instruction-specific selection occurs in spvSelect. + bool select(MachineInstr &I) override; + static const char *getName() { return DEBUG_TYPE; } + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL + +private: + // tblgen-erated 'select' implementation, used as the initial selector for + // the patterns that don't require complex C++. + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + + // All instruction-specific selection that didn't happen in "select()". + // Is basically a large Switch/Case delegating to all other select method. + bool spvSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectGlobalValue(Register ResVReg, MachineInstr &I, + const MachineInstr *Init = nullptr) const; + + bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, Register SrcReg, + unsigned Opcode) const; + bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + unsigned Opcode) const; + + bool selectLoad(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectStore(MachineInstr &I) const; + + bool selectMemOperation(Register ResVReg, MachineInstr &I) const; + + bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, unsigned NewOpcode) const; + + bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFence(MachineInstr &I) const; + + bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBitreverse(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectConstVector(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectCmp(Register ResVReg, const SPIRVType *ResType, + unsigned comparisonOpcode, MachineInstr &I) const; + + bool selectICmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectFCmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + + bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, + MachineInstr &I) const; + + bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned, unsigned Opcode) const; + bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + + bool selectTrunc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectIntToBool(Register IntReg, Register ResVReg, + const SPIRVType *intTy, const SPIRVType *boolTy, + MachineInstr &I) const; + + bool selectOpUndef(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectIntrinsic(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectExtractVal(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectInsertVal(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectExtractElt(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectInsertElt(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectGEP(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFrameIndex(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBranch(MachineInstr &I) const; + bool selectBranchCond(MachineInstr &I) const; + + bool selectPhi(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + Register buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType = nullptr) const; + + Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; + Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, + MachineInstr &I) const; +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +SPIRVInstructionSelector::SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const RegisterBankInfo &RBI) + : InstructionSelector(), STI(ST), TII(*ST.getInstrInfo()), + TRI(*ST.getRegisterInfo()), RBI(RBI), GR(*ST.getSPIRVGlobalRegistry()), +#define GET_GLOBALISEL_PREDICATES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +void SPIRVInstructionSelector::setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + MRI = &MF.getRegInfo(); + GR.setCurrentFunc(MF); + InstructionSelector::setupMF(MF, KB, CoverageInfo, PSI, BFI); +} + +// Defined in SPIRVLegalizerInfo.cpp. +extern bool isTypeFoldingSupported(unsigned Opcode); + +bool SPIRVInstructionSelector::select(MachineInstr &I) { + assert(I.getParent() && "Instruction should be in a basic block!"); + assert(I.getParent()->getParent() && "Instruction should be in a function!"); + + Register Opcode = I.getOpcode(); + // If it's not a GMIR instruction, we've selected it already. + if (!isPreISelGenericOpcode(Opcode)) { + if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more. + auto *Def = MRI->getVRegDef(I.getOperand(1).getReg()); + if (isTypeFoldingSupported(Def->getOpcode())) { + auto Res = selectImpl(I, *CoverageInfo); + assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT); + if (Res) + return Res; + } + MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg()); + I.removeFromParent(); + } else if (I.getNumDefs() == 1) { + // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(I.getOperand(0).getReg(), LLT::scalar(32)); + } + return true; + } + + if (I.getNumOperands() != I.getNumExplicitOperands()) { + LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n"); + return false; + } + + // Common code for getting return reg+type, and removing selected instr + // from parent occurs here. Instr-specific selection happens in spvSelect(). + bool HasDefs = I.getNumDefs() > 0; + Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0); + SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr; + assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE); + if (spvSelect(ResVReg, ResType, I)) { + if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(ResVReg, LLT::scalar(32)); + I.removeFromParent(); + return true; + } + return false; +} + +bool SPIRVInstructionSelector::spvSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(!isTypeFoldingSupported(I.getOpcode()) || + I.getOpcode() == TargetOpcode::G_CONSTANT); + const unsigned Opcode = I.getOpcode(); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: + return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(), + I); + case TargetOpcode::G_GLOBAL_VALUE: + return selectGlobalValue(ResVReg, I); + case TargetOpcode::G_IMPLICIT_DEF: + return selectOpUndef(ResVReg, ResType, I); + + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return selectIntrinsic(ResVReg, ResType, I); + case TargetOpcode::G_BITREVERSE: + return selectBitreverse(ResVReg, ResType, I); + + case TargetOpcode::G_BUILD_VECTOR: + return selectConstVector(ResVReg, ResType, I); + + case TargetOpcode::G_SHUFFLE_VECTOR: { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + for (auto V : I.getOperand(3).getShuffleMask()) + MIB.addImm(V); + return MIB.constrainAllUses(TII, TRI, RBI); + } + case TargetOpcode::G_MEMMOVE: + case TargetOpcode::G_MEMCPY: + return selectMemOperation(ResVReg, I); + + case TargetOpcode::G_ICMP: + return selectICmp(ResVReg, ResType, I); + case TargetOpcode::G_FCMP: + return selectFCmp(ResVReg, ResType, I); + + case TargetOpcode::G_FRAME_INDEX: + return selectFrameIndex(ResVReg, ResType, I); + + case TargetOpcode::G_LOAD: + return selectLoad(ResVReg, ResType, I); + case TargetOpcode::G_STORE: + return selectStore(I); + + case TargetOpcode::G_BR: + return selectBranch(I); + case TargetOpcode::G_BRCOND: + return selectBranchCond(I); + + case TargetOpcode::G_PHI: + return selectPhi(ResVReg, ResType, I); + + case TargetOpcode::G_FPTOSI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); + case TargetOpcode::G_FPTOUI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToU); + + case TargetOpcode::G_SITOFP: + return selectIToF(ResVReg, ResType, I, true, SPIRV::OpConvertSToF); + case TargetOpcode::G_UITOFP: + return selectIToF(ResVReg, ResType, I, false, SPIRV::OpConvertUToF); + + case TargetOpcode::G_CTPOP: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitCount); + + case TargetOpcode::G_SEXT: + return selectExt(ResVReg, ResType, I, true); + case TargetOpcode::G_ANYEXT: + case TargetOpcode::G_ZEXT: + return selectExt(ResVReg, ResType, I, false); + case TargetOpcode::G_TRUNC: + return selectTrunc(ResVReg, ResType, I); + case TargetOpcode::G_FPTRUNC: + case TargetOpcode::G_FPEXT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpFConvert); + + case TargetOpcode::G_PTRTOINT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertPtrToU); + case TargetOpcode::G_INTTOPTR: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr); + case TargetOpcode::G_BITCAST: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); + case TargetOpcode::G_ADDRSPACE_CAST: + return selectAddrSpaceCast(ResVReg, ResType, I); + + case TargetOpcode::G_ATOMICRMW_OR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicOr); + case TargetOpcode::G_ATOMICRMW_ADD: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicIAdd); + case TargetOpcode::G_ATOMICRMW_AND: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicAnd); + case TargetOpcode::G_ATOMICRMW_MAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMax); + case TargetOpcode::G_ATOMICRMW_MIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMin); + case TargetOpcode::G_ATOMICRMW_SUB: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicISub); + case TargetOpcode::G_ATOMICRMW_XOR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicXor); + case TargetOpcode::G_ATOMICRMW_UMAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMax); + case TargetOpcode::G_ATOMICRMW_UMIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMin); + case TargetOpcode::G_ATOMICRMW_XCHG: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicExchange); + case TargetOpcode::G_ATOMIC_CMPXCHG: + return selectAtomicCmpXchg(ResVReg, ResType, I); + + case TargetOpcode::G_FENCE: + return selectFence(I); + + default: + return false; + } +} + +bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + Register SrcReg, + unsigned Opcode) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectUnOp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned Opcode) const { + return selectUnOpWithSrc(ResVReg, ResType, I, I.getOperand(1).getReg(), + Opcode); +} + +static SPIRV::MemorySemantics getMemSemantics(AtomicOrdering Ord) { + switch (Ord) { + case AtomicOrdering::Acquire: + return SPIRV::MemorySemantics::Acquire; + case AtomicOrdering::Release: + return SPIRV::MemorySemantics::Release; + case AtomicOrdering::AcquireRelease: + return SPIRV::MemorySemantics::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return SPIRV::MemorySemantics::SequentiallyConsistent; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::NotAtomic: + return SPIRV::MemorySemantics::None; + } +} + +static SPIRV::Scope getScope(SyncScope::ID Ord) { + switch (Ord) { + case SyncScope::SingleThread: + return SPIRV::Scope::Invocation; + case SyncScope::System: + return SPIRV::Scope::Device; + default: + llvm_unreachable("Unsupported synchronization Scope ID."); + } +} + +static void addMemoryOperands(MachineMemOperand *MemOp, + MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast<uint32_t>(SPIRV::MemoryOperand::None); + if (MemOp->isVolatile()) + SpvMemOp |= static_cast<uint32_t>(SPIRV::MemoryOperand::Volatile); + if (MemOp->isNonTemporal()) + SpvMemOp |= static_cast<uint32_t>(SPIRV::MemoryOperand::Nontemporal); + if (MemOp->getAlign().value()) + SpvMemOp |= static_cast<uint32_t>(SPIRV::MemoryOperand::Aligned); + + if (SpvMemOp != static_cast<uint32_t>(SPIRV::MemoryOperand::None)) { + MIB.addImm(SpvMemOp); + if (SpvMemOp & static_cast<uint32_t>(SPIRV::MemoryOperand::Aligned)) + MIB.addImm(MemOp->getAlign().value()); + } +} + +static void addMemoryOperands(uint64_t Flags, MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast<uint32_t>(SPIRV::MemoryOperand::None); + if (Flags & MachineMemOperand::Flags::MOVolatile) + SpvMemOp |= static_cast<uint32_t>(SPIRV::MemoryOperand::Volatile); + if (Flags & MachineMemOperand::Flags::MONonTemporal) + SpvMemOp |= static_cast<uint32_t>(SPIRV::MemoryOperand::Nontemporal); + + if (SpvMemOp != static_cast<uint32_t>(SPIRV::MemoryOperand::None)) + MIB.addImm(SpvMemOp); +} + +bool SPIRVInstructionSelector::selectLoad(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register StoreVal = I.getOperand(0 + OpOffset).getReg(); + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore)) + .addUse(Ptr) + .addUse(StoreVal); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized)) + .addDef(I.getOperand(0).getReg()) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + if (I.getNumMemOperands()) + addMemoryOperands(*I.memoperands_begin(), MIB); + bool Result = MIB.constrainAllUses(TII, TRI, RBI); + if (ResVReg.isValid() && ResVReg != MIB->getOperand(0).getReg()) { + BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg) + .addUse(MIB->getOperand(0).getReg()); + } + return Result; +} + +bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned NewOpcode) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(1).getReg(); + // TODO: Changed as it's implemented in the translator. See test/atomicrmw.ll + // auto ScSem = + // getMemSemanticsForStorageClass(GR.getPointerStorageClass(Ptr)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I); + + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemReg) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const { + AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm()); + uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem, I); + SyncScope::ID Ord = SyncScope::ID(I.getOperand(1).getImm()); + uint32_t Scope = static_cast<uint32_t>(getScope(Ord)); + Register ScopeReg = buildI32Constant(Scope, I); + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMemoryBarrier)) + .addUse(ScopeReg) + .addUse(MemSemReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(2).getReg(); + Register Cmp = I.getOperand(3).getReg(); + Register Val = I.getOperand(4).getReg(); + + SPIRVType *SpvValTy = GR.getSPIRVTypeForVReg(Val); + SPIRV::StorageClass SC = GR.getPointerStorageClass(Ptr); + uint32_t ScSem = static_cast<uint32_t>(getMemSemanticsForStorageClass(SC)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSemEq = static_cast<uint32_t>(getMemSemantics(AO)) | ScSem; + Register MemSemEqReg = buildI32Constant(MemSemEq, I); + AtomicOrdering FO = MemOp->getFailureOrdering(); + uint32_t MemSemNeq = static_cast<uint32_t>(getMemSemantics(FO)) | ScSem; + Register MemSemNeqReg = + MemSemEq == MemSemNeq ? MemSemEqReg : buildI32Constant(MemSemNeq, I); + const DebugLoc &DL = I.getDebugLoc(); + return BuildMI(*I.getParent(), I, DL, TII.get(SPIRV::OpAtomicCompareExchange)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(SpvValTy)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemEqReg) + .addUse(MemSemNeqReg) + .addUse(Val) + .addUse(Cmp) + .constrainAllUses(TII, TRI, RBI); +} + +static bool isGenericCastablePtr(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Workgroup: + case SPIRV::StorageClass::CrossWorkgroup: + case SPIRV::StorageClass::Function: + return true; + default: + return false; + } +} + +// In SPIR-V address space casting can only happen to and from the Generic +// storage class. We can also only case Workgroup, CrossWorkgroup, or Function +// pointers to and from Generic pointers. As such, we can convert e.g. from +// Workgroup to Function by going via a Generic pointer as an intermediary. All +// other combinations can only be done by a bitcast, and are probably not safe. +bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + Register SrcPtr = I.getOperand(1).getReg(); + SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr); + SPIRV::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr); + SPIRV::StorageClass DstSC = GR.getPointerStorageClass(ResVReg); + + // Casting from an eligable pointer to Generic. + if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric); + // Casting from Generic to an eligable pointer. + if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr); + // Casting between 2 eligable pointers using Generic as an intermediary. + if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { + Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); + SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( + SrcPtrTy, I, TII, SPIRV::StorageClass::Generic); + MachineBasicBlock &BB = *I.getParent(); + const DebugLoc &DL = I.getDebugLoc(); + bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric)) + .addDef(Tmp) + .addUse(GR.getSPIRVTypeID(GenericPtrTy)) + .addUse(SrcPtr) + .constrainAllUses(TII, TRI, RBI); + return Success && BuildMI(BB, I, DL, TII.get(SPIRV::OpGenericCastToPtr)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Tmp) + .constrainAllUses(TII, TRI, RBI); + } + // TODO Should this case just be disallowed completely? + // We're casting 2 other arbitrary address spaces, so have to bitcast. + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); +} + +static unsigned getFCmpOpcode(unsigned PredNum) { + auto Pred = static_cast<CmpInst::Predicate>(PredNum); + switch (Pred) { + case CmpInst::FCMP_OEQ: + return SPIRV::OpFOrdEqual; + case CmpInst::FCMP_OGE: + return SPIRV::OpFOrdGreaterThanEqual; + case CmpInst::FCMP_OGT: + return SPIRV::OpFOrdGreaterThan; + case CmpInst::FCMP_OLE: + return SPIRV::OpFOrdLessThanEqual; + case CmpInst::FCMP_OLT: + return SPIRV::OpFOrdLessThan; + case CmpInst::FCMP_ONE: + return SPIRV::OpFOrdNotEqual; + case CmpInst::FCMP_ORD: + return SPIRV::OpOrdered; + case CmpInst::FCMP_UEQ: + return SPIRV::OpFUnordEqual; + case CmpInst::FCMP_UGE: + return SPIRV::OpFUnordGreaterThanEqual; + case CmpInst::FCMP_UGT: + return SPIRV::OpFUnordGreaterThan; + case CmpInst::FCMP_ULE: + return SPIRV::OpFUnordLessThanEqual; + case CmpInst::FCMP_ULT: + return SPIRV::OpFUnordLessThan; + case CmpInst::FCMP_UNE: + return SPIRV::OpFUnordNotEqual; + case CmpInst::FCMP_UNO: + return SPIRV::OpUnordered; + default: + llvm_unreachable("Unknown predicate type for FCmp"); + } +} + +static unsigned getICmpOpcode(unsigned PredNum) { + auto Pred = static_cast<CmpInst::Predicate>(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpIEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpINotEqual; + case CmpInst::ICMP_SGE: + return SPIRV::OpSGreaterThanEqual; + case CmpInst::ICMP_SGT: + return SPIRV::OpSGreaterThan; + case CmpInst::ICMP_SLE: + return SPIRV::OpSLessThanEqual; + case CmpInst::ICMP_SLT: + return SPIRV::OpSLessThan; + case CmpInst::ICMP_UGE: + return SPIRV::OpUGreaterThanEqual; + case CmpInst::ICMP_UGT: + return SPIRV::OpUGreaterThan; + case CmpInst::ICMP_ULE: + return SPIRV::OpULessThanEqual; + case CmpInst::ICMP_ULT: + return SPIRV::OpULessThan; + default: + llvm_unreachable("Unknown predicate type for ICmp"); + } +} + +static unsigned getPtrCmpOpcode(unsigned Pred) { + switch (static_cast<CmpInst::Predicate>(Pred)) { + case CmpInst::ICMP_EQ: + return SPIRV::OpPtrEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpPtrNotEqual; + default: + llvm_unreachable("Unknown predicate type for pointer comparison"); + } +} + +// Return the logical operation, or abort if none exists. +static unsigned getBoolCmpOpcode(unsigned PredNum) { + auto Pred = static_cast<CmpInst::Predicate>(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpLogicalEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpLogicalNotEqual; + default: + llvm_unreachable("Unknown predicate type for Bool comparison"); + } +} + +bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpBitReverse)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectConstVector(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // TODO: only const case is supported for now. + assert(std::all_of( + I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) { + if (MO.isDef()) + return true; + if (!MO.isReg()) + return false; + SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg()); + assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE && + ConstTy->getOperand(1).isReg()); + Register ConstReg = ConstTy->getOperand(1).getReg(); + const MachineInstr *Const = this->MRI->getVRegDef(ConstReg); + assert(Const); + return (Const->getOpcode() == TargetOpcode::G_CONSTANT || + Const->getOpcode() == TargetOpcode::G_FCONSTANT); + })); + + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(SPIRV::OpConstantComposite)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i) + MIB.addUse(I.getOperand(i).getReg()); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectCmp(Register ResVReg, + const SPIRVType *ResType, + unsigned CmpOpc, + MachineInstr &I) const { + Register Cmp0 = I.getOperand(2).getReg(); + Register Cmp1 = I.getOperand(3).getReg(); + assert(GR.getSPIRVTypeForVReg(Cmp0)->getOpcode() == + GR.getSPIRVTypeForVReg(Cmp1)->getOpcode() && + "CMP operands should have the same type"); + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(CmpOpc)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Cmp0) + .addUse(Cmp1) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectICmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto Pred = I.getOperand(1).getPredicate(); + unsigned CmpOpc; + + Register CmpOperand = I.getOperand(2).getReg(); + if (GR.isScalarOfType(CmpOperand, SPIRV::OpTypePointer)) + CmpOpc = getPtrCmpOpcode(Pred); + else if (GR.isScalarOrVectorOfType(CmpOperand, SPIRV::OpTypeBool)) + CmpOpc = getBoolCmpOpcode(Pred); + else + CmpOpc = getICmpOpcode(Pred); + return selectCmp(ResVReg, ResType, CmpOpc, I); +} + +void SPIRVInstructionSelector::renderFImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_FCONSTANT && OpIdx == -1 && + "Expected G_FCONSTANT"); + const ConstantFP *FPImm = I.getOperand(1).getFPImm(); + addNumImm(FPImm->getValueAPF().bitcastToAPInt(), MIB); +} + +void SPIRVInstructionSelector::renderImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_CONSTANT && OpIdx == -1 && + "Expected G_CONSTANT"); + addNumImm(I.getOperand(1).getCImm()->getValue(), MIB); +} + +Register +SPIRVInstructionSelector::buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType) const { + const SPIRVType *SpvI32Ty = + ResType ? ResType : GR.getOrCreateSPIRVIntegerType(32, I, TII); + Register NewReg; + NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MachineInstr *MI; + MachineBasicBlock &BB = *I.getParent(); + if (Val == 0) { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)); + } else { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)) + .addImm(APInt(32, Val).getZExtValue()); + } + constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); + return NewReg; +} + +bool SPIRVInstructionSelector::selectFCmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned CmpOp = getFCmpOpcode(I.getOperand(1).getPredicate()); + return selectCmp(ResVReg, ResType, CmpOp, I); +} + +Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, + MachineInstr &I) const { + return buildI32Constant(0, I, ResType); +} + +Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + APInt One = AllOnes ? APInt::getAllOnesValue(BitWidth) + : APInt::getOneBitSet(BitWidth, 0); + Register OneReg = buildI32Constant(One.getZExtValue(), I, ResType); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumEles = ResType->getOperand(2).getImm(); + Register OneVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); + unsigned Opcode = SPIRV::OpConstantComposite; + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(OneVec) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = 0; i < NumEles; ++i) + MIB.addUse(OneReg); + constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI); + return OneVec; + } + return OneReg; +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { + // To extend a bool, we need to use OpSelect between constants. + Register ZeroReg = buildZerosVal(ResType, I); + Register OneReg = buildOnesVal(IsSigned, ResType, I); + bool IsScalarBool = + GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); + unsigned Opcode = + IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(OneReg) + .addUse(ZeroReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIToF(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned, + unsigned Opcode) const { + Register SrcReg = I.getOperand(1).getReg(); + // We can convert bool value directly to float type without OpConvert*ToF, + // however the translator generates OpSelect+OpConvert*ToF, so we do the same. + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + SPIRVType *TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumElts = ResType->getOperand(2).getImm(); + TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); + } + SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + selectSelect(SrcReg, TmpType, I, false); + } + return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); +} + +bool SPIRVInstructionSelector::selectExt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const { + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) + return selectSelect(ResVReg, ResType, I, IsSigned); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectIntToBool(Register IntReg, + Register ResVReg, + const SPIRVType *IntTy, + const SPIRVType *BoolTy, + MachineInstr &I) const { + // To truncate to a bool, we use OpBitwiseAnd 1 and OpINotEqual to zero. + Register BitIntReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + bool IsVectorTy = IntTy->getOpcode() == SPIRV::OpTypeVector; + unsigned Opcode = IsVectorTy ? SPIRV::OpBitwiseAndV : SPIRV::OpBitwiseAndS; + Register Zero = buildZerosVal(IntTy, I); + Register One = buildOnesVal(false, IntTy, I); + MachineBasicBlock &BB = *I.getParent(); + BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(BitIntReg) + .addUse(GR.getSPIRVTypeID(IntTy)) + .addUse(IntReg) + .addUse(One) + .constrainAllUses(TII, TRI, RBI); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpINotEqual)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(BoolTy)) + .addUse(BitIntReg) + .addUse(Zero) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectTrunc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (GR.isScalarOrVectorOfType(ResVReg, SPIRV::OpTypeBool)) { + Register IntReg = I.getOperand(1).getReg(); + const SPIRVType *ArgType = GR.getSPIRVTypeForVReg(IntReg); + return selectIntToBool(IntReg, ResVReg, ArgType, ResType, I); + } + bool IsSigned = GR.isScalarOrVectorSigned(ResType); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectConst(Register ResVReg, + const SPIRVType *ResType, + const APInt &Imm, + MachineInstr &I) const { + assert(ResType->getOpcode() != SPIRV::OpTypePointer || Imm.isNullValue()); + MachineBasicBlock &BB = *I.getParent(); + if (ResType->getOpcode() == SPIRV::OpTypePointer && Imm.isNullValue()) { + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); + } + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + // <=32-bit integers should be caught by the sdag pattern. + assert(Imm.getBitWidth() > 32); + addNumImm(Imm, MIB); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectOpUndef(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); +} + +static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI) { + assert(MO.isReg()); + const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg()); + if (TypeInst->getOpcode() != SPIRV::ASSIGN_TYPE) + return false; + assert(TypeInst->getOperand(1).isReg()); + MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); + return ImmInst->getOpcode() == TargetOpcode::G_CONSTANT; +} + +static int64_t foldImm(const MachineOperand &MO, MachineRegisterInfo *MRI) { + const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg()); + MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); + assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT); + return ImmInst->getOperand(1).getCImm()->getZExtValue(); +} + +bool SPIRVInstructionSelector::selectInsertVal(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeInsert)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + // object to insert + .addUse(I.getOperand(3).getReg()) + // composite to insert into + .addUse(I.getOperand(2).getReg()) + // TODO: support arbitrary number of indices + .addImm(foldImm(I.getOperand(4), MRI)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectExtractVal(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + // TODO: support arbitrary number of indices + .addImm(foldImm(I.getOperand(3), MRI)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectInsertElt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (isImm(I.getOperand(4), MRI)) + return selectInsertVal(ResVReg, ResType, I); + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorInsertDynamic)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .addUse(I.getOperand(4).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectExtractElt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (isImm(I.getOperand(3), MRI)) + return selectExtractVal(ResVReg, ResType, I); + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorExtractDynamic)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectGEP(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // In general we should also support OpAccessChain instrs here (i.e. not + // PtrAccessChain) but SPIRV-LLVM Translator doesn't emit them at all and so + // do we to stay compliant with its test and more importantly consumers. + unsigned Opcode = I.getOperand(2).getImm() ? SPIRV::OpInBoundsPtrAccessChain + : SPIRV::OpPtrAccessChain; + auto Res = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + // Object to get a pointer to. + .addUse(I.getOperand(3).getReg()); + // Adding indices. + for (unsigned i = 4; i < I.getNumExplicitOperands(); ++i) + Res.addUse(I.getOperand(i).getReg()); + return Res.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + switch (I.getIntrinsicID()) { + case Intrinsic::spv_load: + return selectLoad(ResVReg, ResType, I); + break; + case Intrinsic::spv_store: + return selectStore(I); + break; + case Intrinsic::spv_extractv: + return selectExtractVal(ResVReg, ResType, I); + break; + case Intrinsic::spv_insertv: + return selectInsertVal(ResVReg, ResType, I); + break; + case Intrinsic::spv_extractelt: + return selectExtractElt(ResVReg, ResType, I); + break; + case Intrinsic::spv_insertelt: + return selectInsertElt(ResVReg, ResType, I); + break; + case Intrinsic::spv_gep: + return selectGEP(ResVReg, ResType, I); + break; + case Intrinsic::spv_unref_global: + case Intrinsic::spv_init_global: { + MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); + MachineInstr *Init = I.getNumExplicitOperands() > 2 + ? MRI->getVRegDef(I.getOperand(2).getReg()) + : nullptr; + assert(MI); + return selectGlobalValue(MI->getOperand(0).getReg(), *MI, Init); + } break; + case Intrinsic::spv_const_composite: { + // If no values are attached, the composite is null constant. + bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands(); + unsigned Opcode = + IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite; + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + // skip type MD node we already used when generated assign.type for this + if (!IsNull) { + for (unsigned i = I.getNumExplicitDefs() + 1; + i < I.getNumExplicitOperands(); ++i) { + MIB.addUse(I.getOperand(i).getReg()); + } + } + return MIB.constrainAllUses(TII, TRI, RBI); + } break; + case Intrinsic::spv_assign_name: { + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpName)); + MIB.addUse(I.getOperand(I.getNumExplicitDefs() + 1).getReg()); + for (unsigned i = I.getNumExplicitDefs() + 2; + i < I.getNumExplicitOperands(); ++i) { + MIB.addImm(I.getOperand(i).getImm()); + } + return MIB.constrainAllUses(TII, TRI, RBI); + } break; + case Intrinsic::spv_switch: { + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSwitch)); + for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) { + if (I.getOperand(i).isReg()) + MIB.addReg(I.getOperand(i).getReg()); + else if (I.getOperand(i).isCImm()) + addNumImm(I.getOperand(i).getCImm()->getValue(), MIB); + else if (I.getOperand(i).isMBB()) + MIB.addMBB(I.getOperand(i).getMBB()); + else + llvm_unreachable("Unexpected OpSwitch operand"); + } + return MIB.constrainAllUses(TII, TRI, RBI); + } break; + default: + llvm_unreachable("Intrinsic selection not implemented"); + } + return true; +} + +bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranch(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. We can use + // both a G_BR and a G_BRCOND to create an OpBranchConditional. We hit G_BR + // first, so can generate an OpBranchConditional here. If there is no + // G_BRCOND, we just use OpBranch for a regular unconditional branch. + const MachineInstr *PrevI = I.getPrevNode(); + MachineBasicBlock &MBB = *I.getParent(); + if (PrevI != nullptr && PrevI->getOpcode() == TargetOpcode::G_BRCOND) { + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(PrevI->getOperand(0).getReg()) + .addMBB(PrevI->getOperand(1).getMBB()) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); + } + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranch)) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranchCond(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. For an + // explicit conditional branch with no fallthrough, we use both a G_BR and a + // G_BRCOND to create an OpBranchConditional. We should hit G_BR first, and + // generate the OpBranchConditional in selectBranch above. + // + // If an OpBranchConditional has been generated, we simply return, as the work + // is alread done. If there is no OpBranchConditional, LLVM must be relying on + // implicit fallthrough to the next basic block, so we need to create an + // OpBranchConditional with an explicit "false" argument pointing to the next + // basic block that LLVM would fall through to. + const MachineInstr *NextI = I.getNextNode(); + // Check if this has already been successfully selected. + if (NextI != nullptr && NextI->getOpcode() == SPIRV::OpBranchConditional) + return true; + // Must be relying on implicit block fallthrough, so generate an + // OpBranchConditional with the "next" basic block as the "false" target. + MachineBasicBlock &MBB = *I.getParent(); + unsigned NextMBBNum = MBB.getNextNode()->getNumber(); + MachineBasicBlock *NextMBB = I.getMF()->getBlockNumbered(NextMBBNum); + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(I.getOperand(0).getReg()) + .addMBB(I.getOperand(1).getMBB()) + .addMBB(NextMBB) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectPhi(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpPhi)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + const unsigned NumOps = I.getNumOperands(); + for (unsigned i = 1; i < NumOps; i += 2) { + MIB.addUse(I.getOperand(i + 0).getReg()); + MIB.addMBB(I.getOperand(i + 1).getMBB()); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectGlobalValue( + Register ResVReg, MachineInstr &I, const MachineInstr *Init) const { + // FIXME: don't use MachineIRBuilder here, replace it with BuildMI. + MachineIRBuilder MIRBuilder(I); + const GlobalValue *GV = I.getOperand(1).getGlobal(); + SPIRVType *ResType = GR.getOrCreateSPIRVType( + GV->getType(), MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); + + std::string GlobalIdent = GV->getGlobalIdentifier(); + // TODO: suport @llvm.global.annotations. + auto GlobalVar = cast<GlobalVariable>(GV); + + bool HasInit = GlobalVar->hasInitializer() && + !isa<UndefValue>(GlobalVar->getInitializer()); + // Skip empty declaration for GVs with initilaizers till we get the decl with + // passed initializer. + if (HasInit && !Init) + return true; + + unsigned AddrSpace = GV->getAddressSpace(); + SPIRV::StorageClass Storage = addressSpaceToStorageClass(AddrSpace); + bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage && + Storage != SPIRV::StorageClass::Function; + SPIRV::LinkageType LnkType = + (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + + Register Reg = GR.buildGlobalVariable(ResVReg, ResType, GlobalIdent, GV, + Storage, Init, GlobalVar->isConstant(), + HasLnkTy, LnkType, MIRBuilder, true); + return Reg.isValid(); +} + +namespace llvm { +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const RegisterBankInfo &RBI) { + return new SPIRVInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp new file mode 100644 index 000000000000..87f9e9545dd3 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -0,0 +1,301 @@ +//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the Machinelegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVLegalizerInfo.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" + +using namespace llvm; +using namespace llvm::LegalizeActions; +using namespace llvm::LegalityPredicates; + +static const std::set<unsigned> TypeFoldingSupportingOpcs = { + TargetOpcode::G_ADD, + TargetOpcode::G_FADD, + TargetOpcode::G_SUB, + TargetOpcode::G_FSUB, + TargetOpcode::G_MUL, + TargetOpcode::G_FMUL, + TargetOpcode::G_SDIV, + TargetOpcode::G_UDIV, + TargetOpcode::G_FDIV, + TargetOpcode::G_SREM, + TargetOpcode::G_UREM, + TargetOpcode::G_FREM, + TargetOpcode::G_FNEG, + TargetOpcode::G_CONSTANT, + TargetOpcode::G_FCONSTANT, + TargetOpcode::G_AND, + TargetOpcode::G_OR, + TargetOpcode::G_XOR, + TargetOpcode::G_SHL, + TargetOpcode::G_ASHR, + TargetOpcode::G_LSHR, + TargetOpcode::G_SELECT, + TargetOpcode::G_EXTRACT_VECTOR_ELT, +}; + +bool isTypeFoldingSupported(unsigned Opcode) { + return TypeFoldingSupportingOpcs.count(Opcode) > 0; +} + +SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + using namespace TargetOpcode; + + this->ST = &ST; + GR = ST.getSPIRVGlobalRegistry(); + + const LLT s1 = LLT::scalar(1); + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT v16s64 = LLT::fixed_vector(16, 64); + const LLT v16s32 = LLT::fixed_vector(16, 32); + const LLT v16s16 = LLT::fixed_vector(16, 16); + const LLT v16s8 = LLT::fixed_vector(16, 8); + const LLT v16s1 = LLT::fixed_vector(16, 1); + + const LLT v8s64 = LLT::fixed_vector(8, 64); + const LLT v8s32 = LLT::fixed_vector(8, 32); + const LLT v8s16 = LLT::fixed_vector(8, 16); + const LLT v8s8 = LLT::fixed_vector(8, 8); + const LLT v8s1 = LLT::fixed_vector(8, 1); + + const LLT v4s64 = LLT::fixed_vector(4, 64); + const LLT v4s32 = LLT::fixed_vector(4, 32); + const LLT v4s16 = LLT::fixed_vector(4, 16); + const LLT v4s8 = LLT::fixed_vector(4, 8); + const LLT v4s1 = LLT::fixed_vector(4, 1); + + const LLT v3s64 = LLT::fixed_vector(3, 64); + const LLT v3s32 = LLT::fixed_vector(3, 32); + const LLT v3s16 = LLT::fixed_vector(3, 16); + const LLT v3s8 = LLT::fixed_vector(3, 8); + const LLT v3s1 = LLT::fixed_vector(3, 1); + + const LLT v2s64 = LLT::fixed_vector(2, 64); + const LLT v2s32 = LLT::fixed_vector(2, 32); + const LLT v2s16 = LLT::fixed_vector(2, 16); + const LLT v2s8 = LLT::fixed_vector(2, 8); + const LLT v2s1 = LLT::fixed_vector(2, 1); + + const unsigned PSize = ST.getPointerSize(); + const LLT p0 = LLT::pointer(0, PSize); // Function + const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup + const LLT p2 = LLT::pointer(2, PSize); // UniformConstant + const LLT p3 = LLT::pointer(3, PSize); // Workgroup + const LLT p4 = LLT::pointer(4, PSize); // Generic + const LLT p5 = LLT::pointer(5, PSize); // Input + + // TODO: remove copy-pasting here by using concatenation in some way. + auto allPtrsScalarsAndVectors = { + p0, p1, p2, p3, p4, p5, s1, s8, s16, + s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, + v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, + v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allScalarsAndVectors = { + s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, + v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, + v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, + v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, + v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; + + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + + auto allIntScalars = {s8, s16, s32, s64}; + + auto allFloatScalarsAndVectors = { + s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; + + auto allFloatAndIntScalars = allIntScalars; + + auto allPtrs = {p0, p1, p2, p3, p4, p5}; + auto allWritablePtrs = {p0, p1, p3, p4}; + + for (auto Opc : TypeFoldingSupportingOpcs) + getActionDefinitionsBuilder(Opc).custom(); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + + // TODO: add proper rules for vectors legalization. + getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); + + getActionDefinitionsBuilder(G_ADDRSPACE_CAST) + .legalForCartesianProduct(allPtrs, allPtrs); + + getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); + + getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) + .legalForCartesianProduct(allIntScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allScalarsAndVectors); + + getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) + .legalFor(allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( + allIntScalarsAndVectors, allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); + + getActionDefinitionsBuilder(G_BITCAST).legalIf(all( + typeInSet(0, allPtrsScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors), + LegalityPredicate(([=](const LegalityQuery &Query) { + return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); + })))); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); + + getActionDefinitionsBuilder(G_INTTOPTR) + .legalForCartesianProduct(allPtrs, allIntScalars); + getActionDefinitionsBuilder(G_PTRTOINT) + .legalForCartesianProduct(allIntScalars, allPtrs); + getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( + allPtrs, allIntScalars); + + // ST.canDirectlyComparePointers() for pointer args is supported in + // legalizeCustom(). + getActionDefinitionsBuilder(G_ICMP).customIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors))); + + getActionDefinitionsBuilder(G_FCMP).legalIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allFloatScalarsAndVectors))); + + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, + G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, + G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, + G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) + .legalForCartesianProduct(allIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) + .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); + // TODO: add proper legalization rules. + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); + + getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) + .alwaysLegal(); + + // Extensions. + getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) + .legalForCartesianProduct(allScalarsAndVectors); + + // FP conversions. + getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) + .legalForCartesianProduct(allFloatScalarsAndVectors); + + // Pointer-handling. + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + // Control-flow. + getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); + + getActionDefinitionsBuilder({G_FPOW, + G_FEXP, + G_FEXP2, + G_FLOG, + G_FLOG2, + G_FABS, + G_FMINNUM, + G_FMAXNUM, + G_FCEIL, + G_FCOS, + G_FSIN, + G_FSQRT, + G_FFLOOR, + G_FRINT, + G_FNEARBYINT, + G_INTRINSIC_ROUND, + G_INTRINSIC_TRUNC, + G_FMINIMUM, + G_FMAXIMUM, + G_INTRINSIC_ROUNDEVEN}) + .legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( + allFloatScalarsAndVectors, allIntScalarsAndVectors); + + getLegacyLegalizerInfo().computeTables(); + verify(*ST.getInstrInfo()); +} + +static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, + LegalizerHelper &Helper, + MachineRegisterInfo &MRI, + SPIRVGlobalRegistry *GR) { + Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); + GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); + Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) + .addDef(ConvReg) + .addUse(Reg); + return ConvReg; +} + +bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, + MachineInstr &MI) const { + auto Opc = MI.getOpcode(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + if (!isTypeFoldingSupported(Opc)) { + assert(Opc == TargetOpcode::G_ICMP); + assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); + auto &Op0 = MI.getOperand(2); + auto &Op1 = MI.getOperand(3); + Register Reg0 = Op0.getReg(); + Register Reg1 = Op1.getReg(); + CmpInst::Predicate Cond = + static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); + if ((!ST->canDirectlyComparePointers() || + (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && + MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { + LLT ConvT = LLT::scalar(ST->getPointerSize()); + Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), + ST->getPointerSize()); + SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); + Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); + Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); + } + return true; + } + // TODO: implement legalization for other opcodes. + return true; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h new file mode 100644 index 000000000000..2541ff29edb0 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -0,0 +1,36 @@ +//===- SPIRVLegalizerInfo.h --- SPIR-V Legalization Rules --------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the targeting of the MachineLegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H + +#include "SPIRVGlobalRegistry.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +bool isTypeFoldingSupported(unsigned Opcode); + +namespace llvm { + +class LLVMContext; +class SPIRVSubtarget; + +// This class provides the information for legalizing SPIR-V instructions. +class SPIRVLegalizerInfo : public LegalizerInfo { + const SPIRVSubtarget *ST; + SPIRVGlobalRegistry *GR; + +public: + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp new file mode 100644 index 000000000000..8e4ab973bf07 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp @@ -0,0 +1,58 @@ +//=- SPIRVMCInstLower.cpp - Convert SPIR-V MachineInstr to MCInst -*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains code to lower SPIR-V MachineInstrs to their corresponding +// MCInst records. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVMCInstLower.h" +#include "SPIRV.h" +#include "SPIRVModuleAnalysis.h" +#include "SPIRVUtils.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/IR/Constants.h" + +using namespace llvm; + +void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI, + SPIRV::ModuleAnalysisInfo *MAI) const { + OutMI.setOpcode(MI->getOpcode()); + const MachineFunction *MF = MI->getMF(); + for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { + const MachineOperand &MO = MI->getOperand(i); + MCOperand MCOp; + switch (MO.getType()) { + default: + llvm_unreachable("unknown operand type"); + case MachineOperand::MO_GlobalAddress: { + Register FuncReg = MAI->getFuncReg(MO.getGlobal()->getGlobalIdentifier()); + assert(FuncReg.isValid() && "Cannot find function Id"); + MCOp = MCOperand::createReg(FuncReg); + break; + } + case MachineOperand::MO_MachineBasicBlock: + MCOp = MCOperand::createReg(MAI->getOrCreateMBBRegister(*MO.getMBB())); + break; + case MachineOperand::MO_Register: { + Register NewReg = MAI->getRegisterAlias(MF, MO.getReg()); + MCOp = MCOperand::createReg(NewReg.isValid() ? NewReg : MO.getReg()); + break; + } + case MachineOperand::MO_Immediate: + MCOp = MCOperand::createImm(MO.getImm()); + break; + case MachineOperand::MO_FPImmediate: + MCOp = MCOperand::createDFPImm( + MO.getFPImm()->getValueAPF().convertToFloat()); + break; + } + + OutMI.addOperand(MCOp); + } +} diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.h b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.h new file mode 100644 index 000000000000..8392656ed067 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.h @@ -0,0 +1,29 @@ +//=- SPIRVMCInstLower.h -- Convert SPIR-V MachineInstr to MCInst --*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMCINSTLOWER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMCINSTLOWER_H + +#include "llvm/Support/Compiler.h" + +namespace llvm { +class MCInst; +class MachineInstr; +namespace SPIRV { +struct ModuleAnalysisInfo; +} // namespace SPIRV + +// This class is used to lower a MachineInstr into an MCInst. +class LLVM_LIBRARY_VISIBILITY SPIRVMCInstLower { +public: + void lower(const MachineInstr *MI, MCInst &OutMI, + SPIRV::ModuleAnalysisInfo *MAI) const; +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMCINSTLOWER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp new file mode 100644 index 000000000000..fa78dd7942c6 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -0,0 +1,250 @@ +//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The analysis collects instructions that should be output at the module level +// and performs the global register numbering. +// +// The results of this analysis are used in AsmPrinter to rename registers +// globally and to output required instructions at the module level. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVModuleAnalysis.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "TargetInfo/SPIRVTargetInfo.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" + +using namespace llvm; + +#define DEBUG_TYPE "spirv-module-analysis" + +char llvm::SPIRVModuleAnalysis::ID = 0; + +namespace llvm { +void initializeSPIRVModuleAnalysisPass(PassRegistry &); +} // namespace llvm + +INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, + true) + +// Retrieve an unsigned from an MDNode with a list of them as operands. +static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, + unsigned DefaultVal = 0) { + if (MdNode && OpIndex < MdNode->getNumOperands()) { + const auto &Op = MdNode->getOperand(OpIndex); + return mdconst::extract<ConstantInt>(Op)->getZExtValue(); + } + return DefaultVal; +} + +void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { + MAI.MaxID = 0; + for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) + MAI.MS[i].clear(); + MAI.RegisterAliasTable.clear(); + MAI.InstrsToDelete.clear(); + MAI.FuncNameMap.clear(); + MAI.GlobalVarList.clear(); + + // TODO: determine memory model and source language from the configuratoin. + MAI.Mem = SPIRV::MemoryModel::OpenCL; + MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; + unsigned PtrSize = ST->getPointerSize(); + MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 + : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 + : SPIRV::AddressingModel::Logical; + // Get the OpenCL version number from metadata. + // TODO: support other source languages. + MAI.SrcLangVersion = 0; + if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { + // Construct version literal according to OpenCL 2.2 environment spec. + auto VersionMD = VerNode->getOperand(0); + unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); + unsigned MinorNum = getMetadataUInt(VersionMD, 1); + unsigned RevNum = getMetadataUInt(VersionMD, 2); + MAI.SrcLangVersion = 0 | (MajorNum << 16) | (MinorNum << 8) | RevNum; + } +} + +// True if there is an instruction in the MS list with all the same operands as +// the given instruction has (after the given starting index). +// TODO: maybe it needs to check Opcodes too. +static bool findSameInstrInMS(const MachineInstr &A, + SPIRV::ModuleSectionType MSType, + SPIRV::ModuleAnalysisInfo &MAI, + bool UpdateRegAliases, + unsigned StartOpIndex = 0) { + for (const auto *B : MAI.MS[MSType]) { + const unsigned NumAOps = A.getNumOperands(); + if (NumAOps == B->getNumOperands() && A.getNumDefs() == B->getNumDefs()) { + bool AllOpsMatch = true; + for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) { + if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) { + Register RegA = A.getOperand(i).getReg(); + Register RegB = B->getOperand(i).getReg(); + AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) == + MAI.getRegisterAlias(B->getMF(), RegB); + } else { + AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i)); + } + } + if (AllOpsMatch) { + if (UpdateRegAliases) { + assert(A.getOperand(0).isReg() && B->getOperand(0).isReg()); + Register LocalReg = A.getOperand(0).getReg(); + Register GlobalReg = + MAI.getRegisterAlias(B->getMF(), B->getOperand(0).getReg()); + MAI.setRegisterAlias(A.getMF(), LocalReg, GlobalReg); + } + return true; + } + } + } + return false; +} + +// Look for IDs declared with Import linkage, and map the imported name string +// to the register defining that variable (which will usually be the result of +// an OpFunction). This lets us call externally imported functions using +// the correct ID registers. +void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, + const Function &F) { + if (MI.getOpcode() == SPIRV::OpDecorate) { + // If it's got Import linkage. + auto Dec = MI.getOperand(1).getImm(); + if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { + auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); + if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { + // Map imported function name to function ID register. + std::string Name = getStringImm(MI, 2); + Register Target = MI.getOperand(0).getReg(); + // TODO: check defs from different MFs. + MAI.FuncNameMap[Name] = MAI.getRegisterAlias(MI.getMF(), Target); + } + } + } else if (MI.getOpcode() == SPIRV::OpFunction) { + // Record all internal OpFunction declarations. + Register Reg = MI.defs().begin()->getReg(); + Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); + assert(GlobalReg.isValid()); + // TODO: check that it does not conflict with existing entries. + MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg; + } +} + +// Collect the given instruction in the specified MS. We assume global register +// numbering has already occurred by this point. We can directly compare reg +// arguments when detecting duplicates. +static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, + SPIRV::ModuleSectionType MSType, + bool IsConstOrType = false) { + MAI.setSkipEmission(&MI); + if (findSameInstrInMS(MI, MSType, MAI, IsConstOrType, IsConstOrType ? 1 : 0)) + return; // Found a duplicate, so don't add it. + // No duplicates, so add it. + MAI.MS[MSType].push_back(&MI); +} + +// Some global instructions make reference to function-local ID regs, so cannot +// be correctly collected until these registers are globally numbered. +void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { + for (auto F = M.begin(), E = M.end(); F != E; ++F) { + if ((*F).isDeclaration()) + continue; + MachineFunction *MF = MMI->getMachineFunction(*F); + assert(MF); + unsigned FCounter = 0; + for (MachineBasicBlock &MBB : *MF) + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() == SPIRV::OpFunction) + FCounter++; + if (MAI.getSkipEmission(&MI)) + continue; + const unsigned OpCode = MI.getOpcode(); + const bool IsFuncOrParm = + OpCode == SPIRV::OpFunction || OpCode == SPIRV::OpFunctionParameter; + const bool IsConstOrType = + TII->isConstantInstr(MI) || TII->isTypeDeclInstr(MI); + if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { + collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames); + } else if (OpCode == SPIRV::OpEntryPoint) { + collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints); + } else if (TII->isDecorationInstr(MI)) { + collectOtherInstr(MI, MAI, SPIRV::MB_Annotations); + collectFuncNames(MI, *F); + } else if (IsConstOrType || (FCounter > 1 && IsFuncOrParm)) { + // Now OpSpecConstant*s are not in DT, + // but they need to be collected anyway. + enum SPIRV::ModuleSectionType Type = + IsFuncOrParm ? SPIRV::MB_ExtFuncDecls : SPIRV::MB_TypeConstVars; + collectOtherInstr(MI, MAI, Type, IsConstOrType); + } else if (OpCode == SPIRV::OpFunction) { + collectFuncNames(MI, *F); + } + } + } +} + +// Number registers in all functions globally from 0 onwards and store +// the result in global register alias table. +void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { + for (auto F = M.begin(), E = M.end(); F != E; ++F) { + if ((*F).isDeclaration()) + continue; + MachineFunction *MF = MMI->getMachineFunction(*F); + assert(MF); + for (MachineBasicBlock &MBB : *MF) { + for (MachineInstr &MI : MBB) { + for (MachineOperand &Op : MI.operands()) { + if (!Op.isReg()) + continue; + Register Reg = Op.getReg(); + if (MAI.hasRegisterAlias(MF, Reg)) + continue; + Register NewReg = Register::index2VirtReg(MAI.getNextID()); + MAI.setRegisterAlias(MF, Reg, NewReg); + } + } + } + } +} + +struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; + +void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetPassConfig>(); + AU.addRequired<MachineModuleInfoWrapperPass>(); +} + +bool SPIRVModuleAnalysis::runOnModule(Module &M) { + SPIRVTargetMachine &TM = + getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); + ST = TM.getSubtargetImpl(); + GR = ST->getSPIRVGlobalRegistry(); + TII = ST->getInstrInfo(); + + MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); + + setBaseInfo(M); + + // TODO: Process type/const/global var/func decl instructions, number their + // destination registers from 0 to N, collect Extensions and Capabilities. + + // Number rest of registers from N+1 onwards. + numberRegistersGlobally(M); + + // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. + processOtherInstrs(M); + + return false; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h new file mode 100644 index 000000000000..1bef13d458c1 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -0,0 +1,137 @@ +//===- SPIRVModuleAnalysis.h - analysis of global instrs & regs -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The analysis collects instructions that should be output at the module level +// and performs the global register numbering. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMODULEANALYSIS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMODULEANALYSIS_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVSubtarget.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace llvm { +class MachineFunction; +class MachineModuleInfo; + +namespace SPIRV { +// The enum contains logical module sections for the instruction collection. +enum ModuleSectionType { + // MB_Capabilities, MB_Extensions, MB_ExtInstImports, MB_MemoryModel, + MB_EntryPoints, // All OpEntryPoint instructions (if any). + // MB_ExecutionModes, MB_DebugSourceAndStrings, + MB_DebugNames, // All OpName and OpMemberName intrs. + MB_DebugModuleProcessed, // All OpModuleProcessed instructions. + MB_Annotations, // OpDecorate, OpMemberDecorate etc. + MB_TypeConstVars, // OpTypeXXX, OpConstantXXX, and global OpVariables. + MB_ExtFuncDecls, // OpFunction etc. to declare for external funcs. + NUM_MODULE_SECTIONS // Total number of sections requiring basic blocks. +}; + +using InstrList = SmallVector<MachineInstr *>; +// Maps a local register to the corresponding global alias. +using LocalToGlobalRegTable = std::map<Register, Register>; +using RegisterAliasMapTy = + std::map<const MachineFunction *, LocalToGlobalRegTable>; + +// The struct contains results of the module analysis and methods +// to access them. +struct ModuleAnalysisInfo { + SPIRV::MemoryModel Mem; + SPIRV::AddressingModel Addr; + SPIRV::SourceLanguage SrcLang; + unsigned SrcLangVersion; + // Contains the list of all global OpVariables in the module. + SmallVector<MachineInstr *, 4> GlobalVarList; + // Maps function names to coresponding function ID registers. + StringMap<Register> FuncNameMap; + // The set contains machine instructions which are necessary + // for correct MIR but will not be emitted in function bodies. + DenseSet<MachineInstr *> InstrsToDelete; + // The table contains global aliases of local registers for each machine + // function. The aliases are used to substitute local registers during + // code emission. + RegisterAliasMapTy RegisterAliasTable; + // The counter holds the maximum ID we have in the module. + unsigned MaxID; + // The array contains lists of MIs for each module section. + InstrList MS[NUM_MODULE_SECTIONS]; + // The table maps MBB number to SPIR-V unique ID register. + DenseMap<int, Register> BBNumToRegMap; + + Register getFuncReg(std::string FuncName) { + auto FuncReg = FuncNameMap.find(FuncName); + assert(FuncReg != FuncNameMap.end() && "Cannot find function Id"); + return FuncReg->second; + } + InstrList &getMSInstrs(unsigned MSType) { return MS[MSType]; } + void setSkipEmission(MachineInstr *MI) { InstrsToDelete.insert(MI); } + bool getSkipEmission(const MachineInstr *MI) { + return InstrsToDelete.contains(MI); + } + void setRegisterAlias(const MachineFunction *MF, Register Reg, + Register AliasReg) { + RegisterAliasTable[MF][Reg] = AliasReg; + } + Register getRegisterAlias(const MachineFunction *MF, Register Reg) { + auto RI = RegisterAliasTable[MF].find(Reg); + if (RI == RegisterAliasTable[MF].end()) { + return Register(0); + } + return RegisterAliasTable[MF][Reg]; + } + bool hasRegisterAlias(const MachineFunction *MF, Register Reg) { + return RegisterAliasTable.find(MF) != RegisterAliasTable.end() && + RegisterAliasTable[MF].find(Reg) != RegisterAliasTable[MF].end(); + } + unsigned getNextID() { return MaxID++; } + bool hasMBBRegister(const MachineBasicBlock &MBB) { + return BBNumToRegMap.find(MBB.getNumber()) != BBNumToRegMap.end(); + } + // Convert MBB's number to corresponding ID register. + Register getOrCreateMBBRegister(const MachineBasicBlock &MBB) { + auto f = BBNumToRegMap.find(MBB.getNumber()); + if (f != BBNumToRegMap.end()) + return f->second; + Register NewReg = Register::index2VirtReg(getNextID()); + BBNumToRegMap[MBB.getNumber()] = NewReg; + return NewReg; + } +}; +} // namespace SPIRV + +struct SPIRVModuleAnalysis : public ModulePass { + static char ID; + +public: + SPIRVModuleAnalysis() : ModulePass(ID) {} + + bool runOnModule(Module &M) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + static struct SPIRV::ModuleAnalysisInfo MAI; + +private: + void setBaseInfo(const Module &M); + template <typename T> void collectTypesConstsVars(); + void processDefInstrs(const Module &M); + void collectFuncNames(MachineInstr &MI, const Function &F); + void processOtherInstrs(const Module &M); + void numberRegistersGlobally(const Module &M); + + const SPIRVSubtarget *ST; + SPIRVGlobalRegistry *GR; + const SPIRVInstrInfo *TII; + MachineModuleInfo *MMI; +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMODULEANALYSIS_H diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp new file mode 100644 index 000000000000..687f84046650 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -0,0 +1,440 @@ +//===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The pass prepares IR for legalization: it assigns SPIR-V types to registers +// and removes intrinsics which holded these types during IR translation. +// Also it processes constants and registers them in GR to avoid duplication. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "SPIRVUtils.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Target/TargetIntrinsicInfo.h" + +#define DEBUG_TYPE "spirv-prelegalizer" + +using namespace llvm; + +namespace { +class SPIRVPreLegalizer : public MachineFunctionPass { +public: + static char ID; + SPIRVPreLegalizer() : MachineFunctionPass(ID) { + initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry()); + } + bool runOnMachineFunction(MachineFunction &MF) override; +}; +} // namespace + +static bool isSpvIntrinsic(MachineInstr &MI, Intrinsic::ID IntrinsicID) { + if (MI.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS && + MI.getIntrinsicID() == IntrinsicID) + return true; + return false; +} + +static void foldConstantsIntoIntrinsics(MachineFunction &MF) { + SmallVector<MachineInstr *, 10> ToErase; + MachineRegisterInfo &MRI = MF.getRegInfo(); + const unsigned AssignNameOperandShift = 2; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name)) + continue; + unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift; + while (MI.getOperand(NumOp).isReg()) { + MachineOperand &MOp = MI.getOperand(NumOp); + MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg()); + assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT); + MI.removeOperand(NumOp); + MI.addOperand(MachineOperand::CreateImm( + ConstMI->getOperand(1).getCImm()->getZExtValue())); + if (MRI.use_empty(ConstMI->getOperand(0).getReg())) + ToErase.push_back(ConstMI); + } + } + } + for (MachineInstr *MI : ToErase) + MI->eraseFromParent(); +} + +static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + SmallVector<MachineInstr *, 10> ToErase; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) + continue; + assert(MI.getOperand(2).isReg()); + MIB.setInsertPt(*MI.getParent(), MI); + MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); + ToErase.push_back(&MI); + } + } + for (MachineInstr *MI : ToErase) + MI->eraseFromParent(); +} + +// Translating GV, IRTranslator sometimes generates following IR: +// %1 = G_GLOBAL_VALUE +// %2 = COPY %1 +// %3 = G_ADDRSPACE_CAST %2 +// New registers have no SPIRVType and no register class info. +// +// Set SPIRVType for GV, propagate it from GV to other instructions, +// also set register classes. +static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, + MachineRegisterInfo &MRI, + MachineIRBuilder &MIB) { + SPIRVType *SpirvTy = nullptr; + assert(MI && "Machine instr is expected"); + if (MI->getOperand(0).isReg()) { + Register Reg = MI->getOperand(0).getReg(); + SpirvTy = GR->getSPIRVTypeForVReg(Reg); + if (!SpirvTy) { + switch (MI->getOpcode()) { + case TargetOpcode::G_CONSTANT: { + MIB.setInsertPt(*MI->getParent(), MI); + Type *Ty = MI->getOperand(1).getCImm()->getType(); + SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); + break; + } + case TargetOpcode::G_GLOBAL_VALUE: { + MIB.setInsertPt(*MI->getParent(), MI); + Type *Ty = MI->getOperand(1).getGlobal()->getType(); + SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); + break; + } + case TargetOpcode::G_TRUNC: + case TargetOpcode::G_ADDRSPACE_CAST: + case TargetOpcode::COPY: { + MachineOperand &Op = MI->getOperand(1); + MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr; + if (Def) + SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB); + break; + } + default: + break; + } + if (SpirvTy) + GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); + if (!MRI.getRegClassOrNull(Reg)) + MRI.setRegClass(Reg, &SPIRV::IDRegClass); + } + } + return SpirvTy; +} + +// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as +// a dst of the definition, assign SPIRVType to both registers. If SpirvTy is +// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty. +// TODO: maybe move to SPIRVUtils. +static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + MachineRegisterInfo &MRI) { + MachineInstr *Def = MRI.getVRegDef(Reg); + assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected."); + MIB.setInsertPt(*Def->getParent(), + (Def->getNextNode() ? Def->getNextNode()->getIterator() + : Def->getParent()->end())); + Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); + if (auto *RC = MRI.getRegClassOrNull(Reg)) + MRI.setRegClass(NewReg, RC); + SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); + GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); + // This is to make it convenient for Legalizer to get the SPIRVType + // when processing the actual MI (i.e. not pseudo one). + GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF()); + MIB.buildInstr(SPIRV::ASSIGN_TYPE) + .addDef(Reg) + .addUse(NewReg) + .addUse(GR->getSPIRVTypeID(SpirvTy)); + Def->getOperand(0).setReg(NewReg); + MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); + return NewReg; +} + +static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + SmallVector<MachineInstr *, 10> ToErase; + + for (MachineBasicBlock *MBB : post_order(&MF)) { + if (MBB->empty()) + continue; + + bool ReachedBegin = false; + for (auto MII = std::prev(MBB->end()), Begin = MBB->begin(); + !ReachedBegin;) { + MachineInstr &MI = *MII; + + if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) { + Register Reg = MI.getOperand(1).getReg(); + Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0); + MachineInstr *Def = MRI.getVRegDef(Reg); + assert(Def && "Expecting an instruction that defines the register"); + // G_GLOBAL_VALUE already has type info. + if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE) + insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo()); + ToErase.push_back(&MI); + } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT || + MI.getOpcode() == TargetOpcode::G_FCONSTANT || + MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { + // %rc = G_CONSTANT ty Val + // ===> + // %cty = OpType* ty + // %rctmp = G_CONSTANT ty Val + // %rc = ASSIGN_TYPE %rctmp, %cty + Register Reg = MI.getOperand(0).getReg(); + if (MRI.hasOneUse(Reg)) { + MachineInstr &UseMI = *MRI.use_instr_begin(Reg); + if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) || + isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name)) + continue; + } + Type *Ty = nullptr; + if (MI.getOpcode() == TargetOpcode::G_CONSTANT) + Ty = MI.getOperand(1).getCImm()->getType(); + else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT) + Ty = MI.getOperand(1).getFPImm()->getType(); + else { + assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); + Type *ElemTy = nullptr; + MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg()); + assert(ElemMI); + + if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) + ElemTy = ElemMI->getOperand(1).getCImm()->getType(); + else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) + ElemTy = ElemMI->getOperand(1).getFPImm()->getType(); + else + llvm_unreachable("Unexpected opcode"); + unsigned NumElts = + MI.getNumExplicitOperands() - MI.getNumExplicitDefs(); + Ty = VectorType::get(ElemTy, NumElts, false); + } + insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI); + } else if (MI.getOpcode() == TargetOpcode::G_TRUNC || + MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE || + MI.getOpcode() == TargetOpcode::COPY || + MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) { + propagateSPIRVType(&MI, GR, MRI, MIB); + } + + if (MII == Begin) + ReachedBegin = true; + else + --MII; + } + } + for (MachineInstr *MI : ToErase) + MI->eraseFromParent(); +} + +static std::pair<Register, unsigned> +createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI, + const SPIRVGlobalRegistry &GR) { + LLT NewT = LLT::scalar(32); + SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg); + assert(SpvType && "VReg is expected to have SPIRV type"); + bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; + bool IsVectorFloat = + SpvType->getOpcode() == SPIRV::OpTypeVector && + GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == + SPIRV::OpTypeFloat; + IsFloat |= IsVectorFloat; + auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; + auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; + if (MRI.getType(ValReg).isPointer()) { + NewT = LLT::pointer(0, 32); + GetIdOp = SPIRV::GET_pID; + DstClass = &SPIRV::pIDRegClass; + } else if (MRI.getType(ValReg).isVector()) { + NewT = LLT::fixed_vector(2, NewT); + GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; + DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; + } + Register IdReg = MRI.createGenericVirtualRegister(NewT); + MRI.setRegClass(IdReg, DstClass); + return {IdReg, GetIdOp}; +} + +static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, + MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) { + unsigned Opc = MI.getOpcode(); + assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg())); + MachineInstr &AssignTypeInst = + *(MRI.use_instr_begin(MI.getOperand(0).getReg())); + auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; + AssignTypeInst.getOperand(1).setReg(NewReg); + MI.getOperand(0).setReg(NewReg); + MIB.setInsertPt(*MI.getParent(), + (MI.getNextNode() ? MI.getNextNode()->getIterator() + : MI.getParent()->end())); + for (auto &Op : MI.operands()) { + if (!Op.isReg() || Op.isDef()) + continue; + auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); + MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg()); + Op.setReg(IdOpInfo.first); + } +} + +// Defined in SPIRVLegalizerInfo.cpp. +extern bool isTypeFoldingSupported(unsigned Opcode); + +static void processInstrsWithTypeFolding(MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (isTypeFoldingSupported(MI.getOpcode())) + processInstr(MI, MIB, MRI, GR); + } + } +} + +static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>> + SwitchRegToMBB; + DenseMap<Register, MachineBasicBlock *> DefaultMBBs; + DenseSet<Register> SwitchRegs; + MachineRegisterInfo &MRI = MF.getRegInfo(); + // Before IRTranslator pass, spv_switch calls are inserted before each + // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples. + // A switch with two cases may be translated to this MIR sequesnce: + // intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1 + // %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0 + // G_BRCOND %Dst0, %bb.2 + // G_BR %bb.5 + // bb.5.entry: + // %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1 + // G_BRCOND %Dst1, %bb.3 + // G_BR %bb.4 + // bb.2.sw.bb: + // ... + // bb.3.sw.bb1: + // ... + // bb.4.sw.epilog: + // ... + // Walk MIs and collect information about destination MBBs to update + // spv_switch call. We assume that all spv_switch precede corresponding ICMPs. + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) { + assert(MI.getOperand(1).isReg()); + Register Reg = MI.getOperand(1).getReg(); + SwitchRegs.insert(Reg); + // Set the first successor as default MBB to support empty switches. + DefaultMBBs[Reg] = *MBB.succ_begin(); + } + // Process only ICMPs that relate to spv_switches. + if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() && + SwitchRegs.contains(MI.getOperand(2).getReg())) { + assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() && + MI.getOperand(3).isReg()); + Register Dst = MI.getOperand(0).getReg(); + // Set type info for destination register of switch's ICMP instruction. + if (GR->getSPIRVTypeForVReg(Dst) == nullptr) { + MIB.setInsertPt(*MI.getParent(), MI); + Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1); + SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB); + MRI.setRegClass(Dst, &SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF()); + } + Register CmpReg = MI.getOperand(2).getReg(); + MachineOperand &PredOp = MI.getOperand(1); + const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); + assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) && + MRI.hasOneDef(CmpReg)); + uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI); + MachineInstr *CBr = MRI.use_begin(Dst)->getParent(); + assert(CBr->getOpcode() == SPIRV::G_BRCOND && + CBr->getOperand(1).isMBB()); + SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB(); + // The next MI is always BR to either the next case or the default. + MachineInstr *NextMI = CBr->getNextNode(); + assert(NextMI->getOpcode() == SPIRV::G_BR && + NextMI->getOperand(0).isMBB()); + MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB(); + assert(NextMBB != nullptr); + // The default MBB is not started by ICMP with switch's cmp register. + if (NextMBB->front().getOpcode() != SPIRV::G_ICMP || + (NextMBB->front().getOperand(2).isReg() && + NextMBB->front().getOperand(2).getReg() != CmpReg)) + DefaultMBBs[CmpReg] = NextMBB; + } + } + } + // Modify spv_switch's operands by collected values. For the example above, + // the result will be like this: + // intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3 + // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks + // them as skipped and AsmPrinter does not output them. + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isSpvIntrinsic(MI, Intrinsic::spv_switch)) + continue; + assert(MI.getOperand(1).isReg()); + Register Reg = MI.getOperand(1).getReg(); + unsigned NumOp = MI.getNumExplicitOperands(); + SmallVector<const ConstantInt *, 3> Vals; + SmallVector<MachineBasicBlock *, 3> MBBs; + for (unsigned i = 2; i < NumOp; i++) { + Register CReg = MI.getOperand(i).getReg(); + uint64_t Val = getIConstVal(CReg, &MRI); + MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI); + Vals.push_back(ConstInstr->getOperand(1).getCImm()); + MBBs.push_back(SwitchRegToMBB[Reg][Val]); + } + for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--) + MI.removeOperand(i); + MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg])); + for (unsigned i = 0; i < Vals.size(); i++) { + MI.addOperand(MachineOperand::CreateCImm(Vals[i])); + MI.addOperand(MachineOperand::CreateMBB(MBBs[i])); + } + } + } +} + +bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { + // Initialize the type registry. + const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>(); + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + GR->setCurrentFunc(MF); + MachineIRBuilder MIB(MF); + foldConstantsIntoIntrinsics(MF); + insertBitcasts(MF, GR, MIB); + generateAssignInstrs(MF, GR, MIB); + processInstrsWithTypeFolding(MF, GR, MIB); + processSwitches(MF, GR, MIB); + + return true; +} + +INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false, + false) + +char SPIRVPreLegalizer::ID = 0; + +FunctionPass *llvm::createSPIRVPreLegalizerPass() { + return new SPIRVPreLegalizer(); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp new file mode 100644 index 000000000000..9bf9d7fe5b39 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp @@ -0,0 +1,47 @@ +//===- SPIRVRegisterBankInfo.cpp ------------------------------*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the RegisterBankInfo class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "llvm/CodeGen/RegisterBank.h" + +#define GET_REGINFO_ENUM +#include "SPIRVGenRegisterInfo.inc" + +#define GET_TARGET_REGBANK_IMPL +#include "SPIRVGenRegisterBank.inc" + +using namespace llvm; + +// This required for .td selection patterns to work or we'd end up with RegClass +// checks being redundant as all the classes would be mapped to the same bank. +const RegisterBank & +SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, + LLT Ty) const { + switch (RC.getID()) { + case SPIRV::TYPERegClassID: + return SPIRV::TYPERegBank; + case SPIRV::pIDRegClassID: + case SPIRV::IDRegClassID: + return SPIRV::IDRegBank; + case SPIRV::fIDRegClassID: + return SPIRV::fIDRegBank; + case SPIRV::vIDRegClassID: + return SPIRV::vIDRegBank; + case SPIRV::vfIDRegClassID: + return SPIRV::vfIDRegBank; + case SPIRV::ANYIDRegClassID: + case SPIRV::ANYRegClassID: + return SPIRV::IDRegBank; + } + llvm_unreachable("Unknown register class"); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.h b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.h new file mode 100644 index 000000000000..67ddcdefb7dd --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.h @@ -0,0 +1,38 @@ +//===- SPIRVRegisterBankInfo.h -----------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the targeting of the RegisterBankInfo class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERBANKINFO_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERBANKINFO_H + +#include "llvm/CodeGen/RegisterBankInfo.h" + +#define GET_REGBANK_DECLARATIONS +#include "SPIRVGenRegisterBank.inc" + +namespace llvm { + +class TargetRegisterInfo; + +class SPIRVGenRegisterBankInfo : public RegisterBankInfo { +protected: +#define GET_TARGET_REGBANK_CLASS +#include "SPIRVGenRegisterBank.inc" +}; + +// This class provides the information for the target register banks. +class SPIRVRegisterBankInfo final : public SPIRVGenRegisterBankInfo { +public: + const RegisterBank &getRegBankFromRegClass(const TargetRegisterClass &RC, + LLT Ty) const override; +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERBANKINFO_H diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td new file mode 100644 index 000000000000..90c7f3a6e672 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td @@ -0,0 +1,15 @@ +//===-- SPIRVRegisterBanks.td - Describe SPIR-V RegBanks ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Although RegisterBankSelection is disabled we need to distinct the banks +// as InstructionSelector RegClass checking code relies on them +def IDRegBank : RegisterBank<"IDBank", [ID]>; +def fIDRegBank : RegisterBank<"fIDBank", [fID]>; +def vIDRegBank : RegisterBank<"vIDBank", [vID]>; +def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>; +def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.cpp new file mode 100644 index 000000000000..cf8a967d59c4 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.cpp @@ -0,0 +1,32 @@ +//===-- SPIRVRegisterInfo.cpp - SPIR-V Register Information -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the SPIR-V implementation of the TargetRegisterInfo class. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVRegisterInfo.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/MachineFunction.h" + +#define GET_REGINFO_TARGET_DESC +#include "SPIRVGenRegisterInfo.inc" +using namespace llvm; + +SPIRVRegisterInfo::SPIRVRegisterInfo() : SPIRVGenRegisterInfo(SPIRV::ID0) {} + +BitVector SPIRVRegisterInfo::getReservedRegs(const MachineFunction &MF) const { + return BitVector(getNumRegs()); +} + +const MCPhysReg * +SPIRVRegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { + static const MCPhysReg CalleeSavedReg = {0}; + return &CalleeSavedReg; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.h b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.h new file mode 100644 index 000000000000..f6f22b81e0bc --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.h @@ -0,0 +1,36 @@ +//===-- SPIRVRegisterInfo.h - SPIR-V Register Information -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the SPIR-V implementation of the TargetRegisterInfo class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERINFO_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERINFO_H + +#include "llvm/CodeGen/TargetRegisterInfo.h" + +#define GET_REGINFO_HEADER +#include "SPIRVGenRegisterInfo.inc" + +namespace llvm { + +struct SPIRVRegisterInfo : public SPIRVGenRegisterInfo { + SPIRVRegisterInfo(); + const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override; + BitVector getReservedRegs(const MachineFunction &MF) const override; + void eliminateFrameIndex(MachineBasicBlock::iterator MI, int SPAdj, + unsigned FIOperandNum, + RegScavenger *RS = nullptr) const override {} + Register getFrameRegister(const MachineFunction &MF) const override { + return 0; + } +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVREGISTERINFO_H diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td new file mode 100644 index 000000000000..d0b64b6895d0 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td @@ -0,0 +1,39 @@ +//===-- SPIRVRegisterInfo.td - SPIR-V Register defs --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declarations that describe the SPIR-V register file. +// +//===----------------------------------------------------------------------===// + +let Namespace = "SPIRV" in { + def p0 : PtrValueType <i32, 0>; + // All registers are for 32-bit identifiers, so have a single dummy register + + // Class for registers that are the result of OpTypeXXX instructions + def TYPE0 : Register<"TYPE0">; + def TYPE : RegisterClass<"SPIRV", [i32], 32, (add TYPE0)>; + + // Class for every other non-type ID + def ID0 : Register<"ID0">; + def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>; + def fID0 : Register<"FID0">; + def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>; + def pID0 : Register<"pID0">; + def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>; + def vID0 : Register<"pID0">; + def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>; + def vfID0 : Register<"pID0">; + def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>; + + def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>; + + // A few instructions like OpName can take ids from both type and non-type + // instructions, so we need a super-class to allow for both to count as valid + // arguments for these instructions. + def ANY : RegisterClass<"SPIRV", [i32], 32, (add TYPE, ID)>; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp new file mode 100644 index 000000000000..cdf3a160f373 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -0,0 +1,68 @@ +//===-- SPIRVSubtarget.cpp - SPIR-V Subtarget Information ------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SPIR-V specific subclass of TargetSubtargetInfo. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVSubtarget.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVTargetMachine.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/Host.h" + +using namespace llvm; + +#define DEBUG_TYPE "spirv-subtarget" + +#define GET_SUBTARGETINFO_TARGET_DESC +#define GET_SUBTARGETINFO_CTOR +#include "SPIRVGenSubtargetInfo.inc" + +// Compare version numbers, but allow 0 to mean unspecified. +static bool isAtLeastVer(uint32_t Target, uint32_t VerToCompareTo) { + return Target == 0 || Target >= VerToCompareTo; +} + +static unsigned computePointerSize(const Triple &TT) { + const auto Arch = TT.getArch(); + // TODO: unify this with pointers legalization. + assert(TT.isSPIRV()); + return Arch == Triple::spirv32 ? 32 : 64; +} + +SPIRVSubtarget::SPIRVSubtarget(const Triple &TT, const std::string &CPU, + const std::string &FS, + const SPIRVTargetMachine &TM) + : SPIRVGenSubtargetInfo(TT, CPU, /*TuneCPU=*/CPU, FS), + PointerSize(computePointerSize(TT)), SPIRVVersion(0), InstrInfo(), + FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) { + GR = std::make_unique<SPIRVGlobalRegistry>(PointerSize); + CallLoweringInfo = + std::make_unique<SPIRVCallLowering>(TLInfo, *this, GR.get()); + Legalizer = std::make_unique<SPIRVLegalizerInfo>(*this); + RegBankInfo = std::make_unique<SPIRVRegisterBankInfo>(); + InstSelector.reset( + createSPIRVInstructionSelector(TM, *this, *RegBankInfo.get())); +} + +SPIRVSubtarget &SPIRVSubtarget::initSubtargetDependencies(StringRef CPU, + StringRef FS) { + ParseSubtargetFeatures(CPU, /*TuneCPU=*/CPU, FS); + if (SPIRVVersion == 0) + SPIRVVersion = 14; + return *this; +} + +// If the SPIR-V version is >= 1.4 we can call OpPtrEqual and OpPtrNotEqual. +bool SPIRVSubtarget::canDirectlyComparePointers() const { + return isAtLeastVer(SPIRVVersion, 14); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h new file mode 100644 index 000000000000..a6332cfefa8e --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -0,0 +1,93 @@ +//===-- SPIRVSubtarget.h - SPIR-V Subtarget Information --------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the SPIR-V specific subclass of TargetSubtargetInfo. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVSUBTARGET_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVSUBTARGET_H + +#include "SPIRVCallLowering.h" +#include "SPIRVFrameLowering.h" +#include "SPIRVISelLowering.h" +#include "SPIRVInstrInfo.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" +#include "llvm/CodeGen/SelectionDAGTargetInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Target/TargetMachine.h" + +#define GET_SUBTARGETINFO_HEADER +#include "SPIRVGenSubtargetInfo.inc" + +namespace llvm { +class StringRef; +class SPIRVGlobalRegistry; +class SPIRVTargetMachine; + +class SPIRVSubtarget : public SPIRVGenSubtargetInfo { +private: + const unsigned PointerSize; + uint32_t SPIRVVersion; + + std::unique_ptr<SPIRVGlobalRegistry> GR; + + SPIRVInstrInfo InstrInfo; + SPIRVFrameLowering FrameLowering; + SPIRVTargetLowering TLInfo; + + // GlobalISel related APIs. + std::unique_ptr<CallLowering> CallLoweringInfo; + std::unique_ptr<RegisterBankInfo> RegBankInfo; + std::unique_ptr<LegalizerInfo> Legalizer; + std::unique_ptr<InstructionSelector> InstSelector; + +public: + // This constructor initializes the data members to match that + // of the specified triple. + SPIRVSubtarget(const Triple &TT, const std::string &CPU, + const std::string &FS, const SPIRVTargetMachine &TM); + SPIRVSubtarget &initSubtargetDependencies(StringRef CPU, StringRef FS); + + // Parses features string setting specified subtarget options. + // The definition of this function is auto generated by tblgen. + void ParseSubtargetFeatures(StringRef CPU, StringRef TuneCPU, StringRef FS); + unsigned getPointerSize() const { return PointerSize; } + bool canDirectlyComparePointers() const; + uint32_t getSPIRVVersion() const { return SPIRVVersion; }; + SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); } + + const CallLowering *getCallLowering() const override { + return CallLoweringInfo.get(); + } + const RegisterBankInfo *getRegBankInfo() const override { + return RegBankInfo.get(); + } + const LegalizerInfo *getLegalizerInfo() const override { + return Legalizer.get(); + } + InstructionSelector *getInstructionSelector() const override { + return InstSelector.get(); + } + const SPIRVInstrInfo *getInstrInfo() const override { return &InstrInfo; } + const SPIRVFrameLowering *getFrameLowering() const override { + return &FrameLowering; + } + const SPIRVTargetLowering *getTargetLowering() const override { + return &TLInfo; + } + const SPIRVRegisterInfo *getRegisterInfo() const override { + return &InstrInfo.getRegisterInfo(); + } +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVSUBTARGET_H diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp new file mode 100644 index 000000000000..f7c88a5c6d4a --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -0,0 +1,186 @@ +//===- SPIRVTargetMachine.cpp - Define TargetMachine for SPIR-V -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements the info about SPIR-V target spec. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVTargetMachine.h" +#include "SPIRV.h" +#include "SPIRVCallLowering.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" +#include "SPIRVTargetObjectFile.h" +#include "SPIRVTargetTransformInfo.h" +#include "TargetInfo/SPIRVTargetInfo.h" +#include "llvm/CodeGen/GlobalISel/IRTranslator.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelect.h" +#include "llvm/CodeGen/GlobalISel/Legalizer.h" +#include "llvm/CodeGen/GlobalISel/RegBankSelect.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/InitializePasses.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetOptions.h" + +using namespace llvm; + +extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { + // Register the target. + RegisterTargetMachine<SPIRVTargetMachine> X(getTheSPIRV32Target()); + RegisterTargetMachine<SPIRVTargetMachine> Y(getTheSPIRV64Target()); + + PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); + initializeSPIRVModuleAnalysisPass(PR); +} + +static std::string computeDataLayout(const Triple &TT) { + const auto Arch = TT.getArch(); + if (Arch == Triple::spirv32) + return "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-" + "v96:128-v192:256-v256:256-v512:512-v1024:1024"; + return "e-i64:64-v16:16-v24:32-v32:32-v48:64-" + "v96:128-v192:256-v256:256-v512:512-v1024:1024"; +} + +static Reloc::Model getEffectiveRelocModel(Optional<Reloc::Model> RM) { + if (!RM) + return Reloc::PIC_; + return *RM; +} + +// Pin SPIRVTargetObjectFile's vtables to this file. +SPIRVTargetObjectFile::~SPIRVTargetObjectFile() {} + +SPIRVTargetMachine::SPIRVTargetMachine(const Target &T, const Triple &TT, + StringRef CPU, StringRef FS, + const TargetOptions &Options, + Optional<Reloc::Model> RM, + Optional<CodeModel::Model> CM, + CodeGenOpt::Level OL, bool JIT) + : LLVMTargetMachine(T, computeDataLayout(TT), TT, CPU, FS, Options, + getEffectiveRelocModel(RM), + getEffectiveCodeModel(CM, CodeModel::Small), OL), + TLOF(std::make_unique<TargetLoweringObjectFileELF>()), + Subtarget(TT, CPU.str(), FS.str(), *this) { + initAsmInfo(); + setGlobalISel(true); + setFastISel(false); + setO0WantsFastISel(false); + setRequiresStructuredCFG(false); +} + +namespace { +// SPIR-V Code Generator Pass Configuration Options. +class SPIRVPassConfig : public TargetPassConfig { +public: + SPIRVPassConfig(SPIRVTargetMachine &TM, PassManagerBase &PM) + : TargetPassConfig(TM, PM) {} + + SPIRVTargetMachine &getSPIRVTargetMachine() const { + return getTM<SPIRVTargetMachine>(); + } + void addIRPasses() override; + void addISelPrepare() override; + + bool addIRTranslator() override; + void addPreLegalizeMachineIR() override; + bool addLegalizeMachineIR() override; + bool addRegBankSelect() override; + bool addGlobalInstructionSelect() override; + + FunctionPass *createTargetRegisterAllocator(bool) override; + void addFastRegAlloc() override {} + void addOptimizedRegAlloc() override {} + + void addPostRegAlloc() override; +}; +} // namespace + +// We do not use physical registers, and maintain virtual registers throughout +// the entire pipeline, so return nullptr to disable register allocation. +FunctionPass *SPIRVPassConfig::createTargetRegisterAllocator(bool) { + return nullptr; +} + +// Disable passes that break from assuming no virtual registers exist. +void SPIRVPassConfig::addPostRegAlloc() { + // Do not work with vregs instead of physical regs. + disablePass(&MachineCopyPropagationID); + disablePass(&PostRAMachineSinkingID); + disablePass(&PostRASchedulerID); + disablePass(&FuncletLayoutID); + disablePass(&StackMapLivenessID); + disablePass(&PatchableFunctionID); + disablePass(&ShrinkWrapID); + disablePass(&LiveDebugValuesID); + + // Do not work with OpPhi. + disablePass(&BranchFolderPassID); + disablePass(&MachineBlockPlacementID); + + TargetPassConfig::addPostRegAlloc(); +} + +TargetTransformInfo +SPIRVTargetMachine::getTargetTransformInfo(const Function &F) const { + return TargetTransformInfo(SPIRVTTIImpl(this, F)); +} + +TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) { + return new SPIRVPassConfig(*this, PM); +} + +void SPIRVPassConfig::addIRPasses() { TargetPassConfig::addIRPasses(); } + +void SPIRVPassConfig::addISelPrepare() { + addPass(createSPIRVEmitIntrinsicsPass(&getTM<SPIRVTargetMachine>())); + TargetPassConfig::addISelPrepare(); +} + +bool SPIRVPassConfig::addIRTranslator() { + addPass(new IRTranslator(getOptLevel())); + return false; +} + +void SPIRVPassConfig::addPreLegalizeMachineIR() { + addPass(createSPIRVPreLegalizerPass()); +} + +// Use a default legalizer. +bool SPIRVPassConfig::addLegalizeMachineIR() { + addPass(new Legalizer()); + return false; +} + +// Do not add a RegBankSelect pass, as we only ever need virtual registers. +bool SPIRVPassConfig::addRegBankSelect() { + disablePass(&RegBankSelect::ID); + return false; +} + +namespace { +// A custom subclass of InstructionSelect, which is mostly the same except from +// not requiring RegBankSelect to occur previously. +class SPIRVInstructionSelect : public InstructionSelect { + // We don't use register banks, so unset the requirement for them + MachineFunctionProperties getRequiredProperties() const override { + return InstructionSelect::getRequiredProperties().reset( + MachineFunctionProperties::Property::RegBankSelected); + } +}; +} // namespace + +bool SPIRVPassConfig::addGlobalInstructionSelect() { + addPass(new SPIRVInstructionSelect()); + return false; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h new file mode 100644 index 000000000000..f3597971bc95 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h @@ -0,0 +1,47 @@ +//===-- SPIRVTargetMachine.h - Define TargetMachine for SPIR-V -*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the SPIR-V specific subclass of TargetMachine. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTARGETMACHINE_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTARGETMACHINE_H + +#include "SPIRVSubtarget.h" +#include "llvm/Target/TargetMachine.h" + +namespace llvm { +class SPIRVTargetMachine : public LLVMTargetMachine { + std::unique_ptr<TargetLoweringObjectFile> TLOF; + SPIRVSubtarget Subtarget; + +public: + SPIRVTargetMachine(const Target &T, const Triple &TT, StringRef CPU, + StringRef FS, const TargetOptions &Options, + Optional<Reloc::Model> RM, Optional<CodeModel::Model> CM, + CodeGenOpt::Level OL, bool JIT); + + const SPIRVSubtarget *getSubtargetImpl() const { return &Subtarget; } + + const SPIRVSubtarget *getSubtargetImpl(const Function &) const override { + return &Subtarget; + } + + TargetTransformInfo getTargetTransformInfo(const Function &F) const override; + + TargetPassConfig *createPassConfig(PassManagerBase &PM) override; + bool usesPhysRegsForValues() const override { return false; } + + TargetLoweringObjectFile *getObjFileLowering() const override { + return TLOF.get(); + } +}; +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVTARGETMACHINE_H diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetObjectFile.h b/llvm/lib/Target/SPIRV/SPIRVTargetObjectFile.h new file mode 100644 index 000000000000..00c456971ef1 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVTargetObjectFile.h @@ -0,0 +1,45 @@ +//===-- SPIRVTargetObjectFile.h - SPIRV Object Info -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTARGETOBJECTFILE_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTARGETOBJECTFILE_H + +#include "llvm/MC/MCSection.h" +#include "llvm/MC/SectionKind.h" +#include "llvm/Target/TargetLoweringObjectFile.h" + +namespace llvm { + +class SPIRVTargetObjectFile : public TargetLoweringObjectFile { +public: + ~SPIRVTargetObjectFile() override; + + void Initialize(MCContext &ctx, const TargetMachine &TM) override { + TargetLoweringObjectFile::Initialize(ctx, TM); + } + // All words in a SPIR-V module (excepting the first 5 ones) are a linear + // sequence of instructions in a specific order. We put all the instructions + // in the single text section. + MCSection *getSectionForConstant(const DataLayout &DL, SectionKind Kind, + const Constant *C, + Align &Alignment) const override { + return TextSection; + } + MCSection *getExplicitSectionGlobal(const GlobalObject *GO, SectionKind Kind, + const TargetMachine &TM) const override { + return TextSection; + } + MCSection *SelectSectionForGlobal(const GlobalObject *GO, SectionKind Kind, + const TargetMachine &TM) const override { + return TextSection; + } +}; + +} // end namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVTARGETOBJECTFILE_H diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h new file mode 100644 index 000000000000..ac351cf42f5c --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h @@ -0,0 +1,44 @@ +//===- SPIRVTargetTransformInfo.h - SPIR-V specific TTI ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// \file +// This file contains a TargetTransformInfo::Concept conforming object specific +// to the SPIRV target machine. It uses the target's detailed information to +// provide more precise answers to certain TTI queries, while letting the +// target independent and default TTI implementations handle the rest. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTARGETTRANSFORMINFO_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTARGETTRANSFORMINFO_H + +#include "SPIRV.h" +#include "SPIRVTargetMachine.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/BasicTTIImpl.h" + +namespace llvm { +class SPIRVTTIImpl : public BasicTTIImplBase<SPIRVTTIImpl> { + using BaseT = BasicTTIImplBase<SPIRVTTIImpl>; + + friend BaseT; + + const SPIRVSubtarget *ST; + const SPIRVTargetLowering *TLI; + + const TargetSubtargetInfo *getST() const { return ST; } + const SPIRVTargetLowering *getTLI() const { return TLI; } + +public: + explicit SPIRVTTIImpl(const SPIRVTargetMachine *TM, const Function &F) + : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)), + TLI(ST->getTargetLowering()) {} +}; + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVTARGETTRANSFORMINFO_H diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp new file mode 100644 index 000000000000..b92dc12735f8 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -0,0 +1,207 @@ +//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVUtils.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRV.h" +#include "SPIRVInstrInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/IR/IntrinsicsSPIRV.h" + +using namespace llvm; + +// The following functions are used to add these string literals as a series of +// 32-bit integer operands with the correct format, and unpack them if necessary +// when making string comparisons in compiler passes. +// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. +static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { + uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. + for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { + unsigned StrIndex = i + WordIndex; + uint8_t CharToAdd = 0; // Initilize char as padding/null. + if (StrIndex < Str.size()) { // If it's within the string, get a real char. + CharToAdd = Str[StrIndex]; + } + Word |= (CharToAdd << (WordIndex * 8)); + } + return Word; +} + +// Get length including padding and null terminator. +static size_t getPaddedLen(const StringRef &Str) { + const size_t Len = Str.size() + 1; + return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4)); +} + +void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add an operand for the 32-bits of chars or padding. + MIB.addImm(convertCharsToWord(Str, i)); + } +} + +void addStringImm(const StringRef &Str, IRBuilder<> &B, + std::vector<Value *> &Args) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add a vector element for the 32-bits of chars or padding. + Args.push_back(B.getInt32(convertCharsToWord(Str, i))); + } +} + +std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { + return getSPIRVStringOperand(MI, StartIndex); +} + +void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { + const auto Bitwidth = Imm.getBitWidth(); + switch (Bitwidth) { + case 1: + break; // Already handled. + case 8: + case 16: + case 32: + MIB.addImm(Imm.getZExtValue()); + break; + case 64: { + uint64_t FullImm = Imm.getZExtValue(); + uint32_t LowBits = FullImm & 0xffffffff; + uint32_t HighBits = (FullImm >> 32) & 0xffffffff; + MIB.addImm(LowBits).addImm(HighBits); + break; + } + default: + report_fatal_error("Unsupported constant bitwidth"); + } +} + +void buildOpName(Register Target, const StringRef &Name, + MachineIRBuilder &MIRBuilder) { + if (!Name.empty()) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); + addStringImm(Name, MIB); + } +} + +static void finishBuildOpDecorate(MachineInstrBuilder &MIB, + const std::vector<uint32_t> &DecArgs, + StringRef StrImm) { + if (!StrImm.empty()) + addStringImm(StrImm, MIB); + for (const auto &DecArg : DecArgs) + MIB.addImm(DecArg); +} + +void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector<uint32_t> &DecArgs, StringRef StrImm) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) + .addUse(Reg) + .addImm(static_cast<uint32_t>(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector<uint32_t> &DecArgs, StringRef StrImm) { + MachineBasicBlock &MBB = *I.getParent(); + auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) + .addUse(Reg) + .addImm(static_cast<uint32_t>(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +// TODO: maybe the following two functions should be handled in the subtarget +// to allow for different OpenCL vs Vulkan handling. +unsigned storageClassToAddressSpace(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Function: + return 0; + case SPIRV::StorageClass::CrossWorkgroup: + return 1; + case SPIRV::StorageClass::UniformConstant: + return 2; + case SPIRV::StorageClass::Workgroup: + return 3; + case SPIRV::StorageClass::Generic: + return 4; + case SPIRV::StorageClass::Input: + return 7; + default: + llvm_unreachable("Unable to get address space id"); + } +} + +SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace) { + switch (AddrSpace) { + case 0: + return SPIRV::StorageClass::Function; + case 1: + return SPIRV::StorageClass::CrossWorkgroup; + case 2: + return SPIRV::StorageClass::UniformConstant; + case 3: + return SPIRV::StorageClass::Workgroup; + case 4: + return SPIRV::StorageClass::Generic; + case 7: + return SPIRV::StorageClass::Input; + default: + llvm_unreachable("Unknown address space"); + } +} + +SPIRV::MemorySemantics getMemSemanticsForStorageClass(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::StorageBuffer: + case SPIRV::StorageClass::Uniform: + return SPIRV::MemorySemantics::UniformMemory; + case SPIRV::StorageClass::Workgroup: + return SPIRV::MemorySemantics::WorkgroupMemory; + case SPIRV::StorageClass::CrossWorkgroup: + return SPIRV::MemorySemantics::CrossWorkgroupMemory; + case SPIRV::StorageClass::AtomicCounter: + return SPIRV::MemorySemantics::AtomicCounterMemory; + case SPIRV::StorageClass::Image: + return SPIRV::MemorySemantics::ImageMemory; + default: + return SPIRV::MemorySemantics::None; + } +} + +MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, + const MachineRegisterInfo *MRI) { + MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg); + if (ConstInstr->getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS && + ConstInstr->getIntrinsicID() == Intrinsic::spv_track_constant) { + ConstReg = ConstInstr->getOperand(2).getReg(); + ConstInstr = MRI->getVRegDef(ConstReg); + } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { + ConstReg = ConstInstr->getOperand(1).getReg(); + ConstInstr = MRI->getVRegDef(ConstReg); + } + return ConstInstr; +} + +uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { + const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); + assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); + return MI->getOperand(1).getCImm()->getValue().getZExtValue(); +} + +Type *getMDOperandAsType(const MDNode *N, unsigned I) { + return cast<ValueAsMetadata>(N->getOperand(I))->getType(); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h new file mode 100644 index 000000000000..ffa82c9c1fe4 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -0,0 +1,83 @@ +//===--- SPIRVUtils.h ---- SPIR-V Utility Functions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "llvm/IR/IRBuilder.h" +#include <string> + +namespace llvm { +class MCInst; +class MachineFunction; +class MachineInstr; +class MachineInstrBuilder; +class MachineIRBuilder; +class MachineRegisterInfo; +class Register; +class StringRef; +class SPIRVInstrInfo; +} // namespace llvm + +// Add the given string as a series of integer operand, inserting null +// terminators and padding to make sure the operands all have 32-bit +// little-endian words. +void addStringImm(const llvm::StringRef &Str, llvm::MachineInstrBuilder &MIB); +void addStringImm(const llvm::StringRef &Str, llvm::IRBuilder<> &B, + std::vector<llvm::Value *> &Args); + +// Read the series of integer operands back as a null-terminated string using +// the reverse of the logic in addStringImm. +std::string getStringImm(const llvm::MachineInstr &MI, unsigned StartIndex); + +// Add the given numerical immediate to MIB. +void addNumImm(const llvm::APInt &Imm, llvm::MachineInstrBuilder &MIB); + +// Add an OpName instruction for the given target register. +void buildOpName(llvm::Register Target, const llvm::StringRef &Name, + llvm::MachineIRBuilder &MIRBuilder); + +// Add an OpDecorate instruction for the given Reg. +void buildOpDecorate(llvm::Register Reg, llvm::MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector<uint32_t> &DecArgs, + llvm::StringRef StrImm = ""); +void buildOpDecorate(llvm::Register Reg, llvm::MachineInstr &I, + const llvm::SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector<uint32_t> &DecArgs, + llvm::StringRef StrImm = ""); + +// Convert a SPIR-V storage class to the corresponding LLVM IR address space. +unsigned storageClassToAddressSpace(llvm::SPIRV::StorageClass SC); + +// Convert an LLVM IR address space to a SPIR-V storage class. +llvm::SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace); + +llvm::SPIRV::MemorySemantics +getMemSemanticsForStorageClass(llvm::SPIRV::StorageClass SC); + +// Find def instruction for the given ConstReg, walking through +// spv_track_constant and ASSIGN_TYPE instructions. Updates ConstReg by def +// of OpConstant instruction. +llvm::MachineInstr * +getDefInstrMaybeConstant(llvm::Register &ConstReg, + const llvm::MachineRegisterInfo *MRI); + +// Get constant integer value of the given ConstReg. +uint64_t getIConstVal(llvm::Register ConstReg, + const llvm::MachineRegisterInfo *MRI); + +// Get type of i-th operand of the metadata node. +llvm::Type *getMDOperandAsType(const llvm::MDNode *N, unsigned I); +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.cpp b/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.cpp new file mode 100644 index 000000000000..fb7cab4fe779 --- /dev/null +++ b/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.cpp @@ -0,0 +1,28 @@ +//===-- SPIRVTargetInfo.cpp - SPIR-V Target Implementation ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TargetInfo/SPIRVTargetInfo.h" +#include "llvm/MC/TargetRegistry.h" + +using namespace llvm; + +Target &llvm::getTheSPIRV32Target() { + static Target TheSPIRV32Target; + return TheSPIRV32Target; +} +Target &llvm::getTheSPIRV64Target() { + static Target TheSPIRV64Target; + return TheSPIRV64Target; +} + +extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTargetInfo() { + RegisterTarget<Triple::spirv32> X(getTheSPIRV32Target(), "spirv32", + "SPIR-V 32-bit", "SPIRV"); + RegisterTarget<Triple::spirv64> Y(getTheSPIRV64Target(), "spirv64", + "SPIR-V 64-bit", "SPIRV"); +} diff --git a/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.h b/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.h new file mode 100644 index 000000000000..4353258e1d1a --- /dev/null +++ b/llvm/lib/Target/SPIRV/TargetInfo/SPIRVTargetInfo.h @@ -0,0 +1,21 @@ +//===-- SPIRVTargetInfo.h - SPIRV Target Implementation ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_TARGETINFO_SPIRVTARGETINFO_H +#define LLVM_LIB_TARGET_SPIRV_TARGETINFO_SPIRVTARGETINFO_H + +namespace llvm { + +class Target; + +Target &getTheSPIRV32Target(); +Target &getTheSPIRV64Target(); + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_TARGETINFO_SPIRVTARGETINFO_H |