Type-checking for ForEach with RS types.
BUG=4203264
Change-Id: I90e54cdf22fea76ffde9548617fb7b492ba9d643
diff --git a/slang_rs_reflection.cpp b/slang_rs_reflection.cpp
index 17c562a..22343ab 100644
--- a/slang_rs_reflection.cpp
+++ b/slang_rs_reflection.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright 2010, The Android Open Source Project
+ * Copyright 2010-2011, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -130,8 +130,7 @@
const char **BaseElement = NULL;
switch (EVT->getType()) {
- case RSExportPrimitiveType::DataTypeSigned8:
- case RSExportPrimitiveType::DataTypeBoolean: {
+ case RSExportPrimitiveType::DataTypeSigned8: {
BaseElement = VectorTypeJavaNameMap[0];
break;
}
@@ -160,7 +159,7 @@
break;
}
default: {
- slangAssert(false && "RSReflection::genElementTypeName : Unsupported "
+ slangAssert(false && "RSReflection::GetVectorTypeName : Unsupported "
"vector element data type");
break;
}
@@ -168,7 +167,78 @@
slangAssert((EVT->getNumElement() > 1) &&
(EVT->getNumElement() <= 4) &&
- "Number of element in vector type is invalid");
+ "Number of elements in vector type is invalid");
+
+ return BaseElement[EVT->getNumElement() - 2];
+}
+
+static const char *GetVectorElementName(const RSExportVectorType *EVT) {
+ static const char *VectorElementNameMap[][3] = {
+ /* 0 */ { "U8_2", "U8_3", "U8_4" },
+ /* 1 */ { "I8_2", "I8_3", "I8_4" },
+ /* 2 */ { "U16_2", "U16_3", "U16_4" },
+ /* 3 */ { "I16_2", "I16_3", "I16_4" },
+ /* 4 */ { "U32_2", "U32_3", "U32_4" },
+ /* 5 */ { "I32_2", "I32_3", "I32_4" },
+ /* 6 */ { "U64_2", "U64_3", "U64_4" },
+ /* 7 */ { "I64_2", "I64_3", "I64_4" },
+ /* 8 */ { "F32_2", "F32_3", "F32_4" },
+ /* 9 */ { "F64_2", "F64_3", "F64_4" },
+ };
+
+ const char **BaseElement = NULL;
+
+ switch (EVT->getType()) {
+ case RSExportPrimitiveType::DataTypeUnsigned8: {
+ BaseElement = VectorElementNameMap[0];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeSigned8: {
+ BaseElement = VectorElementNameMap[1];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeUnsigned16: {
+ BaseElement = VectorElementNameMap[2];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeSigned16: {
+ BaseElement = VectorElementNameMap[3];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeUnsigned32: {
+ BaseElement = VectorElementNameMap[4];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeSigned32: {
+ BaseElement = VectorElementNameMap[5];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeUnsigned64: {
+ BaseElement = VectorElementNameMap[6];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeSigned64: {
+ BaseElement = VectorElementNameMap[7];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeFloat32: {
+ BaseElement = VectorElementNameMap[8];
+ break;
+ }
+ case RSExportPrimitiveType::DataTypeFloat64: {
+ BaseElement = VectorElementNameMap[9];
+ break;
+ }
+ default: {
+ slangAssert(false && "RSReflection::GetVectorElementName : Unsupported "
+ "vector element data type");
+ break;
+ }
+ }
+
+ slangAssert((EVT->getNumElement() > 1) &&
+ (EVT->getNumElement() <= 4) &&
+ "Number of elements in vector type is invalid");
return BaseElement[EVT->getNumElement() - 2];
}
@@ -488,6 +558,51 @@
return NULL;
}
+static const char *GetElementJavaTypeName(RSExportPrimitiveType::DataType DT) {
+ static const char *ElementJavaTypeNameMap[] = {
+ NULL, // RSExportPrimitiveType::DataTypeFloat16
+ "F32", // RSExportPrimitiveType::DataTypeFloat32
+ "F64", // RSExportPrimitiveType::DataTypeFloat64
+ "I8", // RSExportPrimitiveType::DataTypeSigned8
+ "I16", // RSExportPrimitiveType::DataTypeSigned16
+ "I32", // RSExportPrimitiveType::DataTypeSigned32
+ "I64", // RSExportPrimitiveType::DataTypeSigned64
+ "U8", // RSExportPrimitiveType::DataTypeUnsigned8
+ "U16", // RSExportPrimitiveType::DataTypeUnsigned16
+ "U32", // RSExportPrimitiveType::DataTypeUnsigned32
+ "U64", // RSExportPrimitiveType::DataTypeUnsigned64
+ "BOOLEAN", // RSExportPrimitiveType::DataTypeBoolean
+
+ "RGB_565", // RSExportPrimitiveType::DataTypeUnsigned565
+ "RGBA_5551", // RSExportPrimitiveType::DataTypeUnsigned5551
+ "RGBA_4444", // RSExportPrimitiveType::DataTypeUnsigned4444
+
+ // DataTypeRSMatrix* must have been resolved in GetBuiltinElementConstruct()
+ NULL, // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix2x2
+ NULL, // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix3x3
+ NULL, // (Dummy) RSExportPrimitiveType::DataTypeRSMatrix4x4
+
+ "ELEMENT", // RSExportPrimitiveType::DataTypeRSElement
+ "TYPE", // RSExportPrimitiveType::DataTypeRSType
+ "ALLOCATION", // RSExportPrimitiveType::DataTypeRSAllocation
+ "SAMPLER", // RSExportPrimitiveType::DataTypeRSSampler
+ "SCRIPT", // RSExportPrimitiveType::DataTypeRSScript
+ "MESH", // RSExportPrimitiveType::DataTypeRSMesh
+ "PROGRAM_FRAGMENT", // RSExportPrimitiveType::DataTypeRSProgramFragment
+ "PROGRAM_VERTEX", // RSExportPrimitiveType::DataTypeRSProgramVertex
+ "PROGRAM_RASTER", // RSExportPrimitiveType::DataTypeRSProgramRaster
+ "PROGRAM_STORE", // RSExportPrimitiveType::DataTypeRSProgramStore
+ "FONT", // RSExportPrimitiveType::DataTypeRSFont
+ };
+
+ if (static_cast<unsigned>(DT) <
+ (sizeof(ElementJavaTypeNameMap) / sizeof(const char*)))
+ return ElementJavaTypeNameMap[DT];
+ else
+ return NULL;
+}
+
+
/********************** Methods to generate script class **********************/
bool RSReflection::genScriptClass(Context &C,
const std::string &ClassName,
@@ -551,8 +666,32 @@
genInitExportVariable(C, EV->getType(), EV->getName(), EV->getInit());
}
+ for (RSContext::const_export_foreach_iterator
+ I = mRSContext->export_foreach_begin(),
+ E = mRSContext->export_foreach_end();
+ I != E;
+ I++) {
+ const RSExportForEach *EF = *I;
+
+ const RSExportType *IET = EF->getInType();
+ if (IET) {
+ genTypeInstance(C, IET);
+ }
+ const RSExportType *OET = EF->getOutType();
+ if (OET) {
+ genTypeInstance(C, OET);
+ }
+ }
+
C.endFunction();
+ for (std::set<std::string>::iterator I = C.mTypesToCheck.begin(),
+ E = C.mTypesToCheck.end();
+ I != E;
+ I++) {
+ C.indent() << "private Element __" << *I << ";" << std::endl;
+ }
+
return;
}
@@ -812,10 +951,8 @@
if (EF->hasIn())
Args.push_back(std::make_pair("Allocation", "ain"));
- //Args.push_back(std::make_pair(GetTypeName(EF->getInType()), "ain"));
if (EF->hasOut())
Args.push_back(std::make_pair("Allocation", "aout"));
- //Args.push_back(std::make_pair(GetTypeName(EF->getOutType()), "aout"));
const RSExportRecordType *ERT = EF->getParamPacketType();
if (ERT) {
@@ -854,7 +991,8 @@
C.indent() << " (tIn.getZ() != tOut.getZ()) ||" << std::endl;
C.indent() << " (tIn.hasFaces() != tOut.hasFaces()) ||" << std::endl;
C.indent() << " (tIn.hasMipmaps() != tOut.hasMipmaps())) {" << std::endl;
- C.indent() << " throw new RSRuntimeException(\"Dimension mismatch\");";
+ C.indent() << " throw new RSRuntimeException(\"Dimension mismatch "
+ << "between input and output parameters!\");";
C.out() << std::endl;
C.indent() << "}" << std::endl;
}
@@ -888,10 +1026,129 @@
return;
}
+void RSReflection::genTypeInstance(Context &C,
+ const RSExportType *ET) {
+ if (ET->getClass() == RSExportType::ExportClassPointer) {
+ const RSExportPointerType *EPT =
+ static_cast<const RSExportPointerType*>(ET);
+ ET = EPT->getPointeeType();
+ switch (ET->getClass()) {
+ case RSExportType::ExportClassPrimitive: {
+ const RSExportPrimitiveType *EPT =
+ static_cast<const RSExportPrimitiveType*>(ET);
+ slangAssert(EPT);
+
+ switch (EPT->getKind()) {
+ case RSExportPrimitiveType::DataKindPixelL:
+ case RSExportPrimitiveType::DataKindPixelA:
+ case RSExportPrimitiveType::DataKindPixelLA:
+ case RSExportPrimitiveType::DataKindPixelRGB:
+ case RSExportPrimitiveType::DataKindPixelRGBA: {
+ break;
+ }
+
+ case RSExportPrimitiveType::DataKindUser:
+ default: {
+ std::string TypeName = GetElementJavaTypeName(EPT->getType());
+ if (C.mTypesToCheck.find(TypeName) == C.mTypesToCheck.end()) {
+ C.indent() << "__" << TypeName << " = Element." << TypeName
+ << "(rs);" << std::endl;
+ C.mTypesToCheck.insert(TypeName);
+ }
+ break;
+ }
+ }
+ break;
+ }
+
+ case RSExportType::ExportClassVector: {
+ const RSExportVectorType *EVT =
+ static_cast<const RSExportVectorType*>(ET);
+ slangAssert(EVT);
+
+ const char *TypeName = GetVectorElementName(EVT);
+ if (C.mTypesToCheck.find(TypeName) == C.mTypesToCheck.end()) {
+ C.indent() << "__" << TypeName << " = Element." << TypeName
+ << "(rs);" << std::endl;
+ C.mTypesToCheck.insert(TypeName);
+ }
+ break;
+ }
+
+ case RSExportType::ExportClassRecord: {
+ const RSExportRecordType *ERT =
+ static_cast<const RSExportRecordType*>(ET);
+ slangAssert(ERT);
+
+ std::string ClassName = RS_TYPE_CLASS_NAME_PREFIX + ERT->getName();
+ if (C.mTypesToCheck.find(ClassName) == C.mTypesToCheck.end()) {
+ C.indent() << "__" << ClassName << " = " << ClassName <<
+ ".createElement(rs);" << std::endl;
+ C.mTypesToCheck.insert(ClassName);
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+}
+
void RSReflection::genTypeCheck(Context &C,
const RSExportType *ET,
const char *VarName) {
C.indent() << "// check " << VarName << std::endl;
+
+ if (ET->getClass() == RSExportType::ExportClassPointer) {
+ const RSExportPointerType *EPT =
+ static_cast<const RSExportPointerType*>(ET);
+ ET = EPT->getPointeeType();
+ }
+
+ std::string TypeName;
+
+ switch (ET->getClass()) {
+ case RSExportType::ExportClassPrimitive: {
+ const RSExportPrimitiveType *EPT =
+ static_cast<const RSExportPrimitiveType*>(ET);
+ slangAssert(EPT);
+
+ if (EPT->getKind() == RSExportPrimitiveType::DataKindUser) {
+ TypeName = GetElementJavaTypeName(EPT->getType());
+ }
+ break;
+ }
+
+ case RSExportType::ExportClassVector: {
+ const RSExportVectorType *EVT =
+ static_cast<const RSExportVectorType*>(ET);
+ slangAssert(EVT);
+ TypeName = GetVectorElementName(EVT);
+ break;
+ }
+
+ case RSExportType::ExportClassRecord: {
+ const RSExportRecordType *ERT =
+ static_cast<const RSExportRecordType*>(ET);
+ slangAssert(ERT);
+ TypeName = RS_TYPE_CLASS_NAME_PREFIX + ERT->getName();
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ if (!TypeName.empty()) {
+ C.indent() << "if (!" << VarName
+ << ".getType().getElement().isCompatible(__"
+ << TypeName << ")) {" << std::endl;
+ C.indent() << " throw new RSRuntimeException(\"Type mismatch with "
+ << TypeName << "!\");" << std::endl;
+ C.indent() << "}" << std::endl;
+ }
+
return;
}
diff --git a/slang_rs_reflection.h b/slang_rs_reflection.h
index 2da7484..29e9f18 100644
--- a/slang_rs_reflection.h
+++ b/slang_rs_reflection.h
@@ -1,5 +1,5 @@
/*
- * Copyright 2010, The Android Open Source Project
+ * Copyright 2010-2011, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
#include <fstream>
#include <iostream>
#include <map>
+#include <set>
#include <string>
#include <vector>
@@ -102,6 +103,8 @@
bool mUseStdout;
mutable std::ofstream mOF;
+ std::set<std::string> mTypesToCheck;
+
static const char *AccessModifierStr(AccessModifier AM);
Context(const std::string &OutputPathBase,
@@ -241,9 +244,12 @@
void genExportForEach(Context &C,
const RSExportForEach *EF);
- void genTypeCheck(Context &C,
- const RSExportType *ET,
- const char *VarName);
+ static void genTypeCheck(Context &C,
+ const RSExportType *ET,
+ const char *VarName);
+
+ static void genTypeInstance(Context &C,
+ const RSExportType *ET);
bool genTypeClass(Context &C,
const RSExportRecordType *ERT,