diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 | 
| commit | 044eb2f6afba375a914ac9d8024f8f5142bb912e (patch) | |
| tree | 1475247dc9f9fe5be155ebd4c9069c75aadf8c20 /utils/TableGen/CodeGenDAGPatterns.cpp | |
| parent | eb70dddbd77e120e5d490bd8fbe7ff3f8fa81c6b (diff) | |
Notes
Diffstat (limited to 'utils/TableGen/CodeGenDAGPatterns.cpp')
| -rw-r--r-- | utils/TableGen/CodeGenDAGPatterns.cpp | 2170 | 
1 files changed, 1402 insertions, 768 deletions
| diff --git a/utils/TableGen/CodeGenDAGPatterns.cpp b/utils/TableGen/CodeGenDAGPatterns.cpp index e48ba3845326..51473f06da79 100644 --- a/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/utils/TableGen/CodeGenDAGPatterns.cpp @@ -13,9 +13,12 @@  //===----------------------------------------------------------------------===//  #include "CodeGenDAGPatterns.h" +#include "llvm/ADT/DenseSet.h"  #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h"  #include "llvm/ADT/SmallString.h"  #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h"  #include "llvm/ADT/Twine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h" @@ -28,752 +31,1116 @@ using namespace llvm;  #define DEBUG_TYPE "dag-patterns" -//===----------------------------------------------------------------------===// -//  EEVT::TypeSet Implementation -//===----------------------------------------------------------------------===// - -static inline bool isInteger(MVT::SimpleValueType VT) { -  return MVT(VT).isInteger(); +static inline bool isIntegerOrPtr(MVT VT) { +  return VT.isInteger() || VT == MVT::iPTR;  } -static inline bool isFloatingPoint(MVT::SimpleValueType VT) { -  return MVT(VT).isFloatingPoint(); +static inline bool isFloatingPoint(MVT VT) { +  return VT.isFloatingPoint();  } -static inline bool isVector(MVT::SimpleValueType VT) { -  return MVT(VT).isVector(); +static inline bool isVector(MVT VT) { +  return VT.isVector();  } -static inline bool isScalar(MVT::SimpleValueType VT) { -  return !MVT(VT).isVector(); +static inline bool isScalar(MVT VT) { +  return !VT.isVector();  } -EEVT::TypeSet::TypeSet(MVT::SimpleValueType VT, TreePattern &TP) { -  if (VT == MVT::iAny) -    EnforceInteger(TP); -  else if (VT == MVT::fAny) -    EnforceFloatingPoint(TP); -  else if (VT == MVT::vAny) -    EnforceVector(TP); -  else { -    assert((VT < MVT::LAST_VALUETYPE || VT == MVT::iPTR || -            VT == MVT::iPTRAny || VT == MVT::Any) && "Not a concrete type!"); -    TypeVec.push_back(VT); +template <typename Predicate> +static bool berase_if(MachineValueTypeSet &S, Predicate P) { +  bool Erased = false; +  // It is ok to iterate over MachineValueTypeSet and remove elements from it +  // at the same time. +  for (MVT T : S) { +    if (!P(T)) +      continue; +    Erased = true; +    S.erase(T);    } +  return Erased;  } +// --- TypeSetByHwMode -EEVT::TypeSet::TypeSet(ArrayRef<MVT::SimpleValueType> VTList) { -  assert(!VTList.empty() && "empty list?"); -  TypeVec.append(VTList.begin(), VTList.end()); - -  if (!VTList.empty()) -    assert(VTList[0] != MVT::iAny && VTList[0] != MVT::vAny && -           VTList[0] != MVT::fAny); +// This is a parameterized type-set class. For each mode there is a list +// of types that are currently possible for a given tree node. Type +// inference will apply to each mode separately. -  // Verify no duplicates. -  array_pod_sort(TypeVec.begin(), TypeVec.end()); -  assert(std::unique(TypeVec.begin(), TypeVec.end()) == TypeVec.end()); +TypeSetByHwMode::TypeSetByHwMode(ArrayRef<ValueTypeByHwMode> VTList) { +  for (const ValueTypeByHwMode &VVT : VTList) +    insert(VVT);  } -/// FillWithPossibleTypes - Set to all legal types and return true, only valid -/// on completely unknown type sets. -bool EEVT::TypeSet::FillWithPossibleTypes(TreePattern &TP, -                                          bool (*Pred)(MVT::SimpleValueType), -                                          const char *PredicateName) { -  assert(isCompletelyUnknown()); -  ArrayRef<MVT::SimpleValueType> LegalTypes = -    TP.getDAGPatterns().getTargetInfo().getLegalValueTypes(); - -  if (TP.hasError()) -    return false; - -  for (MVT::SimpleValueType VT : LegalTypes) -    if (!Pred || Pred(VT)) -      TypeVec.push_back(VT); - -  // If we have nothing that matches the predicate, bail out. -  if (TypeVec.empty()) { -    TP.error("Type inference contradiction found, no " + -             std::string(PredicateName) + " types found"); -    return false; +bool TypeSetByHwMode::isValueTypeByHwMode(bool AllowEmpty) const { +  for (const auto &I : *this) { +    if (I.second.size() > 1) +      return false; +    if (!AllowEmpty && I.second.empty()) +      return false;    } -  // No need to sort with one element. -  if (TypeVec.size() == 1) return true; - -  // Remove duplicates. -  array_pod_sort(TypeVec.begin(), TypeVec.end()); -  TypeVec.erase(std::unique(TypeVec.begin(), TypeVec.end()), TypeVec.end()); -    return true;  } -/// hasIntegerTypes - Return true if this TypeSet contains iAny or an -/// integer value type. -bool EEVT::TypeSet::hasIntegerTypes() const { -  return any_of(TypeVec, isInteger); +ValueTypeByHwMode TypeSetByHwMode::getValueTypeByHwMode() const { +  assert(isValueTypeByHwMode(true) && +         "The type set has multiple types for at least one HW mode"); +  ValueTypeByHwMode VVT; +  for (const auto &I : *this) { +    MVT T = I.second.empty() ? MVT::Other : *I.second.begin(); +    VVT.getOrCreateTypeForMode(I.first, T); +  } +  return VVT;  } -/// hasFloatingPointTypes - Return true if this TypeSet contains an fAny or -/// a floating point value type. -bool EEVT::TypeSet::hasFloatingPointTypes() const { -  return any_of(TypeVec, isFloatingPoint); +bool TypeSetByHwMode::isPossible() const { +  for (const auto &I : *this) +    if (!I.second.empty()) +      return true; +  return false;  } -/// hasScalarTypes - Return true if this TypeSet contains a scalar value type. -bool EEVT::TypeSet::hasScalarTypes() const { -  return any_of(TypeVec, isScalar); +bool TypeSetByHwMode::insert(const ValueTypeByHwMode &VVT) { +  bool Changed = false; +  SmallDenseSet<unsigned, 4> Modes; +  for (const auto &P : VVT) { +    unsigned M = P.first; +    Modes.insert(M); +    // Make sure there exists a set for each specific mode from VVT. +    Changed |= getOrCreate(M).insert(P.second).second; +  } + +  // If VVT has a default mode, add the corresponding type to all +  // modes in "this" that do not exist in VVT. +  if (Modes.count(DefaultMode)) { +    MVT DT = VVT.getType(DefaultMode); +    for (auto &I : *this) +      if (!Modes.count(I.first)) +        Changed |= I.second.insert(DT).second; +  } +  return Changed;  } -/// hasVectorTypes - Return true if this TypeSet contains a vAny or a vector -/// value type. -bool EEVT::TypeSet::hasVectorTypes() const { -  return any_of(TypeVec, isVector); +// Constrain the type set to be the intersection with VTS. +bool TypeSetByHwMode::constrain(const TypeSetByHwMode &VTS) { +  bool Changed = false; +  if (hasDefault()) { +    for (const auto &I : VTS) { +      unsigned M = I.first; +      if (M == DefaultMode || hasMode(M)) +        continue; +      Map.insert({M, Map.at(DefaultMode)}); +      Changed = true; +    } +  } + +  for (auto &I : *this) { +    unsigned M = I.first; +    SetType &S = I.second; +    if (VTS.hasMode(M) || VTS.hasDefault()) { +      Changed |= intersect(I.second, VTS.get(M)); +    } else if (!S.empty()) { +      S.clear(); +      Changed = true; +    } +  } +  return Changed;  } +template <typename Predicate> +bool TypeSetByHwMode::constrain(Predicate P) { +  bool Changed = false; +  for (auto &I : *this) +    Changed |= berase_if(I.second, [&P](MVT VT) { return !P(VT); }); +  return Changed; +} -std::string EEVT::TypeSet::getName() const { -  if (TypeVec.empty()) return "<empty>"; +template <typename Predicate> +bool TypeSetByHwMode::assign_if(const TypeSetByHwMode &VTS, Predicate P) { +  assert(empty()); +  for (const auto &I : VTS) { +    SetType &S = getOrCreate(I.first); +    for (auto J : I.second) +      if (P(J)) +        S.insert(J); +  } +  return !empty(); +} -  std::string Result; +void TypeSetByHwMode::writeToStream(raw_ostream &OS) const { +  SmallVector<unsigned, 4> Modes; +  Modes.reserve(Map.size()); -  for (unsigned i = 0, e = TypeVec.size(); i != e; ++i) { -    std::string VTName = llvm::getEnumName(TypeVec[i]); -    // Strip off MVT:: prefix if present. -    if (VTName.substr(0,5) == "MVT::") -      VTName = VTName.substr(5); -    if (i) Result += ':'; -    Result += VTName; +  for (const auto &I : *this) +    Modes.push_back(I.first); +  if (Modes.empty()) { +    OS << "{}"; +    return;    } +  array_pod_sort(Modes.begin(), Modes.end()); -  if (TypeVec.size() == 1) -    return Result; -  return "{" + Result + "}"; +  OS << '{'; +  for (unsigned M : Modes) { +    OS << ' ' << getModeName(M) << ':'; +    writeToStream(get(M), OS); +  } +  OS << " }";  } -/// MergeInTypeInfo - This merges in type information from the specified -/// argument.  If 'this' changes, it returns true.  If the two types are -/// contradictory (e.g. merge f32 into i32) then this flags an error. -bool EEVT::TypeSet::MergeInTypeInfo(const EEVT::TypeSet &InVT, TreePattern &TP){ -  if (InVT.isCompletelyUnknown() || *this == InVT || TP.hasError()) -    return false; +void TypeSetByHwMode::writeToStream(const SetType &S, raw_ostream &OS) { +  SmallVector<MVT, 4> Types(S.begin(), S.end()); +  array_pod_sort(Types.begin(), Types.end()); -  if (isCompletelyUnknown()) { -    *this = InVT; -    return true; +  OS << '['; +  for (unsigned i = 0, e = Types.size(); i != e; ++i) { +    OS << ValueTypeByHwMode::getMVTName(Types[i]); +    if (i != e-1) +      OS << ' ';    } +  OS << ']'; +} -  assert(!TypeVec.empty() && !InVT.TypeVec.empty() && "No unknowns"); +bool TypeSetByHwMode::operator==(const TypeSetByHwMode &VTS) const { +  bool HaveDefault = hasDefault(); +  if (HaveDefault != VTS.hasDefault()) +    return false; -  // Handle the abstract cases, seeing if we can resolve them better. -  switch (TypeVec[0]) { -  default: break; -  case MVT::iPTR: -  case MVT::iPTRAny: -    if (InVT.hasIntegerTypes()) { -      EEVT::TypeSet InCopy(InVT); -      InCopy.EnforceInteger(TP); -      InCopy.EnforceScalar(TP); +  if (isSimple()) { +    if (VTS.isSimple()) +      return *begin() == *VTS.begin(); +    return false; +  } -      if (InCopy.isConcrete()) { -        // If the RHS has one integer type, upgrade iPTR to i32. -        TypeVec[0] = InVT.TypeVec[0]; -        return true; -      } +  SmallDenseSet<unsigned, 4> Modes; +  for (auto &I : *this) +    Modes.insert(I.first); +  for (const auto &I : VTS) +    Modes.insert(I.first); -      // If the input has multiple scalar integers, this doesn't add any info. -      if (!InCopy.isCompletelyUnknown()) +  if (HaveDefault) { +    // Both sets have default mode. +    for (unsigned M : Modes) { +      if (get(M) != VTS.get(M)) +        return false; +    } +  } else { +    // Neither set has default mode. +    for (unsigned M : Modes) { +      // If there is no default mode, an empty set is equivalent to not having +      // the corresponding mode. +      bool NoModeThis = !hasMode(M) || get(M).empty(); +      bool NoModeVTS = !VTS.hasMode(M) || VTS.get(M).empty(); +      if (NoModeThis != NoModeVTS)          return false; +      if (!NoModeThis) +        if (get(M) != VTS.get(M)) +          return false;      } -    break;    } -  // If the input constraint is iAny/iPTR and this is an integer type list, -  // remove non-integer types from the list. -  if ((InVT.TypeVec[0] == MVT::iPTR || InVT.TypeVec[0] == MVT::iPTRAny) && -      hasIntegerTypes()) { -    bool MadeChange = EnforceInteger(TP); +  return true; +} -    // If we're merging in iPTR/iPTRAny and the node currently has a list of -    // multiple different integer types, replace them with a single iPTR. -    if ((InVT.TypeVec[0] == MVT::iPTR || InVT.TypeVec[0] == MVT::iPTRAny) && -        TypeVec.size() != 1) { -      TypeVec.assign(1, InVT.TypeVec[0]); -      MadeChange = true; -    } +namespace llvm { +  raw_ostream &operator<<(raw_ostream &OS, const TypeSetByHwMode &T) { +    T.writeToStream(OS); +    return OS; +  } +} -    return MadeChange; +LLVM_DUMP_METHOD +void TypeSetByHwMode::dump() const { +  dbgs() << *this << '\n'; +} + +bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) { +  bool OutP = Out.count(MVT::iPTR), InP = In.count(MVT::iPTR); +  auto Int = [&In](MVT T) -> bool { return !In.count(T); }; + +  if (OutP == InP) +    return berase_if(Out, Int); + +  // Compute the intersection of scalars separately to account for only +  // one set containing iPTR. +  // The itersection of iPTR with a set of integer scalar types that does not +  // include iPTR will result in the most specific scalar type: +  // - iPTR is more specific than any set with two elements or more +  // - iPTR is less specific than any single integer scalar type. +  // For example +  // { iPTR } * { i32 }     -> { i32 } +  // { iPTR } * { i32 i64 } -> { iPTR } +  // and +  // { iPTR i32 } * { i32 }          -> { i32 } +  // { iPTR i32 } * { i32 i64 }      -> { i32 i64 } +  // { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 } + +  // Compute the difference between the two sets in such a way that the +  // iPTR is in the set that is being subtracted. This is to see if there +  // are any extra scalars in the set without iPTR that are not in the +  // set containing iPTR. Then the iPTR could be considered a "wildcard" +  // matching these scalars. If there is only one such scalar, it would +  // replace the iPTR, if there are more, the iPTR would be retained. +  SetType Diff; +  if (InP) { +    Diff = Out; +    berase_if(Diff, [&In](MVT T) { return In.count(T); }); +    // Pre-remove these elements and rely only on InP/OutP to determine +    // whether a change has been made. +    berase_if(Out, [&Diff](MVT T) { return Diff.count(T); }); +  } else { +    Diff = In; +    berase_if(Diff, [&Out](MVT T) { return Out.count(T); }); +    Out.erase(MVT::iPTR); +  } + +  // The actual intersection. +  bool Changed = berase_if(Out, Int); +  unsigned NumD = Diff.size(); +  if (NumD == 0) +    return Changed; + +  if (NumD == 1) { +    Out.insert(*Diff.begin()); +    // This is a change only if Out was the one with iPTR (which is now +    // being replaced). +    Changed |= OutP; +  } else { +    // Multiple elements from Out are now replaced with iPTR. +    Out.insert(MVT::iPTR); +    Changed |= !OutP;    } +  return Changed; +} -  // If this is a type list and the RHS is a typelist as well, eliminate entries -  // from this list that aren't in the other one. -  TypeSet InputSet(*this); +void TypeSetByHwMode::validate() const { +#ifndef NDEBUG +  if (empty()) +    return; +  bool AllEmpty = true; +  for (const auto &I : *this) +    AllEmpty &= I.second.empty(); +  assert(!AllEmpty && +          "type set is empty for each HW mode: type contradiction?"); +#endif +} -  TypeVec.clear(); -  std::set_intersection(InputSet.TypeVec.begin(), InputSet.TypeVec.end(), -                        InVT.TypeVec.begin(), InVT.TypeVec.end(), -                        std::back_inserter(TypeVec)); +// --- TypeInfer -  // If the intersection is the same size as the original set then we're done. -  if (TypeVec.size() == InputSet.TypeVec.size()) +bool TypeInfer::MergeInTypeInfo(TypeSetByHwMode &Out, +                                const TypeSetByHwMode &In) { +  ValidateOnExit _1(Out); +  In.validate(); +  if (In.empty() || Out == In || TP.hasError())      return false; - -  // If we removed all of our types, we have a type contradiction. -  if (!TypeVec.empty()) +  if (Out.empty()) { +    Out = In;      return true; +  } -  // FIXME: Really want an SMLoc here! -  TP.error("Type inference contradiction found, merging '" + -           InVT.getName() + "' into '" + InputSet.getName() + "'"); -  return false; +  bool Changed = Out.constrain(In); +  if (Changed && Out.empty()) +    TP.error("Type contradiction"); + +  return Changed;  } -/// EnforceInteger - Remove all non-integer types from this set. -bool EEVT::TypeSet::EnforceInteger(TreePattern &TP) { +bool TypeInfer::forceArbitrary(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out);    if (TP.hasError())      return false; -  // If we know nothing, then get the full set. -  if (TypeVec.empty()) -    return FillWithPossibleTypes(TP, isInteger, "integer"); - -  if (!hasFloatingPointTypes()) -    return false; +  assert(!Out.empty() && "cannot pick from an empty set"); -  TypeSet InputSet(*this); - -  // Filter out all the fp types. -  TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isInteger))), -                TypeVec.end()); - -  if (TypeVec.empty()) { -    TP.error("Type inference contradiction found, '" + -             InputSet.getName() + "' needs to be integer"); -    return false; +  bool Changed = false; +  for (auto &I : Out) { +    TypeSetByHwMode::SetType &S = I.second; +    if (S.size() <= 1) +      continue; +    MVT T = *S.begin(); // Pick the first element. +    S.clear(); +    S.insert(T); +    Changed = true;    } -  return true; +  return Changed;  } -/// EnforceFloatingPoint - Remove all integer types from this set. -bool EEVT::TypeSet::EnforceFloatingPoint(TreePattern &TP) { +bool TypeInfer::EnforceInteger(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out);    if (TP.hasError())      return false; -  // If we know nothing, then get the full set. -  if (TypeVec.empty()) -    return FillWithPossibleTypes(TP, isFloatingPoint, "floating point"); +  if (!Out.empty()) +    return Out.constrain(isIntegerOrPtr); -  if (!hasIntegerTypes()) -    return false; +  return Out.assign_if(getLegalTypes(), isIntegerOrPtr); +} -  TypeSet InputSet(*this); +bool TypeInfer::EnforceFloatingPoint(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out); +  if (TP.hasError()) +    return false; +  if (!Out.empty()) +    return Out.constrain(isFloatingPoint); -  // Filter out all the integer types. -  TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isFloatingPoint))), -                TypeVec.end()); +  return Out.assign_if(getLegalTypes(), isFloatingPoint); +} -  if (TypeVec.empty()) { -    TP.error("Type inference contradiction found, '" + -             InputSet.getName() + "' needs to be floating point"); +bool TypeInfer::EnforceScalar(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out); +  if (TP.hasError())      return false; -  } -  return true; +  if (!Out.empty()) +    return Out.constrain(isScalar); + +  return Out.assign_if(getLegalTypes(), isScalar);  } -/// EnforceScalar - Remove all vector types from this. -bool EEVT::TypeSet::EnforceScalar(TreePattern &TP) { +bool TypeInfer::EnforceVector(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out);    if (TP.hasError())      return false; +  if (!Out.empty()) +    return Out.constrain(isVector); -  // If we know nothing, then get the full set. -  if (TypeVec.empty()) -    return FillWithPossibleTypes(TP, isScalar, "scalar"); +  return Out.assign_if(getLegalTypes(), isVector); +} -  if (!hasVectorTypes()) +bool TypeInfer::EnforceAny(TypeSetByHwMode &Out) { +  ValidateOnExit _1(Out); +  if (TP.hasError() || !Out.empty())      return false; -  TypeSet InputSet(*this); +  Out = getLegalTypes(); +  return true; +} -  // Filter out all the vector types. -  TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isScalar))), -                TypeVec.end()); +template <typename Iter, typename Pred, typename Less> +static Iter min_if(Iter B, Iter E, Pred P, Less L) { +  if (B == E) +    return E; +  Iter Min = E; +  for (Iter I = B; I != E; ++I) { +    if (!P(*I)) +      continue; +    if (Min == E || L(*I, *Min)) +      Min = I; +  } +  return Min; +} -  if (TypeVec.empty()) { -    TP.error("Type inference contradiction found, '" + -             InputSet.getName() + "' needs to be scalar"); -    return false; +template <typename Iter, typename Pred, typename Less> +static Iter max_if(Iter B, Iter E, Pred P, Less L) { +  if (B == E) +    return E; +  Iter Max = E; +  for (Iter I = B; I != E; ++I) { +    if (!P(*I)) +      continue; +    if (Max == E || L(*Max, *I)) +      Max = I;    } -  return true; +  return Max;  } -/// EnforceVector - Remove all vector types from this. -bool EEVT::TypeSet::EnforceVector(TreePattern &TP) { +/// Make sure that for each type in Small, there exists a larger type in Big. +bool TypeInfer::EnforceSmallerThan(TypeSetByHwMode &Small, +                                   TypeSetByHwMode &Big) { +  ValidateOnExit _1(Small), _2(Big);    if (TP.hasError())      return false; +  bool Changed = false; + +  if (Small.empty()) +    Changed |= EnforceAny(Small); +  if (Big.empty()) +    Changed |= EnforceAny(Big); + +  assert(Small.hasDefault() && Big.hasDefault()); + +  std::vector<unsigned> Modes = union_modes(Small, Big); + +  // 1. Only allow integer or floating point types and make sure that +  //    both sides are both integer or both floating point. +  // 2. Make sure that either both sides have vector types, or neither +  //    of them does. +  for (unsigned M : Modes) { +    TypeSetByHwMode::SetType &S = Small.get(M); +    TypeSetByHwMode::SetType &B = Big.get(M); + +    if (any_of(S, isIntegerOrPtr) && any_of(S, isIntegerOrPtr)) { +      auto NotInt = [](MVT VT) { return !isIntegerOrPtr(VT); }; +      Changed |= berase_if(S, NotInt) | +                 berase_if(B, NotInt); +    } else if (any_of(S, isFloatingPoint) && any_of(B, isFloatingPoint)) { +      auto NotFP = [](MVT VT) { return !isFloatingPoint(VT); }; +      Changed |= berase_if(S, NotFP) | +                 berase_if(B, NotFP); +    } else if (S.empty() || B.empty()) { +      Changed = !S.empty() || !B.empty(); +      S.clear(); +      B.clear(); +    } else { +      TP.error("Incompatible types"); +      return Changed; +    } -  // If we know nothing, then get the full set. -  if (TypeVec.empty()) -    return FillWithPossibleTypes(TP, isVector, "vector"); - -  TypeSet InputSet(*this); -  bool MadeChange = false; - -  // Filter out all the scalar types. -  TypeVec.erase(remove_if(TypeVec, std::not1(std::ptr_fun(isVector))), -                TypeVec.end()); - -  if (TypeVec.empty()) { -    TP.error("Type inference contradiction found, '" + -             InputSet.getName() + "' needs to be a vector"); -    return false; +    if (none_of(S, isVector) || none_of(B, isVector)) { +      Changed |= berase_if(S, isVector) | +                 berase_if(B, isVector); +    }    } -  return MadeChange; -} +  auto LT = [](MVT A, MVT B) -> bool { +    return A.getScalarSizeInBits() < B.getScalarSizeInBits() || +           (A.getScalarSizeInBits() == B.getScalarSizeInBits() && +            A.getSizeInBits() < B.getSizeInBits()); +  }; +  auto LE = [](MVT A, MVT B) -> bool { +    // This function is used when removing elements: when a vector is compared +    // to a non-vector, it should return false (to avoid removal). +    if (A.isVector() != B.isVector()) +      return false; +    // Note on the < comparison below: +    // X86 has patterns like +    //   (set VR128X:$dst, (v16i8 (X86vtrunc (v4i32 VR128X:$src1)))), +    // where the truncated vector is given a type v16i8, while the source +    // vector has type v4i32. They both have the same size in bits. +    // The minimal type in the result is obviously v16i8, and when we remove +    // all types from the source that are smaller-or-equal than v8i16, the +    // only source type would also be removed (since it's equal in size). +    return A.getScalarSizeInBits() <= B.getScalarSizeInBits() || +           A.getSizeInBits() < B.getSizeInBits(); +  }; + +  for (unsigned M : Modes) { +    TypeSetByHwMode::SetType &S = Small.get(M); +    TypeSetByHwMode::SetType &B = Big.get(M); +    // MinS = min scalar in Small, remove all scalars from Big that are +    // smaller-or-equal than MinS. +    auto MinS = min_if(S.begin(), S.end(), isScalar, LT); +    if (MinS != S.end()) +      Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinS)); + +    // MaxS = max scalar in Big, remove all scalars from Small that are +    // larger than MaxS. +    auto MaxS = max_if(B.begin(), B.end(), isScalar, LT); +    if (MaxS != B.end()) +      Changed |= berase_if(S, std::bind(LE, *MaxS, std::placeholders::_1)); + +    // MinV = min vector in Small, remove all vectors from Big that are +    // smaller-or-equal than MinV. +    auto MinV = min_if(S.begin(), S.end(), isVector, LT); +    if (MinV != S.end()) +      Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinV)); + +    // MaxV = max vector in Big, remove all vectors from Small that are +    // larger than MaxV. +    auto MaxV = max_if(B.begin(), B.end(), isVector, LT); +    if (MaxV != B.end()) +      Changed |= berase_if(S, std::bind(LE, *MaxV, std::placeholders::_1)); +  } + +  return Changed; +} -/// EnforceSmallerThan - 'this' must be a smaller VT than Other. For vectors -/// this should be based on the element type. Update this and other based on -/// this information. -bool EEVT::TypeSet::EnforceSmallerThan(EEVT::TypeSet &Other, TreePattern &TP) { +/// 1. Ensure that for each type T in Vec, T is a vector type, and that +///    for each type U in Elem, U is a scalar type. +/// 2. Ensure that for each (scalar) type U in Elem, there exists a (vector) +///    type T in Vec, such that U is the element type of T. +bool TypeInfer::EnforceVectorEltTypeIs(TypeSetByHwMode &Vec, +                                       TypeSetByHwMode &Elem) { +  ValidateOnExit _1(Vec), _2(Elem);    if (TP.hasError())      return false; +  bool Changed = false; + +  if (Vec.empty()) +    Changed |= EnforceVector(Vec); +  if (Elem.empty()) +    Changed |= EnforceScalar(Elem); + +  for (unsigned M : union_modes(Vec, Elem)) { +    TypeSetByHwMode::SetType &V = Vec.get(M); +    TypeSetByHwMode::SetType &E = Elem.get(M); + +    Changed |= berase_if(V, isScalar);  // Scalar = !vector +    Changed |= berase_if(E, isVector);  // Vector = !scalar +    assert(!V.empty() && !E.empty()); + +    SmallSet<MVT,4> VT, ST; +    // Collect element types from the "vector" set. +    for (MVT T : V) +      VT.insert(T.getVectorElementType()); +    // Collect scalar types from the "element" set. +    for (MVT T : E) +      ST.insert(T); + +    // Remove from V all (vector) types whose element type is not in S. +    Changed |= berase_if(V, [&ST](MVT T) -> bool { +                              return !ST.count(T.getVectorElementType()); +                            }); +    // Remove from E all (scalar) types, for which there is no corresponding +    // type in V. +    Changed |= berase_if(E, [&VT](MVT T) -> bool { return !VT.count(T); }); +  } + +  return Changed; +} -  // Both operands must be integer or FP, but we don't care which. -  bool MadeChange = false; - -  if (isCompletelyUnknown()) -    MadeChange = FillWithPossibleTypes(TP); - -  if (Other.isCompletelyUnknown()) -    MadeChange = Other.FillWithPossibleTypes(TP); - -  // If one side is known to be integer or known to be FP but the other side has -  // no information, get at least the type integrality info in there. -  if (!hasFloatingPointTypes()) -    MadeChange |= Other.EnforceInteger(TP); -  else if (!hasIntegerTypes()) -    MadeChange |= Other.EnforceFloatingPoint(TP); -  if (!Other.hasFloatingPointTypes()) -    MadeChange |= EnforceInteger(TP); -  else if (!Other.hasIntegerTypes()) -    MadeChange |= EnforceFloatingPoint(TP); - -  assert(!isCompletelyUnknown() && !Other.isCompletelyUnknown() && -         "Should have a type list now"); - -  // If one contains vectors but the other doesn't pull vectors out. -  if (!hasVectorTypes()) -    MadeChange |= Other.EnforceScalar(TP); -  else if (!hasScalarTypes()) -    MadeChange |= Other.EnforceVector(TP); -  if (!Other.hasVectorTypes()) -    MadeChange |= EnforceScalar(TP); -  else if (!Other.hasScalarTypes()) -    MadeChange |= EnforceVector(TP); - -  // This code does not currently handle nodes which have multiple types, -  // where some types are integer, and some are fp.  Assert that this is not -  // the case. -  assert(!(hasIntegerTypes() && hasFloatingPointTypes()) && -         !(Other.hasIntegerTypes() && Other.hasFloatingPointTypes()) && -         "SDTCisOpSmallerThanOp does not handle mixed int/fp types!"); +bool TypeInfer::EnforceVectorEltTypeIs(TypeSetByHwMode &Vec, +                                       const ValueTypeByHwMode &VVT) { +  TypeSetByHwMode Tmp(VVT); +  ValidateOnExit _1(Vec), _2(Tmp); +  return EnforceVectorEltTypeIs(Vec, Tmp); +} +/// Ensure that for each type T in Sub, T is a vector type, and there +/// exists a type U in Vec such that U is a vector type with the same +/// element type as T and at least as many elements as T. +bool TypeInfer::EnforceVectorSubVectorTypeIs(TypeSetByHwMode &Vec, +                                             TypeSetByHwMode &Sub) { +  ValidateOnExit _1(Vec), _2(Sub);    if (TP.hasError())      return false; -  // Okay, find the smallest type from current set and remove anything the -  // same or smaller from the other set. We need to ensure that the scalar -  // type size is smaller than the scalar size of the smallest type. For -  // vectors, we also need to make sure that the total size is no larger than -  // the size of the smallest type. -  { -    TypeSet InputSet(Other); -    MVT Smallest = *std::min_element(TypeVec.begin(), TypeVec.end(), -      [](MVT A, MVT B) { -        return A.getScalarSizeInBits() < B.getScalarSizeInBits() || -               (A.getScalarSizeInBits() == B.getScalarSizeInBits() && -                A.getSizeInBits() < B.getSizeInBits()); -      }); - -    auto I = remove_if(Other.TypeVec, [Smallest](MVT OtherVT) { -      // Don't compare vector and non-vector types. -      if (OtherVT.isVector() != Smallest.isVector()) -        return false; -      // The getSizeInBits() check here is only needed for vectors, but is -      // a subset of the scalar check for scalars so no need to qualify. -      return OtherVT.getScalarSizeInBits() <= Smallest.getScalarSizeInBits() || -             OtherVT.getSizeInBits() < Smallest.getSizeInBits(); -    }); -    MadeChange |= I != Other.TypeVec.end(); // If we're about to remove types. -    Other.TypeVec.erase(I, Other.TypeVec.end()); - -    if (Other.TypeVec.empty()) { -      TP.error("Type inference contradiction found, '" + InputSet.getName() + -               "' has nothing larger than '" + getName() +"'!"); +  /// Return true if B is a suB-vector of P, i.e. P is a suPer-vector of B. +  auto IsSubVec = [](MVT B, MVT P) -> bool { +    if (!B.isVector() || !P.isVector())        return false; -    } -  } +    // Logically a <4 x i32> is a valid subvector of <n x 4 x i32> +    // but until there are obvious use-cases for this, keep the +    // types separate. +    if (B.isScalableVector() != P.isScalableVector()) +      return false; +    if (B.getVectorElementType() != P.getVectorElementType()) +      return false; +    return B.getVectorNumElements() < P.getVectorNumElements(); +  }; + +  /// Return true if S has no element (vector type) that T is a sub-vector of, +  /// i.e. has the same element type as T and more elements. +  auto NoSubV = [&IsSubVec](const TypeSetByHwMode::SetType &S, MVT T) -> bool { +    for (const auto &I : S) +      if (IsSubVec(T, I)) +        return false; +    return true; +  }; -  // Okay, find the largest type from the other set and remove anything the -  // same or smaller from the current set. We need to ensure that the scalar -  // type size is larger than the scalar size of the largest type. For -  // vectors, we also need to make sure that the total size is no smaller than -  // the size of the largest type. -  { -    TypeSet InputSet(*this); -    MVT Largest = *std::max_element(Other.TypeVec.begin(), Other.TypeVec.end(), -      [](MVT A, MVT B) { -        return A.getScalarSizeInBits() < B.getScalarSizeInBits() || -               (A.getScalarSizeInBits() == B.getScalarSizeInBits() && -                A.getSizeInBits() < B.getSizeInBits()); -      }); -    auto I = remove_if(TypeVec, [Largest](MVT OtherVT) { -      // Don't compare vector and non-vector types. -      if (OtherVT.isVector() != Largest.isVector()) +  /// Return true if S has no element (vector type) that T is a super-vector +  /// of, i.e. has the same element type as T and fewer elements. +  auto NoSupV = [&IsSubVec](const TypeSetByHwMode::SetType &S, MVT T) -> bool { +    for (const auto &I : S) +      if (IsSubVec(I, T))          return false; -      return OtherVT.getScalarSizeInBits() >= Largest.getScalarSizeInBits() || -             OtherVT.getSizeInBits() > Largest.getSizeInBits(); -    }); -    MadeChange |= I != TypeVec.end(); // If we're about to remove types. -    TypeVec.erase(I, TypeVec.end()); - -    if (TypeVec.empty()) { -      TP.error("Type inference contradiction found, '" + InputSet.getName() + -               "' has nothing smaller than '" + Other.getName() +"'!"); -      return false; -    } -  } +    return true; +  }; -  return MadeChange; -} +  bool Changed = false; -/// EnforceVectorEltTypeIs - 'this' is now constrained to be a vector type -/// whose element is specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorEltTypeIs(MVT::SimpleValueType VT, -                                           TreePattern &TP) { -  bool MadeChange = false; +  if (Vec.empty()) +    Changed |= EnforceVector(Vec); +  if (Sub.empty()) +    Changed |= EnforceVector(Sub); -  MadeChange |= EnforceVector(TP); +  for (unsigned M : union_modes(Vec, Sub)) { +    TypeSetByHwMode::SetType &S = Sub.get(M); +    TypeSetByHwMode::SetType &V = Vec.get(M); -  TypeSet InputSet(*this); +    Changed |= berase_if(S, isScalar); -  // Filter out all the types which don't have the right element type. -  auto I = remove_if(TypeVec, [VT](MVT VVT) { -    return VVT.getVectorElementType().SimpleTy != VT; -  }); -  MadeChange |= I != TypeVec.end(); -  TypeVec.erase(I, TypeVec.end()); +    // Erase all types from S that are not sub-vectors of a type in V. +    Changed |= berase_if(S, std::bind(NoSubV, V, std::placeholders::_1)); -  if (TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -    TP.error("Type inference contradiction found, forcing '" + -             InputSet.getName() + "' to have a vector element of type " + -             getEnumName(VT)); -    return false; +    // Erase all types from V that are not super-vectors of a type in S. +    Changed |= berase_if(V, std::bind(NoSupV, S, std::placeholders::_1));    } -  return MadeChange; +  return Changed;  } -/// EnforceVectorEltTypeIs - 'this' is now constrained to be a vector type -/// whose element is specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorEltTypeIs(EEVT::TypeSet &VTOperand, -                                           TreePattern &TP) { +/// 1. Ensure that V has a scalar type iff W has a scalar type. +/// 2. Ensure that for each vector type T in V, there exists a vector +///    type U in W, such that T and U have the same number of elements. +/// 3. Ensure that for each vector type U in W, there exists a vector +///    type T in V, such that T and U have the same number of elements +///    (reverse of 2). +bool TypeInfer::EnforceSameNumElts(TypeSetByHwMode &V, TypeSetByHwMode &W) { +  ValidateOnExit _1(V), _2(W);    if (TP.hasError())      return false; -  // "This" must be a vector and "VTOperand" must be a scalar. -  bool MadeChange = false; -  MadeChange |= EnforceVector(TP); -  MadeChange |= VTOperand.EnforceScalar(TP); - -  // If we know the vector type, it forces the scalar to agree. -  if (isConcrete()) { -    MVT IVT = getConcrete(); -    IVT = IVT.getVectorElementType(); -    return MadeChange || VTOperand.MergeInTypeInfo(IVT.SimpleTy, TP); -  } - -  // If the scalar type is known, filter out vector types whose element types -  // disagree. -  if (!VTOperand.isConcrete()) -    return MadeChange; - -  MVT::SimpleValueType VT = VTOperand.getConcrete(); - -  MadeChange |= EnforceVectorEltTypeIs(VT, TP); - -  return MadeChange; +  bool Changed = false; +  if (V.empty()) +    Changed |= EnforceAny(V); +  if (W.empty()) +    Changed |= EnforceAny(W); + +  // An actual vector type cannot have 0 elements, so we can treat scalars +  // as zero-length vectors. This way both vectors and scalars can be +  // processed identically. +  auto NoLength = [](const SmallSet<unsigned,2> &Lengths, MVT T) -> bool { +    return !Lengths.count(T.isVector() ? T.getVectorNumElements() : 0); +  }; + +  for (unsigned M : union_modes(V, W)) { +    TypeSetByHwMode::SetType &VS = V.get(M); +    TypeSetByHwMode::SetType &WS = W.get(M); + +    SmallSet<unsigned,2> VN, WN; +    for (MVT T : VS) +      VN.insert(T.isVector() ? T.getVectorNumElements() : 0); +    for (MVT T : WS) +      WN.insert(T.isVector() ? T.getVectorNumElements() : 0); + +    Changed |= berase_if(VS, std::bind(NoLength, WN, std::placeholders::_1)); +    Changed |= berase_if(WS, std::bind(NoLength, VN, std::placeholders::_1)); +  } +  return Changed;  } -/// EnforceVectorSubVectorTypeIs - 'this' is now constrained to be a -/// vector type specified by VTOperand. -bool EEVT::TypeSet::EnforceVectorSubVectorTypeIs(EEVT::TypeSet &VTOperand, -                                                 TreePattern &TP) { +/// 1. Ensure that for each type T in A, there exists a type U in B, +///    such that T and U have equal size in bits. +/// 2. Ensure that for each type U in B, there exists a type T in A +///    such that T and U have equal size in bits (reverse of 1). +bool TypeInfer::EnforceSameSize(TypeSetByHwMode &A, TypeSetByHwMode &B) { +  ValidateOnExit _1(A), _2(B);    if (TP.hasError())      return false; +  bool Changed = false; +  if (A.empty()) +    Changed |= EnforceAny(A); +  if (B.empty()) +    Changed |= EnforceAny(B); -  // "This" must be a vector and "VTOperand" must be a vector. -  bool MadeChange = false; -  MadeChange |= EnforceVector(TP); -  MadeChange |= VTOperand.EnforceVector(TP); - -  // If one side is known to be integer or known to be FP but the other side has -  // no information, get at least the type integrality info in there. -  if (!hasFloatingPointTypes()) -    MadeChange |= VTOperand.EnforceInteger(TP); -  else if (!hasIntegerTypes()) -    MadeChange |= VTOperand.EnforceFloatingPoint(TP); -  if (!VTOperand.hasFloatingPointTypes()) -    MadeChange |= EnforceInteger(TP); -  else if (!VTOperand.hasIntegerTypes()) -    MadeChange |= EnforceFloatingPoint(TP); - -  assert(!isCompletelyUnknown() && !VTOperand.isCompletelyUnknown() && -         "Should have a type list now"); - -  // If we know the vector type, it forces the scalar types to agree. -  // Also force one vector to have more elements than the other. -  if (isConcrete()) { -    MVT IVT = getConcrete(); -    unsigned NumElems = IVT.getVectorNumElements(); -    IVT = IVT.getVectorElementType(); - -    EEVT::TypeSet EltTypeSet(IVT.SimpleTy, TP); -    MadeChange |= VTOperand.EnforceVectorEltTypeIs(EltTypeSet, TP); - -    // Only keep types that have less elements than VTOperand. -    TypeSet InputSet(VTOperand); - -    auto I = remove_if(VTOperand.TypeVec, [NumElems](MVT VVT) { -      return VVT.getVectorNumElements() >= NumElems; -    }); -    MadeChange |= I != VTOperand.TypeVec.end(); -    VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); - -    if (VTOperand.TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -      TP.error("Type inference contradiction found, forcing '" + -               InputSet.getName() + "' to have less vector elements than '" + -               getName() + "'"); -      return false; -    } -  } else if (VTOperand.isConcrete()) { -    MVT IVT = VTOperand.getConcrete(); -    unsigned NumElems = IVT.getVectorNumElements(); -    IVT = IVT.getVectorElementType(); - -    EEVT::TypeSet EltTypeSet(IVT.SimpleTy, TP); -    MadeChange |= EnforceVectorEltTypeIs(EltTypeSet, TP); +  auto NoSize = [](const SmallSet<unsigned,2> &Sizes, MVT T) -> bool { +    return !Sizes.count(T.getSizeInBits()); +  }; -    // Only keep types that have more elements than 'this'. -    TypeSet InputSet(*this); +  for (unsigned M : union_modes(A, B)) { +    TypeSetByHwMode::SetType &AS = A.get(M); +    TypeSetByHwMode::SetType &BS = B.get(M); +    SmallSet<unsigned,2> AN, BN; -    auto I = remove_if(TypeVec, [NumElems](MVT VVT) { -      return VVT.getVectorNumElements() <= NumElems; -    }); -    MadeChange |= I != TypeVec.end(); -    TypeVec.erase(I, TypeVec.end()); +    for (MVT T : AS) +      AN.insert(T.getSizeInBits()); +    for (MVT T : BS) +      BN.insert(T.getSizeInBits()); -    if (TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -      TP.error("Type inference contradiction found, forcing '" + -               InputSet.getName() + "' to have more vector elements than '" + -               VTOperand.getName() + "'"); -      return false; -    } +    Changed |= berase_if(AS, std::bind(NoSize, BN, std::placeholders::_1)); +    Changed |= berase_if(BS, std::bind(NoSize, AN, std::placeholders::_1));    } -  return MadeChange; +  return Changed;  } -/// EnforceameNumElts - If VTOperand is a scalar, then 'this' is a scalar. If -/// VTOperand is a vector, then 'this' must have the same number of elements. -bool EEVT::TypeSet::EnforceSameNumElts(EEVT::TypeSet &VTOperand, -                                       TreePattern &TP) { -  if (TP.hasError()) -    return false; - -  bool MadeChange = false; +void TypeInfer::expandOverloads(TypeSetByHwMode &VTS) { +  ValidateOnExit _1(VTS); +  TypeSetByHwMode Legal = getLegalTypes(); +  bool HaveLegalDef = Legal.hasDefault(); -  if (isCompletelyUnknown()) -    MadeChange = FillWithPossibleTypes(TP); - -  if (VTOperand.isCompletelyUnknown()) -    MadeChange = VTOperand.FillWithPossibleTypes(TP); - -  // If one contains vectors but the other doesn't pull vectors out. -  if (!hasVectorTypes()) -    MadeChange |= VTOperand.EnforceScalar(TP); -  else if (!hasScalarTypes()) -    MadeChange |= VTOperand.EnforceVector(TP); -  if (!VTOperand.hasVectorTypes()) -    MadeChange |= EnforceScalar(TP); -  else if (!VTOperand.hasScalarTypes()) -    MadeChange |= EnforceVector(TP); - -  // If one type is a vector, make sure the other has the same element count. -  // If this a scalar, then we are already done with the above. -  if (isConcrete()) { -    MVT IVT = getConcrete(); -    if (IVT.isVector()) { -      unsigned NumElems = IVT.getVectorNumElements(); - -      // Only keep types that have same elements as 'this'. -      TypeSet InputSet(VTOperand); - -      auto I = remove_if(VTOperand.TypeVec, [NumElems](MVT VVT) { -        return VVT.getVectorNumElements() != NumElems; -      }); -      MadeChange |= I != VTOperand.TypeVec.end(); -      VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); - -      if (VTOperand.TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -        TP.error("Type inference contradiction found, forcing '" + -                 InputSet.getName() + "' to have same number elements as '" + -                 getName() + "'"); -        return false; -      } +  for (auto &I : VTS) { +    unsigned M = I.first; +    if (!Legal.hasMode(M) && !HaveLegalDef) { +      TP.error("Invalid mode " + Twine(M)); +      return;      } -  } else if (VTOperand.isConcrete()) { -    MVT IVT = VTOperand.getConcrete(); -    if (IVT.isVector()) { -      unsigned NumElems = IVT.getVectorNumElements(); - -      // Only keep types that have same elements as VTOperand. -      TypeSet InputSet(*this); +    expandOverloads(I.second, Legal.get(M)); +  } +} -      auto I = remove_if(TypeVec, [NumElems](MVT VVT) { -        return VVT.getVectorNumElements() != NumElems; -      }); -      MadeChange |= I != TypeVec.end(); -      TypeVec.erase(I, TypeVec.end()); +void TypeInfer::expandOverloads(TypeSetByHwMode::SetType &Out, +                                const TypeSetByHwMode::SetType &Legal) { +  std::set<MVT> Ovs; +  for (MVT T : Out) { +    if (!T.isOverloaded()) +      continue; -      if (TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -        TP.error("Type inference contradiction found, forcing '" + -                 InputSet.getName() + "' to have same number elements than '" + -                 VTOperand.getName() + "'"); -        return false; -      } +    Ovs.insert(T); +    // MachineValueTypeSet allows iteration and erasing. +    Out.erase(T); +  } + +  for (MVT Ov : Ovs) { +    switch (Ov.SimpleTy) { +      case MVT::iPTRAny: +        Out.insert(MVT::iPTR); +        return; +      case MVT::iAny: +        for (MVT T : MVT::integer_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        for (MVT T : MVT::integer_vector_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        return; +      case MVT::fAny: +        for (MVT T : MVT::fp_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        for (MVT T : MVT::fp_vector_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        return; +      case MVT::vAny: +        for (MVT T : MVT::vector_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        return; +      case MVT::Any: +        for (MVT T : MVT::all_valuetypes()) +          if (Legal.count(T)) +            Out.insert(T); +        return; +      default: +        break;      }    } +} -  return MadeChange; +TypeSetByHwMode TypeInfer::getLegalTypes() { +  if (!LegalTypesCached) { +    // Stuff all types from all modes into the default mode. +    const TypeSetByHwMode <S = TP.getDAGPatterns().getLegalTypes(); +    for (const auto &I : LTS) +      LegalCache.insert(I.second); +    LegalTypesCached = true; +  } +  TypeSetByHwMode VTS; +  VTS.getOrCreate(DefaultMode) = LegalCache; +  return VTS;  } -/// EnforceSameSize - 'this' is now constrained to be same size as VTOperand. -bool EEVT::TypeSet::EnforceSameSize(EEVT::TypeSet &VTOperand, -                                    TreePattern &TP) { -  if (TP.hasError()) -    return false; +//===----------------------------------------------------------------------===// +// TreePredicateFn Implementation +//===----------------------------------------------------------------------===// -  bool MadeChange = false; +/// TreePredicateFn constructor.  Here 'N' is a subclass of PatFrag. +TreePredicateFn::TreePredicateFn(TreePattern *N) : PatFragRec(N) { +  assert( +      (!hasPredCode() || !hasImmCode()) && +      ".td file corrupt: can't have a node predicate *and* an imm predicate"); +} -  if (isCompletelyUnknown()) -    MadeChange = FillWithPossibleTypes(TP); +bool TreePredicateFn::hasPredCode() const { +  return isLoad() || isStore() || isAtomic() || +         !PatFragRec->getRecord()->getValueAsString("PredicateCode").empty(); +} -  if (VTOperand.isCompletelyUnknown()) -    MadeChange = VTOperand.FillWithPossibleTypes(TP); +std::string TreePredicateFn::getPredCode() const { +  std::string Code = ""; -  // If we know one of the types, it forces the other type agree. -  if (isConcrete()) { -    MVT IVT = getConcrete(); -    unsigned Size = IVT.getSizeInBits(); +  if (!isLoad() && !isStore() && !isAtomic()) { +    Record *MemoryVT = getMemoryVT(); -    // Only keep types that have the same size as 'this'. -    TypeSet InputSet(VTOperand); +    if (MemoryVT) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "MemoryVT requires IsLoad or IsStore"); +  } -    auto I = remove_if(VTOperand.TypeVec, -                       [&](MVT VT) { return VT.getSizeInBits() != Size; }); -    MadeChange |= I != VTOperand.TypeVec.end(); -    VTOperand.TypeVec.erase(I, VTOperand.TypeVec.end()); +  if (!isLoad() && !isStore()) { +    if (isUnindexed()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsUnindexed requires IsLoad or IsStore"); -    if (VTOperand.TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -      TP.error("Type inference contradiction found, forcing '" + -               InputSet.getName() + "' to have same size as '" + -               getName() + "'"); -      return false; -    } -  } else if (VTOperand.isConcrete()) { -    MVT IVT = VTOperand.getConcrete(); -    unsigned Size = IVT.getSizeInBits(); +    Record *ScalarMemoryVT = getScalarMemoryVT(); -    // Only keep types that have the same size as VTOperand. -    TypeSet InputSet(*this); +    if (ScalarMemoryVT) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "ScalarMemoryVT requires IsLoad or IsStore"); +  } -    auto I = -        remove_if(TypeVec, [&](MVT VT) { return VT.getSizeInBits() != Size; }); -    MadeChange |= I != TypeVec.end(); -    TypeVec.erase(I, TypeVec.end()); +  if (isLoad() + isStore() + isAtomic() > 1) +    PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                    "IsLoad, IsStore, and IsAtomic are mutually exclusive"); -    if (TypeVec.empty()) {  // FIXME: Really want an SMLoc here! -      TP.error("Type inference contradiction found, forcing '" + -               InputSet.getName() + "' to have same size as '" + -               VTOperand.getName() + "'"); -      return false; +  if (isLoad()) { +    if (!isUnindexed() && !isNonExtLoad() && !isAnyExtLoad() && +        !isSignExtLoad() && !isZeroExtLoad() && getMemoryVT() == nullptr && +        getScalarMemoryVT() == nullptr) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsLoad cannot be used by itself"); +  } else { +    if (isNonExtLoad()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsNonExtLoad requires IsLoad"); +    if (isAnyExtLoad()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAnyExtLoad requires IsLoad"); +    if (isSignExtLoad()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsSignExtLoad requires IsLoad"); +    if (isZeroExtLoad()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsZeroExtLoad requires IsLoad"); +  } + +  if (isStore()) { +    if (!isUnindexed() && !isTruncStore() && !isNonTruncStore() && +        getMemoryVT() == nullptr && getScalarMemoryVT() == nullptr) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsStore cannot be used by itself"); +  } else { +    if (isNonTruncStore()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsNonTruncStore requires IsStore"); +    if (isTruncStore()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsTruncStore requires IsStore"); +  } + +  if (isAtomic()) { +    if (getMemoryVT() == nullptr && !isAtomicOrderingMonotonic() && +        !isAtomicOrderingAcquire() && !isAtomicOrderingRelease() && +        !isAtomicOrderingAcquireRelease() && +        !isAtomicOrderingSequentiallyConsistent() && +        !isAtomicOrderingAcquireOrStronger() && +        !isAtomicOrderingReleaseOrStronger() && +        !isAtomicOrderingWeakerThanAcquire() && +        !isAtomicOrderingWeakerThanRelease()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomic cannot be used by itself"); +  } else { +    if (isAtomicOrderingMonotonic()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingMonotonic requires IsAtomic"); +    if (isAtomicOrderingAcquire()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingAcquire requires IsAtomic"); +    if (isAtomicOrderingRelease()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingRelease requires IsAtomic"); +    if (isAtomicOrderingAcquireRelease()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingAcquireRelease requires IsAtomic"); +    if (isAtomicOrderingSequentiallyConsistent()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingSequentiallyConsistent requires IsAtomic"); +    if (isAtomicOrderingAcquireOrStronger()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingAcquireOrStronger requires IsAtomic"); +    if (isAtomicOrderingReleaseOrStronger()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingReleaseOrStronger requires IsAtomic"); +    if (isAtomicOrderingWeakerThanAcquire()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsAtomicOrderingWeakerThanAcquire requires IsAtomic"); +  } + +  if (isLoad() || isStore() || isAtomic()) { +    StringRef SDNodeName = +        isLoad() ? "LoadSDNode" : isStore() ? "StoreSDNode" : "AtomicSDNode"; + +    Record *MemoryVT = getMemoryVT(); + +    if (MemoryVT) +      Code += ("if (cast<" + SDNodeName + ">(N)->getMemoryVT() != MVT::" + +               MemoryVT->getName() + ") return false;\n") +                  .str(); +  } + +  if (isAtomic() && isAtomicOrderingMonotonic()) +    Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " +            "AtomicOrdering::Monotonic) return false;\n"; +  if (isAtomic() && isAtomicOrderingAcquire()) +    Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " +            "AtomicOrdering::Acquire) return false;\n"; +  if (isAtomic() && isAtomicOrderingRelease()) +    Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " +            "AtomicOrdering::Release) return false;\n"; +  if (isAtomic() && isAtomicOrderingAcquireRelease()) +    Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " +            "AtomicOrdering::AcquireRelease) return false;\n"; +  if (isAtomic() && isAtomicOrderingSequentiallyConsistent()) +    Code += "if (cast<AtomicSDNode>(N)->getOrdering() != " +            "AtomicOrdering::SequentiallyConsistent) return false;\n"; + +  if (isAtomic() && isAtomicOrderingAcquireOrStronger()) +    Code += "if (!isAcquireOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " +            "return false;\n"; +  if (isAtomic() && isAtomicOrderingWeakerThanAcquire()) +    Code += "if (isAcquireOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " +            "return false;\n"; + +  if (isAtomic() && isAtomicOrderingReleaseOrStronger()) +    Code += "if (!isReleaseOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " +            "return false;\n"; +  if (isAtomic() && isAtomicOrderingWeakerThanRelease()) +    Code += "if (isReleaseOrStronger(cast<AtomicSDNode>(N)->getOrdering())) " +            "return false;\n"; + +  if (isLoad() || isStore()) { +    StringRef SDNodeName = isLoad() ? "LoadSDNode" : "StoreSDNode"; + +    if (isUnindexed()) +      Code += ("if (cast<" + SDNodeName + +               ">(N)->getAddressingMode() != ISD::UNINDEXED) " +               "return false;\n") +                  .str(); + +    if (isLoad()) { +      if ((isNonExtLoad() + isAnyExtLoad() + isSignExtLoad() + +           isZeroExtLoad()) > 1) +        PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                        "IsNonExtLoad, IsAnyExtLoad, IsSignExtLoad, and " +                        "IsZeroExtLoad are mutually exclusive"); +      if (isNonExtLoad()) +        Code += "if (cast<LoadSDNode>(N)->getExtensionType() != " +                "ISD::NON_EXTLOAD) return false;\n"; +      if (isAnyExtLoad()) +        Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::EXTLOAD) " +                "return false;\n"; +      if (isSignExtLoad()) +        Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::SEXTLOAD) " +                "return false;\n"; +      if (isZeroExtLoad()) +        Code += "if (cast<LoadSDNode>(N)->getExtensionType() != ISD::ZEXTLOAD) " +                "return false;\n"; +    } else { +      if ((isNonTruncStore() + isTruncStore()) > 1) +        PrintFatalError( +            getOrigPatFragRecord()->getRecord()->getLoc(), +            "IsNonTruncStore, and IsTruncStore are mutually exclusive"); +      if (isNonTruncStore()) +        Code += +            " if (cast<StoreSDNode>(N)->isTruncatingStore()) return false;\n"; +      if (isTruncStore()) +        Code += +            " if (!cast<StoreSDNode>(N)->isTruncatingStore()) return false;\n";      } + +    Record *ScalarMemoryVT = getScalarMemoryVT(); + +    if (ScalarMemoryVT) +      Code += ("if (cast<" + SDNodeName + +               ">(N)->getMemoryVT().getScalarType() != MVT::" + +               ScalarMemoryVT->getName() + ") return false;\n") +                  .str();    } -  return MadeChange; -} +  std::string PredicateCode = PatFragRec->getRecord()->getValueAsString("PredicateCode"); -//===----------------------------------------------------------------------===// -// Helpers for working with extended types. +  Code += PredicateCode; -/// Dependent variable map for CodeGenDAGPattern variant generation -typedef std::map<std::string, int> DepVarMap; +  if (PredicateCode.empty() && !Code.empty()) +    Code += "return true;\n"; -static void FindDepVarsOf(TreePatternNode *N, DepVarMap &DepMap) { -  if (N->isLeaf()) { -    if (isa<DefInit>(N->getLeafValue())) -      DepMap[N->getName()]++; -  } else { -    for (size_t i = 0, e = N->getNumChildren(); i != e; ++i) -      FindDepVarsOf(N->getChild(i), DepMap); -  } -} -   -/// Find dependent variables within child patterns -static void FindDepVars(TreePatternNode *N, MultipleUseVarSet &DepVars) { -  DepVarMap depcounts; -  FindDepVarsOf(N, depcounts); -  for (const std::pair<std::string, int> &Pair : depcounts) { -    if (Pair.second > 1) -      DepVars.insert(Pair.first); -  } +  return Code;  } -#ifndef NDEBUG -/// Dump the dependent variable set: -static void DumpDepVars(MultipleUseVarSet &DepVars) { -  if (DepVars.empty()) { -    DEBUG(errs() << "<empty set>"); -  } else { -    DEBUG(errs() << "[ "); -    for (const std::string &DepVar : DepVars) { -      DEBUG(errs() << DepVar << " "); -    } -    DEBUG(errs() << "]"); -  } +bool TreePredicateFn::hasImmCode() const { +  return !PatFragRec->getRecord()->getValueAsString("ImmediateCode").empty();  } -#endif +std::string TreePredicateFn::getImmCode() const { +  return PatFragRec->getRecord()->getValueAsString("ImmediateCode"); +} -//===----------------------------------------------------------------------===// -// TreePredicateFn Implementation -//===----------------------------------------------------------------------===// +bool TreePredicateFn::immCodeUsesAPInt() const { +  return getOrigPatFragRecord()->getRecord()->getValueAsBit("IsAPInt"); +} -/// TreePredicateFn constructor.  Here 'N' is a subclass of PatFrag. -TreePredicateFn::TreePredicateFn(TreePattern *N) : PatFragRec(N) { -  assert((getPredCode().empty() || getImmCode().empty()) && -        ".td file corrupt: can't have a node predicate *and* an imm predicate"); +bool TreePredicateFn::immCodeUsesAPFloat() const { +  bool Unset; +  // The return value will be false when IsAPFloat is unset. +  return getOrigPatFragRecord()->getRecord()->getValueAsBitOrUnset("IsAPFloat", +                                                                   Unset);  } -std::string TreePredicateFn::getPredCode() const { -  return PatFragRec->getRecord()->getValueAsString("PredicateCode"); +bool TreePredicateFn::isPredefinedPredicateEqualTo(StringRef Field, +                                                   bool Value) const { +  bool Unset; +  bool Result = +      getOrigPatFragRecord()->getRecord()->getValueAsBitOrUnset(Field, Unset); +  if (Unset) +    return false; +  return Result == Value; +} +bool TreePredicateFn::isLoad() const { +  return isPredefinedPredicateEqualTo("IsLoad", true); +} +bool TreePredicateFn::isStore() const { +  return isPredefinedPredicateEqualTo("IsStore", true); +} +bool TreePredicateFn::isAtomic() const { +  return isPredefinedPredicateEqualTo("IsAtomic", true); +} +bool TreePredicateFn::isUnindexed() const { +  return isPredefinedPredicateEqualTo("IsUnindexed", true); +} +bool TreePredicateFn::isNonExtLoad() const { +  return isPredefinedPredicateEqualTo("IsNonExtLoad", true); +} +bool TreePredicateFn::isAnyExtLoad() const { +  return isPredefinedPredicateEqualTo("IsAnyExtLoad", true); +} +bool TreePredicateFn::isSignExtLoad() const { +  return isPredefinedPredicateEqualTo("IsSignExtLoad", true); +} +bool TreePredicateFn::isZeroExtLoad() const { +  return isPredefinedPredicateEqualTo("IsZeroExtLoad", true); +} +bool TreePredicateFn::isNonTruncStore() const { +  return isPredefinedPredicateEqualTo("IsTruncStore", false); +} +bool TreePredicateFn::isTruncStore() const { +  return isPredefinedPredicateEqualTo("IsTruncStore", true); +} +bool TreePredicateFn::isAtomicOrderingMonotonic() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingMonotonic", true); +} +bool TreePredicateFn::isAtomicOrderingAcquire() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquire", true); +} +bool TreePredicateFn::isAtomicOrderingRelease() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingRelease", true); +} +bool TreePredicateFn::isAtomicOrderingAcquireRelease() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireRelease", true); +} +bool TreePredicateFn::isAtomicOrderingSequentiallyConsistent() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingSequentiallyConsistent", +                                      true); +} +bool TreePredicateFn::isAtomicOrderingAcquireOrStronger() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireOrStronger", true); +} +bool TreePredicateFn::isAtomicOrderingWeakerThanAcquire() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingAcquireOrStronger", false); +} +bool TreePredicateFn::isAtomicOrderingReleaseOrStronger() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingReleaseOrStronger", true); +} +bool TreePredicateFn::isAtomicOrderingWeakerThanRelease() const { +  return isPredefinedPredicateEqualTo("IsAtomicOrderingReleaseOrStronger", false); +} +Record *TreePredicateFn::getMemoryVT() const { +  Record *R = getOrigPatFragRecord()->getRecord(); +  if (R->isValueUnset("MemoryVT")) +    return nullptr; +  return R->getValueAsDef("MemoryVT"); +} +Record *TreePredicateFn::getScalarMemoryVT() const { +  Record *R = getOrigPatFragRecord()->getRecord(); +  if (R->isValueUnset("ScalarMemoryVT")) +    return nullptr; +  return R->getValueAsDef("ScalarMemoryVT");  } -std::string TreePredicateFn::getImmCode() const { -  return PatFragRec->getRecord()->getValueAsString("ImmediateCode"); +StringRef TreePredicateFn::getImmType() const { +  if (immCodeUsesAPInt()) +    return "const APInt &"; +  if (immCodeUsesAPFloat()) +    return "const APFloat &"; +  return "int64_t";  } +StringRef TreePredicateFn::getImmTypeIdentifier() const { +  if (immCodeUsesAPInt()) +    return "APInt"; +  else if (immCodeUsesAPFloat()) +    return "APFloat"; +  return "I64"; +}  /// isAlwaysTrue - Return true if this is a noop predicate.  bool TreePredicateFn::isAlwaysTrue() const { -  return getPredCode().empty() && getImmCode().empty(); +  return !hasPredCode() && !hasImmCode();  }  /// Return the name to use in the generated code to reference this, this is @@ -790,14 +1157,61 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {    // Handle immediate predicates first.    std::string ImmCode = getImmCode();    if (!ImmCode.empty()) { -    std::string Result = -      "    int64_t Imm = cast<ConstantSDNode>(Node)->getSExtValue();\n"; +    if (isLoad()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsLoad cannot be used with ImmLeaf or its subclasses"); +    if (isStore()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "IsStore cannot be used with ImmLeaf or its subclasses"); +    if (isUnindexed()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsUnindexed cannot be used with ImmLeaf or its subclasses"); +    if (isNonExtLoad()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsNonExtLoad cannot be used with ImmLeaf or its subclasses"); +    if (isAnyExtLoad()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsAnyExtLoad cannot be used with ImmLeaf or its subclasses"); +    if (isSignExtLoad()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsSignExtLoad cannot be used with ImmLeaf or its subclasses"); +    if (isZeroExtLoad()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsZeroExtLoad cannot be used with ImmLeaf or its subclasses"); +    if (isNonTruncStore()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsNonTruncStore cannot be used with ImmLeaf or its subclasses"); +    if (isTruncStore()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "IsTruncStore cannot be used with ImmLeaf or its subclasses"); +    if (getMemoryVT()) +      PrintFatalError(getOrigPatFragRecord()->getRecord()->getLoc(), +                      "MemoryVT cannot be used with ImmLeaf or its subclasses"); +    if (getScalarMemoryVT()) +      PrintFatalError( +          getOrigPatFragRecord()->getRecord()->getLoc(), +          "ScalarMemoryVT cannot be used with ImmLeaf or its subclasses"); + +    std::string Result = ("    " + getImmType() + " Imm = ").str(); +    if (immCodeUsesAPFloat()) +      Result += "cast<ConstantFPSDNode>(Node)->getValueAPF();\n"; +    else if (immCodeUsesAPInt()) +      Result += "cast<ConstantSDNode>(Node)->getAPIntValue();\n"; +    else +      Result += "cast<ConstantSDNode>(Node)->getSExtValue();\n";      return Result + ImmCode;    } -   +    // Handle arbitrary node predicates. -  assert(!getPredCode().empty() && "Don't have any predicate code!"); -  std::string ClassName; +  assert(hasPredCode() && "Don't have any predicate code!"); +  StringRef ClassName;    if (PatFragRec->getOnlyTree()->isLeaf())      ClassName = "SDNode";    else { @@ -808,8 +1222,8 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {    if (ClassName == "SDNode")      Result = "    SDNode *N = Node;\n";    else -    Result = "    auto *N = cast<" + ClassName + ">(Node);\n"; -   +    Result = "    auto *N = cast<" + ClassName.str() + ">(Node);\n"; +    return Result + getPredCode();  } @@ -817,7 +1231,6 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {  // PatternToMatch implementation  // -  /// getPatternSize - Return the 'size' of this pattern.  We want to match large  /// patterns before small ones.  This is used to determine the size of a  /// pattern. @@ -829,10 +1242,8 @@ static unsigned getPatternSize(const TreePatternNode *P,    if (P->isLeaf() && isa<IntInit>(P->getLeafValue()))      Size += 2; -  const ComplexPattern *AM = P->getComplexPatternInfo(CGP); -  if (AM) { +  if (const ComplexPattern *AM = P->getComplexPatternInfo(CGP)) {      Size += AM->getComplexity(); -      // We don't want to count any children twice, so return early.      return Size;    } @@ -844,11 +1255,17 @@ static unsigned getPatternSize(const TreePatternNode *P,    // Count children in the count if they are also nodes.    for (unsigned i = 0, e = P->getNumChildren(); i != e; ++i) { -    TreePatternNode *Child = P->getChild(i); -    if (!Child->isLeaf() && Child->getNumTypes() && -        Child->getType(0) != MVT::Other) -      Size += getPatternSize(Child, CGP); -    else if (Child->isLeaf()) { +    const TreePatternNode *Child = P->getChild(i); +    if (!Child->isLeaf() && Child->getNumTypes()) { +      const TypeSetByHwMode &T0 = Child->getType(0); +      // At this point, all variable type sets should be simple, i.e. only +      // have a default mode. +      if (T0.getMachineValueType() != MVT::Other) { +        Size += getPatternSize(Child, CGP); +        continue; +      } +    } +    if (Child->isLeaf()) {        if (isa<IntInit>(Child->getLeafValue()))          Size += 5;  // Matches a ConstantSDNode (+3) and a specific value (+2).        else if (Child->getComplexPatternInfo(CGP)) @@ -868,52 +1285,37 @@ getPatternComplexity(const CodeGenDAGPatterns &CGP) const {    return getPatternSize(getSrcPattern(), CGP) + getAddedComplexity();  } -  /// getPredicateCheck - Return a single string containing all of this  /// pattern's predicates concatenated with "&&" operators.  ///  std::string PatternToMatch::getPredicateCheck() const { -  SmallVector<Record *, 4> PredicateRecs; -  for (Init *I : Predicates->getValues()) { -    if (DefInit *Pred = dyn_cast<DefInit>(I)) { -      Record *Def = Pred->getDef(); -      if (!Def->isSubClassOf("Predicate")) { -#ifndef NDEBUG -        Def->dump(); -#endif -        llvm_unreachable("Unknown predicate type!"); -      } -      PredicateRecs.push_back(Def); -    } -  } -  // Sort so that different orders get canonicalized to the same string. -  std::sort(PredicateRecs.begin(), PredicateRecs.end(), LessRecord()); - -  SmallString<128> PredicateCheck; -  for (Record *Pred : PredicateRecs) { -    if (!PredicateCheck.empty()) -      PredicateCheck += " && "; -    PredicateCheck += "("; -    PredicateCheck += Pred->getValueAsString("CondString"); -    PredicateCheck += ")"; -  } - -  return PredicateCheck.str(); +  SmallVector<const Predicate*,4> PredList; +  for (const Predicate &P : Predicates) +    PredList.push_back(&P); +  std::sort(PredList.begin(), PredList.end(), deref<llvm::less>()); + +  std::string Check; +  for (unsigned i = 0, e = PredList.size(); i != e; ++i) { +    if (i != 0) +      Check += " && "; +    Check += '(' + PredList[i]->getCondString() + ')'; +  } +  return Check;  }  //===----------------------------------------------------------------------===//  // SDTypeConstraint implementation  // -SDTypeConstraint::SDTypeConstraint(Record *R) { +SDTypeConstraint::SDTypeConstraint(Record *R, const CodeGenHwModes &CGH) {    OperandNo = R->getValueAsInt("OperandNum");    if (R->isSubClassOf("SDTCisVT")) {      ConstraintType = SDTCisVT; -    x.SDTCisVT_Info.VT = getValueType(R->getValueAsDef("VT")); -    if (x.SDTCisVT_Info.VT == MVT::isVoid) -      PrintFatalError(R->getLoc(), "Cannot use 'Void' as type to SDTCisVT"); - +    VVT = getValueTypeByHwMode(R->getValueAsDef("VT"), CGH); +    for (const auto &P : VVT) +      if (P.second == MVT::isVoid) +        PrintFatalError(R->getLoc(), "Cannot use 'Void' as type to SDTCisVT");    } else if (R->isSubClassOf("SDTCisPtrTy")) {      ConstraintType = SDTCisPtrTy;    } else if (R->isSubClassOf("SDTCisInt")) { @@ -942,13 +1344,16 @@ SDTypeConstraint::SDTypeConstraint(Record *R) {        R->getValueAsInt("OtherOpNum");    } else if (R->isSubClassOf("SDTCVecEltisVT")) {      ConstraintType = SDTCVecEltisVT; -    x.SDTCVecEltisVT_Info.VT = getValueType(R->getValueAsDef("VT")); -    if (MVT(x.SDTCVecEltisVT_Info.VT).isVector()) -      PrintFatalError(R->getLoc(), "Cannot use vector type as SDTCVecEltisVT"); -    if (!MVT(x.SDTCVecEltisVT_Info.VT).isInteger() && -        !MVT(x.SDTCVecEltisVT_Info.VT).isFloatingPoint()) -      PrintFatalError(R->getLoc(), "Must use integer or floating point type " -                                   "as SDTCVecEltisVT"); +    VVT = getValueTypeByHwMode(R->getValueAsDef("VT"), CGH); +    for (const auto &P : VVT) { +      MVT T = P.second; +      if (T.isVector()) +        PrintFatalError(R->getLoc(), +                        "Cannot use vector type as SDTCVecEltisVT"); +      if (!T.isInteger() && !T.isFloatingPoint()) +        PrintFatalError(R->getLoc(), "Must use integer or floating point type " +                                     "as SDTCVecEltisVT"); +    }    } else if (R->isSubClassOf("SDTCisSameNumEltsAs")) {      ConstraintType = SDTCisSameNumEltsAs;      x.SDTCisSameNumEltsAs_Info.OtherOperandNum = @@ -998,23 +1403,24 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N,    unsigned ResNo = 0; // The result number being referenced.    TreePatternNode *NodeToApply = getOperandNum(OperandNo, N, NodeInfo, ResNo); +  TypeInfer &TI = TP.getInfer();    switch (ConstraintType) {    case SDTCisVT:      // Operand must be a particular type. -    return NodeToApply->UpdateNodeType(ResNo, x.SDTCisVT_Info.VT, TP); +    return NodeToApply->UpdateNodeType(ResNo, VVT, TP);    case SDTCisPtrTy:      // Operand must be same as target pointer type.      return NodeToApply->UpdateNodeType(ResNo, MVT::iPTR, TP);    case SDTCisInt:      // Require it to be one of the legal integer VTs. -    return NodeToApply->getExtType(ResNo).EnforceInteger(TP); +     return TI.EnforceInteger(NodeToApply->getExtType(ResNo));    case SDTCisFP:      // Require it to be one of the legal fp VTs. -    return NodeToApply->getExtType(ResNo).EnforceFloatingPoint(TP); +    return TI.EnforceFloatingPoint(NodeToApply->getExtType(ResNo));    case SDTCisVec:      // Require it to be one of the legal vector VTs. -    return NodeToApply->getExtType(ResNo).EnforceVector(TP); +    return TI.EnforceVector(NodeToApply->getExtType(ResNo));    case SDTCisSameAs: {      unsigned OResNo = 0;      TreePatternNode *OtherNode = @@ -1032,36 +1438,35 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N,        TP.error(N->getOperator()->getName() + " expects a VT operand!");        return false;      } -    MVT::SimpleValueType VT = -     getValueType(static_cast<DefInit*>(NodeToApply->getLeafValue())->getDef()); - -    EEVT::TypeSet TypeListTmp(VT, TP); +    DefInit *DI = static_cast<DefInit*>(NodeToApply->getLeafValue()); +    const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); +    auto VVT = getValueTypeByHwMode(DI->getDef(), T.getHwModes()); +    TypeSetByHwMode TypeListTmp(VVT);      unsigned OResNo = 0;      TreePatternNode *OtherNode =        getOperandNum(x.SDTCisVTSmallerThanOp_Info.OtherOperandNum, N, NodeInfo,                      OResNo); -    return TypeListTmp.EnforceSmallerThan(OtherNode->getExtType(OResNo), TP); +    return TI.EnforceSmallerThan(TypeListTmp, OtherNode->getExtType(OResNo));    }    case SDTCisOpSmallerThanOp: {      unsigned BResNo = 0;      TreePatternNode *BigOperand =        getOperandNum(x.SDTCisOpSmallerThanOp_Info.BigOperandNum, N, NodeInfo,                      BResNo); -    return NodeToApply->getExtType(ResNo). -                  EnforceSmallerThan(BigOperand->getExtType(BResNo), TP); +    return TI.EnforceSmallerThan(NodeToApply->getExtType(ResNo), +                                 BigOperand->getExtType(BResNo));    }    case SDTCisEltOfVec: {      unsigned VResNo = 0;      TreePatternNode *VecOperand =        getOperandNum(x.SDTCisEltOfVec_Info.OtherOperandNum, N, NodeInfo,                      VResNo); -      // Filter vector types out of VecOperand that don't have the right element      // type. -    return VecOperand->getExtType(VResNo). -      EnforceVectorEltTypeIs(NodeToApply->getExtType(ResNo), TP); +    return TI.EnforceVectorEltTypeIs(VecOperand->getExtType(VResNo), +                                     NodeToApply->getExtType(ResNo));    }    case SDTCisSubVecOfVec: {      unsigned VResNo = 0; @@ -1071,28 +1476,27 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N,      // Filter vector types out of BigVecOperand that don't have the      // right subvector type. -    return BigVecOperand->getExtType(VResNo). -      EnforceVectorSubVectorTypeIs(NodeToApply->getExtType(ResNo), TP); +    return TI.EnforceVectorSubVectorTypeIs(BigVecOperand->getExtType(VResNo), +                                           NodeToApply->getExtType(ResNo));    }    case SDTCVecEltisVT: { -    return NodeToApply->getExtType(ResNo). -      EnforceVectorEltTypeIs(x.SDTCVecEltisVT_Info.VT, TP); +    return TI.EnforceVectorEltTypeIs(NodeToApply->getExtType(ResNo), VVT);    }    case SDTCisSameNumEltsAs: {      unsigned OResNo = 0;      TreePatternNode *OtherNode =        getOperandNum(x.SDTCisSameNumEltsAs_Info.OtherOperandNum,                      N, NodeInfo, OResNo); -    return OtherNode->getExtType(OResNo). -      EnforceSameNumElts(NodeToApply->getExtType(ResNo), TP); +    return TI.EnforceSameNumElts(OtherNode->getExtType(OResNo), +                                 NodeToApply->getExtType(ResNo));    }    case SDTCisSameSizeAs: {      unsigned OResNo = 0;      TreePatternNode *OtherNode =        getOperandNum(x.SDTCisSameSizeAs_Info.OtherOperandNum,                      N, NodeInfo, OResNo); -    return OtherNode->getExtType(OResNo). -      EnforceSameSize(NodeToApply->getExtType(ResNo), TP); +    return TI.EnforceSameSize(OtherNode->getExtType(OResNo), +                              NodeToApply->getExtType(ResNo));    }    }    llvm_unreachable("Invalid ConstraintType!"); @@ -1110,9 +1514,11 @@ bool TreePatternNode::UpdateNodeTypeFromInst(unsigned ResNo,      return false;    // The Operand class specifies a type directly. -  if (Operand->isSubClassOf("Operand")) -    return UpdateNodeType(ResNo, getValueType(Operand->getValueAsDef("Type")), -                          TP); +  if (Operand->isSubClassOf("Operand")) { +    Record *R = Operand->getValueAsDef("Type"); +    const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); +    return UpdateNodeType(ResNo, getValueTypeByHwMode(R, T.getHwModes()), TP); +  }    // PointerLikeRegClass has a type that is determined at runtime.    if (Operand->isSubClassOf("PointerLikeRegClass")) @@ -1131,11 +1537,53 @@ bool TreePatternNode::UpdateNodeTypeFromInst(unsigned ResNo,    return UpdateNodeType(ResNo, Tgt.getRegisterClass(RC).getValueTypes(), TP);  } +bool TreePatternNode::ContainsUnresolvedType(TreePattern &TP) const { +  for (unsigned i = 0, e = Types.size(); i != e; ++i) +    if (!TP.getInfer().isConcrete(Types[i], true)) +      return true; +  for (unsigned i = 0, e = getNumChildren(); i != e; ++i) +    if (getChild(i)->ContainsUnresolvedType(TP)) +      return true; +  return false; +} + +bool TreePatternNode::hasProperTypeByHwMode() const { +  for (const TypeSetByHwMode &S : Types) +    if (!S.isDefaultOnly()) +      return true; +  for (TreePatternNode *C : Children) +    if (C->hasProperTypeByHwMode()) +      return true; +  return false; +} + +bool TreePatternNode::hasPossibleType() const { +  for (const TypeSetByHwMode &S : Types) +    if (!S.isPossible()) +      return false; +  for (TreePatternNode *C : Children) +    if (!C->hasPossibleType()) +      return false; +  return true; +} + +bool TreePatternNode::setDefaultMode(unsigned Mode) { +  for (TypeSetByHwMode &S : Types) { +    S.makeSimple(Mode); +    // Check if the selected mode had a type conflict. +    if (S.get(DefaultMode).empty()) +      return false; +  } +  for (TreePatternNode *C : Children) +    if (!C->setDefaultMode(Mode)) +      return false; +  return true; +}  //===----------------------------------------------------------------------===//  // SDNodeInfo implementation  // -SDNodeInfo::SDNodeInfo(Record *R) : Def(R) { +SDNodeInfo::SDNodeInfo(Record *R, const CodeGenHwModes &CGH) : Def(R) {    EnumName    = R->getValueAsString("Opcode");    SDClassName = R->getValueAsString("SDClass");    Record *TypeProfile = R->getValueAsDef("TypeProfile"); @@ -1178,7 +1626,8 @@ SDNodeInfo::SDNodeInfo(Record *R) : Def(R) {    // Parse the type constraints.    std::vector<Record*> ConstraintList =      TypeProfile->getValueAsListOfDefs("Constraints"); -  TypeConstraints.assign(ConstraintList.begin(), ConstraintList.end()); +  for (Record *R : ConstraintList) +    TypeConstraints.emplace_back(R, CGH);  }  /// getKnownType - If the type constraints on this node imply a fixed type @@ -1198,7 +1647,9 @@ MVT::SimpleValueType SDNodeInfo::getKnownType(unsigned ResNo) const {      switch (Constraint.ConstraintType) {      default: break;      case SDTypeConstraint::SDTCisVT: -      return Constraint.x.SDTCisVT_Info.VT; +      if (Constraint.VVT.isSimple()) +        return Constraint.VVT.getSimple().SimpleTy; +      break;      case SDTypeConstraint::SDTCisPtrTy:        return MVT::iPTR;      } @@ -1284,8 +1735,10 @@ void TreePatternNode::print(raw_ostream &OS) const {    else      OS << '(' << getOperator()->getName(); -  for (unsigned i = 0, e = Types.size(); i != e; ++i) -    OS << ':' << getExtType(i).getName(); +  for (unsigned i = 0, e = Types.size(); i != e; ++i) { +    OS << ':'; +    getExtType(i).writeToStream(OS); +  }    if (!isLeaf()) {      if (getNumChildren() != 0) { @@ -1368,7 +1821,7 @@ TreePatternNode *TreePatternNode::clone() const {  /// RemoveAllTypes - Recursively strip all the types of this tree.  void TreePatternNode::RemoveAllTypes() {    // Reset to unknown type. -  std::fill(Types.begin(), Types.end(), EEVT::TypeSet()); +  std::fill(Types.begin(), Types.end(), TypeSetByHwMode());    if (isLeaf()) return;    for (unsigned i = 0, e = getNumChildren(); i != e; ++i)      getChild(i)->RemoveAllTypes(); @@ -1485,18 +1938,20 @@ TreePatternNode *TreePatternNode::InlinePatternFragments(TreePattern &TP) {  /// When Unnamed is false, return the type of a named DAG operand such as the  /// GPR:$src operand above.  /// -static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo, -                                     bool NotRegisters, -                                     bool Unnamed, -                                     TreePattern &TP) { +static TypeSetByHwMode getImplicitType(Record *R, unsigned ResNo, +                                       bool NotRegisters, +                                       bool Unnamed, +                                       TreePattern &TP) { +  CodeGenDAGPatterns &CDP = TP.getDAGPatterns(); +    // Check to see if this is a register operand.    if (R->isSubClassOf("RegisterOperand")) {      assert(ResNo == 0 && "Regoperand ref only has one result!");      if (NotRegisters) -      return EEVT::TypeSet(); // Unknown. +      return TypeSetByHwMode(); // Unknown.      Record *RegClass = R->getValueAsDef("RegClass");      const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); -    return EEVT::TypeSet(T.getRegisterClass(RegClass).getValueTypes()); +    return TypeSetByHwMode(T.getRegisterClass(RegClass).getValueTypes());    }    // Check to see if this is a register or a register class. @@ -1505,33 +1960,33 @@ static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo,      // An unnamed register class represents itself as an i32 immediate, for      // example on a COPY_TO_REGCLASS instruction.      if (Unnamed) -      return EEVT::TypeSet(MVT::i32, TP); +      return TypeSetByHwMode(MVT::i32);      // In a named operand, the register class provides the possible set of      // types.      if (NotRegisters) -      return EEVT::TypeSet(); // Unknown. +      return TypeSetByHwMode(); // Unknown.      const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); -    return EEVT::TypeSet(T.getRegisterClass(R).getValueTypes()); +    return TypeSetByHwMode(T.getRegisterClass(R).getValueTypes());    }    if (R->isSubClassOf("PatFrag")) {      assert(ResNo == 0 && "FIXME: PatFrag with multiple results?");      // Pattern fragment types will be resolved when they are inlined. -    return EEVT::TypeSet(); // Unknown. +    return TypeSetByHwMode(); // Unknown.    }    if (R->isSubClassOf("Register")) {      assert(ResNo == 0 && "Registers only produce one result!");      if (NotRegisters) -      return EEVT::TypeSet(); // Unknown. +      return TypeSetByHwMode(); // Unknown.      const CodeGenTarget &T = TP.getDAGPatterns().getTargetInfo(); -    return EEVT::TypeSet(T.getRegisterVTs(R)); +    return TypeSetByHwMode(T.getRegisterVTs(R));    }    if (R->isSubClassOf("SubRegIndex")) {      assert(ResNo == 0 && "SubRegisterIndices only produce one result!"); -    return EEVT::TypeSet(MVT::i32, TP); +    return TypeSetByHwMode(MVT::i32);    }    if (R->isSubClassOf("ValueType")) { @@ -1541,46 +1996,51 @@ static EEVT::TypeSet getImplicitType(Record *R, unsigned ResNo,      //   (sext_inreg GPR:$src, i16)      //                         ~~~      if (Unnamed) -      return EEVT::TypeSet(MVT::Other, TP); +      return TypeSetByHwMode(MVT::Other);      // With a name, the ValueType simply provides the type of the named      // variable.      //      //   (sext_inreg i32:$src, i16)      //               ~~~~~~~~      if (NotRegisters) -      return EEVT::TypeSet(); // Unknown. -    return EEVT::TypeSet(getValueType(R), TP); +      return TypeSetByHwMode(); // Unknown. +    const CodeGenHwModes &CGH = CDP.getTargetInfo().getHwModes(); +    return TypeSetByHwMode(getValueTypeByHwMode(R, CGH));    }    if (R->isSubClassOf("CondCode")) {      assert(ResNo == 0 && "This node only has one result!");      // Using a CondCodeSDNode. -    return EEVT::TypeSet(MVT::Other, TP); +    return TypeSetByHwMode(MVT::Other);    }    if (R->isSubClassOf("ComplexPattern")) {      assert(ResNo == 0 && "FIXME: ComplexPattern with multiple results?");      if (NotRegisters) -      return EEVT::TypeSet(); // Unknown. -   return EEVT::TypeSet(TP.getDAGPatterns().getComplexPattern(R).getValueType(), -                         TP); +      return TypeSetByHwMode(); // Unknown. +    return TypeSetByHwMode(CDP.getComplexPattern(R).getValueType());    }    if (R->isSubClassOf("PointerLikeRegClass")) {      assert(ResNo == 0 && "Regclass can only have one result!"); -    return EEVT::TypeSet(MVT::iPTR, TP); +    TypeSetByHwMode VTS(MVT::iPTR); +    TP.getInfer().expandOverloads(VTS); +    return VTS;    }    if (R->getName() == "node" || R->getName() == "srcvalue" ||        R->getName() == "zero_reg") {      // Placeholder. -    return EEVT::TypeSet(); // Unknown. +    return TypeSetByHwMode(); // Unknown.    } -  if (R->isSubClassOf("Operand")) -    return EEVT::TypeSet(getValueType(R->getValueAsDef("Type"))); +  if (R->isSubClassOf("Operand")) { +    const CodeGenHwModes &CGH = CDP.getTargetInfo().getHwModes(); +    Record *T = R->getValueAsDef("Type"); +    return TypeSetByHwMode(getValueTypeByHwMode(T, CGH)); +  }    TP.error("Unknown node flavor used in pattern: " + R->getName()); -  return EEVT::TypeSet(MVT::Other, TP); +  return TypeSetByHwMode(MVT::Other);  } @@ -1722,29 +2182,34 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) {        assert(Types.size() == 1 && "Invalid IntInit");        // Int inits are always integers. :) -      bool MadeChange = Types[0].EnforceInteger(TP); - -      if (!Types[0].isConcrete()) -        return MadeChange; +      bool MadeChange = TP.getInfer().EnforceInteger(Types[0]); -      MVT::SimpleValueType VT = getType(0); -      if (VT == MVT::iPTR || VT == MVT::iPTRAny) +      if (!TP.getInfer().isConcrete(Types[0], false))          return MadeChange; -      unsigned Size = MVT(VT).getSizeInBits(); -      // Make sure that the value is representable for this type. -      if (Size >= 32) return MadeChange; - -      // Check that the value doesn't use more bits than we have. It must either -      // be a sign- or zero-extended equivalent of the original. -      int64_t SignBitAndAbove = II->getValue() >> (Size - 1); -      if (SignBitAndAbove == -1 || SignBitAndAbove == 0 || SignBitAndAbove == 1) -        return MadeChange; +      ValueTypeByHwMode VVT = TP.getInfer().getConcrete(Types[0], false); +      for (auto &P : VVT) { +        MVT::SimpleValueType VT = P.second.SimpleTy; +        if (VT == MVT::iPTR || VT == MVT::iPTRAny) +          continue; +        unsigned Size = MVT(VT).getSizeInBits(); +        // Make sure that the value is representable for this type. +        if (Size >= 32) +          continue; +        // Check that the value doesn't use more bits than we have. It must +        // either be a sign- or zero-extended equivalent of the original. +        int64_t SignBitAndAbove = II->getValue() >> (Size - 1); +        if (SignBitAndAbove == -1 || SignBitAndAbove == 0 || +            SignBitAndAbove == 1) +          continue; -      TP.error("Integer value '" + itostr(II->getValue()) + -               "' is out of range for type '" + getEnumName(getType(0)) + "'!"); -      return false; +        TP.error("Integer value '" + itostr(II->getValue()) + +                 "' is out of range for type '" + getEnumName(VT) + "'!"); +        break; +      } +      return MadeChange;      } +      return false;    } @@ -1773,7 +2238,7 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) {      bool MadeChange = false;      for (unsigned i = 0; i < getNumChildren(); ++i) -      MadeChange = getChild(i)->ApplyTypeConstraints(TP, NotRegisters); +      MadeChange |= getChild(i)->ApplyTypeConstraints(TP, NotRegisters);      return MadeChange;    } @@ -1818,9 +2283,10 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) {        return false;      } -    bool MadeChange = NI.ApplyTypeConstraints(this, TP); +    bool MadeChange = false;      for (unsigned i = 0, e = getNumChildren(); i != e; ++i)        MadeChange |= getChild(i)->ApplyTypeConstraints(TP, NotRegisters); +    MadeChange |= NI.ApplyTypeConstraints(this, TP);      return MadeChange;    } @@ -1975,18 +2441,6 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) {    }    bool MadeChange = getChild(0)->ApplyTypeConstraints(TP, NotRegisters); - - -  // If either the output or input of the xform does not have exact -  // type info. We assume they must be the same. Otherwise, it is perfectly -  // legal to transform from one type to a completely different type. -#if 0 -  if (!hasTypeSet() || !getChild(0)->hasTypeSet()) { -    bool MadeChange = UpdateNodeType(getChild(0)->getExtType(), TP); -    MadeChange |= getChild(0)->UpdateNodeType(getExtType(), TP); -    return MadeChange; -  } -#endif    return MadeChange;  } @@ -2050,20 +2504,23 @@ bool TreePatternNode::canPatternMatch(std::string &Reason,  TreePattern::TreePattern(Record *TheRec, ListInit *RawPat, bool isInput,                           CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), -                         isInputPattern(isInput), HasError(false) { +                         isInputPattern(isInput), HasError(false), +                         Infer(*this) {    for (Init *I : RawPat->getValues())      Trees.push_back(ParseTreePattern(I, ""));  }  TreePattern::TreePattern(Record *TheRec, DagInit *Pat, bool isInput,                           CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), -                         isInputPattern(isInput), HasError(false) { +                         isInputPattern(isInput), HasError(false), +                         Infer(*this) {    Trees.push_back(ParseTreePattern(Pat, ""));  }  TreePattern::TreePattern(Record *TheRec, TreePatternNode *Pat, bool isInput,                           CodeGenDAGPatterns &cdp) : TheRecord(TheRec), CDP(cdp), -                         isInputPattern(isInput), HasError(false) { +                         isInputPattern(isInput), HasError(false), +                         Infer(*this) {    Trees.push_back(Pat);  } @@ -2158,7 +2615,8 @@ TreePatternNode *TreePattern::ParseTreePattern(Init *TheInit, StringRef OpName){      // Apply the type cast.      assert(New->getNumTypes() == 1 && "FIXME: Unhandled"); -    New->UpdateNodeType(0, getValueType(Operator), *this); +    const CodeGenHwModes &CGH = getDAGPatterns().getTargetInfo().getHwModes(); +    New->UpdateNodeType(0, getValueTypeByHwMode(Operator, CGH), *this);      if (!OpName.empty())        error("ValueType cast should not have a name!"); @@ -2273,7 +2731,7 @@ static bool SimplifyTree(TreePatternNode *&N) {    // If we have a bitconvert with a resolved type and if the source and    // destination types are the same, then the bitconvert is useless, remove it.    if (N->getOperator()->getName() == "bitconvert" && -      N->getExtType(0).isConcrete() && +      N->getExtType(0).isValueTypeByHwMode(false) &&        N->getExtType(0) == N->getChild(0)->getExtType(0) &&        N->getName().empty()) {      N = N->getChild(0); @@ -2304,7 +2762,7 @@ InferAllTypes(const StringMap<SmallVector<TreePatternNode*,1> > *InNamedTypes) {    bool MadeChange = true;    while (MadeChange) {      MadeChange = false; -    for (TreePatternNode *Tree : Trees) { +    for (TreePatternNode *&Tree : Trees) {        MadeChange |= Tree->ApplyTypeConstraints(*this, false);        MadeChange |= SimplifyTree(Tree);      } @@ -2364,7 +2822,7 @@ InferAllTypes(const StringMap<SmallVector<TreePatternNode*,1> > *InNamedTypes) {    bool HasUnresolvedTypes = false;    for (const TreePatternNode *Tree : Trees) -    HasUnresolvedTypes |= Tree->ContainsUnresolvedType(); +    HasUnresolvedTypes |= Tree->ContainsUnresolvedType(*this);    return !HasUnresolvedTypes;  } @@ -2396,8 +2854,10 @@ void TreePattern::dump() const { print(errs()); }  // CodeGenDAGPatterns implementation  // -CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R) : -  Records(R), Target(R) { +CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R, +                                       PatternRewriterFn PatternRewriter) +    : Records(R), Target(R), LegalVTS(Target.getLegalValueTypes()), +      PatternRewriter(PatternRewriter) {    Intrinsics = CodeGenIntrinsicTable(Records, false);    TgtIntrinsics = CodeGenIntrinsicTable(Records, true); @@ -2410,6 +2870,11 @@ CodeGenDAGPatterns::CodeGenDAGPatterns(RecordKeeper &R) :    ParsePatternFragments(/*OutFrags*/true);    ParsePatterns(); +  // Break patterns with parameterized types into a series of patterns, +  // where each one has a fixed type and is predicated on the conditions +  // of the associated HW mode. +  ExpandHwModeBasedTypes(); +    // Generate variants.  For example, commutative patterns can match    // multiple ways.  Add them to PatternsToMatch as well.    GenerateVariants(); @@ -2434,8 +2899,11 @@ Record *CodeGenDAGPatterns::getSDNodeNamed(const std::string &Name) const {  // Parse all of the SDNode definitions for the target, populating SDNodes.  void CodeGenDAGPatterns::ParseNodeInfo() {    std::vector<Record*> Nodes = Records.getAllDerivedDefinitions("SDNode"); +  const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); +    while (!Nodes.empty()) { -    SDNodes.insert(std::make_pair(Nodes.back(), Nodes.back())); +    Record *R = Nodes.back(); +    SDNodes.insert(std::make_pair(R, SDNodeInfo(R, CGH)));      Nodes.pop_back();    } @@ -2489,7 +2957,10 @@ void CodeGenDAGPatterns::ParsePatternFragments(bool OutFrags) {      // Validate the argument list, converting it to set, to discard duplicates.      std::vector<std::string> &Args = P->getArgList(); -    std::set<std::string> OperandsSet(Args.begin(), Args.end()); +    // Copy the args so we can take StringRefs to them. +    auto ArgsCopy = Args; +    SmallDenseSet<StringRef, 4> OperandsSet; +    OperandsSet.insert(ArgsCopy.begin(), ArgsCopy.end());      if (OperandsSet.count(""))        P->error("Cannot have unnamed 'node' values in pattern fragment!"); @@ -2589,7 +3060,7 @@ void CodeGenDAGPatterns::ParseDefaultOperands() {        while (TPN->ApplyTypeConstraints(P, false))          /* Resolve all types */; -      if (TPN->ContainsUnresolvedType()) { +      if (TPN->ContainsUnresolvedType(P)) {          PrintFatalError("Value #" + Twine(i) + " of OperandWithDefaultOps '" +                          DefaultOps[i]->getName() +                          "' doesn't have a concrete type!"); @@ -2981,17 +3452,20 @@ const DAGInstruction &CodeGenDAGPatterns::parseInstructionPattern(    // Verify that the top-level forms in the instruction are of void type, and    // fill in the InstResults map. +  SmallString<32> TypesString;    for (unsigned j = 0, e = I->getNumTrees(); j != e; ++j) { +    TypesString.clear();      TreePatternNode *Pat = I->getTree(j);      if (Pat->getNumTypes() != 0) { -      std::string Types; +      raw_svector_ostream OS(TypesString);        for (unsigned k = 0, ke = Pat->getNumTypes(); k != ke; ++k) {          if (k > 0) -          Types += ", "; -        Types += Pat->getExtType(k).getName(); +          OS << ", "; +        Pat->getExtType(k).writeToStream(OS);        }        I->error("Top-level forms in instruction pattern should have" -               " void types, has types " + Types); +               " void types, has types " + +               OS.str());      }      // Find inputs and outputs, and verify the structure of the uses/defs. @@ -3174,6 +3648,8 @@ void CodeGenDAGPatterns::ParseInstructions() {      TreePattern *I = TheInst.getPattern();      if (!I) continue;  // No pattern. +    if (PatternRewriter) +      PatternRewriter(I);      // FIXME: Assume only the first tree is the pattern. The others are clobber      // nodes.      TreePatternNode *Pattern = I->getTree(0); @@ -3186,14 +3662,13 @@ void CodeGenDAGPatterns::ParseInstructions() {      }      Record *Instr = Entry.first; -    AddPatternToMatch(I, -                      PatternToMatch(Instr, -                                     Instr->getValueAsListInit("Predicates"), -                                     SrcPattern, -                                     TheInst.getResultPattern(), -                                     TheInst.getImpResults(), -                                     Instr->getValueAsInt("AddedComplexity"), -                                     Instr->getID())); +    ListInit *Preds = Instr->getValueAsListInit("Predicates"); +    int Complexity = Instr->getValueAsInt("AddedComplexity"); +    AddPatternToMatch( +        I, +        PatternToMatch(Instr, makePredList(Preds), SrcPattern, +                       TheInst.getResultPattern(), TheInst.getImpResults(), +                       Complexity, Instr->getID()));    }  } @@ -3219,6 +3694,20 @@ static void FindNames(const TreePatternNode *P,    }  } +std::vector<Predicate> CodeGenDAGPatterns::makePredList(ListInit *L) { +  std::vector<Predicate> Preds; +  for (Init *I : L->getValues()) { +    if (DefInit *Pred = dyn_cast<DefInit>(I)) +      Preds.push_back(Pred->getDef()); +    else +      llvm_unreachable("Non-def on the list"); +  } + +  // Sort so that different orders get canonicalized to the same string. +  std::sort(Preds.begin(), Preds.end()); +  return Preds; +} +  void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern,                                             PatternToMatch &&PTM) {    // Do some sanity checking on the pattern we're about to match. @@ -3262,8 +3751,6 @@ void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern,    PatternsToMatch.push_back(std::move(PTM));  } - -  void CodeGenDAGPatterns::InferInstructionFlags() {    ArrayRef<const CodeGenInstruction*> Instructions =      Target.getInstructionsByEnumValue(); @@ -3425,12 +3912,13 @@ static bool ForceArbitraryInstResultType(TreePatternNode *N, TreePattern &TP) {    // If this type is already concrete or completely unknown we can't do    // anything. +  TypeInfer &TI = TP.getInfer();    for (unsigned i = 0, e = N->getNumTypes(); i != e; ++i) { -    if (N->getExtType(i).isCompletelyUnknown() || N->getExtType(i).isConcrete()) +    if (N->getExtType(i).empty() || TI.isConcrete(N->getExtType(i), false))        continue; -    // Otherwise, force its type to the first possibility (an arbitrary choice). -    if (N->getExtType(i).MergeInTypeInfo(N->getExtType(i).getTypeList()[0], TP)) +    // Otherwise, force its type to an arbitrary choice. +    if (TI.forceArbitrary(N->getExtType(i)))        return true;    } @@ -3551,15 +4039,156 @@ void CodeGenDAGPatterns::ParsePatterns() {      TreePattern Temp(Result.getRecord(), DstPattern, false, *this);      Temp.InferAllTypes(); -    AddPatternToMatch( -        Pattern, -        PatternToMatch( -            CurPattern, CurPattern->getValueAsListInit("Predicates"), -            Pattern->getTree(0), Temp.getOnlyTree(), std::move(InstImpResults), -            CurPattern->getValueAsInt("AddedComplexity"), CurPattern->getID())); +    // A pattern may end up with an "impossible" type, i.e. a situation +    // where all types have been eliminated for some node in this pattern. +    // This could occur for intrinsics that only make sense for a specific +    // value type, and use a specific register class. If, for some mode, +    // that register class does not accept that type, the type inference +    // will lead to a contradiction, which is not an error however, but +    // a sign that this pattern will simply never match. +    if (Pattern->getTree(0)->hasPossibleType() && +        Temp.getOnlyTree()->hasPossibleType()) { +      ListInit *Preds = CurPattern->getValueAsListInit("Predicates"); +      int Complexity = CurPattern->getValueAsInt("AddedComplexity"); +      if (PatternRewriter) +        PatternRewriter(Pattern); +      AddPatternToMatch( +          Pattern, +          PatternToMatch( +              CurPattern, makePredList(Preds), Pattern->getTree(0), +              Temp.getOnlyTree(), std::move(InstImpResults), Complexity, +              CurPattern->getID())); +    }    }  } +static void collectModes(std::set<unsigned> &Modes, const TreePatternNode *N) { +  for (const TypeSetByHwMode &VTS : N->getExtTypes()) +    for (const auto &I : VTS) +      Modes.insert(I.first); + +  for (unsigned i = 0, e = N->getNumChildren(); i != e; ++i) +    collectModes(Modes, N->getChild(i)); +} + +void CodeGenDAGPatterns::ExpandHwModeBasedTypes() { +  const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); +  std::map<unsigned,std::vector<Predicate>> ModeChecks; +  std::vector<PatternToMatch> Copy = PatternsToMatch; +  PatternsToMatch.clear(); + +  auto AppendPattern = [this,&ModeChecks](PatternToMatch &P, unsigned Mode) { +    TreePatternNode *NewSrc = P.SrcPattern->clone(); +    TreePatternNode *NewDst = P.DstPattern->clone(); +    if (!NewSrc->setDefaultMode(Mode) || !NewDst->setDefaultMode(Mode)) { +      delete NewSrc; +      delete NewDst; +      return; +    } + +    std::vector<Predicate> Preds = P.Predicates; +    const std::vector<Predicate> &MC = ModeChecks[Mode]; +    Preds.insert(Preds.end(), MC.begin(), MC.end()); +    PatternsToMatch.emplace_back(P.getSrcRecord(), Preds, NewSrc, NewDst, +                                 P.getDstRegs(), P.getAddedComplexity(), +                                 Record::getNewUID(), Mode); +  }; + +  for (PatternToMatch &P : Copy) { +    TreePatternNode *SrcP = nullptr, *DstP = nullptr; +    if (P.SrcPattern->hasProperTypeByHwMode()) +      SrcP = P.SrcPattern; +    if (P.DstPattern->hasProperTypeByHwMode()) +      DstP = P.DstPattern; +    if (!SrcP && !DstP) { +      PatternsToMatch.push_back(P); +      continue; +    } + +    std::set<unsigned> Modes; +    if (SrcP) +      collectModes(Modes, SrcP); +    if (DstP) +      collectModes(Modes, DstP); + +    // The predicate for the default mode needs to be constructed for each +    // pattern separately. +    // Since not all modes must be present in each pattern, if a mode m is +    // absent, then there is no point in constructing a check for m. If such +    // a check was created, it would be equivalent to checking the default +    // mode, except not all modes' predicates would be a part of the checking +    // code. The subsequently generated check for the default mode would then +    // have the exact same patterns, but a different predicate code. To avoid +    // duplicated patterns with different predicate checks, construct the +    // default check as a negation of all predicates that are actually present +    // in the source/destination patterns. +    std::vector<Predicate> DefaultPred; + +    for (unsigned M : Modes) { +      if (M == DefaultMode) +        continue; +      if (ModeChecks.find(M) != ModeChecks.end()) +        continue; + +      // Fill the map entry for this mode. +      const HwMode &HM = CGH.getMode(M); +      ModeChecks[M].emplace_back(Predicate(HM.Features, true)); + +      // Add negations of the HM's predicates to the default predicate. +      DefaultPred.emplace_back(Predicate(HM.Features, false)); +    } + +    for (unsigned M : Modes) { +      if (M == DefaultMode) +        continue; +      AppendPattern(P, M); +    } + +    bool HasDefault = Modes.count(DefaultMode); +    if (HasDefault) +      AppendPattern(P, DefaultMode); +  } +} + +/// Dependent variable map for CodeGenDAGPattern variant generation +typedef StringMap<int> DepVarMap; + +static void FindDepVarsOf(TreePatternNode *N, DepVarMap &DepMap) { +  if (N->isLeaf()) { +    if (N->hasName() && isa<DefInit>(N->getLeafValue())) +      DepMap[N->getName()]++; +  } else { +    for (size_t i = 0, e = N->getNumChildren(); i != e; ++i) +      FindDepVarsOf(N->getChild(i), DepMap); +  } +} + +/// Find dependent variables within child patterns +static void FindDepVars(TreePatternNode *N, MultipleUseVarSet &DepVars) { +  DepVarMap depcounts; +  FindDepVarsOf(N, depcounts); +  for (const auto &Pair : depcounts) { +    if (Pair.getValue() > 1) +      DepVars.insert(Pair.getKey()); +  } +} + +#ifndef NDEBUG +/// Dump the dependent variable set: +static void DumpDepVars(MultipleUseVarSet &DepVars) { +  if (DepVars.empty()) { +    DEBUG(errs() << "<empty set>"); +  } else { +    DEBUG(errs() << "[ "); +    for (const auto &DepVar : DepVars) { +      DEBUG(errs() << DepVar.getKey() << " "); +    } +    DEBUG(errs() << "]"); +  } +} +#endif + +  /// CombineChildVariants - Given a bunch of permutations of each child of the  /// 'operator' node, put them together in all possible ways.  static void CombineChildVariants(TreePatternNode *Orig, @@ -3744,7 +4373,7 @@ static void GenerateVariantsOf(TreePatternNode *N,    // If this node is commutative, consider the commuted order.    bool isCommIntrinsic = N->isCommutativeIntrinsic(CDP);    if (NodeInfo.hasProperty(SDNPCommutative) || isCommIntrinsic) { -    assert((N->getNumChildren()==2 || isCommIntrinsic) && +    assert((N->getNumChildren()>=2 || isCommIntrinsic) &&             "Commutative but doesn't have 2 children!");      // Don't count children which are actually register references.      unsigned NC = 0; @@ -3772,9 +4401,14 @@ static void GenerateVariantsOf(TreePatternNode *N,        for (unsigned i = 3; i != NC; ++i)          Variants.push_back(ChildVariants[i]);        CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); -    } else if (NC == 2) -      CombineChildVariants(N, ChildVariants[1], ChildVariants[0], -                           OutVariants, CDP, DepVars); +    } else if (NC == N->getNumChildren()) { +      std::vector<std::vector<TreePatternNode*> > Variants; +      Variants.push_back(ChildVariants[1]); +      Variants.push_back(ChildVariants[0]); +      for (unsigned i = 2; i != NC; ++i) +        Variants.push_back(ChildVariants[i]); +      CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); +    }    }  } | 
