Only visit default arguments for template declarations when visiting the template declaration which introduced them. Patch by Yang Chen!


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@157723 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h
index 011dda5..72f24a0 100644
--- a/include/clang/AST/RecursiveASTVisitor.h
+++ b/include/clang/AST/RecursiveASTVisitor.h
@@ -1486,7 +1486,7 @@
     // D is the "T" in something like
     //   template <template <typename> class T> class container { };
     TRY_TO(TraverseDecl(D->getTemplatedDecl()));
-    if (D->hasDefaultArgument()) {
+    if (D->hasDefaultArgument() && !D->defaultArgumentWasInherited()) {
       TRY_TO(TraverseTemplateArgumentLoc(D->getDefaultArgument()));
     }
     TRY_TO(TraverseTemplateParameterListHelper(D->getTemplateParameters()));
@@ -1496,7 +1496,7 @@
     // D is the "T" in something like "template<typename T> class vector;"
     if (D->getTypeForDecl())
       TRY_TO(TraverseType(QualType(D->getTypeForDecl(), 0)));
-    if (D->hasDefaultArgument())
+    if (D->hasDefaultArgument() && !D->defaultArgumentWasInherited())
       TRY_TO(TraverseTypeLoc(D->getDefaultArgumentInfo()->getTypeLoc()));
   })
 
@@ -1766,7 +1766,8 @@
 DEF_TRAVERSE_DECL(NonTypeTemplateParmDecl, {
     // A non-type template parameter, e.g. "S" in template<int S> class Foo ...
     TRY_TO(TraverseDeclaratorHelper(D));
-    TRY_TO(TraverseStmt(D->getDefaultArgument()));
+    if (D->hasDefaultArgument() && !D->defaultArgumentWasInherited())
+      TRY_TO(TraverseStmt(D->getDefaultArgument()));
   })
 
 DEF_TRAVERSE_DECL(ParmVarDecl, {
diff --git a/unittests/Tooling/RecursiveASTVisitorTest.cpp b/unittests/Tooling/RecursiveASTVisitorTest.cpp
index 5a77673..9ad825f 100644
--- a/unittests/Tooling/RecursiveASTVisitorTest.cpp
+++ b/unittests/Tooling/RecursiveASTVisitorTest.cpp
@@ -185,6 +185,33 @@
   }
 };
 
+class TemplateArgumentLocTraverser
+  : public ExpectedLocationVisitor<TemplateArgumentLocTraverser> {
+public:
+  bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
+    std::string ArgStr;
+    llvm::raw_string_ostream Stream(ArgStr);
+    const TemplateArgument &Arg = ArgLoc.getArgument();
+
+    Arg.print(Context->getPrintingPolicy(), Stream);
+    Match(Stream.str(), ArgLoc.getLocation());
+    return ExpectedLocationVisitor<TemplateArgumentLocTraverser>::
+      TraverseTemplateArgumentLoc(ArgLoc);
+  }
+};
+
+class CXXBoolLiteralExprVisitor 
+  : public ExpectedLocationVisitor<CXXBoolLiteralExprVisitor> {
+public:
+  bool VisitCXXBoolLiteralExpr(CXXBoolLiteralExpr *BE) {
+    if (BE->getValue())
+      Match("true", BE->getLocation());
+    else
+      Match("false", BE->getLocation());
+    return true;
+  }
+};
+
 TEST(RecursiveASTVisitor, VisitsBaseClassDeclarations) {
   TypeLocVisitor Visitor;
   Visitor.ExpectMatch("class X", 1, 30);
@@ -394,4 +421,31 @@
   EXPECT_TRUE(Visitor.runOver("int k = (4) + 9;\n"));
 }
 
+TEST(RecursiveASTVisitor, VisitsClassTemplateNonTypeParmDefaultArgument) {
+  CXXBoolLiteralExprVisitor Visitor;
+  Visitor.ExpectMatch("true", 2, 19);
+  EXPECT_TRUE(Visitor.runOver(
+    "template<bool B> class X;\n"
+    "template<bool B = true> class Y;\n"
+    "template<bool B> class Y {};\n"));
+}
+
+TEST(RecursiveASTVisitor, VisitsClassTemplateTypeParmDefaultArgument) {
+  TypeLocVisitor Visitor;
+  Visitor.ExpectMatch("class X", 2, 23);
+  EXPECT_TRUE(Visitor.runOver(
+    "class X;\n"
+    "template<typename T = X> class Y;\n"
+    "template<typename T> class Y {};\n"));
+}
+
+TEST(RecursiveASTVisitor, VisitsClassTemplateTemplateParmDefaultArgument) {
+  TemplateArgumentLocTraverser Visitor;
+  Visitor.ExpectMatch("X", 2, 40);
+  EXPECT_TRUE(Visitor.runOver(
+    "template<typename T> class X;\n"
+    "template<template <typename> class T = X> class Y;\n"
+    "template<template <typename> class T> class Y {};\n"));
+}
+
 } // end namespace clang