Add method to ignore invisible AST nodes
Reviewers: aaron.ballman
Subscribers: mgorny, cfe-commits
Tags: #clang
Differential Revision: https://reviews.llvm.org/D70613
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index b2e6d9e..f097d02 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -113,6 +113,9 @@
         case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
           S = E->IgnoreParenImpCasts();
           break;
+        case ast_type_traits::TK_IgnoreUnlessSpelledInSource:
+          S = E->IgnoreUnlessSpelledInSource();
+          break;
         }
       }
 
diff --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h
index 3e2f416..1a12281 100644
--- a/clang/include/clang/AST/ASTTypeTraits.h
+++ b/clang/include/clang/AST/ASTTypeTraits.h
@@ -43,7 +43,10 @@
 
   /// Will not traverse implicit casts and parentheses.
   /// Corresponds to Expr::IgnoreParenImpCasts()
-  TK_IgnoreImplicitCastsAndParentheses
+  TK_IgnoreImplicitCastsAndParentheses,
+
+  /// Ignore AST nodes not written in the source
+  TK_IgnoreUnlessSpelledInSource
 };
 
 /// Kind identifier.
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index ef0d1e5..fa95578 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -762,6 +762,15 @@
   /// member expression.
   static QualType findBoundMemberType(const Expr *expr);
 
+  /// Skip past any invisble AST nodes which might surround this
+  /// statement, such as ExprWithCleanups or ImplicitCastExpr nodes,
+  /// but also injected CXXMemberExpr and CXXConstructExpr which represent
+  /// implicit conversions.
+  Expr *IgnoreUnlessSpelledInSource();
+  const Expr *IgnoreUnlessSpelledInSource() const {
+    return const_cast<Expr *>(this)->IgnoreUnlessSpelledInSource();
+  }
+
   /// Skip past any implicit casts which might surround this expression until
   /// reaching a fixed point. Skips:
   /// * ImplicitCastExpr
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index da279fa..a09b1aca 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -112,6 +112,8 @@
     return E;
   case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
     return E->IgnoreParenImpCasts();
+  case ast_type_traits::TK_IgnoreUnlessSpelledInSource:
+    return E->IgnoreUnlessSpelledInSource();
   }
   llvm_unreachable("Invalid Traversal type!");
 }
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c4c8fcf..5c9ceac 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3020,6 +3020,34 @@
   });
 }
 
+Expr *Expr::IgnoreUnlessSpelledInSource() {
+  Expr *E = this;
+
+  Expr *LastE = nullptr;
+  while (E != LastE) {
+    LastE = E;
+    E = E->IgnoreImplicit();
+
+    auto SR = E->getSourceRange();
+
+    if (auto *C = dyn_cast<CXXConstructExpr>(E)) {
+      if (C->getNumArgs() == 1) {
+        Expr *A = C->getArg(0);
+        if (A->getSourceRange() == SR || !isa<CXXTemporaryObjectExpr>(C))
+          E = A;
+      }
+    }
+
+    if (auto *C = dyn_cast<CXXMemberCallExpr>(E)) {
+      Expr *ExprNode = C->getImplicitObjectArgument()->IgnoreParenImpCasts();
+      if (ExprNode->getSourceRange() == SR)
+        E = ExprNode;
+    }
+  }
+
+  return E;
+}
+
 bool Expr::isDefaultArgument() const {
   const Expr *E = this;
   if (const MaterializeTemporaryExpr *M = dyn_cast<MaterializeTemporaryExpr>(E))
diff --git a/clang/unittests/AST/ASTTraverserTest.cpp b/clang/unittests/AST/ASTTraverserTest.cpp
index 6debdf1..c995f55 100644
--- a/clang/unittests/AST/ASTTraverserTest.cpp
+++ b/clang/unittests/AST/ASTTraverserTest.cpp
@@ -95,6 +95,21 @@
   return OS.str();
 }
 
