summaryrefslogtreecommitdiff
path: root/compiler-rt/lib/orc/wrapper_function_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'compiler-rt/lib/orc/wrapper_function_utils.h')
-rw-r--r--compiler-rt/lib/orc/wrapper_function_utils.h176
1 files changed, 133 insertions, 43 deletions
diff --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h
index 49faa03e5eb8..23385e1bd794 100644
--- a/compiler-rt/lib/orc/wrapper_function_utils.h
+++ b/compiler-rt/lib/orc/wrapper_function_utils.h
@@ -16,6 +16,7 @@
#include "c_api.h"
#include "common.h"
#include "error.h"
+#include "executor_address.h"
#include "simple_packed_serialization.h"
#include <type_traits>
@@ -61,7 +62,7 @@ public:
}
/// Get a pointer to the data contained in this instance.
- const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); }
+ char *data() { return __orc_rt_CWrapperFunctionResultData(&R); }
/// Returns the size of the data contained in this instance.
size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); }
@@ -72,10 +73,10 @@ public:
/// Create a WrapperFunctionResult with the given size and return a pointer
/// to the underlying memory.
- static char *allocate(WrapperFunctionResult &R, size_t Size) {
- __orc_rt_DisposeCWrapperFunctionResult(&R.R);
- __orc_rt_CWrapperFunctionResultInit(&R.R);
- return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size);
+ static WrapperFunctionResult allocate(size_t Size) {
+ WrapperFunctionResult R;
+ R.R = __orc_rt_CWrapperFunctionResultAllocate(Size);
+ return R;
}
/// Copy from the given char range.
@@ -103,6 +104,16 @@ public:
return createOutOfBandError(Msg.c_str());
}
+ template <typename SPSArgListT, typename... ArgTs>
+ static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
+ auto Result = allocate(SPSArgListT::size(Args...));
+ SPSOutputBuffer OB(Result.data(), Result.size());
+ if (!SPSArgListT::serialize(OB, Args...))
+ return createOutOfBandError(
+ "Error serializing arguments to blob in call");
+ return Result;
+ }
+
/// If this value is an out-of-band error then this returns the error message,
/// otherwise returns nullptr.
const char *getOutOfBandError() const {
@@ -115,19 +126,6 @@ private:
namespace detail {
-template <typename SPSArgListT, typename... ArgTs>
-Expected<WrapperFunctionResult>
-serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
- WrapperFunctionResult Result;
- char *DataPtr =
- WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
- SPSOutputBuffer OB(DataPtr, Result.size());
- if (!SPSArgListT::serialize(OB, Args...))
- return make_error<StringError>(
- "Error serializing arguments to blob in call");
- return std::move(Result);
-}
-
template <typename RetT> class WrapperFunctionHandlerCaller {
public:
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
@@ -173,12 +171,8 @@ public:
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
std::forward<HandlerT>(H), Args, ArgIndices{});
- if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
- std::move(HandlerResult)))
- return std::move(*Result);
- else
- return WrapperFunctionResult::createOutOfBandError(
- toString(Result.takeError()));
+ return ResultSerializer<decltype(HandlerResult)>::serialize(
+ std::move(HandlerResult));
}
private:
@@ -188,13 +182,12 @@ private:
SPSInputBuffer IB(ArgData, ArgSize);
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
}
-
};
-// Map function references to function types.
+// Map function pointers to function types.
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
-class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
+class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
@@ -217,16 +210,15 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
public:
- static Expected<WrapperFunctionResult> serialize(RetT Result) {
- return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
- Result);
+ static WrapperFunctionResult serialize(RetT Result) {
+ return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
}
};
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
public:
- static Expected<WrapperFunctionResult> serialize(Error Err) {
- return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
+ static WrapperFunctionResult serialize(Error Err) {
+ return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
};
@@ -234,8 +226,8 @@ public:
template <typename SPSRetTagT, typename T>
class ResultSerializer<SPSRetTagT, Expected<T>> {
public:
- static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
- return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
+ static WrapperFunctionResult serialize(Expected<T> E) {
+ return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(E)));
}
};
@@ -310,14 +302,12 @@ public:
return make_error<StringError>("__orc_rt_jit_dispatch not set");
auto ArgBuffer =
- detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
- Args...);
- if (!ArgBuffer)
- return ArgBuffer.takeError();
+ WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
+ if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
+ return make_error<StringError>(ErrMsg);
- WrapperFunctionResult ResultBuffer =
- __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag,
- ArgBuffer->data(), ArgBuffer->size());
+ WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
+ &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
if (auto ErrMsg = ResultBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg);
@@ -329,8 +319,8 @@ public:
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
HandlerT &&Handler) {
using WFHH =
- detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
- SPSTagTs...>;
+ detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
+ ResultSerializer, SPSTagTs...>;
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
}
@@ -362,6 +352,106 @@ public:
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
};
+/// A function object that takes an ExecutorAddr as its first argument,
+/// casts that address to a ClassT*, then calls the given method on that
+/// pointer passing in the remaining function arguments. This utility
+/// removes some of the boilerplate from writing wrappers for method calls.
+///
+/// @code{.cpp}
+/// class MyClass {
+/// public:
+/// void myMethod(uint32_t, bool) { ... }
+/// };
+///
+/// // SPS Method signature -- note MyClass object address as first argument.
+/// using SPSMyMethodWrapperSignature =
+/// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
+///
+/// WrapperFunctionResult
+/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
+/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
+/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
+/// }
+/// @endcode
+///
+template <typename RetT, typename ClassT, typename... ArgTs>
+class MethodWrapperHandler {
+public:
+ using MethodT = RetT (ClassT::*)(ArgTs...);
+ MethodWrapperHandler(MethodT M) : M(M) {}
+ RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
+ return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
+ }
+
+private:
+ MethodT M;
+};
+
+/// Create a MethodWrapperHandler object from the given method pointer.
+template <typename RetT, typename ClassT, typename... ArgTs>
+MethodWrapperHandler<RetT, ClassT, ArgTs...>
+makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
+ return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
+}
+
+/// Represents a call to a wrapper function.
+struct WrapperFunctionCall {
+ ExecutorAddr Func;
+ ExecutorAddrRange ArgData;
+
+ WrapperFunctionCall() = default;
+ WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
+ : Func(Func), ArgData(ArgData) {}
+
+ /// Run and return result as WrapperFunctionResult.
+ WrapperFunctionResult run() {
+ WrapperFunctionResult WFR(
+ Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
+ ArgData.Start.toPtr<const char *>(),
+ static_cast<size_t>(ArgData.size().getValue())));
+ return WFR;
+ }
+
+ /// Run call and deserialize result using SPS.
+ template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
+ auto WFR = run();
+ if (const char *ErrMsg = WFR.getOutOfBandError())
+ return make_error<StringError>(ErrMsg);
+ SPSInputBuffer IB(WFR.data(), WFR.size());
+ if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
+ return make_error<StringError>("Could not deserialize result from "
+ "serialized wrapper function call");
+ return Error::success();
+ }
+
+ /// Overload for SPS functions returning void.
+ Error runWithSPSRet() {
+ SPSEmpty E;
+ return runWithSPSRet<SPSEmpty>(E);
+ }
+};
+
+class SPSWrapperFunctionCall {};
+
+template <>
+class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
+public:
+ static size_t size(const WrapperFunctionCall &WFC) {
+ return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
+ WFC.ArgData);
+ }
+
+ static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
+ return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
+ OB, WFC.Func, WFC.ArgData);
+ }
+
+ static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
+ return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
+ IB, WFC.Func, WFC.ArgData);
+ }
+};
+
} // end namespace __orc_rt
#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H