diff options
Diffstat (limited to 'clang/lib/Support/RISCVVIntrinsicUtils.cpp')
-rw-r--r-- | clang/lib/Support/RISCVVIntrinsicUtils.cpp | 226 |
1 files changed, 169 insertions, 57 deletions
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp index 513e6376f5ae..25084dd98e5c 100644 --- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -8,16 +8,15 @@ #include "clang/Support/RISCVVIntrinsicUtils.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <numeric> -#include <set> -#include <unordered_map> +#include <optional> using namespace llvm; @@ -67,7 +66,7 @@ VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { } // Illegal vscale result would be less than 1 if (Log2ScaleResult < 0) - return llvm::None; + return std::nullopt; return 1 << Log2ScaleResult; } @@ -114,7 +113,7 @@ bool RVVType::verifyType() const { return false; if (isFloat() && ElementBitwidth == 8) return false; - unsigned V = Scale.value(); + unsigned V = *Scale; switch (ElementBitwidth) { case 1: case 8: @@ -364,7 +363,8 @@ void RVVType::applyBasicType() { assert(ElementBitwidth != 0 && "Bad element bitwidth!"); } -Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( +std::optional<PrototypeDescriptor> +PrototypeDescriptor::parsePrototypeDescriptor( llvm::StringRef PrototypeDescriptorStr) { PrototypeDescriptor PD; BaseTypeModifier PT = BaseTypeModifier::Invalid; @@ -435,7 +435,7 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( uint32_t Log2EEW; if (ComplexTT.second.getAsInteger(10, Log2EEW)) { llvm_unreachable("Invalid Log2EEW value!"); - return None; + return std::nullopt; } switch (Log2EEW) { case 3: @@ -452,13 +452,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( break; default: llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); - return None; + return std::nullopt; } } else if (ComplexTT.first == "FixedSEW") { uint32_t NewSEW; if (ComplexTT.second.getAsInteger(10, NewSEW)) { llvm_unreachable("Invalid FixedSEW value!"); - return None; + return std::nullopt; } switch (NewSEW) { case 8: @@ -475,13 +475,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( break; default: llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); - return None; + return std::nullopt; } } else if (ComplexTT.first == "LFixedLog2LMUL") { int32_t Log2LMUL; if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { llvm_unreachable("Invalid LFixedLog2LMUL value!"); - return None; + return std::nullopt; } switch (Log2LMUL) { case -3: @@ -507,13 +507,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( break; default: llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); - return None; + return std::nullopt; } } else if (ComplexTT.first == "SFixedLog2LMUL") { int32_t Log2LMUL; if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { llvm_unreachable("Invalid SFixedLog2LMUL value!"); - return None; + return std::nullopt; } switch (Log2LMUL) { case -3: @@ -539,7 +539,7 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( break; default: llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); - return None; + return std::nullopt; } } else { @@ -785,20 +785,20 @@ void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { Scale = LMUL.getScale(ElementBitwidth); } -Optional<RVVTypes> -RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, - ArrayRef<PrototypeDescriptor> Prototype) { +std::optional<RVVTypes> +RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, + ArrayRef<PrototypeDescriptor> Prototype) { // LMUL x NF must be less than or equal to 8. if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) - return llvm::None; + return std::nullopt; RVVTypes Types; for (const PrototypeDescriptor &Proto : Prototype) { auto T = computeType(BT, Log2LMUL, Proto); if (!T) - return llvm::None; + return std::nullopt; // Record legal type index - Types.push_back(T.value()); + Types.push_back(*T); } return Types; } @@ -816,11 +816,8 @@ static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, ((uint64_t)(Proto.VTM & 0xff) << 32); } -Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL, - PrototypeDescriptor Proto) { - // Concat BasicType, LMUL and Proto as key - static std::unordered_map<uint64_t, RVVType> LegalTypes; - static std::set<uint64_t> IllegalTypes; +std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL, + PrototypeDescriptor Proto) { uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); // Search first auto It = LegalTypes.find(Idx); @@ -828,34 +825,38 @@ Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL, return &(It->second); if (IllegalTypes.count(Idx)) - return llvm::None; + return std::nullopt; // Compute type and record the result. RVVType T(BT, Log2LMUL, Proto); if (T.isValid()) { // Record legal type index and value. - LegalTypes.insert({Idx, T}); - return &(LegalTypes[Idx]); + std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool> + InsertResult = LegalTypes.insert({Idx, T}); + return &(InsertResult.first->second); } // Record illegal type index. IllegalTypes.insert(Idx); - return llvm::None; + return std::nullopt; } //===----------------------------------------------------------------------===// // RVVIntrinsic implementation //===----------------------------------------------------------------------===// -RVVIntrinsic::RVVIntrinsic( - StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, - StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, - bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, - bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen, - const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, - const std::vector<StringRef> &RequiredFeatures, unsigned NF) - : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme), - HasUnMaskedOverloaded(HasUnMaskedOverloaded), - HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()), - NF(NF) { +RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, + StringRef NewOverloadedName, + StringRef OverloadedSuffix, StringRef IRName, + bool IsMasked, bool HasMaskedOffOperand, bool HasVL, + PolicyScheme Scheme, bool SupportOverloading, + bool HasBuiltinAlias, StringRef ManualCodegen, + const RVVTypes &OutInTypes, + const std::vector<int64_t> &NewIntrinsicTypes, + const std::vector<StringRef> &RequiredFeatures, + unsigned NF, Policy NewPolicyAttrs) + : IRName(IRName), IsMasked(IsMasked), + HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), + SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), + ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { // Init BuiltinName, Name and OverloadedName BuiltinName = NewName.str(); @@ -868,10 +869,9 @@ RVVIntrinsic::RVVIntrinsic( Name += "_" + Suffix.str(); if (!OverloadedSuffix.empty()) OverloadedName += "_" + OverloadedSuffix.str(); - if (IsMasked) { - BuiltinName += "_m"; - Name += "_m"; - } + + updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, + PolicyAttrs); // Init OutputType and InputTypes OutputType = OutInTypes[0]; @@ -880,7 +880,7 @@ RVVIntrinsic::RVVIntrinsic( // IntrinsicTypes is unmasked TA version index. Need to update it // if there is merge operand (It is always in first operand). IntrinsicTypes = NewIntrinsicTypes; - if ((IsMasked && HasMaskedOffOperand) || + if ((IsMasked && hasMaskedOffOperand()) || (!IsMasked && hasPassthruOperand())) { for (auto &I : IntrinsicTypes) { if (I >= 0) @@ -899,36 +899,37 @@ std::string RVVIntrinsic::getBuiltinTypeStr() const { } std::string RVVIntrinsic::getSuffixStr( - BasicType Type, int Log2LMUL, + RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL, llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { SmallVector<std::string> SuffixStrs; for (auto PD : PrototypeDescriptors) { - auto T = RVVType::computeType(Type, Log2LMUL, PD); + auto T = TypeCache.computeType(Type, Log2LMUL, PD); SuffixStrs.push_back((*T)->getShortStr()); } return join(SuffixStrs, "_"); } -llvm::SmallVector<PrototypeDescriptor> -RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype, - bool IsMasked, bool HasMaskedOffOperand, - bool HasVL, unsigned NF) { +llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes( + llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked, + bool HasMaskedOffOperand, bool HasVL, unsigned NF, + PolicyScheme DefaultScheme, Policy PolicyAttrs) { SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(), Prototype.end()); + bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand; if (IsMasked) { - // If HasMaskedOffOperand, insert result type as first input operand. - if (HasMaskedOffOperand) { + // If HasMaskedOffOperand, insert result type as first input operand if + // need. + if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) { if (NF == 1) { NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]); - } else { + } else if (NF > 1) { // Convert // (void, op0 address, op1 address, ...) // to // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) PrototypeDescriptor MaskoffType = NewPrototype[1]; MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); - for (unsigned I = 0; I < NF; ++I) - NewPrototype.insert(NewPrototype.begin() + NF + 1, MaskoffType); + NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); } } if (HasMaskedOffOperand && NF > 1) { @@ -943,7 +944,21 @@ RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype, // If IsMasked, insert PrototypeDescriptor:Mask as first input operand. NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask); } - } + } else { + if (NF == 1) { + if (PolicyAttrs.isTUPolicy() && HasPassthruOp) + NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]); + } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) { + // NF > 1 cases for segment load operations. + // Convert + // (void, op0 address, op1 address, ...) + // to + // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...) + PrototypeDescriptor MaskoffType = Prototype[1]; + MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); + NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); + } + } // If HasVL, append PrototypeDescriptor:VL to last operand if (HasVL) @@ -951,6 +966,99 @@ RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype, return NewPrototype; } +llvm::SmallVector<Policy> +RVVIntrinsic::getSupportedUnMaskedPolicies(bool HasTailPolicy, + bool HasMaskPolicy) { + return { + Policy(Policy::PolicyType::Undisturbed, HasTailPolicy, + HasMaskPolicy), // TU + Policy(Policy::PolicyType::Agnostic, HasTailPolicy, HasMaskPolicy)}; // TA +} + +llvm::SmallVector<Policy> +RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy, + bool HasMaskPolicy) { + if (HasTailPolicy && HasMaskPolicy) + return { + Policy(Policy::PolicyType::Undisturbed, Policy::PolicyType::Agnostic, + HasTailPolicy, HasMaskPolicy), // TUMA + Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic, + HasTailPolicy, HasMaskPolicy), // TAMA + Policy(Policy::PolicyType::Undisturbed, Policy::PolicyType::Undisturbed, + HasTailPolicy, HasMaskPolicy), // TUMU + Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Undisturbed, + HasTailPolicy, HasMaskPolicy)}; // TAMU + if (HasTailPolicy && !HasMaskPolicy) + return {Policy(Policy::PolicyType::Undisturbed, + Policy::PolicyType::Agnostic, HasTailPolicy, + HasMaskPolicy), // TUM + Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic, + HasTailPolicy, HasMaskPolicy)}; // TAM + if (!HasTailPolicy && HasMaskPolicy) + return {Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic, + HasTailPolicy, HasMaskPolicy), // MA + Policy(Policy::PolicyType::Agnostic, + Policy::PolicyType::Undisturbed, HasTailPolicy, + HasMaskPolicy)}; // MU + llvm_unreachable("An RVV instruction should not be without both tail policy " + "and mask policy"); +} + +void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy, + std::string &Name, + std::string &BuiltinName, + std::string &OverloadedName, + Policy &PolicyAttrs) { + + auto appendPolicySuffix = [&](const std::string &suffix) { + Name += suffix; + BuiltinName += suffix; + OverloadedName += suffix; + }; + + if (PolicyAttrs.isUnspecified()) { + PolicyAttrs.IsUnspecified = false; + if (IsMasked) { + Name += "_m"; + if (HasPolicy) + BuiltinName += "_tama"; + else + BuiltinName += "_m"; + } else { + if (HasPolicy) + BuiltinName += "_ta"; + } + } else { + if (IsMasked) { + if (PolicyAttrs.isTUMAPolicy() && !PolicyAttrs.hasMaskPolicy()) + appendPolicySuffix("_tum"); + else if (PolicyAttrs.isTAMAPolicy() && !PolicyAttrs.hasMaskPolicy()) + appendPolicySuffix("_tam"); + else if (PolicyAttrs.isMUPolicy() && !PolicyAttrs.hasTailPolicy()) + appendPolicySuffix("_mu"); + else if (PolicyAttrs.isMAPolicy() && !PolicyAttrs.hasTailPolicy()) + appendPolicySuffix("_ma"); + else if (PolicyAttrs.isTUMUPolicy()) + appendPolicySuffix("_tumu"); + else if (PolicyAttrs.isTAMUPolicy()) + appendPolicySuffix("_tamu"); + else if (PolicyAttrs.isTUMAPolicy()) + appendPolicySuffix("_tuma"); + else if (PolicyAttrs.isTAMAPolicy()) + appendPolicySuffix("_tama"); + else + llvm_unreachable("Unhandled policy condition"); + } else { + if (PolicyAttrs.isTUPolicy()) + appendPolicySuffix("_tu"); + else if (PolicyAttrs.isTAPolicy()) + appendPolicySuffix("_ta"); + else + llvm_unreachable("Unhandled policy condition"); + } + } +} + SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { SmallVector<PrototypeDescriptor> PrototypeDescriptors; const StringRef Primaries("evwqom0ztul"); @@ -993,6 +1101,10 @@ raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) { OS << (int)Record.HasMasked << ","; OS << (int)Record.HasVL << ","; OS << (int)Record.HasMaskedOffOperand << ","; + OS << (int)Record.HasTailPolicy << ","; + OS << (int)Record.HasMaskPolicy << ","; + OS << (int)Record.UnMaskedPolicyScheme << ","; + OS << (int)Record.MaskedPolicyScheme << ","; OS << "},\n"; return OS; } |