Type-checking for ForEach with RS types.

BUG=4203264

Change-Id: I90e54cdf22fea76ffde9548617fb7b492ba9d643
diff --git a/slang_rs_reflection.cpp b/slang_rs_reflection.cpp
index 17c562a..22343ab 100644
--- a/slang_rs_reflection.cpp
+++ b/slang_rs_reflection.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright 2010, The Android Open Source Project
+ * Copyright 2010-2011, The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -130,8 +130,7 @@
   const char **BaseElement = NULL;
 
   switch (EVT->getType()) {
-    case RSExportPrimitiveType::DataTypeSigned8:
-    case RSExportPrimitiveType::DataTypeBoolean: {
+    case RSExportPrimitiveType::DataTypeSigned8: {
       BaseElement = VectorTypeJavaNameMap[0];
       break;
     }
@@ -160,7 +159,7 @@
       break;
     }
     default: {
-      slangAssert(false && "RSReflection::genElementTypeName : Unsupported "
+      slangAssert(false && "RSReflection::GetVectorTypeName : Unsupported "
                            "vector element data type");
       break;
     }
@@ -168,7 +167,78 @@
 
   slangAssert((EVT->getNumElement() > 1) &&
               (EVT->getNumElement() <= 4) &&
-              "Number of element in vector type is invalid");
+              "Number of elements in vector type is invalid");
+
+  return BaseElement[EVT->getNumElement() - 2];
+}
+
+static const char *GetVectorElementName(const RSExportVectorType *EVT) {
+  static const char *VectorElementNameMap[][3] = {
+    /* 0 */ { "U8_2",   "U8_3",   "U8_4" },
+    /* 1 */ { "I8_2",   "I8_3",   "I8_4" },
+    /* 2 */ { "U16_2",  "U16_3",  "U16_4" },
+    /* 3 */ { "I16_2",  "I16_3",  "I16_4" },
+    /* 4 */ { "U32_2",  "U32_3",  "U32_4" },
+    /* 5 */ { "I32_2",  "I32_3",  "I32_4" },
+    /* 6 */ { "U64_2",  "U64_3",  "U64_4" },
+    /* 7 */ { "I64_2",  "I64_3",  "I64_4" },
+    /* 8 */ { "F32_2",  "F32_3",  "F32_4" },
+    /* 9 */ { "F64_2",  "F64_3",  "F64_4" },
+  };
+
+  const char **BaseElement = NULL;
+
+  switch (EVT->getType()) {
+    case RSExportPrimitiveType::DataTypeUnsigned8: {
+      BaseElement = VectorElementNameMap[0];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeSigned8: {
+      BaseElement = VectorElementNameMap[1];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeUnsigned16: {
+      BaseElement = VectorElementNameMap[2];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeSigned16: {
+      BaseElement = VectorElementNameMap[3];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeUnsigned32: {
+      BaseElement = VectorElementNameMap[4];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeSigned32: {
+      BaseElement = VectorElementNameMap[5];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeUnsigned64: {
+      BaseElement = VectorElementNameMap[6];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeSigned64: {
+      BaseElement = VectorElementNameMap[7];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeFloat32: {
+      BaseElement = VectorElementNameMap[8];
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeFloat64: {
+      BaseElement = VectorElementNameMap[9];
+      break;
+    }
+    default: {
+      slangAssert(false && "RSReflection::GetVectorElementName : Unsupported "
+                           "vector element data type");
+      break;
+    }
+  }
+
+  slangAssert((EVT->getNumElement() > 1) &&
+              (EVT->getNumElement() <= 4) &&
+              "Number of elements in vector type is invalid");
 
   return BaseElement[EVT->getNumElement() - 2];
 }
@@ -488,6 +558,51 @@
     return NULL;
 }
 
+static const char *GetElementJavaTypeName(RSExportPrimitiveType::DataType DT) {
+  static const char *ElementJavaTypeNameMap[] = {
+    NULL,               // RSExportPrimitiveType::DataTypeFloat16
+    "F32",              // RSExportPrimitiveType::DataTypeFloat32
+    "F64",              // RSExportPrimitiveType::DataTypeFloat64
+    "I8",               // RSExportPrimitiveType::DataTypeSigned8
+    "I16",              // RSExportPrimitiveType::DataTypeSigned16
+    "I32",              // RSExportPrimitiveType::DataTypeSigned32
+    "I64",              // RSExportPrimitiveType::DataTypeSigned64
+    "U8",               // RSExportPrimitiveType::DataTypeUnsigned8
+    "U16",              // RSExportPrimitiveType::DataTypeUnsigned16
+    "U32",              // RSExportPrimitiveType::DataTypeUnsigned32
+    "U64",              // RSExportPrimitiveType::DataTypeUnsigned64
+    "BOOLEAN",          // RSExportPrimitiveType::DataTypeBoolean
+
+    "RGB_565",          // RSExportPrimitiveType::DataTypeUnsigned565
+    "RGBA_5551",        // RSExportPrimitiveType::DataTypeUnsigned5551
+    "RGBA_4444",        // RSExportPrimitiveType::DataTypeUnsigned4444
+
+    // DataTypeRSMatrix* must have been resolved in GetBuiltinElementConstruct()
+    NULL,               // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix2x2
+    NULL,               // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix3x3
+    NULL,               // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix4x4
+
+    "ELEMENT",          // RSExportPrimitiveType::DataTypeRSElement
+    "TYPE",             // RSExportPrimitiveType::DataTypeRSType
+    "ALLOCATION",       // RSExportPrimitiveType::DataTypeRSAllocation
+    "SAMPLER",          // RSExportPrimitiveType::DataTypeRSSampler
+    "SCRIPT",           // RSExportPrimitiveType::DataTypeRSScript
+    "MESH",             // RSExportPrimitiveType::DataTypeRSMesh
+    "PROGRAM_FRAGMENT", // RSExportPrimitiveType::DataTypeRSProgramFragment
+    "PROGRAM_VERTEX",   // RSExportPrimitiveType::DataTypeRSProgramVertex
+    "PROGRAM_RASTER",   // RSExportPrimitiveType::DataTypeRSProgramRaster
+    "PROGRAM_STORE",    // RSExportPrimitiveType::DataTypeRSProgramStore
+    "FONT",             // RSExportPrimitiveType::DataTypeRSFont
+  };
+
+  if (static_cast<unsigned>(DT) <
+      (sizeof(ElementJavaTypeNameMap) / sizeof(const char*)))
+    return ElementJavaTypeNameMap[DT];
+  else
+    return NULL;
+}
+
+
 /********************** Methods to generate script class **********************/
 bool RSReflection::genScriptClass(Context &C,
                                   const std::string &ClassName,
@@ -551,8 +666,32 @@
       genInitExportVariable(C, EV->getType(), EV->getName(), EV->getInit());
   }
 
+  for (RSContext::const_export_foreach_iterator
+           I = mRSContext->export_foreach_begin(),
+           E = mRSContext->export_foreach_end();
+       I != E;
+       I++) {
+    const RSExportForEach *EF = *I;
+
+    const RSExportType *IET = EF->getInType();
+    if (IET) {
+      genTypeInstance(C, IET);
+    }
+    const RSExportType *OET = EF->getOutType();
+    if (OET) {
+      genTypeInstance(C, OET);
+    }
+  }
+
   C.endFunction();
 
+  for (std::set<std::string>::iterator I = C.mTypesToCheck.begin(),
+                                       E = C.mTypesToCheck.end();
+       I != E;
+       I++) {
+    C.indent() << "private Element __" << *I << ";" << std::endl;
+  }
+
   return;
 }
 
@@ -812,10 +951,8 @@
 
   if (EF->hasIn())
     Args.push_back(std::make_pair("Allocation", "ain"));
-    //Args.push_back(std::make_pair(GetTypeName(EF->getInType()), "ain"));
   if (EF->hasOut())
     Args.push_back(std::make_pair("Allocation", "aout"));
-    //Args.push_back(std::make_pair(GetTypeName(EF->getOutType()), "aout"));
 
   const RSExportRecordType *ERT = EF->getParamPacketType();
   if (ERT) {
@@ -854,7 +991,8 @@
     C.indent() << "    (tIn.getZ() != tOut.getZ()) ||" << std::endl;
     C.indent() << "    (tIn.hasFaces() != tOut.hasFaces()) ||" << std::endl;
     C.indent() << "    (tIn.hasMipmaps() != tOut.hasMipmaps())) {" << std::endl;
-    C.indent() << "    throw new RSRuntimeException(\"Dimension mismatch\");";
+    C.indent() << "    throw new RSRuntimeException(\"Dimension mismatch "
+               << "between input and output parameters!\");";
     C.out()    << std::endl;
     C.indent() << "}" << std::endl;
   }
@@ -888,10 +1026,129 @@
   return;
 }
 
+void RSReflection::genTypeInstance(Context &C,
+                                   const RSExportType *ET) {
+  if (ET->getClass() == RSExportType::ExportClassPointer) {
+    const RSExportPointerType *EPT =
+        static_cast<const RSExportPointerType*>(ET);
+    ET = EPT->getPointeeType();
+    switch (ET->getClass()) {
+      case RSExportType::ExportClassPrimitive: {
+        const RSExportPrimitiveType *EPT =
+            static_cast<const RSExportPrimitiveType*>(ET);
+        slangAssert(EPT);
+
+        switch (EPT->getKind()) {
+          case RSExportPrimitiveType::DataKindPixelL:
+          case RSExportPrimitiveType::DataKindPixelA:
+          case RSExportPrimitiveType::DataKindPixelLA:
+          case RSExportPrimitiveType::DataKindPixelRGB:
+          case RSExportPrimitiveType::DataKindPixelRGBA: {
+            break;
+          }
+
+          case RSExportPrimitiveType::DataKindUser:
+          default: {
+            std::string TypeName = GetElementJavaTypeName(EPT->getType());
+            if (C.mTypesToCheck.find(TypeName) == C.mTypesToCheck.end()) {
+              C.indent() << "__" << TypeName << " = Element." << TypeName
+                         << "(rs);" << std::endl;
+              C.mTypesToCheck.insert(TypeName);
+            }
+            break;
+          }
+        }
+        break;
+      }
+
+      case RSExportType::ExportClassVector: {
+        const RSExportVectorType *EVT =
+            static_cast<const RSExportVectorType*>(ET);
+        slangAssert(EVT);
+
+        const char *TypeName = GetVectorElementName(EVT);
+        if (C.mTypesToCheck.find(TypeName) == C.mTypesToCheck.end()) {
+          C.indent() << "__" << TypeName << " = Element." << TypeName
+                     << "(rs);" << std::endl;
+          C.mTypesToCheck.insert(TypeName);
+        }
+        break;
+      }
+
+      case RSExportType::ExportClassRecord: {
+        const RSExportRecordType *ERT =
+            static_cast<const RSExportRecordType*>(ET);
+        slangAssert(ERT);
+
+        std::string ClassName = RS_TYPE_CLASS_NAME_PREFIX + ERT->getName();
+        if (C.mTypesToCheck.find(ClassName) == C.mTypesToCheck.end()) {
+          C.indent() << "__" << ClassName << " = " << ClassName <<
+                        ".createElement(rs);" << std::endl;
+          C.mTypesToCheck.insert(ClassName);
+        }
+        break;
+      }
+
+      default:
+        break;
+    }
+  }
+}
+
 void RSReflection::genTypeCheck(Context &C,
                                 const RSExportType *ET,
                                 const char *VarName) {
   C.indent() << "// check " << VarName << std::endl;
+
+  if (ET->getClass() == RSExportType::ExportClassPointer) {
+    const RSExportPointerType *EPT =
+        static_cast<const RSExportPointerType*>(ET);
+    ET = EPT->getPointeeType();
+  }
+
+  std::string TypeName;
+
+  switch (ET->getClass()) {
+    case RSExportType::ExportClassPrimitive: {
+      const RSExportPrimitiveType *EPT =
+          static_cast<const RSExportPrimitiveType*>(ET);
+      slangAssert(EPT);
+
+      if (EPT->getKind() == RSExportPrimitiveType::DataKindUser) {
+        TypeName = GetElementJavaTypeName(EPT->getType());
+      }
+      break;
+    }
+
+    case RSExportType::ExportClassVector: {
+      const RSExportVectorType *EVT =
+          static_cast<const RSExportVectorType*>(ET);
+      slangAssert(EVT);
+      TypeName = GetVectorElementName(EVT);
+      break;
+    }
+
+    case RSExportType::ExportClassRecord: {
+      const RSExportRecordType *ERT =
+          static_cast<const RSExportRecordType*>(ET);
+      slangAssert(ERT);
+      TypeName = RS_TYPE_CLASS_NAME_PREFIX + ERT->getName();
+      break;
+    }
+
+    default:
+      break;
+  }
+
+  if (!TypeName.empty()) {
+    C.indent() << "if (!" << VarName
+               << ".getType().getElement().isCompatible(__"
+               << TypeName << ")) {" << std::endl;
+    C.indent() << "    throw new RSRuntimeException(\"Type mismatch with "
+               << TypeName << "!\");" << std::endl;
+    C.indent() << "}" << std::endl;
+  }
+
   return;
 }
 
diff --git a/slang_rs_reflection.h b/slang_rs_reflection.h
index 2da7484..29e9f18 100644
--- a/slang_rs_reflection.h
+++ b/slang_rs_reflection.h
@@ -1,5 +1,5 @@
 /*
- * Copyright 2010, The Android Open Source Project
+ * Copyright 2010-2011, The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
 #include <fstream>
 #include <iostream>
 #include <map>
+#include <set>
 #include <string>
 #include <vector>
 
@@ -102,6 +103,8 @@
     bool mUseStdout;
     mutable std::ofstream mOF;
 
+    std::set<std::string> mTypesToCheck;
+
     static const char *AccessModifierStr(AccessModifier AM);
 
     Context(const std::string &OutputPathBase,
@@ -241,9 +244,12 @@
   void genExportForEach(Context &C,
                         const RSExportForEach *EF);
 
-  void genTypeCheck(Context &C,
-                    const RSExportType *ET,
-                    const char *VarName);
+  static void genTypeCheck(Context &C,
+                           const RSExportType *ET,
+                           const char *VarName);
+
+  static void genTypeInstance(Context &C,
+                              const RSExportType *ET);
 
   bool genTypeClass(Context &C,
                     const RSExportRecordType *ERT,