Make class loaders weak roots

Making the class loaders weak roots in the class linker prevents them
from keeping the classes as live. However we currently do mark them
as strong roots to make sure no accidental class unloading occurs
until the logic to free from linear alloc is complete.

Bug: 22720414

Change-Id: I57466236d9ce6fd064dda9a30ce8ab68094fb8b0
diff --git a/runtime/class_linker.cc b/runtime/class_linker.cc
index 5f2c944..73da2cb 100644
--- a/runtime/class_linker.cc
+++ b/runtime/class_linker.cc
@@ -1295,7 +1295,8 @@
 }
 
 void ClassLinker::VisitClassRoots(RootVisitor* visitor, VisitRootFlags flags) {
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  Thread* const self = Thread::Current();
+  WriterMutexLock mu(self, *Locks::classlinker_classes_lock_);
   BufferedRootVisitor<kDefaultBufferedRootCount> buffered_visitor(
       visitor, RootInfo(kRootStickyClass));
   if ((flags & kVisitRootFlagAllRoots) != 0) {
@@ -1315,9 +1316,13 @@
     // Need to make sure to not copy ArtMethods without doing read barriers since the roots are
     // marked concurrently and we don't hold the classlinker_classes_lock_ when we do the copy.
     boot_class_table_.VisitRoots(buffered_visitor);
-    for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-      // May be null for boot ClassLoader.
-      root.VisitRoot(visitor, RootInfo(kRootVMInternal));
+    // TODO: Avoid marking these to enable class unloading.
+    JavaVMExt* const vm = Runtime::Current()->GetJavaVM();
+    for (jweak weak_root : class_loaders_) {
+      mirror::Object* class_loader =
+          down_cast<mirror::ClassLoader*>(vm->DecodeWeakGlobal(self, weak_root));
+      // Don't need to update anything since the class loaders will be updated by SweepSystemWeaks.
+      visitor->VisitRootIfNonNull(&class_loader, RootInfo(kRootVMInternal));
     }
   } else if ((flags & kVisitRootFlagNewRoots) != 0) {
     for (auto& root : new_class_roots_) {
@@ -1353,14 +1358,31 @@
   }
 }
 
+class VisitClassLoaderClassesVisitor : public ClassLoaderVisitor {
+ public:
+  explicit VisitClassLoaderClassesVisitor(ClassVisitor* visitor)
+      : visitor_(visitor),
+        done_(false) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
+    if (!done_ && class_table != nullptr && !class_table->Visit(visitor_)) {
+      // If the visitor ClassTable returns false it means that we don't need to continue.
+      done_ = true;
+    }
+  }
+
+ private:
+  ClassVisitor* const visitor_;
+  // If done is true then we don't need to do any more visiting.
+  bool done_;
+};
+
 void ClassLinker::VisitClassesInternal(ClassVisitor* visitor) {
   if (boot_class_table_.Visit(visitor)) {
-    for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-      ClassTable* const class_table = root.Read()->GetClassTable();
-      if (class_table != nullptr && !class_table->Visit(visitor)) {
-        return;
-      }
-    }
+    VisitClassLoaderClassesVisitor loader_visitor(visitor);
+    VisitClassLoaders(&loader_visitor);
   }
 }
 
@@ -1479,10 +1501,17 @@
   mirror::LongArray::ResetArrayClass();
   mirror::ShortArray::ResetArrayClass();
   STLDeleteElements(&oat_files_);
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    delete class_table;
+  Thread* const self = Thread::Current();
+  JavaVMExt* const vm = Runtime::Current()->GetJavaVM();
+  for (jweak weak_root : class_loaders_) {
+    auto* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      delete class_loader->GetClassTable();
+    }
+    vm->DeleteWeakGlobalRef(self, weak_root);
   }
