Propagate interface annotations to methods

Change-Id: I85f9686e4b5df7df0d9fb77e6a1b50a93ff3e2d7
Test: Unit tests pass
Bug: 26911508
diff --git a/aidl.cpp b/aidl.cpp
index 6cc25a0..c60626f 100644
--- a/aidl.cpp
+++ b/aidl.cpp
@@ -189,6 +189,12 @@
                 TypeNamespace* types) {
   int err = 0;
 
+  if (c->IsUtf8() && c->IsUtf8InCpp()) {
+    cerr << filename << ":" << c->GetLine()
+         << "Interface cannot be marked as both @utf8 and @utf8InCpp";
+    err = 1;
+  }
+
   // Has to be a pointer due to deleting copy constructor. No idea why.
   map<string, const AidlMethod*> method_names;
   for (const auto& m : c->GetMethods()) {
@@ -199,7 +205,7 @@
     }
 
     const ValidatableType* return_type =
-        types->GetReturnType(m->GetType(), filename);
+        types->GetReturnType(m->GetType(), filename, *c);
 
     if (!return_type) {
       err = 1;
@@ -221,7 +227,7 @@
       }
 
       const ValidatableType* arg_type =
-          types->GetArgType(*arg, index, filename);
+          types->GetArgType(*arg, index, filename, *c);
 
       if (!arg_type) {
         err = 1;
diff --git a/type_cpp.cpp b/type_cpp.cpp
index a436935..8b6c901 100644
--- a/type_cpp.cpp
+++ b/type_cpp.cpp
@@ -579,7 +579,8 @@
 
 const ValidatableType* TypeNamespace::GetArgType(const AidlArgument& a,
     int arg_index,
-    const std::string& filename) const {
+    const std::string& filename,
+    const AidlInterface& interface) const {
   const string error_prefix = StringPrintf(
       "In file %s line %d parameter %s (%d):\n    ",
       filename.c_str(), a.GetLine(), a.GetName().c_str(), arg_index);
@@ -591,7 +592,8 @@
     return nullptr;
   }
 
-  return ::android::aidl::TypeNamespace::GetArgType(a, arg_index, filename);
+  return ::android::aidl::TypeNamespace::GetArgType(a, arg_index, filename,
+                                                    interface);
 }
 
 }  // namespace cpp
diff --git a/type_cpp.h b/type_cpp.h
index 776a396..66b2ed5 100644
--- a/type_cpp.h
+++ b/type_cpp.h
@@ -125,7 +125,8 @@
   bool IsValidPackage(const std::string& package) const override;
   const ValidatableType* GetArgType(const AidlArgument& a,
                              int arg_index,
-                             const std::string& filename) const override;
+                             const std::string& filename,
+                             const AidlInterface& interface) const override;
 
   const Type* VoidType() const { return void_type_; }
   const Type* IBinderType() const { return ibinder_type_; }
