diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 160 |
1 files changed, 90 insertions, 70 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp index 893aa4a91828..56025aa5c45f 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// // -// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics -// with vector operands) with matching calls to functions from a vector -// library (e.g., libmvec, SVML) according to TargetLibraryInfo. +// Replaces LLVM IR instructions with vector operands (i.e., the frem +// instruction or calls to LLVM intrinsics) with matching calls to functions +// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface. // //===----------------------------------------------------------------------===// @@ -69,88 +69,98 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy, return TLIFunc; } -/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to -/// the corresponding function from the vector library ( \p TLIVecFunc ). -static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info, +/// Replace the instruction \p I with a call to the corresponding function from +/// the vector library (\p TLIVecFunc). +static void replaceWithTLIFunction(Instruction &I, VFInfo &Info, Function *TLIVecFunc) { - IRBuilder<> IRBuilder(&CalltoReplace); - SmallVector<Value *> Args(CalltoReplace.args()); + IRBuilder<> IRBuilder(&I); + auto *CI = dyn_cast<CallInst>(&I); + SmallVector<Value *> Args(CI ? CI->args() : I.operands()); if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) { - auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()), - Info.Shape.VF); + auto *MaskTy = + VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF); Args.insert(Args.begin() + OptMaskpos.value(), Constant::getAllOnesValue(MaskTy)); } - // Preserve the operand bundles. + // If it is a call instruction, preserve the operand bundles. SmallVector<OperandBundleDef, 1> OpBundles; - CalltoReplace.getOperandBundlesAsDefs(OpBundles); - CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles); - CalltoReplace.replaceAllUsesWith(Replacement); + if (CI) + CI->getOperandBundlesAsDefs(OpBundles); + + auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles); + I.replaceAllUsesWith(Replacement); // Preserve fast math flags for FP math. if (isa<FPMathOperator>(Replacement)) - Replacement->copyFastMathFlags(&CalltoReplace); + Replacement->copyFastMathFlags(&I); } -/// Returns true when successfully replaced \p CallToReplace with a suitable -/// function taking vector arguments, based on available mappings in the \p TLI. -/// Currently only works when \p CallToReplace is a call to vectorized -/// intrinsic. +/// Returns true when successfully replaced \p I with a suitable function taking +/// vector arguments, based on available mappings in the \p TLI. Currently only +/// works when \p I is a call to vectorized intrinsic or the frem instruction. static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, - CallInst &CallToReplace) { - if (!CallToReplace.getCalledFunction()) - return false; + Instruction &I) { + // At the moment VFABI assumes the return type is always widened unless it is + // a void type. + auto *VTy = dyn_cast<VectorType>(I.getType()); + ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0)); - auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID(); - // Replacement is only performed for intrinsic functions. - if (IntrinsicID == Intrinsic::not_intrinsic) - return false; - - // Compute arguments types of the corresponding scalar call. Additionally - // checks if in the vector call, all vector operands have the same EC. - ElementCount VF = ElementCount::getFixed(0); - SmallVector<Type *> ScalarArgTypes; - for (auto Arg : enumerate(CallToReplace.args())) { - auto *ArgTy = Arg.value()->getType(); - if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) { - ScalarArgTypes.push_back(ArgTy); - } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) { - ScalarArgTypes.push_back(ArgTy->getScalarType()); - // Disallow vector arguments with different VFs. When processing the first - // vector argument, store it's VF, and for the rest ensure that they match - // it. - if (VF.isZero()) - VF = VectorArgTy->getElementCount(); - else if (VF != VectorArgTy->getElementCount()) + // Compute the argument types of the corresponding scalar call and the scalar + // function name. For calls, it additionally finds the function to replace + // and checks that all vector operands match the previously found EC. + SmallVector<Type *, 8> ScalarArgTypes; + std::string ScalarName; + Function *FuncToReplace = nullptr; + if (auto *CI = dyn_cast<CallInst>(&I)) { + FuncToReplace = CI->getCalledFunction(); + Intrinsic::ID IID = FuncToReplace->getIntrinsicID(); + assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic"); + for (auto Arg : enumerate(CI->args())) { + auto *ArgTy = Arg.value()->getType(); + if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) { + ScalarArgTypes.push_back(ArgTy); + } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) { + ScalarArgTypes.push_back(VectorArgTy->getElementType()); + // When return type is void, set EC to the first vector argument, and + // disallow vector arguments with different ECs. + if (EC.isZero()) + EC = VectorArgTy->getElementCount(); + else if (EC != VectorArgTy->getElementCount()) + return false; + } else + // Exit when it is supposed to be a vector argument but it isn't. return false; - } else - // Exit when it is supposed to be a vector argument but it isn't. + } + // Try to reconstruct the name for the scalar version of the instruction, + // using scalar argument types. + ScalarName = Intrinsic::isOverloaded(IID) + ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule()) + : Intrinsic::getName(IID).str(); + } else { + assert(VTy && "Return type must be a vector"); + auto *ScalarTy = VTy->getScalarType(); + LibFunc Func; + if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func)) return false; + ScalarName = TLI.getName(Func); + ScalarArgTypes = {ScalarTy, ScalarTy}; } - // Try to reconstruct the name for the scalar version of this intrinsic using - // the intrinsic ID and the argument types converted to scalar above. - std::string ScalarName = - (Intrinsic::isOverloaded(IntrinsicID) - ? Intrinsic::getName(IntrinsicID, ScalarArgTypes, - CallToReplace.getModule()) - : Intrinsic::getName(IntrinsicID).str()); - // Try to find the mapping for the scalar version of this intrinsic and the // exact vector width of the call operands in the TargetLibraryInfo. First, // check with a non-masked variant, and if that fails try with a masked one. const VecDesc *VD = - TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ false); - if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ true))) + TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false); + if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true))) return false; LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName - << "` and vector width " << VF << " to: `" + << "` and vector width " << EC << " to: `" << VD->getVectorFnName() << "`.\n"); // Replace the call to the intrinsic with a call to the vector library // function. - Type *ScalarRetTy = CallToReplace.getType()->getScalarType(); + Type *ScalarRetTy = I.getType()->getScalarType(); FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false); const std::string MangledName = VD->getVectorFunctionABIVariantString(); @@ -162,27 +172,37 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, if (!VectorFTy) return false; - Function *FuncToReplace = CallToReplace.getCalledFunction(); - Function *TLIFunc = getTLIFunction(CallToReplace.getModule(), VectorFTy, + Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy, VD->getVectorFnName(), FuncToReplace); - replaceWithTLIFunction(CallToReplace, *OptInfo, TLIFunc); - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << FuncToReplace->getName() << "` with call to `" - << TLIFunc->getName() << "`.\n"); + replaceWithTLIFunction(I, *OptInfo, TLIFunc); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName + << "` with call to `" << TLIFunc->getName() << "`.\n"); ++NumCallsReplaced; return true; } +/// Supported instruction \p I must be a vectorized frem or a call to an +/// intrinsic that returns either void or a vector. +static bool isSupportedInstruction(Instruction *I) { + Type *Ty = I->getType(); + if (auto *CI = dyn_cast<CallInst>(I)) + return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() && + CI->getCalledFunction()->getIntrinsicID() != + Intrinsic::not_intrinsic; + if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy()) + return true; + return false; +} + static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector<CallInst *> ReplacedCalls; + SmallVector<Instruction *> ReplacedCalls; for (auto &I : instructions(F)) { - if (auto *CI = dyn_cast<CallInst>(&I)) { - if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); - Changed = true; - } + if (!isSupportedInstruction(&I)) + continue; + if (replaceWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); + Changed = true; } } // Erase the calls to the intrinsics that have been replaced |
