Merge "Constant Expressions have same tree structure as Type"
diff --git a/AST.cpp b/AST.cpp
index 6602ba4..ba3e250 100644
--- a/AST.cpp
+++ b/AST.cpp
@@ -93,14 +93,31 @@
     return OK;
 }
 
+status_t AST::constantExpressionRecursivePass(
+    const std::function<status_t(ConstantExpression*)>& func) {
+    std::unordered_set<const Type*> visitedTypes;
+    std::unordered_set<const ConstantExpression*> visitedCE;
+    return mRootScope.recursivePass(
+        [&](Type* type) -> status_t {
+            for (auto* ce : type->getConstantExpressions()) {
+                status_t err = ce->recursivePass(func, &visitedCE);
+                if (err != OK) return err;
+            }
+            return OK;
+        },
+        &visitedTypes);
+}
+
 status_t AST::resolveInheritance() {
     std::unordered_set<const Type*> visited;
     return mRootScope.recursivePass(&Type::resolveInheritance, &visited);
 }
 
 status_t AST::evaluate() {
-    std::unordered_set<const Type*> visited;
-    return mRootScope.recursivePass(&Type::evaluate, &visited);
+    return constantExpressionRecursivePass([](ConstantExpression* ce) {
+        ce->evaluate();
+        return OK;
+    });
 }
 
 status_t AST::validate() const {
diff --git a/AST.h b/AST.h
index 9dd2717..951cd24 100644
--- a/AST.h
+++ b/AST.h
@@ -20,6 +20,7 @@
 
 #include <android-base/macros.h>
 #include <hidl-util/FQName.h>
+#include <functional>
 #include <map>
 #include <set>
 #include <string>
@@ -31,6 +32,7 @@
 namespace android {
 
 struct Coordinator;
+struct ConstantExpression;
 struct EnumValue;
 struct Formatter;
 struct Interface;
@@ -71,6 +73,10 @@
     // being ready to generate output.
     status_t postParse();
 
+    // Recursive pass on constant expression tree
+    status_t constantExpressionRecursivePass(
+        const std::function<status_t(ConstantExpression*)>& func);
+
     // Recursive tree pass that completes type declarations
     // that depend on super types
     status_t resolveInheritance();
diff --git a/Annotation.cpp b/Annotation.cpp
index 7e519e9..cade080 100644
--- a/Annotation.cpp
+++ b/Annotation.cpp
@@ -29,12 +29,8 @@
     return mName;
 }
 
-status_t AnnotationParam::evaluate() {
-    return OK;
-}
-
-status_t AnnotationParam::validate() const {
-    return OK;
+std::vector<ConstantExpression*> AnnotationParam::getConstantExpressions() const {
+    return {};
 }
 
 std::string AnnotationParam::getSingleString() const {
@@ -99,11 +95,8 @@
     return convertToString(mValues->at(0));
 }
 
-status_t ConstantExpressionAnnotationParam::evaluate() {
-    for (auto* value : *mValues) {
-        value->evaluate();
-    }
-    return AnnotationParam::evaluate();
+std::vector<ConstantExpression*> ConstantExpressionAnnotationParam::getConstantExpressions() const {
+    return *mValues;
 }
 
 Annotation::Annotation(const char* name, AnnotationParamVector* params)
@@ -118,7 +111,7 @@
 }
 
 const AnnotationParam *Annotation::getParam(const std::string &name) const {
-    for (auto *i: *mParams) {
+    for (const auto* i : *mParams) {
         if (i->getName() == name) {
             return i;
         }
@@ -127,20 +120,13 @@
     return nullptr;
 }
 
-status_t Annotation::evaluate() {
-    for (auto* param : *mParams) {
-        status_t err = param->evaluate();
-        if (err != OK) return err;
-    }
-    return OK;
-}
-
-status_t Annotation::validate() const {
+std::vector<ConstantExpression*> Annotation::getConstantExpressions() const {
+    std::vector<ConstantExpression*> ret;
     for (const auto* param : *mParams) {
-        status_t err = param->validate();
-        if (err != OK) return err;
+        const auto& retParam = param->getConstantExpressions();
+        ret.insert(ret.end(), retParam.begin(), retParam.end());
     }
-    return OK;
+    return ret;
 }
 
 void Annotation::dump(Formatter &out) const {
diff --git a/Annotation.h b/Annotation.h
index e41d304..0688b05 100644
--- a/Annotation.h
+++ b/Annotation.h
@@ -21,6 +21,7 @@
 #include <android-base/macros.h>
 #include <map>
 #include <string>
+#include <vector>
 
 #include "ConstantExpression.h"
 
@@ -42,8 +43,7 @@
     /* Returns value interpretted as a boolean */
     bool getSingleBool() const;
 
-    virtual status_t evaluate();
-    virtual status_t validate() const;
+    virtual std::vector<ConstantExpression*> getConstantExpressions() const;
 
    protected:
     const std::string mName;
@@ -68,7 +68,7 @@
     std::vector<std::string> getValues() const override;
     std::string getSingleValue() const override;
 
-    status_t evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
    private:
     std::vector<ConstantExpression*>* const mValues;
@@ -83,8 +83,7 @@
     const AnnotationParamVector &params() const;
     const AnnotationParam *getParam(const std::string &name) const;
 
-    status_t evaluate();
-    status_t validate() const;
+    std::vector<ConstantExpression*> getConstantExpressions() const;
 
     void dump(Formatter &out) const;
 
diff --git a/ArrayType.cpp b/ArrayType.cpp
index 81d3313..2ef07b0 100644
--- a/ArrayType.cpp
+++ b/ArrayType.cpp
@@ -71,12 +71,8 @@
     return {mElementType};
 }
 
-status_t ArrayType::evaluate() {
-    for (auto* size : mSizes) {
-        size->evaluate();
-    }
-
-    return Type::evaluate();
+std::vector<ConstantExpression*> ArrayType::getConstantExpressions() const {
+    return mSizes;
 }
 
 status_t ArrayType::validate() const {
diff --git a/ArrayType.h b/ArrayType.h
index 95f2da3..97b83a0 100644
--- a/ArrayType.h
+++ b/ArrayType.h
@@ -46,7 +46,8 @@
 
     std::vector<Reference<Type>> getReferences() const override;
 
-    status_t evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
+
     status_t validate() const override;
 
     std::string getCppType(StorageMode mode,
diff --git a/ConstantExpression.cpp b/ConstantExpression.cpp
index be8688d..d4df274 100644
--- a/ConstantExpression.cpp
+++ b/ConstantExpression.cpp
@@ -57,6 +57,7 @@
 namespace android {
 
 static inline bool isSupported(ScalarType::Kind kind) {
+    // TODO(b/64358435) move isSupported to EnumValue
     return SK(BOOL) == kind || ScalarType(kind).isValidEnumStorageType();
 }
 
@@ -236,7 +237,7 @@
 
 void UnaryConstantExpression::evaluate() {
     if (isEvaluated()) return;
-    mUnary->evaluate();
+    CHECK(mUnary->isEvaluated());
     mIsEvaluated = true;
 
     mExpr = std::string("(") + mOp + mUnary->description() + ")";
@@ -251,8 +252,8 @@
 
 void BinaryConstantExpression::evaluate() {
     if (isEvaluated()) return;
-    mLval->evaluate();
-    mRval->evaluate();
+    CHECK(mLval->isEvaluated());
+    CHECK(mRval->isEvaluated());
     mIsEvaluated = true;
 
     mExpr = std::string("(") + mLval->description() + " " + mOp + " " + mRval->description() + ")";
@@ -310,15 +311,15 @@
 
 void TernaryConstantExpression::evaluate() {
     if (isEvaluated()) return;
-    mCond->evaluate();
-    mTrueVal->evaluate();
-    mFalseVal->evaluate();
+    CHECK(mCond->isEvaluated());
+    CHECK(mTrueVal->isEvaluated());
+    CHECK(mFalseVal->isEvaluated());
     mIsEvaluated = true;
 
     mExpr = std::string("(") + mCond->description() + "?" + mTrueVal->description() + ":" +
             mFalseVal->description() + ")";
 
-    // note: for ?:, unlike arithmetic ops, integral promotion is not necessary.
+    // note: for ?:, unlike arithmetic ops, integral promotion is not processed.
     mValueKind = usualArithmeticConversion(mTrueVal->mValueKind, mFalseVal->mValueKind);
 
 #define CASE_TERNARY(__type__)                                           \
@@ -331,10 +332,10 @@
 
 void ReferenceConstantExpression::evaluate() {
     if (isEvaluated()) return;
+    CHECK(mReference->constExpr() != nullptr);
 
     ConstantExpression* expr = mReference->constExpr();
-    CHECK(expr != nullptr);
-    expr->evaluate();
+    CHECK(expr->isEvaluated());
 
     mValueKind = expr->mValueKind;
     mValue = expr->mValue;
@@ -437,18 +438,51 @@
     return this->cast<size_t>();
 }
 
+status_t ConstantExpression::recursivePass(const std::function<status_t(ConstantExpression*)>& func,
+                                           std::unordered_set<const ConstantExpression*>* visited) {
+    if (visited->find(this) != visited->end()) return OK;
+    visited->insert(this);
+
+    for (auto* nextCE : getConstantExpressions()) {
+        status_t err = nextCE->recursivePass(func, visited);
+        if (err != OK) return err;
+    }
+
+    // Unlike types, constant expressions need to be proceeded after dependencies
+    status_t err = func(this);
+    if (err != OK) return err;
+
+    return OK;
+}
+
+std::vector<ConstantExpression*> LiteralConstantExpression::getConstantExpressions() const {
+    return {};
+}
+
 UnaryConstantExpression::UnaryConstantExpression(const std::string& op, ConstantExpression* value)
     : mUnary(value), mOp(op) {}
 
+std::vector<ConstantExpression*> UnaryConstantExpression::getConstantExpressions() const {
+    return {mUnary};
+}
+
 BinaryConstantExpression::BinaryConstantExpression(ConstantExpression* lval, const std::string& op,
                                                    ConstantExpression* rval)
     : mLval(lval), mRval(rval), mOp(op) {}
 
+std::vector<ConstantExpression*> BinaryConstantExpression::getConstantExpressions() const {
+    return {mLval, mRval};
+}
+
 TernaryConstantExpression::TernaryConstantExpression(ConstantExpression* cond,
                                                      ConstantExpression* trueVal,
                                                      ConstantExpression* falseVal)
     : mCond(cond), mTrueVal(trueVal), mFalseVal(falseVal) {}
 
+std::vector<ConstantExpression*> TernaryConstantExpression::getConstantExpressions() const {
+    return {mCond, mTrueVal, mFalseVal};
+}
+
 ReferenceConstantExpression::ReferenceConstantExpression(const Reference<LocalIdentifier>& value,
                                                          const std::string& expr)
     : mReference(value) {
@@ -456,6 +490,11 @@
     mTrivialDescription = mExpr.empty();
 }
 
+std::vector<ConstantExpression*> ReferenceConstantExpression::getConstantExpressions() const {
+    CHECK(mReference->constExpr() != nullptr);
+    return {mReference->constExpr()};
+}
+
 /*
 
 Evaluating expressions in HIDL language
diff --git a/ConstantExpression.h b/ConstantExpression.h
index ebb6c48..5f8df93 100644
--- a/ConstantExpression.h
+++ b/ConstantExpression.h
@@ -19,8 +19,11 @@
 #define CONSTANT_EXPRESSION_H_
 
 #include <android-base/macros.h>
+#include <functional>
 #include <memory>
 #include <string>
+#include <unordered_set>
+#include <vector>
 
 #include "Reference.h"
 #include "ScalarType.h"
@@ -45,13 +48,18 @@
 
     virtual ~ConstantExpression() {}
 
-    /*
-     * Runs recursive evaluation.
-     * Provides sort of lazy computation,
-     * mainly used for forward identifier reference.
-     */
+    // Proceeds recursive pass
+    // Makes sure to visit each node only once
+    // Used to provide lookup and lazy evaluation
+    status_t recursivePass(const std::function<status_t(ConstantExpression*)>& func,
+                           std::unordered_set<const ConstantExpression*>* visited);
+
+    // Evaluates current constant expression
+    // Doesn't call recursive evaluation, so must be called after dependencies
     virtual void evaluate() = 0;
 
+    virtual std::vector<ConstantExpression*> getConstantExpressions() const = 0;
+
     /* Returns true iff the value has already been evaluated. */
     bool isEvaluated() const;
     /* Evaluated result in a string form. */
@@ -78,7 +86,6 @@
    private:
     /* If the result value has been evaluated. */
     bool mIsEvaluated = false;
-
     /* The formatted expression. */
     std::string mExpr;
     /* The kind of the result value. */
@@ -114,11 +121,13 @@
     LiteralConstantExpression(ScalarType::Kind kind, uint64_t value);
     LiteralConstantExpression(const std::string& value);
     void evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 };
 
 struct UnaryConstantExpression : public ConstantExpression {
     UnaryConstantExpression(const std::string& mOp, ConstantExpression* value);
     void evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
    private:
     ConstantExpression* const mUnary;
@@ -129,6 +138,7 @@
     BinaryConstantExpression(ConstantExpression* lval, const std::string& op,
                              ConstantExpression* rval);
     void evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
    private:
     ConstantExpression* const mLval;
@@ -140,6 +150,7 @@
     TernaryConstantExpression(ConstantExpression* cond, ConstantExpression* trueVal,
                               ConstantExpression* falseVal);
     void evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
    private:
     ConstantExpression* const mCond;
@@ -150,6 +161,7 @@
 struct ReferenceConstantExpression : public ConstantExpression {
     ReferenceConstantExpression(const Reference<LocalIdentifier>& value, const std::string& expr);
     void evaluate() override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
    private:
     Reference<LocalIdentifier> mReference;
diff --git a/EnumType.cpp b/EnumType.cpp
index 7b0f805..bc5f3f3 100644
--- a/EnumType.cpp
+++ b/EnumType.cpp
@@ -72,22 +72,17 @@
     return {mStorageType};
 }
 
-status_t EnumType::evaluate() {
-    for (auto* value : mValues) {
-        status_t err = value->evaluate();
-        if (err != OK) return err;
+std::vector<ConstantExpression*> EnumType::getConstantExpressions() const {
+    std::vector<ConstantExpression*> ret;
+    for (const auto* value : mValues) {
+        ret.push_back(value->constExpr());
     }
-    return Scope::evaluate();
+    return ret;
 }
 
 status_t EnumType::validate() const {
     CHECK(getSubTypes().empty());
 
-    for (auto* value : mValues) {
-        status_t err = value->validate();
-        if (err != OK) return err;
-    }
-
     if (!isElidableType() || !mStorageType->isValidEnumStorageType()) {
         std::cerr << "ERROR: Invalid enum storage type (" << (mStorageType)->typeName()
                   << ") specified at " << mStorageType.location() << "\n";
@@ -792,16 +787,6 @@
     return true;
 }
 
-status_t EnumValue::evaluate() {
-    mValue->evaluate();
-    return OK;
-}
-
-status_t EnumValue::validate() const {
-    // TODO(b/64358435) move isSupported from ConstantExpression
-    return OK;
-}
-
 const Location& EnumValue::location() const {
     return mLocation;
 }
diff --git a/EnumType.h b/EnumType.h
index 0fb215f..05a85d5 100644
--- a/EnumType.h
+++ b/EnumType.h
@@ -61,9 +61,9 @@
     BitFieldType *getBitfieldType() const;
 
     std::vector<Reference<Type>> getReferences() const override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
     status_t resolveInheritance() override;
-    status_t evaluate() override;
     status_t validate() const override;
     status_t validateUniqueNames() const;
 
@@ -145,9 +145,6 @@
     bool isAutoFill() const;
     bool isEnumValue() const override;
 
-    status_t evaluate() override;
-    status_t validate() const override;
-
     const Location& location() const;
 
    private:
diff --git a/Interface.cpp b/Interface.cpp
index 79824ae..10fb89d 100644
--- a/Interface.cpp
+++ b/Interface.cpp
@@ -475,6 +475,15 @@
     return ret;
 }
 
+std::vector<ConstantExpression*> Interface::getConstantExpressions() const {
+    std::vector<ConstantExpression*> ret;
+    for (const auto* method : methods()) {
+        const auto& retMethod = method->getConstantExpressions();
+        ret.insert(ret.end(), retMethod.begin(), retMethod.end());
+    }
+    return ret;
+}
+
 status_t Interface::resolveInheritance() {
     size_t serial = FIRST_CALL_TRANSACTION;
     for (const auto* ancestor : superTypeChain()) {
@@ -496,15 +505,6 @@
     return Scope::resolveInheritance();
 }
 
-status_t Interface::evaluate() {
-    for (auto* method : methods()) {
-        status_t err = method->evaluate();
-        if (err != OK) return err;
-    }
-
-    return Scope::evaluate();
-}
-
 status_t Interface::validate() const {
     CHECK(isIBase() == mSuperType.isEmptyReference());
 
@@ -513,11 +513,6 @@
         return UNKNOWN_ERROR;
     }
 
-    for (const auto* method : methods()) {
-        status_t err = method->validate();
-        if (err != OK) return err;
-    }
-
     status_t err = validateUniqueNames();
     if (err != OK) return err;
 
diff --git a/Interface.h b/Interface.h
index b664fdf..7b7eb33 100644
--- a/Interface.h
+++ b/Interface.h
@@ -88,8 +88,9 @@
 
     std::vector<Reference<Type>> getReferences() const override;
 
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
+
     status_t resolveInheritance() override;
-    status_t evaluate() override;
     status_t validate() const override;
     status_t validateUniqueNames() const;
 
diff --git a/Method.cpp b/Method.cpp
index 07f429d..3c353f8 100644
--- a/Method.cpp
+++ b/Method.cpp
@@ -17,6 +17,7 @@
 #include "Method.h"
 
 #include "Annotation.h"
+#include "ConstantExpression.h"
 #include "ScalarType.h"
 #include "Type.h"
 
@@ -79,22 +80,13 @@
     return ret;
 }
 
-status_t Method::evaluate() {
-    for (auto* annotaion : *mAnnotations) {
-        status_t err = annotaion->evaluate();
-        if (err != OK) return err;
+std::vector<ConstantExpression*> Method::getConstantExpressions() const {
+    std::vector<ConstantExpression*> ret;
+    for (const auto* annotation : *mAnnotations) {
+        const auto& retAnnotation = annotation->getConstantExpressions();
+        ret.insert(ret.end(), retAnnotation.begin(), retAnnotation.end());
     }
-
-    return OK;
-}
-
-status_t Method::validate() const {
-    for (const auto* annotaion : *mAnnotations) {
-        status_t err = annotaion->validate();
-        if (err != OK) return err;
-    }
-
-    return OK;
+    return ret;
 }
 
 void Method::cppImpl(MethodImplType type, Formatter &out) const {
diff --git a/Method.h b/Method.h
index 941f7b7..8a4db1a 100644
--- a/Method.h
+++ b/Method.h
@@ -33,6 +33,7 @@
 namespace android {
 
 struct Annotation;
+struct ConstantExpression;
 struct Formatter;
 struct ScalarType;
 struct Type;
@@ -67,8 +68,7 @@
 
     std::vector<Reference<Type>> getReferences() const;
 
-    status_t evaluate();
-    status_t validate() const;
+    std::vector<ConstantExpression*> getConstantExpressions() const;
 
     // Make a copy with the same name, args, results, oneway, annotations.
     // Implementations, serial are not copied.
diff --git a/Scope.cpp b/Scope.cpp
index 31ce843..fb9d339 100644
--- a/Scope.cpp
+++ b/Scope.cpp
@@ -17,6 +17,7 @@
 #include "Scope.h"
 
 #include "Annotation.h"
+#include "ConstantExpression.h"
 #include "Interface.h"
 
 #include <android-base/logging.h>
@@ -122,6 +123,15 @@
     return ret;
 }
 
+std::vector<ConstantExpression*> Scope::getConstantExpressions() const {
+    std::vector<ConstantExpression*> ret;
+    for (const auto* annotation : mAnnotations) {
+        const auto& retAnnotation = annotation->getConstantExpressions();
+        ret.insert(ret.end(), retAnnotation.begin(), retAnnotation.end());
+    }
+    return ret;
+}
+
 status_t Scope::forEachType(const std::function<status_t(Type *)> &func) const {
     for (size_t i = 0; i < mTypes.size(); ++i) {
         status_t err = func(mTypes[i]);
@@ -134,24 +144,6 @@
     return OK;
 }
 
-status_t Scope::evaluate() {
-    for (auto* annotation : mAnnotations) {
-        status_t err = annotation->evaluate();
-        if (err != OK) return err;
-    }
-
-    return NamedType::evaluate();
-}
-
-status_t Scope::validate() const {
-    for (const auto* annotation : mAnnotations) {
-        status_t err = annotation->validate();
-        if (err != OK) return err;
-    }
-
-    return NamedType::validate();
-}
-
 status_t Scope::emitTypeDeclarations(Formatter &out) const {
     return forEachType([&](Type *type) {
         return type->emitTypeDeclarations(out);
@@ -249,14 +241,5 @@
     return nullptr;
 }
 
-status_t LocalIdentifier::evaluate() {
-    return OK;
-}
-
-status_t LocalIdentifier::validate() const {
-    CHECK(isEnumValue());
-    return OK;
-}
-
 }  // namespace android
 
diff --git a/Scope.h b/Scope.h
index 2c3e177..15bf01b 100644
--- a/Scope.h
+++ b/Scope.h
@@ -56,8 +56,7 @@
 
     std::vector<Type*> getDefinedTypes() const override;
 
-    virtual status_t evaluate() override;
-    virtual status_t validate() const override;
+    std::vector<ConstantExpression*> getConstantExpressions() const override;
 
     status_t emitTypeDeclarations(Formatter &out) const override;
     status_t emitGlobalTypeDeclarations(Formatter &out) const override;
@@ -102,9 +101,6 @@
     virtual ~LocalIdentifier();
     virtual bool isEnumValue() const;
 
-    virtual status_t evaluate();
-    virtual status_t validate() const;
-
     virtual ConstantExpression* constExpr() const;
 };
 
diff --git a/Type.cpp b/Type.cpp
index 65a7bc6..bf48072 100644
--- a/Type.cpp
+++ b/Type.cpp
@@ -16,6 +16,7 @@
 
 #include "Type.h"
 
+#include "ConstantExpression.h"
 #include "ScalarType.h"
 
 #include <hidl-util/Formatter.h>
@@ -99,6 +100,10 @@
     return {};
 }
 
+std::vector<ConstantExpression*> Type::getConstantExpressions() const {
+    return {};
+}
+
 status_t Type::recursivePass(const std::function<status_t(Type*)>& func,
                              std::unordered_set<const Type*>* visited) {
     if (visited->find(this) != visited->end()) return OK;
@@ -146,10 +151,6 @@
     return OK;
 }
 
-status_t Type::evaluate() {
-    return OK;
-}
-
 status_t Type::validate() const {
     return OK;
 }
diff --git a/Type.h b/Type.h
index 8b385b7..3c2cbd3 100644
--- a/Type.h
+++ b/Type.h
@@ -29,9 +29,10 @@
 
 namespace android {
 
+struct ConstantExpression;
 struct Formatter;
-struct ScalarType;
 struct FQName;
+struct ScalarType;
 
 struct Type {
     Type();
@@ -60,6 +61,9 @@
     // All types referenced in this type.
     virtual std::vector<Reference<Type>> getReferences() const;
 
+    // All constant expressions referenced in this type.
+    virtual std::vector<ConstantExpression*> getConstantExpressions() const;
+
     // Proceeds recursive pass
     // Makes sure to visit each node only once.
     status_t recursivePass(const std::function<status_t(Type*)>& func,
@@ -71,9 +75,6 @@
     // that depend on super types
     virtual status_t resolveInheritance();
 
-    // Recursive tree pass that evaluates constant expressions
-    virtual status_t evaluate();
-
     // Recursive tree pass that validates all type-related
     // syntax restrictions
     virtual status_t validate() const;