Clean up forEach reflection code.

BUG=4203264

Change-Id: I8196608408fe333bd8e875d9517b8e875bdce17d
diff --git a/slang_rs_context.cpp b/slang_rs_context.cpp
index 5ee2a81..e6a0770 100644
--- a/slang_rs_context.cpp
+++ b/slang_rs_context.cpp
@@ -115,9 +115,6 @@
   }
 
   if (RSExportForEach::isRSForEachFunc(FD)) {
-    if (!RSExportForEach::validateSpecialFuncDecl(getDiagnostics(), FD)) {
-      return false;
-    }
     RSExportForEach *EFE = RSExportForEach::Create(this, FD);
     if (EFE == NULL)
       return false;
@@ -125,7 +122,7 @@
       mExportForEach.push_back(EFE);
     return true;
   } else if (RSExportForEach::isSpecialRSFunc(FD)) {
-    // Do not reflect specialized RS functions like init/root.
+    // Do not reflect specialized RS functions like init or graphics root.
     if (!RSExportForEach::validateSpecialFuncDecl(getDiagnostics(), FD)) {
       return false;
     }
diff --git a/slang_rs_export_foreach.cpp b/slang_rs_export_foreach.cpp
index 5286457..19e0649 100644
--- a/slang_rs_export_foreach.cpp
+++ b/slang_rs_export_foreach.cpp
@@ -20,6 +20,7 @@
 
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/TypeLoc.h"
 
 #include "llvm/DerivedTypes.h"
 #include "llvm/Target/TargetData.h"
@@ -30,36 +31,212 @@
 
 namespace slang {
 
+namespace {
+
+static void ReportNameError(clang::Diagnostic *Diags,
+                            const clang::ParmVarDecl *PVD) {
+  slangAssert(Diags && PVD);
+  const clang::SourceManager &SM = Diags->getSourceManager();
+
+  Diags->Report(clang::FullSourceLoc(PVD->getLocation(), SM),
+                Diags->getCustomDiagID(clang::Diagnostic::Error,
+                "Duplicate parameter entry (by position/name): '%0'"))
+       << PVD->getName();
+  return;
+}
+
+}  // namespace
+
+// This function takes care of additional validation and construction of
+// parameters related to forEach_* reflection.
+bool RSExportForEach::validateAndConstructParams(
+    RSContext *Context, const clang::FunctionDecl *FD) {
+  slangAssert(Context && FD);
+  bool valid = true;
+  clang::ASTContext &C = Context->getASTContext();
+  clang::Diagnostic *Diags = Context->getDiagnostics();
+
+  if (!isRootRSFunc(FD)) {
+    slangAssert(false && "must be called on compute root function!");
+  }
+
+  numParams = FD->getNumParams();
+  slangAssert(numParams > 0);
+
+  // Compute root functions are required to return a void type for now
+  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
+    Diags->Report(
+        clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
+        Diags->getCustomDiagID(clang::Diagnostic::Error,
+                               "compute root() is required to return a "
+                               "void type"));
+    valid = false;
+  }
+
+  // Validate remaining parameter types
+  // TODO(all): Add support for LOD/face when we have them
+
+  for (size_t i = 0; i < numParams; i++) {
+    const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
+    clang::QualType QT = PVD->getType().getCanonicalType();
+    llvm::StringRef ParamName = PVD->getName();
+
+    if (QT->isPointerType()) {
+      if (QT->getPointeeType().isConstQualified()) {
+        // const T1 *in
+        // const T3 *usrData
+        if (ParamName.equals("in")) {
+          if (mIn) {
+            ReportNameError(Diags, PVD);
+            valid = false;
+          } else {
+            mIn = PVD;
+          }
+        } else if (ParamName.equals("usrData")) {
+          if (mUsrData) {
+            ReportNameError(Diags, PVD);
+            valid = false;
+          } else {
+            mUsrData = PVD;
+          }
+        } else {
+          // Issue warning about positional parameter usage
+          if (!mIn) {
+            mIn = PVD;
+          } else if (!mUsrData) {
+            mUsrData = PVD;
+          } else {
+            Diags->Report(
+                clang::FullSourceLoc(PVD->getLocation(),
+                                     Diags->getSourceManager()),
+                Diags->getCustomDiagID(clang::Diagnostic::Error,
+                                       "Unexpected root() parameter '%0' "
+                                       "of type '%1'"))
+                << PVD->getName() << PVD->getType().getAsString();
+            valid = false;
+          }
+        }
+      } else {
+        // T2 *out
+        if (ParamName.equals("out")) {
+          if (mOut) {
+            ReportNameError(Diags, PVD);
+            valid = false;
+          } else {
+            mOut = PVD;
+          }
+        } else {
+          if (!mOut) {
+            mOut = PVD;
+          } else {
+            Diags->Report(
+                clang::FullSourceLoc(PVD->getLocation(),
+                                     Diags->getSourceManager()),
+                Diags->getCustomDiagID(clang::Diagnostic::Error,
+                                       "Unexpected root() parameter '%0' "
+                                       "of type '%1'"))
+                << PVD->getName() << PVD->getType().getAsString();
+            valid = false;
+          }
+        }
+      }
+    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
+      if (ParamName.equals("x")) {
+        if (mX) {
+          ReportNameError(Diags, PVD);
+          valid = false;
+        } else {
+          mX = PVD;
+        }
+      } else if (ParamName.equals("y")) {
+        if (mY) {
+          ReportNameError(Diags, PVD);
+          valid = false;
+        } else {
+          mY = PVD;
+        }
+      } else if (ParamName.equals("z")) {
+        if (mZ) {
+          ReportNameError(Diags, PVD);
+          valid = false;
+        } else {
+          mZ = PVD;
+        }
+      } else if (ParamName.equals("ar")) {
+        if (mAr) {
+          ReportNameError(Diags, PVD);
+          valid = false;
+        } else {
+          mAr = PVD;
+        }
+      } else {
+        if (!mX) {
+          mX = PVD;
+        } else if (!mY) {
+          mY = PVD;
+        } else if (!mZ) {
+          mZ = PVD;
+        } else if (!mAr) {
+          mAr = PVD;
+        } else {
+          Diags->Report(
+              clang::FullSourceLoc(PVD->getLocation(),
+                                   Diags->getSourceManager()),
+              Diags->getCustomDiagID(clang::Diagnostic::Error,
+                                     "Unexpected root() parameter '%0' "
+                                     "of type '%1'"))
+              << PVD->getName() << PVD->getType().getAsString();
+          valid = false;
+        }
+      }
+    } else {
+      Diags->Report(
+          clang::FullSourceLoc(
+              PVD->getTypeSourceInfo()->getTypeLoc().getBeginLoc(),
+              Diags->getSourceManager()),
+          Diags->getCustomDiagID(clang::Diagnostic::Error,
+              "Unexpected root() parameter type '%0'"))
+          << PVD->getType().getAsString();
+      valid = false;
+    }
+  }
+
+  if (!mIn && !mOut) {
+    Diags->Report(
+        clang::FullSourceLoc(FD->getLocation(),
+                             Diags->getSourceManager()),
+        Diags->getCustomDiagID(clang::Diagnostic::Error,
+                               "Compute root() must have at least one "
+                               "parameter for in or out"));
+    valid = false;
+  }
+
+  return valid;
+}
+
 RSExportForEach *RSExportForEach::Create(RSContext *Context,
                                          const clang::FunctionDecl *FD) {
+  slangAssert(Context && FD);
   llvm::StringRef Name = FD->getName();
-  RSExportForEach *F;
+  RSExportForEach *FE;
 
   slangAssert(!Name.empty() && "Function must have a name");
 
-  F = new RSExportForEach(Context, Name, FD);
+  FE = new RSExportForEach(Context, Name, FD);
 
-  F->numParams = FD->getNumParams();
-
-  if (F->numParams == 0) {
-    slangAssert(false && "Should have at least one parameter for root");
+  if (!FE->validateAndConstructParams(Context, FD)) {
+    delete FE;
+    return NULL;
   }
 
   clang::ASTContext &Ctx = Context->getASTContext();
 
   std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
-  Id.append(F->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
-
-  clang::RecordDecl *RD =
-      clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
-                                Ctx.getTranslationUnitDecl(),
-                                clang::SourceLocation(),
-                                clang::SourceLocation(),
-                                &Ctx.Idents.get(Id));
+  Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
 
   // Extract the usrData parameter (if we have one)
-  if (F->numParams >= 3) {
-    const clang::ParmVarDecl *PVD = FD->getParamDecl(2);
+  if (FE->mUsrData) {
+    const clang::ParmVarDecl *PVD = FE->mUsrData;
     clang::QualType QT = PVD->getType().getCanonicalType();
     slangAssert(QT->isPointerType() &&
                 QT->getPointeeType().isConstQualified());
@@ -69,8 +246,15 @@
         C.VoidTy) {
       // In the case of using const void*, we can't reflect an appopriate
       // Java type, so we fall back to just reflecting the ain/aout parameters
-      F->numParams = 2;
+      FE->mUsrData = NULL;
     } else {
+      clang::RecordDecl *RD =
+          clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
+                                    Ctx.getTranslationUnitDecl(),
+                                    clang::SourceLocation(),
+                                    clang::SourceLocation(),
+                                    &Ctx.Idents.get(Id));
+
       llvm::StringRef ParamName = PVD->getName();
       clang::FieldDecl *FD =
           clang::FieldDecl::Create(Ctx,
@@ -83,33 +267,40 @@
                                    /* BitWidth = */NULL,
                                    /* Mutable = */false);
       RD->addDecl(FD);
+      RD->completeDefinition();
+
+      // Create an export type iff we have a valid usrData type
+      clang::QualType T = Ctx.getTagDeclType(RD);
+      slangAssert(!T.isNull());
+
+      RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
+
+      if (ET == NULL) {
+        fprintf(stderr, "Failed to export the function %s. There's at least "
+                        "one parameter whose type is not supported by the "
+                        "reflection\n", FE->getName().c_str());
+        delete FE;
+        return NULL;
+      }
+
+      slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
+                  "Parameter packet must be a record");
+
+      FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
     }
   }
-  RD->completeDefinition();
 
-  if (F->numParams >= 3) {
-    // Create an export type iff we have a valid usrData type
-    clang::QualType T = Ctx.getTagDeclType(RD);
-    slangAssert(!T.isNull());
-
-    RSExportType *ET =
-      RSExportType::Create(Context, T.getTypePtr());
-
-    if (ET == NULL) {
-      fprintf(stderr, "Failed to export the function %s. There's at least one "
-                      "parameter whose type is not supported by the "
-                      "reflection\n", F->getName().c_str());
-      delete F;
-      return NULL;
-    }
-
-    slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
-           "Parameter packet must be a record");
-
-    F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
+  if (FE->mIn) {
+    const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
+    FE->mInType = RSExportType::Create(Context, T);
   }
 
-  return F;
+  if (FE->mOut) {
+    const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
+    FE->mOutType = RSExportType::Create(Context, T);
+  }
+
+  return FE;
 }
 
 bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
