diff options
Diffstat (limited to 'utils/TableGen/CodeGenDAGPatterns.cpp')
-rw-r--r-- | utils/TableGen/CodeGenDAGPatterns.cpp | 2170 |
1 files changed, 1402 insertions, 768 deletions
diff --git a/utils/TableGen/CodeGenDAGPatterns.cpp b/utils/TableGen/CodeGenDAGPatterns.cpp index e48ba3845326..51473f06da79 100644 --- a/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/utils/TableGen/CodeGenDAGPatterns.cpp @@ -13,9 +13,12 @@ //===----------------------------------------------------------------------===// #include "CodeGenDAGPatterns.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -28,752 +31,1116 @@ using namespace llvm; #define DEBUG_TYPE "dag-patterns" -//===----------------------------------------------------------------------===// -// EEVT::TypeSet Implementation -//===----------------------------------------------------------------------===// - -static inline bool isInteger(MVT::SimpleValueType VT) { - return MVT(VT).isInteger(); +static inline bool isIntegerOrPtr(MVT VT) { + return VT.isInteger() || VT == MVT::iPTR; } -static inline bool isFloatingPoint(MVT::SimpleValueType VT) { - return MVT(VT).isFloatingPoint(); +static inline bool isFloatingPoint(MVT VT) { + return VT.isFloatingPoint(); } -static inline bool isVector(MVT::SimpleValueType VT) { - return MVT(VT).isVector(); +static inline bool isVector(MVT VT) { + return VT.isVector(); } -static inline bool isScalar(MVT::SimpleValueType VT) { - return !MVT(VT).isVector(); +static inline bool isScalar(MVT VT) { + return !VT.isVector(); } -EEVT::TypeSet::TypeSet(MVT::SimpleValueType VT, TreePattern &TP) { - if (VT == MVT::iAny) - EnforceInteger(TP); - else if (VT == MVT::fAny) - EnforceFloatingPoint(TP); - else if (VT == MVT::vAny) - EnforceVector(TP); - else { - assert((VT < MVT::LAST_VALUETYPE || VT == MVT::iPTR || - VT == MVT::iPTRAny || VT == MVT::Any) && "Not a concrete type!"); - TypeVec.push_back(VT); +template <typename Predicate> +static bool berase_if(MachineValueTypeSet &S, Predicate P) { + bool Erased = false; + // It is ok to iterate over MachineValueTypeSet and remove elements from it + // at the same time. + for (MVT T : S) { + if (!P(T)) + continue; + Erased = true; + S.erase(T); } + return Erased; } +// --- TypeSetByHwMode -EEVT::TypeSet::TypeSet(ArrayRef<MVT::SimpleValueType> VTList) { - assert(!VTList.empty() && "empty list?"); - TypeVec.append(VTList.begin(), VTList.end()); - - if (!VTList.empty()) - assert(VTList[0] != MVT::iAny && VTList[0] != MVT::vAny && - VTList[0] != MVT::fAny); +// This is a parameterized type-set class. For each mode there is a list +// of types that are currently possible for a given tree node. Type +// inference will apply to each mode separately. - // Verify no duplicates. - array_pod_sort(TypeVec.begin(), TypeVec.end()); - assert(std::unique(TypeVec.begin(), TypeVec.end()) == TypeVec.end()); +TypeSetByHwMode::TypeSetByHwMode(ArrayRef<ValueTypeByHwMode> VTList) { + for (const ValueTypeByHwMode &VVT : VTList) + insert(VVT); } -/// FillWithPossibleTypes - Set to all legal types and return true, only valid -/// on completely unknown type sets. -bool EEVT::TypeSet::FillWithPossibleTypes(TreePattern &TP, - bool (*Pred)(MVT::SimpleValueType), - const char *PredicateName) { - assert(isCompletelyUnknown()); - ArrayRef<MVT::SimpleValueType> LegalTypes = - TP.getDAGPatterns().getTargetInfo().getLegalValueTypes(); - - if (TP.hasError()) - return false; - - for (MVT::SimpleValueType VT : LegalTypes) - if (!Pred || Pred(VT)) - TypeVec.push_back(VT); - - // If we have nothing that matches the predicate, bail out. - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, no " + - std::string(PredicateName) + " types found"); - return false; +bool TypeSetByHwMode::isValueTypeByHwMode(bool AllowEmpty) const { + for (const auto &I : *this) { + if (I.second.size() > 1) + return false; + if (!AllowEmpty && I.second.empty()) + return false; } - // No need to sort with one element. - if (TypeVec.size() == 1) return true; - - // Remove duplicates. - array_pod_sort(TypeVec.begin(), TypeVec.end()); - TypeVec.erase(std::unique(TypeVec.begin(), TypeVec.end()), TypeVec.end()); - return true; } -/// hasIntegerTypes - Return true if this TypeSet contains iAny or an -/// integer value type. -bool EEVT::TypeSet::hasIntegerTypes() const { - return any_of(TypeVec, isInteger); +ValueTypeByHwMode TypeSetByHwMode::getValueTypeByHwMode() const { + assert(isValueTypeByHwMode(true) && + "The type set has multiple types for at least one HW mode"); + ValueTypeByHwMode VVT; + for (const auto &I : *this) { + MVT T = I.second.empty() ? MVT::Other : *I.second.begin(); + VVT.getOrCreateTypeForMode(I.first, T); + } + return VVT; } -/// hasFloatingPointTypes - Return true if this TypeSet contains an fAny or -/// a floating point value type. -bool EEVT::TypeSet::hasFloatingPointTypes() const { - return any_of(TypeVec, isFloatingPoint); +bool TypeSetByHwMode::isPossible() const { + for (const auto &I : *this) + if (!I.second.empty()) + return true; + return false; } -/// hasScalarTypes - Return true if this TypeSet contains a scalar value type. -bool EEVT::TypeSet::hasScalarTypes() const { - return any_of(TypeVec, isScalar); +bool TypeSetByHwMode::insert(const ValueTypeByHwMode &VVT) { + bool Changed = false; + SmallDenseSet<unsigned, 4> Modes; + for (const auto &P : VVT) { + unsigned M = P.first; + Modes.insert(M); + // Make sure there exists a set for each specific mode from VVT. + Changed |= getOrCreate(M).insert(P.second).second; + } + + // If VVT has a default mode, add the corresponding type to all + // modes in "this" that do not exist in VVT. + if (Modes.count(DefaultMode)) { + MVT DT = VVT.getType(DefaultMode); + for (auto &I : *this) + if (!Modes.count(I.first)) + Changed |= I.second.insert(DT).second; + } + return Changed; } -/// hasVectorTypes - Return true if this TypeSet contains a vAny or a vector -/// value type. -bool EEVT::TypeSet::hasVectorTypes() const { - return any_of(TypeVec, isVector); +// Constrain the type set to be the intersection with VTS. +bool TypeSetByHwMode::constrain(const TypeSetByHwMode &VTS) { + bool Changed = false; + if (hasDefault()) { + for (const auto &I : VTS) { + unsigned M = I.first; + if (M == DefaultMode || hasMode(M)) + continue; + Map.insert({M, Map.at(DefaultMode)}); + Changed = true; + } + } + + for (auto &I : *this) { + unsigned M = I.first; + SetType &S = I.second; + if (VTS.hasMode(M) || VTS.hasDefault()) { + Changed |= intersect(I.second, VTS.get(M)); + } else if (!S.empty()) { + S.clear(); + Changed = true; + } + } + return Changed; } +template <typename Predicate> +bool TypeSetByHwMode::constrain(Predicate P) { + bool Changed = false; + for (auto &I : *this) + Changed |= berase_if(I.second, [&P](MVT VT) { return !P(VT); }); + return Changed; +} -std::string EEVT::TypeSet::getName() const { - if (TypeVec.empty()) return "<empty>"; +template <typename Predicate> +bool TypeSetByHwMode::assign_if(const TypeSetByHwMode &VTS, Predicate P) { + assert(empty()); + for (const auto &I : VTS) { + SetType &S = getOrCreate(I.first); + for (auto J : I.second) + if (P(J)) + S.insert(J); + } + return !empty(); +} - std::string Result; +void TypeSetByHwMode::writeToStream(raw_ostream &OS) const { + SmallVector<unsigned, 4> Modes; + Modes.reserve(Map.size()); - for (unsigned i = 0, e = TypeVec.size(); i != e; ++i) { - std::string VTName = llvm::getEnumName(TypeVec[i]); - // Strip off MVT:: prefix if present. - if (VTName.substr(0,5) == "MVT::") - VTName = VTName.substr(5); - if (i) Result += ':'; - Result += VTName; + for (const auto &I : *this) + Modes.push_back(I.first); + if (Modes.empty()) { + OS << "{}"; + return; } + array_pod_sort(Modes.begin(), Modes.end()); - if (TypeVec.size() == 1) - return Result; - return "{" + Result + "}"; + OS << '{'; + for (unsigned M : Modes) { + OS << ' ' << getModeName(M) << ':'; + writeToStream(get(M), OS); + } + OS << " }"; } -/// MergeInTypeInfo - This merges in type information from the specified -/// argument. If 'this' changes, it returns true. If the two types are -/// contradictory (e.g. merge f32 into i32) then this flags an error. -bool EEVT::TypeSet::MergeInTypeInfo(const EEVT::TypeSet &InVT, TreePattern &TP){ - if (InVT.isCompletelyUnknown() || *this == InVT || TP.hasError()) - return false; +void TypeSetByHwMode::writeToStream(const SetType &S, raw_ostream &OS) { + SmallVector<MVT, 4> Types(S.begin(), S.end()); + array_pod_sort(Types.begin(), Types.end()); - if (isCompletelyUnknown()) { - *this = InVT; - return true; + OS << '['; + for (unsigned i = 0, e = Types.size(); i != e; ++i) { + OS << ValueTypeByHwMode::getMVTName(Types[i]); + if (i != e-1) + OS << ' '; } + OS << ']'; +} - assert(!TypeVec.empty() && !InVT.TypeVec.empty() && "No unknowns"); +bool TypeSetByHwMode::operator==(const TypeSetByHwMode &VTS) const { + bool HaveDefault = hasDefault(); + if (HaveDefault != VTS.hasDefault()) + return false; - // Handle the abstract cases, seeing if we can resolve them better. - switch (TypeVec[0]) { - default: break; - case MVT::iPTR: - case MVT::iPTRAny: - if (InVT.hasIntegerTypes()) { - EEVT::TypeSet InCopy(InVT); - InCopy.EnforceInteger(TP); - InCopy.EnforceScalar(TP); + if (isSimple()) { + if (VTS.isSimple()) + return *begin() == *VTS.begin(); + return false; + } - if (InCopy.isConcrete()) { - // If the RHS has one integer type, upgrade iPTR to i32. - TypeVec[0] = InVT.TypeVec[0]; - return true; - } + SmallDenseSet<unsigned, 4> Modes; + for (auto &I : *this) + Modes.insert(I.first); + for (const auto &I : VTS) + Modes.insert(I.first); - // If the input has multiple scalar integers, this doesn't add any info. - if (!InCopy.isCompletelyUnknown()) + if (HaveDefault) { + // Both sets have default mode. + for (unsigned M : Modes) { + if (get(M) != VTS.get(M)) + return false; + } + } else { + // Neither set has default mode. + for (unsigned M : Modes) { + // If there is no default mode, an empty set is equivalent to not having + // the corresponding mode. + bool NoModeThis = !hasMode(M) || get(M).empty(); + bool NoModeVTS = !VTS.hasMode(M) || VTS.get(M).empty(); + if (NoModeThis != NoModeVTS) return false; + if (!NoModeThis) + if (get(M) != VTS.get(M)) + return false; } - break; } - // If the input constraint is iAny/iPTR and this is an integer type list, - // remove non-integer types from the list. - if ((InVT.TypeVec[0] == MVT::iPTR || InVT.TypeVec[0] == MVT::iPTRAny) && - hasIntegerTypes()) { - bool MadeChange = EnforceInteger(TP); + return true; +} - // If we're merging in iPTR/iPTRAny and the node currently has a list of - // multiple different integer types, replace them with a single iPTR. - if ((InVT.TypeVec[0] == MVT::iPTR || InVT.TypeVec[0] == MVT::iPTRAny) && - TypeVec.size() != 1) { - TypeVec.assign(1, InVT.TypeVec[0]); - MadeChange = true; - } +namespace llvm { + raw_ostream &operator<<(raw_ostream &OS, const TypeSetByHwMode &T) { + T.writeToStream(OS); + return OS; + } +} - return MadeChange; +LLVM_DUMP_METHOD +void TypeSetByHwMode::dump() const { + dbgs() << *this << '\n'; +} + +bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) { + bool OutP = Out.count(MVT::iPTR), InP = In.count(MVT::iPTR); + auto Int = [&In](MVT T) -> bool { return !In.count(T); }; + + if (OutP == InP) + return berase_if(Out, Int); + + // Compute the intersection of scalars separately to account for only + // one set containing iPTR. + // The itersection of iPTR with a set of integer scalar types that does not + // include iPTR will result in the most specific scalar type: + // - iPTR is more specific than any set with two elements or more + // - iPTR is less specific than any single integer scalar type. + // For example + // { iPTR } * { i32 } -> { i32 } + // { iPTR } * { i32 i64 } -> { iPTR } + // and + // { iPTR i32 } * { i32 } -> { i32 } + // { iPTR i32 } * { i32 i64 } -> { i32 i64 } + // { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 } + + // Compute the difference between the two sets in such a way that the + // iPTR is in the set that is being subtracted. This is to see if there + // are any extra scalars in the set without iPTR that are not in the + // set containing iPTR. Then the iPTR could be considered a "wildcard" + // matching these scalars. If there is only one such scalar, it would + // replace the iPTR, if there are more, the iPTR would be retained. + SetType Diff; + if (InP) { + Diff = Out; + berase_if(Diff, [&In](MVT T) { return In.count(T); }); + // Pre-remove these elements and rely only on InP/OutP to determine + // whether a change has been made. + berase_if(Out, [&Diff](MVT T) { return Diff.count(T); }); + } else { + Diff = In; + berase_if(Diff, [&Out](MVT T) { return Out.count(T); }); + Out.erase(MVT::iPTR); + } + + // The actual intersection. + bool Changed = berase_if(Out, Int); + unsigned NumD = Diff.size(); + if (NumD == 0) + return Changed; + + if (NumD == 1) { + Out.insert(*Diff.begin()); + // This is a change only if Out was the one with iPTR (which is now + // being replaced). + Changed |= OutP; + } else { + // Multiple elements from Out are now replaced with iPTR. + Out.insert(MVT::iPTR); + Changed |= !OutP; } + return Changed; +} - // If this is a type list and the RHS is a typelist as well, eliminate entries - // from this list that aren't in the other one. - TypeSet InputSet(*this); +void TypeSetByHwMode::validate() const { +#ifndef NDEBUG + if (empty()) + return; + bool AllEmpty = true; + for (const auto &I : *this) + AllEmpty &= I.second.empty(); + assert(!AllEmpty && + "type set is empty for each HW mode: type contradiction?"); +#endif +} - TypeVec.clear(); - std::set_intersection(InputSet.TypeVec.begin(), InputSet.TypeVec.end(), - InVT.TypeVec.begin(), InVT.TypeVec.end(), - std::back_inserter(TypeVec)); +// --- TypeInfer - // If the intersection is the same size as the original set then we're done. - if (TypeVec.size() == InputSet.TypeVec.size()) +bool TypeInfer::MergeInTypeInfo(TypeSetByHwMode &Out, + const TypeSetByHwMode &In) { + ValidateOnExit _1(Out); + In.validate(); + if (In.empty() || Out == In || TP.hasError()) return false; - - // If we removed all of our types, we have a type contradiction. - if (!TypeVec.empty()) + if (Out.empty()) { + Out = In; return true; + } - // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, merging '" + - InVT.getName() + "' into '" + InputSet.getName() + "'"); - return false; + bool Changed = Out.constrain(In); + if (Changed && Out.empty()) + TP.error("Type contradiction"); + + return Changed; } -/// EnforceInteger - Remove all non-integer types from this set. -bool EEVT::TypeSet::EnforceInteger(TreePattern &TP) { +bool TypeInfer::forceArbitrary(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); if (TP.hasError()) return false; - // If we know nothing, then get the full set. - if (TypeVec.empty()) - return FillWithPossibleTypes(TP, isInteger, "integer"); - - if (!hasFloatingPointTypes()) - return false; + assert(!Out.empty() && "cannot pick from an empty set"); - TypeSet InputSet(*this); - - // Filter out all the fp types. - TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isInteger))), - TypeVec.end()); - - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + - InputSet.getName() + "' needs to be integer"); - return false; + bool Changed = false; + for (auto &I : Out) { + TypeSetByHwMode::SetType &S = I.second; + if (S.size() <= 1) + continue; + MVT T = *S.begin(); // Pick the first element. + S.clear(); + S.insert(T); + Changed = true; } - return true; + return Changed; } -/// EnforceFloatingPoint - Remove all integer types from this set. -bool EEVT::TypeSet::EnforceFloatingPoint(TreePattern &TP) { +bool TypeInfer::EnforceInteger(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); if (TP.hasError()) return false; - // If we know nothing, then get the full set. - if (TypeVec.empty()) - return FillWithPossibleTypes(TP, isFloatingPoint, "floating point"); + if (!Out.empty()) + return Out.constrain(isIntegerOrPtr); - if (!hasIntegerTypes()) - return false; + return Out.assign_if(getLegalTypes(), isIntegerOrPtr); +} - TypeSet InputSet(*this); +bool TypeInfer::EnforceFloatingPoint(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); + if (TP.hasError()) + return false; + if (!Out.empty()) + return Out.constrain(isFloatingPoint); - // Filter out all the integer types. - TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isFloatingPoint))), - TypeVec.end()); + return Out.assign_if(getLegalTypes(), isFloatingPoint); +} - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + - InputSet.getName() + "' needs to be floating point"); +bool TypeInfer::EnforceScalar(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); + if (TP.hasError()) return false; - } - return true; + if (!Out.empty()) + return Out.constrain(isScalar); + + return Out.assign_if(getLegalTypes(), isScalar); } -/// EnforceScalar - Remove all vector types from this. -bool EEVT::TypeSet::EnforceScalar(TreePattern &TP) { +bool TypeInfer::EnforceVector(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); if (TP.hasError()) return false; + if (!Out.empty()) + return Out.constrain(isVector); - // If we know nothing, then get the full set. - if (TypeVec.empty()) - return FillWithPossibleTypes(TP, isScalar, "scalar"); + return Out.assign_if(getLegalTypes(), isVector); +} - if (!hasVectorTypes()) +bool TypeInfer::EnforceAny(TypeSetByHwMode &Out) { + ValidateOnExit _1(Out); + if (TP.hasError() || !Out.empty()) return false; - TypeSet InputSet(*this); + Out = getLegalTypes(); + return true; +} - // Filter out all the vector types. - TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isScalar))), - TypeVec.end()); +template <typename Iter, typename Pred, typename Less> +static Iter min_if(Iter B, Iter E, Pred P, Less L) { + if (B == E) + return E; + Iter Min = E; + for (Iter I = B; I != E; ++I) { + if (!P(*I)) + continue; + if (Min == E || L(*I, *Min)) + Min = I; + } + return Min; +} - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + - InputSet.getName() + "' needs to be scalar"); - return false; +template <typename Iter, typename Pred, typename Less> +static Iter max_if(Iter B, Iter E, Pred P, Less L) { + if (B == E) + return E; + Iter Max = E; + for (Iter I = B; I != E; ++I) { + if (!P(*I)) + continue; + if (Max == E || L(*Max, *I)) + Max = I; } - return true; + return Max; } -/// EnforceVector - Remove all vector types from this. -bool EEVT::TypeSet::EnforceVector(TreePattern &TP) { +/// Make sure that for each type in Small, there exists a larger type in Big. +bool TypeInfer::EnforceSmallerThan(TypeSetByHwMode &Small, + TypeSetByHwMode &Big) { + ValidateOnExit _1(Small), _2(Big); if (TP.hasError()) return false; + bool Changed = false; + + if (Small.empty()) + Changed |= EnforceAny(Small); + if (Big.empty()) + Changed |= EnforceAny(Big); + + assert(Small.hasDefault() && Big.hasDefault()); + + std::vector<unsigned> Modes = union_modes(Small, Big); + + // 1. Only allow integer or floating point types and make sure that + // both sides are both integer or both floating point. + // 2. Make sure that either both sides have vector types, or neither + // of them does. + for (unsigned M : Modes) { + TypeSetByHwMode::SetType &S = Small.get(M); + TypeSetByHwMode::SetType &B = Big.get(M); + + if (any_of(S, isIntegerOrPtr) && any_of(S, isIntegerOrPtr)) { + auto NotInt = [](MVT VT) { return !isIntegerOrPtr(VT); }; + Changed |= berase_if(S, NotInt) | + berase_if(B, NotInt); + } else if (any_of(S, isFloatingPoint) && any_of(B, isFloatingPoint)) { + auto NotFP = [](MVT VT) { return !isFloatingPoint(VT); }; + Changed |= berase_if(S, NotFP) | + berase_if(B, NotFP); + } else if (S.empty() || B.empty()) { + Changed = !S.empty() || !B.empty(); + S.clear(); + B.clear(); + } else { + TP.error("Incompatible types"); + return Changed; + } - // If we know nothing, then get the full set. - if (TypeVec.empty()) - return FillWithPossibleTypes(TP, isVector, "vector"); - - TypeSet InputSet(*this); - bool MadeChange = false; - - // Filter out all the scalar types. - TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isVector))), - TypeVec.end()); - - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + - InputSet.getName() + "' needs to be a vector"); - return false; + if (none_of(S, isVector) || none_of(B, isVector)) { + Changed |= berase_if(S, isVector) | + berase_if(B, isVector); + } } - return MadeChange; -} + auto LT = [](MVT A, MVT B) -> bool { + return A.getScalarSizeInBits() < B.getScalarSizeInBits() || + (A.getScalarSizeInBits() == B.getScalarSizeInBits() && + A.getSizeInBits() < B.getSizeInBits()); + }; + auto LE = [](MVT A, MVT B) -> bool { + // This function is used when removing elements: when a vector is compared + // to a non-vector, it should return false (to avoid removal). + if (A.isVector() != B.isVector()) + return false; + // Note on the < comparison below: + // X86 has patterns like + // (set VR128X:$dst, (v16i8 (X86vtrunc (v4i32 VR128X:$src1)))), + // where the truncated vector is given a type v16i8, while the source + // vector has type v4i32. They both have the same size in bits. + // The minimal type in the result is obviously v16i8, and when we remove + // all types from the source that are smaller-or-equal than v8i16, the + // only source type would also be removed (since it's equal in size). + return A.getScalarSizeInBits() <= B.getScalarSizeInBits() || + A.getSizeInBits() < B.getSizeInBits(); + }; + + for (unsigned M : Modes) { + TypeSetByHwMode::SetType &S = Small.get(M); + TypeSetByHwMode::SetType &B = Big.get(M); + // MinS = min scalar in Small, remove all scalars from Big that are + // smaller-or-equal than MinS. + auto MinS = min_if(S.begin(), S.end(), isScalar, LT); + if (MinS != S.end()) + Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinS)); + + // MaxS = max scalar in Big, remove all scalars from Small that are + // larger than MaxS. + auto MaxS = max_if(B.begin(), B.end(), isScalar, LT); + if (MaxS != B.end()) + Changed |= berase_if(S, std::bind(LE, *MaxS, std::placeholders::_1)); + + // MinV = min vector in Small, remove all vectors from Big that are + // smaller-or-equal than MinV. + auto MinV = min_if(S.begin(), S.end(), isVector, LT); + if (MinV != S.end()) + Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinV)); + + // MaxV = max vector in Big, remove all vectors from Small that are + // larger than MaxV. + auto MaxV = max_if(B.begin(), B.end(), isVector, LT); + if (MaxV != B.end()) + Changed |= berase_if(S, std::bind(LE, *MaxV, std::placeholders::_1)); + } + + return Changed; +} -/// EnforceSmallerThan - 'this' must be a smaller VT than Other. For vectors -/// this should be based on the element type. Update this and other based on -/// this information. -bool EEVT::TypeSet::EnforceSmallerThan(EEVT::TypeSet &Other, TreePattern &TP) { +/// 1. Ensure that for each type T in Vec, T is a vector type, and that +/// for each type U in Elem, U is a scalar type. +/// 2. Ensure that for each (scalar) type U in Elem, there exists a (vector) +/// type T in Vec, such that U is the element type of T. +bool TypeInfer::EnforceVectorEltTypeIs(TypeSetByHwMode &Vec, + TypeSetByHwMode &Elem) { + ValidateOnExit _1(Vec), _2(Elem); if (TP.hasError()) return false; + bool Changed = false; + + if (Vec.empty()) + Changed |= EnforceVector(Vec); + if (Elem.empty()) + Changed |= EnforceScalar(Elem); + + for (unsigned M : union_modes(Vec, Elem)) { + TypeSetByHwMode::SetType &V = Vec.get(M); + TypeSetByHwMode::SetType &E = Elem.get(M); + + Changed |= berase_if(V, isScalar); // Scalar = !vector + Changed |= berase_if(E, isVector); // Vector = !scalar + assert(!V.empty() && !E.empty()); + + SmallSet<MVT,4> VT, ST; + // Collect element types from the "vector" set. + for (MVT T : V) + VT.insert(T.getVectorElementType()); + // Collect scalar types from the "element" set. + for (MVT T : E) + ST.insert(T); + + // Remove from V all (vector) types whose element type is not in S. + Changed |= berase_if(V, [&ST](MVT T) -> bool { + return !ST.count(T.getVectorElementType()); + }); + // Remove from E all (scalar) types, for which there is no corresponding + // type in V. + Changed |= berase_if(E, [&VT](MVT T) -> bool { return !VT.count(T); }); + } + + return Changed; +} - // Both operands must be integer or FP, but we don't care which. - bool MadeChange = false; - - if (isCompletelyUnknown()) - MadeChange = FillWithPossibleTypes(TP); - - if (Other.isCompletelyUnknown()) - MadeChange = Other.FillWithPossibleTypes(TP); - - // If one side is known to be integer or known to be FP but the other side has - // no information, get at least the type integrality info in there. - if (!hasFloatingPointTypes()) - MadeChange |= Other.EnforceInteger(TP); - else if (!hasIntegerTypes()) - MadeChange |= Other.EnforceFloatingPoint(TP); - if (!Other.hasFloatingPointTypes()) - MadeChange |= EnforceInteger(TP); - else if (!Other.hasIntegerTypes()) - MadeChange |= EnforceFloatingPoint(TP); - - assert(!isCompletelyUnknown() && !Other.isCompletelyUnknown() && - "Should have a type list now"); - - // If one contains vectors but the other doesn't pull vectors out. - if (!hasVectorTypes()) - MadeChange |= Other.EnforceScalar(TP); - else if (!hasScalarTypes()) - MadeChange |= Other.EnforceVector(TP); - if (!Other.hasVectorTypes()) - MadeChange |= EnforceScalar(TP); - else if (!Other.hasScalarTypes()) - MadeChange |= EnforceVector(TP); - - // This code does not currently handle nodes which have multiple types, - // where some types are integer, and some are fp. Assert that this is not - // the case. - assert(!(hasIntegerTypes() && hasFloatingPointTypes()) && - !(Other.hasIntegerTypes() && Other.hasFloatingPointTypes()) && - "SDTCisOpSmallerThanOp does not handle mixed int/fp types!"); +bool TypeInfer::EnforceVectorEltTypeIs(TypeSetByHwMode &Vec, + const ValueTypeByHwMode &VVT) { + TypeSetByHwMode Tmp(VVT); + ValidateOnExit _1(Vec), _2(Tmp); + return EnforceVectorEltTypeIs(Vec, Tmp); +} +/// Ensure that for each type T in Sub, T is a vector type, and there +/// exists a type U in Vec such that U is a vector type with the same +/// element type as T and at least as many elements as T. +bool TypeInfer::EnforceVectorSubVectorTypeIs(TypeSetByHwMode &Vec, + TypeSetByHwMode &Sub) { + ValidateOnExit _1(Vec), _2(Sub); if (TP.hasError()) return false; - // Okay, find the smallest type from current set and remove anything the - // same or smaller from the other set. We need to ensure that the scalar - // type size is smaller than the scalar size of the smallest type. For - // vectors, we also need to make sure that the total size is no larger than - // the size of the smallest type. - { - TypeSet InputSet(Other); - MVT Smallest = *std::min_element(TypeVec.begin(), TypeVec.end(), - [](MVT A, MVT B) { - return A.getScalarSizeInBits() < B.getScalarSizeInBits() || - (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); - }); - - auto I = remove_if(Other.TypeVec, [Smallest](MVT OtherVT) { - // Don't compare vector and non-vector types. - if (OtherVT.isVector() != Smallest.isVector()) - return false; - // The getSizeInBits() check here is only needed for vectors, but is - // a subset of the scalar check for scalars so no need to qualify. - return OtherVT.getScalarSizeInBits() <= Smallest.getScalarSizeInBits() || - OtherVT.getSizeInBits() < Smallest.getSizeInBits(); - }); - MadeChange |= I != Other.TypeVec.end(); // If we're about to remove types. - Other.TypeVec.erase(I, Other.TypeVec.end()); - - if (Other.TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + InputSet.getName() + - "' has nothing larger than '" + getName() +"'!"); + /// Return true if B is a suB-vector of P, i.e. P is a suPer-vector of B. + auto IsSubVec = [](MVT B, MVT P) -> bool { + if (!B.isVector() || !P.isVector()) return false; - } - } + // Logically a <4 x i32> is a valid subvector of <n x 4 x i32> + // but until there are obvious use-cases for this, keep the + // types separate. + if (B.isScalableVector() != P.isScalableVector()) + return false; + if (B.getVectorElementType() != P.getVectorElementType()) + return false; + return B.getVectorNumElements() < P.getVectorNumElements(); + }; + + /// Return true if S has no element (vector type) that T is a sub-vector of, + /// i.e. has the same element type as T and more elements. + auto NoSubV = [&IsSubVec](const TypeSetByHwMode::SetType &S, MVT T) -> bool { + for (const auto &I : S) + if (IsSubVec(T, I)) + return false; + return true; + }; - // Okay, find the largest type from the other set and remove anything the - // same or smaller from the current set. We need to ensure that the scalar - // type size is larger than the scalar size of the largest type. For - // vectors, we also need to make sure that the total size is no smaller than - // the size of the largest type. - { - TypeSet InputSet(*this); - MVT Largest = *std::max_element(Other.TypeVec.begin(), Other.TypeVec.end(), - [](MVT A, MVT B) { - return A.getScalarSizeInBits() < B.getScalarSizeInBits() || - (A.getScalarSizeInBits() == B.getScalarSizeInBits() && - A.getSizeInBits() < B.getSizeInBits()); - }); - auto I = remove_if(TypeVec, [Largest](MVT OtherVT) { - // Don't compare vector and non-vector types. - if (OtherVT.isVector() != Largest.isVector()) + /// Return true if S has no element (vector type) that T is a super-vector + /// of, i.e. has the same element type as T and fewer elements. + auto NoSupV = [&IsSubVec](const TypeSetByHwMode::SetType &S, MVT T) -> bool { + for (const auto &I : S) + if (IsSubVec(I, T)) return false; - return OtherVT.getScalarSizeInBits() >= Largest.getScalarSizeInBits() || - OtherVT.getSizeInBits() > Largest.getSizeInBits(); - }); - MadeChange |= I != TypeVec.end(); // If we're about to remove types. - TypeVec.erase(I, TypeVec.end()); - - if (TypeVec.empty()) { - TP.error("Type inference contradiction found, '" + InputSet.getName() + - "' has nothing smaller than '" + Other.getName() +"'!"); - return false; - } - } + return true; + }; - return MadeChange; -} + bool Changed = false; -/// EnforceVectorEltTypeIs - 'this' is now constrained to be a vector type -/// whose element is specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorEltTypeIs(MVT::SimpleValueType VT, - TreePattern &TP) { - bool MadeChange = false; + if (Vec.empty()) + Changed |= EnforceVector(Vec); + if (Sub.empty()) + Changed |= EnforceVector(Sub); - MadeChange |= EnforceVector(TP); + for (unsigned M : union_modes(Vec, Sub)) { + TypeSetByHwMode::SetType &S = Sub.get(M); + TypeSetByHwMode::SetType &V = Vec.get(M); - TypeSet InputSet(*this); + Changed |= berase_if(S, isScalar); - // Filter out all the types which don't have the right element type. - auto I = remove_if(TypeVec, [VT](MVT VVT) { - return VVT.getVectorElementType().SimpleTy != VT; - }); - MadeChange |= I != TypeVec.end(); - TypeVec.erase(I, TypeVec.end()); + // Erase all types from S that are not sub-vectors of a type in V. + Changed |= berase_if(S, std::bind(NoSubV, V, std::placeholders::_1)); - if (TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have a vector element of type " + - getEnumName(VT)); - return false; + // Erase all types from V that are not super-vectors of a type in S. + Changed |= berase_if(V, std::bind(NoSupV, S, std::placeholders::_1)); } - return MadeChange; + return Changed; } -/// EnforceVectorEltTypeIs - 'this' is now constrained to be a vector type -/// whose element is specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorEltTypeIs(EEVT::TypeSet &VTOperand, - TreePattern &TP) { +/// 1. Ensure that V has a scalar type iff W has a scalar type. +/// 2. Ensure that for each vector type T in V, there exists a vector +/// type U in W, such that T and U have the same number of elements. +/// 3. Ensure that for each vector type U in W, there exists a vector +/// type T in V, such that T and U have the same number of elements +/// (reverse of 2). +bool TypeInfer::EnforceSameNumElts(TypeSetByHwMode &V, TypeSetByHwMode &W) { + ValidateOnExit _1(V), _2(W); if (TP.hasError()) return false; - // "This" must be a vector and "VTOperand" must be a scalar. - bool MadeChange = false; - MadeChange |= EnforceVector(TP); - MadeChange |= VTOperand.EnforceScalar(TP); - - // If we know the vector type, it forces the scalar to agree. - if (isConcrete()) { - MVT IVT = getConcrete(); - IVT = IVT.getVectorElementType(); - return MadeChange || VTOperand.MergeInTypeInfo(IVT.SimpleTy, TP); - } - - // If the scalar type is known, filter out vector types whose element types - // disagree. - if (!VTOperand.isConcrete()) - return MadeChange; - - MVT::SimpleValueType VT = VTOperand.getConcrete(); - - MadeChange |= EnforceVectorEltTypeIs(VT, TP); - - return MadeChange; + bool Changed = false; + if (V.empty()) + Changed |= EnforceAny(V); + if (W.empty()) + Changed |= EnforceAny(W); + + // An actual vector type cannot have 0 elements, so we can treat scalars + // as zero-length vectors. This way both vectors and scalars can be + // processed identically. + auto NoLength = [](const SmallSet<unsigned,2> &Lengths, MVT T) -> bool { + return !Lengths.count(T.isVector() ? T.getVectorNumElements() : 0); + }; + + for (unsigned M : union_modes(V, W)) { + TypeSetByHwMode::SetType &VS = V.get(M); + TypeSetByHwMode::SetType &WS = W.get(M); + + SmallSet<unsigned,2> VN, WN; + for (MVT T : VS) + VN.insert(T.isVector() ? T.getVectorNumElements() : 0); + for (MVT T : WS) + WN.insert(T.isVector() ? T.getVectorNumElements() : 0); + + Changed |= berase_if(VS, std::bind(NoLength, WN, std::placeholders::_1)); + Changed |= berase_if(WS, std::bind(NoLength, VN, std::placeholders::_1)); + } + return Changed; } -/// EnforceVectorSubVectorTypeIs - 'this' is now constrained to be a -/// vector type specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorSubVectorTypeIs(EEVT::TypeSet &VTOperand, - TreePattern &TP) { +/// 1. Ensure that for each type T in A, there exists a type U in B, +/// such that T and U have equal size in bits. +/// 2. Ensure that for each type U in B, there exists a type T in A +/// such that T and U have equal size in bits (reverse of 1). +bool TypeInfer::EnforceSameSize(TypeSetByHwMode &A, TypeSetByHwMode &B) { + ValidateOnExit _1(A), _2(B); if (TP.hasError()) return false; + bool Changed = false; + if (A.empty()) + Changed |= EnforceAny(A); + if (B.empty()) + Changed |= EnforceAny(B); - // "This" must be a vector and "VTOperand" must be a vector. - bool MadeChange = false; - MadeChange |= EnforceVector(TP); - MadeChange |= VTOperand.EnforceVector(TP); - - // If one side is known to be integer or known to be FP but the other side has - // no information, get at least the type integrality info in there. - if (!hasFloatingPointTypes()) - MadeChange |= VTOperand.EnforceInteger(TP); - else if (!hasIntegerTypes()) - MadeChange |= VTOperand.EnforceFloatingPoint(TP); - if (!VTOperand.hasFloatingPointTypes()) - MadeChange |= EnforceInteger(TP); - else if (!VTOperand.hasIntegerTypes()) - MadeChange |= EnforceFloatingPoint(TP); - - assert(!isCompletelyUnknown() && !VTOperand.isCompletelyUnknown() && - "Should have a type list now"); - - // If we know the vector type, it forces the scalar types to agree. - // Also force one vector to have more elements than the other. - if (isConcrete()) { - MVT IVT = getConcrete(); - unsigned NumElems = IVT.getVectorNumElements(); - IVT = IVT.getVectorElementType(); - - EEVT::TypeSet EltTypeSet(IVT.SimpleTy, TP); - MadeChange |= VTOperand.EnforceVectorEltTypeIs(EltTypeSet, TP); - - // Only keep types that have less elements than VTOperand. - TypeSet InputSet(VTOperand); - - auto I = remove_if(VTOperand.TypeVec, [NumElems](MVT VVT) { - return VVT.getVectorNumElements() >= NumElems; - }); - MadeChange |= I != VTOperand.TypeVec.end(); - VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); - - if (VTOperand.TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have less vector elements than '" + - getName() + "'"); - return false; - } - } else if (VTOperand.isConcrete()) { - MVT IVT = VTOperand.getConcrete(); - unsigned NumElems = IVT.getVectorNumElements(); - IVT = IVT.getVectorElementType(); - - EEVT::TypeSet EltTypeSet(IVT.SimpleTy, TP); - MadeChange |= EnforceVectorEltTypeIs(EltTypeSet, TP); + auto NoSize = [](const SmallSet<unsigned,2> &Sizes, MVT T) -> bool { + return !Sizes.count(T.getSizeInBits()); + }; - // Only keep types that have more elements than 'this'. - TypeSet InputSet(*this); + for (unsigned M : union_modes(A, B)) { + TypeSetByHwMode::SetType &AS = A.get(M); + TypeSetByHwMode::SetType &BS = B.get(M); + SmallSet<unsigned,2> AN, BN; - auto I = remove_if(TypeVec, [NumElems](MVT VVT) { - return VVT.getVectorNumElements() <= NumElems; - }); - MadeChange |= I != TypeVec.end(); - TypeVec.erase(I, TypeVec.end()); + for (MVT T : AS) + AN.insert(T.getSizeInBits()); + for (MVT T : BS) + BN.insert(T.getSizeInBits()); - if (TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have more vector elements than '" + - VTOperand.getName() + "'"); - return false; - } + Changed |= berase_if(AS, std::bind(NoSize, BN, std::placeholders::_1)); + Changed |= berase_if(BS, std::bind(NoSize, AN, std::placeholders::_1)); } - return MadeChange; + return Changed; } -/// EnforceameNumElts - If VTOperand is a scalar, then 'this' is a scalar. If -/// VTOperand is a vector, then 'this' must have the same number of elements. -bool EEVT::TypeSet::EnforceSameNumElts(EEVT::TypeSet &VTOperand, - TreePattern &TP) { - if (TP.hasError()) - return false; - - bool MadeChange = false; +void TypeInfer::expandOverloads(TypeSetByHwMode &VTS) { + ValidateOnExit _1(VTS); + TypeSetByHwMode Legal = getLegalTypes(); + bool HaveLegalDef = Legal.hasDefault(); - if (isCompletelyUnknown()) - MadeChange = FillWithPossibleTypes(TP); - - if (VTOperand.isCompletelyUnknown()) - MadeChange = VTOperand.FillWithPossibleTypes(TP); - - // If one contains vectors but the other doesn't pull vectors out. - if (!hasVectorTypes()) - MadeChange |= VTOperand.EnforceScalar(TP); - else if (!hasScalarTypes()) - MadeChange |= VTOperand.EnforceVector(TP); - if (!VTOperand.hasVectorTypes()) - MadeChange |= EnforceScalar(TP); - else if (!VTOperand.hasScalarTypes()) - MadeChange |= EnforceVector(TP); - - // If one type is a vector, make sure the other has the same element count. - // If this a scalar, then we are already done with the above. - if (isConcrete()) { - MVT IVT = getConcrete(); - if (IVT.isVector()) { - unsigned NumElems = IVT.getVectorNumElements(); - - // Only keep types that have same elements as 'this'. - TypeSet InputSet(VTOperand); - - auto I = remove_if(VTOperand.TypeVec, [NumElems](MVT VVT) { - return VVT.getVectorNumElements() != NumElems; - }); - MadeChange |= I != VTOperand.TypeVec.end(); - VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); - - if (VTOperand.TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have same number elements as '" + - getName() + "'"); - return false; - } + for (auto &I : VTS) { + unsigned M = I.first; + if (!Legal.hasMode(M) && !HaveLegalDef) { + TP.error("Invalid mode " + Twine(M)); + return; } - } else if (VTOperand.isConcrete()) { - MVT IVT = VTOperand.getConcrete(); - if (IVT.isVector()) { - unsigned NumElems = IVT.getVectorNumElements(); - - // Only keep types that have same elements as VTOperand. - TypeSet InputSet(*this); + expandOverloads(I.second, Legal.get(M)); + } +} - auto I = remove_if(TypeVec, [NumElems](MVT VVT) { - return VVT.getVectorNumElements() != NumElems; - }); - MadeChange |= I != TypeVec.end(); - TypeVec.erase(I, TypeVec.end()); +void TypeInfer::expandOverloads(TypeSetByHwMode::SetType &Out, + const TypeSetByHwMode::SetType &Legal) { + std::set<MVT> Ovs; + for (MVT T : Out) { + if (!T.isOverloaded()) + continue; - if (TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have same number elements than '" + - VTOperand.getName() + "'"); - return false; - } + Ovs.insert(T); + // MachineValueTypeSet allows iteration and erasing. + Out.erase(T); + } + + for (MVT Ov : Ovs) { + switch (Ov.SimpleTy) { + case MVT::iPTRAny: + Out.insert(MVT::iPTR); + return; + case MVT::iAny: + for (MVT T : MVT::integer_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + for (MVT T : MVT::integer_vector_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + return; + case MVT::fAny: + for (MVT T : MVT::fp_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + for (MVT T : MVT::fp_vector_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + return; + case MVT::vAny: + for (MVT T : MVT::vector_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + return; + case MVT::Any: + for (MVT T : MVT::all_valuetypes()) + if (Legal.count(T)) + Out.insert(T); + return; + default: + break; } } +} - return MadeChange; +TypeSetByHwMode TypeInfer::getLegalTypes() { + if (!LegalTypesCached) { + // Stuff all types from all modes into the default mode. + const TypeSetByHwMode <S = TP.getDAGPatterns().getLegalTypes(); + for (const auto &I : LTS) + LegalCache.insert(I.second); + LegalTypesCached = true; + } + TypeSetByHwMode VTS; + VTS.getOrCreate(DefaultMode) = LegalCache; + return VTS; } -/// EnforceSameSize - 'this' is now constrained to be same size as VTOperand. -bool EEVT::TypeSet::EnforceSameSize(EEVT::TypeSet &VTOperand, - TreePattern &TP) { - if (TP.hasError()) - return false; +//===----------------------------------------------------------------------===// +// TreePredicateFn Implementation +//===----------------------------------------------------------------------===// - bool MadeChange = false; +/// TreePredicateFn constructor. Here 'N' is a subclass of PatFrag. +TreePredicateFn::TreePredicateFn(TreePattern *N) : PatFragRec(N) { + assert( + (!hasPredCode() || !hasImmCode()) && + ".td file corrupt: can't have a node predicate *and* an imm predicate"); +} - if (isCompletelyUnknown()) - MadeChange = FillWithPossibleTypes(TP); +bool TreePredicateFn::hasPredCode() const { + return isLoad() || isStore() || isAtomic() || + !PatFragRec->getRecord()->getValueAsString("PredicateCode").empty(); +} - if (VTOperand.isCompletelyUnknown()) - MadeChange = VTOperand.FillWithPossibleTypes(TP); +std::string TreePredicateFn::getPredCode() const { + std::string Code = ""; - // If we know one of the types, it forces the other type agree. - if (isConcrete()) { - MVT IVT = getConcrete(); - unsigned Size = IVT.getSizeInBits(); + if (!isLoad() && !isStore() && !isAtomic()) { + Record *MemoryVT = getMemoryVT(); - // Only keep types that have the same size as 'this'. - TypeSet InputSet(VTOperand); + if (MemoryVT) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "MemoryVT requires IsLoad or IsStore"); + } - auto I = remove_if(VTOperand.TypeVec, - [&](MVT VT) { return VT.getSizeInBits() != Size; }); - MadeChange |= I != VTOperand.TypeVec.end(); - VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); + if (!isLoad() && !isStore()) { + if (isUnindexed()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsUnindexed requires IsLoad or IsStore"); - if (VTOperand.TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have same size as '" + - getName() + "'"); - return false; - } - } else if (VTOperand.isConcrete()) { - MVT IVT = VTOperand.getConcrete(); - unsigned Size = IVT.getSizeInBits(); + Record *ScalarMemoryVT = getScalarMemoryVT(); - // Only keep types that have the same size as VTOperand. - TypeSet InputSet(*this); + if (ScalarMemoryVT) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "ScalarMemoryVT requires IsLoad or IsStore"); + } - auto I = - remove_if(TypeVec, [&](MVT VT) { return VT.getSizeInBits() != Size; }); - MadeChange |= I != TypeVec.end(); - TypeVec.erase(I, TypeVec.end()); + if (isLoad() + isStore() + isAtomic() > 1) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsLoad, IsStore, and IsAtomic are mutually exclusive"); - if (TypeVec.empty()) { // FIXME: Really want an SMLoc here! - TP.error("Type inference contradiction found, forcing '" + - InputSet.getName() + "' to have same size as '" + - VTOperand.getName() + "'"); - return false; + if (isLoad()) { + if (!isUnindexed() && !isNonExtLoad() && !isAnyExtLoad() && + !isSignExtLoad() && !isZeroExtLoad() && getMemoryVT() == nullptr && + getScalarMemoryVT() == nullptr) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsLoad cannot be used by itself"); + } else { + if (isNonExtLoad()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonExtLoad requires IsLoad"); + if (isAnyExtLoad()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAnyExtLoad requires IsLoad"); + if (isSignExtLoad()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsSignExtLoad requires IsLoad"); + if (isZeroExtLoad()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsZeroExtLoad requires IsLoad"); + } + + if (isStore()) { + if (!isUnindexed() && !isTruncStore() && !isNonTruncStore() && + getMemoryVT() == nullptr && getScalarMemoryVT() == nullptr) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsStore cannot be used by itself"); + } else { + if (isNonTruncStore()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonTruncStore requires IsStore"); + if (isTruncStore()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsTruncStore requires IsStore"); + } + + if (isAtomic()) { + if (getMemoryVT() == nullptr && !isAtomicOrderingMonotonic() && + !isAtomicOrderingAcquire() && !isAtomicOrderingRelease() && + !isAtomicOrderingAcquireRelease() && + !isAtomicOrderingSequentiallyConsistent() && + !isAtomicOrderingAcquireOrStronger() && + !isAtomicOrderingReleaseOrStronger() && + !isAtomicOrderingWeakerThanAcquire() && + !isAtomicOrderingWeakerThanRelease()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomic cannot be used by itself"); + } else { + if (isAtomicOrderingMonotonic()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingMonotonic requires IsAtomic"); + if (isAtomicOrderingAcquire()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingAcquire requires IsAtomic"); + if (isAtomicOrderingRelease()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingRelease requires IsAtomic"); + if (isAtomicOrderingAcquireRelease()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingAcquireRelease requires IsAtomic"); + if (isAtomicOrderingSequentiallyConsistent()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingSequentiallyConsistent requires IsAtomic"); + if (isAtomicOrderingAcquireOrStronger()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingAcquireOrStronger requires IsAtomic"); + if (isAtomicOrderingReleaseOrStronger()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingReleaseOrStronger requires IsAtomic"); + if (isAtomicOrderingWeakerThanAcquire()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAtomicOrderingWeakerThanAcquire requires IsAtomic"); + } + + if (isLoad() || isStore() || isAtomic()) { + StringRef SDNodeName = + isLoad() ? "LoadSDNode" : isStore() ? "StoreSDNode" : "AtomicSDNode"; + + Record *MemoryVT = getMemoryVT(); + + if (MemoryVT) + Code += ("if (cast<" + SDNodeName + ">(N)->getMemoryVT() != MVT::" + + MemoryVT->getName() + ") return false;\n") + .str(); + } + + if (isAtomic() && isAtomicOrderingMonotonic()) + Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " + "AtomicOrdering::Monotonic) return false;\n"; + if (isAtomic() && isAtomicOrderingAcquire()) + Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " + "AtomicOrdering::Acquire) return false;\n"; + if (isAtomic() && isAtomicOrderingRelease()) + Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " + "AtomicOrdering::Release) return false;\n"; + if (isAtomic() && isAtomicOrderingAcquireRelease()) + Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " + "AtomicOrdering::AcquireRelease) return false;\n"; + if (isAtomic() && isAtomicOrderingSequentiallyConsistent()) + Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " + "AtomicOrdering::SequentiallyConsistent) return false;\n"; + + if (isAtomic() && isAtomicOrderingAcquireOrStronger()) + Code += "if (!isAcquireOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " + "return false;\n"; + if (isAtomic() && isAtomicOrderingWeakerThanAcquire()) + Code += "if (isAcquireOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " + "return false;\n"; + + if (isAtomic() && isAtomicOrderingReleaseOrStronger()) + Code += "if (!isReleaseOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " + "return false;\n"; + if (isAtomic() && isAtomicOrderingWeakerThanRelease()) + Code += "if (isReleaseOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " + "return false;\n"; + + if (isLoad() || isStore()) { + StringRef SDNodeName = isLoad() ? "LoadSDNode" : "StoreSDNode"; + + if (isUnindexed()) + Code += ("if (cast<" + SDNodeName + + ">(N)->getAddressingMode() != ISD::UNINDEXED) " + "return false;\n") + .str(); + + if (isLoad()) { + if ((isNonExtLoad() + isAnyExtLoad() + isSignExtLoad() + + isZeroExtLoad()) > 1) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonExtLoad, IsAnyExtLoad, IsSignExtLoad, and " + "IsZeroExtLoad are mutually exclusive"); + if (isNonExtLoad()) + Code += "if (cast<LoadSDNode>(N)->getExtensionType() != " + "ISD::NON_EXTLOAD) return false;\n"; + if (isAnyExtLoad()) + Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::EXTLOAD) " + "return false;\n"; + if (isSignExtLoad()) + Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::SEXTLOAD) " + "return false;\n"; + if (isZeroExtLoad()) + Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::ZEXTLOAD) " + "return false;\n"; + } else { + if ((isNonTruncStore() + isTruncStore()) > 1) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonTruncStore, and IsTruncStore are mutually exclusive"); + if (isNonTruncStore()) + Code += + " if (cast<StoreSDNode>(N)->isTruncatingStore()) return false;\n"; + if (isTruncStore()) + Code += + " if (!cast<StoreSDNode>(N)->isTruncatingStore()) return false;\n"; } + + Record *ScalarMemoryVT = getScalarMemoryVT(); + + if (ScalarMemoryVT) + Code += ("if (cast<" + SDNodeName + + ">(N)->getMemoryVT().getScalarType() != MVT::" + + ScalarMemoryVT->getName() + ") return false;\n") + .str(); } - return MadeChange; -} + std::string PredicateCode = PatFragRec->getRecord()->getValueAsString("PredicateCode"); -//===----------------------------------------------------------------------===// -// Helpers for working with extended types. + Code += PredicateCode; -/// Dependent variable map for CodeGenDAGPattern variant generation -typedef std::map<std::string, int> DepVarMap; + if (PredicateCode.empty() && !Code.empty()) + Code += "return true;\n"; -static void FindDepVarsOf(TreePatternNode *N, DepVarMap &DepMap) { - if (N->isLeaf()) { - if (isa<DefInit>(N->getLeafValue())) - DepMap[N->getName()]++; - } else { - for (size_t i = 0, e = N->getNumChildren(); i != e; ++i) - FindDepVarsOf(N->getChild(i), DepMap); - } -} - -/// Find dependent variables within child patterns -static void FindDepVars(TreePatternNode *N, MultipleUseVarSet &DepVars) { - DepVarMap depcounts; - FindDepVarsOf(N, depcounts); - for (const std::pair<std::string, int> &Pair : depcounts) { - if (Pair.second > 1) - DepVars.insert(Pair.first); - } + return Code; } -#ifndef NDEBUG -/// Dump the dependent variable set: -static void DumpDepVars(MultipleUseVarSet &DepVars) { - if (DepVars.empty()) { - DEBUG(errs() << "<empty set>"); - } else { - DEBUG(errs() << "[ "); - for (const std::string &DepVar : DepVars) { - DEBUG(errs() << DepVar << " "); - } - DEBUG(errs() << "]"); - } +bool TreePredicateFn::hasImmCode() const { + return !PatFragRec->getRecord()->getValueAsString("ImmediateCode").empty(); } -#endif +std::string TreePredicateFn::getImmCode() const { + return PatFragRec->getRecord()->getValueAsString("ImmediateCode"); +} -//===----------------------------------------------------------------------===// -// TreePredicateFn Implementation -//===----------------------------------------------------------------------===// +bool TreePredicateFn::immCodeUsesAPInt() const { + return getOrigPatFragRecord()->getRecord()->getValueAsBit("IsAPInt"); +} -/// TreePredicateFn constructor. Here 'N' is a subclass of PatFrag. -TreePredicateFn::TreePredicateFn(TreePattern *N) : PatFragRec(N) { - assert((getPredCode().empty() || getImmCode().empty()) && - ".td file corrupt: can't have a node predicate *and* an imm predicate"); +bool TreePredicateFn::immCodeUsesAPFloat() const { + bool Unset; + // The return value will be false when IsAPFloat is unset. + return getOrigPatFragRecord()->getRecord()->getValueAsBitOrUnset("IsAPFloat", + Unset); } -std::string TreePredicateFn::getPredCode() const { - return PatFragRec->getRecord()->getValueAsString("PredicateCode"); +bool TreePredicateFn::isPredefinedPredicateEqualTo(StringRef Field, + bool Value) const { + bool Unset; + bool Result = + getOrigPatFragRecord()->getRecord()->getValueAsBitOrUnset(Field, Unset); + if (Unset) + return false; + return Result == Value; +} +bool TreePredicateFn::isLoad() const { + return isPredefinedPredicateEqualTo("IsLoad", true); +} +bool TreePredicateFn::isStore() const { + return isPredefinedPredicateEqualTo("IsStore", true); +} +bool TreePredicateFn::isAtomic() const { + return isPredefinedPredicateEqualTo("IsAtomic", true); +} +bool TreePredicateFn::isUnindexed() const { + return isPredefinedPredicateEqualTo("IsUnindexed", true); +} +bool TreePredicateFn::isNonExtLoad() const { + return isPredefinedPredicateEqualTo("IsNonExtLoad", true); +} +bool TreePredicateFn::isAnyExtLoad() const { + return isPredefinedPredicateEqualTo("IsAnyExtLoad", true); +} +bool TreePredicateFn::isSignExtLoad() const { + return isPredefinedPredicateEqualTo("IsSignExtLoad", true); +} +bool TreePredicateFn::isZeroExtLoad() const { + return isPredefinedPredicateEqualTo("IsZeroExtLoad", true); +} +bool TreePredicateFn::isNonTruncStore() const { + return isPredefinedPredicateEqualTo("IsTruncStore", false); +} +bool TreePredicateFn::isTruncStore() const { + return isPredefinedPredicateEqualTo("IsTruncStore", true); +} +bool TreePredicateFn::isAtomicOrderingMonotonic() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingMonotonic", true); +} +bool TreePredicateFn::isAtomicOrderingAcquire() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquire", true); +} +bool TreePredicateFn::isAtomicOrderingRelease() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingRelease", true); +} +bool TreePredicateFn::isAtomicOrderingAcquireRelease() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireRelease", true); +} +bool TreePredicateFn::isAtomicOrderingSequentiallyConsistent() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingSequentiallyConsistent", + true); +} +bool TreePredicateFn::isAtomicOrderingAcquireOrStronger() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireOrStronger", true); +} +bool TreePredicateFn::isAtomicOrderingWeakerThanAcquire() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireOrStronger", false); +} +bool TreePredicateFn::isAtomicOrderingReleaseOrStronger() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingReleaseOrStronger", true); +} +bool TreePredicateFn::isAtomicOrderingWeakerThanRelease() const { + return isPredefinedPredicateEqualTo("IsAtomicOrderingReleaseOrStronger", false); +} +Record *TreePredicateFn::getMemoryVT() const { + Record *R = getOrigPatFragRecord()->getRecord(); + if (R->isValueUnset("MemoryVT")) + return nullptr; + return R->getValueAsDef("MemoryVT"); +} +Record *TreePredicateFn::getScalarMemoryVT() const { + Record *R = getOrigPatFragRecord()->getRecord(); + if (R->isValueUnset("ScalarMemoryVT")) + return nullptr; + return R->getValueAsDef("ScalarMemoryVT"); } -std::string TreePredicateFn::getImmCode() const { - return PatFragRec->getRecord()->getValueAsString("ImmediateCode"); +StringRef TreePredicateFn::getImmType() const { + if (immCodeUsesAPInt()) + return "const APInt &"; + if (immCodeUsesAPFloat()) + return "const APFloat &"; + return "int64_t"; } +StringRef TreePredicateFn::getImmTypeIdentifier() const { + if (immCodeUsesAPInt()) + return "APInt"; + else if (immCodeUsesAPFloat()) + return "APFloat"; + return "I64"; +} /// isAlwaysTrue - Return true if this is a noop predicate. bool TreePredicateFn::isAlwaysTrue() const { - return getPredCode().empty() && getImmCode().empty(); + return !hasPredCode() && !hasImmCode(); } /// Return the name to use in the generated code to reference this, this is @@ -790,14 +1157,61 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const { // Handle immediate predicates first. std::string ImmCode = getImmCode(); if (!ImmCode.empty()) { - std::string Result = - " int64_t Imm = cast<ConstantSDNode>(Node)->getSExtValue();\n"; + if (isLoad()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsLoad cannot be used with ImmLeaf or its subclasses"); + if (isStore()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "IsStore cannot be used with ImmLeaf or its subclasses"); + if (isUnindexed()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsUnindexed cannot be used with ImmLeaf or its subclasses"); + if (isNonExtLoad()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonExtLoad cannot be used with ImmLeaf or its subclasses"); + if (isAnyExtLoad()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsAnyExtLoad cannot be used with ImmLeaf or its subclasses"); + if (isSignExtLoad()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsSignExtLoad cannot be used with ImmLeaf or its subclasses"); + if (isZeroExtLoad()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsZeroExtLoad cannot be used with ImmLeaf or its subclasses"); + if (isNonTruncStore()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsNonTruncStore cannot be used with ImmLeaf or its subclasses"); + if (isTruncStore()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "IsTruncStore cannot be used with ImmLeaf or its subclasses"); + if (getMemoryVT()) + PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), + "MemoryVT cannot be used with ImmLeaf or its subclasses"); + if (getScalarMemoryVT()) + PrintFatalError( + getOrigPatFragRecord()->getRecord()->getLoc(), + "ScalarMemoryVT cannot be used with ImmLeaf or its subclasses"); + + std::string Result = (" " + getImmType() + " Imm = ").str(); + if (immCodeUsesAPFloat()) + Result += "cast<ConstantFPSDNode>(Node)->getValueAPF();\n"; + else if (immCodeUsesAPInt()) + Result += "cast<ConstantSDNode>(Node)->getAPIntValue();\n"; + else + Result += "cast<ConstantSDNode>(Node)->getSExtValue();\n"; return Result + ImmCode; } - + // Handle arbitrary node predicates. - assert(!getPredCode().empty() && "Don't have any predicate code!"); - std::string ClassName; + assert(hasPredCode() && "Don't have any predicate code!"); + StringRef ClassName; if (PatFragRec->getOnlyTree()->isLeaf()) ClassName = "SDNode"; else { @@ -808,8 +1222,8 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const { if (ClassName == "SDNode") Result = " SDNode *N = Node;\n"; else - Result = " auto *N = cast<" + ClassName + ">(Node);\n"; - + Result = " auto *N = cast<" + ClassName.str() + ">(Node);\n"; + return Result + getPredCode(); } @@ -817,7 +1231,6 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const { // PatternToMatch implementation // - /// getPatternSize - Return the 'size' of this pattern. We want to match large /// patterns before small ones. This is used to determine the size of a /// pattern. @@ -829,10 +1242,8 @@ static unsigned getPatternSize(const TreePatternNode *P, if (P->isLeaf() && isa<IntInit>(P->getLeafValue())) Size += 2; - const ComplexPattern *AM = P->getComplexPatternInfo(CGP); - if (AM) { + if (const ComplexPattern *AM = P->getComplexPatternInfo(CGP)) { Size += AM->getComplexity(); - // We don't want to count any children twice, so return early. return Size; } @@ -844,11 +1255,17 @@ static unsigned getPatternSize(const TreePatternNode *P, // Count children in the count if they are also nodes. for (unsigned i = 0, e = P->getNumChildren(); i != e; ++i) { - TreePatternNode *Child = P->getChild(i); - if (!Child->isLeaf() && Child->getNumTypes() && - Child->getType(0) != MVT::Other) - Size += getPatternSize(Child, CGP); - else if (Child->isLeaf()) { + const TreePatternNode *Child = P->getChild(i); + if (!Child->isLeaf() && Child->getNumTypes()) { + const TypeSetByHwMode &T0 = Child->getType(0); + // At this point, all variable type sets should be simple, i.e. only + // have a default mode. + if (T0.getMachineValueType() != MVT::Other) { + Size += getPatternSize(Child, CGP); + continue; + } + } + if (Child->isLeaf()) { if (isa<IntInit>(Child->getLeafValue())) Size += 5; // Matches a ConstantSDNode (+3) and a specific value (+2). else if (Child->getComplexPatternInfo(CGP)) @@ -868,52 +1285,37 @@ getPatternComplexity(const CodeGenDAGPatterns &CGP) const { return getPatternSize(getSrcPattern(), CGP) + getAddedComplexity(); } - /// getPredicateCheck - Return a single string containing all of this /// pattern's predicates concatenated with "&&" operators. /// std::string PatternToMatch::getPredicateCheck() const { - SmallVector<Record *, 4> PredicateRecs; - for (Init *I : Predicates->getValues()) { - if (DefInit *Pred = dyn_cast<DefInit>(I)) { - Record *Def = Pred->getDef(); - if (!Def->isSubClassOf("Predicate")) { -#ifndef NDEBUG - Def->dump(); -#endif - llvm_unreachable("Unknown predicate type!"); - } - PredicateRecs.push_back(Def); - } - } - // Sort so that different orders get canonicalized to the same string. - std::sort(PredicateRecs.begin(), PredicateRecs.end(), LessRecord()); - - SmallString<128> PredicateCheck; - for (Record *Pred : PredicateRecs) { - if (!PredicateCheck.empty()) - PredicateCheck += " && "; - PredicateCheck += "("; - PredicateCheck += Pred->getValueAsString("CondString"); - PredicateCheck += ")"; - } - - return PredicateCheck.str(); + SmallVector<const Predicate*,4> PredList; + for (const Predicate &P : Predicates) + PredList.push_back(&P); + std::sort(PredList.begin(), PredList.end(), deref<llvm::less>()); + + std::string Check; + for (unsigned i = 0, e = PredList.size(); i != e; ++i) { + if (i != 0) + Check += " && "; + Check += '(' + PredList[i]->getCondString() + ')'; + } + return Check; } //===----------------------------------------------------------------------===// // SDTypeConstraint implementation // -SDTypeConstraint::SDTypeConstraint(Record *R) { +SDTypeConstraint::SDTypeConstraint(Record *R, const CodeGenHwModes &CGH) { OperandNo = R->getValueAsInt("OperandNum"); if (R->isSubClassOf("SDTCisVT")) { ConstraintType = SDTCisVT; - x.SDTCisVT_Info.VT = getValueType(R->getValueAsDef("VT")); - if (x.SDTCisVT_Info.VT == MVT::isVoid) - PrintFatalError(R->getLoc(), "Cannot use 'Void' as type to SDTCisVT"); - + VVT = getValueTypeByHwMode(R->getValueAsDef("VT"), CGH); + for (const auto &P : VVT) + if (P.second == MVT::isVoid) + PrintFatalError(R->getLoc(), "Cannot use 'Void' as type to SDTCisVT"); } else if (R->isSubClassOf("SDTCisPtrTy")) { ConstraintType = SDTCisPtrTy; } else if (R->isSubClassOf("SDTCisInt")) { @@ -942,13 +1344,16 @@ SDTypeConstraint::SDTypeConstraint(Record *R) { R->getValueAsInt("OtherOpNum"); } else if (R->isSubClassOf("SDTCVecEltisVT")) { ConstraintType = SDTCVecEltisVT; - x.SDTCVecEltisVT_Info.VT = getValueType(R->getValueAsDef("VT")); - if (MVT(x.SDTCVecEltisVT_Info.VT).isVector()) - PrintFatalError(R->getLoc(), "Cannot use vector type as SDTCVecEltisVT"); - if (!MVT(x.SDTCVecEltisVT_Info.VT).isInteger() && - !MVT(x.SDTCVecEltisVT_Info.VT).isFloatingPoint()) - PrintFatalError(R->getLoc(), "Must use integer or floating point type " - "as SDTCVecEltisVT"); + VVT = getValueTypeByHwMode(R->getValueAsDef("VT"), CGH); + for (const auto &P : VVT) { + MVT T = P.second; + if (T.isVector()) + PrintFatalError(R->getLoc(), + "Cannot use vector type as SDTCVecEltisVT"); + if (!T.isInteger() && !T.isFloatingPoint()) + PrintFatalError(R->getLoc(), "Must use integer or floating point type " + "as SDTCVecEltisVT"); + } } else if (R->isSubClassOf("SDTCisSameNumEltsAs")) { ConstraintType = SDTCisSameNumEltsAs; x.SDTCisSameNumEltsAs_Info.OtherOperandNum = @@ -998,23 +1403,24 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N, unsigned ResNo = 0; // The result number being referenced. TreePatternNode *NodeToApply = getOperandNum(OperandNo, N, NodeInfo, ResNo); + TypeInfer &TI = TP.getInfer(); switch (ConstraintType) { case SDTCisVT: // Operand must be a particular type. - return NodeToApply->UpdateNodeType(ResNo, x.SDTCisVT_Info.VT, TP); + return NodeToApply->UpdateNodeType(ResNo, VVT, TP); case SDTCisPtrTy: // Operand must be same as target pointer type. return NodeToApply->UpdateNodeType(ResNo, MVT::iPTR, TP); case SDTCisInt: // Require it to be one of the legal integer VTs. - return NodeToApply->getExtType(ResNo).EnforceInteger(TP); + return TI.EnforceInteger(NodeToApply->getExtType(ResNo)); case SDTCisFP: // Require it to be one of the legal fp VTs. - return NodeToApply->getExtType(ResNo).EnforceFloatingPoint(TP); + return TI.EnforceFloatingPoint(NodeToApply->getExtType(ResNo)); case SDTCisVec: // Require it to be one of the legal vector VTs. - return NodeToApply->getExtType(ResNo).EnforceVector(TP); + return TI.EnforceVector(NodeToApply->getExtType(ResNo)); case SDTCisSameAs: { unsigned OResNo = 0; TreePatternNode *OtherNode = @@ -1032,36 +1438,35 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N, TP.error(N->getOperator()->getName() + " expects a VT operand!"); return false; } - MVT::SimpleValueType VT = - getValueType(static_cast<DefInit*>(NodeToApply->getLeafValue())->getDef()); - - EEVT::TypeSet TypeListTmp(VT, TP); + DefInit *DI = static_cast<DefInit*>(NodeToApply->getLeafValue()); + const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); + auto VVT = getValueTypeByHwMode(DI->getDef(), T.getHwModes()); + TypeSetByHwMode TypeListTmp(VVT); unsigned OResNo = 0; TreePatternNode *OtherNode = getOperandNum(x.SDTCisVTSmallerThanOp_Info.OtherOperandNum, N, NodeInfo, OResNo); - return TypeListTmp.EnforceSmallerThan(OtherNode->getExtType(OResNo), TP); + return TI.EnforceSmallerThan(TypeListTmp, OtherNode->getExtType(OResNo)); } case SDTCisOpSmallerThanOp: { unsigned BResNo = 0; TreePatternNode *BigOperand = getOperandNum(x.SDTCisOpSmallerThanOp_Info.BigOperandNum, N, NodeInfo, BResNo); - return NodeToApply->getExtType(ResNo). - EnforceSmallerThan(BigOperand->getExtType(BResNo), TP); + return TI.EnforceSmallerThan(NodeToApply->getExtType(ResNo), + BigOperand->getExtType(BResNo)); } case SDTCisEltOfVec: { unsigned VResNo = 0; TreePatternNode *VecOperand = getOperandNum(x.SDTCisEltOfVec_Info.OtherOperandNum, N, NodeInfo, VResNo); - // Filter vector types out of VecOperand that don't have the right element // type. - return VecOperand->getExtType(VResNo). - EnforceVectorEltTypeIs(NodeToApply->getExtType(ResNo), TP); + return TI.EnforceVectorEltTypeIs(VecOperand->getExtType(VResNo), + NodeToApply->getExtType(ResNo)); } case SDTCisSubVecOfVec: { unsigned VResNo = 0; @@ -1071,28 +1476,27 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N, // Filter vector types out of BigVecOperand that don't have the // right subvector type. - return BigVecOperand->getExtType(VResNo). - EnforceVectorSubVectorTypeIs(NodeToApply->getExtType(ResNo), TP); + return TI.EnforceVectorSubVectorTypeIs(BigVecOperand->getExtType(VResNo), + NodeToApply->getExtType(ResNo)); } case SDTCVecEltisVT: { - return NodeToApply->getExtType(ResNo). - EnforceVectorEltTypeIs(x.SDTCVecEltisVT_Info.VT, TP); + return TI.EnforceVectorEltTypeIs(NodeToApply->getExtType(ResNo), VVT); } case SDTCisSameNumEltsAs: { unsigned OResNo = 0; TreePatternNode *OtherNode = getOperandNum(x.SDTCisSameNumEltsAs_Info.OtherOperandNum, N, NodeInfo, OResNo); - return OtherNode->getExtType(OResNo). - EnforceSameNumElts(NodeToApply->getExtType(ResNo), TP); + return TI.EnforceSameNumElts(OtherNode->getExtType(OResNo), + NodeToApply->getExtType(ResNo)); } case SDTCisSameSizeAs: { unsigned OResNo = 0; TreePatternNode *OtherNode = getOperandNum(x.SDTCisSameSizeAs_Info.OtherOperandNum, N, NodeInfo, OResNo); - return OtherNode->getExtType(OResNo). - EnforceSameSize(NodeToApply->getExtType(ResNo), TP); + return TI.EnforceSameSize(OtherNode->getExtType(OResNo), + NodeToApply->getExtType(ResNo)); } } llvm_unreachable("Invalid ConstraintType!"); @@ -1110,9 +1514,11 @@ bool TreePatternNode::UpdateNodeTypeFromInst(unsigned ResNo, return false; // The Operand class specifies a type directly. - if (Operand->isSubClassOf("Operand")) - return UpdateNodeType(ResNo, getValueType(Operand->getValueAsDef("Type")), - TP); + if (Operand->isSubClassOf("Operand")) { + Record *R = Operand->getValueAsDef("Type"); + const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); + return UpdateNodeType(ResNo, getValueTypeByHwMode(R, T.getHwModes()), TP); + } // PointerLikeRegClass has a type that is determined at runtime. if (Operand->isSubClassOf("PointerLikeRegClass")) @@ -1131,11 +1537,53 @@ bool TreePatternNode::UpdateNodeTypeFromInst(unsigned ResNo, return UpdateNodeType(ResNo, Tgt.getRegisterClass(RC).getValueTypes(), TP); } +bool TreePatternNode::ContainsUnresolvedType(TreePattern &TP) const { + for (unsigned i = 0, e = Types.size(); i != e; ++i) + if (!TP.getInfer().isConcrete(Types[i], true)) + return true; + for (unsigned i = 0, e = getNumChildren(); i != e; ++i) + if (getChild(i)->ContainsUnresolvedType(TP)) + return true; + return false; +} + +bool TreePatternNode::hasProperTypeByHwMode() const { + for (const TypeSetByHwMode &S : Types) + if (!S.isDefaultOnly()) + return true; + for (TreePatternNode *C : Children) + if (C->hasProperTypeByHwMode()) + return true; + return false; +} + +bool TreePatternNode::hasPossibleType() const { + for (const TypeSetByHwMode &S : Types) + if (!S.isPossible()) + return false; + for (TreePatternNode *C : Children) + if (!C->hasPossibleType()) + return false; + return true; +} + +bool TreePatternNode::setDefaultMode(unsigned Mode) { + for (TypeSetByHwMode &S : Types) { + S.makeSimple(Mode); + // Check if the selected mode had a type conflict. + if (S.get(DefaultMode).empty()) + return false; + } + for (TreePatternNode *C : Children) + if (!C->setDefaultMode(Mode)) + return false; + return true; +} //===----------------------------------------------------------------------===// // SDNodeInfo implementation // -SDNodeInfo::SDNodeInfo(Record *R) : Def(R) { +SDNodeInfo::SDNodeInfo(Record *R, const CodeGenHwModes &CGH) : Def(R) { EnumName = R->getValueAsString("Opcode"); SDClassName = R->getValueAsString("SDClass"); Record *TypeProfile = R->getValueAsDef("TypeProfile"); @@ -1178,7 +1626,8 @@ SDNodeInfo::SDNodeInfo(Record *R) : Def(R) { // Parse the type constraints. std::vector<Record*> ConstraintList = TypeProfile->getValueAsListOfDefs("Constraints"); - TypeConstraints.assign(ConstraintList.begin(), ConstraintList.end()); + for (Record *R : ConstraintList) + TypeConstraints.emplace_back(R, CGH); } /// getKnownType - If the type constraints on this node imply a fixed type @@ -1198,7 +1647,9 @@ MVT::SimpleValueType SDNodeInfo::getKnownType(unsigned ResNo) const { switch (Constraint.ConstraintType) { default: break; case SDTypeConstraint::SDTCisVT: - return Constraint.x.SDTCisVT_Info.VT; + if (Constraint.VVT.isSimple()) + return Constraint.VVT.getSimple().SimpleTy; + break; case SDTypeConstraint::SDTCisPtrTy: return MVT::iPTR; } @@ -1284,8 +1735,10 @@ void TreePatternNode::print(raw_ostream &OS) const { else OS << '(' << getOperator()->getName(); - for (unsigned i = 0, e = Types.size(); i != e; ++i) - OS << ':' << getExtType(i).getName(); + for (unsigned i = 0, e = Types.size(); i != e; ++i) { + OS << ':'; + getExtType(i).writeToStream(OS); + } if (!isLeaf()) { if (getNumChildren() != 0) { @@ -1368,7 +1821,7 @@ TreePatternNode *TreePatternNode::clone() const { /// RemoveAllTypes - Recursively strip all the types of this tree. void TreePatternNode::RemoveAllTypes() { // Reset to unknown type. - std::fill(Types.begin(), Types.end(), EEVT::TypeSet()); + std::fill(Types.begin(), Types.end(), TypeSetByHwMode()); if (isLeaf()) return; for (unsigned i = 0, e = getNumChildren(); i != e; ++i) getChild(i)->RemoveAllTypes(); @@ -1485,18 +1938,20 @@ TreePatternNode *TreePatternNode::InlinePatternFragments(TreePattern &TP) { /// When Unnamed is false, return the type of a named DAG operand such as the /// GPR:$src operand above. /// -static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo, - bool NotRegisters, - bool Unnamed, - TreePattern &TP) { +static TypeSetByHwMode getImplicitType(Record *R, unsigned ResNo, + bool NotRegisters, + bool Unnamed, + TreePattern &TP) { + CodeGenDAGPatterns &CDP = TP.getDAGPatterns(); + // Check to see if this is a register operand. if (R->isSubClassOf("RegisterOperand")) { assert(ResNo == 0 && "Regoperand ref only has one result!"); if (NotRegisters) - return EEVT::TypeSet(); // Unknown. + return TypeSetByHwMode(); // Unknown. Record *RegClass = R->getValueAsDef("RegClass"); const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); - return EEVT::TypeSet(T.getRegisterClass(RegClass).getValueTypes()); + return TypeSetByHwMode(T.getRegisterClass(RegClass).getValueTypes()); } // Check to see if this is a register or a register class. @@ -1505,33 +1960,33 @@ static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo, // An unnamed register class represents itself as an i32 immediate, for // example on a COPY_TO_REGCLASS instruction. if (Unnamed) - return EEVT::TypeSet(MVT::i32, TP); + return TypeSetByHwMode(MVT::i32); // In a named operand, the register class provides the possible set of // types. if (NotRegisters) - return EEVT::TypeSet(); // Unknown. + return TypeSetByHwMode(); // Unknown. const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); - return EEVT::TypeSet(T.getRegisterClass(R).getValueTypes()); + return TypeSetByHwMode(T.getRegisterClass(R).getValueTypes()); } if (R->isSubClassOf("PatFrag")) { assert(ResNo == 0 && "FIXME: PatFrag with multiple results?"); // Pattern fragment types will be resolved when they are inlined. - return EEVT::TypeSet(); // Unknown. + return TypeSetByHwMode(); // Unknown. } if (R->isSubClassOf("Register")) { assert(ResNo == 0 && "Registers only produce one result!"); if (NotRegisters) - return EEVT::TypeSet(); // Unknown. + return TypeSetByHwMode(); // Unknown. const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); - return EEVT::TypeSet(T.getRegisterVTs(R)); + return TypeSetByHwMode(T.getRegisterVTs(R)); } if (R->isSubClassOf("SubRegIndex")) { assert(ResNo == 0 && "SubRegisterIndices only produce one result!"); - return EEVT::TypeSet(MVT::i32, TP); + return TypeSetByHwMode(MVT::i32); } if (R->isSubClassOf("ValueType")) { @@ -1541,46 +1996,51 @@ static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo, // (sext_inreg GPR:$src, i16) // ~~~ if (Unnamed) - return EEVT::TypeSet(MVT::Other, TP); + return TypeSetByHwMode(MVT::Other); // With a name, the ValueType simply provides the type of the named // variable. // // (sext_inreg i32:$src, i16) // ~~~~~~~~ if (NotRegisters) - return EEVT::TypeSet(); // Unknown. - return EEVT::TypeSet(getValueType(R), TP); + return TypeSetByHwMode(); // Unknown. + const CodeGenHwModes &CGH = CDP.getTargetInfo().getHwModes(); + return TypeSetByHwMode(getValueTypeByHwMode(R, CGH)); } if (R->isSubClassOf("CondCode")) { assert(ResNo == 0 && "This node only has one result!"); // Using a CondCodeSDNode. - return EEVT::TypeSet(MVT::Other, TP); + return TypeSetByHwMode(MVT::Other); } if (R->isSubClassOf("ComplexPattern")) { assert(ResNo == 0 && "FIXME: ComplexPattern with multiple results?"); if (NotRegisters) - return EEVT::TypeSet(); // Unknown. - return EEVT::TypeSet(TP.getDAGPatterns().getComplexPattern(R).getValueType(), - TP); + return TypeSetByHwMode(); // Unknown. + return TypeSetByHwMode(CDP.getComplexPattern(R).getValueType()); } if (R->isSubClassOf("PointerLikeRegClass")) { assert(ResNo == 0 && "Regclass can only have one result!"); - return EEVT::TypeSet(MVT::iPTR, TP); + TypeSetByHwMode VTS(MVT::iPTR); + TP.getInfer().expandOverloads(VTS); + return VTS; } if (R->getName() == "node" || R->getName() == "srcvalue" || R->getName() == "zero_reg") { // Placeholder. - return EEVT::TypeSet(); // Unknown. + return TypeSetByHwMode(); // Unknown. } - if (R->isSubClassOf("Operand")) - return EEVT::TypeSet(getValueType(R->getValueAsDef("Type"))); + if (R->isSubClassOf("Operand")) { + const CodeGenHwModes &CGH = CDP.getTargetInfo().getHwModes(); + Record *T = R->getValueAsDef("Type"); + return TypeSetByHwMode(getValueTypeByHwMode(T, CGH)); + } TP.error("Unknown node flavor used in pattern: " + R->getName()); - return EEVT::TypeSet(MVT::Other, TP); + return TypeSetByHwMode(MVT::Other); } @@ -1722,29 +2182,34 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) { assert(Types.size() == 1 && "Invalid IntInit"); // Int inits are always integers. :) - bool MadeChange = Types[0].EnforceInteger(TP); - - if (!Types[0].isConcrete()) - return MadeChange; + bool MadeChange = TP.getInfer().EnforceInteger(Types[0]); - MVT::SimpleValueType VT = getType(0); - if (VT == MVT::iPTR || VT == MVT::iPTRAny) + if (!TP.getInfer().isConcrete(Types[0], false)) return MadeChange; - unsigned Size = MVT(VT).getSizeInBits(); - // Make sure that the value is representable for this type. - if (Size >= 32) return MadeChange; - - // Check that the value doesn't use more bits than we have. It must either - // be a sign- or zero-extended equivalent of the original. - int64_t SignBitAndAbove = II->getValue() >> (Size - 1); - if (SignBitAndAbove == -1 || SignBitAndAbove == 0 || SignBitAndAbove == 1) - return MadeChange; + ValueTypeByHwMode VVT = TP.getInfer().getConcrete(Types[0], false); + for (auto &P : VVT) { + MVT::SimpleValueType VT = P.second.SimpleTy; + if (VT == MVT::iPTR || VT == MVT::iPTRAny) + continue; + unsigned Size = MVT(VT).getSizeInBits(); + // Make sure that the value is representable for this type. + if (Size >= 32) + continue; + // Check that the value doesn't use more bits than we have. It must + // either be a sign- or zero-extended equivalent of the original. + int64_t SignBitAndAbove = II->getValue() >> (Size - 1); + if (SignBitAndAbove == -1 || SignBitAndAbove == 0 || + SignBitAndAbove == 1) + continue; - TP.error("Integer value '" + itostr(II->getValue()) + - "' is out of range for type '" + getEnumName(getType(0)) + "'!"); - return false; + TP.error("Integer value '" + itostr(II->getValue()) + + "' is out of range for type '" + getEnumName(VT) + "'!"); + break; + } + return MadeChange; } + return false; } @@ -1773,7 +2238,7 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) { bool MadeChange = false; for (unsigned i = 0; i < getNumChildren(); ++i) - MadeChange = getChild(i)->ApplyTypeConstraints(TP, NotRegisters); + MadeChange |= getChild(i)->ApplyTypeConstraints(TP, NotRegisters); return MadeChange; } @@ -1818,9 +2283,10 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) { return false; } - bool MadeChange = NI.ApplyTypeConstraints(this, TP); + bool MadeChange = false; for (unsigned i = 0, e = getNumChildren(); i != e; ++i) MadeChange |= getChild(i)->ApplyTypeConstraints(TP, NotRegisters); + MadeChange |= NI.ApplyTypeConstraints(this, TP); return MadeChange; } @@ -1975,18 +2441,6 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) { } bool MadeChange = getChild(0)->ApplyTypeConstraints(TP, NotRegisters); - - - // If either the output or input of the xform does not have exact - // type info. We assume they must be the same. Otherwise, it is perfectly - // legal to transform from one type to a completely different type. -#if 0 - if (!hasTypeSet() || !getChild(0)->hasTypeSet()) { - bool MadeChange = UpdateNodeType(getChild(0)->getExtType(), TP); - MadeChange |= getChild(0)->UpdateNodeType(getExtType(), TP); - return MadeChange; - } -#endif return MadeChange; } @@ -2050,20 +2504,23 @@ bool TreePatternNode::canPatternMatch(std::string &Reason, TreePattern::TreePattern(Record *TheRec, ListInit *RawPat, bool isInput, CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), - isInputPattern(isInput), HasError(false) { + isInputPattern(isInput), HasError(false), + Infer(*this) { for (Init *I : RawPat->getValues()) Trees.push_back(ParseTreePattern(I, "")); } TreePattern::TreePattern(Record *TheRec, DagInit *Pat, bool isInput, CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), - isInputPattern(isInput), HasError(false) { + isInputPattern(isInput), HasError(false), + Infer(*this) { Trees.push_back(ParseTreePattern(Pat, "")); } TreePattern::TreePattern(Record *TheRec, TreePatternNode *Pat, bool isInput, CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), - isInputPattern(isInput), HasError(false) { + isInputPattern(isInput), HasError(false), + Infer(*this) { Trees.push_back(Pat); } @@ -2158,7 +2615,8 @@ TreePatternNode *TreePattern::ParseTreePattern(Init *TheInit, StringRef OpName){ // Apply the type cast. assert(New->getNumTypes() == 1 && "FIXME: Unhandled"); - New->UpdateNodeType(0, getValueType(Operator), *this); + const CodeGenHwModes &CGH = getDAGPatterns().getTargetInfo().getHwModes(); + New->UpdateNodeType(0, getValueTypeByHwMode(Operator, CGH), *this); if (!OpName.empty()) error("ValueType cast should not have a name!"); @@ -2273,7 +2731,7 @@ static bool SimplifyTree(TreePatternNode *&N) { // If we have a bitconvert with a resolved type and if the source and // destination types are the same, then the bitconvert is useless, remove it. if (N->getOperator()->getName() == "bitconvert" && - N->getExtType(0).isConcrete() && + N->getExtType(0).isValueTypeByHwMode(false) && N->getExtType(0) == N->getChild(0)->getExtType(0) && N->getName().empty()) { N = N->getChild(0); @@ -2304,7 +2762,7 @@ InferAllTypes(const StringMap<SmallVector<TreePatternNode*,1> > *InNamedTypes) { bool MadeChange = true; while (MadeChange) { MadeChange = false; - for (TreePatternNode *Tree : Trees) { + for (TreePatternNode *&Tree : Trees) { MadeChange |= Tree->ApplyTypeConstraints(*this, false); MadeChange |= SimplifyTree(Tree); } @@ -2364,7 +2822,7 @@ InferAllTypes(const StringMap<SmallVector<TreePatternNode*,1> > *InNamedTypes) { bool HasUnresolvedTypes = false; for (const TreePatternNode *Tree : Trees) - HasUnresolvedTypes |= Tree->ContainsUnresolvedType(); + HasUnresolvedTypes |= Tree->ContainsUnresolvedType(*this); return !HasUnresolvedTypes; } @@ -2396,8 +2854,10 @@ void TreePattern::dump() const { print(errs()); } // CodeGenDAGPatterns implementation // -CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R) : - Records(R), Target(R) { +CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R, + PatternRewriterFn PatternRewriter) + : Records(R), Target(R), LegalVTS(Target.getLegalValueTypes()), + PatternRewriter(PatternRewriter) { Intrinsics = CodeGenIntrinsicTable(Records, false); TgtIntrinsics = CodeGenIntrinsicTable(Records, true); @@ -2410,6 +2870,11 @@ CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R) : ParsePatternFragments(/*OutFrags*/true); ParsePatterns(); + // Break patterns with parameterized types into a series of patterns, + // where each one has a fixed type and is predicated on the conditions + // of the associated HW mode. + ExpandHwModeBasedTypes(); + // Generate variants. For example, commutative patterns can match // multiple ways. Add them to PatternsToMatch as well. GenerateVariants(); @@ -2434,8 +2899,11 @@ Record *CodeGenDAGPatterns::getSDNodeNamed(const std::string &Name) const { // Parse all of the SDNode definitions for the target, populating SDNodes. void CodeGenDAGPatterns::ParseNodeInfo() { std::vector<Record*> Nodes = Records.getAllDerivedDefinitions("SDNode"); + const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); + while (!Nodes.empty()) { - SDNodes.insert(std::make_pair(Nodes.back(), Nodes.back())); + Record *R = Nodes.back(); + SDNodes.insert(std::make_pair(R, SDNodeInfo(R, CGH))); Nodes.pop_back(); } @@ -2489,7 +2957,10 @@ void CodeGenDAGPatterns::ParsePatternFragments(bool OutFrags) { // Validate the argument list, converting it to set, to discard duplicates. std::vector<std::string> &Args = P->getArgList(); - std::set<std::string> OperandsSet(Args.begin(), Args.end()); + // Copy the args so we can take StringRefs to them. + auto ArgsCopy = Args; + SmallDenseSet<StringRef, 4> OperandsSet; + OperandsSet.insert(ArgsCopy.begin(), ArgsCopy.end()); if (OperandsSet.count("")) P->error("Cannot have unnamed 'node' values in pattern fragment!"); @@ -2589,7 +3060,7 @@ void CodeGenDAGPatterns::ParseDefaultOperands() { while (TPN->ApplyTypeConstraints(P, false)) /* Resolve all types */; - if (TPN->ContainsUnresolvedType()) { + if (TPN->ContainsUnresolvedType(P)) { PrintFatalError("Value #" + Twine(i) + " of OperandWithDefaultOps '" + DefaultOps[i]->getName() + "' doesn't have a concrete type!"); @@ -2981,17 +3452,20 @@ const DAGInstruction &CodeGenDAGPatterns::parseInstructionPattern( // Verify that the top-level forms in the instruction are of void type, and // fill in the InstResults map. + SmallString<32> TypesString; for (unsigned j = 0, e = I->getNumTrees(); j != e; ++j) { + TypesString.clear(); TreePatternNode *Pat = I->getTree(j); if (Pat->getNumTypes() != 0) { - std::string Types; + raw_svector_ostream OS(TypesString); for (unsigned k = 0, ke = Pat->getNumTypes(); k != ke; ++k) { if (k > 0) - Types += ", "; - Types += Pat->getExtType(k).getName(); + OS << ", "; + Pat->getExtType(k).writeToStream(OS); } I->error("Top-level forms in instruction pattern should have" - " void types, has types " + Types); + " void types, has types " + + OS.str()); } // Find inputs and outputs, and verify the structure of the uses/defs. @@ -3174,6 +3648,8 @@ void CodeGenDAGPatterns::ParseInstructions() { TreePattern *I = TheInst.getPattern(); if (!I) continue; // No pattern. + if (PatternRewriter) + PatternRewriter(I); // FIXME: Assume only the first tree is the pattern. The others are clobber // nodes. TreePatternNode *Pattern = I->getTree(0); @@ -3186,14 +3662,13 @@ void CodeGenDAGPatterns::ParseInstructions() { } Record *Instr = Entry.first; - AddPatternToMatch(I, - PatternToMatch(Instr, - Instr->getValueAsListInit("Predicates"), - SrcPattern, - TheInst.getResultPattern(), - TheInst.getImpResults(), - Instr->getValueAsInt("AddedComplexity"), - Instr->getID())); + ListInit *Preds = Instr->getValueAsListInit("Predicates"); + int Complexity = Instr->getValueAsInt("AddedComplexity"); + AddPatternToMatch( + I, + PatternToMatch(Instr, makePredList(Preds), SrcPattern, + TheInst.getResultPattern(), TheInst.getImpResults(), + Complexity, Instr->getID())); } } @@ -3219,6 +3694,20 @@ static void FindNames(const TreePatternNode *P, } } +std::vector<Predicate> CodeGenDAGPatterns::makePredList(ListInit *L) { + std::vector<Predicate> Preds; + for (Init *I : L->getValues()) { + if (DefInit *Pred = dyn_cast<DefInit>(I)) + Preds.push_back(Pred->getDef()); + else + llvm_unreachable("Non-def on the list"); + } + + // Sort so that different orders get canonicalized to the same string. + std::sort(Preds.begin(), Preds.end()); + return Preds; +} + void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern, PatternToMatch &&PTM) { // Do some sanity checking on the pattern we're about to match. @@ -3262,8 +3751,6 @@ void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern, PatternsToMatch.push_back(std::move(PTM)); } - - void CodeGenDAGPatterns::InferInstructionFlags() { ArrayRef<const CodeGenInstruction*> Instructions = Target.getInstructionsByEnumValue(); @@ -3425,12 +3912,13 @@ static bool ForceArbitraryInstResultType(TreePatternNode *N, TreePattern &TP) { // If this type is already concrete or completely unknown we can't do // anything. + TypeInfer &TI = TP.getInfer(); for (unsigned i = 0, e = N->getNumTypes(); i != e; ++i) { - if (N->getExtType(i).isCompletelyUnknown() || N->getExtType(i).isConcrete()) + if (N->getExtType(i).empty() || TI.isConcrete(N->getExtType(i), false)) continue; - // Otherwise, force its type to the first possibility (an arbitrary choice). - if (N->getExtType(i).MergeInTypeInfo(N->getExtType(i).getTypeList()[0], TP)) + // Otherwise, force its type to an arbitrary choice. + if (TI.forceArbitrary(N->getExtType(i))) return true; } @@ -3551,15 +4039,156 @@ void CodeGenDAGPatterns::ParsePatterns() { TreePattern Temp(Result.getRecord(), DstPattern, false, *this); Temp.InferAllTypes(); - AddPatternToMatch( - Pattern, - PatternToMatch( - CurPattern, CurPattern->getValueAsListInit("Predicates"), - Pattern->getTree(0), Temp.getOnlyTree(), std::move(InstImpResults), - CurPattern->getValueAsInt("AddedComplexity"), CurPattern->getID())); + // A pattern may end up with an "impossible" type, i.e. a situation + // where all types have been eliminated for some node in this pattern. + // This could occur for intrinsics that only make sense for a specific + // value type, and use a specific register class. If, for some mode, + // that register class does not accept that type, the type inference + // will lead to a contradiction, which is not an error however, but + // a sign that this pattern will simply never match. + if (Pattern->getTree(0)->hasPossibleType() && + Temp.getOnlyTree()->hasPossibleType()) { + ListInit *Preds = CurPattern->getValueAsListInit("Predicates"); + int Complexity = CurPattern->getValueAsInt("AddedComplexity"); + if (PatternRewriter) + PatternRewriter(Pattern); + AddPatternToMatch( + Pattern, + PatternToMatch( + CurPattern, makePredList(Preds), Pattern->getTree(0), + Temp.getOnlyTree(), std::move(InstImpResults), Complexity, + CurPattern->getID())); + } } } +static void collectModes(std::set<unsigned> &Modes, const TreePatternNode *N) { + for (const TypeSetByHwMode &VTS : N->getExtTypes()) + for (const auto &I : VTS) + Modes.insert(I.first); + + for (unsigned i = 0, e = N->getNumChildren(); i != e; ++i) + collectModes(Modes, N->getChild(i)); +} + +void CodeGenDAGPatterns::ExpandHwModeBasedTypes() { + const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); + std::map<unsigned,std::vector<Predicate>> ModeChecks; + std::vector<PatternToMatch> Copy = PatternsToMatch; + PatternsToMatch.clear(); + + auto AppendPattern = [this,&ModeChecks](PatternToMatch &P, unsigned Mode) { + TreePatternNode *NewSrc = P.SrcPattern->clone(); + TreePatternNode *NewDst = P.DstPattern->clone(); + if (!NewSrc->setDefaultMode(Mode) || !NewDst->setDefaultMode(Mode)) { + delete NewSrc; + delete NewDst; + return; + } + + std::vector<Predicate> Preds = P.Predicates; + const std::vector<Predicate> &MC = ModeChecks[Mode]; + Preds.insert(Preds.end(), MC.begin(), MC.end()); + PatternsToMatch.emplace_back(P.getSrcRecord(), Preds, NewSrc, NewDst, + P.getDstRegs(), P.getAddedComplexity(), + Record::getNewUID(), Mode); + }; + + for (PatternToMatch &P : Copy) { + TreePatternNode *SrcP = nullptr, *DstP = nullptr; + if (P.SrcPattern->hasProperTypeByHwMode()) + SrcP = P.SrcPattern; + if (P.DstPattern->hasProperTypeByHwMode()) + DstP = P.DstPattern; + if (!SrcP && !DstP) { + PatternsToMatch.push_back(P); + continue; + } + + std::set<unsigned> Modes; + if (SrcP) + collectModes(Modes, SrcP); + if (DstP) + collectModes(Modes, DstP); + + // The predicate for the default mode needs to be constructed for each + // pattern separately. + // Since not all modes must be present in each pattern, if a mode m is + // absent, then there is no point in constructing a check for m. If such + // a check was created, it would be equivalent to checking the default + // mode, except not all modes' predicates would be a part of the checking + // code. The subsequently generated check for the default mode would then + // have the exact same patterns, but a different predicate code. To avoid + // duplicated patterns with different predicate checks, construct the + // default check as a negation of all predicates that are actually present + // in the source/destination patterns. + std::vector<Predicate> DefaultPred; + + for (unsigned M : Modes) { + if (M == DefaultMode) + continue; + if (ModeChecks.find(M) != ModeChecks.end()) + continue; + + // Fill the map entry for this mode. + const HwMode &HM = CGH.getMode(M); + ModeChecks[M].emplace_back(Predicate(HM.Features, true)); + + // Add negations of the HM's predicates to the default predicate. + DefaultPred.emplace_back(Predicate(HM.Features, false)); + } + + for (unsigned M : Modes) { + if (M == DefaultMode) + continue; + AppendPattern(P, M); + } + + bool HasDefault = Modes.count(DefaultMode); + if (HasDefault) + AppendPattern(P, DefaultMode); + } +} + +/// Dependent variable map for CodeGenDAGPattern variant generation +typedef StringMap<int> DepVarMap; + +static void FindDepVarsOf(TreePatternNode *N, DepVarMap &DepMap) { + if (N->isLeaf()) { + if (N->hasName() && isa<DefInit>(N->getLeafValue())) + DepMap[N->getName()]++; + } else { + for (size_t i = 0, e = N->getNumChildren(); i != e; ++i) + FindDepVarsOf(N->getChild(i), DepMap); + } +} + +/// Find dependent variables within child patterns +static void FindDepVars(TreePatternNode *N, MultipleUseVarSet &DepVars) { + DepVarMap depcounts; + FindDepVarsOf(N, depcounts); + for (const auto &Pair : depcounts) { + if (Pair.getValue() > 1) + DepVars.insert(Pair.getKey()); + } +} + +#ifndef NDEBUG +/// Dump the dependent variable set: +static void DumpDepVars(MultipleUseVarSet &DepVars) { + if (DepVars.empty()) { + DEBUG(errs() << "<empty set>"); + } else { + DEBUG(errs() << "[ "); + for (const auto &DepVar : DepVars) { + DEBUG(errs() << DepVar.getKey() << " "); + } + DEBUG(errs() << "]"); + } +} +#endif + + /// CombineChildVariants - Given a bunch of permutations of each child of the /// 'operator' node, put them together in all possible ways. static void CombineChildVariants(TreePatternNode *Orig, @@ -3744,7 +4373,7 @@ static void GenerateVariantsOf(TreePatternNode *N, // If this node is commutative, consider the commuted order. bool isCommIntrinsic = N->isCommutativeIntrinsic(CDP); if (NodeInfo.hasProperty(SDNPCommutative) || isCommIntrinsic) { - assert((N->getNumChildren()==2 || isCommIntrinsic) && + assert((N->getNumChildren()>=2 || isCommIntrinsic) && "Commutative but doesn't have 2 children!"); // Don't count children which are actually register references. unsigned NC = 0; @@ -3772,9 +4401,14 @@ static void GenerateVariantsOf(TreePatternNode *N, for (unsigned i = 3; i != NC; ++i) Variants.push_back(ChildVariants[i]); CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); - } else if (NC == 2) - CombineChildVariants(N, ChildVariants[1], ChildVariants[0], - OutVariants, CDP, DepVars); + } else if (NC == N->getNumChildren()) { + std::vector<std::vector<TreePatternNode*> > Variants; + Variants.push_back(ChildVariants[1]); + Variants.push_back(ChildVariants[0]); + for (unsigned i = 2; i != NC; ++i) + Variants.push_back(ChildVariants[i]); + CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); + } } } |