diff options
Diffstat (limited to 'unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp')
-rw-r--r-- | unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 486 |
1 files changed, 406 insertions, 80 deletions
diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 186c3d408486..1c9764b555fd 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/RPCUtils.h" +#include "QueueChannel.h" #include "gtest/gtest.h" #include <queue> @@ -17,47 +17,6 @@ using namespace llvm; using namespace llvm::orc; using namespace llvm::orc::rpc; -class Queue : public std::queue<char> { -public: - std::mutex &getMutex() { return M; } - std::condition_variable &getCondVar() { return CV; } -private: - std::mutex M; - std::condition_variable CV; -}; - -class QueueChannel : public RawByteChannel { -public: - QueueChannel(Queue &InQueue, Queue &OutQueue) - : InQueue(InQueue), OutQueue(OutQueue) {} - - Error readBytes(char *Dst, unsigned Size) override { - std::unique_lock<std::mutex> Lock(InQueue.getMutex()); - while (Size) { - while (InQueue.empty()) - InQueue.getCondVar().wait(Lock); - *Dst++ = InQueue.front(); - --Size; - InQueue.pop(); - } - return Error::success(); - } - - Error appendBytes(const char *Src, unsigned Size) override { - std::unique_lock<std::mutex> Lock(OutQueue.getMutex()); - while (Size--) - OutQueue.push(*Src++); - OutQueue.getCondVar().notify_one(); - return Error::success(); - } - - Error send() override { return Error::success(); } - -private: - Queue &InQueue; - Queue &OutQueue; -}; - class RPCFoo {}; namespace llvm { @@ -88,6 +47,54 @@ namespace rpc { class RPCBar {}; +class DummyError : public ErrorInfo<DummyError> { +public: + + static char ID; + + DummyError(uint32_t Val) : Val(Val) {} + + std::error_code convertToErrorCode() const override { + // Use a nonsense error code - we want to verify that errors + // transmitted over the network are replaced with + // OrcErrorCode::UnknownErrorCodeFromRemote. + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + } + + void log(raw_ostream &OS) const override { + OS << "Dummy error " << Val; + } + + uint32_t getValue() const { return Val; } + +public: + uint32_t Val; +}; + +char DummyError::ID = 0; + +template <typename ChannelT> +void registerDummyErrorSerialization() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits<ChannelT, Error>:: + template registerErrorType<DummyError>( + "DummyError", + [](ChannelT &C, const DummyError &DE) { + return serializeSeq(C, DE.getValue()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + uint32_t Val; + if (auto Err = deserializeSeq(C, Val)) + return Err; + Err = make_error<DummyError>(Val); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + namespace llvm { namespace orc { namespace rpc { @@ -120,6 +127,11 @@ namespace DummyRPCAPI { static const char* getName() { return "IntInt"; } }; + class VoidString : public Function<VoidString, void(std::string)> { + public: + static const char* getName() { return "VoidString"; } + }; + class AllTheTypes : public Function<AllTheTypes, void(int8_t, uint8_t, int16_t, uint16_t, int32_t, @@ -134,21 +146,38 @@ namespace DummyRPCAPI { static const char* getName() { return "CustomType"; } }; + class ErrorFunc : public Function<ErrorFunc, Error()> { + public: + static const char* getName() { return "ErrorFunc"; } + }; + + class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> { + public: + static const char* getName() { return "ExpectedFunc"; } + }; + } class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> { public: - DummyRPCEndpoint(Queue &Q1, Queue &Q2) - : SingleThreadedRPCEndpoint(C, true), C(Q1, Q2) {} -private: - QueueChannel C; + DummyRPCEndpoint(QueueChannel &C) + : SingleThreadedRPCEndpoint(C, true) {} }; -TEST(DummyRPC, TestAsyncVoidBool) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); +void freeVoidBool(bool B) { +} + +TEST(DummyRPC, TestFreeFunctionHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.first); + Server.addHandler<DummyRPCAPI::VoidBool>(freeVoidBool); +} + +TEST(DummyRPC, TestCallAsyncVoidBool) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::VoidBool>( @@ -189,10 +218,10 @@ TEST(DummyRPC, TestAsyncVoidBool) { ServerThread.join(); } -TEST(DummyRPC, TestAsyncIntInt) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); +TEST(DummyRPC, TestCallAsyncIntInt) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::IntInt>( @@ -234,10 +263,147 @@ TEST(DummyRPC, TestAsyncIntInt) { ServerThread.join(); } +TEST(DummyRPC, TestAsyncIntIntHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addAsyncHandler<DummyRPCAPI::IntInt>( + [](std::function<Error(Expected<int32_t>)> SendResult, + int32_t X) { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return SendResult(2 * X); + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; + } + + { + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestAsyncIntIntHandlerMethod) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + class Dummy { + public: + Error handler(std::function<Error(Expected<int32_t>)> SendResult, + int32_t X) { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return SendResult(2 * X); + } + }; + + std::thread ServerThread([&]() { + Dummy D; + Server.addAsyncHandler<DummyRPCAPI::IntInt>(D, &Dummy::handler); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; + } + + { + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestCallAsyncVoidString) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::VoidString>( + [](const std::string &S) { + EXPECT_EQ(S, "hello") + << "Server void(std::string) received unexpected result"; + }); + + // Poke the server to handle the negotiate call. + for (int I = 0; I < 4; ++I) { + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call"; + } + }); + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>(std::string("hello")); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)"; + } + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>(StringRef("hello")); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)"; + } + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>("hello"); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(string)"; + } + + ServerThread.join(); +} + TEST(DummyRPC, TestSerialization) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::AllTheTypes>( @@ -300,9 +466,9 @@ TEST(DummyRPC, TestSerialization) { } TEST(DummyRPC, TestCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::CustomType>( @@ -343,9 +509,9 @@ TEST(DummyRPC, TestCustomType) { } TEST(DummyRPC, TestWithAltCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::CustomType>( @@ -385,10 +551,144 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } +TEST(DummyRPC, ReturnErrorSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return Error::success(); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_FALSE(!!Err) << "Expected success value"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnErrorFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return make_error<DummyError>(42); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 42ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnExpectedSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> uint32_t { + return 42; + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_TRUE(!!ValOrErr) + << "Expected success value"; + EXPECT_EQ(*ValOrErr, 42ULL) + << "Incorrect Expected<uint32_t> deserialization"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnExpectedFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> Expected<uint32_t> { + return make_error<DummyError>(7); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_FALSE(!!ValOrErr) + << "Expected failure value"; + auto Err = ValOrErr.takeError(); + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 7ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + TEST(DummyRPC, TestParallelCallGroup) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::IntInt>( @@ -405,10 +705,11 @@ TEST(DummyRPC, TestParallelCallGroup) { { int A, B, C; - ParallelCallGroup<DummyRPCEndpoint> PCG(Client); + ParallelCallGroup PCG; { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&A](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; A = *Result; @@ -418,7 +719,8 @@ TEST(DummyRPC, TestParallelCallGroup) { } { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&B](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; B = *Result; @@ -428,7 +730,8 @@ TEST(DummyRPC, TestParallelCallGroup) { } { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&C](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; C = *Result; @@ -443,10 +746,7 @@ TEST(DummyRPC, TestParallelCallGroup) { EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - { - auto Err = PCG.wait(); - EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; - } + PCG.wait(); EXPECT_EQ(A, 2) << "First parallel call returned bogus result"; EXPECT_EQ(B, 4) << "Second parallel call returned bogus result"; @@ -468,9 +768,9 @@ TEST(DummyRPC, TestAPICalls) { static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value, "Contains<Func> template should return false here"); - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread( [&]() { @@ -496,11 +796,37 @@ TEST(DummyRPC, TestAPICalls) { { auto Err = DummyCallsAll::negotiate(Client); - EXPECT_EQ(errorToErrorCode(std::move(Err)).value(), - static_cast<int>(OrcErrorCode::UnknownRPCFunction)) - << "Expected 'UnknownRPCFunction' error for attempted negotiate of " + EXPECT_TRUE(Err.isA<CouldNotNegotiate>()) + << "Expected CouldNotNegotiate error for attempted negotiate of " "unsupported function"; + consumeError(std::move(Err)); } ServerThread.join(); } + +TEST(DummyRPC, TestRemoveHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); + + Server.addHandler<DummyRPCAPI::VoidBool>( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + Server.removeHandler<DummyRPCAPI::VoidBool>(); +} + +TEST(DummyRPC, TestClearHandlers) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); + + Server.addHandler<DummyRPCAPI::VoidBool>( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + Server.clearHandlers(); +} |