summaryrefslogtreecommitdiff
path: root/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp')
-rw-r--r--lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp223
1 files changed, 216 insertions, 7 deletions
diff --git a/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index bc1458b1c203..6a6d832e33cd 100644
--- a/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -26,8 +26,57 @@ enum OpenMPRTLFunctionNVPTX {
OMPRTL_NVPTX__kmpc_kernel_init,
/// \brief Call to void __kmpc_kernel_deinit();
OMPRTL_NVPTX__kmpc_kernel_deinit,
+ /// \brief Call to void __kmpc_kernel_prepare_parallel(void
+ /// *outlined_function);
+ OMPRTL_NVPTX__kmpc_kernel_prepare_parallel,
+ /// \brief Call to bool __kmpc_kernel_parallel(void **outlined_function);
+ OMPRTL_NVPTX__kmpc_kernel_parallel,
+ /// \brief Call to void __kmpc_kernel_end_parallel();
+ OMPRTL_NVPTX__kmpc_kernel_end_parallel,
+ /// Call to void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+ /// global_tid);
+ OMPRTL_NVPTX__kmpc_serialized_parallel,
+ /// Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+ /// global_tid);
+ OMPRTL_NVPTX__kmpc_end_serialized_parallel,
};
-} // namespace
+
+/// Pre(post)-action for different OpenMP constructs specialized for NVPTX.
+class NVPTXActionTy final : public PrePostActionTy {
+ llvm::Value *EnterCallee;
+ ArrayRef<llvm::Value *> EnterArgs;
+ llvm::Value *ExitCallee;
+ ArrayRef<llvm::Value *> ExitArgs;
+ bool Conditional;
+ llvm::BasicBlock *ContBlock = nullptr;
+
+public:
+ NVPTXActionTy(llvm::Value *EnterCallee, ArrayRef<llvm::Value *> EnterArgs,
+ llvm::Value *ExitCallee, ArrayRef<llvm::Value *> ExitArgs,
+ bool Conditional = false)
+ : EnterCallee(EnterCallee), EnterArgs(EnterArgs), ExitCallee(ExitCallee),
+ ExitArgs(ExitArgs), Conditional(Conditional) {}
+ void Enter(CodeGenFunction &CGF) override {
+ llvm::Value *EnterRes = CGF.EmitRuntimeCall(EnterCallee, EnterArgs);
+ if (Conditional) {
+ llvm::Value *CallBool = CGF.Builder.CreateIsNotNull(EnterRes);
+ auto *ThenBlock = CGF.createBasicBlock("omp_if.then");
+ ContBlock = CGF.createBasicBlock("omp_if.end");
+ // Generate the branch (If-stmt)
+ CGF.Builder.CreateCondBr(CallBool, ThenBlock, ContBlock);
+ CGF.EmitBlock(ThenBlock);
+ }
+ }
+ void Done(CodeGenFunction &CGF) {
+ // Emit the rest of blocks/branches
+ CGF.EmitBranch(ContBlock);
+ CGF.EmitBlock(ContBlock, true);
+ }
+ void Exit(CodeGenFunction &CGF) override {
+ CGF.EmitRuntimeCall(ExitCallee, ExitArgs);
+ }
+};
+} // anonymous namespace
/// Get the GPU warp size.
static llvm::Value *getNVPTXWarpSize(CodeGenFunction &CGF) {
@@ -118,6 +167,7 @@ void CGOpenMPRuntimeNVPTX::emitGenericKernel(const OMPExecutableDirective &D,
const RegionCodeGenTy &CodeGen) {
EntryFunctionState EST;
WorkerFunctionState WST(CGM);
+ Work.clear();
// Emit target region as a standalone region.
class NVPTXPrePostActionTy : public PrePostActionTy {
@@ -246,7 +296,10 @@ void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
CGF.InitTempAlloca(ExecStatus, Bld.getInt8(/*C=*/0));
CGF.InitTempAlloca(WorkFn, llvm::Constant::getNullValue(CGF.Int8PtrTy));
- // TODO: Call into runtime to get parallel work.
+ llvm::Value *Args[] = {WorkFn.getPointer()};
+ llvm::Value *Ret = CGF.EmitRuntimeCall(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_parallel), Args);
+ Bld.CreateStore(Bld.CreateZExt(Ret, CGF.Int8Ty), ExecStatus);
// On termination condition (workid == 0), exit loop.
llvm::Value *ShouldTerminate =
@@ -261,10 +314,42 @@ void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
// Signal start of parallel region.
CGF.EmitBlock(ExecuteBB);
- // TODO: Add parallel work.
+
+ // Process work items: outlined parallel functions.
+ for (auto *W : Work) {
+ // Try to match this outlined function.
+ auto *ID = Bld.CreatePointerBitCastOrAddrSpaceCast(W, CGM.Int8PtrTy);
+
+ llvm::Value *WorkFnMatch =
+ Bld.CreateICmpEQ(Bld.CreateLoad(WorkFn), ID, "work_match");
+
+ llvm::BasicBlock *ExecuteFNBB = CGF.createBasicBlock(".execute.fn");
+ llvm::BasicBlock *CheckNextBB = CGF.createBasicBlock(".check.next");
+ Bld.CreateCondBr(WorkFnMatch, ExecuteFNBB, CheckNextBB);
+
+ // Execute this outlined function.
+ CGF.EmitBlock(ExecuteFNBB);
+
+ // Insert call to work function.
+ // FIXME: Pass arguments to outlined function from master thread.
+ auto *Fn = cast<llvm::Function>(W);
+ Address ZeroAddr =
+ CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr");
+ CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0));
+ llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()};
+ CGF.EmitCallOrInvoke(Fn, FnArgs);
+
+ // Go to end of parallel region.
+ CGF.EmitBranch(TerminateBB);
+
+ CGF.EmitBlock(CheckNextBB);
+ }
// Signal end of parallel region.
CGF.EmitBlock(TerminateBB);
+ CGF.EmitRuntimeCall(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_end_parallel),
+ llvm::None);
CGF.EmitBranch(BarrierBB);
// All active and inactive workers wait at a barrier after parallel region.
@@ -296,10 +381,53 @@ CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
case OMPRTL_NVPTX__kmpc_kernel_deinit: {
// Build void __kmpc_kernel_deinit();
llvm::FunctionType *FnTy =
- llvm::FunctionType::get(CGM.VoidTy, {}, /*isVarArg*/ false);
+ llvm::FunctionType::get(CGM.VoidTy, llvm::None, /*isVarArg*/ false);
RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_deinit");
break;
}
+ case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: {
+ /// Build void __kmpc_kernel_prepare_parallel(
+ /// void *outlined_function);
+ llvm::Type *TypeParams[] = {CGM.Int8PtrTy};
+ llvm::FunctionType *FnTy =
+ llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel");
+ break;
+ }
+ case OMPRTL_NVPTX__kmpc_kernel_parallel: {
+ /// Build bool __kmpc_kernel_parallel(void **outlined_function);
+ llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy};
+ llvm::Type *RetTy = CGM.getTypes().ConvertType(CGM.getContext().BoolTy);
+ llvm::FunctionType *FnTy =
+ llvm::FunctionType::get(RetTy, TypeParams, /*isVarArg*/ false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_parallel");
+ break;
+ }
+ case OMPRTL_NVPTX__kmpc_kernel_end_parallel: {
+ /// Build void __kmpc_kernel_end_parallel();
+ llvm::FunctionType *FnTy =
+ llvm::FunctionType::get(CGM.VoidTy, llvm::None, /*isVarArg*/ false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_end_parallel");
+ break;
+ }
+ case OMPRTL_NVPTX__kmpc_serialized_parallel: {
+ // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+ // global_tid);
+ llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+ llvm::FunctionType *FnTy =
+ llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_serialized_parallel");
+ break;
+ }
+ case OMPRTL_NVPTX__kmpc_end_serialized_parallel: {
+ // Build void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+ // global_tid);
+ llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+ llvm::FunctionType *FnTy =
+ llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_serialized_parallel");
+ break;
+ }
}
return RTLFn;
}
@@ -362,9 +490,12 @@ llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOrTeamsOutlinedFunction(
OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
OutlinedFun->removeFnAttr(llvm::Attribute::NoInline);
OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
- } else
- llvm_unreachable("parallel directive is not yet supported for nvptx "
- "backend.");
+ } else {
+ llvm::Value *OutlinedFunVal =
+ CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
+ D, ThreadIDVar, InnermostKind, CodeGen);
+ OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
+ }
return OutlinedFun;
}
@@ -387,3 +518,81 @@ void CGOpenMPRuntimeNVPTX::emitTeamsCall(CodeGenFunction &CGF,
OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
}
+
+void CGOpenMPRuntimeNVPTX::emitParallelCall(
+ CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
+ ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
+ if (!CGF.HaveInsertPoint())
+ return;
+
+ emitGenericParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond);
+}
+
+void CGOpenMPRuntimeNVPTX::emitGenericParallelCall(
+ CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
+ ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
+ llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
+
+ auto &&L0ParallelGen = [this, Fn, &CapturedVars](CodeGenFunction &CGF,
+ PrePostActionTy &) {
+ CGBuilderTy &Bld = CGF.Builder;
+
+ // Prepare for parallel region. Indicate the outlined function.
+ llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy)};
+ CGF.EmitRuntimeCall(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
+ Args);
+
+ // Activate workers. This barrier is used by the master to signal
+ // work for the workers.
+ syncCTAThreads(CGF);
+
+ // OpenMP [2.5, Parallel Construct, p.49]
+ // There is an implied barrier at the end of a parallel region. After the
+ // end of a parallel region, only the master thread of the team resumes
+ // execution of the enclosing task region.
+ //
+ // The master waits at this barrier until all workers are done.
+ syncCTAThreads(CGF);
+
+ // Remember for post-processing in worker loop.
+ Work.push_back(Fn);
+ };
+
+ auto *RTLoc = emitUpdateLocation(CGF, Loc);
+ auto *ThreadID = getThreadID(CGF, Loc);
+ llvm::Value *Args[] = {RTLoc, ThreadID};
+
+ auto &&SeqGen = [this, Fn, &CapturedVars, &Args](CodeGenFunction &CGF,
+ PrePostActionTy &) {
+ auto &&CodeGen = [this, Fn, &CapturedVars, &Args](CodeGenFunction &CGF,
+ PrePostActionTy &Action) {
+ Action.Enter(CGF);
+
+ llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
+ OutlinedFnArgs.push_back(
+ llvm::ConstantPointerNull::get(CGM.Int32Ty->getPointerTo()));
+ OutlinedFnArgs.push_back(
+ llvm::ConstantPointerNull::get(CGM.Int32Ty->getPointerTo()));
+ OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
+ CGF.EmitCallOrInvoke(Fn, OutlinedFnArgs);
+ };
+
+ RegionCodeGenTy RCG(CodeGen);
+ NVPTXActionTy Action(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_serialized_parallel),
+ Args,
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_end_serialized_parallel),
+ Args);
+ RCG.setAction(Action);
+ RCG(CGF);
+ };
+
+ if (IfCond)
+ emitOMPIfClause(CGF, IfCond, L0ParallelGen, SeqGen);
+ else {
+ CodeGenFunction::RunCleanupsScope Scope(CGF);
+ RegionCodeGenTy ThenRCG(L0ParallelGen);
+ ThenRCG(CGF);
+ }
+}