Implement tooling::Replacements as a class.

Summary:
- Implement clang::tooling::Replacements as a class to provide interfaces to
  control how replacements for a single file are combined and provide guarantee
  on the order of replacements being applied.
- tooling::Replacements only contains replacements for the same file now.
  Use std::map<std::string, tooling::Replacements> to represent multi-file
  replacements.
- Error handling for the interface change will be improved in followup patches.

Reviewers: djasper, klimek

Subscribers: cfe-commits

Differential Revision: https://reviews.llvm.org/D21748

llvm-svn: 277335
diff --git a/clang/lib/Tooling/Core/Replacement.cpp b/clang/lib/Tooling/Core/Replacement.cpp
index 02aab2d..5e22f57 100644
--- a/clang/lib/Tooling/Core/Replacement.cpp
+++ b/clang/lib/Tooling/Core/Replacement.cpp
@@ -137,200 +137,30 @@
                         ReplacementText);
 }
 
-template <typename T>
-unsigned shiftedCodePositionInternal(const T &Replaces, unsigned Position) {
-  unsigned Offset = 0;
-  for (const auto& R : Replaces) {
-    if (R.getOffset() + R.getLength() <= Position) {
-      Offset += R.getReplacementText().size() - R.getLength();
-      continue;
+llvm::Error Replacements::add(const Replacement &R) {
+  if (R.getOffset() != UINT_MAX)
+    for (auto Replace : Replaces) {
+      if (R.getFilePath() != Replace.getFilePath())
+        return llvm::make_error<llvm::StringError>(
+            "All replacements must have the same file path. New replacement: " +
+                R.getFilePath() + ", existing replacements: " +
+                Replace.getFilePath() + "\n",
+            llvm::inconvertibleErrorCode());
+      if (R.getOffset() == Replace.getOffset() ||
+          Range(R.getOffset(), R.getLength())
+              .overlapsWith(Range(Replace.getOffset(), Replace.getLength())))
+        return llvm::make_error<llvm::StringError>(
+            "New replacement:\n" + R.toString() +
+                "\nconflicts with existing replacement:\n" + Replace.toString(),
+            llvm::inconvertibleErrorCode());
     }
-    if (R.getOffset() < Position &&
-        R.getOffset() + R.getReplacementText().size() <= Position) {
-      Position = R.getOffset() + R.getReplacementText().size() - 1;
-    }
-    break;
-  }
-  return Position + Offset;
-}
 
-unsigned shiftedCodePosition(const Replacements &Replaces, unsigned Position) {
-  return shiftedCodePositionInternal(Replaces, Position);
-}
-
-// FIXME: Remove this function when Replacements is implemented as std::vector
-// instead of std::set.
-unsigned shiftedCodePosition(const std::vector<Replacement> &Replaces,
-                             unsigned Position) {
-  return shiftedCodePositionInternal(Replaces, Position);
-}
-
-void deduplicate(std::vector<Replacement> &Replaces,
-                 std::vector<Range> &Conflicts) {
-  if (Replaces.empty())
-    return;
-
-  auto LessNoPath = [](const Replacement &LHS, const Replacement &RHS) {
-    if (LHS.getOffset() != RHS.getOffset())
-      return LHS.getOffset() < RHS.getOffset();
-    if (LHS.getLength() != RHS.getLength())
-      return LHS.getLength() < RHS.getLength();
-    return LHS.getReplacementText() < RHS.getReplacementText();
-  };
-
-  auto EqualNoPath = [](const Replacement &LHS, const Replacement &RHS) {
-    return LHS.getOffset() == RHS.getOffset() &&
-           LHS.getLength() == RHS.getLength() &&
-           LHS.getReplacementText() == RHS.getReplacementText();
-  };
-
-  // Deduplicate. We don't want to deduplicate based on the path as we assume
-  // that all replacements refer to the same file (or are symlinks).
-  std::sort(Replaces.begin(), Replaces.end(), LessNoPath);
-  Replaces.erase(std::unique(Replaces.begin(), Replaces.end(), EqualNoPath),
-                 Replaces.end());
-
-  // Detect conflicts
-  Range ConflictRange(Replaces.front().getOffset(),
-                      Replaces.front().getLength());
-  unsigned ConflictStart = 0;
-  unsigned ConflictLength = 1;
-  for (unsigned i = 1; i < Replaces.size(); ++i) {
-    Range Current(Replaces[i].getOffset(), Replaces[i].getLength());
-    if (ConflictRange.overlapsWith(Current)) {
-      // Extend conflicted range
-      ConflictRange = Range(ConflictRange.getOffset(),
-                            std::max(ConflictRange.getLength(),
-                                     Current.getOffset() + Current.getLength() -
-                                         ConflictRange.getOffset()));
-      ++ConflictLength;
-    } else {
-      if (ConflictLength > 1)
-        Conflicts.push_back(Range(ConflictStart, ConflictLength));
-      ConflictRange = Current;
-      ConflictStart = i;
-      ConflictLength = 1;
-    }
-  }
-
-  if (ConflictLength > 1)
-    Conflicts.push_back(Range(ConflictStart, ConflictLength));
-}
-
-bool applyAllReplacements(const Replacements &Replaces, Rewriter &Rewrite) {
-  bool Result = true;
-  for (Replacements::const_iterator I = Replaces.begin(),
-                                    E = Replaces.end();
-       I != E; ++I) {
-    if (I->isApplicable()) {
-      Result = I->apply(Rewrite) && Result;
-    } else {
-      Result = false;
-    }
-  }
-  return Result;
-}
-
-// FIXME: Remove this function when Replacements is implemented as std::vector
-// instead of std::set.
-bool applyAllReplacements(const std::vector<Replacement> &Replaces,
-                          Rewriter &Rewrite) {
-  bool Result = true;
-  for (std::vector<Replacement>::const_iterator I = Replaces.begin(),
-                                                E = Replaces.end();
-       I != E; ++I) {
-    if (I->isApplicable()) {
-      Result = I->apply(Rewrite) && Result;
-    } else {
-      Result = false;
-    }
-  }
-  return Result;
-}
-
-llvm::Expected<std::string> applyAllReplacements(StringRef Code,
-                                                const Replacements &Replaces) {
-  if (Replaces.empty())
-    return Code.str();
-
-  IntrusiveRefCntPtr<vfs::InMemoryFileSystem> InMemoryFileSystem(
-      new vfs::InMemoryFileSystem);
-  FileManager Files(FileSystemOptions(), InMemoryFileSystem);
-  DiagnosticsEngine Diagnostics(
-      IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs),
-      new DiagnosticOptions);
-  SourceManager SourceMgr(Diagnostics, Files);
-  Rewriter Rewrite(SourceMgr, LangOptions());
-  InMemoryFileSystem->addFile(
-      "<stdin>", 0, llvm::MemoryBuffer::getMemBuffer(Code, "<stdin>"));
-  FileID ID = SourceMgr.createFileID(Files.getFile("<stdin>"), SourceLocation(),
-                                     clang::SrcMgr::C_User);
-  for (Replacements::const_iterator I = Replaces.begin(), E = Replaces.end();
-       I != E; ++I) {
-    Replacement Replace("<stdin>", I->getOffset(), I->getLength(),
-                        I->getReplacementText());
-    if (!Replace.apply(Rewrite))
-      return llvm::make_error<llvm::StringError>(
-          "Failed to apply replacement: " + Replace.toString(),
-          llvm::inconvertibleErrorCode());
-  }
-  std::string Result;
-  llvm::raw_string_ostream OS(Result);
-  Rewrite.getEditBuffer(ID).write(OS);
-  OS.flush();
-  return Result;
-}
-
-// Merge and sort overlapping ranges in \p Ranges.
-static std::vector<Range> mergeAndSortRanges(std::vector<Range> Ranges) {
-  std::sort(Ranges.begin(), Ranges.end(),
-            [](const Range &LHS, const Range &RHS) {
-              if (LHS.getOffset() != RHS.getOffset())
-                return LHS.getOffset() < RHS.getOffset();
-              return LHS.getLength() < RHS.getLength();
-            });
-  std::vector<Range> Result;
-  for (const auto &R : Ranges) {
-    if (Result.empty() ||
-        Result.back().getOffset() + Result.back().getLength() < R.getOffset()) {
-      Result.push_back(R);
-    } else {
-      unsigned NewEnd =
-          std::max(Result.back().getOffset() + Result.back().getLength(),
-                   R.getOffset() + R.getLength());
-      Result[Result.size() - 1] =
-          Range(Result.back().getOffset(), NewEnd - Result.back().getOffset());
-    }
-  }
-  return Result;
-}
-
-std::vector<Range> calculateChangedRanges(const Replacements &Replaces) {
-  std::vector<Range> ChangedRanges;
-  int Shift = 0;
-  for (const Replacement &R : Replaces) {
-    unsigned Offset = R.getOffset() + Shift;
-    unsigned Length = R.getReplacementText().size();
-    Shift += Length - R.getLength();
-    ChangedRanges.push_back(Range(Offset, Length));
-  }
-  return mergeAndSortRanges(ChangedRanges);
-}
-
-std::vector<Range>
-calculateRangesAfterReplacements(const Replacements &Replaces,
-                                 const std::vector<Range> &Ranges) {
-  auto MergedRanges = mergeAndSortRanges(Ranges);
-  tooling::Replacements FakeReplaces;
-  for (const auto &R : MergedRanges)
-    FakeReplaces.insert(Replacement(Replaces.begin()->getFilePath(),
-                                    R.getOffset(), R.getLength(),
-                                    std::string(R.getLength(), ' ')));
-  tooling::Replacements NewReplaces = mergeReplacements(FakeReplaces, Replaces);
-  return calculateChangedRanges(NewReplaces);
+  Replaces.insert(R);
+  return llvm::Error::success();
 }
 
 namespace {
+
 // Represents a merged replacement, i.e. a replacement consisting of multiple
 // overlapping replacements from 'First' and 'Second' in mergeReplacements.
 //
@@ -424,26 +254,19 @@
   unsigned Length;
   std::string Text;
 };
+
 } // namespace
 
