diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp | 905 |
1 files changed, 679 insertions, 226 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp index 599829a9e474..4eacc921b6cd 100644 --- a/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -7,15 +7,12 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file lowers exception-related instructions and setjmp/longjmp -/// function calls in order to use Emscripten's JavaScript try and catch -/// mechanism. +/// This file lowers exception-related instructions and setjmp/longjmp function +/// calls to use Emscripten's library functions. The pass uses JavaScript's try +/// and catch mechanism in case of Emscripten EH/SjLj and Wasm EH intrinsics in +/// case of Emscripten SjLJ. /// -/// To handle exceptions and setjmp/longjmps, this scheme relies on JavaScript's -/// try and catch syntax and relevant exception-related libraries implemented -/// in JavaScript glue code that will be produced by Emscripten. -/// -/// * Exception handling +/// * Emscripten exception handling /// This pass lowers invokes and landingpads into library functions in JS glue /// code. Invokes are lowered into function wrappers called invoke wrappers that /// exist in JS side, which wraps the original function call with JS try-catch. @@ -23,7 +20,7 @@ /// variables (see below) so we can check whether an exception occurred from /// wasm code and handle it appropriately. /// -/// * Setjmp-longjmp handling +/// * Emscripten setjmp-longjmp handling /// This pass lowers setjmp to a reasonably-performant approach for emscripten. /// The idea is that each block with a setjmp is broken up into two parts: the /// part containing setjmp and the part right after the setjmp. The latter part @@ -52,7 +49,7 @@ /// __threwValue is 0 for exceptions, and the argument to longjmp in case of /// longjmp. /// -/// * Exception handling +/// * Emscripten exception handling /// /// 2) We assume the existence of setThrew and setTempRet0/getTempRet0 functions /// at link time. setThrew exists in Emscripten's compiler-rt: @@ -121,16 +118,16 @@ /// call @llvm_eh_typeid_for(type) /// llvm_eh_typeid_for function will be generated in JS glue code. /// -/// * Setjmp / Longjmp handling +/// * Emscripten setjmp / longjmp handling /// -/// In case calls to longjmp() exists +/// If there are calls to longjmp() /// /// 1) Lower -/// longjmp(buf, value) +/// longjmp(env, val) /// into -/// emscripten_longjmp(buf, value) +/// emscripten_longjmp(env, val) /// -/// In case calls to setjmp() exists +/// If there are calls to setjmp() /// /// 2) In the function entry that calls setjmp, initialize setjmpTable and /// sejmpTableSize as follows: @@ -141,9 +138,9 @@ /// Emscripten compiler-rt. /// /// 3) Lower -/// setjmp(buf) +/// setjmp(env) /// into -/// setjmpTable = saveSetjmp(buf, label, setjmpTable, setjmpTableSize); +/// setjmpTable = saveSetjmp(env, label, setjmpTable, setjmpTableSize); /// setjmpTableSize = getTempRet0(); /// For each dynamic setjmp call, setjmpTable stores its ID (a number which /// is incrementally assigned from 0) and its label (a unique number that @@ -151,10 +148,9 @@ /// setjmpTable, it is reallocated in saveSetjmp() in Emscripten's /// compiler-rt and it will return the new table address, and assign the new /// table size in setTempRet0(). saveSetjmp also stores the setjmp's ID into -/// the buffer buf. A BB with setjmp is split into two after setjmp call in +/// the buffer 'env'. A BB with setjmp is split into two after setjmp call in /// order to make the post-setjmp BB the possible destination of longjmp BB. /// -/// /// 4) Lower every call that might longjmp into /// __THREW__ = 0; /// call @__invoke_SIG(func, arg1, arg2) @@ -171,7 +167,7 @@ /// %label = -1; /// } /// longjmp_result = getTempRet0(); -/// switch label { +/// switch %label { /// label 1: goto post-setjmp BB 1 /// label 2: goto post-setjmp BB 2 /// ... @@ -188,23 +184,114 @@ /// occurred. Otherwise we jump to the right post-setjmp BB based on the /// label. /// +/// * Wasm setjmp / longjmp handling +/// This mode still uses some Emscripten library functions but not JavaScript's +/// try-catch mechanism. It instead uses Wasm exception handling intrinsics, +/// which will be lowered to exception handling instructions. +/// +/// If there are calls to longjmp() +/// +/// 1) Lower +/// longjmp(env, val) +/// into +/// __wasm_longjmp(env, val) +/// +/// If there are calls to setjmp() +/// +/// 2) and 3): The same as 2) and 3) in Emscripten SjLj. +/// (setjmpTable/setjmpTableSize initialization + setjmp callsite +/// transformation) +/// +/// 4) Create a catchpad with a wasm.catch() intrinsic, which returns the value +/// thrown by __wasm_longjmp function. In Emscripten library, we have this +/// struct: +/// +/// struct __WasmLongjmpArgs { +/// void *env; +/// int val; +/// }; +/// struct __WasmLongjmpArgs __wasm_longjmp_args; +/// +/// The thrown value here is a pointer to __wasm_longjmp_args struct object. We +/// use this struct to transfer two values by throwing a single value. Wasm +/// throw and catch instructions are capable of throwing and catching multiple +/// values, but it also requires multivalue support that is currently not very +/// reliable. +/// TODO Switch to throwing and catching two values without using the struct +/// +/// All longjmpable function calls will be converted to an invoke that will +/// unwind to this catchpad in case a longjmp occurs. Within the catchpad, we +/// test the thrown values using testSetjmp function as we do for Emscripten +/// SjLj. The main difference is, in Emscripten SjLj, we need to transform every +/// longjmpable callsite into a sequence of code including testSetjmp() call; in +/// Wasm SjLj we do the testing in only one place, in this catchpad. +/// +/// After testing calling testSetjmp(), if the longjmp does not correspond to +/// one of the setjmps within the current function, it rethrows the longjmp +/// by calling __wasm_longjmp(). If it corresponds to one of setjmps in the +/// function, we jump to the beginning of the function, which contains a switch +/// to each post-setjmp BB. Again, in Emscripten SjLj, this switch is added for +/// every longjmpable callsite; in Wasm SjLj we do this only once at the top of +/// the function. (after setjmpTable/setjmpTableSize initialization) +/// +/// The below is the pseudocode for what we have described +/// +/// entry: +/// Initialize setjmpTable and setjmpTableSize +/// +/// setjmp.dispatch: +/// switch %label { +/// label 1: goto post-setjmp BB 1 +/// label 2: goto post-setjmp BB 2 +/// ... +/// default: goto splitted next BB +/// } +/// ... +/// +/// bb: +/// invoke void @foo() ;; foo is a longjmpable function +/// to label %next unwind label %catch.dispatch.longjmp +/// ... +/// +/// catch.dispatch.longjmp: +/// %0 = catchswitch within none [label %catch.longjmp] unwind to caller +/// +/// catch.longjmp: +/// %longjmp.args = wasm.catch() ;; struct __WasmLongjmpArgs +/// %env = load 'env' field from __WasmLongjmpArgs +/// %val = load 'val' field from __WasmLongjmpArgs +/// %label = testSetjmp(mem[%env], setjmpTable, setjmpTableSize); +/// if (%label == 0) +/// __wasm_longjmp(%env, %val) +/// catchret to %setjmp.dispatch +/// ///===----------------------------------------------------------------------===// #include "WebAssembly.h" #include "WebAssemblyTargetMachine.h" #include "llvm/ADT/StringExtras.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/WasmEHFuncInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/SSAUpdaterBulk.h" using namespace llvm; #define DEBUG_TYPE "wasm-lower-em-ehsjlj" +// Emscripten's asm.js-style exception handling +extern cl::opt<bool> WasmEnableEmEH; +// Emscripten's asm.js-style setjmp/longjmp handling +extern cl::opt<bool> WasmEnableEmSjLj; +// Wasm setjmp/longjmp handling using wasm EH instructions +extern cl::opt<bool> WasmEnableSjLj; + static cl::list<std::string> EHAllowlist("emscripten-cxx-exceptions-allowed", cl::desc("The list of function names in which Emscripten-style " @@ -214,19 +301,25 @@ static cl::list<std::string> namespace { class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass { - bool EnableEH; // Enable exception handling - bool EnableSjLj; // Enable setjmp/longjmp handling - bool DoSjLj; // Whether we actually perform setjmp/longjmp handling - - GlobalVariable *ThrewGV = nullptr; - GlobalVariable *ThrewValueGV = nullptr; - Function *GetTempRet0Func = nullptr; - Function *SetTempRet0Func = nullptr; - Function *ResumeF = nullptr; - Function *EHTypeIDF = nullptr; - Function *EmLongjmpF = nullptr; - Function *SaveSetjmpF = nullptr; - Function *TestSetjmpF = nullptr; + bool EnableEmEH; // Enable Emscripten exception handling + bool EnableEmSjLj; // Enable Emscripten setjmp/longjmp handling + bool EnableWasmSjLj; // Enable Wasm setjmp/longjmp handling + bool DoSjLj; // Whether we actually perform setjmp/longjmp handling + + GlobalVariable *ThrewGV = nullptr; // __THREW__ (Emscripten) + GlobalVariable *ThrewValueGV = nullptr; // __threwValue (Emscripten) + Function *GetTempRet0F = nullptr; // getTempRet0() (Emscripten) + Function *SetTempRet0F = nullptr; // setTempRet0() (Emscripten) + Function *ResumeF = nullptr; // __resumeException() (Emscripten) + Function *EHTypeIDF = nullptr; // llvm.eh.typeid.for() (intrinsic) + Function *EmLongjmpF = nullptr; // emscripten_longjmp() (Emscripten) + Function *SaveSetjmpF = nullptr; // saveSetjmp() (Emscripten) + Function *TestSetjmpF = nullptr; // testSetjmp() (Emscripten) + Function *WasmLongjmpF = nullptr; // __wasm_longjmp() (Emscripten) + Function *CatchF = nullptr; // wasm.catch() (intrinsic) + + // type of 'struct __WasmLongjmpArgs' defined in emscripten + Type *LongjmpArgsTy = nullptr; // __cxa_find_matching_catch_N functions. // Indexed by the number of clauses in an original landingpad instruction. @@ -242,31 +335,47 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass { return "WebAssembly Lower Emscripten Exceptions"; } + using InstVector = SmallVectorImpl<Instruction *>; bool runEHOnFunction(Function &F); bool runSjLjOnFunction(Function &F); + void handleLongjmpableCallsForEmscriptenSjLj( + Function &F, InstVector &SetjmpTableInsts, + InstVector &SetjmpTableSizeInsts, + SmallVectorImpl<PHINode *> &SetjmpRetPHIs); + void + handleLongjmpableCallsForWasmSjLj(Function &F, InstVector &SetjmpTableInsts, + InstVector &SetjmpTableSizeInsts, + SmallVectorImpl<PHINode *> &SetjmpRetPHIs); Function *getFindMatchingCatch(Module &M, unsigned NumClauses); Value *wrapInvoke(CallBase *CI); void wrapTestSetjmp(BasicBlock *BB, DebugLoc DL, Value *Threw, Value *SetjmpTable, Value *SetjmpTableSize, Value *&Label, - Value *&LongjmpResult, BasicBlock *&EndBB); + Value *&LongjmpResult, BasicBlock *&CallEmLongjmpBB, + PHINode *&CallEmLongjmpBBThrewPHI, + PHINode *&CallEmLongjmpBBThrewValuePHI, + BasicBlock *&EndBB); Function *getInvokeWrapper(CallBase *CI); bool areAllExceptionsAllowed() const { return EHAllowlistSet.empty(); } - bool canLongjmp(Module &M, const Value *Callee) const; - bool isEmAsmCall(Module &M, const Value *Callee) const; bool supportsException(const Function *F) const { - return EnableEH && (areAllExceptionsAllowed() || - EHAllowlistSet.count(std::string(F->getName()))); + return EnableEmEH && (areAllExceptionsAllowed() || + EHAllowlistSet.count(std::string(F->getName()))); } + void replaceLongjmpWith(Function *LongjmpF, Function *NewF); void rebuildSSA(Function &F); public: static char ID; - WebAssemblyLowerEmscriptenEHSjLj(bool EnableEH = true, bool EnableSjLj = true) - : ModulePass(ID), EnableEH(EnableEH), EnableSjLj(EnableSjLj) { + WebAssemblyLowerEmscriptenEHSjLj() + : ModulePass(ID), EnableEmEH(WasmEnableEmEH), + EnableEmSjLj(WasmEnableEmSjLj), EnableWasmSjLj(WasmEnableSjLj) { + assert(!(EnableEmSjLj && EnableWasmSjLj) && + "Two SjLj modes cannot be turned on at the same time"); + assert(!(EnableEmEH && EnableWasmSjLj) && + "Wasm SjLj should be only used with Wasm EH"); EHAllowlistSet.insert(EHAllowlist.begin(), EHAllowlist.end()); } bool runOnModule(Module &M) override; @@ -282,9 +391,8 @@ INITIALIZE_PASS(WebAssemblyLowerEmscriptenEHSjLj, DEBUG_TYPE, "WebAssembly Lower Emscripten Exceptions / Setjmp / Longjmp", false, false) -ModulePass *llvm::createWebAssemblyLowerEmscriptenEHSjLj(bool EnableEH, - bool EnableSjLj) { - return new WebAssemblyLowerEmscriptenEHSjLj(EnableEH, EnableSjLj); +ModulePass *llvm::createWebAssemblyLowerEmscriptenEHSjLj() { + return new WebAssemblyLowerEmscriptenEHSjLj(); } static bool canThrow(const Value *V) { @@ -353,12 +461,12 @@ static Function *getEmscriptenFunction(FunctionType *Ty, const Twine &Name, if (!F->hasFnAttribute("wasm-import-module")) { llvm::AttrBuilder B; B.addAttribute("wasm-import-module", "env"); - F->addAttributes(llvm::AttributeList::FunctionIndex, B); + F->addFnAttrs(B); } if (!F->hasFnAttribute("wasm-import-name")) { llvm::AttrBuilder B; B.addAttribute("wasm-import-name", F->getName()); - F->addAttributes(llvm::AttributeList::FunctionIndex, B); + F->addFnAttrs(B); } return F; } @@ -415,15 +523,6 @@ Value *WebAssemblyLowerEmscriptenEHSjLj::wrapInvoke(CallBase *CI) { Module *M = CI->getModule(); LLVMContext &C = M->getContext(); - // If we are calling a function that is noreturn, we must remove that - // attribute. The code we insert here does expect it to return, after we - // catch the exception. - if (CI->doesNotReturn()) { - if (auto *F = CI->getCalledFunction()) - F->removeFnAttr(Attribute::NoReturn); - CI->removeAttribute(AttributeList::FunctionIndex, Attribute::NoReturn); - } - IRBuilder<> IRB(C); IRB.SetInsertPoint(CI); @@ -450,10 +549,10 @@ Value *WebAssemblyLowerEmscriptenEHSjLj::wrapInvoke(CallBase *CI) { // No attributes for the callee pointer. ArgAttributes.push_back(AttributeSet()); // Copy the argument attributes from the original - for (unsigned I = 0, E = CI->getNumArgOperands(); I < E; ++I) - ArgAttributes.push_back(InvokeAL.getParamAttributes(I)); + for (unsigned I = 0, E = CI->arg_size(); I < E; ++I) + ArgAttributes.push_back(InvokeAL.getParamAttrs(I)); - AttrBuilder FnAttrs(InvokeAL.getFnAttributes()); + AttrBuilder FnAttrs(InvokeAL.getFnAttrs()); if (FnAttrs.contains(Attribute::AllocSize)) { // The allocsize attribute (if any) referes to parameters by index and needs // to be adjusted. @@ -467,9 +566,8 @@ Value *WebAssemblyLowerEmscriptenEHSjLj::wrapInvoke(CallBase *CI) { } // Reconstruct the AttributesList based on the vector we constructed. - AttributeList NewCallAL = - AttributeList::get(C, AttributeSet::get(C, FnAttrs), - InvokeAL.getRetAttributes(), ArgAttributes); + AttributeList NewCallAL = AttributeList::get( + C, AttributeSet::get(C, FnAttrs), InvokeAL.getRetAttrs(), ArgAttributes); NewCall->setAttributes(NewCallAL); CI->replaceAllUsesWith(NewCall); @@ -504,8 +602,7 @@ Function *WebAssemblyLowerEmscriptenEHSjLj::getInvokeWrapper(CallBase *CI) { return F; } -bool WebAssemblyLowerEmscriptenEHSjLj::canLongjmp(Module &M, - const Value *Callee) const { +static bool canLongjmp(const Value *Callee) { if (auto *CalleeF = dyn_cast<Function>(Callee)) if (CalleeF->isIntrinsic()) return false; @@ -543,8 +640,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::canLongjmp(Module &M, return true; } -bool WebAssemblyLowerEmscriptenEHSjLj::isEmAsmCall(Module &M, - const Value *Callee) const { +static bool isEmAsmCall(const Value *Callee) { StringRef CalleeName = Callee->getName(); // This is an exhaustive list from Emscripten's <emscripten/em_asm.h>. return CalleeName == "emscripten_asm_const_int" || @@ -558,7 +654,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::isEmAsmCall(Module &M, // The code this generates is equivalent to the following JavaScript code: // %__threwValue.val = __threwValue; // if (%__THREW__.val != 0 & %__threwValue.val != 0) { -// %label = _testSetjmp(mem[%__THREW__.val], setjmpTable, setjmpTableSize); +// %label = testSetjmp(mem[%__THREW__.val], setjmpTable, setjmpTableSize); // if (%label == 0) // emscripten_longjmp(%__THREW__.val, %__threwValue.val); // setTempRet0(%__threwValue.val); @@ -572,7 +668,8 @@ bool WebAssemblyLowerEmscriptenEHSjLj::isEmAsmCall(Module &M, void WebAssemblyLowerEmscriptenEHSjLj::wrapTestSetjmp( BasicBlock *BB, DebugLoc DL, Value *Threw, Value *SetjmpTable, Value *SetjmpTableSize, Value *&Label, Value *&LongjmpResult, - BasicBlock *&EndBB) { + BasicBlock *&CallEmLongjmpBB, PHINode *&CallEmLongjmpBBThrewPHI, + PHINode *&CallEmLongjmpBBThrewValuePHI, BasicBlock *&EndBB) { Function *F = BB->getParent(); Module *M = F->getParent(); LLVMContext &C = M->getContext(); @@ -591,10 +688,27 @@ void WebAssemblyLowerEmscriptenEHSjLj::wrapTestSetjmp( Value *Cmp1 = IRB.CreateAnd(ThrewCmp, ThrewValueCmp, "cmp1"); IRB.CreateCondBr(Cmp1, ThenBB1, ElseBB1); - // %label = _testSetjmp(mem[%__THREW__.val], _setjmpTable, _setjmpTableSize); + // Generate call.em.longjmp BB once and share it within the function + if (!CallEmLongjmpBB) { + // emscripten_longjmp(%__THREW__.val, %__threwValue.val); + CallEmLongjmpBB = BasicBlock::Create(C, "call.em.longjmp", F); + IRB.SetInsertPoint(CallEmLongjmpBB); + CallEmLongjmpBBThrewPHI = IRB.CreatePHI(getAddrIntType(M), 4, "threw.phi"); + CallEmLongjmpBBThrewValuePHI = + IRB.CreatePHI(IRB.getInt32Ty(), 4, "threwvalue.phi"); + CallEmLongjmpBBThrewPHI->addIncoming(Threw, ThenBB1); + CallEmLongjmpBBThrewValuePHI->addIncoming(ThrewValue, ThenBB1); + IRB.CreateCall(EmLongjmpF, + {CallEmLongjmpBBThrewPHI, CallEmLongjmpBBThrewValuePHI}); + IRB.CreateUnreachable(); + } else { + CallEmLongjmpBBThrewPHI->addIncoming(Threw, ThenBB1); + CallEmLongjmpBBThrewValuePHI->addIncoming(ThrewValue, ThenBB1); + } + + // %label = testSetjmp(mem[%__THREW__.val], setjmpTable, setjmpTableSize); // if (%label == 0) IRB.SetInsertPoint(ThenBB1); - BasicBlock *ThenBB2 = BasicBlock::Create(C, "if.then2", F); BasicBlock *EndBB2 = BasicBlock::Create(C, "if.end2", F); Value *ThrewPtr = IRB.CreateIntToPtr(Threw, getAddrPtrType(M), Threw->getName() + ".p"); @@ -603,16 +717,11 @@ void WebAssemblyLowerEmscriptenEHSjLj::wrapTestSetjmp( Value *ThenLabel = IRB.CreateCall( TestSetjmpF, {LoadedThrew, SetjmpTable, SetjmpTableSize}, "label"); Value *Cmp2 = IRB.CreateICmpEQ(ThenLabel, IRB.getInt32(0)); - IRB.CreateCondBr(Cmp2, ThenBB2, EndBB2); - - // emscripten_longjmp(%__THREW__.val, %__threwValue.val); - IRB.SetInsertPoint(ThenBB2); - IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue}); - IRB.CreateUnreachable(); + IRB.CreateCondBr(Cmp2, CallEmLongjmpBB, EndBB2); // setTempRet0(%__threwValue.val); IRB.SetInsertPoint(EndBB2); - IRB.CreateCall(SetTempRet0Func, ThrewValue); + IRB.CreateCall(SetTempRet0F, ThrewValue); IRB.CreateBr(EndBB1); IRB.SetInsertPoint(ElseBB1); @@ -628,53 +737,67 @@ void WebAssemblyLowerEmscriptenEHSjLj::wrapTestSetjmp( // Output parameter assignment Label = LabelPHI; EndBB = EndBB1; - LongjmpResult = IRB.CreateCall(GetTempRet0Func, None, "longjmp_result"); + LongjmpResult = IRB.CreateCall(GetTempRet0F, None, "longjmp_result"); } void WebAssemblyLowerEmscriptenEHSjLj::rebuildSSA(Function &F) { DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); DT.recalculate(F); // CFG has been changed - SSAUpdater SSA; + + SSAUpdaterBulk SSA; for (BasicBlock &BB : F) { for (Instruction &I : BB) { - SSA.Initialize(I.getType(), I.getName()); - SSA.AddAvailableValue(&BB, &I); - for (auto UI = I.use_begin(), UE = I.use_end(); UI != UE;) { - Use &U = *UI; - ++UI; + unsigned VarID = SSA.AddVariable(I.getName(), I.getType()); + // If a value is defined by an invoke instruction, it is only available in + // its normal destination and not in its unwind destination. + if (auto *II = dyn_cast<InvokeInst>(&I)) + SSA.AddAvailableValue(VarID, II->getNormalDest(), II); + else + SSA.AddAvailableValue(VarID, &BB, &I); + for (auto &U : I.uses()) { auto *User = cast<Instruction>(U.getUser()); if (auto *UserPN = dyn_cast<PHINode>(User)) if (UserPN->getIncomingBlock(U) == &BB) continue; - if (DT.dominates(&I, User)) continue; - SSA.RewriteUseAfterInsertions(U); + SSA.AddUse(VarID, &U); } } } + SSA.RewriteAllUses(&DT); } -// Replace uses of longjmp with emscripten_longjmp. emscripten_longjmp takes -// arguments of type {i32, i32} (wasm32) / {i64, i32} (wasm64) and longjmp takes -// {jmp_buf*, i32}, so we need a ptrtoint instruction here to make the type -// match. jmp_buf* will eventually be lowered to i32 in the wasm backend. -static void replaceLongjmpWithEmscriptenLongjmp(Function *LongjmpF, - Function *EmLongjmpF) { +// Replace uses of longjmp with a new longjmp function in Emscripten library. +// In Emscripten SjLj, the new function is +// void emscripten_longjmp(uintptr_t, i32) +// In Wasm SjLj, the new function is +// void __wasm_longjmp(i8*, i32) +// Because the original libc longjmp function takes (jmp_buf*, i32), we need a +// ptrtoint/bitcast instruction here to make the type match. jmp_buf* will +// eventually be lowered to i32/i64 in the wasm backend. +void WebAssemblyLowerEmscriptenEHSjLj::replaceLongjmpWith(Function *LongjmpF, + Function *NewF) { + assert(NewF == EmLongjmpF || NewF == WasmLongjmpF); Module *M = LongjmpF->getParent(); SmallVector<CallInst *, 8> ToErase; LLVMContext &C = LongjmpF->getParent()->getContext(); IRBuilder<> IRB(C); - // For calls to longjmp, replace it with emscripten_longjmp and cast its first - // argument (jmp_buf*) to int + // For calls to longjmp, replace it with emscripten_longjmp/__wasm_longjmp and + // cast its first argument (jmp_buf*) appropriately for (User *U : LongjmpF->users()) { auto *CI = dyn_cast<CallInst>(U); if (CI && CI->getCalledFunction() == LongjmpF) { IRB.SetInsertPoint(CI); - Value *Jmpbuf = - IRB.CreatePtrToInt(CI->getArgOperand(0), getAddrIntType(M), "jmpbuf"); - IRB.CreateCall(EmLongjmpF, {Jmpbuf, CI->getArgOperand(1)}); + Value *Env = nullptr; + if (NewF == EmLongjmpF) + Env = + IRB.CreatePtrToInt(CI->getArgOperand(0), getAddrIntType(M), "env"); + else // WasmLongjmpF + Env = + IRB.CreateBitCast(CI->getArgOperand(0), IRB.getInt8PtrTy(), "env"); + IRB.CreateCall(NewF, {Env, CI->getArgOperand(1)}); ToErase.push_back(CI); } } @@ -682,14 +805,23 @@ static void replaceLongjmpWithEmscriptenLongjmp(Function *LongjmpF, I->eraseFromParent(); // If we have any remaining uses of longjmp's function pointer, replace it - // with (int(*)(jmp_buf*, int))emscripten_longjmp. + // with (void(*)(jmp_buf*, int))emscripten_longjmp / __wasm_longjmp. if (!LongjmpF->uses().empty()) { - Value *EmLongjmp = - IRB.CreateBitCast(EmLongjmpF, LongjmpF->getType(), "em_longjmp"); - LongjmpF->replaceAllUsesWith(EmLongjmp); + Value *NewLongjmp = + IRB.CreateBitCast(NewF, LongjmpF->getType(), "longjmp.cast"); + LongjmpF->replaceAllUsesWith(NewLongjmp); } } +static bool containsLongjmpableCalls(const Function *F) { + for (const auto &BB : *F) + for (const auto &I : BB) + if (const auto *CB = dyn_cast<CallBase>(&I)) + if (canLongjmp(CB->getCalledOperand())) + return true; + return false; +} + bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { LLVM_DEBUG(dbgs() << "********** Lower Emscripten EH & SjLj **********\n"); @@ -698,39 +830,60 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { Function *SetjmpF = M.getFunction("setjmp"); Function *LongjmpF = M.getFunction("longjmp"); - bool SetjmpUsed = SetjmpF && !SetjmpF->use_empty(); - bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty(); - DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed); + + // In some platforms _setjmp and _longjmp are used instead. Change these to + // use setjmp/longjmp instead, because we later detect these functions by + // their names. + Function *SetjmpF2 = M.getFunction("_setjmp"); + Function *LongjmpF2 = M.getFunction("_longjmp"); + if (SetjmpF2) { + if (SetjmpF) { + if (SetjmpF->getFunctionType() != SetjmpF2->getFunctionType()) + report_fatal_error("setjmp and _setjmp have different function types"); + } else { + SetjmpF = Function::Create(SetjmpF2->getFunctionType(), + GlobalValue::ExternalLinkage, "setjmp", M); + } + SetjmpF2->replaceAllUsesWith(SetjmpF); + } + if (LongjmpF2) { + if (LongjmpF) { + if (LongjmpF->getFunctionType() != LongjmpF2->getFunctionType()) + report_fatal_error( + "longjmp and _longjmp have different function types"); + } else { + LongjmpF = Function::Create(LongjmpF2->getFunctionType(), + GlobalValue::ExternalLinkage, "setjmp", M); + } + LongjmpF2->replaceAllUsesWith(LongjmpF); + } auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); assert(TPC && "Expected a TargetPassConfig"); auto &TM = TPC->getTM<WebAssemblyTargetMachine>(); - if (EnableEH && TM.Options.ExceptionModel == ExceptionHandling::Wasm) - report_fatal_error("-exception-model=wasm not allowed with " - "-enable-emscripten-cxx-exceptions"); - // Declare (or get) global variables __THREW__, __threwValue, and // getTempRet0/setTempRet0 function which are used in common for both // exception handling and setjmp/longjmp handling ThrewGV = getGlobalVariable(M, getAddrIntType(&M), TM, "__THREW__"); ThrewValueGV = getGlobalVariable(M, IRB.getInt32Ty(), TM, "__threwValue"); - GetTempRet0Func = getEmscriptenFunction( + GetTempRet0F = getEmscriptenFunction( FunctionType::get(IRB.getInt32Ty(), false), "getTempRet0", &M); - SetTempRet0Func = getEmscriptenFunction( + SetTempRet0F = getEmscriptenFunction( FunctionType::get(IRB.getVoidTy(), IRB.getInt32Ty(), false), "setTempRet0", &M); - GetTempRet0Func->setDoesNotThrow(); - SetTempRet0Func->setDoesNotThrow(); + GetTempRet0F->setDoesNotThrow(); + SetTempRet0F->setDoesNotThrow(); bool Changed = false; // Function registration for exception handling - if (EnableEH) { + if (EnableEmEH) { // Register __resumeException function FunctionType *ResumeFTy = FunctionType::get(IRB.getVoidTy(), IRB.getInt8PtrTy(), false); ResumeF = getEmscriptenFunction(ResumeFTy, "__resumeException", &M); + ResumeF->addFnAttr(Attribute::NoReturn); // Register llvm_eh_typeid_for function FunctionType *EHTypeIDTy = @@ -738,20 +891,55 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M); } + if ((EnableEmSjLj || EnableWasmSjLj) && SetjmpF) { + // Precompute setjmp users + for (User *U : SetjmpF->users()) { + if (auto *CB = dyn_cast<CallBase>(U)) { + auto *UserF = CB->getFunction(); + // If a function that calls setjmp does not contain any other calls that + // can longjmp, we don't need to do any transformation on that function, + // so can ignore it + if (containsLongjmpableCalls(UserF)) + SetjmpUsers.insert(UserF); + } else { + std::string S; + raw_string_ostream SS(S); + SS << *U; + report_fatal_error(Twine("Indirect use of setjmp is not supported: ") + + SS.str()); + } + } + } + + bool SetjmpUsed = SetjmpF && !SetjmpUsers.empty(); + bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty(); + DoSjLj = (EnableEmSjLj | EnableWasmSjLj) && (SetjmpUsed || LongjmpUsed); + // Function registration and data pre-gathering for setjmp/longjmp handling if (DoSjLj) { - // Register emscripten_longjmp function - FunctionType *FTy = FunctionType::get( - IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false); - EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M); + assert(EnableEmSjLj || EnableWasmSjLj); + if (EnableEmSjLj) { + // Register emscripten_longjmp function + FunctionType *FTy = FunctionType::get( + IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false); + EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M); + EmLongjmpF->addFnAttr(Attribute::NoReturn); + } else { // EnableWasmSjLj + // Register __wasm_longjmp function, which calls __builtin_wasm_longjmp. + FunctionType *FTy = FunctionType::get( + IRB.getVoidTy(), {IRB.getInt8PtrTy(), IRB.getInt32Ty()}, false); + WasmLongjmpF = getEmscriptenFunction(FTy, "__wasm_longjmp", &M); + WasmLongjmpF->addFnAttr(Attribute::NoReturn); + } if (SetjmpF) { // Register saveSetjmp function FunctionType *SetjmpFTy = SetjmpF->getFunctionType(); - FTy = FunctionType::get(Type::getInt32PtrTy(C), - {SetjmpFTy->getParamType(0), IRB.getInt32Ty(), - Type::getInt32PtrTy(C), IRB.getInt32Ty()}, - false); + FunctionType *FTy = + FunctionType::get(Type::getInt32PtrTy(C), + {SetjmpFTy->getParamType(0), IRB.getInt32Ty(), + Type::getInt32PtrTy(C), IRB.getInt32Ty()}, + false); SaveSetjmpF = getEmscriptenFunction(FTy, "saveSetjmp", &M); // Register testSetjmp function @@ -761,16 +949,18 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { false); TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M); - // Precompute setjmp users - for (User *U : SetjmpF->users()) { - auto *UI = cast<Instruction>(U); - SetjmpUsers.insert(UI->getFunction()); - } + // wasm.catch() will be lowered down to wasm 'catch' instruction in + // instruction selection. + CatchF = Intrinsic::getDeclaration(&M, Intrinsic::wasm_catch); + // Type for struct __WasmLongjmpArgs + LongjmpArgsTy = StructType::get(IRB.getInt8PtrTy(), // env + IRB.getInt32Ty() // val + ); } } // Exception handling transformation - if (EnableEH) { + if (EnableEmEH) { for (Function &F : M) { if (F.isDeclaration()) continue; @@ -782,7 +972,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { if (DoSjLj) { Changed = true; // We have setjmp or longjmp somewhere if (LongjmpF) - replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF); + replaceLongjmpWith(LongjmpF, EnableEmSjLj ? EmLongjmpF : WasmLongjmpF); // Only traverse functions that uses setjmp in order not to insert // unnecessary prep / cleanup code in every function if (SetjmpF) @@ -816,6 +1006,12 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) { SmallVector<Instruction *, 64> ToErase; SmallPtrSet<LandingPadInst *, 32> LandingPads; + // rethrow.longjmp BB that will be shared within the function. + BasicBlock *RethrowLongjmpBB = nullptr; + // PHI node for the loaded value of __THREW__ global variable in + // rethrow.longjmp BB + PHINode *RethrowLongjmpBBThrewPHI = nullptr; + for (BasicBlock &BB : F) { auto *II = dyn_cast<InvokeInst>(BB.getTerminator()); if (!II) @@ -836,37 +1032,48 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) { // setjmp, it will be appropriately handled in runSjLjOnFunction. But even // if the function does not contain setjmp calls, we shouldn't silently // ignore longjmps; we should rethrow them so they can be correctly - // handled in somewhere up the call chain where setjmp is. - // __THREW__'s value is 0 when nothing happened, 1 when an exception is - // thrown, other values when longjmp is thrown. + // handled in somewhere up the call chain where setjmp is. __THREW__'s + // value is 0 when nothing happened, 1 when an exception is thrown, and + // other values when longjmp is thrown. // // if (%__THREW__.val == 0 || %__THREW__.val == 1) // goto %tail // else // goto %longjmp.rethrow // - // longjmp.rethrow: ;; This is longjmp. Rethrow it + // rethrow.longjmp: ;; This is longjmp. Rethrow it // %__threwValue.val = __threwValue // emscripten_longjmp(%__THREW__.val, %__threwValue.val); // // tail: ;; Nothing happened or an exception is thrown // ... Continue exception handling ... - if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(M, Callee)) { + if (DoSjLj && EnableEmSjLj && !SetjmpUsers.count(&F) && + canLongjmp(Callee)) { + // Create longjmp.rethrow BB once and share it within the function + if (!RethrowLongjmpBB) { + RethrowLongjmpBB = BasicBlock::Create(C, "rethrow.longjmp", &F); + IRB.SetInsertPoint(RethrowLongjmpBB); + RethrowLongjmpBBThrewPHI = + IRB.CreatePHI(getAddrIntType(&M), 4, "threw.phi"); + RethrowLongjmpBBThrewPHI->addIncoming(Threw, &BB); + Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV, + ThrewValueGV->getName() + ".val"); + IRB.CreateCall(EmLongjmpF, {RethrowLongjmpBBThrewPHI, ThrewValue}); + IRB.CreateUnreachable(); + } else { + RethrowLongjmpBBThrewPHI->addIncoming(Threw, &BB); + } + + IRB.SetInsertPoint(II); // Restore the insert point back BasicBlock *Tail = BasicBlock::Create(C, "tail", &F); - BasicBlock *RethrowBB = BasicBlock::Create(C, "longjmp.rethrow", &F); Value *CmpEqOne = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); Value *CmpEqZero = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 0), "cmp.eq.zero"); Value *Or = IRB.CreateOr(CmpEqZero, CmpEqOne, "or"); - IRB.CreateCondBr(Or, Tail, RethrowBB); - IRB.SetInsertPoint(RethrowBB); - Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV, - ThrewValueGV->getName() + ".val"); - IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue}); - - IRB.CreateUnreachable(); + IRB.CreateCondBr(Or, Tail, RethrowLongjmpBB); IRB.SetInsertPoint(Tail); + BB.replaceSuccessorsPhiUsesWith(&BB, Tail); } // Insert a branch based on __THREW__ variable @@ -961,7 +1168,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) { CallInst *FMCI = IRB.CreateCall(FMCF, FMCArgs, "fmc"); Value *Undef = UndefValue::get(LPI->getType()); Value *Pair0 = IRB.CreateInsertValue(Undef, FMCI, 0, "pair0"); - Value *TempRet0 = IRB.CreateCall(GetTempRet0Func, None, "tempret0"); + Value *TempRet0 = IRB.CreateCall(GetTempRet0F, None, "tempret0"); Value *Pair1 = IRB.CreateInsertValue(Pair0, TempRet0, 1, "pair1"); LPI->replaceAllUsesWith(Pair1); @@ -997,14 +1204,15 @@ static DebugLoc getOrCreateDebugLoc(const Instruction *InsertBefore, } bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { + assert(EnableEmSjLj || EnableWasmSjLj); Module &M = *F.getParent(); LLVMContext &C = F.getContext(); IRBuilder<> IRB(C); SmallVector<Instruction *, 64> ToErase; // Vector of %setjmpTable values - std::vector<Instruction *> SetjmpTableInsts; + SmallVector<Instruction *, 4> SetjmpTableInsts; // Vector of %setjmpTableSize values - std::vector<Instruction *> SetjmpTableSizeInsts; + SmallVector<Instruction *, 4> SetjmpTableSizeInsts; // Setjmp preparation @@ -1012,11 +1220,13 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { // We create this as an instruction intentionally, and we don't want to fold // this instruction to a constant 4, because this value will be used in // SSAUpdater.AddAvailableValue(...) later. - BasicBlock &EntryBB = F.getEntryBlock(); - DebugLoc FirstDL = getOrCreateDebugLoc(&*EntryBB.begin(), F.getSubprogram()); - BinaryOperator *SetjmpTableSize = BinaryOperator::Create( - Instruction::Add, IRB.getInt32(4), IRB.getInt32(0), "setjmpTableSize", - &*EntryBB.getFirstInsertionPt()); + BasicBlock *Entry = &F.getEntryBlock(); + DebugLoc FirstDL = getOrCreateDebugLoc(&*Entry->begin(), F.getSubprogram()); + SplitBlock(Entry, &*Entry->getFirstInsertionPt()); + + BinaryOperator *SetjmpTableSize = + BinaryOperator::Create(Instruction::Add, IRB.getInt32(4), IRB.getInt32(0), + "setjmpTableSize", Entry->getTerminator()); SetjmpTableSize->setDebugLoc(FirstDL); // setjmpTable = (int *) malloc(40); Instruction *SetjmpTable = CallInst::CreateMalloc( @@ -1036,13 +1246,14 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { SetjmpTableSizeInsts.push_back(SetjmpTableSize); // Setjmp transformation - std::vector<PHINode *> SetjmpRetPHIs; + SmallVector<PHINode *, 4> SetjmpRetPHIs; Function *SetjmpF = M.getFunction("setjmp"); for (User *U : SetjmpF->users()) { auto *CI = dyn_cast<CallInst>(U); + // FIXME 'invoke' to setjmp can happen when we use Wasm EH + Wasm SjLj, but + // we don't support two being used together yet. if (!CI) - report_fatal_error("Does not support indirect calls to setjmp"); - + report_fatal_error("Wasm EH + Wasm SjLj is not fully supported yet"); BasicBlock *BB = CI->getParent(); if (BB->getParent() != &F) // in other function continue; @@ -1072,14 +1283,136 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { Instruction *NewSetjmpTable = IRB.CreateCall(SaveSetjmpF, Args, "setjmpTable"); Instruction *NewSetjmpTableSize = - IRB.CreateCall(GetTempRet0Func, None, "setjmpTableSize"); + IRB.CreateCall(GetTempRet0F, None, "setjmpTableSize"); SetjmpTableInsts.push_back(NewSetjmpTable); SetjmpTableSizeInsts.push_back(NewSetjmpTableSize); ToErase.push_back(CI); } - // Update each call that can longjmp so it can return to a setjmp where - // relevant. + // Handle longjmpable calls. + if (EnableEmSjLj) + handleLongjmpableCallsForEmscriptenSjLj( + F, SetjmpTableInsts, SetjmpTableSizeInsts, SetjmpRetPHIs); + else // EnableWasmSjLj + handleLongjmpableCallsForWasmSjLj(F, SetjmpTableInsts, SetjmpTableSizeInsts, + SetjmpRetPHIs); + + // Erase everything we no longer need in this function + for (Instruction *I : ToErase) + I->eraseFromParent(); + + // Free setjmpTable buffer before each return instruction + function-exiting + // call + SmallVector<Instruction *, 16> ExitingInsts; + for (BasicBlock &BB : F) { + Instruction *TI = BB.getTerminator(); + if (isa<ReturnInst>(TI)) + ExitingInsts.push_back(TI); + // Any 'call' instruction with 'noreturn' attribute exits the function at + // this point. If this throws but unwinds to another EH pad within this + // function instead of exiting, this would have been an 'invoke', which + // happens if we use Wasm EH or Wasm SjLJ. + for (auto &I : BB) { + if (auto *CI = dyn_cast<CallInst>(&I)) { + bool IsNoReturn = CI->hasFnAttr(Attribute::NoReturn); + if (Function *CalleeF = CI->getCalledFunction()) + IsNoReturn |= CalleeF->hasFnAttribute(Attribute::NoReturn); + if (IsNoReturn) + ExitingInsts.push_back(&I); + } + } + } + for (auto *I : ExitingInsts) { + DebugLoc DL = getOrCreateDebugLoc(I, F.getSubprogram()); + // If this existing instruction is a call within a catchpad, we should add + // it as "funclet" to the operand bundle of 'free' call + SmallVector<OperandBundleDef, 1> Bundles; + if (auto *CB = dyn_cast<CallBase>(I)) + if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) + Bundles.push_back(OperandBundleDef(*Bundle)); + auto *Free = CallInst::CreateFree(SetjmpTable, Bundles, I); + Free->setDebugLoc(DL); + // CallInst::CreateFree may create a bitcast instruction if its argument + // types mismatch. We need to set the debug loc for the bitcast too. + if (auto *FreeCallI = dyn_cast<CallInst>(Free)) { + if (auto *BitCastI = dyn_cast<BitCastInst>(FreeCallI->getArgOperand(0))) + BitCastI->setDebugLoc(DL); + } + } + + // Every call to saveSetjmp can change setjmpTable and setjmpTableSize + // (when buffer reallocation occurs) + // entry: + // setjmpTableSize = 4; + // setjmpTable = (int *) malloc(40); + // setjmpTable[0] = 0; + // ... + // somebb: + // setjmpTable = saveSetjmp(env, label, setjmpTable, setjmpTableSize); + // setjmpTableSize = getTempRet0(); + // So we need to make sure the SSA for these variables is valid so that every + // saveSetjmp and testSetjmp calls have the correct arguments. + SSAUpdater SetjmpTableSSA; + SSAUpdater SetjmpTableSizeSSA; + SetjmpTableSSA.Initialize(Type::getInt32PtrTy(C), "setjmpTable"); + SetjmpTableSizeSSA.Initialize(Type::getInt32Ty(C), "setjmpTableSize"); + for (Instruction *I : SetjmpTableInsts) + SetjmpTableSSA.AddAvailableValue(I->getParent(), I); + for (Instruction *I : SetjmpTableSizeInsts) + SetjmpTableSizeSSA.AddAvailableValue(I->getParent(), I); + + for (auto &U : make_early_inc_range(SetjmpTable->uses())) + if (auto *I = dyn_cast<Instruction>(U.getUser())) + if (I->getParent() != Entry) + SetjmpTableSSA.RewriteUse(U); + for (auto &U : make_early_inc_range(SetjmpTableSize->uses())) + if (auto *I = dyn_cast<Instruction>(U.getUser())) + if (I->getParent() != Entry) + SetjmpTableSizeSSA.RewriteUse(U); + + // Finally, our modifications to the cfg can break dominance of SSA variables. + // For example, in this code, + // if (x()) { .. setjmp() .. } + // if (y()) { .. longjmp() .. } + // We must split the longjmp block, and it can jump into the block splitted + // from setjmp one. But that means that when we split the setjmp block, it's + // first part no longer dominates its second part - there is a theoretically + // possible control flow path where x() is false, then y() is true and we + // reach the second part of the setjmp block, without ever reaching the first + // part. So, we rebuild SSA form here. + rebuildSSA(F); + return true; +} + +// Update each call that can longjmp so it can return to the corresponding +// setjmp. Refer to 4) of "Emscripten setjmp/longjmp handling" section in the +// comments at top of the file for details. +void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForEmscriptenSjLj( + Function &F, InstVector &SetjmpTableInsts, InstVector &SetjmpTableSizeInsts, + SmallVectorImpl<PHINode *> &SetjmpRetPHIs) { + Module &M = *F.getParent(); + LLVMContext &C = F.getContext(); + IRBuilder<> IRB(C); + SmallVector<Instruction *, 64> ToErase; + + // We need to pass setjmpTable and setjmpTableSize to testSetjmp function. + // These values are defined in the beginning of the function and also in each + // setjmp callsite, but we don't know which values we should use at this + // point. So here we arbitraily use the ones defined in the beginning of the + // function, and SSAUpdater will later update them to the correct values. + Instruction *SetjmpTable = *SetjmpTableInsts.begin(); + Instruction *SetjmpTableSize = *SetjmpTableSizeInsts.begin(); + + // call.em.longjmp BB that will be shared within the function. + BasicBlock *CallEmLongjmpBB = nullptr; + // PHI node for the loaded value of __THREW__ global variable in + // call.em.longjmp BB + PHINode *CallEmLongjmpBBThrewPHI = nullptr; + // PHI node for the loaded value of __threwValue global variable in + // call.em.longjmp BB + PHINode *CallEmLongjmpBBThrewValuePHI = nullptr; + // rethrow.exn BB that will be shared within the function. + BasicBlock *RethrowExnBB = nullptr; // Because we are creating new BBs while processing and don't want to make // all these newly created BBs candidates again for longjmp processing, we @@ -1092,15 +1425,18 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { for (unsigned I = 0; I < BBs.size(); I++) { BasicBlock *BB = BBs[I]; for (Instruction &I : *BB) { - assert(!isa<InvokeInst>(&I)); + if (isa<InvokeInst>(&I)) + report_fatal_error("When using Wasm EH with Emscripten SjLj, there is " + "a restriction that `setjmp` function call and " + "exception cannot be used within the same function"); auto *CI = dyn_cast<CallInst>(&I); if (!CI) continue; const Value *Callee = CI->getCalledOperand(); - if (!canLongjmp(M, Callee)) + if (!canLongjmp(Callee)) continue; - if (isEmAsmCall(M, Callee)) + if (isEmAsmCall(Callee)) report_fatal_error("Cannot use EM_ASM* alongside setjmp/longjmp in " + F.getName() + ". Please consider using EM_JS, or move the " @@ -1171,19 +1507,26 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { // tail: // ... if (supportsException(&F) && canThrow(Callee)) { - IRB.SetInsertPoint(CI); // We will add a new conditional branch. So remove the branch created // when we split the BB ToErase.push_back(BB->getTerminator()); + + // Generate rethrow.exn BB once and share it within the function + if (!RethrowExnBB) { + RethrowExnBB = BasicBlock::Create(C, "rethrow.exn", &F); + IRB.SetInsertPoint(RethrowExnBB); + CallInst *Exn = + IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn"); + IRB.CreateCall(ResumeF, {Exn}); + IRB.CreateUnreachable(); + } + + IRB.SetInsertPoint(CI); BasicBlock *NormalBB = BasicBlock::Create(C, "normal", &F); - BasicBlock *RethrowBB = BasicBlock::Create(C, "eh.rethrow", &F); Value *CmpEqOne = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one"); - IRB.CreateCondBr(CmpEqOne, RethrowBB, NormalBB); - IRB.SetInsertPoint(RethrowBB); - CallInst *Exn = IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn"); - IRB.CreateCall(ResumeF, {Exn}); - IRB.CreateUnreachable(); + IRB.CreateCondBr(CmpEqOne, RethrowExnBB, NormalBB); + IRB.SetInsertPoint(NormalBB); IRB.CreateBr(Tail); BB = NormalBB; // New insertion point to insert testSetjmp() @@ -1202,7 +1545,9 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { Value *LongjmpResult = nullptr; BasicBlock *EndBB = nullptr; wrapTestSetjmp(BB, CI->getDebugLoc(), Threw, SetjmpTable, SetjmpTableSize, - Label, LongjmpResult, EndBB); + Label, LongjmpResult, CallEmLongjmpBB, + CallEmLongjmpBBThrewPHI, CallEmLongjmpBBThrewValuePHI, + EndBB); assert(Label && LongjmpResult && EndBB); // Create switch instruction @@ -1224,76 +1569,184 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { } } - // Erase everything we no longer need in this function for (Instruction *I : ToErase) I->eraseFromParent(); +} - // Free setjmpTable buffer before each return instruction - for (BasicBlock &BB : F) { - Instruction *TI = BB.getTerminator(); - if (isa<ReturnInst>(TI)) { - DebugLoc DL = getOrCreateDebugLoc(TI, F.getSubprogram()); - auto *Free = CallInst::CreateFree(SetjmpTable, TI); - Free->setDebugLoc(DL); - // CallInst::CreateFree may create a bitcast instruction if its argument - // types mismatch. We need to set the debug loc for the bitcast too. - if (auto *FreeCallI = dyn_cast<CallInst>(Free)) { - if (auto *BitCastI = dyn_cast<BitCastInst>(FreeCallI->getArgOperand(0))) - BitCastI->setDebugLoc(DL); - } - } +// Create a catchpad in which we catch a longjmp's env and val arguments, test +// if the longjmp corresponds to one of setjmps in the current function, and if +// so, jump to the setjmp dispatch BB from which we go to one of post-setjmp +// BBs. Refer to 4) of "Wasm setjmp/longjmp handling" section in the comments at +// top of the file for details. +void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj( + Function &F, InstVector &SetjmpTableInsts, InstVector &SetjmpTableSizeInsts, + SmallVectorImpl<PHINode *> &SetjmpRetPHIs) { + Module &M = *F.getParent(); + LLVMContext &C = F.getContext(); + IRBuilder<> IRB(C); + + // A function with catchswitch/catchpad instruction should have a personality + // function attached to it. Search for the wasm personality function, and if + // it exists, use it, and if it doesn't, create a dummy personality function. + // (SjLj is not going to call it anyway.) + if (!F.hasPersonalityFn()) { + StringRef PersName = getEHPersonalityName(EHPersonality::Wasm_CXX); + FunctionType *PersType = + FunctionType::get(IRB.getInt32Ty(), /* isVarArg */ true); + Value *PersF = M.getOrInsertFunction(PersName, PersType).getCallee(); + F.setPersonalityFn( + cast<Constant>(IRB.CreateBitCast(PersF, IRB.getInt8PtrTy()))); } - // Every call to saveSetjmp can change setjmpTable and setjmpTableSize - // (when buffer reallocation occurs) + // Use the entry BB's debugloc as a fallback + BasicBlock *Entry = &F.getEntryBlock(); + DebugLoc FirstDL = getOrCreateDebugLoc(&*Entry->begin(), F.getSubprogram()); + IRB.SetCurrentDebugLocation(FirstDL); + + // Arbitrarily use the ones defined in the beginning of the function. + // SSAUpdater will later update them to the correct values. + Instruction *SetjmpTable = *SetjmpTableInsts.begin(); + Instruction *SetjmpTableSize = *SetjmpTableSizeInsts.begin(); + + // Add setjmp.dispatch BB right after the entry block. Because we have + // initialized setjmpTable/setjmpTableSize in the entry block and split the + // rest into another BB, here 'OrigEntry' is the function's original entry + // block before the transformation. + // // entry: - // setjmpTableSize = 4; - // setjmpTable = (int *) malloc(40); - // setjmpTable[0] = 0; - // ... - // somebb: - // setjmpTable = saveSetjmp(buf, label, setjmpTable, setjmpTableSize); - // setjmpTableSize = getTempRet0(); - // So we need to make sure the SSA for these variables is valid so that every - // saveSetjmp and testSetjmp calls have the correct arguments. - SSAUpdater SetjmpTableSSA; - SSAUpdater SetjmpTableSizeSSA; - SetjmpTableSSA.Initialize(Type::getInt32PtrTy(C), "setjmpTable"); - SetjmpTableSizeSSA.Initialize(Type::getInt32Ty(C), "setjmpTableSize"); - for (Instruction *I : SetjmpTableInsts) - SetjmpTableSSA.AddAvailableValue(I->getParent(), I); - for (Instruction *I : SetjmpTableSizeInsts) - SetjmpTableSizeSSA.AddAvailableValue(I->getParent(), I); + // setjmpTable / setjmpTableSize initialization + // setjmp.dispatch: + // switch will be inserted here later + // entry.split: (OrigEntry) + // the original function starts here + BasicBlock *OrigEntry = Entry->getNextNode(); + BasicBlock *SetjmpDispatchBB = + BasicBlock::Create(C, "setjmp.dispatch", &F, OrigEntry); + cast<BranchInst>(Entry->getTerminator())->setSuccessor(0, SetjmpDispatchBB); + + // Create catch.dispatch.longjmp BB a catchswitch instruction + BasicBlock *CatchSwitchBB = + BasicBlock::Create(C, "catch.dispatch.longjmp", &F); + IRB.SetInsertPoint(CatchSwitchBB); + CatchSwitchInst *CatchSwitch = + IRB.CreateCatchSwitch(ConstantTokenNone::get(C), nullptr, 1); + + // Create catch.longjmp BB and a catchpad instruction + BasicBlock *CatchLongjmpBB = BasicBlock::Create(C, "catch.longjmp", &F); + CatchSwitch->addHandler(CatchLongjmpBB); + IRB.SetInsertPoint(CatchLongjmpBB); + CatchPadInst *CatchPad = IRB.CreateCatchPad(CatchSwitch, {}); + + // Wasm throw and catch instructions can throw and catch multiple values, but + // that requires multivalue support in the toolchain, which is currently not + // very reliable. We instead throw and catch a pointer to a struct value of + // type 'struct __WasmLongjmpArgs', which is defined in Emscripten. + Instruction *CatchCI = + IRB.CreateCall(CatchF, {IRB.getInt32(WebAssembly::C_LONGJMP)}, "thrown"); + Value *LongjmpArgs = + IRB.CreateBitCast(CatchCI, LongjmpArgsTy->getPointerTo(), "longjmp.args"); + Value *EnvField = + IRB.CreateConstGEP2_32(LongjmpArgsTy, LongjmpArgs, 0, 0, "env_gep"); + Value *ValField = + IRB.CreateConstGEP2_32(LongjmpArgsTy, LongjmpArgs, 0, 1, "val_gep"); + // void *env = __wasm_longjmp_args.env; + Instruction *Env = IRB.CreateLoad(IRB.getInt8PtrTy(), EnvField, "env"); + // int val = __wasm_longjmp_args.val; + Instruction *Val = IRB.CreateLoad(IRB.getInt32Ty(), ValField, "val"); + + // %label = testSetjmp(mem[%env], setjmpTable, setjmpTableSize); + // if (%label == 0) + // __wasm_longjmp(%env, %val) + // catchret to %setjmp.dispatch + BasicBlock *ThenBB = BasicBlock::Create(C, "if.then", &F); + BasicBlock *EndBB = BasicBlock::Create(C, "if.end", &F); + Value *EnvP = IRB.CreateBitCast(Env, getAddrPtrType(&M), "env.p"); + Value *SetjmpID = IRB.CreateLoad(getAddrIntType(&M), EnvP, "setjmp.id"); + Value *Label = + IRB.CreateCall(TestSetjmpF, {SetjmpID, SetjmpTable, SetjmpTableSize}, + OperandBundleDef("funclet", CatchPad), "label"); + Value *Cmp = IRB.CreateICmpEQ(Label, IRB.getInt32(0)); + IRB.CreateCondBr(Cmp, ThenBB, EndBB); + + IRB.SetInsertPoint(ThenBB); + CallInst *WasmLongjmpCI = IRB.CreateCall( + WasmLongjmpF, {Env, Val}, OperandBundleDef("funclet", CatchPad)); + IRB.CreateUnreachable(); - for (auto UI = SetjmpTable->use_begin(), UE = SetjmpTable->use_end(); - UI != UE;) { - // Grab the use before incrementing the iterator. - Use &U = *UI; - // Increment the iterator before removing the use from the list. - ++UI; - if (auto *I = dyn_cast<Instruction>(U.getUser())) - if (I->getParent() != &EntryBB) - SetjmpTableSSA.RewriteUse(U); + IRB.SetInsertPoint(EndBB); + // Jump to setjmp.dispatch block + IRB.CreateCatchRet(CatchPad, SetjmpDispatchBB); + + // Go back to setjmp.dispatch BB + // setjmp.dispatch: + // switch %label { + // label 1: goto post-setjmp BB 1 + // label 2: goto post-setjmp BB 2 + // ... + // default: goto splitted next BB + // } + IRB.SetInsertPoint(SetjmpDispatchBB); + PHINode *LabelPHI = IRB.CreatePHI(IRB.getInt32Ty(), 2, "label.phi"); + LabelPHI->addIncoming(Label, EndBB); + LabelPHI->addIncoming(IRB.getInt32(-1), Entry); + SwitchInst *SI = IRB.CreateSwitch(LabelPHI, OrigEntry, SetjmpRetPHIs.size()); + // -1 means no longjmp happened, continue normally (will hit the default + // switch case). 0 means a longjmp that is not ours to handle, needs a + // rethrow. Otherwise the index is the same as the index in P+1 (to avoid + // 0). + for (unsigned I = 0; I < SetjmpRetPHIs.size(); I++) { + SI->addCase(IRB.getInt32(I + 1), SetjmpRetPHIs[I]->getParent()); + SetjmpRetPHIs[I]->addIncoming(Val, SetjmpDispatchBB); } - for (auto UI = SetjmpTableSize->use_begin(), UE = SetjmpTableSize->use_end(); - UI != UE;) { - Use &U = *UI; - ++UI; - if (auto *I = dyn_cast<Instruction>(U.getUser())) - if (I->getParent() != &EntryBB) - SetjmpTableSizeSSA.RewriteUse(U); + + // Convert all longjmpable call instructions to invokes that unwind to the + // newly created catch.dispatch.longjmp BB. + SmallVector<Instruction *, 64> ToErase; + for (auto *BB = &*F.begin(); BB; BB = BB->getNextNode()) { + for (Instruction &I : *BB) { + auto *CI = dyn_cast<CallInst>(&I); + if (!CI) + continue; + const Value *Callee = CI->getCalledOperand(); + if (!canLongjmp(Callee)) + continue; + if (isEmAsmCall(Callee)) + report_fatal_error("Cannot use EM_ASM* alongside setjmp/longjmp in " + + F.getName() + + ". Please consider using EM_JS, or move the " + "EM_ASM into another function.", + false); + // This is __wasm_longjmp() call we inserted in this function, which + // rethrows the longjmp when the longjmp does not correspond to one of + // setjmps in this function. We should not convert this call to an invoke. + if (CI == WasmLongjmpCI) + continue; + ToErase.push_back(CI); + + // Even if the callee function has attribute 'nounwind', which is true for + // all C functions, it can longjmp, which means it can throw a Wasm + // exception now. + CI->removeFnAttr(Attribute::NoUnwind); + if (Function *CalleeF = CI->getCalledFunction()) { + CalleeF->removeFnAttr(Attribute::NoUnwind); + } + + IRB.SetInsertPoint(CI); + BasicBlock *Tail = SplitBlock(BB, CI->getNextNode()); + // We will add a new invoke. So remove the branch created when we split + // the BB + ToErase.push_back(BB->getTerminator()); + SmallVector<Value *, 8> Args(CI->args()); + InvokeInst *II = + IRB.CreateInvoke(CI->getFunctionType(), CI->getCalledOperand(), Tail, + CatchSwitchBB, Args); + II->takeName(CI); + II->setDebugLoc(CI->getDebugLoc()); + II->setAttributes(CI->getAttributes()); + CI->replaceAllUsesWith(II); + } } - // Finally, our modifications to the cfg can break dominance of SSA variables. - // For example, in this code, - // if (x()) { .. setjmp() .. } - // if (y()) { .. longjmp() .. } - // We must split the longjmp block, and it can jump into the block splitted - // from setjmp one. But that means that when we split the setjmp block, it's - // first part no longer dominates its second part - there is a theoretically - // possible control flow path where x() is false, then y() is true and we - // reach the second part of the setjmp block, without ever reaching the first - // part. So, we rebuild SSA form here. - rebuildSSA(F); - return true; + for (Instruction *I : ToErase) + I->eraseFromParent(); } |