diff --git a/type_namespace.cpp b/type_namespace.cpp
index 49b96ac..d3f4aaf 100644
--- a/type_namespace.cpp
+++ b/type_namespace.cpp
@@ -98,9 +98,11 @@
 }
 
 const ValidatableType* TypeNamespace::GetReturnType(
-    const AidlType& raw_type, const string& filename) const {
+    const AidlType& raw_type, const string& filename,
+    const AidlInterface& interface) const {
   string error_msg;
-  const ValidatableType* return_type = GetValidatableType(raw_type, &error_msg);
+  const ValidatableType* return_type = GetValidatableType(raw_type, &error_msg,
+                                                          interface);
   if (return_type == nullptr) {
     LOG(ERROR) << StringPrintf("In file %s line %d return type %s:\n    ",
                                filename.c_str(), raw_type.GetLine(),
@@ -113,14 +115,16 @@
 }
 
 const ValidatableType* TypeNamespace::GetArgType(
-    const AidlArgument& a, int arg_index, const string& filename) const {
+    const AidlArgument& a, int arg_index, const string& filename,
+    const AidlInterface& interface) const {
   string error_prefix = StringPrintf(
       "In file %s line %d parameter %s (argument %d):\n    ",
       filename.c_str(), a.GetLine(), a.GetName().c_str(), arg_index);
 
   // check the arg type
   string error_msg;
-  const ValidatableType* t = GetValidatableType(a.GetType(), &error_msg);
+  const ValidatableType* t = GetValidatableType(a.GetType(), &error_msg,
+                                                interface);
   if (t == nullptr) {
     LOG(ERROR) << error_prefix << error_msg;
     return nullptr;
diff --git a/type_namespace.h b/type_namespace.h
index ded96a2..caea1fa 100644
--- a/type_namespace.h
+++ b/type_namespace.h
@@ -113,13 +113,16 @@
   // if this is an invalid return type.
   virtual const ValidatableType* GetReturnType(
       const AidlType& raw_type,
-      const std::string& filename) const;
+      const std::string& filename,
+      const AidlInterface& interface) const;
 
   // Returns a pointer to a type corresponding to |a| or nullptr if |a|
   // has an invalid argument type.
-  virtual const ValidatableType* GetArgType(const AidlArgument& a,
-                                            int arg_index,
-                                            const std::string& filename) const;
+  virtual const ValidatableType* GetArgType(
+      const AidlArgument& a,
+      int arg_index,
+      const std::string& filename,
+      const AidlInterface& interface) const;
 
   // Returns a pointer to a type corresponding to |interface|.
   virtual const ValidatableType* GetInterfaceType(
@@ -130,7 +133,8 @@
   virtual ~TypeNamespace() = default;
 
   virtual const ValidatableType* GetValidatableType(
-      const AidlType& type, std::string* error_msg) const = 0;
+      const AidlType& type, std::string* error_msg,
+      const AidlInterface& interface) const = 0;
 
  private:
   DISALLOW_COPY_AND_ASSIGN(TypeNamespace);
@@ -182,7 +186,8 @@
   bool IsContainerType(const std::string& type_name) const;
 
   const ValidatableType* GetValidatableType(
-      const AidlType& type, std::string* error_msg) const override;
+      const AidlType& type, std::string* error_msg,
+      const AidlInterface& interface) const override;
 
   std::vector<std::unique_ptr<const T>> types_;
 
@@ -380,7 +385,8 @@
 
 template<typename T>
 const ValidatableType* LanguageTypeNamespace<T>::GetValidatableType(
-    const AidlType& aidl_type, std::string* error_msg) const {
+    const AidlType& aidl_type, std::string* error_msg,
+    const AidlInterface& interface) const {
   using android::base::StringPrintf;
 
   const ValidatableType* type = Find(aidl_type);
@@ -410,34 +416,43 @@
     return nullptr;
   }
 
+  bool utf8 = aidl_type.IsUtf8();
+  bool utf8InCpp = aidl_type.IsUtf8InCpp();
+
   // Strings inside containers get remapped to appropriate utf8 versions when
   // we convert the container name to its canonical form and the look up the
   // type.  However, for non-compound types (i.e. those not in a container) we
   // must patch them up here.
-  if (!IsContainerType(type->CanonicalName()) &&
-      (aidl_type.IsUtf8() || aidl_type.IsUtf8InCpp())) {
+  if (IsContainerType(type->CanonicalName())) {
+    utf8 = false;
+    utf8InCpp = false;
+  } else if (aidl_type.GetName() == "String" ||
+             aidl_type.GetName() == "java.lang.String") {
+    utf8 = utf8 || interface.IsUtf8();
+    utf8InCpp = utf8InCpp || interface.IsUtf8InCpp();
+  } else if (utf8 || utf8InCpp) {
     const char* annotation_literal =
-        (aidl_type.IsUtf8()) ? kUtf8Annotation : kUtf8InCppAnnotation;
-    if (aidl_type.GetName() != "String" &&
-        aidl_type.GetName() != "java.lang.String") {
-      *error_msg = StringPrintf("type '%s' may not be annotated as %s.",
-                                aidl_type.GetName().c_str(),
-                                annotation_literal);
-      return nullptr;
-    }
+        (utf8) ? kUtf8Annotation : kUtf8InCppAnnotation;
+    *error_msg = StringPrintf("type '%s' may not be annotated as %s.",
+                              aidl_type.GetName().c_str(),
+                              annotation_literal);
+    return nullptr;
+  }
 
-    if (aidl_type.IsUtf8()) {
-      type = FindTypeByCanonicalName(kUtf8StringCanonicalName);
-    } else {  // aidl_type.IsUtf8InCpp()
-      type = FindTypeByCanonicalName(kUtf8InCppStringCanonicalName);
-    }
+  if (utf8) {
+    type = FindTypeByCanonicalName(kUtf8StringCanonicalName);
+  } else if (utf8InCpp) {
+    type = FindTypeByCanonicalName(kUtf8InCppStringCanonicalName);
+  }
 
-    if (type == nullptr) {
-      *error_msg = StringPrintf(
-          "%s is unsupported when generating code for this language.",
-          annotation_literal);
-      return nullptr;
-    }
+  // One of our UTF8 transforms made type null
+  if (type == nullptr) {
+    const char* annotation_literal =
+        (utf8) ? kUtf8Annotation : kUtf8InCppAnnotation;
+    *error_msg = StringPrintf(
+        "%s is unsupported when generating code for this language.",
+        annotation_literal);
+    return nullptr;
   }
 
   if (!type->CanWriteToParcel()) {
@@ -454,6 +469,14 @@
     }
   }
 
+  if (interface.IsNullable()) {
+    const ValidatableType* nullableType = type->NullableType();
+
+    if (nullableType) {
+      return nullableType;
+    }
+  }
+
   if (aidl_type.IsNullable()) {
     type = type->NullableType();
     if (!type) {