-std::map<std::string, Replacements>
-groupReplacementsByFile(const Replacements &Replaces) {
-  std::map<std::string, Replacements> FileToReplaces;
-  for (const auto &Replace : Replaces) {
-    FileToReplaces[Replace.getFilePath()].insert(Replace);
-  }
-  return FileToReplaces;
-}
+Replacements Replacements::merge(const Replacements &ReplacesToMerge) const {
+  if (empty() || ReplacesToMerge.empty())
+    return empty() ? ReplacesToMerge : *this;
 
-Replacements mergeReplacements(const Replacements &First,
-                               const Replacements &Second) {
-  if (First.empty() || Second.empty())
-    return First.empty() ? Second : First;
-
+  auto &First = Replaces;
+  auto &Second = ReplacesToMerge.Replaces;
   // Delta is the amount of characters that replacements from 'Second' need to
   // be shifted so that their offsets refer to the original text.
   int Delta = 0;
-  Replacements Result;
+  ReplacementsImpl Result;
 
   // Iterate over both sets and always add the next element (smallest total
   // Offset) from either 'First' or 'Second'. Merge that element with
@@ -469,8 +292,141 @@
     Delta -= Merged.deltaFirst();
     Result.insert(Merged.asReplacement());
   }
+  return Replacements(Result.begin(), Result.end());
+}
+
+// Combines overlapping ranges in \p Ranges and sorts the combined ranges.
+// Returns a set of non-overlapping and sorted ranges that is equivalent to
+// \p Ranges.
+static std::vector<Range> combineAndSortRanges(std::vector<Range> Ranges) {
+  std::sort(Ranges.begin(), Ranges.end(),
+            [](const Range &LHS, const Range &RHS) {
+              if (LHS.getOffset() != RHS.getOffset())
+                return LHS.getOffset() < RHS.getOffset();
+              return LHS.getLength() < RHS.getLength();
+            });
+  std::vector<Range> Result;
+  for (const auto &R : Ranges) {
+    if (Result.empty() ||
+        Result.back().getOffset() + Result.back().getLength() < R.getOffset()) {
+      Result.push_back(R);
+    } else {
+      unsigned NewEnd =
+          std::max(Result.back().getOffset() + Result.back().getLength(),
+                   R.getOffset() + R.getLength());
+      Result[Result.size() - 1] =
+          Range(Result.back().getOffset(), NewEnd - Result.back().getOffset());
+    }
+  }
   return Result;
 }
 
