diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
| commit | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch) | |
| tree | f42add1021b9f2ac6a69ac7cf6c4499962739a45 /compiler-rt/lib/orc/wrapper_function_utils.h | |
| parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) | |
Diffstat (limited to 'compiler-rt/lib/orc/wrapper_function_utils.h')
| -rw-r--r-- | compiler-rt/lib/orc/wrapper_function_utils.h | 176 |
1 files changed, 133 insertions, 43 deletions
diff --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h index 49faa03e5eb8..23385e1bd794 100644 --- a/compiler-rt/lib/orc/wrapper_function_utils.h +++ b/compiler-rt/lib/orc/wrapper_function_utils.h @@ -16,6 +16,7 @@ #include "c_api.h" #include "common.h" #include "error.h" +#include "executor_address.h" #include "simple_packed_serialization.h" #include <type_traits> @@ -61,7 +62,7 @@ public: } /// Get a pointer to the data contained in this instance. - const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); } + char *data() { return __orc_rt_CWrapperFunctionResultData(&R); } /// Returns the size of the data contained in this instance. size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } @@ -72,10 +73,10 @@ public: /// Create a WrapperFunctionResult with the given size and return a pointer /// to the underlying memory. - static char *allocate(WrapperFunctionResult &R, size_t Size) { - __orc_rt_DisposeCWrapperFunctionResult(&R.R); - __orc_rt_CWrapperFunctionResultInit(&R.R); - return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size); + static WrapperFunctionResult allocate(size_t Size) { + WrapperFunctionResult R; + R.R = __orc_rt_CWrapperFunctionResultAllocate(Size); + return R; } /// Copy from the given char range. @@ -103,6 +104,16 @@ public: return createOutOfBandError(Msg.c_str()); } + template <typename SPSArgListT, typename... ArgTs> + static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { + auto Result = allocate(SPSArgListT::size(Args...)); + SPSOutputBuffer OB(Result.data(), Result.size()); + if (!SPSArgListT::serialize(OB, Args...)) + return createOutOfBandError( + "Error serializing arguments to blob in call"); + return Result; + } + /// If this value is an out-of-band error then this returns the error message, /// otherwise returns nullptr. const char *getOutOfBandError() const { @@ -115,19 +126,6 @@ private: namespace detail { -template <typename SPSArgListT, typename... ArgTs> -Expected<WrapperFunctionResult> -serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { - WrapperFunctionResult Result; - char *DataPtr = - WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); - SPSOutputBuffer OB(DataPtr, Result.size()); - if (!SPSArgListT::serialize(OB, Args...)) - return make_error<StringError>( - "Error serializing arguments to blob in call"); - return std::move(Result); -} - template <typename RetT> class WrapperFunctionHandlerCaller { public: template <typename HandlerT, typename ArgTupleT, std::size_t... I> @@ -173,12 +171,8 @@ public: auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( std::forward<HandlerT>(H), Args, ArgIndices{}); - if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize( - std::move(HandlerResult))) - return std::move(*Result); - else - return WrapperFunctionResult::createOutOfBandError( - toString(Result.takeError())); + return ResultSerializer<decltype(HandlerResult)>::serialize( + std::move(HandlerResult)); } private: @@ -188,13 +182,12 @@ private: SPSInputBuffer IB(ArgData, ArgSize); return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); } - }; -// Map function references to function types. +// Map function pointers to function types. template <typename RetT, typename... ArgTs, template <typename> class ResultSerializer, typename... SPSTagTs> -class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer, +class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, SPSTagTs...> : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, SPSTagTs...> {}; @@ -217,16 +210,15 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, template <typename SPSRetTagT, typename RetT> class ResultSerializer { public: - static Expected<WrapperFunctionResult> serialize(RetT Result) { - return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( - Result); + static WrapperFunctionResult serialize(RetT Result) { + return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); } }; template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { public: - static Expected<WrapperFunctionResult> serialize(Error Err) { - return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( + static WrapperFunctionResult serialize(Error Err) { + return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( toSPSSerializable(std::move(Err))); } }; @@ -234,8 +226,8 @@ public: template <typename SPSRetTagT, typename T> class ResultSerializer<SPSRetTagT, Expected<T>> { public: - static Expected<WrapperFunctionResult> serialize(Expected<T> E) { - return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( + static WrapperFunctionResult serialize(Expected<T> E) { + return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( toSPSSerializable(std::move(E))); } }; @@ -310,14 +302,12 @@ public: return make_error<StringError>("__orc_rt_jit_dispatch not set"); auto ArgBuffer = - detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( - Args...); - if (!ArgBuffer) - return ArgBuffer.takeError(); + WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); + if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) + return make_error<StringError>(ErrMsg); - WrapperFunctionResult ResultBuffer = - __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, - ArgBuffer->data(), ArgBuffer->size()); + WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch( + &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size()); if (auto ErrMsg = ResultBuffer.getOutOfBandError()) return make_error<StringError>(ErrMsg); @@ -329,8 +319,8 @@ public: static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, HandlerT &&Handler) { using WFHH = - detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer, - SPSTagTs...>; + detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, + ResultSerializer, SPSTagTs...>; return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); } @@ -362,6 +352,106 @@ public: using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; }; +/// A function object that takes an ExecutorAddr as its first argument, +/// casts that address to a ClassT*, then calls the given method on that +/// pointer passing in the remaining function arguments. This utility +/// removes some of the boilerplate from writing wrappers for method calls. +/// +/// @code{.cpp} +/// class MyClass { +/// public: +/// void myMethod(uint32_t, bool) { ... } +/// }; +/// +/// // SPS Method signature -- note MyClass object address as first argument. +/// using SPSMyMethodWrapperSignature = +/// SPSTuple<SPSExecutorAddr, uint32_t, bool>; +/// +/// WrapperFunctionResult +/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { +/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( +/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); +/// } +/// @endcode +/// +template <typename RetT, typename ClassT, typename... ArgTs> +class MethodWrapperHandler { +public: + using MethodT = RetT (ClassT::*)(ArgTs...); + MethodWrapperHandler(MethodT M) : M(M) {} + RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { + return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); + } + +private: + MethodT M; +}; + +/// Create a MethodWrapperHandler object from the given method pointer. +template <typename RetT, typename ClassT, typename... ArgTs> +MethodWrapperHandler<RetT, ClassT, ArgTs...> +makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { + return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); +} + +/// Represents a call to a wrapper function. +struct WrapperFunctionCall { + ExecutorAddr Func; + ExecutorAddrRange ArgData; + + WrapperFunctionCall() = default; + WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData) + : Func(Func), ArgData(ArgData) {} + + /// Run and return result as WrapperFunctionResult. + WrapperFunctionResult run() { + WrapperFunctionResult WFR( + Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()( + ArgData.Start.toPtr<const char *>(), + static_cast<size_t>(ArgData.size().getValue()))); + return WFR; + } + + /// Run call and deserialize result using SPS. + template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) { + auto WFR = run(); + if (const char *ErrMsg = WFR.getOutOfBandError()) + return make_error<StringError>(ErrMsg); + SPSInputBuffer IB(WFR.data(), WFR.size()); + if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) + return make_error<StringError>("Could not deserialize result from " + "serialized wrapper function call"); + return Error::success(); + } + + /// Overload for SPS functions returning void. + Error runWithSPSRet() { + SPSEmpty E; + return runWithSPSRet<SPSEmpty>(E); + } +}; + +class SPSWrapperFunctionCall {}; + +template <> +class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { +public: + static size_t size(const WrapperFunctionCall &WFC) { + return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func, + WFC.ArgData); + } + + static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { + return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize( + OB, WFC.Func, WFC.ArgData); + } + + static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { + return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize( + IB, WFC.Func, WFC.ArgData); + } +}; + } // end namespace __orc_rt #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H |
