Fix JNI thread state transitions.

Thread state transitions need correct fencing. This change introduces
the fences and makes the behaviour match that of Thread::SetState.

Change-Id: Ia0ff68e2493ae153cf24d251e610b02b3f39d93e
diff --git a/src/assembler.h b/src/assembler.h
index 4d085e0..0808e06 100644
--- a/src/assembler.h
+++ b/src/assembler.h
@@ -360,6 +360,8 @@
   virtual void Copy(FrameOffset dest, FrameOffset src, ManagedRegister scratch,
                     unsigned int size) = 0;
 
+  virtual void MemoryBarrier(ManagedRegister scratch) = 0;
+
   // Exploit fast access in managed code to Thread::Current()
   virtual void GetCurrentThread(ManagedRegister tr) = 0;
   virtual void GetCurrentThread(FrameOffset dest_offset,
diff --git a/src/assembler_arm.cc b/src/assembler_arm.cc
index 93d08b2..6d51867 100644
--- a/src/assembler_arm.cc
+++ b/src/assembler_arm.cc
@@ -1667,6 +1667,24 @@
   }
 }
 
+void ArmAssembler::MemoryBarrier(ManagedRegister mscratch) {
+#if ANDROID_SMP != 0
+#if defined(__ARM_HAVE_DMB)
+  int32_t encoding = 0xf57ff05f;  // dmb
+  Emit(encoding);
+#elif  defined(__ARM_HAVE_LDREX_STREX)
+  CHECK(mscratch.AsArm().AsCoreRegister() == R12);
+  LoadImmediate(R12, 0);
+  int32_t encoding = 0xee07cfba;  // mcr p15, 0, r12, c7, c10, 5
+  Emit(encoding);
+#else
+  CHECK(mscratch.AsArm().AsCoreRegister() == R12);
+  LoadImmediate(R12, 0xffff0fa0);  // kuser_memory_barrier
+  blx(R12);
+#endif
+#endif
+}
+
 void ArmAssembler::CreateSirtEntry(ManagedRegister mout_reg,
                                    FrameOffset sirt_offset,
                                    ManagedRegister min_reg, bool null_allowed) {
@@ -1796,10 +1814,9 @@
   __ Bind(&entry_);
   // Save return value
   __ Store(return_save_location_, return_register_, return_size_);
-  // Pass top of stack as argument
-  __ mov(R0, ShifterOperand(SP));
-  __ LoadFromOffset(kLoadWord, R12, TR,
-                         Thread::SuspendCountEntryPointOffset().Int32Value());
+  // Pass thread as argument
+  __ mov(R0, ShifterOperand(TR));
+  __ LoadFromOffset(kLoadWord, R12, TR, OFFSETOF_MEMBER(Thread, pCheckSuspendFromCode));
   // Note: assume that link register will be spilled/filled on method entry/exit
   __ blx(R12);
   // Reload return value
diff --git a/src/assembler_arm.h b/src/assembler_arm.h
index 1a722f7..577fd55 100644
--- a/src/assembler_arm.h
+++ b/src/assembler_arm.h
@@ -481,6 +481,8 @@
   virtual void Copy(FrameOffset dest, FrameOffset src, ManagedRegister scratch,
                     unsigned int size);
 
+  virtual void MemoryBarrier(ManagedRegister scratch);
+
   // Exploit fast access in managed code to Thread::Current()
   virtual void GetCurrentThread(ManagedRegister tr);
   virtual void GetCurrentThread(FrameOffset dest_offset,
diff --git a/src/assembler_x86.cc b/src/assembler_x86.cc
index e126d88..2aae7de 100644
--- a/src/assembler_x86.cc
+++ b/src/assembler_x86.cc
@@ -1585,6 +1585,14 @@
   }
 }
 