@@ -118,9 +309,7 @@
     return false;
   }
 
-  const clang::ASTContext &C = FD->getASTContext();
-  if (FD->getNumParams() == 0 &&
-      FD->getResultType().getCanonicalType() == C.IntTy) {
+  if (FD->getNumParams() == 0) {
     // Graphics compute function
     return false;
   }
@@ -129,10 +318,7 @@
 
 bool RSExportForEach::validateSpecialFuncDecl(clang::Diagnostic *Diags,
                                               const clang::FunctionDecl *FD) {
-  if (!FD) {
-    return false;
-  }
-
+  slangAssert(Diags && FD);
   bool valid = true;
   const clang::ASTContext &C = FD->getASTContext();
 
@@ -149,80 +335,8 @@
         valid = false;
       }
     } else {
-      // Compute root functions are required to return a void type for now
-      if (FD->getResultType().getCanonicalType() != C.VoidTy) {
-        Diags->Report(
-            clang::FullSourceLoc(FD->getLocation(), Diags->getSourceManager()),
-            Diags->getCustomDiagID(clang::Diagnostic::Error,
-                                   "compute root() is required to return a "
-                                   "void type"));
-        valid = false;
-      }
-
-      // Validate remaining parameter types
-      const clang::ParmVarDecl *tooManyParams = NULL;
-      for (unsigned int i = 0; i < numParams; i++) {
-        const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
-        clang::QualType QT = PVD->getType().getCanonicalType();
-        switch (i) {
-          case 0:     // const T1 *ain
-          case 2: {   // const T3 *usrData
-            if (!QT->isPointerType() ||
-                !QT->getPointeeType().isConstQualified()) {
-              Diags->Report(
-                  clang::FullSourceLoc(PVD->getLocation(),
-                                       Diags->getSourceManager()),
-                  Diags->getCustomDiagID(clang::Diagnostic::Error,
-                                         "compute root() parameter must be a "
-                                         "const pointer type"));
-              valid = false;
-            }
-            break;
-          }
-          case 1: {   // T2 *aout
-            if (!QT->isPointerType()) {
-              Diags->Report(
-                  clang::FullSourceLoc(PVD->getLocation(),
-                                       Diags->getSourceManager()),
-                  Diags->getCustomDiagID(clang::Diagnostic::Error,
-                                         "compute root() parameter must be a "
-                                         "pointer type"));
-              valid = false;
-            }
-            break;
-          }
-          case 3:     // unsigned int x
-          case 4:     // unsigned int y
-          case 5:     // unsigned int z
-          case 6: {   // unsigned int ar
-            if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
-              Diags->Report(
-                  clang::FullSourceLoc(PVD->getLocation(),
-                                       Diags->getSourceManager()),
-                  Diags->getCustomDiagID(clang::Diagnostic::Error,
-                                         "compute root() parameter must be a "
-                                         "uint32_t type"));
-              valid = false;
-            }
-            break;
-          }
-          default: {
-            if (!tooManyParams) {
-              tooManyParams = PVD;
-            }
-            break;
-          }
-        }
-      }
-      if (tooManyParams) {
-        Diags->Report(
-            clang::FullSourceLoc(tooManyParams->getLocation(),
-                                 Diags->getSourceManager()),
-            Diags->getCustomDiagID(clang::Diagnostic::Error,
-                                   "too many compute root() parameters "
-                                   "specified"));
-        valid = false;
-      }
+      slangAssert(false &&
+          "Should not call validateSpecialFuncDecl() on compute root()");
     }
   } else if (isInitRSFunc(FD)) {
     if (FD->getNumParams() != 0) {
diff --git a/slang_rs_export_foreach.h b/slang_rs_export_foreach.h
index 22cc180..650510a 100644
--- a/slang_rs_export_foreach.h
+++ b/slang_rs_export_foreach.h
@@ -39,17 +39,31 @@
  private:
   std::string mName;
   RSExportRecordType *mParamPacketType;
+  RSExportType *mInType;
+  RSExportType *mOutType;
   size_t numParams;
 
+  const clang::ParmVarDecl *mIn;
+  const clang::ParmVarDecl *mOut;
+  const clang::ParmVarDecl *mUsrData;
+  const clang::ParmVarDecl *mX;
+  const clang::ParmVarDecl *mY;
+  const clang::ParmVarDecl *mZ;
+  const clang::ParmVarDecl *mAr;
+
+  // TODO(all): Add support for LOD/face when we have them
   RSExportForEach(RSContext *Context, const llvm::StringRef &Name,
          const clang::FunctionDecl *FD)
     : RSExportable(Context, RSExportable::EX_FOREACH),
-      mName(Name.data(), Name.size()),
-      mParamPacketType(NULL),
-      numParams(0) {
+      mName(Name.data(), Name.size()), mParamPacketType(NULL), mInType(NULL),
+      mOutType(NULL), numParams(0), mIn(NULL), mOut(NULL), mUsrData(NULL),
+      mX(NULL), mY(NULL), mZ(NULL), mAr(NULL) {
     return;
   }
 
+  bool validateAndConstructParams(RSContext *Context,
+                                  const clang::FunctionDecl *FD);
+
  public:
   static RSExportForEach *Create(RSContext *Context,
                                  const clang::FunctionDecl *FD);
@@ -62,8 +76,29 @@
     return numParams;
   }
 
-  inline const RSExportRecordType *getParamPacketType() const
-    { return mParamPacketType; }
+  inline bool hasIn() const {
+    return (mIn != NULL);
+  }
+
+  inline bool hasOut() const {
+    return (mOut != NULL);
+  }
+
+  inline bool hasUsrData() const {
+    return (mUsrData != NULL);
+  }
+
+  inline const RSExportType *getInType() const {
+    return mInType;
+  }
+
+  inline const RSExportType *getOutType() const {
+    return mOutType;
+  }
+
+  inline const RSExportRecordType *getParamPacketType() const {
+    return mParamPacketType;
+  }
 
   typedef RSExportRecordType::const_field_iterator const_param_iterator;
 
@@ -72,6 +107,7 @@
                 "Get parameter from export foreach having no parameter!");
     return mParamPacketType->fields_begin();
   }
