Add ASTMatchRefactorer and ReplaceNodeWithTemplate to RefactoringCallbacks

Summary: This is the first change as part of developing a clang-query based search and replace tool.

Reviewers: klimek, bkramer, ioeric, sbenza, jbangert

Reviewed By: ioeric, jbangert

Subscribers: sbenza, ioeric, cfe-commits

Patch by Julian Bangert!

Differential Revision: https://reviews.llvm.org/D29621

llvm-svn: 302624
diff --git a/clang/lib/Tooling/RefactoringCallbacks.cpp b/clang/lib/Tooling/RefactoringCallbacks.cpp
index e900c23..6d3c016 100644
--- a/clang/lib/Tooling/RefactoringCallbacks.cpp
+++ b/clang/lib/Tooling/RefactoringCallbacks.cpp
@@ -9,8 +9,13 @@
 //
 //
 //===----------------------------------------------------------------------===//
-#include "clang/Lex/Lexer.h"
 #include "clang/Tooling/RefactoringCallbacks.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Lex/Lexer.h"
+
+using llvm::StringError;
+using llvm::make_error;
 
 namespace clang {
 namespace tooling {
@@ -20,18 +25,62 @@
   return Replace;
 }
 
-static Replacement replaceStmtWithText(SourceManager &Sources,
-                                       const Stmt &From,
-                                       StringRef Text) {
-  return tooling::Replacement(Sources, CharSourceRange::getTokenRange(
-      From.getSourceRange()), Text);
+ASTMatchRefactorer::ASTMatchRefactorer(
+    std::map<std::string, Replacements> &FileToReplaces)
+    : FileToReplaces(FileToReplaces) {}
+
+void ASTMatchRefactorer::addDynamicMatcher(
+    const ast_matchers::internal::DynTypedMatcher &Matcher,
+    RefactoringCallback *Callback) {
+  MatchFinder.addDynamicMatcher(Matcher, Callback);
+  Callbacks.push_back(Callback);
 }
-static Replacement replaceStmtWithStmt(SourceManager &Sources,
-                                       const Stmt &From,
+
+class RefactoringASTConsumer : public ASTConsumer {
+public:
+  RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
+      : Refactoring(Refactoring) {}
+
+  void HandleTranslationUnit(ASTContext &Context) override {
+    // The ASTMatchRefactorer is re-used between translation units.
+    // Clear the matchers so that each Replacement is only emitted once.
+    for (const auto &Callback : Refactoring.Callbacks) {
+      Callback->getReplacements().clear();
+    }
+    Refactoring.MatchFinder.matchAST(Context);
+    for (const auto &Callback : Refactoring.Callbacks) {
+      for (const auto &Replacement : Callback->getReplacements()) {
+        llvm::Error Err =
+            Refactoring.FileToReplaces[Replacement.getFilePath()].add(
+                Replacement);
+        if (Err) {
+          llvm::errs() << "Skipping replacement " << Replacement.toString()
+                       << " due to this error:\n"
+                       << toString(std::move(Err)) << "\n";
+        }
+      }
+    }
+  }
+
+private:
+  ASTMatchRefactorer &Refactoring;
+};
+
+std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
+  return llvm::make_unique<RefactoringASTConsumer>(*this);
+}
+
+static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
+                                       StringRef Text) {
+  return tooling::Replacement(
+      Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
+}
+static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
                                        const Stmt &To) {
-  return replaceStmtWithText(Sources, From, Lexer::getSourceText(
-      CharSourceRange::getTokenRange(To.getSourceRange()),
-      Sources, LangOptions()));
+  return replaceStmtWithText(
+      Sources, From,
+      Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
+                           Sources, LangOptions()));
 }
 
 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
@@ -103,5 +152,90 @@
   }
 }
 
+ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
+    llvm::StringRef FromId, std::vector<TemplateElement> &&Template)
+    : FromId(FromId), Template(Template) {}
+
+llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
+ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
+  std::vector<TemplateElement> ParsedTemplate;
+  for (size_t Index = 0; Index < ToTemplate.size();) {
+    if (ToTemplate[Index] == '$') {
+      if (ToTemplate.substr(Index, 2) == "$$") {
+        Index += 2;
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Literal, "$"});
+      } else if (ToTemplate.substr(Index, 2) == "${") {
+        size_t EndOfIdentifier = ToTemplate.find("}", Index);
+        if (EndOfIdentifier == std::string::npos) {
+          return make_error<StringError>(
+              "Unterminated ${...} in replacement template near " +
+                  ToTemplate.substr(Index),
+              std::make_error_code(std::errc::bad_message));
+        }
+        std::string SourceNodeName =
+            ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Identifier, SourceNodeName});
+        Index = EndOfIdentifier + 1;
+      } else {
+        return make_error<StringError>(
+            "Invalid $ in replacement template near " +
+                ToTemplate.substr(Index),
+            std::make_error_code(std::errc::bad_message));
+      }
+    } else {
+      size_t NextIndex = ToTemplate.find('$', Index + 1);
+      ParsedTemplate.push_back(
+          TemplateElement{TemplateElement::Literal,
+                          ToTemplate.substr(Index, NextIndex - Index)});
+      Index = NextIndex;
+    }
+  }
+  return std::unique_ptr<ReplaceNodeWithTemplate>(
+      new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
+}
+
+void ReplaceNodeWithTemplate::run(
+    const ast_matchers::MatchFinder::MatchResult &Result) {
+  const auto &NodeMap = Result.Nodes.getMap();
+
+  std::string ToText;
+  for (const auto &Element : Template) {
+    switch (Element.Type) {
+    case TemplateElement::Literal:
+      ToText += Element.Value;
+      break;
+    case TemplateElement::Identifier: {
+      auto NodeIter = NodeMap.find(Element.Value);
+      if (NodeIter == NodeMap.end()) {
+        llvm::errs() << "Node " << Element.Value
+                     << " used in replacement template not bound in Matcher \n";
+        llvm::report_fatal_error("Unbound node in replacement template.");
+      }
+      CharSourceRange Source =
+          CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
+      ToText += Lexer::getSourceText(Source, *Result.SourceManager,
+                                     Result.Context->getLangOpts());
+      break;
+    }
+    }
+  }
+  if (NodeMap.count(FromId) == 0) {
+    llvm::errs() << "Node to be replaced " << FromId
+                 << " not bound in query.\n";
+    llvm::report_fatal_error("FromId node not bound in MatchResult");
+  }
+  auto Replacement =
+      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
+                           Result.Context->getLangOpts());
+  llvm::Error Err = Replace.add(Replacement);
+  if (Err) {
+    llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
+                 << "! " << llvm::toString(std::move(Err)) << "\n";
+    llvm::report_fatal_error("Replacement failed");
+  }
+}
+
 } // end namespace tooling
 } // end namespace clang