ART: Change boot image class update

Compute the transitive closure by a search from the already known
image classes.

Bug: 17632031
Change-Id: I66a2793517fc1380f20e12d624a76eed6975832c
diff --git a/compiler/driver/compiler_driver.cc b/compiler/driver/compiler_driver.cc
index e1b5984..aa4789b 100644
--- a/compiler/driver/compiler_driver.cc
+++ b/compiler/driver/compiler_driver.cc
@@ -19,6 +19,7 @@
 #define ATRACE_TAG ATRACE_TAG_DALVIK
 #include <utils/Trace.h>
 
+#include <unordered_set>
 #include <vector>
 #include <unistd.h>
 
@@ -53,6 +54,7 @@
 #include "ScopedLocalRef.h"
 #include "handle_scope-inl.h"
 #include "thread.h"
+#include "thread_list.h"
 #include "thread_pool.h"
 #include "trampolines/trampoline_compiler.h"
 #include "transaction.h"
@@ -773,23 +775,141 @@
   }
 }
 
-void CompilerDriver::FindClinitImageClassesCallback(mirror::Object* object, void* arg) {
-  DCHECK(object != nullptr);
-  DCHECK(arg != nullptr);
-  CompilerDriver* compiler_driver = reinterpret_cast<CompilerDriver*>(arg);
-  StackHandleScope<1> hs(Thread::Current());
-  MaybeAddToImageClasses(hs.NewHandle(object->GetClass()), compiler_driver->image_classes_.get());
-}
+// Keeps all the data for the update together. Also doubles as the reference visitor.
+// Note: we can use object pointers because we suspend all threads.
+class ClinitImageUpdate {
+ public:
+  static ClinitImageUpdate* Create(std::set<std::string>* image_class_descriptors, Thread* self,
+                                   ClassLinker* linker, std::string* error_msg) {
+    std::unique_ptr<ClinitImageUpdate> res(new ClinitImageUpdate(image_class_descriptors, self,
+                                                                 linker));
+    if (res->art_method_class_ == nullptr) {
+      *error_msg = "Could not find ArtMethod class.";
+      return nullptr;
+    } else if (res->dex_cache_class_ == nullptr) {
+      *error_msg = "Could not find DexCache class.";
+      return nullptr;
+    }
+
+    return res.release();
+  }
+
+  ~ClinitImageUpdate() {
+    // Allow others to suspend again.
+    self_->EndAssertNoThreadSuspension(old_cause_);
+  }
+
+  // Visitor for VisitReferences.
+  void operator()(mirror::Object* object, MemberOffset field_offset, bool /* is_static */) const
+      SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
+    mirror::Object* ref = object->GetFieldObject<mirror::Object>(field_offset);
+    if (ref != nullptr) {
+      VisitClinitClassesObject(ref);
+    }
+  }
+
+  // java.lang.Reference visitor for VisitReferences.
+  void operator()(mirror::Class* /*klass*/, mirror::Reference* ref) const {
+  }
+
+  void Walk() SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
+    // Use the initial classes as roots for a search.
+    for (mirror::Class* klass_root : image_classes_) {
+      VisitClinitClassesObject(klass_root);
+    }
+  }
+
+ private:
+  ClinitImageUpdate(std::set<std::string>* image_class_descriptors, Thread* self,
+                    ClassLinker* linker)
+      SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) :
+      image_class_descriptors_(image_class_descriptors), self_(self) {
+    CHECK(linker != nullptr);
+    CHECK(image_class_descriptors != nullptr);
+
+    // Make sure nobody interferes with us.
+    old_cause_ = self->StartAssertNoThreadSuspension("Boot image closure");
+
+    // Find the interesting classes.
+    art_method_class_ = linker->LookupClass(self, "Ljava/lang/reflect/ArtMethod;", nullptr);
+    dex_cache_class_ = linker->LookupClass(self, "Ljava/lang/DexCache;", nullptr);
+
+    // Find all the already-marked classes.
+    WriterMutexLock mu(self, *Locks::heap_bitmap_lock_);
+    linker->VisitClasses(FindImageClasses, this);
+  }
+
+  static bool FindImageClasses(mirror::Class* klass, void* arg)
+      SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
+    ClinitImageUpdate* data = reinterpret_cast<ClinitImageUpdate*>(arg);
+    std::string temp;
+    const char* name = klass->GetDescriptor(&temp);
+    if (data->image_class_descriptors_->find(name) != data->image_class_descriptors_->end()) {
+      data->image_classes_.push_back(klass);
+    }
+
+    return true;
+  }
+
+  void VisitClinitClassesObject(mirror::Object* object) const
+      SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
+    DCHECK(object != nullptr);
+    if (marked_objects_.find(object) != marked_objects_.end()) {
+      // Already processed.
+      return;
+    }
+
+    // Mark it.
+    marked_objects_.insert(object);
+
+    if (object->IsClass()) {
+      // If it is a class, add it.
+      StackHandleScope<1> hs(self_);
+      MaybeAddToImageClasses(hs.NewHandle(object->AsClass()), image_class_descriptors_);
+    } else {
+      // Else visit the object's class.
+      VisitClinitClassesObject(object->GetClass());
+    }
+
+    // If it is not a dex cache or an ArtMethod, visit all references.
+    mirror::Class* klass = object->GetClass();
+    if (klass != art_method_class_ && klass != dex_cache_class_) {
+      object->VisitReferences<false /* visit class */>(*this, *this);
+    }
+  }
+
+  mutable std::unordered_set<mirror::Object*> marked_objects_;
+  std::set<std::string>* const image_class_descriptors_;
+  std::vector<mirror::Class*> image_classes_;
+  const mirror::Class* art_method_class_;
+  const mirror::Class* dex_cache_class_;
+  Thread* const self_;
+  const char* old_cause_;
+
+  DISALLOW_COPY_AND_ASSIGN(ClinitImageUpdate);
+};
 
 void CompilerDriver::UpdateImageClasses(TimingLogger* timings) {
   if (IsImage()) {
     TimingLogger::ScopedTiming t("UpdateImageClasses", timings);
-    // Update image_classes_ with classes for objects created by <clinit> methods.
-    gc::Heap* heap = Runtime::Current()->GetHeap();
-    // TODO: Image spaces only?
-    ScopedObjectAccess soa(Thread::Current());
-    WriterMutexLock mu(soa.Self(), *Locks::heap_bitmap_lock_);
-    heap->VisitObjects(FindClinitImageClassesCallback, this);
+
+    Runtime* current = Runtime::Current();
+
+    // Suspend all threads.
+    current->GetThreadList()->SuspendAll();
+
+    std::string error_msg;
+    std::unique_ptr<ClinitImageUpdate> update(ClinitImageUpdate::Create(image_classes_.get(),
+                                                                        Thread::Current(),
+                                                                        current->GetClassLinker(),
+                                                                        &error_msg));
+    CHECK(update.get() != nullptr) << error_msg;  // TODO: Soft failure?
+
+    // Do the marking.
+    update->Walk();
+
+    // Resume threads.
+    current->GetThreadList()->ResumeAll();
   }
 }