Made CheckJNI check varargs when possible

Several JNI functions such as NewObject and Call*Method take a variable
number of arguments. This patch will make CheckJNI do (limited) dynamic
validation of these arguments. Currently it is limited to checking that
objects are valid and that no value types have illegal values.

Bug: 20344596
Change-Id: I1f81d2bdd80215e0007fc636bac27d5bcc2dba96
diff --git a/runtime/check_jni.cc b/runtime/check_jni.cc
index 4172b89..b6ad547 100644
--- a/runtime/check_jni.cc
+++ b/runtime/check_jni.cc
@@ -66,6 +66,8 @@
 #define kFlag_Invocation    0x8000      // Part of the invocation interface (JavaVM*).
 
 #define kFlag_ForceTrace    0x80000000  // Add this to a JNI function's flags if you want to trace every call.
+
+class VarArgs;
 /*
  * Java primitive types:
  * B - jbyte
@@ -126,6 +128,116 @@
   jshort S;
   const void* V;  // void
   jboolean Z;
+  const VarArgs* va;
+};
+
+/*
+ * A structure containing all the information needed to validate varargs arguments.
+ *
+ * Note that actually getting the arguments from this structure mutates it so should only be done on
+ * owned copies.
+ */
+class VarArgs {
+ public:
+  VarArgs(jmethodID m, va_list var) : m_(m), type_(kTypeVaList), cnt_(0) {
+    va_copy(vargs_, var);
+  }
+
+  VarArgs(jmethodID m, const jvalue* vals) : m_(m), type_(kTypePtr), cnt_(0), ptr_(vals) {}
+
+  ~VarArgs() {
+    if (type_ == kTypeVaList) {
+      va_end(vargs_);
+    }
+  }
+
+  VarArgs(VarArgs&& other) {
+    m_ = other.m_;
+    cnt_ = other.cnt_;
+    type_ = other.type_;
+    if (other.type_ == kTypeVaList) {
+      va_copy(vargs_, other.vargs_);
+    } else {
+      ptr_ = other.ptr_;
+    }
+  }
+
+  // This method is const because we need to ensure that one only uses the GetValue method on an
+  // owned copy of the VarArgs. This is because getting the next argument from a va_list is a
+  // mutating operation. Therefore we pass around these VarArgs with the 'const' qualifier and when
+  // we want to use one we need to Clone() it.
+  VarArgs Clone() const {
+    if (type_ == kTypeVaList) {
+      // const_cast needed to make sure the compiler is okay with va_copy, which (being a macro) is
+      // messed up if the source argument is not the exact type 'va_list'.
+      return VarArgs(m_, cnt_, const_cast<VarArgs*>(this)->vargs_);
+    } else {
+      return VarArgs(m_, cnt_, ptr_);
+    }
+  }
+
+  jmethodID GetMethodID() const {
+    return m_;
+  }
+
+  JniValueType GetValue(char fmt) {
+    JniValueType o;
+    if (type_ == kTypeVaList) {
+      switch (fmt) {
+        case 'Z': o.Z = static_cast<jboolean>(va_arg(vargs_, jint)); break;
+        case 'B': o.B = static_cast<jbyte>(va_arg(vargs_, jint)); break;
+        case 'C': o.C = static_cast<jchar>(va_arg(vargs_, jint)); break;
+        case 'S': o.S = static_cast<jshort>(va_arg(vargs_, jint)); break;
+        case 'I': o.I = va_arg(vargs_, jint); break;
+        case 'J': o.J = va_arg(vargs_, jlong); break;
+        case 'F': o.F = static_cast<jfloat>(va_arg(vargs_, jdouble)); break;
+        case 'D': o.D = va_arg(vargs_, jdouble); break;
+        case 'L': o.L = va_arg(vargs_, jobject); break;
+        default:
+          LOG(FATAL) << "Illegal type format char " << fmt;
+          UNREACHABLE();
+      }
+    } else {
+      CHECK(type_ == kTypePtr);
+      jvalue v = ptr_[cnt_];
+      cnt_++;
+      switch (fmt) {
+        case 'Z': o.Z = v.z; break;
+        case 'B': o.B = v.b; break;
+        case 'C': o.C = v.c; break;
+        case 'S': o.S = v.s; break;
+        case 'I': o.I = v.i; break;
+        case 'J': o.J = v.j; break;
+        case 'F': o.F = v.f; break;
+        case 'D': o.D = v.d; break;
+        case 'L': o.L = v.l; break;
+        default:
+          LOG(FATAL) << "Illegal type format char " << fmt;
+          UNREACHABLE();
+      }
+    }
+    return o;
+  }
+
+ private:
+  VarArgs(jmethodID m, uint32_t cnt, va_list var) : m_(m), type_(kTypeVaList), cnt_(cnt) {
+    va_copy(vargs_, var);
+  }
+
+  VarArgs(jmethodID m, uint32_t cnt, const jvalue* vals) : m_(m), type_(kTypePtr), cnt_(cnt), ptr_(vals) {}
+
+  enum VarArgsType {
+    kTypeVaList,
+    kTypePtr,
+  };
+
+  jmethodID m_;
+  VarArgsType type_;
+  uint32_t cnt_;
+  union {
+    va_list vargs_;
+    const jvalue* ptr_;
+  };
 };
 
 class ScopedCheck {
@@ -339,7 +451,7 @@
    * z - jsize (for lengths; use i if negative values are okay)
    * v - JavaVM*
    * E - JNIEnv*
-   * . - no argument; just print "..." (used for varargs JNI calls)
+   * . - VarArgs* for Jni calls with variable length arguments
    *
    * Use the kFlag_NullableUtf flag where 'u' field(s) are nullable.
    */
@@ -736,11 +848,35 @@
         return CheckThread(arg.E);
       case 'L':  // jobject
         return CheckInstance(soa, kObject, arg.L, true);
+      case '.':  // A VarArgs list
+        return CheckVarArgs(soa, arg.va);
       default:
         return CheckNonHeapValue(fmt, arg);
     }
   }
 
