[clangd] Start using SyntaxTrees for folding ranges feature

This is an initial attempt to start using Syntax Trees in clangd while improving state of folding ranges feature and experimenting with Syntax Tree capabilities.

Reviewed By: sammccall

Differential Revision: https://reviews.llvm.org/D88553
diff --git a/clang-tools-extra/clangd/SemanticSelection.cpp b/clang-tools-extra/clangd/SemanticSelection.cpp
index cfce152..b855d7b 100644
--- a/clang-tools-extra/clangd/SemanticSelection.cpp
+++ b/clang-tools-extra/clangd/SemanticSelection.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+
 #include "SemanticSelection.h"
 #include "FindSymbols.h"
 #include "ParsedAST.h"
@@ -13,8 +14,16 @@
 #include "SourceCode.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceLocation.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Basic/TokenKinds.h"
+#include "clang/Tooling/Syntax/BuildTree.h"
+#include "clang/Tooling/Syntax/Nodes.h"
+#include "clang/Tooling/Syntax/Tree.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Error.h"
+#include <queue>
+#include <vector>
 
 namespace clang {
 namespace clangd {
@@ -28,17 +37,65 @@
   }
 }
 
-// Recursively collects FoldingRange from a symbol and its children.
-void collectFoldingRanges(DocumentSymbol Symbol,
-                          std::vector<FoldingRange> &Result) {
+llvm::Optional<FoldingRange> toFoldingRange(SourceRange SR,
+                                            const SourceManager &SM) {
+  const auto Begin = SM.getDecomposedLoc(SR.getBegin()),
+             End = SM.getDecomposedLoc(SR.getEnd());
+  // Do not produce folding ranges if either range ends is not within the main
+  // file. Macros have their own FileID so this also checks if locations are not
+  // within the macros.
+  if ((Begin.first != SM.getMainFileID()) || (End.first != SM.getMainFileID()))
+    return llvm::None;
   FoldingRange Range;
-  Range.startLine = Symbol.range.start.line;
-  Range.startCharacter = Symbol.range.start.character;
-  Range.endLine = Symbol.range.end.line;
-  Range.endCharacter = Symbol.range.end.character;
-  Result.push_back(Range);
-  for (const auto &Child : Symbol.children)
-    collectFoldingRanges(Child, Result);
+  Range.startCharacter = SM.getColumnNumber(Begin.first, Begin.second) - 1;
+  Range.startLine = SM.getLineNumber(Begin.first, Begin.second) - 1;
+  Range.endCharacter = SM.getColumnNumber(End.first, End.second) - 1;
+  Range.endLine = SM.getLineNumber(End.first, End.second) - 1;
+  return Range;
+}
+
+llvm::Optional<FoldingRange> extractFoldingRange(const syntax::Node *Node,
+                                                 const SourceManager &SM) {
+  if (const auto *Stmt = dyn_cast<syntax::CompoundStatement>(Node)) {
+    const auto *LBrace = cast_or_null<syntax::Leaf>(
+        Stmt->findChild(syntax::NodeRole::OpenParen));
+    // FIXME(kirillbobyrev): This should find the last child. Compound
+    // statements have only one pair of braces so this is valid but for other
+    // node kinds it might not be correct.
+    const auto *RBrace = cast_or_null<syntax::Leaf>(
+        Stmt->findChild(syntax::NodeRole::CloseParen));
+    if (!LBrace || !RBrace)
+      return llvm::None;
+    // Fold the entire range within braces, including whitespace.
+    const SourceLocation LBraceLocInfo = LBrace->getToken()->endLocation(),
+                         RBraceLocInfo = RBrace->getToken()->location();
+    auto Range = toFoldingRange(SourceRange(LBraceLocInfo, RBraceLocInfo), SM);
+    // Do not generate folding range for compound statements without any
+    // nodes and newlines.
+    if (Range && Range->startLine != Range->endLine)
+      return Range;
+  }
+  return llvm::None;
+}
+
+// Traverse the tree and collect folding ranges along the way.
+std::vector<FoldingRange> collectFoldingRanges(const syntax::Node *Root,
+                                               const SourceManager &SM) {
+  std::queue<const syntax::Node *> Nodes;
+  Nodes.push(Root);
+  std::vector<FoldingRange> Result;
+  while (!Nodes.empty()) {
+    const syntax::Node *Node = Nodes.front();
+    Nodes.pop();
+    const auto Range = extractFoldingRange(Node, SM);
+    if (Range)
+      Result.push_back(*Range);
+    if (const auto *T = dyn_cast<syntax::Tree>(Node))
+      for (const auto *NextNode = T->getFirstChild(); NextNode;
+           NextNode = NextNode->getNextSibling())
+        Nodes.push(NextNode);
+  }
+  return Result;
 }
 
 } // namespace
