Make AidlNode visitable

Pull up DispatchVisit/TraverseChildren to AidlNode.
AidlConstantValue::Visitor is replaced with AidlVisitor as well.

Bug: none
Test: aidl_unittests
Change-Id: I51e0e3fdb06f7deeac9a28d7797e5fe738c28bb2
diff --git a/aidl_language.cpp b/aidl_language.cpp
index 9f8e638..5389a4f 100644
--- a/aidl_language.cpp
+++ b/aidl_language.cpp
@@ -188,14 +188,16 @@
     std::map<std::string, std::shared_ptr<AidlConstantValue>>&& parameters)
     : AidlNode(location), schema_(schema), parameters_(std::move(parameters)) {}
 
-struct ConstReferenceFinder : AidlConstantValue::Visitor {
+struct ConstReferenceFinder : AidlVisitor {
   const AidlConstantReference* found;
-  void Visit(const AidlConstantValue&) override {}
-  void Visit(const AidlUnaryConstExpression&) override {}
-  void Visit(const AidlBinaryConstExpression&) override {}
   void Visit(const AidlConstantReference& ref) override {
     if (!found) found = &ref;
   }
+  static const AidlConstantReference* Find(const AidlConstantValue& c) {
+    ConstReferenceFinder finder;
+    VisitTopDown(finder, c);
+    return finder.found;
+  }
 };
 
 bool AidlAnnotation::CheckValid() const {
@@ -216,11 +218,10 @@
       return false;
     }
 
-    ConstReferenceFinder finder;
-    param->Accept(finder);
-    if (finder.found) {
-      AIDL_ERROR(finder.found) << "Value must be a constant expression but contains reference to "
-                               << finder.found->GetFieldName() << ".";
+    const auto& found = ConstReferenceFinder::Find(*param);
+    if (found) {
+      AIDL_ERROR(found) << "Value must be a constant expression but contains reference to "
+                        << found->GetFieldName() << ".";
       return false;
     }
 
@@ -288,6 +289,13 @@
   }
 }
 
+void AidlAnnotation::TraverseChildren(std::function<void(const AidlNode&)> traverse) const {
+  for (const auto& [name, value] : parameters_) {
+    (void)name;
+    traverse(*value);
+  }
+}
+
 static const AidlAnnotation* GetAnnotation(const vector<AidlAnnotation>& annotations,
                                            AidlAnnotation::Type type) {
   for (const auto& a : annotations) {
@@ -705,6 +713,14 @@
   }
 }
 
+void AidlVariableDeclaration::TraverseChildren(
+    std::function<void(const AidlNode&)> traverse) const {
+  traverse(GetType());
+  if (IsDefaultUserSpecified()) {
+    traverse(*GetDefaultValue());
+  }
+}
+
 AidlArgument::AidlArgument(const AidlLocation& location, AidlArgument::Direction direction,
                            AidlTypeSpecifier* type, const std::string& name)
     : AidlVariableDeclaration(location, type, name),
diff --git a/aidl_language.h b/aidl_language.h
index 0ea5ce4..15428c4 100644
--- a/aidl_language.h
+++ b/aidl_language.h
@@ -76,6 +76,7 @@
 bool ParseFloating(std::string_view sv, float* parsed);
 
 class AidlDocument;
+class AidlImport;
 class AidlInterface;
 class AidlParcelable;
 class AidlStructuredParcelable;
@@ -86,6 +87,11 @@
 class AidlEnumerator;
 class AidlMethod;
 class AidlArgument;
+class AidlConstantValue;
+class AidlConstantReference;
+class AidlUnaryConstExpression;
+class AidlBinaryConstExpression;
+class AidlAnnotation;
 
 // Interface for visitors that can traverse AidlTraversable nodes.
 class AidlVisitor {
@@ -103,6 +109,12 @@
   virtual void Visit(const AidlConstantDeclaration&) {}
   virtual void Visit(const AidlArgument&) {}
   virtual void Visit(const AidlTypeSpecifier&) {}
+  virtual void Visit(const AidlConstantValue&) {}
+  virtual void Visit(const AidlConstantReference&) {}
+  virtual void Visit(const AidlUnaryConstExpression&) {}
+  virtual void Visit(const AidlBinaryConstExpression&) {}
+  virtual void Visit(const AidlAnnotation&) {}
+  virtual void Visit(const AidlImport&) {}
 };
 
 // Anything that is locatable in a .aidl file.
