diff options
Diffstat (limited to 'contrib/llvm-project/clang/lib/ASTMatchers')
9 files changed, 6867 insertions, 0 deletions
diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchFinder.cpp new file mode 100644 index 000000000000..0bac2ed63a92 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -0,0 +1,1716 @@ +//===--- ASTMatchFinder.cpp - Structural query framework ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements an algorithm to efficiently search for matches on AST nodes. +// Uses memoization to support recursive matches like HasDescendant. +// +// The general idea is to visit all AST nodes with a RecursiveASTVisitor, +// calling the Matches(...) method of each matcher we are running on each +// AST node. The matcher can recurse via the ASTMatchFinder interface. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/DeclCXX.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Timer.h" +#include <deque> +#include <memory> +#include <set> + +namespace clang { +namespace ast_matchers { +namespace internal { +namespace { + +typedef MatchFinder::MatchCallback MatchCallback; + +// The maximum number of memoization entries to store. +// 10k has been experimentally found to give a good trade-off +// of performance vs. memory consumption by running matcher +// that match on every statement over a very large codebase. +// +// FIXME: Do some performance optimization in general and +// revisit this number; also, put up micro-benchmarks that we can +// optimize this on. +static const unsigned MaxMemoizationEntries = 10000; + +enum class MatchType { + Ancestors, + + Descendants, + Child, +}; + +// We use memoization to avoid running the same matcher on the same +// AST node twice. This struct is the key for looking up match +// result. It consists of an ID of the MatcherInterface (for +// identifying the matcher), a pointer to the AST node and the +// bound nodes before the matcher was executed. +// +// We currently only memoize on nodes whose pointers identify the +// nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc). +// For \c QualType and \c TypeLoc it is possible to implement +// generation of keys for each type. +// FIXME: Benchmark whether memoization of non-pointer typed nodes +// provides enough benefit for the additional amount of code. +struct MatchKey { + DynTypedMatcher::MatcherIDType MatcherID; + DynTypedNode Node; + BoundNodesTreeBuilder BoundNodes; + TraversalKind Traversal = TK_AsIs; + MatchType Type; + + bool operator<(const MatchKey &Other) const { + return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) < + std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node, + Other.BoundNodes); + } +}; + +// Used to store the result of a match and possibly bound nodes. +struct MemoizedMatchResult { + bool ResultOfMatch; + BoundNodesTreeBuilder Nodes; +}; + +// A RecursiveASTVisitor that traverses all children or all descendants of +// a node. +class MatchChildASTVisitor + : public RecursiveASTVisitor<MatchChildASTVisitor> { +public: + typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase; + + // Creates an AST visitor that matches 'matcher' on all children or + // descendants of a traversed node. max_depth is the maximum depth + // to traverse: use 1 for matching the children and INT_MAX for + // matching the descendants. + MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, int MaxDepth, + bool IgnoreImplicitChildren, + ASTMatchFinder::BindKind Bind) + : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0), + MaxDepth(MaxDepth), IgnoreImplicitChildren(IgnoreImplicitChildren), + Bind(Bind), Matches(false) {} + + // Returns true if a match is found in the subtree rooted at the + // given AST node. This is done via a set of mutually recursive + // functions. Here's how the recursion is done (the *wildcard can + // actually be Decl, Stmt, or Type): + // + // - Traverse(node) calls BaseTraverse(node) when it needs + // to visit the descendants of node. + // - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node)) + // Traverse*(c) for each child c of 'node'. + // - Traverse*(c) in turn calls Traverse(c), completing the + // recursion. + bool findMatch(const DynTypedNode &DynNode) { + reset(); + if (const Decl *D = DynNode.get<Decl>()) + traverse(*D); + else if (const Stmt *S = DynNode.get<Stmt>()) + traverse(*S); + else if (const NestedNameSpecifier *NNS = + DynNode.get<NestedNameSpecifier>()) + traverse(*NNS); + else if (const NestedNameSpecifierLoc *NNSLoc = + DynNode.get<NestedNameSpecifierLoc>()) + traverse(*NNSLoc); + else if (const QualType *Q = DynNode.get<QualType>()) + traverse(*Q); + else if (const TypeLoc *T = DynNode.get<TypeLoc>()) + traverse(*T); + else if (const auto *C = DynNode.get<CXXCtorInitializer>()) + traverse(*C); + else if (const TemplateArgumentLoc *TALoc = + DynNode.get<TemplateArgumentLoc>()) + traverse(*TALoc); + else if (const Attr *A = DynNode.get<Attr>()) + traverse(*A); + // FIXME: Add other base types after adding tests. + + // It's OK to always overwrite the bound nodes, as if there was + // no match in this recursive branch, the result set is empty + // anyway. + *Builder = ResultBindings; + + return Matches; + } + + // The following are overriding methods from the base visitor class. + // They are public only to allow CRTP to work. They are *not *part + // of the public API of this class. + bool TraverseDecl(Decl *DeclNode) { + + if (DeclNode && DeclNode->isImplicit() && + Finder->isTraversalIgnoringImplicitNodes()) + return baseTraverse(*DeclNode); + + ScopedIncrement ScopedDepth(&CurrentDepth); + return (DeclNode == nullptr) || traverse(*DeclNode); + } + + Stmt *getStmtToTraverse(Stmt *StmtNode) { + Stmt *StmtToTraverse = StmtNode; + if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) { + auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode); + if (LambdaNode && Finder->isTraversalIgnoringImplicitNodes()) + StmtToTraverse = LambdaNode; + else + StmtToTraverse = + Finder->getASTContext().getParentMapContext().traverseIgnored( + ExprNode); + } + return StmtToTraverse; + } + + bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) { + // If we need to keep track of the depth, we can't perform data recursion. + if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX)) + Queue = nullptr; + + ScopedIncrement ScopedDepth(&CurrentDepth); + Stmt *StmtToTraverse = getStmtToTraverse(StmtNode); + if (!StmtToTraverse) + return true; + + if (IgnoreImplicitChildren && isa<CXXDefaultArgExpr>(StmtNode)) + return true; + + if (!match(*StmtToTraverse)) + return false; + return VisitorBase::TraverseStmt(StmtToTraverse, Queue); + } + // We assume that the QualType and the contained type are on the same + // hierarchy level. Thus, we try to match either of them. + bool TraverseType(QualType TypeNode) { + if (TypeNode.isNull()) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + // Match the Type. + if (!match(*TypeNode)) + return false; + // The QualType is matched inside traverse. + return traverse(TypeNode); + } + // We assume that the TypeLoc, contained QualType and contained Type all are + // on the same hierarchy level. Thus, we try to match all of them. + bool TraverseTypeLoc(TypeLoc TypeLocNode) { + if (TypeLocNode.isNull()) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + // Match the Type. + if (!match(*TypeLocNode.getType())) + return false; + // Match the QualType. + if (!match(TypeLocNode.getType())) + return false; + // The TypeLoc is matched inside traverse. + return traverse(TypeLocNode); + } + bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) { + ScopedIncrement ScopedDepth(&CurrentDepth); + return (NNS == nullptr) || traverse(*NNS); + } + bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) { + if (!NNS) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + if (!match(*NNS.getNestedNameSpecifier())) + return false; + return traverse(NNS); + } + bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) { + if (!CtorInit) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + return traverse(*CtorInit); + } + bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL) { + ScopedIncrement ScopedDepth(&CurrentDepth); + return traverse(TAL); + } + bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) { + if (!Finder->isTraversalIgnoringImplicitNodes()) + return VisitorBase::TraverseCXXForRangeStmt(Node); + if (!Node) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + if (auto *Init = Node->getInit()) + if (!traverse(*Init)) + return false; + if (!match(*Node->getLoopVariable())) + return false; + if (match(*Node->getRangeInit())) + if (!VisitorBase::TraverseStmt(Node->getRangeInit())) + return false; + if (!match(*Node->getBody())) + return false; + return VisitorBase::TraverseStmt(Node->getBody()); + } + bool TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *Node) { + if (!Finder->isTraversalIgnoringImplicitNodes()) + return VisitorBase::TraverseCXXRewrittenBinaryOperator(Node); + if (!Node) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + + return match(*Node->getLHS()) && match(*Node->getRHS()); + } + bool TraverseAttr(Attr *A) { + if (A == nullptr || + (A->isImplicit() && + Finder->getASTContext().getParentMapContext().getTraversalKind() == + TK_IgnoreUnlessSpelledInSource)) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + return traverse(*A); + } + bool TraverseLambdaExpr(LambdaExpr *Node) { + if (!Finder->isTraversalIgnoringImplicitNodes()) + return VisitorBase::TraverseLambdaExpr(Node); + if (!Node) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + + for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) { + const auto *C = Node->capture_begin() + I; + if (!C->isExplicit()) + continue; + if (Node->isInitCapture(C) && !match(*C->getCapturedVar())) + return false; + if (!match(*Node->capture_init_begin()[I])) + return false; + } + + if (const auto *TPL = Node->getTemplateParameterList()) { + for (const auto *TP : *TPL) { + if (!match(*TP)) + return false; + } + } + + for (const auto *P : Node->getCallOperator()->parameters()) { + if (!match(*P)) + return false; + } + + if (!match(*Node->getBody())) + return false; + + return VisitorBase::TraverseStmt(Node->getBody()); + } + + bool shouldVisitTemplateInstantiations() const { return true; } + bool shouldVisitImplicitCode() const { return !IgnoreImplicitChildren; } + +private: + // Used for updating the depth during traversal. + struct ScopedIncrement { + explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); } + ~ScopedIncrement() { --(*Depth); } + + private: + int *Depth; + }; + + // Resets the state of this object. + void reset() { + Matches = false; + CurrentDepth = 0; + } + + // Forwards the call to the corresponding Traverse*() method in the + // base visitor class. + bool baseTraverse(const Decl &DeclNode) { + return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode)); + } + bool baseTraverse(const Stmt &StmtNode) { + return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode)); + } + bool baseTraverse(QualType TypeNode) { + return VisitorBase::TraverseType(TypeNode); + } + bool baseTraverse(TypeLoc TypeLocNode) { + return VisitorBase::TraverseTypeLoc(TypeLocNode); + } + bool baseTraverse(const NestedNameSpecifier &NNS) { + return VisitorBase::TraverseNestedNameSpecifier( + const_cast<NestedNameSpecifier*>(&NNS)); + } + bool baseTraverse(NestedNameSpecifierLoc NNS) { + return VisitorBase::TraverseNestedNameSpecifierLoc(NNS); + } + bool baseTraverse(const CXXCtorInitializer &CtorInit) { + return VisitorBase::TraverseConstructorInitializer( + const_cast<CXXCtorInitializer *>(&CtorInit)); + } + bool baseTraverse(TemplateArgumentLoc TAL) { + return VisitorBase::TraverseTemplateArgumentLoc(TAL); + } + bool baseTraverse(const Attr &AttrNode) { + return VisitorBase::TraverseAttr(const_cast<Attr *>(&AttrNode)); + } + + // Sets 'Matched' to true if 'Matcher' matches 'Node' and: + // 0 < CurrentDepth <= MaxDepth. + // + // Returns 'true' if traversal should continue after this function + // returns, i.e. if no match is found or 'Bind' is 'BK_All'. + template <typename T> + bool match(const T &Node) { + if (CurrentDepth == 0 || CurrentDepth > MaxDepth) { + return true; + } + if (Bind != ASTMatchFinder::BK_All) { + BoundNodesTreeBuilder RecursiveBuilder(*Builder); + if (Matcher->matches(DynTypedNode::create(Node), Finder, + &RecursiveBuilder)) { + Matches = true; + ResultBindings.addMatch(RecursiveBuilder); + return false; // Abort as soon as a match is found. + } + } else { + BoundNodesTreeBuilder RecursiveBuilder(*Builder); + if (Matcher->matches(DynTypedNode::create(Node), Finder, + &RecursiveBuilder)) { + // After the first match the matcher succeeds. + Matches = true; + ResultBindings.addMatch(RecursiveBuilder); + } + } + return true; + } + + // Traverses the subtree rooted at 'Node'; returns true if the + // traversal should continue after this function returns. + template <typename T> + bool traverse(const T &Node) { + static_assert(IsBaseType<T>::value, + "traverse can only be instantiated with base type"); + if (!match(Node)) + return false; + return baseTraverse(Node); + } + + const DynTypedMatcher *const Matcher; + ASTMatchFinder *const Finder; + BoundNodesTreeBuilder *const Builder; + BoundNodesTreeBuilder ResultBindings; + int CurrentDepth; + const int MaxDepth; + const bool IgnoreImplicitChildren; + const ASTMatchFinder::BindKind Bind; + bool Matches; +}; + +// Controls the outermost traversal of the AST and allows to match multiple +// matchers. +class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>, + public ASTMatchFinder { +public: + MatchASTVisitor(const MatchFinder::MatchersByType *Matchers, + const MatchFinder::MatchFinderOptions &Options) + : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {} + + ~MatchASTVisitor() override { + if (Options.CheckProfiling) { + Options.CheckProfiling->Records = std::move(TimeByBucket); + } + } + + void onStartOfTranslationUnit() { + const bool EnableCheckProfiling = Options.CheckProfiling.has_value(); + TimeBucketRegion Timer; + for (MatchCallback *MC : Matchers->AllCallbacks) { + if (EnableCheckProfiling) + Timer.setBucket(&TimeByBucket[MC->getID()]); + MC->onStartOfTranslationUnit(); + } + } + + void onEndOfTranslationUnit() { + const bool EnableCheckProfiling = Options.CheckProfiling.has_value(); + TimeBucketRegion Timer; + for (MatchCallback *MC : Matchers->AllCallbacks) { + if (EnableCheckProfiling) + Timer.setBucket(&TimeByBucket[MC->getID()]); + MC->onEndOfTranslationUnit(); + } + } + + void set_active_ast_context(ASTContext *NewActiveASTContext) { + ActiveASTContext = NewActiveASTContext; + } + + // The following Visit*() and Traverse*() functions "override" + // methods in RecursiveASTVisitor. + + bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) { + // When we see 'typedef A B', we add name 'B' to the set of names + // A's canonical type maps to. This is necessary for implementing + // isDerivedFrom(x) properly, where x can be the name of the base + // class or any of its aliases. + // + // In general, the is-alias-of (as defined by typedefs) relation + // is tree-shaped, as you can typedef a type more than once. For + // example, + // + // typedef A B; + // typedef A C; + // typedef C D; + // typedef C E; + // + // gives you + // + // A + // |- B + // `- C + // |- D + // `- E + // + // It is wrong to assume that the relation is a chain. A correct + // implementation of isDerivedFrom() needs to recognize that B and + // E are aliases, even though neither is a typedef of the other. + // Therefore, we cannot simply walk through one typedef chain to + // find out whether the type name matches. + const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr(); + const Type *CanonicalType = // root of the typedef tree + ActiveASTContext->getCanonicalType(TypeNode); + TypeAliases[CanonicalType].insert(DeclNode); + return true; + } + + bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) { + const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface(); + CompatibleAliases[InterfaceDecl].insert(CAD); + return true; + } + + bool TraverseDecl(Decl *DeclNode); + bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr); + bool TraverseType(QualType TypeNode); + bool TraverseTypeLoc(TypeLoc TypeNode); + bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS); + bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS); + bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit); + bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL); + bool TraverseAttr(Attr *AttrNode); + + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { + if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) { + { + ASTNodeNotAsIsSourceScope RAII(this, true); + TraverseStmt(RF->getInit()); + // Don't traverse under the loop variable + match(*RF->getLoopVariable()); + TraverseStmt(RF->getRangeInit()); + } + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + for (auto *SubStmt : RF->children()) { + if (SubStmt != RF->getBody()) + TraverseStmt(SubStmt); + } + } + TraverseStmt(RF->getBody()); + return true; + } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) { + { + ASTNodeNotAsIsSourceScope RAII(this, true); + TraverseStmt(const_cast<Expr *>(RBO->getLHS())); + TraverseStmt(const_cast<Expr *>(RBO->getRHS())); + } + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + for (auto *SubStmt : RBO->children()) { + TraverseStmt(SubStmt); + } + } + return true; + } else if (auto *LE = dyn_cast<LambdaExpr>(S)) { + for (auto I : llvm::zip(LE->captures(), LE->capture_inits())) { + auto C = std::get<0>(I); + ASTNodeNotSpelledInSourceScope RAII( + this, TraversingASTNodeNotSpelledInSource || !C.isExplicit()); + TraverseLambdaCapture(LE, &C, std::get<1>(I)); + } + + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + TraverseDecl(LE->getLambdaClass()); + } + { + ASTNodeNotAsIsSourceScope RAII(this, true); + + // We need to poke around to find the bits that might be explicitly + // written. + TypeLoc TL = LE->getCallOperator()->getTypeSourceInfo()->getTypeLoc(); + FunctionProtoTypeLoc Proto = TL.getAsAdjusted<FunctionProtoTypeLoc>(); + + if (auto *TPL = LE->getTemplateParameterList()) { + for (NamedDecl *D : *TPL) { + TraverseDecl(D); + } + if (Expr *RequiresClause = TPL->getRequiresClause()) { + TraverseStmt(RequiresClause); + } + } + + if (LE->hasExplicitParameters()) { + // Visit parameters. + for (ParmVarDecl *Param : Proto.getParams()) + TraverseDecl(Param); + } + + const auto *T = Proto.getTypePtr(); + for (const auto &E : T->exceptions()) + TraverseType(E); + + if (Expr *NE = T->getNoexceptExpr()) + TraverseStmt(NE, Queue); + + if (LE->hasExplicitResultType()) + TraverseTypeLoc(Proto.getReturnLoc()); + TraverseStmt(LE->getTrailingRequiresClause()); + } + + TraverseStmt(LE->getBody()); + return true; + } + return RecursiveASTVisitor<MatchASTVisitor>::dataTraverseNode(S, Queue); + } + + // Matches children or descendants of 'Node' with 'BaseMatcher'. + bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder, int MaxDepth, + BindKind Bind) { + // For AST-nodes that don't have an identity, we can't memoize. + if (!Node.getMemoizationData() || !Builder->isComparable()) + return matchesRecursively(Node, Matcher, Builder, MaxDepth, Bind); + + MatchKey Key; + Key.MatcherID = Matcher.getID(); + Key.Node = Node; + // Note that we key on the bindings *before* the match. + Key.BoundNodes = *Builder; + Key.Traversal = Ctx.getParentMapContext().getTraversalKind(); + // Memoize result even doing a single-level match, it might be expensive. + Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants; + MemoizationMap::iterator I = ResultCache.find(Key); + if (I != ResultCache.end()) { + *Builder = I->second.Nodes; + return I->second.ResultOfMatch; + } + + MemoizedMatchResult Result; + Result.Nodes = *Builder; + Result.ResultOfMatch = + matchesRecursively(Node, Matcher, &Result.Nodes, MaxDepth, Bind); + + MemoizedMatchResult &CachedResult = ResultCache[Key]; + CachedResult = std::move(Result); + + *Builder = CachedResult.Nodes; + return CachedResult.ResultOfMatch; + } + + // Matches children or descendants of 'Node' with 'BaseMatcher'. + bool matchesRecursively(const DynTypedNode &Node, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder, int MaxDepth, + BindKind Bind) { + bool ScopedTraversal = TraversingASTNodeNotSpelledInSource || + TraversingASTChildrenNotSpelledInSource; + + bool IgnoreImplicitChildren = false; + + if (isTraversalIgnoringImplicitNodes()) { + IgnoreImplicitChildren = true; + } + + ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal); + + MatchChildASTVisitor Visitor(&Matcher, this, Builder, MaxDepth, + IgnoreImplicitChildren, Bind); + return Visitor.findMatch(Node); + } + + bool classIsDerivedFrom(const CXXRecordDecl *Declaration, + const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, + bool Directly) override; + +private: + bool + classIsDerivedFromImpl(const CXXRecordDecl *Declaration, + const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, bool Directly, + llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited); + +public: + bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration, + const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, + bool Directly) override; + +public: + // Implements ASTMatchFinder::matchesChildOf. + bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder, BindKind Bind) override { + if (ResultCache.size() > MaxMemoizationEntries) + ResultCache.clear(); + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Bind); + } + // Implements ASTMatchFinder::matchesDescendantOf. + bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder, + BindKind Bind) override { + if (ResultCache.size() > MaxMemoizationEntries) + ResultCache.clear(); + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX, + Bind); + } + // Implements ASTMatchFinder::matchesAncestorOf. + bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder, + AncestorMatchMode MatchMode) override { + // Reset the cache outside of the recursive call to make sure we + // don't invalidate any iterators. + if (ResultCache.size() > MaxMemoizationEntries) + ResultCache.clear(); + if (MatchMode == AncestorMatchMode::AMM_ParentOnly) + return matchesParentOf(Node, Matcher, Builder); + return matchesAnyAncestorOf(Node, Ctx, Matcher, Builder); + } + + // Matches all registered matchers on the given node and calls the + // result callback for every node that matches. + void match(const DynTypedNode &Node) { + // FIXME: Improve this with a switch or a visitor pattern. + if (auto *N = Node.get<Decl>()) { + match(*N); + } else if (auto *N = Node.get<Stmt>()) { + match(*N); + } else if (auto *N = Node.get<Type>()) { + match(*N); + } else if (auto *N = Node.get<QualType>()) { + match(*N); + } else if (auto *N = Node.get<NestedNameSpecifier>()) { + match(*N); + } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) { + match(*N); + } else if (auto *N = Node.get<TypeLoc>()) { + match(*N); + } else if (auto *N = Node.get<CXXCtorInitializer>()) { + match(*N); + } else if (auto *N = Node.get<TemplateArgumentLoc>()) { + match(*N); + } else if (auto *N = Node.get<Attr>()) { + match(*N); + } + } + + template <typename T> void match(const T &Node) { + matchDispatch(&Node); + } + + // Implements ASTMatchFinder::getASTContext. + ASTContext &getASTContext() const override { return *ActiveASTContext; } + + bool shouldVisitTemplateInstantiations() const { return true; } + bool shouldVisitImplicitCode() const { return true; } + + // We visit the lambda body explicitly, so instruct the RAV + // to not visit it on our behalf too. + bool shouldVisitLambdaBody() const { return false; } + + bool IsMatchingInASTNodeNotSpelledInSource() const override { + return TraversingASTNodeNotSpelledInSource; + } + bool isMatchingChildrenNotSpelledInSource() const override { + return TraversingASTChildrenNotSpelledInSource; + } + void setMatchingChildrenNotSpelledInSource(bool Set) override { + TraversingASTChildrenNotSpelledInSource = Set; + } + + bool IsMatchingInASTNodeNotAsIs() const override { + return TraversingASTNodeNotAsIs; + } + + bool TraverseTemplateInstantiations(ClassTemplateDecl *D) { + ASTNodeNotSpelledInSourceScope RAII(this, true); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations( + D); + } + + bool TraverseTemplateInstantiations(VarTemplateDecl *D) { + ASTNodeNotSpelledInSourceScope RAII(this, true); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations( + D); + } + + bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) { + ASTNodeNotSpelledInSourceScope RAII(this, true); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations( + D); + } + +private: + bool TraversingASTNodeNotSpelledInSource = false; + bool TraversingASTNodeNotAsIs = false; + bool TraversingASTChildrenNotSpelledInSource = false; + + class CurMatchData { +// We don't have enough free low bits in 32bit builds to discriminate 8 pointer +// types in PointerUnion. so split the union in 2 using a free bit from the +// callback pointer. +#define CMD_TYPES_0 \ + const QualType *, const TypeLoc *, const NestedNameSpecifier *, \ + const NestedNameSpecifierLoc * +#define CMD_TYPES_1 \ + const CXXCtorInitializer *, const TemplateArgumentLoc *, const Attr *, \ + const DynTypedNode * + +#define IMPL(Index) \ + template <typename NodeType> \ + std::enable_if_t< \ + llvm::is_one_of<const NodeType *, CMD_TYPES_##Index>::value> \ + SetCallbackAndRawNode(const MatchCallback *CB, const NodeType &N) { \ + assertEmpty(); \ + Callback.setPointerAndInt(CB, Index); \ + Node##Index = &N; \ + } \ + \ + template <typename T> \ + std::enable_if_t<llvm::is_one_of<const T *, CMD_TYPES_##Index>::value, \ + const T *> \ + getNode() const { \ + assertHoldsState(); \ + return Callback.getInt() == (Index) ? Node##Index.dyn_cast<const T *>() \ + : nullptr; \ + } + + public: + CurMatchData() : Node0(nullptr) {} + + IMPL(0) + IMPL(1) + + const MatchCallback *getCallback() const { return Callback.getPointer(); } + + void SetBoundNodes(const BoundNodes &BN) { + assertHoldsState(); + BNodes = &BN; + } + + void clearBoundNodes() { + assertHoldsState(); + BNodes = nullptr; + } + + const BoundNodes *getBoundNodes() const { + assertHoldsState(); + return BNodes; + } + + void reset() { + assertHoldsState(); + Callback.setPointerAndInt(nullptr, 0); + Node0 = nullptr; + } + + private: + void assertHoldsState() const { + assert(Callback.getPointer() != nullptr && !Node0.isNull()); + } + + void assertEmpty() const { + assert(Callback.getPointer() == nullptr && Node0.isNull() && + BNodes == nullptr); + } + + llvm::PointerIntPair<const MatchCallback *, 1> Callback; + union { + llvm::PointerUnion<CMD_TYPES_0> Node0; + llvm::PointerUnion<CMD_TYPES_1> Node1; + }; + const BoundNodes *BNodes = nullptr; + +#undef CMD_TYPES_0 +#undef CMD_TYPES_1 +#undef IMPL + } CurMatchState; + + struct CurMatchRAII { + template <typename NodeType> + CurMatchRAII(MatchASTVisitor &MV, const MatchCallback *CB, + const NodeType &NT) + : MV(MV) { + MV.CurMatchState.SetCallbackAndRawNode(CB, NT); + } + + ~CurMatchRAII() { MV.CurMatchState.reset(); } + + private: + MatchASTVisitor &MV; + }; + +public: + class TraceReporter : llvm::PrettyStackTraceEntry { + static void dumpNode(const ASTContext &Ctx, const DynTypedNode &Node, + raw_ostream &OS) { + if (const auto *D = Node.get<Decl>()) { + OS << D->getDeclKindName() << "Decl "; + if (const auto *ND = dyn_cast<NamedDecl>(D)) { + ND->printQualifiedName(OS); + OS << " : "; + } else + OS << ": "; + D->getSourceRange().print(OS, Ctx.getSourceManager()); + } else if (const auto *S = Node.get<Stmt>()) { + OS << S->getStmtClassName() << " : "; + S->getSourceRange().print(OS, Ctx.getSourceManager()); + } else if (const auto *T = Node.get<Type>()) { + OS << T->getTypeClassName() << "Type : "; + QualType(T, 0).print(OS, Ctx.getPrintingPolicy()); + } else if (const auto *QT = Node.get<QualType>()) { + OS << "QualType : "; + QT->print(OS, Ctx.getPrintingPolicy()); + } else { + OS << Node.getNodeKind().asStringRef() << " : "; + Node.getSourceRange().print(OS, Ctx.getSourceManager()); + } + } + + static void dumpNodeFromState(const ASTContext &Ctx, + const CurMatchData &State, raw_ostream &OS) { + if (const DynTypedNode *MatchNode = State.getNode<DynTypedNode>()) { + dumpNode(Ctx, *MatchNode, OS); + } else if (const auto *QT = State.getNode<QualType>()) { + dumpNode(Ctx, DynTypedNode::create(*QT), OS); + } else if (const auto *TL = State.getNode<TypeLoc>()) { + dumpNode(Ctx, DynTypedNode::create(*TL), OS); + } else if (const auto *NNS = State.getNode<NestedNameSpecifier>()) { + dumpNode(Ctx, DynTypedNode::create(*NNS), OS); + } else if (const auto *NNSL = State.getNode<NestedNameSpecifierLoc>()) { + dumpNode(Ctx, DynTypedNode::create(*NNSL), OS); + } else if (const auto *CtorInit = State.getNode<CXXCtorInitializer>()) { + dumpNode(Ctx, DynTypedNode::create(*CtorInit), OS); + } else if (const auto *TAL = State.getNode<TemplateArgumentLoc>()) { + dumpNode(Ctx, DynTypedNode::create(*TAL), OS); + } else if (const auto *At = State.getNode<Attr>()) { + dumpNode(Ctx, DynTypedNode::create(*At), OS); + } + } + + public: + TraceReporter(const MatchASTVisitor &MV) : MV(MV) {} + void print(raw_ostream &OS) const override { + const CurMatchData &State = MV.CurMatchState; + const MatchCallback *CB = State.getCallback(); + if (!CB) { + OS << "ASTMatcher: Not currently matching\n"; + return; + } + + assert(MV.ActiveASTContext && + "ActiveASTContext should be set if there is a matched callback"); + + ASTContext &Ctx = MV.getASTContext(); + + if (const BoundNodes *Nodes = State.getBoundNodes()) { + OS << "ASTMatcher: Processing '" << CB->getID() << "' against:\n\t"; + dumpNodeFromState(Ctx, State, OS); + const BoundNodes::IDToNodeMap &Map = Nodes->getMap(); + if (Map.empty()) { + OS << "\nNo bound nodes\n"; + return; + } + OS << "\n--- Bound Nodes Begin ---\n"; + for (const auto &Item : Map) { + OS << " " << Item.first << " - { "; + dumpNode(Ctx, Item.second, OS); + OS << " }\n"; + } + OS << "--- Bound Nodes End ---\n"; + } else { + OS << "ASTMatcher: Matching '" << CB->getID() << "' against:\n\t"; + dumpNodeFromState(Ctx, State, OS); + OS << '\n'; + } + } + + private: + const MatchASTVisitor &MV; + }; + +private: + struct ASTNodeNotSpelledInSourceScope { + ASTNodeNotSpelledInSourceScope(MatchASTVisitor *V, bool B) + : MV(V), MB(V->TraversingASTNodeNotSpelledInSource) { + V->TraversingASTNodeNotSpelledInSource = B; + } + ~ASTNodeNotSpelledInSourceScope() { + MV->TraversingASTNodeNotSpelledInSource = MB; + } + + private: + MatchASTVisitor *MV; + bool MB; + }; + + struct ASTNodeNotAsIsSourceScope { + ASTNodeNotAsIsSourceScope(MatchASTVisitor *V, bool B) + : MV(V), MB(V->TraversingASTNodeNotAsIs) { + V->TraversingASTNodeNotAsIs = B; + } + ~ASTNodeNotAsIsSourceScope() { MV->TraversingASTNodeNotAsIs = MB; } + + private: + MatchASTVisitor *MV; + bool MB; + }; + + class TimeBucketRegion { + public: + TimeBucketRegion() = default; + ~TimeBucketRegion() { setBucket(nullptr); } + + /// Start timing for \p NewBucket. + /// + /// If there was a bucket already set, it will finish the timing for that + /// other bucket. + /// \p NewBucket will be timed until the next call to \c setBucket() or + /// until the \c TimeBucketRegion is destroyed. + /// If \p NewBucket is the same as the currently timed bucket, this call + /// does nothing. + void setBucket(llvm::TimeRecord *NewBucket) { + if (Bucket != NewBucket) { + auto Now = llvm::TimeRecord::getCurrentTime(true); + if (Bucket) + *Bucket += Now; + if (NewBucket) + *NewBucket -= Now; + Bucket = NewBucket; + } + } + + private: + llvm::TimeRecord *Bucket = nullptr; + }; + + /// Runs all the \p Matchers on \p Node. + /// + /// Used by \c matchDispatch() below. + template <typename T, typename MC> + void matchWithoutFilter(const T &Node, const MC &Matchers) { + const bool EnableCheckProfiling = Options.CheckProfiling.has_value(); + TimeBucketRegion Timer; + for (const auto &MP : Matchers) { + if (EnableCheckProfiling) + Timer.setBucket(&TimeByBucket[MP.second->getID()]); + BoundNodesTreeBuilder Builder; + CurMatchRAII RAII(*this, MP.second, Node); + if (MP.first.matches(Node, this, &Builder)) { + MatchVisitor Visitor(*this, ActiveASTContext, MP.second); + Builder.visitMatches(&Visitor); + } + } + } + + void matchWithFilter(const DynTypedNode &DynNode) { + auto Kind = DynNode.getNodeKind(); + auto it = MatcherFiltersMap.find(Kind); + const auto &Filter = + it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind); + + if (Filter.empty()) + return; + + const bool EnableCheckProfiling = Options.CheckProfiling.has_value(); + TimeBucketRegion Timer; + auto &Matchers = this->Matchers->DeclOrStmt; + for (unsigned short I : Filter) { + auto &MP = Matchers[I]; + if (EnableCheckProfiling) + Timer.setBucket(&TimeByBucket[MP.second->getID()]); + BoundNodesTreeBuilder Builder; + + { + TraversalKindScope RAII(getASTContext(), MP.first.getTraversalKind()); + if (getASTContext().getParentMapContext().traverseIgnored(DynNode) != + DynNode) + continue; + } + + CurMatchRAII RAII(*this, MP.second, DynNode); + if (MP.first.matches(DynNode, this, &Builder)) { + MatchVisitor Visitor(*this, ActiveASTContext, MP.second); + Builder.visitMatches(&Visitor); + } + } + } + + const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) { + auto &Filter = MatcherFiltersMap[Kind]; + auto &Matchers = this->Matchers->DeclOrStmt; + assert((Matchers.size() < USHRT_MAX) && "Too many matchers."); + for (unsigned I = 0, E = Matchers.size(); I != E; ++I) { + if (Matchers[I].first.canMatchNodesOfKind(Kind)) { + Filter.push_back(I); + } + } + return Filter; + } + + /// @{ + /// Overloads to pair the different node types to their matchers. + void matchDispatch(const Decl *Node) { + return matchWithFilter(DynTypedNode::create(*Node)); + } + void matchDispatch(const Stmt *Node) { + return matchWithFilter(DynTypedNode::create(*Node)); + } + + void matchDispatch(const Type *Node) { + matchWithoutFilter(QualType(Node, 0), Matchers->Type); + } + void matchDispatch(const TypeLoc *Node) { + matchWithoutFilter(*Node, Matchers->TypeLoc); + } + void matchDispatch(const QualType *Node) { + matchWithoutFilter(*Node, Matchers->Type); + } + void matchDispatch(const NestedNameSpecifier *Node) { + matchWithoutFilter(*Node, Matchers->NestedNameSpecifier); + } + void matchDispatch(const NestedNameSpecifierLoc *Node) { + matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc); + } + void matchDispatch(const CXXCtorInitializer *Node) { + matchWithoutFilter(*Node, Matchers->CtorInit); + } + void matchDispatch(const TemplateArgumentLoc *Node) { + matchWithoutFilter(*Node, Matchers->TemplateArgumentLoc); + } + void matchDispatch(const Attr *Node) { + matchWithoutFilter(*Node, Matchers->Attr); + } + void matchDispatch(const void *) { /* Do nothing. */ } + /// @} + + // Returns whether a direct parent of \p Node matches \p Matcher. + // Unlike matchesAnyAncestorOf there's no memoization: it doesn't save much. + bool matchesParentOf(const DynTypedNode &Node, const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder) { + for (const auto &Parent : ActiveASTContext->getParents(Node)) { + BoundNodesTreeBuilder BuilderCopy = *Builder; + if (Matcher.matches(Parent, this, &BuilderCopy)) { + *Builder = std::move(BuilderCopy); + return true; + } + } + return false; + } + + // Returns whether an ancestor of \p Node matches \p Matcher. + // + // The order of matching (which can lead to different nodes being bound in + // case there are multiple matches) is breadth first search. + // + // To allow memoization in the very common case of having deeply nested + // expressions inside a template function, we first walk up the AST, memoizing + // the result of the match along the way, as long as there is only a single + // parent. + // + // Once there are multiple parents, the breadth first search order does not + // allow simple memoization on the ancestors. Thus, we only memoize as long + // as there is a single parent. + // + // We avoid a recursive implementation to prevent excessive stack use on + // very deep ASTs (similarly to RecursiveASTVisitor's data recursion). + bool matchesAnyAncestorOf(DynTypedNode Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, + BoundNodesTreeBuilder *Builder) { + + // Memoization keys that can be updated with the result. + // These are the memoizable nodes in the chain of unique parents, which + // terminates when a node has multiple parents, or matches, or is the root. + std::vector<MatchKey> Keys; + // When returning, update the memoization cache. + auto Finish = [&](bool Matched) { + for (const auto &Key : Keys) { + MemoizedMatchResult &CachedResult = ResultCache[Key]; + CachedResult.ResultOfMatch = Matched; + CachedResult.Nodes = *Builder; + } + return Matched; + }; + + // Loop while there's a single parent and we want to attempt memoization. + DynTypedNodeList Parents{ArrayRef<DynTypedNode>()}; // after loop: size != 1 + for (;;) { + // A cache key only makes sense if memoization is possible. + if (Builder->isComparable()) { + Keys.emplace_back(); + Keys.back().MatcherID = Matcher.getID(); + Keys.back().Node = Node; + Keys.back().BoundNodes = *Builder; + Keys.back().Traversal = Ctx.getParentMapContext().getTraversalKind(); + Keys.back().Type = MatchType::Ancestors; + + // Check the cache. + MemoizationMap::iterator I = ResultCache.find(Keys.back()); + if (I != ResultCache.end()) { + Keys.pop_back(); // Don't populate the cache for the matching node! + *Builder = I->second.Nodes; + return Finish(I->second.ResultOfMatch); + } + } + + Parents = ActiveASTContext->getParents(Node); + // Either no parents or multiple parents: leave chain+memoize mode and + // enter bfs+forgetful mode. + if (Parents.size() != 1) + break; + + // Check the next parent. + Node = *Parents.begin(); + BoundNodesTreeBuilder BuilderCopy = *Builder; + if (Matcher.matches(Node, this, &BuilderCopy)) { + *Builder = std::move(BuilderCopy); + return Finish(true); + } + } + // We reached the end of the chain. + + if (Parents.empty()) { + // Nodes may have no parents if: + // a) the node is the TranslationUnitDecl + // b) we have a limited traversal scope that excludes the parent edges + // c) there is a bug in the AST, and the node is not reachable + // Usually the traversal scope is the whole AST, which precludes b. + // Bugs are common enough that it's worthwhile asserting when we can. +#ifndef NDEBUG + if (!Node.get<TranslationUnitDecl>() && + /* Traversal scope is full AST if any of the bounds are the TU */ + llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) { + return D->getKind() == Decl::TranslationUnit; + })) { + llvm::errs() << "Tried to match orphan node:\n"; + Node.dump(llvm::errs(), *ActiveASTContext); + llvm_unreachable("Parent map should be complete!"); + } +#endif + } else { + assert(Parents.size() > 1); + // BFS starting from the parents not yet considered. + // Memoization of newly visited nodes is not possible (but we still update + // results for the elements in the chain we found above). + std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end()); + llvm::DenseSet<const void *> Visited; + while (!Queue.empty()) { + BoundNodesTreeBuilder BuilderCopy = *Builder; + if (Matcher.matches(Queue.front(), this, &BuilderCopy)) { + *Builder = std::move(BuilderCopy); + return Finish(true); + } + for (const auto &Parent : ActiveASTContext->getParents(Queue.front())) { + // Make sure we do not visit the same node twice. + // Otherwise, we'll visit the common ancestors as often as there + // are splits on the way down. + if (Visited.insert(Parent.getMemoizationData()).second) + Queue.push_back(Parent); + } + Queue.pop_front(); + } + } + return Finish(false); + } + + // Implements a BoundNodesTree::Visitor that calls a MatchCallback with + // the aggregated bound nodes for each match. + class MatchVisitor : public BoundNodesTreeBuilder::Visitor { + struct CurBoundScope { + CurBoundScope(MatchASTVisitor::CurMatchData &State, const BoundNodes &BN) + : State(State) { + State.SetBoundNodes(BN); + } + + ~CurBoundScope() { State.clearBoundNodes(); } + + private: + MatchASTVisitor::CurMatchData &State; + }; + + public: + MatchVisitor(MatchASTVisitor &MV, ASTContext *Context, + MatchFinder::MatchCallback *Callback) + : State(MV.CurMatchState), Context(Context), Callback(Callback) {} + + void visitMatch(const BoundNodes& BoundNodesView) override { + TraversalKindScope RAII(*Context, Callback->getCheckTraversalKind()); + CurBoundScope RAII2(State, BoundNodesView); + Callback->run(MatchFinder::MatchResult(BoundNodesView, Context)); + } + + private: + MatchASTVisitor::CurMatchData &State; + ASTContext* Context; + MatchFinder::MatchCallback* Callback; + }; + + // Returns true if 'TypeNode' has an alias that matches the given matcher. + bool typeHasMatchingAlias(const Type *TypeNode, + const Matcher<NamedDecl> &Matcher, + BoundNodesTreeBuilder *Builder) { + const Type *const CanonicalType = + ActiveASTContext->getCanonicalType(TypeNode); + auto Aliases = TypeAliases.find(CanonicalType); + if (Aliases == TypeAliases.end()) + return false; + for (const TypedefNameDecl *Alias : Aliases->second) { + BoundNodesTreeBuilder Result(*Builder); + if (Matcher.matches(*Alias, this, &Result)) { + *Builder = std::move(Result); + return true; + } + } + return false; + } + + bool + objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl, + const Matcher<NamedDecl> &Matcher, + BoundNodesTreeBuilder *Builder) { + auto Aliases = CompatibleAliases.find(InterfaceDecl); + if (Aliases == CompatibleAliases.end()) + return false; + for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) { + BoundNodesTreeBuilder Result(*Builder); + if (Matcher.matches(*Alias, this, &Result)) { + *Builder = std::move(Result); + return true; + } + } + return false; + } + + /// Bucket to record map. + /// + /// Used to get the appropriate bucket for each matcher. + llvm::StringMap<llvm::TimeRecord> TimeByBucket; + + const MatchFinder::MatchersByType *Matchers; + + /// Filtered list of matcher indices for each matcher kind. + /// + /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node + /// kind (and derived kinds) so it is a waste to try every matcher on every + /// node. + /// We precalculate a list of matchers that pass the toplevel restrict check. + llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap; + + const MatchFinder::MatchFinderOptions &Options; + ASTContext *ActiveASTContext; + + // Maps a canonical type to its TypedefDecls. + llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases; + + // Maps an Objective-C interface to its ObjCCompatibleAliasDecls. + llvm::DenseMap<const ObjCInterfaceDecl *, + llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>> + CompatibleAliases; + + // Maps (matcher, node) -> the match result for memoization. + typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap; + MemoizationMap ResultCache; +}; + +static CXXRecordDecl * +getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) { + if (auto *RD = TypeNode->getAsCXXRecordDecl()) + return RD; + + // Find the innermost TemplateSpecializationType that isn't an alias template. + auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>(); + while (TemplateType && TemplateType->isTypeAlias()) + TemplateType = + TemplateType->getAliasedType()->getAs<TemplateSpecializationType>(); + + // If this is the name of a (dependent) template specialization, use the + // definition of the template, even though it might be specialized later. + if (TemplateType) + if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>( + TemplateType->getTemplateName().getAsTemplateDecl())) + return ClassTemplate->getTemplatedDecl(); + + return nullptr; +} + +// Returns true if the given C++ class is directly or indirectly derived +// from a base type with the given name. A class is not considered to be +// derived from itself. +bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, + const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, + bool Directly) { + llvm::SmallPtrSet<const CXXRecordDecl *, 8> Visited; + return classIsDerivedFromImpl(Declaration, Base, Builder, Directly, Visited); +} + +bool MatchASTVisitor::classIsDerivedFromImpl( + const CXXRecordDecl *Declaration, const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, bool Directly, + llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited) { + if (!Declaration->hasDefinition()) + return false; + if (!Visited.insert(Declaration).second) + return false; + for (const auto &It : Declaration->bases()) { + const Type *TypeNode = It.getType().getTypePtr(); + + if (typeHasMatchingAlias(TypeNode, Base, Builder)) + return true; + + // FIXME: Going to the primary template here isn't really correct, but + // unfortunately we accept a Decl matcher for the base class not a Type + // matcher, so it's the best thing we can do with our current interface. + CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode); + if (!ClassDecl) + continue; + if (ClassDecl == Declaration) { + // This can happen for recursive template definitions. + continue; + } + BoundNodesTreeBuilder Result(*Builder); + if (Base.matches(*ClassDecl, this, &Result)) { + *Builder = std::move(Result); + return true; + } + if (!Directly && + classIsDerivedFromImpl(ClassDecl, Base, Builder, Directly, Visited)) + return true; + } + return false; +} + +// Returns true if the given Objective-C class is directly or indirectly +// derived from a matching base class. A class is not considered to be derived +// from itself. +bool MatchASTVisitor::objcClassIsDerivedFrom( + const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base, + BoundNodesTreeBuilder *Builder, bool Directly) { + // Check if any of the superclasses of the class match. + for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass(); + ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) { + // Check if there are any matching compatibility aliases. + if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder)) + return true; + + // Check if there are any matching type aliases. + const Type *TypeNode = ClassDecl->getTypeForDecl(); + if (typeHasMatchingAlias(TypeNode, Base, Builder)) + return true; + + if (Base.matches(*ClassDecl, this, Builder)) + return true; + + // Not `return false` as a temporary workaround for PR43879. + if (Directly) + break; + } + + return false; +} + +bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) { + if (!DeclNode) { + return true; + } + + bool ScopedTraversal = + TraversingASTNodeNotSpelledInSource || DeclNode->isImplicit(); + bool ScopedChildren = TraversingASTChildrenNotSpelledInSource; + + if (const auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(DeclNode)) { + auto SK = CTSD->getSpecializationKind(); + if (SK == TSK_ExplicitInstantiationDeclaration || + SK == TSK_ExplicitInstantiationDefinition) + ScopedChildren = true; + } else if (const auto *FD = dyn_cast<FunctionDecl>(DeclNode)) { + if (FD->isDefaulted()) + ScopedChildren = true; + if (FD->isTemplateInstantiation()) + ScopedTraversal = true; + } else if (isa<BindingDecl>(DeclNode)) { + ScopedChildren = true; + } + + ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal); + ASTChildrenNotSpelledInSourceScope RAII2(this, ScopedChildren); + + match(*DeclNode); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode); +} + +bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) { + if (!StmtNode) { + return true; + } + bool ScopedTraversal = TraversingASTNodeNotSpelledInSource || + TraversingASTChildrenNotSpelledInSource; + + ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal); + match(*StmtNode); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue); +} + +bool MatchASTVisitor::TraverseType(QualType TypeNode) { + match(TypeNode); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode); +} + +bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) { + // The RecursiveASTVisitor only visits types if they're not within TypeLocs. + // We still want to find those types via matchers, so we match them here. Note + // that the TypeLocs are structurally a shadow-hierarchy to the expressed + // type, so we visit all involved parts of a compound type when matching on + // each TypeLoc. + match(TypeLocNode); + match(TypeLocNode.getType()); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode); +} + +bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) { + match(*NNS); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS); +} + +bool MatchASTVisitor::TraverseNestedNameSpecifierLoc( + NestedNameSpecifierLoc NNS) { + if (!NNS) + return true; + + match(NNS); + + // We only match the nested name specifier here (as opposed to traversing it) + // because the traversal is already done in the parallel "Loc"-hierarchy. + if (NNS.hasQualifier()) + match(*NNS.getNestedNameSpecifier()); + return + RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS); +} + +bool MatchASTVisitor::TraverseConstructorInitializer( + CXXCtorInitializer *CtorInit) { + if (!CtorInit) + return true; + + bool ScopedTraversal = TraversingASTNodeNotSpelledInSource || + TraversingASTChildrenNotSpelledInSource; + + if (!CtorInit->isWritten()) + ScopedTraversal = true; + + ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal); + + match(*CtorInit); + + return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer( + CtorInit); +} + +bool MatchASTVisitor::TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc) { + match(Loc); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateArgumentLoc(Loc); +} + +bool MatchASTVisitor::TraverseAttr(Attr *AttrNode) { + match(*AttrNode); + return RecursiveASTVisitor<MatchASTVisitor>::TraverseAttr(AttrNode); +} + +class MatchASTConsumer : public ASTConsumer { +public: + MatchASTConsumer(MatchFinder *Finder, + MatchFinder::ParsingDoneTestCallback *ParsingDone) + : Finder(Finder), ParsingDone(ParsingDone) {} + +private: + void HandleTranslationUnit(ASTContext &Context) override { + if (ParsingDone != nullptr) { + ParsingDone->run(); + } + Finder->matchAST(Context); + } + + MatchFinder *Finder; + MatchFinder::ParsingDoneTestCallback *ParsingDone; +}; + +} // end namespace +} // end namespace internal + +MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes, + ASTContext *Context) + : Nodes(Nodes), Context(Context), + SourceManager(&Context->getSourceManager()) {} + +MatchFinder::MatchCallback::~MatchCallback() {} +MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {} + +MatchFinder::MatchFinder(MatchFinderOptions Options) + : Options(std::move(Options)), ParsingDone(nullptr) {} + +MatchFinder::~MatchFinder() {} + +void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch, + MatchCallback *Action) { + std::optional<TraversalKind> TK; + if (Action) + TK = Action->getCheckTraversalKind(); + if (TK) + Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action); + else + Matchers.DeclOrStmt.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const TypeMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.Type.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const StatementMatcher &NodeMatch, + MatchCallback *Action) { + std::optional<TraversalKind> TK; + if (Action) + TK = Action->getCheckTraversalKind(); + if (TK) + Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action); + else + Matchers.DeclOrStmt.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.TypeLoc.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.CtorInit.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const TemplateArgumentLocMatcher &NodeMatch, + MatchCallback *Action) { + Matchers.TemplateArgumentLoc.emplace_back(NodeMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +void MatchFinder::addMatcher(const AttrMatcher &AttrMatch, + MatchCallback *Action) { + Matchers.Attr.emplace_back(AttrMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + +bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch, + MatchCallback *Action) { + if (NodeMatch.canConvertTo<Decl>()) { + addMatcher(NodeMatch.convertTo<Decl>(), Action); + return true; + } else if (NodeMatch.canConvertTo<QualType>()) { + addMatcher(NodeMatch.convertTo<QualType>(), Action); + return true; + } else if (NodeMatch.canConvertTo<Stmt>()) { + addMatcher(NodeMatch.convertTo<Stmt>(), Action); + return true; + } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) { + addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action); + return true; + } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) { + addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action); + return true; + } else if (NodeMatch.canConvertTo<TypeLoc>()) { + addMatcher(NodeMatch.convertTo<TypeLoc>(), Action); + return true; + } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) { + addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action); + return true; + } else if (NodeMatch.canConvertTo<TemplateArgumentLoc>()) { + addMatcher(NodeMatch.convertTo<TemplateArgumentLoc>(), Action); + return true; + } else if (NodeMatch.canConvertTo<Attr>()) { + addMatcher(NodeMatch.convertTo<Attr>(), Action); + return true; + } + return false; +} + +std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() { + return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone); +} + +void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) { + internal::MatchASTVisitor Visitor(&Matchers, Options); + Visitor.set_active_ast_context(&Context); + Visitor.match(Node); +} + +void MatchFinder::matchAST(ASTContext &Context) { + internal::MatchASTVisitor Visitor(&Matchers, Options); + internal::MatchASTVisitor::TraceReporter StackTrace(Visitor); + Visitor.set_active_ast_context(&Context); + Visitor.onStartOfTranslationUnit(); + Visitor.TraverseAST(Context); + Visitor.onEndOfTranslationUnit(); +} + +void MatchFinder::registerTestCallbackAfterParsing( + MatchFinder::ParsingDoneTestCallback *NewParsingDone) { + ParsingDone = NewParsingDone; +} + +StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; } + +std::optional<TraversalKind> +MatchFinder::MatchCallback::getCheckTraversalKind() const { + return std::nullopt; +} + +} // end namespace ast_matchers +} // end namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchersInternal.cpp new file mode 100644 index 000000000000..bf87b1aa0992 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -0,0 +1,1099 @@ +//===- ASTMatchersInternal.cpp - Structural query framework ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements the base layer of the matcher framework. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" +#include "clang/AST/Decl.h" +#include "clang/AST/DeclTemplate.h" +#include "clang/AST/ParentMapContext.h" +#include "clang/AST/PrettyPrinter.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Basic/LLVM.h" +#include "clang/Lex/Lexer.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/WithColor.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace clang { +namespace ast_matchers { + +AST_MATCHER_P(ObjCMessageExpr, hasAnySelectorMatcher, std::vector<std::string>, + Matches) { + return llvm::is_contained(Matches, Node.getSelector().getAsString()); +} + +namespace internal { + +static bool notUnaryOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers); + +static bool allOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers); + +static bool eachOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers); + +static bool anyOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers); + +static bool optionallyVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers); + +bool matchesAnyBase(const CXXRecordDecl &Node, + const Matcher<CXXBaseSpecifier> &BaseSpecMatcher, + ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) { + if (!Node.hasDefinition()) + return false; + + CXXBasePaths Paths; + Paths.setOrigin(&Node); + + const auto basePredicate = + [Finder, Builder, &BaseSpecMatcher](const CXXBaseSpecifier *BaseSpec, + CXXBasePath &IgnoredParam) { + BoundNodesTreeBuilder Result(*Builder); + if (BaseSpecMatcher.matches(*BaseSpec, Finder, &Result)) { + *Builder = std::move(Result); + return true; + } + return false; + }; + + return Node.lookupInBases(basePredicate, Paths, + /*LookupInDependent =*/true); +} + +void BoundNodesTreeBuilder::visitMatches(Visitor *ResultVisitor) { + if (Bindings.empty()) + Bindings.push_back(BoundNodesMap()); + for (BoundNodesMap &Binding : Bindings) { + ResultVisitor->visitMatch(BoundNodes(Binding)); + } +} + +namespace { + +using VariadicOperatorFunction = bool (*)( + const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, ArrayRef<DynTypedMatcher> InnerMatchers); + +template <VariadicOperatorFunction Func> +class VariadicMatcher : public DynMatcherInterface { +public: + VariadicMatcher(std::vector<DynTypedMatcher> InnerMatchers) + : InnerMatchers(std::move(InnerMatchers)) {} + + bool dynMatches(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return Func(DynNode, Finder, Builder, InnerMatchers); + } + +private: + std::vector<DynTypedMatcher> InnerMatchers; +}; + +class IdDynMatcher : public DynMatcherInterface { +public: + IdDynMatcher(StringRef ID, + IntrusiveRefCntPtr<DynMatcherInterface> InnerMatcher) + : ID(ID), InnerMatcher(std::move(InnerMatcher)) {} + + bool dynMatches(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + bool Result = InnerMatcher->dynMatches(DynNode, Finder, Builder); + if (Result) Builder->setBinding(ID, DynNode); + return Result; + } + + std::optional<clang::TraversalKind> TraversalKind() const override { + return InnerMatcher->TraversalKind(); + } + +private: + const std::string ID; + const IntrusiveRefCntPtr<DynMatcherInterface> InnerMatcher; +}; + +/// A matcher that always returns true. +class TrueMatcherImpl : public DynMatcherInterface { +public: + TrueMatcherImpl() = default; + + bool dynMatches(const DynTypedNode &, ASTMatchFinder *, + BoundNodesTreeBuilder *) const override { + return true; + } +}; + +/// A matcher that specifies a particular \c TraversalKind. +/// +/// The kind provided to the constructor overrides any kind that may be +/// specified by the `InnerMatcher`. +class DynTraversalMatcherImpl : public DynMatcherInterface { +public: + explicit DynTraversalMatcherImpl( + clang::TraversalKind TK, + IntrusiveRefCntPtr<DynMatcherInterface> InnerMatcher) + : TK(TK), InnerMatcher(std::move(InnerMatcher)) {} + + bool dynMatches(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return this->InnerMatcher->dynMatches(DynNode, Finder, Builder); + } + + std::optional<clang::TraversalKind> TraversalKind() const override { + return TK; + } + +private: + clang::TraversalKind TK; + IntrusiveRefCntPtr<DynMatcherInterface> InnerMatcher; +}; + +} // namespace + +bool ASTMatchFinder::isTraversalIgnoringImplicitNodes() const { + return getASTContext().getParentMapContext().getTraversalKind() == + TK_IgnoreUnlessSpelledInSource; +} + +DynTypedMatcher +DynTypedMatcher::constructVariadic(DynTypedMatcher::VariadicOperator Op, + ASTNodeKind SupportedKind, + std::vector<DynTypedMatcher> InnerMatchers) { + assert(!InnerMatchers.empty() && "Array must not be empty."); + assert(llvm::all_of(InnerMatchers, + [SupportedKind](const DynTypedMatcher &M) { + return M.canConvertTo(SupportedKind); + }) && + "InnerMatchers must be convertible to SupportedKind!"); + + // We must relax the restrict kind here. + // The different operators might deal differently with a mismatch. + // Make it the same as SupportedKind, since that is the broadest type we are + // allowed to accept. + auto RestrictKind = SupportedKind; + + switch (Op) { + case VO_AllOf: + // In the case of allOf() we must pass all the checks, so making + // RestrictKind the most restrictive can save us time. This way we reject + // invalid types earlier and we can elide the kind checks inside the + // matcher. + for (auto &IM : InnerMatchers) { + RestrictKind = + ASTNodeKind::getMostDerivedType(RestrictKind, IM.RestrictKind); + } + return DynTypedMatcher( + SupportedKind, RestrictKind, + new VariadicMatcher<allOfVariadicOperator>(std::move(InnerMatchers))); + + case VO_AnyOf: + return DynTypedMatcher( + SupportedKind, RestrictKind, + new VariadicMatcher<anyOfVariadicOperator>(std::move(InnerMatchers))); + + case VO_EachOf: + return DynTypedMatcher( + SupportedKind, RestrictKind, + new VariadicMatcher<eachOfVariadicOperator>(std::move(InnerMatchers))); + + case VO_Optionally: + return DynTypedMatcher(SupportedKind, RestrictKind, + new VariadicMatcher<optionallyVariadicOperator>( + std::move(InnerMatchers))); + + case VO_UnaryNot: + // FIXME: Implement the Not operator to take a single matcher instead of a + // vector. + return DynTypedMatcher( + SupportedKind, RestrictKind, + new VariadicMatcher<notUnaryOperator>(std::move(InnerMatchers))); + } + llvm_unreachable("Invalid Op value."); +} + +DynTypedMatcher +DynTypedMatcher::constructRestrictedWrapper(const DynTypedMatcher &InnerMatcher, + ASTNodeKind RestrictKind) { + DynTypedMatcher Copy = InnerMatcher; + Copy.RestrictKind = RestrictKind; + return Copy; +} + +DynTypedMatcher DynTypedMatcher::withTraversalKind(TraversalKind TK) { + auto Copy = *this; + Copy.Implementation = + new DynTraversalMatcherImpl(TK, std::move(Copy.Implementation)); + return Copy; +} + +DynTypedMatcher DynTypedMatcher::trueMatcher(ASTNodeKind NodeKind) { + // We only ever need one instance of TrueMatcherImpl, so we create a static + // instance and reuse it to reduce the overhead of the matcher and increase + // the chance of cache hits. + static const llvm::IntrusiveRefCntPtr<TrueMatcherImpl> Instance = + new TrueMatcherImpl(); + return DynTypedMatcher(NodeKind, NodeKind, Instance); +} + +bool DynTypedMatcher::canMatchNodesOfKind(ASTNodeKind Kind) const { + return RestrictKind.isBaseOf(Kind); +} + +DynTypedMatcher DynTypedMatcher::dynCastTo(const ASTNodeKind Kind) const { + auto Copy = *this; + Copy.SupportedKind = Kind; + Copy.RestrictKind = ASTNodeKind::getMostDerivedType(Kind, RestrictKind); + return Copy; +} + +bool DynTypedMatcher::matches(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const { + TraversalKindScope RAII(Finder->getASTContext(), + Implementation->TraversalKind()); + + if (Finder->isTraversalIgnoringImplicitNodes() && + Finder->IsMatchingInASTNodeNotSpelledInSource()) + return false; + + if (!Finder->isTraversalIgnoringImplicitNodes() && + Finder->IsMatchingInASTNodeNotAsIs()) + return false; + + auto N = + Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode); + + if (RestrictKind.isBaseOf(N.getNodeKind()) && + Implementation->dynMatches(N, Finder, Builder)) { + return true; + } + // Delete all bindings when a matcher does not match. + // This prevents unexpected exposure of bound nodes in unmatches + // branches of the match tree. + Builder->removeBindings([](const BoundNodesMap &) { return true; }); + return false; +} + +bool DynTypedMatcher::matchesNoKindCheck(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const { + TraversalKindScope raii(Finder->getASTContext(), + Implementation->TraversalKind()); + + if (Finder->isTraversalIgnoringImplicitNodes() && + Finder->IsMatchingInASTNodeNotSpelledInSource()) + return false; + + if (!Finder->isTraversalIgnoringImplicitNodes() && + Finder->IsMatchingInASTNodeNotAsIs()) + return false; + + auto N = + Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode); + + assert(RestrictKind.isBaseOf(N.getNodeKind())); + if (Implementation->dynMatches(N, Finder, Builder)) { + return true; + } + // Delete all bindings when a matcher does not match. + // This prevents unexpected exposure of bound nodes in unmatches + // branches of the match tree. + Builder->removeBindings([](const BoundNodesMap &) { return true; }); + return false; +} + +std::optional<DynTypedMatcher> DynTypedMatcher::tryBind(StringRef ID) const { + if (!AllowBind) + return std::nullopt; + auto Result = *this; + Result.Implementation = + new IdDynMatcher(ID, std::move(Result.Implementation)); + return std::move(Result); +} + +bool DynTypedMatcher::canConvertTo(ASTNodeKind To) const { + const auto From = getSupportedKind(); + auto QualKind = ASTNodeKind::getFromNodeKind<QualType>(); + auto TypeKind = ASTNodeKind::getFromNodeKind<Type>(); + /// Mimic the implicit conversions of Matcher<>. + /// - From Matcher<Type> to Matcher<QualType> + if (From.isSame(TypeKind) && To.isSame(QualKind)) return true; + /// - From Matcher<Base> to Matcher<Derived> + return From.isBaseOf(To); +} + +void BoundNodesTreeBuilder::addMatch(const BoundNodesTreeBuilder &Other) { + Bindings.append(Other.Bindings.begin(), Other.Bindings.end()); +} + +static bool notUnaryOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers) { + if (InnerMatchers.size() != 1) + return false; + + // The 'unless' matcher will always discard the result: + // If the inner matcher doesn't match, unless returns true, + // but the inner matcher cannot have bound anything. + // If the inner matcher matches, the result is false, and + // any possible binding will be discarded. + // We still need to hand in all the bound nodes up to this + // point so the inner matcher can depend on bound nodes, + // and we need to actively discard the bound nodes, otherwise + // the inner matcher will reset the bound nodes if it doesn't + // match, but this would be inversed by 'unless'. + BoundNodesTreeBuilder Discard(*Builder); + return !InnerMatchers[0].matches(DynNode, Finder, &Discard); +} + +static bool allOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers) { + // allOf leads to one matcher for each alternative in the first + // matcher combined with each alternative in the second matcher. + // Thus, we can reuse the same Builder. + return llvm::all_of(InnerMatchers, [&](const DynTypedMatcher &InnerMatcher) { + return InnerMatcher.matchesNoKindCheck(DynNode, Finder, Builder); + }); +} + +static bool eachOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers) { + BoundNodesTreeBuilder Result; + bool Matched = false; + for (const DynTypedMatcher &InnerMatcher : InnerMatchers) { + BoundNodesTreeBuilder BuilderInner(*Builder); + if (InnerMatcher.matches(DynNode, Finder, &BuilderInner)) { + Matched = true; + Result.addMatch(BuilderInner); + } + } + *Builder = std::move(Result); + return Matched; +} + +static bool anyOfVariadicOperator(const DynTypedNode &DynNode, + ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers) { + for (const DynTypedMatcher &InnerMatcher : InnerMatchers) { + BoundNodesTreeBuilder Result = *Builder; + if (InnerMatcher.matches(DynNode, Finder, &Result)) { + *Builder = std::move(Result); + return true; + } + } + return false; +} + +static bool +optionallyVariadicOperator(const DynTypedNode &DynNode, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder, + ArrayRef<DynTypedMatcher> InnerMatchers) { + if (InnerMatchers.size() != 1) + return false; + + BoundNodesTreeBuilder Result(*Builder); + if (InnerMatchers[0].matches(DynNode, Finder, &Result)) + *Builder = std::move(Result); + return true; +} + +inline static +std::vector<std::string> vectorFromRefs(ArrayRef<const StringRef *> NameRefs) { + std::vector<std::string> Names; + Names.reserve(NameRefs.size()); + for (auto *Name : NameRefs) + Names.emplace_back(*Name); + return Names; +} + +Matcher<NamedDecl> hasAnyNameFunc(ArrayRef<const StringRef *> NameRefs) { + return internal::Matcher<NamedDecl>( + new internal::HasNameMatcher(vectorFromRefs(NameRefs))); +} + +Matcher<ObjCMessageExpr> hasAnySelectorFunc( + ArrayRef<const StringRef *> NameRefs) { + return hasAnySelectorMatcher(vectorFromRefs(NameRefs)); +} + +HasOpNameMatcher hasAnyOperatorNameFunc(ArrayRef<const StringRef *> NameRefs) { + return HasOpNameMatcher(vectorFromRefs(NameRefs)); +} + +HasOverloadOpNameMatcher +hasAnyOverloadedOperatorNameFunc(ArrayRef<const StringRef *> NameRefs) { + return HasOverloadOpNameMatcher(vectorFromRefs(NameRefs)); +} + +HasNameMatcher::HasNameMatcher(std::vector<std::string> N) + : UseUnqualifiedMatch( + llvm::all_of(N, [](StringRef Name) { return !Name.contains("::"); })), + Names(std::move(N)) { +#ifndef NDEBUG + for (StringRef Name : Names) + assert(!Name.empty()); +#endif +} + +static bool consumeNameSuffix(StringRef &FullName, StringRef Suffix) { + StringRef Name = FullName; + if (!Name.ends_with(Suffix)) + return false; + Name = Name.drop_back(Suffix.size()); + if (!Name.empty()) { + if (!Name.ends_with("::")) + return false; + Name = Name.drop_back(2); + } + FullName = Name; + return true; +} + +static StringRef getNodeName(const NamedDecl &Node, + llvm::SmallString<128> &Scratch) { + // Simple name. + if (Node.getIdentifier()) + return Node.getName(); + + if (Node.getDeclName()) { + // Name needs to be constructed. + Scratch.clear(); + llvm::raw_svector_ostream OS(Scratch); + Node.printName(OS); + return OS.str(); + } + + return "(anonymous)"; +} + +static StringRef getNodeName(const RecordDecl &Node, + llvm::SmallString<128> &Scratch) { + if (Node.getIdentifier()) { + return Node.getName(); + } + Scratch.clear(); + return ("(anonymous " + Node.getKindName() + ")").toStringRef(Scratch); +} + +static StringRef getNodeName(const NamespaceDecl &Node, + llvm::SmallString<128> &Scratch) { + return Node.isAnonymousNamespace() ? "(anonymous namespace)" : Node.getName(); +} + +namespace { + +class PatternSet { +public: + PatternSet(ArrayRef<std::string> Names) { + Patterns.reserve(Names.size()); + for (StringRef Name : Names) + Patterns.push_back({Name, Name.starts_with("::")}); + } + + /// Consumes the name suffix from each pattern in the set and removes the ones + /// that didn't match. + /// Return true if there are still any patterns left. + bool consumeNameSuffix(StringRef NodeName, bool CanSkip) { + for (size_t I = 0; I < Patterns.size();) { + if (::clang::ast_matchers::internal::consumeNameSuffix(Patterns[I].P, + NodeName) || + CanSkip) { + ++I; + } else { + Patterns.erase(Patterns.begin() + I); + } + } + return !Patterns.empty(); + } + + /// Check if any of the patterns are a match. + /// A match will be a pattern that was fully consumed, that also matches the + /// 'fully qualified' requirement. + bool foundMatch(bool AllowFullyQualified) const { + return llvm::any_of(Patterns, [&](const Pattern &Pattern) { + return Pattern.P.empty() && + (AllowFullyQualified || !Pattern.IsFullyQualified); + }); + } + +private: + struct Pattern { + StringRef P; + bool IsFullyQualified; + }; + + llvm::SmallVector<Pattern, 8> Patterns; +}; + +} // namespace + +bool HasNameMatcher::matchesNodeUnqualified(const NamedDecl &Node) const { + assert(UseUnqualifiedMatch); + llvm::SmallString<128> Scratch; + StringRef NodeName = getNodeName(Node, Scratch); + return llvm::any_of(Names, [&](StringRef Name) { + return consumeNameSuffix(Name, NodeName) && Name.empty(); + }); +} + +bool HasNameMatcher::matchesNodeFullFast(const NamedDecl &Node) const { + PatternSet Patterns(Names); + llvm::SmallString<128> Scratch; + + // This function is copied and adapted from NamedDecl::printQualifiedName() + // By matching each part individually we optimize in a couple of ways: + // - We can exit early on the first failure. + // - We can skip inline/anonymous namespaces without another pass. + // - We print one name at a time, reducing the chance of overflowing the + // inlined space of the SmallString. + + // First, match the name. + if (!Patterns.consumeNameSuffix(getNodeName(Node, Scratch), + /*CanSkip=*/false)) + return false; + + // Try to match each declaration context. + // We are allowed to skip anonymous and inline namespaces if they don't match. + const DeclContext *Ctx = Node.getDeclContext(); + + if (Ctx->isFunctionOrMethod()) + return Patterns.foundMatch(/*AllowFullyQualified=*/false); + + for (; Ctx; Ctx = Ctx->getParent()) { + // Linkage Spec can just be ignored + // FIXME: Any other DeclContext kinds that can be safely disregarded + if (isa<LinkageSpecDecl>(Ctx)) + continue; + if (!isa<NamedDecl>(Ctx)) + break; + if (Patterns.foundMatch(/*AllowFullyQualified=*/false)) + return true; + + if (const auto *ND = dyn_cast<NamespaceDecl>(Ctx)) { + // If it matches (or we can skip it), continue. + if (Patterns.consumeNameSuffix(getNodeName(*ND, Scratch), + /*CanSkip=*/ND->isAnonymousNamespace() || + ND->isInline())) + continue; + return false; + } + if (const auto *RD = dyn_cast<RecordDecl>(Ctx)) { + if (!isa<ClassTemplateSpecializationDecl>(Ctx)) { + if (Patterns.consumeNameSuffix(getNodeName(*RD, Scratch), + /*CanSkip=*/false)) + continue; + + return false; + } + } + + // We don't know how to deal with this DeclContext. + // Fallback to the slow version of the code. + return matchesNodeFullSlow(Node); + } + + return Patterns.foundMatch(/*AllowFullyQualified=*/true); +} + +bool HasNameMatcher::matchesNodeFullSlow(const NamedDecl &Node) const { + const bool SkipUnwrittenCases[] = {false, true}; + for (bool SkipUnwritten : SkipUnwrittenCases) { + llvm::SmallString<128> NodeName = StringRef("::"); + llvm::raw_svector_ostream OS(NodeName); + + PrintingPolicy Policy = Node.getASTContext().getPrintingPolicy(); + Policy.SuppressUnwrittenScope = SkipUnwritten; + Policy.SuppressInlineNamespace = SkipUnwritten; + Node.printQualifiedName(OS, Policy); + + const StringRef FullName = OS.str(); + + for (const StringRef Pattern : Names) { + if (Pattern.starts_with("::")) { + if (FullName == Pattern) + return true; + } else if (FullName.ends_with(Pattern) && + FullName.drop_back(Pattern.size()).ends_with("::")) { + return true; + } + } + } + + return false; +} + +bool HasNameMatcher::matchesNode(const NamedDecl &Node) const { + assert(matchesNodeFullFast(Node) == matchesNodeFullSlow(Node)); + if (UseUnqualifiedMatch) { + assert(matchesNodeUnqualified(Node) == matchesNodeFullFast(Node)); + return matchesNodeUnqualified(Node); + } + return matchesNodeFullFast(Node); +} + +// Checks whether \p Loc points to a token with source text of \p TokenText. +static bool isTokenAtLoc(const SourceManager &SM, const LangOptions &LangOpts, + StringRef Text, SourceLocation Loc) { + llvm::SmallString<16> Buffer; + bool Invalid = false; + // Since `Loc` may point into an expansion buffer, which has no corresponding + // source, we need to look at the spelling location to read the actual source. + StringRef TokenText = Lexer::getSpelling(SM.getSpellingLoc(Loc), Buffer, SM, + LangOpts, &Invalid); + return !Invalid && Text == TokenText; +} + +std::optional<SourceLocation> +getExpansionLocOfMacro(StringRef MacroName, SourceLocation Loc, + const ASTContext &Context) { + auto &SM = Context.getSourceManager(); + const LangOptions &LangOpts = Context.getLangOpts(); + while (Loc.isMacroID()) { + SrcMgr::ExpansionInfo Expansion = + SM.getSLocEntry(SM.getFileID(Loc)).getExpansion(); + if (Expansion.isMacroArgExpansion()) + // Check macro argument for an expansion of the given macro. For example, + // `F(G(3))`, where `MacroName` is `G`. + if (std::optional<SourceLocation> ArgLoc = getExpansionLocOfMacro( + MacroName, Expansion.getSpellingLoc(), Context)) + return ArgLoc; + Loc = Expansion.getExpansionLocStart(); + if (isTokenAtLoc(SM, LangOpts, MacroName, Loc)) + return Loc; + } + return std::nullopt; +} + +std::shared_ptr<llvm::Regex> createAndVerifyRegex(StringRef Regex, + llvm::Regex::RegexFlags Flags, + StringRef MatcherID) { + assert(!Regex.empty() && "Empty regex string"); + auto SharedRegex = std::make_shared<llvm::Regex>(Regex, Flags); + std::string Error; + if (!SharedRegex->isValid(Error)) { + llvm::WithColor::error() + << "building matcher '" << MatcherID << "': " << Error << "\n"; + llvm::WithColor::note() << " input was '" << Regex << "'\n"; + } + return SharedRegex; +} +} // end namespace internal + +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCAutoreleasePoolStmt> + autoreleasePoolStmt; +const internal::VariadicDynCastAllOfMatcher<Decl, TranslationUnitDecl> + translationUnitDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TypedefDecl> typedefDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TypedefNameDecl> + typedefNameDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TypeAliasDecl> typeAliasDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TypeAliasTemplateDecl> + typeAliasTemplateDecl; +const internal::VariadicAllOfMatcher<Decl> decl; +const internal::VariadicDynCastAllOfMatcher<Decl, DecompositionDecl> decompositionDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, BindingDecl> bindingDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, LinkageSpecDecl> + linkageSpecDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, NamedDecl> namedDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, LabelDecl> labelDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, NamespaceDecl> namespaceDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, NamespaceAliasDecl> + namespaceAliasDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, RecordDecl> recordDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXRecordDecl> cxxRecordDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ClassTemplateDecl> + classTemplateDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, + ClassTemplateSpecializationDecl> + classTemplateSpecializationDecl; +const internal::VariadicDynCastAllOfMatcher< + Decl, ClassTemplatePartialSpecializationDecl> + classTemplatePartialSpecializationDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, DeclaratorDecl> + declaratorDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ParmVarDecl> parmVarDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, AccessSpecDecl> + accessSpecDecl; +const internal::VariadicAllOfMatcher<CXXBaseSpecifier> cxxBaseSpecifier; +const internal::VariadicAllOfMatcher<CXXCtorInitializer> cxxCtorInitializer; +const internal::VariadicAllOfMatcher<TemplateArgument> templateArgument; +const internal::VariadicAllOfMatcher<TemplateArgumentLoc> templateArgumentLoc; +const internal::VariadicAllOfMatcher<TemplateName> templateName; +const internal::VariadicDynCastAllOfMatcher<Decl, NonTypeTemplateParmDecl> + nonTypeTemplateParmDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TemplateTypeParmDecl> + templateTypeParmDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TemplateTemplateParmDecl> + templateTemplateParmDecl; + +const internal::VariadicAllOfMatcher<LambdaCapture> lambdaCapture; +const internal::VariadicAllOfMatcher<QualType> qualType; +const internal::VariadicAllOfMatcher<Type> type; +const internal::VariadicAllOfMatcher<TypeLoc> typeLoc; + +const internal::VariadicDynCastAllOfMatcher<TypeLoc, QualifiedTypeLoc> + qualifiedTypeLoc; +const internal::VariadicDynCastAllOfMatcher<TypeLoc, PointerTypeLoc> + pointerTypeLoc; +const internal::VariadicDynCastAllOfMatcher<TypeLoc, ReferenceTypeLoc> + referenceTypeLoc; +const internal::VariadicDynCastAllOfMatcher<TypeLoc, + TemplateSpecializationTypeLoc> + templateSpecializationTypeLoc; +const internal::VariadicDynCastAllOfMatcher<TypeLoc, ElaboratedTypeLoc> + elaboratedTypeLoc; + +const internal::VariadicDynCastAllOfMatcher<Stmt, UnaryExprOrTypeTraitExpr> + unaryExprOrTypeTraitExpr; +const internal::VariadicDynCastAllOfMatcher<Decl, ValueDecl> valueDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXConstructorDecl> + cxxConstructorDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXDestructorDecl> + cxxDestructorDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, EnumDecl> enumDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, EnumConstantDecl> + enumConstantDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, TagDecl> tagDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXMethodDecl> cxxMethodDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXConversionDecl> + cxxConversionDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ConceptDecl> conceptDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, VarDecl> varDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, FieldDecl> fieldDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, IndirectFieldDecl> + indirectFieldDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, FunctionDecl> functionDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, FunctionTemplateDecl> + functionTemplateDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, FriendDecl> friendDecl; +const internal::VariadicAllOfMatcher<Stmt> stmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, DeclStmt> declStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, MemberExpr> memberExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, UnresolvedMemberExpr> + unresolvedMemberExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXDependentScopeMemberExpr> + cxxDependentScopeMemberExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CallExpr> callExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, LambdaExpr> lambdaExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXMemberCallExpr> + cxxMemberCallExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCMessageExpr> + objcMessageExpr; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCInterfaceDecl> + objcInterfaceDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCImplementationDecl> + objcImplementationDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCProtocolDecl> + objcProtocolDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCCategoryDecl> + objcCategoryDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCCategoryImplDecl> + objcCategoryImplDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCMethodDecl> + objcMethodDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, BlockDecl> + blockDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCIvarDecl> objcIvarDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, ObjCPropertyDecl> + objcPropertyDecl; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCAtThrowStmt> + objcThrowStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCAtTryStmt> objcTryStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCAtCatchStmt> + objcCatchStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCAtFinallyStmt> + objcFinallyStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ExprWithCleanups> + exprWithCleanups; +const internal::VariadicDynCastAllOfMatcher<Stmt, InitListExpr> initListExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXStdInitializerListExpr> + cxxStdInitializerListExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ImplicitValueInitExpr> + implicitValueInitExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ParenListExpr> parenListExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, SubstNonTypeTemplateParmExpr> + substNonTypeTemplateParmExpr; +const internal::VariadicDynCastAllOfMatcher<Decl, UsingDecl> usingDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, UsingEnumDecl> usingEnumDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, UsingDirectiveDecl> + usingDirectiveDecl; +const internal::VariadicDynCastAllOfMatcher<Stmt, UnresolvedLookupExpr> + unresolvedLookupExpr; +const internal::VariadicDynCastAllOfMatcher<Decl, UnresolvedUsingValueDecl> + unresolvedUsingValueDecl; +const internal::VariadicDynCastAllOfMatcher<Decl, UnresolvedUsingTypenameDecl> + unresolvedUsingTypenameDecl; +const internal::VariadicDynCastAllOfMatcher<Stmt, ConstantExpr> constantExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ParenExpr> parenExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXConstructExpr> + cxxConstructExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXUnresolvedConstructExpr> + cxxUnresolvedConstructExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXThisExpr> cxxThisExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXBindTemporaryExpr> + cxxBindTemporaryExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, MaterializeTemporaryExpr> + materializeTemporaryExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXNewExpr> cxxNewExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXDeleteExpr> cxxDeleteExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXNoexceptExpr> + cxxNoexceptExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ArraySubscriptExpr> + arraySubscriptExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ArrayInitIndexExpr> + arrayInitIndexExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ArrayInitLoopExpr> + arrayInitLoopExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXDefaultArgExpr> + cxxDefaultArgExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXOperatorCallExpr> + cxxOperatorCallExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXRewrittenBinaryOperator> + cxxRewrittenBinaryOperator; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXFoldExpr> cxxFoldExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, Expr> expr; +const internal::VariadicDynCastAllOfMatcher<Stmt, DeclRefExpr> declRefExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCIvarRefExpr> objcIvarRefExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, BlockExpr> blockExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, IfStmt> ifStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ForStmt> forStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXForRangeStmt> + cxxForRangeStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, WhileStmt> whileStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, DoStmt> doStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, BreakStmt> breakStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ContinueStmt> continueStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CoreturnStmt> coreturnStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, ReturnStmt> returnStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, GotoStmt> gotoStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, LabelStmt> labelStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, AddrLabelExpr> addrLabelExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, SwitchStmt> switchStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, SwitchCase> switchCase; +const internal::VariadicDynCastAllOfMatcher<Stmt, CaseStmt> caseStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, DefaultStmt> defaultStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CompoundStmt> compoundStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CoroutineBodyStmt> + coroutineBodyStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXCatchStmt> cxxCatchStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXTryStmt> cxxTryStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXThrowExpr> cxxThrowExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, NullStmt> nullStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, AsmStmt> asmStmt; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXBoolLiteralExpr> + cxxBoolLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, StringLiteral> stringLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, ObjCStringLiteral> objcStringLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, CharacterLiteral> + characterLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, IntegerLiteral> + integerLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, FloatingLiteral> floatLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, ImaginaryLiteral> imaginaryLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, FixedPointLiteral> + fixedPointLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, UserDefinedLiteral> + userDefinedLiteral; +const internal::VariadicDynCastAllOfMatcher<Stmt, CompoundLiteralExpr> + compoundLiteralExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXNullPtrLiteralExpr> + cxxNullPtrLiteralExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ChooseExpr> chooseExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ConvertVectorExpr> + convertVectorExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CoawaitExpr> + coawaitExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, DependentCoawaitExpr> + dependentCoawaitExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CoyieldExpr> + coyieldExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, GNUNullExpr> gnuNullExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, GenericSelectionExpr> + genericSelectionExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, AtomicExpr> atomicExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, StmtExpr> stmtExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, BinaryOperator> + binaryOperator; +const internal::MapAnyOfMatcher<BinaryOperator, CXXOperatorCallExpr, + CXXRewrittenBinaryOperator> + binaryOperation; +const internal::MapAnyOfMatcher<CallExpr, CXXConstructExpr> invocation; +const internal::VariadicDynCastAllOfMatcher<Stmt, UnaryOperator> unaryOperator; +const internal::VariadicDynCastAllOfMatcher<Stmt, ConditionalOperator> + conditionalOperator; +const internal::VariadicDynCastAllOfMatcher<Stmt, BinaryConditionalOperator> + binaryConditionalOperator; +const internal::VariadicDynCastAllOfMatcher<Stmt, OpaqueValueExpr> + opaqueValueExpr; +const internal::VariadicDynCastAllOfMatcher<Decl, StaticAssertDecl> + staticAssertDecl; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXReinterpretCastExpr> + cxxReinterpretCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXStaticCastExpr> + cxxStaticCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXDynamicCastExpr> + cxxDynamicCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXConstCastExpr> + cxxConstCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CStyleCastExpr> + cStyleCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ExplicitCastExpr> + explicitCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, ImplicitCastExpr> + implicitCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CastExpr> castExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXFunctionalCastExpr> + cxxFunctionalCastExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CXXTemporaryObjectExpr> + cxxTemporaryObjectExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, PredefinedExpr> + predefinedExpr; +const internal::VariadicDynCastAllOfMatcher<Stmt, DesignatedInitExpr> + designatedInitExpr; +const internal::VariadicOperatorMatcherFunc< + 2, std::numeric_limits<unsigned>::max()> + eachOf = {internal::DynTypedMatcher::VO_EachOf}; +const internal::VariadicOperatorMatcherFunc< + 2, std::numeric_limits<unsigned>::max()> + anyOf = {internal::DynTypedMatcher::VO_AnyOf}; +const internal::VariadicOperatorMatcherFunc< + 2, std::numeric_limits<unsigned>::max()> + allOf = {internal::DynTypedMatcher::VO_AllOf}; +const internal::VariadicOperatorMatcherFunc<1, 1> optionally = { + internal::DynTypedMatcher::VO_Optionally}; +const internal::VariadicFunction<internal::Matcher<NamedDecl>, StringRef, + internal::hasAnyNameFunc> + hasAnyName = {}; + +const internal::VariadicFunction<internal::HasOpNameMatcher, StringRef, + internal::hasAnyOperatorNameFunc> + hasAnyOperatorName = {}; +const internal::VariadicFunction<internal::HasOverloadOpNameMatcher, StringRef, + internal::hasAnyOverloadedOperatorNameFunc> + hasAnyOverloadedOperatorName = {}; +const internal::VariadicFunction<internal::Matcher<ObjCMessageExpr>, StringRef, + internal::hasAnySelectorFunc> + hasAnySelector = {}; +const internal::ArgumentAdaptingMatcherFunc<internal::HasMatcher> has = {}; +const internal::ArgumentAdaptingMatcherFunc<internal::HasDescendantMatcher> + hasDescendant = {}; +const internal::ArgumentAdaptingMatcherFunc<internal::ForEachMatcher> forEach = + {}; +const internal::ArgumentAdaptingMatcherFunc<internal::ForEachDescendantMatcher> + forEachDescendant = {}; +const internal::ArgumentAdaptingMatcherFunc< + internal::HasParentMatcher, + internal::TypeList<Decl, NestedNameSpecifierLoc, Stmt, TypeLoc, Attr>, + internal::TypeList<Decl, NestedNameSpecifierLoc, Stmt, TypeLoc, Attr>> + hasParent = {}; +const internal::ArgumentAdaptingMatcherFunc< + internal::HasAncestorMatcher, + internal::TypeList<Decl, NestedNameSpecifierLoc, Stmt, TypeLoc, Attr>, + internal::TypeList<Decl, NestedNameSpecifierLoc, Stmt, TypeLoc, Attr>> + hasAncestor = {}; +const internal::VariadicOperatorMatcherFunc<1, 1> unless = { + internal::DynTypedMatcher::VO_UnaryNot}; +const internal::VariadicAllOfMatcher<NestedNameSpecifier> nestedNameSpecifier; +const internal::VariadicAllOfMatcher<NestedNameSpecifierLoc> + nestedNameSpecifierLoc; +const internal::VariadicAllOfMatcher<Attr> attr; +const internal::VariadicDynCastAllOfMatcher<Stmt, CUDAKernelCallExpr> + cudaKernelCallExpr; +const AstTypeMatcher<BuiltinType> builtinType; +const AstTypeMatcher<ArrayType> arrayType; +const AstTypeMatcher<ComplexType> complexType; +const AstTypeMatcher<ConstantArrayType> constantArrayType; +const AstTypeMatcher<DeducedTemplateSpecializationType> + deducedTemplateSpecializationType; +const AstTypeMatcher<DependentSizedArrayType> dependentSizedArrayType; +const AstTypeMatcher<DependentSizedExtVectorType> dependentSizedExtVectorType; +const AstTypeMatcher<IncompleteArrayType> incompleteArrayType; +const AstTypeMatcher<VariableArrayType> variableArrayType; +const AstTypeMatcher<AtomicType> atomicType; +const AstTypeMatcher<AutoType> autoType; +const AstTypeMatcher<DecltypeType> decltypeType; +const AstTypeMatcher<FunctionType> functionType; +const AstTypeMatcher<FunctionProtoType> functionProtoType; +const AstTypeMatcher<ParenType> parenType; +const AstTypeMatcher<BlockPointerType> blockPointerType; +const AstTypeMatcher<MacroQualifiedType> macroQualifiedType; +const AstTypeMatcher<MemberPointerType> memberPointerType; +const AstTypeMatcher<PointerType> pointerType; +const AstTypeMatcher<ObjCObjectPointerType> objcObjectPointerType; +const AstTypeMatcher<ReferenceType> referenceType; +const AstTypeMatcher<LValueReferenceType> lValueReferenceType; +const AstTypeMatcher<RValueReferenceType> rValueReferenceType; +const AstTypeMatcher<TypedefType> typedefType; +const AstTypeMatcher<EnumType> enumType; +const AstTypeMatcher<TemplateSpecializationType> templateSpecializationType; +const AstTypeMatcher<UnaryTransformType> unaryTransformType; +const AstTypeMatcher<RecordType> recordType; +const AstTypeMatcher<TagType> tagType; +const AstTypeMatcher<ElaboratedType> elaboratedType; +const AstTypeMatcher<UsingType> usingType; +const AstTypeMatcher<SubstTemplateTypeParmType> substTemplateTypeParmType; +const AstTypeMatcher<TemplateTypeParmType> templateTypeParmType; +const AstTypeMatcher<InjectedClassNameType> injectedClassNameType; +const AstTypeMatcher<DecayedType> decayedType; +AST_TYPELOC_TRAVERSE_MATCHER_DEF(hasElementType, + AST_POLYMORPHIC_SUPPORTED_TYPES(ArrayType, + ComplexType)); +AST_TYPELOC_TRAVERSE_MATCHER_DEF(hasValueType, + AST_POLYMORPHIC_SUPPORTED_TYPES(AtomicType)); +AST_TYPELOC_TRAVERSE_MATCHER_DEF( + pointee, + AST_POLYMORPHIC_SUPPORTED_TYPES(BlockPointerType, MemberPointerType, + PointerType, ReferenceType)); + +const internal::VariadicDynCastAllOfMatcher<Stmt, OMPExecutableDirective> + ompExecutableDirective; +const internal::VariadicDynCastAllOfMatcher<OMPClause, OMPDefaultClause> + ompDefaultClause; +const internal::VariadicDynCastAllOfMatcher<Decl, CXXDeductionGuideDecl> + cxxDeductionGuideDecl; + +} // end namespace ast_matchers +} // end namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp new file mode 100644 index 000000000000..41ab0ed70fda --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp @@ -0,0 +1,231 @@ +//===--- Diagnostics.cpp - Helper class for error diagnostics ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/Dynamic/Diagnostics.h" + +namespace clang { +namespace ast_matchers { +namespace dynamic { +Diagnostics::ArgStream Diagnostics::pushContextFrame(ContextType Type, + SourceRange Range) { + ContextStack.emplace_back(); + ContextFrame& data = ContextStack.back(); + data.Type = Type; + data.Range = Range; + return ArgStream(&data.Args); +} + +Diagnostics::Context::Context(ConstructMatcherEnum, Diagnostics *Error, + StringRef MatcherName, + SourceRange MatcherRange) + : Error(Error) { + Error->pushContextFrame(CT_MatcherConstruct, MatcherRange) << MatcherName; +} + +Diagnostics::Context::Context(MatcherArgEnum, Diagnostics *Error, + StringRef MatcherName, + SourceRange MatcherRange, + unsigned ArgNumber) + : Error(Error) { + Error->pushContextFrame(CT_MatcherArg, MatcherRange) << ArgNumber + << MatcherName; +} + +Diagnostics::Context::~Context() { Error->ContextStack.pop_back(); } + +Diagnostics::OverloadContext::OverloadContext(Diagnostics *Error) + : Error(Error), BeginIndex(Error->Errors.size()) {} + +Diagnostics::OverloadContext::~OverloadContext() { + // Merge all errors that happened while in this context. + if (BeginIndex < Error->Errors.size()) { + Diagnostics::ErrorContent &Dest = Error->Errors[BeginIndex]; + for (size_t i = BeginIndex + 1, e = Error->Errors.size(); i < e; ++i) { + Dest.Messages.push_back(Error->Errors[i].Messages[0]); + } + Error->Errors.resize(BeginIndex + 1); + } +} + +void Diagnostics::OverloadContext::revertErrors() { + // Revert the errors. + Error->Errors.resize(BeginIndex); +} + +Diagnostics::ArgStream &Diagnostics::ArgStream::operator<<(const Twine &Arg) { + Out->push_back(Arg.str()); + return *this; +} + +Diagnostics::ArgStream Diagnostics::addError(SourceRange Range, + ErrorType Error) { + Errors.emplace_back(); + ErrorContent &Last = Errors.back(); + Last.ContextStack = ContextStack; + Last.Messages.emplace_back(); + Last.Messages.back().Range = Range; + Last.Messages.back().Type = Error; + return ArgStream(&Last.Messages.back().Args); +} + +static StringRef contextTypeToFormatString(Diagnostics::ContextType Type) { + switch (Type) { + case Diagnostics::CT_MatcherConstruct: + return "Error building matcher $0."; + case Diagnostics::CT_MatcherArg: + return "Error parsing argument $0 for matcher $1."; + } + llvm_unreachable("Unknown ContextType value."); +} + +static StringRef errorTypeToFormatString(Diagnostics::ErrorType Type) { + switch (Type) { + case Diagnostics::ET_RegistryMatcherNotFound: + return "Matcher not found: $0"; + case Diagnostics::ET_RegistryWrongArgCount: + return "Incorrect argument count. (Expected = $0) != (Actual = $1)"; + case Diagnostics::ET_RegistryWrongArgType: + return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)"; + case Diagnostics::ET_RegistryNotBindable: + return "Matcher does not support binding."; + case Diagnostics::ET_RegistryAmbiguousOverload: + // TODO: Add type info about the overload error. + return "Ambiguous matcher overload."; + case Diagnostics::ET_RegistryValueNotFound: + return "Value not found: $0"; + case Diagnostics::ET_RegistryUnknownEnumWithReplace: + return "Unknown value '$1' for arg $0; did you mean '$2'"; + case Diagnostics::ET_RegistryNonNodeMatcher: + return "Matcher not a node matcher: $0"; + case Diagnostics::ET_RegistryMatcherNoWithSupport: + return "Matcher does not support with call."; + + case Diagnostics::ET_ParserStringError: + return "Error parsing string token: <$0>"; + case Diagnostics::ET_ParserNoOpenParen: + return "Error parsing matcher. Found token <$0> while looking for '('."; + case Diagnostics::ET_ParserNoCloseParen: + return "Error parsing matcher. Found end-of-code while looking for ')'."; + case Diagnostics::ET_ParserNoComma: + return "Error parsing matcher. Found token <$0> while looking for ','."; + case Diagnostics::ET_ParserNoCode: + return "End of code found while looking for token."; + case Diagnostics::ET_ParserNotAMatcher: + return "Input value is not a matcher expression."; + case Diagnostics::ET_ParserInvalidToken: + return "Invalid token <$0> found when looking for a value."; + case Diagnostics::ET_ParserMalformedBindExpr: + return "Malformed bind() expression."; + case Diagnostics::ET_ParserTrailingCode: + return "Expected end of code."; + case Diagnostics::ET_ParserNumberError: + return "Error parsing numeric literal: <$0>"; + case Diagnostics::ET_ParserOverloadedType: + return "Input value has unresolved overloaded type: $0"; + case Diagnostics::ET_ParserMalformedChainedExpr: + return "Period not followed by valid chained call."; + case Diagnostics::ET_ParserFailedToBuildMatcher: + return "Failed to build matcher: $0."; + + case Diagnostics::ET_None: + return "<N/A>"; + } + llvm_unreachable("Unknown ErrorType value."); +} + +static void formatErrorString(StringRef FormatString, + ArrayRef<std::string> Args, + llvm::raw_ostream &OS) { + while (!FormatString.empty()) { + std::pair<StringRef, StringRef> Pieces = FormatString.split("$"); + OS << Pieces.first.str(); + if (Pieces.second.empty()) break; + + const char Next = Pieces.second.front(); + FormatString = Pieces.second.drop_front(); + if (Next >= '0' && Next <= '9') { + const unsigned Index = Next - '0'; + if (Index < Args.size()) { + OS << Args[Index]; + } else { + OS << "<Argument_Not_Provided>"; + } + } + } +} + +static void maybeAddLineAndColumn(SourceRange Range, + llvm::raw_ostream &OS) { + if (Range.Start.Line > 0 && Range.Start.Column > 0) { + OS << Range.Start.Line << ":" << Range.Start.Column << ": "; + } +} + +static void printContextFrameToStream(const Diagnostics::ContextFrame &Frame, + llvm::raw_ostream &OS) { + maybeAddLineAndColumn(Frame.Range, OS); + formatErrorString(contextTypeToFormatString(Frame.Type), Frame.Args, OS); +} + +static void +printMessageToStream(const Diagnostics::ErrorContent::Message &Message, + const Twine Prefix, llvm::raw_ostream &OS) { + maybeAddLineAndColumn(Message.Range, OS); + OS << Prefix; + formatErrorString(errorTypeToFormatString(Message.Type), Message.Args, OS); +} + +static void printErrorContentToStream(const Diagnostics::ErrorContent &Content, + llvm::raw_ostream &OS) { + if (Content.Messages.size() == 1) { + printMessageToStream(Content.Messages[0], "", OS); + } else { + for (size_t i = 0, e = Content.Messages.size(); i != e; ++i) { + if (i != 0) OS << "\n"; + printMessageToStream(Content.Messages[i], + "Candidate " + Twine(i + 1) + ": ", OS); + } + } +} + +void Diagnostics::printToStream(llvm::raw_ostream &OS) const { + for (size_t i = 0, e = Errors.size(); i != e; ++i) { + if (i != 0) OS << "\n"; + printErrorContentToStream(Errors[i], OS); + } +} + +std::string Diagnostics::toString() const { + std::string S; + llvm::raw_string_ostream OS(S); + printToStream(OS); + return S; +} + +void Diagnostics::printToStreamFull(llvm::raw_ostream &OS) const { + for (size_t i = 0, e = Errors.size(); i != e; ++i) { + if (i != 0) OS << "\n"; + const ErrorContent &Error = Errors[i]; + for (size_t i = 0, e = Error.ContextStack.size(); i != e; ++i) { + printContextFrameToStream(Error.ContextStack[i], OS); + OS << "\n"; + } + printErrorContentToStream(Error, OS); + } +} + +std::string Diagnostics::toStringFull() const { + std::string S; + llvm::raw_string_ostream OS(S); + printToStreamFull(OS); + return S; +} + +} // namespace dynamic +} // namespace ast_matchers +} // namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.cpp new file mode 100644 index 000000000000..37c91abb5c83 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.cpp @@ -0,0 +1,170 @@ +//===--- Marshallers.cpp ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Marshallers.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Regex.h" +#include <optional> +#include <string> + +static std::optional<std::string> +getBestGuess(llvm::StringRef Search, llvm::ArrayRef<llvm::StringRef> Allowed, + llvm::StringRef DropPrefix = "", unsigned MaxEditDistance = 3) { + if (MaxEditDistance != ~0U) + ++MaxEditDistance; + llvm::StringRef Res; + for (const llvm::StringRef &Item : Allowed) { + if (Item.equals_insensitive(Search)) { + assert(Item != Search && "This should be handled earlier on."); + MaxEditDistance = 1; + Res = Item; + continue; + } + unsigned Distance = Item.edit_distance(Search); + if (Distance < MaxEditDistance) { + MaxEditDistance = Distance; + Res = Item; + } + } + if (!Res.empty()) + return Res.str(); + if (!DropPrefix.empty()) { + --MaxEditDistance; // Treat dropping the prefix as 1 edit + for (const llvm::StringRef &Item : Allowed) { + auto NoPrefix = Item; + if (!NoPrefix.consume_front(DropPrefix)) + continue; + if (NoPrefix.equals_insensitive(Search)) { + if (NoPrefix == Search) + return Item.str(); + MaxEditDistance = 1; + Res = Item; + continue; + } + unsigned Distance = NoPrefix.edit_distance(Search); + if (Distance < MaxEditDistance) { + MaxEditDistance = Distance; + Res = Item; + } + } + if (!Res.empty()) + return Res.str(); + } + return std::nullopt; +} + +std::optional<std::string> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + clang::attr::Kind>::getBestGuess(const VariantValue &Value) { + static constexpr llvm::StringRef Allowed[] = { +#define ATTR(X) "attr::" #X, +#include "clang/Basic/AttrList.inc" + }; + if (Value.isString()) + return ::getBestGuess(Value.getString(), llvm::ArrayRef(Allowed), "attr::"); + return std::nullopt; +} + +std::optional<std::string> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + clang::CastKind>::getBestGuess(const VariantValue &Value) { + static constexpr llvm::StringRef Allowed[] = { +#define CAST_OPERATION(Name) "CK_" #Name, +#include "clang/AST/OperationKinds.def" + }; + if (Value.isString()) + return ::getBestGuess(Value.getString(), llvm::ArrayRef(Allowed), "CK_"); + return std::nullopt; +} + +std::optional<std::string> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + clang::OpenMPClauseKind>::getBestGuess(const VariantValue &Value) { + static constexpr llvm::StringRef Allowed[] = { +#define GEN_CLANG_CLAUSE_CLASS +#define CLAUSE_CLASS(Enum, Str, Class) #Enum, +#include "llvm/Frontend/OpenMP/OMP.inc" + }; + if (Value.isString()) + return ::getBestGuess(Value.getString(), llvm::ArrayRef(Allowed), "OMPC_"); + return std::nullopt; +} + +std::optional<std::string> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + clang::UnaryExprOrTypeTrait>::getBestGuess(const VariantValue &Value) { + static constexpr llvm::StringRef Allowed[] = { +#define UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key) "UETT_" #Name, +#define CXX11_UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key) "UETT_" #Name, +#include "clang/Basic/TokenKinds.def" + }; + if (Value.isString()) + return ::getBestGuess(Value.getString(), llvm::ArrayRef(Allowed), "UETT_"); + return std::nullopt; +} + +static constexpr std::pair<llvm::StringRef, llvm::Regex::RegexFlags> + RegexMap[] = { + {"NoFlags", llvm::Regex::RegexFlags::NoFlags}, + {"IgnoreCase", llvm::Regex::RegexFlags::IgnoreCase}, + {"Newline", llvm::Regex::RegexFlags::Newline}, + {"BasicRegex", llvm::Regex::RegexFlags::BasicRegex}, +}; + +static std::optional<llvm::Regex::RegexFlags> +getRegexFlag(llvm::StringRef Flag) { + for (const auto &StringFlag : RegexMap) { + if (Flag == StringFlag.first) + return StringFlag.second; + } + return std::nullopt; +} + +static std::optional<llvm::StringRef> getCloseRegexMatch(llvm::StringRef Flag) { + for (const auto &StringFlag : RegexMap) { + if (Flag.edit_distance(StringFlag.first) < 3) + return StringFlag.first; + } + return std::nullopt; +} + +std::optional<llvm::Regex::RegexFlags> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + llvm::Regex::RegexFlags>::getFlags(llvm::StringRef Flags) { + std::optional<llvm::Regex::RegexFlags> Flag; + SmallVector<StringRef, 4> Split; + Flags.split(Split, '|', -1, false); + for (StringRef OrFlag : Split) { + if (std::optional<llvm::Regex::RegexFlags> NextFlag = + getRegexFlag(OrFlag.trim())) + Flag = Flag.value_or(llvm::Regex::NoFlags) | *NextFlag; + else + return std::nullopt; + } + return Flag; +} + +std::optional<std::string> +clang::ast_matchers::dynamic::internal::ArgTypeTraits< + llvm::Regex::RegexFlags>::getBestGuess(const VariantValue &Value) { + if (!Value.isString()) + return std::nullopt; + SmallVector<StringRef, 4> Split; + llvm::StringRef(Value.getString()).split(Split, '|', -1, false); + for (llvm::StringRef &Flag : Split) { + if (std::optional<llvm::StringRef> BestGuess = + getCloseRegexMatch(Flag.trim())) + Flag = *BestGuess; + else + return std::nullopt; + } + if (Split.empty()) + return std::nullopt; + return llvm::join(Split, " | "); +} diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.h b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.h new file mode 100644 index 000000000000..0e640cbada72 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.h @@ -0,0 +1,1170 @@ +//===- Marshallers.h - Generic matcher function marshallers -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Functions templates and classes to wrap matcher construct functions. +/// +/// A collection of template function and classes that provide a generic +/// marshalling layer on top of matcher construct functions. +/// These are used by the registry to export all marshaller constructors with +/// the same generic interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H +#define LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H + +#include "clang/AST/ASTTypeTraits.h" +#include "clang/AST/OperationKinds.h" +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/ASTMatchers/Dynamic/Diagnostics.h" +#include "clang/ASTMatchers/Dynamic/VariantValue.h" +#include "clang/Basic/AttrKinds.h" +#include "clang/Basic/LLVM.h" +#include "clang/Basic/OpenMPKinds.h" +#include "clang/Basic/TypeTraits.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Regex.h" +#include <cassert> +#include <cstddef> +#include <iterator> +#include <limits> +#include <memory> +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace clang { +namespace ast_matchers { +namespace dynamic { +namespace internal { + +/// Helper template class to just from argument type to the right is/get +/// functions in VariantValue. +/// Used to verify and extract the matcher arguments below. +template <class T> struct ArgTypeTraits; +template <class T> struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> { +}; + +template <> struct ArgTypeTraits<std::string> { + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue &Value) { return true; } + + static const std::string &get(const VariantValue &Value) { + return Value.getString(); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_String); + } + + static std::optional<std::string> getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> +struct ArgTypeTraits<StringRef> : public ArgTypeTraits<std::string> { +}; + +template <class T> struct ArgTypeTraits<ast_matchers::internal::Matcher<T>> { + static bool hasCorrectType(const VariantValue& Value) { + return Value.isMatcher(); + } + static bool hasCorrectValue(const VariantValue &Value) { + return Value.getMatcher().hasTypedMatcher<T>(); + } + + static ast_matchers::internal::Matcher<T> get(const VariantValue &Value) { + return Value.getMatcher().getTypedMatcher<T>(); + } + + static ArgKind getKind() { + return ArgKind::MakeMatcherArg(ASTNodeKind::getFromNodeKind<T>()); + } + + static std::optional<std::string> getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> struct ArgTypeTraits<bool> { + static bool hasCorrectType(const VariantValue &Value) { + return Value.isBoolean(); + } + static bool hasCorrectValue(const VariantValue &Value) { return true; } + + static bool get(const VariantValue &Value) { + return Value.getBoolean(); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_Boolean); + } + + static std::optional<std::string> getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> struct ArgTypeTraits<double> { + static bool hasCorrectType(const VariantValue &Value) { + return Value.isDouble(); + } + static bool hasCorrectValue(const VariantValue &Value) { return true; } + + static double get(const VariantValue &Value) { + return Value.getDouble(); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_Double); + } + + static std::optional<std::string> getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> struct ArgTypeTraits<unsigned> { + static bool hasCorrectType(const VariantValue &Value) { + return Value.isUnsigned(); + } + static bool hasCorrectValue(const VariantValue &Value) { return true; } + + static unsigned get(const VariantValue &Value) { + return Value.getUnsigned(); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_Unsigned); + } + + static std::optional<std::string> getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> struct ArgTypeTraits<attr::Kind> { +private: + static std::optional<attr::Kind> getAttrKind(llvm::StringRef AttrKind) { + if (!AttrKind.consume_front("attr::")) + return std::nullopt; + return llvm::StringSwitch<std::optional<attr::Kind>>(AttrKind) +#define ATTR(X) .Case(#X, attr::X) +#include "clang/Basic/AttrList.inc" + .Default(std::nullopt); + } + +public: + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue& Value) { + return getAttrKind(Value.getString()).has_value(); + } + + static attr::Kind get(const VariantValue &Value) { + return *getAttrKind(Value.getString()); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_String); + } + + static std::optional<std::string> getBestGuess(const VariantValue &Value); +}; + +template <> struct ArgTypeTraits<CastKind> { +private: + static std::optional<CastKind> getCastKind(llvm::StringRef AttrKind) { + if (!AttrKind.consume_front("CK_")) + return std::nullopt; + return llvm::StringSwitch<std::optional<CastKind>>(AttrKind) +#define CAST_OPERATION(Name) .Case(#Name, CK_##Name) +#include "clang/AST/OperationKinds.def" + .Default(std::nullopt); + } + +public: + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue& Value) { + return getCastKind(Value.getString()).has_value(); + } + + static CastKind get(const VariantValue &Value) { + return *getCastKind(Value.getString()); + } + + static ArgKind getKind() { + return ArgKind(ArgKind::AK_String); + } + + static std::optional<std::string> getBestGuess(const VariantValue &Value); +}; + +template <> struct ArgTypeTraits<llvm::Regex::RegexFlags> { +private: + static std::optional<llvm::Regex::RegexFlags> getFlags(llvm::StringRef Flags); + +public: + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue& Value) { + return getFlags(Value.getString()).has_value(); + } + + static llvm::Regex::RegexFlags get(const VariantValue &Value) { + return *getFlags(Value.getString()); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + + static std::optional<std::string> getBestGuess(const VariantValue &Value); +}; + +template <> struct ArgTypeTraits<OpenMPClauseKind> { +private: + static std::optional<OpenMPClauseKind> + getClauseKind(llvm::StringRef ClauseKind) { + return llvm::StringSwitch<std::optional<OpenMPClauseKind>>(ClauseKind) +#define GEN_CLANG_CLAUSE_CLASS +#define CLAUSE_CLASS(Enum, Str, Class) .Case(#Enum, llvm::omp::Clause::Enum) +#include "llvm/Frontend/OpenMP/OMP.inc" + .Default(std::nullopt); + } + +public: + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue& Value) { + return getClauseKind(Value.getString()).has_value(); + } + + static OpenMPClauseKind get(const VariantValue &Value) { + return *getClauseKind(Value.getString()); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + + static std::optional<std::string> getBestGuess(const VariantValue &Value); +}; + +template <> struct ArgTypeTraits<UnaryExprOrTypeTrait> { +private: + static std::optional<UnaryExprOrTypeTrait> + getUnaryOrTypeTraitKind(llvm::StringRef ClauseKind) { + if (!ClauseKind.consume_front("UETT_")) + return std::nullopt; + return llvm::StringSwitch<std::optional<UnaryExprOrTypeTrait>>(ClauseKind) +#define UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key) .Case(#Name, UETT_##Name) +#define CXX11_UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key) \ + .Case(#Name, UETT_##Name) +#include "clang/Basic/TokenKinds.def" + .Default(std::nullopt); + } + +public: + static bool hasCorrectType(const VariantValue &Value) { + return Value.isString(); + } + static bool hasCorrectValue(const VariantValue& Value) { + return getUnaryOrTypeTraitKind(Value.getString()).has_value(); + } + + static UnaryExprOrTypeTrait get(const VariantValue &Value) { + return *getUnaryOrTypeTraitKind(Value.getString()); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + + static std::optional<std::string> getBestGuess(const VariantValue &Value); +}; + +/// Matcher descriptor interface. +/// +/// Provides a \c create() method that constructs the matcher from the provided +/// arguments, and various other methods for type introspection. +class MatcherDescriptor { +public: + virtual ~MatcherDescriptor() = default; + + virtual VariantMatcher create(SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const = 0; + + virtual ASTNodeKind nodeMatcherType() const { return ASTNodeKind(); } + + virtual bool isBuilderMatcher() const { return false; } + + virtual std::unique_ptr<MatcherDescriptor> + buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args, + Diagnostics *Error) const { + return {}; + } + + /// Returns whether the matcher is variadic. Variadic matchers can take any + /// number of arguments, but they must be of the same type. + virtual bool isVariadic() const = 0; + + /// Returns the number of arguments accepted by the matcher if not variadic. + virtual unsigned getNumArgs() const = 0; + + /// Given that the matcher is being converted to type \p ThisKind, append the + /// set of argument types accepted for argument \p ArgNo to \p ArgKinds. + // FIXME: We should provide the ability to constrain the output of this + // function based on the types of other matcher arguments. + virtual void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &ArgKinds) const = 0; + + /// Returns whether this matcher is convertible to the given type. If it is + /// so convertible, store in *Specificity a value corresponding to the + /// "specificity" of the converted matcher to the given context, and in + /// *LeastDerivedKind the least derived matcher kind which would result in the + /// same matcher overload. Zero specificity indicates that this conversion + /// would produce a trivial matcher that will either always or never match. + /// Such matchers are excluded from code completion results. + virtual bool + isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr, + ASTNodeKind *LeastDerivedKind = nullptr) const = 0; + + /// Returns whether the matcher will, given a matcher of any type T, yield a + /// matcher of type T. + virtual bool isPolymorphic() const { return false; } +}; + +inline bool isRetKindConvertibleTo(ArrayRef<ASTNodeKind> RetKinds, + ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) { + for (const ASTNodeKind &NodeKind : RetKinds) { + if (ArgKind::MakeMatcherArg(NodeKind).isConvertibleTo( + ArgKind::MakeMatcherArg(Kind), Specificity)) { + if (LeastDerivedKind) + *LeastDerivedKind = NodeKind; + return true; + } + } + return false; +} + +/// Simple callback implementation. Marshaller and function are provided. +/// +/// This class wraps a function of arbitrary signature and a marshaller +/// function into a MatcherDescriptor. +/// The marshaller is in charge of taking the VariantValue arguments, checking +/// their types, unpacking them and calling the underlying function. +class FixedArgCountMatcherDescriptor : public MatcherDescriptor { +public: + using MarshallerType = VariantMatcher (*)(void (*Func)(), + StringRef MatcherName, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error); + + /// \param Marshaller Function to unpack the arguments and call \c Func + /// \param Func Matcher construct function. This is the function that + /// compile-time matcher expressions would use to create the matcher. + /// \param RetKinds The list of matcher types to which the matcher is + /// convertible. + /// \param ArgKinds The types of the arguments this matcher takes. + FixedArgCountMatcherDescriptor(MarshallerType Marshaller, void (*Func)(), + StringRef MatcherName, + ArrayRef<ASTNodeKind> RetKinds, + ArrayRef<ArgKind> ArgKinds) + : Marshaller(Marshaller), Func(Func), MatcherName(MatcherName), + RetKinds(RetKinds.begin(), RetKinds.end()), + ArgKinds(ArgKinds.begin(), ArgKinds.end()) {} + + VariantMatcher create(SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + return Marshaller(Func, MatcherName, NameRange, Args, Error); + } + + bool isVariadic() const override { return false; } + unsigned getNumArgs() const override { return ArgKinds.size(); } + + void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &Kinds) const override { + Kinds.push_back(ArgKinds[ArgNo]); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + return isRetKindConvertibleTo(RetKinds, Kind, Specificity, + LeastDerivedKind); + } + +private: + const MarshallerType Marshaller; + void (* const Func)(); + const std::string MatcherName; + const std::vector<ASTNodeKind> RetKinds; + const std::vector<ArgKind> ArgKinds; +}; + +/// Helper methods to extract and merge all possible typed matchers +/// out of the polymorphic object. +template <class PolyMatcher> +static void mergePolyMatchers(const PolyMatcher &Poly, + std::vector<DynTypedMatcher> &Out, + ast_matchers::internal::EmptyTypeList) {} + +template <class PolyMatcher, class TypeList> +static void mergePolyMatchers(const PolyMatcher &Poly, + std::vector<DynTypedMatcher> &Out, TypeList) { + Out.push_back(ast_matchers::internal::Matcher<typename TypeList::head>(Poly)); + mergePolyMatchers(Poly, Out, typename TypeList::tail()); +} + +/// Convert the return values of the functions into a VariantMatcher. +/// +/// There are 2 cases right now: The return value is a Matcher<T> or is a +/// polymorphic matcher. For the former, we just construct the VariantMatcher. +/// For the latter, we instantiate all the possible Matcher<T> of the poly +/// matcher. +inline VariantMatcher outvalueToVariantMatcher(const DynTypedMatcher &Matcher) { + return VariantMatcher::SingleMatcher(Matcher); +} + +template <typename T> +static VariantMatcher outvalueToVariantMatcher(const T &PolyMatcher, + typename T::ReturnTypes * = + nullptr) { + std::vector<DynTypedMatcher> Matchers; + mergePolyMatchers(PolyMatcher, Matchers, typename T::ReturnTypes()); + VariantMatcher Out = VariantMatcher::PolymorphicMatcher(std::move(Matchers)); + return Out; +} + +template <typename T> +inline void +buildReturnTypeVectorFromTypeList(std::vector<ASTNodeKind> &RetTypes) { + RetTypes.push_back(ASTNodeKind::getFromNodeKind<typename T::head>()); + buildReturnTypeVectorFromTypeList<typename T::tail>(RetTypes); +} + +template <> +inline void +buildReturnTypeVectorFromTypeList<ast_matchers::internal::EmptyTypeList>( + std::vector<ASTNodeKind> &RetTypes) {} + +template <typename T> +struct BuildReturnTypeVector { + static void build(std::vector<ASTNodeKind> &RetTypes) { + buildReturnTypeVectorFromTypeList<typename T::ReturnTypes>(RetTypes); + } +}; + +template <typename T> +struct BuildReturnTypeVector<ast_matchers::internal::Matcher<T>> { + static void build(std::vector<ASTNodeKind> &RetTypes) { + RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>()); + } +}; + +template <typename T> +struct BuildReturnTypeVector<ast_matchers::internal::BindableMatcher<T>> { + static void build(std::vector<ASTNodeKind> &RetTypes) { + RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>()); + } +}; + +/// Variadic marshaller function. +template <typename ResultT, typename ArgT, + ResultT (*Func)(ArrayRef<const ArgT *>)> +VariantMatcher +variadicMatcherDescriptor(StringRef MatcherName, SourceRange NameRange, + ArrayRef<ParserValue> Args, Diagnostics *Error) { + SmallVector<ArgT *, 8> InnerArgsPtr; + InnerArgsPtr.resize_for_overwrite(Args.size()); + SmallVector<ArgT, 8> InnerArgs; + InnerArgs.reserve(Args.size()); + + for (size_t i = 0, e = Args.size(); i != e; ++i) { + using ArgTraits = ArgTypeTraits<ArgT>; + + const ParserValue &Arg = Args[i]; + const VariantValue &Value = Arg.Value; + if (!ArgTraits::hasCorrectType(Value)) { + Error->addError(Arg.Range, Error->ET_RegistryWrongArgType) + << (i + 1) << ArgTraits::getKind().asString() << Value.getTypeAsString(); + return {}; + } + if (!ArgTraits::hasCorrectValue(Value)) { + if (std::optional<std::string> BestGuess = + ArgTraits::getBestGuess(Value)) { + Error->addError(Arg.Range, Error->ET_RegistryUnknownEnumWithReplace) + << i + 1 << Value.getString() << *BestGuess; + } else if (Value.isString()) { + Error->addError(Arg.Range, Error->ET_RegistryValueNotFound) + << Value.getString(); + } else { + // This isn't ideal, but it's better than reporting an empty string as + // the error in this case. + Error->addError(Arg.Range, Error->ET_RegistryWrongArgType) + << (i + 1) << ArgTraits::getKind().asString() + << Value.getTypeAsString(); + } + return {}; + } + assert(InnerArgs.size() < InnerArgs.capacity()); + InnerArgs.emplace_back(ArgTraits::get(Value)); + InnerArgsPtr[i] = &InnerArgs[i]; + } + return outvalueToVariantMatcher(Func(InnerArgsPtr)); +} + +/// Matcher descriptor for variadic functions. +/// +/// This class simply wraps a VariadicFunction with the right signature to export +/// it as a MatcherDescriptor. +/// This allows us to have one implementation of the interface for as many free +/// functions as we want, reducing the number of symbols and size of the +/// object file. +class VariadicFuncMatcherDescriptor : public MatcherDescriptor { +public: + using RunFunc = VariantMatcher (*)(StringRef MatcherName, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error); + + template <typename ResultT, typename ArgT, + ResultT (*F)(ArrayRef<const ArgT *>)> + VariadicFuncMatcherDescriptor( + ast_matchers::internal::VariadicFunction<ResultT, ArgT, F> Func, + StringRef MatcherName) + : Func(&variadicMatcherDescriptor<ResultT, ArgT, F>), + MatcherName(MatcherName.str()), + ArgsKind(ArgTypeTraits<ArgT>::getKind()) { + BuildReturnTypeVector<ResultT>::build(RetKinds); + } + + VariantMatcher create(SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + return Func(MatcherName, NameRange, Args, Error); + } + + bool isVariadic() const override { return true; } + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &Kinds) const override { + Kinds.push_back(ArgsKind); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + return isRetKindConvertibleTo(RetKinds, Kind, Specificity, + LeastDerivedKind); + } + + ASTNodeKind nodeMatcherType() const override { return RetKinds[0]; } + +private: + const RunFunc Func; + const std::string MatcherName; + std::vector<ASTNodeKind> RetKinds; + const ArgKind ArgsKind; +}; + +/// Return CK_Trivial when appropriate for VariadicDynCastAllOfMatchers. +class DynCastAllOfMatcherDescriptor : public VariadicFuncMatcherDescriptor { +public: + template <typename BaseT, typename DerivedT> + DynCastAllOfMatcherDescriptor( + ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT> Func, + StringRef MatcherName) + : VariadicFuncMatcherDescriptor(Func, MatcherName), + DerivedKind(ASTNodeKind::getFromNodeKind<DerivedT>()) {} + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + // If Kind is not a base of DerivedKind, either DerivedKind is a base of + // Kind (in which case the match will always succeed) or Kind and + // DerivedKind are unrelated (in which case it will always fail), so set + // Specificity to 0. + if (VariadicFuncMatcherDescriptor::isConvertibleTo(Kind, Specificity, + LeastDerivedKind)) { + if (Kind.isSame(DerivedKind) || !Kind.isBaseOf(DerivedKind)) { + if (Specificity) + *Specificity = 0; + } + return true; + } else { + return false; + } + } + + ASTNodeKind nodeMatcherType() const override { return DerivedKind; } + +private: + const ASTNodeKind DerivedKind; +}; + +/// Helper macros to check the arguments on all marshaller functions. +#define CHECK_ARG_COUNT(count) \ + if (Args.size() != count) { \ + Error->addError(NameRange, Error->ET_RegistryWrongArgCount) \ + << count << Args.size(); \ + return VariantMatcher(); \ + } + +#define CHECK_ARG_TYPE(index, type) \ + if (!ArgTypeTraits<type>::hasCorrectType(Args[index].Value)) { \ + Error->addError(Args[index].Range, Error->ET_RegistryWrongArgType) \ + << (index + 1) << ArgTypeTraits<type>::getKind().asString() \ + << Args[index].Value.getTypeAsString(); \ + return VariantMatcher(); \ + } \ + if (!ArgTypeTraits<type>::hasCorrectValue(Args[index].Value)) { \ + if (std::optional<std::string> BestGuess = \ + ArgTypeTraits<type>::getBestGuess(Args[index].Value)) { \ + Error->addError(Args[index].Range, \ + Error->ET_RegistryUnknownEnumWithReplace) \ + << index + 1 << Args[index].Value.getString() << *BestGuess; \ + } else if (Args[index].Value.isString()) { \ + Error->addError(Args[index].Range, Error->ET_RegistryValueNotFound) \ + << Args[index].Value.getString(); \ + } \ + return VariantMatcher(); \ + } + +/// 0-arg marshaller function. +template <typename ReturnType> +static VariantMatcher matcherMarshall0(void (*Func)(), StringRef MatcherName, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) { + using FuncType = ReturnType (*)(); + CHECK_ARG_COUNT(0); + return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)()); +} + +/// 1-arg marshaller function. +template <typename ReturnType, typename ArgType1> +static VariantMatcher matcherMarshall1(void (*Func)(), StringRef MatcherName, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) { + using FuncType = ReturnType (*)(ArgType1); + CHECK_ARG_COUNT(1); + CHECK_ARG_TYPE(0, ArgType1); + return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)( + ArgTypeTraits<ArgType1>::get(Args[0].Value))); +} + +/// 2-arg marshaller function. +template <typename ReturnType, typename ArgType1, typename ArgType2> +static VariantMatcher matcherMarshall2(void (*Func)(), StringRef MatcherName, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) { + using FuncType = ReturnType (*)(ArgType1, ArgType2); + CHECK_ARG_COUNT(2); + CHECK_ARG_TYPE(0, ArgType1); + CHECK_ARG_TYPE(1, ArgType2); + return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)( + ArgTypeTraits<ArgType1>::get(Args[0].Value), + ArgTypeTraits<ArgType2>::get(Args[1].Value))); +} + +#undef CHECK_ARG_COUNT +#undef CHECK_ARG_TYPE + +/// Helper class used to collect all the possible overloads of an +/// argument adaptative matcher function. +template <template <typename ToArg, typename FromArg> class ArgumentAdapterT, + typename FromTypes, typename ToTypes> +class AdaptativeOverloadCollector { +public: + AdaptativeOverloadCollector( + StringRef Name, std::vector<std::unique_ptr<MatcherDescriptor>> &Out) + : Name(Name), Out(Out) { + collect(FromTypes()); + } + +private: + using AdaptativeFunc = ast_matchers::internal::ArgumentAdaptingMatcherFunc< + ArgumentAdapterT, FromTypes, ToTypes>; + + /// End case for the recursion + static void collect(ast_matchers::internal::EmptyTypeList) {} + + /// Recursive case. Get the overload for the head of the list, and + /// recurse to the tail. + template <typename FromTypeList> + inline void collect(FromTypeList); + + StringRef Name; + std::vector<std::unique_ptr<MatcherDescriptor>> &Out; +}; + +/// MatcherDescriptor that wraps multiple "overloads" of the same +/// matcher. +/// +/// It will try every overload and generate appropriate errors for when none or +/// more than one overloads match the arguments. +class OverloadedMatcherDescriptor : public MatcherDescriptor { +public: + OverloadedMatcherDescriptor( + MutableArrayRef<std::unique_ptr<MatcherDescriptor>> Callbacks) + : Overloads(std::make_move_iterator(Callbacks.begin()), + std::make_move_iterator(Callbacks.end())) {} + + ~OverloadedMatcherDescriptor() override = default; + + VariantMatcher create(SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + std::vector<VariantMatcher> Constructed; + Diagnostics::OverloadContext Ctx(Error); + for (const auto &O : Overloads) { + VariantMatcher SubMatcher = O->create(NameRange, Args, Error); + if (!SubMatcher.isNull()) { + Constructed.push_back(SubMatcher); + } + } + + if (Constructed.empty()) return VariantMatcher(); // No overload matched. + // We ignore the errors if any matcher succeeded. + Ctx.revertErrors(); + if (Constructed.size() > 1) { + // More than one constructed. It is ambiguous. + Error->addError(NameRange, Error->ET_RegistryAmbiguousOverload); + return VariantMatcher(); + } + return Constructed[0]; + } + + bool isVariadic() const override { + bool Overload0Variadic = Overloads[0]->isVariadic(); +#ifndef NDEBUG + for (const auto &O : Overloads) { + assert(Overload0Variadic == O->isVariadic()); + } +#endif + return Overload0Variadic; + } + + unsigned getNumArgs() const override { + unsigned Overload0NumArgs = Overloads[0]->getNumArgs(); +#ifndef NDEBUG + for (const auto &O : Overloads) { + assert(Overload0NumArgs == O->getNumArgs()); + } +#endif + return Overload0NumArgs; + } + + void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &Kinds) const override { + for (const auto &O : Overloads) { + if (O->isConvertibleTo(ThisKind)) + O->getArgKinds(ThisKind, ArgNo, Kinds); + } + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + for (const auto &O : Overloads) { + if (O->isConvertibleTo(Kind, Specificity, LeastDerivedKind)) + return true; + } + return false; + } + +private: + std::vector<std::unique_ptr<MatcherDescriptor>> Overloads; +}; + +template <typename ReturnType> +class RegexMatcherDescriptor : public MatcherDescriptor { +public: + RegexMatcherDescriptor(ReturnType (*WithFlags)(StringRef, + llvm::Regex::RegexFlags), + ReturnType (*NoFlags)(StringRef), + ArrayRef<ASTNodeKind> RetKinds) + : WithFlags(WithFlags), NoFlags(NoFlags), + RetKinds(RetKinds.begin(), RetKinds.end()) {} + bool isVariadic() const override { return true; } + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &Kinds) const override { + assert(ArgNo < 2); + Kinds.push_back(ArgKind::AK_String); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + return isRetKindConvertibleTo(RetKinds, Kind, Specificity, + LeastDerivedKind); + } + + VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + if (Args.size() < 1 || Args.size() > 2) { + Error->addError(NameRange, Diagnostics::ET_RegistryWrongArgCount) + << "1 or 2" << Args.size(); + return VariantMatcher(); + } + if (!ArgTypeTraits<StringRef>::hasCorrectType(Args[0].Value)) { + Error->addError(Args[0].Range, Error->ET_RegistryWrongArgType) + << 1 << ArgTypeTraits<StringRef>::getKind().asString() + << Args[0].Value.getTypeAsString(); + return VariantMatcher(); + } + if (Args.size() == 1) { + return outvalueToVariantMatcher( + NoFlags(ArgTypeTraits<StringRef>::get(Args[0].Value))); + } + if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectType( + Args[1].Value)) { + Error->addError(Args[1].Range, Error->ET_RegistryWrongArgType) + << 2 << ArgTypeTraits<llvm::Regex::RegexFlags>::getKind().asString() + << Args[1].Value.getTypeAsString(); + return VariantMatcher(); + } + if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectValue( + Args[1].Value)) { + if (std::optional<std::string> BestGuess = + ArgTypeTraits<llvm::Regex::RegexFlags>::getBestGuess( + Args[1].Value)) { + Error->addError(Args[1].Range, Error->ET_RegistryUnknownEnumWithReplace) + << 2 << Args[1].Value.getString() << *BestGuess; + } else { + Error->addError(Args[1].Range, Error->ET_RegistryValueNotFound) + << Args[1].Value.getString(); + } + return VariantMatcher(); + } + return outvalueToVariantMatcher( + WithFlags(ArgTypeTraits<StringRef>::get(Args[0].Value), + ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value))); + } + +private: + ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags); + ReturnType (*const NoFlags)(StringRef); + const std::vector<ASTNodeKind> RetKinds; +}; + +/// Variadic operator marshaller function. +class VariadicOperatorMatcherDescriptor : public MatcherDescriptor { +public: + using VarOp = DynTypedMatcher::VariadicOperator; + + VariadicOperatorMatcherDescriptor(unsigned MinCount, unsigned MaxCount, + VarOp Op, StringRef MatcherName) + : MinCount(MinCount), MaxCount(MaxCount), Op(Op), + MatcherName(MatcherName) {} + + VariantMatcher create(SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + if (Args.size() < MinCount || MaxCount < Args.size()) { + const std::string MaxStr = + (MaxCount == std::numeric_limits<unsigned>::max() ? "" + : Twine(MaxCount)) + .str(); + Error->addError(NameRange, Error->ET_RegistryWrongArgCount) + << ("(" + Twine(MinCount) + ", " + MaxStr + ")") << Args.size(); + return VariantMatcher(); + } + + std::vector<VariantMatcher> InnerArgs; + for (size_t i = 0, e = Args.size(); i != e; ++i) { + const ParserValue &Arg = Args[i]; + const VariantValue &Value = Arg.Value; + if (!Value.isMatcher()) { + Error->addError(Arg.Range, Error->ET_RegistryWrongArgType) + << (i + 1) << "Matcher<>" << Value.getTypeAsString(); + return VariantMatcher(); + } + InnerArgs.push_back(Value.getMatcher()); + } + return VariantMatcher::VariadicOperatorMatcher(Op, std::move(InnerArgs)); + } + + bool isVariadic() const override { return true; } + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo, + std::vector<ArgKind> &Kinds) const override { + Kinds.push_back(ArgKind::MakeMatcherArg(ThisKind)); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + if (Specificity) + *Specificity = 1; + if (LeastDerivedKind) + *LeastDerivedKind = Kind; + return true; + } + + bool isPolymorphic() const override { return true; } + +private: + const unsigned MinCount; + const unsigned MaxCount; + const VarOp Op; + const StringRef MatcherName; +}; + +class MapAnyOfMatcherDescriptor : public MatcherDescriptor { + ASTNodeKind CladeNodeKind; + std::vector<ASTNodeKind> NodeKinds; + +public: + MapAnyOfMatcherDescriptor(ASTNodeKind CladeNodeKind, + std::vector<ASTNodeKind> NodeKinds) + : CladeNodeKind(CladeNodeKind), NodeKinds(std::move(NodeKinds)) {} + + VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args, + Diagnostics *Error) const override { + + std::vector<DynTypedMatcher> NodeArgs; + + for (auto NK : NodeKinds) { + std::vector<DynTypedMatcher> InnerArgs; + + for (const auto &Arg : Args) { + if (!Arg.Value.isMatcher()) + return {}; + const VariantMatcher &VM = Arg.Value.getMatcher(); + if (VM.hasTypedMatcher(NK)) { + auto DM = VM.getTypedMatcher(NK); + InnerArgs.push_back(DM); + } + } + + if (InnerArgs.empty()) { + NodeArgs.push_back( + DynTypedMatcher::trueMatcher(NK).dynCastTo(CladeNodeKind)); + } else { + NodeArgs.push_back( + DynTypedMatcher::constructVariadic( + ast_matchers::internal::DynTypedMatcher::VO_AllOf, NK, + InnerArgs) + .dynCastTo(CladeNodeKind)); + } + } + + auto Result = DynTypedMatcher::constructVariadic( + ast_matchers::internal::DynTypedMatcher::VO_AnyOf, CladeNodeKind, + NodeArgs); + Result.setAllowBind(true); + return VariantMatcher::SingleMatcher(Result); + } + + bool isVariadic() const override { return true; } + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(ASTNodeKind ThisKind, unsigned, + std::vector<ArgKind> &Kinds) const override { + Kinds.push_back(ArgKind::MakeMatcherArg(ThisKind)); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity, + ASTNodeKind *LeastDerivedKind) const override { + if (Specificity) + *Specificity = 1; + if (LeastDerivedKind) + *LeastDerivedKind = CladeNodeKind; + return true; + } +}; + +class MapAnyOfBuilderDescriptor : public MatcherDescriptor { +public: + VariantMatcher create(SourceRange, ArrayRef<ParserValue>, + Diagnostics *) const override { + return {}; + } + + bool isBuilderMatcher() const override { return true; } + + std::unique_ptr<MatcherDescriptor> + buildMatcherCtor(SourceRange, ArrayRef<ParserValue> Args, + Diagnostics *) const override { + + std::vector<ASTNodeKind> NodeKinds; + for (const auto &Arg : Args) { + if (!Arg.Value.isNodeKind()) + return {}; + NodeKinds.push_back(Arg.Value.getNodeKind()); + } + + if (NodeKinds.empty()) + return {}; + + ASTNodeKind CladeNodeKind = NodeKinds.front().getCladeKind(); + + for (auto NK : NodeKinds) + { + if (!NK.getCladeKind().isSame(CladeNodeKind)) + return {}; + } + + return std::make_unique<MapAnyOfMatcherDescriptor>(CladeNodeKind, + std::move(NodeKinds)); + } + + bool isVariadic() const override { return true; } + + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(ASTNodeKind ThisKind, unsigned, + std::vector<ArgKind> &ArgKinds) const override { + ArgKinds.push_back(ArgKind::MakeNodeArg(ThisKind)); + } + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr, + ASTNodeKind *LeastDerivedKind = nullptr) const override { + if (Specificity) + *Specificity = 1; + if (LeastDerivedKind) + *LeastDerivedKind = Kind; + return true; + } + + bool isPolymorphic() const override { return false; } +}; + +/// Helper functions to select the appropriate marshaller functions. +/// They detect the number of arguments, arguments types and return type. + +/// 0-arg overload +template <typename ReturnType> +std::unique_ptr<MatcherDescriptor> +makeMatcherAutoMarshall(ReturnType (*Func)(), StringRef MatcherName) { + std::vector<ASTNodeKind> RetTypes; + BuildReturnTypeVector<ReturnType>::build(RetTypes); + return std::make_unique<FixedArgCountMatcherDescriptor>( + matcherMarshall0<ReturnType>, reinterpret_cast<void (*)()>(Func), + MatcherName, RetTypes, std::nullopt); +} + +/// 1-arg overload +template <typename ReturnType, typename ArgType1> +std::unique_ptr<MatcherDescriptor> +makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1), StringRef MatcherName) { + std::vector<ASTNodeKind> RetTypes; + BuildReturnTypeVector<ReturnType>::build(RetTypes); + ArgKind AK = ArgTypeTraits<ArgType1>::getKind(); + return std::make_unique<FixedArgCountMatcherDescriptor>( + matcherMarshall1<ReturnType, ArgType1>, + reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AK); +} + +/// 2-arg overload +template <typename ReturnType, typename ArgType1, typename ArgType2> +std::unique_ptr<MatcherDescriptor> +makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1, ArgType2), + StringRef MatcherName) { + std::vector<ASTNodeKind> RetTypes; + BuildReturnTypeVector<ReturnType>::build(RetTypes); + ArgKind AKs[] = { ArgTypeTraits<ArgType1>::getKind(), + ArgTypeTraits<ArgType2>::getKind() }; + return std::make_unique<FixedArgCountMatcherDescriptor>( + matcherMarshall2<ReturnType, ArgType1, ArgType2>, + reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AKs); +} + +template <typename ReturnType> +std::unique_ptr<MatcherDescriptor> makeMatcherRegexMarshall( + ReturnType (*FuncFlags)(llvm::StringRef, llvm::Regex::RegexFlags), + ReturnType (*Func)(llvm::StringRef)) { + std::vector<ASTNodeKind> RetTypes; + BuildReturnTypeVector<ReturnType>::build(RetTypes); + return std::make_unique<RegexMatcherDescriptor<ReturnType>>(FuncFlags, Func, + RetTypes); +} + +/// Variadic overload. +template <typename ResultT, typename ArgT, + ResultT (*Func)(ArrayRef<const ArgT *>)> +std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall( + ast_matchers::internal::VariadicFunction<ResultT, ArgT, Func> VarFunc, + StringRef MatcherName) { + return std::make_unique<VariadicFuncMatcherDescriptor>(VarFunc, MatcherName); +} + +/// Overload for VariadicDynCastAllOfMatchers. +/// +/// Not strictly necessary, but DynCastAllOfMatcherDescriptor gives us better +/// completion results for that type of matcher. +template <typename BaseT, typename DerivedT> +std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall( + ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT> + VarFunc, + StringRef MatcherName) { + return std::make_unique<DynCastAllOfMatcherDescriptor>(VarFunc, MatcherName); +} + +/// Argument adaptative overload. +template <template <typename ToArg, typename FromArg> class ArgumentAdapterT, + typename FromTypes, typename ToTypes> +std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall( + ast_matchers::internal::ArgumentAdaptingMatcherFunc<ArgumentAdapterT, + FromTypes, ToTypes>, + StringRef MatcherName) { + std::vector<std::unique_ptr<MatcherDescriptor>> Overloads; + AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes, ToTypes>(MatcherName, + Overloads); + return std::make_unique<OverloadedMatcherDescriptor>(Overloads); +} + +template <template <typename ToArg, typename FromArg> class ArgumentAdapterT, + typename FromTypes, typename ToTypes> +template <typename FromTypeList> +inline void AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes, + ToTypes>::collect(FromTypeList) { + Out.push_back(makeMatcherAutoMarshall( + &AdaptativeFunc::template create<typename FromTypeList::head>, Name)); + collect(typename FromTypeList::tail()); +} + +/// Variadic operator overload. +template <unsigned MinCount, unsigned MaxCount> +std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall( + ast_matchers::internal::VariadicOperatorMatcherFunc<MinCount, MaxCount> + Func, + StringRef MatcherName) { + return std::make_unique<VariadicOperatorMatcherDescriptor>( + MinCount, MaxCount, Func.Op, MatcherName); +} + +template <typename CladeType, typename... MatcherT> +std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall( + ast_matchers::internal::MapAnyOfMatcherImpl<CladeType, MatcherT...>, + StringRef MatcherName) { + return std::make_unique<MapAnyOfMatcherDescriptor>( + ASTNodeKind::getFromNodeKind<CladeType>(), + std::vector<ASTNodeKind>{ASTNodeKind::getFromNodeKind<MatcherT>()...}); +} + +} // namespace internal +} // namespace dynamic +} // namespace ast_matchers +} // namespace clang + +#endif // LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Parser.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Parser.cpp new file mode 100644 index 000000000000..6a16c2184fcf --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Parser.cpp @@ -0,0 +1,926 @@ +//===- Parser.cpp - Matcher expression parser -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Recursive parser implementation for the matcher expression grammar. +/// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/Dynamic/Parser.h" +#include "clang/ASTMatchers/ASTMatchersInternal.h" +#include "clang/ASTMatchers/Dynamic/Diagnostics.h" +#include "clang/ASTMatchers/Dynamic/Registry.h" +#include "clang/Basic/CharInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" +#include <algorithm> +#include <cassert> +#include <cerrno> +#include <cstddef> +#include <cstdlib> +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace clang { +namespace ast_matchers { +namespace dynamic { + +/// Simple structure to hold information for one token from the parser. +struct Parser::TokenInfo { + /// Different possible tokens. + enum TokenKind { + TK_Eof, + TK_NewLine, + TK_OpenParen, + TK_CloseParen, + TK_Comma, + TK_Period, + TK_Literal, + TK_Ident, + TK_InvalidChar, + TK_Error, + TK_CodeCompletion + }; + + /// Some known identifiers. + static const char* const ID_Bind; + static const char *const ID_With; + + TokenInfo() = default; + + StringRef Text; + TokenKind Kind = TK_Eof; + SourceRange Range; + VariantValue Value; +}; + +const char* const Parser::TokenInfo::ID_Bind = "bind"; +const char *const Parser::TokenInfo::ID_With = "with"; + +/// Simple tokenizer for the parser. +class Parser::CodeTokenizer { +public: + explicit CodeTokenizer(StringRef &MatcherCode, Diagnostics *Error) + : Code(MatcherCode), StartOfLine(MatcherCode), Error(Error) { + NextToken = getNextToken(); + } + + CodeTokenizer(StringRef &MatcherCode, Diagnostics *Error, + unsigned CodeCompletionOffset) + : Code(MatcherCode), StartOfLine(MatcherCode), Error(Error), + CodeCompletionLocation(MatcherCode.data() + CodeCompletionOffset) { + NextToken = getNextToken(); + } + + /// Returns but doesn't consume the next token. + const TokenInfo &peekNextToken() const { return NextToken; } + + /// Consumes and returns the next token. + TokenInfo consumeNextToken() { + TokenInfo ThisToken = NextToken; + NextToken = getNextToken(); + return ThisToken; + } + + TokenInfo SkipNewlines() { + while (NextToken.Kind == TokenInfo::TK_NewLine) + NextToken = getNextToken(); + return NextToken; + } + + TokenInfo consumeNextTokenIgnoreNewlines() { + SkipNewlines(); + if (NextToken.Kind == TokenInfo::TK_Eof) + return NextToken; + return consumeNextToken(); + } + + TokenInfo::TokenKind nextTokenKind() const { return NextToken.Kind; } + +private: + TokenInfo getNextToken() { + consumeWhitespace(); + TokenInfo Result; + Result.Range.Start = currentLocation(); + + if (CodeCompletionLocation && CodeCompletionLocation <= Code.data()) { + Result.Kind = TokenInfo::TK_CodeCompletion; + Result.Text = StringRef(CodeCompletionLocation, 0); + CodeCompletionLocation = nullptr; + return Result; + } + + if (Code.empty()) { + Result.Kind = TokenInfo::TK_Eof; + Result.Text = ""; + return Result; + } + + switch (Code[0]) { + case '#': + Code = Code.drop_until([](char c) { return c == '\n'; }); + return getNextToken(); + case ',': + Result.Kind = TokenInfo::TK_Comma; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(); + break; + case '.': + Result.Kind = TokenInfo::TK_Period; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(); + break; + case '\n': + ++Line; + StartOfLine = Code.drop_front(); + Result.Kind = TokenInfo::TK_NewLine; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(); + break; + case '(': + Result.Kind = TokenInfo::TK_OpenParen; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(); + break; + case ')': + Result.Kind = TokenInfo::TK_CloseParen; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(); + break; + + case '"': + case '\'': + // Parse a string literal. + consumeStringLiteral(&Result); + break; + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + // Parse an unsigned and float literal. + consumeNumberLiteral(&Result); + break; + + default: + if (isAlphanumeric(Code[0])) { + // Parse an identifier + size_t TokenLength = 1; + while (true) { + // A code completion location in/immediately after an identifier will + // cause the portion of the identifier before the code completion + // location to become a code completion token. + if (CodeCompletionLocation == Code.data() + TokenLength) { + CodeCompletionLocation = nullptr; + Result.Kind = TokenInfo::TK_CodeCompletion; + Result.Text = Code.substr(0, TokenLength); + Code = Code.drop_front(TokenLength); + return Result; + } + if (TokenLength == Code.size() || !isAlphanumeric(Code[TokenLength])) + break; + ++TokenLength; + } + if (TokenLength == 4 && Code.starts_with("true")) { + Result.Kind = TokenInfo::TK_Literal; + Result.Value = true; + } else if (TokenLength == 5 && Code.starts_with("false")) { + Result.Kind = TokenInfo::TK_Literal; + Result.Value = false; + } else { + Result.Kind = TokenInfo::TK_Ident; + Result.Text = Code.substr(0, TokenLength); + } + Code = Code.drop_front(TokenLength); + } else { + Result.Kind = TokenInfo::TK_InvalidChar; + Result.Text = Code.substr(0, 1); + Code = Code.drop_front(1); + } + break; + } + + Result.Range.End = currentLocation(); + return Result; + } + + /// Consume an unsigned and float literal. + void consumeNumberLiteral(TokenInfo *Result) { + bool isFloatingLiteral = false; + unsigned Length = 1; + if (Code.size() > 1) { + // Consume the 'x' or 'b' radix modifier, if present. + switch (toLowercase(Code[1])) { + case 'x': case 'b': Length = 2; + } + } + while (Length < Code.size() && isHexDigit(Code[Length])) + ++Length; + + // Try to recognize a floating point literal. + while (Length < Code.size()) { + char c = Code[Length]; + if (c == '-' || c == '+' || c == '.' || isHexDigit(c)) { + isFloatingLiteral = true; + Length++; + } else { + break; + } + } + + Result->Text = Code.substr(0, Length); + Code = Code.drop_front(Length); + + if (isFloatingLiteral) { + char *end; + errno = 0; + std::string Text = Result->Text.str(); + double doubleValue = strtod(Text.c_str(), &end); + if (*end == 0 && errno == 0) { + Result->Kind = TokenInfo::TK_Literal; + Result->Value = doubleValue; + return; + } + } else { + unsigned Value; + if (!Result->Text.getAsInteger(0, Value)) { + Result->Kind = TokenInfo::TK_Literal; + Result->Value = Value; + return; + } + } + + SourceRange Range; + Range.Start = Result->Range.Start; + Range.End = currentLocation(); + Error->addError(Range, Error->ET_ParserNumberError) << Result->Text; + Result->Kind = TokenInfo::TK_Error; + } + + /// Consume a string literal. + /// + /// \c Code must be positioned at the start of the literal (the opening + /// quote). Consumed until it finds the same closing quote character. + void consumeStringLiteral(TokenInfo *Result) { + bool InEscape = false; + const char Marker = Code[0]; + for (size_t Length = 1, Size = Code.size(); Length != Size; ++Length) { + if (InEscape) { + InEscape = false; + continue; + } + if (Code[Length] == '\\') { + InEscape = true; + continue; + } + if (Code[Length] == Marker) { + Result->Kind = TokenInfo::TK_Literal; + Result->Text = Code.substr(0, Length + 1); + Result->Value = Code.substr(1, Length - 1); + Code = Code.drop_front(Length + 1); + return; + } + } + + StringRef ErrorText = Code; + Code = Code.drop_front(Code.size()); + SourceRange Range; + Range.Start = Result->Range.Start; + Range.End = currentLocation(); + Error->addError(Range, Error->ET_ParserStringError) << ErrorText; + Result->Kind = TokenInfo::TK_Error; + } + + /// Consume all leading whitespace from \c Code. + void consumeWhitespace() { + // Don't trim newlines. + Code = Code.ltrim(" \t\v\f\r"); + } + + SourceLocation currentLocation() { + SourceLocation Location; + Location.Line = Line; + Location.Column = Code.data() - StartOfLine.data() + 1; + return Location; + } + + StringRef &Code; + StringRef StartOfLine; + unsigned Line = 1; + Diagnostics *Error; + TokenInfo NextToken; + const char *CodeCompletionLocation = nullptr; +}; + +Parser::Sema::~Sema() = default; + +std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes( + llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> Context) { + return {}; +} + +std::vector<MatcherCompletion> +Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> AcceptedTypes) { + return {}; +} + +struct Parser::ScopedContextEntry { + Parser *P; + + ScopedContextEntry(Parser *P, MatcherCtor C) : P(P) { + P->ContextStack.push_back(std::make_pair(C, 0u)); + } + + ~ScopedContextEntry() { + P->ContextStack.pop_back(); + } + + void nextArg() { + ++P->ContextStack.back().second; + } +}; + +/// Parse expressions that start with an identifier. +/// +/// This function can parse named values and matchers. +/// In case of failure it will try to determine the user's intent to give +/// an appropriate error message. +bool Parser::parseIdentifierPrefixImpl(VariantValue *Value) { + const TokenInfo NameToken = Tokenizer->consumeNextToken(); + + if (Tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) { + // Parse as a named value. + if (const VariantValue NamedValue = + NamedValues ? NamedValues->lookup(NameToken.Text) + : VariantValue()) { + + if (Tokenizer->nextTokenKind() != TokenInfo::TK_Period) { + *Value = NamedValue; + return true; + } + + std::string BindID; + Tokenizer->consumeNextToken(); + TokenInfo ChainCallToken = Tokenizer->consumeNextToken(); + if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) { + addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1)); + return false; + } + + if (ChainCallToken.Kind != TokenInfo::TK_Ident || + (ChainCallToken.Text != TokenInfo::ID_Bind && + ChainCallToken.Text != TokenInfo::ID_With)) { + Error->addError(ChainCallToken.Range, + Error->ET_ParserMalformedChainedExpr); + return false; + } + if (ChainCallToken.Text == TokenInfo::ID_With) { + + Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error, + NameToken.Text, NameToken.Range); + + Error->addError(ChainCallToken.Range, + Error->ET_RegistryMatcherNoWithSupport); + return false; + } + if (!parseBindID(BindID)) + return false; + + assert(NamedValue.isMatcher()); + std::optional<DynTypedMatcher> Result = + NamedValue.getMatcher().getSingleMatcher(); + if (Result) { + std::optional<DynTypedMatcher> Bound = Result->tryBind(BindID); + if (Bound) { + *Value = VariantMatcher::SingleMatcher(*Bound); + return true; + } + } + return false; + } + + if (Tokenizer->nextTokenKind() == TokenInfo::TK_NewLine) { + Error->addError(Tokenizer->peekNextToken().Range, + Error->ET_ParserNoOpenParen) + << "NewLine"; + return false; + } + + // If the syntax is correct and the name is not a matcher either, report + // unknown named value. + if ((Tokenizer->nextTokenKind() == TokenInfo::TK_Comma || + Tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen || + Tokenizer->nextTokenKind() == TokenInfo::TK_NewLine || + Tokenizer->nextTokenKind() == TokenInfo::TK_Eof) && + !S->lookupMatcherCtor(NameToken.Text)) { + Error->addError(NameToken.Range, Error->ET_RegistryValueNotFound) + << NameToken.Text; + return false; + } + // Otherwise, fallback to the matcher parser. + } + + Tokenizer->SkipNewlines(); + + assert(NameToken.Kind == TokenInfo::TK_Ident); + TokenInfo OpenToken = Tokenizer->consumeNextToken(); + if (OpenToken.Kind != TokenInfo::TK_OpenParen) { + Error->addError(OpenToken.Range, Error->ET_ParserNoOpenParen) + << OpenToken.Text; + return false; + } + + std::optional<MatcherCtor> Ctor = S->lookupMatcherCtor(NameToken.Text); + + // Parse as a matcher expression. + return parseMatcherExpressionImpl(NameToken, OpenToken, Ctor, Value); +} + +bool Parser::parseBindID(std::string &BindID) { + // Parse the parenthesized argument to .bind("foo") + const TokenInfo OpenToken = Tokenizer->consumeNextToken(); + const TokenInfo IDToken = Tokenizer->consumeNextTokenIgnoreNewlines(); + const TokenInfo CloseToken = Tokenizer->consumeNextTokenIgnoreNewlines(); + + // TODO: We could use different error codes for each/some to be more + // explicit about the syntax error. + if (OpenToken.Kind != TokenInfo::TK_OpenParen) { + Error->addError(OpenToken.Range, Error->ET_ParserMalformedBindExpr); + return false; + } + if (IDToken.Kind != TokenInfo::TK_Literal || !IDToken.Value.isString()) { + Error->addError(IDToken.Range, Error->ET_ParserMalformedBindExpr); + return false; + } + if (CloseToken.Kind != TokenInfo::TK_CloseParen) { + Error->addError(CloseToken.Range, Error->ET_ParserMalformedBindExpr); + return false; + } + BindID = IDToken.Value.getString(); + return true; +} + +bool Parser::parseMatcherBuilder(MatcherCtor Ctor, const TokenInfo &NameToken, + const TokenInfo &OpenToken, + VariantValue *Value) { + std::vector<ParserValue> Args; + TokenInfo EndToken; + + Tokenizer->SkipNewlines(); + + { + ScopedContextEntry SCE(this, Ctor); + + while (Tokenizer->nextTokenKind() != TokenInfo::TK_Eof) { + if (Tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) { + // End of args. + EndToken = Tokenizer->consumeNextToken(); + break; + } + if (!Args.empty()) { + // We must find a , token to continue. + TokenInfo CommaToken = Tokenizer->consumeNextToken(); + if (CommaToken.Kind != TokenInfo::TK_Comma) { + Error->addError(CommaToken.Range, Error->ET_ParserNoComma) + << CommaToken.Text; + return false; + } + } + + Diagnostics::Context Ctx(Diagnostics::Context::MatcherArg, Error, + NameToken.Text, NameToken.Range, + Args.size() + 1); + ParserValue ArgValue; + Tokenizer->SkipNewlines(); + + if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_CodeCompletion) { + addExpressionCompletions(); + return false; + } + + TokenInfo NodeMatcherToken = Tokenizer->consumeNextToken(); + + if (NodeMatcherToken.Kind != TokenInfo::TK_Ident) { + Error->addError(NameToken.Range, Error->ET_ParserFailedToBuildMatcher) + << NameToken.Text; + return false; + } + + ArgValue.Text = NodeMatcherToken.Text; + ArgValue.Range = NodeMatcherToken.Range; + + std::optional<MatcherCtor> MappedMatcher = + S->lookupMatcherCtor(ArgValue.Text); + + if (!MappedMatcher) { + Error->addError(NodeMatcherToken.Range, + Error->ET_RegistryMatcherNotFound) + << NodeMatcherToken.Text; + return false; + } + + ASTNodeKind NK = S->nodeMatcherType(*MappedMatcher); + + if (NK.isNone()) { + Error->addError(NodeMatcherToken.Range, + Error->ET_RegistryNonNodeMatcher) + << NodeMatcherToken.Text; + return false; + } + + ArgValue.Value = NK; + + Tokenizer->SkipNewlines(); + Args.push_back(ArgValue); + + SCE.nextArg(); + } + } + + if (EndToken.Kind == TokenInfo::TK_Eof) { + Error->addError(OpenToken.Range, Error->ET_ParserNoCloseParen); + return false; + } + + internal::MatcherDescriptorPtr BuiltCtor = + S->buildMatcherCtor(Ctor, NameToken.Range, Args, Error); + + if (!BuiltCtor.get()) { + Error->addError(NameToken.Range, Error->ET_ParserFailedToBuildMatcher) + << NameToken.Text; + return false; + } + + std::string BindID; + if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) { + Tokenizer->consumeNextToken(); + TokenInfo ChainCallToken = Tokenizer->consumeNextToken(); + if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) { + addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1)); + addCompletion(ChainCallToken, MatcherCompletion("with(", "with", 1)); + return false; + } + if (ChainCallToken.Kind != TokenInfo::TK_Ident || + (ChainCallToken.Text != TokenInfo::ID_Bind && + ChainCallToken.Text != TokenInfo::ID_With)) { + Error->addError(ChainCallToken.Range, + Error->ET_ParserMalformedChainedExpr); + return false; + } + if (ChainCallToken.Text == TokenInfo::ID_Bind) { + if (!parseBindID(BindID)) + return false; + Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error, + NameToken.Text, NameToken.Range); + SourceRange MatcherRange = NameToken.Range; + MatcherRange.End = ChainCallToken.Range.End; + VariantMatcher Result = S->actOnMatcherExpression( + BuiltCtor.get(), MatcherRange, BindID, {}, Error); + if (Result.isNull()) + return false; + + *Value = Result; + return true; + } else if (ChainCallToken.Text == TokenInfo::ID_With) { + Tokenizer->SkipNewlines(); + + if (Tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) { + StringRef ErrTxt = Tokenizer->nextTokenKind() == TokenInfo::TK_Eof + ? StringRef("EOF") + : Tokenizer->peekNextToken().Text; + Error->addError(Tokenizer->peekNextToken().Range, + Error->ET_ParserNoOpenParen) + << ErrTxt; + return false; + } + + TokenInfo WithOpenToken = Tokenizer->consumeNextToken(); + + return parseMatcherExpressionImpl(NameToken, WithOpenToken, + BuiltCtor.get(), Value); + } + } + + Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error, + NameToken.Text, NameToken.Range); + SourceRange MatcherRange = NameToken.Range; + MatcherRange.End = EndToken.Range.End; + VariantMatcher Result = S->actOnMatcherExpression( + BuiltCtor.get(), MatcherRange, BindID, {}, Error); + if (Result.isNull()) + return false; + + *Value = Result; + return true; +} + +/// Parse and validate a matcher expression. +/// \return \c true on success, in which case \c Value has the matcher parsed. +/// If the input is malformed, or some argument has an error, it +/// returns \c false. +bool Parser::parseMatcherExpressionImpl(const TokenInfo &NameToken, + const TokenInfo &OpenToken, + std::optional<MatcherCtor> Ctor, + VariantValue *Value) { + if (!Ctor) { + Error->addError(NameToken.Range, Error->ET_RegistryMatcherNotFound) + << NameToken.Text; + // Do not return here. We need to continue to give completion suggestions. + } + + if (Ctor && *Ctor && S->isBuilderMatcher(*Ctor)) + return parseMatcherBuilder(*Ctor, NameToken, OpenToken, Value); + + std::vector<ParserValue> Args; + TokenInfo EndToken; + + Tokenizer->SkipNewlines(); + + { + ScopedContextEntry SCE(this, Ctor.value_or(nullptr)); + + while (Tokenizer->nextTokenKind() != TokenInfo::TK_Eof) { + if (Tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) { + // End of args. + EndToken = Tokenizer->consumeNextToken(); + break; + } + if (!Args.empty()) { + // We must find a , token to continue. + const TokenInfo CommaToken = Tokenizer->consumeNextToken(); + if (CommaToken.Kind != TokenInfo::TK_Comma) { + Error->addError(CommaToken.Range, Error->ET_ParserNoComma) + << CommaToken.Text; + return false; + } + } + + Diagnostics::Context Ctx(Diagnostics::Context::MatcherArg, Error, + NameToken.Text, NameToken.Range, + Args.size() + 1); + ParserValue ArgValue; + Tokenizer->SkipNewlines(); + ArgValue.Text = Tokenizer->peekNextToken().Text; + ArgValue.Range = Tokenizer->peekNextToken().Range; + if (!parseExpressionImpl(&ArgValue.Value)) { + return false; + } + + Tokenizer->SkipNewlines(); + Args.push_back(ArgValue); + SCE.nextArg(); + } + } + + if (EndToken.Kind == TokenInfo::TK_Eof) { + Error->addError(OpenToken.Range, Error->ET_ParserNoCloseParen); + return false; + } + + std::string BindID; + if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) { + Tokenizer->consumeNextToken(); + TokenInfo ChainCallToken = Tokenizer->consumeNextToken(); + if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) { + addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1)); + return false; + } + + if (ChainCallToken.Kind != TokenInfo::TK_Ident) { + Error->addError(ChainCallToken.Range, + Error->ET_ParserMalformedChainedExpr); + return false; + } + if (ChainCallToken.Text == TokenInfo::ID_With) { + + Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error, + NameToken.Text, NameToken.Range); + + Error->addError(ChainCallToken.Range, + Error->ET_RegistryMatcherNoWithSupport); + return false; + } + if (ChainCallToken.Text != TokenInfo::ID_Bind) { + Error->addError(ChainCallToken.Range, + Error->ET_ParserMalformedChainedExpr); + return false; + } + if (!parseBindID(BindID)) + return false; + } + + if (!Ctor) + return false; + + // Merge the start and end infos. + Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error, + NameToken.Text, NameToken.Range); + SourceRange MatcherRange = NameToken.Range; + MatcherRange.End = EndToken.Range.End; + VariantMatcher Result = S->actOnMatcherExpression( + *Ctor, MatcherRange, BindID, Args, Error); + if (Result.isNull()) return false; + + *Value = Result; + return true; +} + +// If the prefix of this completion matches the completion token, add it to +// Completions minus the prefix. +void Parser::addCompletion(const TokenInfo &CompToken, + const MatcherCompletion& Completion) { + if (StringRef(Completion.TypedText).starts_with(CompToken.Text) && + Completion.Specificity > 0) { + Completions.emplace_back(Completion.TypedText.substr(CompToken.Text.size()), + Completion.MatcherDecl, Completion.Specificity); + } +} + +std::vector<MatcherCompletion> Parser::getNamedValueCompletions( + ArrayRef<ArgKind> AcceptedTypes) { + if (!NamedValues) return std::vector<MatcherCompletion>(); + std::vector<MatcherCompletion> Result; + for (const auto &Entry : *NamedValues) { + unsigned Specificity; + if (Entry.getValue().isConvertibleTo(AcceptedTypes, &Specificity)) { + std::string Decl = + (Entry.getValue().getTypeAsString() + " " + Entry.getKey()).str(); + Result.emplace_back(Entry.getKey(), Decl, Specificity); + } + } + return Result; +} + +void Parser::addExpressionCompletions() { + const TokenInfo CompToken = Tokenizer->consumeNextTokenIgnoreNewlines(); + assert(CompToken.Kind == TokenInfo::TK_CodeCompletion); + + // We cannot complete code if there is an invalid element on the context + // stack. + for (ContextStackTy::iterator I = ContextStack.begin(), + E = ContextStack.end(); + I != E; ++I) { + if (!I->first) + return; + } + + auto AcceptedTypes = S->getAcceptedCompletionTypes(ContextStack); + for (const auto &Completion : S->getMatcherCompletions(AcceptedTypes)) { + addCompletion(CompToken, Completion); + } + + for (const auto &Completion : getNamedValueCompletions(AcceptedTypes)) { + addCompletion(CompToken, Completion); + } +} + +/// Parse an <Expression> +bool Parser::parseExpressionImpl(VariantValue *Value) { + switch (Tokenizer->nextTokenKind()) { + case TokenInfo::TK_Literal: + *Value = Tokenizer->consumeNextToken().Value; + return true; + + case TokenInfo::TK_Ident: + return parseIdentifierPrefixImpl(Value); + + case TokenInfo::TK_CodeCompletion: + addExpressionCompletions(); + return false; + + case TokenInfo::TK_Eof: + Error->addError(Tokenizer->consumeNextToken().Range, + Error->ET_ParserNoCode); + return false; + + case TokenInfo::TK_Error: + // This error was already reported by the tokenizer. + return false; + case TokenInfo::TK_NewLine: + case TokenInfo::TK_OpenParen: + case TokenInfo::TK_CloseParen: + case TokenInfo::TK_Comma: + case TokenInfo::TK_Period: + case TokenInfo::TK_InvalidChar: + const TokenInfo Token = Tokenizer->consumeNextToken(); + Error->addError(Token.Range, Error->ET_ParserInvalidToken) + << (Token.Kind == TokenInfo::TK_NewLine ? "NewLine" : Token.Text); + return false; + } + + llvm_unreachable("Unknown token kind."); +} + +static llvm::ManagedStatic<Parser::RegistrySema> DefaultRegistrySema; + +Parser::Parser(CodeTokenizer *Tokenizer, Sema *S, + const NamedValueMap *NamedValues, Diagnostics *Error) + : Tokenizer(Tokenizer), S(S ? S : &*DefaultRegistrySema), + NamedValues(NamedValues), Error(Error) {} + +Parser::RegistrySema::~RegistrySema() = default; + +std::optional<MatcherCtor> +Parser::RegistrySema::lookupMatcherCtor(StringRef MatcherName) { + return Registry::lookupMatcherCtor(MatcherName); +} + +VariantMatcher Parser::RegistrySema::actOnMatcherExpression( + MatcherCtor Ctor, SourceRange NameRange, StringRef BindID, + ArrayRef<ParserValue> Args, Diagnostics *Error) { + if (BindID.empty()) { + return Registry::constructMatcher(Ctor, NameRange, Args, Error); + } else { + return Registry::constructBoundMatcher(Ctor, NameRange, BindID, Args, + Error); + } +} + +std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes( + ArrayRef<std::pair<MatcherCtor, unsigned>> Context) { + return Registry::getAcceptedCompletionTypes(Context); +} + +std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions( + ArrayRef<ArgKind> AcceptedTypes) { + return Registry::getMatcherCompletions(AcceptedTypes); +} + +bool Parser::RegistrySema::isBuilderMatcher(MatcherCtor Ctor) const { + return Registry::isBuilderMatcher(Ctor); +} + +ASTNodeKind Parser::RegistrySema::nodeMatcherType(MatcherCtor Ctor) const { + return Registry::nodeMatcherType(Ctor); +} + +internal::MatcherDescriptorPtr +Parser::RegistrySema::buildMatcherCtor(MatcherCtor Ctor, SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) const { + return Registry::buildMatcherCtor(Ctor, NameRange, Args, Error); +} + +bool Parser::parseExpression(StringRef &Code, Sema *S, + const NamedValueMap *NamedValues, + VariantValue *Value, Diagnostics *Error) { + CodeTokenizer Tokenizer(Code, Error); + if (!Parser(&Tokenizer, S, NamedValues, Error).parseExpressionImpl(Value)) + return false; + auto NT = Tokenizer.peekNextToken(); + if (NT.Kind != TokenInfo::TK_Eof && NT.Kind != TokenInfo::TK_NewLine) { + Error->addError(Tokenizer.peekNextToken().Range, + Error->ET_ParserTrailingCode); + return false; + } + return true; +} + +std::vector<MatcherCompletion> +Parser::completeExpression(StringRef &Code, unsigned CompletionOffset, Sema *S, + const NamedValueMap *NamedValues) { + Diagnostics Error; + CodeTokenizer Tokenizer(Code, &Error, CompletionOffset); + Parser P(&Tokenizer, S, NamedValues, &Error); + VariantValue Dummy; + P.parseExpressionImpl(&Dummy); + + // Sort by specificity, then by name. + llvm::sort(P.Completions, + [](const MatcherCompletion &A, const MatcherCompletion &B) { + if (A.Specificity != B.Specificity) + return A.Specificity > B.Specificity; + return A.TypedText < B.TypedText; + }); + + return P.Completions; +} + +std::optional<DynTypedMatcher> +Parser::parseMatcherExpression(StringRef &Code, Sema *S, + const NamedValueMap *NamedValues, + Diagnostics *Error) { + VariantValue Value; + if (!parseExpression(Code, S, NamedValues, &Value, Error)) + return std::nullopt; + if (!Value.isMatcher()) { + Error->addError(SourceRange(), Error->ET_ParserNotAMatcher); + return std::nullopt; + } + std::optional<DynTypedMatcher> Result = Value.getMatcher().getSingleMatcher(); + if (!Result) { + Error->addError(SourceRange(), Error->ET_ParserOverloadedType) + << Value.getTypeAsString(); + } + return Result; +} + +} // namespace dynamic +} // namespace ast_matchers +} // namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Registry.cpp new file mode 100644 index 000000000000..2c75e6beb743 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -0,0 +1,829 @@ +//===- Registry.cpp - Matcher registry ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Registry map populated at static initialization time. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/Dynamic/Registry.h" +#include "Marshallers.h" +#include "clang/AST/ASTTypeTraits.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/ASTMatchers/Dynamic/Diagnostics.h" +#include "clang/ASTMatchers/Dynamic/VariantValue.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <iterator> +#include <memory> +#include <optional> +#include <set> +#include <string> +#include <utility> +#include <vector> + +namespace clang { +namespace ast_matchers { +namespace dynamic { + +namespace { + +using internal::MatcherDescriptor; + +using ConstructorMap = + llvm::StringMap<std::unique_ptr<const MatcherDescriptor>>; + +class RegistryMaps { +public: + RegistryMaps(); + ~RegistryMaps(); + + const ConstructorMap &constructors() const { return Constructors; } + +private: + void registerMatcher(StringRef MatcherName, + std::unique_ptr<MatcherDescriptor> Callback); + + ConstructorMap Constructors; +}; + +} // namespace + +void RegistryMaps::registerMatcher( + StringRef MatcherName, std::unique_ptr<MatcherDescriptor> Callback) { + assert(!Constructors.contains(MatcherName)); + Constructors[MatcherName] = std::move(Callback); +} + +#define REGISTER_MATCHER(name) \ + registerMatcher(#name, internal::makeMatcherAutoMarshall( \ + ::clang::ast_matchers::name, #name)); + +#define REGISTER_MATCHER_OVERLOAD(name) \ + registerMatcher(#name, \ + std::make_unique<internal::OverloadedMatcherDescriptor>(name##Callbacks)) + +#define SPECIFIC_MATCHER_OVERLOAD(name, Id) \ + static_cast<::clang::ast_matchers::name##_Type##Id>( \ + ::clang::ast_matchers::name) + +#define MATCHER_OVERLOAD_ENTRY(name, Id) \ + internal::makeMatcherAutoMarshall(SPECIFIC_MATCHER_OVERLOAD(name, Id), \ + #name) + +#define REGISTER_OVERLOADED_2(name) \ + do { \ + std::unique_ptr<MatcherDescriptor> name##Callbacks[] = { \ + MATCHER_OVERLOAD_ENTRY(name, 0), \ + MATCHER_OVERLOAD_ENTRY(name, 1)}; \ + REGISTER_MATCHER_OVERLOAD(name); \ + } while (false) + +#define REGISTER_REGEX_MATCHER(name) \ + registerMatcher(#name, internal::makeMatcherRegexMarshall(name, name)) + +/// Generate a registry map with all the known matchers. +/// Please keep sorted alphabetically! +RegistryMaps::RegistryMaps() { + // TODO: Here is the list of the missing matchers, grouped by reason. + // + // Polymorphic + argument overload: + // findAll + // + // Other: + // equalsNode + + registerMatcher("mapAnyOf", + std::make_unique<internal::MapAnyOfBuilderDescriptor>()); + + REGISTER_OVERLOADED_2(callee); + REGISTER_OVERLOADED_2(hasPrefix); + REGISTER_OVERLOADED_2(hasType); + REGISTER_OVERLOADED_2(ignoringParens); + REGISTER_OVERLOADED_2(isDerivedFrom); + REGISTER_OVERLOADED_2(isDirectlyDerivedFrom); + REGISTER_OVERLOADED_2(isSameOrDerivedFrom); + REGISTER_OVERLOADED_2(loc); + REGISTER_OVERLOADED_2(pointsTo); + REGISTER_OVERLOADED_2(references); + REGISTER_OVERLOADED_2(thisPointerType); + + std::unique_ptr<MatcherDescriptor> equalsCallbacks[] = { + MATCHER_OVERLOAD_ENTRY(equals, 0), + MATCHER_OVERLOAD_ENTRY(equals, 1), + MATCHER_OVERLOAD_ENTRY(equals, 2), + }; + REGISTER_MATCHER_OVERLOAD(equals); + + REGISTER_REGEX_MATCHER(isExpansionInFileMatching); + REGISTER_REGEX_MATCHER(matchesName); + REGISTER_REGEX_MATCHER(matchesSelector); + + REGISTER_MATCHER(accessSpecDecl); + REGISTER_MATCHER(addrLabelExpr); + REGISTER_MATCHER(alignOfExpr); + REGISTER_MATCHER(allOf); + REGISTER_MATCHER(anyOf); + REGISTER_MATCHER(anything); + REGISTER_MATCHER(arrayInitIndexExpr); + REGISTER_MATCHER(arrayInitLoopExpr); + REGISTER_MATCHER(argumentCountIs); + REGISTER_MATCHER(argumentCountAtLeast); + REGISTER_MATCHER(arraySubscriptExpr); + REGISTER_MATCHER(arrayType); + REGISTER_MATCHER(asString); + REGISTER_MATCHER(asmStmt); + REGISTER_MATCHER(atomicExpr); + REGISTER_MATCHER(atomicType); + REGISTER_MATCHER(attr); + REGISTER_MATCHER(autoType); + REGISTER_MATCHER(autoreleasePoolStmt) + REGISTER_MATCHER(binaryConditionalOperator); + REGISTER_MATCHER(binaryOperator); + REGISTER_MATCHER(binaryOperation); + REGISTER_MATCHER(bindingDecl); + REGISTER_MATCHER(blockDecl); + REGISTER_MATCHER(blockExpr); + REGISTER_MATCHER(blockPointerType); + REGISTER_MATCHER(booleanType); + REGISTER_MATCHER(breakStmt); + REGISTER_MATCHER(builtinType); + REGISTER_MATCHER(cStyleCastExpr); + REGISTER_MATCHER(callExpr); + REGISTER_MATCHER(capturesThis); + REGISTER_MATCHER(capturesVar); + REGISTER_MATCHER(caseStmt); + REGISTER_MATCHER(castExpr); + REGISTER_MATCHER(characterLiteral); + REGISTER_MATCHER(chooseExpr); + REGISTER_MATCHER(classTemplateDecl); + REGISTER_MATCHER(classTemplatePartialSpecializationDecl); + REGISTER_MATCHER(classTemplateSpecializationDecl); + REGISTER_MATCHER(complexType); + REGISTER_MATCHER(compoundLiteralExpr); + REGISTER_MATCHER(compoundStmt); + REGISTER_MATCHER(coawaitExpr); + REGISTER_MATCHER(conceptDecl); + REGISTER_MATCHER(conditionalOperator); + REGISTER_MATCHER(constantArrayType); + REGISTER_MATCHER(constantExpr); + REGISTER_MATCHER(containsDeclaration); + REGISTER_MATCHER(continueStmt); + REGISTER_MATCHER(convertVectorExpr); + REGISTER_MATCHER(coreturnStmt); + REGISTER_MATCHER(coroutineBodyStmt); + REGISTER_MATCHER(coyieldExpr); + REGISTER_MATCHER(cudaKernelCallExpr); + REGISTER_MATCHER(cxxBaseSpecifier); + REGISTER_MATCHER(cxxBindTemporaryExpr); + REGISTER_MATCHER(cxxBoolLiteral); + REGISTER_MATCHER(cxxCatchStmt); + REGISTER_MATCHER(cxxConstCastExpr); + REGISTER_MATCHER(cxxConstructExpr); + REGISTER_MATCHER(cxxConstructorDecl); + REGISTER_MATCHER(cxxConversionDecl); + REGISTER_MATCHER(cxxCtorInitializer); + REGISTER_MATCHER(cxxDeductionGuideDecl); + REGISTER_MATCHER(cxxDefaultArgExpr); + REGISTER_MATCHER(cxxDeleteExpr); + REGISTER_MATCHER(cxxDependentScopeMemberExpr); + REGISTER_MATCHER(cxxDestructorDecl); + REGISTER_MATCHER(cxxDynamicCastExpr); + REGISTER_MATCHER(cxxFoldExpr); + REGISTER_MATCHER(cxxForRangeStmt); + REGISTER_MATCHER(cxxFunctionalCastExpr); + REGISTER_MATCHER(cxxMemberCallExpr); + REGISTER_MATCHER(cxxMethodDecl); + REGISTER_MATCHER(cxxNewExpr); + REGISTER_MATCHER(cxxNoexceptExpr); + REGISTER_MATCHER(cxxNullPtrLiteralExpr); + REGISTER_MATCHER(cxxOperatorCallExpr); + REGISTER_MATCHER(cxxRecordDecl); + REGISTER_MATCHER(cxxReinterpretCastExpr); + REGISTER_MATCHER(cxxRewrittenBinaryOperator); + REGISTER_MATCHER(cxxStaticCastExpr); + REGISTER_MATCHER(cxxStdInitializerListExpr); + REGISTER_MATCHER(cxxTemporaryObjectExpr); + REGISTER_MATCHER(cxxThisExpr); + REGISTER_MATCHER(cxxThrowExpr); + REGISTER_MATCHER(cxxTryStmt); + REGISTER_MATCHER(cxxUnresolvedConstructExpr); + REGISTER_MATCHER(decayedType); + REGISTER_MATCHER(decl); + REGISTER_MATCHER(decompositionDecl); + REGISTER_MATCHER(declCountIs); + REGISTER_MATCHER(declRefExpr); + REGISTER_MATCHER(declStmt); + REGISTER_MATCHER(declaratorDecl); + REGISTER_MATCHER(decltypeType); + REGISTER_MATCHER(deducedTemplateSpecializationType); + REGISTER_MATCHER(defaultStmt); + REGISTER_MATCHER(dependentCoawaitExpr); + REGISTER_MATCHER(dependentSizedArrayType); + REGISTER_MATCHER(dependentSizedExtVectorType); + REGISTER_MATCHER(designatedInitExpr); + REGISTER_MATCHER(designatorCountIs); + REGISTER_MATCHER(doStmt); + REGISTER_MATCHER(eachOf); + REGISTER_MATCHER(elaboratedType); + REGISTER_MATCHER(elaboratedTypeLoc); + REGISTER_MATCHER(usingType); + REGISTER_MATCHER(enumConstantDecl); + REGISTER_MATCHER(enumDecl); + REGISTER_MATCHER(enumType); + REGISTER_MATCHER(equalsBoundNode); + REGISTER_MATCHER(equalsIntegralValue); + REGISTER_MATCHER(explicitCastExpr); + REGISTER_MATCHER(expr); + REGISTER_MATCHER(exprWithCleanups); + REGISTER_MATCHER(fieldDecl); + REGISTER_MATCHER(fixedPointLiteral); + REGISTER_MATCHER(floatLiteral); + REGISTER_MATCHER(forCallable); + REGISTER_MATCHER(forDecomposition); + REGISTER_MATCHER(forEach); + REGISTER_MATCHER(forEachArgumentWithParam); + REGISTER_MATCHER(forEachArgumentWithParamType); + REGISTER_MATCHER(forEachConstructorInitializer); + REGISTER_MATCHER(forEachDescendant); + REGISTER_MATCHER(forEachLambdaCapture); + REGISTER_MATCHER(forEachOverridden); + REGISTER_MATCHER(forEachSwitchCase); + REGISTER_MATCHER(forEachTemplateArgument); + REGISTER_MATCHER(forField); + REGISTER_MATCHER(forFunction); + REGISTER_MATCHER(forStmt); + REGISTER_MATCHER(friendDecl); + REGISTER_MATCHER(functionDecl); + REGISTER_MATCHER(functionProtoType); + REGISTER_MATCHER(functionTemplateDecl); + REGISTER_MATCHER(functionType); + REGISTER_MATCHER(genericSelectionExpr); + REGISTER_MATCHER(gnuNullExpr); + REGISTER_MATCHER(gotoStmt); + REGISTER_MATCHER(has); + REGISTER_MATCHER(hasAncestor); + REGISTER_MATCHER(hasAnyArgument); + REGISTER_MATCHER(hasAnyBase); + REGISTER_MATCHER(hasAnyBinding); + REGISTER_MATCHER(hasAnyBody); + REGISTER_MATCHER(hasAnyCapture); + REGISTER_MATCHER(hasAnyClause); + REGISTER_MATCHER(hasAnyConstructorInitializer); + REGISTER_MATCHER(hasAnyDeclaration); + REGISTER_MATCHER(hasAnyName); + REGISTER_MATCHER(hasAnyOperatorName); + REGISTER_MATCHER(hasAnyOverloadedOperatorName); + REGISTER_MATCHER(hasAnyParameter); + REGISTER_MATCHER(hasAnyPlacementArg); + REGISTER_MATCHER(hasAnySelector); + REGISTER_MATCHER(hasAnySubstatement); + REGISTER_MATCHER(hasAnyTemplateArgument); + REGISTER_MATCHER(hasAnyTemplateArgumentLoc); + REGISTER_MATCHER(hasAnyUsingShadowDecl); + REGISTER_MATCHER(hasArgument); + REGISTER_MATCHER(hasArgumentOfType); + REGISTER_MATCHER(hasArraySize); + REGISTER_MATCHER(hasAttr); + REGISTER_MATCHER(hasAutomaticStorageDuration); + REGISTER_MATCHER(hasBase); + REGISTER_MATCHER(hasBinding); + REGISTER_MATCHER(hasBitWidth); + REGISTER_MATCHER(hasBody); + REGISTER_MATCHER(hasCanonicalType); + REGISTER_MATCHER(hasCaseConstant); + REGISTER_MATCHER(hasCastKind); + REGISTER_MATCHER(hasCondition); + REGISTER_MATCHER(hasConditionVariableStatement); + REGISTER_MATCHER(hasDecayedType); + REGISTER_MATCHER(hasDeclContext); + REGISTER_MATCHER(hasDeclaration); + REGISTER_MATCHER(hasDeducedType); + REGISTER_MATCHER(hasDefaultArgument); + REGISTER_MATCHER(hasDefinition); + REGISTER_MATCHER(hasDescendant); + REGISTER_MATCHER(hasDestinationType); + REGISTER_MATCHER(hasDirectBase); + REGISTER_MATCHER(hasDynamicExceptionSpec); + REGISTER_MATCHER(hasEitherOperand); + REGISTER_MATCHER(hasElementType); + REGISTER_MATCHER(hasElse); + REGISTER_MATCHER(hasExplicitSpecifier); + REGISTER_MATCHER(hasExternalFormalLinkage); + REGISTER_MATCHER(hasFalseExpression); + REGISTER_MATCHER(hasFoldInit); + REGISTER_MATCHER(hasGlobalStorage); + REGISTER_MATCHER(hasImplicitDestinationType); + REGISTER_MATCHER(hasInClassInitializer); + REGISTER_MATCHER(hasIncrement); + REGISTER_MATCHER(hasIndex); + REGISTER_MATCHER(hasInit); + REGISTER_MATCHER(hasInitializer); + REGISTER_MATCHER(hasInitStatement); + REGISTER_MATCHER(hasKeywordSelector); + REGISTER_MATCHER(hasLHS); + REGISTER_MATCHER(hasLocalQualifiers); + REGISTER_MATCHER(hasLocalStorage); + REGISTER_MATCHER(hasLoopInit); + REGISTER_MATCHER(hasLoopVariable); + REGISTER_MATCHER(hasMemberName); + REGISTER_MATCHER(hasMethod); + REGISTER_MATCHER(hasName); + REGISTER_MATCHER(hasNamedTypeLoc); + REGISTER_MATCHER(hasNullSelector); + REGISTER_MATCHER(hasObjectExpression); + REGISTER_MATCHER(hasOperands); + REGISTER_MATCHER(hasOperatorName); + REGISTER_MATCHER(hasOverloadedOperatorName); + REGISTER_MATCHER(hasParameter); + REGISTER_MATCHER(hasParent); + REGISTER_MATCHER(hasPattern); + REGISTER_MATCHER(hasPointeeLoc); + REGISTER_MATCHER(hasQualifier); + REGISTER_MATCHER(hasRHS); + REGISTER_MATCHER(hasRangeInit); + REGISTER_MATCHER(hasReceiver); + REGISTER_MATCHER(hasReceiverType); + REGISTER_MATCHER(hasReferentLoc); + REGISTER_MATCHER(hasReplacementType); + REGISTER_MATCHER(hasReturnTypeLoc); + REGISTER_MATCHER(hasReturnValue); + REGISTER_MATCHER(hasPlacementArg); + REGISTER_MATCHER(hasSelector); + REGISTER_MATCHER(hasSingleDecl); + REGISTER_MATCHER(hasSize); + REGISTER_MATCHER(hasSizeExpr); + REGISTER_MATCHER(hasSourceExpression); + REGISTER_MATCHER(hasSpecializedTemplate); + REGISTER_MATCHER(hasStaticStorageDuration); + REGISTER_MATCHER(hasStructuredBlock); + REGISTER_MATCHER(hasSyntacticForm); + REGISTER_MATCHER(hasTargetDecl); + REGISTER_MATCHER(hasTemplateArgument); + REGISTER_MATCHER(hasTemplateArgumentLoc); + REGISTER_MATCHER(hasThen); + REGISTER_MATCHER(hasThreadStorageDuration); + REGISTER_MATCHER(hasTrailingReturn); + REGISTER_MATCHER(hasTrueExpression); + REGISTER_MATCHER(hasTypeLoc); + REGISTER_MATCHER(hasUnaryOperand); + REGISTER_MATCHER(hasUnarySelector); + REGISTER_MATCHER(hasUnderlyingDecl); + REGISTER_MATCHER(hasUnderlyingType); + REGISTER_MATCHER(hasUnqualifiedDesugaredType); + REGISTER_MATCHER(hasUnqualifiedLoc); + REGISTER_MATCHER(hasValueType); + REGISTER_MATCHER(ifStmt); + REGISTER_MATCHER(ignoringElidableConstructorCall); + REGISTER_MATCHER(ignoringImpCasts); + REGISTER_MATCHER(ignoringImplicit); + REGISTER_MATCHER(ignoringParenCasts); + REGISTER_MATCHER(ignoringParenImpCasts); + REGISTER_MATCHER(imaginaryLiteral); + REGISTER_MATCHER(implicitCastExpr); + REGISTER_MATCHER(implicitValueInitExpr); + REGISTER_MATCHER(incompleteArrayType); + REGISTER_MATCHER(indirectFieldDecl); + REGISTER_MATCHER(initListExpr); + REGISTER_MATCHER(injectedClassNameType); + REGISTER_MATCHER(innerType); + REGISTER_MATCHER(integerLiteral); + REGISTER_MATCHER(invocation); + REGISTER_MATCHER(isAllowedToContainClauseKind); + REGISTER_MATCHER(isAnonymous); + REGISTER_MATCHER(isAnyCharacter); + REGISTER_MATCHER(isAnyPointer); + REGISTER_MATCHER(isArray); + REGISTER_MATCHER(isArrow); + REGISTER_MATCHER(isAssignmentOperator); + REGISTER_MATCHER(isAtPosition); + REGISTER_MATCHER(isBaseInitializer); + REGISTER_MATCHER(isBinaryFold); + REGISTER_MATCHER(isBitField); + REGISTER_MATCHER(isCatchAll); + REGISTER_MATCHER(isClass); + REGISTER_MATCHER(isClassMessage); + REGISTER_MATCHER(isClassMethod); + REGISTER_MATCHER(isComparisonOperator); + REGISTER_MATCHER(isConst); + REGISTER_MATCHER(isConstQualified); + REGISTER_MATCHER(isConsteval); + REGISTER_MATCHER(isConstexpr); + REGISTER_MATCHER(isConstinit); + REGISTER_MATCHER(isCopyAssignmentOperator); + REGISTER_MATCHER(isCopyConstructor); + REGISTER_MATCHER(isDefaultConstructor); + REGISTER_MATCHER(isDefaulted); + REGISTER_MATCHER(isDefinition); + REGISTER_MATCHER(isDelegatingConstructor); + REGISTER_MATCHER(isDeleted); + REGISTER_MATCHER(isEnum); + REGISTER_MATCHER(isExceptionVariable); + REGISTER_MATCHER(isExpandedFromMacro); + REGISTER_MATCHER(isExpansionInMainFile); + REGISTER_MATCHER(isExpansionInSystemHeader); + REGISTER_MATCHER(isExplicit); + REGISTER_MATCHER(isExplicitObjectMemberFunction); + REGISTER_MATCHER(isExplicitTemplateSpecialization); + REGISTER_MATCHER(isExpr); + REGISTER_MATCHER(isExternC); + REGISTER_MATCHER(isFinal); + REGISTER_MATCHER(isPrivateKind); + REGISTER_MATCHER(isFirstPrivateKind); + REGISTER_MATCHER(isImplicit); + REGISTER_MATCHER(isInAnonymousNamespace); + REGISTER_MATCHER(isInStdNamespace); + REGISTER_MATCHER(isInTemplateInstantiation); + REGISTER_MATCHER(isInitCapture); + REGISTER_MATCHER(isInline); + REGISTER_MATCHER(isInstanceMessage); + REGISTER_MATCHER(isInstanceMethod); + REGISTER_MATCHER(isInstantiated); + REGISTER_MATCHER(isInstantiationDependent); + REGISTER_MATCHER(isInteger); + REGISTER_MATCHER(isIntegral); + REGISTER_MATCHER(isLambda); + REGISTER_MATCHER(isLeftFold); + REGISTER_MATCHER(isListInitialization); + REGISTER_MATCHER(isMain); + REGISTER_MATCHER(isMemberInitializer); + REGISTER_MATCHER(isMoveAssignmentOperator); + REGISTER_MATCHER(isMoveConstructor); + REGISTER_MATCHER(isNoReturn); + REGISTER_MATCHER(isNoThrow); + REGISTER_MATCHER(isNoneKind); + REGISTER_MATCHER(isOverride); + REGISTER_MATCHER(isPrivate); + REGISTER_MATCHER(isProtected); + REGISTER_MATCHER(isPublic); + REGISTER_MATCHER(isPure); + REGISTER_MATCHER(isRightFold); + REGISTER_MATCHER(isScoped); + REGISTER_MATCHER(isSharedKind); + REGISTER_MATCHER(isSignedInteger); + REGISTER_MATCHER(isStandaloneDirective); + REGISTER_MATCHER(isStaticLocal); + REGISTER_MATCHER(isStaticStorageClass); + REGISTER_MATCHER(isStruct); + REGISTER_MATCHER(isTemplateInstantiation); + REGISTER_MATCHER(isTypeDependent); + REGISTER_MATCHER(isUnaryFold); + REGISTER_MATCHER(isUnion); + REGISTER_MATCHER(isUnsignedInteger); + REGISTER_MATCHER(isUserProvided); + REGISTER_MATCHER(isValueDependent); + REGISTER_MATCHER(isVariadic); + REGISTER_MATCHER(isVirtual); + REGISTER_MATCHER(isVirtualAsWritten); + REGISTER_MATCHER(isVolatileQualified); + REGISTER_MATCHER(isWeak); + REGISTER_MATCHER(isWritten); + REGISTER_MATCHER(lValueReferenceType); + REGISTER_MATCHER(labelDecl); + REGISTER_MATCHER(labelStmt); + REGISTER_MATCHER(lambdaCapture); + REGISTER_MATCHER(lambdaExpr); + REGISTER_MATCHER(linkageSpecDecl); + REGISTER_MATCHER(macroQualifiedType); + REGISTER_MATCHER(materializeTemporaryExpr); + REGISTER_MATCHER(member); + REGISTER_MATCHER(memberExpr); + REGISTER_MATCHER(memberHasSameNameAsBoundNode); + REGISTER_MATCHER(memberPointerType); + REGISTER_MATCHER(namedDecl); + REGISTER_MATCHER(namesType); + REGISTER_MATCHER(namespaceAliasDecl); + REGISTER_MATCHER(namespaceDecl); + REGISTER_MATCHER(nestedNameSpecifier); + REGISTER_MATCHER(nestedNameSpecifierLoc); + REGISTER_MATCHER(nonTypeTemplateParmDecl); + REGISTER_MATCHER(nullPointerConstant); + REGISTER_MATCHER(nullStmt); + REGISTER_MATCHER(numSelectorArgs); + REGISTER_MATCHER(objcCatchStmt); + REGISTER_MATCHER(objcCategoryDecl); + REGISTER_MATCHER(objcCategoryImplDecl); + REGISTER_MATCHER(objcFinallyStmt); + REGISTER_MATCHER(objcImplementationDecl); + REGISTER_MATCHER(objcInterfaceDecl); + REGISTER_MATCHER(objcIvarDecl); + REGISTER_MATCHER(objcIvarRefExpr); + REGISTER_MATCHER(objcMessageExpr); + REGISTER_MATCHER(objcMethodDecl); + REGISTER_MATCHER(objcObjectPointerType); + REGISTER_MATCHER(objcPropertyDecl); + REGISTER_MATCHER(objcProtocolDecl); + REGISTER_MATCHER(objcStringLiteral); + REGISTER_MATCHER(objcThrowStmt); + REGISTER_MATCHER(objcTryStmt); + REGISTER_MATCHER(ofClass); + REGISTER_MATCHER(ofKind); + REGISTER_MATCHER(ompDefaultClause); + REGISTER_MATCHER(ompExecutableDirective); + REGISTER_MATCHER(on); + REGISTER_MATCHER(onImplicitObjectArgument); + REGISTER_MATCHER(opaqueValueExpr); + REGISTER_MATCHER(optionally); + REGISTER_MATCHER(parameterCountIs); + REGISTER_MATCHER(parenExpr); + REGISTER_MATCHER(parenListExpr); + REGISTER_MATCHER(parenType); + REGISTER_MATCHER(parmVarDecl); + REGISTER_MATCHER(pointee); + REGISTER_MATCHER(pointerType); + REGISTER_MATCHER(pointerTypeLoc); + REGISTER_MATCHER(predefinedExpr); + REGISTER_MATCHER(qualType); + REGISTER_MATCHER(qualifiedTypeLoc); + REGISTER_MATCHER(rValueReferenceType); + REGISTER_MATCHER(realFloatingPointType); + REGISTER_MATCHER(recordDecl); + REGISTER_MATCHER(recordType); + REGISTER_MATCHER(referenceType); + REGISTER_MATCHER(referenceTypeLoc); + REGISTER_MATCHER(refersToDeclaration); + REGISTER_MATCHER(refersToIntegralType); + REGISTER_MATCHER(refersToTemplate); + REGISTER_MATCHER(refersToType); + REGISTER_MATCHER(requiresZeroInitialization); + REGISTER_MATCHER(returnStmt); + REGISTER_MATCHER(returns); + REGISTER_MATCHER(sizeOfExpr); + REGISTER_MATCHER(specifiesNamespace); + REGISTER_MATCHER(specifiesType); + REGISTER_MATCHER(specifiesTypeLoc); + REGISTER_MATCHER(statementCountIs); + REGISTER_MATCHER(staticAssertDecl); + REGISTER_MATCHER(stmt); + REGISTER_MATCHER(stmtExpr); + REGISTER_MATCHER(stringLiteral); + REGISTER_MATCHER(substNonTypeTemplateParmExpr); + REGISTER_MATCHER(substTemplateTypeParmType); + REGISTER_MATCHER(switchCase); + REGISTER_MATCHER(switchStmt); + REGISTER_MATCHER(tagDecl); + REGISTER_MATCHER(tagType); + REGISTER_MATCHER(templateArgument); + REGISTER_MATCHER(templateArgumentCountIs); + REGISTER_MATCHER(templateArgumentLoc); + REGISTER_MATCHER(templateName); + REGISTER_MATCHER(templateSpecializationType); + REGISTER_MATCHER(templateSpecializationTypeLoc); + REGISTER_MATCHER(templateTemplateParmDecl); + REGISTER_MATCHER(templateTypeParmDecl); + REGISTER_MATCHER(templateTypeParmType); + REGISTER_MATCHER(throughUsingDecl); + REGISTER_MATCHER(to); + REGISTER_MATCHER(translationUnitDecl); + REGISTER_MATCHER(type); + REGISTER_MATCHER(typeAliasDecl); + REGISTER_MATCHER(typeAliasTemplateDecl); + REGISTER_MATCHER(typeLoc); + REGISTER_MATCHER(typedefDecl); + REGISTER_MATCHER(typedefNameDecl); + REGISTER_MATCHER(typedefType); + REGISTER_MATCHER(unaryExprOrTypeTraitExpr); + REGISTER_MATCHER(unaryOperator); + REGISTER_MATCHER(unaryTransformType); + REGISTER_MATCHER(unless); + REGISTER_MATCHER(unresolvedLookupExpr); + REGISTER_MATCHER(unresolvedMemberExpr); + REGISTER_MATCHER(unresolvedUsingTypenameDecl); + REGISTER_MATCHER(unresolvedUsingValueDecl); + REGISTER_MATCHER(userDefinedLiteral); + REGISTER_MATCHER(usesADL); + REGISTER_MATCHER(usingDecl); + REGISTER_MATCHER(usingEnumDecl); + REGISTER_MATCHER(usingDirectiveDecl); + REGISTER_MATCHER(valueDecl); + REGISTER_MATCHER(varDecl); + REGISTER_MATCHER(variableArrayType); + REGISTER_MATCHER(voidType); + REGISTER_MATCHER(whileStmt); + REGISTER_MATCHER(withInitializer); +} + +RegistryMaps::~RegistryMaps() = default; + +static llvm::ManagedStatic<RegistryMaps> RegistryData; + +ASTNodeKind Registry::nodeMatcherType(MatcherCtor Ctor) { + return Ctor->nodeMatcherType(); +} + +internal::MatcherDescriptorPtr::MatcherDescriptorPtr(MatcherDescriptor *Ptr) + : Ptr(Ptr) {} + +internal::MatcherDescriptorPtr::~MatcherDescriptorPtr() { delete Ptr; } + +bool Registry::isBuilderMatcher(MatcherCtor Ctor) { + return Ctor->isBuilderMatcher(); +} + +internal::MatcherDescriptorPtr +Registry::buildMatcherCtor(MatcherCtor Ctor, SourceRange NameRange, + ArrayRef<ParserValue> Args, Diagnostics *Error) { + return internal::MatcherDescriptorPtr( + Ctor->buildMatcherCtor(NameRange, Args, Error).release()); +} + +// static +std::optional<MatcherCtor> Registry::lookupMatcherCtor(StringRef MatcherName) { + auto it = RegistryData->constructors().find(MatcherName); + return it == RegistryData->constructors().end() ? std::optional<MatcherCtor>() + : it->second.get(); +} + +static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const std::set<ASTNodeKind> &KS) { + unsigned Count = 0; + for (std::set<ASTNodeKind>::const_iterator I = KS.begin(), E = KS.end(); + I != E; ++I) { + if (I != KS.begin()) + OS << "|"; + if (Count++ == 3) { + OS << "..."; + break; + } + OS << *I; + } + return OS; +} + +std::vector<ArgKind> Registry::getAcceptedCompletionTypes( + ArrayRef<std::pair<MatcherCtor, unsigned>> Context) { + ASTNodeKind InitialTypes[] = { + ASTNodeKind::getFromNodeKind<Decl>(), + ASTNodeKind::getFromNodeKind<QualType>(), + ASTNodeKind::getFromNodeKind<Type>(), + ASTNodeKind::getFromNodeKind<Stmt>(), + ASTNodeKind::getFromNodeKind<NestedNameSpecifier>(), + ASTNodeKind::getFromNodeKind<NestedNameSpecifierLoc>(), + ASTNodeKind::getFromNodeKind<TypeLoc>()}; + + // Starting with the above seed of acceptable top-level matcher types, compute + // the acceptable type set for the argument indicated by each context element. + std::set<ArgKind> TypeSet; + for (auto IT : InitialTypes) { + TypeSet.insert(ArgKind::MakeMatcherArg(IT)); + } + for (const auto &CtxEntry : Context) { + MatcherCtor Ctor = CtxEntry.first; + unsigned ArgNumber = CtxEntry.second; + std::vector<ArgKind> NextTypeSet; + for (const ArgKind &Kind : TypeSet) { + if (Kind.getArgKind() == Kind.AK_Matcher && + Ctor->isConvertibleTo(Kind.getMatcherKind()) && + (Ctor->isVariadic() || ArgNumber < Ctor->getNumArgs())) + Ctor->getArgKinds(Kind.getMatcherKind(), ArgNumber, NextTypeSet); + } + TypeSet.clear(); + TypeSet.insert(NextTypeSet.begin(), NextTypeSet.end()); + } + return std::vector<ArgKind>(TypeSet.begin(), TypeSet.end()); +} + +std::vector<MatcherCompletion> +Registry::getMatcherCompletions(ArrayRef<ArgKind> AcceptedTypes) { + std::vector<MatcherCompletion> Completions; + + // Search the registry for acceptable matchers. + for (const auto &M : RegistryData->constructors()) { + const MatcherDescriptor& Matcher = *M.getValue(); + StringRef Name = M.getKey(); + + std::set<ASTNodeKind> RetKinds; + unsigned NumArgs = Matcher.isVariadic() ? 1 : Matcher.getNumArgs(); + bool IsPolymorphic = Matcher.isPolymorphic(); + std::vector<std::vector<ArgKind>> ArgsKinds(NumArgs); + unsigned MaxSpecificity = 0; + bool NodeArgs = false; + for (const ArgKind& Kind : AcceptedTypes) { + if (Kind.getArgKind() != Kind.AK_Matcher && + Kind.getArgKind() != Kind.AK_Node) { + continue; + } + + if (Kind.getArgKind() == Kind.AK_Node) { + NodeArgs = true; + unsigned Specificity; + ASTNodeKind LeastDerivedKind; + if (Matcher.isConvertibleTo(Kind.getNodeKind(), &Specificity, + &LeastDerivedKind)) { + if (MaxSpecificity < Specificity) + MaxSpecificity = Specificity; + RetKinds.insert(LeastDerivedKind); + for (unsigned Arg = 0; Arg != NumArgs; ++Arg) + Matcher.getArgKinds(Kind.getNodeKind(), Arg, ArgsKinds[Arg]); + if (IsPolymorphic) + break; + } + } else { + unsigned Specificity; + ASTNodeKind LeastDerivedKind; + if (Matcher.isConvertibleTo(Kind.getMatcherKind(), &Specificity, + &LeastDerivedKind)) { + if (MaxSpecificity < Specificity) + MaxSpecificity = Specificity; + RetKinds.insert(LeastDerivedKind); + for (unsigned Arg = 0; Arg != NumArgs; ++Arg) + Matcher.getArgKinds(Kind.getMatcherKind(), Arg, ArgsKinds[Arg]); + if (IsPolymorphic) + break; + } + } + } + + if (!RetKinds.empty() && MaxSpecificity > 0) { + std::string Decl; + llvm::raw_string_ostream OS(Decl); + + std::string TypedText = std::string(Name); + + if (NodeArgs) { + OS << Name; + } else { + + if (IsPolymorphic) { + OS << "Matcher<T> " << Name << "(Matcher<T>"; + } else { + OS << "Matcher<" << RetKinds << "> " << Name << "("; + for (const std::vector<ArgKind> &Arg : ArgsKinds) { + if (&Arg != &ArgsKinds[0]) + OS << ", "; + + bool FirstArgKind = true; + std::set<ASTNodeKind> MatcherKinds; + // Two steps. First all non-matchers, then matchers only. + for (const ArgKind &AK : Arg) { + if (AK.getArgKind() == ArgKind::AK_Matcher) { + MatcherKinds.insert(AK.getMatcherKind()); + } else { + if (!FirstArgKind) + OS << "|"; + FirstArgKind = false; + OS << AK.asString(); + } + } + if (!MatcherKinds.empty()) { + if (!FirstArgKind) OS << "|"; + OS << "Matcher<" << MatcherKinds << ">"; + } + } + } + if (Matcher.isVariadic()) + OS << "..."; + OS << ")"; + + TypedText += "("; + if (ArgsKinds.empty()) + TypedText += ")"; + else if (ArgsKinds[0][0].getArgKind() == ArgKind::AK_String) + TypedText += "\""; + } + + Completions.emplace_back(TypedText, OS.str(), MaxSpecificity); + } + } + + return Completions; +} + +VariantMatcher Registry::constructMatcher(MatcherCtor Ctor, + SourceRange NameRange, + ArrayRef<ParserValue> Args, + Diagnostics *Error) { + return Ctor->create(NameRange, Args, Error); +} + +VariantMatcher Registry::constructBoundMatcher(MatcherCtor Ctor, + SourceRange NameRange, + StringRef BindID, + ArrayRef<ParserValue> Args, + Diagnostics *Error) { + VariantMatcher Out = constructMatcher(Ctor, NameRange, Args, Error); + if (Out.isNull()) return Out; + + std::optional<DynTypedMatcher> Result = Out.getSingleMatcher(); + if (Result) { + std::optional<DynTypedMatcher> Bound = Result->tryBind(BindID); + if (Bound) { + return VariantMatcher::SingleMatcher(*Bound); + } + } + Error->addError(NameRange, Error->ET_RegistryNotBindable); + return VariantMatcher(); +} + +} // namespace dynamic +} // namespace ast_matchers +} // namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp new file mode 100644 index 000000000000..4f6b021b26f0 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/VariantValue.cpp @@ -0,0 +1,493 @@ +//===--- VariantValue.cpp - Polymorphic value type --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Polymorphic value type. +/// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/Dynamic/VariantValue.h" +#include "clang/Basic/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include <optional> + +namespace clang { +namespace ast_matchers { +namespace dynamic { + +std::string ArgKind::asString() const { + switch (getArgKind()) { + case AK_Matcher: + return (Twine("Matcher<") + NodeKind.asStringRef() + ">").str(); + case AK_Node: + return NodeKind.asStringRef().str(); + case AK_Boolean: + return "boolean"; + case AK_Double: + return "double"; + case AK_Unsigned: + return "unsigned"; + case AK_String: + return "string"; + } + llvm_unreachable("unhandled ArgKind"); +} + +bool ArgKind::isConvertibleTo(ArgKind To, unsigned *Specificity) const { + if (K != To.K) + return false; + if (K != AK_Matcher && K != AK_Node) { + if (Specificity) + *Specificity = 1; + return true; + } + unsigned Distance; + if (!NodeKind.isBaseOf(To.NodeKind, &Distance)) + return false; + + if (Specificity) + *Specificity = 100 - Distance; + return true; +} + +bool +VariantMatcher::MatcherOps::canConstructFrom(const DynTypedMatcher &Matcher, + bool &IsExactMatch) const { + IsExactMatch = Matcher.getSupportedKind().isSame(NodeKind); + return Matcher.canConvertTo(NodeKind); +} + +DynTypedMatcher VariantMatcher::MatcherOps::convertMatcher( + const DynTypedMatcher &Matcher) const { + return Matcher.dynCastTo(NodeKind); +} + +std::optional<DynTypedMatcher> +VariantMatcher::MatcherOps::constructVariadicOperator( + DynTypedMatcher::VariadicOperator Op, + ArrayRef<VariantMatcher> InnerMatchers) const { + std::vector<DynTypedMatcher> DynMatchers; + for (const auto &InnerMatcher : InnerMatchers) { + // Abort if any of the inner matchers can't be converted to + // Matcher<T>. + if (!InnerMatcher.Value) + return std::nullopt; + std::optional<DynTypedMatcher> Inner = + InnerMatcher.Value->getTypedMatcher(*this); + if (!Inner) + return std::nullopt; + DynMatchers.push_back(*Inner); + } + return DynTypedMatcher::constructVariadic(Op, NodeKind, DynMatchers); +} + +VariantMatcher::Payload::~Payload() {} + +class VariantMatcher::SinglePayload : public VariantMatcher::Payload { +public: + SinglePayload(const DynTypedMatcher &Matcher) : Matcher(Matcher) {} + + std::optional<DynTypedMatcher> getSingleMatcher() const override { + return Matcher; + } + + std::string getTypeAsString() const override { + return (Twine("Matcher<") + Matcher.getSupportedKind().asStringRef() + ">") + .str(); + } + + std::optional<DynTypedMatcher> + getTypedMatcher(const MatcherOps &Ops) const override { + bool Ignore; + if (Ops.canConstructFrom(Matcher, Ignore)) + return Matcher; + return std::nullopt; + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity) const override { + return ArgKind::MakeMatcherArg(Matcher.getSupportedKind()) + .isConvertibleTo(ArgKind::MakeMatcherArg(Kind), Specificity); + } + +private: + const DynTypedMatcher Matcher; +}; + +class VariantMatcher::PolymorphicPayload : public VariantMatcher::Payload { +public: + PolymorphicPayload(std::vector<DynTypedMatcher> MatchersIn) + : Matchers(std::move(MatchersIn)) {} + + ~PolymorphicPayload() override {} + + std::optional<DynTypedMatcher> getSingleMatcher() const override { + if (Matchers.size() != 1) + return std::nullopt; + return Matchers[0]; + } + + std::string getTypeAsString() const override { + std::string Inner; + for (size_t i = 0, e = Matchers.size(); i != e; ++i) { + if (i != 0) + Inner += "|"; + Inner += Matchers[i].getSupportedKind().asStringRef(); + } + return (Twine("Matcher<") + Inner + ">").str(); + } + + std::optional<DynTypedMatcher> + getTypedMatcher(const MatcherOps &Ops) const override { + bool FoundIsExact = false; + const DynTypedMatcher *Found = nullptr; + int NumFound = 0; + for (size_t i = 0, e = Matchers.size(); i != e; ++i) { + bool IsExactMatch; + if (Ops.canConstructFrom(Matchers[i], IsExactMatch)) { + if (Found) { + if (FoundIsExact) { + assert(!IsExactMatch && "We should not have two exact matches."); + continue; + } + } + Found = &Matchers[i]; + FoundIsExact = IsExactMatch; + ++NumFound; + } + } + // We only succeed if we found exactly one, or if we found an exact match. + if (Found && (FoundIsExact || NumFound == 1)) + return *Found; + return std::nullopt; + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity) const override { + unsigned MaxSpecificity = 0; + for (const DynTypedMatcher &Matcher : Matchers) { + unsigned ThisSpecificity; + if (ArgKind::MakeMatcherArg(Matcher.getSupportedKind()) + .isConvertibleTo(ArgKind::MakeMatcherArg(Kind), + &ThisSpecificity)) { + MaxSpecificity = std::max(MaxSpecificity, ThisSpecificity); + } + } + if (Specificity) + *Specificity = MaxSpecificity; + return MaxSpecificity > 0; + } + + const std::vector<DynTypedMatcher> Matchers; +}; + +class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload { +public: + VariadicOpPayload(DynTypedMatcher::VariadicOperator Op, + std::vector<VariantMatcher> Args) + : Op(Op), Args(std::move(Args)) {} + + std::optional<DynTypedMatcher> getSingleMatcher() const override { + return std::nullopt; + } + + std::string getTypeAsString() const override { + std::string Inner; + for (size_t i = 0, e = Args.size(); i != e; ++i) { + if (i != 0) + Inner += "&"; + Inner += Args[i].getTypeAsString(); + } + return Inner; + } + + std::optional<DynTypedMatcher> + getTypedMatcher(const MatcherOps &Ops) const override { + return Ops.constructVariadicOperator(Op, Args); + } + + bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity) const override { + for (const VariantMatcher &Matcher : Args) { + if (!Matcher.isConvertibleTo(Kind, Specificity)) + return false; + } + return true; + } + +private: + const DynTypedMatcher::VariadicOperator Op; + const std::vector<VariantMatcher> Args; +}; + +VariantMatcher::VariantMatcher() {} + +VariantMatcher VariantMatcher::SingleMatcher(const DynTypedMatcher &Matcher) { + return VariantMatcher(std::make_shared<SinglePayload>(Matcher)); +} + +VariantMatcher +VariantMatcher::PolymorphicMatcher(std::vector<DynTypedMatcher> Matchers) { + return VariantMatcher( + std::make_shared<PolymorphicPayload>(std::move(Matchers))); +} + +VariantMatcher VariantMatcher::VariadicOperatorMatcher( + DynTypedMatcher::VariadicOperator Op, + std::vector<VariantMatcher> Args) { + return VariantMatcher( + std::make_shared<VariadicOpPayload>(Op, std::move(Args))); +} + +std::optional<DynTypedMatcher> VariantMatcher::getSingleMatcher() const { + return Value ? Value->getSingleMatcher() : std::optional<DynTypedMatcher>(); +} + +void VariantMatcher::reset() { Value.reset(); } + +std::string VariantMatcher::getTypeAsString() const { + if (Value) return Value->getTypeAsString(); + return "<Nothing>"; +} + +VariantValue::VariantValue(const VariantValue &Other) : Type(VT_Nothing) { + *this = Other; +} + +VariantValue::VariantValue(bool Boolean) : Type(VT_Nothing) { + setBoolean(Boolean); +} + +VariantValue::VariantValue(double Double) : Type(VT_Nothing) { + setDouble(Double); +} + +VariantValue::VariantValue(unsigned Unsigned) : Type(VT_Nothing) { + setUnsigned(Unsigned); +} + +VariantValue::VariantValue(StringRef String) : Type(VT_Nothing) { + setString(String); +} + +VariantValue::VariantValue(ASTNodeKind NodeKind) : Type(VT_Nothing) { + setNodeKind(NodeKind); +} + +VariantValue::VariantValue(const VariantMatcher &Matcher) : Type(VT_Nothing) { + setMatcher(Matcher); +} + +VariantValue::~VariantValue() { reset(); } + +VariantValue &VariantValue::operator=(const VariantValue &Other) { + if (this == &Other) return *this; + reset(); + switch (Other.Type) { + case VT_Boolean: + setBoolean(Other.getBoolean()); + break; + case VT_Double: + setDouble(Other.getDouble()); + break; + case VT_Unsigned: + setUnsigned(Other.getUnsigned()); + break; + case VT_String: + setString(Other.getString()); + break; + case VT_NodeKind: + setNodeKind(Other.getNodeKind()); + break; + case VT_Matcher: + setMatcher(Other.getMatcher()); + break; + case VT_Nothing: + Type = VT_Nothing; + break; + } + return *this; +} + +void VariantValue::reset() { + switch (Type) { + case VT_String: + delete Value.String; + break; + case VT_Matcher: + delete Value.Matcher; + break; + case VT_NodeKind: + delete Value.NodeKind; + break; + // Cases that do nothing. + case VT_Boolean: + case VT_Double: + case VT_Unsigned: + case VT_Nothing: + break; + } + Type = VT_Nothing; +} + +bool VariantValue::isBoolean() const { + return Type == VT_Boolean; +} + +bool VariantValue::getBoolean() const { + assert(isBoolean()); + return Value.Boolean; +} + +void VariantValue::setBoolean(bool NewValue) { + reset(); + Type = VT_Boolean; + Value.Boolean = NewValue; +} + +bool VariantValue::isDouble() const { + return Type == VT_Double; +} + +double VariantValue::getDouble() const { + assert(isDouble()); + return Value.Double; +} + +void VariantValue::setDouble(double NewValue) { + reset(); + Type = VT_Double; + Value.Double = NewValue; +} + +bool VariantValue::isUnsigned() const { + return Type == VT_Unsigned; +} + +unsigned VariantValue::getUnsigned() const { + assert(isUnsigned()); + return Value.Unsigned; +} + +void VariantValue::setUnsigned(unsigned NewValue) { + reset(); + Type = VT_Unsigned; + Value.Unsigned = NewValue; +} + +bool VariantValue::isString() const { + return Type == VT_String; +} + +const std::string &VariantValue::getString() const { + assert(isString()); + return *Value.String; +} + +void VariantValue::setString(StringRef NewValue) { + reset(); + Type = VT_String; + Value.String = new std::string(NewValue); +} + +bool VariantValue::isNodeKind() const { return Type == VT_NodeKind; } + +const ASTNodeKind &VariantValue::getNodeKind() const { + assert(isNodeKind()); + return *Value.NodeKind; +} + +void VariantValue::setNodeKind(ASTNodeKind NewValue) { + reset(); + Type = VT_NodeKind; + Value.NodeKind = new ASTNodeKind(NewValue); +} + +bool VariantValue::isMatcher() const { + return Type == VT_Matcher; +} + +const VariantMatcher &VariantValue::getMatcher() const { + assert(isMatcher()); + return *Value.Matcher; +} + +void VariantValue::setMatcher(const VariantMatcher &NewValue) { + reset(); + Type = VT_Matcher; + Value.Matcher = new VariantMatcher(NewValue); +} + +bool VariantValue::isConvertibleTo(ArgKind Kind, unsigned *Specificity) const { + switch (Kind.getArgKind()) { + case ArgKind::AK_Boolean: + if (!isBoolean()) + return false; + *Specificity = 1; + return true; + + case ArgKind::AK_Double: + if (!isDouble()) + return false; + *Specificity = 1; + return true; + + case ArgKind::AK_Unsigned: + if (!isUnsigned()) + return false; + *Specificity = 1; + return true; + + case ArgKind::AK_String: + if (!isString()) + return false; + *Specificity = 1; + return true; + + case ArgKind::AK_Node: + if (!isNodeKind()) + return false; + return getMatcher().isConvertibleTo(Kind.getNodeKind(), Specificity); + + case ArgKind::AK_Matcher: + if (!isMatcher()) + return false; + return getMatcher().isConvertibleTo(Kind.getMatcherKind(), Specificity); + } + llvm_unreachable("Invalid Type"); +} + +bool VariantValue::isConvertibleTo(ArrayRef<ArgKind> Kinds, + unsigned *Specificity) const { + unsigned MaxSpecificity = 0; + for (const ArgKind& Kind : Kinds) { + unsigned ThisSpecificity; + if (!isConvertibleTo(Kind, &ThisSpecificity)) + continue; + MaxSpecificity = std::max(MaxSpecificity, ThisSpecificity); + } + if (Specificity && MaxSpecificity > 0) { + *Specificity = MaxSpecificity; + } + return MaxSpecificity > 0; +} + +std::string VariantValue::getTypeAsString() const { + switch (Type) { + case VT_String: return "String"; + case VT_Matcher: return getMatcher().getTypeAsString(); + case VT_Boolean: return "Boolean"; + case VT_Double: return "Double"; + case VT_Unsigned: return "Unsigned"; + case VT_NodeKind: + return getNodeKind().asStringRef().str(); + case VT_Nothing: return "Nothing"; + } + llvm_unreachable("Invalid Type"); +} + +} // end namespace dynamic +} // end namespace ast_matchers +} // end namespace clang diff --git a/contrib/llvm-project/clang/lib/ASTMatchers/GtestMatchers.cpp b/contrib/llvm-project/clang/lib/ASTMatchers/GtestMatchers.cpp new file mode 100644 index 000000000000..a556d8ef2da0 --- /dev/null +++ b/contrib/llvm-project/clang/lib/ASTMatchers/GtestMatchers.cpp @@ -0,0 +1,233 @@ +//===- GtestMatchers.cpp - AST Matchers for Gtest ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements several matchers for popular gtest macros. In general, +// AST matchers cannot match calls to macros. However, we can simulate such +// matches if the macro definition has identifiable elements that themselves can +// be matched. In that case, we can match on those elements and then check that +// the match occurs within an expansion of the desired macro. The more uncommon +// the identified elements, the more efficient this process will be. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/GtestMatchers.h" +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" + +namespace clang { +namespace ast_matchers { +namespace { + +enum class MacroType { + Expect, + Assert, + On, +}; + +} // namespace + +static DeclarationMatcher getComparisonDecl(GtestCmp Cmp) { + switch (Cmp) { + case GtestCmp::Eq: + return cxxMethodDecl(hasName("Compare"), + ofClass(cxxRecordDecl(isSameOrDerivedFrom( + hasName("::testing::internal::EqHelper"))))); + case GtestCmp::Ne: + return functionDecl(hasName("::testing::internal::CmpHelperNE")); + case GtestCmp::Ge: + return functionDecl(hasName("::testing::internal::CmpHelperGE")); + case GtestCmp::Gt: + return functionDecl(hasName("::testing::internal::CmpHelperGT")); + case GtestCmp::Le: + return functionDecl(hasName("::testing::internal::CmpHelperLE")); + case GtestCmp::Lt: + return functionDecl(hasName("::testing::internal::CmpHelperLT")); + } + llvm_unreachable("Unhandled GtestCmp enum"); +} + +static llvm::StringRef getMacroTypeName(MacroType Macro) { + switch (Macro) { + case MacroType::Expect: + return "EXPECT"; + case MacroType::Assert: + return "ASSERT"; + case MacroType::On: + return "ON"; + } + llvm_unreachable("Unhandled MacroType enum"); +} + +static llvm::StringRef getComparisonTypeName(GtestCmp Cmp) { + switch (Cmp) { + case GtestCmp::Eq: + return "EQ"; + case GtestCmp::Ne: + return "NE"; + case GtestCmp::Ge: + return "GE"; + case GtestCmp::Gt: + return "GT"; + case GtestCmp::Le: + return "LE"; + case GtestCmp::Lt: + return "LT"; + } + llvm_unreachable("Unhandled GtestCmp enum"); +} + +static std::string getMacroName(MacroType Macro, GtestCmp Cmp) { + return (getMacroTypeName(Macro) + "_" + getComparisonTypeName(Cmp)).str(); +} + +static std::string getMacroName(MacroType Macro, llvm::StringRef Operation) { + return (getMacroTypeName(Macro) + "_" + Operation).str(); +} + +// Under the hood, ON_CALL is expanded to a call to `InternalDefaultActionSetAt` +// to set a default action spec to the underlying function mocker, while +// EXPECT_CALL is expanded to a call to `InternalExpectedAt` to set a new +// expectation spec. +static llvm::StringRef getSpecSetterName(MacroType Macro) { + switch (Macro) { + case MacroType::On: + return "InternalDefaultActionSetAt"; + case MacroType::Expect: + return "InternalExpectedAt"; + default: + llvm_unreachable("Unhandled MacroType enum"); + } + llvm_unreachable("Unhandled MacroType enum"); +} + +// In general, AST matchers cannot match calls to macros. However, we can +// simulate such matches if the macro definition has identifiable elements that +// themselves can be matched. In that case, we can match on those elements and +// then check that the match occurs within an expansion of the desired +// macro. The more uncommon the identified elements, the more efficient this +// process will be. +// +// We use this approach to implement the derived matchers gtestAssert and +// gtestExpect. +static internal::BindableMatcher<Stmt> +gtestComparisonInternal(MacroType Macro, GtestCmp Cmp, StatementMatcher Left, + StatementMatcher Right) { + return callExpr(isExpandedFromMacro(getMacroName(Macro, Cmp)), + callee(getComparisonDecl(Cmp)), hasArgument(2, Left), + hasArgument(3, Right)); +} + +static internal::BindableMatcher<Stmt> +gtestThatInternal(MacroType Macro, StatementMatcher Actual, + StatementMatcher Matcher) { + return cxxOperatorCallExpr( + isExpandedFromMacro(getMacroName(Macro, "THAT")), + hasOverloadedOperatorName("()"), hasArgument(2, Actual), + hasArgument( + 0, expr(hasType(classTemplateSpecializationDecl(hasName( + "::testing::internal::PredicateFormatterFromMatcher"))), + ignoringImplicit( + callExpr(callee(functionDecl(hasName( + "::testing::internal::" + "MakePredicateFormatterFromMatcher"))), + hasArgument(0, ignoringImplicit(Matcher))))))); +} + +static internal::BindableMatcher<Stmt> +gtestCallInternal(MacroType Macro, StatementMatcher MockCall, MockArgs Args) { + // A ON_CALL or EXPECT_CALL macro expands to different AST structures + // depending on whether the mock method has arguments or not. + switch (Args) { + // For example, + // `ON_CALL(mock, TwoParamMethod)` is expanded to + // `mock.gmock_TwoArgsMethod(WithoutMatchers(), + // nullptr).InternalDefaultActionSetAt(...)`. + // EXPECT_CALL is the same except + // that it calls `InternalExpectedAt` instead of `InternalDefaultActionSetAt` + // in the end. + case MockArgs::None: + return cxxMemberCallExpr( + isExpandedFromMacro(getMacroName(Macro, "CALL")), + callee(functionDecl(hasName(getSpecSetterName(Macro)))), + onImplicitObjectArgument(ignoringImplicit(MockCall))); + // For example, + // `ON_CALL(mock, TwoParamMethod(m1, m2))` is expanded to + // `mock.gmock_TwoParamMethod(m1,m2)(WithoutMatchers(), + // nullptr).InternalDefaultActionSetAt(...)`. + // EXPECT_CALL is the same except that it calls `InternalExpectedAt` instead + // of `InternalDefaultActionSetAt` in the end. + case MockArgs::Some: + return cxxMemberCallExpr( + isExpandedFromMacro(getMacroName(Macro, "CALL")), + callee(functionDecl(hasName(getSpecSetterName(Macro)))), + onImplicitObjectArgument(ignoringImplicit(cxxOperatorCallExpr( + hasOverloadedOperatorName("()"), argumentCountIs(3), + hasArgument(0, ignoringImplicit(MockCall)))))); + } + llvm_unreachable("Unhandled MockArgs enum"); +} + +static internal::BindableMatcher<Stmt> +gtestCallInternal(MacroType Macro, StatementMatcher MockObject, + llvm::StringRef MockMethodName, MockArgs Args) { + return gtestCallInternal( + Macro, + cxxMemberCallExpr( + onImplicitObjectArgument(MockObject), + callee(functionDecl(hasName(("gmock_" + MockMethodName).str())))), + Args); +} + +internal::BindableMatcher<Stmt> gtestAssert(GtestCmp Cmp, StatementMatcher Left, + StatementMatcher Right) { + return gtestComparisonInternal(MacroType::Assert, Cmp, Left, Right); +} + +internal::BindableMatcher<Stmt> gtestExpect(GtestCmp Cmp, StatementMatcher Left, + StatementMatcher Right) { + return gtestComparisonInternal(MacroType::Expect, Cmp, Left, Right); +} + +internal::BindableMatcher<Stmt> gtestAssertThat(StatementMatcher Actual, + StatementMatcher Matcher) { + return gtestThatInternal(MacroType::Assert, Actual, Matcher); +} + +internal::BindableMatcher<Stmt> gtestExpectThat(StatementMatcher Actual, + StatementMatcher Matcher) { + return gtestThatInternal(MacroType::Expect, Actual, Matcher); +} + +internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockObject, + llvm::StringRef MockMethodName, + MockArgs Args) { + return gtestCallInternal(MacroType::On, MockObject, MockMethodName, Args); +} + +internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockCall, + MockArgs Args) { + return gtestCallInternal(MacroType::On, MockCall, Args); +} + +internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockObject, + llvm::StringRef MockMethodName, + MockArgs Args) { + return gtestCallInternal(MacroType::Expect, MockObject, MockMethodName, Args); +} + +internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockCall, + MockArgs Args) { + return gtestCallInternal(MacroType::Expect, MockCall, Args); +} + +} // end namespace ast_matchers +} // end namespace clang |