aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Support/RISCVVIntrinsicUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Support/RISCVVIntrinsicUtils.cpp')
-rw-r--r--clang/lib/Support/RISCVVIntrinsicUtils.cpp226
1 files changed, 169 insertions, 57 deletions
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
index 513e6376f5ae..25084dd98e5c 100644
--- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -8,16 +8,15 @@
#include "clang/Support/RISCVVIntrinsicUtils.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
-#include <set>
-#include <unordered_map>
+#include <optional>
using namespace llvm;
@@ -67,7 +66,7 @@ VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
}
// Illegal vscale result would be less than 1
if (Log2ScaleResult < 0)
- return llvm::None;
+ return std::nullopt;
return 1 << Log2ScaleResult;
}
@@ -114,7 +113,7 @@ bool RVVType::verifyType() const {
return false;
if (isFloat() && ElementBitwidth == 8)
return false;
- unsigned V = Scale.value();
+ unsigned V = *Scale;
switch (ElementBitwidth) {
case 1:
case 8:
@@ -364,7 +363,8 @@ void RVVType::applyBasicType() {
assert(ElementBitwidth != 0 && "Bad element bitwidth!");
}
-Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
+std::optional<PrototypeDescriptor>
+PrototypeDescriptor::parsePrototypeDescriptor(
llvm::StringRef PrototypeDescriptorStr) {
PrototypeDescriptor PD;
BaseTypeModifier PT = BaseTypeModifier::Invalid;
@@ -435,7 +435,7 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
uint32_t Log2EEW;
if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
llvm_unreachable("Invalid Log2EEW value!");
- return None;
+ return std::nullopt;
}
switch (Log2EEW) {
case 3:
@@ -452,13 +452,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
break;
default:
llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
- return None;
+ return std::nullopt;
}
} else if (ComplexTT.first == "FixedSEW") {
uint32_t NewSEW;
if (ComplexTT.second.getAsInteger(10, NewSEW)) {
llvm_unreachable("Invalid FixedSEW value!");
- return None;
+ return std::nullopt;
}
switch (NewSEW) {
case 8:
@@ -475,13 +475,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
break;
default:
llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
- return None;
+ return std::nullopt;
}
} else if (ComplexTT.first == "LFixedLog2LMUL") {
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid LFixedLog2LMUL value!");
- return None;
+ return std::nullopt;
}
switch (Log2LMUL) {
case -3:
@@ -507,13 +507,13 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
- return None;
+ return std::nullopt;
}
} else if (ComplexTT.first == "SFixedLog2LMUL") {
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid SFixedLog2LMUL value!");
- return None;
+ return std::nullopt;
}
switch (Log2LMUL) {
case -3:
@@ -539,7 +539,7 @@ Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
- return None;
+ return std::nullopt;
}
} else {
@@ -785,20 +785,20 @@ void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
Scale = LMUL.getScale(ElementBitwidth);
}
-Optional<RVVTypes>
-RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
- ArrayRef<PrototypeDescriptor> Prototype) {
+std::optional<RVVTypes>
+RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
+ ArrayRef<PrototypeDescriptor> Prototype) {
// LMUL x NF must be less than or equal to 8.
if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
- return llvm::None;
+ return std::nullopt;
RVVTypes Types;
for (const PrototypeDescriptor &Proto : Prototype) {
auto T = computeType(BT, Log2LMUL, Proto);
if (!T)
- return llvm::None;
+ return std::nullopt;
// Record legal type index
- Types.push_back(T.value());
+ Types.push_back(*T);
}
return Types;
}
@@ -816,11 +816,8 @@ static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
((uint64_t)(Proto.VTM & 0xff) << 32);
}
-Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
- PrototypeDescriptor Proto) {
- // Concat BasicType, LMUL and Proto as key
- static std::unordered_map<uint64_t, RVVType> LegalTypes;
- static std::set<uint64_t> IllegalTypes;
+std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
+ PrototypeDescriptor Proto) {
uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
// Search first
auto It = LegalTypes.find(Idx);
@@ -828,34 +825,38 @@ Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
return &(It->second);
if (IllegalTypes.count(Idx))
- return llvm::None;
+ return std::nullopt;
// Compute type and record the result.
RVVType T(BT, Log2LMUL, Proto);
if (T.isValid()) {
// Record legal type index and value.
- LegalTypes.insert({Idx, T});
- return &(LegalTypes[Idx]);
+ std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
+ InsertResult = LegalTypes.insert({Idx, T});
+ return &(InsertResult.first->second);
}
// Record illegal type index.
IllegalTypes.insert(Idx);
- return llvm::None;
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
// RVVIntrinsic implementation
//===----------------------------------------------------------------------===//
-RVVIntrinsic::RVVIntrinsic(
- StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
- StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
- bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
- bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
- const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
- const std::vector<StringRef> &RequiredFeatures, unsigned NF)
- : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
- HasUnMaskedOverloaded(HasUnMaskedOverloaded),
- HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
- NF(NF) {
+RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
+ StringRef NewOverloadedName,
+ StringRef OverloadedSuffix, StringRef IRName,
+ bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
+ PolicyScheme Scheme, bool SupportOverloading,
+ bool HasBuiltinAlias, StringRef ManualCodegen,
+ const RVVTypes &OutInTypes,
+ const std::vector<int64_t> &NewIntrinsicTypes,
+ const std::vector<StringRef> &RequiredFeatures,
+ unsigned NF, Policy NewPolicyAttrs)
+ : IRName(IRName), IsMasked(IsMasked),
+ HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
+ SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
+ ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
// Init BuiltinName, Name and OverloadedName
BuiltinName = NewName.str();
@@ -868,10 +869,9 @@ RVVIntrinsic::RVVIntrinsic(
Name += "_" + Suffix.str();
if (!OverloadedSuffix.empty())
OverloadedName += "_" + OverloadedSuffix.str();
- if (IsMasked) {
- BuiltinName += "_m";
- Name += "_m";
- }
+
+ updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
+ PolicyAttrs);
// Init OutputType and InputTypes
OutputType = OutInTypes[0];
@@ -880,7 +880,7 @@ RVVIntrinsic::RVVIntrinsic(
// IntrinsicTypes is unmasked TA version index. Need to update it
// if there is merge operand (It is always in first operand).
IntrinsicTypes = NewIntrinsicTypes;
- if ((IsMasked && HasMaskedOffOperand) ||
+ if ((IsMasked && hasMaskedOffOperand()) ||
(!IsMasked && hasPassthruOperand())) {
for (auto &I : IntrinsicTypes) {
if (I >= 0)
@@ -899,36 +899,37 @@ std::string RVVIntrinsic::getBuiltinTypeStr() const {
}
std::string RVVIntrinsic::getSuffixStr(
- BasicType Type, int Log2LMUL,
+ RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
SmallVector<std::string> SuffixStrs;
for (auto PD : PrototypeDescriptors) {
- auto T = RVVType::computeType(Type, Log2LMUL, PD);
+ auto T = TypeCache.computeType(Type, Log2LMUL, PD);
SuffixStrs.push_back((*T)->getShortStr());
}
return join(SuffixStrs, "_");
}
-llvm::SmallVector<PrototypeDescriptor>
-RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype,
- bool IsMasked, bool HasMaskedOffOperand,
- bool HasVL, unsigned NF) {
+llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
+ llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
+ bool HasMaskedOffOperand, bool HasVL, unsigned NF,
+ PolicyScheme DefaultScheme, Policy PolicyAttrs) {
SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
Prototype.end());
+ bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
if (IsMasked) {
- // If HasMaskedOffOperand, insert result type as first input operand.
- if (HasMaskedOffOperand) {
+ // If HasMaskedOffOperand, insert result type as first input operand if
+ // need.
+ if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
if (NF == 1) {
NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
- } else {
+ } else if (NF > 1) {
// Convert
// (void, op0 address, op1 address, ...)
// to
// (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
PrototypeDescriptor MaskoffType = NewPrototype[1];
MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
- for (unsigned I = 0; I < NF; ++I)
- NewPrototype.insert(NewPrototype.begin() + NF + 1, MaskoffType);
+ NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
}
}
if (HasMaskedOffOperand && NF > 1) {
@@ -943,7 +944,21 @@ RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype,
// If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
}
- }
+ } else {
+ if (NF == 1) {
+ if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
+ NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
+ } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
+ // NF > 1 cases for segment load operations.
+ // Convert
+ // (void, op0 address, op1 address, ...)
+ // to
+ // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
+ PrototypeDescriptor MaskoffType = Prototype[1];
+ MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
+ NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
+ }
+ }
// If HasVL, append PrototypeDescriptor:VL to last operand
if (HasVL)
@@ -951,6 +966,99 @@ RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype,
return NewPrototype;
}
+llvm::SmallVector<Policy>
+RVVIntrinsic::getSupportedUnMaskedPolicies(bool HasTailPolicy,
+ bool HasMaskPolicy) {
+ return {
+ Policy(Policy::PolicyType::Undisturbed, HasTailPolicy,
+ HasMaskPolicy), // TU
+ Policy(Policy::PolicyType::Agnostic, HasTailPolicy, HasMaskPolicy)}; // TA
+}
+
+llvm::SmallVector<Policy>
+RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
+ bool HasMaskPolicy) {
+ if (HasTailPolicy && HasMaskPolicy)
+ return {
+ Policy(Policy::PolicyType::Undisturbed, Policy::PolicyType::Agnostic,
+ HasTailPolicy, HasMaskPolicy), // TUMA
+ Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic,
+ HasTailPolicy, HasMaskPolicy), // TAMA
+ Policy(Policy::PolicyType::Undisturbed, Policy::PolicyType::Undisturbed,
+ HasTailPolicy, HasMaskPolicy), // TUMU
+ Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Undisturbed,
+ HasTailPolicy, HasMaskPolicy)}; // TAMU
+ if (HasTailPolicy && !HasMaskPolicy)
+ return {Policy(Policy::PolicyType::Undisturbed,
+ Policy::PolicyType::Agnostic, HasTailPolicy,
+ HasMaskPolicy), // TUM
+ Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic,
+ HasTailPolicy, HasMaskPolicy)}; // TAM
+ if (!HasTailPolicy && HasMaskPolicy)
+ return {Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Agnostic,
+ HasTailPolicy, HasMaskPolicy), // MA
+ Policy(Policy::PolicyType::Agnostic,
+ Policy::PolicyType::Undisturbed, HasTailPolicy,
+ HasMaskPolicy)}; // MU
+ llvm_unreachable("An RVV instruction should not be without both tail policy "
+ "and mask policy");
+}
+
+void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
+ std::string &Name,
+ std::string &BuiltinName,
+ std::string &OverloadedName,
+ Policy &PolicyAttrs) {
+
+ auto appendPolicySuffix = [&](const std::string &suffix) {
+ Name += suffix;
+ BuiltinName += suffix;
+ OverloadedName += suffix;
+ };
+
+ if (PolicyAttrs.isUnspecified()) {
+ PolicyAttrs.IsUnspecified = false;
+ if (IsMasked) {
+ Name += "_m";
+ if (HasPolicy)
+ BuiltinName += "_tama";
+ else
+ BuiltinName += "_m";
+ } else {
+ if (HasPolicy)
+ BuiltinName += "_ta";
+ }
+ } else {
+ if (IsMasked) {
+ if (PolicyAttrs.isTUMAPolicy() && !PolicyAttrs.hasMaskPolicy())
+ appendPolicySuffix("_tum");
+ else if (PolicyAttrs.isTAMAPolicy() && !PolicyAttrs.hasMaskPolicy())
+ appendPolicySuffix("_tam");
+ else if (PolicyAttrs.isMUPolicy() && !PolicyAttrs.hasTailPolicy())
+ appendPolicySuffix("_mu");
+ else if (PolicyAttrs.isMAPolicy() && !PolicyAttrs.hasTailPolicy())
+ appendPolicySuffix("_ma");
+ else if (PolicyAttrs.isTUMUPolicy())
+ appendPolicySuffix("_tumu");
+ else if (PolicyAttrs.isTAMUPolicy())
+ appendPolicySuffix("_tamu");
+ else if (PolicyAttrs.isTUMAPolicy())
+ appendPolicySuffix("_tuma");
+ else if (PolicyAttrs.isTAMAPolicy())
+ appendPolicySuffix("_tama");
+ else
+ llvm_unreachable("Unhandled policy condition");
+ } else {
+ if (PolicyAttrs.isTUPolicy())
+ appendPolicySuffix("_tu");
+ else if (PolicyAttrs.isTAPolicy())
+ appendPolicySuffix("_ta");
+ else
+ llvm_unreachable("Unhandled policy condition");
+ }
+ }
+}
+
SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
SmallVector<PrototypeDescriptor> PrototypeDescriptors;
const StringRef Primaries("evwqom0ztul");
@@ -993,6 +1101,10 @@ raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
OS << (int)Record.HasMasked << ",";
OS << (int)Record.HasVL << ",";
OS << (int)Record.HasMaskedOffOperand << ",";
+ OS << (int)Record.HasTailPolicy << ",";
+ OS << (int)Record.HasMaskPolicy << ",";
+ OS << (int)Record.UnMaskedPolicyScheme << ",";
+ OS << (int)Record.MaskedPolicyScheme << ",";
OS << "},\n";
return OS;
}