diff options
Diffstat (limited to 'llvm/include/llvm/ExecutionEngine')
52 files changed, 14784 insertions, 0 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h b/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h new file mode 100644 index 000000000000..4fb6dad96387 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h @@ -0,0 +1,677 @@ +//===- ExecutionEngine.h - Abstract Execution Engine Interface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the abstract interface that implements execution support +// for LLVM. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_EXECUTIONENGINE_H +#define LLVM_EXECUTIONENGINE_EXECUTIONENGINE_H + +#include "llvm-c/ExecutionEngine.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/OrcV1Deprecation.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Module.h" +#include "llvm/Object/Binary.h" +#include "llvm/Support/CBindingWrapping.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include <algorithm> +#include <cstdint> +#include <functional> +#include <map> +#include <memory> +#include <string> +#include <vector> + +namespace llvm { + +class Constant; +class Function; +struct GenericValue; +class GlobalValue; +class GlobalVariable; +class JITEventListener; +class MCJITMemoryManager; +class ObjectCache; +class RTDyldMemoryManager; +class Triple; +class Type; + +namespace object { + +class Archive; +class ObjectFile; + +} // end namespace object + +/// Helper class for helping synchronize access to the global address map +/// table. Access to this class should be serialized under a mutex. +class ExecutionEngineState { +public: + using GlobalAddressMapTy = StringMap<uint64_t>; + +private: + /// GlobalAddressMap - A mapping between LLVM global symbol names values and + /// their actualized version... + GlobalAddressMapTy GlobalAddressMap; + + /// GlobalAddressReverseMap - This is the reverse mapping of GlobalAddressMap, + /// used to convert raw addresses into the LLVM global value that is emitted + /// at the address. This map is not computed unless getGlobalValueAtAddress + /// is called at some point. + std::map<uint64_t, std::string> GlobalAddressReverseMap; + +public: + GlobalAddressMapTy &getGlobalAddressMap() { + return GlobalAddressMap; + } + + std::map<uint64_t, std::string> &getGlobalAddressReverseMap() { + return GlobalAddressReverseMap; + } + + /// Erase an entry from the mapping table. + /// + /// \returns The address that \p ToUnmap was happed to. + uint64_t RemoveMapping(StringRef Name); +}; + +using FunctionCreator = std::function<void *(const std::string &)>; + +/// Abstract interface for implementation execution of LLVM modules, +/// designed to support both interpreter and just-in-time (JIT) compiler +/// implementations. +class ExecutionEngine { + /// The state object holding the global address mapping, which must be + /// accessed synchronously. + // + // FIXME: There is no particular need the entire map needs to be + // synchronized. Wouldn't a reader-writer design be better here? + ExecutionEngineState EEState; + + /// The target data for the platform for which execution is being performed. + /// + /// Note: the DataLayout is LLVMContext specific because it has an + /// internal cache based on type pointers. It makes unsafe to reuse the + /// ExecutionEngine across context, we don't enforce this rule but undefined + /// behavior can occurs if the user tries to do it. + const DataLayout DL; + + /// Whether lazy JIT compilation is enabled. + bool CompilingLazily; + + /// Whether JIT compilation of external global variables is allowed. + bool GVCompilationDisabled; + + /// Whether the JIT should perform lookups of external symbols (e.g., + /// using dlsym). + bool SymbolSearchingDisabled; + + /// Whether the JIT should verify IR modules during compilation. + bool VerifyModules; + + friend class EngineBuilder; // To allow access to JITCtor and InterpCtor. + +protected: + /// The list of Modules that we are JIT'ing from. We use a SmallVector to + /// optimize for the case where there is only one module. + SmallVector<std::unique_ptr<Module>, 1> Modules; + + /// getMemoryforGV - Allocate memory for a global variable. + virtual char *getMemoryForGV(const GlobalVariable *GV); + + static ExecutionEngine *(*MCJITCtor)( + std::unique_ptr<Module> M, std::string *ErrorStr, + std::shared_ptr<MCJITMemoryManager> MM, + std::shared_ptr<LegacyJITSymbolResolver> SR, + std::unique_ptr<TargetMachine> TM); + + static ExecutionEngine *(*OrcMCJITReplacementCtor)( + std::string *ErrorStr, std::shared_ptr<MCJITMemoryManager> MM, + std::shared_ptr<LegacyJITSymbolResolver> SR, + std::unique_ptr<TargetMachine> TM); + + static ExecutionEngine *(*InterpCtor)(std::unique_ptr<Module> M, + std::string *ErrorStr); + + /// LazyFunctionCreator - If an unknown function is needed, this function + /// pointer is invoked to create it. If this returns null, the JIT will + /// abort. + FunctionCreator LazyFunctionCreator; + + /// getMangledName - Get mangled name. + std::string getMangledName(const GlobalValue *GV); + +public: + /// lock - This lock protects the ExecutionEngine and MCJIT classes. It must + /// be held while changing the internal state of any of those classes. + sys::Mutex lock; + + //===--------------------------------------------------------------------===// + // ExecutionEngine Startup + //===--------------------------------------------------------------------===// + + virtual ~ExecutionEngine(); + + /// Add a Module to the list of modules that we can JIT from. + virtual void addModule(std::unique_ptr<Module> M) { + Modules.push_back(std::move(M)); + } + + /// addObjectFile - Add an ObjectFile to the execution engine. + /// + /// This method is only supported by MCJIT. MCJIT will immediately load the + /// object into memory and adds its symbols to the list used to resolve + /// external symbols while preparing other objects for execution. + /// + /// Objects added using this function will not be made executable until + /// needed by another object. + /// + /// MCJIT will take ownership of the ObjectFile. + virtual void addObjectFile(std::unique_ptr<object::ObjectFile> O); + virtual void addObjectFile(object::OwningBinary<object::ObjectFile> O); + + /// addArchive - Add an Archive to the execution engine. + /// + /// This method is only supported by MCJIT. MCJIT will use the archive to + /// resolve external symbols in objects it is loading. If a symbol is found + /// in the Archive the contained object file will be extracted (in memory) + /// and loaded for possible execution. + virtual void addArchive(object::OwningBinary<object::Archive> A); + + //===--------------------------------------------------------------------===// + + const DataLayout &getDataLayout() const { return DL; } + + /// removeModule - Removes a Module from the list of modules, but does not + /// free the module's memory. Returns true if M is found, in which case the + /// caller assumes responsibility for deleting the module. + // + // FIXME: This stealth ownership transfer is horrible. This will probably be + // fixed by deleting ExecutionEngine. + virtual bool removeModule(Module *M); + + /// FindFunctionNamed - Search all of the active modules to find the function that + /// defines FnName. This is very slow operation and shouldn't be used for + /// general code. + virtual Function *FindFunctionNamed(StringRef FnName); + + /// FindGlobalVariableNamed - Search all of the active modules to find the global variable + /// that defines Name. This is very slow operation and shouldn't be used for + /// general code. + virtual GlobalVariable *FindGlobalVariableNamed(StringRef Name, bool AllowInternal = false); + + /// runFunction - Execute the specified function with the specified arguments, + /// and return the result. + /// + /// For MCJIT execution engines, clients are encouraged to use the + /// "GetFunctionAddress" method (rather than runFunction) and cast the + /// returned uint64_t to the desired function pointer type. However, for + /// backwards compatibility MCJIT's implementation can execute 'main-like' + /// function (i.e. those returning void or int, and taking either no + /// arguments or (int, char*[])). + virtual GenericValue runFunction(Function *F, + ArrayRef<GenericValue> ArgValues) = 0; + + /// getPointerToNamedFunction - This method returns the address of the + /// specified function by using the dlsym function call. As such it is only + /// useful for resolving library symbols, not code generated symbols. + /// + /// If AbortOnFailure is false and no function with the given name is + /// found, this function silently returns a null pointer. Otherwise, + /// it prints a message to stderr and aborts. + /// + /// This function is deprecated for the MCJIT execution engine. + virtual void *getPointerToNamedFunction(StringRef Name, + bool AbortOnFailure = true) = 0; + + /// mapSectionAddress - map a section to its target address space value. + /// Map the address of a JIT section as returned from the memory manager + /// to the address in the target process as the running code will see it. + /// This is the address which will be used for relocation resolution. + virtual void mapSectionAddress(const void *LocalAddress, + uint64_t TargetAddress) { + llvm_unreachable("Re-mapping of section addresses not supported with this " + "EE!"); + } + + /// generateCodeForModule - Run code generation for the specified module and + /// load it into memory. + /// + /// When this function has completed, all code and data for the specified + /// module, and any module on which this module depends, will be generated + /// and loaded into memory, but relocations will not yet have been applied + /// and all memory will be readable and writable but not executable. + /// + /// This function is primarily useful when generating code for an external + /// target, allowing the client an opportunity to remap section addresses + /// before relocations are applied. Clients that intend to execute code + /// locally can use the getFunctionAddress call, which will generate code + /// and apply final preparations all in one step. + /// + /// This method has no effect for the interpeter. + virtual void generateCodeForModule(Module *M) {} + + /// finalizeObject - ensure the module is fully processed and is usable. + /// + /// It is the user-level function for completing the process of making the + /// object usable for execution. It should be called after sections within an + /// object have been relocated using mapSectionAddress. When this method is + /// called the MCJIT execution engine will reapply relocations for a loaded + /// object. This method has no effect for the interpeter. + virtual void finalizeObject() {} + + /// runStaticConstructorsDestructors - This method is used to execute all of + /// the static constructors or destructors for a program. + /// + /// \param isDtors - Run the destructors instead of constructors. + virtual void runStaticConstructorsDestructors(bool isDtors); + + /// This method is used to execute all of the static constructors or + /// destructors for a particular module. + /// + /// \param isDtors - Run the destructors instead of constructors. + void runStaticConstructorsDestructors(Module &module, bool isDtors); + + + /// runFunctionAsMain - This is a helper function which wraps runFunction to + /// handle the common task of starting up main with the specified argc, argv, + /// and envp parameters. + int runFunctionAsMain(Function *Fn, const std::vector<std::string> &argv, + const char * const * envp); + + + /// addGlobalMapping - Tell the execution engine that the specified global is + /// at the specified location. This is used internally as functions are JIT'd + /// and as global variables are laid out in memory. It can and should also be + /// used by clients of the EE that want to have an LLVM global overlay + /// existing data in memory. Values to be mapped should be named, and have + /// external or weak linkage. Mappings are automatically removed when their + /// GlobalValue is destroyed. + void addGlobalMapping(const GlobalValue *GV, void *Addr); + void addGlobalMapping(StringRef Name, uint64_t Addr); + + /// clearAllGlobalMappings - Clear all global mappings and start over again, + /// for use in dynamic compilation scenarios to move globals. + void clearAllGlobalMappings(); + + /// clearGlobalMappingsFromModule - Clear all global mappings that came from a + /// particular module, because it has been removed from the JIT. + void clearGlobalMappingsFromModule(Module *M); + + /// updateGlobalMapping - Replace an existing mapping for GV with a new + /// address. This updates both maps as required. If "Addr" is null, the + /// entry for the global is removed from the mappings. This returns the old + /// value of the pointer, or null if it was not in the map. + uint64_t updateGlobalMapping(const GlobalValue *GV, void *Addr); + uint64_t updateGlobalMapping(StringRef Name, uint64_t Addr); + + /// getAddressToGlobalIfAvailable - This returns the address of the specified + /// global symbol. + uint64_t getAddressToGlobalIfAvailable(StringRef S); + + /// getPointerToGlobalIfAvailable - This returns the address of the specified + /// global value if it is has already been codegen'd, otherwise it returns + /// null. + void *getPointerToGlobalIfAvailable(StringRef S); + void *getPointerToGlobalIfAvailable(const GlobalValue *GV); + + /// getPointerToGlobal - This returns the address of the specified global + /// value. This may involve code generation if it's a function. + /// + /// This function is deprecated for the MCJIT execution engine. Use + /// getGlobalValueAddress instead. + void *getPointerToGlobal(const GlobalValue *GV); + + /// getPointerToFunction - The different EE's represent function bodies in + /// different ways. They should each implement this to say what a function + /// pointer should look like. When F is destroyed, the ExecutionEngine will + /// remove its global mapping and free any machine code. Be sure no threads + /// are running inside F when that happens. + /// + /// This function is deprecated for the MCJIT execution engine. Use + /// getFunctionAddress instead. + virtual void *getPointerToFunction(Function *F) = 0; + + /// getPointerToFunctionOrStub - If the specified function has been + /// code-gen'd, return a pointer to the function. If not, compile it, or use + /// a stub to implement lazy compilation if available. See + /// getPointerToFunction for the requirements on destroying F. + /// + /// This function is deprecated for the MCJIT execution engine. Use + /// getFunctionAddress instead. + virtual void *getPointerToFunctionOrStub(Function *F) { + // Default implementation, just codegen the function. + return getPointerToFunction(F); + } + + /// getGlobalValueAddress - Return the address of the specified global + /// value. This may involve code generation. + /// + /// This function should not be called with the interpreter engine. + virtual uint64_t getGlobalValueAddress(const std::string &Name) { + // Default implementation for the interpreter. MCJIT will override this. + // JIT and interpreter clients should use getPointerToGlobal instead. + return 0; + } + + /// getFunctionAddress - Return the address of the specified function. + /// This may involve code generation. + virtual uint64_t getFunctionAddress(const std::string &Name) { + // Default implementation for the interpreter. MCJIT will override this. + // Interpreter clients should use getPointerToFunction instead. + return 0; + } + + /// getGlobalValueAtAddress - Return the LLVM global value object that starts + /// at the specified address. + /// + const GlobalValue *getGlobalValueAtAddress(void *Addr); + + /// StoreValueToMemory - Stores the data in Val of type Ty at address Ptr. + /// Ptr is the address of the memory at which to store Val, cast to + /// GenericValue *. It is not a pointer to a GenericValue containing the + /// address at which to store Val. + void StoreValueToMemory(const GenericValue &Val, GenericValue *Ptr, + Type *Ty); + + void InitializeMemory(const Constant *Init, void *Addr); + + /// getOrEmitGlobalVariable - Return the address of the specified global + /// variable, possibly emitting it to memory if needed. This is used by the + /// Emitter. + /// + /// This function is deprecated for the MCJIT execution engine. Use + /// getGlobalValueAddress instead. + virtual void *getOrEmitGlobalVariable(const GlobalVariable *GV) { + return getPointerToGlobal((const GlobalValue *)GV); + } + + /// Registers a listener to be called back on various events within + /// the JIT. See JITEventListener.h for more details. Does not + /// take ownership of the argument. The argument may be NULL, in + /// which case these functions do nothing. + virtual void RegisterJITEventListener(JITEventListener *) {} + virtual void UnregisterJITEventListener(JITEventListener *) {} + + /// Sets the pre-compiled object cache. The ownership of the ObjectCache is + /// not changed. Supported by MCJIT but not the interpreter. + virtual void setObjectCache(ObjectCache *) { + llvm_unreachable("No support for an object cache"); + } + + /// setProcessAllSections (MCJIT Only): By default, only sections that are + /// "required for execution" are passed to the RTDyldMemoryManager, and other + /// sections are discarded. Passing 'true' to this method will cause + /// RuntimeDyld to pass all sections to its RTDyldMemoryManager regardless + /// of whether they are "required to execute" in the usual sense. + /// + /// Rationale: Some MCJIT clients want to be able to inspect metadata + /// sections (e.g. Dwarf, Stack-maps) to enable functionality or analyze + /// performance. Passing these sections to the memory manager allows the + /// client to make policy about the relevant sections, rather than having + /// MCJIT do it. + virtual void setProcessAllSections(bool ProcessAllSections) { + llvm_unreachable("No support for ProcessAllSections option"); + } + + /// Return the target machine (if available). + virtual TargetMachine *getTargetMachine() { return nullptr; } + + /// DisableLazyCompilation - When lazy compilation is off (the default), the + /// JIT will eagerly compile every function reachable from the argument to + /// getPointerToFunction. If lazy compilation is turned on, the JIT will only + /// compile the one function and emit stubs to compile the rest when they're + /// first called. If lazy compilation is turned off again while some lazy + /// stubs are still around, and one of those stubs is called, the program will + /// abort. + /// + /// In order to safely compile lazily in a threaded program, the user must + /// ensure that 1) only one thread at a time can call any particular lazy + /// stub, and 2) any thread modifying LLVM IR must hold the JIT's lock + /// (ExecutionEngine::lock) or otherwise ensure that no other thread calls a + /// lazy stub. See http://llvm.org/PR5184 for details. + void DisableLazyCompilation(bool Disabled = true) { + CompilingLazily = !Disabled; + } + bool isCompilingLazily() const { + return CompilingLazily; + } + + /// DisableGVCompilation - If called, the JIT will abort if it's asked to + /// allocate space and populate a GlobalVariable that is not internal to + /// the module. + void DisableGVCompilation(bool Disabled = true) { + GVCompilationDisabled = Disabled; + } + bool isGVCompilationDisabled() const { + return GVCompilationDisabled; + } + + /// DisableSymbolSearching - If called, the JIT will not try to lookup unknown + /// symbols with dlsym. A client can still use InstallLazyFunctionCreator to + /// resolve symbols in a custom way. + void DisableSymbolSearching(bool Disabled = true) { + SymbolSearchingDisabled = Disabled; + } + bool isSymbolSearchingDisabled() const { + return SymbolSearchingDisabled; + } + + /// Enable/Disable IR module verification. + /// + /// Note: Module verification is enabled by default in Debug builds, and + /// disabled by default in Release. Use this method to override the default. + void setVerifyModules(bool Verify) { + VerifyModules = Verify; + } + bool getVerifyModules() const { + return VerifyModules; + } + + /// InstallLazyFunctionCreator - If an unknown function is needed, the + /// specified function pointer is invoked to create it. If it returns null, + /// the JIT will abort. + void InstallLazyFunctionCreator(FunctionCreator C) { + LazyFunctionCreator = std::move(C); + } + +protected: + ExecutionEngine(DataLayout DL) : DL(std::move(DL)) {} + explicit ExecutionEngine(DataLayout DL, std::unique_ptr<Module> M); + explicit ExecutionEngine(std::unique_ptr<Module> M); + + void emitGlobals(); + + void EmitGlobalVariable(const GlobalVariable *GV); + + GenericValue getConstantValue(const Constant *C); + void LoadValueFromMemory(GenericValue &Result, GenericValue *Ptr, + Type *Ty); + +private: + void Init(std::unique_ptr<Module> M); +}; + +namespace EngineKind { + + // These are actually bitmasks that get or-ed together. + enum Kind { + JIT = 0x1, + Interpreter = 0x2 + }; + const static Kind Either = (Kind)(JIT | Interpreter); + +} // end namespace EngineKind + +/// Builder class for ExecutionEngines. Use this by stack-allocating a builder, +/// chaining the various set* methods, and terminating it with a .create() +/// call. +class EngineBuilder { +private: + std::unique_ptr<Module> M; + EngineKind::Kind WhichEngine; + std::string *ErrorStr; + CodeGenOpt::Level OptLevel; + std::shared_ptr<MCJITMemoryManager> MemMgr; + std::shared_ptr<LegacyJITSymbolResolver> Resolver; + TargetOptions Options; + Optional<Reloc::Model> RelocModel; + Optional<CodeModel::Model> CMModel; + std::string MArch; + std::string MCPU; + SmallVector<std::string, 4> MAttrs; + bool VerifyModules; + bool UseOrcMCJITReplacement; + bool EmulatedTLS = true; + +public: + /// Default constructor for EngineBuilder. + EngineBuilder(); + + /// Constructor for EngineBuilder. + EngineBuilder(std::unique_ptr<Module> M); + + // Out-of-line since we don't have the def'n of RTDyldMemoryManager here. + ~EngineBuilder(); + + /// setEngineKind - Controls whether the user wants the interpreter, the JIT, + /// or whichever engine works. This option defaults to EngineKind::Either. + EngineBuilder &setEngineKind(EngineKind::Kind w) { + WhichEngine = w; + return *this; + } + + /// setMCJITMemoryManager - Sets the MCJIT memory manager to use. This allows + /// clients to customize their memory allocation policies for the MCJIT. This + /// is only appropriate for the MCJIT; setting this and configuring the builder + /// to create anything other than MCJIT will cause a runtime error. If create() + /// is called and is successful, the created engine takes ownership of the + /// memory manager. This option defaults to NULL. + EngineBuilder &setMCJITMemoryManager(std::unique_ptr<RTDyldMemoryManager> mcjmm); + + EngineBuilder& + setMemoryManager(std::unique_ptr<MCJITMemoryManager> MM); + + EngineBuilder &setSymbolResolver(std::unique_ptr<LegacyJITSymbolResolver> SR); + + /// setErrorStr - Set the error string to write to on error. This option + /// defaults to NULL. + EngineBuilder &setErrorStr(std::string *e) { + ErrorStr = e; + return *this; + } + + /// setOptLevel - Set the optimization level for the JIT. This option + /// defaults to CodeGenOpt::Default. + EngineBuilder &setOptLevel(CodeGenOpt::Level l) { + OptLevel = l; + return *this; + } + + /// setTargetOptions - Set the target options that the ExecutionEngine + /// target is using. Defaults to TargetOptions(). + EngineBuilder &setTargetOptions(const TargetOptions &Opts) { + Options = Opts; + return *this; + } + + /// setRelocationModel - Set the relocation model that the ExecutionEngine + /// target is using. Defaults to target specific default "Reloc::Default". + EngineBuilder &setRelocationModel(Reloc::Model RM) { + RelocModel = RM; + return *this; + } + + /// setCodeModel - Set the CodeModel that the ExecutionEngine target + /// data is using. Defaults to target specific default + /// "CodeModel::JITDefault". + EngineBuilder &setCodeModel(CodeModel::Model M) { + CMModel = M; + return *this; + } + + /// setMArch - Override the architecture set by the Module's triple. + EngineBuilder &setMArch(StringRef march) { + MArch.assign(march.begin(), march.end()); + return *this; + } + + /// setMCPU - Target a specific cpu type. + EngineBuilder &setMCPU(StringRef mcpu) { + MCPU.assign(mcpu.begin(), mcpu.end()); + return *this; + } + + /// setVerifyModules - Set whether the JIT implementation should verify + /// IR modules during compilation. + EngineBuilder &setVerifyModules(bool Verify) { + VerifyModules = Verify; + return *this; + } + + /// setMAttrs - Set cpu-specific attributes. + template<typename StringSequence> + EngineBuilder &setMAttrs(const StringSequence &mattrs) { + MAttrs.clear(); + MAttrs.append(mattrs.begin(), mattrs.end()); + return *this; + } + + // Use OrcMCJITReplacement instead of MCJIT. Off by default. + LLVM_ATTRIBUTE_DEPRECATED( + inline void setUseOrcMCJITReplacement(bool UseOrcMCJITReplacement), + "ORCv1 utilities (including OrcMCJITReplacement) are deprecated. Please " + "use ORCv2/LLJIT instead (see docs/ORCv2.rst)"); + + void setUseOrcMCJITReplacement(ORCv1DeprecationAcknowledgement, + bool UseOrcMCJITReplacement) { + this->UseOrcMCJITReplacement = UseOrcMCJITReplacement; + } + + void setEmulatedTLS(bool EmulatedTLS) { + this->EmulatedTLS = EmulatedTLS; + } + + TargetMachine *selectTarget(); + + /// selectTarget - Pick a target either via -march or by guessing the native + /// arch. Add any CPU features specified via -mcpu or -mattr. + TargetMachine *selectTarget(const Triple &TargetTriple, + StringRef MArch, + StringRef MCPU, + const SmallVectorImpl<std::string>& MAttrs); + + ExecutionEngine *create() { + return create(selectTarget()); + } + + ExecutionEngine *create(TargetMachine *TM); +}; + +void EngineBuilder::setUseOrcMCJITReplacement(bool UseOrcMCJITReplacement) { + this->UseOrcMCJITReplacement = UseOrcMCJITReplacement; +} + +// Create wrappers for C Binding types (see CBindingWrapping.h). +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionEngine, LLVMExecutionEngineRef) + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_EXECUTIONENGINE_H diff --git a/llvm/include/llvm/ExecutionEngine/GenericValue.h b/llvm/include/llvm/ExecutionEngine/GenericValue.h new file mode 100644 index 000000000000..1ca989da1b7e --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/GenericValue.h @@ -0,0 +1,54 @@ +//===- GenericValue.h - Represent any type of LLVM value --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The GenericValue class is used to represent an LLVM value of arbitrary type. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_GENERICVALUE_H +#define LLVM_EXECUTIONENGINE_GENERICVALUE_H + +#include "llvm/ADT/APInt.h" +#include <vector> + +namespace llvm { + +using PointerTy = void *; + +struct GenericValue { + struct IntPair { + unsigned int first; + unsigned int second; + }; + union { + double DoubleVal; + float FloatVal; + PointerTy PointerVal; + struct IntPair UIntPairVal; + unsigned char Untyped[8]; + }; + APInt IntVal; // also used for long doubles. + // For aggregate data types. + std::vector<GenericValue> AggregateVal; + + // to make code faster, set GenericValue to zero could be omitted, but it is + // potentially can cause problems, since GenericValue to store garbage + // instead of zero. + GenericValue() : IntVal(1, 0) { + UIntPairVal.first = 0; + UIntPairVal.second = 0; + } + explicit GenericValue(void *V) : PointerVal(V), IntVal(1, 0) {} +}; + +inline GenericValue PTOGV(void *P) { return GenericValue(P); } +inline void *GVTOP(const GenericValue &GV) { return GV.PointerVal; } + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_GENERICVALUE_H diff --git a/llvm/include/llvm/ExecutionEngine/Interpreter.h b/llvm/include/llvm/ExecutionEngine/Interpreter.h new file mode 100644 index 000000000000..0749409766e3 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Interpreter.h @@ -0,0 +1,27 @@ +//===-- Interpreter.h - Abstract Execution Engine Interface -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file forces the interpreter to link in on certain operating systems. +// (Windows). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_INTERPRETER_H +#define LLVM_EXECUTIONENGINE_INTERPRETER_H + +#include "llvm/ExecutionEngine/ExecutionEngine.h" + +extern "C" void LLVMLinkInInterpreter(); + +namespace { + struct ForceInterpreterLinking { + ForceInterpreterLinking() { LLVMLinkInInterpreter(); } + } ForceInterpreterLinking; +} + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/JITEventListener.h b/llvm/include/llvm/ExecutionEngine/JITEventListener.h new file mode 100644 index 000000000000..606b6f7cc128 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITEventListener.h @@ -0,0 +1,117 @@ +//===- JITEventListener.h - Exposes events from JIT compilation -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the JITEventListener interface, which lets users get +// callbacks when significant events happen during the JIT compilation process. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITEVENTLISTENER_H +#define LLVM_EXECUTIONENGINE_JITEVENTLISTENER_H + +#include "llvm-c/ExecutionEngine.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/Support/CBindingWrapping.h" +#include <cstdint> +#include <vector> + +namespace llvm { + +class IntelJITEventsWrapper; +class MachineFunction; +class OProfileWrapper; + +namespace object { + +class ObjectFile; + +} // end namespace object + +/// JITEventListener - Abstract interface for use by the JIT to notify clients +/// about significant events during compilation. For example, to notify +/// profilers and debuggers that need to know where functions have been emitted. +/// +/// The default implementation of each method does nothing. +class JITEventListener { +public: + using ObjectKey = uint64_t; + + JITEventListener() = default; + virtual ~JITEventListener() = default; + + /// notifyObjectLoaded - Called after an object has had its sections allocated + /// and addresses assigned to all symbols. Note: Section memory will not have + /// been relocated yet. notifyFunctionLoaded will not be called for + /// individual functions in the object. + /// + /// ELF-specific information + /// The ObjectImage contains the generated object image + /// with section headers updated to reflect the address at which sections + /// were loaded and with relocations performed in-place on debug sections. + virtual void notifyObjectLoaded(ObjectKey K, const object::ObjectFile &Obj, + const RuntimeDyld::LoadedObjectInfo &L) {} + + /// notifyFreeingObject - Called just before the memory associated with + /// a previously emitted object is released. + virtual void notifyFreeingObject(ObjectKey K) {} + + // Get a pointe to the GDB debugger registration listener. + static JITEventListener *createGDBRegistrationListener(); + +#if LLVM_USE_INTEL_JITEVENTS + // Construct an IntelJITEventListener + static JITEventListener *createIntelJITEventListener(); + + // Construct an IntelJITEventListener with a test Intel JIT API implementation + static JITEventListener *createIntelJITEventListener( + IntelJITEventsWrapper* AlternativeImpl); +#else + static JITEventListener *createIntelJITEventListener() { return nullptr; } + + static JITEventListener *createIntelJITEventListener( + IntelJITEventsWrapper* AlternativeImpl) { + return nullptr; + } +#endif // USE_INTEL_JITEVENTS + +#if LLVM_USE_OPROFILE + // Construct an OProfileJITEventListener + static JITEventListener *createOProfileJITEventListener(); + + // Construct an OProfileJITEventListener with a test opagent implementation + static JITEventListener *createOProfileJITEventListener( + OProfileWrapper* AlternativeImpl); +#else + static JITEventListener *createOProfileJITEventListener() { return nullptr; } + + static JITEventListener *createOProfileJITEventListener( + OProfileWrapper* AlternativeImpl) { + return nullptr; + } +#endif // USE_OPROFILE + +#if LLVM_USE_PERF + static JITEventListener *createPerfJITEventListener(); +#else + static JITEventListener *createPerfJITEventListener() + { + return nullptr; + } +#endif // USE_PERF + +private: + virtual void anchor(); +}; + +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JITEventListener, LLVMJITEventListenerRef) + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITEVENTLISTENER_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/EHFrameSupport.h b/llvm/include/llvm/ExecutionEngine/JITLink/EHFrameSupport.h new file mode 100644 index 000000000000..72687682f606 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/EHFrameSupport.h @@ -0,0 +1,91 @@ +//===--------- EHFrameSupport.h - JITLink eh-frame utils --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// EHFrame registration support for JITLink. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_EHFRAMESUPPORT_H +#define LLVM_EXECUTIONENGINE_JITLINK_EHFRAMESUPPORT_H + +#include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITLink/JITLink.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Support/Error.h" + +namespace llvm { +namespace jitlink { + +/// Registers all FDEs in the given eh-frame section with the current process. +Error registerEHFrameSection(const void *EHFrameSectionAddr, + size_t EHFrameSectionSize); + +/// Deregisters all FDEs in the given eh-frame section with the current process. +Error deregisterEHFrameSection(const void *EHFrameSectionAddr, + size_t EHFrameSectionSize); + +/// Supports registration/deregistration of EH-frames in a target process. +class EHFrameRegistrar { +public: + virtual ~EHFrameRegistrar(); + virtual Error registerEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) = 0; + virtual Error deregisterEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) = 0; +}; + +/// Registers / Deregisters EH-frames in the current process. +class InProcessEHFrameRegistrar final : public EHFrameRegistrar { +public: + /// Get a reference to the InProcessEHFrameRegistrar singleton. + static InProcessEHFrameRegistrar &getInstance(); + + InProcessEHFrameRegistrar(const InProcessEHFrameRegistrar &) = delete; + InProcessEHFrameRegistrar & + operator=(const InProcessEHFrameRegistrar &) = delete; + + InProcessEHFrameRegistrar(InProcessEHFrameRegistrar &&) = delete; + InProcessEHFrameRegistrar &operator=(InProcessEHFrameRegistrar &&) = delete; + + Error registerEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) override { + return registerEHFrameSection( + jitTargetAddressToPointer<void *>(EHFrameSectionAddr), + EHFrameSectionSize); + } + + Error deregisterEHFrames(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize) override { + return deregisterEHFrameSection( + jitTargetAddressToPointer<void *>(EHFrameSectionAddr), + EHFrameSectionSize); + } + +private: + InProcessEHFrameRegistrar(); +}; + +using StoreFrameRangeFunction = + std::function<void(JITTargetAddress EHFrameSectionAddr, + size_t EHFrameSectionSize)>; + +/// Creates a pass that records the address and size of the EH frame section. +/// If no eh-frame section is found then the address and size will both be given +/// as zero. +/// +/// Authors of JITLinkContexts can use this function to register a post-fixup +/// pass that records the range of the eh-frame section. This range can +/// be used after finalization to register and deregister the frame. +LinkGraphPassFunction +createEHFrameRecorderPass(const Triple &TT, + StoreFrameRangeFunction StoreFrameRange); + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_EHFRAMESUPPORT_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/JITLink.h b/llvm/include/llvm/ExecutionEngine/JITLink/JITLink.h new file mode 100644 index 000000000000..b531127cf892 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/JITLink.h @@ -0,0 +1,1044 @@ +//===------------ JITLink.h - JIT linker functionality ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains generic JIT-linker types. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_JITLINK_H +#define LLVM_EXECUTIONENGINE_JITLINK_JITLINK_H + +#include "JITLinkMemoryManager.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/MemoryBuffer.h" + +#include <map> +#include <string> +#include <system_error> + +namespace llvm { +namespace jitlink { + +class Symbol; +class Section; + +/// Base class for errors originating in JIT linker, e.g. missing relocation +/// support. +class JITLinkError : public ErrorInfo<JITLinkError> { +public: + static char ID; + + JITLinkError(Twine ErrMsg) : ErrMsg(ErrMsg.str()) {} + + void log(raw_ostream &OS) const override; + const std::string &getErrorMessage() const { return ErrMsg; } + std::error_code convertToErrorCode() const override; + +private: + std::string ErrMsg; +}; + +/// Represents fixups and constraints in the LinkGraph. +class Edge { +public: + using Kind = uint8_t; + + enum GenericEdgeKind : Kind { + Invalid, // Invalid edge value. + FirstKeepAlive, // Keeps target alive. Offset/addend zero. + KeepAlive = FirstKeepAlive, // Tag first edge kind that preserves liveness. + FirstRelocation // First architecture specific relocation. + }; + + using OffsetT = uint32_t; + using AddendT = int64_t; + + Edge(Kind K, OffsetT Offset, Symbol &Target, AddendT Addend) + : Target(&Target), Offset(Offset), Addend(Addend), K(K) {} + + OffsetT getOffset() const { return Offset; } + Kind getKind() const { return K; } + void setKind(Kind K) { this->K = K; } + bool isRelocation() const { return K >= FirstRelocation; } + Kind getRelocation() const { + assert(isRelocation() && "Not a relocation edge"); + return K - FirstRelocation; + } + bool isKeepAlive() const { return K >= FirstKeepAlive; } + Symbol &getTarget() const { return *Target; } + void setTarget(Symbol &Target) { this->Target = &Target; } + AddendT getAddend() const { return Addend; } + void setAddend(AddendT Addend) { this->Addend = Addend; } + +private: + Symbol *Target = nullptr; + OffsetT Offset = 0; + AddendT Addend = 0; + Kind K = 0; +}; + +/// Returns the string name of the given generic edge kind, or "unknown" +/// otherwise. Useful for debugging. +const char *getGenericEdgeKindName(Edge::Kind K); + +/// Base class for Addressable entities (externals, absolutes, blocks). +class Addressable { + friend class LinkGraph; + +protected: + Addressable(JITTargetAddress Address, bool IsDefined) + : Address(Address), IsDefined(IsDefined), IsAbsolute(false) {} + + Addressable(JITTargetAddress Address) + : Address(Address), IsDefined(false), IsAbsolute(true) { + assert(!(IsDefined && IsAbsolute) && + "Block cannot be both defined and absolute"); + } + +public: + Addressable(const Addressable &) = delete; + Addressable &operator=(const Addressable &) = default; + Addressable(Addressable &&) = delete; + Addressable &operator=(Addressable &&) = default; + + JITTargetAddress getAddress() const { return Address; } + void setAddress(JITTargetAddress Address) { this->Address = Address; } + + /// Returns true if this is a defined addressable, in which case you + /// can downcast this to a . + bool isDefined() const { return static_cast<bool>(IsDefined); } + bool isAbsolute() const { return static_cast<bool>(IsAbsolute); } + +private: + JITTargetAddress Address = 0; + uint64_t IsDefined : 1; + uint64_t IsAbsolute : 1; +}; + +using BlockOrdinal = unsigned; +using SectionOrdinal = unsigned; + +/// An Addressable with content and edges. +class Block : public Addressable { + friend class LinkGraph; + +private: + /// Create a zero-fill defined addressable. + Block(Section &Parent, BlockOrdinal Ordinal, JITTargetAddress Size, + JITTargetAddress Address, uint64_t Alignment, uint64_t AlignmentOffset) + : Addressable(Address, true), Parent(Parent), Size(Size), + Ordinal(Ordinal) { + assert(isPowerOf2_64(Alignment) && "Alignment must be power of 2"); + assert(AlignmentOffset < Alignment && + "Alignment offset cannot exceed alignment"); + assert(AlignmentOffset <= MaxAlignmentOffset && + "Alignment offset exceeds maximum"); + P2Align = Alignment ? countTrailingZeros(Alignment) : 0; + this->AlignmentOffset = AlignmentOffset; + } + + /// Create a defined addressable for the given content. + Block(Section &Parent, BlockOrdinal Ordinal, StringRef Content, + JITTargetAddress Address, uint64_t Alignment, uint64_t AlignmentOffset) + : Addressable(Address, true), Parent(Parent), Data(Content.data()), + Size(Content.size()), Ordinal(Ordinal) { + assert(isPowerOf2_64(Alignment) && "Alignment must be power of 2"); + assert(AlignmentOffset < Alignment && + "Alignment offset cannot exceed alignment"); + assert(AlignmentOffset <= MaxAlignmentOffset && + "Alignment offset exceeds maximum"); + P2Align = Alignment ? countTrailingZeros(Alignment) : 0; + this->AlignmentOffset = AlignmentOffset; + } + +public: + using EdgeVector = std::vector<Edge>; + using edge_iterator = EdgeVector::iterator; + using const_edge_iterator = EdgeVector::const_iterator; + + Block(const Block &) = delete; + Block &operator=(const Block &) = delete; + Block(Block &&) = delete; + Block &operator=(Block &&) = delete; + + /// Return the parent section for this block. + Section &getSection() const { return Parent; } + + /// Return the ordinal for this block. + BlockOrdinal getOrdinal() const { return Ordinal; } + + /// Returns true if this is a zero-fill block. + /// + /// If true, getSize is callable but getContent is not (the content is + /// defined to be a sequence of zero bytes of length Size). + bool isZeroFill() const { return !Data; } + + /// Returns the size of this defined addressable. + size_t getSize() const { return Size; } + + /// Get the content for this block. Block must not be a zero-fill block. + StringRef getContent() const { + assert(Data && "Section does not contain content"); + return StringRef(Data, Size); + } + + /// Set the content for this block. + /// Caller is responsible for ensuring the underlying bytes are not + /// deallocated while pointed to by this block. + void setContent(StringRef Content) { + Data = Content.data(); + Size = Content.size(); + } + + /// Get the alignment for this content. + uint64_t getAlignment() const { return 1ull << P2Align; } + + /// Get the alignment offset for this content. + uint64_t getAlignmentOffset() const { return AlignmentOffset; } + + /// Add an edge to this block. + void addEdge(Edge::Kind K, Edge::OffsetT Offset, Symbol &Target, + Edge::AddendT Addend) { + Edges.push_back(Edge(K, Offset, Target, Addend)); + } + + /// Return the list of edges attached to this content. + iterator_range<edge_iterator> edges() { + return make_range(Edges.begin(), Edges.end()); + } + + /// Returns the list of edges attached to this content. + iterator_range<const_edge_iterator> edges() const { + return make_range(Edges.begin(), Edges.end()); + } + + /// Return the size of the edges list. + size_t edges_size() const { return Edges.size(); } + + /// Returns true if the list of edges is empty. + bool edges_empty() const { return Edges.empty(); } + +private: + static constexpr uint64_t MaxAlignmentOffset = (1ULL << 57) - 1; + + uint64_t P2Align : 5; + uint64_t AlignmentOffset : 57; + Section &Parent; + const char *Data = nullptr; + size_t Size = 0; + BlockOrdinal Ordinal = 0; + std::vector<Edge> Edges; +}; + +/// Describes symbol linkage. This can be used to make resolve definition +/// clashes. +enum class Linkage : uint8_t { + Strong, + Weak, +}; + +/// For errors and debugging output. +const char *getLinkageName(Linkage L); + +/// Defines the scope in which this symbol should be visible: +/// Default -- Visible in the public interface of the linkage unit. +/// Hidden -- Visible within the linkage unit, but not exported from it. +/// Local -- Visible only within the LinkGraph. +enum class Scope : uint8_t { Default, Hidden, Local }; + +/// For debugging output. +const char *getScopeName(Scope S); + +raw_ostream &operator<<(raw_ostream &OS, const Block &B); + +/// Symbol representation. +/// +/// Symbols represent locations within Addressable objects. +/// They can be either Named or Anonymous. +/// Anonymous symbols have neither linkage nor visibility, and must point at +/// ContentBlocks. +/// Named symbols may be in one of four states: +/// - Null: Default initialized. Assignable, but otherwise unusable. +/// - Defined: Has both linkage and visibility and points to a ContentBlock +/// - Common: Has both linkage and visibility, points to a null Addressable. +/// - External: Has neither linkage nor visibility, points to an external +/// Addressable. +/// +class Symbol { + friend class LinkGraph; + +private: + Symbol(Addressable &Base, JITTargetAddress Offset, StringRef Name, + JITTargetAddress Size, Linkage L, Scope S, bool IsLive, + bool IsCallable) + : Name(Name), Base(&Base), Offset(Offset), Size(Size) { + setLinkage(L); + setScope(S); + setLive(IsLive); + setCallable(IsCallable); + } + + static Symbol &constructCommon(void *SymStorage, Block &Base, StringRef Name, + JITTargetAddress Size, Scope S, bool IsLive) { + assert(SymStorage && "Storage cannot be null"); + assert(!Name.empty() && "Common symbol name cannot be empty"); + assert(Base.isDefined() && + "Cannot create common symbol from undefined block"); + assert(static_cast<Block &>(Base).getSize() == Size && + "Common symbol size should match underlying block size"); + auto *Sym = reinterpret_cast<Symbol *>(SymStorage); + new (Sym) Symbol(Base, 0, Name, Size, Linkage::Weak, S, IsLive, false); + return *Sym; + } + + static Symbol &constructExternal(void *SymStorage, Addressable &Base, + StringRef Name, JITTargetAddress Size) { + assert(SymStorage && "Storage cannot be null"); + assert(!Base.isDefined() && + "Cannot create external symbol from defined block"); + assert(!Name.empty() && "External symbol name cannot be empty"); + auto *Sym = reinterpret_cast<Symbol *>(SymStorage); + new (Sym) Symbol(Base, 0, Name, Size, Linkage::Strong, Scope::Default, + false, false); + return *Sym; + } + + static Symbol &constructAbsolute(void *SymStorage, Addressable &Base, + StringRef Name, JITTargetAddress Size, + Linkage L, Scope S, bool IsLive) { + assert(SymStorage && "Storage cannot be null"); + assert(!Base.isDefined() && + "Cannot create absolute symbol from a defined block"); + auto *Sym = reinterpret_cast<Symbol *>(SymStorage); + new (Sym) Symbol(Base, 0, Name, Size, L, S, IsLive, false); + return *Sym; + } + + static Symbol &constructAnonDef(void *SymStorage, Block &Base, + JITTargetAddress Offset, + JITTargetAddress Size, bool IsCallable, + bool IsLive) { + assert(SymStorage && "Storage cannot be null"); + auto *Sym = reinterpret_cast<Symbol *>(SymStorage); + new (Sym) Symbol(Base, Offset, StringRef(), Size, Linkage::Strong, + Scope::Local, IsLive, IsCallable); + return *Sym; + } + + static Symbol &constructNamedDef(void *SymStorage, Block &Base, + JITTargetAddress Offset, StringRef Name, + JITTargetAddress Size, Linkage L, Scope S, + bool IsLive, bool IsCallable) { + assert(SymStorage && "Storage cannot be null"); + assert(!Name.empty() && "Name cannot be empty"); + auto *Sym = reinterpret_cast<Symbol *>(SymStorage); + new (Sym) Symbol(Base, Offset, Name, Size, L, S, IsLive, IsCallable); + return *Sym; + } + +public: + /// Create a null Symbol. This allows Symbols to be default initialized for + /// use in containers (e.g. as map values). Null symbols are only useful for + /// assigning to. + Symbol() = default; + + // Symbols are not movable or copyable. + Symbol(const Symbol &) = delete; + Symbol &operator=(const Symbol &) = delete; + Symbol(Symbol &&) = delete; + Symbol &operator=(Symbol &&) = delete; + + /// Returns true if this symbol has a name. + bool hasName() const { return !Name.empty(); } + + /// Returns the name of this symbol (empty if the symbol is anonymous). + StringRef getName() const { + assert((!Name.empty() || getScope() == Scope::Local) && + "Anonymous symbol has non-local scope"); + return Name; + } + + /// Returns true if this Symbol has content (potentially) defined within this + /// object file (i.e. is anything but an external or absolute symbol). + bool isDefined() const { + assert(Base && "Attempt to access null symbol"); + return Base->isDefined(); + } + + /// Returns true if this symbol is live (i.e. should be treated as a root for + /// dead stripping). + bool isLive() const { + assert(Base && "Attempting to access null symbol"); + return IsLive; + } + + /// Set this symbol's live bit. + void setLive(bool IsLive) { this->IsLive = IsLive; } + + /// Returns true is this symbol is callable. + bool isCallable() const { return IsCallable; } + + /// Set this symbol's callable bit. + void setCallable(bool IsCallable) { this->IsCallable = IsCallable; } + + /// Returns true if the underlying addressable is an unresolved external. + bool isExternal() const { + assert(Base && "Attempt to access null symbol"); + return !Base->isDefined() && !Base->isAbsolute(); + } + + /// Returns true if the underlying addressable is an absolute symbol. + bool isAbsolute() const { + assert(Base && "Attempt to access null symbol"); + return !Base->isDefined() && Base->isAbsolute(); + } + + /// Return the addressable that this symbol points to. + Addressable &getAddressable() { + assert(Base && "Cannot get underlying addressable for null symbol"); + return *Base; + } + + /// Return the addressable that thsi symbol points to. + const Addressable &getAddressable() const { + assert(Base && "Cannot get underlying addressable for null symbol"); + return *Base; + } + + /// Return the Block for this Symbol (Symbol must be defined). + Block &getBlock() { + assert(Base && "Cannot get block for null symbol"); + assert(Base->isDefined() && "Not a defined symbol"); + return static_cast<Block &>(*Base); + } + + /// Return the Block for this Symbol (Symbol must be defined). + const Block &getBlock() const { + assert(Base && "Cannot get block for null symbol"); + assert(Base->isDefined() && "Not a defined symbol"); + return static_cast<const Block &>(*Base); + } + + /// Returns the offset for this symbol within the underlying addressable. + JITTargetAddress getOffset() const { return Offset; } + + /// Returns the address of this symbol. + JITTargetAddress getAddress() const { return Base->getAddress() + Offset; } + + /// Returns the size of this symbol. + JITTargetAddress getSize() const { return Size; } + + /// Returns true if this symbol is backed by a zero-fill block. + /// This method may only be called on defined symbols. + bool isSymbolZeroFill() const { return getBlock().isZeroFill(); } + + /// Returns the content in the underlying block covered by this symbol. + /// This method may only be called on defined non-zero-fill symbols. + StringRef getSymbolContent() const { + return getBlock().getContent().substr(Offset, Size); + } + + /// Get the linkage for this Symbol. + Linkage getLinkage() const { return static_cast<Linkage>(L); } + + /// Set the linkage for this Symbol. + void setLinkage(Linkage L) { + assert((L == Linkage::Strong || (Base->isDefined() && !Name.empty())) && + "Linkage can only be applied to defined named symbols"); + this->L = static_cast<uint8_t>(L); + } + + /// Get the visibility for this Symbol. + Scope getScope() const { return static_cast<Scope>(S); } + + /// Set the visibility for this Symbol. + void setScope(Scope S) { + assert((S == Scope::Default || Base->isDefined() || Base->isAbsolute()) && + "Invalid visibility for symbol type"); + this->S = static_cast<uint8_t>(S); + } + +private: + void makeExternal(Addressable &A) { + assert(!A.isDefined() && "Attempting to make external with defined block"); + Base = &A; + Offset = 0; + setLinkage(Linkage::Strong); + setScope(Scope::Default); + IsLive = 0; + // note: Size and IsCallable fields left unchanged. + } + + static constexpr uint64_t MaxOffset = (1ULL << 59) - 1; + + // FIXME: A char* or SymbolStringPtr may pack better. + StringRef Name; + Addressable *Base = nullptr; + uint64_t Offset : 59; + uint64_t L : 1; + uint64_t S : 2; + uint64_t IsLive : 1; + uint64_t IsCallable : 1; + JITTargetAddress Size = 0; +}; + +raw_ostream &operator<<(raw_ostream &OS, const Symbol &A); + +void printEdge(raw_ostream &OS, const Block &B, const Edge &E, + StringRef EdgeKindName); + +/// Represents an object file section. +class Section { + friend class LinkGraph; + +private: + Section(StringRef Name, sys::Memory::ProtectionFlags Prot, + SectionOrdinal SecOrdinal) + : Name(Name), Prot(Prot), SecOrdinal(SecOrdinal) {} + + using SymbolSet = DenseSet<Symbol *>; + using BlockSet = DenseSet<Block *>; + +public: + using symbol_iterator = SymbolSet::iterator; + using const_symbol_iterator = SymbolSet::const_iterator; + + using block_iterator = BlockSet::iterator; + using const_block_iterator = BlockSet::const_iterator; + + ~Section(); + + /// Returns the name of this section. + StringRef getName() const { return Name; } + + /// Returns the protection flags for this section. + sys::Memory::ProtectionFlags getProtectionFlags() const { return Prot; } + + /// Returns the ordinal for this section. + SectionOrdinal getOrdinal() const { return SecOrdinal; } + + /// Returns an iterator over the symbols defined in this section. + iterator_range<symbol_iterator> symbols() { + return make_range(Symbols.begin(), Symbols.end()); + } + + /// Returns an iterator over the symbols defined in this section. + iterator_range<const_symbol_iterator> symbols() const { + return make_range(Symbols.begin(), Symbols.end()); + } + + /// Return the number of symbols in this section. + SymbolSet::size_type symbols_size() { return Symbols.size(); } + + /// Return true if this section contains no symbols. + bool symbols_empty() const { return Symbols.empty(); } + + /// Returns the ordinal for the next block. + BlockOrdinal getNextBlockOrdinal() { return NextBlockOrdinal++; } + +private: + void addSymbol(Symbol &Sym) { + assert(!Symbols.count(&Sym) && "Symbol is already in this section"); + Symbols.insert(&Sym); + } + + void removeSymbol(Symbol &Sym) { + assert(Symbols.count(&Sym) && "symbol is not in this section"); + Symbols.erase(&Sym); + } + + StringRef Name; + sys::Memory::ProtectionFlags Prot; + SectionOrdinal SecOrdinal = 0; + BlockOrdinal NextBlockOrdinal = 0; + SymbolSet Symbols; +}; + +/// Represents a section address range via a pair of Block pointers +/// to the first and last Blocks in the section. +class SectionRange { +public: + SectionRange() = default; + SectionRange(const Section &Sec) { + if (Sec.symbols_empty()) + return; + First = Last = *Sec.symbols().begin(); + for (auto *Sym : Sec.symbols()) { + if (Sym->getAddress() < First->getAddress()) + First = Sym; + if (Sym->getAddress() > Last->getAddress()) + Last = Sym; + } + } + Symbol *getFirstSymbol() const { + assert((!Last || First) && "First can not be null if end is non-null"); + return First; + } + Symbol *getLastSymbol() const { + assert((First || !Last) && "Last can not be null if start is non-null"); + return Last; + } + bool isEmpty() const { + assert((First || !Last) && "Last can not be null if start is non-null"); + return !First; + } + JITTargetAddress getStart() const { + return First ? First->getBlock().getAddress() : 0; + } + JITTargetAddress getEnd() const { + return Last ? Last->getBlock().getAddress() + Last->getBlock().getSize() + : 0; + } + uint64_t getSize() const { return getEnd() - getStart(); } + +private: + Symbol *First = nullptr; + Symbol *Last = nullptr; +}; + +class LinkGraph { +private: + using SectionList = std::vector<std::unique_ptr<Section>>; + using ExternalSymbolSet = DenseSet<Symbol *>; + using BlockSet = DenseSet<Block *>; + + template <typename... ArgTs> + Addressable &createAddressable(ArgTs &&... Args) { + Addressable *A = + reinterpret_cast<Addressable *>(Allocator.Allocate<Addressable>()); + new (A) Addressable(std::forward<ArgTs>(Args)...); + return *A; + } + + void destroyAddressable(Addressable &A) { + A.~Addressable(); + Allocator.Deallocate(&A); + } + + template <typename... ArgTs> Block &createBlock(ArgTs &&... Args) { + Block *B = reinterpret_cast<Block *>(Allocator.Allocate<Block>()); + new (B) Block(std::forward<ArgTs>(Args)...); + Blocks.insert(B); + return *B; + } + + void destroyBlock(Block &B) { + Blocks.erase(&B); + B.~Block(); + Allocator.Deallocate(&B); + } + + void destroySymbol(Symbol &S) { + S.~Symbol(); + Allocator.Deallocate(&S); + } + +public: + using external_symbol_iterator = ExternalSymbolSet::iterator; + + using block_iterator = BlockSet::iterator; + + using section_iterator = pointee_iterator<SectionList::iterator>; + using const_section_iterator = pointee_iterator<SectionList::const_iterator>; + + template <typename SectionItrT, typename SymbolItrT, typename T> + class defined_symbol_iterator_impl + : public iterator_facade_base< + defined_symbol_iterator_impl<SectionItrT, SymbolItrT, T>, + std::forward_iterator_tag, T> { + public: + defined_symbol_iterator_impl() = default; + + defined_symbol_iterator_impl(SectionItrT SecI, SectionItrT SecE) + : SecI(SecI), SecE(SecE), + SymI(SecI != SecE ? SecI->symbols().begin() : SymbolItrT()) { + moveToNextSymbolOrEnd(); + } + + bool operator==(const defined_symbol_iterator_impl &RHS) const { + return (SecI == RHS.SecI) && (SymI == RHS.SymI); + } + + T operator*() const { + assert(SymI != SecI->symbols().end() && "Dereferencing end?"); + return *SymI; + } + + defined_symbol_iterator_impl operator++() { + ++SymI; + moveToNextSymbolOrEnd(); + return *this; + } + + private: + void moveToNextSymbolOrEnd() { + while (SecI != SecE && SymI == SecI->symbols().end()) { + ++SecI; + SymI = SecI == SecE ? SymbolItrT() : SecI->symbols().begin(); + } + } + + SectionItrT SecI, SecE; + SymbolItrT SymI; + }; + + using defined_symbol_iterator = + defined_symbol_iterator_impl<const_section_iterator, + Section::symbol_iterator, Symbol *>; + + using const_defined_symbol_iterator = defined_symbol_iterator_impl< + const_section_iterator, Section::const_symbol_iterator, const Symbol *>; + + LinkGraph(std::string Name, unsigned PointerSize, + support::endianness Endianness) + : Name(std::move(Name)), PointerSize(PointerSize), + Endianness(Endianness) {} + + ~LinkGraph(); + + /// Returns the name of this graph (usually the name of the original + /// underlying MemoryBuffer). + const std::string &getName() { return Name; } + + /// Returns the pointer size for use in this graph. + unsigned getPointerSize() const { return PointerSize; } + + /// Returns the endianness of content in this graph. + support::endianness getEndianness() const { return Endianness; } + + /// Create a section with the given name, protection flags, and alignment. + Section &createSection(StringRef Name, sys::Memory::ProtectionFlags Prot) { + std::unique_ptr<Section> Sec(new Section(Name, Prot, Sections.size())); + Sections.push_back(std::move(Sec)); + return *Sections.back(); + } + + /// Create a content block. + Block &createContentBlock(Section &Parent, StringRef Content, + uint64_t Address, uint64_t Alignment, + uint64_t AlignmentOffset) { + return createBlock(Parent, Parent.getNextBlockOrdinal(), Content, Address, + Alignment, AlignmentOffset); + } + + /// Create a zero-fill block. + Block &createZeroFillBlock(Section &Parent, uint64_t Size, uint64_t Address, + uint64_t Alignment, uint64_t AlignmentOffset) { + return createBlock(Parent, Parent.getNextBlockOrdinal(), Size, Address, + Alignment, AlignmentOffset); + } + + /// Add an external symbol. + /// Some formats (e.g. ELF) allow Symbols to have sizes. For Symbols whose + /// size is not known, you should substitute '0'. + Symbol &addExternalSymbol(StringRef Name, uint64_t Size) { + auto &Sym = Symbol::constructExternal( + Allocator.Allocate<Symbol>(), createAddressable(0, false), Name, Size); + ExternalSymbols.insert(&Sym); + return Sym; + } + + /// Add an absolute symbol. + Symbol &addAbsoluteSymbol(StringRef Name, JITTargetAddress Address, + uint64_t Size, Linkage L, Scope S, bool IsLive) { + auto &Sym = Symbol::constructAbsolute(Allocator.Allocate<Symbol>(), + createAddressable(Address), Name, + Size, L, S, IsLive); + AbsoluteSymbols.insert(&Sym); + return Sym; + } + + /// Convenience method for adding a weak zero-fill symbol. + Symbol &addCommonSymbol(StringRef Name, Scope S, Section &Section, + JITTargetAddress Address, uint64_t Size, + uint64_t Alignment, bool IsLive) { + auto &Sym = Symbol::constructCommon( + Allocator.Allocate<Symbol>(), + createBlock(Section, Section.getNextBlockOrdinal(), Address, Size, + Alignment, 0), + Name, Size, S, IsLive); + Section.addSymbol(Sym); + return Sym; + } + + /// Add an anonymous symbol. + Symbol &addAnonymousSymbol(Block &Content, JITTargetAddress Offset, + JITTargetAddress Size, bool IsCallable, + bool IsLive) { + auto &Sym = Symbol::constructAnonDef(Allocator.Allocate<Symbol>(), Content, + Offset, Size, IsCallable, IsLive); + Content.getSection().addSymbol(Sym); + return Sym; + } + + /// Add a named symbol. + Symbol &addDefinedSymbol(Block &Content, JITTargetAddress Offset, + StringRef Name, JITTargetAddress Size, Linkage L, + Scope S, bool IsCallable, bool IsLive) { + auto &Sym = + Symbol::constructNamedDef(Allocator.Allocate<Symbol>(), Content, Offset, + Name, Size, L, S, IsLive, IsCallable); + Content.getSection().addSymbol(Sym); + return Sym; + } + + iterator_range<section_iterator> sections() { + return make_range(section_iterator(Sections.begin()), + section_iterator(Sections.end())); + } + + /// Returns the section with the given name if it exists, otherwise returns + /// null. + Section *findSectionByName(StringRef Name) { + for (auto &S : sections()) + if (S.getName() == Name) + return &S; + return nullptr; + } + + iterator_range<external_symbol_iterator> external_symbols() { + return make_range(ExternalSymbols.begin(), ExternalSymbols.end()); + } + + iterator_range<external_symbol_iterator> absolute_symbols() { + return make_range(AbsoluteSymbols.begin(), AbsoluteSymbols.end()); + } + + iterator_range<defined_symbol_iterator> defined_symbols() { + return make_range(defined_symbol_iterator(Sections.begin(), Sections.end()), + defined_symbol_iterator(Sections.end(), Sections.end())); + } + + iterator_range<const_defined_symbol_iterator> defined_symbols() const { + return make_range( + const_defined_symbol_iterator(Sections.begin(), Sections.end()), + const_defined_symbol_iterator(Sections.end(), Sections.end())); + } + + iterator_range<block_iterator> blocks() { + return make_range(Blocks.begin(), Blocks.end()); + } + + /// Turn a defined symbol into an external one. + void makeExternal(Symbol &Sym) { + if (Sym.getAddressable().isAbsolute()) { + assert(AbsoluteSymbols.count(&Sym) && + "Sym is not in the absolute symbols set"); + AbsoluteSymbols.erase(&Sym); + } else { + assert(Sym.isDefined() && "Sym is not a defined symbol"); + Section &Sec = Sym.getBlock().getSection(); + Sec.removeSymbol(Sym); + } + Sym.makeExternal(createAddressable(false)); + ExternalSymbols.insert(&Sym); + } + + /// Removes an external symbol. Also removes the underlying Addressable. + void removeExternalSymbol(Symbol &Sym) { + assert(!Sym.isDefined() && !Sym.isAbsolute() && + "Sym is not an external symbol"); + assert(ExternalSymbols.count(&Sym) && "Symbol is not in the externals set"); + ExternalSymbols.erase(&Sym); + Addressable &Base = *Sym.Base; + destroySymbol(Sym); + destroyAddressable(Base); + } + + /// Remove an absolute symbol. Also removes the underlying Addressable. + void removeAbsoluteSymbol(Symbol &Sym) { + assert(!Sym.isDefined() && Sym.isAbsolute() && + "Sym is not an absolute symbol"); + assert(AbsoluteSymbols.count(&Sym) && + "Symbol is not in the absolute symbols set"); + AbsoluteSymbols.erase(&Sym); + Addressable &Base = *Sym.Base; + destroySymbol(Sym); + destroyAddressable(Base); + } + + /// Removes defined symbols. Does not remove the underlying block. + void removeDefinedSymbol(Symbol &Sym) { + assert(Sym.isDefined() && "Sym is not a defined symbol"); + Sym.getBlock().getSection().removeSymbol(Sym); + destroySymbol(Sym); + } + + /// Remove a block. + void removeBlock(Block &B) { + Blocks.erase(&B); + destroyBlock(B); + } + + /// Dump the graph. + /// + /// If supplied, the EdgeKindToName function will be used to name edge + /// kinds in the debug output. Otherwise raw edge kind numbers will be + /// displayed. + void dump(raw_ostream &OS, + std::function<StringRef(Edge::Kind)> EdegKindToName = + std::function<StringRef(Edge::Kind)>()); + +private: + // Put the BumpPtrAllocator first so that we don't free any of the underlying + // memory until the Symbol/Addressable destructors have been run. + BumpPtrAllocator Allocator; + + std::string Name; + unsigned PointerSize; + support::endianness Endianness; + BlockSet Blocks; + SectionList Sections; + ExternalSymbolSet ExternalSymbols; + ExternalSymbolSet AbsoluteSymbols; +}; + +/// A function for mutating LinkGraphs. +using LinkGraphPassFunction = std::function<Error(LinkGraph &)>; + +/// A list of LinkGraph passes. +using LinkGraphPassList = std::vector<LinkGraphPassFunction>; + +/// An LinkGraph pass configuration, consisting of a list of pre-prune, +/// post-prune, and post-fixup passes. +struct PassConfiguration { + + /// Pre-prune passes. + /// + /// These passes are called on the graph after it is built, and before any + /// symbols have been pruned. + /// + /// Notable use cases: Marking symbols live or should-discard. + LinkGraphPassList PrePrunePasses; + + /// Post-prune passes. + /// + /// These passes are called on the graph after dead stripping, but before + /// fixups are applied. + /// + /// Notable use cases: Building GOT, stub, and TLV symbols. + LinkGraphPassList PostPrunePasses; + + /// Post-fixup passes. + /// + /// These passes are called on the graph after block contents has been copied + /// to working memory, and fixups applied. + /// + /// Notable use cases: Testing and validation. + LinkGraphPassList PostFixupPasses; +}; + +/// A map of symbol names to resolved addresses. +using AsyncLookupResult = DenseMap<StringRef, JITEvaluatedSymbol>; + +/// A function object to call with a resolved symbol map (See AsyncLookupResult) +/// or an error if resolution failed. +class JITLinkAsyncLookupContinuation { +public: + virtual ~JITLinkAsyncLookupContinuation() {} + virtual void run(Expected<AsyncLookupResult> LR) = 0; + +private: + virtual void anchor(); +}; + +/// Create a lookup continuation from a function object. +template <typename Continuation> +std::unique_ptr<JITLinkAsyncLookupContinuation> +createLookupContinuation(Continuation Cont) { + + class Impl final : public JITLinkAsyncLookupContinuation { + public: + Impl(Continuation C) : C(std::move(C)) {} + void run(Expected<AsyncLookupResult> LR) override { C(std::move(LR)); } + + private: + Continuation C; + }; + + return std::make_unique<Impl>(std::move(Cont)); +} + +/// Holds context for a single jitLink invocation. +class JITLinkContext { +public: + /// Destroy a JITLinkContext. + virtual ~JITLinkContext(); + + /// Return the MemoryManager to be used for this link. + virtual JITLinkMemoryManager &getMemoryManager() = 0; + + /// Returns a StringRef for the object buffer. + /// This method can not be called once takeObjectBuffer has been called. + virtual MemoryBufferRef getObjectBuffer() const = 0; + + /// Notify this context that linking failed. + /// Called by JITLink if linking cannot be completed. + virtual void notifyFailed(Error Err) = 0; + + /// Called by JITLink to resolve external symbols. This method is passed a + /// lookup continutation which it must call with a result to continue the + /// linking process. + virtual void lookup(const DenseSet<StringRef> &Symbols, + std::unique_ptr<JITLinkAsyncLookupContinuation> LC) = 0; + + /// Called by JITLink once all defined symbols in the graph have been assigned + /// their final memory locations in the target process. At this point the + /// LinkGraph can be inspected to build a symbol table, however the block + /// content will not generally have been copied to the target location yet. + virtual void notifyResolved(LinkGraph &G) = 0; + + /// Called by JITLink to notify the context that the object has been + /// finalized (i.e. emitted to memory and memory permissions set). If all of + /// this objects dependencies have also been finalized then the code is ready + /// to run. + virtual void + notifyFinalized(std::unique_ptr<JITLinkMemoryManager::Allocation> A) = 0; + + /// Called by JITLink prior to linking to determine whether default passes for + /// the target should be added. The default implementation returns true. + /// If subclasses override this method to return false for any target then + /// they are required to fully configure the pass pipeline for that target. + virtual bool shouldAddDefaultTargetPasses(const Triple &TT) const; + + /// Returns the mark-live pass to be used for this link. If no pass is + /// returned (the default) then the target-specific linker implementation will + /// choose a conservative default (usually marking all symbols live). + /// This function is only called if shouldAddDefaultTargetPasses returns true, + /// otherwise the JITContext is responsible for adding a mark-live pass in + /// modifyPassConfig. + virtual LinkGraphPassFunction getMarkLivePass(const Triple &TT) const; + + /// Called by JITLink to modify the pass pipeline prior to linking. + /// The default version performs no modification. + virtual Error modifyPassConfig(const Triple &TT, PassConfiguration &Config); +}; + +/// Marks all symbols in a graph live. This can be used as a default, +/// conservative mark-live implementation. +Error markAllSymbolsLive(LinkGraph &G); + +/// Basic JITLink implementation. +/// +/// This function will use sensible defaults for GOT and Stub handling. +void jitLink(std::unique_ptr<JITLinkContext> Ctx); + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_JITLINK_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h b/llvm/include/llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h new file mode 100644 index 000000000000..ac5a593bb77b --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h @@ -0,0 +1,98 @@ +//===-- JITLinkMemoryManager.h - JITLink mem manager interface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the JITLinkMemoryManager interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_JITLINKMEMORYMANAGER_H +#define LLVM_EXECUTIONENGINE_JITLINK_JITLINKMEMORYMANAGER_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/Memory.h" +#include <cstdint> + +namespace llvm { +namespace jitlink { + +/// Manages allocations of JIT memory. +/// +/// Instances of this class may be accessed concurrently from multiple threads +/// and their implemetations should include any necessary synchronization. +class JITLinkMemoryManager { +public: + using ProtectionFlags = sys::Memory::ProtectionFlags; + + class SegmentRequest { + public: + SegmentRequest() = default; + SegmentRequest(uint64_t Alignment, size_t ContentSize, + uint64_t ZeroFillSize) + : Alignment(Alignment), ContentSize(ContentSize), + ZeroFillSize(ZeroFillSize) { + assert(isPowerOf2_32(Alignment) && "Alignment must be power of 2"); + } + uint64_t getAlignment() const { return Alignment; } + size_t getContentSize() const { return ContentSize; } + uint64_t getZeroFillSize() const { return ZeroFillSize; } + private: + uint64_t Alignment = 0; + size_t ContentSize = 0; + uint64_t ZeroFillSize = 0; + }; + + using SegmentsRequestMap = DenseMap<unsigned, SegmentRequest>; + + /// Represents an allocation created by the memory manager. + /// + /// An allocation object is responsible for allocating and owning jit-linker + /// working and target memory, and for transfering from working to target + /// memory. + /// + class Allocation { + public: + using FinalizeContinuation = std::function<void(Error)>; + + virtual ~Allocation(); + + /// Should return the address of linker working memory for the segment with + /// the given protection flags. + virtual MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) = 0; + + /// Should return the final address in the target process where the segment + /// will reside. + virtual JITTargetAddress getTargetMemory(ProtectionFlags Seg) = 0; + + /// Should transfer from working memory to target memory, and release + /// working memory. + virtual void finalizeAsync(FinalizeContinuation OnFinalize) = 0; + + /// Should deallocate target memory. + virtual Error deallocate() = 0; + }; + + virtual ~JITLinkMemoryManager(); + + /// Create an Allocation object. + virtual Expected<std::unique_ptr<Allocation>> + allocate(const SegmentsRequestMap &Request) = 0; +}; + +/// A JITLinkMemoryManager that allocates in-process memory. +class InProcessMemoryManager : public JITLinkMemoryManager { +public: + Expected<std::unique_ptr<Allocation>> + allocate(const SegmentsRequestMap &Request) override; +}; + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_JITLINK_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/MachO.h b/llvm/include/llvm/ExecutionEngine/JITLink/MachO.h new file mode 100644 index 000000000000..7facb657a51c --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/MachO.h @@ -0,0 +1,30 @@ +//===------- MachO.h - Generic JIT link function for MachO ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Generic jit-link functions for MachO. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_MACHO_H +#define LLVM_EXECUTIONENGINE_JITLINK_MACHO_H + +#include "llvm/ExecutionEngine/JITLink/JITLink.h" + +namespace llvm { +namespace jitlink { + +/// jit-link the given ObjBuffer, which must be a MachO object file. +/// +/// Uses conservative defaults for GOT and stub handling based on the target +/// platform. +void jitLink_MachO(std::unique_ptr<JITLinkContext> Ctx); + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_MACHO_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/MachO_arm64.h b/llvm/include/llvm/ExecutionEngine/JITLink/MachO_arm64.h new file mode 100644 index 000000000000..d70b545fff86 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/MachO_arm64.h @@ -0,0 +1,60 @@ +//===---- MachO_arm64.h - JIT link functions for MachO/arm64 ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// jit-link functions for MachO/arm64. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_MACHO_ARM64_H +#define LLVM_EXECUTIONENGINE_JITLINK_MACHO_ARM64_H + +#include "llvm/ExecutionEngine/JITLink/JITLink.h" + +namespace llvm { +namespace jitlink { + +namespace MachO_arm64_Edges { + +enum MachOARM64RelocationKind : Edge::Kind { + Branch26 = Edge::FirstRelocation, + Pointer32, + Pointer64, + Pointer64Anon, + Page21, + PageOffset12, + GOTPage21, + GOTPageOffset12, + PointerToGOT, + PairedAddend, + LDRLiteral19, + Delta32, + Delta64, + NegDelta32, + NegDelta64, +}; + +} // namespace MachO_arm64_Edges + +/// jit-link the given object buffer, which must be a MachO arm64 object file. +/// +/// If PrePrunePasses is empty then a default mark-live pass will be inserted +/// that will mark all exported atoms live. If PrePrunePasses is not empty, the +/// caller is responsible for including a pass to mark atoms as live. +/// +/// If PostPrunePasses is empty then a default GOT-and-stubs insertion pass will +/// be inserted. If PostPrunePasses is not empty then the caller is responsible +/// for including a pass to insert GOT and stub edges. +void jitLink_MachO_arm64(std::unique_ptr<JITLinkContext> Ctx); + +/// Return the string name of the given MachO arm64 edge kind. +StringRef getMachOARM64RelocationKindName(Edge::Kind R); + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_MACHO_ARM64_H diff --git a/llvm/include/llvm/ExecutionEngine/JITLink/MachO_x86_64.h b/llvm/include/llvm/ExecutionEngine/JITLink/MachO_x86_64.h new file mode 100644 index 000000000000..00a7feb86e83 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITLink/MachO_x86_64.h @@ -0,0 +1,64 @@ +//===--- MachO_x86_64.h - JIT link functions for MachO/x86-64 ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// jit-link functions for MachO/x86-64. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITLINK_MACHO_X86_64_H +#define LLVM_EXECUTIONENGINE_JITLINK_MACHO_X86_64_H + +#include "llvm/ExecutionEngine/JITLink/JITLink.h" + +namespace llvm { +namespace jitlink { + +namespace MachO_x86_64_Edges { + +enum MachOX86RelocationKind : Edge::Kind { + Branch32 = Edge::FirstRelocation, + Pointer32, + Pointer64, + Pointer64Anon, + PCRel32, + PCRel32Minus1, + PCRel32Minus2, + PCRel32Minus4, + PCRel32Anon, + PCRel32Minus1Anon, + PCRel32Minus2Anon, + PCRel32Minus4Anon, + PCRel32GOTLoad, + PCRel32GOT, + PCRel32TLV, + Delta32, + Delta64, + NegDelta32, + NegDelta64, +}; + +} // namespace MachO_x86_64_Edges + +/// jit-link the given object buffer, which must be a MachO x86-64 object file. +/// +/// If PrePrunePasses is empty then a default mark-live pass will be inserted +/// that will mark all exported atoms live. If PrePrunePasses is not empty, the +/// caller is responsible for including a pass to mark atoms as live. +/// +/// If PostPrunePasses is empty then a default GOT-and-stubs insertion pass will +/// be inserted. If PostPrunePasses is not empty then the caller is responsible +/// for including a pass to insert GOT and stub edges. +void jitLink_MachO_x86_64(std::unique_ptr<JITLinkContext> Ctx); + +/// Return the string name of the given MachO x86-64 edge kind. +StringRef getMachOX86RelocationKindName(Edge::Kind R); + +} // end namespace jitlink +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITLINK_MACHO_X86_64_H diff --git a/llvm/include/llvm/ExecutionEngine/JITSymbol.h b/llvm/include/llvm/ExecutionEngine/JITSymbol.h new file mode 100644 index 000000000000..c0f1ca4b9876 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/JITSymbol.h @@ -0,0 +1,391 @@ +//===- JITSymbol.h - JIT symbol abstraction ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Abstraction for target process addresses. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_JITSYMBOL_H +#define LLVM_EXECUTIONENGINE_JITSYMBOL_H + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <functional> +#include <map> +#include <set> +#include <string> + +#include "llvm/ADT/BitmaskEnum.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" + +namespace llvm { + +class GlobalValue; + +namespace object { + +class SymbolRef; + +} // end namespace object + +/// Represents an address in the target process's address space. +using JITTargetAddress = uint64_t; + +/// Convert a JITTargetAddress to a pointer. +template <typename T> T jitTargetAddressToPointer(JITTargetAddress Addr) { + static_assert(std::is_pointer<T>::value, "T must be a pointer type"); + uintptr_t IntPtr = static_cast<uintptr_t>(Addr); + assert(IntPtr == Addr && "JITTargetAddress value out of range for uintptr_t"); + return reinterpret_cast<T>(IntPtr); +} + +template <typename T> JITTargetAddress pointerToJITTargetAddress(T *Ptr) { + return static_cast<JITTargetAddress>(reinterpret_cast<uintptr_t>(Ptr)); +} + +/// Flags for symbols in the JIT. +class JITSymbolFlags { +public: + using UnderlyingType = uint8_t; + using TargetFlagsType = uint8_t; + + enum FlagNames : UnderlyingType { + None = 0, + HasError = 1U << 0, + Weak = 1U << 1, + Common = 1U << 2, + Absolute = 1U << 3, + Exported = 1U << 4, + Callable = 1U << 5, + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ Callable) + }; + + /// Default-construct a JITSymbolFlags instance. + JITSymbolFlags() = default; + + /// Construct a JITSymbolFlags instance from the given flags. + JITSymbolFlags(FlagNames Flags) : Flags(Flags) {} + + /// Construct a JITSymbolFlags instance from the given flags and target + /// flags. + JITSymbolFlags(FlagNames Flags, TargetFlagsType TargetFlags) + : TargetFlags(TargetFlags), Flags(Flags) {} + + /// Implicitly convert to bool. Returs true if any flag is set. + explicit operator bool() const { return Flags != None || TargetFlags != 0; } + + /// Compare for equality. + bool operator==(const JITSymbolFlags &RHS) const { + return Flags == RHS.Flags && TargetFlags == RHS.TargetFlags; + } + + /// Bitwise AND-assignment for FlagNames. + JITSymbolFlags &operator&=(const FlagNames &RHS) { + Flags &= RHS; + return *this; + } + + /// Bitwise OR-assignment for FlagNames. + JITSymbolFlags &operator|=(const FlagNames &RHS) { + Flags |= RHS; + return *this; + } + + /// Return true if there was an error retrieving this symbol. + bool hasError() const { + return (Flags & HasError) == HasError; + } + + /// Returns true if the Weak flag is set. + bool isWeak() const { + return (Flags & Weak) == Weak; + } + + /// Returns true if the Common flag is set. + bool isCommon() const { + return (Flags & Common) == Common; + } + + /// Returns true if the symbol isn't weak or common. + bool isStrong() const { + return !isWeak() && !isCommon(); + } + + /// Returns true if the Exported flag is set. + bool isExported() const { + return (Flags & Exported) == Exported; + } + + /// Returns true if the given symbol is known to be callable. + bool isCallable() const { return (Flags & Callable) == Callable; } + + /// Get the underlying flags value as an integer. + UnderlyingType getRawFlagsValue() const { + return static_cast<UnderlyingType>(Flags); + } + + /// Return a reference to the target-specific flags. + TargetFlagsType& getTargetFlags() { return TargetFlags; } + + /// Return a reference to the target-specific flags. + const TargetFlagsType& getTargetFlags() const { return TargetFlags; } + + /// Construct a JITSymbolFlags value based on the flags of the given global + /// value. + static JITSymbolFlags fromGlobalValue(const GlobalValue &GV); + + /// Construct a JITSymbolFlags value based on the flags of the given libobject + /// symbol. + static Expected<JITSymbolFlags> + fromObjectSymbol(const object::SymbolRef &Symbol); + +private: + TargetFlagsType TargetFlags = 0; + FlagNames Flags = None; +}; + +inline JITSymbolFlags operator&(const JITSymbolFlags &LHS, + const JITSymbolFlags::FlagNames &RHS) { + JITSymbolFlags Tmp = LHS; + Tmp &= RHS; + return Tmp; +} + +inline JITSymbolFlags operator|(const JITSymbolFlags &LHS, + const JITSymbolFlags::FlagNames &RHS) { + JITSymbolFlags Tmp = LHS; + Tmp |= RHS; + return Tmp; +} + +/// ARM-specific JIT symbol flags. +/// FIXME: This should be moved into a target-specific header. +class ARMJITSymbolFlags { +public: + ARMJITSymbolFlags() = default; + + enum FlagNames { + None = 0, + Thumb = 1 << 0 + }; + + operator JITSymbolFlags::TargetFlagsType&() { return Flags; } + + static ARMJITSymbolFlags fromObjectSymbol(const object::SymbolRef &Symbol); + +private: + JITSymbolFlags::TargetFlagsType Flags = 0; +}; + +/// Represents a symbol that has been evaluated to an address already. +class JITEvaluatedSymbol { +public: + JITEvaluatedSymbol() = default; + + /// Create a 'null' symbol. + JITEvaluatedSymbol(std::nullptr_t) {} + + /// Create a symbol for the given address and flags. + JITEvaluatedSymbol(JITTargetAddress Address, JITSymbolFlags Flags) + : Address(Address), Flags(Flags) {} + + /// An evaluated symbol converts to 'true' if its address is non-zero. + explicit operator bool() const { return Address != 0; } + + /// Return the address of this symbol. + JITTargetAddress getAddress() const { return Address; } + + /// Return the flags for this symbol. + JITSymbolFlags getFlags() const { return Flags; } + + /// Set the flags for this symbol. + void setFlags(JITSymbolFlags Flags) { this->Flags = std::move(Flags); } + +private: + JITTargetAddress Address = 0; + JITSymbolFlags Flags; +}; + +/// Represents a symbol in the JIT. +class JITSymbol { +public: + using GetAddressFtor = unique_function<Expected<JITTargetAddress>()>; + + /// Create a 'null' symbol, used to represent a "symbol not found" + /// result from a successful (non-erroneous) lookup. + JITSymbol(std::nullptr_t) + : CachedAddr(0) {} + + /// Create a JITSymbol representing an error in the symbol lookup + /// process (e.g. a network failure during a remote lookup). + JITSymbol(Error Err) + : Err(std::move(Err)), Flags(JITSymbolFlags::HasError) {} + + /// Create a symbol for a definition with a known address. + JITSymbol(JITTargetAddress Addr, JITSymbolFlags Flags) + : CachedAddr(Addr), Flags(Flags) {} + + /// Construct a JITSymbol from a JITEvaluatedSymbol. + JITSymbol(JITEvaluatedSymbol Sym) + : CachedAddr(Sym.getAddress()), Flags(Sym.getFlags()) {} + + /// Create a symbol for a definition that doesn't have a known address + /// yet. + /// @param GetAddress A functor to materialize a definition (fixing the + /// address) on demand. + /// + /// This constructor allows a JIT layer to provide a reference to a symbol + /// definition without actually materializing the definition up front. The + /// user can materialize the definition at any time by calling the getAddress + /// method. + JITSymbol(GetAddressFtor GetAddress, JITSymbolFlags Flags) + : GetAddress(std::move(GetAddress)), CachedAddr(0), Flags(Flags) {} + + JITSymbol(const JITSymbol&) = delete; + JITSymbol& operator=(const JITSymbol&) = delete; + + JITSymbol(JITSymbol &&Other) + : GetAddress(std::move(Other.GetAddress)), Flags(std::move(Other.Flags)) { + if (Flags.hasError()) + Err = std::move(Other.Err); + else + CachedAddr = std::move(Other.CachedAddr); + } + + JITSymbol& operator=(JITSymbol &&Other) { + GetAddress = std::move(Other.GetAddress); + Flags = std::move(Other.Flags); + if (Flags.hasError()) + Err = std::move(Other.Err); + else + CachedAddr = std::move(Other.CachedAddr); + return *this; + } + + ~JITSymbol() { + if (Flags.hasError()) + Err.~Error(); + else + CachedAddr.~JITTargetAddress(); + } + + /// Returns true if the symbol exists, false otherwise. + explicit operator bool() const { + return !Flags.hasError() && (CachedAddr || GetAddress); + } + + /// Move the error field value out of this JITSymbol. + Error takeError() { + if (Flags.hasError()) + return std::move(Err); + return Error::success(); + } + + /// Get the address of the symbol in the target address space. Returns + /// '0' if the symbol does not exist. + Expected<JITTargetAddress> getAddress() { + assert(!Flags.hasError() && "getAddress called on error value"); + if (GetAddress) { + if (auto CachedAddrOrErr = GetAddress()) { + GetAddress = nullptr; + CachedAddr = *CachedAddrOrErr; + assert(CachedAddr && "Symbol could not be materialized."); + } else + return CachedAddrOrErr.takeError(); + } + return CachedAddr; + } + + JITSymbolFlags getFlags() const { return Flags; } + +private: + GetAddressFtor GetAddress; + union { + JITTargetAddress CachedAddr; + Error Err; + }; + JITSymbolFlags Flags; +}; + +/// Symbol resolution interface. +/// +/// Allows symbol flags and addresses to be looked up by name. +/// Symbol queries are done in bulk (i.e. you request resolution of a set of +/// symbols, rather than a single one) to reduce IPC overhead in the case of +/// remote JITing, and expose opportunities for parallel compilation. +class JITSymbolResolver { +public: + using LookupSet = std::set<StringRef>; + using LookupResult = std::map<StringRef, JITEvaluatedSymbol>; + using OnResolvedFunction = unique_function<void(Expected<LookupResult>)>; + + virtual ~JITSymbolResolver() = default; + + /// Returns the fully resolved address and flags for each of the given + /// symbols. + /// + /// This method will return an error if any of the given symbols can not be + /// resolved, or if the resolution process itself triggers an error. + virtual void lookup(const LookupSet &Symbols, + OnResolvedFunction OnResolved) = 0; + + /// Returns the subset of the given symbols that should be materialized by + /// the caller. Only weak/common symbols should be looked up, as strong + /// definitions are implicitly always part of the caller's responsibility. + virtual Expected<LookupSet> + getResponsibilitySet(const LookupSet &Symbols) = 0; + +private: + virtual void anchor(); +}; + +/// Legacy symbol resolution interface. +class LegacyJITSymbolResolver : public JITSymbolResolver { +public: + /// Performs lookup by, for each symbol, first calling + /// findSymbolInLogicalDylib and if that fails calling + /// findSymbol. + void lookup(const LookupSet &Symbols, OnResolvedFunction OnResolved) final; + + /// Performs flags lookup by calling findSymbolInLogicalDylib and + /// returning the flags value for that symbol. + Expected<LookupSet> getResponsibilitySet(const LookupSet &Symbols) final; + + /// This method returns the address of the specified symbol if it exists + /// within the logical dynamic library represented by this JITSymbolResolver. + /// Unlike findSymbol, queries through this interface should return addresses + /// for hidden symbols. + /// + /// This is of particular importance for the Orc JIT APIs, which support lazy + /// compilation by breaking up modules: Each of those broken out modules + /// must be able to resolve hidden symbols provided by the others. Clients + /// writing memory managers for MCJIT can usually ignore this method. + /// + /// This method will be queried by RuntimeDyld when checking for previous + /// definitions of common symbols. + virtual JITSymbol findSymbolInLogicalDylib(const std::string &Name) = 0; + + /// This method returns the address of the specified function or variable. + /// It is used to resolve symbols during module linking. + /// + /// If the returned symbol's address is equal to ~0ULL then RuntimeDyld will + /// skip all relocations for that symbol, and the client will be responsible + /// for handling them manually. + virtual JITSymbol findSymbol(const std::string &Name) = 0; + +private: + virtual void anchor(); +}; + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_JITSYMBOL_H diff --git a/llvm/include/llvm/ExecutionEngine/MCJIT.h b/llvm/include/llvm/ExecutionEngine/MCJIT.h new file mode 100644 index 000000000000..8253bf98963b --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/MCJIT.h @@ -0,0 +1,37 @@ +//===-- MCJIT.h - MC-Based Just-In-Time Execution Engine --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file forces the MCJIT to link in on certain operating systems. +// (Windows). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_MCJIT_H +#define LLVM_EXECUTIONENGINE_MCJIT_H + +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include <cstdlib> + +extern "C" void LLVMLinkInMCJIT(); + +namespace { + struct ForceMCJITLinking { + ForceMCJITLinking() { + // We must reference MCJIT in such a way that compilers will not + // delete it all as dead code, even with whole program optimization, + // yet is effectively a NO-OP. As the compiler isn't smart enough + // to know that getenv() never returns -1, this will do the job. + if (std::getenv("bar") != (char*) -1) + return; + + LLVMLinkInMCJIT(); + } + } ForceMCJITLinking; +} + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/OProfileWrapper.h b/llvm/include/llvm/ExecutionEngine/OProfileWrapper.h new file mode 100644 index 000000000000..b13d7f6e245b --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/OProfileWrapper.h @@ -0,0 +1,123 @@ +//===-- OProfileWrapper.h - OProfile JIT API Wrapper ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file defines a OProfileWrapper object that detects if the oprofile +// daemon is running, and provides wrappers for opagent functions used to +// communicate with the oprofile JIT interface. The dynamic library libopagent +// does not need to be linked directly as this object lazily loads the library +// when the first op_ function is called. +// +// See http://oprofile.sourceforge.net/doc/devel/jit-interface.html for the +// definition of the interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_OPROFILEWRAPPER_H +#define LLVM_EXECUTIONENGINE_OPROFILEWRAPPER_H + +#include "llvm/Support/DataTypes.h" +#include <opagent.h> + +namespace llvm { + + +class OProfileWrapper { + typedef op_agent_t (*op_open_agent_ptr_t)(); + typedef int (*op_close_agent_ptr_t)(op_agent_t); + typedef int (*op_write_native_code_ptr_t)(op_agent_t, + const char*, + uint64_t, + void const*, + const unsigned int); + typedef int (*op_write_debug_line_info_ptr_t)(op_agent_t, + void const*, + size_t, + struct debug_line_info const*); + typedef int (*op_unload_native_code_ptr_t)(op_agent_t, uint64_t); + + // Also used for op_minor_version function which has the same signature + typedef int (*op_major_version_ptr_t)(); + + // This is not a part of the opagent API, but is useful nonetheless + typedef bool (*IsOProfileRunningPtrT)(); + + + op_agent_t Agent; + op_open_agent_ptr_t OpenAgentFunc; + op_close_agent_ptr_t CloseAgentFunc; + op_write_native_code_ptr_t WriteNativeCodeFunc; + op_write_debug_line_info_ptr_t WriteDebugLineInfoFunc; + op_unload_native_code_ptr_t UnloadNativeCodeFunc; + op_major_version_ptr_t MajorVersionFunc; + op_major_version_ptr_t MinorVersionFunc; + IsOProfileRunningPtrT IsOProfileRunningFunc; + + bool Initialized; + +public: + OProfileWrapper(); + + // For testing with a mock opagent implementation, skips the dynamic load and + // the function resolution. + OProfileWrapper(op_open_agent_ptr_t OpenAgentImpl, + op_close_agent_ptr_t CloseAgentImpl, + op_write_native_code_ptr_t WriteNativeCodeImpl, + op_write_debug_line_info_ptr_t WriteDebugLineInfoImpl, + op_unload_native_code_ptr_t UnloadNativeCodeImpl, + op_major_version_ptr_t MajorVersionImpl, + op_major_version_ptr_t MinorVersionImpl, + IsOProfileRunningPtrT MockIsOProfileRunningImpl = 0) + : OpenAgentFunc(OpenAgentImpl), + CloseAgentFunc(CloseAgentImpl), + WriteNativeCodeFunc(WriteNativeCodeImpl), + WriteDebugLineInfoFunc(WriteDebugLineInfoImpl), + UnloadNativeCodeFunc(UnloadNativeCodeImpl), + MajorVersionFunc(MajorVersionImpl), + MinorVersionFunc(MinorVersionImpl), + IsOProfileRunningFunc(MockIsOProfileRunningImpl), + Initialized(true) + { + } + + // Calls op_open_agent in the oprofile JIT library and saves the returned + // op_agent_t handle internally so it can be used when calling all the other + // op_* functions. Callers of this class do not need to keep track of + // op_agent_t objects. + bool op_open_agent(); + + int op_close_agent(); + int op_write_native_code(const char* name, + uint64_t addr, + void const* code, + const unsigned int size); + int op_write_debug_line_info(void const* code, + size_t num_entries, + struct debug_line_info const* info); + int op_unload_native_code(uint64_t addr); + int op_major_version(); + int op_minor_version(); + + // Returns true if the oprofiled process is running, the opagent library is + // loaded and a connection to the agent has been established, and false + // otherwise. + bool isAgentAvailable(); + +private: + // Loads the libopagent library and initializes this wrapper if the oprofile + // daemon is running + bool initialize(); + + // Searches /proc for the oprofile daemon and returns true if the process if + // found, or false otherwise. + bool checkForOProfileProcEntry(); + + bool isOProfileRunning(); +}; + +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_OPROFILEWRAPPER_H diff --git a/llvm/include/llvm/ExecutionEngine/ObjectCache.h b/llvm/include/llvm/ExecutionEngine/ObjectCache.h new file mode 100644 index 000000000000..47e94f18a1c7 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/ObjectCache.h @@ -0,0 +1,41 @@ +//===-- ObjectCache.h - Class definition for the ObjectCache ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_OBJECTCACHE_H +#define LLVM_EXECUTIONENGINE_OBJECTCACHE_H + +#include "llvm/Support/MemoryBuffer.h" +#include <memory> + +namespace llvm { + +class Module; + +/// This is the base ObjectCache type which can be provided to an +/// ExecutionEngine for the purpose of avoiding compilation for Modules that +/// have already been compiled and an object file is available. +class ObjectCache { + virtual void anchor(); + +public: + ObjectCache() = default; + + virtual ~ObjectCache() = default; + + /// notifyObjectCompiled - Provides a pointer to compiled code for Module M. + virtual void notifyObjectCompiled(const Module *M, MemoryBufferRef Obj) = 0; + + /// Returns a pointer to a newly allocated MemoryBuffer that contains the + /// object which corresponds with Module M, or 0 if an object is not + /// available. + virtual std::unique_ptr<MemoryBuffer> getObject(const Module* M) = 0; +}; + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_OBJECTCACHE_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h new file mode 100644 index 000000000000..7946b5b7b209 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -0,0 +1,769 @@ +//===- CompileOnDemandLayer.h - Compile each function on demand -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// JIT layer for breaking up modules and inserting callbacks to allow +// individual functions to be compiled on demand. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" +#include "llvm/ExecutionEngine/Orc/LambdaResolver.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include "llvm/ExecutionEngine/Orc/LazyReexports.h" +#include "llvm/ExecutionEngine/Orc/Legacy.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/Speculation.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Mangler.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <iterator> +#include <list> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +namespace llvm { + +class Value; + +namespace orc { + +class ExtractingIRMaterializationUnit; + +class CompileOnDemandLayer : public IRLayer { + friend class PartitioningIRMaterializationUnit; + +public: + /// Builder for IndirectStubsManagers. + using IndirectStubsManagerBuilder = + std::function<std::unique_ptr<IndirectStubsManager>()>; + + using GlobalValueSet = std::set<const GlobalValue *>; + + /// Partitioning function. + using PartitionFunction = + std::function<Optional<GlobalValueSet>(GlobalValueSet Requested)>; + + /// Off-the-shelf partitioning which compiles all requested symbols (usually + /// a single function at a time). + static Optional<GlobalValueSet> compileRequested(GlobalValueSet Requested); + + /// Off-the-shelf partitioning which compiles whole modules whenever any + /// symbol in them is requested. + static Optional<GlobalValueSet> compileWholeModule(GlobalValueSet Requested); + + /// Construct a CompileOnDemandLayer. + CompileOnDemandLayer(ExecutionSession &ES, IRLayer &BaseLayer, + LazyCallThroughManager &LCTMgr, + IndirectStubsManagerBuilder BuildIndirectStubsManager); + + /// Sets the partition function. + void setPartitionFunction(PartitionFunction Partition); + + /// Sets the ImplSymbolMap + void setImplMap(ImplSymbolMap *Imp); + /// Emits the given module. This should not be called by clients: it will be + /// called by the JIT when a definition added via the add method is requested. + void emit(MaterializationResponsibility R, ThreadSafeModule TSM) override; + +private: + struct PerDylibResources { + public: + PerDylibResources(JITDylib &ImplD, + std::unique_ptr<IndirectStubsManager> ISMgr) + : ImplD(ImplD), ISMgr(std::move(ISMgr)) {} + JITDylib &getImplDylib() { return ImplD; } + IndirectStubsManager &getISManager() { return *ISMgr; } + + private: + JITDylib &ImplD; + std::unique_ptr<IndirectStubsManager> ISMgr; + }; + + using PerDylibResourcesMap = std::map<const JITDylib *, PerDylibResources>; + + PerDylibResources &getPerDylibResources(JITDylib &TargetD); + + void cleanUpModule(Module &M); + + void expandPartition(GlobalValueSet &Partition); + + void emitPartition(MaterializationResponsibility R, ThreadSafeModule TSM, + IRMaterializationUnit::SymbolNameToDefinitionMap Defs); + + mutable std::mutex CODLayerMutex; + + IRLayer &BaseLayer; + LazyCallThroughManager &LCTMgr; + IndirectStubsManagerBuilder BuildIndirectStubsManager; + PerDylibResourcesMap DylibResources; + PartitionFunction Partition = compileRequested; + SymbolLinkagePromoter PromoteSymbols; + ImplSymbolMap *AliaseeImpls = nullptr; +}; + +/// Compile-on-demand layer. +/// +/// When a module is added to this layer a stub is created for each of its +/// function definitions. The stubs and other global values are immediately +/// added to the layer below. When a stub is called it triggers the extraction +/// of the function body from the original module. The extracted body is then +/// compiled and executed. +template <typename BaseLayerT, + typename CompileCallbackMgrT = JITCompileCallbackManager, + typename IndirectStubsMgrT = IndirectStubsManager> +class LegacyCompileOnDemandLayer { +private: + template <typename MaterializerFtor> + class LambdaMaterializer final : public ValueMaterializer { + public: + LambdaMaterializer(MaterializerFtor M) : M(std::move(M)) {} + + Value *materialize(Value *V) final { return M(V); } + + private: + MaterializerFtor M; + }; + + template <typename MaterializerFtor> + LambdaMaterializer<MaterializerFtor> + createLambdaMaterializer(MaterializerFtor M) { + return LambdaMaterializer<MaterializerFtor>(std::move(M)); + } + + // Provide type-erasure for the Modules and MemoryManagers. + template <typename ResourceT> + class ResourceOwner { + public: + ResourceOwner() = default; + ResourceOwner(const ResourceOwner &) = delete; + ResourceOwner &operator=(const ResourceOwner &) = delete; + virtual ~ResourceOwner() = default; + + virtual ResourceT& getResource() const = 0; + }; + + template <typename ResourceT, typename ResourcePtrT> + class ResourceOwnerImpl : public ResourceOwner<ResourceT> { + public: + ResourceOwnerImpl(ResourcePtrT ResourcePtr) + : ResourcePtr(std::move(ResourcePtr)) {} + + ResourceT& getResource() const override { return *ResourcePtr; } + + private: + ResourcePtrT ResourcePtr; + }; + + template <typename ResourceT, typename ResourcePtrT> + std::unique_ptr<ResourceOwner<ResourceT>> + wrapOwnership(ResourcePtrT ResourcePtr) { + using RO = ResourceOwnerImpl<ResourceT, ResourcePtrT>; + return std::make_unique<RO>(std::move(ResourcePtr)); + } + + struct LogicalDylib { + struct SourceModuleEntry { + std::unique_ptr<Module> SourceMod; + std::set<Function*> StubsToClone; + }; + + using SourceModulesList = std::vector<SourceModuleEntry>; + using SourceModuleHandle = typename SourceModulesList::size_type; + + LogicalDylib() = default; + + LogicalDylib(VModuleKey K, std::shared_ptr<SymbolResolver> BackingResolver, + std::unique_ptr<IndirectStubsMgrT> StubsMgr) + : K(std::move(K)), BackingResolver(std::move(BackingResolver)), + StubsMgr(std::move(StubsMgr)) {} + + SourceModuleHandle addSourceModule(std::unique_ptr<Module> M) { + SourceModuleHandle H = SourceModules.size(); + SourceModules.push_back(SourceModuleEntry()); + SourceModules.back().SourceMod = std::move(M); + return H; + } + + Module& getSourceModule(SourceModuleHandle H) { + return *SourceModules[H].SourceMod; + } + + std::set<Function*>& getStubsToClone(SourceModuleHandle H) { + return SourceModules[H].StubsToClone; + } + + JITSymbol findSymbol(BaseLayerT &BaseLayer, const std::string &Name, + bool ExportedSymbolsOnly) { + if (auto Sym = StubsMgr->findStub(Name, ExportedSymbolsOnly)) + return Sym; + for (auto BLK : BaseLayerVModuleKeys) + if (auto Sym = BaseLayer.findSymbolIn(BLK, Name, ExportedSymbolsOnly)) + return Sym; + else if (auto Err = Sym.takeError()) + return std::move(Err); + return nullptr; + } + + Error removeModulesFromBaseLayer(BaseLayerT &BaseLayer) { + for (auto &BLK : BaseLayerVModuleKeys) + if (auto Err = BaseLayer.removeModule(BLK)) + return Err; + return Error::success(); + } + + VModuleKey K; + std::shared_ptr<SymbolResolver> BackingResolver; + std::unique_ptr<IndirectStubsMgrT> StubsMgr; + SymbolLinkagePromoter PromoteSymbols; + SourceModulesList SourceModules; + std::vector<VModuleKey> BaseLayerVModuleKeys; + }; + +public: + + /// Module partitioning functor. + using PartitioningFtor = std::function<std::set<Function*>(Function&)>; + + /// Builder for IndirectStubsManagers. + using IndirectStubsManagerBuilderT = + std::function<std::unique_ptr<IndirectStubsMgrT>()>; + + using SymbolResolverGetter = + std::function<std::shared_ptr<SymbolResolver>(VModuleKey K)>; + + using SymbolResolverSetter = + std::function<void(VModuleKey K, std::shared_ptr<SymbolResolver> R)>; + + /// Construct a compile-on-demand layer instance. + LLVM_ATTRIBUTE_DEPRECATED( + LegacyCompileOnDemandLayer( + ExecutionSession &ES, BaseLayerT &BaseLayer, + SymbolResolverGetter GetSymbolResolver, + SymbolResolverSetter SetSymbolResolver, PartitioningFtor Partition, + CompileCallbackMgrT &CallbackMgr, + IndirectStubsManagerBuilderT CreateIndirectStubsManager, + bool CloneStubsIntoPartitions = true), + "ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please " + "use " + "the ORCv2 LegacyCompileOnDemandLayer instead"); + + /// Legacy layer constructor with deprecation acknowledgement. + LegacyCompileOnDemandLayer( + ORCv1DeprecationAcknowledgement, ExecutionSession &ES, + BaseLayerT &BaseLayer, SymbolResolverGetter GetSymbolResolver, + SymbolResolverSetter SetSymbolResolver, PartitioningFtor Partition, + CompileCallbackMgrT &CallbackMgr, + IndirectStubsManagerBuilderT CreateIndirectStubsManager, + bool CloneStubsIntoPartitions = true) + : ES(ES), BaseLayer(BaseLayer), + GetSymbolResolver(std::move(GetSymbolResolver)), + SetSymbolResolver(std::move(SetSymbolResolver)), + Partition(std::move(Partition)), CompileCallbackMgr(CallbackMgr), + CreateIndirectStubsManager(std::move(CreateIndirectStubsManager)), + CloneStubsIntoPartitions(CloneStubsIntoPartitions) {} + + ~LegacyCompileOnDemandLayer() { + // FIXME: Report error on log. + while (!LogicalDylibs.empty()) + consumeError(removeModule(LogicalDylibs.begin()->first)); + } + + /// Add a module to the compile-on-demand layer. + Error addModule(VModuleKey K, std::unique_ptr<Module> M) { + + assert(!LogicalDylibs.count(K) && "VModuleKey K already in use"); + auto I = LogicalDylibs.insert( + LogicalDylibs.end(), + std::make_pair(K, LogicalDylib(K, GetSymbolResolver(K), + CreateIndirectStubsManager()))); + + return addLogicalModule(I->second, std::move(M)); + } + + /// Add extra modules to an existing logical module. + Error addExtraModule(VModuleKey K, std::unique_ptr<Module> M) { + return addLogicalModule(LogicalDylibs[K], std::move(M)); + } + + /// Remove the module represented by the given key. + /// + /// This will remove all modules in the layers below that were derived from + /// the module represented by K. + Error removeModule(VModuleKey K) { + auto I = LogicalDylibs.find(K); + assert(I != LogicalDylibs.end() && "VModuleKey K not valid here"); + auto Err = I->second.removeModulesFromBaseLayer(BaseLayer); + LogicalDylibs.erase(I); + return Err; + } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) { + for (auto &KV : LogicalDylibs) { + if (auto Sym = KV.second.StubsMgr->findStub(Name, ExportedSymbolsOnly)) + return Sym; + if (auto Sym = findSymbolIn(KV.first, Name, ExportedSymbolsOnly)) + return Sym; + else if (auto Err = Sym.takeError()) + return std::move(Err); + } + return BaseLayer.findSymbol(Name, ExportedSymbolsOnly); + } + + /// Get the address of a symbol provided by this layer, or some layer + /// below this one. + JITSymbol findSymbolIn(VModuleKey K, const std::string &Name, + bool ExportedSymbolsOnly) { + assert(LogicalDylibs.count(K) && "VModuleKey K is not valid here"); + return LogicalDylibs[K].findSymbol(BaseLayer, Name, ExportedSymbolsOnly); + } + + /// Update the stub for the given function to point at FnBodyAddr. + /// This can be used to support re-optimization. + /// @return true if the function exists and the stub is updated, false + /// otherwise. + // + // FIXME: We should track and free associated resources (unused compile + // callbacks, uncompiled IR, and no-longer-needed/reachable function + // implementations). + Error updatePointer(std::string FuncName, JITTargetAddress FnBodyAddr) { + //Find out which logical dylib contains our symbol + auto LDI = LogicalDylibs.begin(); + for (auto LDE = LogicalDylibs.end(); LDI != LDE; ++LDI) { + if (auto LMResources = + LDI->getLogicalModuleResourcesForSymbol(FuncName, false)) { + Module &SrcM = LMResources->SourceModule->getResource(); + std::string CalledFnName = mangle(FuncName, SrcM.getDataLayout()); + if (auto Err = LMResources->StubsMgr->updatePointer(CalledFnName, + FnBodyAddr)) + return Err; + return Error::success(); + } + } + return make_error<JITSymbolNotFound>(FuncName); + } + +private: + Error addLogicalModule(LogicalDylib &LD, std::unique_ptr<Module> SrcMPtr) { + + // Rename anonymous globals and promote linkage to ensure that everything + // will resolve properly after we partition SrcM. + LD.PromoteSymbols(*SrcMPtr); + + // Create a logical module handle for SrcM within the logical dylib. + Module &SrcM = *SrcMPtr; + auto LMId = LD.addSourceModule(std::move(SrcMPtr)); + + // Create stub functions. + const DataLayout &DL = SrcM.getDataLayout(); + { + typename IndirectStubsMgrT::StubInitsMap StubInits; + for (auto &F : SrcM) { + // Skip declarations. + if (F.isDeclaration()) + continue; + + // Skip weak functions for which we already have definitions. + auto MangledName = mangle(F.getName(), DL); + if (F.hasWeakLinkage() || F.hasLinkOnceLinkage()) { + if (auto Sym = LD.findSymbol(BaseLayer, MangledName, false)) + continue; + else if (auto Err = Sym.takeError()) + return std::move(Err); + } + + // Record all functions defined by this module. + if (CloneStubsIntoPartitions) + LD.getStubsToClone(LMId).insert(&F); + + // Create a callback, associate it with the stub for the function, + // and set the compile action to compile the partition containing the + // function. + auto CompileAction = [this, &LD, LMId, &F]() -> JITTargetAddress { + if (auto FnImplAddrOrErr = this->extractAndCompile(LD, LMId, F)) + return *FnImplAddrOrErr; + else { + // FIXME: Report error, return to 'abort' or something similar. + consumeError(FnImplAddrOrErr.takeError()); + return 0; + } + }; + if (auto CCAddr = + CompileCallbackMgr.getCompileCallback(std::move(CompileAction))) + StubInits[MangledName] = + std::make_pair(*CCAddr, JITSymbolFlags::fromGlobalValue(F)); + else + return CCAddr.takeError(); + } + + if (auto Err = LD.StubsMgr->createStubs(StubInits)) + return Err; + } + + // If this module doesn't contain any globals, aliases, or module flags then + // we can bail out early and avoid the overhead of creating and managing an + // empty globals module. + if (SrcM.global_empty() && SrcM.alias_empty() && + !SrcM.getModuleFlagsMetadata()) + return Error::success(); + + // Create the GlobalValues module. + auto GVsM = std::make_unique<Module>((SrcM.getName() + ".globals").str(), + SrcM.getContext()); + GVsM->setDataLayout(DL); + + ValueToValueMapTy VMap; + + // Clone global variable decls. + for (auto &GV : SrcM.globals()) + if (!GV.isDeclaration() && !VMap.count(&GV)) + cloneGlobalVariableDecl(*GVsM, GV, &VMap); + + // And the aliases. + for (auto &A : SrcM.aliases()) + if (!VMap.count(&A)) + cloneGlobalAliasDecl(*GVsM, A, VMap); + + // Clone the module flags. + cloneModuleFlagsMetadata(*GVsM, SrcM, VMap); + + // Now we need to clone the GV and alias initializers. + + // Initializers may refer to functions declared (but not defined) in this + // module. Build a materializer to clone decls on demand. + auto Materializer = createLambdaMaterializer( + [&LD, &GVsM](Value *V) -> Value* { + if (auto *F = dyn_cast<Function>(V)) { + // Decls in the original module just get cloned. + if (F->isDeclaration()) + return cloneFunctionDecl(*GVsM, *F); + + // Definitions in the original module (which we have emitted stubs + // for at this point) get turned into a constant alias to the stub + // instead. + const DataLayout &DL = GVsM->getDataLayout(); + std::string FName = mangle(F->getName(), DL); + unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(F->getType()); + JITTargetAddress StubAddr = + LD.StubsMgr->findStub(FName, false).getAddress(); + + ConstantInt *StubAddrCI = + ConstantInt::get(GVsM->getContext(), APInt(PtrBitWidth, StubAddr)); + Constant *Init = ConstantExpr::getCast(Instruction::IntToPtr, + StubAddrCI, F->getType()); + return GlobalAlias::create(F->getFunctionType(), + F->getType()->getAddressSpace(), + F->getLinkage(), F->getName(), + Init, GVsM.get()); + } + // else.... + return nullptr; + }); + + // Clone the global variable initializers. + for (auto &GV : SrcM.globals()) + if (!GV.isDeclaration()) + moveGlobalVariableInitializer(GV, VMap, &Materializer); + + // Clone the global alias initializers. + for (auto &A : SrcM.aliases()) { + auto *NewA = cast<GlobalAlias>(VMap[&A]); + assert(NewA && "Alias not cloned?"); + Value *Init = MapValue(A.getAliasee(), VMap, RF_None, nullptr, + &Materializer); + NewA->setAliasee(cast<Constant>(Init)); + } + + // Build a resolver for the globals module and add it to the base layer. + auto LegacyLookup = [this, &LD](const std::string &Name) -> JITSymbol { + if (auto Sym = LD.StubsMgr->findStub(Name, false)) + return Sym; + + if (auto Sym = LD.findSymbol(BaseLayer, Name, false)) + return Sym; + else if (auto Err = Sym.takeError()) + return std::move(Err); + + return nullptr; + }; + + auto GVsResolver = createSymbolResolver( + [&LD, LegacyLookup](const SymbolNameSet &Symbols) { + auto RS = getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup); + + if (!RS) { + logAllUnhandledErrors( + RS.takeError(), errs(), + "CODLayer/GVsResolver responsibility set lookup failed: "); + return SymbolNameSet(); + } + + if (RS->size() == Symbols.size()) + return *RS; + + SymbolNameSet NotFoundViaLegacyLookup; + for (auto &S : Symbols) + if (!RS->count(S)) + NotFoundViaLegacyLookup.insert(S); + auto RS2 = + LD.BackingResolver->getResponsibilitySet(NotFoundViaLegacyLookup); + + for (auto &S : RS2) + (*RS).insert(S); + + return *RS; + }, + [this, &LD, + LegacyLookup](std::shared_ptr<AsynchronousSymbolQuery> Query, + SymbolNameSet Symbols) { + auto NotFoundViaLegacyLookup = + lookupWithLegacyFn(ES, *Query, Symbols, LegacyLookup); + return LD.BackingResolver->lookup(Query, NotFoundViaLegacyLookup); + }); + + SetSymbolResolver(LD.K, std::move(GVsResolver)); + + if (auto Err = BaseLayer.addModule(LD.K, std::move(GVsM))) + return Err; + + LD.BaseLayerVModuleKeys.push_back(LD.K); + + return Error::success(); + } + + static std::string mangle(StringRef Name, const DataLayout &DL) { + std::string MangledName; + { + raw_string_ostream MangledNameStream(MangledName); + Mangler::getNameWithPrefix(MangledNameStream, Name, DL); + } + return MangledName; + } + + Expected<JITTargetAddress> + extractAndCompile(LogicalDylib &LD, + typename LogicalDylib::SourceModuleHandle LMId, + Function &F) { + Module &SrcM = LD.getSourceModule(LMId); + + // If F is a declaration we must already have compiled it. + if (F.isDeclaration()) + return 0; + + // Grab the name of the function being called here. + std::string CalledFnName = mangle(F.getName(), SrcM.getDataLayout()); + + JITTargetAddress CalledAddr = 0; + auto Part = Partition(F); + if (auto PartKeyOrErr = emitPartition(LD, LMId, Part)) { + auto &PartKey = *PartKeyOrErr; + for (auto *SubF : Part) { + std::string FnName = mangle(SubF->getName(), SrcM.getDataLayout()); + if (auto FnBodySym = BaseLayer.findSymbolIn(PartKey, FnName, false)) { + if (auto FnBodyAddrOrErr = FnBodySym.getAddress()) { + JITTargetAddress FnBodyAddr = *FnBodyAddrOrErr; + + // If this is the function we're calling record the address so we can + // return it from this function. + if (SubF == &F) + CalledAddr = FnBodyAddr; + + // Update the function body pointer for the stub. + if (auto EC = LD.StubsMgr->updatePointer(FnName, FnBodyAddr)) + return 0; + + } else + return FnBodyAddrOrErr.takeError(); + } else if (auto Err = FnBodySym.takeError()) + return std::move(Err); + else + llvm_unreachable("Function not emitted for partition"); + } + + LD.BaseLayerVModuleKeys.push_back(PartKey); + } else + return PartKeyOrErr.takeError(); + + return CalledAddr; + } + + template <typename PartitionT> + Expected<VModuleKey> + emitPartition(LogicalDylib &LD, + typename LogicalDylib::SourceModuleHandle LMId, + const PartitionT &Part) { + Module &SrcM = LD.getSourceModule(LMId); + + // Create the module. + std::string NewName = SrcM.getName(); + for (auto *F : Part) { + NewName += "."; + NewName += F->getName(); + } + + auto M = std::make_unique<Module>(NewName, SrcM.getContext()); + M->setDataLayout(SrcM.getDataLayout()); + ValueToValueMapTy VMap; + + auto Materializer = createLambdaMaterializer([&LD, &LMId, + &M](Value *V) -> Value * { + if (auto *GV = dyn_cast<GlobalVariable>(V)) + return cloneGlobalVariableDecl(*M, *GV); + + if (auto *F = dyn_cast<Function>(V)) { + // Check whether we want to clone an available_externally definition. + if (!LD.getStubsToClone(LMId).count(F)) + return cloneFunctionDecl(*M, *F); + + // Ok - we want an inlinable stub. For that to work we need a decl + // for the stub pointer. + auto *StubPtr = createImplPointer(*F->getType(), *M, + F->getName() + "$stub_ptr", nullptr); + auto *ClonedF = cloneFunctionDecl(*M, *F); + makeStub(*ClonedF, *StubPtr); + ClonedF->setLinkage(GlobalValue::AvailableExternallyLinkage); + ClonedF->addFnAttr(Attribute::AlwaysInline); + return ClonedF; + } + + if (auto *A = dyn_cast<GlobalAlias>(V)) { + auto *Ty = A->getValueType(); + if (Ty->isFunctionTy()) + return Function::Create(cast<FunctionType>(Ty), + GlobalValue::ExternalLinkage, A->getName(), + M.get()); + + return new GlobalVariable(*M, Ty, false, GlobalValue::ExternalLinkage, + nullptr, A->getName(), nullptr, + GlobalValue::NotThreadLocal, + A->getType()->getAddressSpace()); + } + + return nullptr; + }); + + // Create decls in the new module. + for (auto *F : Part) + cloneFunctionDecl(*M, *F, &VMap); + + // Move the function bodies. + for (auto *F : Part) + moveFunctionBody(*F, VMap, &Materializer); + + auto K = ES.allocateVModule(); + + auto LegacyLookup = [this, &LD](const std::string &Name) -> JITSymbol { + return LD.findSymbol(BaseLayer, Name, false); + }; + + // Create memory manager and symbol resolver. + auto Resolver = createSymbolResolver( + [&LD, LegacyLookup](const SymbolNameSet &Symbols) { + auto RS = getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup); + if (!RS) { + logAllUnhandledErrors( + RS.takeError(), errs(), + "CODLayer/SubResolver responsibility set lookup failed: "); + return SymbolNameSet(); + } + + if (RS->size() == Symbols.size()) + return *RS; + + SymbolNameSet NotFoundViaLegacyLookup; + for (auto &S : Symbols) + if (!RS->count(S)) + NotFoundViaLegacyLookup.insert(S); + + auto RS2 = + LD.BackingResolver->getResponsibilitySet(NotFoundViaLegacyLookup); + + for (auto &S : RS2) + (*RS).insert(S); + + return *RS; + }, + [this, &LD, LegacyLookup](std::shared_ptr<AsynchronousSymbolQuery> Q, + SymbolNameSet Symbols) { + auto NotFoundViaLegacyLookup = + lookupWithLegacyFn(ES, *Q, Symbols, LegacyLookup); + return LD.BackingResolver->lookup(Q, + std::move(NotFoundViaLegacyLookup)); + }); + SetSymbolResolver(K, std::move(Resolver)); + + if (auto Err = BaseLayer.addModule(std::move(K), std::move(M))) + return std::move(Err); + + return K; + } + + ExecutionSession &ES; + BaseLayerT &BaseLayer; + SymbolResolverGetter GetSymbolResolver; + SymbolResolverSetter SetSymbolResolver; + PartitioningFtor Partition; + CompileCallbackMgrT &CompileCallbackMgr; + IndirectStubsManagerBuilderT CreateIndirectStubsManager; + + std::map<VModuleKey, LogicalDylib> LogicalDylibs; + bool CloneStubsIntoPartitions; +}; + +template <typename BaseLayerT, typename CompileCallbackMgrT, + typename IndirectStubsMgrT> +LegacyCompileOnDemandLayer<BaseLayerT, CompileCallbackMgrT, IndirectStubsMgrT>:: + LegacyCompileOnDemandLayer( + ExecutionSession &ES, BaseLayerT &BaseLayer, + SymbolResolverGetter GetSymbolResolver, + SymbolResolverSetter SetSymbolResolver, PartitioningFtor Partition, + CompileCallbackMgrT &CallbackMgr, + IndirectStubsManagerBuilderT CreateIndirectStubsManager, + bool CloneStubsIntoPartitions) + : ES(ES), BaseLayer(BaseLayer), + GetSymbolResolver(std::move(GetSymbolResolver)), + SetSymbolResolver(std::move(SetSymbolResolver)), + Partition(std::move(Partition)), CompileCallbackMgr(CallbackMgr), + CreateIndirectStubsManager(std::move(CreateIndirectStubsManager)), + CloneStubsIntoPartitions(CloneStubsIntoPartitions) {} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/CompileUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/CompileUtils.h new file mode 100644 index 000000000000..eb6d84e8cbb4 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/CompileUtils.h @@ -0,0 +1,94 @@ +//===- CompileUtils.h - Utilities for compiling IR in the JIT ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains utilities for compiling IR to object files. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_COMPILEUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_COMPILEUTILS_H + +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include <memory> + +namespace llvm { + +class MCContext; +class MemoryBuffer; +class Module; +class ObjectCache; +class TargetMachine; + +namespace orc { + +class JITTargetMachineBuilder; + +/// Simple compile functor: Takes a single IR module and returns an ObjectFile. +/// This compiler supports a single compilation thread and LLVMContext only. +/// For multithreaded compilation, use ConcurrentIRCompiler below. +class SimpleCompiler { +public: + using CompileResult = std::unique_ptr<MemoryBuffer>; + + /// Construct a simple compile functor with the given target. + SimpleCompiler(TargetMachine &TM, ObjectCache *ObjCache = nullptr) + : TM(TM), ObjCache(ObjCache) {} + + /// Set an ObjectCache to query before compiling. + void setObjectCache(ObjectCache *NewCache) { ObjCache = NewCache; } + + /// Compile a Module to an ObjectFile. + CompileResult operator()(Module &M); + +private: + CompileResult tryToLoadFromObjectCache(const Module &M); + void notifyObjectCompiled(const Module &M, const MemoryBuffer &ObjBuffer); + + TargetMachine &TM; + ObjectCache *ObjCache = nullptr; +}; + +/// A SimpleCompiler that owns its TargetMachine. +/// +/// This convenient for clients who don't want to own their TargetMachines, +/// e.g. LLJIT. +class TMOwningSimpleCompiler : public SimpleCompiler { +public: + TMOwningSimpleCompiler(std::unique_ptr<TargetMachine> TM, + ObjectCache *ObjCache = nullptr) + : SimpleCompiler(*TM, ObjCache), TM(std::move(TM)) {} + +private: + // FIXME: shared because std::functions (and consequently + // IRCompileLayer::CompileFunction) are not moveable. + std::shared_ptr<llvm::TargetMachine> TM; +}; + +/// A thread-safe version of SimpleCompiler. +/// +/// This class creates a new TargetMachine and SimpleCompiler instance for each +/// compile. +class ConcurrentIRCompiler { +public: + ConcurrentIRCompiler(JITTargetMachineBuilder JTMB, + ObjectCache *ObjCache = nullptr); + + void setObjectCache(ObjectCache *ObjCache) { this->ObjCache = ObjCache; } + + std::unique_ptr<MemoryBuffer> operator()(Module &M); + +private: + JITTargetMachineBuilder JTMB; + ObjectCache *ObjCache = nullptr; +}; + +} // end namespace orc + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_COMPILEUTILS_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h new file mode 100644 index 000000000000..4f22a4c38796 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h @@ -0,0 +1,1029 @@ +//===------ Core.h -- Core ORC APIs (Layer, JITDylib, etc.) -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains core ORC APIs. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_CORE_H +#define LLVM_EXECUTIONENGINE_ORC_CORE_H + +#include "llvm/ADT/BitmaskEnum.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" +#include "llvm/ExecutionEngine/OrcV1Deprecation.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" + +#include <memory> +#include <vector> + +#define DEBUG_TYPE "orc" + +namespace llvm { +namespace orc { + +// Forward declare some classes. +class AsynchronousSymbolQuery; +class ExecutionSession; +class MaterializationUnit; +class MaterializationResponsibility; +class JITDylib; +enum class SymbolState : uint8_t; + +/// VModuleKey provides a unique identifier (allocated and managed by +/// ExecutionSessions) for a module added to the JIT. +using VModuleKey = uint64_t; + +/// A set of symbol names (represented by SymbolStringPtrs for +// efficiency). +using SymbolNameSet = DenseSet<SymbolStringPtr>; + +/// A map from symbol names (as SymbolStringPtrs) to JITSymbols +/// (address/flags pairs). +using SymbolMap = DenseMap<SymbolStringPtr, JITEvaluatedSymbol>; + +/// A map from symbol names (as SymbolStringPtrs) to JITSymbolFlags. +using SymbolFlagsMap = DenseMap<SymbolStringPtr, JITSymbolFlags>; + +/// A map from JITDylibs to sets of symbols. +using SymbolDependenceMap = DenseMap<JITDylib *, SymbolNameSet>; + +/// A list of (JITDylib*, bool) pairs. +using JITDylibSearchList = std::vector<std::pair<JITDylib *, bool>>; + +struct SymbolAliasMapEntry { + SymbolAliasMapEntry() = default; + SymbolAliasMapEntry(SymbolStringPtr Aliasee, JITSymbolFlags AliasFlags) + : Aliasee(std::move(Aliasee)), AliasFlags(AliasFlags) {} + + SymbolStringPtr Aliasee; + JITSymbolFlags AliasFlags; +}; + +/// A map of Symbols to (Symbol, Flags) pairs. +using SymbolAliasMap = DenseMap<SymbolStringPtr, SymbolAliasMapEntry>; + +/// Render a SymbolStringPtr. +raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPtr &Sym); + +/// Render a SymbolNameSet. +raw_ostream &operator<<(raw_ostream &OS, const SymbolNameSet &Symbols); + +/// Render a SymbolFlagsMap entry. +raw_ostream &operator<<(raw_ostream &OS, const SymbolFlagsMap::value_type &KV); + +/// Render a SymbolMap entry. +raw_ostream &operator<<(raw_ostream &OS, const SymbolMap::value_type &KV); + +/// Render a SymbolFlagsMap. +raw_ostream &operator<<(raw_ostream &OS, const SymbolFlagsMap &SymbolFlags); + +/// Render a SymbolMap. +raw_ostream &operator<<(raw_ostream &OS, const SymbolMap &Symbols); + +/// Render a SymbolDependenceMap entry. +raw_ostream &operator<<(raw_ostream &OS, + const SymbolDependenceMap::value_type &KV); + +/// Render a SymbolDependendeMap. +raw_ostream &operator<<(raw_ostream &OS, const SymbolDependenceMap &Deps); + +/// Render a MaterializationUnit. +raw_ostream &operator<<(raw_ostream &OS, const MaterializationUnit &MU); + +/// Render a JITDylibSearchList. +raw_ostream &operator<<(raw_ostream &OS, const JITDylibSearchList &JDs); + +/// Render a SymbolAliasMap. +raw_ostream &operator<<(raw_ostream &OS, const SymbolAliasMap &Aliases); + +/// Render a SymbolState. +raw_ostream &operator<<(raw_ostream &OS, const SymbolState &S); + +/// Callback to notify client that symbols have been resolved. +using SymbolsResolvedCallback = unique_function<void(Expected<SymbolMap>)>; + +/// Callback to register the dependencies for a given query. +using RegisterDependenciesFunction = + std::function<void(const SymbolDependenceMap &)>; + +/// This can be used as the value for a RegisterDependenciesFunction if there +/// are no dependants to register with. +extern RegisterDependenciesFunction NoDependenciesToRegister; + +/// Used to notify a JITDylib that the given set of symbols failed to +/// materialize. +class FailedToMaterialize : public ErrorInfo<FailedToMaterialize> { +public: + static char ID; + + FailedToMaterialize(std::shared_ptr<SymbolDependenceMap> Symbols); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const SymbolDependenceMap &getSymbols() const { return *Symbols; } + +private: + std::shared_ptr<SymbolDependenceMap> Symbols; +}; + +/// Used to notify clients when symbols can not be found during a lookup. +class SymbolsNotFound : public ErrorInfo<SymbolsNotFound> { +public: + static char ID; + + SymbolsNotFound(SymbolNameSet Symbols); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const SymbolNameSet &getSymbols() const { return Symbols; } + +private: + SymbolNameSet Symbols; +}; + +/// Used to notify clients that a set of symbols could not be removed. +class SymbolsCouldNotBeRemoved : public ErrorInfo<SymbolsCouldNotBeRemoved> { +public: + static char ID; + + SymbolsCouldNotBeRemoved(SymbolNameSet Symbols); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const SymbolNameSet &getSymbols() const { return Symbols; } + +private: + SymbolNameSet Symbols; +}; + +/// Tracks responsibility for materialization, and mediates interactions between +/// MaterializationUnits and JDs. +/// +/// An instance of this class is passed to MaterializationUnits when their +/// materialize method is called. It allows MaterializationUnits to resolve and +/// emit symbols, or abandon materialization by notifying any unmaterialized +/// symbols of an error. +class MaterializationResponsibility { + friend class MaterializationUnit; +public: + MaterializationResponsibility(MaterializationResponsibility &&) = default; + MaterializationResponsibility & + operator=(MaterializationResponsibility &&) = delete; + + /// Destruct a MaterializationResponsibility instance. In debug mode + /// this asserts that all symbols being tracked have been either + /// emitted or notified of an error. + ~MaterializationResponsibility(); + + /// Returns the target JITDylib that these symbols are being materialized + /// into. + JITDylib &getTargetJITDylib() const { return JD; } + + /// Returns the VModuleKey for this instance. + VModuleKey getVModuleKey() const { return K; } + + /// Returns the symbol flags map for this responsibility instance. + /// Note: The returned flags may have transient flags (Lazy, Materializing) + /// set. These should be stripped with JITSymbolFlags::stripTransientFlags + /// before using. + const SymbolFlagsMap &getSymbols() const { return SymbolFlags; } + + /// Returns the names of any symbols covered by this + /// MaterializationResponsibility object that have queries pending. This + /// information can be used to return responsibility for unrequested symbols + /// back to the JITDylib via the delegate method. + SymbolNameSet getRequestedSymbols() const; + + /// Notifies the target JITDylib that the given symbols have been resolved. + /// This will update the given symbols' addresses in the JITDylib, and notify + /// any pending queries on the given symbols of their resolution. The given + /// symbols must be ones covered by this MaterializationResponsibility + /// instance. Individual calls to this method may resolve a subset of the + /// symbols, but all symbols must have been resolved prior to calling emit. + /// + /// This method will return an error if any symbols being resolved have been + /// moved to the error state due to the failure of a dependency. If this + /// method returns an error then clients should log it and call + /// failMaterialize. If no dependencies have been registered for the + /// symbols covered by this MaterializationResponsibiility then this method + /// is guaranteed to return Error::success() and can be wrapped with cantFail. + Error notifyResolved(const SymbolMap &Symbols); + + /// Notifies the target JITDylib (and any pending queries on that JITDylib) + /// that all symbols covered by this MaterializationResponsibility instance + /// have been emitted. + /// + /// This method will return an error if any symbols being resolved have been + /// moved to the error state due to the failure of a dependency. If this + /// method returns an error then clients should log it and call + /// failMaterialize. If no dependencies have been registered for the + /// symbols covered by this MaterializationResponsibiility then this method + /// is guaranteed to return Error::success() and can be wrapped with cantFail. + Error notifyEmitted(); + + /// Adds new symbols to the JITDylib and this responsibility instance. + /// JITDylib entries start out in the materializing state. + /// + /// This method can be used by materialization units that want to add + /// additional symbols at materialization time (e.g. stubs, compile + /// callbacks, metadata). + Error defineMaterializing(const SymbolFlagsMap &SymbolFlags); + + /// Notify all not-yet-emitted covered by this MaterializationResponsibility + /// instance that an error has occurred. + /// This will remove all symbols covered by this MaterializationResponsibilty + /// from the target JITDylib, and send an error to any queries waiting on + /// these symbols. + void failMaterialization(); + + /// Transfers responsibility to the given MaterializationUnit for all + /// symbols defined by that MaterializationUnit. This allows + /// materializers to break up work based on run-time information (e.g. + /// by introspecting which symbols have actually been looked up and + /// materializing only those). + void replace(std::unique_ptr<MaterializationUnit> MU); + + /// Delegates responsibility for the given symbols to the returned + /// materialization responsibility. Useful for breaking up work between + /// threads, or different kinds of materialization processes. + MaterializationResponsibility delegate(const SymbolNameSet &Symbols, + VModuleKey NewKey = VModuleKey()); + + void addDependencies(const SymbolStringPtr &Name, + const SymbolDependenceMap &Dependencies); + + /// Add dependencies that apply to all symbols covered by this instance. + void addDependenciesForAll(const SymbolDependenceMap &Dependencies); + +private: + /// Create a MaterializationResponsibility for the given JITDylib and + /// initial symbols. + MaterializationResponsibility(JITDylib &JD, SymbolFlagsMap SymbolFlags, + VModuleKey K); + + JITDylib &JD; + SymbolFlagsMap SymbolFlags; + VModuleKey K; +}; + +/// A MaterializationUnit represents a set of symbol definitions that can +/// be materialized as a group, or individually discarded (when +/// overriding definitions are encountered). +/// +/// MaterializationUnits are used when providing lazy definitions of symbols to +/// JITDylibs. The JITDylib will call materialize when the address of a symbol +/// is requested via the lookup method. The JITDylib will call discard if a +/// stronger definition is added or already present. +class MaterializationUnit { +public: + MaterializationUnit(SymbolFlagsMap InitalSymbolFlags, VModuleKey K) + : SymbolFlags(std::move(InitalSymbolFlags)), K(std::move(K)) {} + + virtual ~MaterializationUnit() {} + + /// Return the name of this materialization unit. Useful for debugging + /// output. + virtual StringRef getName() const = 0; + + /// Return the set of symbols that this source provides. + const SymbolFlagsMap &getSymbols() const { return SymbolFlags; } + + /// Called by materialization dispatchers (see + /// ExecutionSession::DispatchMaterializationFunction) to trigger + /// materialization of this MaterializationUnit. + void doMaterialize(JITDylib &JD) { + materialize(MaterializationResponsibility(JD, std::move(SymbolFlags), + std::move(K))); + } + + /// Called by JITDylibs to notify MaterializationUnits that the given symbol + /// has been overridden. + void doDiscard(const JITDylib &JD, const SymbolStringPtr &Name) { + SymbolFlags.erase(Name); + discard(JD, std::move(Name)); + } + +protected: + SymbolFlagsMap SymbolFlags; + VModuleKey K; + +private: + virtual void anchor(); + + /// Implementations of this method should materialize all symbols + /// in the materialzation unit, except for those that have been + /// previously discarded. + virtual void materialize(MaterializationResponsibility R) = 0; + + /// Implementations of this method should discard the given symbol + /// from the source (e.g. if the source is an LLVM IR Module and the + /// symbol is a function, delete the function body or mark it available + /// externally). + virtual void discard(const JITDylib &JD, const SymbolStringPtr &Name) = 0; +}; + +using MaterializationUnitList = + std::vector<std::unique_ptr<MaterializationUnit>>; + +/// A MaterializationUnit implementation for pre-existing absolute symbols. +/// +/// All symbols will be resolved and marked ready as soon as the unit is +/// materialized. +class AbsoluteSymbolsMaterializationUnit : public MaterializationUnit { +public: + AbsoluteSymbolsMaterializationUnit(SymbolMap Symbols, VModuleKey K); + + StringRef getName() const override; + +private: + void materialize(MaterializationResponsibility R) override; + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; + static SymbolFlagsMap extractFlags(const SymbolMap &Symbols); + + SymbolMap Symbols; +}; + +/// Create an AbsoluteSymbolsMaterializationUnit with the given symbols. +/// Useful for inserting absolute symbols into a JITDylib. E.g.: +/// \code{.cpp} +/// JITDylib &JD = ...; +/// SymbolStringPtr Foo = ...; +/// JITEvaluatedSymbol FooSym = ...; +/// if (auto Err = JD.define(absoluteSymbols({{Foo, FooSym}}))) +/// return Err; +/// \endcode +/// +inline std::unique_ptr<AbsoluteSymbolsMaterializationUnit> +absoluteSymbols(SymbolMap Symbols, VModuleKey K = VModuleKey()) { + return std::make_unique<AbsoluteSymbolsMaterializationUnit>( + std::move(Symbols), std::move(K)); +} + +/// A materialization unit for symbol aliases. Allows existing symbols to be +/// aliased with alternate flags. +class ReExportsMaterializationUnit : public MaterializationUnit { +public: + /// SourceJD is allowed to be nullptr, in which case the source JITDylib is + /// taken to be whatever JITDylib these definitions are materialized in (and + /// MatchNonExported has no effect). This is useful for defining aliases + /// within a JITDylib. + /// + /// Note: Care must be taken that no sets of aliases form a cycle, as such + /// a cycle will result in a deadlock when any symbol in the cycle is + /// resolved. + ReExportsMaterializationUnit(JITDylib *SourceJD, bool MatchNonExported, + SymbolAliasMap Aliases, VModuleKey K); + + StringRef getName() const override; + +private: + void materialize(MaterializationResponsibility R) override; + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; + static SymbolFlagsMap extractFlags(const SymbolAliasMap &Aliases); + + JITDylib *SourceJD = nullptr; + bool MatchNonExported = false; + SymbolAliasMap Aliases; +}; + +/// Create a ReExportsMaterializationUnit with the given aliases. +/// Useful for defining symbol aliases.: E.g., given a JITDylib JD containing +/// symbols "foo" and "bar", we can define aliases "baz" (for "foo") and "qux" +/// (for "bar") with: \code{.cpp} +/// SymbolStringPtr Baz = ...; +/// SymbolStringPtr Qux = ...; +/// if (auto Err = JD.define(symbolAliases({ +/// {Baz, { Foo, JITSymbolFlags::Exported }}, +/// {Qux, { Bar, JITSymbolFlags::Weak }}})) +/// return Err; +/// \endcode +inline std::unique_ptr<ReExportsMaterializationUnit> +symbolAliases(SymbolAliasMap Aliases, VModuleKey K = VModuleKey()) { + return std::make_unique<ReExportsMaterializationUnit>( + nullptr, true, std::move(Aliases), std::move(K)); +} + +/// Create a materialization unit for re-exporting symbols from another JITDylib +/// with alternative names/flags. +/// If MatchNonExported is true then non-exported symbols from SourceJD can be +/// re-exported. If it is false, attempts to re-export a non-exported symbol +/// will result in a "symbol not found" error. +inline std::unique_ptr<ReExportsMaterializationUnit> +reexports(JITDylib &SourceJD, SymbolAliasMap Aliases, + bool MatchNonExported = false, VModuleKey K = VModuleKey()) { + return std::make_unique<ReExportsMaterializationUnit>( + &SourceJD, MatchNonExported, std::move(Aliases), std::move(K)); +} + +/// Build a SymbolAliasMap for the common case where you want to re-export +/// symbols from another JITDylib with the same linkage/flags. +Expected<SymbolAliasMap> +buildSimpleReexportsAliasMap(JITDylib &SourceJD, const SymbolNameSet &Symbols); + +/// Represents the state that a symbol has reached during materialization. +enum class SymbolState : uint8_t { + Invalid, /// No symbol should be in this state. + NeverSearched, /// Added to the symbol table, never queried. + Materializing, /// Queried, materialization begun. + Resolved, /// Assigned address, still materializing. + Emitted, /// Emitted to memory, but waiting on transitive dependencies. + Ready = 0x3f /// Ready and safe for clients to access. +}; + +/// A symbol query that returns results via a callback when results are +/// ready. +/// +/// makes a callback when all symbols are available. +class AsynchronousSymbolQuery { + friend class ExecutionSession; + friend class JITDylib; + friend class JITSymbolResolverAdapter; + +public: + /// Create a query for the given symbols. The NotifyComplete + /// callback will be called once all queried symbols reach the given + /// minimum state. + AsynchronousSymbolQuery(const SymbolNameSet &Symbols, + SymbolState RequiredState, + SymbolsResolvedCallback NotifyComplete); + + /// Notify the query that a requested symbol has reached the required state. + void notifySymbolMetRequiredState(const SymbolStringPtr &Name, + JITEvaluatedSymbol Sym); + + /// Returns true if all symbols covered by this query have been + /// resolved. + bool isComplete() const { return OutstandingSymbolsCount == 0; } + + /// Call the NotifyComplete callback. + /// + /// This should only be called if all symbols covered by the query have + /// reached the specified state. + void handleComplete(); + +private: + SymbolState getRequiredState() { return RequiredState; } + + void addQueryDependence(JITDylib &JD, SymbolStringPtr Name); + + void removeQueryDependence(JITDylib &JD, const SymbolStringPtr &Name); + + bool canStillFail(); + + void handleFailed(Error Err); + + void detach(); + + SymbolsResolvedCallback NotifyComplete; + SymbolDependenceMap QueryRegistrations; + SymbolMap ResolvedSymbols; + size_t OutstandingSymbolsCount; + SymbolState RequiredState; +}; + +/// A symbol table that supports asynchoronous symbol queries. +/// +/// Represents a virtual shared object. Instances can not be copied or moved, so +/// their addresses may be used as keys for resource management. +/// JITDylib state changes must be made via an ExecutionSession to guarantee +/// that they are synchronized with respect to other JITDylib operations. +class JITDylib { + friend class AsynchronousSymbolQuery; + friend class ExecutionSession; + friend class MaterializationResponsibility; +public: + class DefinitionGenerator { + public: + virtual ~DefinitionGenerator(); + virtual Expected<SymbolNameSet> + tryToGenerate(JITDylib &Parent, const SymbolNameSet &Names) = 0; + }; + + using AsynchronousSymbolQuerySet = + std::set<std::shared_ptr<AsynchronousSymbolQuery>>; + + JITDylib(const JITDylib &) = delete; + JITDylib &operator=(const JITDylib &) = delete; + JITDylib(JITDylib &&) = delete; + JITDylib &operator=(JITDylib &&) = delete; + + /// Get the name for this JITDylib. + const std::string &getName() const { return JITDylibName; } + + /// Get a reference to the ExecutionSession for this JITDylib. + ExecutionSession &getExecutionSession() const { return ES; } + + /// Adds a definition generator to this JITDylib and returns a referenece to + /// it. + /// + /// When JITDylibs are searched during lookup, if no existing definition of + /// a symbol is found, then any generators that have been added are run (in + /// the order that they were added) to potentially generate a definition. + template <typename GeneratorT> + GeneratorT &addGenerator(std::unique_ptr<GeneratorT> DefGenerator); + + /// Remove a definition generator from this JITDylib. + /// + /// The given generator must exist in this JITDylib's generators list (i.e. + /// have been added and not yet removed). + void removeGenerator(DefinitionGenerator &G); + + /// Set the search order to be used when fixing up definitions in JITDylib. + /// This will replace the previous search order, and apply to any symbol + /// resolutions made for definitions in this JITDylib after the call to + /// setSearchOrder (even if the definition itself was added before the + /// call). + /// + /// If SearchThisJITDylibFirst is set, which by default it is, then this + /// JITDylib will add itself to the beginning of the SearchOrder (Clients + /// should *not* put this JITDylib in the list in this case, to avoid + /// redundant lookups). + /// + /// If SearchThisJITDylibFirst is false then the search order will be used as + /// given. The main motivation for this feature is to support deliberate + /// shadowing of symbols in this JITDylib by a facade JITDylib. For example, + /// the facade may resolve function names to stubs, and the stubs may compile + /// lazily by looking up symbols in this dylib. Adding the facade dylib + /// as the first in the search order (instead of this dylib) ensures that + /// definitions within this dylib resolve to the lazy-compiling stubs, + /// rather than immediately materializing the definitions in this dylib. + void setSearchOrder(JITDylibSearchList NewSearchOrder, + bool SearchThisJITDylibFirst = true, + bool MatchNonExportedInThisDylib = true); + + /// Add the given JITDylib to the search order for definitions in this + /// JITDylib. + void addToSearchOrder(JITDylib &JD, bool MatcNonExported = false); + + /// Replace OldJD with NewJD in the search order if OldJD is present. + /// Otherwise this operation is a no-op. + void replaceInSearchOrder(JITDylib &OldJD, JITDylib &NewJD, + bool MatchNonExported = false); + + /// Remove the given JITDylib from the search order for this JITDylib if it is + /// present. Otherwise this operation is a no-op. + void removeFromSearchOrder(JITDylib &JD); + + /// Do something with the search order (run under the session lock). + template <typename Func> + auto withSearchOrderDo(Func &&F) + -> decltype(F(std::declval<const JITDylibSearchList &>())); + + /// Define all symbols provided by the materialization unit to be part of this + /// JITDylib. + /// + /// This overload always takes ownership of the MaterializationUnit. If any + /// errors occur, the MaterializationUnit consumed. + template <typename MaterializationUnitType> + Error define(std::unique_ptr<MaterializationUnitType> &&MU); + + /// Define all symbols provided by the materialization unit to be part of this + /// JITDylib. + /// + /// This overload only takes ownership of the MaterializationUnit no error is + /// generated. If an error occurs, ownership remains with the caller. This + /// may allow the caller to modify the MaterializationUnit to correct the + /// issue, then re-call define. + template <typename MaterializationUnitType> + Error define(std::unique_ptr<MaterializationUnitType> &MU); + + /// Tries to remove the given symbols. + /// + /// If any symbols are not defined in this JITDylib this method will return + /// a SymbolsNotFound error covering the missing symbols. + /// + /// If all symbols are found but some symbols are in the process of being + /// materialized this method will return a SymbolsCouldNotBeRemoved error. + /// + /// On success, all symbols are removed. On failure, the JITDylib state is + /// left unmodified (no symbols are removed). + Error remove(const SymbolNameSet &Names); + + /// Search the given JITDylib for the symbols in Symbols. If found, store + /// the flags for each symbol in Flags. Returns any unresolved symbols. + Expected<SymbolFlagsMap> lookupFlags(const SymbolNameSet &Names); + + /// Dump current JITDylib state to OS. + void dump(raw_ostream &OS); + + /// FIXME: Remove this when we remove the old ORC layers. + /// Search the given JITDylibs in order for the symbols in Symbols. Results + /// (once they become available) will be returned via the given Query. + /// + /// If any symbol is not found then the unresolved symbols will be returned, + /// and the query will not be applied. The Query is not failed and can be + /// re-used in a subsequent lookup once the symbols have been added, or + /// manually failed. + Expected<SymbolNameSet> + legacyLookup(std::shared_ptr<AsynchronousSymbolQuery> Q, SymbolNameSet Names); + +private: + using AsynchronousSymbolQueryList = + std::vector<std::shared_ptr<AsynchronousSymbolQuery>>; + + struct UnmaterializedInfo { + UnmaterializedInfo(std::unique_ptr<MaterializationUnit> MU) + : MU(std::move(MU)) {} + + std::unique_ptr<MaterializationUnit> MU; + }; + + using UnmaterializedInfosMap = + DenseMap<SymbolStringPtr, std::shared_ptr<UnmaterializedInfo>>; + + struct MaterializingInfo { + SymbolDependenceMap Dependants; + SymbolDependenceMap UnemittedDependencies; + + void addQuery(std::shared_ptr<AsynchronousSymbolQuery> Q); + void removeQuery(const AsynchronousSymbolQuery &Q); + AsynchronousSymbolQueryList takeQueriesMeeting(SymbolState RequiredState); + AsynchronousSymbolQueryList takeAllPendingQueries() { + return std::move(PendingQueries); + } + bool hasQueriesPending() const { return !PendingQueries.empty(); } + const AsynchronousSymbolQueryList &pendingQueries() const { + return PendingQueries; + } + private: + AsynchronousSymbolQueryList PendingQueries; + }; + + using MaterializingInfosMap = DenseMap<SymbolStringPtr, MaterializingInfo>; + + class SymbolTableEntry { + public: + SymbolTableEntry() = default; + SymbolTableEntry(JITSymbolFlags Flags) + : Flags(Flags), State(static_cast<uint8_t>(SymbolState::NeverSearched)), + MaterializerAttached(false), PendingRemoval(false) {} + + JITTargetAddress getAddress() const { return Addr; } + JITSymbolFlags getFlags() const { return Flags; } + SymbolState getState() const { return static_cast<SymbolState>(State); } + + bool isInMaterializationPhase() const { + return getState() == SymbolState::Materializing || + getState() == SymbolState::Resolved; + } + + bool hasMaterializerAttached() const { return MaterializerAttached; } + bool isPendingRemoval() const { return PendingRemoval; } + + void setAddress(JITTargetAddress Addr) { this->Addr = Addr; } + void setFlags(JITSymbolFlags Flags) { this->Flags = Flags; } + void setState(SymbolState State) { + assert(static_cast<uint8_t>(State) < (1 << 6) && + "State does not fit in bitfield"); + this->State = static_cast<uint8_t>(State); + } + + void setMaterializerAttached(bool MaterializerAttached) { + this->MaterializerAttached = MaterializerAttached; + } + + void setPendingRemoval(bool PendingRemoval) { + this->PendingRemoval = PendingRemoval; + } + + JITEvaluatedSymbol getSymbol() const { + return JITEvaluatedSymbol(Addr, Flags); + } + + private: + JITTargetAddress Addr = 0; + JITSymbolFlags Flags; + uint8_t State : 6; + uint8_t MaterializerAttached : 1; + uint8_t PendingRemoval : 1; + }; + + using SymbolTable = DenseMap<SymbolStringPtr, SymbolTableEntry>; + + JITDylib(ExecutionSession &ES, std::string Name); + + Error defineImpl(MaterializationUnit &MU); + + Expected<SymbolNameSet> lookupFlagsImpl(SymbolFlagsMap &Flags, + const SymbolNameSet &Names); + + Error lodgeQuery(std::shared_ptr<AsynchronousSymbolQuery> &Q, + SymbolNameSet &Unresolved, bool MatchNonExported, + MaterializationUnitList &MUs); + + Error lodgeQueryImpl(std::shared_ptr<AsynchronousSymbolQuery> &Q, + SymbolNameSet &Unresolved, bool MatchNonExported, + MaterializationUnitList &MUs); + + bool lookupImpl(std::shared_ptr<AsynchronousSymbolQuery> &Q, + std::vector<std::unique_ptr<MaterializationUnit>> &MUs, + SymbolNameSet &Unresolved); + + void detachQueryHelper(AsynchronousSymbolQuery &Q, + const SymbolNameSet &QuerySymbols); + + void transferEmittedNodeDependencies(MaterializingInfo &DependantMI, + const SymbolStringPtr &DependantName, + MaterializingInfo &EmittedMI); + + Error defineMaterializing(const SymbolFlagsMap &SymbolFlags); + + void replace(std::unique_ptr<MaterializationUnit> MU); + + SymbolNameSet getRequestedSymbols(const SymbolFlagsMap &SymbolFlags) const; + + void addDependencies(const SymbolStringPtr &Name, + const SymbolDependenceMap &Dependants); + + Error resolve(const SymbolMap &Resolved); + + Error emit(const SymbolFlagsMap &Emitted); + + using FailedSymbolsWorklist = + std::vector<std::pair<JITDylib *, SymbolStringPtr>>; + static void notifyFailed(FailedSymbolsWorklist FailedSymbols); + + ExecutionSession &ES; + std::string JITDylibName; + SymbolTable Symbols; + UnmaterializedInfosMap UnmaterializedInfos; + MaterializingInfosMap MaterializingInfos; + std::vector<std::unique_ptr<DefinitionGenerator>> DefGenerators; + JITDylibSearchList SearchOrder; +}; + +/// An ExecutionSession represents a running JIT program. +class ExecutionSession { + // FIXME: Remove this when we remove the old ORC layers. + friend class JITDylib; + +public: + /// For reporting errors. + using ErrorReporter = std::function<void(Error)>; + + /// For dispatching MaterializationUnit::materialize calls. + using DispatchMaterializationFunction = std::function<void( + JITDylib &JD, std::unique_ptr<MaterializationUnit> MU)>; + + /// Construct an ExecutionSession. + /// + /// SymbolStringPools may be shared between ExecutionSessions. + ExecutionSession(std::shared_ptr<SymbolStringPool> SSP = nullptr); + + /// Add a symbol name to the SymbolStringPool and return a pointer to it. + SymbolStringPtr intern(StringRef SymName) { return SSP->intern(SymName); } + + /// Returns a shared_ptr to the SymbolStringPool for this ExecutionSession. + std::shared_ptr<SymbolStringPool> getSymbolStringPool() const { return SSP; } + + /// Run the given lambda with the session mutex locked. + template <typename Func> auto runSessionLocked(Func &&F) -> decltype(F()) { + std::lock_guard<std::recursive_mutex> Lock(SessionMutex); + return F(); + } + + /// Get the "main" JITDylib, which is created automatically on construction of + /// the ExecutionSession. + JITDylib &getMainJITDylib(); + + /// Return a pointer to the "name" JITDylib. + /// Ownership of JITDylib remains within Execution Session + JITDylib *getJITDylibByName(StringRef Name); + + /// Add a new JITDylib to this ExecutionSession. + /// + /// The JITDylib Name is required to be unique. Clients should verify that + /// names are not being re-used (e.g. by calling getJITDylibByName) if names + /// are based on user input. + JITDylib &createJITDylib(std::string Name, + bool AddToMainDylibSearchOrder = true); + + /// Allocate a module key for a new module to add to the JIT. + VModuleKey allocateVModule() { + return runSessionLocked([this]() { return ++LastKey; }); + } + + /// Return a module key to the ExecutionSession so that it can be + /// re-used. This should only be done once all resources associated + /// with the original key have been released. + void releaseVModule(VModuleKey Key) { /* FIXME: Recycle keys */ + } + + /// Set the error reporter function. + ExecutionSession &setErrorReporter(ErrorReporter ReportError) { + this->ReportError = std::move(ReportError); + return *this; + } + + /// Report a error for this execution session. + /// + /// Unhandled errors can be sent here to log them. + void reportError(Error Err) { ReportError(std::move(Err)); } + + /// Set the materialization dispatch function. + ExecutionSession &setDispatchMaterialization( + DispatchMaterializationFunction DispatchMaterialization) { + this->DispatchMaterialization = std::move(DispatchMaterialization); + return *this; + } + + void legacyFailQuery(AsynchronousSymbolQuery &Q, Error Err); + + using LegacyAsyncLookupFunction = std::function<SymbolNameSet( + std::shared_ptr<AsynchronousSymbolQuery> Q, SymbolNameSet Names)>; + + /// A legacy lookup function for JITSymbolResolverAdapter. + /// Do not use -- this will be removed soon. + Expected<SymbolMap> + legacyLookup(LegacyAsyncLookupFunction AsyncLookup, SymbolNameSet Names, + SymbolState RequiredState, + RegisterDependenciesFunction RegisterDependencies); + + /// Search the given JITDylib list for the given symbols. + /// + /// SearchOrder lists the JITDylibs to search. For each dylib, the associated + /// boolean indicates whether the search should match against non-exported + /// (hidden visibility) symbols in that dylib (true means match against + /// non-exported symbols, false means do not match). + /// + /// The NotifyComplete callback will be called once all requested symbols + /// reach the required state. + /// + /// If all symbols are found, the RegisterDependencies function will be called + /// while the session lock is held. This gives clients a chance to register + /// dependencies for on the queried symbols for any symbols they are + /// materializing (if a MaterializationResponsibility instance is present, + /// this can be implemented by calling + /// MaterializationResponsibility::addDependencies). If there are no + /// dependenant symbols for this query (e.g. it is being made by a top level + /// client to get an address to call) then the value NoDependenciesToRegister + /// can be used. + void lookup(const JITDylibSearchList &SearchOrder, SymbolNameSet Symbols, + SymbolState RequiredState, SymbolsResolvedCallback NotifyComplete, + RegisterDependenciesFunction RegisterDependencies); + + /// Blocking version of lookup above. Returns the resolved symbol map. + /// If WaitUntilReady is true (the default), will not return until all + /// requested symbols are ready (or an error occurs). If WaitUntilReady is + /// false, will return as soon as all requested symbols are resolved, + /// or an error occurs. If WaitUntilReady is false and an error occurs + /// after resolution, the function will return a success value, but the + /// error will be reported via reportErrors. + Expected<SymbolMap> lookup(const JITDylibSearchList &SearchOrder, + const SymbolNameSet &Symbols, + SymbolState RequiredState = SymbolState::Ready, + RegisterDependenciesFunction RegisterDependencies = + NoDependenciesToRegister); + + /// Convenience version of blocking lookup. + /// Searches each of the JITDylibs in the search order in turn for the given + /// symbol. + Expected<JITEvaluatedSymbol> lookup(const JITDylibSearchList &SearchOrder, + SymbolStringPtr Symbol); + + /// Convenience version of blocking lookup. + /// Searches each of the JITDylibs in the search order in turn for the given + /// symbol. The search will not find non-exported symbols. + Expected<JITEvaluatedSymbol> lookup(ArrayRef<JITDylib *> SearchOrder, + SymbolStringPtr Symbol); + + /// Convenience version of blocking lookup. + /// Searches each of the JITDylibs in the search order in turn for the given + /// symbol. The search will not find non-exported symbols. + Expected<JITEvaluatedSymbol> lookup(ArrayRef<JITDylib *> SearchOrder, + StringRef Symbol); + + /// Materialize the given unit. + void dispatchMaterialization(JITDylib &JD, + std::unique_ptr<MaterializationUnit> MU) { + LLVM_DEBUG({ + runSessionLocked([&]() { + dbgs() << "Dispatching " << *MU << " for " << JD.getName() << "\n"; + }); + }); + DispatchMaterialization(JD, std::move(MU)); + } + + /// Dump the state of all the JITDylibs in this session. + void dump(raw_ostream &OS); + +private: + static void logErrorsToStdErr(Error Err) { + logAllUnhandledErrors(std::move(Err), errs(), "JIT session error: "); + } + + static void + materializeOnCurrentThread(JITDylib &JD, + std::unique_ptr<MaterializationUnit> MU) { + MU->doMaterialize(JD); + } + + void runOutstandingMUs(); + + mutable std::recursive_mutex SessionMutex; + std::shared_ptr<SymbolStringPool> SSP; + VModuleKey LastKey = 0; + ErrorReporter ReportError = logErrorsToStdErr; + DispatchMaterializationFunction DispatchMaterialization = + materializeOnCurrentThread; + + std::vector<std::unique_ptr<JITDylib>> JDs; + + // FIXME: Remove this (and runOutstandingMUs) once the linking layer works + // with callbacks from asynchronous queries. + mutable std::recursive_mutex OutstandingMUsMutex; + std::vector<std::pair<JITDylib *, std::unique_ptr<MaterializationUnit>>> + OutstandingMUs; +}; + +template <typename GeneratorT> +GeneratorT &JITDylib::addGenerator(std::unique_ptr<GeneratorT> DefGenerator) { + auto &G = *DefGenerator; + ES.runSessionLocked( + [&]() { DefGenerators.push_back(std::move(DefGenerator)); }); + return G; +} + +template <typename Func> +auto JITDylib::withSearchOrderDo(Func &&F) + -> decltype(F(std::declval<const JITDylibSearchList &>())) { + return ES.runSessionLocked([&]() { return F(SearchOrder); }); +} + +template <typename MaterializationUnitType> +Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &&MU) { + assert(MU && "Can not define with a null MU"); + return ES.runSessionLocked([&, this]() -> Error { + if (auto Err = defineImpl(*MU)) + return Err; + + /// defineImpl succeeded. + auto UMI = std::make_shared<UnmaterializedInfo>(std::move(MU)); + for (auto &KV : UMI->MU->getSymbols()) + UnmaterializedInfos[KV.first] = UMI; + + return Error::success(); + }); +} + +template <typename MaterializationUnitType> +Error JITDylib::define(std::unique_ptr<MaterializationUnitType> &MU) { + assert(MU && "Can not define with a null MU"); + + return ES.runSessionLocked([&, this]() -> Error { + if (auto Err = defineImpl(*MU)) + return Err; + + /// defineImpl succeeded. + auto UMI = std::make_shared<UnmaterializedInfo>(std::move(MU)); + for (auto &KV : UMI->MU->getSymbols()) + UnmaterializedInfos[KV.first] = UMI; + + return Error::success(); + }); +} + +/// ReexportsGenerator can be used with JITDylib::setGenerator to automatically +/// re-export a subset of the source JITDylib's symbols in the target. +class ReexportsGenerator : public JITDylib::DefinitionGenerator { +public: + using SymbolPredicate = std::function<bool(SymbolStringPtr)>; + + /// Create a reexports generator. If an Allow predicate is passed, only + /// symbols for which the predicate returns true will be reexported. If no + /// Allow predicate is passed, all symbols will be exported. + ReexportsGenerator(JITDylib &SourceJD, bool MatchNonExported = false, + SymbolPredicate Allow = SymbolPredicate()); + + Expected<SymbolNameSet> tryToGenerate(JITDylib &JD, + const SymbolNameSet &Names) override; + +private: + JITDylib &SourceJD; + bool MatchNonExported = false; + SymbolPredicate Allow; +}; + +/// Mangles symbol names then uniques them in the context of an +/// ExecutionSession. +class MangleAndInterner { +public: + MangleAndInterner(ExecutionSession &ES, const DataLayout &DL); + SymbolStringPtr operator()(StringRef Name); + +private: + ExecutionSession &ES; + const DataLayout &DL; +}; + +} // End namespace orc +} // End namespace llvm + +#undef DEBUG_TYPE // "orc" + +#endif // LLVM_EXECUTIONENGINE_ORC_CORE_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h new file mode 100644 index 000000000000..cf0a428662ef --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/ExecutionUtils.h @@ -0,0 +1,317 @@ +//===- ExecutionUtils.h - Utilities for executing code in Orc ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains utilities for executing code in Orc. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_EXECUTIONUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_EXECUTIONUTILS_H + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/Object/Archive.h" +#include "llvm/Support/DynamicLibrary.h" +#include <algorithm> +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +namespace llvm { + +class ConstantArray; +class GlobalVariable; +class Function; +class Module; +class TargetMachine; +class Value; + +namespace orc { + +class ObjectLayer; + +/// This iterator provides a convenient way to iterate over the elements +/// of an llvm.global_ctors/llvm.global_dtors instance. +/// +/// The easiest way to get hold of instances of this class is to use the +/// getConstructors/getDestructors functions. +class CtorDtorIterator { +public: + /// Accessor for an element of the global_ctors/global_dtors array. + /// + /// This class provides a read-only view of the element with any casts on + /// the function stripped away. + struct Element { + Element(unsigned Priority, Function *Func, Value *Data) + : Priority(Priority), Func(Func), Data(Data) {} + + unsigned Priority; + Function *Func; + Value *Data; + }; + + /// Construct an iterator instance. If End is true then this iterator + /// acts as the end of the range, otherwise it is the beginning. + CtorDtorIterator(const GlobalVariable *GV, bool End); + + /// Test iterators for equality. + bool operator==(const CtorDtorIterator &Other) const; + + /// Test iterators for inequality. + bool operator!=(const CtorDtorIterator &Other) const; + + /// Pre-increment iterator. + CtorDtorIterator& operator++(); + + /// Post-increment iterator. + CtorDtorIterator operator++(int); + + /// Dereference iterator. The resulting value provides a read-only view + /// of this element of the global_ctors/global_dtors list. + Element operator*() const; + +private: + const ConstantArray *InitList; + unsigned I; +}; + +/// Create an iterator range over the entries of the llvm.global_ctors +/// array. +iterator_range<CtorDtorIterator> getConstructors(const Module &M); + +/// Create an iterator range over the entries of the llvm.global_ctors +/// array. +iterator_range<CtorDtorIterator> getDestructors(const Module &M); + +/// Convenience class for recording constructor/destructor names for +/// later execution. +template <typename JITLayerT> +class LegacyCtorDtorRunner { +public: + /// Construct a CtorDtorRunner for the given range using the given + /// name mangling function. + LLVM_ATTRIBUTE_DEPRECATED( + LegacyCtorDtorRunner(std::vector<std::string> CtorDtorNames, + VModuleKey K), + "ORCv1 utilities (utilities with the 'Legacy' prefix) are deprecated. " + "Please use the ORCv2 CtorDtorRunner utility instead"); + + LegacyCtorDtorRunner(ORCv1DeprecationAcknowledgement, + std::vector<std::string> CtorDtorNames, VModuleKey K) + : CtorDtorNames(std::move(CtorDtorNames)), K(K) {} + + /// Run the recorded constructors/destructors through the given JIT + /// layer. + Error runViaLayer(JITLayerT &JITLayer) const { + using CtorDtorTy = void (*)(); + + for (const auto &CtorDtorName : CtorDtorNames) { + if (auto CtorDtorSym = JITLayer.findSymbolIn(K, CtorDtorName, false)) { + if (auto AddrOrErr = CtorDtorSym.getAddress()) { + CtorDtorTy CtorDtor = + reinterpret_cast<CtorDtorTy>(static_cast<uintptr_t>(*AddrOrErr)); + CtorDtor(); + } else + return AddrOrErr.takeError(); + } else { + if (auto Err = CtorDtorSym.takeError()) + return Err; + else + return make_error<JITSymbolNotFound>(CtorDtorName); + } + } + return Error::success(); + } + +private: + std::vector<std::string> CtorDtorNames; + orc::VModuleKey K; +}; + +template <typename JITLayerT> +LegacyCtorDtorRunner<JITLayerT>::LegacyCtorDtorRunner( + std::vector<std::string> CtorDtorNames, VModuleKey K) + : CtorDtorNames(std::move(CtorDtorNames)), K(K) {} + +class CtorDtorRunner { +public: + CtorDtorRunner(JITDylib &JD) : JD(JD) {} + void add(iterator_range<CtorDtorIterator> CtorDtors); + Error run(); + +private: + using CtorDtorList = std::vector<SymbolStringPtr>; + using CtorDtorPriorityMap = std::map<unsigned, CtorDtorList>; + + JITDylib &JD; + CtorDtorPriorityMap CtorDtorsByPriority; +}; + +/// Support class for static dtor execution. For hosted (in-process) JITs +/// only! +/// +/// If a __cxa_atexit function isn't found C++ programs that use static +/// destructors will fail to link. However, we don't want to use the host +/// process's __cxa_atexit, because it will schedule JIT'd destructors to run +/// after the JIT has been torn down, which is no good. This class makes it easy +/// to override __cxa_atexit (and the related __dso_handle). +/// +/// To use, clients should manually call searchOverrides from their symbol +/// resolver. This should generally be done after attempting symbol resolution +/// inside the JIT, but before searching the host process's symbol table. When +/// the client determines that destructors should be run (generally at JIT +/// teardown or after a return from main), the runDestructors method should be +/// called. +class LocalCXXRuntimeOverridesBase { +public: + /// Run any destructors recorded by the overriden __cxa_atexit function + /// (CXAAtExitOverride). + void runDestructors(); + +protected: + template <typename PtrTy> JITTargetAddress toTargetAddress(PtrTy *P) { + return static_cast<JITTargetAddress>(reinterpret_cast<uintptr_t>(P)); + } + + using DestructorPtr = void (*)(void *); + using CXXDestructorDataPair = std::pair<DestructorPtr, void *>; + using CXXDestructorDataPairList = std::vector<CXXDestructorDataPair>; + CXXDestructorDataPairList DSOHandleOverride; + static int CXAAtExitOverride(DestructorPtr Destructor, void *Arg, + void *DSOHandle); +}; + +class LegacyLocalCXXRuntimeOverrides : public LocalCXXRuntimeOverridesBase { +public: + /// Create a runtime-overrides class. + template <typename MangleFtorT> + LLVM_ATTRIBUTE_DEPRECATED( + LegacyLocalCXXRuntimeOverrides(const MangleFtorT &Mangle), + "ORCv1 utilities (utilities with the 'Legacy' prefix) are deprecated. " + "Please use the ORCv2 LocalCXXRuntimeOverrides utility instead"); + + template <typename MangleFtorT> + LegacyLocalCXXRuntimeOverrides(ORCv1DeprecationAcknowledgement, + const MangleFtorT &Mangle) { + addOverride(Mangle("__dso_handle"), toTargetAddress(&DSOHandleOverride)); + addOverride(Mangle("__cxa_atexit"), toTargetAddress(&CXAAtExitOverride)); + } + + /// Search overrided symbols. + JITEvaluatedSymbol searchOverrides(const std::string &Name) { + auto I = CXXRuntimeOverrides.find(Name); + if (I != CXXRuntimeOverrides.end()) + return JITEvaluatedSymbol(I->second, JITSymbolFlags::Exported); + return nullptr; + } + +private: + void addOverride(const std::string &Name, JITTargetAddress Addr) { + CXXRuntimeOverrides.insert(std::make_pair(Name, Addr)); + } + + StringMap<JITTargetAddress> CXXRuntimeOverrides; +}; + +template <typename MangleFtorT> +LegacyLocalCXXRuntimeOverrides::LegacyLocalCXXRuntimeOverrides( + const MangleFtorT &Mangle) { + addOverride(Mangle("__dso_handle"), toTargetAddress(&DSOHandleOverride)); + addOverride(Mangle("__cxa_atexit"), toTargetAddress(&CXAAtExitOverride)); +} + +class LocalCXXRuntimeOverrides : public LocalCXXRuntimeOverridesBase { +public: + Error enable(JITDylib &JD, MangleAndInterner &Mangler); +}; + +/// A utility class to expose symbols found via dlsym to the JIT. +/// +/// If an instance of this class is attached to a JITDylib as a fallback +/// definition generator, then any symbol found in the given DynamicLibrary that +/// passes the 'Allow' predicate will be added to the JITDylib. +class DynamicLibrarySearchGenerator : public JITDylib::DefinitionGenerator { +public: + using SymbolPredicate = std::function<bool(SymbolStringPtr)>; + + /// Create a DynamicLibrarySearchGenerator that searches for symbols in the + /// given sys::DynamicLibrary. + /// + /// If the Allow predicate is given then only symbols matching the predicate + /// will be searched for. If the predicate is not given then all symbols will + /// be searched for. + DynamicLibrarySearchGenerator(sys::DynamicLibrary Dylib, char GlobalPrefix, + SymbolPredicate Allow = SymbolPredicate()); + + /// Permanently loads the library at the given path and, on success, returns + /// a DynamicLibrarySearchGenerator that will search it for symbol definitions + /// in the library. On failure returns the reason the library failed to load. + static Expected<std::unique_ptr<DynamicLibrarySearchGenerator>> + Load(const char *FileName, char GlobalPrefix, + SymbolPredicate Allow = SymbolPredicate()); + + /// Creates a DynamicLibrarySearchGenerator that searches for symbols in + /// the current process. + static Expected<std::unique_ptr<DynamicLibrarySearchGenerator>> + GetForCurrentProcess(char GlobalPrefix, + SymbolPredicate Allow = SymbolPredicate()) { + return Load(nullptr, GlobalPrefix, std::move(Allow)); + } + + Expected<SymbolNameSet> tryToGenerate(JITDylib &JD, + const SymbolNameSet &Names) override; + +private: + sys::DynamicLibrary Dylib; + SymbolPredicate Allow; + char GlobalPrefix; +}; + +/// A utility class to expose symbols from a static library. +/// +/// If an instance of this class is attached to a JITDylib as a fallback +/// definition generator, then any symbol found in the archive will result in +/// the containing object being added to the JITDylib. +class StaticLibraryDefinitionGenerator : public JITDylib::DefinitionGenerator { +public: + /// Try to create a StaticLibraryDefinitionGenerator from the given path. + /// + /// This call will succeed if the file at the given path is a static library + /// is a valid archive, otherwise it will return an error. + static Expected<std::unique_ptr<StaticLibraryDefinitionGenerator>> + Load(ObjectLayer &L, const char *FileName); + + /// Try to create a StaticLibrarySearchGenerator from the given memory buffer. + /// Thhis call will succeed if the buffer contains a valid archive, otherwise + /// it will return an error. + static Expected<std::unique_ptr<StaticLibraryDefinitionGenerator>> + Create(ObjectLayer &L, std::unique_ptr<MemoryBuffer> ArchiveBuffer); + + Expected<SymbolNameSet> tryToGenerate(JITDylib &JD, + const SymbolNameSet &Names) override; + +private: + StaticLibraryDefinitionGenerator(ObjectLayer &L, + std::unique_ptr<MemoryBuffer> ArchiveBuffer, + Error &Err); + + ObjectLayer &L; + std::unique_ptr<MemoryBuffer> ArchiveBuffer; + object::Archive Archive; + size_t UnrealizedObjects = 0; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_EXECUTIONUTILS_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/GlobalMappingLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/GlobalMappingLayer.h new file mode 100644 index 000000000000..a4e43d4e1c9c --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/GlobalMappingLayer.h @@ -0,0 +1,111 @@ +//===- GlobalMappingLayer.h - Run all IR through a functor ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Convenience layer for injecting symbols that will appear in calls to +// findSymbol. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_GLOBALMAPPINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_GLOBALMAPPINGLAYER_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include <map> +#include <memory> +#include <string> + +namespace llvm { + +class Module; +class JITSymbolResolver; + +namespace orc { + +/// Global mapping layer. +/// +/// This layer overrides the findSymbol method to first search a local symbol +/// table that the client can define. It can be used to inject new symbol +/// mappings into the JIT. Beware, however: symbols within a single IR module or +/// object file will still resolve locally (via RuntimeDyld's symbol table) - +/// such internal references cannot be overriden via this layer. +template <typename BaseLayerT> +class GlobalMappingLayer { +public: + + /// Handle to an added module. + using ModuleHandleT = typename BaseLayerT::ModuleHandleT; + + /// Construct an GlobalMappingLayer with the given BaseLayer + GlobalMappingLayer(BaseLayerT &BaseLayer) : BaseLayer(BaseLayer) {} + + /// Add the given module to the JIT. + /// @return A handle for the added modules. + Expected<ModuleHandleT> + addModule(std::shared_ptr<Module> M, + std::shared_ptr<JITSymbolResolver> Resolver) { + return BaseLayer.addModule(std::move(M), std::move(Resolver)); + } + + /// Remove the module set associated with the handle H. + Error removeModule(ModuleHandleT H) { return BaseLayer.removeModule(H); } + + /// Manually set the address to return for the given symbol. + void setGlobalMapping(const std::string &Name, JITTargetAddress Addr) { + SymbolTable[Name] = Addr; + } + + /// Remove the given symbol from the global mapping. + void eraseGlobalMapping(const std::string &Name) { + SymbolTable.erase(Name); + } + + /// Search for the given named symbol. + /// + /// This method will first search the local symbol table, returning + /// any symbol found there. If the symbol is not found in the local + /// table then this call will be passed through to the base layer. + /// + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { + auto I = SymbolTable.find(Name); + if (I != SymbolTable.end()) + return JITSymbol(I->second, JITSymbolFlags::Exported); + return BaseLayer.findSymbol(Name, ExportedSymbolsOnly); + } + + /// Get the address of the given symbol in the context of the of the + /// module represented by the handle H. This call is forwarded to the + /// base layer's implementation. + /// @param H The handle for the module to search in. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it is found in the + /// given module. + JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name, + bool ExportedSymbolsOnly) { + return BaseLayer.findSymbolIn(H, Name, ExportedSymbolsOnly); + } + + /// Immediately emit and finalize the module set represented by the + /// given handle. + /// @param H Handle for module set to emit/finalize. + Error emitAndFinalize(ModuleHandleT H) { + return BaseLayer.emitAndFinalize(H); + } + +private: + BaseLayerT &BaseLayer; + std::map<std::string, JITTargetAddress> SymbolTable; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_GLOBALMAPPINGLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h new file mode 100644 index 000000000000..52223a83ad42 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h @@ -0,0 +1,145 @@ +//===- IRCompileLayer.h -- Eagerly compile IR for JIT -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definition for a basic, eagerly compiling layer of the JIT. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_IRCOMPILELAYER_H +#define LLVM_EXECUTIONENGINE_ORC_IRCOMPILELAYER_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include <memory> +#include <string> + +namespace llvm { + +class Module; + +namespace orc { + +class IRCompileLayer : public IRLayer { +public: + using CompileFunction = + std::function<Expected<std::unique_ptr<MemoryBuffer>>(Module &)>; + + using NotifyCompiledFunction = + std::function<void(VModuleKey K, ThreadSafeModule TSM)>; + + IRCompileLayer(ExecutionSession &ES, ObjectLayer &BaseLayer, + CompileFunction Compile); + + void setNotifyCompiled(NotifyCompiledFunction NotifyCompiled); + + void emit(MaterializationResponsibility R, ThreadSafeModule TSM) override; + +private: + mutable std::mutex IRLayerMutex; + ObjectLayer &BaseLayer; + CompileFunction Compile; + NotifyCompiledFunction NotifyCompiled = NotifyCompiledFunction(); +}; + +/// Eager IR compiling layer. +/// +/// This layer immediately compiles each IR module added via addModule to an +/// object file and adds this module file to the layer below, which must +/// implement the object layer concept. +template <typename BaseLayerT, typename CompileFtor> +class LegacyIRCompileLayer { +public: + /// Callback type for notifications when modules are compiled. + using NotifyCompiledCallback = + std::function<void(VModuleKey K, std::unique_ptr<Module>)>; + + /// Construct an LegacyIRCompileLayer with the given BaseLayer, which must + /// implement the ObjectLayer concept. + LLVM_ATTRIBUTE_DEPRECATED( + LegacyIRCompileLayer( + BaseLayerT &BaseLayer, CompileFtor Compile, + NotifyCompiledCallback NotifyCompiled = NotifyCompiledCallback()), + "ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please " + "use " + "the ORCv2 IRCompileLayer instead"); + + /// Legacy layer constructor with deprecation acknowledgement. + LegacyIRCompileLayer( + ORCv1DeprecationAcknowledgement, BaseLayerT &BaseLayer, + CompileFtor Compile, + NotifyCompiledCallback NotifyCompiled = NotifyCompiledCallback()) + : BaseLayer(BaseLayer), Compile(std::move(Compile)), + NotifyCompiled(std::move(NotifyCompiled)) {} + + /// Get a reference to the compiler functor. + CompileFtor& getCompiler() { return Compile; } + + /// (Re)set the NotifyCompiled callback. + void setNotifyCompiled(NotifyCompiledCallback NotifyCompiled) { + this->NotifyCompiled = std::move(NotifyCompiled); + } + + /// Compile the module, and add the resulting object to the base layer + /// along with the given memory manager and symbol resolver. + Error addModule(VModuleKey K, std::unique_ptr<Module> M) { + if (auto Err = BaseLayer.addObject(std::move(K), Compile(*M))) + return Err; + if (NotifyCompiled) + NotifyCompiled(std::move(K), std::move(M)); + return Error::success(); + } + + /// Remove the module associated with the VModuleKey K. + Error removeModule(VModuleKey K) { return BaseLayer.removeObject(K); } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { + return BaseLayer.findSymbol(Name, ExportedSymbolsOnly); + } + + /// Get the address of the given symbol in compiled module represented + /// by the handle H. This call is forwarded to the base layer's + /// implementation. + /// @param K The VModuleKey for the module to search in. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it is found in the + /// given module. + JITSymbol findSymbolIn(VModuleKey K, const std::string &Name, + bool ExportedSymbolsOnly) { + return BaseLayer.findSymbolIn(K, Name, ExportedSymbolsOnly); + } + + /// Immediately emit and finalize the module represented by the given + /// handle. + /// @param K The VModuleKey for the module to emit/finalize. + Error emitAndFinalize(VModuleKey K) { return BaseLayer.emitAndFinalize(K); } + +private: + BaseLayerT &BaseLayer; + CompileFtor Compile; + NotifyCompiledCallback NotifyCompiled; +}; + +template <typename BaseLayerT, typename CompileFtor> +LegacyIRCompileLayer<BaseLayerT, CompileFtor>::LegacyIRCompileLayer( + BaseLayerT &BaseLayer, CompileFtor Compile, + NotifyCompiledCallback NotifyCompiled) + : BaseLayer(BaseLayer), Compile(std::move(Compile)), + NotifyCompiled(std::move(NotifyCompiled)) {} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_IRCOMPILINGLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h new file mode 100644 index 000000000000..b71e5b339711 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/IRTransformLayer.h @@ -0,0 +1,130 @@ +//===- IRTransformLayer.h - Run all IR through a functor --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Run all IR passed in through a user supplied functor. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_IRTRANSFORMLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_IRTRANSFORMLAYER_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include <memory> +#include <string> + +namespace llvm { +class Module; +namespace orc { + +/// A layer that applies a transform to emitted modules. +/// The transform function is responsible for locking the ThreadSafeContext +/// before operating on the module. +class IRTransformLayer : public IRLayer { +public: + using TransformFunction = std::function<Expected<ThreadSafeModule>( + ThreadSafeModule, const MaterializationResponsibility &R)>; + + IRTransformLayer(ExecutionSession &ES, IRLayer &BaseLayer, + TransformFunction Transform = identityTransform); + + void setTransform(TransformFunction Transform) { + this->Transform = std::move(Transform); + } + + void emit(MaterializationResponsibility R, ThreadSafeModule TSM) override; + + static ThreadSafeModule + identityTransform(ThreadSafeModule TSM, + const MaterializationResponsibility &R) { + return TSM; + } + +private: + IRLayer &BaseLayer; + TransformFunction Transform; +}; + +/// IR mutating layer. +/// +/// This layer applies a user supplied transform to each module that is added, +/// then adds the transformed module to the layer below. +template <typename BaseLayerT, typename TransformFtor> +class LegacyIRTransformLayer { +public: + + /// Construct an LegacyIRTransformLayer with the given BaseLayer + LLVM_ATTRIBUTE_DEPRECATED( + LegacyIRTransformLayer(BaseLayerT &BaseLayer, + TransformFtor Transform = TransformFtor()), + "ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please " + "use " + "the ORCv2 IRTransformLayer instead"); + + /// Legacy layer constructor with deprecation acknowledgement. + LegacyIRTransformLayer(ORCv1DeprecationAcknowledgement, BaseLayerT &BaseLayer, + TransformFtor Transform = TransformFtor()) + : BaseLayer(BaseLayer), Transform(std::move(Transform)) {} + + /// Apply the transform functor to the module, then add the module to + /// the layer below, along with the memory manager and symbol resolver. + /// + /// @return A handle for the added modules. + Error addModule(VModuleKey K, std::unique_ptr<Module> M) { + return BaseLayer.addModule(std::move(K), Transform(std::move(M))); + } + + /// Remove the module associated with the VModuleKey K. + Error removeModule(VModuleKey K) { return BaseLayer.removeModule(K); } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { + return BaseLayer.findSymbol(Name, ExportedSymbolsOnly); + } + + /// Get the address of the given symbol in the context of the module + /// represented by the VModuleKey K. This call is forwarded to the base + /// layer's implementation. + /// @param K The VModuleKey for the module to search in. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it is found in the + /// given module. + JITSymbol findSymbolIn(VModuleKey K, const std::string &Name, + bool ExportedSymbolsOnly) { + return BaseLayer.findSymbolIn(K, Name, ExportedSymbolsOnly); + } + + /// Immediately emit and finalize the module represented by the given + /// VModuleKey. + /// @param K The VModuleKey for the module to emit/finalize. + Error emitAndFinalize(VModuleKey K) { return BaseLayer.emitAndFinalize(K); } + + /// Access the transform functor directly. + TransformFtor& getTransform() { return Transform; } + + /// Access the mumate functor directly. + const TransformFtor& getTransform() const { return Transform; } + +private: + BaseLayerT &BaseLayer; + TransformFtor Transform; +}; + +template <typename BaseLayerT, typename TransformFtor> +LegacyIRTransformLayer<BaseLayerT, TransformFtor>::LegacyIRTransformLayer( + BaseLayerT &BaseLayer, TransformFtor Transform) + : BaseLayer(BaseLayer), Transform(std::move(Transform)) {} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_IRTRANSFORMLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h new file mode 100644 index 000000000000..a7ed5372d1e4 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h @@ -0,0 +1,493 @@ +//===- IndirectionUtils.h - Utilities for adding indirections ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains utilities for adding indirections and breaking up modules. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/Process.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <functional> +#include <map> +#include <memory> +#include <system_error> +#include <utility> +#include <vector> + +namespace llvm { + +class Constant; +class Function; +class FunctionType; +class GlobalAlias; +class GlobalVariable; +class Module; +class PointerType; +class Triple; +class Value; + +namespace orc { + +/// Base class for pools of compiler re-entry trampolines. +/// These trampolines are callable addresses that save all register state +/// before calling a supplied function to return the trampoline landing +/// address, then restore all state before jumping to that address. They +/// are used by various ORC APIs to support lazy compilation +class TrampolinePool { +public: + virtual ~TrampolinePool() {} + + /// Get an available trampoline address. + /// Returns an error if no trampoline can be created. + virtual Expected<JITTargetAddress> getTrampoline() = 0; + +private: + virtual void anchor(); +}; + +/// A trampoline pool for trampolines within the current process. +template <typename ORCABI> class LocalTrampolinePool : public TrampolinePool { +public: + using GetTrampolineLandingFunction = + std::function<JITTargetAddress(JITTargetAddress TrampolineAddr)>; + + /// Creates a LocalTrampolinePool with the given RunCallback function. + /// Returns an error if this function is unable to correctly allocate, write + /// and protect the resolver code block. + static Expected<std::unique_ptr<LocalTrampolinePool>> + Create(GetTrampolineLandingFunction GetTrampolineLanding) { + Error Err = Error::success(); + + auto LTP = std::unique_ptr<LocalTrampolinePool>( + new LocalTrampolinePool(std::move(GetTrampolineLanding), Err)); + + if (Err) + return std::move(Err); + return std::move(LTP); + } + + /// Get a free trampoline. Returns an error if one can not be provide (e.g. + /// because the pool is empty and can not be grown). + Expected<JITTargetAddress> getTrampoline() override { + std::lock_guard<std::mutex> Lock(LTPMutex); + if (AvailableTrampolines.empty()) { + if (auto Err = grow()) + return std::move(Err); + } + assert(!AvailableTrampolines.empty() && "Failed to grow trampoline pool"); + auto TrampolineAddr = AvailableTrampolines.back(); + AvailableTrampolines.pop_back(); + return TrampolineAddr; + } + + /// Returns the given trampoline to the pool for re-use. + void releaseTrampoline(JITTargetAddress TrampolineAddr) { + std::lock_guard<std::mutex> Lock(LTPMutex); + AvailableTrampolines.push_back(TrampolineAddr); + } + +private: + static JITTargetAddress reenter(void *TrampolinePoolPtr, void *TrampolineId) { + LocalTrampolinePool<ORCABI> *TrampolinePool = + static_cast<LocalTrampolinePool *>(TrampolinePoolPtr); + return TrampolinePool->GetTrampolineLanding(static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineId))); + } + + LocalTrampolinePool(GetTrampolineLandingFunction GetTrampolineLanding, + Error &Err) + : GetTrampolineLanding(std::move(GetTrampolineLanding)) { + + ErrorAsOutParameter _(&Err); + + /// Try to set up the resolver block. + std::error_code EC; + ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + ORCABI::ResolverCodeSize, nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) { + Err = errorCodeToError(EC); + return; + } + + ORCABI::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()), + &reenter, this); + + EC = sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(), + sys::Memory::MF_READ | + sys::Memory::MF_EXEC); + if (EC) { + Err = errorCodeToError(EC); + return; + } + } + + Error grow() { + assert(this->AvailableTrampolines.empty() && "Growing prematurely?"); + + std::error_code EC; + auto TrampolineBlock = + sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + sys::Process::getPageSizeEstimate(), nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) + return errorCodeToError(EC); + + unsigned NumTrampolines = + (sys::Process::getPageSizeEstimate() - ORCABI::PointerSize) / + ORCABI::TrampolineSize; + + uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base()); + ORCABI::writeTrampolines(TrampolineMem, ResolverBlock.base(), + NumTrampolines); + + for (unsigned I = 0; I < NumTrampolines; ++I) + this->AvailableTrampolines.push_back( + static_cast<JITTargetAddress>(reinterpret_cast<uintptr_t>( + TrampolineMem + (I * ORCABI::TrampolineSize)))); + + if (auto EC = sys::Memory::protectMappedMemory( + TrampolineBlock.getMemoryBlock(), + sys::Memory::MF_READ | sys::Memory::MF_EXEC)) + return errorCodeToError(EC); + + TrampolineBlocks.push_back(std::move(TrampolineBlock)); + return Error::success(); + } + + GetTrampolineLandingFunction GetTrampolineLanding; + + std::mutex LTPMutex; + sys::OwningMemoryBlock ResolverBlock; + std::vector<sys::OwningMemoryBlock> TrampolineBlocks; + std::vector<JITTargetAddress> AvailableTrampolines; +}; + +/// Target-independent base class for compile callback management. +class JITCompileCallbackManager { +public: + using CompileFunction = std::function<JITTargetAddress()>; + + virtual ~JITCompileCallbackManager() = default; + + /// Reserve a compile callback. + Expected<JITTargetAddress> getCompileCallback(CompileFunction Compile); + + /// Execute the callback for the given trampoline id. Called by the JIT + /// to compile functions on demand. + JITTargetAddress executeCompileCallback(JITTargetAddress TrampolineAddr); + +protected: + /// Construct a JITCompileCallbackManager. + JITCompileCallbackManager(std::unique_ptr<TrampolinePool> TP, + ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddress) + : TP(std::move(TP)), ES(ES), + CallbacksJD(ES.createJITDylib("<Callbacks>")), + ErrorHandlerAddress(ErrorHandlerAddress) {} + + void setTrampolinePool(std::unique_ptr<TrampolinePool> TP) { + this->TP = std::move(TP); + } + +private: + std::mutex CCMgrMutex; + std::unique_ptr<TrampolinePool> TP; + ExecutionSession &ES; + JITDylib &CallbacksJD; + JITTargetAddress ErrorHandlerAddress; + std::map<JITTargetAddress, SymbolStringPtr> AddrToSymbol; + size_t NextCallbackId = 0; +}; + +/// Manage compile callbacks for in-process JITs. +template <typename ORCABI> +class LocalJITCompileCallbackManager : public JITCompileCallbackManager { +public: + /// Create a new LocalJITCompileCallbackManager. + static Expected<std::unique_ptr<LocalJITCompileCallbackManager>> + Create(ExecutionSession &ES, JITTargetAddress ErrorHandlerAddress) { + Error Err = Error::success(); + auto CCMgr = std::unique_ptr<LocalJITCompileCallbackManager>( + new LocalJITCompileCallbackManager(ES, ErrorHandlerAddress, Err)); + if (Err) + return std::move(Err); + return std::move(CCMgr); + } + +private: + /// Construct a InProcessJITCompileCallbackManager. + /// @param ErrorHandlerAddress The address of an error handler in the target + /// process to be used if a compile callback fails. + LocalJITCompileCallbackManager(ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddress, + Error &Err) + : JITCompileCallbackManager(nullptr, ES, ErrorHandlerAddress) { + ErrorAsOutParameter _(&Err); + auto TP = LocalTrampolinePool<ORCABI>::Create( + [this](JITTargetAddress TrampolineAddr) { + return executeCompileCallback(TrampolineAddr); + }); + + if (!TP) { + Err = TP.takeError(); + return; + } + + setTrampolinePool(std::move(*TP)); + } +}; + +/// Base class for managing collections of named indirect stubs. +class IndirectStubsManager { +public: + /// Map type for initializing the manager. See init. + using StubInitsMap = StringMap<std::pair<JITTargetAddress, JITSymbolFlags>>; + + virtual ~IndirectStubsManager() = default; + + /// Create a single stub with the given name, target address and flags. + virtual Error createStub(StringRef StubName, JITTargetAddress StubAddr, + JITSymbolFlags StubFlags) = 0; + + /// Create StubInits.size() stubs with the given names, target + /// addresses, and flags. + virtual Error createStubs(const StubInitsMap &StubInits) = 0; + + /// Find the stub with the given name. If ExportedStubsOnly is true, + /// this will only return a result if the stub's flags indicate that it + /// is exported. + virtual JITEvaluatedSymbol findStub(StringRef Name, bool ExportedStubsOnly) = 0; + + /// Find the implementation-pointer for the stub. + virtual JITEvaluatedSymbol findPointer(StringRef Name) = 0; + + /// Change the value of the implementation pointer for the stub. + virtual Error updatePointer(StringRef Name, JITTargetAddress NewAddr) = 0; + +private: + virtual void anchor(); +}; + +/// IndirectStubsManager implementation for the host architecture, e.g. +/// OrcX86_64. (See OrcArchitectureSupport.h). +template <typename TargetT> +class LocalIndirectStubsManager : public IndirectStubsManager { +public: + Error createStub(StringRef StubName, JITTargetAddress StubAddr, + JITSymbolFlags StubFlags) override { + std::lock_guard<std::mutex> Lock(StubsMutex); + if (auto Err = reserveStubs(1)) + return Err; + + createStubInternal(StubName, StubAddr, StubFlags); + + return Error::success(); + } + + Error createStubs(const StubInitsMap &StubInits) override { + std::lock_guard<std::mutex> Lock(StubsMutex); + if (auto Err = reserveStubs(StubInits.size())) + return Err; + + for (auto &Entry : StubInits) + createStubInternal(Entry.first(), Entry.second.first, + Entry.second.second); + + return Error::success(); + } + + JITEvaluatedSymbol findStub(StringRef Name, bool ExportedStubsOnly) override { + std::lock_guard<std::mutex> Lock(StubsMutex); + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + void *StubAddr = IndirectStubsInfos[Key.first].getStub(Key.second); + assert(StubAddr && "Missing stub address"); + auto StubTargetAddr = + static_cast<JITTargetAddress>(reinterpret_cast<uintptr_t>(StubAddr)); + auto StubSymbol = JITEvaluatedSymbol(StubTargetAddr, I->second.second); + if (ExportedStubsOnly && !StubSymbol.getFlags().isExported()) + return nullptr; + return StubSymbol; + } + + JITEvaluatedSymbol findPointer(StringRef Name) override { + std::lock_guard<std::mutex> Lock(StubsMutex); + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + void *PtrAddr = IndirectStubsInfos[Key.first].getPtr(Key.second); + assert(PtrAddr && "Missing pointer address"); + auto PtrTargetAddr = + static_cast<JITTargetAddress>(reinterpret_cast<uintptr_t>(PtrAddr)); + return JITEvaluatedSymbol(PtrTargetAddr, I->second.second); + } + + Error updatePointer(StringRef Name, JITTargetAddress NewAddr) override { + using AtomicIntPtr = std::atomic<uintptr_t>; + + std::lock_guard<std::mutex> Lock(StubsMutex); + auto I = StubIndexes.find(Name); + assert(I != StubIndexes.end() && "No stub pointer for symbol"); + auto Key = I->second.first; + AtomicIntPtr *AtomicStubPtr = reinterpret_cast<AtomicIntPtr *>( + IndirectStubsInfos[Key.first].getPtr(Key.second)); + *AtomicStubPtr = static_cast<uintptr_t>(NewAddr); + return Error::success(); + } + +private: + Error reserveStubs(unsigned NumStubs) { + if (NumStubs <= FreeStubs.size()) + return Error::success(); + + unsigned NewStubsRequired = NumStubs - FreeStubs.size(); + unsigned NewBlockId = IndirectStubsInfos.size(); + typename TargetT::IndirectStubsInfo ISI; + if (auto Err = + TargetT::emitIndirectStubsBlock(ISI, NewStubsRequired, nullptr)) + return Err; + for (unsigned I = 0; I < ISI.getNumStubs(); ++I) + FreeStubs.push_back(std::make_pair(NewBlockId, I)); + IndirectStubsInfos.push_back(std::move(ISI)); + return Error::success(); + } + + void createStubInternal(StringRef StubName, JITTargetAddress InitAddr, + JITSymbolFlags StubFlags) { + auto Key = FreeStubs.back(); + FreeStubs.pop_back(); + *IndirectStubsInfos[Key.first].getPtr(Key.second) = + reinterpret_cast<void *>(static_cast<uintptr_t>(InitAddr)); + StubIndexes[StubName] = std::make_pair(Key, StubFlags); + } + + std::mutex StubsMutex; + std::vector<typename TargetT::IndirectStubsInfo> IndirectStubsInfos; + using StubKey = std::pair<uint16_t, uint16_t>; + std::vector<StubKey> FreeStubs; + StringMap<std::pair<StubKey, JITSymbolFlags>> StubIndexes; +}; + +/// Create a local compile callback manager. +/// +/// The given target triple will determine the ABI, and the given +/// ErrorHandlerAddress will be used by the resulting compile callback +/// manager if a compile callback fails. +Expected<std::unique_ptr<JITCompileCallbackManager>> +createLocalCompileCallbackManager(const Triple &T, ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddress); + +/// Create a local indriect stubs manager builder. +/// +/// The given target triple will determine the ABI. +std::function<std::unique_ptr<IndirectStubsManager>()> +createLocalIndirectStubsManagerBuilder(const Triple &T); + +/// Build a function pointer of FunctionType with the given constant +/// address. +/// +/// Usage example: Turn a trampoline address into a function pointer constant +/// for use in a stub. +Constant *createIRTypedAddress(FunctionType &FT, JITTargetAddress Addr); + +/// Create a function pointer with the given type, name, and initializer +/// in the given Module. +GlobalVariable *createImplPointer(PointerType &PT, Module &M, const Twine &Name, + Constant *Initializer); + +/// Turn a function declaration into a stub function that makes an +/// indirect call using the given function pointer. +void makeStub(Function &F, Value &ImplPointer); + +/// Promotes private symbols to global hidden, and renames to prevent clashes +/// with other promoted symbols. The same SymbolPromoter instance should be +/// used for all symbols to be added to a single JITDylib. +class SymbolLinkagePromoter { +public: + /// Promote symbols in the given module. Returns the set of global values + /// that have been renamed/promoted. + std::vector<GlobalValue *> operator()(Module &M); + +private: + unsigned NextId = 0; +}; + +/// Clone a function declaration into a new module. +/// +/// This function can be used as the first step towards creating a callback +/// stub (see makeStub), or moving a function body (see moveFunctionBody). +/// +/// If the VMap argument is non-null, a mapping will be added between F and +/// the new declaration, and between each of F's arguments and the new +/// declaration's arguments. This map can then be passed in to moveFunction to +/// move the function body if required. Note: When moving functions between +/// modules with these utilities, all decls should be cloned (and added to a +/// single VMap) before any bodies are moved. This will ensure that references +/// between functions all refer to the versions in the new module. +Function *cloneFunctionDecl(Module &Dst, const Function &F, + ValueToValueMapTy *VMap = nullptr); + +/// Move the body of function 'F' to a cloned function declaration in a +/// different module (See related cloneFunctionDecl). +/// +/// If the target function declaration is not supplied via the NewF parameter +/// then it will be looked up via the VMap. +/// +/// This will delete the body of function 'F' from its original parent module, +/// but leave its declaration. +void moveFunctionBody(Function &OrigF, ValueToValueMapTy &VMap, + ValueMaterializer *Materializer = nullptr, + Function *NewF = nullptr); + +/// Clone a global variable declaration into a new module. +GlobalVariable *cloneGlobalVariableDecl(Module &Dst, const GlobalVariable &GV, + ValueToValueMapTy *VMap = nullptr); + +/// Move global variable GV from its parent module to cloned global +/// declaration in a different module. +/// +/// If the target global declaration is not supplied via the NewGV parameter +/// then it will be looked up via the VMap. +/// +/// This will delete the initializer of GV from its original parent module, +/// but leave its declaration. +void moveGlobalVariableInitializer(GlobalVariable &OrigGV, + ValueToValueMapTy &VMap, + ValueMaterializer *Materializer = nullptr, + GlobalVariable *NewGV = nullptr); + +/// Clone a global alias declaration into a new module. +GlobalAlias *cloneGlobalAliasDecl(Module &Dst, const GlobalAlias &OrigA, + ValueToValueMapTy &VMap); + +/// Clone module flags metadata into the destination module. +void cloneModuleFlagsMetadata(Module &Dst, const Module &Src, + ValueToValueMapTy &VMap); + +} // end namespace orc + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h b/llvm/include/llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h new file mode 100644 index 000000000000..bcbd72e68f15 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h @@ -0,0 +1,129 @@ +//===- JITTargetMachineBuilder.h - Build TargetMachines for JIT -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A utitily for building TargetMachines for JITs. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_JITTARGETMACHINEBUILDER_H +#define LLVM_EXECUTIONENGINE_ORC_JITTARGETMACHINEBUILDER_H + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Triple.h" +#include "llvm/MC/SubtargetFeature.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/Error.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include <memory> +#include <string> +#include <vector> + +namespace llvm { +namespace orc { + +/// A utility class for building TargetMachines for JITs. +class JITTargetMachineBuilder { +public: + /// Create a JITTargetMachineBuilder based on the given triple. + /// + /// Note: TargetOptions is default-constructed, then EmulatedTLS and + /// ExplicitEmulatedTLS are set to true. If EmulatedTLS is not + /// required, these values should be reset before calling + /// createTargetMachine. + JITTargetMachineBuilder(Triple TT); + + /// Create a JITTargetMachineBuilder for the host system. + /// + /// Note: TargetOptions is default-constructed, then EmulatedTLS and + /// ExplicitEmulatedTLS are set to true. If EmulatedTLS is not + /// required, these values should be reset before calling + /// createTargetMachine. + static Expected<JITTargetMachineBuilder> detectHost(); + + /// Create a TargetMachine. + /// + /// This operation will fail if the requested target is not registered, + /// in which case see llvm/Support/TargetSelect.h. To JIT IR the Target and + /// the target's AsmPrinter must both be registered. To JIT assembly + /// (including inline and module level assembly) the target's AsmParser must + /// also be registered. + Expected<std::unique_ptr<TargetMachine>> createTargetMachine(); + + /// Get the default DataLayout for the target. + /// + /// Note: This is reasonably expensive, as it creates a temporary + /// TargetMachine instance under the hood. It is only suitable for use during + /// JIT setup. + Expected<DataLayout> getDefaultDataLayoutForTarget() { + auto TM = createTargetMachine(); + if (!TM) + return TM.takeError(); + return (*TM)->createDataLayout(); + } + + /// Set the CPU string. + JITTargetMachineBuilder &setCPU(std::string CPU) { + this->CPU = std::move(CPU); + return *this; + } + + /// Set the relocation model. + JITTargetMachineBuilder &setRelocationModel(Optional<Reloc::Model> RM) { + this->RM = std::move(RM); + return *this; + } + + /// Set the code model. + JITTargetMachineBuilder &setCodeModel(Optional<CodeModel::Model> CM) { + this->CM = std::move(CM); + return *this; + } + + /// Set the LLVM CodeGen optimization level. + JITTargetMachineBuilder &setCodeGenOptLevel(CodeGenOpt::Level OptLevel) { + this->OptLevel = OptLevel; + return *this; + } + + /// Add subtarget features. + JITTargetMachineBuilder & + addFeatures(const std::vector<std::string> &FeatureVec); + + /// Access subtarget features. + SubtargetFeatures &getFeatures() { return Features; } + + /// Access subtarget features. + const SubtargetFeatures &getFeatures() const { return Features; } + + /// Access TargetOptions. + TargetOptions &getOptions() { return Options; } + + /// Access TargetOptions. + const TargetOptions &getOptions() const { return Options; } + + /// Access Triple. + Triple &getTargetTriple() { return TT; } + + /// Access Triple. + const Triple &getTargetTriple() const { return TT; } + +private: + Triple TT; + std::string CPU; + SubtargetFeatures Features; + TargetOptions Options; + Optional<Reloc::Model> RM; + Optional<CodeModel::Model> CM; + CodeGenOpt::Level OptLevel = CodeGenOpt::None; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_JITTARGETMACHINEBUILDER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h b/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h new file mode 100644 index 000000000000..b1e47d77557c --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h @@ -0,0 +1,335 @@ +//===----- LLJIT.h -- An ORC-based JIT for compiling LLVM IR ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// An ORC-based JIT for compiling LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LLJIT_H +#define LLVM_EXECUTIONENGINE_ORC_LLJIT_H + +#include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/Support/ThreadPool.h" + +namespace llvm { +namespace orc { + +class LLJITBuilderState; +class LLLazyJITBuilderState; + +/// A pre-fabricated ORC JIT stack that can serve as an alternative to MCJIT. +/// +/// Create instances using LLJITBuilder. +class LLJIT { + template <typename, typename, typename> friend class LLJITBuilderSetters; + +public: + static Expected<std::unique_ptr<LLJIT>> Create(LLJITBuilderState &S); + + /// Destruct this instance. If a multi-threaded instance, waits for all + /// compile threads to complete. + ~LLJIT(); + + /// Returns the ExecutionSession for this instance. + ExecutionSession &getExecutionSession() { return *ES; } + + /// Returns a reference to the DataLayout for this instance. + const DataLayout &getDataLayout() const { return DL; } + + /// Returns a reference to the JITDylib representing the JIT'd main program. + JITDylib &getMainJITDylib() { return Main; } + + /// Returns the JITDylib with the given name, or nullptr if no JITDylib with + /// that name exists. + JITDylib *getJITDylibByName(StringRef Name) { + return ES->getJITDylibByName(Name); + } + + /// Create a new JITDylib with the given name and return a reference to it. + /// + /// JITDylib names must be unique. If the given name is derived from user + /// input or elsewhere in the environment then the client should check + /// (e.g. by calling getJITDylibByName) that the given name is not already in + /// use. + JITDylib &createJITDylib(std::string Name) { + return ES->createJITDylib(std::move(Name)); + } + + /// Convenience method for defining an absolute symbol. + Error defineAbsolute(StringRef Name, JITEvaluatedSymbol Address); + + /// Adds an IR module to the given JITDylib. + Error addIRModule(JITDylib &JD, ThreadSafeModule TSM); + + /// Adds an IR module to the Main JITDylib. + Error addIRModule(ThreadSafeModule TSM) { + return addIRModule(Main, std::move(TSM)); + } + + /// Adds an object file to the given JITDylib. + Error addObjectFile(JITDylib &JD, std::unique_ptr<MemoryBuffer> Obj); + + /// Adds an object file to the given JITDylib. + Error addObjectFile(std::unique_ptr<MemoryBuffer> Obj) { + return addObjectFile(Main, std::move(Obj)); + } + + /// Look up a symbol in JITDylib JD by the symbol's linker-mangled name (to + /// look up symbols based on their IR name use the lookup function instead). + Expected<JITEvaluatedSymbol> lookupLinkerMangled(JITDylib &JD, + StringRef Name); + + /// Look up a symbol in the main JITDylib by the symbol's linker-mangled name + /// (to look up symbols based on their IR name use the lookup function + /// instead). + Expected<JITEvaluatedSymbol> lookupLinkerMangled(StringRef Name) { + return lookupLinkerMangled(Main, Name); + } + + /// Look up a symbol in JITDylib JD based on its IR symbol name. + Expected<JITEvaluatedSymbol> lookup(JITDylib &JD, StringRef UnmangledName) { + return lookupLinkerMangled(JD, mangle(UnmangledName)); + } + + /// Look up a symbol in the main JITDylib based on its IR symbol name. + Expected<JITEvaluatedSymbol> lookup(StringRef UnmangledName) { + return lookup(Main, UnmangledName); + } + + /// Runs all not-yet-run static constructors. + Error runConstructors() { return CtorRunner.run(); } + + /// Runs all not-yet-run static destructors. + Error runDestructors() { return DtorRunner.run(); } + + /// Returns a reference to the ObjLinkingLayer + ObjectLayer &getObjLinkingLayer() { return *ObjLinkingLayer; } + +protected: + static std::unique_ptr<ObjectLayer> + createObjectLinkingLayer(LLJITBuilderState &S, ExecutionSession &ES); + + static Expected<IRCompileLayer::CompileFunction> + createCompileFunction(LLJITBuilderState &S, JITTargetMachineBuilder JTMB); + + /// Create an LLJIT instance with a single compile thread. + LLJIT(LLJITBuilderState &S, Error &Err); + + std::string mangle(StringRef UnmangledName); + + Error applyDataLayout(Module &M); + + void recordCtorDtors(Module &M); + + std::unique_ptr<ExecutionSession> ES; + JITDylib &Main; + + DataLayout DL; + std::unique_ptr<ThreadPool> CompileThreads; + + std::unique_ptr<ObjectLayer> ObjLinkingLayer; + std::unique_ptr<IRCompileLayer> CompileLayer; + + CtorDtorRunner CtorRunner, DtorRunner; +}; + +/// An extended version of LLJIT that supports lazy function-at-a-time +/// compilation of LLVM IR. +class LLLazyJIT : public LLJIT { + template <typename, typename, typename> friend class LLJITBuilderSetters; + +public: + + /// Set an IR transform (e.g. pass manager pipeline) to run on each function + /// when it is compiled. + void setLazyCompileTransform(IRTransformLayer::TransformFunction Transform) { + TransformLayer->setTransform(std::move(Transform)); + } + + /// Sets the partition function. + void + setPartitionFunction(CompileOnDemandLayer::PartitionFunction Partition) { + CODLayer->setPartitionFunction(std::move(Partition)); + } + + /// Add a module to be lazily compiled to JITDylib JD. + Error addLazyIRModule(JITDylib &JD, ThreadSafeModule M); + + /// Add a module to be lazily compiled to the main JITDylib. + Error addLazyIRModule(ThreadSafeModule M) { + return addLazyIRModule(Main, std::move(M)); + } + +private: + + // Create a single-threaded LLLazyJIT instance. + LLLazyJIT(LLLazyJITBuilderState &S, Error &Err); + + std::unique_ptr<LazyCallThroughManager> LCTMgr; + std::unique_ptr<IRTransformLayer> TransformLayer; + std::unique_ptr<CompileOnDemandLayer> CODLayer; +}; + +class LLJITBuilderState { +public: + using ObjectLinkingLayerCreator = std::function<std::unique_ptr<ObjectLayer>( + ExecutionSession &, const Triple &TT)>; + + using CompileFunctionCreator = + std::function<Expected<IRCompileLayer::CompileFunction>( + JITTargetMachineBuilder JTMB)>; + + std::unique_ptr<ExecutionSession> ES; + Optional<JITTargetMachineBuilder> JTMB; + ObjectLinkingLayerCreator CreateObjectLinkingLayer; + CompileFunctionCreator CreateCompileFunction; + unsigned NumCompileThreads = 0; + + /// Called prior to JIT class construcion to fix up defaults. + Error prepareForConstruction(); +}; + +template <typename JITType, typename SetterImpl, typename State> +class LLJITBuilderSetters { +public: + /// Set the JITTargetMachineBuilder for this instance. + /// + /// If this method is not called, JITTargetMachineBuilder::detectHost will be + /// used to construct a default target machine builder for the host platform. + SetterImpl &setJITTargetMachineBuilder(JITTargetMachineBuilder JTMB) { + impl().JTMB = std::move(JTMB); + return impl(); + } + + /// Return a reference to the JITTargetMachineBuilder. + /// + Optional<JITTargetMachineBuilder> &getJITTargetMachineBuilder() { + return impl().JTMB; + } + + /// Set an ObjectLinkingLayer creation function. + /// + /// If this method is not called, a default creation function will be used + /// that will construct an RTDyldObjectLinkingLayer. + SetterImpl &setObjectLinkingLayerCreator( + LLJITBuilderState::ObjectLinkingLayerCreator CreateObjectLinkingLayer) { + impl().CreateObjectLinkingLayer = std::move(CreateObjectLinkingLayer); + return impl(); + } + + /// Set a CompileFunctionCreator. + /// + /// If this method is not called, a default creation function wil be used + /// that will construct a basic IR compile function that is compatible with + /// the selected number of threads (SimpleCompiler for '0' compile threads, + /// ConcurrentIRCompiler otherwise). + SetterImpl &setCompileFunctionCreator( + LLJITBuilderState::CompileFunctionCreator CreateCompileFunction) { + impl().CreateCompileFunction = std::move(CreateCompileFunction); + return impl(); + } + + /// Set the number of compile threads to use. + /// + /// If set to zero, compilation will be performed on the execution thread when + /// JITing in-process. If set to any other number N, a thread pool of N + /// threads will be created for compilation. + /// + /// If this method is not called, behavior will be as if it were called with + /// a zero argument. + SetterImpl &setNumCompileThreads(unsigned NumCompileThreads) { + impl().NumCompileThreads = NumCompileThreads; + return impl(); + } + + /// Create an instance of the JIT. + Expected<std::unique_ptr<JITType>> create() { + if (auto Err = impl().prepareForConstruction()) + return std::move(Err); + + Error Err = Error::success(); + std::unique_ptr<JITType> J(new JITType(impl(), Err)); + if (Err) + return std::move(Err); + return std::move(J); + } + +protected: + SetterImpl &impl() { return static_cast<SetterImpl &>(*this); } +}; + +/// Constructs LLJIT instances. +class LLJITBuilder + : public LLJITBuilderState, + public LLJITBuilderSetters<LLJIT, LLJITBuilder, LLJITBuilderState> {}; + +class LLLazyJITBuilderState : public LLJITBuilderState { + friend class LLLazyJIT; + +public: + using IndirectStubsManagerBuilderFunction = + std::function<std::unique_ptr<IndirectStubsManager>()>; + + Triple TT; + JITTargetAddress LazyCompileFailureAddr = 0; + std::unique_ptr<LazyCallThroughManager> LCTMgr; + IndirectStubsManagerBuilderFunction ISMBuilder; + + Error prepareForConstruction(); +}; + +template <typename JITType, typename SetterImpl, typename State> +class LLLazyJITBuilderSetters + : public LLJITBuilderSetters<JITType, SetterImpl, State> { +public: + /// Set the address in the target address to call if a lazy compile fails. + /// + /// If this method is not called then the value will default to 0. + SetterImpl &setLazyCompileFailureAddr(JITTargetAddress Addr) { + this->impl().LazyCompileFailureAddr = Addr; + return this->impl(); + } + + /// Set the lazy-callthrough manager. + /// + /// If this method is not called then a default, in-process lazy callthrough + /// manager for the host platform will be used. + SetterImpl & + setLazyCallthroughManager(std::unique_ptr<LazyCallThroughManager> LCTMgr) { + this->impl().LCTMgr = std::move(LCTMgr); + return this->impl(); + } + + /// Set the IndirectStubsManager builder function. + /// + /// If this method is not called then a default, in-process + /// IndirectStubsManager builder for the host platform will be used. + SetterImpl &setIndirectStubsManagerBuilder( + LLLazyJITBuilderState::IndirectStubsManagerBuilderFunction ISMBuilder) { + this->impl().ISMBuilder = std::move(ISMBuilder); + return this->impl(); + } +}; + +/// Constructs LLLazyJIT instances. +class LLLazyJITBuilder + : public LLLazyJITBuilderState, + public LLLazyJITBuilderSetters<LLLazyJIT, LLLazyJITBuilder, + LLLazyJITBuilderState> {}; + +} // End namespace orc +} // End namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LLJIT_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/LambdaResolver.h b/llvm/include/llvm/ExecutionEngine/Orc/LambdaResolver.h new file mode 100644 index 000000000000..b31914f12a0d --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/LambdaResolver.h @@ -0,0 +1,84 @@ +//===- LambdaResolverMM - Redirect symbol lookup via a functor --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines a RuntimeDyld::SymbolResolver subclass that uses a user-supplied +// functor for symbol resolution. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LAMBDARESOLVER_H +#define LLVM_EXECUTIONENGINE_ORC_LAMBDARESOLVER_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/OrcV1Deprecation.h" +#include <memory> + +namespace llvm { +namespace orc { + +template <typename DylibLookupFtorT, typename ExternalLookupFtorT> +class LambdaResolver : public LegacyJITSymbolResolver { +public: + LLVM_ATTRIBUTE_DEPRECATED( + LambdaResolver(DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor), + "ORCv1 utilities (including resolvers) are deprecated and will be " + "removed " + "in the next release. Please use ORCv2 (see docs/ORCv2.rst)"); + + LambdaResolver(ORCv1DeprecationAcknowledgement, + DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor) + : DylibLookupFtor(DylibLookupFtor), + ExternalLookupFtor(ExternalLookupFtor) {} + + JITSymbol findSymbolInLogicalDylib(const std::string &Name) final { + return DylibLookupFtor(Name); + } + + JITSymbol findSymbol(const std::string &Name) final { + return ExternalLookupFtor(Name); + } + +private: + DylibLookupFtorT DylibLookupFtor; + ExternalLookupFtorT ExternalLookupFtor; +}; + +template <typename DylibLookupFtorT, typename ExternalLookupFtorT> +LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>::LambdaResolver( + DylibLookupFtorT DylibLookupFtor, ExternalLookupFtorT ExternalLookupFtor) + : DylibLookupFtor(DylibLookupFtor), ExternalLookupFtor(ExternalLookupFtor) { +} + +template <typename DylibLookupFtorT, + typename ExternalLookupFtorT> +std::shared_ptr<LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>> +createLambdaResolver(DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor) { + using LR = LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>; + return std::make_unique<LR>(std::move(DylibLookupFtor), + std::move(ExternalLookupFtor)); +} + +template <typename DylibLookupFtorT, typename ExternalLookupFtorT> +std::shared_ptr<LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>> +createLambdaResolver(ORCv1DeprecationAcknowledgement, + DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor) { + using LR = LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>; + return std::make_unique<LR>(AcknowledgeORCv1Deprecation, + std::move(DylibLookupFtor), + std::move(ExternalLookupFtor)); +} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LAMBDARESOLVER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Layer.h b/llvm/include/llvm/ExecutionEngine/Orc/Layer.h new file mode 100644 index 000000000000..8f9bd704395e --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/Layer.h @@ -0,0 +1,166 @@ +//===---------------- Layer.h -- Layer interfaces --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Layer interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LAYER_H +#define LLVM_EXECUTIONENGINE_ORC_LAYER_H + +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/MemoryBuffer.h" + +namespace llvm { +namespace orc { + +/// Interface for layers that accept LLVM IR. +class IRLayer { +public: + IRLayer(ExecutionSession &ES); + virtual ~IRLayer(); + + /// Returns the ExecutionSession for this layer. + ExecutionSession &getExecutionSession() { return ES; } + + /// Sets the CloneToNewContextOnEmit flag (false by default). + /// + /// When set, IR modules added to this layer will be cloned on to a new + /// context before emit is called. This can be used by clients who want + /// to load all IR using one LLVMContext (to save memory via type and + /// constant uniquing), but want to move Modules to fresh contexts before + /// compiling them to enable concurrent compilation. + /// Single threaded clients, or clients who load every module on a new + /// context, need not set this. + void setCloneToNewContextOnEmit(bool CloneToNewContextOnEmit) { + this->CloneToNewContextOnEmit = CloneToNewContextOnEmit; + } + + /// Returns the current value of the CloneToNewContextOnEmit flag. + bool getCloneToNewContextOnEmit() const { return CloneToNewContextOnEmit; } + + /// Adds a MaterializationUnit representing the given IR to the given + /// JITDylib. + virtual Error add(JITDylib &JD, ThreadSafeModule TSM, + VModuleKey K = VModuleKey()); + + /// Emit should materialize the given IR. + virtual void emit(MaterializationResponsibility R, ThreadSafeModule TSM) = 0; + +private: + bool CloneToNewContextOnEmit = false; + ExecutionSession &ES; +}; + +/// IRMaterializationUnit is a convenient base class for MaterializationUnits +/// wrapping LLVM IR. Represents materialization responsibility for all symbols +/// in the given module. If symbols are overridden by other definitions, then +/// their linkage is changed to available-externally. +class IRMaterializationUnit : public MaterializationUnit { +public: + using SymbolNameToDefinitionMap = std::map<SymbolStringPtr, GlobalValue *>; + + /// Create an IRMaterializationLayer. Scans the module to build the + /// SymbolFlags and SymbolToDefinition maps. + IRMaterializationUnit(ExecutionSession &ES, ThreadSafeModule TSM, + VModuleKey K); + + /// Create an IRMaterializationLayer from a module, and pre-existing + /// SymbolFlags and SymbolToDefinition maps. The maps must provide + /// entries for each definition in M. + /// This constructor is useful for delegating work from one + /// IRMaterializationUnit to another. + IRMaterializationUnit(ThreadSafeModule TSM, VModuleKey K, + SymbolFlagsMap SymbolFlags, + SymbolNameToDefinitionMap SymbolToDefinition); + + /// Return the ModuleIdentifier as the name for this MaterializationUnit. + StringRef getName() const override; + + const ThreadSafeModule &getModule() const { return TSM; } + +protected: + ThreadSafeModule TSM; + SymbolNameToDefinitionMap SymbolToDefinition; + +private: + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; +}; + +/// MaterializationUnit that materializes modules by calling the 'emit' method +/// on the given IRLayer. +class BasicIRLayerMaterializationUnit : public IRMaterializationUnit { +public: + BasicIRLayerMaterializationUnit(IRLayer &L, VModuleKey K, + ThreadSafeModule TSM); + +private: + + void materialize(MaterializationResponsibility R) override; + + IRLayer &L; + VModuleKey K; +}; + +/// Interface for Layers that accept object files. +class ObjectLayer { +public: + ObjectLayer(ExecutionSession &ES); + virtual ~ObjectLayer(); + + /// Returns the execution session for this layer. + ExecutionSession &getExecutionSession() { return ES; } + + /// Adds a MaterializationUnit representing the given IR to the given + /// JITDylib. + virtual Error add(JITDylib &JD, std::unique_ptr<MemoryBuffer> O, + VModuleKey K = VModuleKey()); + + /// Emit should materialize the given IR. + virtual void emit(MaterializationResponsibility R, + std::unique_ptr<MemoryBuffer> O) = 0; + +private: + ExecutionSession &ES; +}; + +/// Materializes the given object file (represented by a MemoryBuffer +/// instance) by calling 'emit' on the given ObjectLayer. +class BasicObjectLayerMaterializationUnit : public MaterializationUnit { +public: + static Expected<std::unique_ptr<BasicObjectLayerMaterializationUnit>> + Create(ObjectLayer &L, VModuleKey K, std::unique_ptr<MemoryBuffer> O); + + BasicObjectLayerMaterializationUnit(ObjectLayer &L, VModuleKey K, + std::unique_ptr<MemoryBuffer> O, + SymbolFlagsMap SymbolFlags); + + /// Return the buffer's identifier as the name for this MaterializationUnit. + StringRef getName() const override; + +private: + + void materialize(MaterializationResponsibility R) override; + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; + + ObjectLayer &L; + std::unique_ptr<MemoryBuffer> O; +}; + +/// Returns a SymbolFlagsMap for the object file represented by the given +/// buffer, or an error if the buffer does not contain a valid object file. +// FIXME: Maybe move to Core.h? +Expected<SymbolFlagsMap> getObjectSymbolFlags(ExecutionSession &ES, + MemoryBufferRef ObjBuffer); + +} // End namespace orc +} // End namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h new file mode 100644 index 000000000000..b67a9feed523 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h @@ -0,0 +1,267 @@ +//===- LazyEmittingLayer.h - Lazily emit IR to lower JIT layers -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definition for a lazy-emitting layer for the JIT. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LAZYEMITTINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_LAZYEMITTINGLAYER_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Mangler.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <list> +#include <memory> +#include <string> + +namespace llvm { +namespace orc { + +/// Lazy-emitting IR layer. +/// +/// This layer accepts LLVM IR Modules (via addModule) but does not +/// immediately emit them the layer below. Instead, emission to the base layer +/// is deferred until the first time the client requests the address (via +/// JITSymbol::getAddress) for a symbol contained in this layer. +template <typename BaseLayerT> class LazyEmittingLayer { +private: + class EmissionDeferredModule { + public: + EmissionDeferredModule(VModuleKey K, std::unique_ptr<Module> M) + : K(std::move(K)), M(std::move(M)) {} + + JITSymbol find(StringRef Name, bool ExportedSymbolsOnly, BaseLayerT &B) { + switch (EmitState) { + case NotEmitted: + if (auto GV = searchGVs(Name, ExportedSymbolsOnly)) { + JITSymbolFlags Flags = JITSymbolFlags::fromGlobalValue(*GV); + auto GetAddress = [this, ExportedSymbolsOnly, Name = Name.str(), + &B]() -> Expected<JITTargetAddress> { + if (this->EmitState == Emitting) + return 0; + else if (this->EmitState == NotEmitted) { + this->EmitState = Emitting; + if (auto Err = this->emitToBaseLayer(B)) + return std::move(Err); + this->EmitState = Emitted; + } + if (auto Sym = B.findSymbolIn(K, Name, ExportedSymbolsOnly)) + return Sym.getAddress(); + else if (auto Err = Sym.takeError()) + return std::move(Err); + else + llvm_unreachable("Successful symbol lookup should return " + "definition address here"); + }; + return JITSymbol(std::move(GetAddress), Flags); + } else + return nullptr; + case Emitting: + // Calling "emit" can trigger a recursive call to 'find' (e.g. to check + // for pre-existing definitions of common-symbol), but any symbol in + // this module would already have been found internally (in the + // RuntimeDyld that did the lookup), so just return a nullptr here. + return nullptr; + case Emitted: + return B.findSymbolIn(K, Name, ExportedSymbolsOnly); + } + llvm_unreachable("Invalid emit-state."); + } + + Error removeModuleFromBaseLayer(BaseLayerT& BaseLayer) { + return EmitState != NotEmitted ? BaseLayer.removeModule(K) + : Error::success(); + } + + void emitAndFinalize(BaseLayerT &BaseLayer) { + assert(EmitState != Emitting && + "Cannot emitAndFinalize while already emitting"); + if (EmitState == NotEmitted) { + EmitState = Emitting; + emitToBaseLayer(BaseLayer); + EmitState = Emitted; + } + BaseLayer.emitAndFinalize(K); + } + + private: + + const GlobalValue* searchGVs(StringRef Name, + bool ExportedSymbolsOnly) const { + // FIXME: We could clean all this up if we had a way to reliably demangle + // names: We could just demangle name and search, rather than + // mangling everything else. + + // If we have already built the mangled name set then just search it. + if (MangledSymbols) { + auto VI = MangledSymbols->find(Name); + if (VI == MangledSymbols->end()) + return nullptr; + auto GV = VI->second; + if (!ExportedSymbolsOnly || GV->hasDefaultVisibility()) + return GV; + return nullptr; + } + + // If we haven't built the mangled name set yet, try to build it. As an + // optimization this will leave MangledNames set to nullptr if we find + // Name in the process of building the set. + return buildMangledSymbols(Name, ExportedSymbolsOnly); + } + + Error emitToBaseLayer(BaseLayerT &BaseLayer) { + // We don't need the mangled names set any more: Once we've emitted this + // to the base layer we'll just look for symbols there. + MangledSymbols.reset(); + return BaseLayer.addModule(std::move(K), std::move(M)); + } + + // If the mangled name of the given GlobalValue matches the given search + // name (and its visibility conforms to the ExportedSymbolsOnly flag) then + // return the symbol. Otherwise, add the mangled name to the Names map and + // return nullptr. + const GlobalValue* addGlobalValue(StringMap<const GlobalValue*> &Names, + const GlobalValue &GV, + const Mangler &Mang, StringRef SearchName, + bool ExportedSymbolsOnly) const { + // Modules don't "provide" decls or common symbols. + if (GV.isDeclaration() || GV.hasCommonLinkage()) + return nullptr; + + // Mangle the GV name. + std::string MangledName; + { + raw_string_ostream MangledNameStream(MangledName); + Mang.getNameWithPrefix(MangledNameStream, &GV, false); + } + + // Check whether this is the name we were searching for, and if it is then + // bail out early. + if (MangledName == SearchName) + if (!ExportedSymbolsOnly || GV.hasDefaultVisibility()) + return &GV; + + // Otherwise add this to the map for later. + Names[MangledName] = &GV; + return nullptr; + } + + // Build the MangledSymbols map. Bails out early (with MangledSymbols left set + // to nullptr) if the given SearchName is found while building the map. + const GlobalValue* buildMangledSymbols(StringRef SearchName, + bool ExportedSymbolsOnly) const { + assert(!MangledSymbols && "Mangled symbols map already exists?"); + + auto Symbols = std::make_unique<StringMap<const GlobalValue*>>(); + + Mangler Mang; + + for (const auto &GO : M->global_objects()) + if (auto GV = addGlobalValue(*Symbols, GO, Mang, SearchName, + ExportedSymbolsOnly)) + return GV; + + MangledSymbols = std::move(Symbols); + return nullptr; + } + + enum { NotEmitted, Emitting, Emitted } EmitState = NotEmitted; + VModuleKey K; + std::unique_ptr<Module> M; + mutable std::unique_ptr<StringMap<const GlobalValue*>> MangledSymbols; + }; + + BaseLayerT &BaseLayer; + std::map<VModuleKey, std::unique_ptr<EmissionDeferredModule>> ModuleMap; + +public: + + /// Construct a lazy emitting layer. + LLVM_ATTRIBUTE_DEPRECATED( + LazyEmittingLayer(BaseLayerT &BaseLayer), + "ORCv1 layers (including LazyEmittingLayer) are deprecated. Please use " + "ORCv2, where lazy emission is the default"); + + /// Construct a lazy emitting layer. + LazyEmittingLayer(ORCv1DeprecationAcknowledgement, BaseLayerT &BaseLayer) + : BaseLayer(BaseLayer) {} + + /// Add the given module to the lazy emitting layer. + Error addModule(VModuleKey K, std::unique_ptr<Module> M) { + assert(!ModuleMap.count(K) && "VModuleKey K already in use"); + ModuleMap[K] = + std::make_unique<EmissionDeferredModule>(std::move(K), std::move(M)); + return Error::success(); + } + + /// Remove the module represented by the given handle. + /// + /// This method will free the memory associated with the given module, both + /// in this layer, and the base layer. + Error removeModule(VModuleKey K) { + auto I = ModuleMap.find(K); + assert(I != ModuleMap.end() && "VModuleKey K not valid here"); + auto EDM = std::move(I.second); + ModuleMap.erase(I); + return EDM->removeModuleFromBaseLayer(BaseLayer); + } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { + // Look for the symbol among existing definitions. + if (auto Symbol = BaseLayer.findSymbol(Name, ExportedSymbolsOnly)) + return Symbol; + + // If not found then search the deferred modules. If any of these contain a + // definition of 'Name' then they will return a JITSymbol that will emit + // the corresponding module when the symbol address is requested. + for (auto &KV : ModuleMap) + if (auto Symbol = KV.second->find(Name, ExportedSymbolsOnly, BaseLayer)) + return Symbol; + + // If no definition found anywhere return a null symbol. + return nullptr; + } + + /// Get the address of the given symbol in the context of the of + /// compiled modules represented by the key K. + JITSymbol findSymbolIn(VModuleKey K, const std::string &Name, + bool ExportedSymbolsOnly) { + assert(ModuleMap.count(K) && "VModuleKey K not valid here"); + return ModuleMap[K]->find(Name, ExportedSymbolsOnly, BaseLayer); + } + + /// Immediately emit and finalize the module represented by the given + /// key. + Error emitAndFinalize(VModuleKey K) { + assert(ModuleMap.count(K) && "VModuleKey K not valid here"); + return ModuleMap[K]->emitAndFinalize(BaseLayer); + } +}; + +template <typename BaseLayerT> +LazyEmittingLayer<BaseLayerT>::LazyEmittingLayer(BaseLayerT &BaseLayer) + : BaseLayer(BaseLayer) {} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LAZYEMITTINGLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h b/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h new file mode 100644 index 000000000000..311ed59b1549 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h @@ -0,0 +1,197 @@ +//===------ LazyReexports.h -- Utilities for lazy reexports -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Lazy re-exports are similar to normal re-exports, except that for callable +// symbols the definitions are replaced with trampolines that will look up and +// call through to the re-exported symbol at runtime. This can be used to +// enable lazy compilation. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LAZYREEXPORTS_H +#define LLVM_EXECUTIONENGINE_ORC_LAZYREEXPORTS_H + +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" +#include "llvm/ExecutionEngine/Orc/Speculation.h" + +namespace llvm { + +class Triple; + +namespace orc { + +/// Manages a set of 'lazy call-through' trampolines. These are compiler +/// re-entry trampolines that are pre-bound to look up a given symbol in a given +/// JITDylib, then jump to that address. Since compilation of symbols is +/// triggered on first lookup, these call-through trampolines can be used to +/// implement lazy compilation. +/// +/// The easiest way to construct these call-throughs is using the lazyReexport +/// function. +class LazyCallThroughManager { +public: + /// Clients will want to take some action on first resolution, e.g. updating + /// a stub pointer. Instances of this class can be used to implement this. + class NotifyResolvedFunction { + public: + virtual ~NotifyResolvedFunction() {} + + /// Called the first time a lazy call through is executed and the target + /// symbol resolved. + virtual Error operator()(JITDylib &SourceJD, + const SymbolStringPtr &SymbolName, + JITTargetAddress ResolvedAddr) = 0; + + private: + virtual void anchor(); + }; + + template <typename NotifyResolvedImpl> + class NotifyResolvedFunctionImpl : public NotifyResolvedFunction { + public: + NotifyResolvedFunctionImpl(NotifyResolvedImpl NotifyResolved) + : NotifyResolved(std::move(NotifyResolved)) {} + Error operator()(JITDylib &SourceJD, const SymbolStringPtr &SymbolName, + JITTargetAddress ResolvedAddr) { + return NotifyResolved(SourceJD, SymbolName, ResolvedAddr); + } + + private: + NotifyResolvedImpl NotifyResolved; + }; + + /// Create a shared NotifyResolvedFunction from a given type that is + /// callable with the correct signature. + template <typename NotifyResolvedImpl> + static std::unique_ptr<NotifyResolvedFunction> + createNotifyResolvedFunction(NotifyResolvedImpl NotifyResolved) { + return std::make_unique<NotifyResolvedFunctionImpl<NotifyResolvedImpl>>( + std::move(NotifyResolved)); + } + + // Return a free call-through trampoline and bind it to look up and call + // through to the given symbol. + Expected<JITTargetAddress> getCallThroughTrampoline( + JITDylib &SourceJD, SymbolStringPtr SymbolName, + std::shared_ptr<NotifyResolvedFunction> NotifyResolved); + +protected: + LazyCallThroughManager(ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddr, + std::unique_ptr<TrampolinePool> TP); + + JITTargetAddress callThroughToSymbol(JITTargetAddress TrampolineAddr); + + void setTrampolinePool(std::unique_ptr<TrampolinePool> TP) { + this->TP = std::move(TP); + } + +private: + using ReexportsMap = + std::map<JITTargetAddress, std::pair<JITDylib *, SymbolStringPtr>>; + + using NotifiersMap = + std::map<JITTargetAddress, std::shared_ptr<NotifyResolvedFunction>>; + + std::mutex LCTMMutex; + ExecutionSession &ES; + JITTargetAddress ErrorHandlerAddr; + std::unique_ptr<TrampolinePool> TP; + ReexportsMap Reexports; + NotifiersMap Notifiers; +}; + +/// A lazy call-through manager that builds trampolines in the current process. +class LocalLazyCallThroughManager : public LazyCallThroughManager { +private: + LocalLazyCallThroughManager(ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddr) + : LazyCallThroughManager(ES, ErrorHandlerAddr, nullptr) {} + + template <typename ORCABI> Error init() { + auto TP = LocalTrampolinePool<ORCABI>::Create( + [this](JITTargetAddress TrampolineAddr) { + return callThroughToSymbol(TrampolineAddr); + }); + + if (!TP) + return TP.takeError(); + + setTrampolinePool(std::move(*TP)); + return Error::success(); + } + +public: + /// Create a LocalLazyCallThroughManager using the given ABI. See + /// createLocalLazyCallThroughManager. + template <typename ORCABI> + static Expected<std::unique_ptr<LocalLazyCallThroughManager>> + Create(ExecutionSession &ES, JITTargetAddress ErrorHandlerAddr) { + auto LLCTM = std::unique_ptr<LocalLazyCallThroughManager>( + new LocalLazyCallThroughManager(ES, ErrorHandlerAddr)); + + if (auto Err = LLCTM->init<ORCABI>()) + return std::move(Err); + + return std::move(LLCTM); + } +}; + +/// Create a LocalLazyCallThroughManager from the given triple and execution +/// session. +Expected<std::unique_ptr<LazyCallThroughManager>> +createLocalLazyCallThroughManager(const Triple &T, ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddr); + +/// A materialization unit that builds lazy re-exports. These are callable +/// entry points that call through to the given symbols. +/// Unlike a 'true' re-export, the address of the lazy re-export will not +/// match the address of the re-exported symbol, but calling it will behave +/// the same as calling the re-exported symbol. +class LazyReexportsMaterializationUnit : public MaterializationUnit { +public: + LazyReexportsMaterializationUnit(LazyCallThroughManager &LCTManager, + IndirectStubsManager &ISManager, + JITDylib &SourceJD, + SymbolAliasMap CallableAliases, + ImplSymbolMap *SrcJDLoc, VModuleKey K); + + StringRef getName() const override; + +private: + void materialize(MaterializationResponsibility R) override; + void discard(const JITDylib &JD, const SymbolStringPtr &Name) override; + static SymbolFlagsMap extractFlags(const SymbolAliasMap &Aliases); + + LazyCallThroughManager &LCTManager; + IndirectStubsManager &ISManager; + JITDylib &SourceJD; + SymbolAliasMap CallableAliases; + std::shared_ptr<LazyCallThroughManager::NotifyResolvedFunction> + NotifyResolved; + ImplSymbolMap *AliaseeTable; +}; + +/// Define lazy-reexports based on the given SymbolAliasMap. Each lazy re-export +/// is a callable symbol that will look up and dispatch to the given aliasee on +/// first call. All subsequent calls will go directly to the aliasee. +inline std::unique_ptr<LazyReexportsMaterializationUnit> +lazyReexports(LazyCallThroughManager &LCTManager, + IndirectStubsManager &ISManager, JITDylib &SourceJD, + SymbolAliasMap CallableAliases, ImplSymbolMap *SrcJDLoc = nullptr, + VModuleKey K = VModuleKey()) { + return std::make_unique<LazyReexportsMaterializationUnit>( + LCTManager, ISManager, SourceJD, std::move(CallableAliases), SrcJDLoc, + std::move(K)); +} + +} // End namespace orc +} // End namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LAZYREEXPORTS_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h b/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h new file mode 100644 index 000000000000..148e260c9569 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h @@ -0,0 +1,215 @@ +//===--- Legacy.h -- Adapters for ExecutionEngine API interop ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains core ORC APIs. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_LEGACY_H +#define LLVM_EXECUTIONENGINE_ORC_LEGACY_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" + +namespace llvm { +namespace orc { + +/// SymbolResolver is a composable interface for looking up symbol flags +/// and addresses using the AsynchronousSymbolQuery type. It will +/// eventually replace the LegacyJITSymbolResolver interface as the +/// stardard ORC symbol resolver type. +/// +/// FIXME: SymbolResolvers should go away and be replaced with VSOs with +/// defenition generators. +class SymbolResolver { +public: + virtual ~SymbolResolver() = default; + + /// Returns the subset of the given symbols that the caller is responsible for + /// materializing. + virtual SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) = 0; + + /// For each symbol in Symbols that can be found, assigns that symbols + /// value in Query. Returns the set of symbols that could not be found. + virtual SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query, + SymbolNameSet Symbols) = 0; + +private: + virtual void anchor(); +}; + +/// Implements SymbolResolver with a pair of supplied function objects +/// for convenience. See createSymbolResolver. +template <typename GetResponsibilitySetFn, typename LookupFn> +class LambdaSymbolResolver final : public SymbolResolver { +public: + template <typename GetResponsibilitySetFnRef, typename LookupFnRef> + LambdaSymbolResolver(GetResponsibilitySetFnRef &&GetResponsibilitySet, + LookupFnRef &&Lookup) + : GetResponsibilitySet( + std::forward<GetResponsibilitySetFnRef>(GetResponsibilitySet)), + Lookup(std::forward<LookupFnRef>(Lookup)) {} + + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final { + return GetResponsibilitySet(Symbols); + } + + SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query, + SymbolNameSet Symbols) final { + return Lookup(std::move(Query), std::move(Symbols)); + } + +private: + GetResponsibilitySetFn GetResponsibilitySet; + LookupFn Lookup; +}; + +/// Creates a SymbolResolver implementation from the pair of supplied +/// function objects. +template <typename GetResponsibilitySetFn, typename LookupFn> +std::unique_ptr<LambdaSymbolResolver< + typename std::remove_cv< + typename std::remove_reference<GetResponsibilitySetFn>::type>::type, + typename std::remove_cv< + typename std::remove_reference<LookupFn>::type>::type>> +createSymbolResolver(GetResponsibilitySetFn &&GetResponsibilitySet, + LookupFn &&Lookup) { + using LambdaSymbolResolverImpl = LambdaSymbolResolver< + typename std::remove_cv< + typename std::remove_reference<GetResponsibilitySetFn>::type>::type, + typename std::remove_cv< + typename std::remove_reference<LookupFn>::type>::type>; + return std::make_unique<LambdaSymbolResolverImpl>( + std::forward<GetResponsibilitySetFn>(GetResponsibilitySet), + std::forward<LookupFn>(Lookup)); +} + +/// Legacy adapter. Remove once we kill off the old ORC layers. +class JITSymbolResolverAdapter : public JITSymbolResolver { +public: + JITSymbolResolverAdapter(ExecutionSession &ES, SymbolResolver &R, + MaterializationResponsibility *MR); + Expected<LookupSet> getResponsibilitySet(const LookupSet &Symbols) override; + void lookup(const LookupSet &Symbols, OnResolvedFunction OnResolved) override; + +private: + ExecutionSession &ES; + std::set<SymbolStringPtr> ResolvedStrings; + SymbolResolver &R; + MaterializationResponsibility *MR; +}; + +/// Use the given legacy-style FindSymbol function (i.e. a function that takes +/// a const std::string& or StringRef and returns a JITSymbol) to get the +/// subset of symbols that the caller is responsible for materializing. If any +/// JITSymbol returned by FindSymbol is in an error state the function returns +/// immediately with that error. +/// +/// Useful for implementing getResponsibilitySet bodies that query legacy +/// resolvers. +template <typename FindSymbolFn> +Expected<SymbolNameSet> +getResponsibilitySetWithLegacyFn(const SymbolNameSet &Symbols, + FindSymbolFn FindSymbol) { + SymbolNameSet Result; + + for (auto &S : Symbols) { + if (JITSymbol Sym = FindSymbol(*S)) { + if (!Sym.getFlags().isStrong()) + Result.insert(S); + } else if (auto Err = Sym.takeError()) + return std::move(Err); + } + + return Result; +} + +/// Use the given legacy-style FindSymbol function (i.e. a function that +/// takes a const std::string& or StringRef and returns a JITSymbol) to +/// find the address and flags for each symbol in Symbols and store the +/// result in Query. If any JITSymbol returned by FindSymbol is in an +/// error then Query.notifyFailed(...) is called with that error and the +/// function returns immediately. On success, returns the set of symbols +/// not found. +/// +/// Useful for implementing lookup bodies that query legacy resolvers. +template <typename FindSymbolFn> +SymbolNameSet +lookupWithLegacyFn(ExecutionSession &ES, AsynchronousSymbolQuery &Query, + const SymbolNameSet &Symbols, FindSymbolFn FindSymbol) { + SymbolNameSet SymbolsNotFound; + bool NewSymbolsResolved = false; + + for (auto &S : Symbols) { + if (JITSymbol Sym = FindSymbol(*S)) { + if (auto Addr = Sym.getAddress()) { + Query.notifySymbolMetRequiredState( + S, JITEvaluatedSymbol(*Addr, Sym.getFlags())); + NewSymbolsResolved = true; + } else { + ES.legacyFailQuery(Query, Addr.takeError()); + return SymbolNameSet(); + } + } else if (auto Err = Sym.takeError()) { + ES.legacyFailQuery(Query, std::move(Err)); + return SymbolNameSet(); + } else + SymbolsNotFound.insert(S); + } + + if (NewSymbolsResolved && Query.isComplete()) + Query.handleComplete(); + + return SymbolsNotFound; +} + +/// An ORC SymbolResolver implementation that uses a legacy +/// findSymbol-like function to perform lookup; +template <typename LegacyLookupFn> +class LegacyLookupFnResolver final : public SymbolResolver { +public: + using ErrorReporter = std::function<void(Error)>; + + LegacyLookupFnResolver(ExecutionSession &ES, LegacyLookupFn LegacyLookup, + ErrorReporter ReportError) + : ES(ES), LegacyLookup(std::move(LegacyLookup)), + ReportError(std::move(ReportError)) {} + + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final { + if (auto ResponsibilitySet = + getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup)) + return std::move(*ResponsibilitySet); + else { + ReportError(ResponsibilitySet.takeError()); + return SymbolNameSet(); + } + } + + SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query, + SymbolNameSet Symbols) final { + return lookupWithLegacyFn(ES, *Query, Symbols, LegacyLookup); + } + +private: + ExecutionSession &ES; + LegacyLookupFn LegacyLookup; + ErrorReporter ReportError; +}; + +template <typename LegacyLookupFn> +std::shared_ptr<LegacyLookupFnResolver<LegacyLookupFn>> +createLegacyLookupResolver(ExecutionSession &ES, LegacyLookupFn LegacyLookup, + std::function<void(Error)> ErrorReporter) { + return std::make_shared<LegacyLookupFnResolver<LegacyLookupFn>>( + ES, std::move(LegacyLookup), std::move(ErrorReporter)); +} + +} // End namespace orc +} // End namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_LEGACY_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/NullResolver.h b/llvm/include/llvm/ExecutionEngine/Orc/NullResolver.h new file mode 100644 index 000000000000..ffa37a13d064 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/NullResolver.h @@ -0,0 +1,43 @@ +//===------ NullResolver.h - Reject symbol lookup requests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines a RuntimeDyld::SymbolResolver subclass that rejects all symbol +// resolution requests, for clients that have no cross-object fixups. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_NULLRESOLVER_H +#define LLVM_EXECUTIONENGINE_ORC_NULLRESOLVER_H + +#include "llvm/ExecutionEngine/Orc/Legacy.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" + +namespace llvm { +namespace orc { + +class NullResolver : public SymbolResolver { +public: + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final; + + SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query, + SymbolNameSet Symbols) final; +}; + +/// SymbolResolver impliementation that rejects all resolution requests. +/// Useful for clients that have no cross-object fixups. +class NullLegacyResolver : public LegacyJITSymbolResolver { +public: + JITSymbol findSymbol(const std::string &Name) final; + + JITSymbol findSymbolInLogicalDylib(const std::string &Name) final; +}; + +} // End namespace orc. +} // End namespace llvm. + +#endif // LLVM_EXECUTIONENGINE_ORC_NULLRESOLVER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h new file mode 100644 index 000000000000..caf8e707516d --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h @@ -0,0 +1,182 @@ +//===-- ObjectLinkingLayer.h - JITLink-based jit linking layer --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definition for an JITLink-based, in-process object linking +// layer. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITLink/JITLink.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include "llvm/Support/Error.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <list> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace llvm { + +namespace jitlink { +class EHFrameRegistrar; +} // namespace jitlink + +namespace object { +class ObjectFile; +} // namespace object + +namespace orc { + +class ObjectLinkingLayerJITLinkContext; + +/// An ObjectLayer implementation built on JITLink. +/// +/// Clients can use this class to add relocatable object files to an +/// ExecutionSession, and it typically serves as the base layer (underneath +/// a compiling layer like IRCompileLayer) for the rest of the JIT. +class ObjectLinkingLayer : public ObjectLayer { + friend class ObjectLinkingLayerJITLinkContext; + +public: + /// Plugin instances can be added to the ObjectLinkingLayer to receive + /// callbacks when code is loaded or emitted, and when JITLink is being + /// configured. + class Plugin { + public: + virtual ~Plugin(); + virtual void modifyPassConfig(MaterializationResponsibility &MR, + const Triple &TT, + jitlink::PassConfiguration &Config) {} + virtual void notifyLoaded(MaterializationResponsibility &MR) {} + virtual Error notifyEmitted(MaterializationResponsibility &MR) { + return Error::success(); + } + virtual Error notifyRemovingModule(VModuleKey K) { + return Error::success(); + } + virtual Error notifyRemovingAllModules() { return Error::success(); } + }; + + using ReturnObjectBufferFunction = + std::function<void(std::unique_ptr<MemoryBuffer>)>; + + /// Construct an ObjectLinkingLayer with the given NotifyLoaded, + /// and NotifyEmitted functors. + ObjectLinkingLayer(ExecutionSession &ES, + jitlink::JITLinkMemoryManager &MemMgr); + + /// Destruct an ObjectLinkingLayer. + ~ObjectLinkingLayer(); + + /// Set an object buffer return function. By default object buffers are + /// deleted once the JIT has linked them. If a return function is set then + /// it will be called to transfer ownership of the buffer instead. + void setReturnObjectBuffer(ReturnObjectBufferFunction ReturnObjectBuffer) { + this->ReturnObjectBuffer = std::move(ReturnObjectBuffer); + } + + /// Add a pass-config modifier. + ObjectLinkingLayer &addPlugin(std::unique_ptr<Plugin> P) { + std::lock_guard<std::mutex> Lock(LayerMutex); + Plugins.push_back(std::move(P)); + return *this; + } + + /// Emit the object. + void emit(MaterializationResponsibility R, + std::unique_ptr<MemoryBuffer> O) override; + + /// Instructs this ObjectLinkingLayer instance to override the symbol flags + /// found in the AtomGraph with the flags supplied by the + /// MaterializationResponsibility instance. This is a workaround to support + /// symbol visibility in COFF, which does not use the libObject's + /// SF_Exported flag. Use only when generating / adding COFF object files. + /// + /// FIXME: We should be able to remove this if/when COFF properly tracks + /// exported symbols. + ObjectLinkingLayer & + setOverrideObjectFlagsWithResponsibilityFlags(bool OverrideObjectFlags) { + this->OverrideObjectFlags = OverrideObjectFlags; + return *this; + } + + /// If set, this ObjectLinkingLayer instance will claim responsibility + /// for any symbols provided by a given object file that were not already in + /// the MaterializationResponsibility instance. Setting this flag allows + /// higher-level program representations (e.g. LLVM IR) to be added based on + /// only a subset of the symbols they provide, without having to write + /// intervening layers to scan and add the additional symbols. This trades + /// diagnostic quality for convenience however: If all symbols are enumerated + /// up-front then clashes can be detected and reported early (and usually + /// deterministically). If this option is set, clashes for the additional + /// symbols may not be detected until late, and detection may depend on + /// the flow of control through JIT'd code. Use with care. + ObjectLinkingLayer & + setAutoClaimResponsibilityForObjectSymbols(bool AutoClaimObjectSymbols) { + this->AutoClaimObjectSymbols = AutoClaimObjectSymbols; + return *this; + } + +private: + using AllocPtr = std::unique_ptr<jitlink::JITLinkMemoryManager::Allocation>; + + void modifyPassConfig(MaterializationResponsibility &MR, const Triple &TT, + jitlink::PassConfiguration &PassConfig); + void notifyLoaded(MaterializationResponsibility &MR); + Error notifyEmitted(MaterializationResponsibility &MR, AllocPtr Alloc); + + Error removeModule(VModuleKey K); + Error removeAllModules(); + + mutable std::mutex LayerMutex; + jitlink::JITLinkMemoryManager &MemMgr; + bool OverrideObjectFlags = false; + bool AutoClaimObjectSymbols = false; + ReturnObjectBufferFunction ReturnObjectBuffer; + DenseMap<VModuleKey, AllocPtr> TrackedAllocs; + std::vector<AllocPtr> UntrackedAllocs; + std::vector<std::unique_ptr<Plugin>> Plugins; +}; + +class EHFrameRegistrationPlugin : public ObjectLinkingLayer::Plugin { +public: + EHFrameRegistrationPlugin(jitlink::EHFrameRegistrar &Registrar); + Error notifyEmitted(MaterializationResponsibility &MR) override; + void modifyPassConfig(MaterializationResponsibility &MR, const Triple &TT, + jitlink::PassConfiguration &PassConfig) override; + Error notifyRemovingModule(VModuleKey K) override; + Error notifyRemovingAllModules() override; + +private: + + struct EHFrameRange { + JITTargetAddress Addr = 0; + size_t Size; + }; + + jitlink::EHFrameRegistrar &Registrar; + DenseMap<MaterializationResponsibility *, EHFrameRange> InProcessLinks; + DenseMap<VModuleKey, EHFrameRange> TrackedEHFrameRanges; + std::vector<EHFrameRange> UntrackedEHFrameRanges; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h new file mode 100644 index 000000000000..eac1cc3e097a --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h @@ -0,0 +1,127 @@ +//===- ObjectTransformLayer.h - Run all objects through functor -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Run all objects passed in through a user supplied functor. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_OBJECTTRANSFORMLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_OBJECTTRANSFORMLAYER_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include <algorithm> +#include <memory> +#include <string> + +namespace llvm { +namespace orc { + +class ObjectTransformLayer : public ObjectLayer { +public: + using TransformFunction = + std::function<Expected<std::unique_ptr<MemoryBuffer>>( + std::unique_ptr<MemoryBuffer>)>; + + ObjectTransformLayer(ExecutionSession &ES, ObjectLayer &BaseLayer, + TransformFunction Transform); + + void emit(MaterializationResponsibility R, + std::unique_ptr<MemoryBuffer> O) override; + +private: + ObjectLayer &BaseLayer; + TransformFunction Transform; +}; + +/// Object mutating layer. +/// +/// This layer accepts sets of ObjectFiles (via addObject). It +/// immediately applies the user supplied functor to each object, then adds +/// the set of transformed objects to the layer below. +template <typename BaseLayerT, typename TransformFtor> +class LegacyObjectTransformLayer { +public: + /// Construct an ObjectTransformLayer with the given BaseLayer + LLVM_ATTRIBUTE_DEPRECATED( + LegacyObjectTransformLayer(BaseLayerT &BaseLayer, + TransformFtor Transform = TransformFtor()), + "ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please " + "use " + "the ORCv2 ObjectTransformLayer instead"); + + /// Legacy layer constructor with deprecation acknowledgement. + LegacyObjectTransformLayer(ORCv1DeprecationAcknowledgement, + BaseLayerT &BaseLayer, + TransformFtor Transform = TransformFtor()) + : BaseLayer(BaseLayer), Transform(std::move(Transform)) {} + + /// Apply the transform functor to each object in the object set, then + /// add the resulting set of objects to the base layer, along with the + /// memory manager and symbol resolver. + /// + /// @return A handle for the added objects. + template <typename ObjectPtr> Error addObject(VModuleKey K, ObjectPtr Obj) { + return BaseLayer.addObject(std::move(K), Transform(std::move(Obj))); + } + + /// Remove the object set associated with the VModuleKey K. + Error removeObject(VModuleKey K) { return BaseLayer.removeObject(K); } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { + return BaseLayer.findSymbol(Name, ExportedSymbolsOnly); + } + + /// Get the address of the given symbol in the context of the set of + /// objects represented by the VModuleKey K. This call is forwarded to + /// the base layer's implementation. + /// @param K The VModuleKey associated with the object set to search in. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it is found in the + /// given object set. + JITSymbol findSymbolIn(VModuleKey K, const std::string &Name, + bool ExportedSymbolsOnly) { + return BaseLayer.findSymbolIn(K, Name, ExportedSymbolsOnly); + } + + /// Immediately emit and finalize the object set represented by the + /// given VModuleKey K. + Error emitAndFinalize(VModuleKey K) { return BaseLayer.emitAndFinalize(K); } + + /// Map section addresses for the objects associated with the + /// VModuleKey K. + void mapSectionAddress(VModuleKey K, const void *LocalAddress, + JITTargetAddress TargetAddr) { + BaseLayer.mapSectionAddress(K, LocalAddress, TargetAddr); + } + + /// Access the transform functor directly. + TransformFtor &getTransform() { return Transform; } + + /// Access the mumate functor directly. + const TransformFtor &getTransform() const { return Transform; } + +private: + BaseLayerT &BaseLayer; + TransformFtor Transform; +}; + +template <typename BaseLayerT, typename TransformFtor> +LegacyObjectTransformLayer<BaseLayerT, TransformFtor>:: + LegacyObjectTransformLayer(BaseLayerT &BaseLayer, TransformFtor Transform) + : BaseLayer(BaseLayer), Transform(std::move(Transform)) {} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_OBJECTTRANSFORMLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcABISupport.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcABISupport.h new file mode 100644 index 000000000000..38246bc480b6 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcABISupport.h @@ -0,0 +1,314 @@ +//===- OrcABISupport.h - ABI support code -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// ABI specific code for Orc, e.g. callback assembly. +// +// ABI classes should be part of the JIT *target* process, not the host +// process (except where you're doing hosted JITing and the two are one and the +// same). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H +#define LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Memory.h" +#include <algorithm> +#include <cstdint> + +namespace llvm { +namespace orc { + +/// Generic ORC ABI support. +/// +/// This class can be substituted as the target architecure support class for +/// ORC templates that require one (e.g. IndirectStubsManagers). It does not +/// support lazy JITing however, and any attempt to use that functionality +/// will result in execution of an llvm_unreachable. +class OrcGenericABI { +public: + static const unsigned PointerSize = sizeof(uintptr_t); + static const unsigned TrampolineSize = 1; + static const unsigned ResolverCodeSize = 1; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr) { + llvm_unreachable("writeResolverCode is not supported by the generic host " + "support class"); + } + + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines) { + llvm_unreachable("writeTrampolines is not supported by the generic host " + "support class"); + } + + class IndirectStubsInfo { + public: + const static unsigned StubSize = 1; + + unsigned getNumStubs() const { llvm_unreachable("Not supported"); } + void *getStub(unsigned Idx) const { llvm_unreachable("Not supported"); } + void **getPtr(unsigned Idx) const { llvm_unreachable("Not supported"); } + }; + + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal) { + llvm_unreachable("emitIndirectStubsBlock is not supported by the generic " + "host support class"); + } +}; + +/// Provide information about stub blocks generated by the +/// makeIndirectStubsBlock function. +template <unsigned StubSizeVal> class GenericIndirectStubsInfo { +public: + const static unsigned StubSize = StubSizeVal; + + GenericIndirectStubsInfo() = default; + GenericIndirectStubsInfo(unsigned NumStubs, sys::OwningMemoryBlock StubsMem) + : NumStubs(NumStubs), StubsMem(std::move(StubsMem)) {} + GenericIndirectStubsInfo(GenericIndirectStubsInfo &&Other) + : NumStubs(Other.NumStubs), StubsMem(std::move(Other.StubsMem)) { + Other.NumStubs = 0; + } + + GenericIndirectStubsInfo &operator=(GenericIndirectStubsInfo &&Other) { + NumStubs = Other.NumStubs; + Other.NumStubs = 0; + StubsMem = std::move(Other.StubsMem); + return *this; + } + + /// Number of stubs in this block. + unsigned getNumStubs() const { return NumStubs; } + + /// Get a pointer to the stub at the given index, which must be in + /// the range 0 .. getNumStubs() - 1. + void *getStub(unsigned Idx) const { + return static_cast<char *>(StubsMem.base()) + Idx * StubSize; + } + + /// Get a pointer to the implementation-pointer at the given index, + /// which must be in the range 0 .. getNumStubs() - 1. + void **getPtr(unsigned Idx) const { + char *PtrsBase = static_cast<char *>(StubsMem.base()) + NumStubs * StubSize; + return reinterpret_cast<void **>(PtrsBase) + Idx; + } + +private: + unsigned NumStubs = 0; + sys::OwningMemoryBlock StubsMem; +}; + +class OrcAArch64 { +public: + static const unsigned PointerSize = 8; + static const unsigned TrampolineSize = 12; + static const unsigned ResolverCodeSize = 0x120; + + using IndirectStubsInfo = GenericIndirectStubsInfo<8>; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + + /// Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); + + /// Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on x86-64, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +/// X86_64 code that's common to all ABIs. +/// +/// X86_64 supports lazy JITing. +class OrcX86_64_Base { +public: + static const unsigned PointerSize = 8; + static const unsigned TrampolineSize = 8; + + using IndirectStubsInfo = GenericIndirectStubsInfo<8>; + + /// Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on x86-64, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +/// X86_64 support for SysV ABI (Linux, MacOSX). +/// +/// X86_64_SysV supports lazy JITing. +class OrcX86_64_SysV : public OrcX86_64_Base { +public: + static const unsigned ResolverCodeSize = 0x6C; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + + /// Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); +}; + +/// X86_64 support for Win32. +/// +/// X86_64_Win32 supports lazy JITing. +class OrcX86_64_Win32 : public OrcX86_64_Base { +public: + static const unsigned ResolverCodeSize = 0x74; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + + /// Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); +}; + +/// I386 support. +/// +/// I386 supports lazy JITing. +class OrcI386 { +public: + static const unsigned PointerSize = 4; + static const unsigned TrampolineSize = 8; + static const unsigned ResolverCodeSize = 0x4a; + + using IndirectStubsInfo = GenericIndirectStubsInfo<8>; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + + /// Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); + + /// Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on i386, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +// @brief Mips32 support. +// +// Mips32 supports lazy JITing. +class OrcMips32_Base { +public: + static const unsigned PointerSize = 4; + static const unsigned TrampolineSize = 20; + static const unsigned ResolverCodeSize = 0xfc; + using IndirectStubsInfo = GenericIndirectStubsInfo<16>; + + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + /// @brief Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr,unsigned NumTrampolines); + + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry,void *CallbackMgr, bool isBigEndian); + /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on Mips32, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo,unsigned MinStubs, void *InitialPtrVal); +}; + + +class OrcMips32Le : public OrcMips32_Base { +public: + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry,void *CallbackMgr) + { OrcMips32_Base::writeResolverCode(ResolveMem, Reentry, CallbackMgr, false); } +}; + +class OrcMips32Be : public OrcMips32_Base { +public: + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry,void *CallbackMgr) + { OrcMips32_Base::writeResolverCode(ResolveMem, Reentry, CallbackMgr, true); } +}; + +// @brief Mips64 support. +// +// Mips64 supports lazy JITing. +class OrcMips64 { +public: + static const unsigned PointerSize = 8; + static const unsigned TrampolineSize = 40; + static const unsigned ResolverCodeSize = 0x120; + + using IndirectStubsInfo = GenericIndirectStubsInfo<32>; + using JITReentryFn = JITTargetAddress (*)(void *CallbackMgr, + void *TrampolineId); + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry,void *CallbackMgr); + + /// @brief Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr,unsigned NumTrampolines); + + /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on Mips64, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo,unsigned MinStubs, void *InitialPtrVal); +}; + + } // end namespace orc + } // end namespace llvm +#endif // LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h new file mode 100644 index 000000000000..e5d6a3eca85f --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -0,0 +1,70 @@ +//===------ OrcError.h - Reject symbol lookup requests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Define an error category, error codes, and helper utilities for Orc. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H +#define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H + +#include "llvm/Support/Error.h" +#include <system_error> + +namespace llvm { +namespace orc { + +enum class OrcErrorCode : int { + // RPC Errors + UnknownORCError = 1, + DuplicateDefinition, + JITSymbolNotFound, + RemoteAllocatorDoesNotExist, + RemoteAllocatorIdAlreadyInUse, + RemoteMProtectAddrUnrecognized, + RemoteIndirectStubsOwnerDoesNotExist, + RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCConnectionClosed, + RPCCouldNotNegotiateFunction, + RPCResponseAbandoned, + UnexpectedRPCCall, + UnexpectedRPCResponse, + UnknownErrorCodeFromRemote, + UnknownResourceHandle +}; + +std::error_code orcError(OrcErrorCode ErrCode); + +class DuplicateDefinition : public ErrorInfo<DuplicateDefinition> { +public: + static char ID; + + DuplicateDefinition(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +class JITSymbolNotFound : public ErrorInfo<JITSymbolNotFound> { +public: + static char ID; + + JITSymbolNotFound(std::string SymbolName); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSymbolName() const; +private: + std::string SymbolName; +}; + +} // End namespace orc. +} // End namespace llvm. + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCERROR_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h new file mode 100644 index 000000000000..86e8d5df3ad9 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -0,0 +1,702 @@ +//===- OrcRemoteTargetClient.h - Orc Remote-target Client -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the OrcRemoteTargetClient class and helpers. This class +// can be used to communicate over an RawByteChannel with an +// OrcRemoteTargetServer instance to support remote-JITing. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#define DEBUG_TYPE "orc-remote" + +namespace llvm { +namespace orc { +namespace remote { + +/// This class provides utilities (including memory manager, indirect stubs +/// manager, and compile callback manager types) that support remote JITing +/// in ORC. +/// +/// Each of the utility classes talks to a JIT server (an instance of the +/// OrcRemoteTargetServer class) via an RPC system (see RPCUtils.h) to carry out +/// its actions. +class OrcRemoteTargetClient + : public rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel> { +public: + /// Remote-mapped RuntimeDyld-compatible memory manager. + class RemoteRTDyldMemoryManager : public RuntimeDyld::MemoryManager { + friend class OrcRemoteTargetClient; + + public: + ~RemoteRTDyldMemoryManager() { + Client.destroyRemoteAllocator(Id); + LLVM_DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n"); + } + + RemoteRTDyldMemoryManager(const RemoteRTDyldMemoryManager &) = delete; + RemoteRTDyldMemoryManager & + operator=(const RemoteRTDyldMemoryManager &) = delete; + RemoteRTDyldMemoryManager(RemoteRTDyldMemoryManager &&) = default; + RemoteRTDyldMemoryManager &operator=(RemoteRTDyldMemoryManager &&) = delete; + + uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, + StringRef SectionName) override { + Unmapped.back().CodeAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast<uint8_t *>( + Unmapped.back().CodeAllocs.back().getLocalAddress()); + LLVM_DEBUG(dbgs() << "Allocator " << Id << " allocated code for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << ")\n"); + return Alloc; + } + + uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, StringRef SectionName, + bool IsReadOnly) override { + if (IsReadOnly) { + Unmapped.back().RODataAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast<uint8_t *>( + Unmapped.back().RODataAllocs.back().getLocalAddress()); + LLVM_DEBUG(dbgs() << "Allocator " << Id << " allocated ro-data for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << ")\n"); + return Alloc; + } // else... + + Unmapped.back().RWDataAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast<uint8_t *>( + Unmapped.back().RWDataAllocs.back().getLocalAddress()); + LLVM_DEBUG(dbgs() << "Allocator " << Id << " allocated rw-data for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << ")\n"); + return Alloc; + } + + void reserveAllocationSpace(uintptr_t CodeSize, uint32_t CodeAlign, + uintptr_t RODataSize, uint32_t RODataAlign, + uintptr_t RWDataSize, + uint32_t RWDataAlign) override { + Unmapped.push_back(ObjectAllocs()); + + LLVM_DEBUG(dbgs() << "Allocator " << Id << " reserved:\n"); + + if (CodeSize != 0) { + Unmapped.back().RemoteCodeAddr = + Client.reserveMem(Id, CodeSize, CodeAlign); + + LLVM_DEBUG( + dbgs() << " code: " + << format("0x%016" PRIx64, Unmapped.back().RemoteCodeAddr) + << " (" << CodeSize << " bytes, alignment " << CodeAlign + << ")\n"); + } + + if (RODataSize != 0) { + Unmapped.back().RemoteRODataAddr = + Client.reserveMem(Id, RODataSize, RODataAlign); + + LLVM_DEBUG( + dbgs() << " ro-data: " + << format("0x%016" PRIx64, Unmapped.back().RemoteRODataAddr) + << " (" << RODataSize << " bytes, alignment " << RODataAlign + << ")\n"); + } + + if (RWDataSize != 0) { + Unmapped.back().RemoteRWDataAddr = + Client.reserveMem(Id, RWDataSize, RWDataAlign); + + LLVM_DEBUG( + dbgs() << " rw-data: " + << format("0x%016" PRIx64, Unmapped.back().RemoteRWDataAddr) + << " (" << RWDataSize << " bytes, alignment " << RWDataAlign + << ")\n"); + } + } + + bool needsToReserveAllocationSpace() override { return true; } + + void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr, + size_t Size) override { + UnfinalizedEHFrames.push_back({LoadAddr, Size}); + } + + void deregisterEHFrames() override { + for (auto &Frame : RegisteredEHFrames) { + // FIXME: Add error poll. + Client.deregisterEHFrames(Frame.Addr, Frame.Size); + } + } + + void notifyObjectLoaded(RuntimeDyld &Dyld, + const object::ObjectFile &Obj) override { + LLVM_DEBUG(dbgs() << "Allocator " << Id << " applied mappings:\n"); + for (auto &ObjAllocs : Unmapped) { + mapAllocsToRemoteAddrs(Dyld, ObjAllocs.CodeAllocs, + ObjAllocs.RemoteCodeAddr); + mapAllocsToRemoteAddrs(Dyld, ObjAllocs.RODataAllocs, + ObjAllocs.RemoteRODataAddr); + mapAllocsToRemoteAddrs(Dyld, ObjAllocs.RWDataAllocs, + ObjAllocs.RemoteRWDataAddr); + Unfinalized.push_back(std::move(ObjAllocs)); + } + Unmapped.clear(); + } + + bool finalizeMemory(std::string *ErrMsg = nullptr) override { + LLVM_DEBUG(dbgs() << "Allocator " << Id << " finalizing:\n"); + + for (auto &ObjAllocs : Unfinalized) { + if (copyAndProtect(ObjAllocs.CodeAllocs, ObjAllocs.RemoteCodeAddr, + sys::Memory::MF_READ | sys::Memory::MF_EXEC)) + return true; + + if (copyAndProtect(ObjAllocs.RODataAllocs, ObjAllocs.RemoteRODataAddr, + sys::Memory::MF_READ)) + return true; + + if (copyAndProtect(ObjAllocs.RWDataAllocs, ObjAllocs.RemoteRWDataAddr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE)) + return true; + } + Unfinalized.clear(); + + for (auto &EHFrame : UnfinalizedEHFrames) { + if (auto Err = Client.registerEHFrames(EHFrame.Addr, EHFrame.Size)) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return false; + } + } + RegisteredEHFrames = std::move(UnfinalizedEHFrames); + UnfinalizedEHFrames = {}; + + return false; + } + + private: + class Alloc { + public: + Alloc(uint64_t Size, unsigned Align) + : Size(Size), Align(Align), Contents(new char[Size + Align - 1]) {} + + Alloc(const Alloc &) = delete; + Alloc &operator=(const Alloc &) = delete; + Alloc(Alloc &&) = default; + Alloc &operator=(Alloc &&) = default; + + uint64_t getSize() const { return Size; } + + unsigned getAlign() const { return Align; } + + char *getLocalAddress() const { + uintptr_t LocalAddr = reinterpret_cast<uintptr_t>(Contents.get()); + LocalAddr = alignTo(LocalAddr, Align); + return reinterpret_cast<char *>(LocalAddr); + } + + void setRemoteAddress(JITTargetAddress RemoteAddr) { + this->RemoteAddr = RemoteAddr; + } + + JITTargetAddress getRemoteAddress() const { return RemoteAddr; } + + private: + uint64_t Size; + unsigned Align; + std::unique_ptr<char[]> Contents; + JITTargetAddress RemoteAddr = 0; + }; + + struct ObjectAllocs { + ObjectAllocs() = default; + ObjectAllocs(const ObjectAllocs &) = delete; + ObjectAllocs &operator=(const ObjectAllocs &) = delete; + ObjectAllocs(ObjectAllocs &&) = default; + ObjectAllocs &operator=(ObjectAllocs &&) = default; + + JITTargetAddress RemoteCodeAddr = 0; + JITTargetAddress RemoteRODataAddr = 0; + JITTargetAddress RemoteRWDataAddr = 0; + std::vector<Alloc> CodeAllocs, RODataAllocs, RWDataAllocs; + }; + + RemoteRTDyldMemoryManager(OrcRemoteTargetClient &Client, + ResourceIdMgr::ResourceId Id) + : Client(Client), Id(Id) { + LLVM_DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); + } + + // Maps all allocations in Allocs to aligned blocks + void mapAllocsToRemoteAddrs(RuntimeDyld &Dyld, std::vector<Alloc> &Allocs, + JITTargetAddress NextAddr) { + for (auto &Alloc : Allocs) { + NextAddr = alignTo(NextAddr, Alloc.getAlign()); + Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextAddr); + LLVM_DEBUG( + dbgs() << " " << static_cast<void *>(Alloc.getLocalAddress()) + << " -> " << format("0x%016" PRIx64, NextAddr) << "\n"); + Alloc.setRemoteAddress(NextAddr); + + // Only advance NextAddr if it was non-null to begin with, + // otherwise leave it as null. + if (NextAddr) + NextAddr += Alloc.getSize(); + } + } + + // Copies data for each alloc in the list, then set permissions on the + // segment. + bool copyAndProtect(const std::vector<Alloc> &Allocs, + JITTargetAddress RemoteSegmentAddr, + unsigned Permissions) { + if (RemoteSegmentAddr) { + assert(!Allocs.empty() && "No sections in allocated segment"); + + for (auto &Alloc : Allocs) { + LLVM_DEBUG(dbgs() << " copying section: " + << static_cast<void *>(Alloc.getLocalAddress()) + << " -> " + << format("0x%016" PRIx64, Alloc.getRemoteAddress()) + << " (" << Alloc.getSize() << " bytes)\n";); + + if (Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), + Alloc.getSize())) + return true; + } + + LLVM_DEBUG(dbgs() << " setting " + << (Permissions & sys::Memory::MF_READ ? 'R' : '-') + << (Permissions & sys::Memory::MF_WRITE ? 'W' : '-') + << (Permissions & sys::Memory::MF_EXEC ? 'X' : '-') + << " permissions on block: " + << format("0x%016" PRIx64, RemoteSegmentAddr) + << "\n"); + if (Client.setProtections(Id, RemoteSegmentAddr, Permissions)) + return true; + } + return false; + } + + OrcRemoteTargetClient &Client; + ResourceIdMgr::ResourceId Id; + std::vector<ObjectAllocs> Unmapped; + std::vector<ObjectAllocs> Unfinalized; + + struct EHFrame { + JITTargetAddress Addr; + uint64_t Size; + }; + std::vector<EHFrame> UnfinalizedEHFrames; + std::vector<EHFrame> RegisteredEHFrames; + }; + + /// Remote indirect stubs manager. + class RemoteIndirectStubsManager : public IndirectStubsManager { + public: + RemoteIndirectStubsManager(OrcRemoteTargetClient &Client, + ResourceIdMgr::ResourceId Id) + : Client(Client), Id(Id) {} + + ~RemoteIndirectStubsManager() override { + Client.destroyIndirectStubsManager(Id); + } + + Error createStub(StringRef StubName, JITTargetAddress StubAddr, + JITSymbolFlags StubFlags) override { + if (auto Err = reserveStubs(1)) + return Err; + + return createStubInternal(StubName, StubAddr, StubFlags); + } + + Error createStubs(const StubInitsMap &StubInits) override { + if (auto Err = reserveStubs(StubInits.size())) + return Err; + + for (auto &Entry : StubInits) + if (auto Err = createStubInternal(Entry.first(), Entry.second.first, + Entry.second.second)) + return Err; + + return Error::success(); + } + + JITEvaluatedSymbol findStub(StringRef Name, bool ExportedStubsOnly) override { + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + auto Flags = I->second.second; + auto StubSymbol = JITEvaluatedSymbol(getStubAddr(Key), Flags); + if (ExportedStubsOnly && !StubSymbol.getFlags().isExported()) + return nullptr; + return StubSymbol; + } + + JITEvaluatedSymbol findPointer(StringRef Name) override { + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + auto Flags = I->second.second; + return JITEvaluatedSymbol(getPtrAddr(Key), Flags); + } + + Error updatePointer(StringRef Name, JITTargetAddress NewAddr) override { + auto I = StubIndexes.find(Name); + assert(I != StubIndexes.end() && "No stub pointer for symbol"); + auto Key = I->second.first; + return Client.writePointer(getPtrAddr(Key), NewAddr); + } + + private: + struct RemoteIndirectStubsInfo { + JITTargetAddress StubBase; + JITTargetAddress PtrBase; + unsigned NumStubs; + }; + + using StubKey = std::pair<uint16_t, uint16_t>; + + Error reserveStubs(unsigned NumStubs) { + if (NumStubs <= FreeStubs.size()) + return Error::success(); + + unsigned NewStubsRequired = NumStubs - FreeStubs.size(); + JITTargetAddress StubBase; + JITTargetAddress PtrBase; + unsigned NumStubsEmitted; + + if (auto StubInfoOrErr = Client.emitIndirectStubs(Id, NewStubsRequired)) + std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr; + else + return StubInfoOrErr.takeError(); + + unsigned NewBlockId = RemoteIndirectStubsInfos.size(); + RemoteIndirectStubsInfos.push_back({StubBase, PtrBase, NumStubsEmitted}); + + for (unsigned I = 0; I < NumStubsEmitted; ++I) + FreeStubs.push_back(std::make_pair(NewBlockId, I)); + + return Error::success(); + } + + Error createStubInternal(StringRef StubName, JITTargetAddress InitAddr, + JITSymbolFlags StubFlags) { + auto Key = FreeStubs.back(); + FreeStubs.pop_back(); + StubIndexes[StubName] = std::make_pair(Key, StubFlags); + return Client.writePointer(getPtrAddr(Key), InitAddr); + } + + JITTargetAddress getStubAddr(StubKey K) { + assert(RemoteIndirectStubsInfos[K.first].StubBase != 0 && + "Missing stub address"); + return RemoteIndirectStubsInfos[K.first].StubBase + + K.second * Client.getIndirectStubSize(); + } + + JITTargetAddress getPtrAddr(StubKey K) { + assert(RemoteIndirectStubsInfos[K.first].PtrBase != 0 && + "Missing pointer address"); + return RemoteIndirectStubsInfos[K.first].PtrBase + + K.second * Client.getPointerSize(); + } + + OrcRemoteTargetClient &Client; + ResourceIdMgr::ResourceId Id; + std::vector<RemoteIndirectStubsInfo> RemoteIndirectStubsInfos; + std::vector<StubKey> FreeStubs; + StringMap<std::pair<StubKey, JITSymbolFlags>> StubIndexes; + }; + + class RemoteTrampolinePool : public TrampolinePool { + public: + RemoteTrampolinePool(OrcRemoteTargetClient &Client) : Client(Client) {} + + Expected<JITTargetAddress> getTrampoline() override { + std::lock_guard<std::mutex> Lock(RTPMutex); + if (AvailableTrampolines.empty()) { + if (auto Err = grow()) + return std::move(Err); + } + assert(!AvailableTrampolines.empty() && "Failed to grow trampoline pool"); + auto TrampolineAddr = AvailableTrampolines.back(); + AvailableTrampolines.pop_back(); + return TrampolineAddr; + } + + private: + Error grow() { + JITTargetAddress BlockAddr = 0; + uint32_t NumTrampolines = 0; + if (auto TrampolineInfoOrErr = Client.emitTrampolineBlock()) + std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr; + else + return TrampolineInfoOrErr.takeError(); + + uint32_t TrampolineSize = Client.getTrampolineSize(); + for (unsigned I = 0; I < NumTrampolines; ++I) + this->AvailableTrampolines.push_back(BlockAddr + (I * TrampolineSize)); + + return Error::success(); + } + + std::mutex RTPMutex; + OrcRemoteTargetClient &Client; + std::vector<JITTargetAddress> AvailableTrampolines; + }; + + /// Remote compile callback manager. + class RemoteCompileCallbackManager : public JITCompileCallbackManager { + public: + RemoteCompileCallbackManager(OrcRemoteTargetClient &Client, + ExecutionSession &ES, + JITTargetAddress ErrorHandlerAddress) + : JITCompileCallbackManager( + std::make_unique<RemoteTrampolinePool>(Client), ES, + ErrorHandlerAddress) {} + }; + + /// Create an OrcRemoteTargetClient. + /// Channel is the ChannelT instance to communicate on. It is assumed that + /// the channel is ready to be read from and written to. + static Expected<std::unique_ptr<OrcRemoteTargetClient>> + Create(rpc::RawByteChannel &Channel, ExecutionSession &ES) { + Error Err = Error::success(); + auto Client = std::unique_ptr<OrcRemoteTargetClient>( + new OrcRemoteTargetClient(Channel, ES, Err)); + if (Err) + return std::move(Err); + return std::move(Client); + } + + /// Call the int(void) function at the given address in the target and return + /// its result. + Expected<int> callIntVoid(JITTargetAddress Addr) { + LLVM_DEBUG(dbgs() << "Calling int(*)(void) " + << format("0x%016" PRIx64, Addr) << "\n"); + return callB<exec::CallIntVoid>(Addr); + } + + /// Call the int(int, char*[]) function at the given address in the target and + /// return its result. + Expected<int> callMain(JITTargetAddress Addr, + const std::vector<std::string> &Args) { + LLVM_DEBUG(dbgs() << "Calling int(*)(int, char*[]) " + << format("0x%016" PRIx64, Addr) << "\n"); + return callB<exec::CallMain>(Addr, Args); + } + + /// Call the void() function at the given address in the target and wait for + /// it to finish. + Error callVoidVoid(JITTargetAddress Addr) { + LLVM_DEBUG(dbgs() << "Calling void(*)(void) " + << format("0x%016" PRIx64, Addr) << "\n"); + return callB<exec::CallVoidVoid>(Addr); + } + + /// Create an RCMemoryManager which will allocate its memory on the remote + /// target. + Expected<std::unique_ptr<RemoteRTDyldMemoryManager>> + createRemoteMemoryManager() { + auto Id = AllocatorIds.getNext(); + if (auto Err = callB<mem::CreateRemoteAllocator>(Id)) + return std::move(Err); + return std::unique_ptr<RemoteRTDyldMemoryManager>( + new RemoteRTDyldMemoryManager(*this, Id)); + } + + /// Create an RCIndirectStubsManager that will allocate stubs on the remote + /// target. + Expected<std::unique_ptr<RemoteIndirectStubsManager>> + createIndirectStubsManager() { + auto Id = IndirectStubOwnerIds.getNext(); + if (auto Err = callB<stubs::CreateIndirectStubsOwner>(Id)) + return std::move(Err); + return std::make_unique<RemoteIndirectStubsManager>(*this, Id); + } + + Expected<RemoteCompileCallbackManager &> + enableCompileCallbacks(JITTargetAddress ErrorHandlerAddress) { + assert(!CallbackManager && "CallbackManager already obtained"); + + // Emit the resolver block on the JIT server. + if (auto Err = callB<stubs::EmitResolverBlock>()) + return std::move(Err); + + // Create the callback manager. + CallbackManager.emplace(*this, ES, ErrorHandlerAddress); + RemoteCompileCallbackManager &Mgr = *CallbackManager; + return Mgr; + } + + /// Search for symbols in the remote process. Note: This should be used by + /// symbol resolvers *after* they've searched the local symbol table in the + /// JIT stack. + Expected<JITTargetAddress> getSymbolAddress(StringRef Name) { + return callB<utils::GetSymbolAddress>(Name); + } + + /// Get the triple for the remote target. + const std::string &getTargetTriple() const { return RemoteTargetTriple; } + + Error terminateSession() { return callB<utils::TerminateSession>(); } + +private: + OrcRemoteTargetClient(rpc::RawByteChannel &Channel, ExecutionSession &ES, + Error &Err) + : rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>(Channel, true), + ES(ES) { + ErrorAsOutParameter EAO(&Err); + + addHandler<utils::RequestCompile>( + [this](JITTargetAddress Addr) -> JITTargetAddress { + if (CallbackManager) + return CallbackManager->executeCompileCallback(Addr); + return 0; + }); + + if (auto RIOrErr = callB<utils::GetRemoteInfo>()) { + std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, + RemoteTrampolineSize, RemoteIndirectStubSize) = *RIOrErr; + Err = Error::success(); + } else + Err = RIOrErr.takeError(); + } + + void deregisterEHFrames(JITTargetAddress Addr, uint32_t Size) { + if (auto Err = callB<eh::RegisterEHFrames>(Addr, Size)) + ES.reportError(std::move(Err)); + } + + void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { + if (auto Err = callB<mem::DestroyRemoteAllocator>(Id)) { + // FIXME: This will be triggered by a removeModuleSet call: Propagate + // error return up through that. + llvm_unreachable("Failed to destroy remote allocator."); + AllocatorIds.release(Id); + } + } + + void destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { + IndirectStubOwnerIds.release(Id); + if (auto Err = callB<stubs::DestroyIndirectStubsOwner>(Id)) + ES.reportError(std::move(Err)); + } + + Expected<std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>> + emitIndirectStubs(ResourceIdMgr::ResourceId Id, uint32_t NumStubsRequired) { + return callB<stubs::EmitIndirectStubs>(Id, NumStubsRequired); + } + + Expected<std::tuple<JITTargetAddress, uint32_t>> emitTrampolineBlock() { + return callB<stubs::EmitTrampolineBlock>(); + } + + uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } + uint32_t getPageSize() const { return RemotePageSize; } + uint32_t getPointerSize() const { return RemotePointerSize; } + + uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } + + Expected<std::vector<uint8_t>> readMem(char *Dst, JITTargetAddress Src, + uint64_t Size) { + return callB<mem::ReadMem>(Src, Size); + } + + Error registerEHFrames(JITTargetAddress &RAddr, uint32_t Size) { + // FIXME: Duplicate error and report it via ReportError too? + return callB<eh::RegisterEHFrames>(RAddr, Size); + } + + JITTargetAddress reserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { + if (auto AddrOrErr = callB<mem::ReserveMem>(Id, Size, Align)) + return *AddrOrErr; + else { + ES.reportError(AddrOrErr.takeError()); + return 0; + } + } + + bool setProtections(ResourceIdMgr::ResourceId Id, + JITTargetAddress RemoteSegAddr, unsigned ProtFlags) { + if (auto Err = callB<mem::SetProtections>(Id, RemoteSegAddr, ProtFlags)) { + ES.reportError(std::move(Err)); + return true; + } else + return false; + } + + bool writeMem(JITTargetAddress Addr, const char *Src, uint64_t Size) { + if (auto Err = callB<mem::WriteMem>(DirectBufferWriter(Src, Addr, Size))) { + ES.reportError(std::move(Err)); + return true; + } else + return false; + } + + Error writePointer(JITTargetAddress Addr, JITTargetAddress PtrVal) { + return callB<mem::WritePtr>(Addr, PtrVal); + } + + static Error doNothing() { return Error::success(); } + + ExecutionSession &ES; + std::function<void(Error)> ReportError; + std::string RemoteTargetTriple; + uint32_t RemotePointerSize = 0; + uint32_t RemotePageSize = 0; + uint32_t RemoteTrampolineSize = 0; + uint32_t RemoteIndirectStubSize = 0; + ResourceIdMgr AllocatorIds, IndirectStubOwnerIds; + Optional<RemoteCompileCallbackManager> CallbackManager; +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#undef DEBUG_TYPE + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h new file mode 100644 index 000000000000..e7b598d8f812 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -0,0 +1,375 @@ +//===- OrcRemoteTargetRPCAPI.h - Orc Remote-target RPC API ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Orc remote-target RPC API. It should not be used +// directly, but is used by the RemoteTargetClient and RemoteTargetServer +// classes. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/RPCUtils.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" + +namespace llvm { +namespace orc { + +namespace remote { + +/// Template error for missing resources. +template <typename ResourceIdT> +class ResourceNotFound + : public ErrorInfo<ResourceNotFound<ResourceIdT>> { +public: + static char ID; + + ResourceNotFound(ResourceIdT ResourceId, + std::string ResourceDescription = "") + : ResourceId(std::move(ResourceId)), + ResourceDescription(std::move(ResourceDescription)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnknownResourceHandle); + } + + void log(raw_ostream &OS) const override { + OS << (ResourceDescription.empty() + ? "Remote resource with id " + : ResourceDescription) + << " " << ResourceId << " not found"; + } + +private: + ResourceIdT ResourceId; + std::string ResourceDescription; +}; + +template <typename ResourceIdT> +char ResourceNotFound<ResourceIdT>::ID = 0; + +class DirectBufferWriter { +public: + DirectBufferWriter() = default; + DirectBufferWriter(const char *Src, JITTargetAddress Dst, uint64_t Size) + : Src(Src), Dst(Dst), Size(Size) {} + + const char *getSrc() const { return Src; } + JITTargetAddress getDst() const { return Dst; } + uint64_t getSize() const { return Size; } + +private: + const char *Src; + JITTargetAddress Dst; + uint64_t Size; +}; + +} // end namespace remote + +namespace rpc { + +template <> +class RPCTypeName<JITSymbolFlags> { +public: + static const char *getName() { return "JITSymbolFlags"; } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, JITSymbolFlags> { +public: + + static Error serialize(ChannelT &C, const JITSymbolFlags &Flags) { + return serializeSeq(C, Flags.getRawFlagsValue(), Flags.getTargetFlags()); + } + + static Error deserialize(ChannelT &C, JITSymbolFlags &Flags) { + JITSymbolFlags::UnderlyingType JITFlags; + JITSymbolFlags::TargetFlagsType TargetFlags; + if (auto Err = deserializeSeq(C, JITFlags, TargetFlags)) + return Err; + Flags = JITSymbolFlags(static_cast<JITSymbolFlags::FlagNames>(JITFlags), + TargetFlags); + return Error::success(); + } +}; + +template <> class RPCTypeName<remote::DirectBufferWriter> { +public: + static const char *getName() { return "DirectBufferWriter"; } +}; + +template <typename ChannelT> +class SerializationTraits< + ChannelT, remote::DirectBufferWriter, remote::DirectBufferWriter, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value>::type> { +public: + static Error serialize(ChannelT &C, const remote::DirectBufferWriter &DBW) { + if (auto EC = serializeSeq(C, DBW.getDst())) + return EC; + if (auto EC = serializeSeq(C, DBW.getSize())) + return EC; + return C.appendBytes(DBW.getSrc(), DBW.getSize()); + } + + static Error deserialize(ChannelT &C, remote::DirectBufferWriter &DBW) { + JITTargetAddress Dst; + if (auto EC = deserializeSeq(C, Dst)) + return EC; + uint64_t Size; + if (auto EC = deserializeSeq(C, Size)) + return EC; + char *Addr = reinterpret_cast<char *>(static_cast<uintptr_t>(Dst)); + + DBW = remote::DirectBufferWriter(nullptr, Dst, Size); + + return C.readBytes(Addr, Size); + } +}; + +} // end namespace rpc + +namespace remote { + +class ResourceIdMgr { +public: + using ResourceId = uint64_t; + static const ResourceId InvalidId = ~0U; + + ResourceIdMgr() = default; + explicit ResourceIdMgr(ResourceId FirstValidId) + : NextId(std::move(FirstValidId)) {} + + ResourceId getNext() { + if (!FreeIds.empty()) { + ResourceId I = FreeIds.back(); + FreeIds.pop_back(); + return I; + } + assert(NextId + 1 != ~0ULL && "All ids allocated"); + return NextId++; + } + + void release(ResourceId I) { FreeIds.push_back(I); } + +private: + ResourceId NextId = 1; + std::vector<ResourceId> FreeIds; +}; + +/// Registers EH frames on the remote. +namespace eh { + + /// Registers EH frames on the remote. + class RegisterEHFrames + : public rpc::Function<RegisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { + public: + static const char *getName() { return "RegisterEHFrames"; } + }; + + /// Deregisters EH frames on the remote. + class DeregisterEHFrames + : public rpc::Function<DeregisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { + public: + static const char *getName() { return "DeregisterEHFrames"; } + }; + +} // end namespace eh + +/// RPC functions for executing remote code. +namespace exec { + + /// Call an 'int32_t()'-type function on the remote, returns the called + /// function's return value. + class CallIntVoid + : public rpc::Function<CallIntVoid, int32_t(JITTargetAddress Addr)> { + public: + static const char *getName() { return "CallIntVoid"; } + }; + + /// Call an 'int32_t(int32_t, char**)'-type function on the remote, returns the + /// called function's return value. + class CallMain + : public rpc::Function<CallMain, int32_t(JITTargetAddress Addr, + std::vector<std::string> Args)> { + public: + static const char *getName() { return "CallMain"; } + }; + + /// Calls a 'void()'-type function on the remote, returns when the called + /// function completes. + class CallVoidVoid + : public rpc::Function<CallVoidVoid, void(JITTargetAddress FnAddr)> { + public: + static const char *getName() { return "CallVoidVoid"; } + }; + +} // end namespace exec + +/// RPC functions for remote memory management / inspection / modification. +namespace mem { + + /// Creates a memory allocator on the remote. + class CreateRemoteAllocator + : public rpc::Function<CreateRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { + public: + static const char *getName() { return "CreateRemoteAllocator"; } + }; + + /// Destroys a remote allocator, freeing any memory allocated by it. + class DestroyRemoteAllocator + : public rpc::Function<DestroyRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { + public: + static const char *getName() { return "DestroyRemoteAllocator"; } + }; + + /// Read a remote memory block. + class ReadMem + : public rpc::Function<ReadMem, std::vector<uint8_t>(JITTargetAddress Src, + uint64_t Size)> { + public: + static const char *getName() { return "ReadMem"; } + }; + + /// Reserve a block of memory on the remote via the given allocator. + class ReserveMem + : public rpc::Function<ReserveMem, + JITTargetAddress(ResourceIdMgr::ResourceId AllocID, + uint64_t Size, uint32_t Align)> { + public: + static const char *getName() { return "ReserveMem"; } + }; + + /// Set the memory protection on a memory block. + class SetProtections + : public rpc::Function<SetProtections, + void(ResourceIdMgr::ResourceId AllocID, + JITTargetAddress Dst, uint32_t ProtFlags)> { + public: + static const char *getName() { return "SetProtections"; } + }; + + /// Write to a remote memory block. + class WriteMem + : public rpc::Function<WriteMem, void(remote::DirectBufferWriter DB)> { + public: + static const char *getName() { return "WriteMem"; } + }; + + /// Write to a remote pointer. + class WritePtr : public rpc::Function<WritePtr, void(JITTargetAddress Dst, + JITTargetAddress Val)> { + public: + static const char *getName() { return "WritePtr"; } + }; + +} // end namespace mem + +/// RPC functions for remote stub and trampoline management. +namespace stubs { + + /// Creates an indirect stub owner on the remote. + class CreateIndirectStubsOwner + : public rpc::Function<CreateIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubOwnerID)> { + public: + static const char *getName() { return "CreateIndirectStubsOwner"; } + }; + + /// RPC function for destroying an indirect stubs owner. + class DestroyIndirectStubsOwner + : public rpc::Function<DestroyIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubsOwnerID)> { + public: + static const char *getName() { return "DestroyIndirectStubsOwner"; } + }; + + /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). + class EmitIndirectStubs + : public rpc::Function< + EmitIndirectStubs, + std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> { + public: + static const char *getName() { return "EmitIndirectStubs"; } + }; + + /// RPC function to emit the resolver block and return its address. + class EmitResolverBlock : public rpc::Function<EmitResolverBlock, void()> { + public: + static const char *getName() { return "EmitResolverBlock"; } + }; + + /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). + class EmitTrampolineBlock + : public rpc::Function<EmitTrampolineBlock, + std::tuple<JITTargetAddress, uint32_t>()> { + public: + static const char *getName() { return "EmitTrampolineBlock"; } + }; + +} // end namespace stubs + +/// Miscelaneous RPC functions for dealing with remotes. +namespace utils { + + /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, + /// IndirectStubsSize). + class GetRemoteInfo + : public rpc::Function< + GetRemoteInfo, + std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>()> { + public: + static const char *getName() { return "GetRemoteInfo"; } + }; + + /// Get the address of a remote symbol. + class GetSymbolAddress + : public rpc::Function<GetSymbolAddress, + JITTargetAddress(std::string SymbolName)> { + public: + static const char *getName() { return "GetSymbolAddress"; } + }; + + /// Request that the host execute a compile callback. + class RequestCompile + : public rpc::Function< + RequestCompile, JITTargetAddress(JITTargetAddress TrampolineAddr)> { + public: + static const char *getName() { return "RequestCompile"; } + }; + + /// Notify the remote and terminate the session. + class TerminateSession : public rpc::Function<TerminateSession, void()> { + public: + static const char *getName() { return "TerminateSession"; } + }; + +} // namespace utils + +class OrcRemoteTargetRPCAPI + : public rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel> { +public: + // FIXME: Remove constructors once MSVC supports synthesizing move-ops. + OrcRemoteTargetRPCAPI(rpc::RawByteChannel &C) + : rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>(C, true) {} +}; + +} // end namespace remote + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h new file mode 100644 index 000000000000..4c8e2ea1a7be --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -0,0 +1,449 @@ +//===- OrcRemoteTargetServer.h - Orc Remote-target Server -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the OrcRemoteTargetServer class. It can be used to build a +// JIT server that can execute code sent from an OrcRemoteTargetClient. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <functional> +#include <map> +#include <memory> +#include <string> +#include <system_error> +#include <tuple> +#include <type_traits> +#include <vector> + +#define DEBUG_TYPE "orc-remote" + +namespace llvm { +namespace orc { +namespace remote { + +template <typename ChannelT, typename TargetT> +class OrcRemoteTargetServer + : public rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel> { +public: + using SymbolLookupFtor = + std::function<JITTargetAddress(const std::string &Name)>; + + using EHFrameRegistrationFtor = + std::function<void(uint8_t *Addr, uint32_t Size)>; + + OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup, + EHFrameRegistrationFtor EHFramesRegister, + EHFrameRegistrationFtor EHFramesDeregister) + : rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>(Channel, true), + SymbolLookup(std::move(SymbolLookup)), + EHFramesRegister(std::move(EHFramesRegister)), + EHFramesDeregister(std::move(EHFramesDeregister)) { + using ThisT = typename std::remove_reference<decltype(*this)>::type; + addHandler<exec::CallIntVoid>(*this, &ThisT::handleCallIntVoid); + addHandler<exec::CallMain>(*this, &ThisT::handleCallMain); + addHandler<exec::CallVoidVoid>(*this, &ThisT::handleCallVoidVoid); + addHandler<mem::CreateRemoteAllocator>(*this, + &ThisT::handleCreateRemoteAllocator); + addHandler<mem::DestroyRemoteAllocator>( + *this, &ThisT::handleDestroyRemoteAllocator); + addHandler<mem::ReadMem>(*this, &ThisT::handleReadMem); + addHandler<mem::ReserveMem>(*this, &ThisT::handleReserveMem); + addHandler<mem::SetProtections>(*this, &ThisT::handleSetProtections); + addHandler<mem::WriteMem>(*this, &ThisT::handleWriteMem); + addHandler<mem::WritePtr>(*this, &ThisT::handleWritePtr); + addHandler<eh::RegisterEHFrames>(*this, &ThisT::handleRegisterEHFrames); + addHandler<eh::DeregisterEHFrames>(*this, &ThisT::handleDeregisterEHFrames); + addHandler<stubs::CreateIndirectStubsOwner>( + *this, &ThisT::handleCreateIndirectStubsOwner); + addHandler<stubs::DestroyIndirectStubsOwner>( + *this, &ThisT::handleDestroyIndirectStubsOwner); + addHandler<stubs::EmitIndirectStubs>(*this, + &ThisT::handleEmitIndirectStubs); + addHandler<stubs::EmitResolverBlock>(*this, + &ThisT::handleEmitResolverBlock); + addHandler<stubs::EmitTrampolineBlock>(*this, + &ThisT::handleEmitTrampolineBlock); + addHandler<utils::GetSymbolAddress>(*this, &ThisT::handleGetSymbolAddress); + addHandler<utils::GetRemoteInfo>(*this, &ThisT::handleGetRemoteInfo); + addHandler<utils::TerminateSession>(*this, &ThisT::handleTerminateSession); + } + + // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. + OrcRemoteTargetServer(const OrcRemoteTargetServer &) = delete; + OrcRemoteTargetServer &operator=(const OrcRemoteTargetServer &) = delete; + + OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) = default; + OrcRemoteTargetServer &operator=(OrcRemoteTargetServer &&) = delete; + + Expected<JITTargetAddress> requestCompile(JITTargetAddress TrampolineAddr) { + return callB<utils::RequestCompile>(TrampolineAddr); + } + + bool receivedTerminate() const { return TerminateFlag; } + +private: + struct Allocator { + Allocator() = default; + Allocator(Allocator &&Other) : Allocs(std::move(Other.Allocs)) {} + + Allocator &operator=(Allocator &&Other) { + Allocs = std::move(Other.Allocs); + return *this; + } + + ~Allocator() { + for (auto &Alloc : Allocs) + sys::Memory::releaseMappedMemory(Alloc.second); + } + + Error allocate(void *&Addr, size_t Size, uint32_t Align) { + std::error_code EC; + sys::MemoryBlock MB = sys::Memory::allocateMappedMemory( + Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC); + if (EC) + return errorCodeToError(EC); + + Addr = MB.base(); + assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc"); + Allocs[MB.base()] = std::move(MB); + return Error::success(); + } + + Error setProtections(void *block, unsigned Flags) { + auto I = Allocs.find(block); + if (I == Allocs.end()) + return errorCodeToError(orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized)); + return errorCodeToError( + sys::Memory::protectMappedMemory(I->second, Flags)); + } + + private: + std::map<void *, sys::MemoryBlock> Allocs; + }; + + static Error doNothing() { return Error::success(); } + + static JITTargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) { + auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr); + auto AddrOrErr = T->requestCompile(static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineAddr))); + // FIXME: Allow customizable failure substitution functions. + assert(AddrOrErr && "Compile request failed"); + return *AddrOrErr; + } + + Expected<int32_t> handleCallIntVoid(JITTargetAddress Addr) { + using IntVoidFnTy = int (*)(); + + IntVoidFnTy Fn = + reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr)); + + LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); + int Result = Fn(); + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + + return Result; + } + + Expected<int32_t> handleCallMain(JITTargetAddress Addr, + std::vector<std::string> Args) { + using MainFnTy = int (*)(int, const char *[]); + + MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr)); + int ArgC = Args.size() + 1; + int Idx = 1; + std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]); + ArgV[0] = "<jit process>"; + for (auto &Arg : Args) + ArgV[Idx++] = Arg.c_str(); + ArgV[ArgC] = 0; + LLVM_DEBUG(for (int Idx = 0; Idx < ArgC; ++Idx) { + llvm::dbgs() << "Arg " << Idx << ": " << ArgV[Idx] << "\n"; + }); + + LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); + int Result = Fn(ArgC, ArgV.get()); + LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); + + return Result; + } + + Error handleCallVoidVoid(JITTargetAddress Addr) { + using VoidVoidFnTy = void (*)(); + + VoidVoidFnTy Fn = + reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr)); + + LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); + Fn(); + LLVM_DEBUG(dbgs() << " Complete.\n"); + + return Error::success(); + } + + Error handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { + auto I = Allocators.find(Id); + if (I != Allocators.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse)); + LLVM_DEBUG(dbgs() << " Created allocator " << Id << "\n"); + Allocators[Id] = Allocator(); + return Error::success(); + } + + Error handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { + auto I = IndirectStubsOwners.find(Id); + if (I != IndirectStubsOwners.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse)); + LLVM_DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n"); + IndirectStubsOwners[Id] = ISBlockOwnerList(); + return Error::success(); + } + + Error handleDeregisterEHFrames(JITTargetAddress TAddr, uint32_t Size) { + uint8_t *Addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(TAddr)); + LLVM_DEBUG(dbgs() << " Registering EH frames at " + << format("0x%016x", TAddr) << ", Size = " << Size + << " bytes\n"); + EHFramesDeregister(Addr, Size); + return Error::success(); + } + + Error handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); + Allocators.erase(I); + LLVM_DEBUG(dbgs() << " Destroyed allocator " << Id << "\n"); + return Error::success(); + } + + Error handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { + auto I = IndirectStubsOwners.find(Id); + if (I == IndirectStubsOwners.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist)); + IndirectStubsOwners.erase(I); + return Error::success(); + } + + Expected<std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>> + handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { + LLVM_DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired + << " stubs.\n"); + + auto StubOwnerItr = IndirectStubsOwners.find(Id); + if (StubOwnerItr == IndirectStubsOwners.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist)); + + typename TargetT::IndirectStubsInfo IS; + if (auto Err = + TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr)) + return std::move(Err); + + JITTargetAddress StubsBase = static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(IS.getStub(0))); + JITTargetAddress PtrsBase = static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(IS.getPtr(0))); + uint32_t NumStubsEmitted = IS.getNumStubs(); + + auto &BlockList = StubOwnerItr->second; + BlockList.push_back(std::move(IS)); + + return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted); + } + + Error handleEmitResolverBlock() { + std::error_code EC; + ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + TargetT::ResolverCodeSize, nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) + return errorCodeToError(EC); + + TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()), + &reenter, this); + + return errorCodeToError(sys::Memory::protectMappedMemory( + ResolverBlock.getMemoryBlock(), + sys::Memory::MF_READ | sys::Memory::MF_EXEC)); + } + + Expected<std::tuple<JITTargetAddress, uint32_t>> handleEmitTrampolineBlock() { + std::error_code EC; + auto TrampolineBlock = + sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + sys::Process::getPageSizeEstimate(), nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) + return errorCodeToError(EC); + + uint32_t NumTrampolines = + (sys::Process::getPageSizeEstimate() - TargetT::PointerSize) / + TargetT::TrampolineSize; + + uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base()); + TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(), + NumTrampolines); + + EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(), + sys::Memory::MF_READ | + sys::Memory::MF_EXEC); + + TrampolineBlocks.push_back(std::move(TrampolineBlock)); + + auto TrampolineBaseAddr = static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineMem)); + + return std::make_tuple(TrampolineBaseAddr, NumTrampolines); + } + + Expected<JITTargetAddress> handleGetSymbolAddress(const std::string &Name) { + JITTargetAddress Addr = SymbolLookup(Name); + LLVM_DEBUG(dbgs() << " Symbol '" << Name + << "' = " << format("0x%016x", Addr) << "\n"); + return Addr; + } + + Expected<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>> + handleGetRemoteInfo() { + std::string ProcessTriple = sys::getProcessTriple(); + uint32_t PointerSize = TargetT::PointerSize; + uint32_t PageSize = sys::Process::getPageSizeEstimate(); + uint32_t TrampolineSize = TargetT::TrampolineSize; + uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize; + LLVM_DEBUG(dbgs() << " Remote info:\n" + << " triple = '" << ProcessTriple << "'\n" + << " pointer size = " << PointerSize << "\n" + << " page size = " << PageSize << "\n" + << " trampoline size = " << TrampolineSize << "\n" + << " indirect stub size = " << IndirectStubSize + << "\n"); + return std::make_tuple(ProcessTriple, PointerSize, PageSize, TrampolineSize, + IndirectStubSize); + } + + Expected<std::vector<uint8_t>> handleReadMem(JITTargetAddress RSrc, + uint64_t Size) { + uint8_t *Src = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(RSrc)); + + LLVM_DEBUG(dbgs() << " Reading " << Size << " bytes from " + << format("0x%016x", RSrc) << "\n"); + + std::vector<uint8_t> Buffer; + Buffer.resize(Size); + for (uint8_t *P = Src; Size != 0; --Size) + Buffer.push_back(*P++); + + return Buffer; + } + + Error handleRegisterEHFrames(JITTargetAddress TAddr, uint32_t Size) { + uint8_t *Addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(TAddr)); + LLVM_DEBUG(dbgs() << " Registering EH frames at " + << format("0x%016x", TAddr) << ", Size = " << Size + << " bytes\n"); + EHFramesRegister(Addr, Size); + return Error::success(); + } + + Expected<JITTargetAddress> handleReserveMem(ResourceIdMgr::ResourceId Id, + uint64_t Size, uint32_t Align) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); + auto &Allocator = I->second; + void *LocalAllocAddr = nullptr; + if (auto Err = Allocator.allocate(LocalAllocAddr, Size, Align)) + return std::move(Err); + + LLVM_DEBUG(dbgs() << " Allocator " << Id << " reserved " << LocalAllocAddr + << " (" << Size << " bytes, alignment " << Align + << ")\n"); + + JITTargetAddress AllocAddr = static_cast<JITTargetAddress>( + reinterpret_cast<uintptr_t>(LocalAllocAddr)); + + return AllocAddr; + } + + Error handleSetProtections(ResourceIdMgr::ResourceId Id, + JITTargetAddress Addr, uint32_t Flags) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); + auto &Allocator = I->second; + void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr)); + LLVM_DEBUG(dbgs() << " Allocator " << Id << " set permissions on " + << LocalAddr << " to " + << (Flags & sys::Memory::MF_READ ? 'R' : '-') + << (Flags & sys::Memory::MF_WRITE ? 'W' : '-') + << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n"); + return Allocator.setProtections(LocalAddr, Flags); + } + + Error handleTerminateSession() { + TerminateFlag = true; + return Error::success(); + } + + Error handleWriteMem(DirectBufferWriter DBW) { + LLVM_DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " + << format("0x%016x", DBW.getDst()) << "\n"); + return Error::success(); + } + + Error handleWritePtr(JITTargetAddress Addr, JITTargetAddress PtrVal) { + LLVM_DEBUG(dbgs() << " Writing pointer *" << format("0x%016x", Addr) + << " = " << format("0x%016x", PtrVal) << "\n"); + uintptr_t *Ptr = + reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr)); + *Ptr = static_cast<uintptr_t>(PtrVal); + return Error::success(); + } + + SymbolLookupFtor SymbolLookup; + EHFrameRegistrationFtor EHFramesRegister, EHFramesDeregister; + std::map<ResourceIdMgr::ResourceId, Allocator> Allocators; + using ISBlockOwnerList = std::vector<typename TargetT::IndirectStubsInfo>; + std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners; + sys::OwningMemoryBlock ResolverBlock; + std::vector<sys::OwningMemoryBlock> TrampolineBlocks; + bool TerminateFlag = false; +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#undef DEBUG_TYPE + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h new file mode 100644 index 000000000000..752a0a34e0a1 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -0,0 +1,703 @@ +//===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H + +#include "OrcError.h" +#include "llvm/Support/thread.h" +#include <map> +#include <mutex> +#include <set> +#include <sstream> +#include <string> +#include <vector> + +namespace llvm { +namespace orc { +namespace rpc { + +template <typename T> +class RPCTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template <typename... ArgTs> class RPCTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template <typename OStream> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template <typename OStream, typename ArgT> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) { + OS << RPCTypeName<ArgT>::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> +OStream& +operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { + OS << RPCTypeName<ArgT1>::getName() << ", " + << RPCTypeNameSequence<ArgT2, ArgTs...>(); + return OS; +} + +template <> +class RPCTypeName<void> { +public: + static const char* getName() { return "void"; } +}; + +template <> +class RPCTypeName<int8_t> { +public: + static const char* getName() { return "int8_t"; } +}; + +template <> +class RPCTypeName<uint8_t> { +public: + static const char* getName() { return "uint8_t"; } +}; + +template <> +class RPCTypeName<int16_t> { +public: + static const char* getName() { return "int16_t"; } +}; + +template <> +class RPCTypeName<uint16_t> { +public: + static const char* getName() { return "uint16_t"; } +}; + +template <> +class RPCTypeName<int32_t> { +public: + static const char* getName() { return "int32_t"; } +}; + +template <> +class RPCTypeName<uint32_t> { +public: + static const char* getName() { return "uint32_t"; } +}; + +template <> +class RPCTypeName<int64_t> { +public: + static const char* getName() { return "int64_t"; } +}; + +template <> +class RPCTypeName<uint64_t> { +public: + static const char* getName() { return "uint64_t"; } +}; + +template <> +class RPCTypeName<bool> { +public: + static const char* getName() { return "bool"; } +}; + +template <> +class RPCTypeName<std::string> { +public: + static const char* getName() { return "std::string"; } +}; + +template <> +class RPCTypeName<Error> { +public: + static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence<T>() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T1, typename T2> +class RPCTypeName<std::pair<T1, T2>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename... ArgTs> +class RPCTypeName<std::tuple<ArgTs...>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::tuple<" + << RPCTypeNameSequence<ArgTs...>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> +class RPCTypeName<std::vector<T>> { +public: + static const char*getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class RPCTypeName<std::set<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::set<" << RPCTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename K, typename V> class RPCTypeName<std::map<K, V>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::map<" << RPCTypeNameSequence<K, V>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +/// The SerializationTraits<ChannelT, T> class describes how to serialize and +/// deserialize an instance of type T to/from an abstract channel of type +/// ChannelT. It also provides a representation of the type's name via the +/// getName method. +/// +/// Specializations of this class should provide the following functions: +/// +/// @code{.cpp} +/// +/// static const char* getName(); +/// static Error serialize(ChannelT&, const T&); +/// static Error deserialize(ChannelT&, T&); +/// +/// @endcode +/// +/// The third argument of SerializationTraits is intended to support SFINAE. +/// E.g.: +/// +/// @code{.cpp} +/// +/// class MyVirtualChannel { ... }; +/// +/// template <DerivedChannelT> +/// class SerializationTraits<DerivedChannelT, bool, +/// typename std::enable_if< +/// std::is_base_of<VirtChannel, DerivedChannel>::value +/// >::type> { +/// public: +/// static const char* getName() { ... }; +/// } +/// +/// @endcode +template <typename ChannelT, typename WireType, + typename ConcreteType = WireType, typename = void> +class SerializationTraits; + +template <typename ChannelT> +class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; + +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the +/// caller argument to over-the-wire value. +template <typename ChannelT, typename... ArgTs> +class SequenceSerialization; + +template <typename ChannelT> +class SequenceSerialization<ChannelT> { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; + +template <typename ChannelT, typename ArgT> +class SequenceSerialization<ChannelT, ArgT> { +public: + + template <typename CArgT> + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, + typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg)); + } + + template <typename CArgT> + static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); + } +}; + +template <typename ChannelT, typename ArgT, typename... ArgTs> +class SequenceSerialization<ChannelT, ArgT, ArgTs...> { +public: + + template <typename CArgT, typename... CArgTs> + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg))) + return Err; + if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>:: + serialize(C, std::forward<CArgTs>(CArgs)...); + } + + template <typename CArgT, typename... CArgTs> + static Error deserialize(ChannelT &C, CArgT &CArg, + CArgTs &... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); + } +}; + +template <typename ChannelT, typename... ArgTs> +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: + serialize(C, std::forward<ArgTs>(Args)...); +} + +template <typename ChannelT, typename... ArgTs> +Error deserializeSeq(ChannelT &C, ArgTs &... Args) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); +} + +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize = std::move(Serialize)]( + ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; + static std::map<const void*, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, + typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( + std::move(Name), + std::forward<SerializeFtor>(Serialize), + std::forward<DeserializeFtor>(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. +template <typename ChannelT> +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + registerErrorSerialization<ChannelT, StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error<StringError>(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + +/// SerializationTraits default specialization for std::pair. +template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> +class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { +public: + static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { + if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); + } + + static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { + if (auto Err = + SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); + } +}; + +/// SerializationTraits default specialization for std::tuple. +template <typename ChannelT, typename... ArgTs> +class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { +public: + + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + +private: + // Serialization helper for std::tuple. + template <size_t... Is> + static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); + } + + // Serialization helper for std::tuple. + template <size_t... Is> + static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); + } +}; + +/// SerializationTraits default specialization for std::vector. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>> { +public: + + /// Serialize a std::vector<T> from std::vector<T>. + static Error serialize(ChannelT &C, const std::vector<T> &V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::vector<T> to a std::vector<T>. + static Error deserialize(ChannelT &C, std::vector<T> &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = deserializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +template <typename ChannelT, typename T, typename T2> +class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { +public: + /// Serialize a std::set<T> from std::set<T2>. + static Error serialize(ChannelT &C, const std::set<T2> &S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set<T> to a std::set<T>. + static Error deserialize(ChannelT &C, std::set<T2> &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { +public: + /// Serialize a std::map<K, V> from std::map<K2, V2>. + static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map<K, V> to a std::map<K, V>. + static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h new file mode 100644 index 000000000000..ee9c2cc69c30 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -0,0 +1,1690 @@ +//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H + +#include <map> +#include <thread> +#include <vector> + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/RPCSerialization.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include <future> + +namespace llvm { +namespace orc { +namespace rpc { + +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall + : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId << "' with " + "sequence number " << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse + : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) + : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } +private: + SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } +private: + std::string Signature; +}; + +template <typename DerivedFunc, typename FnT> class Function; + +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template <typename DerivedFunc, typename RetT, typename... ArgTs> +class Function<DerivedFunc, RetT(ArgTs...)> { +public: + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName() + << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")"; + return Name; + }(); + return Name.data(); + } +}; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template <typename Func> T allocate(): +/// Allocate a unique id for function Func. +template <typename T, typename = void> class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template <typename T> +class RPCFunctionIdAllocator< + T, typename std::enable_if<std::is_integral<T>::value>::type> { +public: + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template <typename Func> T allocate() { return NextId++; } + +private: + T NextId = 3; +}; + +namespace detail { + +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; +}; + +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template <typename RetT> class ResultTraits { +public: + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected<RetT>; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPExpected<RetT>>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> class ResultTraits<void> { +public: + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<MSVCPError>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<MSVCPError>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; + +// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> class ResultTraits<Error> : public ResultTraits<void> {}; + +// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected<T> (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template <typename RetT> +class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; + +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> +class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> +class SupportsErrorReturn<Error> { +public: + static const bool value = true; +}; + +template <typename T> +class SupportsErrorReturn<Expected<T>> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper<true> { +public: + + // Send Expected<T>. + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, + Expected<HandlerRetT>>::serialize( + C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + if (auto Err = C.endSendMessage()) + return Err; + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA<RPCFatalError>()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper<false> { +public: + + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( + C, *ResultOrErr)) + return Err; + + // End the response message. + if (auto Err = C.endSendMessage()) + return Err; + + return C.send(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } + +}; + + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template <typename WireRetT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + sendResult(C, ResponseId, SeqNo, std::move(Err)); +} + +// Converts a given type to the equivalent error return type. +template <typename T> class WrappedHandlerReturn { +public: + using Type = Expected<T>; +}; + +template <typename T> class WrappedHandlerReturn<Expected<T>> { +public: + using Type = Expected<T>; +}; + +template <> class WrappedHandlerReturn<void> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<Error> { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn<ErrorSuccess> { +public: + using Type = Error; +}; + +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : + public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, + ArgTs...)> {}; + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template <typename HandlerT> +class HandlerTraits : public HandlerTraits<decltype( + &std::remove_reference<HandlerT>::type::operator())> { +}; + +// Traits for handlers with a given function type. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(ArgTs...)> { +public: + // Function type of the handler. + using Type = RetT(ArgTs...); + + // Return type of the handler. + using ReturnType = RetT; + + // Call the given handler with the given arguments. + template <typename HandlerT, typename... TArgTs> + static typename WrappedHandlerReturn<RetT>::Type + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { + return unpackAndRunHelper(Handler, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + std::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT> + static typename std::enable_if< + std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + Error>::type + run(HandlerT &Handler, ArgTs &&... Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template <typename HandlerT, typename... TArgTs> + static typename std::enable_if< + !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + typename HandlerTraits<HandlerT>::ReturnType>::type + run(HandlerT &Handler, TArgTs... Args) { + return Handler(std::move(Args)...); + } + + // Serialize arguments to the channel. + template <typename ChannelT, typename... CArgTs> + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + } + + // Deserialize arguments from the channel. + template <typename ChannelT, typename... CArgTs> + static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { + return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); + } + +private: + template <typename ChannelT, typename... CArgTs, size_t... Indexes> + static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, + std::index_sequence<Indexes...> _) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize( + C, std::get<Indexes>(Args)...); + } + + template <typename HandlerT, typename ArgTuple, size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, + std::index_sequence<Indexes...>) { + return run(Handler, std::move(std::get<Indexes>(Args))...); + } + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, std::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } +}; + +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Handler traits for class methods (especially call operators for lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Handler traits for const class methods (especially call operators for +// lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...) const> + : public HandlerTraits<RetT(ArgTs...)> {}; + +// Utility to peel the Expected wrapper off a response handler error type. +template <typename HandlerT> class ResponseHandlerArg; + +template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <typename ArgT> +class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg<Error(Error)> { +public: + using ArgType = Error; +}; + +template <> class ResponseHandlerArg<ErrorSuccess(Error)> { +public: + using ArgType = Error; +}; + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template <typename ChannelT> class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error<ResponseAbandoned>(); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; + UnwrappedArgType Result; + if (auto Err = + SerializationTraits<ChannelT, FuncRetT, + UnwrappedArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// ResponseHandler subclass for RPC functions with void returns. +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, void, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = + SerializationTraits<ChannelT, Expected<FuncRetT>, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( + C, Result)) { + consumeError(std::move(Result)); + return Err; + } + if (auto Err = C.endReceiveMessage()) { + consumeError(std::move(Result)); + return Err; + } + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// Create a ResponseHandler from a given user handler. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { + return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( + std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template <typename ClassT, typename RetT, typename... ArgTs> +class MemberFnWrapper { +public: + using MethodT = RetT (ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&... Args) { + return (Instance.*Method)(std::move(Args)...); + } + +private: + ClassT &Instance; + MethodT Method; +}; + +// Helper that provides a Functor for deserializing arguments. +template <typename... ArgTs> class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; + +template <typename ArgT, typename... ArgTs> +class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { +public: + ReadArgs(ArgT &Arg, ArgTs &... Args) + : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} + + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs<ArgTs...>::operator()(ArgVals...); + } + +private: + ArgT &Arg; +}; + +// Manage sequence numbers. +template <typename SequenceNumberT> class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; +}; + +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template <template <class, class> class P, typename T1Tuple, typename T2Tuple> +class RPCArgTypeCheckHelper; + +template <template <class, class> class P> +class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { +public: + static const bool value = true; +}; + +template <template <class, class> class P, typename T, typename... Ts, + typename U, typename... Us> +class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { +public: + static const bool value = + P<T, U>::value && + RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; +}; + +template <template <class, class> class P, typename T1Sig, typename T2Sig> +class RPCArgTypeCheck { +public: + using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type; + using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type; + + static_assert(std::tuple_size<T1Tuple>::value >= + std::tuple_size<T2Tuple>::value, + "Too many arguments to RPC call"); + static_assert(std::tuple_size<T1Tuple>::value <= + std::tuple_size<T2Tuple>::value, + "Too few arguments to RPC call"); + + static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanSerialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(typename std::enable_if< + std::is_same<decltype(T::serialize(std::declval<ChannelT &>(), + std::declval<const ConcreteT &>())), + Error>::value, + void *>::type); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanDeserialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(typename std::enable_if< + std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), + std::declval<ConcreteT &>())), + Error>::value, + void *>::type); + + template <typename> static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +/// Contains primitive utilities for defining, calling and handling calls to +/// remote procedures. ChannelT is a bidirectional stream conforming to the +/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure +/// identifier type that must be serializable on ChannelT, and SequenceNumberT +/// is an integral type that will be used to number in-flight function calls. +/// +/// These utilities support the construction of very primitive RPC utilities. +/// Their intent is to ensure correct serialization and deserialization of +/// procedure arguments, and to keep the client and server's view of the API in +/// sync. +template <typename ImplT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +class RPCEndpointBase { +protected: + class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { + public: + static const char *getName() { return "__orc_rpc$invalid"; } + }; + + class OrcRPCResponse : public Function<OrcRPCResponse, void()> { + public: + static const char *getName() { return "__orc_rpc$response"; } + }; + + class OrcRPCNegotiate + : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> { + public: + static const char *getName() { return "__orc_rpc$negotiate"; } + }; + + // Helper predicate for testing for the presence of SerializeTraits + // serializers. + template <typename WireT, typename ConcreteT> + class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing serializer for argument (Can't serialize the " + "first template type argument of CanSerializeCheck " + "from the second)"); + }; + + // Helper predicate for testing for the presence of SerializeTraits + // deserializers. + template <typename WireT, typename ConcreteT> + class CanDeserializeCheck + : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing deserializer for argument (Can't deserialize " + "the second template type argument of " + "CanDeserializeCheck from the first)"); + }; + +public: + /// Construct an RPC instance on a channel. + RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) + : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { + // Hold ResponseId in a special variable, since we expect Response to be + // called relatively frequently, and want to avoid the map lookup. + ResponseId = FnIdAllocator.getResponseId(); + RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; + + // Register the negotiate function id and handler. + auto NegotiateId = FnIdAllocator.getNegotiateId(); + RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; + Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( + [this](const std::string &Name) { return handleNegotiate(Name); }); + } + + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> Error negotiateFunction(bool Retry = false) { + return getRemoteFunctionId<Func>(true, Retry).takeError(); + } + + /// Append a call Func, does not call send on the channel. + /// The first argument specifies a user-defined handler to be run when the + /// function returns. The handler should take an Expected<Func::ReturnType>, + /// or an Error (if Func::ReturnType is void). The handler will be called + /// with an error if the return value is abandoned due to a channel error. + template <typename Func, typename HandlerT, typename... ArgTs> + Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { + + static_assert( + detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, + void(ArgTs...)>::value, + ""); + + // Look up the function ID. + FunctionIdT FnId; + if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) + FnId = *FnIdOrErr; + else { + // Negotiation failed. Notify the handler then return the negotiate-failed + // error. + cantFail(Handler(make_error<ResponseAbandoned>())); + return FnIdOrErr.takeError(); + } + + SequenceNumberT SeqNo; // initialized in locked scope below. + { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + // Allocate a sequence number. + SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); + + // Install the user handler. + PendingResponses[SeqNo] = + detail::createResponseHandler<ChannelT, typename Func::ReturnType>( + std::move(Handler)); + } + + // Open the function call message. + if (auto Err = C.startSendMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + + // Serialize the call arguments. + if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( + C, Args...)) { + abandonPendingResponses(); + return Err; + } + + // Close the function call messagee. + if (auto Err = C.endSendMessage()) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + Error sendAppendedCalls() { return C.send(); }; + + template <typename Func, typename HandlerT, typename... ArgTs> + Error callAsync(HandlerT Handler, const ArgTs &... Args) { + if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) + return Err; + return C.send(); + } + + /// Handle one incoming call. + Error handleOne() { + FunctionIdT FnId; + SequenceNumberT SeqNo; + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return Err; + } + if (FnId == ResponseId) + return handleResponse(SeqNo); + auto I = Handlers.find(FnId); + if (I != Handlers.end()) + return I->second(C, SeqNo); + + // else: No handler found. Report error to client? + return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, + SeqNo); + } + + /// Helper for handling setter procedures - this method returns a functor that + /// sets the variables referred to by Args... to values deserialized from the + /// channel. + /// E.g. + /// + /// typedef Function<0, bool, int> Func1; + /// + /// ... + /// bool B; + /// int I; + /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) + /// /* Handle Args */ ; + /// + template <typename... ArgTs> + static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { + return detail::ReadArgs<ArgTs...>(Args...); + } + + /// Abandon all outstanding result handlers. + /// + /// This will call all currently registered result handlers to receive an + /// "abandoned" error as their argument. This is used internally by the RPC + /// in error situations, but can also be called directly by clients who are + /// disconnecting from the remote and don't or can't expect responses to their + /// outstanding calls. (Especially for outstanding blocking calls, calling + /// this function may be necessary to avoid dead threads). + void abandonPendingResponses() { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } + + /// Remove the handler for the given function. + /// A handler must currently be registered for this function. + template <typename Func> + void removeHandler() { + auto IdItr = LocalFunctionIds.find(Func::getPrototype()); + assert(IdItr != LocalFunctionIds.end() && + "Function does not have a registered handler"); + auto HandlerItr = Handlers.find(IdItr->second); + assert(HandlerItr != Handlers.end() && + "Function does not have a registered handler"); + Handlers.erase(HandlerItr); + } + + /// Clear all handlers. + void clearHandlers() { + Handlers.clear(); + } + +protected: + + FunctionIdT getInvalidFunctionId() const { + return FnIdAllocator.getInvalidId(); + } + + /// Add the given handler to the handler map and make it available for + /// autonegotiation and execution. + template <typename Func, typename HandlerT> + void addHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::HandlerTraits<HandlerT>::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type + >::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); + } + + Error handleResponse(SequenceNumberT SeqNo) { + using Handler = typename decltype(PendingResponses)::mapped_type; + Handler PRHandler; + + { + // Lock the pending responses map and sequence number manager. + std::unique_lock<std::mutex> Lock(ResponsesMutex); + auto I = PendingResponses.find(SeqNo); + + if (I != PendingResponses.end()) { + PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + } else { + // Unlock the pending results map to prevent recursive lock. + Lock.unlock(); + abandonPendingResponses(); + return make_error< + InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); + } + } + + assert(PRHandler && + "If we didn't find a response handler we should have bailed out"); + + if (auto Err = PRHandler->handleResponse(C)) { + abandonPendingResponses(); + return Err; + } + + return Error::success(); + } + + FunctionIdT handleNegotiate(const std::string &Name) { + auto I = LocalFunctionIds.find(Name); + if (I == LocalFunctionIds.end()) + return getInvalidFunctionId(); + return I->second; + } + + // Find the remote FunctionId for the given function. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, + bool NegotiateIfInvalid) { + bool DoNegotiate; + + // Check if we already have a function id... + auto I = RemoteFunctionIds.find(Func::getPrototype()); + if (I != RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != getInvalidFunctionId()) + return I->second; + DoNegotiate = NegotiateIfInvalid; + } else + DoNegotiate = NegotiateIfNotInMap; + + // We don't have a function id for Func yet, but we're allowed to try to + // negotiate one. + if (DoNegotiate) { + auto &Impl = static_cast<ImplT &>(*this); + if (auto RemoteIdOrErr = + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == getInvalidFunctionId()) + return make_error<CouldNotNegotiate>(Func::getPrototype()); + return *RemoteIdOrErr; + } else + return RemoteIdOrErr.takeError(); + } + + // No key was available in the map and we weren't allowed to try to + // negotiate one, so return an unknown function error. + return make_error<CouldNotNegotiate>(Func::getPrototype()); + } + + using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using ArgsTuple = + typename detail::FunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = + [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); + }; + } + + ChannelT &C; + + bool LazyAutoNegotiation; + + RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; + + FunctionIdT ResponseId; + std::map<std::string, FunctionIdT> LocalFunctionIds; + std::map<const char *, FunctionIdT> RemoteFunctionIds; + + std::map<FunctionIdT, WrappedHandlerFn> Handlers; + + std::mutex ResponsesMutex; + detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> + PendingResponses; +}; + +} // end namespace detail + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class MultiThreadedRPCEndpoint + : public detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = + detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + /// Add a handler for the given RPC function. + /// This installs the given handler functor for the given RPC Function, and + /// makes the RPC function available for negotiation/calling from the remote. + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + /// Return type for non-blocking call primitives. + template <typename Func> + using NonBlockingCallResult = typename detail::ResultTraits< + typename Func::ReturnType>::ReturnFutureType; + + /// Call Func on Channel C. Does not block, does not call send. Returns a pair + /// of a future result and the sequence number assigned to the result. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallNB method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) { + using RTraits = detail::ResultTraits<typename Func::ReturnType>; + using ErrorReturn = typename RTraits::ErrorReturnType; + using ErrorReturnPromise = typename RTraits::ReturnPromiseType; + + ErrorReturnPromise Promise; + auto FutureResult = Promise.get_future(); + + if (auto Err = this->template appendCallAsync<Func>( + [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { + Promise.set_value(std::move(RetOrErr)); + return Error::success(); + }, + Args...)) { + RTraits::consumeAbandoned(FutureResult.get()); + return std::move(Err); + } + return std::move(FutureResult); + } + + /// The same as appendCallNBWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) { + auto Result = appendCallNB<Func>(Args...); + if (!Result) + return Result; + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result->get())); + return std::move(Err); + } + return Result; + } + + /// Call Func on Channel C. Blocks waiting for a result. Returns an Error + /// for void functions or an Expected<T> for functions returning a T. + /// + /// This function is for use in threaded code where another thread is + /// handling responses and incoming calls. + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &... Args) { + if (auto FutureResOrErr = callNB<Func>(Args...)) + return FutureResOrErr->get(); + else + return FutureResOrErr.takeError(); + } + + /// Handle incoming RPC calls. + Error handlerLoop() { + while (true) + if (auto Err = this->handleOne()) + return Err; + return Error::success(); + } +}; + +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class SingleThreadedRPCEndpoint + : public detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { +private: + using BaseClass = + detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; + +public: + SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} + + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); + } + + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &... Args) { + bool ReceivedResponse = false; + using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; + auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); + + // We have to 'Check' result (which we know is in a success state at this + // point) so that it can be overwritten in the async handler. + (void)!!Result; + + if (auto Err = this->template appendCallAsync<Func>( + [&](ResultType R) { + Result = std::move(R); + ReceivedResponse = true; + return Error::success(); + }, + Args...)) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + if (auto Err = this->C.send()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + + while (!ReceivedResponse) { + if (auto Err = this->handleOne()) { + detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + } + + return Result; + } +}; + +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> +class RPCAsyncDispatch { +public: + RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + + template <typename HandlerT, typename... ArgTs> + Error operator()(HandlerT Handler, const ArgTs &... Args) const { + return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); + } + +private: + RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { + return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} + +/// Allows a set of asynchrounous calls to be dispatched, and then +/// waited on as a group. +class ParallelCallGroup { +public: + + ParallelCallGroup() = default; + ParallelCallGroup(const ParallelCallGroup &) = delete; + ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; + + /// Make as asynchronous call. + template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> + Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, + const ArgTs &... Args) { + // Increment the count of outstanding calls. This has to happen before + // we invoke the call, as the handler may (depending on scheduling) + // be run immediately on another thread, and we don't want the decrement + // in the wrapped handler below to run before the increment. + { + std::unique_lock<std::mutex> Lock(M); + ++NumOutstandingCalls; + } + + // Wrap the user handler in a lambda that will decrement the + // outstanding calls count, then poke the condition variable. + using ArgType = typename detail::ResponseHandlerArg< + typename detail::HandlerTraits<HandlerT>::Type>::ArgType; + auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { + auto Err = Handler(std::move(Arg)); + std::unique_lock<std::mutex> Lock(M); + --NumOutstandingCalls; + CV.notify_all(); + return Err; + }; + + return AsyncDispatch(std::move(WrappedHandler), Args...); + } + + /// Blocks until all calls have been completed and their return value + /// handlers run. + void wait() { + std::unique_lock<std::mutex> Lock(M); + while (NumOutstandingCalls > 0) + CV.wait(Lock); + } + +private: + std::mutex M; + std::condition_variable CV; + uint32_t NumOutstandingCalls = 0; +}; + +/// Convenience class for grouping RPC Functions into APIs that can be +/// negotiated as a block. +/// +template <typename... Funcs> +class APICalls { +public: + + /// Test whether this API contains Function F. + template <typename F> + class Contains { + public: + static const bool value = false; + }; + + /// Negotiate all functions in this API. + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + return Error::success(); + } +}; + +template <typename Func, typename... Funcs> +class APICalls<Func, Funcs...> { +public: + + template <typename F> + class Contains { + public: + static const bool value = std::is_same<F, Func>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + if (auto Err = R.template negotiateFunction<Func>()) + return Err; + return APICalls<Funcs...>::negotiate(R); + } + +}; + +template <typename... InnerFuncs, typename... Funcs> +class APICalls<APICalls<InnerFuncs...>, Funcs...> { +public: + + template <typename F> + class Contains { + public: + static const bool value = + APICalls<InnerFuncs...>::template Contains<F>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) + return Err; + return APICalls<Funcs...>::negotiate(R); + } + +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h new file mode 100644 index 000000000000..c5106cf09ecc --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h @@ -0,0 +1,482 @@ +//===- RTDyldObjectLinkingLayer.h - RTDyld-based jit linking ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definition for an RTDyld-based, in-process object linking layer. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include "llvm/ExecutionEngine/Orc/Legacy.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/Error.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <list> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace llvm { +namespace orc { + +class RTDyldObjectLinkingLayer : public ObjectLayer { +public: + /// Functor for receiving object-loaded notifications. + using NotifyLoadedFunction = + std::function<void(VModuleKey, const object::ObjectFile &Obj, + const RuntimeDyld::LoadedObjectInfo &)>; + + /// Functor for receiving finalization notifications. + using NotifyEmittedFunction = + std::function<void(VModuleKey, std::unique_ptr<MemoryBuffer>)>; + + using GetMemoryManagerFunction = + std::function<std::unique_ptr<RuntimeDyld::MemoryManager>()>; + + /// Construct an ObjectLinkingLayer with the given NotifyLoaded, + /// and NotifyEmitted functors. + RTDyldObjectLinkingLayer(ExecutionSession &ES, + GetMemoryManagerFunction GetMemoryManager); + + /// Emit the object. + void emit(MaterializationResponsibility R, + std::unique_ptr<MemoryBuffer> O) override; + + /// Set the NotifyLoaded callback. + RTDyldObjectLinkingLayer &setNotifyLoaded(NotifyLoadedFunction NotifyLoaded) { + this->NotifyLoaded = std::move(NotifyLoaded); + return *this; + } + + /// Set the NotifyEmitted callback. + RTDyldObjectLinkingLayer & + setNotifyEmitted(NotifyEmittedFunction NotifyEmitted) { + this->NotifyEmitted = std::move(NotifyEmitted); + return *this; + } + + /// Set the 'ProcessAllSections' flag. + /// + /// If set to true, all sections in each object file will be allocated using + /// the memory manager, rather than just the sections required for execution. + /// + /// This is kludgy, and may be removed in the future. + RTDyldObjectLinkingLayer &setProcessAllSections(bool ProcessAllSections) { + this->ProcessAllSections = ProcessAllSections; + return *this; + } + + /// Instructs this RTDyldLinkingLayer2 instance to override the symbol flags + /// returned by RuntimeDyld for any given object file with the flags supplied + /// by the MaterializationResponsibility instance. This is a workaround to + /// support symbol visibility in COFF, which does not use the libObject's + /// SF_Exported flag. Use only when generating / adding COFF object files. + /// + /// FIXME: We should be able to remove this if/when COFF properly tracks + /// exported symbols. + RTDyldObjectLinkingLayer & + setOverrideObjectFlagsWithResponsibilityFlags(bool OverrideObjectFlags) { + this->OverrideObjectFlags = OverrideObjectFlags; + return *this; + } + + /// If set, this RTDyldObjectLinkingLayer instance will claim responsibility + /// for any symbols provided by a given object file that were not already in + /// the MaterializationResponsibility instance. Setting this flag allows + /// higher-level program representations (e.g. LLVM IR) to be added based on + /// only a subset of the symbols they provide, without having to write + /// intervening layers to scan and add the additional symbols. This trades + /// diagnostic quality for convenience however: If all symbols are enumerated + /// up-front then clashes can be detected and reported early (and usually + /// deterministically). If this option is set, clashes for the additional + /// symbols may not be detected until late, and detection may depend on + /// the flow of control through JIT'd code. Use with care. + RTDyldObjectLinkingLayer & + setAutoClaimResponsibilityForObjectSymbols(bool AutoClaimObjectSymbols) { + this->AutoClaimObjectSymbols = AutoClaimObjectSymbols; + return *this; + } + +private: + Error onObjLoad(VModuleKey K, MaterializationResponsibility &R, + object::ObjectFile &Obj, + std::unique_ptr<RuntimeDyld::LoadedObjectInfo> LoadedObjInfo, + std::map<StringRef, JITEvaluatedSymbol> Resolved, + std::set<StringRef> &InternalSymbols); + + void onObjEmit(VModuleKey K, std::unique_ptr<MemoryBuffer> ObjBuffer, + MaterializationResponsibility &R, Error Err); + + mutable std::mutex RTDyldLayerMutex; + GetMemoryManagerFunction GetMemoryManager; + NotifyLoadedFunction NotifyLoaded; + NotifyEmittedFunction NotifyEmitted; + bool ProcessAllSections = false; + bool OverrideObjectFlags = false; + bool AutoClaimObjectSymbols = false; + std::vector<std::unique_ptr<RuntimeDyld::MemoryManager>> MemMgrs; +}; + +class LegacyRTDyldObjectLinkingLayerBase { +public: + using ObjectPtr = std::unique_ptr<MemoryBuffer>; + +protected: + + /// Holds an object to be allocated/linked as a unit in the JIT. + /// + /// An instance of this class will be created for each object added + /// via JITObjectLayer::addObject. Deleting the instance (via + /// removeObject) frees its memory, removing all symbol definitions that + /// had been provided by this instance. Higher level layers are responsible + /// for taking any action required to handle the missing symbols. + class LinkedObject { + public: + LinkedObject() = default; + LinkedObject(const LinkedObject&) = delete; + void operator=(const LinkedObject&) = delete; + virtual ~LinkedObject() = default; + + virtual Error finalize() = 0; + + virtual JITSymbol::GetAddressFtor + getSymbolMaterializer(std::string Name) = 0; + + virtual void mapSectionAddress(const void *LocalAddress, + JITTargetAddress TargetAddr) const = 0; + + JITSymbol getSymbol(StringRef Name, bool ExportedSymbolsOnly) { + auto SymEntry = SymbolTable.find(Name); + if (SymEntry == SymbolTable.end()) + return nullptr; + if (!SymEntry->second.getFlags().isExported() && ExportedSymbolsOnly) + return nullptr; + if (!Finalized) + return JITSymbol(getSymbolMaterializer(Name), + SymEntry->second.getFlags()); + return JITSymbol(SymEntry->second); + } + + protected: + StringMap<JITEvaluatedSymbol> SymbolTable; + bool Finalized = false; + }; +}; + +/// Bare bones object linking layer. +/// +/// This class is intended to be used as the base layer for a JIT. It allows +/// object files to be loaded into memory, linked, and the addresses of their +/// symbols queried. All objects added to this layer can see each other's +/// symbols. +class LegacyRTDyldObjectLinkingLayer : public LegacyRTDyldObjectLinkingLayerBase { +public: + + using LegacyRTDyldObjectLinkingLayerBase::ObjectPtr; + + /// Functor for receiving object-loaded notifications. + using NotifyLoadedFtor = + std::function<void(VModuleKey, const object::ObjectFile &Obj, + const RuntimeDyld::LoadedObjectInfo &)>; + + /// Functor for receiving finalization notifications. + using NotifyFinalizedFtor = + std::function<void(VModuleKey, const object::ObjectFile &Obj, + const RuntimeDyld::LoadedObjectInfo &)>; + + /// Functor for receiving deallocation notifications. + using NotifyFreedFtor = std::function<void(VModuleKey, const object::ObjectFile &Obj)>; + +private: + using OwnedObject = object::OwningBinary<object::ObjectFile>; + + template <typename MemoryManagerPtrT> + class ConcreteLinkedObject : public LinkedObject { + public: + ConcreteLinkedObject(LegacyRTDyldObjectLinkingLayer &Parent, VModuleKey K, + OwnedObject Obj, MemoryManagerPtrT MemMgr, + std::shared_ptr<SymbolResolver> Resolver, + bool ProcessAllSections) + : K(std::move(K)), + Parent(Parent), + MemMgr(std::move(MemMgr)), + PFC(std::make_unique<PreFinalizeContents>( + std::move(Obj), std::move(Resolver), + ProcessAllSections)) { + buildInitialSymbolTable(PFC->Obj); + } + + ~ConcreteLinkedObject() override { + if (this->Parent.NotifyFreed && ObjForNotify.getBinary()) + this->Parent.NotifyFreed(K, *ObjForNotify.getBinary()); + + MemMgr->deregisterEHFrames(); + } + + Error finalize() override { + assert(PFC && "mapSectionAddress called on finalized LinkedObject"); + + JITSymbolResolverAdapter ResolverAdapter(Parent.ES, *PFC->Resolver, + nullptr); + PFC->RTDyld = std::make_unique<RuntimeDyld>(*MemMgr, ResolverAdapter); + PFC->RTDyld->setProcessAllSections(PFC->ProcessAllSections); + + Finalized = true; + + std::unique_ptr<RuntimeDyld::LoadedObjectInfo> Info = + PFC->RTDyld->loadObject(*PFC->Obj.getBinary()); + + // Copy the symbol table out of the RuntimeDyld instance. + { + auto SymTab = PFC->RTDyld->getSymbolTable(); + for (auto &KV : SymTab) + SymbolTable[KV.first] = KV.second; + } + + if (Parent.NotifyLoaded) + Parent.NotifyLoaded(K, *PFC->Obj.getBinary(), *Info); + + PFC->RTDyld->finalizeWithMemoryManagerLocking(); + + if (PFC->RTDyld->hasError()) + return make_error<StringError>(PFC->RTDyld->getErrorString(), + inconvertibleErrorCode()); + + if (Parent.NotifyFinalized) + Parent.NotifyFinalized(K, *PFC->Obj.getBinary(), *Info); + + // Release resources. + if (this->Parent.NotifyFreed) + ObjForNotify = std::move(PFC->Obj); // needed for callback + PFC = nullptr; + return Error::success(); + } + + JITSymbol::GetAddressFtor getSymbolMaterializer(std::string Name) override { + return [this, Name]() -> Expected<JITTargetAddress> { + // The symbol may be materialized between the creation of this lambda + // and its execution, so we need to double check. + if (!this->Finalized) + if (auto Err = this->finalize()) + return std::move(Err); + return this->getSymbol(Name, false).getAddress(); + }; + } + + void mapSectionAddress(const void *LocalAddress, + JITTargetAddress TargetAddr) const override { + assert(PFC && "mapSectionAddress called on finalized LinkedObject"); + assert(PFC->RTDyld && "mapSectionAddress called on raw LinkedObject"); + PFC->RTDyld->mapSectionAddress(LocalAddress, TargetAddr); + } + + private: + void buildInitialSymbolTable(const OwnedObject &Obj) { + for (auto &Symbol : Obj.getBinary()->symbols()) { + if (Symbol.getFlags() & object::SymbolRef::SF_Undefined) + continue; + Expected<StringRef> SymbolName = Symbol.getName(); + // FIXME: Raise an error for bad symbols. + if (!SymbolName) { + consumeError(SymbolName.takeError()); + continue; + } + // FIXME: Raise an error for bad symbols. + auto Flags = JITSymbolFlags::fromObjectSymbol(Symbol); + if (!Flags) { + consumeError(Flags.takeError()); + continue; + } + SymbolTable.insert( + std::make_pair(*SymbolName, JITEvaluatedSymbol(0, *Flags))); + } + } + + // Contains the information needed prior to finalization: the object files, + // memory manager, resolver, and flags needed for RuntimeDyld. + struct PreFinalizeContents { + PreFinalizeContents(OwnedObject Obj, + std::shared_ptr<SymbolResolver> Resolver, + bool ProcessAllSections) + : Obj(std::move(Obj)), + Resolver(std::move(Resolver)), + ProcessAllSections(ProcessAllSections) {} + + OwnedObject Obj; + std::shared_ptr<SymbolResolver> Resolver; + bool ProcessAllSections; + std::unique_ptr<RuntimeDyld> RTDyld; + }; + + VModuleKey K; + LegacyRTDyldObjectLinkingLayer &Parent; + MemoryManagerPtrT MemMgr; + OwnedObject ObjForNotify; + std::unique_ptr<PreFinalizeContents> PFC; + }; + + template <typename MemoryManagerPtrT> + std::unique_ptr<ConcreteLinkedObject<MemoryManagerPtrT>> + createLinkedObject(LegacyRTDyldObjectLinkingLayer &Parent, VModuleKey K, + OwnedObject Obj, MemoryManagerPtrT MemMgr, + std::shared_ptr<SymbolResolver> Resolver, + bool ProcessAllSections) { + using LOS = ConcreteLinkedObject<MemoryManagerPtrT>; + return std::make_unique<LOS>(Parent, std::move(K), std::move(Obj), + std::move(MemMgr), std::move(Resolver), + ProcessAllSections); + } + +public: + struct Resources { + std::shared_ptr<RuntimeDyld::MemoryManager> MemMgr; + std::shared_ptr<SymbolResolver> Resolver; + }; + + using ResourcesGetter = std::function<Resources(VModuleKey)>; + + /// Construct an ObjectLinkingLayer with the given NotifyLoaded, + /// and NotifyFinalized functors. + LLVM_ATTRIBUTE_DEPRECATED( + LegacyRTDyldObjectLinkingLayer( + ExecutionSession &ES, ResourcesGetter GetResources, + NotifyLoadedFtor NotifyLoaded = NotifyLoadedFtor(), + NotifyFinalizedFtor NotifyFinalized = NotifyFinalizedFtor(), + NotifyFreedFtor NotifyFreed = NotifyFreedFtor()), + "ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please " + "use " + "ORCv2 (see docs/ORCv2.rst)"); + + // Legacy layer constructor with deprecation acknowledgement. + LegacyRTDyldObjectLinkingLayer( + ORCv1DeprecationAcknowledgement, ExecutionSession &ES, + ResourcesGetter GetResources, + NotifyLoadedFtor NotifyLoaded = NotifyLoadedFtor(), + NotifyFinalizedFtor NotifyFinalized = NotifyFinalizedFtor(), + NotifyFreedFtor NotifyFreed = NotifyFreedFtor()) + : ES(ES), GetResources(std::move(GetResources)), + NotifyLoaded(std::move(NotifyLoaded)), + NotifyFinalized(std::move(NotifyFinalized)), + NotifyFreed(std::move(NotifyFreed)), ProcessAllSections(false) {} + + /// Set the 'ProcessAllSections' flag. + /// + /// If set to true, all sections in each object file will be allocated using + /// the memory manager, rather than just the sections required for execution. + /// + /// This is kludgy, and may be removed in the future. + void setProcessAllSections(bool ProcessAllSections) { + this->ProcessAllSections = ProcessAllSections; + } + + /// Add an object to the JIT. + Error addObject(VModuleKey K, ObjectPtr ObjBuffer) { + + auto Obj = + object::ObjectFile::createObjectFile(ObjBuffer->getMemBufferRef()); + if (!Obj) + return Obj.takeError(); + + assert(!LinkedObjects.count(K) && "VModuleKey already in use"); + + auto R = GetResources(K); + + LinkedObjects[K] = createLinkedObject( + *this, K, OwnedObject(std::move(*Obj), std::move(ObjBuffer)), + std::move(R.MemMgr), std::move(R.Resolver), ProcessAllSections); + + return Error::success(); + } + + /// Remove the object associated with VModuleKey K. + /// + /// All memory allocated for the object will be freed, and the sections and + /// symbols it provided will no longer be available. No attempt is made to + /// re-emit the missing symbols, and any use of these symbols (directly or + /// indirectly) will result in undefined behavior. If dependence tracking is + /// required to detect or resolve such issues it should be added at a higher + /// layer. + Error removeObject(VModuleKey K) { + assert(LinkedObjects.count(K) && "VModuleKey not associated with object"); + // How do we invalidate the symbols in H? + LinkedObjects.erase(K); + return Error::success(); + } + + /// Search for the given named symbol. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it exists. + JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) { + for (auto &KV : LinkedObjects) + if (auto Sym = KV.second->getSymbol(Name, ExportedSymbolsOnly)) + return Sym; + else if (auto Err = Sym.takeError()) + return std::move(Err); + + return nullptr; + } + + /// Search for the given named symbol in the context of the loaded + /// object represented by the VModuleKey K. + /// @param K The VModuleKey for the object to search in. + /// @param Name The name of the symbol to search for. + /// @param ExportedSymbolsOnly If true, search only for exported symbols. + /// @return A handle for the given named symbol, if it is found in the + /// given object. + JITSymbol findSymbolIn(VModuleKey K, StringRef Name, + bool ExportedSymbolsOnly) { + assert(LinkedObjects.count(K) && "VModuleKey not associated with object"); + return LinkedObjects[K]->getSymbol(Name, ExportedSymbolsOnly); + } + + /// Map section addresses for the object associated with the + /// VModuleKey K. + void mapSectionAddress(VModuleKey K, const void *LocalAddress, + JITTargetAddress TargetAddr) { + assert(LinkedObjects.count(K) && "VModuleKey not associated with object"); + LinkedObjects[K]->mapSectionAddress(LocalAddress, TargetAddr); + } + + /// Immediately emit and finalize the object represented by the given + /// VModuleKey. + /// @param K VModuleKey for object to emit/finalize. + Error emitAndFinalize(VModuleKey K) { + assert(LinkedObjects.count(K) && "VModuleKey not associated with object"); + return LinkedObjects[K]->finalize(); + } + +private: + ExecutionSession &ES; + + ResourcesGetter GetResources; + NotifyLoadedFtor NotifyLoaded; + NotifyFinalizedFtor NotifyFinalized; + NotifyFreedFtor NotifyFreed; + + // NB! `LinkedObjects` needs to be destroyed before `NotifyFreed` because + // `~ConcreteLinkedObject` calls `NotifyFreed` + std::map<VModuleKey, std::unique_ptr<LinkedObject>> LinkedObjects; + bool ProcessAllSections = false; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h new file mode 100644 index 000000000000..46b7c59450e6 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -0,0 +1,184 @@ +//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/RPCSerialization.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include <cstdint> +#include <mutex> +#include <string> +#include <type_traits> + +namespace llvm { +namespace orc { +namespace rpc { + +/// Interface for byte-streams to be used with RPC. +class RawByteChannel { +public: + virtual ~RawByteChannel() = default; + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template <typename FunctionIdT, typename SequenceIdT> + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + writeLock.lock(); + if (auto Err = serializeSeq(*this, FnId, SeqNo)) { + writeLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template <typename FunctionIdT, typename SequenceNumberT> + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { + readLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, T, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || + std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || + std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || + std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || + std::is_same<T, char>::value)>::type> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap<T, support::big>(V); + return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; + support::endian::byte_swap<T, support::big>(V); + return Error::success(); + }; +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, bool, bool, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + static Error serialize(ChannelT &C, bool V) { + uint8_t Tmp = V ? 1 : 0; + if (auto Err = + C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) + return Err; + return Error::success(); + } + + static Error deserialize(ChannelT &C, bool &V) { + uint8_t Tmp = 0; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) + return Err; + V = Tmp != 0; + return Error::success(); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, StringRef, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::string, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, const char*>::value || + std::is_same<T, char*>::value)>::type> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, std::string, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } + + /// RPC channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h new file mode 100644 index 000000000000..d7304cfcf931 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h @@ -0,0 +1,564 @@ +//===------ RemoteObjectLayer.h - Forwards objs to a remote -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Forwards objects to a remote object layer via RPC. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H + +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/LambdaResolver.h" +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" +#include "llvm/Object/ObjectFile.h" +#include <map> + +namespace llvm { +namespace orc { + +/// RPC API needed by RemoteObjectClientLayer and RemoteObjectServerLayer. +class RemoteObjectLayerAPI { +public: + + using ObjHandleT = remote::ResourceIdMgr::ResourceId; + +protected: + + using RemoteSymbolId = remote::ResourceIdMgr::ResourceId; + using RemoteSymbol = std::pair<RemoteSymbolId, JITSymbolFlags>; + +public: + + using BadSymbolHandleError = remote::ResourceNotFound<RemoteSymbolId>; + using BadObjectHandleError = remote::ResourceNotFound<ObjHandleT>; + +protected: + + static const ObjHandleT InvalidObjectHandleId = 0; + static const RemoteSymbolId NullSymbolId = 0; + + class AddObject + : public rpc::Function<AddObject, Expected<ObjHandleT>(std::string)> { + public: + static const char *getName() { return "AddObject"; } + }; + + class RemoveObject + : public rpc::Function<RemoveObject, Error(ObjHandleT)> { + public: + static const char *getName() { return "RemoveObject"; } + }; + + class FindSymbol + : public rpc::Function<FindSymbol, Expected<RemoteSymbol>(std::string, + bool)> { + public: + static const char *getName() { return "FindSymbol"; } + }; + + class FindSymbolIn + : public rpc::Function<FindSymbolIn, + Expected<RemoteSymbol>(ObjHandleT, std::string, + bool)> { + public: + static const char *getName() { return "FindSymbolIn"; } + }; + + class EmitAndFinalize + : public rpc::Function<EmitAndFinalize, + Error(ObjHandleT)> { + public: + static const char *getName() { return "EmitAndFinalize"; } + }; + + class Lookup + : public rpc::Function<Lookup, + Expected<RemoteSymbol>(ObjHandleT, std::string)> { + public: + static const char *getName() { return "Lookup"; } + }; + + class LookupInLogicalDylib + : public rpc::Function<LookupInLogicalDylib, + Expected<RemoteSymbol>(ObjHandleT, std::string)> { + public: + static const char *getName() { return "LookupInLogicalDylib"; } + }; + + class ReleaseRemoteSymbol + : public rpc::Function<ReleaseRemoteSymbol, Error(RemoteSymbolId)> { + public: + static const char *getName() { return "ReleaseRemoteSymbol"; } + }; + + class MaterializeRemoteSymbol + : public rpc::Function<MaterializeRemoteSymbol, + Expected<JITTargetAddress>(RemoteSymbolId)> { + public: + static const char *getName() { return "MaterializeRemoteSymbol"; } + }; +}; + +/// Base class containing common utilities for RemoteObjectClientLayer and +/// RemoteObjectServerLayer. +template <typename RPCEndpoint> +class RemoteObjectLayer : public RemoteObjectLayerAPI { +public: + + RemoteObjectLayer(RPCEndpoint &Remote, + std::function<void(Error)> ReportError) + : Remote(Remote), ReportError(std::move(ReportError)), + SymbolIdMgr(NullSymbolId + 1) { + using ThisT = RemoteObjectLayer<RPCEndpoint>; + Remote.template addHandler<ReleaseRemoteSymbol>( + *this, &ThisT::handleReleaseRemoteSymbol); + Remote.template addHandler<MaterializeRemoteSymbol>( + *this, &ThisT::handleMaterializeRemoteSymbol); + } + +protected: + + /// This class is used as the symbol materializer for JITSymbols returned by + /// RemoteObjectLayerClient/RemoteObjectLayerServer -- the materializer knows + /// how to call back to the other RPC endpoint to get the address when + /// requested. + class RemoteSymbolMaterializer { + public: + + /// Construct a RemoteSymbolMaterializer for the given RemoteObjectLayer + /// with the given Id. + RemoteSymbolMaterializer(RemoteObjectLayer &C, + RemoteSymbolId Id) + : C(C), Id(Id) {} + + RemoteSymbolMaterializer(RemoteSymbolMaterializer &&Other) + : C(Other.C), Id(Other.Id) { + Other.Id = 0; + } + + RemoteSymbolMaterializer &operator=(RemoteSymbolMaterializer &&) = delete; + + /// Release the remote symbol. + ~RemoteSymbolMaterializer() { + if (Id) + C.releaseRemoteSymbol(Id); + } + + /// Materialize the symbol on the remote and get its address. + Expected<JITTargetAddress> materialize() { + auto Addr = C.materializeRemoteSymbol(Id); + Id = 0; + return Addr; + } + + private: + RemoteObjectLayer &C; + RemoteSymbolId Id; + }; + + /// Convenience function for getting a null remote symbol value. + RemoteSymbol nullRemoteSymbol() { + return RemoteSymbol(0, JITSymbolFlags()); + } + + /// Creates a StringError that contains a copy of Err's log message, then + /// sends that StringError to ReportError. + /// + /// This allows us to locally log error messages for errors that will actually + /// be delivered to the remote. + Error teeLog(Error Err) { + return handleErrors(std::move(Err), + [this](std::unique_ptr<ErrorInfoBase> EIB) { + ReportError(make_error<StringError>( + EIB->message(), + EIB->convertToErrorCode())); + return Error(std::move(EIB)); + }); + } + + Error badRemoteSymbolIdError(RemoteSymbolId Id) { + return make_error<BadSymbolHandleError>(Id, "Remote JIT Symbol"); + } + + Error badObjectHandleError(ObjHandleT H) { + return make_error<RemoteObjectLayerAPI::BadObjectHandleError>( + H, "Bad object handle"); + } + + /// Create a RemoteSymbol wrapping the given JITSymbol. + Expected<RemoteSymbol> jitSymbolToRemote(JITSymbol Sym) { + if (Sym) { + auto Id = SymbolIdMgr.getNext(); + auto Flags = Sym.getFlags(); + assert(!InUseSymbols.count(Id) && "Symbol id already in use"); + InUseSymbols.insert(std::make_pair(Id, std::move(Sym))); + return RemoteSymbol(Id, Flags); + } else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + // else... + return nullRemoteSymbol(); + } + + /// Convert an Expected<RemoteSymbol> to a JITSymbol. + JITSymbol remoteToJITSymbol(Expected<RemoteSymbol> RemoteSymOrErr) { + if (RemoteSymOrErr) { + auto &RemoteSym = *RemoteSymOrErr; + if (RemoteSym == nullRemoteSymbol()) + return nullptr; + // else... + RemoteSymbolMaterializer RSM(*this, RemoteSym.first); + auto Sym = JITSymbol( + [RSM = std::move(RSM)]() mutable { return RSM.materialize(); }, + RemoteSym.second); + return Sym; + } else + return RemoteSymOrErr.takeError(); + } + + RPCEndpoint &Remote; + std::function<void(Error)> ReportError; + +private: + + /// Notify the remote to release the given JITSymbol. + void releaseRemoteSymbol(RemoteSymbolId Id) { + if (auto Err = Remote.template callB<ReleaseRemoteSymbol>(Id)) + ReportError(std::move(Err)); + } + + /// Notify the remote to materialize the JITSymbol with the given Id and + /// return its address. + Expected<JITTargetAddress> materializeRemoteSymbol(RemoteSymbolId Id) { + return Remote.template callB<MaterializeRemoteSymbol>(Id); + } + + /// Release the JITSymbol with the given Id. + Error handleReleaseRemoteSymbol(RemoteSymbolId Id) { + auto SI = InUseSymbols.find(Id); + if (SI != InUseSymbols.end()) { + InUseSymbols.erase(SI); + return Error::success(); + } else + return teeLog(badRemoteSymbolIdError(Id)); + } + + /// Run the materializer for the JITSymbol with the given Id and return its + /// address. + Expected<JITTargetAddress> handleMaterializeRemoteSymbol(RemoteSymbolId Id) { + auto SI = InUseSymbols.find(Id); + if (SI != InUseSymbols.end()) { + auto AddrOrErr = SI->second.getAddress(); + InUseSymbols.erase(SI); + SymbolIdMgr.release(Id); + if (AddrOrErr) + return *AddrOrErr; + else + return teeLog(AddrOrErr.takeError()); + } else { + return teeLog(badRemoteSymbolIdError(Id)); + } + } + + remote::ResourceIdMgr SymbolIdMgr; + std::map<RemoteSymbolId, JITSymbol> InUseSymbols; +}; + +/// RemoteObjectClientLayer forwards the ORC Object Layer API over an RPC +/// connection. +/// +/// This class can be used as the base layer of a JIT stack on the client and +/// will forward operations to a corresponding RemoteObjectServerLayer on the +/// server (which can be composed on top of a "real" object layer like +/// RTDyldObjectLinkingLayer to actually carry out the operations). +/// +/// Sending relocatable objects to the server (rather than fully relocated +/// bits) allows JIT'd code to be cached on the server side and re-used in +/// subsequent JIT sessions. +template <typename RPCEndpoint> +class RemoteObjectClientLayer : public RemoteObjectLayer<RPCEndpoint> { +private: + + using AddObject = RemoteObjectLayerAPI::AddObject; + using RemoveObject = RemoteObjectLayerAPI::RemoveObject; + using FindSymbol = RemoteObjectLayerAPI::FindSymbol; + using FindSymbolIn = RemoteObjectLayerAPI::FindSymbolIn; + using EmitAndFinalize = RemoteObjectLayerAPI::EmitAndFinalize; + using Lookup = RemoteObjectLayerAPI::Lookup; + using LookupInLogicalDylib = RemoteObjectLayerAPI::LookupInLogicalDylib; + + using RemoteObjectLayer<RPCEndpoint>::teeLog; + using RemoteObjectLayer<RPCEndpoint>::badObjectHandleError; + using RemoteObjectLayer<RPCEndpoint>::remoteToJITSymbol; + +public: + + using ObjHandleT = RemoteObjectLayerAPI::ObjHandleT; + using RemoteSymbol = RemoteObjectLayerAPI::RemoteSymbol; + + using ObjectPtr = std::unique_ptr<MemoryBuffer>; + + /// Create a RemoteObjectClientLayer that communicates with a + /// RemoteObjectServerLayer instance via the given RPCEndpoint. + /// + /// The ReportError functor can be used locally log errors that are intended + /// to be sent sent + LLVM_ATTRIBUTE_DEPRECATED( + RemoteObjectClientLayer(RPCEndpoint &Remote, + std::function<void(Error)> ReportError), + "ORCv1 layers (including RemoteObjectClientLayer) are deprecated. Please " + "use " + "ORCv2 (see docs/ORCv2.rst)"); + + RemoteObjectClientLayer(ORCv1DeprecationAcknowledgement, RPCEndpoint &Remote, + std::function<void(Error)> ReportError) + : RemoteObjectLayer<RPCEndpoint>(Remote, std::move(ReportError)) { + using ThisT = RemoteObjectClientLayer<RPCEndpoint>; + Remote.template addHandler<Lookup>(*this, &ThisT::lookup); + Remote.template addHandler<LookupInLogicalDylib>( + *this, &ThisT::lookupInLogicalDylib); + } + + /// Add an object to the JIT. + /// + /// @return A handle that can be used to refer to the loaded object (for + /// symbol searching, finalization, freeing memory, etc.). + Expected<ObjHandleT> + addObject(ObjectPtr ObjBuffer, + std::shared_ptr<LegacyJITSymbolResolver> Resolver) { + if (auto HandleOrErr = + this->Remote.template callB<AddObject>(ObjBuffer->getBuffer())) { + auto &Handle = *HandleOrErr; + // FIXME: Return an error for this: + assert(!Resolvers.count(Handle) && "Handle already in use?"); + Resolvers[Handle] = std::move(Resolver); + return Handle; + } else + return HandleOrErr.takeError(); + } + + /// Remove the given object from the JIT. + Error removeObject(ObjHandleT H) { + return this->Remote.template callB<RemoveObject>(H); + } + + /// Search for the given named symbol. + JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) { + return remoteToJITSymbol( + this->Remote.template callB<FindSymbol>(Name, + ExportedSymbolsOnly)); + } + + /// Search for the given named symbol within the given context. + JITSymbol findSymbolIn(ObjHandleT H, StringRef Name, bool ExportedSymbolsOnly) { + return remoteToJITSymbol( + this->Remote.template callB<FindSymbolIn>(H, Name, + ExportedSymbolsOnly)); + } + + /// Immediately emit and finalize the object with the given handle. + Error emitAndFinalize(ObjHandleT H) { + return this->Remote.template callB<EmitAndFinalize>(H); + } + +private: + + Expected<RemoteSymbol> lookup(ObjHandleT H, const std::string &Name) { + auto RI = Resolvers.find(H); + if (RI != Resolvers.end()) { + return this->jitSymbolToRemote(RI->second->findSymbol(Name)); + } else + return teeLog(badObjectHandleError(H)); + } + + Expected<RemoteSymbol> lookupInLogicalDylib(ObjHandleT H, + const std::string &Name) { + auto RI = Resolvers.find(H); + if (RI != Resolvers.end()) + return this->jitSymbolToRemote( + RI->second->findSymbolInLogicalDylib(Name)); + else + return teeLog(badObjectHandleError(H)); + } + + std::map<remote::ResourceIdMgr::ResourceId, + std::shared_ptr<LegacyJITSymbolResolver>> + Resolvers; +}; + +/// RemoteObjectServerLayer acts as a server and handling RPC calls for the +/// object layer API from the given RPC connection. +/// +/// This class can be composed on top of a 'real' object layer (e.g. +/// RTDyldObjectLinkingLayer) to do the actual work of relocating objects +/// and making them executable. +template <typename BaseLayerT, typename RPCEndpoint> +class RemoteObjectServerLayer : public RemoteObjectLayer<RPCEndpoint> { +private: + + using ObjHandleT = RemoteObjectLayerAPI::ObjHandleT; + using RemoteSymbol = RemoteObjectLayerAPI::RemoteSymbol; + + using AddObject = RemoteObjectLayerAPI::AddObject; + using RemoveObject = RemoteObjectLayerAPI::RemoveObject; + using FindSymbol = RemoteObjectLayerAPI::FindSymbol; + using FindSymbolIn = RemoteObjectLayerAPI::FindSymbolIn; + using EmitAndFinalize = RemoteObjectLayerAPI::EmitAndFinalize; + using Lookup = RemoteObjectLayerAPI::Lookup; + using LookupInLogicalDylib = RemoteObjectLayerAPI::LookupInLogicalDylib; + + using RemoteObjectLayer<RPCEndpoint>::teeLog; + using RemoteObjectLayer<RPCEndpoint>::badObjectHandleError; + using RemoteObjectLayer<RPCEndpoint>::remoteToJITSymbol; + +public: + + /// Create a RemoteObjectServerLayer with the given base layer (which must be + /// an object layer), RPC endpoint, and error reporter function. + LLVM_ATTRIBUTE_DEPRECATED( + RemoteObjectServerLayer(BaseLayerT &BaseLayer, RPCEndpoint &Remote, + std::function<void(Error)> ReportError), + "ORCv1 layers (including RemoteObjectServerLayer) are deprecated. Please " + "use " + "ORCv2 (see docs/ORCv2.rst)"); + + RemoteObjectServerLayer(ORCv1DeprecationAcknowledgement, + BaseLayerT &BaseLayer, RPCEndpoint &Remote, + std::function<void(Error)> ReportError) + : RemoteObjectLayer<RPCEndpoint>(Remote, std::move(ReportError)), + BaseLayer(BaseLayer), HandleIdMgr(1) { + using ThisT = RemoteObjectServerLayer<BaseLayerT, RPCEndpoint>; + + Remote.template addHandler<AddObject>(*this, &ThisT::addObject); + Remote.template addHandler<RemoveObject>(*this, &ThisT::removeObject); + Remote.template addHandler<FindSymbol>(*this, &ThisT::findSymbol); + Remote.template addHandler<FindSymbolIn>(*this, &ThisT::findSymbolIn); + Remote.template addHandler<EmitAndFinalize>(*this, &ThisT::emitAndFinalize); + } + +private: + + class StringMemoryBuffer : public MemoryBuffer { + public: + StringMemoryBuffer(std::string Buffer) + : Buffer(std::move(Buffer)) { + init(this->Buffer.data(), this->Buffer.data() + this->Buffer.size(), + false); + } + + BufferKind getBufferKind() const override { return MemoryBuffer_Malloc; } + private: + std::string Buffer; + }; + + JITSymbol lookup(ObjHandleT Id, const std::string &Name) { + return remoteToJITSymbol( + this->Remote.template callB<Lookup>(Id, Name)); + } + + JITSymbol lookupInLogicalDylib(ObjHandleT Id, const std::string &Name) { + return remoteToJITSymbol( + this->Remote.template callB<LookupInLogicalDylib>(Id, Name)); + } + + Expected<ObjHandleT> addObject(std::string ObjBuffer) { + auto Buffer = std::make_unique<StringMemoryBuffer>(std::move(ObjBuffer)); + auto Id = HandleIdMgr.getNext(); + assert(!BaseLayerHandles.count(Id) && "Id already in use?"); + + auto Resolver = createLambdaResolver( + AcknowledgeORCv1Deprecation, + [this, Id](const std::string &Name) { return lookup(Id, Name); }, + [this, Id](const std::string &Name) { + return lookupInLogicalDylib(Id, Name); + }); + + if (auto HandleOrErr = + BaseLayer.addObject(std::move(Buffer), std::move(Resolver))) { + BaseLayerHandles[Id] = std::move(*HandleOrErr); + return Id; + } else + return teeLog(HandleOrErr.takeError()); + } + + Error removeObject(ObjHandleT H) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Err = BaseLayer.removeObject(HI->second)) + return teeLog(std::move(Err)); + return Error::success(); + } else + return teeLog(badObjectHandleError(H)); + } + + Expected<RemoteSymbol> findSymbol(const std::string &Name, + bool ExportedSymbolsOnly) { + if (auto Sym = BaseLayer.findSymbol(Name, ExportedSymbolsOnly)) + return this->jitSymbolToRemote(std::move(Sym)); + else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + return this->nullRemoteSymbol(); + } + + Expected<RemoteSymbol> findSymbolIn(ObjHandleT H, const std::string &Name, + bool ExportedSymbolsOnly) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Sym = BaseLayer.findSymbolIn(HI->second, Name, ExportedSymbolsOnly)) + return this->jitSymbolToRemote(std::move(Sym)); + else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + return this->nullRemoteSymbol(); + } else + return teeLog(badObjectHandleError(H)); + } + + Error emitAndFinalize(ObjHandleT H) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Err = BaseLayer.emitAndFinalize(HI->second)) + return teeLog(std::move(Err)); + return Error::success(); + } else + return teeLog(badObjectHandleError(H)); + } + + BaseLayerT &BaseLayer; + remote::ResourceIdMgr HandleIdMgr; + std::map<ObjHandleT, typename BaseLayerT::ObjHandleT> BaseLayerHandles; +}; + +template <typename RPCEndpoint> +RemoteObjectClientLayer<RPCEndpoint>::RemoteObjectClientLayer( + RPCEndpoint &Remote, std::function<void(Error)> ReportError) + : RemoteObjectLayer<RPCEndpoint>(Remote, std::move(ReportError)) { + using ThisT = RemoteObjectClientLayer<RPCEndpoint>; + Remote.template addHandler<Lookup>(*this, &ThisT::lookup); + Remote.template addHandler<LookupInLogicalDylib>( + *this, &ThisT::lookupInLogicalDylib); +} + +template <typename BaseLayerT, typename RPCEndpoint> +RemoteObjectServerLayer<BaseLayerT, RPCEndpoint>::RemoteObjectServerLayer( + BaseLayerT &BaseLayer, RPCEndpoint &Remote, + std::function<void(Error)> ReportError) + : RemoteObjectLayer<RPCEndpoint>(Remote, std::move(ReportError)), + BaseLayer(BaseLayer), HandleIdMgr(1) { + using ThisT = RemoteObjectServerLayer<BaseLayerT, RPCEndpoint>; + + Remote.template addHandler<AddObject>(*this, &ThisT::addObject); + Remote.template addHandler<RemoveObject>(*this, &ThisT::removeObject); + Remote.template addHandler<FindSymbol>(*this, &ThisT::findSymbol); + Remote.template addHandler<FindSymbolIn>(*this, &ThisT::findSymbolIn); + Remote.template addHandler<EmitAndFinalize>(*this, &ThisT::emitAndFinalize); +} + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/SpeculateAnalyses.h b/llvm/include/llvm/ExecutionEngine/Orc/SpeculateAnalyses.h new file mode 100644 index 000000000000..cf57b63b6448 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/SpeculateAnalyses.h @@ -0,0 +1,84 @@ +//===-- SpeculateAnalyses.h --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// \file +/// Contains the Analyses and Result Interpretation to select likely functions +/// to Speculatively compile before they are called. [Purely Experimentation] +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SPECULATEANALYSES_H +#define LLVM_EXECUTIONENGINE_ORC_SPECULATEANALYSES_H + +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Speculation.h" + +#include <vector> + +namespace llvm { + +namespace orc { + +// Provides common code. +class SpeculateQuery { +protected: + void findCalles(const BasicBlock *, DenseSet<StringRef> &); + bool isStraightLine(const Function &F); + +public: + using ResultTy = Optional<DenseMap<StringRef, DenseSet<StringRef>>>; +}; + +// Direct calls in high frequency basic blocks are extracted. +class BlockFreqQuery : public SpeculateQuery { + size_t numBBToGet(size_t); + +public: + // Find likely next executables based on IR Block Frequency + ResultTy operator()(Function &F); +}; + +// This Query generates a sequence of basic blocks which follows the order of +// execution. +// A handful of BB with higher block frequencies are taken, then path to entry +// and end BB are discovered by traversing up & down the CFG. +class SequenceBBQuery : public SpeculateQuery { + struct WalkDirection { + bool Upward = true, Downward = true; + // the block associated contain a call + bool CallerBlock = false; + }; + +public: + using VisitedBlocksInfoTy = DenseMap<const BasicBlock *, WalkDirection>; + using BlockListTy = SmallVector<const BasicBlock *, 8>; + using BackEdgesInfoTy = + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8>; + using BlockFreqInfoTy = + SmallVector<std::pair<const BasicBlock *, uint64_t>, 8>; + +private: + std::size_t getHottestBlocks(std::size_t TotalBlocks); + BlockListTy rearrangeBB(const Function &, const BlockListTy &); + BlockListTy queryCFG(Function &, const BlockListTy &); + void traverseToEntryBlock(const BasicBlock *, const BlockListTy &, + const BackEdgesInfoTy &, + const BranchProbabilityInfo *, + VisitedBlocksInfoTy &); + void traverseToExitBlock(const BasicBlock *, const BlockListTy &, + const BackEdgesInfoTy &, + const BranchProbabilityInfo *, + VisitedBlocksInfoTy &); + +public: + ResultTy operator()(Function &F); +}; + +} // namespace orc +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SPECULATEANALYSES_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Speculation.h b/llvm/include/llvm/ExecutionEngine/Orc/Speculation.h new file mode 100644 index 000000000000..766a6b070f12 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/Speculation.h @@ -0,0 +1,207 @@ +//===-- Speculation.h - Speculative Compilation --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definition to support speculative compilation when laziness is +// enabled. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SPECULATION_H +#define LLVM_EXECUTIONENGINE_ORC_SPECULATION_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" + +#include <mutex> +#include <type_traits> +#include <utility> +#include <vector> + +namespace llvm { +namespace orc { + +class Speculator; + +// Track the Impls (JITDylib,Symbols) of Symbols while lazy call through +// trampolines are created. Operations are guarded by locks tp ensure that Imap +// stays in consistent state after read/write + +class ImplSymbolMap { + friend class Speculator; + +public: + using AliaseeDetails = std::pair<SymbolStringPtr, JITDylib *>; + using Alias = SymbolStringPtr; + using ImapTy = DenseMap<Alias, AliaseeDetails>; + void trackImpls(SymbolAliasMap ImplMaps, JITDylib *SrcJD); + +private: + // FIX ME: find a right way to distinguish the pre-compile Symbols, and update + // the callsite + Optional<AliaseeDetails> getImplFor(const SymbolStringPtr &StubSymbol) { + std::lock_guard<std::mutex> Lockit(ConcurrentAccess); + auto Position = Maps.find(StubSymbol); + if (Position != Maps.end()) + return Position->getSecond(); + else + return None; + } + + std::mutex ConcurrentAccess; + ImapTy Maps; +}; + +// Defines Speculator Concept, +class Speculator { +public: + using TargetFAddr = JITTargetAddress; + using FunctionCandidatesMap = DenseMap<SymbolStringPtr, SymbolNameSet>; + using StubAddrLikelies = DenseMap<TargetFAddr, SymbolNameSet>; + +private: + void registerSymbolsWithAddr(TargetFAddr ImplAddr, + SymbolNameSet likelySymbols) { + std::lock_guard<std::mutex> Lockit(ConcurrentAccess); + GlobalSpecMap.insert({ImplAddr, std::move(likelySymbols)}); + } + + void launchCompile(JITTargetAddress FAddr) { + SymbolNameSet CandidateSet; + // Copy CandidateSet is necessary, to avoid unsynchronized access to + // the datastructure. + { + std::lock_guard<std::mutex> Lockit(ConcurrentAccess); + auto It = GlobalSpecMap.find(FAddr); + if (It == GlobalSpecMap.end()) + return; + CandidateSet = It->getSecond(); + } + + SymbolDependenceMap SpeculativeLookUpImpls; + + for (auto &Callee : CandidateSet) { + auto ImplSymbol = AliaseeImplTable.getImplFor(Callee); + // try to distinguish already compiled & library symbols + if (!ImplSymbol.hasValue()) + continue; + const auto &ImplSymbolName = ImplSymbol.getPointer()->first; + JITDylib *ImplJD = ImplSymbol.getPointer()->second; + auto &SymbolsInJD = SpeculativeLookUpImpls[ImplJD]; + SymbolsInJD.insert(ImplSymbolName); + } + + DEBUG_WITH_TYPE("orc", for (auto &I + : SpeculativeLookUpImpls) { + llvm::dbgs() << "\n In " << I.first->getName() << " JITDylib "; + for (auto &N : I.second) + llvm::dbgs() << "\n Likely Symbol : " << N; + }); + + // for a given symbol, there may be no symbol qualified for speculatively + // compile try to fix this before jumping to this code if possible. + for (auto &LookupPair : SpeculativeLookUpImpls) + ES.lookup(JITDylibSearchList({{LookupPair.first, true}}), + LookupPair.second, SymbolState::Ready, + [this](Expected<SymbolMap> Result) { + if (auto Err = Result.takeError()) + ES.reportError(std::move(Err)); + }, + NoDependenciesToRegister); + } + +public: + Speculator(ImplSymbolMap &Impl, ExecutionSession &ref) + : AliaseeImplTable(Impl), ES(ref), GlobalSpecMap(0) {} + Speculator(const Speculator &) = delete; + Speculator(Speculator &&) = delete; + Speculator &operator=(const Speculator &) = delete; + Speculator &operator=(Speculator &&) = delete; + + /// Define symbols for this Speculator object (__orc_speculator) and the + /// speculation runtime entry point symbol (__orc_speculate_for) in the + /// given JITDylib. + Error addSpeculationRuntime(JITDylib &JD, MangleAndInterner &Mangle); + + // Speculatively compile likely functions for the given Stub Address. + // destination of __orc_speculate_for jump + void speculateFor(TargetFAddr StubAddr) { launchCompile(StubAddr); } + + // FIXME : Register with Stub Address, after JITLink Fix. + void registerSymbols(FunctionCandidatesMap Candidates, JITDylib *JD) { + for (auto &SymPair : Candidates) { + auto Target = SymPair.first; + auto Likely = SymPair.second; + + auto OnReadyFixUp = [Likely, Target, + this](Expected<SymbolMap> ReadySymbol) { + if (ReadySymbol) { + auto RAddr = (*ReadySymbol)[Target].getAddress(); + registerSymbolsWithAddr(RAddr, std::move(Likely)); + } else + this->getES().reportError(ReadySymbol.takeError()); + }; + // Include non-exported symbols also. + ES.lookup(JITDylibSearchList({{JD, true}}), SymbolNameSet({Target}), + SymbolState::Ready, OnReadyFixUp, NoDependenciesToRegister); + } + } + + ExecutionSession &getES() { return ES; } + +private: + static void speculateForEntryPoint(Speculator *Ptr, uint64_t StubId); + std::mutex ConcurrentAccess; + ImplSymbolMap &AliaseeImplTable; + ExecutionSession &ES; + StubAddrLikelies GlobalSpecMap; +}; + +class IRSpeculationLayer : public IRLayer { +public: + using IRlikiesStrRef = Optional<DenseMap<StringRef, DenseSet<StringRef>>>; + using ResultEval = std::function<IRlikiesStrRef(Function &)>; + using TargetAndLikelies = DenseMap<SymbolStringPtr, SymbolNameSet>; + + IRSpeculationLayer(ExecutionSession &ES, IRCompileLayer &BaseLayer, + Speculator &Spec, MangleAndInterner &Mangle, + ResultEval Interpreter) + : IRLayer(ES), NextLayer(BaseLayer), S(Spec), Mangle(Mangle), + QueryAnalysis(Interpreter) {} + + void emit(MaterializationResponsibility R, ThreadSafeModule TSM); + +private: + TargetAndLikelies + internToJITSymbols(DenseMap<StringRef, DenseSet<StringRef>> IRNames) { + assert(!IRNames.empty() && "No IRNames received to Intern?"); + TargetAndLikelies InternedNames; + DenseSet<SymbolStringPtr> TargetJITNames; + for (auto &NamePair : IRNames) { + for (auto &TargetNames : NamePair.second) + TargetJITNames.insert(Mangle(TargetNames)); + + InternedNames[Mangle(NamePair.first)] = std::move(TargetJITNames); + } + return InternedNames; + } + + IRCompileLayer &NextLayer; + Speculator &S; + MangleAndInterner &Mangle; + ResultEval QueryAnalysis; +}; + +} // namespace orc +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SPECULATION_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h new file mode 100644 index 000000000000..c354f6c3559c --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h @@ -0,0 +1,197 @@ +//===- SymbolStringPool.h - Multi-threaded pool for JIT symbols -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains a multi-threaded string pool suitable for use with ORC. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SYMBOLSTRINGPOOL_H +#define LLVM_EXECUTIONENGINE_ORC_SYMBOLSTRINGPOOL_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" +#include <atomic> +#include <mutex> + +namespace llvm { +namespace orc { + +class SymbolStringPtr; + +/// String pool for symbol names used by the JIT. +class SymbolStringPool { + friend class SymbolStringPtr; +public: + /// Destroy a SymbolStringPool. + ~SymbolStringPool(); + + /// Create a symbol string pointer from the given string. + SymbolStringPtr intern(StringRef S); + + /// Remove from the pool any entries that are no longer referenced. + void clearDeadEntries(); + + /// Returns true if the pool is empty. + bool empty() const; +private: + using RefCountType = std::atomic<size_t>; + using PoolMap = StringMap<RefCountType>; + using PoolMapEntry = StringMapEntry<RefCountType>; + mutable std::mutex PoolMutex; + PoolMap Pool; +}; + +/// Pointer to a pooled string representing a symbol name. +class SymbolStringPtr { + friend class SymbolStringPool; + friend struct DenseMapInfo<SymbolStringPtr>; + +public: + SymbolStringPtr() = default; + SymbolStringPtr(const SymbolStringPtr &Other) + : S(Other.S) { + if (isRealPoolEntry(S)) + ++S->getValue(); + } + + SymbolStringPtr& operator=(const SymbolStringPtr &Other) { + if (isRealPoolEntry(S)) + --S->getValue(); + S = Other.S; + if (isRealPoolEntry(S)) + ++S->getValue(); + return *this; + } + + SymbolStringPtr(SymbolStringPtr &&Other) : S(nullptr) { + std::swap(S, Other.S); + } + + SymbolStringPtr& operator=(SymbolStringPtr &&Other) { + if (isRealPoolEntry(S)) + --S->getValue(); + S = nullptr; + std::swap(S, Other.S); + return *this; + } + + ~SymbolStringPtr() { + if (isRealPoolEntry(S)) + --S->getValue(); + } + + StringRef operator*() const { return S->first(); } + + friend bool operator==(const SymbolStringPtr &LHS, + const SymbolStringPtr &RHS) { + return LHS.S == RHS.S; + } + + friend bool operator!=(const SymbolStringPtr &LHS, + const SymbolStringPtr &RHS) { + return !(LHS == RHS); + } + + friend bool operator<(const SymbolStringPtr &LHS, + const SymbolStringPtr &RHS) { + return LHS.S < RHS.S; + } + +private: + using PoolEntryPtr = SymbolStringPool::PoolMapEntry *; + + SymbolStringPtr(SymbolStringPool::PoolMapEntry *S) + : S(S) { + if (isRealPoolEntry(S)) + ++S->getValue(); + } + + // Returns false for null, empty, and tombstone values, true otherwise. + bool isRealPoolEntry(PoolEntryPtr P) { + return ((reinterpret_cast<uintptr_t>(P) - 1) & InvalidPtrMask) != + InvalidPtrMask; + } + + static SymbolStringPtr getEmptyVal() { + return SymbolStringPtr(reinterpret_cast<PoolEntryPtr>(EmptyBitPattern)); + } + + static SymbolStringPtr getTombstoneVal() { + return SymbolStringPtr(reinterpret_cast<PoolEntryPtr>(TombstoneBitPattern)); + } + + constexpr static uintptr_t EmptyBitPattern = + std::numeric_limits<uintptr_t>::max() + << PointerLikeTypeTraits<PoolEntryPtr>::NumLowBitsAvailable; + + constexpr static uintptr_t TombstoneBitPattern = + (std::numeric_limits<uintptr_t>::max() - 1) + << PointerLikeTypeTraits<PoolEntryPtr>::NumLowBitsAvailable; + + constexpr static uintptr_t InvalidPtrMask = + (std::numeric_limits<uintptr_t>::max() - 3) + << PointerLikeTypeTraits<PoolEntryPtr>::NumLowBitsAvailable; + + PoolEntryPtr S = nullptr; +}; + +inline SymbolStringPool::~SymbolStringPool() { +#ifndef NDEBUG + clearDeadEntries(); + assert(Pool.empty() && "Dangling references at pool destruction time"); +#endif // NDEBUG +} + +inline SymbolStringPtr SymbolStringPool::intern(StringRef S) { + std::lock_guard<std::mutex> Lock(PoolMutex); + PoolMap::iterator I; + bool Added; + std::tie(I, Added) = Pool.try_emplace(S, 0); + return SymbolStringPtr(&*I); +} + +inline void SymbolStringPool::clearDeadEntries() { + std::lock_guard<std::mutex> Lock(PoolMutex); + for (auto I = Pool.begin(), E = Pool.end(); I != E;) { + auto Tmp = I++; + if (Tmp->second == 0) + Pool.erase(Tmp); + } +} + +inline bool SymbolStringPool::empty() const { + std::lock_guard<std::mutex> Lock(PoolMutex); + return Pool.empty(); +} + +} // end namespace orc + +template <> +struct DenseMapInfo<orc::SymbolStringPtr> { + + static orc::SymbolStringPtr getEmptyKey() { + return orc::SymbolStringPtr::getEmptyVal(); + } + + static orc::SymbolStringPtr getTombstoneKey() { + return orc::SymbolStringPtr::getTombstoneVal(); + } + + static unsigned getHashValue(const orc::SymbolStringPtr &V) { + return DenseMapInfo<orc::SymbolStringPtr::PoolEntryPtr>::getHashValue(V.S); + } + + static bool isEqual(const orc::SymbolStringPtr &LHS, + const orc::SymbolStringPtr &RHS) { + return LHS.S == RHS.S; + } +}; + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SYMBOLSTRINGPOOL_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h new file mode 100644 index 000000000000..2347faed37a2 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h @@ -0,0 +1,175 @@ +//===----------- ThreadSafeModule.h -- Layer interfaces ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Thread safe wrappers and utilities for Module and LLVMContext. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_THREADSAFEMODULEWRAPPER_H +#define LLVM_EXECUTIONENGINE_ORC_THREADSAFEMODULEWRAPPER_H + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Compiler.h" + +#include <functional> +#include <memory> +#include <mutex> + +namespace llvm { +namespace orc { + +/// An LLVMContext together with an associated mutex that can be used to lock +/// the context to prevent concurrent access by other threads. +class ThreadSafeContext { +private: + struct State { + State(std::unique_ptr<LLVMContext> Ctx) : Ctx(std::move(Ctx)) {} + + std::unique_ptr<LLVMContext> Ctx; + std::recursive_mutex Mutex; + }; + +public: + // RAII based lock for ThreadSafeContext. + class LLVM_NODISCARD Lock { + public: + Lock(std::shared_ptr<State> S) : S(std::move(S)), L(this->S->Mutex) {} + + private: + std::shared_ptr<State> S; + std::unique_lock<std::recursive_mutex> L; + }; + + /// Construct a null context. + ThreadSafeContext() = default; + + /// Construct a ThreadSafeContext from the given LLVMContext. + ThreadSafeContext(std::unique_ptr<LLVMContext> NewCtx) + : S(std::make_shared<State>(std::move(NewCtx))) { + assert(S->Ctx != nullptr && + "Can not construct a ThreadSafeContext from a nullptr"); + } + + /// Returns a pointer to the LLVMContext that was used to construct this + /// instance, or null if the instance was default constructed. + LLVMContext *getContext() { return S ? S->Ctx.get() : nullptr; } + + /// Returns a pointer to the LLVMContext that was used to construct this + /// instance, or null if the instance was default constructed. + const LLVMContext *getContext() const { return S ? S->Ctx.get() : nullptr; } + + Lock getLock() const { + assert(S && "Can not lock an empty ThreadSafeContext"); + return Lock(S); + } + +private: + std::shared_ptr<State> S; +}; + +/// An LLVM Module together with a shared ThreadSafeContext. +class ThreadSafeModule { +public: + /// Default construct a ThreadSafeModule. This results in a null module and + /// null context. + ThreadSafeModule() = default; + + ThreadSafeModule(ThreadSafeModule &&Other) = default; + + ThreadSafeModule &operator=(ThreadSafeModule &&Other) { + // We have to explicitly define this move operator to copy the fields in + // reverse order (i.e. module first) to ensure the dependencies are + // protected: The old module that is being overwritten must be destroyed + // *before* the context that it depends on. + // We also need to lock the context to make sure the module tear-down + // does not overlap any other work on the context. + if (M) { + auto L = TSCtx.getLock(); + M = nullptr; + } + M = std::move(Other.M); + TSCtx = std::move(Other.TSCtx); + return *this; + } + + /// Construct a ThreadSafeModule from a unique_ptr<Module> and a + /// unique_ptr<LLVMContext>. This creates a new ThreadSafeContext from the + /// given context. + ThreadSafeModule(std::unique_ptr<Module> M, std::unique_ptr<LLVMContext> Ctx) + : M(std::move(M)), TSCtx(std::move(Ctx)) {} + + /// Construct a ThreadSafeModule from a unique_ptr<Module> and an + /// existing ThreadSafeContext. + ThreadSafeModule(std::unique_ptr<Module> M, ThreadSafeContext TSCtx) + : M(std::move(M)), TSCtx(std::move(TSCtx)) {} + + ~ThreadSafeModule() { + // We need to lock the context while we destruct the module. + if (M) { + auto L = TSCtx.getLock(); + M = nullptr; + } + } + + /// Boolean conversion: This ThreadSafeModule will evaluate to true if it + /// wraps a non-null module. + explicit operator bool() const { + if (M) { + assert(TSCtx.getContext() && + "Non-null module must have non-null context"); + return true; + } + return false; + } + + /// Locks the associated ThreadSafeContext and calls the given function + /// on the contained Module. + template <typename Func> + auto withModuleDo(Func &&F) -> decltype(F(std::declval<Module &>())) { + assert(M && "Can not call on null module"); + auto Lock = TSCtx.getLock(); + return F(*M); + } + + /// Locks the associated ThreadSafeContext and calls the given function + /// on the contained Module. + template <typename Func> + auto withModuleDo(Func &&F) const + -> decltype(F(std::declval<const Module &>())) { + auto Lock = TSCtx.getLock(); + return F(*M); + } + + /// Get a raw pointer to the contained module without locking the context. + Module *getModuleUnlocked() { return M.get(); } + + /// Get a raw pointer to the contained module without locking the context. + const Module *getModuleUnlocked() const { return M.get(); } + + /// Returns the context for this ThreadSafeModule. + ThreadSafeContext getContext() const { return TSCtx; } + +private: + std::unique_ptr<Module> M; + ThreadSafeContext TSCtx; +}; + +using GVPredicate = std::function<bool(const GlobalValue &)>; +using GVModifier = std::function<void(GlobalValue &)>; + +/// Clones the given module on to a new context. +ThreadSafeModule +cloneToNewContext(ThreadSafeModule &TSMW, + GVPredicate ShouldCloneDef = GVPredicate(), + GVModifier UpdateClonedDefSource = GVModifier()); + +} // End namespace orc +} // End namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_THREADSAFEMODULEWRAPPER_H diff --git a/llvm/include/llvm/ExecutionEngine/OrcMCJITReplacement.h b/llvm/include/llvm/ExecutionEngine/OrcMCJITReplacement.h new file mode 100644 index 000000000000..6cca1933f39f --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/OrcMCJITReplacement.h @@ -0,0 +1,37 @@ +//===---- OrcMCJITReplacement.h - Orc-based MCJIT replacement ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file forces OrcMCJITReplacement to link in on certain operating systems. +// (Windows). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORCMCJITREPLACEMENT_H +#define LLVM_EXECUTIONENGINE_ORCMCJITREPLACEMENT_H + +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include <cstdlib> + +extern "C" void LLVMLinkInOrcMCJITReplacement(); + +namespace { + struct ForceOrcMCJITReplacementLinking { + ForceOrcMCJITReplacementLinking() { + // We must reference OrcMCJITReplacement in such a way that compilers will + // not delete it all as dead code, even with whole program optimization, + // yet is effectively a NO-OP. As the compiler isn't smart enough to know + // that getenv() never returns -1, this will do the job. + if (std::getenv("bar") != (char*) -1) + return; + + LLVMLinkInOrcMCJITReplacement(); + } + } ForceOrcMCJITReplacementLinking; +} + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/OrcV1Deprecation.h b/llvm/include/llvm/ExecutionEngine/OrcV1Deprecation.h new file mode 100644 index 000000000000..7ed254b3ee04 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/OrcV1Deprecation.h @@ -0,0 +1,22 @@ +//===------ OrcV1Deprecation.h - Memory manager for MC-JIT ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Tag for suppressing ORCv1 deprecation warnings. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORCV1DEPRECATION_H +#define LLVM_EXECUTIONENGINE_ORCV1DEPRECATION_H + +namespace llvm { + +enum ORCv1DeprecationAcknowledgement { AcknowledgeORCv1Deprecation }; + +} // namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORCV1DEPRECATION_H diff --git a/llvm/include/llvm/ExecutionEngine/RTDyldMemoryManager.h b/llvm/include/llvm/ExecutionEngine/RTDyldMemoryManager.h new file mode 100644 index 000000000000..c7c87ecdfa09 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/RTDyldMemoryManager.h @@ -0,0 +1,158 @@ +//===-- RTDyldMemoryManager.cpp - Memory manager for MC-JIT -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Interface of the runtime dynamic memory manager base class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_RTDYLDMEMORYMANAGER_H +#define LLVM_EXECUTIONENGINE_RTDYLDMEMORYMANAGER_H + +#include "llvm-c/ExecutionEngine.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/Support/CBindingWrapping.h" +#include <cstddef> +#include <cstdint> +#include <string> + +namespace llvm { + +class ExecutionEngine; + +namespace object { + class ObjectFile; +} // end namespace object + +class MCJITMemoryManager : public RuntimeDyld::MemoryManager { +public: + // Don't hide the notifyObjectLoaded method from RuntimeDyld::MemoryManager. + using RuntimeDyld::MemoryManager::notifyObjectLoaded; + + /// This method is called after an object has been loaded into memory but + /// before relocations are applied to the loaded sections. The object load + /// may have been initiated by MCJIT to resolve an external symbol for another + /// object that is being finalized. In that case, the object about which + /// the memory manager is being notified will be finalized immediately after + /// the memory manager returns from this call. + /// + /// Memory managers which are preparing code for execution in an external + /// address space can use this call to remap the section addresses for the + /// newly loaded object. + virtual void notifyObjectLoaded(ExecutionEngine *EE, + const object::ObjectFile &) {} + +private: + void anchor() override; +}; + +// RuntimeDyld clients often want to handle the memory management of +// what gets placed where. For JIT clients, this is the subset of +// JITMemoryManager required for dynamic loading of binaries. +// +// FIXME: As the RuntimeDyld fills out, additional routines will be needed +// for the varying types of objects to be allocated. +class RTDyldMemoryManager : public MCJITMemoryManager, + public LegacyJITSymbolResolver { +public: + RTDyldMemoryManager() = default; + RTDyldMemoryManager(const RTDyldMemoryManager&) = delete; + void operator=(const RTDyldMemoryManager&) = delete; + ~RTDyldMemoryManager() override; + + /// Register EH frames in the current process. + static void registerEHFramesInProcess(uint8_t *Addr, size_t Size); + + /// Deregister EH frames in the current proces. + static void deregisterEHFramesInProcess(uint8_t *Addr, size_t Size); + + void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr, size_t Size) override; + void deregisterEHFrames() override; + + /// This method returns the address of the specified function or variable in + /// the current process. + static uint64_t getSymbolAddressInProcess(const std::string &Name); + + /// Legacy symbol lookup - DEPRECATED! Please override findSymbol instead. + /// + /// This method returns the address of the specified function or variable. + /// It is used to resolve symbols during module linking. + virtual uint64_t getSymbolAddress(const std::string &Name) { + return getSymbolAddressInProcess(Name); + } + + /// This method returns a RuntimeDyld::SymbolInfo for the specified function + /// or variable. It is used to resolve symbols during module linking. + /// + /// By default this falls back on the legacy lookup method: + /// 'getSymbolAddress'. The address returned by getSymbolAddress is treated as + /// a strong, exported symbol, consistent with historical treatment by + /// RuntimeDyld. + /// + /// Clients writing custom RTDyldMemoryManagers are encouraged to override + /// this method and return a SymbolInfo with the flags set correctly. This is + /// necessary for RuntimeDyld to correctly handle weak and non-exported symbols. + JITSymbol findSymbol(const std::string &Name) override { + return JITSymbol(getSymbolAddress(Name), JITSymbolFlags::Exported); + } + + /// Legacy symbol lookup -- DEPRECATED! Please override + /// findSymbolInLogicalDylib instead. + /// + /// Default to treating all modules as separate. + virtual uint64_t getSymbolAddressInLogicalDylib(const std::string &Name) { + return 0; + } + + /// Default to treating all modules as separate. + /// + /// By default this falls back on the legacy lookup method: + /// 'getSymbolAddressInLogicalDylib'. The address returned by + /// getSymbolAddressInLogicalDylib is treated as a strong, exported symbol, + /// consistent with historical treatment by RuntimeDyld. + /// + /// Clients writing custom RTDyldMemoryManagers are encouraged to override + /// this method and return a SymbolInfo with the flags set correctly. This is + /// necessary for RuntimeDyld to correctly handle weak and non-exported symbols. + JITSymbol + findSymbolInLogicalDylib(const std::string &Name) override { + return JITSymbol(getSymbolAddressInLogicalDylib(Name), + JITSymbolFlags::Exported); + } + + /// This method returns the address of the specified function. As such it is + /// only useful for resolving library symbols, not code generated symbols. + /// + /// If \p AbortOnFailure is false and no function with the given name is + /// found, this function returns a null pointer. Otherwise, it prints a + /// message to stderr and aborts. + /// + /// This function is deprecated for memory managers to be used with + /// MCJIT or RuntimeDyld. Use getSymbolAddress instead. + virtual void *getPointerToNamedFunction(const std::string &Name, + bool AbortOnFailure = true); + +protected: + struct EHFrame { + uint8_t *Addr; + size_t Size; + }; + typedef std::vector<EHFrame> EHFrameInfos; + EHFrameInfos EHFrames; + +private: + void anchor() override; +}; + +// Create wrappers for C Binding types (see CBindingWrapping.h). +DEFINE_SIMPLE_CONVERSION_FUNCTIONS( + RTDyldMemoryManager, LLVMMCJITMemoryManagerRef) + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_RTDYLDMEMORYMANAGER_H diff --git a/llvm/include/llvm/ExecutionEngine/RuntimeDyld.h b/llvm/include/llvm/ExecutionEngine/RuntimeDyld.h new file mode 100644 index 000000000000..ce7024a7f19b --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/RuntimeDyld.h @@ -0,0 +1,306 @@ +//===- RuntimeDyld.h - Run-time dynamic linker for MC-JIT -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Interface for the runtime dynamic linker facilities of the MC-JIT. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_RUNTIMEDYLD_H +#define LLVM_EXECUTIONENGINE_RUNTIMEDYLD_H + +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/DebugInfo/DIContext.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/Error.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <map> +#include <memory> +#include <string> +#include <system_error> + +namespace llvm { + +namespace object { + +template <typename T> class OwningBinary; + +} // end namespace object + +/// Base class for errors originating in RuntimeDyld, e.g. missing relocation +/// support. +class RuntimeDyldError : public ErrorInfo<RuntimeDyldError> { +public: + static char ID; + + RuntimeDyldError(std::string ErrMsg) : ErrMsg(std::move(ErrMsg)) {} + + void log(raw_ostream &OS) const override; + const std::string &getErrorMessage() const { return ErrMsg; } + std::error_code convertToErrorCode() const override; + +private: + std::string ErrMsg; +}; + +class RuntimeDyldImpl; + +class RuntimeDyld { +protected: + // Change the address associated with a section when resolving relocations. + // Any relocations already associated with the symbol will be re-resolved. + void reassignSectionAddress(unsigned SectionID, uint64_t Addr); + +public: + using NotifyStubEmittedFunction = std::function<void( + StringRef FileName, StringRef SectionName, StringRef SymbolName, + unsigned SectionID, uint32_t StubOffset)>; + + /// Information about the loaded object. + class LoadedObjectInfo : public llvm::LoadedObjectInfo { + friend class RuntimeDyldImpl; + + public: + using ObjSectionToIDMap = std::map<object::SectionRef, unsigned>; + + LoadedObjectInfo(RuntimeDyldImpl &RTDyld, ObjSectionToIDMap ObjSecToIDMap) + : RTDyld(RTDyld), ObjSecToIDMap(std::move(ObjSecToIDMap)) {} + + virtual object::OwningBinary<object::ObjectFile> + getObjectForDebug(const object::ObjectFile &Obj) const = 0; + + uint64_t + getSectionLoadAddress(const object::SectionRef &Sec) const override; + + protected: + virtual void anchor(); + + RuntimeDyldImpl &RTDyld; + ObjSectionToIDMap ObjSecToIDMap; + }; + + /// Memory Management. + class MemoryManager { + friend class RuntimeDyld; + + public: + MemoryManager() = default; + virtual ~MemoryManager() = default; + + /// Allocate a memory block of (at least) the given size suitable for + /// executable code. The SectionID is a unique identifier assigned by the + /// RuntimeDyld instance, and optionally recorded by the memory manager to + /// access a loaded section. + virtual uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, + StringRef SectionName) = 0; + + /// Allocate a memory block of (at least) the given size suitable for data. + /// The SectionID is a unique identifier assigned by the JIT engine, and + /// optionally recorded by the memory manager to access a loaded section. + virtual uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, + StringRef SectionName, + bool IsReadOnly) = 0; + + /// Inform the memory manager about the total amount of memory required to + /// allocate all sections to be loaded: + /// \p CodeSize - the total size of all code sections + /// \p DataSizeRO - the total size of all read-only data sections + /// \p DataSizeRW - the total size of all read-write data sections + /// + /// Note that by default the callback is disabled. To enable it + /// redefine the method needsToReserveAllocationSpace to return true. + virtual void reserveAllocationSpace(uintptr_t CodeSize, uint32_t CodeAlign, + uintptr_t RODataSize, + uint32_t RODataAlign, + uintptr_t RWDataSize, + uint32_t RWDataAlign) {} + + /// Override to return true to enable the reserveAllocationSpace callback. + virtual bool needsToReserveAllocationSpace() { return false; } + + /// Register the EH frames with the runtime so that c++ exceptions work. + /// + /// \p Addr parameter provides the local address of the EH frame section + /// data, while \p LoadAddr provides the address of the data in the target + /// address space. If the section has not been remapped (which will usually + /// be the case for local execution) these two values will be the same. + virtual void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr, + size_t Size) = 0; + virtual void deregisterEHFrames() = 0; + + /// This method is called when object loading is complete and section page + /// permissions can be applied. It is up to the memory manager implementation + /// to decide whether or not to act on this method. The memory manager will + /// typically allocate all sections as read-write and then apply specific + /// permissions when this method is called. Code sections cannot be executed + /// until this function has been called. In addition, any cache coherency + /// operations needed to reliably use the memory are also performed. + /// + /// Returns true if an error occurred, false otherwise. + virtual bool finalizeMemory(std::string *ErrMsg = nullptr) = 0; + + /// This method is called after an object has been loaded into memory but + /// before relocations are applied to the loaded sections. + /// + /// Memory managers which are preparing code for execution in an external + /// address space can use this call to remap the section addresses for the + /// newly loaded object. + /// + /// For clients that do not need access to an ExecutionEngine instance this + /// method should be preferred to its cousin + /// MCJITMemoryManager::notifyObjectLoaded as this method is compatible with + /// ORC JIT stacks. + virtual void notifyObjectLoaded(RuntimeDyld &RTDyld, + const object::ObjectFile &Obj) {} + + private: + virtual void anchor(); + + bool FinalizationLocked = false; + }; + + /// Construct a RuntimeDyld instance. + RuntimeDyld(MemoryManager &MemMgr, JITSymbolResolver &Resolver); + RuntimeDyld(const RuntimeDyld &) = delete; + RuntimeDyld &operator=(const RuntimeDyld &) = delete; + ~RuntimeDyld(); + + /// Add the referenced object file to the list of objects to be loaded and + /// relocated. + std::unique_ptr<LoadedObjectInfo> loadObject(const object::ObjectFile &O); + + /// Get the address of our local copy of the symbol. This may or may not + /// be the address used for relocation (clients can copy the data around + /// and resolve relocatons based on where they put it). + void *getSymbolLocalAddress(StringRef Name) const; + + /// Get the section ID for the section containing the given symbol. + unsigned getSymbolSectionID(StringRef Name) const; + + /// Get the target address and flags for the named symbol. + /// This address is the one used for relocation. + JITEvaluatedSymbol getSymbol(StringRef Name) const; + + /// Returns a copy of the symbol table. This can be used by on-finalized + /// callbacks to extract the symbol table before throwing away the + /// RuntimeDyld instance. Because the map keys (StringRefs) are backed by + /// strings inside the RuntimeDyld instance, the map should be processed + /// before the RuntimeDyld instance is discarded. + std::map<StringRef, JITEvaluatedSymbol> getSymbolTable() const; + + /// Resolve the relocations for all symbols we currently know about. + void resolveRelocations(); + + /// Map a section to its target address space value. + /// Map the address of a JIT section as returned from the memory manager + /// to the address in the target process as the running code will see it. + /// This is the address which will be used for relocation resolution. + void mapSectionAddress(const void *LocalAddress, uint64_t TargetAddress); + + /// Returns the section's working memory. + StringRef getSectionContent(unsigned SectionID) const; + + /// If the section was loaded, return the section's load address, + /// otherwise return None. + uint64_t getSectionLoadAddress(unsigned SectionID) const; + + /// Set the NotifyStubEmitted callback. This is used for debugging + /// purposes. A callback is made for each stub that is generated. + void setNotifyStubEmitted(NotifyStubEmittedFunction NotifyStubEmitted) { + this->NotifyStubEmitted = std::move(NotifyStubEmitted); + } + + /// Register any EH frame sections that have been loaded but not previously + /// registered with the memory manager. Note, RuntimeDyld is responsible + /// for identifying the EH frame and calling the memory manager with the + /// EH frame section data. However, the memory manager itself will handle + /// the actual target-specific EH frame registration. + void registerEHFrames(); + + void deregisterEHFrames(); + + bool hasError(); + StringRef getErrorString(); + + /// By default, only sections that are "required for execution" are passed to + /// the RTDyldMemoryManager, and other sections are discarded. Passing 'true' + /// to this method will cause RuntimeDyld to pass all sections to its + /// memory manager regardless of whether they are "required to execute" in the + /// usual sense. This is useful for inspecting metadata sections that may not + /// contain relocations, E.g. Debug info, stackmaps. + /// + /// Must be called before the first object file is loaded. + void setProcessAllSections(bool ProcessAllSections) { + assert(!Dyld && "setProcessAllSections must be called before loadObject."); + this->ProcessAllSections = ProcessAllSections; + } + + /// Perform all actions needed to make the code owned by this RuntimeDyld + /// instance executable: + /// + /// 1) Apply relocations. + /// 2) Register EH frames. + /// 3) Update memory permissions*. + /// + /// * Finalization is potentially recursive**, and the 3rd step will only be + /// applied by the outermost call to finalize. This allows different + /// RuntimeDyld instances to share a memory manager without the innermost + /// finalization locking the memory and causing relocation fixup errors in + /// outer instances. + /// + /// ** Recursive finalization occurs when one RuntimeDyld instances needs the + /// address of a symbol owned by some other instance in order to apply + /// relocations. + /// + void finalizeWithMemoryManagerLocking(); + +private: + friend void + jitLinkForORC(object::ObjectFile &Obj, + std::unique_ptr<MemoryBuffer> UnderlyingBuffer, + RuntimeDyld::MemoryManager &MemMgr, JITSymbolResolver &Resolver, + bool ProcessAllSections, + unique_function<Error(std::unique_ptr<LoadedObjectInfo>, + std::map<StringRef, JITEvaluatedSymbol>)> + OnLoaded, + unique_function<void(Error)> OnEmitted); + + // RuntimeDyldImpl is the actual class. RuntimeDyld is just the public + // interface. + std::unique_ptr<RuntimeDyldImpl> Dyld; + MemoryManager &MemMgr; + JITSymbolResolver &Resolver; + bool ProcessAllSections; + NotifyStubEmittedFunction NotifyStubEmitted; +}; + +// Asynchronous JIT link for ORC. +// +// Warning: This API is experimental and probably should not be used by anyone +// but ORC's RTDyldObjectLinkingLayer2. Internally it constructs a RuntimeDyld +// instance and uses continuation passing to perform the fix-up and finalize +// steps asynchronously. +void jitLinkForORC( + object::ObjectFile &Obj, std::unique_ptr<MemoryBuffer> UnderlyingBuffer, + RuntimeDyld::MemoryManager &MemMgr, JITSymbolResolver &Resolver, + bool ProcessAllSections, + unique_function<Error(std::unique_ptr<RuntimeDyld::LoadedObjectInfo>, + std::map<StringRef, JITEvaluatedSymbol>)> + OnLoaded, + unique_function<void(Error)> OnEmitted); + +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_RUNTIMEDYLD_H diff --git a/llvm/include/llvm/ExecutionEngine/RuntimeDyldChecker.h b/llvm/include/llvm/ExecutionEngine/RuntimeDyldChecker.h new file mode 100644 index 000000000000..93ea09107bd1 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/RuntimeDyldChecker.h @@ -0,0 +1,184 @@ +//===---- RuntimeDyldChecker.h - RuntimeDyld tester framework -----*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_RUNTIMEDYLDCHECKER_H +#define LLVM_EXECUTIONENGINE_RUNTIMEDYLDCHECKER_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/Support/Endian.h" + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +namespace llvm { + +class StringRef; +class MCDisassembler; +class MemoryBuffer; +class MCInstPrinter; +class RuntimeDyld; +class RuntimeDyldCheckerImpl; +class raw_ostream; + +/// RuntimeDyld invariant checker for verifying that RuntimeDyld has +/// correctly applied relocations. +/// +/// The RuntimeDyldChecker class evaluates expressions against an attached +/// RuntimeDyld instance to verify that relocations have been applied +/// correctly. +/// +/// The expression language supports basic pointer arithmetic and bit-masking, +/// and has limited disassembler integration for accessing instruction +/// operands and the next PC (program counter) address for each instruction. +/// +/// The language syntax is: +/// +/// check = expr '=' expr +/// +/// expr = binary_expr +/// | sliceable_expr +/// +/// sliceable_expr = '*{' number '}' load_addr_expr [slice] +/// | '(' expr ')' [slice] +/// | ident_expr [slice] +/// | number [slice] +/// +/// slice = '[' high-bit-index ':' low-bit-index ']' +/// +/// load_addr_expr = symbol +/// | '(' symbol '+' number ')' +/// | '(' symbol '-' number ')' +/// +/// ident_expr = 'decode_operand' '(' symbol ',' operand-index ')' +/// | 'next_pc' '(' symbol ')' +/// | 'stub_addr' '(' stub-container-name ',' symbol ')' +/// | 'got_addr' '(' stub-container-name ',' symbol ')' +/// | symbol +/// +/// binary_expr = expr '+' expr +/// | expr '-' expr +/// | expr '&' expr +/// | expr '|' expr +/// | expr '<<' expr +/// | expr '>>' expr +/// +class RuntimeDyldChecker { +public: + class MemoryRegionInfo { + public: + MemoryRegionInfo() = default; + + /// Constructor for symbols/sections with content. + MemoryRegionInfo(StringRef Content, JITTargetAddress TargetAddress) + : ContentPtr(Content.data()), Size(Content.size()), + TargetAddress(TargetAddress) {} + + /// Constructor for zero-fill symbols/sections. + MemoryRegionInfo(uint64_t Size, JITTargetAddress TargetAddress) + : Size(Size), TargetAddress(TargetAddress) {} + + /// Returns true if this is a zero-fill symbol/section. + bool isZeroFill() const { + assert(Size && "setContent/setZeroFill must be called first"); + return !ContentPtr; + } + + /// Set the content for this memory region. + void setContent(StringRef Content) { + assert(!ContentPtr && !Size && "Content/zero-fill already set"); + ContentPtr = Content.data(); + Size = Content.size(); + } + + /// Set a zero-fill length for this memory region. + void setZeroFill(uint64_t Size) { + assert(!ContentPtr && !this->Size && "Content/zero-fill already set"); + this->Size = Size; + } + + /// Returns the content for this section if there is any. + StringRef getContent() const { + assert(!isZeroFill() && "Can't get content for a zero-fill section"); + return StringRef(ContentPtr, static_cast<size_t>(Size)); + } + + /// Returns the zero-fill length for this section. + uint64_t getZeroFillLength() const { + assert(isZeroFill() && "Can't get zero-fill length for content section"); + return Size; + } + + /// Set the target address for this region. + void setTargetAddress(JITTargetAddress TargetAddress) { + assert(!this->TargetAddress && "TargetAddress already set"); + this->TargetAddress = TargetAddress; + } + + /// Return the target address for this region. + JITTargetAddress getTargetAddress() const { return TargetAddress; } + + private: + const char *ContentPtr = 0; + uint64_t Size = 0; + JITTargetAddress TargetAddress = 0; + }; + + using IsSymbolValidFunction = std::function<bool(StringRef Symbol)>; + using GetSymbolInfoFunction = + std::function<Expected<MemoryRegionInfo>(StringRef SymbolName)>; + using GetSectionInfoFunction = std::function<Expected<MemoryRegionInfo>( + StringRef FileName, StringRef SectionName)>; + using GetStubInfoFunction = std::function<Expected<MemoryRegionInfo>( + StringRef StubContainer, StringRef TargetName)>; + using GetGOTInfoFunction = std::function<Expected<MemoryRegionInfo>( + StringRef GOTContainer, StringRef TargetName)>; + + RuntimeDyldChecker(IsSymbolValidFunction IsSymbolValid, + GetSymbolInfoFunction GetSymbolInfo, + GetSectionInfoFunction GetSectionInfo, + GetStubInfoFunction GetStubInfo, + GetGOTInfoFunction GetGOTInfo, + support::endianness Endianness, + MCDisassembler *Disassembler, MCInstPrinter *InstPrinter, + raw_ostream &ErrStream); + ~RuntimeDyldChecker(); + + /// Check a single expression against the attached RuntimeDyld + /// instance. + bool check(StringRef CheckExpr) const; + + /// Scan the given memory buffer for lines beginning with the string + /// in RulePrefix. The remainder of the line is passed to the check + /// method to be evaluated as an expression. + bool checkAllRulesInBuffer(StringRef RulePrefix, MemoryBuffer *MemBuf) const; + + /// Returns the address of the requested section (or an error message + /// in the second element of the pair if the address cannot be found). + /// + /// if 'LocalAddress' is true, this returns the address of the section + /// within the linker's memory. If 'LocalAddress' is false it returns the + /// address within the target process (i.e. the load address). + std::pair<uint64_t, std::string> getSectionAddr(StringRef FileName, + StringRef SectionName, + bool LocalAddress); + + /// If there is a section at the given local address, return its load + /// address, otherwise return none. + Optional<uint64_t> getSectionLoadAddress(void *LocalAddress) const; + +private: + std::unique_ptr<RuntimeDyldCheckerImpl> Impl; +}; + +} // end namespace llvm + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h b/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h new file mode 100644 index 000000000000..d7316425da2f --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h @@ -0,0 +1,194 @@ +//===- SectionMemoryManager.h - Memory manager for MCJIT/RtDyld -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declaration of a section-based memory manager used by +// the MCJIT execution engine and RuntimeDyld. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_SECTIONMEMORYMANAGER_H +#define LLVM_EXECUTIONENGINE_SECTIONMEMORYMANAGER_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" +#include "llvm/Support/Memory.h" +#include <cstdint> +#include <string> +#include <system_error> + +namespace llvm { + +/// This is a simple memory manager which implements the methods called by +/// the RuntimeDyld class to allocate memory for section-based loading of +/// objects, usually those generated by the MCJIT execution engine. +/// +/// This memory manager allocates all section memory as read-write. The +/// RuntimeDyld will copy JITed section memory into these allocated blocks +/// and perform any necessary linking and relocations. +/// +/// Any client using this memory manager MUST ensure that section-specific +/// page permissions have been applied before attempting to execute functions +/// in the JITed object. Permissions can be applied either by calling +/// MCJIT::finalizeObject or by calling SectionMemoryManager::finalizeMemory +/// directly. Clients of MCJIT should call MCJIT::finalizeObject. +class SectionMemoryManager : public RTDyldMemoryManager { +public: + /// This enum describes the various reasons to allocate pages from + /// allocateMappedMemory. + enum class AllocationPurpose { + Code, + ROData, + RWData, + }; + + /// Implementations of this interface are used by SectionMemoryManager to + /// request pages from the operating system. + class MemoryMapper { + public: + /// This method attempts to allocate \p NumBytes bytes of virtual memory for + /// \p Purpose. \p NearBlock may point to an existing allocation, in which + /// case an attempt is made to allocate more memory near the existing block. + /// The actual allocated address is not guaranteed to be near the requested + /// address. \p Flags is used to set the initial protection flags for the + /// block of the memory. \p EC [out] returns an object describing any error + /// that occurs. + /// + /// This method may allocate more than the number of bytes requested. The + /// actual number of bytes allocated is indicated in the returned + /// MemoryBlock. + /// + /// The start of the allocated block must be aligned with the system + /// allocation granularity (64K on Windows, page size on Linux). If the + /// address following \p NearBlock is not so aligned, it will be rounded up + /// to the next allocation granularity boundary. + /// + /// \r a non-null MemoryBlock if the function was successful, otherwise a + /// null MemoryBlock with \p EC describing the error. + virtual sys::MemoryBlock + allocateMappedMemory(AllocationPurpose Purpose, size_t NumBytes, + const sys::MemoryBlock *const NearBlock, + unsigned Flags, std::error_code &EC) = 0; + + /// This method sets the protection flags for a block of memory to the state + /// specified by \p Flags. The behavior is not specified if the memory was + /// not allocated using the allocateMappedMemory method. + /// \p Block describes the memory block to be protected. + /// \p Flags specifies the new protection state to be assigned to the block. + /// + /// If \p Flags is MF_WRITE, the actual behavior varies with the operating + /// system (i.e. MF_READ | MF_WRITE on Windows) and the target architecture + /// (i.e. MF_WRITE -> MF_READ | MF_WRITE on i386). + /// + /// \r error_success if the function was successful, or an error_code + /// describing the failure if an error occurred. + virtual std::error_code protectMappedMemory(const sys::MemoryBlock &Block, + unsigned Flags) = 0; + + /// This method releases a block of memory that was allocated with the + /// allocateMappedMemory method. It should not be used to release any memory + /// block allocated any other way. + /// \p Block describes the memory to be released. + /// + /// \r error_success if the function was successful, or an error_code + /// describing the failure if an error occurred. + virtual std::error_code releaseMappedMemory(sys::MemoryBlock &M) = 0; + + virtual ~MemoryMapper(); + }; + + /// Creates a SectionMemoryManager instance with \p MM as the associated + /// memory mapper. If \p MM is nullptr then a default memory mapper is used + /// that directly calls into the operating system. + SectionMemoryManager(MemoryMapper *MM = nullptr); + SectionMemoryManager(const SectionMemoryManager &) = delete; + void operator=(const SectionMemoryManager &) = delete; + ~SectionMemoryManager() override; + + /// Allocates a memory block of (at least) the given size suitable for + /// executable code. + /// + /// The value of \p Alignment must be a power of two. If \p Alignment is zero + /// a default alignment of 16 will be used. + uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, + StringRef SectionName) override; + + /// Allocates a memory block of (at least) the given size suitable for + /// executable code. + /// + /// The value of \p Alignment must be a power of two. If \p Alignment is zero + /// a default alignment of 16 will be used. + uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, StringRef SectionName, + bool isReadOnly) override; + + /// Update section-specific memory permissions and other attributes. + /// + /// This method is called when object loading is complete and section page + /// permissions can be applied. It is up to the memory manager implementation + /// to decide whether or not to act on this method. The memory manager will + /// typically allocate all sections as read-write and then apply specific + /// permissions when this method is called. Code sections cannot be executed + /// until this function has been called. In addition, any cache coherency + /// operations needed to reliably use the memory are also performed. + /// + /// \returns true if an error occurred, false otherwise. + bool finalizeMemory(std::string *ErrMsg = nullptr) override; + + /// Invalidate instruction cache for code sections. + /// + /// Some platforms with separate data cache and instruction cache require + /// explicit cache flush, otherwise JIT code manipulations (like resolved + /// relocations) will get to the data cache but not to the instruction cache. + /// + /// This method is called from finalizeMemory. + virtual void invalidateInstructionCache(); + +private: + struct FreeMemBlock { + // The actual block of free memory + sys::MemoryBlock Free; + // If there is a pending allocation from the same reservation right before + // this block, store it's index in PendingMem, to be able to update the + // pending region if part of this block is allocated, rather than having to + // create a new one + unsigned PendingPrefixIndex; + }; + + struct MemoryGroup { + // PendingMem contains all blocks of memory (subblocks of AllocatedMem) + // which have not yet had their permissions applied, but have been given + // out to the user. FreeMem contains all block of memory, which have + // neither had their permissions applied, nor been given out to the user. + SmallVector<sys::MemoryBlock, 16> PendingMem; + SmallVector<FreeMemBlock, 16> FreeMem; + + // All memory blocks that have been requested from the system + SmallVector<sys::MemoryBlock, 16> AllocatedMem; + + sys::MemoryBlock Near; + }; + + uint8_t *allocateSection(AllocationPurpose Purpose, uintptr_t Size, + unsigned Alignment); + + std::error_code applyMemoryGroupPermissions(MemoryGroup &MemGroup, + unsigned Permissions); + + void anchor() override; + + MemoryGroup CodeMem; + MemoryGroup RWDataMem; + MemoryGroup RODataMem; + MemoryMapper &MMapper; +}; + +} // end namespace llvm + +#endif // LLVM_EXECUTION_ENGINE_SECTION_MEMORY_MANAGER_H |