+void X86Assembler::MemoryBarrier(ManagedRegister) {
+#if ANDROID_SMP != 0
+  EmitUint8(0x0F);  // mfence
+  EmitUint8(0xAE);
+  EmitOperand(0, Operand(EAX));  // EAX is unused
+#endif
+}
+
 void X86Assembler::CreateSirtEntry(ManagedRegister mout_reg,
                                    FrameOffset sirt_offset,
                                    ManagedRegister min_reg, bool null_allowed) {
@@ -1699,9 +1707,9 @@
   __ Bind(&entry_);
   // Save return value
   __ Store(return_save_location_, return_register_, return_size_);
-  // Pass top of stack as argument
-  __ pushl(ESP);
-  __ fs()->call(Address::Absolute(Thread::SuspendCountEntryPointOffset()));
+  // Pass Thread::Current as argument
+  __ fs()->pushl(Address::Absolute(Thread::SelfOffset()));
+  __ fs()->call(Address::Absolute(OFFSETOF_MEMBER(Thread, pCheckSuspendFromCode)));
   // Release argument
   __ addl(ESP, Immediate(kPointerSize));
   // Reload return value
diff --git a/src/assembler_x86.h b/src/assembler_x86.h
index 86069be..7152e14 100644
--- a/src/assembler_x86.h
+++ b/src/assembler_x86.h
@@ -505,6 +505,8 @@
   virtual void Copy(FrameOffset dest, FrameOffset src, ManagedRegister scratch,
                     size_t size);
 
+  virtual void MemoryBarrier(ManagedRegister);
+
   // Exploit fast access in managed code to Thread::Current()
   virtual void GetCurrentThread(ManagedRegister tr);
   virtual void GetCurrentThread(FrameOffset dest_offset,
diff --git a/src/jni_compiler.cc b/src/jni_compiler.cc
index 4d454cf..b072833 100644
--- a/src/jni_compiler.cc
+++ b/src/jni_compiler.cc
@@ -140,12 +140,12 @@
   // 4. Transition from being in managed to native code. Save the top_of_managed_stack_
   // so that the managed stack can be crawled while in native code. Clear the corresponding
   // PC value that has no meaning for the this frame.
-  // TODO: ensure the transition to native follow a store fence.
   __ StoreStackPointerToThread(Thread::TopOfManagedStackOffset());
   __ StoreImmediateToThread(Thread::TopOfManagedStackPcOffset(), 0,
                             mr_conv->InterproceduralScratchRegister());
-  __ StoreImmediateToThread(Thread::StateOffset(), Thread::kNative,
-                            mr_conv->InterproceduralScratchRegister());
+  ChangeThreadState(jni_asm.get(), Thread::kNative,
+                    mr_conv->InterproceduralScratchRegister(),
+                    ManagedRegister::NoRegister(), FrameOffset(0), 0);
 
   // 5. Move frame down to allow space for out going args. Do for as short a
   //    time as possible to aid profiling..
@@ -337,21 +337,21 @@
 
   // 12. Transition from being in native to managed code, possibly entering a
   //     safepoint
-  CHECK(!jni_conv->InterproceduralScratchRegister()
-        .Equals(jni_conv->ReturnRegister()));  // don't clobber result
+  // Don't clobber result
+  CHECK(!jni_conv->InterproceduralScratchRegister().Equals(jni_conv->ReturnRegister()));
   // Location to preserve result on slow path, ensuring its within the frame
   FrameOffset return_save_location = jni_conv->ReturnValueSaveLocation();
   CHECK(return_save_location.Uint32Value() < frame_size ||
         jni_conv->SizeOfReturnValue() == 0);
-  __ SuspendPoll(jni_conv->InterproceduralScratchRegister(),
-                 jni_conv->ReturnRegister(), return_save_location,
-                 jni_conv->SizeOfReturnValue());
+  ChangeThreadState(jni_asm.get(), Thread::kRunnable,
+                    jni_conv->InterproceduralScratchRegister(),
+                    jni_conv->ReturnRegister(), return_save_location,
+                    jni_conv->SizeOfReturnValue());
+
+  // 13. Check for pending exception and forward if there
   __ ExceptionPoll(jni_conv->InterproceduralScratchRegister());
-  __ StoreImmediateToThread(Thread::StateOffset(), Thread::kRunnable,
-                            jni_conv->InterproceduralScratchRegister());
 
-
-  // 13. Place result in correct register possibly loading from indirect
+  // 14. Place result in correct register possibly loading from indirect
   //     reference table
   if (jni_conv->IsReturnAReference()) {
     __ IncreaseFrameSize(out_arg_size);
@@ -381,18 +381,18 @@
   }
   __ Move(mr_conv->ReturnRegister(), jni_conv->ReturnRegister());
 
-  // 14. Remove SIRT from thread
+  // 15. Remove SIRT from thread
   __ CopyRawPtrToThread(Thread::TopSirtOffset(), jni_conv->SirtLinkOffset(),
                         jni_conv->InterproceduralScratchRegister());
 
-  // 15. Remove activation
+  // 16. Remove activation
   if (native_method->IsSynchronized()) {
     __ RemoveFrame(frame_size, callee_save_regs);
   } else {
     __ RemoveFrame(frame_size, std::vector<ManagedRegister>());
   }
 
-  // 16. Finalize code generation
+  // 17. Finalize code generation
   __ EmitSlowPaths();
   size_t cs = __ CodeSize();
   ByteArray* managed_code = ByteArray::Alloc(cs);
@@ -520,4 +520,33 @@
 #undef __
 }
 
+void JniCompiler::ChangeThreadState(Assembler* jni_asm, Thread::State new_state,
+                                    ManagedRegister scratch, ManagedRegister return_reg,
+                                    FrameOffset return_save_location,
+                                    size_t return_size) {
+  /*
+   * This code mirrors that of Thread::SetState where detail is given on why
+   * barriers occur when they do.
+   */
+#define __ jni_asm->
+  if (new_state == Thread::kRunnable) {
+    /*
+     * Change our status to Thread::kRunnable.  The transition requires
+     * that we check for pending suspension, because the VM considers
+     * us to be "asleep" in all other states, and another thread could
+     * be performing a GC now.
+     */
+    __ StoreImmediateToThread(Thread::StateOffset(), Thread::kRunnable, scratch);
+    __ MemoryBarrier(scratch);
+    __ SuspendPoll(scratch, return_reg, return_save_location, return_size);
+  } else {
+    /*
+     * Not changing to Thread::kRunnable. No additional work required.
+     */
+    __ MemoryBarrier(scratch);
+    __ StoreImmediateToThread(Thread::StateOffset(), new_state, scratch);
+  }
+  #undef __
+}
+
 }  // namespace art
