diff options
Diffstat (limited to 'unittests/ExecutionEngine/Orc')
-rw-r--r-- | unittests/ExecutionEngine/Orc/CMakeLists.txt | 5 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp | 18 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp | 4 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/OrcTestCommon.cpp | 2 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/OrcTestCommon.h | 22 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/QueueChannel.cpp | 14 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/QueueChannel.h | 146 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 486 | ||||
-rw-r--r-- | unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp (renamed from unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp) | 28 |
9 files changed, 614 insertions, 111 deletions
diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index 68f6d0c28d7c..db40c4213bd7 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -14,11 +14,12 @@ add_llvm_unittest(OrcJITTests IndirectionUtilsTest.cpp GlobalMappingLayerTest.cpp LazyEmittingLayerTest.cpp - ObjectLinkingLayerTest.cpp ObjectTransformLayerTest.cpp OrcCAPITest.cpp OrcTestCommon.cpp + QueueChannel.cpp RPCUtilsTest.cpp + RTDyldObjectLinkingLayerTest.cpp ) -target_link_libraries(OrcJITTests ${PTHREAD_LIB}) +target_link_libraries(OrcJITTests ${LLVM_PTHREAD_LIB}) diff --git a/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp b/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp index ac847039d9fb..4af3aa707a90 100644 --- a/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp @@ -20,17 +20,17 @@ TEST(IndirectionUtilsTest, MakeStub) { LLVMContext Context; ModuleBuilder MB(Context, "x86_64-apple-macosx10.10", ""); Function *F = MB.createFunctionDecl<void(DummyStruct, DummyStruct)>(""); - SmallVector<AttributeSet, 4> Attrs; + SmallVector<AttributeList, 4> Attrs; Attrs.push_back( - AttributeSet::get(MB.getModule()->getContext(), 1U, - AttrBuilder().addAttribute(Attribute::StructRet))); + AttributeList::get(MB.getModule()->getContext(), 1U, + AttrBuilder().addAttribute(Attribute::StructRet))); Attrs.push_back( - AttributeSet::get(MB.getModule()->getContext(), 2U, - AttrBuilder().addAttribute(Attribute::ByVal))); + AttributeList::get(MB.getModule()->getContext(), 2U, + AttrBuilder().addAttribute(Attribute::ByVal))); Attrs.push_back( - AttributeSet::get(MB.getModule()->getContext(), ~0U, - AttrBuilder().addAttribute(Attribute::NoUnwind))); - F->setAttributes(AttributeSet::get(MB.getModule()->getContext(), Attrs)); + AttributeList::get(MB.getModule()->getContext(), ~0U, + AttrBuilder().addAttribute(Attribute::NoUnwind))); + F->setAttributes(AttributeList::get(MB.getModule()->getContext(), Attrs)); auto ImplPtr = orc::createImplPointer(*F->getType(), *MB.getModule(), "", nullptr); orc::makeStub(*F, *ImplPtr); @@ -42,7 +42,7 @@ TEST(IndirectionUtilsTest, MakeStub) { EXPECT_TRUE(Call->isTailCall()) << "Indirect call from stub should be tail call."; EXPECT_TRUE(Call->hasStructRetAttr()) << "makeStub should propagate sret attr on 1st argument."; - EXPECT_TRUE(Call->paramHasAttr(2U, Attribute::ByVal)) + EXPECT_TRUE(Call->paramHasAttr(1U, Attribute::ByVal)) << "makeStub should propagate byval attr on 2nd argument."; } diff --git a/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp b/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp index 63b85dc82ca8..96214a368dce 100644 --- a/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp +++ b/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp @@ -12,7 +12,7 @@ #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/NullResolver.h" -#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h" #include "llvm/Object/ObjectFile.h" #include "gtest/gtest.h" @@ -309,7 +309,7 @@ TEST(ObjectTransformLayerTest, Main) { }; // Construct the jit layers. - ObjectLinkingLayer<> BaseLayer; + RTDyldObjectLinkingLayer<> BaseLayer; auto IdentityTransform = []( std::unique_ptr<llvm::object::OwningBinary<llvm::object::ObjectFile>> Obj) { return Obj; }; diff --git a/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp b/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp index 17d1e9c9276e..ccd2fc0fb189 100644 --- a/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp +++ b/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp @@ -15,7 +15,7 @@ using namespace llvm; -bool OrcExecutionTest::NativeTargetInitialized = false; +bool OrcNativeTarget::NativeTargetInitialized = false; ModuleBuilder::ModuleBuilder(LLVMContext &Context, StringRef Triple, StringRef Name) diff --git a/unittests/ExecutionEngine/Orc/OrcTestCommon.h b/unittests/ExecutionEngine/Orc/OrcTestCommon.h index f3972a3084e5..7fb26634c7a7 100644 --- a/unittests/ExecutionEngine/Orc/OrcTestCommon.h +++ b/unittests/ExecutionEngine/Orc/OrcTestCommon.h @@ -28,17 +28,29 @@ namespace llvm { -// Base class for Orc tests that will execute code. -class OrcExecutionTest { +class OrcNativeTarget { public: - - OrcExecutionTest() { + static void initialize() { if (!NativeTargetInitialized) { InitializeNativeTarget(); InitializeNativeTargetAsmParser(); InitializeNativeTargetAsmPrinter(); NativeTargetInitialized = true; } + } + +private: + static bool NativeTargetInitialized; +}; + +// Base class for Orc tests that will execute code. +class OrcExecutionTest { +public: + + OrcExecutionTest() { + + // Initialize the native target if it hasn't been done already. + OrcNativeTarget::initialize(); // Try to select a TargetMachine for the host. TM.reset(EngineBuilder().selectTarget()); @@ -56,8 +68,6 @@ public: protected: LLVMContext Context; std::unique_ptr<TargetMachine> TM; -private: - static bool NativeTargetInitialized; }; class ModuleBuilder { diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.cpp b/unittests/ExecutionEngine/Orc/QueueChannel.cpp new file mode 100644 index 000000000000..e309a7e428c0 --- /dev/null +++ b/unittests/ExecutionEngine/Orc/QueueChannel.cpp @@ -0,0 +1,14 @@ +//===-------- QueueChannel.cpp - Unit tests the remote executors ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "QueueChannel.h" + +char llvm::QueueChannelError::ID; +char llvm::QueueChannelClosedError::ID; + diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.h b/unittests/ExecutionEngine/Orc/QueueChannel.h new file mode 100644 index 000000000000..3d1058a83ebc --- /dev/null +++ b/unittests/ExecutionEngine/Orc/QueueChannel.h @@ -0,0 +1,146 @@ +//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H +#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H + +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" +#include "llvm/Support/Error.h" + +#include <queue> +#include <condition_variable> + +namespace llvm { + +class QueueChannelError : public ErrorInfo<QueueChannelError> { +public: + static char ID; +}; + +class QueueChannelClosedError + : public ErrorInfo<QueueChannelClosedError, QueueChannelError> { +public: + static char ID; + std::error_code convertToErrorCode() const override { + return inconvertibleErrorCode(); + } + + void log(raw_ostream &OS) const override { + OS << "Queue closed"; + } +}; + +class Queue : public std::queue<char> { +public: + using ErrorInjector = std::function<Error()>; + + Queue() + : ReadError([]() { return Error::success(); }), + WriteError([]() { return Error::success(); }) {} + + Queue(const Queue&) = delete; + Queue& operator=(const Queue&) = delete; + Queue(Queue&&) = delete; + Queue& operator=(Queue&&) = delete; + + std::mutex &getMutex() { return M; } + std::condition_variable &getCondVar() { return CV; } + Error checkReadError() { return ReadError(); } + Error checkWriteError() { return WriteError(); } + void setReadError(ErrorInjector NewReadError) { + { + std::lock_guard<std::mutex> Lock(M); + ReadError = std::move(NewReadError); + } + CV.notify_one(); + } + void setWriteError(ErrorInjector NewWriteError) { + std::lock_guard<std::mutex> Lock(M); + WriteError = std::move(NewWriteError); + } +private: + std::mutex M; + std::condition_variable CV; + std::function<Error()> ReadError, WriteError; +}; + +class QueueChannel : public orc::rpc::RawByteChannel { +public: + QueueChannel(std::shared_ptr<Queue> InQueue, + std::shared_ptr<Queue> OutQueue) + : InQueue(InQueue), OutQueue(OutQueue) {} + + QueueChannel(const QueueChannel&) = delete; + QueueChannel& operator=(const QueueChannel&) = delete; + QueueChannel(QueueChannel&&) = delete; + QueueChannel& operator=(QueueChannel&&) = delete; + + Error readBytes(char *Dst, unsigned Size) override { + std::unique_lock<std::mutex> Lock(InQueue->getMutex()); + while (Size) { + { + Error Err = InQueue->checkReadError(); + while (!Err && InQueue->empty()) { + InQueue->getCondVar().wait(Lock); + Err = InQueue->checkReadError(); + } + if (Err) + return Err; + } + *Dst++ = InQueue->front(); + --Size; + ++NumRead; + InQueue->pop(); + } + return Error::success(); + } + + Error appendBytes(const char *Src, unsigned Size) override { + std::unique_lock<std::mutex> Lock(OutQueue->getMutex()); + while (Size--) { + if (Error Err = OutQueue->checkWriteError()) + return Err; + OutQueue->push(*Src++); + ++NumWritten; + } + OutQueue->getCondVar().notify_one(); + return Error::success(); + } + + Error send() override { return Error::success(); } + + void close() { + auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); }; + InQueue->setReadError(ChannelClosed); + InQueue->setWriteError(ChannelClosed); + OutQueue->setReadError(ChannelClosed); + OutQueue->setWriteError(ChannelClosed); + } + + uint64_t NumWritten = 0; + uint64_t NumRead = 0; + +private: + + std::shared_ptr<Queue> InQueue; + std::shared_ptr<Queue> OutQueue; +}; + +inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>> +createPairedQueueChannels() { + auto Q1 = std::make_shared<Queue>(); + auto Q2 = std::make_shared<Queue>(); + auto C1 = llvm::make_unique<QueueChannel>(Q1, Q2); + auto C2 = llvm::make_unique<QueueChannel>(Q2, Q1); + return std::make_pair(std::move(C1), std::move(C2)); +} + +} + +#endif diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 186c3d408486..1c9764b555fd 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/RPCUtils.h" +#include "QueueChannel.h" #include "gtest/gtest.h" #include <queue> @@ -17,47 +17,6 @@ using namespace llvm; using namespace llvm::orc; using namespace llvm::orc::rpc; -class Queue : public std::queue<char> { -public: - std::mutex &getMutex() { return M; } - std::condition_variable &getCondVar() { return CV; } -private: - std::mutex M; - std::condition_variable CV; -}; - -class QueueChannel : public RawByteChannel { -public: - QueueChannel(Queue &InQueue, Queue &OutQueue) - : InQueue(InQueue), OutQueue(OutQueue) {} - - Error readBytes(char *Dst, unsigned Size) override { - std::unique_lock<std::mutex> Lock(InQueue.getMutex()); - while (Size) { - while (InQueue.empty()) - InQueue.getCondVar().wait(Lock); - *Dst++ = InQueue.front(); - --Size; - InQueue.pop(); - } - return Error::success(); - } - - Error appendBytes(const char *Src, unsigned Size) override { - std::unique_lock<std::mutex> Lock(OutQueue.getMutex()); - while (Size--) - OutQueue.push(*Src++); - OutQueue.getCondVar().notify_one(); - return Error::success(); - } - - Error send() override { return Error::success(); } - -private: - Queue &InQueue; - Queue &OutQueue; -}; - class RPCFoo {}; namespace llvm { @@ -88,6 +47,54 @@ namespace rpc { class RPCBar {}; +class DummyError : public ErrorInfo<DummyError> { +public: + + static char ID; + + DummyError(uint32_t Val) : Val(Val) {} + + std::error_code convertToErrorCode() const override { + // Use a nonsense error code - we want to verify that errors + // transmitted over the network are replaced with + // OrcErrorCode::UnknownErrorCodeFromRemote. + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + } + + void log(raw_ostream &OS) const override { + OS << "Dummy error " << Val; + } + + uint32_t getValue() const { return Val; } + +public: + uint32_t Val; +}; + +char DummyError::ID = 0; + +template <typename ChannelT> +void registerDummyErrorSerialization() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits<ChannelT, Error>:: + template registerErrorType<DummyError>( + "DummyError", + [](ChannelT &C, const DummyError &DE) { + return serializeSeq(C, DE.getValue()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + uint32_t Val; + if (auto Err = deserializeSeq(C, Val)) + return Err; + Err = make_error<DummyError>(Val); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + namespace llvm { namespace orc { namespace rpc { @@ -120,6 +127,11 @@ namespace DummyRPCAPI { static const char* getName() { return "IntInt"; } }; + class VoidString : public Function<VoidString, void(std::string)> { + public: + static const char* getName() { return "VoidString"; } + }; + class AllTheTypes : public Function<AllTheTypes, void(int8_t, uint8_t, int16_t, uint16_t, int32_t, @@ -134,21 +146,38 @@ namespace DummyRPCAPI { static const char* getName() { return "CustomType"; } }; + class ErrorFunc : public Function<ErrorFunc, Error()> { + public: + static const char* getName() { return "ErrorFunc"; } + }; + + class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> { + public: + static const char* getName() { return "ExpectedFunc"; } + }; + } class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> { public: - DummyRPCEndpoint(Queue &Q1, Queue &Q2) - : SingleThreadedRPCEndpoint(C, true), C(Q1, Q2) {} -private: - QueueChannel C; + DummyRPCEndpoint(QueueChannel &C) + : SingleThreadedRPCEndpoint(C, true) {} }; -TEST(DummyRPC, TestAsyncVoidBool) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); +void freeVoidBool(bool B) { +} + +TEST(DummyRPC, TestFreeFunctionHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.first); + Server.addHandler<DummyRPCAPI::VoidBool>(freeVoidBool); +} + +TEST(DummyRPC, TestCallAsyncVoidBool) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::VoidBool>( @@ -189,10 +218,10 @@ TEST(DummyRPC, TestAsyncVoidBool) { ServerThread.join(); } -TEST(DummyRPC, TestAsyncIntInt) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); +TEST(DummyRPC, TestCallAsyncIntInt) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::IntInt>( @@ -234,10 +263,147 @@ TEST(DummyRPC, TestAsyncIntInt) { ServerThread.join(); } +TEST(DummyRPC, TestAsyncIntIntHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addAsyncHandler<DummyRPCAPI::IntInt>( + [](std::function<Error(Expected<int32_t>)> SendResult, + int32_t X) { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return SendResult(2 * X); + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; + } + + { + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestAsyncIntIntHandlerMethod) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + class Dummy { + public: + Error handler(std::function<Error(Expected<int32_t>)> SendResult, + int32_t X) { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return SendResult(2 * X); + } + }; + + std::thread ServerThread([&]() { + Dummy D; + Server.addAsyncHandler<DummyRPCAPI::IntInt>(D, &Dummy::handler); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; + } + + { + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestCallAsyncVoidString) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::VoidString>( + [](const std::string &S) { + EXPECT_EQ(S, "hello") + << "Server void(std::string) received unexpected result"; + }); + + // Poke the server to handle the negotiate call. + for (int I = 0; I < 4; ++I) { + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call"; + } + }); + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>(std::string("hello")); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)"; + } + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>(StringRef("hello")); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)"; + } + + { + // Make an call using a std::string. + auto Err = Client.callB<DummyRPCAPI::VoidString>("hello"); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(string)"; + } + + ServerThread.join(); +} + TEST(DummyRPC, TestSerialization) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::AllTheTypes>( @@ -300,9 +466,9 @@ TEST(DummyRPC, TestSerialization) { } TEST(DummyRPC, TestCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::CustomType>( @@ -343,9 +509,9 @@ TEST(DummyRPC, TestCustomType) { } TEST(DummyRPC, TestWithAltCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::CustomType>( @@ -385,10 +551,144 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } +TEST(DummyRPC, ReturnErrorSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return Error::success(); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_FALSE(!!Err) << "Expected success value"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnErrorFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return make_error<DummyError>(42); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 42ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnExpectedSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> uint32_t { + return 42; + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_TRUE(!!ValOrErr) + << "Expected success value"; + EXPECT_EQ(*ValOrErr, 42ULL) + << "Incorrect Expected<uint32_t> deserialization"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnExpectedFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> Expected<uint32_t> { + return make_error<DummyError>(7); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_FALSE(!!ValOrErr) + << "Expected failure value"; + auto Err = ValOrErr.takeError(); + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 7ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + TEST(DummyRPC, TestParallelCallGroup) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler<DummyRPCAPI::IntInt>( @@ -405,10 +705,11 @@ TEST(DummyRPC, TestParallelCallGroup) { { int A, B, C; - ParallelCallGroup<DummyRPCEndpoint> PCG(Client); + ParallelCallGroup PCG; { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&A](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; A = *Result; @@ -418,7 +719,8 @@ TEST(DummyRPC, TestParallelCallGroup) { } { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&B](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; B = *Result; @@ -428,7 +730,8 @@ TEST(DummyRPC, TestParallelCallGroup) { } { - auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + auto Err = PCG.call( + rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client), [&C](Expected<int> Result) { EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; C = *Result; @@ -443,10 +746,7 @@ TEST(DummyRPC, TestParallelCallGroup) { EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - { - auto Err = PCG.wait(); - EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; - } + PCG.wait(); EXPECT_EQ(A, 2) << "First parallel call returned bogus result"; EXPECT_EQ(B, 4) << "Second parallel call returned bogus result"; @@ -468,9 +768,9 @@ TEST(DummyRPC, TestAPICalls) { static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value, "Contains<Func> template should return false here"); - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread( [&]() { @@ -496,11 +796,37 @@ TEST(DummyRPC, TestAPICalls) { { auto Err = DummyCallsAll::negotiate(Client); - EXPECT_EQ(errorToErrorCode(std::move(Err)).value(), - static_cast<int>(OrcErrorCode::UnknownRPCFunction)) - << "Expected 'UnknownRPCFunction' error for attempted negotiate of " + EXPECT_TRUE(Err.isA<CouldNotNegotiate>()) + << "Expected CouldNotNegotiate error for attempted negotiate of " "unsupported function"; + consumeError(std::move(Err)); } ServerThread.join(); } + +TEST(DummyRPC, TestRemoveHandler) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); + + Server.addHandler<DummyRPCAPI::VoidBool>( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + Server.removeHandler<DummyRPCAPI::VoidBool>(); +} + +TEST(DummyRPC, TestClearHandlers) { + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); + + Server.addHandler<DummyRPCAPI::VoidBool>( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + Server.clearHandlers(); +} diff --git a/unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp index 44b44f604159..de99c022fb9d 100644 --- a/unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp +++ b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp @@ -1,4 +1,4 @@ -//===-- ObjectLinkingLayerTest.cpp - Unit tests for object linking layer --===// +//===- RTDyldObjectLinkingLayerTest.cpp - RTDyld linking layer unit tests -===// // // The LLVM Compiler Infrastructure // @@ -13,7 +13,7 @@ #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/LambdaResolver.h" #include "llvm/ExecutionEngine/Orc/NullResolver.h" -#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/IR/Constants.h" #include "llvm/IR/LLVMContext.h" #include "gtest/gtest.h" @@ -23,8 +23,8 @@ using namespace llvm::orc; namespace { -class ObjectLinkingLayerExecutionTest : public testing::Test, - public OrcExecutionTest { +class RTDyldObjectLinkingLayerExecutionTest : public testing::Test, + public OrcExecutionTest { }; @@ -44,7 +44,7 @@ public: } }; -TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) { +TEST(RTDyldObjectLinkingLayerTest, TestSetProcessAllSections) { class SectionMemoryManagerWrapper : public SectionMemoryManager { public: SectionMemoryManagerWrapper(bool &DebugSeen) : DebugSeen(DebugSeen) {} @@ -60,10 +60,10 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) { IsReadOnly); } private: - bool DebugSeen; + bool &DebugSeen; }; - ObjectLinkingLayer<> ObjLayer; + RTDyldObjectLinkingLayer<> ObjLayer; LLVMContext Context; auto M = llvm::make_unique<Module>("", Context); @@ -75,6 +75,10 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) { GV->setSection(".debug_str"); + + // Initialize the native target in case this is the first unit test + // to try to build a TM. + OrcNativeTarget::initialize(); std::unique_ptr<TargetMachine> TM( EngineBuilder().selectTarget(Triple(M->getTargetTriple()), "", "", SmallVector<std::string, 1>())); @@ -99,6 +103,7 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) { { // Test with ProcessAllSections = false (the default). auto H = ObjLayer.addObjectSet(Objs, &SMMW, &*Resolver); + ObjLayer.emitAndFinalize(H); EXPECT_EQ(DebugSectionSeen, false) << "Unexpected debug info section"; ObjLayer.removeObjectSet(H); @@ -108,17 +113,18 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) { // Test with ProcessAllSections = true. ObjLayer.setProcessAllSections(true); auto H = ObjLayer.addObjectSet(Objs, &SMMW, &*Resolver); + ObjLayer.emitAndFinalize(H); EXPECT_EQ(DebugSectionSeen, true) << "Expected debug info section not seen"; ObjLayer.removeObjectSet(H); } } -TEST_F(ObjectLinkingLayerExecutionTest, NoDuplicateFinalization) { +TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoDuplicateFinalization) { if (!TM) return; - ObjectLinkingLayer<> ObjLayer; + RTDyldObjectLinkingLayer<> ObjLayer; SimpleCompiler Compile(*TM); // Create a pair of modules that will trigger recursive finalization: @@ -183,11 +189,11 @@ TEST_F(ObjectLinkingLayerExecutionTest, NoDuplicateFinalization) { << "Extra call to finalize"; } -TEST_F(ObjectLinkingLayerExecutionTest, NoPrematureAllocation) { +TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoPrematureAllocation) { if (!TM) return; - ObjectLinkingLayer<> ObjLayer; + RTDyldObjectLinkingLayer<> ObjLayer; SimpleCompiler Compile(*TM); // Create a pair of unrelated modules: |