Make AidlNode visitable

Pull up DispatchVisit/TraverseChildren to AidlNode.
AidlConstantValue::Visitor is replaced with AidlVisitor as well.

Bug: none
Test: aidl_unittests
Change-Id: I51e0e3fdb06f7deeac9a28d7797e5fe738c28bb2
diff --git a/aidl_language.h b/aidl_language.h
index 0ea5ce4..15428c4 100644
--- a/aidl_language.h
+++ b/aidl_language.h
@@ -76,6 +76,7 @@
 bool ParseFloating(std::string_view sv, float* parsed);
 
 class AidlDocument;
+class AidlImport;
 class AidlInterface;
 class AidlParcelable;
 class AidlStructuredParcelable;
@@ -86,6 +87,11 @@
 class AidlEnumerator;
 class AidlMethod;
 class AidlArgument;
+class AidlConstantValue;
+class AidlConstantReference;
+class AidlUnaryConstExpression;
+class AidlBinaryConstExpression;
+class AidlAnnotation;
 
 // Interface for visitors that can traverse AidlTraversable nodes.
 class AidlVisitor {
@@ -103,6 +109,12 @@
   virtual void Visit(const AidlConstantDeclaration&) {}
   virtual void Visit(const AidlArgument&) {}
   virtual void Visit(const AidlTypeSpecifier&) {}
+  virtual void Visit(const AidlConstantValue&) {}
+  virtual void Visit(const AidlConstantReference&) {}
+  virtual void Visit(const AidlUnaryConstExpression&) {}
+  virtual void Visit(const AidlBinaryConstExpression&) {}
+  virtual void Visit(const AidlAnnotation&) {}
+  virtual void Visit(const AidlImport&) {}
 };
 
 // Anything that is locatable in a .aidl file.
@@ -122,6 +134,8 @@
   friend std::string android::aidl::java::dump_location(const AidlNode&);
 
   const AidlLocation& GetLocation() const { return location_; }
+  virtual void TraverseChildren(std::function<void(const AidlNode&)> traverse) const = 0;
+  virtual void DispatchVisit(AidlVisitor&) const = 0;
 
  private:
   std::string PrintLine() const;
@@ -129,15 +143,6 @@
   const AidlLocation location_;
 };
 
-// Anything that is traversable by the AidlVisitor
-class AidlTraversable {
- public:
-  virtual ~AidlTraversable() = default;
-
-  virtual void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
-  virtual void DispatchVisit(AidlVisitor&) const = 0;
-};
-
 // unique_ptr<AidlTypeSpecifier> for type arugment,
 // std::string for type parameter(T, U, and so on).
 template <typename T>
@@ -222,6 +227,8 @@
       const ConstantValueDecorator& decorator) const;
   const string& GetComments() const { return comments_; }
   void SetComments(const string& comments) { comments_ = comments; }
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   struct Schema {
@@ -292,6 +299,11 @@
 
   const vector<AidlAnnotation>& GetAnnotations() const { return annotations_; }
   bool CheckValid(const AidlTypenames&) const;
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    for (const auto& annot : GetAnnotations()) {
+      traverse(annot);
+    }
+  }
 
  protected:
   virtual std::set<AidlAnnotation::Type> GetSupportedAnnotations() const = 0;
@@ -303,7 +315,6 @@
 // AidlTypeSpecifier represents a reference to either a built-in type,
 // a defined type, or a variant (e.g., array of generic) of a type.
 class AidlTypeSpecifier final : public AidlAnnotatable,