diff --git a/src/jni_compiler.h b/src/jni_compiler.h
index d78404b..086d96e 100644
--- a/src/jni_compiler.h
+++ b/src/jni_compiler.h
@@ -40,6 +40,11 @@
                           JniCallingConvention* jni_conv,
                           ManagedRegister in_reg);
 
+  void ChangeThreadState(Assembler* jni_asm, Thread::State new_state,
+                         ManagedRegister scratch, ManagedRegister return_reg,
+                         FrameOffset return_save_location,
+                         size_t return_size);
+
   // Architecture to generate code for
   InstructionSet instruction_set_;
 
diff --git a/src/jni_compiler_test.cc b/src/jni_compiler_test.cc
index 35a76c9..482ddcb 100644
--- a/src/jni_compiler_test.cc
+++ b/src/jni_compiler_test.cc
@@ -414,40 +414,6 @@
   EXPECT_EQ(7, gJava_MyClass_fooSSIOO_calls);
 }
 
-// TODO: this is broken now we have thread suspend implemented.
-int gSuspendCounterHandler_calls;
-void SuspendCountHandler(Method** frame) {
-  // Check we came here in the native state then transition to runnable to work
-  // on the Object*
-  EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
-  ScopedJniThreadState ts(Thread::Current()->GetJniEnv());
-
-  EXPECT_TRUE((*frame)->GetName()->Equals("fooI"));
-  gSuspendCounterHandler_calls++;
-  //Thread::Current()->DecrementSuspendCount();
-}
-
-TEST_F(JniCompilerTest, DISABLED_SuspendCountAcknowledgement) {
-  SetupForTest(false, "fooI", "(I)I",
-               reinterpret_cast<void*>(&Java_MyClass_fooI));
-  Thread::Current()->RegisterSuspendCountEntryPoint(&SuspendCountHandler);
-
-  gJava_MyClass_fooI_calls = 0;
-  jint result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 42);
-  EXPECT_EQ(42, result);
-  EXPECT_EQ(1, gJava_MyClass_fooI_calls);
-  EXPECT_EQ(0, gSuspendCounterHandler_calls);
-  //Thread::Current()->IncrementSuspendCount();
-  result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 42);
-  EXPECT_EQ(42, result);
-  EXPECT_EQ(2, gJava_MyClass_fooI_calls);
-  EXPECT_EQ(1, gSuspendCounterHandler_calls);
-  result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 42);
-  EXPECT_EQ(42, result);
-  EXPECT_EQ(3, gJava_MyClass_fooI_calls);
-  EXPECT_EQ(1, gSuspendCounterHandler_calls);
-}
-
 void Java_MyClass_throwException(JNIEnv* env, jobject) {
   jclass c = env->FindClass("java/lang/RuntimeException");
   env->ThrowNew(c, "hello");
diff --git a/src/thread.h b/src/thread.h
index de65d42..6f1bcc3 100644
--- a/src/thread.h
+++ b/src/thread.h
@@ -422,10 +422,6 @@
     NotifyLocked();
   }
 
-  void RegisterSuspendCountEntryPoint(void (*handler)(Method**)) {
-    suspend_count_entry_point_ = handler;
-  }
-
   // Linked list recording transitions from native to managed code
   void PushNativeToManagedRecord(NativeToManagedRecord* record) {
     record->last_top_of_managed_stack_ = reinterpret_cast<void*>(top_of_managed_stack_.GetSP());
@@ -509,10 +505,6 @@
     return ThreadOffset(OFFSETOF_MEMBER(Thread, top_sirt_));
   }
 
-  static ThreadOffset SuspendCountEntryPointOffset() {
-    return ThreadOffset(OFFSETOF_MEMBER(Thread, suspend_count_entry_point_));
-  }
-
  private:
   Thread();
   ~Thread();
@@ -622,9 +614,6 @@
   // TLS key used to retrieve the VM thread object.
   static pthread_key_t pthread_key_self_;
 
-  // Entry point called when suspend_count_ is non-zero
-  void (*suspend_count_entry_point_)(Method** frame);
-
   DISALLOW_COPY_AND_ASSIGN(Thread);
 };