Simple support for reflecting rsForEach().

BUG=4203264

Change-Id: Idf722ee3fb07c8e46ac0c4628e753ff2fa6840cf
diff --git a/Android.mk b/Android.mk
index 1e0f5ce..2bb4fd9 100644
--- a/Android.mk
+++ b/Android.mk
@@ -209,10 +209,10 @@
 	slang_rs_export_element.cpp	\
 	slang_rs_export_var.cpp	\
 	slang_rs_export_func.cpp	\
+	slang_rs_export_foreach.cpp \
 	slang_rs_object_ref_count.cpp	\
 	slang_rs_reflection.cpp \
 	slang_rs_reflect_utils.cpp  \
-	slang_rs_root.cpp \
 	slang_rs_metadata_spec_encoder.cpp
 
 LOCAL_STATIC_LIBRARIES :=	\
diff --git a/slang_rs_context.cpp b/slang_rs_context.cpp
index 87b0001..5ee2a81 100644
--- a/slang_rs_context.cpp
+++ b/slang_rs_context.cpp
@@ -35,13 +35,13 @@
 
 #include "slang.h"
 #include "slang_assert.h"
+#include "slang_rs_export_foreach.h"
 #include "slang_rs_export_func.h"
 #include "slang_rs_export_type.h"
 #include "slang_rs_export_var.h"
 #include "slang_rs_exportable.h"
 #include "slang_rs_pragma_handler.h"
 #include "slang_rs_reflection.h"
-#include "slang_rs_root.h"
 
 namespace slang {
 
@@ -114,9 +114,19 @@
     return false;
   }
 