+  class_loaders_.clear();
 }
 
 mirror::PointerArray* ClassLinker::AllocPointerArray(Thread* self, size_t length) {
@@ -2611,8 +2640,7 @@
                                                   bool allow_failure) {
   // Search assuming unique-ness of dex file.
   JavaVMExt* const vm = self->GetJniEnv()->vm;
-  for (jobject weak_root : dex_caches_) {
-    DCHECK_EQ(GetIndirectRefKind(weak_root), kWeakGlobal);
+  for (jweak weak_root : dex_caches_) {
     mirror::DexCache* dex_cache = down_cast<mirror::DexCache*>(
         vm->DecodeWeakGlobal(self, weak_root));
     if (dex_cache != nullptr && dex_cache->GetDexFile() == &dex_file) {
@@ -2985,15 +3013,25 @@
   dex_cache_image_class_lookup_required_ = false;
 }
 
-void ClassLinker::MoveClassTableToPreZygote() {
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
-  boot_class_table_.FreezeSnapshot();
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
+class MoveClassTableToPreZygoteVisitor : public ClassLoaderVisitor {
+ public:
+  explicit MoveClassTableToPreZygoteVisitor() {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      REQUIRES(Locks::classlinker_classes_lock_)
+      SHARED_REQUIRES(Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
     if (class_table != nullptr) {
       class_table->FreezeSnapshot();
     }
   }
+};
+
+void ClassLinker::MoveClassTableToPreZygote() {
+  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  boot_class_table_.FreezeSnapshot();
+  MoveClassTableToPreZygoteVisitor visitor;
+  VisitClassLoadersAndRemoveClearedLoaders(&visitor);
 }
 
 mirror::Class* ClassLinker::LookupClassFromImage(const char* descriptor) {
@@ -3019,25 +3057,43 @@
   return nullptr;
 }
 
+// Look up classes by hash and descriptor and put all matching ones in the result array.
+class LookupClassesVisitor : public ClassLoaderVisitor {
+ public:
+  LookupClassesVisitor(const char* descriptor, size_t hash, std::vector<mirror::Class*>* result)
+     : descriptor_(descriptor),
+       hash_(hash),
+       result_(result) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
+    mirror::Class* klass = class_table->Lookup(descriptor_, hash_);
+    if (klass != nullptr) {
+      result_->push_back(klass);
+    }
+  }
+
+ private:
+  const char* const descriptor_;
+  const size_t hash_;
+  std::vector<mirror::Class*>* const result_;
+};
+
 void ClassLinker::LookupClasses(const char* descriptor, std::vector<mirror::Class*>& result) {
   result.clear();
   if (dex_cache_image_class_lookup_required_) {
     MoveImageClassesToClassTable();
   }
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  Thread* const self = Thread::Current();
+  ReaderMutexLock mu(self, *Locks::classlinker_classes_lock_);
   const size_t hash = ComputeModifiedUtf8Hash(descriptor);
   mirror::Class* klass = boot_class_table_.Lookup(descriptor, hash);
   if (klass != nullptr) {
     result.push_back(klass);
   }
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    // There can only be one class with the same descriptor per class loader.
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    klass = class_table->Lookup(descriptor, hash);
-    if (klass != nullptr) {
-      result.push_back(klass);
-    }
-  }
+  LookupClassesVisitor visitor(descriptor, hash, &result);
+  VisitClassLoaders(&visitor);
 }
 
 void ClassLinker::VerifyClass(Thread* self, Handle<mirror::Class> klass) {
@@ -4109,7 +4165,8 @@
   ClassTable* class_table = class_loader->GetClassTable();
   if (class_table == nullptr) {
     class_table = new ClassTable;
-    class_loaders_.push_back(class_loader);
+    Thread* const self = Thread::Current();
+    class_loaders_.push_back(self->GetJniEnv()->vm->AddWeakGlobalRef(self, class_loader));
     // Don't already have a class table, add it to the class loader.
     class_loader->SetClassTable(class_table);
   }
@@ -5875,26 +5932,33 @@
      << NumNonZygoteClasses() << "\n";
 }
 
-size_t ClassLinker::NumZygoteClasses() const {
-  size_t sum = boot_class_table_.NumZygoteClasses();
-  for (const GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
+class CountClassesVisitor : public ClassLoaderVisitor {
+ public:
+  CountClassesVisitor() : num_zygote_classes(0), num_non_zygote_classes(0) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
     if (class_table != nullptr) {
-      sum += class_table->NumZygoteClasses();
+      num_zygote_classes += class_table->NumZygoteClasses();
+      num_non_zygote_classes += class_table->NumNonZygoteClasses();
     }
   }
-  return sum;
+
+  size_t num_zygote_classes;
+  size_t num_non_zygote_classes;
+};
+
+size_t ClassLinker::NumZygoteClasses() const {
+  CountClassesVisitor visitor;
+  VisitClassLoaders(&visitor);
+  return visitor.num_zygote_classes + boot_class_table_.NumZygoteClasses();
 }
 
 size_t ClassLinker::NumNonZygoteClasses() const {
-  size_t sum = boot_class_table_.NumNonZygoteClasses();
-  for (const GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    if (class_table != nullptr) {
-      sum += class_table->NumNonZygoteClasses();
-    }
-  }
-  return sum;
+  CountClassesVisitor visitor;
+  VisitClassLoaders(&visitor);
+  return visitor.num_non_zygote_classes + boot_class_table_.NumNonZygoteClasses();
 }
 
 size_t ClassLinker::NumLoadedClasses() {
@@ -6107,4 +6171,35 @@
   find_array_class_cache_next_victim_ = 0;
 }
 
+void ClassLinker::VisitClassLoadersAndRemoveClearedLoaders(ClassLoaderVisitor* visitor) {
+  Thread* const self = Thread::Current();
+  Locks::classlinker_classes_lock_->AssertExclusiveHeld(self);
+  JavaVMExt* const vm = self->GetJniEnv()->vm;
+  for (auto it = class_loaders_.begin(); it != class_loaders_.end();) {
+    const jweak weak_root = *it;
+    mirror::ClassLoader* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      visitor->Visit(class_loader);
+      ++it;
+    } else {
+      // Remove the cleared weak reference from the array.
+      vm->DeleteWeakGlobalRef(self, weak_root);
+      it = class_loaders_.erase(it);
+    }
+  }
+}
+
+void ClassLinker::VisitClassLoaders(ClassLoaderVisitor* visitor) const {
+  Thread* const self = Thread::Current();
+  JavaVMExt* const vm = self->GetJniEnv()->vm;
+  for (jweak weak_root : class_loaders_) {
+    mirror::ClassLoader* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      visitor->Visit(class_loader);
+    }
+  }
+}
+
 }  // namespace art