+
   inline const_param_iterator params_end() const {
     slangAssert((mParamPacketType != NULL) &&
                 "Get parameter from export foreach having no parameter!");
diff --git a/slang_rs_reflection.cpp b/slang_rs_reflection.cpp
index dcee519..52b2a47 100644
--- a/slang_rs_reflection.cpp
+++ b/slang_rs_reflection.cpp
@@ -800,20 +800,21 @@
              << EF->getName() << " = " << C.getNextExportForEachSlot() << ";"
              << std::endl;
 
-  // for_each_*()
+  // forEach_*()
   Context::ArgTy Args;
 
-  std::string FieldPackerName = EF->getName() + "_fp";
-  size_t numParams = EF->getNumParameters();
+  slangAssert(EF->getNumParameters() > 0);
 
-  slangAssert(numParams >= 1);
-  Args.push_back(std::make_pair("Allocation", "ain"));
-  //GetTypeName(RSExportPrimitiveType::DataTypeRSAllocation), "ain");
-  if (numParams >= 2) {
+  if (EF->hasIn())
+    Args.push_back(std::make_pair("Allocation", "ain"));
+    //Args.push_back(std::make_pair(GetTypeName(EF->getInType()), "ain"));
+  if (EF->hasOut())
     Args.push_back(std::make_pair("Allocation", "aout"));
-  }
-  if (numParams >= 3) {
-    for (RSExportFunc::const_param_iterator I = EF->params_begin(),
+    //Args.push_back(std::make_pair(GetTypeName(EF->getOutType()), "aout"));
+
+  const RSExportRecordType *ERT = EF->getParamPacketType();
+  if (ERT) {
+    for (RSExportForEach::const_param_iterator I = EF->params_begin(),
              E = EF->params_end();
          I != E;
          I++) {
@@ -828,32 +829,68 @@
                   "forEach_" + EF->getName(),
                   Args);
 
-  if (numParams >= 3) {
-    const RSExportRecordType *ERT = EF->getParamPacketType();
+  const RSExportType *IET = EF->getInType();
+  if (IET) {
+    genTypeCheck(C, IET, "ain");
+  }
 
-    if (genCreateFieldPacker(C, ERT, FieldPackerName.c_str()))
+  const RSExportType *OET = EF->getOutType();
+  if (OET) {
+    genTypeCheck(C, OET, "aout");
+  }
+
+  if (EF->hasIn() && EF->hasOut()) {
+    C.indent() << "// Verify dimensions" << std::endl;
+    C.indent() << "Type tIn = ain.getType();" << std::endl;
+    C.indent() << "Type tOut = aout.getType();" << std::endl;
+    C.indent() << "if ((tIn.getCount() != tOut.getCount()) ||" << std::endl;
+    C.indent() << "    (tIn.getX() != tOut.getX()) ||" << std::endl;
+    C.indent() << "    (tIn.getY() != tOut.getY()) ||" << std::endl;
+    C.indent() << "    (tIn.getZ() != tOut.getZ()) ||" << std::endl;
+    C.indent() << "    (tIn.hasFaces() != tOut.hasFaces()) ||" << std::endl;
+    C.indent() << "    (tIn.hasMipmaps() != tOut.hasMipmaps())) {" << std::endl;
+    C.indent() << "    throw new RSRuntimeException(\"Dimension mismatch\");";
+    C.out()    << std::endl;
+    C.indent() << "}" << std::endl;
+  }
+
+  std::string FieldPackerName = EF->getName() + "_fp";
+  if (ERT) {
+    if (genCreateFieldPacker(C, ERT, FieldPackerName.c_str())) {
       genPackVarOfType(C, ERT, NULL, FieldPackerName.c_str());
+    }
   }
-  C.indent() << "forEach("RS_EXPORT_FOREACH_INDEX_PREFIX << EF->getName()
-             << ", ain";
+  C.indent() << "forEach("RS_EXPORT_FOREACH_INDEX_PREFIX << EF->getName();
 
-  switch (numParams) {
-    case 1:
-      C.out() << ", null, null);" << std::endl;
-      break;
-    case 2:
-      C.out() << ", aout, null);" << std::endl;
-      break;
-    case 3:
-    default:
-      C.out() << ", aout, " << FieldPackerName << ");" << std::endl;
-      break;
-  }
+  if (EF->hasIn())
+    C.out() << ", ain";
+  else
+    C.out() << ", null";
+
+  if (EF->hasOut())
+    C.out() << ", aout";
+  else
+    C.out() << ", null";
+
+  if (EF->hasUsrData())
+    C.out() << ", " << FieldPackerName;
+  else
+    C.out() << ", null";
+
+  C.out() << ");" << std::endl;
 
   C.endFunction();
   return;
 }
 
+void RSReflection::genTypeCheck(Context &C,
+                                const RSExportType *ET,
+                                const char *VarName) {
+  C.indent() << "// check " << VarName << std::endl;
+  return;
+}
+
+
 void RSReflection::genPrimitiveTypeExportVariable(
     Context &C,
     const RSExportVar *EV) {
diff --git a/slang_rs_reflection.h b/slang_rs_reflection.h
index c26f82f..2da7484 100644
--- a/slang_rs_reflection.h
+++ b/slang_rs_reflection.h
@@ -241,6 +241,10 @@
   void genExportForEach(Context &C,
                         const RSExportForEach *EF);
 
+  void genTypeCheck(Context &C,
+                    const RSExportType *ET,
+                    const char *VarName);
+
   bool genTypeClass(Context &C,
                     const RSExportRecordType *ERT,
                     std::string &ErrorMsg);
diff --git a/tests/F_root_compute_non_const_non_ptr_ain/root_compute_non_const_non_ptr_ain.rs b/tests/F_root_compute_int_in/root_compute_int_in.rs
similarity index 67%
rename from tests/F_root_compute_non_const_non_ptr_ain/root_compute_non_const_non_ptr_ain.rs
rename to tests/F_root_compute_int_in/root_compute_int_in.rs
index 364b7c8..b2560b4 100644
--- a/tests/F_root_compute_non_const_non_ptr_ain/root_compute_non_const_non_ptr_ain.rs
+++ b/tests/F_root_compute_int_in/root_compute_int_in.rs
@@ -1,5 +1,5 @@
 #pragma version(1)
 #pragma rs java_package_name(foo)
 
-void root(int ain) {
+void root(const int in) {
 }
diff --git a/tests/F_root_compute_int_in/stderr.txt.expect b/tests/F_root_compute_int_in/stderr.txt.expect
new file mode 100644
index 0000000..803bc75
--- /dev/null
+++ b/tests/F_root_compute_int_in/stderr.txt.expect
@@ -0,0 +1,2 @@
+root_compute_int_in.rs:4:17: error: Unexpected root() parameter type 'const int'
+root_compute_int_in.rs:4:6: error: Compute root() must have at least one parameter for in or out
diff --git a/tests/F_root_compute_non_ptr_ain/stdout.txt.expect b/tests/F_root_compute_int_in/stdout.txt.expect
similarity index 100%
rename from tests/F_root_compute_non_ptr_ain/stdout.txt.expect
rename to tests/F_root_compute_int_in/stdout.txt.expect
diff --git a/tests/F_root_compute_non_const_ain/root_compute_non_const_ain.rs b/tests/F_root_compute_non_const_ain/root_compute_non_const_ain.rs
deleted file mode 100644
index 5cbf52b..0000000
--- a/tests/F_root_compute_non_const_ain/root_compute_non_const_ain.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-#pragma version(1)
-#pragma rs java_package_name(foo)
-
-void root(int *ain) {
-}
diff --git a/tests/F_root_compute_non_const_ain/stderr.txt.expect b/tests/F_root_compute_non_const_ain/stderr.txt.expect
deleted file mode 100644
index 539e033..0000000
--- a/tests/F_root_compute_non_const_ain/stderr.txt.expect
+++ /dev/null
@@ -1 +0,0 @@
-root_compute_non_const_ain.rs:4:16: error: compute root() parameter must be a const pointer type
diff --git a/tests/F_root_compute_non_const_ain/stdout.txt.expect b/tests/F_root_compute_non_const_ain/stdout.txt.expect
deleted file mode 100644
index e69de29..0000000
--- a/tests/F_root_compute_non_const_ain/stdout.txt.expect
+++ /dev/null
diff --git a/tests/F_root_compute_non_const_non_ptr_ain/stderr.txt.expect b/tests/F_root_compute_non_const_non_ptr_ain/stderr.txt.expect
deleted file mode 100644
index f774ed2..0000000
--- a/tests/F_root_compute_non_const_non_ptr_ain/stderr.txt.expect
+++ /dev/null
@@ -1 +0,0 @@
-root_compute_non_const_non_ptr_ain.rs:4:15: error: compute root() parameter must be a const pointer type
diff --git a/tests/F_root_compute_non_const_non_ptr_ain/stdout.txt.expect b/tests/F_root_compute_non_const_non_ptr_ain/stdout.txt.expect
deleted file mode 100644
index e69de29..0000000
--- a/tests/F_root_compute_non_const_non_ptr_ain/stdout.txt.expect
+++ /dev/null
diff --git a/tests/F_root_compute_non_const_usrData/stderr.txt.expect b/tests/F_root_compute_non_const_usrData/stderr.txt.expect
index 053353d..7cd483c 100644
--- a/tests/F_root_compute_non_const_usrData/stderr.txt.expect
+++ b/tests/F_root_compute_non_const_usrData/stderr.txt.expect
@@ -1 +1 @@
-root_compute_non_const_usrData.rs:4:44: error: compute root() parameter must be a const pointer type
+root_compute_non_const_usrData.rs:4:44: error: Unexpected root() parameter 'usrData' of type 'void *'
diff --git a/tests/F_root_compute_non_ptr_ain/root_compute_non_ptr_ain.rs b/tests/F_root_compute_non_ptr_ain/root_compute_non_ptr_ain.rs
deleted file mode 100644
index a311d47..0000000
--- a/tests/F_root_compute_non_ptr_ain/root_compute_non_ptr_ain.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-#pragma version(1)
-#pragma rs java_package_name(foo)
-
-void root(const int ain) {
-}
diff --git a/tests/F_root_compute_non_ptr_ain/stderr.txt.expect b/tests/F_root_compute_non_ptr_ain/stderr.txt.expect
deleted file mode 100644
index 3ffb17a..0000000
--- a/tests/F_root_compute_non_ptr_ain/stderr.txt.expect
+++ /dev/null
@@ -1 +0,0 @@
-root_compute_non_ptr_ain.rs:4:21: error: compute root() parameter must be a const pointer type
diff --git a/tests/F_root_compute_non_ptr_aout/root_compute_non_ptr_aout.rs b/tests/F_root_compute_non_ptr_aout/root_compute_non_ptr_aout.rs
deleted file mode 100644
index 351e335..0000000
--- a/tests/F_root_compute_non_ptr_aout/root_compute_non_ptr_aout.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-#pragma version(1)
-#pragma rs java_package_name(foo)
-
-void root(const int *ain, int aout) {
-}
diff --git a/tests/F_root_compute_non_ptr_aout/stderr.txt.expect b/tests/F_root_compute_non_ptr_aout/stderr.txt.expect
deleted file mode 100644
index 0d256c0..0000000
--- a/tests/F_root_compute_non_ptr_aout/stderr.txt.expect
+++ /dev/null
@@ -1 +0,0 @@
-root_compute_non_ptr_aout.rs:4:31: error: compute root() parameter must be a pointer type
diff --git a/tests/F_root_compute_non_ptr_aout/stdout.txt.expect b/tests/F_root_compute_non_ptr_aout/stdout.txt.expect
deleted file mode 100644
index e69de29..0000000
--- a/tests/F_root_compute_non_ptr_aout/stdout.txt.expect
+++ /dev/null
diff --git a/tests/F_root_compute_non_ptr_usrData/stderr.txt.expect b/tests/F_root_compute_non_ptr_usrData/stderr.txt.expect
index 44940b8..766186f 100644
--- a/tests/F_root_compute_non_ptr_usrData/stderr.txt.expect
+++ b/tests/F_root_compute_non_ptr_usrData/stderr.txt.expect
@@ -1 +1 @@
-root_compute_non_ptr_usrData.rs:4:48: error: compute root() parameter must be a const pointer type
+root_compute_non_ptr_usrData.rs:4:44: error: Unexpected root() parameter type 'const int'
diff --git a/tests/F_root_compute_non_uint32_t_xyzar/stderr.txt.expect b/tests/F_root_compute_non_uint32_t_xyzar/stderr.txt.expect
index beaadfc..33f4c6e 100644
--- a/tests/F_root_compute_non_uint32_t_xyzar/stderr.txt.expect
+++ b/tests/F_root_compute_non_uint32_t_xyzar/stderr.txt.expect
@@ -1,4 +1,4 @@
-root_compute_non_uint32_t_xyzar.rs:5:15: error: compute root() parameter must be a uint32_t type
-root_compute_non_uint32_t_xyzar.rs:5:24: error: compute root() parameter must be a uint32_t type
-root_compute_non_uint32_t_xyzar.rs:5:34: error: compute root() parameter must be a uint32_t type
-root_compute_non_uint32_t_xyzar.rs:5:43: error: compute root() parameter must be a uint32_t type
+root_compute_non_uint32_t_xyzar.rs:5:11: error: Unexpected root() parameter type 'int'
+root_compute_non_uint32_t_xyzar.rs:5:18: error: Unexpected root() parameter type 'float'
+root_compute_non_uint32_t_xyzar.rs:5:27: error: Unexpected root() parameter type 'double'
+root_compute_non_uint32_t_xyzar.rs:5:37: error: Unexpected root() parameter type 'uchar'
diff --git a/tests/F_root_compute_non_void_ret/root_compute_non_void_ret.rs b/tests/F_root_compute_non_void_ret/root_compute_non_void_ret.rs
index fe0c3cc..b1dd4fa 100644
--- a/tests/F_root_compute_non_void_ret/root_compute_non_void_ret.rs
+++ b/tests/F_root_compute_non_void_ret/root_compute_non_void_ret.rs
@@ -1,6 +1,6 @@
 #pragma version(1)
 #pragma rs java_package_name(foo)
 
-int root(const int *ain, int *aout, const void *usrData) {
+int root(const int *in, int *out, const void *usrData) {
     return 10;
 }
diff --git a/tests/F_root_compute_really_bad/stderr.txt.expect b/tests/F_root_compute_really_bad/stderr.txt.expect
index 7babf29..9e64c99 100644
--- a/tests/F_root_compute_really_bad/stderr.txt.expect
+++ b/tests/F_root_compute_really_bad/stderr.txt.expect
@@ -1,9 +1,9 @@
 root_compute_really_bad.rs:4:5: error: compute root() is required to return a void type
-root_compute_really_bad.rs:4:14: error: compute root() parameter must be a const pointer type
-root_compute_really_bad.rs:4:23: error: compute root() parameter must be a pointer type
-root_compute_really_bad.rs:4:33: error: compute root() parameter must be a const pointer type
-root_compute_really_bad.rs:4:48: error: compute root() parameter must be a uint32_t type
-root_compute_really_bad.rs:4:58: error: compute root() parameter must be a uint32_t type
-root_compute_really_bad.rs:4:67: error: compute root() parameter must be a uint32_t type
-root_compute_really_bad.rs:4:77: error: compute root() parameter must be a uint32_t type
-root_compute_really_bad.rs:5:19: error: too many compute root() parameters specified
+root_compute_really_bad.rs:4:10: error: Unexpected root() parameter type 'int'
+root_compute_really_bad.rs:4:19: error: Unexpected root() parameter type 'int'
+root_compute_really_bad.rs:4:29: error: Unexpected root() parameter type 'int'
+root_compute_really_bad.rs:4:42: error: Unexpected root() parameter type 'float'
+root_compute_really_bad.rs:4:51: error: Unexpected root() parameter type 'double'
+root_compute_really_bad.rs:4:61: error: Unexpected root() parameter type 'uchar'
+root_compute_really_bad.rs:4:70: error: Unexpected root() parameter type 'ushort'
+root_compute_really_bad.rs:4:5: error: Compute root() must have at least one parameter for in or out
diff --git a/tests/F_root_compute_too_many_args/root_compute_too_many_args.rs b/tests/F_root_compute_too_many_args/root_compute_too_many_args.rs
index 655fba3..7400f87 100644
--- a/tests/F_root_compute_too_many_args/root_compute_too_many_args.rs
+++ b/tests/F_root_compute_too_many_args/root_compute_too_many_args.rs
@@ -1,7 +1,7 @@
 #pragma version(1)
 #pragma rs java_package_name(foo)
 
-void root(const int *ain, int *aout, const void *usrData,
+void root(const int *in, int *out, const void *usrData,
           uint32_t x, uint32_t y, uint32_t z, uint32_t ar,
           uint32_t extra1, uint32_t extra2) {
 }
diff --git a/tests/F_root_compute_too_many_args/stderr.txt.expect b/tests/F_root_compute_too_many_args/stderr.txt.expect
index 78b3266..70f9499 100644
--- a/tests/F_root_compute_too_many_args/stderr.txt.expect
+++ b/tests/F_root_compute_too_many_args/stderr.txt.expect
@@ -1 +1,2 @@
-root_compute_too_many_args.rs:6:20: error: too many compute root() parameters specified
+root_compute_too_many_args.rs:6:20: error: Unexpected root() parameter 'extra1' of type 'uint32_t'
+root_compute_too_many_args.rs:6:37: error: Unexpected root() parameter 'extra2' of type 'uint32_t'