+template <typename... NodeType>
+std::string dumpASTString(ast_type_traits::TraversalKind TK, NodeType &&... N) {
+  std::string Buffer;
+  llvm::raw_string_ostream OS(Buffer);
+
+  TestASTDumper Dumper(OS);
+  Dumper.SetTraversalKind(TK);
+
+  OS << "\n";
+
+  Dumper.Visit(std::forward<NodeType &&>(N)...);
+
+  return OS.str();
+}
+
 const FunctionDecl *getFunctionNode(clang::ASTUnit *AST,
                                     const std::string &Name) {
   auto Result = ast_matchers::match(functionDecl(hasName(Name)).bind("fn"),
@@ -244,4 +259,224 @@
   EXPECT_EQ(toTargetAddressSpace(static_cast<LangAS>(AS->getAddressSpace())),
             19u);
 }
+
+TEST(Traverse, IgnoreUnlessSpelledInSource) {
+
+  auto AST = buildASTFromCode(R"cpp(
+
+struct A
+{
+};
+
+struct B
+{
+  B(int);
+  B(A const& a);
+  B();
+};
+
+struct C
+{
+  operator B();
+};
+
+B func1() {
+  return 42;
+}
+
+B func2() {
+  return B{42};
+}
+
+B func3() {
+  return B(42);
+}
+
+B func4() {
+  return B();
+}
+
+B func5() {
+  return B{};
+}
+
+B func6() {
+  return C();
+}
+
+B func7() {
+  return A();
+}
+
+B func8() {
+  return C{};
+}
+
+B func9() {
+  return A{};
+}
+
+B func10() {
+  A a;
+  return a;
+}
+
+B func11() {
+  B b;
+  return b;
+}
+
+B func12() {
+  C c;
+  return c;
+}
+
+)cpp");
+
+  auto getFunctionNode = [&AST](const std::string &name) {
+    auto BN = ast_matchers::match(functionDecl(hasName(name)).bind("fn"),
+                                  AST->getASTContext());
+    EXPECT_EQ(BN.size(), 1u);
+    return BN[0].getNodeAs<Decl>("fn");
+  };
+
+  {
+    auto FN = getFunctionNode("func1");
+
+    EXPECT_EQ(dumpASTString(ast_type_traits::TK_AsIs, FN),
+              R"cpp(
+FunctionDecl 'func1'
+`-CompoundStmt
+  `-ReturnStmt
+    `-ExprWithCleanups
+      `-CXXConstructExpr
+        `-MaterializeTemporaryExpr
+          `-ImplicitCastExpr
+            `-CXXConstructExpr
+              `-IntegerLiteral
+)cpp");
+
+    EXPECT_EQ(
+        dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource, FN),
+        R"cpp(
+FunctionDecl 'func1'
+`-CompoundStmt
+  `-ReturnStmt
+    `-IntegerLiteral
+)cpp");
+  }
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func2")),
+            R"cpp(
+FunctionDecl 'func2'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXTemporaryObjectExpr
+      `-IntegerLiteral
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func3")),
+            R"cpp(
+FunctionDecl 'func3'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXFunctionalCastExpr
+      `-IntegerLiteral
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func4")),
+            R"cpp(
+FunctionDecl 'func4'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXTemporaryObjectExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func5")),
+            R"cpp(
+FunctionDecl 'func5'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXTemporaryObjectExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func6")),
+            R"cpp(
+FunctionDecl 'func6'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXTemporaryObjectExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func7")),
+            R"cpp(
+FunctionDecl 'func7'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXTemporaryObjectExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func8")),
+            R"cpp(
+FunctionDecl 'func8'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXFunctionalCastExpr
+      `-InitListExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func9")),
+            R"cpp(
+FunctionDecl 'func9'
+`-CompoundStmt
+  `-ReturnStmt
+    `-CXXFunctionalCastExpr
+      `-InitListExpr
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func10")),
+            R"cpp(
+FunctionDecl 'func10'
+`-CompoundStmt
+  |-DeclStmt
+  | `-VarDecl 'a'
+  |   `-CXXConstructExpr
+  `-ReturnStmt
+    `-DeclRefExpr 'a'
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func11")),
+            R"cpp(
+FunctionDecl 'func11'
+`-CompoundStmt
+  |-DeclStmt
+  | `-VarDecl 'b'
+  |   `-CXXConstructExpr
+  `-ReturnStmt
+    `-DeclRefExpr 'b'
+)cpp");
+
+  EXPECT_EQ(dumpASTString(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                          getFunctionNode("func12")),
+            R"cpp(
+FunctionDecl 'func12'
+`-CompoundStmt
+  |-DeclStmt
+  | `-VarDecl 'c'
+  |   `-CXXConstructExpr
+  `-ReturnStmt
+    `-DeclRefExpr 'c'
+)cpp");
+}
+
 } // namespace clang
diff --git a/clang/unittests/AST/CMakeLists.txt b/clang/unittests/AST/CMakeLists.txt
index 4d463d1..be585ef 100644
--- a/clang/unittests/AST/CMakeLists.txt
+++ b/clang/unittests/AST/CMakeLists.txt
@@ -15,6 +15,7 @@
   ASTImporterVisibilityTest.cpp
   ASTTraverserTest.cpp
   ASTTypeTraitsTest.cpp
+  ASTTraverserTest.cpp
   ASTVectorTest.cpp
   CommentLexer.cpp
   CommentParser.cpp
diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index f67e4d9..a21ed04 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -1680,6 +1680,149 @@
                          functionDecl(hasDescendant(Matcher)))))));
 }
 
+TEST(Traversal, traverseUnlessSpelledInSource) {
+
+  StringRef Code = R"cpp(
+
+struct A
+{
+};
+
+struct B
+{
+  B(int);
+  B(A const& a);
+  B();
+};
+
+struct C
+{
+  operator B();
+};
+
+B func1() {
+  return 42;
+}
+
+B func2() {
+  return B{42};
+}
+
+B func3() {
+  return B(42);
+}
+
+B func4() {
+  return B();
+}
+
+B func5() {
+  return B{};
+}
+
+B func6() {
+  return C();
+}
+
+B func7() {
+  return A();
+}
+
+B func8() {
+  return C{};
+}
+
+B func9() {
+  return A{};
+}
+
+B func10() {
+  A a;
+  return a;
+}
+
+B func11() {
+  B b;
+  return b;
+}
+
+B func12() {
+  C c;
+  return c;
+}
+
+)cpp";
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func1"))),
+                                hasReturnValue(integerLiteral(equals(42)))))));
+
+  EXPECT_TRUE(matches(
+      Code,
+      traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+               returnStmt(forFunction(functionDecl(hasName("func2"))),
+                          hasReturnValue(cxxTemporaryObjectExpr(
+                              hasArgument(0, integerLiteral(equals(42)))))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func3"))),
+                                hasReturnValue(
+                                    cxxFunctionalCastExpr(hasSourceExpression(
+                                        integerLiteral(equals(42)))))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func4"))),
+                                hasReturnValue(cxxTemporaryObjectExpr())))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func5"))),
+                                hasReturnValue(cxxTemporaryObjectExpr())))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func6"))),
+                                hasReturnValue(cxxTemporaryObjectExpr())))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func7"))),
+                                hasReturnValue(cxxTemporaryObjectExpr())))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func8"))),
+                                hasReturnValue(cxxFunctionalCastExpr(
+                                    hasSourceExpression(initListExpr())))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func9"))),
+                                hasReturnValue(cxxFunctionalCastExpr(
+                                    hasSourceExpression(initListExpr())))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func10"))),
+                                hasReturnValue(
+                                    declRefExpr(to(varDecl(hasName("a")))))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func11"))),
+                                hasReturnValue(
+                                    declRefExpr(to(varDecl(hasName("b")))))))));
+
+  EXPECT_TRUE(matches(
+      Code, traverse(ast_type_traits::TK_IgnoreUnlessSpelledInSource,
+                     returnStmt(forFunction(functionDecl(hasName("func12"))),
+                                hasReturnValue(
+                                    declRefExpr(to(varDecl(hasName("c")))))))));
+}
+
 TEST(IgnoringImpCasts, MatchesImpCasts) {
   // This test checks that ignoringImpCasts matches when implicit casts are
   // present and its inner matcher alone does not match.