Sanity checks in JNI code for closure creation

b/20728113

In case the requested size for memory allocation overflows, or memory
allocation fails.

Change-Id: I8dac132dd4d0210938660ffbb82cbe44000d2a90
diff --git a/rs/java/android/renderscript/RenderScript.java b/rs/java/android/renderscript/RenderScript.java
index 6b1939c..24bc2c1 100644
--- a/rs/java/android/renderscript/RenderScript.java
+++ b/rs/java/android/renderscript/RenderScript.java
@@ -302,8 +302,12 @@
         long[] fieldIDs, long[] values, int[] sizes, long[] depClosures,
         long[] depFieldIDs) {
       validate();
-      return rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values,
+      long c = rsnClosureCreate(mContext, kernelID, returnValue, fieldIDs, values,
           sizes, depClosures, depFieldIDs);
+      if (c == 0) {
+          throw new RSRuntimeException("Failed creating closure.");
+      }
+      return c;
     }
 
     native long rsnInvokeClosureCreate(long con, long invokeID, byte[] params,
@@ -311,8 +315,12 @@
     synchronized long nInvokeClosureCreate(long invokeID, byte[] params,
         long[] fieldIDs, long[] values, int[] sizes) {
       validate();
-      return rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs,
+      long c = rsnInvokeClosureCreate(mContext, invokeID, params, fieldIDs,
           values, sizes);
+      if (c == 0) {
+          throw new RSRuntimeException("Failed creating closure.");
+      }
+      return c;
     }
 
     native void rsnClosureSetArg(long con, long closureID, int index,
@@ -337,7 +345,11 @@
     synchronized long nScriptGroup2Create(String name, String cachePath,
                                           long[] closures) {
       validate();
-      return rsnScriptGroup2Create(mContext, name, cachePath, closures);
+      long g = rsnScriptGroup2Create(mContext, name, cachePath, closures);
+      if (g == 0) {
+          throw new RSRuntimeException("Failed creating script group.");
+      }
+      return g;
     }
 
     native void rsnScriptGroup2Execute(long con, long groupID);
diff --git a/rs/jni/android_renderscript_RenderScript.cpp b/rs/jni/android_renderscript_RenderScript.cpp
index 6f6729b..d339a0d 100644
--- a/rs/jni/android_renderscript_RenderScript.cpp
+++ b/rs/jni/android_renderscript_RenderScript.cpp
@@ -44,6 +44,9 @@
 
 //#define LOG_API ALOGE
 static constexpr bool kLogApi = false;
+static constexpr size_t kMaxNumberArgsAndBindings = 1000;
+static constexpr size_t kMaxNumberClosuresInScriptGroup = 1000000;
+static constexpr size_t kMaxNumberKernelArguments = 256;
 
 using namespace android;
 
@@ -333,79 +336,167 @@
                jlong returnValue, jlongArray fieldIDArray,
                jlongArray valueArray, jintArray sizeArray,
                jlongArray depClosureArray, jlongArray depFieldIDArray) {
+  jlong ret = 0;
+
   jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr);
   jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray);
-  RsScriptFieldID* fieldIDs =
-      (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length);
-  for (int i = 0; i< fieldIDs_length; i++) {
-    fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i];
-  }
-
   jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
   jsize values_length = _env->GetArrayLength(valueArray);
-  uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length);
-  for (int i = 0; i < values_length; i++) {
-    values[i] = (uintptr_t)jValues[i];
-  }
-
-  jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr);
+  jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr);
   jsize sizes_length = _env->GetArrayLength(sizeArray);
-
   jlong* jDepClosures =
       _env->GetLongArrayElements(depClosureArray, nullptr);
   jsize depClosures_length = _env->GetArrayLength(depClosureArray);
-  RsClosure* depClosures =
-      (RsClosure*)alloca(sizeof(RsClosure) * depClosures_length);
-  for (int i = 0; i < depClosures_length; i++) {
-    depClosures[i] = (RsClosure)jDepClosures[i];
-  }
-
   jlong* jDepFieldIDs =
       _env->GetLongArrayElements(depFieldIDArray, nullptr);
   jsize depFieldIDs_length = _env->GetArrayLength(depFieldIDArray);