@@ -100,20 +157,12 @@
 // FIXME(kirillbobyrev): Collect comments, PP conditional regions, includes and
 // other code regions (e.g. public/private/protected sections of classes,
 // control flow statement bodies).
-// Related issue:
-// https://github.com/clangd/clangd/issues/310
+// Related issue: https://github.com/clangd/clangd/issues/310
 llvm::Expected<std::vector<FoldingRange>> getFoldingRanges(ParsedAST &AST) {
-  // FIXME(kirillbobyrev): getDocumentSymbols() is conveniently available but
-  // limited (e.g. doesn't yield blocks inside functions and provides ranges for
-  // nodes themselves instead of their contents which is less useful). Replace
-  // this with a more general RecursiveASTVisitor implementation instead.
-  auto DocumentSymbols = getDocumentSymbols(AST);
-  if (!DocumentSymbols)
-    return DocumentSymbols.takeError();
-  std::vector<FoldingRange> Result;
-  for (const auto &Symbol : *DocumentSymbols)
-    collectFoldingRanges(Symbol, Result);
-  return Result;
+  syntax::Arena A(AST.getSourceManager(), AST.getLangOpts(), AST.getTokens());
+  const auto *SyntaxTree =
+      syntax::buildSyntaxTree(A, *AST.getASTContext().getTranslationUnitDecl());
+  return collectFoldingRanges(SyntaxTree, AST.getSourceManager());
 }
 
 } // namespace clangd
diff --git a/clang-tools-extra/clangd/unittests/SemanticSelectionTests.cpp b/clang-tools-extra/clangd/unittests/SemanticSelectionTests.cpp
index 5c1a80a..1138ce7 100644
--- a/clang-tools-extra/clangd/unittests/SemanticSelectionTests.cpp
+++ b/clang-tools-extra/clangd/unittests/SemanticSelectionTests.cpp
@@ -203,26 +203,61 @@
 TEST(FoldingRanges, All) {
   const char *Tests[] = {
       R"cpp(
-        [[int global_variable]];
+        #define FOO int foo() {\
+          int Variable = 42; \
+        }
 
-        [[void func() {
-          int v = 100;
-        }]]
+        // Do not generate folding range for braces within macro expansion.
+        FOO
+
+        // Do not generate folding range within macro arguments.
+        #define FUNCTOR(functor) functor
+        void func() {[[
+          FUNCTOR([](){});
+        ]]}
+
+        // Do not generate folding range with a brace coming from macro.
+        #define LBRACE {
+        void bar() LBRACE
+          int X = 42;
+        }
       )cpp",
       R"cpp(