@@ -122,6 +134,8 @@
   friend std::string android::aidl::java::dump_location(const AidlNode&);
 
   const AidlLocation& GetLocation() const { return location_; }
+  virtual void TraverseChildren(std::function<void(const AidlNode&)> traverse) const = 0;
+  virtual void DispatchVisit(AidlVisitor&) const = 0;
 
  private:
   std::string PrintLine() const;
@@ -129,15 +143,6 @@
   const AidlLocation location_;
 };
 
-// Anything that is traversable by the AidlVisitor
-class AidlTraversable {
- public:
-  virtual ~AidlTraversable() = default;
-
-  virtual void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
-  virtual void DispatchVisit(AidlVisitor&) const = 0;
-};
-
 // unique_ptr<AidlTypeSpecifier> for type arugment,
 // std::string for type parameter(T, U, and so on).
 template <typename T>
@@ -222,6 +227,8 @@
       const ConstantValueDecorator& decorator) const;
   const string& GetComments() const { return comments_; }
   void SetComments(const string& comments) { comments_ = comments; }
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   struct Schema {
@@ -292,6 +299,11 @@
 
   const vector<AidlAnnotation>& GetAnnotations() const { return annotations_; }
   bool CheckValid(const AidlTypenames&) const;
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    for (const auto& annot : GetAnnotations()) {
+      traverse(annot);
+    }
+  }
 
  protected:
   virtual std::set<AidlAnnotation::Type> GetSupportedAnnotations() const = 0;
@@ -303,7 +315,6 @@
 // AidlTypeSpecifier represents a reference to either a built-in type,
 // a defined type, or a variant (e.g., array of generic) of a type.
 class AidlTypeSpecifier final : public AidlAnnotatable,