-  RsScriptFieldID* depFieldIDs =
-      (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * depFieldIDs_length);
-  for (int i = 0; i < depClosures_length; i++) {
+
+  size_t numValues, numDependencies;
+  RsScriptFieldID* fieldIDs;
+  uintptr_t* values;
+  RsClosure* depClosures;
+  RsScriptFieldID* depFieldIDs;
+
+  if (fieldIDs_length != values_length || values_length != sizes_length) {
+      ALOGE("Unmatched field IDs, values, and sizes in closure creation.");
+      goto exit;
+  }
+
+  numValues = (size_t)fieldIDs_length;
+
+  if (depClosures_length != depFieldIDs_length) {
+      ALOGE("Unmatched closures and field IDs for dependencies in closure creation.");
+      goto exit;
+  }
+
+  numDependencies = (size_t)depClosures_length;
+
+  if (numDependencies > numValues) {
+      ALOGE("Unexpected number of dependencies in closure creation");
+      goto exit;
+  }
+
+  if (numValues > kMaxNumberArgsAndBindings) {
+      ALOGE("Too many arguments or globals in closure creation");
+      goto exit;
+  }
+
+  fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues);
+  if (fieldIDs == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i < numValues; i++) {
+    fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i];
+  }
+
+  values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues);
+  if (values == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i < numValues; i++) {
+    values[i] = (uintptr_t)jValues[i];
+  }
+
+  depClosures = (RsClosure*)alloca(sizeof(RsClosure) * numDependencies);
+  if (depClosures == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i < numDependencies; i++) {
+    depClosures[i] = (RsClosure)jDepClosures[i];
+  }
+
+  depFieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numDependencies);
+  if (depFieldIDs == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i < numDependencies; i++) {
     depFieldIDs[i] = (RsClosure)jDepFieldIDs[i];
   }
 
