summaryrefslogtreecommitdiff
path: root/lib/Headers/__clang_cuda_runtime_wrapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Headers/__clang_cuda_runtime_wrapper.h')
-rw-r--r--lib/Headers/__clang_cuda_runtime_wrapper.h36
1 files changed, 35 insertions, 1 deletions
diff --git a/lib/Headers/__clang_cuda_runtime_wrapper.h b/lib/Headers/__clang_cuda_runtime_wrapper.h
index 931d44b6965b7..a82a8490f3670 100644
--- a/lib/Headers/__clang_cuda_runtime_wrapper.h
+++ b/lib/Headers/__clang_cuda_runtime_wrapper.h
@@ -62,7 +62,7 @@
#include "cuda.h"
#if !defined(CUDA_VERSION)
#error "cuda.h did not define CUDA_VERSION"
-#elif CUDA_VERSION < 7000 || CUDA_VERSION > 8000
+#elif CUDA_VERSION < 7000 || CUDA_VERSION > 9000
#error "Unsupported CUDA version!"
#endif
@@ -86,7 +86,11 @@
#define __COMMON_FUNCTIONS_H__
#undef __CUDACC__
+#if CUDA_VERSION < 9000
#define __CUDABE__
+#else
+#define __CUDA_LIBDEVICE__
+#endif
// Disables definitions of device-side runtime support stubs in
// cuda_device_runtime_api.h
#include "driver_types.h"
@@ -94,6 +98,7 @@
#include "host_defines.h"
#undef __CUDABE__
+#undef __CUDA_LIBDEVICE__
#define __CUDACC__
#include "cuda_runtime.h"
@@ -105,7 +110,9 @@
#define __nvvm_memcpy(s, d, n, a) __builtin_memcpy(s, d, n)
#define __nvvm_memset(d, c, n, a) __builtin_memset(d, c, n)
+#if CUDA_VERSION < 9000
#include "crt/device_runtime.h"
+#endif
#include "crt/host_runtime.h"
// device_runtime.h defines __cxa_* macros that will conflict with
// cxxabi.h.
@@ -166,7 +173,18 @@ inline __host__ double __signbitd(double x) {
// __device__.
#pragma push_macro("__forceinline__")
#define __forceinline__ __device__ __inline__ __attribute__((always_inline))
+
+#pragma push_macro("__float2half_rn")
+#if CUDA_VERSION >= 9000
+// CUDA-9 has conflicting prototypes for __float2half_rn(float f) in
+// cuda_fp16.h[pp] and device_functions.hpp. We need to get the one in
+// device_functions.hpp out of the way.
+#define __float2half_rn __float2half_rn_disabled
+#endif
+
#include "device_functions.hpp"
+#pragma pop_macro("__float2half_rn")
+
// math_function.hpp uses the __USE_FAST_MATH__ macro to determine whether we
// get the slow-but-accurate or fast-but-inaccurate versions of functions like
@@ -247,7 +265,23 @@ static inline __device__ void __brkpt(int __c) { __brkpt(); }
#pragma push_macro("__GNUC__")
#undef __GNUC__
#define signbit __ignored_cuda_signbit
+
+// CUDA-9 omits device-side definitions of some math functions if it sees
+// include guard from math.h wrapper from libstdc++. We have to undo the header
+// guard temporarily to get the definitions we need.
+#pragma push_macro("_GLIBCXX_MATH_H")
+#pragma push_macro("_LIBCPP_VERSION")
+#if CUDA_VERSION >= 9000
+#undef _GLIBCXX_MATH_H
+// We also need to undo another guard that checks for libc++ 3.8+
+#ifdef _LIBCPP_VERSION
+#define _LIBCPP_VERSION 3700
+#endif
+#endif
+
#include "math_functions.hpp"
+#pragma pop_macro("_GLIBCXX_MATH_H")
+#pragma pop_macro("_LIBCPP_VERSION")
#pragma pop_macro("__GNUC__")
#pragma pop_macro("signbit")