//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass implements regularization of LLVM IR for SPIR-V. The prototype of // the pass was taken from SPIRV-LLVM translator. // //===----------------------------------------------------------------------===// #include "SPIRV.h" #include "SPIRVTargetMachine.h" #include "llvm/Demangle/Demangle.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PassManager.h" #include "llvm/Transforms/Utils/Cloning.h" #include #define DEBUG_TYPE "spirv-regularizer" using namespace llvm; namespace llvm { void initializeSPIRVRegularizerPass(PassRegistry &); } namespace { struct SPIRVRegularizer : public FunctionPass, InstVisitor { DenseMap Old2NewFuncs; public: static char ID; SPIRVRegularizer() : FunctionPass(ID) { initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "SPIR-V Regularizer"; } void getAnalysisUsage(AnalysisUsage &AU) const override { FunctionPass::getAnalysisUsage(AU); } void visitCallInst(CallInst &CI); private: void visitCallScalToVec(CallInst *CI, StringRef MangledName, StringRef DemangledName); void runLowerConstExpr(Function &F); }; } // namespace char SPIRVRegularizer::ID = 0; INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, false) // Since SPIR-V cannot represent constant expression, constant expressions // in LLVM IR need to be lowered to instructions. For each function, // the constant expressions used by instructions of the function are replaced // by instructions placed in the entry block since it dominates all other BBs. // Each constant expression only needs to be lowered once in each function // and all uses of it by instructions in that function are replaced by // one instruction. // TODO: remove redundant instructions for common subexpression. void SPIRVRegularizer::runLowerConstExpr(Function &F) { LLVMContext &Ctx = F.getContext(); std::list WorkList; for (auto &II : instructions(F)) WorkList.push_back(&II); auto FBegin = F.begin(); while (!WorkList.empty()) { Instruction *II = WorkList.front(); auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { if (isa(V)) return V; auto *CE = cast(V); LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); auto ReplInst = CE->getAsInstruction(); auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); ReplInst->insertBefore(InsPoint); LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); std::vector Users; // Do not replace use during iteration of use. Do it in another loop. for (auto U : CE->users()) { LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n'); auto InstUser = dyn_cast(U); // Only replace users in scope of current function. if (InstUser && InstUser->getParent()->getParent() == &F) Users.push_back(InstUser); } for (auto &User : Users) { if (ReplInst->getParent() == User->getParent() && User->comesBefore(ReplInst)) ReplInst->moveBefore(User); User->replaceUsesOfWith(CE, ReplInst); } return ReplInst; }; WorkList.pop_front(); auto LowerConstantVec = [&II, &LowerOp, &WorkList, &Ctx](ConstantVector *Vec, unsigned NumOfOp) -> Value * { if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { return isa(V) || isa(V); })) { // Expand a vector of constexprs and construct it back with // series of insertelement instructions. std::list OpList; std::transform(Vec->op_begin(), Vec->op_end(), std::back_inserter(OpList), [LowerOp](Value *V) { return LowerOp(V); }); Value *Repl = nullptr; unsigned Idx = 0; auto *PhiII = dyn_cast(II); Instruction *InsPoint = PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; std::list ReplList; for (auto V : OpList) { if (auto *Inst = dyn_cast(V)) ReplList.push_back(Inst); Repl = InsertElementInst::Create( (Repl ? Repl : PoisonValue::get(Vec->getType())), V, ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint); } WorkList.splice(WorkList.begin(), ReplList); return Repl; } return nullptr; }; for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { auto *Op = II->getOperand(OI); if (auto *Vec = dyn_cast(Op)) { Value *ReplInst = LowerConstantVec(Vec, OI); if (ReplInst) II->replaceUsesOfWith(Op, ReplInst); } else if (auto CE = dyn_cast(Op)) { WorkList.push_front(cast(LowerOp(CE))); } else if (auto MDAsVal = dyn_cast(Op)) { auto ConstMD = dyn_cast(MDAsVal->getMetadata()); if (!ConstMD) continue; Constant *C = ConstMD->getValue(); Value *ReplInst = nullptr; if (auto *Vec = dyn_cast(C)) ReplInst = LowerConstantVec(Vec, OI); if (auto *CE = dyn_cast(C)) ReplInst = LowerOp(CE); if (!ReplInst) continue; Metadata *RepMD = ValueAsMetadata::get(ReplInst); Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD); II->setOperand(OI, RepMDVal); WorkList.push_front(cast(ReplInst)); } } } } // It fixes calls to OCL builtins that accept vector arguments and one of them // is actually a scalar splat. void SPIRVRegularizer::visitCallInst(CallInst &CI) { auto F = CI.getCalledFunction(); if (!F) return; auto MangledName = F->getName(); char *NameStr = itaniumDemangle(F->getName().data()); if (!NameStr) return; StringRef DemangledName(NameStr); // TODO: add support for other builtins. if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") || DemangledName.starts_with("min") || DemangledName.starts_with("max")) visitCallScalToVec(&CI, MangledName, DemangledName); free(NameStr); } void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName, StringRef DemangledName) { // Check if all arguments have the same type - it's simple case. auto Uniform = true; Type *Arg0Ty = CI->getOperand(0)->getType(); auto IsArg0Vector = isa(Arg0Ty); for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I) Uniform = isa(CI->getOperand(I)->getType()) == IsArg0Vector; if (Uniform) return; auto *OldF = CI->getCalledFunction(); Function *NewF = nullptr; if (!Old2NewFuncs.count(OldF)) { AttributeList Attrs = CI->getCalledFunction()->getAttributes(); SmallVector ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty}; auto *NewFTy = FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg()); NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(), *OldF->getParent()); ValueToValueMapTy VMap; auto NewFArgIt = NewF->arg_begin(); for (auto &Arg : OldF->args()) { auto ArgName = Arg.getName(); NewFArgIt->setName(ArgName); VMap[&Arg] = &(*NewFArgIt++); } SmallVector Returns; CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly, Returns); NewF->setAttributes(Attrs); Old2NewFuncs[OldF] = NewF; } else { NewF = Old2NewFuncs[OldF]; } assert(NewF); // This produces an instruction sequence that implements a splat of // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst // and ShuffleVectorInst to generate the same code as the SPIR-V translator. // For instance (transcoding/OpMin.ll), this call // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> , i32 5) // is translated to // %8 = OpUndef %v2uint // %14 = OpConstantComposite %v2uint %uint_1 %uint_10 // ... // %10 = OpCompositeInsert %v2uint %uint_5 %8 0 // %11 = OpVectorShuffle %v2uint %10 %8 0 0 // %call = OpExtInst %v2uint %1 s_min %14 %11 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); PoisonValue *PVal = PoisonValue::get(Arg0Ty); Instruction *Inst = InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI); ElementCount VecElemCount = cast(Arg0Ty)->getElementCount(); Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt); Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI); CI->setOperand(1, NewVec); CI->replaceUsesOfWith(OldF, NewF); CI->mutateFunctionType(NewF->getFunctionType()); } bool SPIRVRegularizer::runOnFunction(Function &F) { runLowerConstExpr(F); visit(F); for (auto &OldNew : Old2NewFuncs) { Function *OldF = OldNew.first; Function *NewF = OldNew.second; NewF->takeName(OldF); OldF->eraseFromParent(); } return true; } FunctionPass *llvm::createSPIRVRegularizerPass() { return new SPIRVRegularizer(); }