-                                public AidlTraversable,
                                 public AidlParameterizable<unique_ptr<AidlTypeSpecifier>> {
  public:
   AidlTypeSpecifier(const AidlLocation& location, const string& unresolved_name, bool is_array,
@@ -368,7 +379,8 @@
   const AidlNode& AsAidlNode() const override { return *this; }
 
   const AidlDefinedType* GetDefinedType() const;
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlAnnotatable::TraverseChildren(traverse);
     if (IsGeneric()) {
       for (const auto& tp : GetTypeParameters()) {
         traverse(*tp);
@@ -392,12 +404,7 @@
 // Returns the universal value unaltered.
 std::string AidlConstantValueDecorator(const AidlTypeSpecifier& type, const std::string& raw_value);
 
-class AidlConstantValue;
-class AidlMethod;
-class AidlConstantDeclaration;
-class AidlVariableDeclaration;
-
-class AidlMember : public AidlNode, public AidlTraversable {
+class AidlMember : public AidlNode {
  public:
   AidlMember(const AidlLocation& location);
   virtual ~AidlMember() = default;
@@ -423,9 +430,6 @@
     return const_cast<AidlVariableDeclaration*>(
         const_cast<const AidlMember*>(this)->AsVariableDeclaration());
   }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
-  void DispatchVisit(AidlVisitor& v) const = 0;
 };
 
 // TODO: This class is used for method arguments and also parcelable fields,
@@ -475,9 +479,7 @@
 
   std::string ValueString(const ConstantValueDecorator& decorator) const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    traverse(GetType());
-  }
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
@@ -514,7 +516,7 @@
   // e.g) "in @utf8InCpp String[] names"
   std::string ToString() const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
@@ -547,14 +549,6 @@
     ERROR,
   };
 
-  struct Visitor {
-    virtual ~Visitor() {}
-    virtual void Visit(const AidlConstantValue&) = 0;
-    virtual void Visit(const AidlConstantReference&) = 0;
-    virtual void Visit(const AidlUnaryConstExpression&) = 0;
-    virtual void Visit(const AidlBinaryConstExpression&) = 0;
-  };
-
   // Returns the evaluated value. T> should match to the actual type.
   template <typename T>
   T EvaluatedValue() const {
@@ -620,14 +614,15 @@
 
   // Raw value of type (currently valid in C++ and Java). Empty string on error.
   string ValueString(const AidlTypeSpecifier& type, const ConstantValueDecorator& decorator) const;
-  virtual void Accept(Visitor& visitor) {
-    visitor.Visit(*this);
+
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const {
     if (type_ == Type::ARRAY) {
       for (const auto& v : values_) {
-        v.get()->Accept(visitor);
+        traverse(*v);
       }
     }
   }
+  void DispatchVisit(AidlVisitor& visitor) const override { visitor.Visit(*this); }
 
  private:
   AidlConstantValue(const AidlLocation& location, Type parsed_type, int64_t parsed_value,
@@ -672,7 +667,10 @@
   const std::string& GetComments() const { return comments_; }
 
   bool CheckValid() const override;
-  void Accept(Visitor& visitor) override { visitor.Visit(*this); }
+  void TraverseChildren(std::function<void(const AidlNode&)>) const override {
+    // resolved_ is not my child.
+  }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
   const AidlConstantValue* Resolve(const AidlDefinedType* scope) const;
 
  private:
@@ -691,10 +689,10 @@
 
   static bool IsCompatibleType(Type type, const string& op);
   bool CheckValid() const override;
-  void Accept(Visitor& visitor) override {
-    visitor.Visit(*this);
-    unary_->Accept(visitor);
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*unary_);
   }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   bool evaluate() const override;
@@ -715,11 +713,11 @@
   static Type UsualArithmeticConversion(Type left, Type right);
   // Returns the promoted integral type where INT32 is the smallest type
   static Type IntegralPromotion(Type in);
-  void Accept(Visitor& visitor) override {
-    visitor.Visit(*this);
-    left_val_->Accept(visitor);
-    right_val_->Accept(visitor);
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*left_val_);
+    traverse(*right_val_);
   }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   bool evaluate() const override;
@@ -770,8 +768,9 @@
 
   const AidlConstantDeclaration* AsConstantDeclaration() const override { return this; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
+    traverse(GetValue());
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
@@ -837,7 +836,7 @@
   // e.g) "foo(int, String)"
   std::string Signature() const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
     for (const auto& a : GetArguments()) {
       traverse(*a);
@@ -858,15 +857,9 @@
   bool is_user_defined_ = true;
 };
 
-class AidlDefinedType;
-class AidlInterface;
-class AidlParcelable;
-class AidlStructuredParcelable;
-class AidlUnionDecl;
-
 // AidlDefinedType represents either an interface, parcelable, or enum that is
 // defined in the source file.
-class AidlDefinedType : public AidlAnnotatable, public AidlTraversable {
+class AidlDefinedType : public AidlAnnotatable {
  public:
   AidlDefinedType(const AidlLocation& location, const std::string& name,
                   const std::string& comments, const std::string& package,
@@ -947,9 +940,12 @@
   const std::vector<std::unique_ptr<AidlMethod>>& GetMethods() const { return methods_; }
   void AddMethod(std::unique_ptr<AidlMethod> method) { methods_.push_back(std::move(method)); }
   const std::vector<const AidlMember*>& GetMembers() const { return members_; }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)>) const = 0;
-  void DispatchVisit(AidlVisitor& v) const = 0;
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlAnnotatable::TraverseChildren(traverse);
+    for (const auto c : GetMembers()) {
+      traverse(*c);
+    }
+  }
 
  protected:
   // utility for subclasses with getter names
@@ -995,11 +991,6 @@
 
   void Dump(CodeWriter* writer) const override;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
@@ -1030,15 +1021,10 @@
   bool LanguageSpecificCheckValid(const AidlTypenames& typenames,
                                   Options::Language lang) const override;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
-class AidlEnumerator : public AidlNode, public AidlTraversable {
+class AidlEnumerator : public AidlNode {
  public:
   AidlEnumerator(const AidlLocation& location, const std::string& name, AidlConstantValue* value,
                  const std::string& comments);
@@ -1061,8 +1047,8 @@
   void SetValue(std::unique_ptr<AidlConstantValue> value) { value_ = std::move(value); }
   bool IsValueUserSpecified() const { return value_user_specified_; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)>) const override {
-    // no children to traverse
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*value_);
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
@@ -1102,7 +1088,8 @@
 
   const AidlEnumDeclaration* AsEnumDeclaration() const override { return this; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlDefinedType::TraverseChildren(traverse);
     for (const auto& c : GetEnumerators()) {
       traverse(*c);
     }
@@ -1139,12 +1126,6 @@
 
   void Dump(CodeWriter* writer) const override;
   const AidlUnionDecl* AsUnionDeclaration() const override { return this; }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
@@ -1172,12 +1153,6 @@
                                   Options::Language lang) const override;
 
   std::string GetDescriptor() const;
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
@@ -1193,13 +1168,15 @@
   AidlImport& operator=(AidlImport&&) = delete;
 
   const std::string& GetNeededClass() const { return needed_class_; }
+  void TraverseChildren(std::function<void(const AidlNode&)>) const {}
+  void DispatchVisit(AidlVisitor& v) const { v.Visit(*this); }
 
  private:
   std::string needed_class_;
 };
 
 // AidlDocument models an AIDL file
-class AidlDocument : public AidlNode, public AidlTraversable {
+class AidlDocument : public AidlNode {
  public:
   AidlDocument(const AidlLocation& location, std::vector<std::unique_ptr<AidlImport>>& imports,
                std::vector<std::unique_ptr<AidlDefinedType>>&& defined_types)
@@ -1220,7 +1197,10 @@
     return defined_types_;
   }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    for (const auto& i : Imports()) {
+      traverse(*i);
+    }
     for (const auto& t : DefinedTypes()) {
       traverse(*t);
     }
@@ -1240,3 +1220,16 @@
   }
   return it->second->EvaluatedValue<T>();
 }
+
+// Utility to make a visitor to visit AST tree in top-down order
+// Given:       foo
+//              / \
+//            bar baz
+// VisitTopDown(v, foo) makes v visit foo -> bar -> baz.
+inline void VisitTopDown(AidlVisitor& v, const AidlNode& node) {
+  std::function<void(const AidlNode&)> top_down = [&](const AidlNode& n) {
+    n.DispatchVisit(v);
+    n.TraverseChildren(top_down);
+  };
+  top_down(node);
+}
\ No newline at end of file
diff --git a/diagnostics.cpp b/diagnostics.cpp
index ca738b3..1e00eef 100644
--- a/diagnostics.cpp
+++ b/diagnostics.cpp
@@ -109,14 +109,14 @@
     };
     Hook suppress{std::bind(&DiagnosticsContext::Suppress, &diag, _1)};
     Hook restore{std::bind(&DiagnosticsContext::Restore, &diag, _1)};
-    std::function<void(const AidlTraversable&)> topDown =
-        [&topDown, &suppress, &restore, visitor](const AidlTraversable& a) {
+    std::function<void(const AidlNode&)> top_down = [&top_down, &suppress, &restore,
+                                                     visitor](const AidlNode& a) {
       a.DispatchVisit(suppress);
       a.DispatchVisit(*visitor);
-      a.TraverseChildren(topDown);
+      a.TraverseChildren(top_down);
       a.DispatchVisit(restore);
     };
-    topDown(doc);
+    top_down(doc);
   }
  protected:
   DiagnosticsContext& diag;
diff --git a/parser.cpp b/parser.cpp
index 897efa0..083e75c 100644
--- a/parser.cpp
+++ b/parser.cpp
@@ -61,14 +61,11 @@
   }
 }
 