-  return (jlong)(uintptr_t)rsClosureCreate(
+  ret = (jlong)(uintptr_t)rsClosureCreate(
       (RsContext)con, (RsScriptKernelID)kernelID, (RsAllocation)returnValue,
-      fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length,
-      (int*)sizes, (size_t)sizes_length,
-      depClosures, (size_t)depClosures_length,
-      depFieldIDs, (size_t)depFieldIDs_length);
+      fieldIDs, numValues, values, numValues,
+      (int*)jSizes, numValues,
+      depClosures, numDependencies,
+      depFieldIDs, numDependencies);
+
+exit:
+
+  _env->ReleaseLongArrayElements(depFieldIDArray, jDepFieldIDs, JNI_ABORT);
+  _env->ReleaseLongArrayElements(depClosureArray, jDepClosures, JNI_ABORT);
+  _env->ReleaseIntArrayElements (sizeArray,       jSizes,       JNI_ABORT);
+  _env->ReleaseLongArrayElements(valueArray,      jValues,      JNI_ABORT);
+  _env->ReleaseLongArrayElements(fieldIDArray,    jFieldIDs,    JNI_ABORT);
+
+  return ret;
 }
 
 static jlong
 nInvokeClosureCreate(JNIEnv *_env, jobject _this, jlong con, jlong invokeID,
                      jbyteArray paramArray, jlongArray fieldIDArray, jlongArray valueArray,
                      jintArray sizeArray) {
+  jlong ret = 0;
+
   jbyte* jParams = _env->GetByteArrayElements(paramArray, nullptr);
   jsize jParamLength = _env->GetArrayLength(paramArray);
-
   jlong* jFieldIDs = _env->GetLongArrayElements(fieldIDArray, nullptr);
   jsize fieldIDs_length = _env->GetArrayLength(fieldIDArray);
-  RsScriptFieldID* fieldIDs =
-      (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * fieldIDs_length);
-  for (int i = 0; i< fieldIDs_length; i++) {
+  jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
+  jsize values_length = _env->GetArrayLength(valueArray);
+  jint* jSizes = _env->GetIntArrayElements(sizeArray, nullptr);
+  jsize sizes_length = _env->GetArrayLength(sizeArray);
+
+  size_t numValues;
+  RsScriptFieldID* fieldIDs;
+  uintptr_t* values;
+
+  if (fieldIDs_length != values_length || values_length != sizes_length) {
+      ALOGE("Unmatched field IDs, values, and sizes in closure creation.");
+      goto exit;
+  }
+
+  numValues = (size_t) fieldIDs_length;
+
+  if (numValues > kMaxNumberArgsAndBindings) {
+      ALOGE("Too many arguments or globals in closure creation");
+      goto exit;
+  }
+
+  fieldIDs = (RsScriptFieldID*)alloca(sizeof(RsScriptFieldID) * numValues);
+  if (fieldIDs == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i< numValues; i++) {
     fieldIDs[i] = (RsScriptFieldID)jFieldIDs[i];
   }
 
-  jlong* jValues = _env->GetLongArrayElements(valueArray, nullptr);
-  jsize values_length = _env->GetArrayLength(valueArray);
-  uintptr_t* values = (uintptr_t*)alloca(sizeof(uintptr_t) * values_length);
-  for (int i = 0; i < values_length; i++) {
+  values = (uintptr_t*)alloca(sizeof(uintptr_t) * numValues);
+  if (values == nullptr) {
+      goto exit;
+  }
+
+  for (size_t i = 0; i < numValues; i++) {
     values[i] = (uintptr_t)jValues[i];
   }
 
-  jint* sizes = _env->GetIntArrayElements(sizeArray, nullptr);
-  jsize sizes_length = _env->GetArrayLength(sizeArray);
-
-  return (jlong)(uintptr_t)rsInvokeClosureCreate(
+  ret = (jlong)(uintptr_t)rsInvokeClosureCreate(
       (RsContext)con, (RsScriptInvokeID)invokeID, jParams, jParamLength,
-      fieldIDs, (size_t)fieldIDs_length, values, (size_t)values_length,
-      (int*)sizes, (size_t)sizes_length);
+      fieldIDs, numValues, values, numValues,
+      (int*)jSizes, numValues);
+
+exit:
+
+  _env->ReleaseIntArrayElements (sizeArray,       jSizes,       JNI_ABORT);
+  _env->ReleaseLongArrayElements(valueArray,      jValues,      JNI_ABORT);
+  _env->ReleaseLongArrayElements(fieldIDArray,    jFieldIDs,    JNI_ABORT);
+  _env->ReleaseByteArrayElements(paramArray,      jParams,      JNI_ABORT);
+
+  return ret;
 }
 
 static void
@@ -425,20 +516,40 @@
 static long
 nScriptGroup2Create(JNIEnv *_env, jobject _this, jlong con, jstring name,
                     jstring cacheDir, jlongArray closureArray) {
+  jlong ret = 0;
+
   AutoJavaStringToUTF8 nameUTF(_env, name);
   AutoJavaStringToUTF8 cacheDirUTF(_env, cacheDir);
 
   jlong* jClosures = _env->GetLongArrayElements(closureArray, nullptr);
   jsize numClosures = _env->GetArrayLength(closureArray);
-  RsClosure* closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures);
+
+  RsClosure* closures;
+
+  if (numClosures > (jsize) kMaxNumberClosuresInScriptGroup) {
+    ALOGE("Too many closures in script group");
+    goto exit;
+  }
+
+  closures = (RsClosure*)alloca(sizeof(RsClosure) * numClosures);
+  if (closures == nullptr) {
+      goto exit;
+  }
+
   for (int i = 0; i < numClosures; i++) {
     closures[i] = (RsClosure)jClosures[i];
   }
 
-  return (jlong)(uintptr_t)rsScriptGroup2Create(
+  ret = (jlong)(uintptr_t)rsScriptGroup2Create(
       (RsContext)con, nameUTF.c_str(), nameUTF.length(),
       cacheDirUTF.c_str(), cacheDirUTF.length(),
       closures, numClosures);
+
+exit:
+
+  _env->ReleaseLongArrayElements(closureArray, jClosures, JNI_ABORT);
+
+  return ret;
 }
 
 static void
@@ -1766,8 +1877,14 @@
 
     if (ains != nullptr) {
         in_len = _env->GetArrayLength(ains);
-        in_ptr = _env->GetLongArrayElements(ains, nullptr);
+        if (in_len > (jint)kMaxNumberKernelArguments) {
+            ALOGE("Too many arguments in kernel launch.");
+            // TODO (b/20758983): Report back to Java and throw an exception
+            return;
+        }
 
+        // TODO (b/20760800): Check in_ptr is not null
+        in_ptr = _env->GetLongArrayElements(ains, nullptr);
         if (sizeof(RsAllocation) == sizeof(jlong)) {
             in_allocs = (RsAllocation*)in_ptr;
 
@@ -1775,6 +1892,11 @@
             // Convert from 64-bit jlong types to the native pointer type.
 
             in_allocs = (RsAllocation*)alloca(in_len * sizeof(RsAllocation));
+            if (in_allocs == nullptr) {
+                ALOGE("Failed launching kernel for lack of memory.");
+                _env->ReleaseLongArrayElements(ains, in_ptr, JNI_ABORT);
+                return;
+            }
 
             for (int index = in_len; --index >= 0;) {
                 in_allocs[index] = (RsAllocation)in_ptr[index];