Implement has(), hasDescendant(), forEach() and forEachDescendant() for
Types, QualTypes and TypeLocs.

Review: http://llvm-reviews.chandlerc.com/D83

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@166917 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/ASTMatchers/ASTMatchFinder.cpp b/lib/ASTMatchers/ASTMatchFinder.cpp
index 218b781..38df2a1 100644
--- a/lib/ASTMatchers/ASTMatchFinder.cpp
+++ b/lib/ASTMatchers/ASTMatchFinder.cpp
@@ -124,7 +124,7 @@
       : Matcher(Matcher),
         Finder(Finder),
         Builder(Builder),
-        CurrentDepth(-1),
+        CurrentDepth(0),
         MaxDepth(MaxDepth),
         Traversal(Traversal),
         Bind(Bind),
@@ -147,6 +147,10 @@
       traverse(*D);
     else if (const Stmt *S = DynNode.get<Stmt>())
       traverse(*S);
+    else if (const QualType *Q = DynNode.get<QualType>())
+      traverse(*Q);
+    else if (const TypeLoc *T = DynNode.get<TypeLoc>())
+      traverse(*T);
     // FIXME: Add other base types after adding tests.
     return Matches;
   }
@@ -155,9 +159,11 @@
   // They are public only to allow CRTP to work. They are *not *part
   // of the public API of this class.
   bool TraverseDecl(Decl *DeclNode) {
+    ScopedIncrement ScopedDepth(&CurrentDepth);
     return (DeclNode == NULL) || traverse(*DeclNode);
   }
   bool TraverseStmt(Stmt *StmtNode) {
+    ScopedIncrement ScopedDepth(&CurrentDepth);
     const Stmt *StmtToTraverse = StmtNode;
     if (Traversal ==
         ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) {
@@ -168,9 +174,29 @@
     }
     return (StmtToTraverse == NULL) || traverse(*StmtToTraverse);
   }
+  // We assume that the QualType and the contained type are on the same
+  // hierarchy level. Thus, we try to match either of them.
   bool TraverseType(QualType TypeNode) {
+    ScopedIncrement ScopedDepth(&CurrentDepth);
+    // Match the Type.
+    if (!match(*TypeNode))
+      return false;
+    // The QualType is matched inside traverse.
     return traverse(TypeNode);
   }
+  // We assume that the TypeLoc, contained QualType and contained Type all are
+  // on the same hierarchy level. Thus, we try to match all of them.
+  bool TraverseTypeLoc(TypeLoc TypeLocNode) {
+    ScopedIncrement ScopedDepth(&CurrentDepth);
+    // Match the Type.
+    if (!match(*TypeLocNode.getType()))
+      return false;
+    // Match the QualType.
+    if (!match(TypeLocNode.getType()))
+      return false;
+    // The TypeLoc is matched inside traverse.
+    return traverse(TypeLocNode);
+  }
 
   bool shouldVisitTemplateInstantiations() const { return true; }
   bool shouldVisitImplicitCode() const { return true; }
@@ -188,7 +214,7 @@
   // Resets the state of this object.
   void reset() {
     Matches = false;
-    CurrentDepth = -1;
+    CurrentDepth = 0;
   }
 
   // Forwards the call to the corresponding Traverse*() method in the
@@ -202,18 +228,19 @@
   bool baseTraverse(QualType TypeNode) {
     return VisitorBase::TraverseType(TypeNode);
   }
+  bool baseTraverse(TypeLoc TypeLocNode) {
+    return VisitorBase::TraverseTypeLoc(TypeLocNode);
+  }
 
-  // Traverses the subtree rooted at 'node'; returns true if the
-  // traversal should continue after this function returns; also sets
-  // matched_ to true if a match is found during the traversal.
+  // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
+  //   0 < CurrentDepth <= MaxDepth.
+  //
+  // Returns 'true' if traversal should continue after this function
+  // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
   template <typename T>
-  bool traverse(const T &Node) {
-    TOOLING_COMPILE_ASSERT(IsBaseType<T>::value,
-                           traverse_can_only_be_instantiated_with_base_type);
-    ScopedIncrement ScopedDepth(&CurrentDepth);
-    if (CurrentDepth == 0) {
-      // We don't want to match the root node, so just recurse.
-      return baseTraverse(Node);
+  bool match(const T &Node) {
+    if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
+      return true;
     }
     if (Bind != ASTMatchFinder::BK_All) {
       if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node),
@@ -221,15 +248,6 @@
         Matches = true;
         return false;  // Abort as soon as a match is found.
       }
-      if (CurrentDepth < MaxDepth) {
-        // The current node doesn't match, and we haven't reached the
-        // maximum depth yet, so recurse.
-        return baseTraverse(Node);
-      }
-      // The current node doesn't match, and we have reached the
-      // maximum depth, so don't recurse (but continue the traversal
-      // such that other nodes at the current level can be visited).
-      return true;
     } else {
       BoundNodesTreeBuilder RecursiveBuilder;
       if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node),
@@ -238,12 +256,19 @@
         Matches = true;
         Builder->addMatch(RecursiveBuilder.build());
       }
-      if (CurrentDepth < MaxDepth) {
-        baseTraverse(Node);
-      }
-      // In kBindAll mode we always search for more matches.
-      return true;
     }
+    return true;
+  }
+
+  // Traverses the subtree rooted at 'Node'; returns true if the
+  // traversal should continue after this function returns.
+  template <typename T>
+  bool traverse(const T &Node) {
+    TOOLING_COMPILE_ASSERT(IsBaseType<T>::value,
+                           traverse_can_only_be_instantiated_with_base_type);
+    if (!match(Node))
+      return false;
+    return baseTraverse(Node);
   }
 
   const DynTypedMatcher *const Matcher;
@@ -322,8 +347,12 @@
                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
                                   TraversalKind Traversal, BindKind Bind) {
     const UntypedMatchInput input(Matcher.getID(), Node.getMemoizationData());
-    assert(input.second &&
-           "Fix getMemoizationData once more types allow recursive matching.");
+
+    // For AST-nodes that don't have an identity, we can't memoize.
+    if (!input.second)
+      return matchesRecursively(Node, Matcher, Builder, MaxDepth, Traversal,
+                                Bind);
+
     std::pair<MemoizationMap::iterator, bool> InsertResult
       = ResultCache.insert(std::make_pair(input, MemoizedMatchResult()));
     if (InsertResult.second) {