-  // Do not reflect specialized RS functions like init/root.
-  if (RSRoot::isSpecialRSFunc(FD)) {
-    if (!RSRoot::validateSpecialFuncDecl(getDiagnostics(), FD)) {
+  if (RSExportForEach::isRSForEachFunc(FD)) {
+    if (!RSExportForEach::validateSpecialFuncDecl(getDiagnostics(), FD)) {
+      return false;
+    }
+    RSExportForEach *EFE = RSExportForEach::Create(this, FD);
+    if (EFE == NULL)
+      return false;
+    else
+      mExportForEach.push_back(EFE);
+    return true;
+  } else if (RSExportForEach::isSpecialRSFunc(FD)) {
+    // Do not reflect specialized RS functions like init/root.
+    if (!RSExportForEach::validateSpecialFuncDecl(getDiagnostics(), FD)) {
       return false;
     }
     return true;
@@ -178,6 +188,11 @@
 
 bool RSContext::processExport() {
   bool valid = true;
+
+  if (getDiagnostics()->hasErrorOccurred()) {
+    return false;
+  }
+
   // Export variable
   clang::TranslationUnitDecl *TUDecl = mCtx.getTranslationUnitDecl();
   for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
diff --git a/slang_rs_context.h b/slang_rs_context.h
index 91b1aa1..c49f2d5 100644
--- a/slang_rs_context.h
+++ b/slang_rs_context.h
@@ -48,6 +48,7 @@
   class RSExportable;
   class RSExportVar;
   class RSExportFunc;
+  class RSExportForEach;
   class RSExportType;
 
 class RSContext {
@@ -59,6 +60,7 @@
   typedef std::list<RSExportable*> ExportableList;
   typedef std::list<RSExportVar*> ExportVarList;
   typedef std::list<RSExportFunc*> ExportFuncList;
+  typedef std::list<RSExportForEach*> ExportForEachList;
   typedef llvm::StringMap<RSExportType*> ExportTypeMap;
 
  private:
@@ -88,6 +90,7 @@
 
   ExportVarList mExportVars;
   ExportFuncList mExportFuncs;
+  ExportForEachList mExportForEach;
   ExportTypeMap mExportTypes;
 
  public:
@@ -162,6 +165,15 @@
   }
   inline bool hasExportFunc() const { return !mExportFuncs.empty(); }
 
+  typedef ExportForEachList::const_iterator const_export_foreach_iterator;
+  const_export_foreach_iterator export_foreach_begin() const {
+    return mExportForEach.begin();
+  }
+  const_export_foreach_iterator export_foreach_end() const {
+    return mExportForEach.end();
+  }
+  inline bool hasExportForEach() const { return !mExportForEach.empty(); }
+
   typedef ExportTypeMap::iterator export_type_iterator;
   typedef ExportTypeMap::const_iterator const_export_type_iterator;
   export_type_iterator export_types_begin() { return mExportTypes.begin(); }
diff --git a/slang_rs_root.cpp b/slang_rs_export_foreach.cpp
similarity index 61%
rename from slang_rs_root.cpp
rename to slang_rs_export_foreach.cpp
index 2309a8d..5286457 100644
--- a/slang_rs_root.cpp
+++ b/slang_rs_export_foreach.cpp
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 
-#include "slang_rs_root.h"
+#include "slang_rs_export_foreach.h"
 
 #include <string>
 
@@ -26,22 +26,109 @@
 
 #include "slang_assert.h"
 #include "slang_rs_context.h"
+#include "slang_rs_export_type.h"
 
 namespace slang {
 
-RSRoot *RSRoot::Create(RSContext *Context, const clang::FunctionDecl *FD) {
+RSExportForEach *RSExportForEach::Create(RSContext *Context,
+                                         const clang::FunctionDecl *FD) {
   llvm::StringRef Name = FD->getName();
-  RSRoot *F;
+  RSExportForEach *F;
 
   slangAssert(!Name.empty() && "Function must have a name");
 
-  F = new RSRoot(Context, Name, FD);
+  F = new RSExportForEach(Context, Name, FD);
+
+  F->numParams = FD->getNumParams();
+
+  if (F->numParams == 0) {
+    slangAssert(false && "Should have at least one parameter for root");
+  }
+
+  clang::ASTContext &Ctx = Context->getASTContext();
+
+  std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
+  Id.append(F->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
+
+  clang::RecordDecl *RD =
+      clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
+                                Ctx.getTranslationUnitDecl(),
+                                clang::SourceLocation(),
+                                clang::SourceLocation(),
+                                &Ctx.Idents.get(Id));
+
+  // Extract the usrData parameter (if we have one)
+  if (F->numParams >= 3) {
+    const clang::ParmVarDecl *PVD = FD->getParamDecl(2);
+    clang::QualType QT = PVD->getType().getCanonicalType();
+    slangAssert(QT->isPointerType() &&
+                QT->getPointeeType().isConstQualified());
+
+    const clang::ASTContext &C = Context->getASTContext();
+    if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
+        C.VoidTy) {
+      // In the case of using const void*, we can't reflect an appopriate
+      // Java type, so we fall back to just reflecting the ain/aout parameters
+      F->numParams = 2;
+    } else {
+      llvm::StringRef ParamName = PVD->getName();
+      clang::FieldDecl *FD =
+          clang::FieldDecl::Create(Ctx,
+                                   RD,
+                                   clang::SourceLocation(),
+                                   clang::SourceLocation(),
+                                   PVD->getIdentifier(),
+                                   QT->getPointeeType(),
+                                   NULL,
+                                   /* BitWidth = */NULL,
+                                   /* Mutable = */false);
+      RD->addDecl(FD);
+    }
+  }
+  RD->completeDefinition();
+
+  if (F->numParams >= 3) {
+    // Create an export type iff we have a valid usrData type
+    clang::QualType T = Ctx.getTagDeclType(RD);
+    slangAssert(!T.isNull());
+
+    RSExportType *ET =
+      RSExportType::Create(Context, T.getTypePtr());
+
+    if (ET == NULL) {
+      fprintf(stderr, "Failed to export the function %s. There's at least one "
+                      "parameter whose type is not supported by the "
+                      "reflection\n", F->getName().c_str());
+      delete F;
+      return NULL;
+    }
+
+    slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
+           "Parameter packet must be a record");
+
+    F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
+  }
 
   return F;
 }
 
-bool RSRoot::validateSpecialFuncDecl(clang::Diagnostic *Diags,
-                                     const clang::FunctionDecl *FD) {
+bool RSExportForEach::isRSForEachFunc(const clang::FunctionDecl *FD) {
+  // We currently support only compute root() being exported via forEach
+  if (!isRootRSFunc(FD)) {
+    return false;
+  }
+
+  const clang::ASTContext &C = FD->getASTContext();
+  if (FD->getNumParams() == 0 &&
+      FD->getResultType().getCanonicalType() == C.IntTy) {
+    // Graphics compute function
+    return false;
+  }
+  return true;
+}
+
+bool RSExportForEach::validateSpecialFuncDecl(clang::Diagnostic *Diags,
+                                              const clang::FunctionDecl *FD) {
   if (!FD) {
     return false;
   }
diff --git a/slang_rs_export_foreach.h b/slang_rs_export_foreach.h
new file mode 100644
index 0000000..22cc180
--- /dev/null
+++ b/slang_rs_export_foreach.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2011, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_EXPORT_FOREACH_H_  // NOLINT
+#define _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_EXPORT_FOREACH_H_
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "clang/AST/Decl.h"
+
+#include "slang_assert.h"
+#include "slang_rs_context.h"
+#include "slang_rs_exportable.h"
+#include "slang_rs_export_type.h"
+
+namespace clang {
+  class FunctionDecl;
+}  // namespace clang
+
+namespace slang {
+
+// Base class for reflecting control-side forEach (currently for root()
+// functions that fit appropriate criteria)
+class RSExportForEach : public RSExportable {
+ private:
+  std::string mName;
+  RSExportRecordType *mParamPacketType;
+  size_t numParams;
+
+  RSExportForEach(RSContext *Context, const llvm::StringRef &Name,
+         const clang::FunctionDecl *FD)
+    : RSExportable(Context, RSExportable::EX_FOREACH),
+      mName(Name.data(), Name.size()),
+      mParamPacketType(NULL),
+      numParams(0) {
+    return;
+  }
+
+ public:
+  static RSExportForEach *Create(RSContext *Context,
+                                 const clang::FunctionDecl *FD);
+
+  inline const std::string &getName() const {
+    return mName;
+  }
+
+  inline size_t getNumParameters() const {
+    return numParams;
+  }
+
+  inline const RSExportRecordType *getParamPacketType() const
+    { return mParamPacketType; }
+
+  typedef RSExportRecordType::const_field_iterator const_param_iterator;
+
+  inline const_param_iterator params_begin() const {
+    slangAssert((mParamPacketType != NULL) &&
+                "Get parameter from export foreach having no parameter!");
+    return mParamPacketType->fields_begin();
+  }
+  inline const_param_iterator params_end() const {
+    slangAssert((mParamPacketType != NULL) &&
+                "Get parameter from export foreach having no parameter!");
+    return mParamPacketType->fields_end();
+  }
+
+  inline static bool isInitRSFunc(const clang::FunctionDecl *FD) {
+    if (!FD) {
+      return false;
+    }
+    const llvm::StringRef Name = FD->getName();
+    static llvm::StringRef FuncInit("init");
+    return Name.equals(FuncInit);
+  }
+
+  inline static bool isRootRSFunc(const clang::FunctionDecl *FD) {
+    if (!FD) {
+      return false;
+    }
+    const llvm::StringRef Name = FD->getName();
+    static llvm::StringRef FuncRoot("root");
+    return Name.equals(FuncRoot);
+  }
+
+  static bool isRSForEachFunc(const clang::FunctionDecl *FD);
+
+  inline static bool isSpecialRSFunc(const clang::FunctionDecl *FD) {
+    return isRootRSFunc(FD) || isInitRSFunc(FD);
+  }
+
+  static bool validateSpecialFuncDecl(clang::Diagnostic *Diags,
+                                      const clang::FunctionDecl *FD);
+};  // RSExportForEach
+
+}  // namespace slang
+
+#endif  // _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_EXPORT_FOREACH_H_  NOLINT
diff --git a/slang_rs_exportable.h b/slang_rs_exportable.h
index c64dbf6..0a2c066 100644
--- a/slang_rs_exportable.h
+++ b/slang_rs_exportable.h
@@ -26,7 +26,8 @@
   enum Kind {
     EX_FUNC,
     EX_TYPE,
-    EX_VAR
+    EX_VAR,
+    EX_FOREACH
   };
 
  private:
diff --git a/slang_rs_reflection.cpp b/slang_rs_reflection.cpp
index 756e577..dcee519 100644
--- a/slang_rs_reflection.cpp
+++ b/slang_rs_reflection.cpp
@@ -31,6 +31,7 @@
 #include "os_sep.h"
 #include "slang_rs_context.h"
 #include "slang_rs_export_var.h"
+#include "slang_rs_export_foreach.h"
 #include "slang_rs_export_func.h"
 #include "slang_rs_reflect_utils.h"
 #include "slang_utils.h"
@@ -50,6 +51,7 @@
 #define RS_EXPORT_VAR_PREFIX             "mExportVar_"
 
 #define RS_EXPORT_FUNC_INDEX_PREFIX      "mExportFuncIdx_"
+#define RS_EXPORT_FOREACH_INDEX_PREFIX   "mExportForEachIdx_"
 
 #define RS_EXPORT_VAR_ALLOCATION_PREFIX  "mAlloction_"
 #define RS_EXPORT_VAR_DATA_STORAGE_PREFIX "mData_"
@@ -501,6 +503,13 @@
        I++)
     genExportVariable(C, *I);
 
+  // Reflect export for each functions
+  for (RSContext::const_export_foreach_iterator
+           I = mRSContext->export_foreach_begin(),
+           E = mRSContext->export_foreach_end();
+       I != E; I++)
+    genExportForEach(C, *I);
+
   // Reflect export function
   for (RSContext::const_export_func_iterator
            I = mRSContext->export_funcs_begin(),
@@ -786,6 +795,65 @@
   return;
 }
 
+void RSReflection::genExportForEach(Context &C, const RSExportForEach *EF) {
+  C.indent() << "private final static int "RS_EXPORT_FOREACH_INDEX_PREFIX
+             << EF->getName() << " = " << C.getNextExportForEachSlot() << ";"
+             << std::endl;
+
+  // for_each_*()
+  Context::ArgTy Args;
+
+  std::string FieldPackerName = EF->getName() + "_fp";
+  size_t numParams = EF->getNumParameters();
+
+  slangAssert(numParams >= 1);
+  Args.push_back(std::make_pair("Allocation", "ain"));
+  //GetTypeName(RSExportPrimitiveType::DataTypeRSAllocation), "ain");
+  if (numParams >= 2) {
+    Args.push_back(std::make_pair("Allocation", "aout"));
+  }
+  if (numParams >= 3) {
+    for (RSExportFunc::const_param_iterator I = EF->params_begin(),
+             E = EF->params_end();
+         I != E;
+         I++) {
+      Args.push_back(std::make_pair(GetTypeName((*I)->getType()),
+                                    (*I)->getName()));
+    }
+  }
+
+  C.startFunction(Context::AM_Public,
+                  false,
+                  "void",
+                  "forEach_" + EF->getName(),
+                  Args);
+
+  if (numParams >= 3) {
+    const RSExportRecordType *ERT = EF->getParamPacketType();
+
+    if (genCreateFieldPacker(C, ERT, FieldPackerName.c_str()))
+      genPackVarOfType(C, ERT, NULL, FieldPackerName.c_str());
+  }
+  C.indent() << "forEach("RS_EXPORT_FOREACH_INDEX_PREFIX << EF->getName()
+             << ", ain";
+
+  switch (numParams) {
+    case 1:
+      C.out() << ", null, null);" << std::endl;
+      break;
+    case 2:
+      C.out() << ", aout, null);" << std::endl;
+      break;
+    case 3:
+    default:
+      C.out() << ", aout, " << FieldPackerName << ");" << std::endl;
+      break;
+  }
+
+  C.endFunction();
+  return;
+}
+
 void RSReflection::genPrimitiveTypeExportVariable(
     Context &C,
     const RSExportVar *EV) {
diff --git a/slang_rs_reflection.h b/slang_rs_reflection.h
index 2f9e18c..c26f82f 100644
--- a/slang_rs_reflection.h
+++ b/slang_rs_reflection.h
@@ -33,6 +33,7 @@
   class RSContext;
   class RSExportVar;
   class RSExportFunc;
+  class RSExportForEach;
 
 class RSReflection {
  private:
@@ -68,6 +69,7 @@
 
     int mNextExportVarSlot;
     int mNextExportFuncSlot;
+    int mNextExportForEachSlot;
 
     // A mapping from a field in a record type to its index in the rsType
     // instance. Only used when generates TypeClass (ScriptField_*).
@@ -83,6 +85,7 @@
       mPaddingFieldIndex = 1;
       mNextExportVarSlot = 0;
       mNextExportFuncSlot = 0;
+      mNextExportForEachSlot = 0;
       return;
     }
 
@@ -143,6 +146,7 @@
     inline int getNextExportVarSlot() { return mNextExportVarSlot++; }
 
     inline int getNextExportFuncSlot() { return mNextExportFuncSlot++; }
+    inline int getNextExportForEachSlot() { return mNextExportForEachSlot++; }
 
     // Will remove later due to field name information is not necessary for
     // C-reflect-to-Java
@@ -234,6 +238,9 @@
   void genExportFunction(Context &C,
                          const RSExportFunc *EF);
 
+  void genExportForEach(Context &C,
+                        const RSExportForEach *EF);
+
   bool genTypeClass(Context &C,
                     const RSExportRecordType *ERT,
                     std::string &ErrorMsg);
diff --git a/slang_rs_root.h b/slang_rs_root.h
deleted file mode 100644
index 1c7efa5..0000000
--- a/slang_rs_root.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Copyright 2011, The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_ROOT_H_  // NOLINT
-#define _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_ROOT_H_
-
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/raw_ostream.h"
-
-#include "clang/AST/Decl.h"
-
-#include "slang_assert.h"
-#include "slang_rs_context.h"
-
-namespace clang {
-  class FunctionDecl;
-}  // namespace clang
-
-namespace slang {
-
-// Base class for handling root() functions (including reflection of
-// control-side code for issuing rsForEach).
-class RSRoot {
- private:
-  std::string mName;
-  std::string mMangledName;
-  bool mShouldMangle;
-
-  RSRoot(RSContext *Context, const llvm::StringRef &Name,
-         const clang::FunctionDecl *FD)
-    : mName(Name.data(), Name.size()),
-      mMangledName(),
-      mShouldMangle(false) {
-    mShouldMangle = Context->getMangleContext().shouldMangleDeclName(FD);
-
-    if (mShouldMangle) {
-      llvm::raw_string_ostream BufStm(mMangledName);
-      Context->getMangleContext().mangleName(FD, BufStm);
-      BufStm.flush();
-    }
-
-    return;
-  }
-
- public:
-  static RSRoot *Create(RSContext *Context, const clang::FunctionDecl *FD);
-
-  inline const std::string &getName(bool mangle = true) const {
-    return (mShouldMangle && mangle) ? mMangledName : mName;
-  }
-
-  inline static bool isInitRSFunc(const clang::FunctionDecl *FD) {
-    if (!FD) {
-      return false;
-    }
-    const llvm::StringRef Name = FD->getName();
-    static llvm::StringRef FuncInit("init");
-    return Name.equals(FuncInit);
-  }
-
-  inline static bool isRootRSFunc(const clang::FunctionDecl *FD) {
-    if (!FD) {
-      return false;
-    }
-    const llvm::StringRef Name = FD->getName();
-    static llvm::StringRef FuncRoot("root");
-    return Name.equals(FuncRoot);
-  }
-
-  inline static bool isSpecialRSFunc(const clang::FunctionDecl *FD) {
-    return isRootRSFunc(FD) || isInitRSFunc(FD);
-  }
-
-  static bool validateSpecialFuncDecl(clang::Diagnostic *Diags,
-                                      const clang::FunctionDecl *FD);
-};  // RSRoot
-
-}  // namespace slang
-
-#endif  // _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_ROOT_H_  NOLINT