diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp new file mode 100644 index 000000000000..13c3c12c1b41 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -0,0 +1,288 @@ +//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass modifies function signatures containing aggregate arguments +// and/or return value. Also it substitutes some llvm intrinsic calls by +// function calls, generating these functions as the translator does. +// +// NOTE: this pass is a module-level one due to the necessity to modify +// GVs/functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVPrepareFunctionsPass(PassRegistry &); +} + +namespace { + +class SPIRVPrepareFunctions : public ModulePass { + Function *processFunctionSignature(Function *F); + +public: + static char ID; + SPIRVPrepareFunctions() : ModulePass(ID) { + initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + StringRef getPassName() const override { return "SPIRV prepare functions"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + ModulePass::getAnalysisUsage(AU); + } +}; + +} // namespace + +char SPIRVPrepareFunctions::ID = 0; + +INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", + "SPIRV prepare functions", false, false) + +Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) { + IRBuilder<> B(F->getContext()); + + bool IsRetAggr = F->getReturnType()->isAggregateType(); + bool HasAggrArg = + std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { + return Arg.getType()->isAggregateType(); + }); + bool DoClone = IsRetAggr || HasAggrArg; + if (!DoClone) + return F; + SmallVector<std::pair<int, Type *>, 4> ChangedTypes; + Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); + if (IsRetAggr) + ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); + SmallVector<Type *, 4> ArgTypes; + for (const auto &Arg : F->args()) { + if (Arg.getType()->isAggregateType()) { + ArgTypes.push_back(B.getInt32Ty()); + ChangedTypes.push_back( + std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); + } else + ArgTypes.push_back(Arg.getType()); + } + FunctionType *NewFTy = + FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); + Function *NewF = + Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); + + ValueToValueMapTy VMap; + auto NewFArgIt = NewF->arg_begin(); + for (auto &Arg : F->args()) { + StringRef ArgName = Arg.getName(); + NewFArgIt->setName(ArgName); + VMap[&Arg] = &(*NewFArgIt++); + } + SmallVector<ReturnInst *, 8> Returns; + + CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns); + NewF->takeName(F); + + NamedMDNode *FuncMD = + F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); + SmallVector<Metadata *, 2> MDArgs; + MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); + for (auto &ChangedTyP : ChangedTypes) + MDArgs.push_back(MDNode::get( + B.getContext(), + {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), + ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); + MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); + FuncMD->addOperand(ThisFuncMD); + + for (auto *U : make_early_inc_range(F->users())) { + if (auto *CI = dyn_cast<CallInst>(U)) + CI->mutateFunctionType(NewF->getFunctionType()); + U->replaceUsesOfWith(F, NewF); + } + return NewF; +} + +std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { + Function *IntrinsicFunc = II->getCalledFunction(); + assert(IntrinsicFunc && "Missing function"); + std::string FuncName = IntrinsicFunc->getName().str(); + std::replace(FuncName.begin(), FuncName.end(), '.', '_'); + FuncName = "spirv." + FuncName; + return FuncName; +} + +static Function *getOrCreateFunction(Module *M, Type *RetTy, + ArrayRef<Type *> ArgTypes, + StringRef Name) { + FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); + Function *F = M->getFunction(Name); + if (F && F->getFunctionType() == FT) + return F; + Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); + if (F) + NewF->setDSOLocal(F->isDSOLocal()); + NewF->setCallingConv(CallingConv::SPIR_FUNC); + return NewF; +} + +static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { + // Get a separate function - otherwise, we'd have to rework the CFG of the + // current one. Then simply replace the intrinsic uses with a call to the new + // function. + // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) + FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); + Type *FSHRetTy = FSHFuncTy->getReturnType(); + const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); + Function *FSHFunc = + getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); + + if (!FSHFunc->empty()) { + FSHIntrinsic->setCalledFunction(FSHFunc); + return; + } + BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); + IRBuilder<> IRB(RotateBB); + Type *Ty = FSHFunc->getReturnType(); + // Build the actual funnel shift rotate logic. + // In the comments, "int" is used interchangeably with "vector of int + // elements". + FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); + Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; + unsigned BitWidth = IntTy->getIntegerBitWidth(); + ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); + Value *BitWidthForInsts = + VectorTy + ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) + : BitWidthConstant; + Value *RotateModVal = + IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); + Value *FirstShift = nullptr, *SecShift = nullptr; + if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { + // Shift the less significant number right, the "rotate" number of bits + // will be 0-filled on the left as a result of this regular shift. + FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); + } else { + // Shift the more significant number left, the "rotate" number of bits + // will be 0-filled on the right as a result of this regular shift. + FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); + } + // We want the "rotate" number of the more significant int's LSBs (MSBs) to + // occupy the leftmost (rightmost) "0 space" left by the previous operation. + // Therefore, subtract the "rotate" number from the integer bitsize... + Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); + if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { + // ...and left-shift the more significant int by this number, zero-filling + // the LSBs. + SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); + } else { + // ...and right-shift the less significant int by this number, zero-filling + // the MSBs. + SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); + } + // A simple binary addition of the shifted ints yields the final result. + IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); + + FSHIntrinsic->setCalledFunction(FSHFunc); +} + +static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { + // The function body is already created. + if (!UMulFunc->empty()) + return; + + BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc); + IRBuilder<> IRB(EntryBB); + // Build the actual unsigned multiplication logic with the overflow + // indication. Do unsigned multiplication Mul = A * B. Then check + // if unsigned division Div = Mul / A is not equal to B. If so, + // then overflow has happened. + Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); + Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); + Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); + + // umul.with.overflow intrinsic return a structure, where the first element + // is the multiplication result, and the second is an overflow bit. + Type *StructTy = UMulFunc->getReturnType(); + Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0}); + Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); + IRB.CreateRet(Res); +} + +static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) { + // Get a separate function - otherwise, we'd have to rework the CFG of the + // current one. Then simply replace the intrinsic uses with a call to the new + // function. + FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); + Type *FSHLRetTy = UMulFuncTy->getReturnType(); + const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); + Function *UMulFunc = + getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); + buildUMulWithOverflowFunc(M, UMulFunc); + UMulIntrinsic->setCalledFunction(UMulFunc); +} + +static void substituteIntrinsicCalls(Module *M, Function *F) { + for (BasicBlock &BB : *F) { + for (Instruction &I : BB) { + auto Call = dyn_cast<CallInst>(&I); + if (!Call) + continue; + Call->setTailCall(false); + Function *CF = Call->getCalledFunction(); + if (!CF || !CF->isIntrinsic()) + continue; + auto *II = cast<IntrinsicInst>(Call); + if (II->getIntrinsicID() == Intrinsic::fshl || + II->getIntrinsicID() == Intrinsic::fshr) + lowerFunnelShifts(M, II); + else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) + lowerUMulWithOverflow(M, II); + } + } +} + +bool SPIRVPrepareFunctions::runOnModule(Module &M) { + for (Function &F : M) + substituteIntrinsicCalls(&M, &F); + + std::vector<Function *> FuncsWorklist; + bool Changed = false; + for (auto &F : M) + FuncsWorklist.push_back(&F); + + for (auto *Func : FuncsWorklist) { + Function *F = processFunctionSignature(Func); + + bool CreatedNewF = F != Func; + + if (Func->isDeclaration()) { + Changed |= CreatedNewF; + continue; + } + + if (CreatedNewF) + Func->eraseFromParent(); + } + + return Changed; +} + +ModulePass *llvm::createSPIRVPrepareFunctionsPass() { + return new SPIRVPrepareFunctions(); +} |