-                                public AidlTraversable,
                                 public AidlParameterizable<unique_ptr<AidlTypeSpecifier>> {
  public:
   AidlTypeSpecifier(const AidlLocation& location, const string& unresolved_name, bool is_array,
@@ -368,7 +379,8 @@
   const AidlNode& AsAidlNode() const override { return *this; }
 
   const AidlDefinedType* GetDefinedType() const;
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlAnnotatable::TraverseChildren(traverse);
     if (IsGeneric()) {
       for (const auto& tp : GetTypeParameters()) {
         traverse(*tp);
@@ -392,12 +404,7 @@
 // Returns the universal value unaltered.
 std::string AidlConstantValueDecorator(const AidlTypeSpecifier& type, const std::string& raw_value);
 
-class AidlConstantValue;
-class AidlMethod;
-class AidlConstantDeclaration;
-class AidlVariableDeclaration;
-
-class AidlMember : public AidlNode, public AidlTraversable {
+class AidlMember : public AidlNode {
  public:
   AidlMember(const AidlLocation& location);
   virtual ~AidlMember() = default;
@@ -423,9 +430,6 @@
     return const_cast<AidlVariableDeclaration*>(
         const_cast<const AidlMember*>(this)->AsVariableDeclaration());
   }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
-  void DispatchVisit(AidlVisitor& v) const = 0;
 };
 
 // TODO: This class is used for method arguments and also parcelable fields,
@@ -475,9 +479,7 @@
 
   std::string ValueString(const ConstantValueDecorator& decorator) const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    traverse(GetType());
-  }
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
@@ -514,7 +516,7 @@
   // e.g) "in @utf8InCpp String[] names"
   std::string ToString() const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
@@ -547,14 +549,6 @@
     ERROR,
   };
 
-  struct Visitor {
-    virtual ~Visitor() {}
-    virtual void Visit(const AidlConstantValue&) = 0;
-    virtual void Visit(const AidlConstantReference&) = 0;
-    virtual void Visit(const AidlUnaryConstExpression&) = 0;
-    virtual void Visit(const AidlBinaryConstExpression&) = 0;
-  };
-
   // Returns the evaluated value. T> should match to the actual type.
   template <typename T>
   T EvaluatedValue() const {
@@ -620,14 +614,15 @@
 
   // Raw value of type (currently valid in C++ and Java). Empty string on error.
   string ValueString(const AidlTypeSpecifier& type, const ConstantValueDecorator& decorator) const;
-  virtual void Accept(Visitor& visitor) {
-    visitor.Visit(*this);
+
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const {
     if (type_ == Type::ARRAY) {
       for (const auto& v : values_) {
-        v.get()->Accept(visitor);
+        traverse(*v);
       }
     }
   }
+  void DispatchVisit(AidlVisitor& visitor) const override { visitor.Visit(*this); }
 
  private:
   AidlConstantValue(const AidlLocation& location, Type parsed_type, int64_t parsed_value,
@@ -672,7 +667,10 @@
   const std::string& GetComments() const { return comments_; }
 
   bool CheckValid() const override;
-  void Accept(Visitor& visitor) override { visitor.Visit(*this); }
+  void TraverseChildren(std::function<void(const AidlNode&)>) const override {
+    // resolved_ is not my child.
+  }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
   const AidlConstantValue* Resolve(const AidlDefinedType* scope) const;
 
  private:
@@ -691,10 +689,10 @@
 
   static bool IsCompatibleType(Type type, const string& op);
   bool CheckValid() const override;
-  void Accept(Visitor& visitor) override {
-    visitor.Visit(*this);
-    unary_->Accept(visitor);
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*unary_);
   }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   bool evaluate() const override;
@@ -715,11 +713,11 @@
   static Type UsualArithmeticConversion(Type left, Type right);
   // Returns the promoted integral type where INT32 is the smallest type
   static Type IntegralPromotion(Type in);
-  void Accept(Visitor& visitor) override {
-    visitor.Visit(*this);
-    left_val_->Accept(visitor);
-    right_val_->Accept(visitor);
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*left_val_);
+    traverse(*right_val_);
   }
+  void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
   bool evaluate() const override;
@@ -770,8 +768,9 @@
 
   const AidlConstantDeclaration* AsConstantDeclaration() const override { return this; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
+    traverse(GetValue());
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
@@ -837,7 +836,7 @@
   // e.g) "foo(int, String)"
   std::string Signature() const;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
     traverse(GetType());
     for (const auto& a : GetArguments()) {
       traverse(*a);
@@ -858,15 +857,9 @@
   bool is_user_defined_ = true;
 };
 
-class AidlDefinedType;
-class AidlInterface;
-class AidlParcelable;
-class AidlStructuredParcelable;
-class AidlUnionDecl;
-
 // AidlDefinedType represents either an interface, parcelable, or enum that is
 // defined in the source file.
-class AidlDefinedType : public AidlAnnotatable, public AidlTraversable {
+class AidlDefinedType : public AidlAnnotatable {
  public:
   AidlDefinedType(const AidlLocation& location, const std::string& name,
                   const std::string& comments, const std::string& package,
@@ -947,9 +940,12 @@
   const std::vector<std::unique_ptr<AidlMethod>>& GetMethods() const { return methods_; }
   void AddMethod(std::unique_ptr<AidlMethod> method) { methods_.push_back(std::move(method)); }
   const std::vector<const AidlMember*>& GetMembers() const { return members_; }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)>) const = 0;
-  void DispatchVisit(AidlVisitor& v) const = 0;
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlAnnotatable::TraverseChildren(traverse);
+    for (const auto c : GetMembers()) {
+      traverse(*c);
+    }
+  }
 
  protected:
   // utility for subclasses with getter names
@@ -995,11 +991,6 @@
 
   void Dump(CodeWriter* writer) const override;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
  private:
@@ -1030,15 +1021,10 @@
   bool LanguageSpecificCheckValid(const AidlTypenames& typenames,
                                   Options::Language lang) const override;
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
-class AidlEnumerator : public AidlNode, public AidlTraversable {
+class AidlEnumerator : public AidlNode {
  public:
   AidlEnumerator(const AidlLocation& location, const std::string& name, AidlConstantValue* value,
                  const std::string& comments);
@@ -1061,8 +1047,8 @@
   void SetValue(std::unique_ptr<AidlConstantValue> value) { value_ = std::move(value); }
   bool IsValueUserSpecified() const { return value_user_specified_; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)>) const override {
-    // no children to traverse
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    traverse(*value_);
   }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 
@@ -1102,7 +1088,8 @@
 
   const AidlEnumDeclaration* AsEnumDeclaration() const override { return this; }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    AidlDefinedType::TraverseChildren(traverse);
     for (const auto& c : GetEnumerators()) {
       traverse(*c);
     }
@@ -1139,12 +1126,6 @@
 
   void Dump(CodeWriter* writer) const override;
   const AidlUnionDecl* AsUnionDeclaration() const override { return this; }
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
@@ -1172,12 +1153,6 @@
                                   Options::Language lang) const override;
 
   std::string GetDescriptor() const;
-
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
-    for (const auto c : GetMembers()) {
-      traverse(*c);
-    }
-  }
   void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
 };
 
@@ -1193,13 +1168,15 @@
   AidlImport& operator=(AidlImport&&) = delete;
 
   const std::string& GetNeededClass() const { return needed_class_; }
+  void TraverseChildren(std::function<void(const AidlNode&)>) const {}
+  void DispatchVisit(AidlVisitor& v) const { v.Visit(*this); }
 
  private:
   std::string needed_class_;
 };
 
 // AidlDocument models an AIDL file
-class AidlDocument : public AidlNode, public AidlTraversable {
+class AidlDocument : public AidlNode {
  public:
   AidlDocument(const AidlLocation& location, std::vector<std::unique_ptr<AidlImport>>& imports,
                std::vector<std::unique_ptr<AidlDefinedType>>&& defined_types)
@@ -1220,7 +1197,10 @@
     return defined_types_;
   }
 
-  void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+  void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+    for (const auto& i : Imports()) {
+      traverse(*i);
+    }
     for (const auto& t : DefinedTypes()) {
       traverse(*t);
     }
@@ -1240,3 +1220,16 @@
   }
   return it->second->EvaluatedValue<T>();
 }
+
+// Utility to make a visitor to visit AST tree in top-down order
+// Given:       foo
+//              / \
+//            bar baz
+// VisitTopDown(v, foo) makes v visit foo -> bar -> baz.
+inline void VisitTopDown(AidlVisitor& v, const AidlNode& node) {
+  std::function<void(const AidlNode&)> top_down = [&](const AidlNode& n) {
+    n.DispatchVisit(v);
+    n.TraverseChildren(top_down);
+  };
+  top_down(node);
+}
\ No newline at end of file