+std::vector<Range>
+calculateRangesAfterReplacements(const Replacements &Replaces,
+                                 const std::vector<Range> &Ranges) {
+  // To calculate the new ranges,
+  //   - Turn \p Ranges into Replacements at (offset, length) with an empty
+  //     (unimportant) replacement text of length "length".
+  //   - Merge with \p Replaces.
+  //   - The new ranges will be the affected ranges of the merged replacements.
+  auto MergedRanges = combineAndSortRanges(Ranges);
+  tooling::Replacements FakeReplaces;
+  for (const auto &R : MergedRanges) {
+    auto Err = FakeReplaces.add(Replacement(Replaces.begin()->getFilePath(),
+                                            R.getOffset(), R.getLength(),
+                                            std::string(R.getLength(), ' ')));
+    assert(!Err &&
+           "Replacements must not conflict since ranges have been merged.");
+    (void)Err;
+  }
+  return FakeReplaces.merge(Replaces).getAffectedRanges();
+}
+
+std::vector<Range> Replacements::getAffectedRanges() const {
+  std::vector<Range> ChangedRanges;
+  int Shift = 0;
+  for (const Replacement &R : Replaces) {
+    unsigned Offset = R.getOffset() + Shift;
+    unsigned Length = R.getReplacementText().size();
+    Shift += Length - R.getLength();
+    ChangedRanges.push_back(Range(Offset, Length));
+  }
+  return combineAndSortRanges(ChangedRanges);
+}
+
+unsigned Replacements::getShiftedCodePosition(unsigned Position) const {
+  unsigned Offset = 0;
+  for (const auto& R : Replaces) {
+    if (R.getOffset() + R.getLength() <= Position) {
+      Offset += R.getReplacementText().size() - R.getLength();
+      continue;
+    }
+    if (R.getOffset() < Position &&
+        R.getOffset() + R.getReplacementText().size() <= Position) {
+      Position = R.getOffset() + R.getReplacementText().size();
+      if (R.getReplacementText().size() > 0)
+        Position--;
+    }
+    break;
+  }
+  return Position + Offset;
+}
+
+bool applyAllReplacements(const Replacements &Replaces, Rewriter &Rewrite) {
+  bool Result = true;
+  for (Replacements::const_iterator I = Replaces.begin(),
+                                    E = Replaces.end();
+       I != E; ++I) {
+    if (I->isApplicable()) {
+      Result = I->apply(Rewrite) && Result;
+    } else {
+      Result = false;
+    }
+  }
+  return Result;
+}
+
+llvm::Expected<std::string> applyAllReplacements(StringRef Code,
+                                                const Replacements &Replaces) {
+  if (Replaces.empty())
+    return Code.str();
+
+  IntrusiveRefCntPtr<vfs::InMemoryFileSystem> InMemoryFileSystem(
+      new vfs::InMemoryFileSystem);
+  FileManager Files(FileSystemOptions(), InMemoryFileSystem);
+  DiagnosticsEngine Diagnostics(
+      IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs),
+      new DiagnosticOptions);
+  SourceManager SourceMgr(Diagnostics, Files);
+  Rewriter Rewrite(SourceMgr, LangOptions());
+  InMemoryFileSystem->addFile(
+      "<stdin>", 0, llvm::MemoryBuffer::getMemBuffer(Code, "<stdin>"));
+  FileID ID = SourceMgr.createFileID(Files.getFile("<stdin>"), SourceLocation(),
+                                     clang::SrcMgr::C_User);
+  for (Replacements::const_iterator I = Replaces.begin(), E = Replaces.end();
+       I != E; ++I) {
+    Replacement Replace("<stdin>", I->getOffset(), I->getLength(),
+                        I->getReplacementText());
+    if (!Replace.apply(Rewrite))
+      return llvm::make_error<llvm::StringError>(
+          "Failed to apply replacement: " + Replace.toString(),
+          llvm::inconvertibleErrorCode());
+  }
+  std::string Result;
+  llvm::raw_string_ostream OS(Result);
+  Rewrite.getEditBuffer(ID).write(OS);
+  OS.flush();
+  return Result;
+}
+
+std::map<std::string, Replacements>
+groupReplacementsByFile(const Replacements &Replaces) {
+  std::map<std::string, Replacements> FileToReplaces;
+  for (const auto &Replace : Replaces)
+    // We can ignore the Error here since \p Replaces is already conflict-free.
+    FileToReplaces[Replace.getFilePath()].add(Replace);
+  return FileToReplaces;
+}
+
 } // end namespace tooling
 } // end namespace clang
