Add support for changing roots through the root visitor callback.

Needed for copying collectors.

Change-Id: Icc4a342a57e0cfb79587edb02ef8c85e08808877
diff --git a/runtime/thread.cc b/runtime/thread.cc
index a454195..d7d4b1f 100644
--- a/runtime/thread.cc
+++ b/runtime/thread.cc
@@ -1012,9 +1012,10 @@
   }
 }
 
-static void MonitorExitVisitor(const mirror::Object* object, void* arg) NO_THREAD_SAFETY_ANALYSIS {
+static mirror::Object* MonitorExitVisitor(mirror::Object* object, void* arg)
+    NO_THREAD_SAFETY_ANALYSIS {
   Thread* self = reinterpret_cast<Thread*>(arg);
-  mirror::Object* entered_monitor = const_cast<mirror::Object*>(object);
+  mirror::Object* entered_monitor = object;
   if (self->HoldsLock(entered_monitor)) {
     LOG(WARNING) << "Calling MonitorExit on object "
                  << object << " (" << PrettyTypeOf(object) << ")"
@@ -1022,6 +1023,7 @@
                  << *Thread::Current() << " which is detaching";
     entered_monitor->MonitorExit(self);
   }
+  return object;
 }
 
 void Thread::Destroy() {
@@ -1151,8 +1153,12 @@
     size_t num_refs = cur->NumberOfReferences();
     for (size_t j = 0; j < num_refs; j++) {
       mirror::Object* object = cur->GetReference(j);
-      if (object != NULL) {
-        visitor(object, arg);
+      if (object != nullptr) {
+        const mirror::Object* new_obj = visitor(object, arg);
+        DCHECK(new_obj != nullptr);
+        if (new_obj != object) {
+          cur->SetReference(j, const_cast<mirror::Object*>(new_obj));
+        }
       }
     }
   }
@@ -2019,8 +2025,11 @@
         // SIRT for JNI or References for interpreter.
         for (size_t reg = 0; reg < num_regs; ++reg) {
           mirror::Object* ref = shadow_frame->GetVRegReference(reg);
-          if (ref != NULL) {
-            visitor_(ref, reg, this);
+          if (ref != nullptr) {
+            mirror::Object* new_ref = visitor_(ref, reg, this);
+            if (new_ref != ref) {
+             shadow_frame->SetVRegReference(reg, new_ref);
+            }
           }
         }
       } else {
@@ -2040,8 +2049,11 @@
         for (size_t reg = 0; reg < num_regs; ++reg) {
           if (TestBitmap(reg, reg_bitmap)) {
             mirror::Object* ref = shadow_frame->GetVRegReference(reg);
-            if (ref != NULL) {
-              visitor_(ref, reg, this);
+            if (ref != nullptr) {
+              mirror::Object* new_ref = visitor_(ref, reg, this);
+              if (new_ref != ref) {
+               shadow_frame->SetVRegReference(reg, new_ref);
+              }
             }
           }
         }
@@ -2072,19 +2084,25 @@
             // Does this register hold a reference?
             if (TestBitmap(reg, reg_bitmap)) {
               uint32_t vmap_offset;
-              mirror::Object* ref;
               if (vmap_table.IsInContext(reg, kReferenceVReg, &vmap_offset)) {
-                uintptr_t val = GetGPR(vmap_table.ComputeRegister(core_spills, vmap_offset,
-                                                                  kReferenceVReg));
-                ref = reinterpret_cast<mirror::Object*>(val);
+                int vmap_reg = vmap_table.ComputeRegister(core_spills, vmap_offset, kReferenceVReg);
+                mirror::Object* ref = reinterpret_cast<mirror::Object*>(GetGPR(vmap_reg));
+                if (ref != nullptr) {
+                  mirror::Object* new_ref = visitor_(ref, reg, this);
+                  if (ref != new_ref) {
+                    SetGPR(vmap_reg, reinterpret_cast<uintptr_t>(new_ref));
+                  }
+                }
               } else {
-                ref = reinterpret_cast<mirror::Object*>(GetVReg(cur_quick_frame, code_item,
-                                                                core_spills, fp_spills, frame_size,
-                                                                reg));
-              }
-
-              if (ref != NULL) {
-                visitor_(ref, reg, this);
+                uint32_t* reg_addr =
+                    GetVRegAddr(cur_quick_frame, code_item, core_spills, fp_spills, frame_size, reg);
+                mirror::Object* ref = reinterpret_cast<mirror::Object*>(*reg_addr);
+                if (ref != nullptr) {
+                  mirror::Object* new_ref = visitor_(ref, reg, this);
+                  if (ref != new_ref) {
+                    *reg_addr = reinterpret_cast<uint32_t>(new_ref);
+                  }
+                }
               }
             }
           }
@@ -2110,8 +2128,8 @@
  public:
   RootCallbackVisitor(RootVisitor* visitor, void* arg) : visitor_(visitor), arg_(arg) {}
 
-  void operator()(const mirror::Object* obj, size_t, const StackVisitor*) const {
-    visitor_(obj, arg_);
+  mirror::Object* operator()(mirror::Object* obj, size_t, const StackVisitor*) const {
+    return visitor_(obj, arg_);
   }
 
  private:
@@ -2135,67 +2153,17 @@
   void* const arg_;
 };
 
-struct VerifyRootWrapperArg {
-  VerifyRootVisitor* visitor;
-  void* arg;
-};
-
-static void VerifyRootWrapperCallback(const mirror::Object* root, void* arg) {
-  VerifyRootWrapperArg* wrapperArg = reinterpret_cast<VerifyRootWrapperArg*>(arg);
-  wrapperArg->visitor(root, wrapperArg->arg, 0, NULL);
-}
-
-void Thread::VerifyRoots(VerifyRootVisitor* visitor, void* arg) {
-  // We need to map from a RootVisitor to VerifyRootVisitor, so pass in nulls for arguments we
-  // don't have.
-  VerifyRootWrapperArg wrapperArg;
-  wrapperArg.arg = arg;
-  wrapperArg.visitor = visitor;
-
-  if (opeer_ != NULL) {
-    VerifyRootWrapperCallback(opeer_, &wrapperArg);
-  }
-  if (exception_ != NULL) {
-    VerifyRootWrapperCallback(exception_, &wrapperArg);
-  }
-  throw_location_.VisitRoots(VerifyRootWrapperCallback, &wrapperArg);
-  if (class_loader_override_ != NULL) {
-    VerifyRootWrapperCallback(class_loader_override_, &wrapperArg);
-  }
-  jni_env_->locals.VisitRoots(VerifyRootWrapperCallback, &wrapperArg);
-  jni_env_->monitors.VisitRoots(VerifyRootWrapperCallback, &wrapperArg);
-
-  SirtVisitRoots(VerifyRootWrapperCallback, &wrapperArg);
-
-  // Visit roots on this thread's stack
-  Context* context = GetLongJumpContext();
-  VerifyCallbackVisitor visitorToCallback(visitor, arg);
-  ReferenceMapVisitor<VerifyCallbackVisitor> mapper(this, context, visitorToCallback);
-  mapper.WalkStack();
-  ReleaseLongJumpContext(context);
-
-  std::deque<instrumentation::InstrumentationStackFrame>* instrumentation_stack = GetInstrumentationStack();
-  typedef std::deque<instrumentation::InstrumentationStackFrame>::const_iterator It;
-  for (It it = instrumentation_stack->begin(), end = instrumentation_stack->end(); it != end; ++it) {
-    mirror::Object* this_object = (*it).this_object_;
-    if (this_object != NULL) {
-      VerifyRootWrapperCallback(this_object, &wrapperArg);
-    }
-    mirror::ArtMethod* method = (*it).method_;
-    VerifyRootWrapperCallback(method, &wrapperArg);
-  }
-}
-
 void Thread::VisitRoots(RootVisitor* visitor, void* arg) {
-  if (opeer_ != NULL) {
-    visitor(opeer_, arg);
+  if (opeer_ != nullptr) {
+    opeer_ = visitor(opeer_, arg);
   }
-  if (exception_ != NULL) {
-    visitor(exception_, arg);
+  if (exception_ != nullptr) {
+    exception_ = reinterpret_cast<mirror::Throwable*>(visitor(exception_, arg));
   }
   throw_location_.VisitRoots(visitor, arg);
-  if (class_loader_override_ != NULL) {
-    visitor(class_loader_override_, arg);
+  if (class_loader_override_ != nullptr) {
+    class_loader_override_ = reinterpret_cast<mirror::ClassLoader*>(
+        visitor(class_loader_override_, arg));
   }
   jni_env_->locals.VisitRoots(visitor, arg);
   jni_env_->monitors.VisitRoots(visitor, arg);
@@ -2209,24 +2177,26 @@
   mapper.WalkStack();
   ReleaseLongJumpContext(context);
 
-  for (const instrumentation::InstrumentationStackFrame& frame : *GetInstrumentationStack()) {
-    mirror::Object* this_object = frame.this_object_;
-    if (this_object != NULL) {
-      visitor(this_object, arg);
+  for (instrumentation::InstrumentationStackFrame& frame : *GetInstrumentationStack()) {
+    if (frame.this_object_ != nullptr) {
+      frame.this_object_ = visitor(frame.this_object_, arg);
+      DCHECK(frame.this_object_ != nullptr);
     }
-    mirror::ArtMethod* method = frame.method_;
-    visitor(method, arg);
+    frame.method_ = reinterpret_cast<mirror::ArtMethod*>(visitor(frame.method_, arg));
+    DCHECK(frame.method_ != nullptr);
   }
 }
 
-static void VerifyObject(const mirror::Object* root, void* arg) {
-  gc::Heap* heap = reinterpret_cast<gc::Heap*>(arg);
-  heap->VerifyObject(root);
+static mirror::Object* VerifyRoot(mirror::Object* root, void* arg) {
+  DCHECK(root != nullptr);
+  DCHECK(arg != nullptr);
+  reinterpret_cast<gc::Heap*>(arg)->VerifyObject(root);
+  return root;
 }
 
 void Thread::VerifyStackImpl() {
   UniquePtr<Context> context(Context::Create());
-  RootCallbackVisitor visitorToCallback(VerifyObject, Runtime::Current()->GetHeap());
+  RootCallbackVisitor visitorToCallback(VerifyRoot, Runtime::Current()->GetHeap());
   ReferenceMapVisitor<RootCallbackVisitor> mapper(this, context.get(), visitorToCallback);
   mapper.WalkStack();
 }