summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
commit344a3780b2e33f6ca763666c380202b18aab72a3 (patch)
treef0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
parentb60736ec1405bb0a8dd40989f67ef4c93da068ab (diff)
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp460
1 files changed, 460 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
new file mode 100644
index 000000000000..dabb4d006d99
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
@@ -0,0 +1,460 @@
+//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass replaces all the uses of LDS within non-kernel functions by
+// corresponding pointer counter-parts.
+//
+// The main motivation behind this pass is - to *avoid* subsequent LDS lowering
+// pass from directly packing LDS (assume large LDS) into a struct type which
+// would otherwise cause allocating huge memory for struct instance within every
+// kernel.
+//
+// Brief sketch of the algorithm implemented in this pass is as below:
+//
+// 1. Collect all the LDS defined in the module which qualify for pointer
+// replacement, say it is, LDSGlobals set.
+//
+// 2. Collect all the reachable callees for each kernel defined in the module,
+// say it is, KernelToCallees map.
+//
+// 3. FOR (each global GV from LDSGlobals set) DO
+// LDSUsedNonKernels = Collect all non-kernel functions which use GV.
+// FOR (each kernel K in KernelToCallees map) DO
+// ReachableCallees = KernelToCallees[K]
+// ReachableAndLDSUsedCallees =
+// SetIntersect(LDSUsedNonKernels, ReachableCallees)
+// IF (ReachableAndLDSUsedCallees is not empty) THEN
+// Pointer = Create a pointer to point-to GV if not created.
+// Initialize Pointer to point-to GV within kernel K.
+// ENDIF
+// ENDFOR
+// Replace all uses of GV within non kernel functions by Pointer.
+// ENFOR
+//
+// LLVM IR example:
+//
+// Input IR:
+//
+// @lds = internal addrspace(3) global [4 x i32] undef, align 16
+//
+// define internal void @f0() {
+// entry:
+// %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
+// i32 0, i32 0
+// ret void
+// }
+//
+// define protected amdgpu_kernel void @k0() {
+// entry:
+// call void @f0()
+// ret void
+// }
+//
+// Output IR:
+//
+// @lds = internal addrspace(3) global [4 x i32] undef, align 16
+// @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
+//
+// define internal void @f0() {
+// entry:
+// %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
+// %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
+// %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
+// %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
+// i32 0, i32 0
+// ret void
+// }
+//
+// define protected amdgpu_kernel void @k0() {
+// entry:
+// store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
+// i16 addrspace(3)* @lds.ptr, align 2
+// call void @f0()
+// ret void
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "GCNSubtarget.h"
+#include "Utils/AMDGPUBaseInfo.h"
+#include "Utils/AMDGPULDSUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <algorithm>
+#include <vector>
+
+#define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
+
+using namespace llvm;
+
+namespace {
+
+class ReplaceLDSUseImpl {
+ Module &M;
+ LLVMContext &Ctx;
+ const DataLayout &DL;
+ Constant *LDSMemBaseAddr;
+
+ DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
+ DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
+ DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
+ DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
+ DenseMap<Function *, BasicBlock *> KernelToInitBB;
+ DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
+ FunctionToLDSToReplaceInst;
+
+ // Collect LDS which requires their uses to be replaced by pointer.
+ std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
+ // Collect LDS which requires module lowering.
+ std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);
+
+ // Remove LDS which don't qualify for replacement.
+ LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(),
+ [&](GlobalVariable *GV) {
+ return shouldIgnorePointerReplacement(GV);
+ }),
+ LDSGlobals.end());
+
+ return LDSGlobals;
+ }
+
+ // Returns true if uses of given LDS global within non-kernel functions should
+ // be keep as it is without pointer replacement.
+ bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
+ // LDS whose size is very small and doesn`t exceed pointer size is not worth
+ // replacing.
+ if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
+ return true;
+
+ // LDS which is not used from non-kernel function scope or it is used from
+ // global scope does not qualify for replacement.
+ LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
+ return LDSToNonKernels[GV].empty();
+
+ // FIXME: When GV is used within all (or within most of the kernels), then
+ // it does not make sense to create a pointer for it.
+ }
+
+ // Insert new global LDS pointer which points to LDS.
+ GlobalVariable *createLDSPointer(GlobalVariable *GV) {
+ // LDS pointer which points to LDS is already created? return it.
+ auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
+ if (!PointerEntry.second)
+ return PointerEntry.first->second;
+
+ // We need to create new LDS pointer which points to LDS.
+ //
+ // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
+ // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
+ auto *I16Ty = Type::getInt16Ty(Ctx);
+ GlobalVariable *LDSPointer = new GlobalVariable(
+ M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
+ GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
+ AMDGPUAS::LOCAL_ADDRESS);
+
+ LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
+ LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));
+
+ // Mark that an associated LDS pointer is created for LDS.
+ LDSToPointer[GV] = LDSPointer;
+
+ return LDSPointer;
+ }
+
+ // Split entry basic block in such a way that only lane 0 of each wave does
+ // the LDS pointer initialization, and return newly created basic block.
+ BasicBlock *activateLaneZero(Function *K) {
+ // If the entry basic block of kernel K is already splitted, then return
+ // newly created basic block.
+ auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
+ if (!BasicBlockEntry.second)
+ return BasicBlockEntry.first->second;
+
+ // Split entry basic block of kernel K.
+ auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
+ IRBuilder<> Builder(EI);
+
+ Value *Mbcnt =
+ Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
+ {Builder.getInt32(-1), Builder.getInt32(0)});
+ Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
+ Instruction *WB = cast<Instruction>(
+ Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));
+
+ BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();
+
+ // Mark that the entry basic block of kernel K is splitted.
+ KernelToInitBB[K] = NBB;
+
+ return NBB;
+ }
+
+ // Within given kernel, initialize given LDS pointer to point to given LDS.
+ void initializeLDSPointer(Function *K, GlobalVariable *GV,
+ GlobalVariable *LDSPointer) {
+ // If LDS pointer is already initialized within K, then nothing to do.
+ auto PointerEntry = KernelToLDSPointers.insert(
+ std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
+ if (!PointerEntry.second)
+ if (PointerEntry.first->second.contains(LDSPointer))
+ return;
+
+ // Insert instructions at EI which initialize LDS pointer to point-to LDS
+ // within kernel K.
+ //
+ // That is, convert pointer type of GV to i16, and then store this converted
+ // i16 value within LDSPointer which is of type i16*.
+ auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
+ IRBuilder<> Builder(EI);
+ Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
+ LDSPointer);
+
+ // Mark that LDS pointer is initialized within kernel K.
+ KernelToLDSPointers[K].insert(LDSPointer);
+ }
+
+ // We have created an LDS pointer for LDS, and initialized it to point-to LDS
+ // within all relevent kernels. Now replace all the uses of LDS within
+ // non-kernel functions by LDS pointer.
+ void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
+ SmallVector<User *, 8> LDSUsers(GV->users());
+ for (auto *U : LDSUsers) {
+ // When `U` is a constant expression, it is possible that same constant
+ // expression exists within multiple instructions, and within multiple
+ // non-kernel functions. Collect all those non-kernel functions and all
+ // those instructions within which `U` exist.
+ auto FunctionToInsts =
+ AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);
+
+ for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
+ FI != FE; ++FI) {
+ Function *F = FI->first;
+ auto &Insts = FI->second;
+ for (auto *I : Insts) {
+ // If `U` is a constant expression, then we need to break the
+ // associated instruction into a set of separate instructions by
+ // converting constant expressions into instructions.
+ SmallPtrSet<Instruction *, 8> UserInsts;
+
+ if (U == I) {
+ // `U` is an instruction, conversion from constant expression to
+ // set of instructions is *not* required.
+ UserInsts.insert(I);
+ } else {
+ // `U` is a constant expression, convert it into corresponding set
+ // of instructions.
+ auto *CE = cast<ConstantExpr>(U);
+ convertConstantExprsToInstructions(I, CE, &UserInsts);
+ }
+
+ // Go through all the user instrutions, if LDS exist within them as an
+ // operand, then replace it by replace instruction.
+ for (auto *II : UserInsts) {
+ auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
+ II->replaceUsesOfWith(GV, ReplaceInst);
+ }
+ }
+ }
+ }
+ }
+
+ // Create a set of replacement instructions which together replace LDS within
+ // non-kernel function F by accessing LDS indirectly using LDS pointer.
+ Value *getReplacementInst(Function *F, GlobalVariable *GV,
+ GlobalVariable *LDSPointer) {
+ // If the instruction which replaces LDS within F is already created, then
+ // return it.
+ auto LDSEntry = FunctionToLDSToReplaceInst.insert(
+ std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
+ if (!LDSEntry.second) {
+ auto ReplaceInstEntry =
+ LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
+ if (!ReplaceInstEntry.second)
+ return ReplaceInstEntry.first->second;
+ }
+
+ // Get the instruction insertion point within the beginning of the entry
+ // block of current non-kernel function.
+ auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
+ IRBuilder<> Builder(EI);
+
+ // Insert required set of instructions which replace LDS within F.
+ auto *V = Builder.CreateBitCast(
+ Builder.CreateGEP(
+ Builder.getInt8Ty(), LDSMemBaseAddr,
+ Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
+ GV->getType());
+
+ // Mark that the replacement instruction which replace LDS within F is
+ // created.
+ FunctionToLDSToReplaceInst[F][GV] = V;
+
+ return V;
+ }
+
+public:
+ ReplaceLDSUseImpl(Module &M)
+ : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
+ LDSMemBaseAddr = Constant::getIntegerValue(
+ PointerType::get(Type::getInt8Ty(M.getContext()),
+ AMDGPUAS::LOCAL_ADDRESS),
+ APInt(32, 0));
+ }
+
+ // Entry-point function which interface ReplaceLDSUseImpl with outside of the
+ // class.
+ bool replaceLDSUse();
+
+private:
+ // For a given LDS from collected LDS globals set, replace its non-kernel
+ // function scope uses by pointer.
+ bool replaceLDSUse(GlobalVariable *GV);
+};
+
+// For given LDS from collected LDS globals set, replace its non-kernel function
+// scope uses by pointer.
+bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
+ // Holds all those non-kernel functions within which LDS is being accessed.
+ SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];
+
+ // The LDS pointer which points to LDS and replaces all the uses of LDS.
+ GlobalVariable *LDSPointer = nullptr;
+
+ // Traverse through each kernel K, check and if required, initialize the
+ // LDS pointer to point to LDS within K.
+ for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
+ ++KI) {
+ Function *K = KI->first;
+ SmallPtrSet<Function *, 8> Callees = KI->second;
+
+ // Compute reachable and LDS used callees for kernel K.
+ set_intersect(Callees, LDSAccessors);
+
+ // None of the LDS accessing non-kernel functions are reachable from
+ // kernel K. Hence, no need to initialize LDS pointer within kernel K.
+ if (Callees.empty())
+ continue;
+
+ // We have found reachable and LDS used callees for kernel K, and we need to
+ // initialize LDS pointer within kernel K, and we need to replace LDS use
+ // within those callees by LDS pointer.
+ //
+ // But, first check if LDS pointer is already created, if not create one.
+ LDSPointer = createLDSPointer(GV);
+
+ // Initialize LDS pointer to point to LDS within kernel K.
+ initializeLDSPointer(K, GV, LDSPointer);
+ }
+
+ // We have not found reachable and LDS used callees for any of the kernels,
+ // and hence we have not created LDS pointer.
+ if (!LDSPointer)
+ return false;
+
+ // We have created an LDS pointer for LDS, and initialized it to point-to LDS
+ // within all relevent kernels. Now replace all the uses of LDS within
+ // non-kernel functions by LDS pointer.
+ replaceLDSUseByPointer(GV, LDSPointer);
+
+ return true;
+}
+
+// Entry-point function which interface ReplaceLDSUseImpl with outside of the
+// class.
+bool ReplaceLDSUseImpl::replaceLDSUse() {
+ // Collect LDS which requires their uses to be replaced by pointer.
+ std::vector<GlobalVariable *> LDSGlobals =
+ collectLDSRequiringPointerReplace();
+
+ // No LDS to pointer-replace. Nothing to do.
+ if (LDSGlobals.empty())
+ return false;
+
+ // Collect reachable callee set for each kernel defined in the module.
+ AMDGPU::collectReachableCallees(M, KernelToCallees);
+
+ if (KernelToCallees.empty()) {
+ // Either module does not have any kernel definitions, or none of the kernel
+ // has a call to non-kernel functions, or we could not resolve any of the
+ // call sites to proper non-kernel functions, because of the situations like
+ // inline asm calls. Nothing to replace.
+ return false;
+ }
+
+ // For every LDS from collected LDS globals set, replace its non-kernel
+ // function scope use by pointer.
+ bool Changed = false;
+ for (auto *GV : LDSGlobals)
+ Changed |= replaceLDSUse(GV);
+
+ return Changed;
+}
+
+class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
+public:
+ static char ID;
+
+ AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
+ initializeAMDGPUReplaceLDSUseWithPointerPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ bool runOnModule(Module &M) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetPassConfig>();
+ }
+};
+
+} // namespace
+
+char AMDGPUReplaceLDSUseWithPointer::ID = 0;
+char &llvm::AMDGPUReplaceLDSUseWithPointerID =
+ AMDGPUReplaceLDSUseWithPointer::ID;
+
+INITIALIZE_PASS_BEGIN(
+ AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
+ "Replace within non-kernel function use of LDS with pointer",
+ false /*only look at the cfg*/, false /*analysis pass*/)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_END(
+ AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
+ "Replace within non-kernel function use of LDS with pointer",
+ false /*only look at the cfg*/, false /*analysis pass*/)
+
+bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
+ ReplaceLDSUseImpl LDSUseReplacer{M};
+ return LDSUseReplacer.replaceLDSUse();
+}
+
+ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
+ return new AMDGPUReplaceLDSUseWithPointer();
+}
+
+PreservedAnalyses
+AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
+ ReplaceLDSUseImpl LDSUseReplacer{M};
+ LDSUseReplacer.replaceLDSUse();
+ return PreservedAnalyses::all();
+}