Type checking for C++ API.

Bug: 10427951

Change-Id: I76a6093237a3b27a0c9e6ce38997cb1e0128efb9
diff --git a/slang_rs_reflection_base.cpp b/slang_rs_reflection_base.cpp
index 65da696..fe93054 100644
--- a/slang_rs_reflection_base.cpp
+++ b/slang_rs_reflection_base.cpp
@@ -194,5 +194,14 @@
   return tmp.str();
 }
 
+bool RSReflectionBase::addTypeNameForElement(
+    const std::string &TypeName) {
+  if (mTypesToCheck.find(TypeName) == mTypesToCheck.end()) {
+    mTypesToCheck.insert(TypeName);
+    return true;
+  } else {
+    return false;
+  }
+}
 
 }
diff --git a/slang_rs_reflection_base.h b/slang_rs_reflection_base.h
index 8667e0d..100f5eb 100644
--- a/slang_rs_reflection_base.h
+++ b/slang_rs_reflection_base.h
@@ -39,7 +39,8 @@
 protected:
     const RSContext *mRSContext;
 
-
+    // Generated RS Elements for type-checking code.
+    std::set<std::string> mTypesToCheck;
 
     RSReflectionBase(const RSContext *);
 
@@ -67,6 +68,7 @@
 
     bool writeFile(const std::string &filename, const std::vector< std::string > &txt);
 
+    bool addTypeNameForElement(const std::string &TypeName);
 
 private:
 
diff --git a/slang_rs_reflection_cpp.cpp b/slang_rs_reflection_cpp.cpp
index a6a921a..ea4227d 100644
--- a/slang_rs_reflection_cpp.cpp
+++ b/slang_rs_reflection_cpp.cpp
@@ -43,6 +43,8 @@
 
 #define RS_TYPE_ITEM_CLASS_NAME          "Item"
 