-class ConstantReferenceResolver : public AidlConstantValue::Visitor {
+class ConstantReferenceResolver : public AidlVisitor {
  public:
   ConstantReferenceResolver(const AidlDefinedType* scope, const AidlTypenames& typenames,
                             TypeResolver& resolver, bool* success)
       : scope_(scope), typenames_(typenames), resolver_(resolver), success_(success) {}
-  void Visit(const AidlConstantValue&) override {}
-  void Visit(const AidlUnaryConstExpression&) override {}
-  void Visit(const AidlBinaryConstExpression&) override {}
   void Visit(const AidlConstantReference& v) override {
     if (IsCircularReference(&v)) {
       *success_ = false;
@@ -91,7 +88,7 @@
 
     // resolve recursive references
     Push(&v);
-    const_cast<AidlConstantValue*>(resolved)->Accept(*this);
+    VisitTopDown(*this, *resolved);
     Pop();
   }
 
@@ -149,22 +146,7 @@
   // resolve "field references" as well.
   for (const auto& type : document_->DefinedTypes()) {
     ConstantReferenceResolver ref_resolver{type.get(), typenames_, type_resolver, &success};
-    if (auto enum_type = type->AsEnumDeclaration(); enum_type) {
-      for (const auto& enumerator : enum_type->GetEnumerators()) {
-        if (auto value = enumerator->GetValue(); value) {
-          value->Accept(ref_resolver);
-        }
-      }
-    } else {
-      for (const auto& constant : type->GetConstantDeclarations()) {
-        const_cast<AidlConstantValue&>(constant->GetValue()).Accept(ref_resolver);
-      }
-      for (const auto& field : type->GetFields()) {
-        if (field->IsDefaultUserSpecified()) {
-          const_cast<AidlConstantValue*>(field->GetDefaultValue())->Accept(ref_resolver);
-        }
-      }
-    }
+    VisitTopDown(ref_resolver, *type);
   }
 
   return success;