Runtime access checks on virtual method calls

At verification time we may not know that an illegal access or method
not found exception should be raised and so we defer the decision to
runtime. When the decision is deferred we perform an appropriate slow
path method invocation that can check for access violations.

This change also attempts to reduce code duplication, improve the
diagnostic information in exceptions, clean up field slow paths slightly
and to move the slow path calls lower in the Thread class so that they
don't effect the offsets of data items when calls are added or removed.

Change-Id: I8376b83dcd7e302cbbddf44c1a55a25687b9dcdb
diff --git a/src/runtime_support.cc b/src/runtime_support.cc
index 7c47a8f..1ac92fd 100644
--- a/src/runtime_support.cc
+++ b/src/runtime_support.cc
@@ -34,6 +34,71 @@
   self->SetTopOfStack(sp, 0);
 }
 
+static void ThrowNewIllegalAccessErrorClass(Thread* self, Class* referrer, Class* accessed) {
+  self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
+                           "illegal class access: '%s' -> '%s'",
+                           PrettyDescriptor(referrer).c_str(),
+                           PrettyDescriptor(accessed).c_str());
+}
+
+static void ThrowNewIllegalAccessErrorClassForMethodDispatch(Thread* self, Class* referrer,
+                                                             Class* accessed, const Method* caller,
+                                                             const Method* called,
+                                                             bool is_interface, bool is_super) {
+  self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
+                           "illegal class access ('%s' -> '%s')"
+                           "in attempt to invoke %s method '%s' from '%s'",
+                           PrettyDescriptor(referrer).c_str(),
+                           PrettyDescriptor(accessed).c_str(),
+                           (is_interface ? "interface" : (is_super ? "super class" : "virtual")),
+                           PrettyMethod(called).c_str(),
+                           PrettyMethod(caller).c_str());
+}
+
+static void ThrowNewIncompatibleClassChangeErrorClassForInterfaceDispatch(Thread* self,
+                                                                          const Method* referrer,
+                                                                          const Method* interface_method,
+                                                                          Object* this_object) {
+  Thread::Current()->ThrowNewExceptionF("Ljava/lang/IncompatibleClassChangeError;",
+      "class '%s' does not implement interface '%s' in call to '%s' from '%s'",
+      PrettyDescriptor(this_object->GetClass()).c_str(),
+      PrettyDescriptor(interface_method->GetDeclaringClass()).c_str(),
+      PrettyMethod(interface_method).c_str(), PrettyMethod(referrer).c_str());
+}
+
+static void ThrowNewIllegalAccessErrorField(Thread* self, Class* referrer, Field* accessed) {
+  self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
+                           "Field '%s' is inaccessible to class '%s'",
+                           PrettyField(accessed, false).c_str(),
+                           PrettyDescriptor(referrer).c_str());
+}
+
+static void ThrowNewIllegalAccessErrorMethod(Thread* self, Class* referrer, Method* accessed) {
+  self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
+                           "Method '%s' is inaccessible to class '%s'",
+                           PrettyMethod(accessed).c_str(),
+                           PrettyDescriptor(referrer).c_str());
+}
+
+static void ThrowNullPointerExceptionForFieldAccess(Thread* self, Field* field, bool is_read) {
+  self->ThrowNewExceptionF("Ljava/lang/NullPointerException;",
+                           "Attempt to %s field '%s' on a null object reference",
+                           is_read ? "read from" : "write to",
+                           PrettyField(field, true).c_str());
+}
+
+static void ThrowNullPointerExceptionForMethodAccess(Thread* self, Method* caller,
+                                                     uint32_t method_idx, bool is_interface,
+                                                     bool is_super) {
+  const DexFile& dex_file =
+      Runtime::Current()->GetClassLinker()->FindDexFile(caller->GetDeclaringClass()->GetDexCache());
+  self->ThrowNewExceptionF("Ljava/lang/NullPointerException;",
+                           "Attempt to invoke %s method '%s' from '%s' on a null object reference",
+                           (is_interface ? "interface" : (is_super ? "super class" : "virtual")),
+                           PrettyMethod(method_idx, dex_file, true).c_str(),
+                           PrettyMethod(caller).c_str());
+}
+
 /*
  * Report location to debugger.  Note: dalvikPC is the current offset within
  * the method.  However, because the offset alone cannot distinguish between
@@ -482,6 +547,7 @@
   return resolved_field;
 }
 
+
 // Slow path field resolution and declaring class initialization
 Field* FindFieldFromCode(uint32_t field_idx, const Method* referrer, Thread* self,
                          bool is_static, bool is_primitive, size_t expected_size) {
@@ -491,14 +557,10 @@
     Class* fields_class = resolved_field->GetDeclaringClass();
     Class* referring_class = referrer->GetDeclaringClass();
     if (UNLIKELY(!referring_class->CanAccess(fields_class))) {
-      self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;","%s tried to access class %s",
-                               PrettyMethod(referrer).c_str(),
-                               PrettyDescriptor(fields_class).c_str());
+      ThrowNewIllegalAccessErrorClass(self, referring_class, fields_class);
     } else if (UNLIKELY(!referring_class->CanAccessMember(fields_class,
                                                           resolved_field->GetAccessFlags()))) {
-      self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;","%s tried to access field %s",
-                               PrettyMethod(referrer).c_str(),
-                               PrettyField(resolved_field, false).c_str());
+      ThrowNewIllegalAccessErrorField(self, referring_class, resolved_field);
       return NULL;  // failure
     }
     FieldHelper fh(resolved_field);
@@ -527,13 +589,6 @@
   return NULL;
 }
 
-static void ThrowNullPointerExceptionForFieldAccess(Thread* self, Field* field, bool is_read) {
-  self->ThrowNewExceptionF("Ljava/lang/NullPointerException;",
-                           "Attempt to %s field '%s' of a null object",
-                           is_read ? "read from" : "write to",
-                           PrettyField(field, true).c_str());
-}
-
 extern "C" uint32_t artGet32StaticFromCode(uint32_t field_idx, const Method* referrer,
                                            Thread* self, Method** sp) {
   Field* field = FindFieldFast(field_idx, referrer, true, sizeof(int32_t));
@@ -766,10 +821,7 @@
     }
     Class* referrer = method->GetDeclaringClass();
     if (UNLIKELY(!referrer->CanAccess(klass))) {
-      self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
-                               "illegal class access: '%s' -> '%s'",
-                               PrettyDescriptor(referrer).c_str(),
-                               PrettyDescriptor(klass).c_str());
+      ThrowNewIllegalAccessErrorClass(self, referrer, klass);
       return NULL;  // Failure
     }
   }
@@ -815,10 +867,7 @@
   if (access_check) {
     Class* referrer = method->GetDeclaringClass();
     if (UNLIKELY(!referrer->CanAccess(klass))) {
-      self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
-                               "illegal class access: '%s' -> '%s'",
-                               PrettyDescriptor(referrer).c_str(),
-                               PrettyDescriptor(klass).c_str());
+      ThrowNewIllegalAccessErrorClass(self, referrer, klass);
       return NULL;  // Failure
     }
   }
@@ -868,10 +917,7 @@
     if (access_check) {
       Class* referrer = method->GetDeclaringClass();
       if (UNLIKELY(!referrer->CanAccess(klass))) {
-        self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
-                                 "illegal class access: '%s' -> '%s'",
-                                 PrettyDescriptor(referrer).c_str(),
-                                 PrettyDescriptor(klass).c_str());
+        ThrowNewIllegalAccessErrorClass(self, referrer, klass);
         return NULL;  // Failure
       }
     }
@@ -945,11 +991,9 @@
     return NULL;  // Failure - Indicate to caller to deliver exception
   }
   // Perform access check if necessary.
-  if (verify_access && !referrer->GetDeclaringClass()->CanAccess(klass)) {
-    self->ThrowNewExceptionF("Ljava/lang/IllegalAccessError;",
-                             "Class %s is inaccessible to method %s",
-                             PrettyDescriptor(klass).c_str(),
-                             PrettyMethod(referrer, true).c_str());
+  Class* referring_class = referrer->GetDeclaringClass();
+  if (verify_access && UNLIKELY(!referring_class->CanAccess(klass))) {
+    ThrowNewIllegalAccessErrorClass(self, referring_class, klass);
     return NULL;  // Failure - Indicate to caller to deliver exception
   }
   // If we're just implementing const-class, we shouldn't call <clinit>.
@@ -960,7 +1004,7 @@
   //
   // Do not set the DexCache InitializedStaticStorage, since that implies <clinit> has finished
   // running.
-  if (klass == referrer->GetDeclaringClass() && MethodHelper(referrer).IsClassInitializer()) {
+  if (klass == referring_class && MethodHelper(referrer).IsClassInitializer()) {
     return klass;
   }
   if (!class_linker->EnsureInitialized(klass, true)) {
@@ -1086,46 +1130,191 @@
   return 0;  // Success
 }
 
-// See comments in runtime_support_asm.S
-extern "C" uint64_t artFindInterfaceMethodInCacheFromCode(uint32_t method_idx,
-                                                          Object* this_object,
-                                                          Method* caller_method,
-                                                          Thread* thread, Method** sp) {
-  Method* interface_method = caller_method->GetDexCacheResolvedMethods()->Get(method_idx);
-  Method* found_method = NULL;  // The found method
-  if (LIKELY(interface_method != NULL && this_object != NULL)) {
-    found_method = this_object->GetClass()->FindVirtualMethodForInterface(interface_method, false);
+// Fast path method resolution that can't throw exceptions
+static Method* FindMethodFast(uint32_t method_idx, Object* this_object, const Method* referrer,
+                              bool access_check, bool is_interface, bool is_super) {
+  if (UNLIKELY(this_object == NULL)) {
+    return NULL;
   }
-  if (UNLIKELY(found_method == NULL)) {
-    FinishCalleeSaveFrameSetup(thread, sp, Runtime::kRefsAndArgs);
-    if (this_object == NULL) {
-      thread->ThrowNewExceptionF("Ljava/lang/NullPointerException;",
-          "null receiver during interface dispatch");
-      return 0;
+  Method* resolved_method =
+      referrer->GetDeclaringClass()->GetDexCache()->GetResolvedMethod(method_idx);
+  if (UNLIKELY(resolved_method == NULL)) {
+    return NULL;
+  }
+  if (access_check) {
+    Class* methods_class = resolved_method->GetDeclaringClass();
+    Class* referring_class = referrer->GetDeclaringClass();
+    if (UNLIKELY(!referring_class->CanAccess(methods_class) ||
+                 !referring_class->CanAccessMember(methods_class,
+                                                   resolved_method->GetAccessFlags()))) {
+      // potential illegal access
+      return NULL;
     }
-    if (interface_method == NULL) {
-      ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
-      interface_method = class_linker->ResolveMethod(method_idx, caller_method, false);
-      if (interface_method == NULL) {
-        // Could not resolve interface method. Throw error and unwind
-        CHECK(thread->IsExceptionPending());
-        return 0;
+  }
+  if (is_interface) {
+    return this_object->GetClass()->FindVirtualMethodForInterface(resolved_method);
+  } else if (is_super) {
+    return referrer->GetDeclaringClass()->GetSuperClass()->GetVTable()->Get(resolved_method->GetMethodIndex());
+  } else {
+    return this_object->GetClass()->GetVTable()->Get(resolved_method->GetMethodIndex());
+  }
+}
+
+// Slow path method resolution
+static Method* FindMethodFromCode(uint32_t method_idx, Object* this_object, const Method* referrer,
+                                  Thread* self, bool access_check, bool is_interface,
+                                  bool is_super) {
+  ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
+  Method* resolved_method = class_linker->ResolveMethod(method_idx, referrer, false);
+  if (LIKELY(resolved_method != NULL)) {
+    if (!access_check) {
+      if (is_interface) {
+        Method* interface_method =
+            this_object->GetClass()->FindVirtualMethodForInterface(resolved_method);
+        if (UNLIKELY(interface_method == NULL)) {
+          ThrowNewIncompatibleClassChangeErrorClassForInterfaceDispatch(self, referrer,
+                                                                        resolved_method,
+                                                                        this_object);
+          return NULL;
+        } else {
+          return interface_method;
+        }
+      } else {
+        ObjectArray<Method>* vtable;
+        uint16_t vtable_index = resolved_method->GetMethodIndex();
+        if (is_super) {
+          vtable = referrer->GetDeclaringClass()->GetSuperClass()->GetVTable();
+        } else {
+          vtable = this_object->GetClass()->GetVTable();
+        }
+        // TODO: eliminate bounds check?
+        return vtable->Get(vtable_index);
+      }
+    } else {
+      Class* methods_class = resolved_method->GetDeclaringClass();
+      Class* referring_class = referrer->GetDeclaringClass();
+      if (UNLIKELY(!referring_class->CanAccess(methods_class) ||
+                   !referring_class->CanAccessMember(methods_class,
+                                                     resolved_method->GetAccessFlags()))) {
+        // The referring class can't access the resolved method, this may occur as a result of a
+        // protected method being made public by implementing an interface that re-declares the
+        // method public. Resort to the dex file to determine the correct class for the access check
+        const DexFile& dex_file = class_linker->FindDexFile(referring_class->GetDexCache());
+        methods_class = class_linker->ResolveType(dex_file,
+                                                  dex_file.GetMethodId(method_idx).class_idx_,
+                                                  referring_class);
+        if (UNLIKELY(!referring_class->CanAccess(methods_class))) {
+          ThrowNewIllegalAccessErrorClassForMethodDispatch(self, referring_class, methods_class,
+                                                           referrer, resolved_method, is_interface,
+                                                           is_super);
+          return NULL;  // failure
+        } else if (UNLIKELY(!referring_class->CanAccessMember(methods_class,
+                                                              resolved_method->GetAccessFlags()))) {
+          ThrowNewIllegalAccessErrorMethod(self, referring_class, resolved_method);
+          return NULL;  // failure
+        }
+      }
+      if (is_interface) {
+        Method* interface_method =
+            this_object->GetClass()->FindVirtualMethodForInterface(resolved_method);
+        if (UNLIKELY(interface_method == NULL)) {
+          ThrowNewIncompatibleClassChangeErrorClassForInterfaceDispatch(self, referrer,
+                                                                        resolved_method,
+                                                                        this_object);
+          return NULL;
+        } else {
+          return interface_method;
+        }
+      } else {
+        ObjectArray<Method>* vtable;
+        uint16_t vtable_index = resolved_method->GetMethodIndex();
+        if (is_super) {
+          Class* super_class = referring_class->GetSuperClass();
+          if (LIKELY(super_class != NULL)) {
+            vtable = referring_class->GetSuperClass()->GetVTable();
+          } else {
+            vtable = NULL;
+          }
+        } else {
+          vtable = this_object->GetClass()->GetVTable();
+        }
+        if (LIKELY(vtable != NULL &&
+                   vtable_index < static_cast<uint32_t>(vtable->GetLength()))) {
+          return vtable->GetWithoutChecks(vtable_index);
+        } else {
+          // Behavior to agree with that of the verifier
+          self->ThrowNewExceptionF("Ljava/lang/NoSuchMethodError;",
+                                   "attempt to invoke %s method '%s' from '%s'"
+                                   " using incorrect form of method dispatch",
+                                   (is_super ? "super class" : "virtual"),
+                                   PrettyMethod(resolved_method).c_str(),
+                                   PrettyMethod(referrer).c_str());
+          return NULL;
+        }
       }
     }
-    found_method = this_object->GetClass()->FindVirtualMethodForInterface(interface_method, true);
-    if (found_method == NULL) {
-      CHECK(thread->IsExceptionPending());
-      return 0;
+  }
+  DCHECK(self->IsExceptionPending());  // Throw exception and unwind
+  return NULL;
+}
+
+static uint64_t artInvokeCommon(uint32_t method_idx, Object* this_object, Method* caller_method,
+                                Thread* self, Method** sp, bool access_check, bool is_interface,
+                                bool is_super){
+  Method* method = FindMethodFast(method_idx, this_object, caller_method, access_check,
+                                  is_interface, is_super);
+  if (UNLIKELY(method == NULL)) {
+    FinishCalleeSaveFrameSetup(self, sp, Runtime::kRefsAndArgs);
+    if (UNLIKELY(this_object == NULL)) {
+      ThrowNullPointerExceptionForMethodAccess(self, caller_method, method_idx, is_interface,
+                                               is_super);
+      return 0;  // failure
+    }
+    method = FindMethodFromCode(method_idx, this_object, caller_method, self, access_check,
+                                is_interface, is_super);
+    if (UNLIKELY(method == NULL)) {
+      CHECK(self->IsExceptionPending());
+      return 0;  // failure
     }
   }
-  const void* code = found_method->GetCode();
+  // TODO: DCHECK
+  CHECK(!self->IsExceptionPending());
+  const void* code = method->GetCode();
 
-  uint32_t method_uint = reinterpret_cast<uint32_t>(found_method);
+  uint32_t method_uint = reinterpret_cast<uint32_t>(method);
   uint64_t code_uint = reinterpret_cast<uint32_t>(code);
   uint64_t result = ((code_uint << 32) | method_uint);
   return result;
 }
 
+// See comments in runtime_support_asm.S
+extern "C" uint64_t artInvokeInterfaceTrampoline(uint32_t method_idx, Object* this_object,
+                                                 Method* caller_method, Thread* self,
+                                                 Method** sp) {
+  return artInvokeCommon(method_idx, this_object, caller_method, self, sp, false, true, false);
+}
+
+extern "C" uint64_t artInvokeInterfaceTrampolineWithAccessCheck(uint32_t method_idx,
+                                                                Object* this_object,
+                                                                Method* caller_method, Thread* self,
+                                                                Method** sp) {
+  return artInvokeCommon(method_idx, this_object, caller_method, self, sp, true, true, false);
+}
+
+extern "C" uint64_t artInvokeSuperTrampolineWithAccessCheck(uint32_t method_idx,
+                                                            Object* this_object,
+                                                            Method* caller_method, Thread* self,
+                                                            Method** sp) {
+  return artInvokeCommon(method_idx, this_object, caller_method, self, sp, true, false, true);
+}
+
+extern "C" uint64_t artInvokeVirtualTrampolineWithAccessCheck(uint32_t method_idx,
+                                                              Object* this_object,
+                                                              Method* caller_method, Thread* self,
+                                                              Method** sp) {
+  return artInvokeCommon(method_idx, this_object, caller_method, self, sp, true, false, false);
+}
+
 static void ThrowNewUndeclaredThrowableException(Thread* self, JNIEnv* env, Throwable* exception) {
   ScopedLocalRef<jclass> jlr_UTE_class(env,
       env->FindClass("java/lang/reflect/UndeclaredThrowableException"));