aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp45
1 files changed, 45 insertions, 0 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
new file mode 100644
index 000000000000..66ff51c912b0
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -0,0 +1,45 @@
+//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SPIRVTargetLowering class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVISelLowering.h"
+#include "SPIRV.h"
+
+#define DEBUG_TYPE "spirv-lower"
+
+using namespace llvm;
+
+unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
+ LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
+ // This code avoids CallLowering fail inside getVectorTypeBreakdown
+ // on v3i1 arguments. Maybe we need to return 1 for all types.
+ // TODO: remove it once this case is supported by the default implementation.
+ if (VT.isVector() && VT.getVectorNumElements() == 3 &&
+ (VT.getVectorElementType() == MVT::i1 ||
+ VT.getVectorElementType() == MVT::i8))
+ return 1;
+ return getNumRegisters(Context, VT);
+}
+
+MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
+ EVT VT) const {
+ // This code avoids CallLowering fail inside getVectorTypeBreakdown
+ // on v3i1 arguments. Maybe we need to return i32 for all types.
+ // TODO: remove it once this case is supported by the default implementation.
+ if (VT.isVector() && VT.getVectorNumElements() == 3) {
+ if (VT.getVectorElementType() == MVT::i1)
+ return MVT::v4i1;
+ else if (VT.getVectorElementType() == MVT::i8)
+ return MVT::v4i8;
+ }
+ return getRegisterType(Context, VT);
+}