Make AidlNode visitable
Pull up DispatchVisit/TraverseChildren to AidlNode.
AidlConstantValue::Visitor is replaced with AidlVisitor as well.
Bug: none
Test: aidl_unittests
Change-Id: I51e0e3fdb06f7deeac9a28d7797e5fe738c28bb2
diff --git a/aidl_language.cpp b/aidl_language.cpp
index 9f8e638..5389a4f 100644
--- a/aidl_language.cpp
+++ b/aidl_language.cpp
@@ -188,14 +188,16 @@
std::map<std::string, std::shared_ptr<AidlConstantValue>>&& parameters)
: AidlNode(location), schema_(schema), parameters_(std::move(parameters)) {}
-struct ConstReferenceFinder : AidlConstantValue::Visitor {
+struct ConstReferenceFinder : AidlVisitor {
const AidlConstantReference* found;
- void Visit(const AidlConstantValue&) override {}
- void Visit(const AidlUnaryConstExpression&) override {}
- void Visit(const AidlBinaryConstExpression&) override {}
void Visit(const AidlConstantReference& ref) override {
if (!found) found = &ref;
}
+ static const AidlConstantReference* Find(const AidlConstantValue& c) {
+ ConstReferenceFinder finder;
+ VisitTopDown(finder, c);
+ return finder.found;
+ }
};
bool AidlAnnotation::CheckValid() const {
@@ -216,11 +218,10 @@
return false;
}
- ConstReferenceFinder finder;
- param->Accept(finder);
- if (finder.found) {
- AIDL_ERROR(finder.found) << "Value must be a constant expression but contains reference to "
- << finder.found->GetFieldName() << ".";
+ const auto& found = ConstReferenceFinder::Find(*param);
+ if (found) {
+ AIDL_ERROR(found) << "Value must be a constant expression but contains reference to "
+ << found->GetFieldName() << ".";
return false;
}
@@ -288,6 +289,13 @@
}
}
+void AidlAnnotation::TraverseChildren(std::function<void(const AidlNode&)> traverse) const {
+ for (const auto& [name, value] : parameters_) {
+ (void)name;
+ traverse(*value);
+ }
+}
+
static const AidlAnnotation* GetAnnotation(const vector<AidlAnnotation>& annotations,
AidlAnnotation::Type type) {
for (const auto& a : annotations) {
@@ -705,6 +713,14 @@
}
}
+void AidlVariableDeclaration::TraverseChildren(
+ std::function<void(const AidlNode&)> traverse) const {
+ traverse(GetType());
+ if (IsDefaultUserSpecified()) {
+ traverse(*GetDefaultValue());
+ }
+}
+
AidlArgument::AidlArgument(const AidlLocation& location, AidlArgument::Direction direction,
AidlTypeSpecifier* type, const std::string& name)
: AidlVariableDeclaration(location, type, name),
diff --git a/aidl_language.h b/aidl_language.h
index 0ea5ce4..15428c4 100644
--- a/aidl_language.h
+++ b/aidl_language.h
@@ -76,6 +76,7 @@
bool ParseFloating(std::string_view sv, float* parsed);
class AidlDocument;
+class AidlImport;
class AidlInterface;
class AidlParcelable;
class AidlStructuredParcelable;
@@ -86,6 +87,11 @@
class AidlEnumerator;
class AidlMethod;
class AidlArgument;
+class AidlConstantValue;
+class AidlConstantReference;
+class AidlUnaryConstExpression;
+class AidlBinaryConstExpression;
+class AidlAnnotation;
// Interface for visitors that can traverse AidlTraversable nodes.
class AidlVisitor {
@@ -103,6 +109,12 @@
virtual void Visit(const AidlConstantDeclaration&) {}
virtual void Visit(const AidlArgument&) {}
virtual void Visit(const AidlTypeSpecifier&) {}
+ virtual void Visit(const AidlConstantValue&) {}
+ virtual void Visit(const AidlConstantReference&) {}
+ virtual void Visit(const AidlUnaryConstExpression&) {}
+ virtual void Visit(const AidlBinaryConstExpression&) {}
+ virtual void Visit(const AidlAnnotation&) {}
+ virtual void Visit(const AidlImport&) {}
};
// Anything that is locatable in a .aidl file.
@@ -122,6 +134,8 @@
friend std::string android::aidl::java::dump_location(const AidlNode&);
const AidlLocation& GetLocation() const { return location_; }
+ virtual void TraverseChildren(std::function<void(const AidlNode&)> traverse) const = 0;
+ virtual void DispatchVisit(AidlVisitor&) const = 0;
private:
std::string PrintLine() const;
@@ -129,15 +143,6 @@
const AidlLocation location_;
};
-// Anything that is traversable by the AidlVisitor
-class AidlTraversable {
- public:
- virtual ~AidlTraversable() = default;
-
- virtual void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
- virtual void DispatchVisit(AidlVisitor&) const = 0;
-};
-
// unique_ptr<AidlTypeSpecifier> for type arugment,
// std::string for type parameter(T, U, and so on).
template <typename T>
@@ -222,6 +227,8 @@
const ConstantValueDecorator& decorator) const;
const string& GetComments() const { return comments_; }
void SetComments(const string& comments) { comments_ = comments; }
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
+ void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
private:
struct Schema {
@@ -292,6 +299,11 @@
const vector<AidlAnnotation>& GetAnnotations() const { return annotations_; }
bool CheckValid(const AidlTypenames&) const;
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ for (const auto& annot : GetAnnotations()) {
+ traverse(annot);
+ }
+ }
protected:
virtual std::set<AidlAnnotation::Type> GetSupportedAnnotations() const = 0;
@@ -303,7 +315,6 @@
// AidlTypeSpecifier represents a reference to either a built-in type,
// a defined type, or a variant (e.g., array of generic) of a type.
class AidlTypeSpecifier final : public AidlAnnotatable,
- public AidlTraversable,
public AidlParameterizable<unique_ptr<AidlTypeSpecifier>> {
public:
AidlTypeSpecifier(const AidlLocation& location, const string& unresolved_name, bool is_array,
@@ -368,7 +379,8 @@
const AidlNode& AsAidlNode() const override { return *this; }
const AidlDefinedType* GetDefinedType() const;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ AidlAnnotatable::TraverseChildren(traverse);
if (IsGeneric()) {
for (const auto& tp : GetTypeParameters()) {
traverse(*tp);
@@ -392,12 +404,7 @@
// Returns the universal value unaltered.
std::string AidlConstantValueDecorator(const AidlTypeSpecifier& type, const std::string& raw_value);
-class AidlConstantValue;
-class AidlMethod;
-class AidlConstantDeclaration;
-class AidlVariableDeclaration;
-
-class AidlMember : public AidlNode, public AidlTraversable {
+class AidlMember : public AidlNode {
public:
AidlMember(const AidlLocation& location);
virtual ~AidlMember() = default;
@@ -423,9 +430,6 @@
return const_cast<AidlVariableDeclaration*>(
const_cast<const AidlMember*>(this)->AsVariableDeclaration());
}
-
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const = 0;
- void DispatchVisit(AidlVisitor& v) const = 0;
};
// TODO: This class is used for method arguments and also parcelable fields,
@@ -475,9 +479,7 @@
std::string ValueString(const ConstantValueDecorator& decorator) const;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
- traverse(GetType());
- }
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override;
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
private:
@@ -514,7 +516,7 @@
// e.g) "in @utf8InCpp String[] names"
std::string ToString() const;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
traverse(GetType());
}
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
@@ -547,14 +549,6 @@
ERROR,
};
- struct Visitor {
- virtual ~Visitor() {}
- virtual void Visit(const AidlConstantValue&) = 0;
- virtual void Visit(const AidlConstantReference&) = 0;
- virtual void Visit(const AidlUnaryConstExpression&) = 0;
- virtual void Visit(const AidlBinaryConstExpression&) = 0;
- };
-
// Returns the evaluated value. T> should match to the actual type.
template <typename T>
T EvaluatedValue() const {
@@ -620,14 +614,15 @@
// Raw value of type (currently valid in C++ and Java). Empty string on error.
string ValueString(const AidlTypeSpecifier& type, const ConstantValueDecorator& decorator) const;
- virtual void Accept(Visitor& visitor) {
- visitor.Visit(*this);
+
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const {
if (type_ == Type::ARRAY) {
for (const auto& v : values_) {
- v.get()->Accept(visitor);
+ traverse(*v);
}
}
}
+ void DispatchVisit(AidlVisitor& visitor) const override { visitor.Visit(*this); }
private:
AidlConstantValue(const AidlLocation& location, Type parsed_type, int64_t parsed_value,
@@ -672,7 +667,10 @@
const std::string& GetComments() const { return comments_; }
bool CheckValid() const override;
- void Accept(Visitor& visitor) override { visitor.Visit(*this); }
+ void TraverseChildren(std::function<void(const AidlNode&)>) const override {
+ // resolved_ is not my child.
+ }
+ void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
const AidlConstantValue* Resolve(const AidlDefinedType* scope) const;
private:
@@ -691,10 +689,10 @@
static bool IsCompatibleType(Type type, const string& op);
bool CheckValid() const override;
- void Accept(Visitor& visitor) override {
- visitor.Visit(*this);
- unary_->Accept(visitor);
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ traverse(*unary_);
}
+ void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
private:
bool evaluate() const override;
@@ -715,11 +713,11 @@
static Type UsualArithmeticConversion(Type left, Type right);
// Returns the promoted integral type where INT32 is the smallest type
static Type IntegralPromotion(Type in);
- void Accept(Visitor& visitor) override {
- visitor.Visit(*this);
- left_val_->Accept(visitor);
- right_val_->Accept(visitor);
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ traverse(*left_val_);
+ traverse(*right_val_);
}
+ void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
private:
bool evaluate() const override;
@@ -770,8 +768,9 @@
const AidlConstantDeclaration* AsConstantDeclaration() const override { return this; }
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
traverse(GetType());
+ traverse(GetValue());
}
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
@@ -837,7 +836,7 @@
// e.g) "foo(int, String)"
std::string Signature() const;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
traverse(GetType());
for (const auto& a : GetArguments()) {
traverse(*a);
@@ -858,15 +857,9 @@
bool is_user_defined_ = true;
};
-class AidlDefinedType;
-class AidlInterface;
-class AidlParcelable;
-class AidlStructuredParcelable;
-class AidlUnionDecl;
-
// AidlDefinedType represents either an interface, parcelable, or enum that is
// defined in the source file.
-class AidlDefinedType : public AidlAnnotatable, public AidlTraversable {
+class AidlDefinedType : public AidlAnnotatable {
public:
AidlDefinedType(const AidlLocation& location, const std::string& name,
const std::string& comments, const std::string& package,
@@ -947,9 +940,12 @@
const std::vector<std::unique_ptr<AidlMethod>>& GetMethods() const { return methods_; }
void AddMethod(std::unique_ptr<AidlMethod> method) { methods_.push_back(std::move(method)); }
const std::vector<const AidlMember*>& GetMembers() const { return members_; }
-
- void TraverseChildren(std::function<void(const AidlTraversable&)>) const = 0;
- void DispatchVisit(AidlVisitor& v) const = 0;
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ AidlAnnotatable::TraverseChildren(traverse);
+ for (const auto c : GetMembers()) {
+ traverse(*c);
+ }
+ }
protected:
// utility for subclasses with getter names
@@ -995,11 +991,6 @@
void Dump(CodeWriter* writer) const override;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
- for (const auto c : GetMembers()) {
- traverse(*c);
- }
- }
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
private:
@@ -1030,15 +1021,10 @@
bool LanguageSpecificCheckValid(const AidlTypenames& typenames,
Options::Language lang) const override;
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
- for (const auto c : GetMembers()) {
- traverse(*c);
- }
- }
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
};
-class AidlEnumerator : public AidlNode, public AidlTraversable {
+class AidlEnumerator : public AidlNode {
public:
AidlEnumerator(const AidlLocation& location, const std::string& name, AidlConstantValue* value,
const std::string& comments);
@@ -1061,8 +1047,8 @@
void SetValue(std::unique_ptr<AidlConstantValue> value) { value_ = std::move(value); }
bool IsValueUserSpecified() const { return value_user_specified_; }
- void TraverseChildren(std::function<void(const AidlTraversable&)>) const override {
- // no children to traverse
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ traverse(*value_);
}
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
@@ -1102,7 +1088,8 @@
const AidlEnumDeclaration* AsEnumDeclaration() const override { return this; }
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ AidlDefinedType::TraverseChildren(traverse);
for (const auto& c : GetEnumerators()) {
traverse(*c);
}
@@ -1139,12 +1126,6 @@
void Dump(CodeWriter* writer) const override;
const AidlUnionDecl* AsUnionDeclaration() const override { return this; }
-
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
- for (const auto c : GetMembers()) {
- traverse(*c);
- }
- }
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
};
@@ -1172,12 +1153,6 @@
Options::Language lang) const override;
std::string GetDescriptor() const;
-
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
- for (const auto c : GetMembers()) {
- traverse(*c);
- }
- }
void DispatchVisit(AidlVisitor& v) const override { v.Visit(*this); }
};
@@ -1193,13 +1168,15 @@
AidlImport& operator=(AidlImport&&) = delete;
const std::string& GetNeededClass() const { return needed_class_; }
+ void TraverseChildren(std::function<void(const AidlNode&)>) const {}
+ void DispatchVisit(AidlVisitor& v) const { v.Visit(*this); }
private:
std::string needed_class_;
};
// AidlDocument models an AIDL file
-class AidlDocument : public AidlNode, public AidlTraversable {
+class AidlDocument : public AidlNode {
public:
AidlDocument(const AidlLocation& location, std::vector<std::unique_ptr<AidlImport>>& imports,
std::vector<std::unique_ptr<AidlDefinedType>>&& defined_types)
@@ -1220,7 +1197,10 @@
return defined_types_;
}
- void TraverseChildren(std::function<void(const AidlTraversable&)> traverse) const override {
+ void TraverseChildren(std::function<void(const AidlNode&)> traverse) const override {
+ for (const auto& i : Imports()) {
+ traverse(*i);
+ }
for (const auto& t : DefinedTypes()) {
traverse(*t);
}
@@ -1240,3 +1220,16 @@
}
return it->second->EvaluatedValue<T>();
}
+
+// Utility to make a visitor to visit AST tree in top-down order
+// Given: foo
+// / \
+// bar baz
+// VisitTopDown(v, foo) makes v visit foo -> bar -> baz.
+inline void VisitTopDown(AidlVisitor& v, const AidlNode& node) {
+ std::function<void(const AidlNode&)> top_down = [&](const AidlNode& n) {
+ n.DispatchVisit(v);
+ n.TraverseChildren(top_down);
+ };
+ top_down(node);
+}
\ No newline at end of file
diff --git a/diagnostics.cpp b/diagnostics.cpp
index ca738b3..1e00eef 100644
--- a/diagnostics.cpp
+++ b/diagnostics.cpp
@@ -109,14 +109,14 @@
};
Hook suppress{std::bind(&DiagnosticsContext::Suppress, &diag, _1)};
Hook restore{std::bind(&DiagnosticsContext::Restore, &diag, _1)};
- std::function<void(const AidlTraversable&)> topDown =
- [&topDown, &suppress, &restore, visitor](const AidlTraversable& a) {
+ std::function<void(const AidlNode&)> top_down = [&top_down, &suppress, &restore,
+ visitor](const AidlNode& a) {
a.DispatchVisit(suppress);
a.DispatchVisit(*visitor);
- a.TraverseChildren(topDown);
+ a.TraverseChildren(top_down);
a.DispatchVisit(restore);
};
- topDown(doc);
+ top_down(doc);
}
protected:
DiagnosticsContext& diag;
diff --git a/parser.cpp b/parser.cpp
index 897efa0..083e75c 100644
--- a/parser.cpp
+++ b/parser.cpp
@@ -61,14 +61,11 @@
}
}
-class ConstantReferenceResolver : public AidlConstantValue::Visitor {
+class ConstantReferenceResolver : public AidlVisitor {
public:
ConstantReferenceResolver(const AidlDefinedType* scope, const AidlTypenames& typenames,
TypeResolver& resolver, bool* success)
: scope_(scope), typenames_(typenames), resolver_(resolver), success_(success) {}
- void Visit(const AidlConstantValue&) override {}
- void Visit(const AidlUnaryConstExpression&) override {}
- void Visit(const AidlBinaryConstExpression&) override {}
void Visit(const AidlConstantReference& v) override {
if (IsCircularReference(&v)) {
*success_ = false;
@@ -91,7 +88,7 @@
// resolve recursive references
Push(&v);
- const_cast<AidlConstantValue*>(resolved)->Accept(*this);
+ VisitTopDown(*this, *resolved);
Pop();
}
@@ -149,22 +146,7 @@
// resolve "field references" as well.
for (const auto& type : document_->DefinedTypes()) {
ConstantReferenceResolver ref_resolver{type.get(), typenames_, type_resolver, &success};
- if (auto enum_type = type->AsEnumDeclaration(); enum_type) {
- for (const auto& enumerator : enum_type->GetEnumerators()) {
- if (auto value = enumerator->GetValue(); value) {
- value->Accept(ref_resolver);
- }
- }
- } else {
- for (const auto& constant : type->GetConstantDeclarations()) {
- const_cast<AidlConstantValue&>(constant->GetValue()).Accept(ref_resolver);
- }
- for (const auto& field : type->GetFields()) {
- if (field->IsDefaultUserSpecified()) {
- const_cast<AidlConstantValue*>(field->GetDefaultValue())->Accept(ref_resolver);
- }
- }
- }
+ VisitTopDown(ref_resolver, *type);
}
return success;