aboutsummaryrefslogtreecommitdiff
path: root/unittests/ExecutionEngine/Orc
diff options
context:
space:
mode:
Diffstat (limited to 'unittests/ExecutionEngine/Orc')
-rw-r--r--unittests/ExecutionEngine/Orc/CMakeLists.txt5
-rw-r--r--unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp18
-rw-r--r--unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp4
-rw-r--r--unittests/ExecutionEngine/Orc/OrcTestCommon.cpp2
-rw-r--r--unittests/ExecutionEngine/Orc/OrcTestCommon.h22
-rw-r--r--unittests/ExecutionEngine/Orc/QueueChannel.cpp14
-rw-r--r--unittests/ExecutionEngine/Orc/QueueChannel.h146
-rw-r--r--unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp486
-rw-r--r--unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp (renamed from unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp)28
9 files changed, 614 insertions, 111 deletions
diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt
index 68f6d0c28d7c..db40c4213bd7 100644
--- a/unittests/ExecutionEngine/Orc/CMakeLists.txt
+++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt
@@ -14,11 +14,12 @@ add_llvm_unittest(OrcJITTests
IndirectionUtilsTest.cpp
GlobalMappingLayerTest.cpp
LazyEmittingLayerTest.cpp
- ObjectLinkingLayerTest.cpp
ObjectTransformLayerTest.cpp
OrcCAPITest.cpp
OrcTestCommon.cpp
+ QueueChannel.cpp
RPCUtilsTest.cpp
+ RTDyldObjectLinkingLayerTest.cpp
)
-target_link_libraries(OrcJITTests ${PTHREAD_LIB})
+target_link_libraries(OrcJITTests ${LLVM_PTHREAD_LIB})
diff --git a/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp b/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp
index ac847039d9fb..4af3aa707a90 100644
--- a/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp
+++ b/unittests/ExecutionEngine/Orc/IndirectionUtilsTest.cpp
@@ -20,17 +20,17 @@ TEST(IndirectionUtilsTest, MakeStub) {
LLVMContext Context;
ModuleBuilder MB(Context, "x86_64-apple-macosx10.10", "");
Function *F = MB.createFunctionDecl<void(DummyStruct, DummyStruct)>("");
- SmallVector<AttributeSet, 4> Attrs;
+ SmallVector<AttributeList, 4> Attrs;
Attrs.push_back(
- AttributeSet::get(MB.getModule()->getContext(), 1U,
- AttrBuilder().addAttribute(Attribute::StructRet)));
+ AttributeList::get(MB.getModule()->getContext(), 1U,
+ AttrBuilder().addAttribute(Attribute::StructRet)));
Attrs.push_back(
- AttributeSet::get(MB.getModule()->getContext(), 2U,
- AttrBuilder().addAttribute(Attribute::ByVal)));
+ AttributeList::get(MB.getModule()->getContext(), 2U,
+ AttrBuilder().addAttribute(Attribute::ByVal)));
Attrs.push_back(
- AttributeSet::get(MB.getModule()->getContext(), ~0U,
- AttrBuilder().addAttribute(Attribute::NoUnwind)));
- F->setAttributes(AttributeSet::get(MB.getModule()->getContext(), Attrs));
+ AttributeList::get(MB.getModule()->getContext(), ~0U,
+ AttrBuilder().addAttribute(Attribute::NoUnwind)));
+ F->setAttributes(AttributeList::get(MB.getModule()->getContext(), Attrs));
auto ImplPtr = orc::createImplPointer(*F->getType(), *MB.getModule(), "", nullptr);
orc::makeStub(*F, *ImplPtr);
@@ -42,7 +42,7 @@ TEST(IndirectionUtilsTest, MakeStub) {
EXPECT_TRUE(Call->isTailCall()) << "Indirect call from stub should be tail call.";
EXPECT_TRUE(Call->hasStructRetAttr())
<< "makeStub should propagate sret attr on 1st argument.";
- EXPECT_TRUE(Call->paramHasAttr(2U, Attribute::ByVal))
+ EXPECT_TRUE(Call->paramHasAttr(1U, Attribute::ByVal))
<< "makeStub should propagate byval attr on 2nd argument.";
}
diff --git a/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp b/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp
index 63b85dc82ca8..96214a368dce 100644
--- a/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp
+++ b/unittests/ExecutionEngine/Orc/ObjectTransformLayerTest.cpp
@@ -12,7 +12,7 @@
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/NullResolver.h"
-#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
+#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h"
#include "llvm/Object/ObjectFile.h"
#include "gtest/gtest.h"
@@ -309,7 +309,7 @@ TEST(ObjectTransformLayerTest, Main) {
};
// Construct the jit layers.
- ObjectLinkingLayer<> BaseLayer;
+ RTDyldObjectLinkingLayer<> BaseLayer;
auto IdentityTransform = [](
std::unique_ptr<llvm::object::OwningBinary<llvm::object::ObjectFile>>
Obj) { return Obj; };
diff --git a/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp b/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
index 17d1e9c9276e..ccd2fc0fb189 100644
--- a/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
+++ b/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
@@ -15,7 +15,7 @@
using namespace llvm;
-bool OrcExecutionTest::NativeTargetInitialized = false;
+bool OrcNativeTarget::NativeTargetInitialized = false;
ModuleBuilder::ModuleBuilder(LLVMContext &Context, StringRef Triple,
StringRef Name)
diff --git a/unittests/ExecutionEngine/Orc/OrcTestCommon.h b/unittests/ExecutionEngine/Orc/OrcTestCommon.h
index f3972a3084e5..7fb26634c7a7 100644
--- a/unittests/ExecutionEngine/Orc/OrcTestCommon.h
+++ b/unittests/ExecutionEngine/Orc/OrcTestCommon.h
@@ -28,17 +28,29 @@
namespace llvm {
-// Base class for Orc tests that will execute code.
-class OrcExecutionTest {
+class OrcNativeTarget {
public:
-
- OrcExecutionTest() {
+ static void initialize() {
if (!NativeTargetInitialized) {
InitializeNativeTarget();
InitializeNativeTargetAsmParser();
InitializeNativeTargetAsmPrinter();
NativeTargetInitialized = true;
}
+ }
+
+private:
+ static bool NativeTargetInitialized;
+};
+
+// Base class for Orc tests that will execute code.
+class OrcExecutionTest {
+public:
+
+ OrcExecutionTest() {
+
+ // Initialize the native target if it hasn't been done already.
+ OrcNativeTarget::initialize();
// Try to select a TargetMachine for the host.
TM.reset(EngineBuilder().selectTarget());
@@ -56,8 +68,6 @@ public:
protected:
LLVMContext Context;
std::unique_ptr<TargetMachine> TM;
-private:
- static bool NativeTargetInitialized;
};
class ModuleBuilder {
diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.cpp b/unittests/ExecutionEngine/Orc/QueueChannel.cpp
new file mode 100644
index 000000000000..e309a7e428c0
--- /dev/null
+++ b/unittests/ExecutionEngine/Orc/QueueChannel.cpp
@@ -0,0 +1,14 @@
+//===-------- QueueChannel.cpp - Unit tests the remote executors ----------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "QueueChannel.h"
+
+char llvm::QueueChannelError::ID;
+char llvm::QueueChannelClosedError::ID;
+
diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.h b/unittests/ExecutionEngine/Orc/QueueChannel.h
new file mode 100644
index 000000000000..3d1058a83ebc
--- /dev/null
+++ b/unittests/ExecutionEngine/Orc/QueueChannel.h
@@ -0,0 +1,146 @@
+//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
+#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
+
+#include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
+#include "llvm/Support/Error.h"
+
+#include <queue>
+#include <condition_variable>
+
+namespace llvm {
+
+class QueueChannelError : public ErrorInfo<QueueChannelError> {
+public:
+ static char ID;
+};
+
+class QueueChannelClosedError
+ : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
+public:
+ static char ID;
+ std::error_code convertToErrorCode() const override {
+ return inconvertibleErrorCode();
+ }
+
+ void log(raw_ostream &OS) const override {
+ OS << "Queue closed";
+ }
+};
+
+class Queue : public std::queue<char> {
+public:
+ using ErrorInjector = std::function<Error()>;
+
+ Queue()
+ : ReadError([]() { return Error::success(); }),
+ WriteError([]() { return Error::success(); }) {}
+
+ Queue(const Queue&) = delete;
+ Queue& operator=(const Queue&) = delete;
+ Queue(Queue&&) = delete;
+ Queue& operator=(Queue&&) = delete;
+
+ std::mutex &getMutex() { return M; }
+ std::condition_variable &getCondVar() { return CV; }
+ Error checkReadError() { return ReadError(); }
+ Error checkWriteError() { return WriteError(); }
+ void setReadError(ErrorInjector NewReadError) {
+ {
+ std::lock_guard<std::mutex> Lock(M);
+ ReadError = std::move(NewReadError);
+ }
+ CV.notify_one();
+ }
+ void setWriteError(ErrorInjector NewWriteError) {
+ std::lock_guard<std::mutex> Lock(M);
+ WriteError = std::move(NewWriteError);
+ }
+private:
+ std::mutex M;
+ std::condition_variable CV;
+ std::function<Error()> ReadError, WriteError;
+};
+
+class QueueChannel : public orc::rpc::RawByteChannel {
+public:
+ QueueChannel(std::shared_ptr<Queue> InQueue,
+ std::shared_ptr<Queue> OutQueue)
+ : InQueue(InQueue), OutQueue(OutQueue) {}
+
+ QueueChannel(const QueueChannel&) = delete;
+ QueueChannel& operator=(const QueueChannel&) = delete;
+ QueueChannel(QueueChannel&&) = delete;
+ QueueChannel& operator=(QueueChannel&&) = delete;
+
+ Error readBytes(char *Dst, unsigned Size) override {
+ std::unique_lock<std::mutex> Lock(InQueue->getMutex());
+ while (Size) {
+ {
+ Error Err = InQueue->checkReadError();
+ while (!Err && InQueue->empty()) {
+ InQueue->getCondVar().wait(Lock);
+ Err = InQueue->checkReadError();
+ }
+ if (Err)
+ return Err;
+ }
+ *Dst++ = InQueue->front();
+ --Size;
+ ++NumRead;
+ InQueue->pop();
+ }
+ return Error::success();
+ }
+
+ Error appendBytes(const char *Src, unsigned Size) override {
+ std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
+ while (Size--) {
+ if (Error Err = OutQueue->checkWriteError())
+ return Err;
+ OutQueue->push(*Src++);
+ ++NumWritten;
+ }
+ OutQueue->getCondVar().notify_one();
+ return Error::success();
+ }
+
+ Error send() override { return Error::success(); }
+
+ void close() {
+ auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
+ InQueue->setReadError(ChannelClosed);
+ InQueue->setWriteError(ChannelClosed);
+ OutQueue->setReadError(ChannelClosed);
+ OutQueue->setWriteError(ChannelClosed);
+ }
+
+ uint64_t NumWritten = 0;
+ uint64_t NumRead = 0;
+
+private:
+
+ std::shared_ptr<Queue> InQueue;
+ std::shared_ptr<Queue> OutQueue;
+};
+
+inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
+createPairedQueueChannels() {
+ auto Q1 = std::make_shared<Queue>();
+ auto Q2 = std::make_shared<Queue>();
+ auto C1 = llvm::make_unique<QueueChannel>(Q1, Q2);
+ auto C2 = llvm::make_unique<QueueChannel>(Q2, Q1);
+ return std::make_pair(std::move(C1), std::move(C2));
+}
+
+}
+
+#endif
diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
index 186c3d408486..1c9764b555fd 100644
--- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
+++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
@@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
#include "llvm/ExecutionEngine/Orc/RPCUtils.h"
+#include "QueueChannel.h"
#include "gtest/gtest.h"
#include <queue>
@@ -17,47 +17,6 @@ using namespace llvm;
using namespace llvm::orc;
using namespace llvm::orc::rpc;
-class Queue : public std::queue<char> {
-public:
- std::mutex &getMutex() { return M; }
- std::condition_variable &getCondVar() { return CV; }
-private:
- std::mutex M;
- std::condition_variable CV;
-};
-
-class QueueChannel : public RawByteChannel {
-public:
- QueueChannel(Queue &InQueue, Queue &OutQueue)
- : InQueue(InQueue), OutQueue(OutQueue) {}
-
- Error readBytes(char *Dst, unsigned Size) override {
- std::unique_lock<std::mutex> Lock(InQueue.getMutex());
- while (Size) {
- while (InQueue.empty())
- InQueue.getCondVar().wait(Lock);
- *Dst++ = InQueue.front();
- --Size;
- InQueue.pop();
- }
- return Error::success();
- }
-
- Error appendBytes(const char *Src, unsigned Size) override {
- std::unique_lock<std::mutex> Lock(OutQueue.getMutex());
- while (Size--)
- OutQueue.push(*Src++);
- OutQueue.getCondVar().notify_one();
- return Error::success();
- }
-
- Error send() override { return Error::success(); }
-
-private:
- Queue &InQueue;
- Queue &OutQueue;
-};
-
class RPCFoo {};
namespace llvm {
@@ -88,6 +47,54 @@ namespace rpc {
class RPCBar {};
+class DummyError : public ErrorInfo<DummyError> {
+public:
+
+ static char ID;
+
+ DummyError(uint32_t Val) : Val(Val) {}
+
+ std::error_code convertToErrorCode() const override {
+ // Use a nonsense error code - we want to verify that errors
+ // transmitted over the network are replaced with
+ // OrcErrorCode::UnknownErrorCodeFromRemote.
+ return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+ }
+
+ void log(raw_ostream &OS) const override {
+ OS << "Dummy error " << Val;
+ }
+
+ uint32_t getValue() const { return Val; }
+
+public:
+ uint32_t Val;
+};
+
+char DummyError::ID = 0;
+
+template <typename ChannelT>
+void registerDummyErrorSerialization() {
+ static bool AlreadyRegistered = false;
+ if (!AlreadyRegistered) {
+ SerializationTraits<ChannelT, Error>::
+ template registerErrorType<DummyError>(
+ "DummyError",
+ [](ChannelT &C, const DummyError &DE) {
+ return serializeSeq(C, DE.getValue());
+ },
+ [](ChannelT &C, Error &Err) -> Error {
+ ErrorAsOutParameter EAO(&Err);
+ uint32_t Val;
+ if (auto Err = deserializeSeq(C, Val))
+ return Err;
+ Err = make_error<DummyError>(Val);
+ return Error::success();
+ });
+ AlreadyRegistered = true;
+ }
+}
+
namespace llvm {
namespace orc {
namespace rpc {
@@ -120,6 +127,11 @@ namespace DummyRPCAPI {
static const char* getName() { return "IntInt"; }
};
+ class VoidString : public Function<VoidString, void(std::string)> {
+ public:
+ static const char* getName() { return "VoidString"; }
+ };
+
class AllTheTypes
: public Function<AllTheTypes,
void(int8_t, uint8_t, int16_t, uint16_t, int32_t,
@@ -134,21 +146,38 @@ namespace DummyRPCAPI {
static const char* getName() { return "CustomType"; }
};
+ class ErrorFunc : public Function<ErrorFunc, Error()> {
+ public:
+ static const char* getName() { return "ErrorFunc"; }
+ };
+
+ class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> {
+ public:
+ static const char* getName() { return "ExpectedFunc"; }
+ };
+
}
class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> {
public:
- DummyRPCEndpoint(Queue &Q1, Queue &Q2)
- : SingleThreadedRPCEndpoint(C, true), C(Q1, Q2) {}
-private:
- QueueChannel C;
+ DummyRPCEndpoint(QueueChannel &C)
+ : SingleThreadedRPCEndpoint(C, true) {}
};
-TEST(DummyRPC, TestAsyncVoidBool) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+void freeVoidBool(bool B) {
+}
+
+TEST(DummyRPC, TestFreeFunctionHandler) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Server(*Channels.first);
+ Server.addHandler<DummyRPCAPI::VoidBool>(freeVoidBool);
+}
+
+TEST(DummyRPC, TestCallAsyncVoidBool) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::VoidBool>(
@@ -189,10 +218,10 @@ TEST(DummyRPC, TestAsyncVoidBool) {
ServerThread.join();
}
-TEST(DummyRPC, TestAsyncIntInt) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+TEST(DummyRPC, TestCallAsyncIntInt) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::IntInt>(
@@ -234,10 +263,147 @@ TEST(DummyRPC, TestAsyncIntInt) {
ServerThread.join();
}
+TEST(DummyRPC, TestAsyncIntIntHandler) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addAsyncHandler<DummyRPCAPI::IntInt>(
+ [](std::function<Error(Expected<int32_t>)> SendResult,
+ int32_t X) {
+ EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result";
+ return SendResult(2 * X);
+ });
+
+ {
+ // Poke the server to handle the negotiate call.
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
+ }
+
+ {
+ // Poke the server to handle the VoidBool call.
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)";
+ }
+ });
+
+ {
+ auto Err = Client.callAsync<DummyRPCAPI::IntInt>(
+ [](Expected<int> Result) {
+ EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
+ EXPECT_EQ(*Result, 42)
+ << "Async int(int) response handler received incorrect result";
+ return Error::success();
+ }, 21);
+ EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)";
+ }
+
+ {
+ // Poke the client to process the result.
+ auto Err = Client.handleOne();
+ EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)";
+ }
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, TestAsyncIntIntHandlerMethod) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ class Dummy {
+ public:
+ Error handler(std::function<Error(Expected<int32_t>)> SendResult,
+ int32_t X) {
+ EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result";
+ return SendResult(2 * X);
+ }
+ };
+
+ std::thread ServerThread([&]() {
+ Dummy D;
+ Server.addAsyncHandler<DummyRPCAPI::IntInt>(D, &Dummy::handler);
+
+ {
+ // Poke the server to handle the negotiate call.
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
+ }
+
+ {
+ // Poke the server to handle the VoidBool call.
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)";
+ }
+ });
+
+ {
+ auto Err = Client.callAsync<DummyRPCAPI::IntInt>(
+ [](Expected<int> Result) {
+ EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
+ EXPECT_EQ(*Result, 42)
+ << "Async int(int) response handler received incorrect result";
+ return Error::success();
+ }, 21);
+ EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)";
+ }
+
+ {
+ // Poke the client to process the result.
+ auto Err = Client.handleOne();
+ EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)";
+ }
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, TestCallAsyncVoidString) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::VoidString>(
+ [](const std::string &S) {
+ EXPECT_EQ(S, "hello")
+ << "Server void(std::string) received unexpected result";
+ });
+
+ // Poke the server to handle the negotiate call.
+ for (int I = 0; I < 4; ++I) {
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call";
+ }
+ });
+
+ {
+ // Make an call using a std::string.
+ auto Err = Client.callB<DummyRPCAPI::VoidString>(std::string("hello"));
+ EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)";
+ }
+
+ {
+ // Make an call using a std::string.
+ auto Err = Client.callB<DummyRPCAPI::VoidString>(StringRef("hello"));
+ EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(std::string)";
+ }
+
+ {
+ // Make an call using a std::string.
+ auto Err = Client.callB<DummyRPCAPI::VoidString>("hello");
+ EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(string)";
+ }
+
+ ServerThread.join();
+}
+
TEST(DummyRPC, TestSerialization) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::AllTheTypes>(
@@ -300,9 +466,9 @@ TEST(DummyRPC, TestSerialization) {
}
TEST(DummyRPC, TestCustomType) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
@@ -343,9 +509,9 @@ TEST(DummyRPC, TestCustomType) {
}
TEST(DummyRPC, TestWithAltCustomType) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
@@ -385,10 +551,144 @@ TEST(DummyRPC, TestWithAltCustomType) {
ServerThread.join();
}
+TEST(DummyRPC, ReturnErrorSuccess) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ErrorFunc>(
+ []() {
+ return Error::success();
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+ [&](Error Err) {
+ EXPECT_FALSE(!!Err) << "Expected success value";
+ return Error::success();
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, ReturnErrorFailure) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ErrorFunc>(
+ []() {
+ return make_error<DummyError>(42);
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+ [&](Error Err) {
+ EXPECT_TRUE(Err.isA<DummyError>())
+ << "Incorrect error type";
+ return handleErrors(
+ std::move(Err),
+ [](const DummyError &DE) {
+ EXPECT_EQ(DE.getValue(), 42ULL)
+ << "Incorrect DummyError serialization";
+ });
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, ReturnExpectedSuccess) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+ []() -> uint32_t {
+ return 42;
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+ [&](Expected<uint32_t> ValOrErr) {
+ EXPECT_TRUE(!!ValOrErr)
+ << "Expected success value";
+ EXPECT_EQ(*ValOrErr, 42ULL)
+ << "Incorrect Expected<uint32_t> deserialization";
+ return Error::success();
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, ReturnExpectedFailure) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+ []() -> Expected<uint32_t> {
+ return make_error<DummyError>(7);
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+ [&](Expected<uint32_t> ValOrErr) {
+ EXPECT_FALSE(!!ValOrErr)
+ << "Expected failure value";
+ auto Err = ValOrErr.takeError();
+ EXPECT_TRUE(Err.isA<DummyError>())
+ << "Incorrect error type";
+ return handleErrors(
+ std::move(Err),
+ [](const DummyError &DE) {
+ EXPECT_EQ(DE.getValue(), 7ULL)
+ << "Incorrect DummyError serialization";
+ });
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
TEST(DummyRPC, TestParallelCallGroup) {
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::IntInt>(
@@ -405,10 +705,11 @@ TEST(DummyRPC, TestParallelCallGroup) {
{
int A, B, C;
- ParallelCallGroup<DummyRPCEndpoint> PCG(Client);
+ ParallelCallGroup PCG;
{
- auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ auto Err = PCG.call(
+ rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
[&A](Expected<int> Result) {
EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
A = *Result;
@@ -418,7 +719,8 @@ TEST(DummyRPC, TestParallelCallGroup) {
}
{
- auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ auto Err = PCG.call(
+ rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
[&B](Expected<int> Result) {
EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
B = *Result;
@@ -428,7 +730,8 @@ TEST(DummyRPC, TestParallelCallGroup) {
}
{
- auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ auto Err = PCG.call(
+ rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
[&C](Expected<int> Result) {
EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
C = *Result;
@@ -443,10 +746,7 @@ TEST(DummyRPC, TestParallelCallGroup) {
EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)";
}
- {
- auto Err = PCG.wait();
- EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)";
- }
+ PCG.wait();
EXPECT_EQ(A, 2) << "First parallel call returned bogus result";
EXPECT_EQ(B, 4) << "Second parallel call returned bogus result";
@@ -468,9 +768,9 @@ TEST(DummyRPC, TestAPICalls) {
static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value,
"Contains<Func> template should return false here");
- Queue Q1, Q2;
- DummyRPCEndpoint Client(Q1, Q2);
- DummyRPCEndpoint Server(Q2, Q1);
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
std::thread ServerThread(
[&]() {
@@ -496,11 +796,37 @@ TEST(DummyRPC, TestAPICalls) {
{
auto Err = DummyCallsAll::negotiate(Client);
- EXPECT_EQ(errorToErrorCode(std::move(Err)).value(),
- static_cast<int>(OrcErrorCode::UnknownRPCFunction))
- << "Expected 'UnknownRPCFunction' error for attempted negotiate of "
+ EXPECT_TRUE(Err.isA<CouldNotNegotiate>())
+ << "Expected CouldNotNegotiate error for attempted negotiate of "
"unsupported function";
+ consumeError(std::move(Err));
}
ServerThread.join();
}
+
+TEST(DummyRPC, TestRemoveHandler) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Server(*Channels.second);
+
+ Server.addHandler<DummyRPCAPI::VoidBool>(
+ [](bool B) {
+ EXPECT_EQ(B, true)
+ << "Server void(bool) received unexpected result";
+ });
+
+ Server.removeHandler<DummyRPCAPI::VoidBool>();
+}
+
+TEST(DummyRPC, TestClearHandlers) {
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Server(*Channels.second);
+
+ Server.addHandler<DummyRPCAPI::VoidBool>(
+ [](bool B) {
+ EXPECT_EQ(B, true)
+ << "Server void(bool) received unexpected result";
+ });
+
+ Server.clearHandlers();
+}
diff --git a/unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp
index 44b44f604159..de99c022fb9d 100644
--- a/unittests/ExecutionEngine/Orc/ObjectLinkingLayerTest.cpp
+++ b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp
@@ -1,4 +1,4 @@
-//===-- ObjectLinkingLayerTest.cpp - Unit tests for object linking layer --===//
+//===- RTDyldObjectLinkingLayerTest.cpp - RTDyld linking layer unit tests -===//
//
// The LLVM Compiler Infrastructure
//
@@ -13,7 +13,7 @@
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
#include "llvm/ExecutionEngine/Orc/NullResolver.h"
-#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
+#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "gtest/gtest.h"
@@ -23,8 +23,8 @@ using namespace llvm::orc;
namespace {
-class ObjectLinkingLayerExecutionTest : public testing::Test,
- public OrcExecutionTest {
+class RTDyldObjectLinkingLayerExecutionTest : public testing::Test,
+ public OrcExecutionTest {
};
@@ -44,7 +44,7 @@ public:
}
};
-TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) {
+TEST(RTDyldObjectLinkingLayerTest, TestSetProcessAllSections) {
class SectionMemoryManagerWrapper : public SectionMemoryManager {
public:
SectionMemoryManagerWrapper(bool &DebugSeen) : DebugSeen(DebugSeen) {}
@@ -60,10 +60,10 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) {
IsReadOnly);
}
private:
- bool DebugSeen;
+ bool &DebugSeen;
};
- ObjectLinkingLayer<> ObjLayer;
+ RTDyldObjectLinkingLayer<> ObjLayer;
LLVMContext Context;
auto M = llvm::make_unique<Module>("", Context);
@@ -75,6 +75,10 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) {
GV->setSection(".debug_str");
+
+ // Initialize the native target in case this is the first unit test
+ // to try to build a TM.
+ OrcNativeTarget::initialize();
std::unique_ptr<TargetMachine> TM(
EngineBuilder().selectTarget(Triple(M->getTargetTriple()), "", "",
SmallVector<std::string, 1>()));
@@ -99,6 +103,7 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) {
{
// Test with ProcessAllSections = false (the default).
auto H = ObjLayer.addObjectSet(Objs, &SMMW, &*Resolver);
+ ObjLayer.emitAndFinalize(H);
EXPECT_EQ(DebugSectionSeen, false)
<< "Unexpected debug info section";
ObjLayer.removeObjectSet(H);
@@ -108,17 +113,18 @@ TEST(ObjectLinkingLayerTest, TestSetProcessAllSections) {
// Test with ProcessAllSections = true.
ObjLayer.setProcessAllSections(true);
auto H = ObjLayer.addObjectSet(Objs, &SMMW, &*Resolver);
+ ObjLayer.emitAndFinalize(H);
EXPECT_EQ(DebugSectionSeen, true)
<< "Expected debug info section not seen";
ObjLayer.removeObjectSet(H);
}
}
-TEST_F(ObjectLinkingLayerExecutionTest, NoDuplicateFinalization) {
+TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoDuplicateFinalization) {
if (!TM)
return;
- ObjectLinkingLayer<> ObjLayer;
+ RTDyldObjectLinkingLayer<> ObjLayer;
SimpleCompiler Compile(*TM);
// Create a pair of modules that will trigger recursive finalization:
@@ -183,11 +189,11 @@ TEST_F(ObjectLinkingLayerExecutionTest, NoDuplicateFinalization) {
<< "Extra call to finalize";
}
-TEST_F(ObjectLinkingLayerExecutionTest, NoPrematureAllocation) {
+TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoPrematureAllocation) {
if (!TM)
return;
- ObjectLinkingLayer<> ObjLayer;
+ RTDyldObjectLinkingLayer<> ObjLayer;
SimpleCompiler Compile(*TM);
// Create a pair of unrelated modules: