diff options
Diffstat (limited to 'llvm/lib/Analysis/VFABIDemangling.cpp')
-rw-r--r-- | llvm/lib/Analysis/VFABIDemangling.cpp | 108 |
1 files changed, 80 insertions, 28 deletions
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp index a331b95e818b..0192a216b2f7 100644 --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -70,6 +70,9 @@ ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) { /// ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) { if (ParseString.consume_front("x")) { + // Set VF to 0, to be later adjusted to a value grater than zero + // by looking at the signature of the vector function with + // `getECFromSignature`. VF = 0; IsScalable = true; return ParseRet::OK; @@ -78,6 +81,10 @@ ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) { if (ParseString.consumeInteger(10, VF)) return ParseRet::Error; + // The token `0` is invalid for VLEN. + if (VF == 0) + return ParseRet::Error; + IsScalable = false; return ParseRet::OK; } @@ -207,28 +214,6 @@ ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString, return ParseRet::None; } -/// The function looks for the following strings at the beginning of -/// the input string `ParseString`: -/// -/// "u" <number> -/// -/// On success, it removes the parsed parameter from `ParseString`, -/// sets `PKind` to the correspondent enum value, sets `Pos` to -/// <number>, and return success. On a syntax error, it return a -/// parsing error. If nothing is parsed, it returns None. -ParseRet tryParseUniform(StringRef &ParseString, VFParamKind &PKind, int &Pos) { - // "u" <Pos> - const char *UniformToken = "u"; - if (ParseString.consume_front(UniformToken)) { - PKind = VFABI::getVFParamKindFromString(UniformToken); - if (ParseString.consumeInteger(10, Pos)) - return ParseRet::Error; - - return ParseRet::OK; - } - return ParseRet::None; -} - /// Looks into the <parameters> part of the mangled name in search /// for valid paramaters at the beginning of the string /// `ParseString`. @@ -245,6 +230,12 @@ ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind, return ParseRet::OK; } + if (ParseString.consume_front("u")) { + PKind = VFParamKind::OMP_Uniform; + StepOrPos = 0; + return ParseRet::OK; + } + const ParseRet HasLinearRuntime = tryParseLinearWithRuntimeStep(ParseString, PKind, StepOrPos); if (HasLinearRuntime != ParseRet::None) @@ -255,10 +246,6 @@ ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind, if (HasLinearCompileTime != ParseRet::None) return HasLinearCompileTime; - const ParseRet HasUniform = tryParseUniform(ParseString, PKind, StepOrPos); - if (HasUniform != ParseRet::None) - return HasUniform; - return ParseRet::None; } @@ -287,11 +274,50 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) { return ParseRet::None; } +#ifndef NDEBUG +// Verify the assumtion that all vectors in the signature of a vector +// function have the same number of elements. +bool verifyAllVectorsHaveSameWidth(FunctionType *Signature) { + SmallVector<VectorType *, 2> VecTys; + if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType())) + VecTys.push_back(RetTy); + for (auto *Ty : Signature->params()) + if (auto *VTy = dyn_cast<VectorType>(Ty)) + VecTys.push_back(VTy); + + if (VecTys.size() <= 1) + return true; + + assert(VecTys.size() > 1 && "Invalid number of elements."); + const ElementCount EC = VecTys[0]->getElementCount(); + return llvm::all_of( + llvm::make_range(VecTys.begin() + 1, VecTys.end()), + [&EC](VectorType *VTy) { return (EC == VTy->getElementCount()); }); +} + +#endif // NDEBUG + +// Extract the VectorizationFactor from a given function signature, +// under the assumtion that all vectors have the same number of +// elements, i.e. same ElementCount.Min. +ElementCount getECFromSignature(FunctionType *Signature) { + assert(verifyAllVectorsHaveSameWidth(Signature) && + "Invalid vector signature."); + + if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType())) + return RetTy->getElementCount(); + for (auto *Ty : Signature->params()) + if (auto *VTy = dyn_cast<VectorType>(Ty)) + return VTy->getElementCount(); + + return ElementCount(/*Min=*/1, /*Scalable=*/false); +} } // namespace // Format of the ABI name: // _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] -Optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName) { +Optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName, + const Module &M) { const StringRef OriginalName = MangledName; // Assume there is no custom name <redirection>, and therefore the // vector name consists of @@ -402,8 +428,34 @@ Optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName) { assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate && "The global predicate must be the last parameter"); + // Adjust the VF for scalable signatures. The EC.Min is not encoded + // in the name of the function, but it is encoded in the IR + // signature of the function. We need to extract this information + // because it is needed by the loop vectorizer, which reasons in + // terms of VectorizationFactor or ElementCount. In particular, we + // need to make sure that the VF field of the VFShape class is never + // set to 0. + if (IsScalable) { + const Function *F = M.getFunction(VectorName); + // The declaration of the function must be present in the module + // to be able to retrieve its signature. + if (!F) + return None; + const ElementCount EC = getECFromSignature(F->getFunctionType()); + VF = EC.Min; + } + + // Sanity checks. + // 1. We don't accept a zero lanes vectorization factor. + // 2. We don't accept the demangling if the vector function is not + // present in the module. + if (VF == 0) + return None; + if (!M.getFunction(VectorName)) + return None; + const VFShape Shape({VF, IsScalable, Parameters}); - return VFInfo({Shape, ScalarName, VectorName, ISA}); + return VFInfo({Shape, std::string(ScalarName), std::string(VectorName), ISA}); } VFParamKind VFABI::getVFParamKindFromString(const StringRef Token) { |