Reuse SIRT for C++ references

Change-Id: I8310e55da42f55f7ec60f6b17face436c77a979f
diff --git a/src/jni_compiler_test.cc b/src/jni_compiler_test.cc
index ecc2f88..938d733 100644
--- a/src/jni_compiler_test.cc
+++ b/src/jni_compiler_test.cc
@@ -21,14 +21,11 @@
 
 class JniCompilerTest : public CommonTest {
  protected:
-  virtual void SetUp() {
-    CommonTest::SetUp();
-    class_loader_ = LoadDex("MyClassNatives");
-  }
 
-  void CompileForTest(bool direct, const char* method_name, const char* method_sig) {
+  void CompileForTest(ClassLoader* class_loader, bool direct,
+                      const char* method_name, const char* method_sig) {
     // Compile the native method before starting the runtime
-    Class* c = class_linker_->FindClass("LMyClass;", class_loader_);
+    Class* c = class_linker_->FindClass("LMyClass;", class_loader);
     Method* method;
     if (direct) {
       method = c->FindDirectMethod(method_name, method_sig);
@@ -43,9 +40,10 @@
     ASSERT_TRUE(method->GetCode() != NULL);
   }
 
-  void SetupForTest(bool direct, const char* method_name, const char* method_sig,
+  void SetupForTest(ClassLoader* class_loader, bool direct,
+                    const char* method_name, const char* method_sig,
                     void* native_fnptr) {
-    CompileForTest(direct, method_name, method_sig);
+    CompileForTest(class_loader, direct, method_name, method_sig);
     if (!runtime_->IsStarted()) {
       runtime_->Start();
     }
@@ -78,7 +76,6 @@
   static jclass jklass_;
   static jobject jobj_;
  protected:
-  const ClassLoader* class_loader_;
   JNIEnv* env_;
   jmethodID jmethod_;
 };
@@ -88,7 +85,8 @@
 
 int gJava_MyClass_foo_calls = 0;
 void Java_MyClass_foo(JNIEnv* env, jobject thisObj) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -97,7 +95,9 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunNoArgMethod) {
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "foo", "()V",
+               reinterpret_cast<void*>(&Java_MyClass_foo));
 
   EXPECT_EQ(0, gJava_MyClass_foo_calls);
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
@@ -107,15 +107,13 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntMethodThroughStub) {
-  SetupForTest(false,
-               "bar",
-               "(I)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "bar", "(I)I",
                NULL /* dlsym will find &Java_MyClass_bar later */);
 
   std::string path("libarttest.so");
   std::string reason;
-  ASSERT_TRUE(Runtime::Current()->GetJavaVM()->LoadNativeLibrary(
-      path, const_cast<ClassLoader*>(class_loader_), reason))
+  ASSERT_TRUE(Runtime::Current()->GetJavaVM()->LoadNativeLibrary(path, class_loader.get(), reason))
       << path << ": " << reason;
 
   jint result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 24);
@@ -124,7 +122,8 @@
 
 int gJava_MyClass_fooI_calls = 0;
 jint Java_MyClass_fooI(JNIEnv* env, jobject thisObj, jint x) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -134,7 +133,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntMethod) {
-  SetupForTest(false, "fooI", "(I)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I",
                reinterpret_cast<void*>(&Java_MyClass_fooI));
 
   EXPECT_EQ(0, gJava_MyClass_fooI_calls);
@@ -148,7 +148,8 @@
 
 int gJava_MyClass_fooII_calls = 0;
 jint Java_MyClass_fooII(JNIEnv* env, jobject thisObj, jint x, jint y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -158,7 +159,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntIntMethod) {
-  SetupForTest(false, "fooII", "(II)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooII", "(II)I",
                reinterpret_cast<void*>(&Java_MyClass_fooII));
 
   EXPECT_EQ(0, gJava_MyClass_fooII_calls);
@@ -173,7 +175,8 @@
 
 int gJava_MyClass_fooJJ_calls = 0;
 jlong Java_MyClass_fooJJ(JNIEnv* env, jobject thisObj, jlong x, jlong y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -183,7 +186,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunLongLongMethod) {
-  SetupForTest(false, "fooJJ", "(JJ)J",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooJJ", "(JJ)J",
                reinterpret_cast<void*>(&Java_MyClass_fooJJ));
 
   EXPECT_EQ(0, gJava_MyClass_fooJJ_calls);
@@ -199,7 +203,8 @@
 
 int gJava_MyClass_fooDD_calls = 0;
 jdouble Java_MyClass_fooDD(JNIEnv* env, jobject thisObj, jdouble x, jdouble y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -209,7 +214,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunDoubleDoubleMethod) {
-  SetupForTest(false, "fooDD", "(DD)D",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooDD", "(DD)D",
                reinterpret_cast<void*>(&Java_MyClass_fooDD));
 
   EXPECT_EQ(0, gJava_MyClass_fooDD_calls);
@@ -227,7 +233,8 @@
 int gJava_MyClass_fooIOO_calls = 0;
 jobject Java_MyClass_fooIOO(JNIEnv* env, jobject thisObj, jint x, jobject y,
                             jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + this + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -244,7 +251,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntObjectObjectMethod) {
-  SetupForTest(false, "fooIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooIOO));
 
@@ -276,7 +284,8 @@
 
 int gJava_MyClass_fooSII_calls = 0;
 jint Java_MyClass_fooSII(JNIEnv* env, jclass klass, jint x, jint y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + klass
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -286,8 +295,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticIntIntMethod) {
-  SetupForTest(true, "fooSII",
-               "(II)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSII", "(II)I",
                reinterpret_cast<void*>(&Java_MyClass_fooSII));
 
   EXPECT_EQ(0, gJava_MyClass_fooSII_calls);
@@ -298,7 +307,8 @@
 
 int gJava_MyClass_fooSDD_calls = 0;
 jdouble Java_MyClass_fooSDD(JNIEnv* env, jclass klass, jdouble x, jdouble y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + klass
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -308,7 +318,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticDoubleDoubleMethod) {
-  SetupForTest(true, "fooSDD", "(DD)D",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSDD", "(DD)D",
                reinterpret_cast<void*>(&Java_MyClass_fooSDD));
 
   EXPECT_EQ(0, gJava_MyClass_fooSDD_calls);
@@ -325,7 +336,8 @@
 int gJava_MyClass_fooSIOO_calls = 0;
 jobject Java_MyClass_fooSIOO(JNIEnv* env, jclass klass, jint x, jobject y,
                              jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + klass + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -343,7 +355,8 @@
 
 
 TEST_F(JniCompilerTest, CompileAndRunStaticIntObjectObjectMethod) {
-  SetupForTest(true, "fooSIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooSIOO));
 
@@ -376,7 +389,8 @@
 int gJava_MyClass_fooSSIOO_calls = 0;
 jobject Java_MyClass_fooSSIOO(JNIEnv* env, jclass klass, jint x, jobject y,
                              jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + klass + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -393,7 +407,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticSynchronizedIntObjectObjectMethod) {
-  SetupForTest(true, "fooSSIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSSIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooSSIOO));
 
@@ -429,22 +444,25 @@
 }
 
 TEST_F(JniCompilerTest, ExceptionHandling) {
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+
   // all compilation needs to happen before SetupForTest calls Runtime::Start
-  CompileForTest(false, "foo", "()V");
-  CompileForTest(false, "throwException", "()V");
-  CompileForTest(false, "foo", "()V");
+  CompileForTest(class_loader.get(), false, "foo", "()V");
+  CompileForTest(class_loader.get(), false, "throwException", "()V");
+  CompileForTest(class_loader.get(), false, "foo", "()V");
 
   gJava_MyClass_foo_calls = 0;
 
   // Check a single call of a JNI method is ok
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SetupForTest(class_loader.get(), false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(1, gJava_MyClass_foo_calls);
   EXPECT_FALSE(Thread::Current()->IsExceptionPending());
 
   // Get class for exception we expect to be thrown
-  Class* jlre = class_linker_->FindClass("Ljava/lang/RuntimeException;", class_loader_);
-  SetupForTest(false, "throwException", "()V", reinterpret_cast<void*>(&Java_MyClass_throwException));
+  Class* jlre = class_linker_->FindClass("Ljava/lang/RuntimeException;", class_loader.get());
+  SetupForTest(class_loader.get(), false, "throwException", "()V",
+               reinterpret_cast<void*>(&Java_MyClass_throwException));
   // Call Java_MyClass_throwException (JNI method that throws exception)
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(1, gJava_MyClass_foo_calls);
@@ -453,7 +471,7 @@
   Thread::Current()->ClearException();
 
   // Check a single call of a JNI method is ok
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SetupForTest(class_loader.get(), false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(2, gJava_MyClass_foo_calls);
 }
@@ -497,7 +515,9 @@
 }
 
 TEST_F(JniCompilerTest, NativeStackTraceElement) {
-  SetupForTest(false, "fooI", "(I)I", reinterpret_cast<void*>(&Java_MyClass_nativeUpCall));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I",
+               reinterpret_cast<void*>(&Java_MyClass_nativeUpCall));
   jint result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 10);
   EXPECT_EQ(10+9+8+7+6+5+4+3+2+1, result);
 }
@@ -507,7 +527,8 @@
 }
 
 TEST_F(JniCompilerTest, ReturnGlobalRef) {
-  SetupForTest(false, "fooO", "(Ljava/lang/Object;)Ljava/lang/Object;",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooO", "(Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooO));
   jobject result = env_->CallNonvirtualObjectMethod(jobj_, jklass_, jmethod_, jobj_);
   EXPECT_EQ(JNILocalRefType, env_->GetObjectRefType(result));
@@ -523,7 +544,8 @@
 }
 
 TEST_F(JniCompilerTest, LocalReferenceTableClearingTest) {
-  SetupForTest(false, "fooI", "(I)I", reinterpret_cast<void*>(&local_ref_test));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I", reinterpret_cast<void*>(&local_ref_test));
   // 1000 invocations of a method that adds 10 local references
   for (int i=0; i < 1000; i++) {
     jint result = env_->CallIntMethod(jobj_, jmethod_, i);
@@ -541,7 +563,8 @@
 }
 
 TEST_F(JniCompilerTest, JavaLangSystemArrayCopy) {
-  SetupForTest(true, "arraycopy", "(Ljava/lang/Object;ILjava/lang/Object;II)V",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "arraycopy", "(Ljava/lang/Object;ILjava/lang/Object;II)V",
                reinterpret_cast<void*>(&my_arraycopy));
   env_->CallStaticVoidMethod(jklass_, jmethod_, jobj_, 1234, jklass_, 5678, 9876);
 }
@@ -556,7 +579,8 @@
 }
 
 TEST_F(JniCompilerTest, CompareAndSwapInt) {
-  SetupForTest(false, "compareAndSwapInt", "(Ljava/lang/Object;JII)Z",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "compareAndSwapInt", "(Ljava/lang/Object;JII)Z",
                reinterpret_cast<void*>(&my_casi));
   jboolean result = env_->CallBooleanMethod(jobj_, jmethod_, jobj_, 0x12345678ABCDEF88ll, 0xCAFEF00D, 0xEBADF00D);
   EXPECT_EQ(result, JNI_TRUE);