Move lookup from parsing

Enables forward references!

Bug: 31827278

Test: mma
Test: tests in topic
Test: boots
Change-Id: I6c0599c3656db2a6c2246459330e4aed9fd2538b
diff --git a/AST.cpp b/AST.cpp
index 3966d52..2b40a4a 100644
--- a/AST.cpp
+++ b/AST.cpp
@@ -84,9 +84,13 @@
 
 status_t AST::postParse() {
     status_t err;
-    // validateDefinedTypesUniqueNames is the first call,
-    // as other errors could appear because user meant
-    // different type than we assumed.
+
+    // lookupTypes is the first pass.
+    err = lookupTypes();
+    if (err != OK) return err;
+    // validateDefinedTypesUniqueNames is the first call
+    // after lookup, as other errors could appear because
+    // user meant different type than we assumed.
     err = validateDefinedTypesUniqueNames();
     if (err != OK) return err;
     // checkAcyclicTypes is before resolveInheritance, as we
@@ -95,6 +99,8 @@
     if (err != OK) return err;
     err = resolveInheritance();
     if (err != OK) return err;
+    err = lookupLocalIdentifiers();
+    if (err != OK) return err;
     // checkAcyclicConstantExpressions is after resolveInheritance,
     // as resolveInheritance autofills enum values.
     err = checkAcyclicConstantExpressions();
@@ -108,10 +114,12 @@
 
     // Make future packages not to call passes
     // for processed types and expressions
-    constantExpressionRecursivePass([](ConstantExpression* ce) {
-        ce->setPostParseCompleted();
-        return OK;
-    });
+    constantExpressionRecursivePass(
+        [](ConstantExpression* ce) {
+            ce->setPostParseCompleted();
+            return OK;
+        },
+        true /* processBeforeDependencies */);
     std::unordered_set<const Type*> visited;
     mRootScope.recursivePass(
         [](Type* type) {
@@ -124,13 +132,13 @@
 }
 
 status_t AST::constantExpressionRecursivePass(
-    const std::function<status_t(ConstantExpression*)>& func) {
+    const std::function<status_t(ConstantExpression*)>& func, bool processBeforeDependencies) {
     std::unordered_set<const Type*> visitedTypes;
     std::unordered_set<const ConstantExpression*> visitedCE;
     return mRootScope.recursivePass(
         [&](Type* type) -> status_t {
             for (auto* ce : type->getConstantExpressions()) {
-                status_t err = ce->recursivePass(func, &visitedCE);
+                status_t err = ce->recursivePass(func, &visitedCE, processBeforeDependencies);
                 if (err != OK) return err;
             }
             return OK;
@@ -138,6 +146,59 @@
         &visitedTypes);
 }
 
+status_t AST::lookupTypes() {
+    std::unordered_set<const Type*> visited;
+    return mRootScope.recursivePass(
+        [&](Type* type) -> status_t {
+            Scope* scope = type->isScope() ? static_cast<Scope*>(type) : type->parent();
+
+            for (auto* nextRef : type->getReferences()) {
+                if (nextRef->isResolved()) continue;
+
+                Type* nextType = lookupType(nextRef->getLookupFqName(), scope);
+                if (nextType == nullptr) {
+                    std::cerr << "ERROR: Failed to lookup type '"
+                              << nextRef->getLookupFqName().string() << "' at "
+                              << nextRef->location() << "\n";
+                    return UNKNOWN_ERROR;
+                }
+                nextRef->set(nextType);
+            }
+
+            return OK;
+        },
+        &visited);
+}
+
+status_t AST::lookupLocalIdentifiers() {
+    std::unordered_set<const Type*> visitedTypes;
+    std::unordered_set<const ConstantExpression*> visitedCE;
+
+    return mRootScope.recursivePass(
+        [&](Type* type) -> status_t {
+            Scope* scope = type->isScope() ? static_cast<Scope*>(type) : type->parent();
+
+            for (auto* ce : type->getConstantExpressions()) {
+                status_t err = ce->recursivePass(
+                    [&](ConstantExpression* ce) {
+                        for (auto* nextRef : ce->getReferences()) {
+                            if (nextRef->isResolved()) continue;
+
+                            LocalIdentifier* iden = lookupLocalIdentifier(*nextRef, scope);
+                            if (iden == nullptr) return UNKNOWN_ERROR;
+                            nextRef->set(iden);
+                        }
+                        return OK;
+                    },
+                    &visitedCE, true /* processBeforeDependencies */);
+                if (err != OK) return err;
+            }
+
+            return OK;
+        },
+        &visitedTypes);
+}
+
 status_t AST::validateDefinedTypesUniqueNames() const {
     std::unordered_set<const Type*> visited;
     return mRootScope.recursivePass(
@@ -157,10 +218,12 @@
 }
 
 status_t AST::evaluate() {
-    return constantExpressionRecursivePass([](ConstantExpression* ce) {
-        ce->evaluate();
-        return OK;
-    });
+    return constantExpressionRecursivePass(
+        [](ConstantExpression* ce) {
+            ce->evaluate();
+            return OK;
+        },
+        false /* processBeforeDependencies */);
 }
 
 status_t AST::validate() const {
@@ -319,6 +382,28 @@
     mDefinedTypesByFullName[type->fqName()] = type;
 }
 
+LocalIdentifier* AST::lookupLocalIdentifier(const Reference<LocalIdentifier>& ref, Scope* scope) {
+    const FQName& fqName = ref.getLookupFqName();
+
+    if (fqName.isIdentifier()) {
+        LocalIdentifier* iden = scope->lookupIdentifier(fqName.name());
+        if (iden == nullptr) {
+            std::cerr << "ERROR: identifier " << fqName.string() << " could not be found at "
+                      << ref.location() << "\n";
+            return nullptr;
+        }
+        return iden;
+    } else {
+        std::string errorMsg;
+        EnumValue* enumValue = lookupEnumValue(fqName, &errorMsg, scope);
+        if (enumValue == nullptr) {
+            std::cerr << "ERROR: " << errorMsg << " at " << ref.location() << "\n";
+            return nullptr;
+        }
+        return enumValue;
+    }
+}
+
 EnumValue* AST::lookupEnumValue(const FQName& fqName, std::string* errorMsg, Scope* scope) {
     FQName enumTypeName = fqName.typeName();
     std::string enumValueName = fqName.valueName();
diff --git a/AST.h b/AST.h
index 0f27c85..f3cf63c 100644
--- a/AST.h
+++ b/AST.h
@@ -61,6 +61,10 @@
 
     const std::string &getFilename() const;
 
+    // Look up local identifier.
+    // It could be plain identifier or enum value as described by lookupEnumValue.
+    LocalIdentifier* lookupLocalIdentifier(const Reference<LocalIdentifier>& ref, Scope* scope);
+
     // Look up an enum value by "FQName:valueName".
     EnumValue* lookupEnumValue(const FQName& fqName, std::string* errorMsg, Scope* scope);
 
@@ -77,7 +81,13 @@
 
     // Recursive pass on constant expression tree
     status_t constantExpressionRecursivePass(
-        const std::function<status_t(ConstantExpression*)>& func);
+        const std::function<status_t(ConstantExpression*)>& func, bool processBeforeDependencies);
+
+    // Recursive tree pass that looks up all referenced types
+    status_t lookupTypes();
+
+    // Recursive tree pass that looks up all referenced local identifiers
+    status_t lookupLocalIdentifiers();
 
     // Recursive tree pass that validates that all defined types
     // have unique names in their scopes.
diff --git a/ConstantExpression.cpp b/ConstantExpression.cpp
index 5f24f1f..1b2979d 100644
--- a/ConstantExpression.cpp
+++ b/ConstantExpression.cpp
@@ -466,55 +466,67 @@
 }
 
 status_t ConstantExpression::recursivePass(const std::function<status_t(ConstantExpression*)>& func,
-                                           std::unordered_set<const ConstantExpression*>* visited) {
+                                           std::unordered_set<const ConstantExpression*>* visited,
+                                           bool processBeforeDependencies) {
     if (mIsPostParseCompleted) return OK;
 
     if (visited->find(this) != visited->end()) return OK;
     visited->insert(this);
 
+    if (processBeforeDependencies) {
+        status_t err = func(this);
+        if (err != OK) return err;
+    }
+
     for (auto* nextCE : getConstantExpressions()) {
-        status_t err = nextCE->recursivePass(func, visited);
+        status_t err = nextCE->recursivePass(func, visited, processBeforeDependencies);
         if (err != OK) return err;
     }
 
     for (auto* nextRef : getReferences()) {
         auto* nextCE = nextRef->shallowGet()->constExpr();
         CHECK(nextCE != nullptr) << "Local identifier is not a constant expression";
-        status_t err = nextCE->recursivePass(func, visited);
+        status_t err = nextCE->recursivePass(func, visited, processBeforeDependencies);
         if (err != OK) return err;
     }
 
-    // Unlike types, constant expressions need to be proceeded after dependencies
-    status_t err = func(this);
-    if (err != OK) return err;
+    if (!processBeforeDependencies) {
+        status_t err = func(this);
+        if (err != OK) return err;
+    }
 
     return OK;
 }
 
 status_t ConstantExpression::recursivePass(
     const std::function<status_t(const ConstantExpression*)>& func,
-    std::unordered_set<const ConstantExpression*>* visited) const {
-
+    std::unordered_set<const ConstantExpression*>* visited, bool processBeforeDependencies) const {
     if (mIsPostParseCompleted) return OK;
 
     if (visited->find(this) != visited->end()) return OK;
     visited->insert(this);
 
+    if (processBeforeDependencies) {
+        status_t err = func(this);
+        if (err != OK) return err;
+    }
+
     for (const auto* nextCE : getConstantExpressions()) {
-        status_t err = nextCE->recursivePass(func, visited);
+        status_t err = nextCE->recursivePass(func, visited, processBeforeDependencies);
         if (err != OK) return err;
     }
 
     for (const auto* nextRef : getReferences()) {
         const auto* nextCE = nextRef->shallowGet()->constExpr();
         CHECK(nextCE != nullptr) << "Local identifier is not a constant expression";
-        status_t err = nextCE->recursivePass(func, visited);
+        status_t err = nextCE->recursivePass(func, visited, processBeforeDependencies);
         if (err != OK) return err;
     }
 
-    // Unlike types, constant expressions need to be proceeded after dependencies
-    status_t err = func(this);
-    if (err != OK) return err;
+    if (!processBeforeDependencies) {
+        status_t err = func(this);
+        if (err != OK) return err;
+    }
 
     return OK;
 }
diff --git a/ConstantExpression.h b/ConstantExpression.h
index 3d23d61..0dc9693 100644
--- a/ConstantExpression.h
+++ b/ConstantExpression.h
@@ -54,9 +54,11 @@
     // Makes sure to visit each node only once
     // Used to provide lookup and lazy evaluation
     status_t recursivePass(const std::function<status_t(ConstantExpression*)>& func,
-                           std::unordered_set<const ConstantExpression*>* visited);
+                           std::unordered_set<const ConstantExpression*>* visited,
+                           bool processBeforeDependencies);
     status_t recursivePass(const std::function<status_t(const ConstantExpression*)>& func,
-                           std::unordered_set<const ConstantExpression*>* visited) const;
+                           std::unordered_set<const ConstantExpression*>* visited,
+                           bool processBeforeDependencies) const;
 
     // Evaluates current constant expression
     // Doesn't call recursive evaluation, so must be called after dependencies
diff --git a/Interface.cpp b/Interface.cpp
index 0fd702c..86087e0 100644
--- a/Interface.cpp
+++ b/Interface.cpp
@@ -461,7 +461,7 @@
 std::vector<const Reference<Type>*> Interface::getReferences() const {
     std::vector<const Reference<Type>*> ret;
 
-    if (superType() != nullptr) {
+    if (!isIBase()) {
         ret.push_back(&mSuperType);
     }
 
@@ -488,7 +488,7 @@
     // not necessary for other references.
 
     std::vector<const Reference<Type>*> ret;
-    if (superType() != nullptr) {
+    if (!isIBase()) {
         ret.push_back(&mSuperType);
     }
 
diff --git a/hidl-gen_y.yy b/hidl-gen_y.yy
index 11c685e..745211c 100644
--- a/hidl-gen_y.yy
+++ b/hidl-gen_y.yy
@@ -488,17 +488,6 @@
     : fqname
       {
           $$ = new Reference<Type>(*$1, convertYYLoc(@1));
-
-          Type* type = ast->lookupType($$->getLookupFqName(), *scope);
-          if (type == nullptr) {
-              std::cerr << "ERROR: Failed to lookup type '" << $1->string() << "' at "
-                        << @1
-                        << "\n";
-
-              YYERROR;
-          }
-
-          $$->set(type);
       }
     | TYPE
       {
@@ -655,9 +644,6 @@
 
               if (superType == nullptr) {
                   superType = new Reference<Type>(gIBaseFqName, convertYYLoc(@$));
-                  Type* type = ast->lookupType(superType->getLookupFqName(), *scope);
-                  CHECK(type != nullptr && type->isInterface());
-                  superType->set(type);
               }
           }
 
@@ -721,27 +707,8 @@
               YYERROR;
           }
 
-          if($1->isIdentifier()) {
-              std::string identifier = $1->name();
-              LocalIdentifier *iden = (*scope)->lookupIdentifier(identifier);
-              if(!iden) {
-                  std::cerr << "ERROR: identifier " << $1->string()
-                            << " could not be found at " << @1 << ".\n";
-                  YYERROR;
-              }
-              $$ = new ReferenceConstantExpression(
-                  Reference<LocalIdentifier>(iden, convertYYLoc(@1)), $1->string());
-          } else {
-              std::string errorMsg;
-              EnumValue* v = ast->lookupEnumValue(*$1, &errorMsg, *scope);
-              if(v == nullptr) {
-                  std::cerr << "ERROR: " << errorMsg << " at " << @1 << ".\n";
-                  YYERROR;
-              }
-
-              $$ = new ReferenceConstantExpression(
-                  Reference<LocalIdentifier>(v, convertYYLoc(@1)), $1->string());
-          }
+          $$ = new ReferenceConstantExpression(
+              Reference<LocalIdentifier>(*$1, convertYYLoc(@1)), $1->string());
       }
     | const_expr '?' const_expr ':' const_expr
       {
@@ -1038,15 +1005,6 @@
       {
           // "interface" is a synonym of android.hidl.base@1.0::IBase
           $$ = new Reference<Type>(gIBaseFqName, convertYYLoc(@1));
-          Type* type = ast->lookupType($$->getLookupFqName(), *scope);
-          if (type == nullptr) {
-              std::cerr << "ERROR: Cannot find "
-                        << gIBaseFqName.string()
-                        << " at " << @1 << "\n";
-
-              YYERROR;
-          }
-          $$->set(type);
       }
     ;