-        [[class Foo {
+        void func() {[[
+          int Variable = 100;
+
+          if (Variable > 5) {[[
+            Variable += 42;
+          ]]} else if (Variable++)
+            ++Variable;
+          else {[[
+            Variable--;
+          ]]}
+
+          // Do not generate FoldingRange for empty CompoundStmts.
+          for (;;) {}
+
+          // If there are newlines between {}, we should generate one.
+          for (;;) {[[
+
+          ]]}
+        ]]}
+      )cpp",
+      R"cpp(
+        class Foo {
         public:
-          [[Foo() {
+          Foo() {[[
             int X = 1;
-          }]]
+          ]]}
 
         private:
-          [[int getBar() {
+          int getBar() {[[
             return 42;
-          }]]
+          ]]}
 
-          [[void getFooBar() { }]]
-        }]];
+          // Braces are located at the same line: no folding range here.
+          void getFooBar() { }
+        };
       )cpp",
   };
   for (const char *Test : Tests) {
diff --git a/clang/include/clang/Tooling/Syntax/Tree.h b/clang/include/clang/Tooling/Syntax/Tree.h
index c643045..e1fd3a2 100644
--- a/clang/include/clang/Tooling/Syntax/Tree.h
+++ b/clang/include/clang/Tooling/Syntax/Tree.h
@@ -168,20 +168,24 @@
   Node *getFirstChild() { return FirstChild; }
   const Node *getFirstChild() const { return FirstChild; }
 
-  Leaf *findFirstLeaf();
-  const Leaf *findFirstLeaf() const {
-    return const_cast<Tree *>(this)->findFirstLeaf();
+  const Leaf *findFirstLeaf() const;
+  Leaf *findFirstLeaf() {
+    return const_cast<Leaf *>(const_cast<const Tree *>(this)->findFirstLeaf());
   }
 
-  Leaf *findLastLeaf();
-  const Leaf *findLastLeaf() const {
-    return const_cast<Tree *>(this)->findLastLeaf();
+  const Leaf *findLastLeaf() const;
+  Leaf *findLastLeaf() {
+    return const_cast<Leaf *>(const_cast<const Tree *>(this)->findLastLeaf());
+  }
+
+  /// Find the first node with a corresponding role.
+  const Node *findChild(NodeRole R) const;
+  Node *findChild(NodeRole R) {
+    return const_cast<Node *>(const_cast<const Tree *>(this)->findChild(R));
   }
 
 protected:
   using Node::Node;
-  /// Find the first node with a corresponding role.
-  Node *findChild(NodeRole R);
 
 private:
   /// Prepend \p Child to the list of children and and sets the parent pointer.
diff --git a/clang/lib/Tooling/Syntax/Tree.cpp b/clang/lib/Tooling/Syntax/Tree.cpp
index 87526ad..9904d14 100644
--- a/clang/lib/Tooling/Syntax/Tree.cpp
+++ b/clang/lib/Tooling/Syntax/Tree.cpp
@@ -271,29 +271,29 @@
 #endif
 }
 
-syntax::Leaf *syntax::Tree::findFirstLeaf() {
-  for (auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    if (auto *L = dyn_cast<syntax::Leaf>(C))
+const syntax::Leaf *syntax::Tree::findFirstLeaf() const {
+  for (const auto *C = getFirstChild(); C; C = C->getNextSibling()) {
+    if (const auto *L = dyn_cast<syntax::Leaf>(C))
       return L;
-    if (auto *L = cast<syntax::Tree>(C)->findFirstLeaf())
+    if (const auto *L = cast<syntax::Tree>(C)->findFirstLeaf())
       return L;
   }
   return nullptr;
 }
 
-syntax::Leaf *syntax::Tree::findLastLeaf() {
-  syntax::Leaf *Last = nullptr;
-  for (auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    if (auto *L = dyn_cast<syntax::Leaf>(C))
+const syntax::Leaf *syntax::Tree::findLastLeaf() const {
+  const syntax::Leaf *Last = nullptr;
+  for (const auto *C = getFirstChild(); C; C = C->getNextSibling()) {
+    if (const auto *L = dyn_cast<syntax::Leaf>(C))
       Last = L;
-    else if (auto *L = cast<syntax::Tree>(C)->findLastLeaf())
+    else if (const auto *L = cast<syntax::Tree>(C)->findLastLeaf())
       Last = L;
   }
   return Last;
 }
 
-syntax::Node *syntax::Tree::findChild(NodeRole R) {
-  for (auto *C = FirstChild; C; C = C->getNextSibling()) {
+const syntax::Node *syntax::Tree::findChild(NodeRole R) const {
+  for (const auto *C = FirstChild; C; C = C->getNextSibling()) {
     if (C->getRole() == R)
       return C;
   }