+#define RS_ELEM_PREFIX "__rs_elem_"
+
 static const char *GetMatrixTypeName(const RSExportMatrixType *EMT) {
   static const char *MatrixTypeCNameMap[] = {
     "rs_matrix2x2",
@@ -170,12 +172,32 @@
   for (RSContext::const_export_var_iterator I = mRSContext->export_vars_begin(),
          E = mRSContext->export_vars_end(); I != E; I++, slot++) {
     const RSExportVar *ev = *I;
-    RSReflectionTypeData rtd;
-    ev->getType()->convertToRTD(&rtd);
     if (!ev->isConst()) {
       write(GetTypeName(ev->getType()) + " __" + ev->getName() + ";");
     }
   }
+  for (RSContext::const_export_foreach_iterator
+           I = mRSContext->export_foreach_begin(),
+           E = mRSContext->export_foreach_end(); I != E; I++) {
+    const RSExportForEach *EF = *I;
+    const RSExportType *IET = EF->getInType();
+    const RSExportType *OET = EF->getOutType();
+    if (IET) {
+      genTypeInstanceFromPointer(IET);
+    }
+    if (OET) {
+      genTypeInstanceFromPointer(OET);
+    }
+  }
+
+  for (std::set<std::string>::iterator I = mTypesToCheck.begin(),
+                                       E = mTypesToCheck.end();
+       I != E;
+       I++) {
+    write("android::RSC::sp<const android::RSC::Element> " RS_ELEM_PREFIX
+          + *I + ";");
+  }
+
   decIndent();
 
   write("public:");
@@ -298,7 +320,12 @@
      << ", \"/data/data/" << packageName << "/app\", sizeof(\"" << packageName << "\")) {";
   write(ss);
   incIndent();
-  //...
+  for (std::set<std::string>::iterator I = mTypesToCheck.begin(),
+                                       E = mTypesToCheck.end();
+       I != E;
+       I++) {
+    write(RS_ELEM_PREFIX + *I + " = android::RSC::Element::" + *I + "(mRS);");
+  }
   decIndent();
   write("}");
   write("");
@@ -347,6 +374,19 @@
     write(tmp);
     tmp.str("");
 
+    const RSExportType *IET = ef->getInType();
+    const RSExportType *OET = ef->getOutType();
+
+    incIndent();
+    if (IET) {
+      genTypeCheck(IET, "ain");
+    }
+
+    if (OET) {
+      genTypeCheck(OET, "aout");
+    }
+    decIndent();
+
     std::string FieldPackerName = ef->getName() + "_fp";
     if (ERT) {
       if (genCreateFieldPacker(ERT, FieldPackerName.c_str())) {
@@ -367,6 +407,7 @@
       tmp << "NULL, ";
     }
 
+    // FIXME (no support for usrData with C++ kernels)
     tmp << "NULL, 0);";
     write(tmp);
 
@@ -596,8 +637,6 @@
     case RSExportType::ExportClassVector:
     case RSExportType::ExportClassPointer:
     case RSExportType::ExportClassMatrix: {
-      RSReflectionTypeData rtd;
-      ET->convertToRTD(&rtd);
       ss << "    " << FieldPackerName << ".add(" << VarName << ");";
       write(ss);
       break;
@@ -690,4 +729,79 @@
   }
 }
 
+
+void RSReflectionCpp::genTypeCheck(const RSExportType *ET,
+                                   const char *VarName) {
+  stringstream tmp;
+  tmp << "// Type check for " << VarName;
+  write(tmp);
+  tmp.str("");
+
+  if (ET->getClass() == RSExportType::ExportClassPointer) {
+    const RSExportPointerType *EPT =
+        static_cast<const RSExportPointerType*>(ET);
+    ET = EPT->getPointeeType();
+  }
+
+  std::string TypeName;
+  switch (ET->getClass()) {
+    case RSExportType::ExportClassPrimitive:
+    case RSExportType::ExportClassVector:
+    case RSExportType::ExportClassRecord: {
+      TypeName = ET->getElementName();
+      break;
+    }
+
+    default:
+      break;
+  }
+
+  if (!TypeName.empty()) {
+    //tmp << "// TypeName: " << TypeName;
+    tmp << "if (!" << VarName
+        << "->getType()->getElement()->isCompatible("
+        << RS_ELEM_PREFIX
+        << TypeName << ")) {";
+    write(tmp);
+
+    incIndent();
+    write("mRS->throwError(RS_ERROR_RUNTIME_ERROR, "
+          "\"Incompatible type\");");
+    write("return;");
+    decIndent();
+
+    write("}");
+  }
+}
+
+void RSReflectionCpp::genTypeInstanceFromPointer(const RSExportType *ET) {
+  if (ET->getClass() == RSExportType::ExportClassPointer) {
+    // For pointer parameters to original forEach kernels.
+    const RSExportPointerType *EPT =
+        static_cast<const RSExportPointerType*>(ET);
+    genTypeInstance(EPT->getPointeeType());
+  } else {
+    // For handling pass-by-value kernel parameters.
+    genTypeInstance(ET);
+  }
+}
+
+void RSReflectionCpp::genTypeInstance(const RSExportType *ET) {
+  switch (ET->getClass()) {
+    case RSExportType::ExportClassPrimitive:
+    case RSExportType::ExportClassVector:
+    case RSExportType::ExportClassConstantArray:
+    case RSExportType::ExportClassRecord: {
+      std::string TypeName = ET->getElementName();
+      addTypeNameForElement(TypeName);
+      break;
+    }
+
+    default:
+      break;
+  }
+}
+
+
+
 }  // namespace slang
diff --git a/slang_rs_reflection_cpp.h b/slang_rs_reflection_cpp.h
index de7e824..948d624 100644
--- a/slang_rs_reflection_cpp.h
+++ b/slang_rs_reflection_cpp.h
@@ -19,6 +19,9 @@
 
 #include "slang_rs_reflection_base.h"
 
+#include <set>
+#include <string>
+
 namespace slang {
 
 class RSReflectionCpp : public RSReflectionBase {
@@ -40,6 +43,7 @@
     mNextExportVarSlot = 0;
     mNextExportFuncSlot = 0;
     mNextExportForEachSlot = 0;
+    mTypesToCheck.clear();
   }
 
   inline unsigned int getNextExportVarSlot() {
@@ -83,6 +87,14 @@
   void genPackVarOfType(const RSExportType *ET,
                         const char *VarName,
                         const char *FieldPackerName);
+
+  // Generate a runtime type check for VarName.
+  void genTypeCheck(const RSExportType *ET, const char *VarName);
+
+  // Generate a type instance for a given forEach argument type.
+  void genTypeInstanceFromPointer(const RSExportType *ET);
+  void genTypeInstance(const RSExportType *ET);
+
 };  // class RSReflectionCpp
 
 }   // namespace slang