diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp new file mode 100644 index 000000000000..66ff51c912b0 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -0,0 +1,45 @@ +//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SPIRVTargetLowering class. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVISelLowering.h" +#include "SPIRV.h" + +#define DEBUG_TYPE "spirv-lower" + +using namespace llvm; + +unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( + LLVMContext &Context, CallingConv::ID CC, EVT VT) const { + // This code avoids CallLowering fail inside getVectorTypeBreakdown + // on v3i1 arguments. Maybe we need to return 1 for all types. + // TODO: remove it once this case is supported by the default implementation. + if (VT.isVector() && VT.getVectorNumElements() == 3 && + (VT.getVectorElementType() == MVT::i1 || + VT.getVectorElementType() == MVT::i8)) + return 1; + return getNumRegisters(Context, VT); +} + +MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, + CallingConv::ID CC, + EVT VT) const { + // This code avoids CallLowering fail inside getVectorTypeBreakdown + // on v3i1 arguments. Maybe we need to return i32 for all types. + // TODO: remove it once this case is supported by the default implementation. + if (VT.isVector() && VT.getVectorNumElements() == 3) { + if (VT.getVectorElementType() == MVT::i1) + return MVT::v4i1; + else if (VT.getVectorElementType() == MVT::i8) + return MVT::v4i8; + } + return getRegisterType(Context, VT); +} |