+  bool CheckVarArgs(ScopedObjectAccess& soa, const VarArgs* args_p)
+      SHARED_REQUIRES(Locks::mutator_lock_) {
+    CHECK(args_p != nullptr);
+    VarArgs args(args_p->Clone());
+    ArtMethod* m = CheckMethodID(soa, args.GetMethodID());
+    if (m == nullptr) {
+      return false;
+    }
+    uint32_t len = 0;
+    const char* shorty = m->GetShorty(&len);
+    // Skip the return type
+    CHECK_GE(len, 1u);
+    len--;
+    shorty++;
+    for (uint32_t i = 0; i < len; i++) {
+      if (!CheckPossibleHeapValue(soa, shorty[i], args.GetValue(shorty[i]))) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   bool CheckNonHeapValue(char fmt, JniValueType arg) {
     switch (fmt) {
       case 'p':  // TODO: pointer - null or readable?
@@ -833,6 +969,24 @@
         }
         break;
       }
+      case '.': {
+        const VarArgs* va = arg.va;
+        VarArgs args(va->Clone());
+        ArtMethod* m = soa.DecodeMethod(args.GetMethodID());
+        uint32_t len;
+        const char* shorty = m->GetShorty(&len);
+        CHECK_GE(len, 1u);
+        // Skip past return value.
+        len--;
+        shorty++;
+        // Remove the previous ', ' from the message.
+        msg->erase(msg->length() - 2);
+        for (uint32_t i = 0; i < len; i++) {
+          *msg += ", ";
+          TracePossibleHeapValue(soa, entry, shorty[i], args.GetValue(shorty[i]), msg);
+        }
+        break;
+      }
       default:
         TraceNonHeapValue(fmt, arg, msg);
         break;
@@ -1836,8 +1990,9 @@
   static jobject NewObjectV(JNIEnv* env, jclass c, jmethodID mid, va_list vargs) {
     ScopedObjectAccess soa(env);
     ScopedCheck sc(kFlag_Default, __FUNCTION__);
-    JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}};
-    if (sc.Check(soa, true, "Ecm", args) && sc.CheckInstantiableNonArray(soa, c) &&
+    VarArgs rest(mid, vargs);
+    JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = &rest}};
+    if (sc.Check(soa, true, "Ecm.", args) && sc.CheckInstantiableNonArray(soa, c) &&
         sc.CheckConstructor(soa, mid)) {
       JniValueType result;
       result.L = baseEnv(env)->NewObjectV(env, c, mid, vargs);
@@ -1859,8 +2014,9 @@
   static jobject NewObjectA(JNIEnv* env, jclass c, jmethodID mid, jvalue* vargs) {
     ScopedObjectAccess soa(env);
     ScopedCheck sc(kFlag_Default, __FUNCTION__);
-    JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}};
-    if (sc.Check(soa, true, "Ecm", args) && sc.CheckInstantiableNonArray(soa, c) &&
+    VarArgs rest(mid, vargs);
+    JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = &rest}};
+    if (sc.Check(soa, true, "Ecm.", args) && sc.CheckInstantiableNonArray(soa, c) &&
         sc.CheckConstructor(soa, mid)) {
       JniValueType result;
       result.L = baseEnv(env)->NewObjectA(env, c, mid, vargs);
@@ -2689,25 +2845,25 @@
   }
 
   static bool CheckCallArgs(ScopedObjectAccess& soa, ScopedCheck& sc, JNIEnv* env, jobject obj,
-                            jclass c, jmethodID mid, InvokeType invoke)
+                            jclass c, jmethodID mid, InvokeType invoke, const VarArgs* vargs)
       SHARED_REQUIRES(Locks::mutator_lock_) {
     bool checked;
     switch (invoke) {
       case kVirtual: {
         DCHECK(c == nullptr);
-        JniValueType args[3] = {{.E = env}, {.L = obj}, {.m = mid}};
-        checked = sc.Check(soa, true, "ELm", args);
+        JniValueType args[4] = {{.E = env}, {.L = obj}, {.m = mid}, {.va = vargs}};
+        checked = sc.Check(soa, true, "ELm.", args);
         break;
       }
       case kDirect: {
-        JniValueType args[4] = {{.E = env}, {.L = obj}, {.c = c}, {.m = mid}};
-        checked = sc.Check(soa, true, "ELcm", args);
+        JniValueType args[5] = {{.E = env}, {.L = obj}, {.c = c}, {.m = mid}, {.va = vargs}};
+        checked = sc.Check(soa, true, "ELcm.", args);
         break;
       }
       case kStatic: {
         DCHECK(obj == nullptr);
-        JniValueType args[3] = {{.E = env}, {.c = c}, {.m = mid}};
-        checked = sc.Check(soa, true, "Ecm", args);
+        JniValueType args[4] = {{.E = env}, {.c = c}, {.m = mid}, {.va = vargs}};
+        checked = sc.Check(soa, true, "Ecm.", args);
         break;
       }
       default:
@@ -2724,7 +2880,8 @@
     ScopedObjectAccess soa(env);
     ScopedCheck sc(kFlag_Default, function_name);
     JniValueType result;
-    if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke) &&
+    VarArgs rest(mid, vargs);
+    if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke, &rest) &&
         sc.CheckMethodAndSig(soa, obj, c, mid, type, invoke)) {
       const char* result_check;
       switch (type) {
@@ -2907,7 +3064,8 @@
     ScopedObjectAccess soa(env);
     ScopedCheck sc(kFlag_Default, function_name);
     JniValueType result;
-    if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke) &&
+    VarArgs rest(mid, vargs);
+    if (CheckCallArgs(soa, sc, env, obj, c, mid, invoke, &rest) &&
         sc.CheckMethodAndSig(soa, obj, c, mid, type, invoke)) {
       const char* result_check;
       switch (type) {