diff --git a/clang/lib/Tooling/Refactoring.cpp b/clang/lib/Tooling/Refactoring.cpp
index d48713c..5565b54 100644
--- a/clang/lib/Tooling/Refactoring.cpp
+++ b/clang/lib/Tooling/Refactoring.cpp
@@ -30,7 +30,9 @@
     std::shared_ptr<PCHContainerOperations> PCHContainerOps)
     : ClangTool(Compilations, SourcePaths, PCHContainerOps) {}
 
-Replacements &RefactoringTool::getReplacements() { return Replace; }
+std::map<std::string, Replacements> &RefactoringTool::getReplacements() {
+  return FileToReplaces;
+}
 
 int RefactoringTool::runAndSave(FrontendActionFactory *ActionFactory) {
   if (int Result = run(ActionFactory)) {
@@ -54,20 +56,22 @@
 }
 
 bool RefactoringTool::applyAllReplacements(Rewriter &Rewrite) {
-  return tooling::applyAllReplacements(Replace, Rewrite);
+  bool Result = true;
+  for (const auto &Entry : FileToReplaces)
+    Result = tooling::applyAllReplacements(Entry.second, Rewrite) && Result;
+  return Result;
 }
 
 int RefactoringTool::saveRewrittenFiles(Rewriter &Rewrite) {
   return Rewrite.overwriteChangedFiles() ? 1 : 0;
 }
 
-bool formatAndApplyAllReplacements(const Replacements &Replaces,
-                                   Rewriter &Rewrite, StringRef Style) {
+bool formatAndApplyAllReplacements(
+    const std::map<std::string, Replacements> &FileToReplaces, Rewriter &Rewrite,
+    StringRef Style) {
   SourceManager &SM = Rewrite.getSourceMgr();
   FileManager &Files = SM.getFileManager();
 
-  auto FileToReplaces = groupReplacementsByFile(Replaces);
-
   bool Result = true;
   for (const auto &FileAndReplaces : FileToReplaces) {
     const std::string &FilePath = FileAndReplaces.first;
diff --git a/clang/lib/Tooling/RefactoringCallbacks.cpp b/clang/lib/Tooling/RefactoringCallbacks.cpp
index 4de125e..af25fd8 100644
--- a/clang/lib/Tooling/RefactoringCallbacks.cpp
+++ b/clang/lib/Tooling/RefactoringCallbacks.cpp
@@ -40,10 +40,14 @@
 void ReplaceStmtWithText::run(
     const ast_matchers::MatchFinder::MatchResult &Result) {
   if (const Stmt *FromMatch = Result.Nodes.getStmtAs<Stmt>(FromId)) {
-    Replace.insert(tooling::Replacement(
+    auto Err = Replace.add(tooling::Replacement(
         *Result.SourceManager,
-        CharSourceRange::getTokenRange(FromMatch->getSourceRange()),
-        ToText));
+        CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
+    // FIXME: better error handling. For now, just print error message in the
+    // release version.
+    if (Err)
+      llvm::errs() << llvm::toString(std::move(Err)) << "\n";
+    assert(!Err);
   }
 }
 
@@ -54,9 +58,15 @@
     const ast_matchers::MatchFinder::MatchResult &Result) {
   const Stmt *FromMatch = Result.Nodes.getStmtAs<Stmt>(FromId);
   const Stmt *ToMatch = Result.Nodes.getStmtAs<Stmt>(ToId);
-  if (FromMatch && ToMatch)
-    Replace.insert(replaceStmtWithStmt(
-        *Result.SourceManager, *FromMatch, *ToMatch));
+  if (FromMatch && ToMatch) {
+    auto Err = Replace.add(
+        replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
+    // FIXME: better error handling. For now, just print error message in the
+    // release version.
+    if (Err)
+      llvm::errs() << llvm::toString(std::move(Err)) << "\n";
+    assert(!Err);
+  }
 }
 
 ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
@@ -68,11 +78,23 @@
   if (const IfStmt *Node = Result.Nodes.getStmtAs<IfStmt>(Id)) {
     const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
     if (Body) {
-      Replace.insert(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
+      auto Err =
+          Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
+      // FIXME: better error handling. For now, just print error message in the
+      // release version.
+      if (Err)
+        llvm::errs() << llvm::toString(std::move(Err)) << "\n";
+      assert(!Err);
     } else if (!PickTrueBranch) {
       // If we want to use the 'else'-branch, but it doesn't exist, delete
       // the whole 'if'.
-      Replace.insert(replaceStmtWithText(*Result.SourceManager, *Node, ""));
+      auto Err =
+          Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
+      // FIXME: better error handling. For now, just print error message in the
+      // release version.
+      if (Err)
+        llvm::errs() << llvm::toString(std::move(Err)) << "\n";
+      assert(!Err);
     }
   }
 }