Merge "ART: Add reference following code"
diff --git a/runtime/openjdkjvmti/OpenjdkJvmTi.cc b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
index d9031ea..9d4b554 100644
--- a/runtime/openjdkjvmti/OpenjdkJvmTi.cc
+++ b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
@@ -277,7 +277,13 @@
                                      jobject initial_object,
                                      const jvmtiHeapCallbacks* callbacks,
                                      const void* user_data) {
-    return ERR(NOT_IMPLEMENTED);
+    HeapUtil heap_util(&gObjectTagTable);
+    return heap_util.FollowReferences(env,
+                                      heap_filter,
+                                      klass,
+                                      initial_object,
+                                      callbacks,
+                                      user_data);
   }
 
   static jvmtiError IterateThroughHeap(jvmtiEnv* env,
diff --git a/runtime/openjdkjvmti/object_tagging.h b/runtime/openjdkjvmti/object_tagging.h
index 997cedb..0296f1a 100644
--- a/runtime/openjdkjvmti/object_tagging.h
+++ b/runtime/openjdkjvmti/object_tagging.h
@@ -34,7 +34,7 @@
 class ObjectTagTable : public art::gc::SystemWeakHolder {
  public:
   explicit ObjectTagTable(EventHandler* event_handler)
-      : art::gc::SystemWeakHolder(art::LockLevel::kAllocTrackerLock),
+      : art::gc::SystemWeakHolder(kTaggingLockLevel),
         update_since_last_sweep_(false),
         event_handler_(event_handler) {
   }
@@ -180,6 +180,10 @@
     }
   };
 
+  // The tag table is used when visiting roots. So it needs to have a low lock level.
+  static constexpr art::LockLevel kTaggingLockLevel =
+      static_cast<art::LockLevel>(art::LockLevel::kAbortLock + 1);
+
   std::unordered_map<art::GcRoot<art::mirror::Object>,
                      jlong,
                      HashGcRoot,
diff --git a/runtime/openjdkjvmti/ti_heap.cc b/runtime/openjdkjvmti/ti_heap.cc
index 6b20743..0eff469 100644
--- a/runtime/openjdkjvmti/ti_heap.cc
+++ b/runtime/openjdkjvmti/ti_heap.cc
@@ -16,19 +16,25 @@
 
 #include "ti_heap.h"
 
+#include "art_field-inl.h"
 #include "art_jvmti.h"
 #include "base/macros.h"
 #include "base/mutex.h"
 #include "class_linker.h"
 #include "gc/heap.h"
+#include "gc_root-inl.h"
 #include "jni_env_ext.h"
+#include "jni_internal.h"
 #include "mirror/class.h"
+#include "mirror/object-inl.h"
+#include "mirror/object_array-inl.h"
 #include "object_callbacks.h"
 #include "object_tagging.h"
 #include "obj_ptr-inl.h"
 #include "runtime.h"
 #include "scoped_thread_state_change-inl.h"
 #include "thread-inl.h"
+#include "thread_list.h"
 
 namespace openjdkjvmti {
 
@@ -165,6 +171,466 @@
   return ERR(NONE);
 }
 
+class FollowReferencesHelper FINAL {
+ public:
+  FollowReferencesHelper(HeapUtil* h,
+                         art::ObjPtr<art::mirror::Object> initial_object ATTRIBUTE_UNUSED,
+                         const jvmtiHeapCallbacks* callbacks,
+                         const void* user_data)
+      : tag_table_(h->GetTags()),
+        callbacks_(callbacks),
+        user_data_(user_data),
+        start_(0),
+        stop_reports_(false) {
+  }
+
+  void Init()
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    CollectAndReportRootsVisitor carrv(this, tag_table_, &worklist_, &visited_);
+    art::Runtime::Current()->VisitRoots(&carrv);
+    art::Runtime::Current()->VisitImageRoots(&carrv);
+    stop_reports_ = carrv.IsStopReports();
+
+    if (stop_reports_) {
+      worklist_.clear();
+    }
+  }
+
+  void Work()
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    // Currently implemented as a BFS. To lower overhead, we don't erase elements immediately
+    // from the head of the work list, instead postponing until there's a gap that's "large."
+    //
+    // Alternatively, we can implement a DFS and use the work list as a stack.
+    while (start_ < worklist_.size()) {
+      art::mirror::Object* cur_obj = worklist_[start_];
+      start_++;
+
+      if (start_ >= kMaxStart) {
+        worklist_.erase(worklist_.begin(), worklist_.begin() + start_);
+        start_ = 0;
+      }
+
+      VisitObject(cur_obj);
+
+      if (stop_reports_) {
+        break;
+      }
+    }
+  }
+
+ private:
+  class CollectAndReportRootsVisitor FINAL : public art::RootVisitor {
+   public:
+    CollectAndReportRootsVisitor(FollowReferencesHelper* helper,
+                                 ObjectTagTable* tag_table,
+                                 std::vector<art::mirror::Object*>* worklist,
+                                 std::unordered_set<art::mirror::Object*>* visited)
+        : helper_(helper),
+          tag_table_(tag_table),
+          worklist_(worklist),
+          visited_(visited),
+          stop_reports_(false) {}
+
+    void VisitRoots(art::mirror::Object*** roots, size_t count, const art::RootInfo& info)
+        OVERRIDE
+        REQUIRES_SHARED(art::Locks::mutator_lock_)
+        REQUIRES(!*helper_->tag_table_->GetAllowDisallowLock()) {
+      for (size_t i = 0; i != count; ++i) {
+        AddRoot(*roots[i], info);
+      }
+    }
+
+    void VisitRoots(art::mirror::CompressedReference<art::mirror::Object>** roots,
+                    size_t count,
+                    const art::RootInfo& info)
+        OVERRIDE REQUIRES_SHARED(art::Locks::mutator_lock_)
+        REQUIRES(!*helper_->tag_table_->GetAllowDisallowLock()) {
+      for (size_t i = 0; i != count; ++i) {
+        AddRoot(roots[i]->AsMirrorPtr(), info);
+      }
+    }
+
+    bool IsStopReports() {
+      return stop_reports_;
+    }
+
+   private:
+    void AddRoot(art::mirror::Object* root_obj, const art::RootInfo& info)
+        REQUIRES_SHARED(art::Locks::mutator_lock_)
+        REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+      // We use visited_ to mark roots already so we do not need another set.
+      if (visited_->find(root_obj) == visited_->end()) {
+        visited_->insert(root_obj);
+        worklist_->push_back(root_obj);
+      }
+      ReportRoot(root_obj, info);
+    }
+
+    jvmtiHeapReferenceKind GetReferenceKind(const art::RootInfo& info,
+                                            jvmtiHeapReferenceInfo* ref_info)
+        REQUIRES_SHARED(art::Locks::mutator_lock_) {
+      // TODO: Fill in ref_info.
+      memset(ref_info, 0, sizeof(jvmtiHeapReferenceInfo));
+
+      switch (info.GetType()) {
+        case art::RootType::kRootJNIGlobal:
+          return JVMTI_HEAP_REFERENCE_JNI_GLOBAL;
+
+        case art::RootType::kRootJNILocal:
+          return JVMTI_HEAP_REFERENCE_JNI_LOCAL;
+
+        case art::RootType::kRootJavaFrame:
+          return JVMTI_HEAP_REFERENCE_STACK_LOCAL;
+
+        case art::RootType::kRootNativeStack:
+        case art::RootType::kRootThreadBlock:
+        case art::RootType::kRootThreadObject:
+          return JVMTI_HEAP_REFERENCE_THREAD;
+
+        case art::RootType::kRootStickyClass:
+        case art::RootType::kRootInternedString:
+          // Note: this isn't a root in the RI.
+          return JVMTI_HEAP_REFERENCE_SYSTEM_CLASS;
+
+        case art::RootType::kRootMonitorUsed:
+        case art::RootType::kRootJNIMonitor:
+          return JVMTI_HEAP_REFERENCE_MONITOR;
+
+        case art::RootType::kRootFinalizing:
+        case art::RootType::kRootDebugger:
+        case art::RootType::kRootReferenceCleanup:
+        case art::RootType::kRootVMInternal:
+        case art::RootType::kRootUnknown:
+          return JVMTI_HEAP_REFERENCE_OTHER;
+      }
+      LOG(FATAL) << "Unreachable";
+      UNREACHABLE();
+    }
+
+    void ReportRoot(art::mirror::Object* root_obj, const art::RootInfo& info)
+        REQUIRES_SHARED(art::Locks::mutator_lock_)
+        REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+      jvmtiHeapReferenceInfo ref_info;
+      jvmtiHeapReferenceKind kind = GetReferenceKind(info, &ref_info);
+      jint result = helper_->ReportReference(kind, &ref_info, nullptr, root_obj);
+      if ((result & JVMTI_VISIT_ABORT) != 0) {
+        stop_reports_ = true;
+      }
+    }
+
+   private:
+    FollowReferencesHelper* helper_;
+    ObjectTagTable* tag_table_;
+    std::vector<art::mirror::Object*>* worklist_;
+    std::unordered_set<art::mirror::Object*>* visited_;
+    bool stop_reports_;
+  };
+
+  void VisitObject(art::mirror::Object* obj)
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    if (obj->IsClass()) {
+      VisitClass(obj->AsClass());
+      return;
+    }
+    if (obj->IsArrayInstance()) {
+      VisitArray(obj);
+      return;
+    }
+
+    // TODO: We'll probably have to rewrite this completely with our own visiting logic, if we
+    //       want to have a chance of getting the field indices computed halfway efficiently. For
+    //       now, ignore them altogether.
+
+    struct InstanceReferenceVisitor {
+      explicit InstanceReferenceVisitor(FollowReferencesHelper* helper_)
+          : helper(helper_), stop_reports(false) {}
+
+      void operator()(art::mirror::Object* src,
+                      art::MemberOffset field_offset,
+                      bool is_static ATTRIBUTE_UNUSED) const
+          REQUIRES_SHARED(art::Locks::mutator_lock_)
+          REQUIRES(!*helper->tag_table_->GetAllowDisallowLock()) {
+        if (stop_reports) {
+          return;
+        }
+
+        art::mirror::Object* trg = src->GetFieldObjectReferenceAddr(field_offset)->AsMirrorPtr();
+        jvmtiHeapReferenceInfo reference_info;
+        memset(&reference_info, 0, sizeof(reference_info));
+
+        // TODO: Implement spec-compliant numbering.
+        reference_info.field.index = field_offset.Int32Value();
+
+        jvmtiHeapReferenceKind kind =
+            field_offset.Int32Value() == art::mirror::Object::ClassOffset().Int32Value()
+                ? JVMTI_HEAP_REFERENCE_CLASS
+                : JVMTI_HEAP_REFERENCE_FIELD;
+        const jvmtiHeapReferenceInfo* reference_info_ptr =
+            kind == JVMTI_HEAP_REFERENCE_CLASS ? nullptr : &reference_info;
+
+        stop_reports = !helper->ReportReferenceMaybeEnqueue(kind, reference_info_ptr, src, trg);
+      }
+
+      void VisitRoot(art::mirror::CompressedReference<art::mirror::Object>* root ATTRIBUTE_UNUSED)
+          const {
+        LOG(FATAL) << "Unreachable";
+      }
+      void VisitRootIfNonNull(
+          art::mirror::CompressedReference<art::mirror::Object>* root ATTRIBUTE_UNUSED) const {
+        LOG(FATAL) << "Unreachable";
+      }
+
+      // "mutable" required by the visitor API.
+      mutable FollowReferencesHelper* helper;
+      mutable bool stop_reports;
+    };
+
+    InstanceReferenceVisitor visitor(this);
+    // Visit references, not native roots.
+    obj->VisitReferences<false>(visitor, art::VoidFunctor());
+
+    stop_reports_ = visitor.stop_reports;
+  }
+
+  void VisitArray(art::mirror::Object* array)
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_CLASS,
+                                                 nullptr,
+                                                 array,
+                                                 array->GetClass());
+    if (stop_reports_) {
+      return;
+    }
+
+    if (array->IsObjectArray()) {
+      art::mirror::ObjectArray<art::mirror::Object>* obj_array =
+          array->AsObjectArray<art::mirror::Object>();
+      int32_t length = obj_array->GetLength();
+      for (int32_t i = 0; i != length; ++i) {
+        art::mirror::Object* elem = obj_array->GetWithoutChecks(i);
+        if (elem != nullptr) {
+          jvmtiHeapReferenceInfo reference_info;
+          reference_info.array.index = i;
+          stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT,
+                                                       &reference_info,
+                                                       array,
+                                                       elem);
+          if (stop_reports_) {
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  void VisitClass(art::mirror::Class* klass)
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    // TODO: Are erroneous classes reported? Are non-prepared ones? For now, just use resolved ones.
+    if (!klass->IsResolved()) {
+      return;
+    }
+
+    // Superclass.
+    stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_SUPERCLASS,
+                                                 nullptr,
+                                                 klass,
+                                                 klass->GetSuperClass());
+    if (stop_reports_) {
+      return;
+    }
+
+    // Directly implemented or extended interfaces.
+    art::Thread* self = art::Thread::Current();
+    art::StackHandleScope<1> hs(self);
+    art::Handle<art::mirror::Class> h_klass(hs.NewHandle<art::mirror::Class>(klass));
+    for (size_t i = 0; i < h_klass->NumDirectInterfaces(); ++i) {
+      art::ObjPtr<art::mirror::Class> inf_klass =
+          art::mirror::Class::GetDirectInterface(self, h_klass, i);
+      if (inf_klass == nullptr) {
+        // TODO: With a resolved class this should not happen...
+        self->ClearException();
+        break;
+      }
+
+      stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_INTERFACE,
+                                                   nullptr,
+                                                   klass,
+                                                   inf_klass.Ptr());
+      if (stop_reports_) {
+        return;
+      }
+    }
+
+    // Classloader.
+    // TODO: What about the boot classpath loader? We'll skip for now, but do we have to find the
+    //       fake BootClassLoader?
+    if (klass->GetClassLoader() != nullptr) {
+      stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_CLASS_LOADER,
+                                                   nullptr,
+                                                   klass,
+                                                   klass->GetClassLoader());
+      if (stop_reports_) {
+        return;
+      }
+    }
+    DCHECK_EQ(h_klass.Get(), klass);
+
+    // Declared static fields.
+    for (auto& field : klass->GetSFields()) {
+      if (!field.IsPrimitiveType()) {
+        art::ObjPtr<art::mirror::Object> field_value = field.GetObject(klass);
+        if (field_value != nullptr) {
+          jvmtiHeapReferenceInfo reference_info;
+          memset(&reference_info, 0, sizeof(reference_info));
+
+          // TODO: Implement spec-compliant numbering.
+          reference_info.field.index = field.GetOffset().Int32Value();
+
+          stop_reports_ = !ReportReferenceMaybeEnqueue(JVMTI_HEAP_REFERENCE_STATIC_FIELD,
+                                                       &reference_info,
+                                                       klass,
+                                                       field_value.Ptr());
+          if (stop_reports_) {
+            return;
+          }
+        }
+      }
+    }
+  }
+
+  void MaybeEnqueue(art::mirror::Object* obj) REQUIRES_SHARED(art::Locks::mutator_lock_) {
+    if (visited_.find(obj) == visited_.end()) {
+      worklist_.push_back(obj);
+      visited_.insert(obj);
+    }
+  }
+
+  bool ReportReferenceMaybeEnqueue(jvmtiHeapReferenceKind kind,
+                                   const jvmtiHeapReferenceInfo* reference_info,
+                                   art::mirror::Object* referree,
+                                   art::mirror::Object* referrer)
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    jint result = ReportReference(kind, reference_info, referree, referrer);
+    if ((result & JVMTI_VISIT_ABORT) == 0) {
+      if ((result & JVMTI_VISIT_OBJECTS) != 0) {
+        MaybeEnqueue(referrer);
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  jint ReportReference(jvmtiHeapReferenceKind kind,
+                       const jvmtiHeapReferenceInfo* reference_info,
+                       art::mirror::Object* referrer,
+                       art::mirror::Object* referree)
+      REQUIRES_SHARED(art::Locks::mutator_lock_)
+      REQUIRES(!*tag_table_->GetAllowDisallowLock()) {
+    if (referree == nullptr || stop_reports_) {
+      return 0;
+    }
+
+    const jlong class_tag = tag_table_->GetTagOrZero(referree->GetClass());
+    const jlong referrer_class_tag =
+        referrer == nullptr ? 0 : tag_table_->GetTagOrZero(referrer->GetClass());
+    const jlong size = static_cast<jlong>(referree->SizeOf());
+    jlong tag = tag_table_->GetTagOrZero(referree);
+    jlong saved_tag = tag;
+    jlong referrer_tag = 0;
+    jlong saved_referrer_tag = 0;
+    jlong* referrer_tag_ptr;
+    if (referrer == nullptr) {
+      referrer_tag_ptr = nullptr;
+    } else {
+      if (referrer == referree) {
+        referrer_tag_ptr = &tag;
+      } else {
+        referrer_tag = saved_referrer_tag = tag_table_->GetTagOrZero(referrer);
+        referrer_tag_ptr = &referrer_tag;
+      }
+    }
+    jint length = -1;
+    if (referree->IsArrayInstance()) {
+      length = referree->AsArray()->GetLength();
+    }
+
+    jint result = callbacks_->heap_reference_callback(kind,
+                                                      reference_info,
+                                                      class_tag,
+                                                      referrer_class_tag,
+                                                      size,
+                                                      &tag,
+                                                      referrer_tag_ptr,
+                                                      length,
+                                                      const_cast<void*>(user_data_));
+
+    if (tag != saved_tag) {
+      tag_table_->Set(referree, tag);
+    }
+    if (referrer_tag != saved_referrer_tag) {
+      tag_table_->Set(referrer, referrer_tag);
+    }
+
+    return result;
+  }
+
+  ObjectTagTable* tag_table_;
+  const jvmtiHeapCallbacks* callbacks_;
+  const void* user_data_;
+
+  std::vector<art::mirror::Object*> worklist_;
+  size_t start_;
+  static constexpr size_t kMaxStart = 1000000U;
+
+  std::unordered_set<art::mirror::Object*> visited_;
+
+  bool stop_reports_;
+
+  friend class CollectAndReportRootsVisitor;
+};
+
+jvmtiError HeapUtil::FollowReferences(jvmtiEnv* env ATTRIBUTE_UNUSED,
+                                      jint heap_filter ATTRIBUTE_UNUSED,
+                                      jclass klass ATTRIBUTE_UNUSED,
+                                      jobject initial_object,
+                                      const jvmtiHeapCallbacks* callbacks,
+                                      const void* user_data) {
+  if (callbacks == nullptr) {
+    return ERR(NULL_POINTER);
+  }
+
+  if (callbacks->array_primitive_value_callback != nullptr) {
+    // TODO: Implement.
+    return ERR(NOT_IMPLEMENTED);
+  }
+
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);      // Now we know we have the shared lock.
+
+  art::Runtime::Current()->GetHeap()->IncrementDisableMovingGC(self);
+  {
+    art::ObjPtr<art::mirror::Object> o_initial = soa.Decode<art::mirror::Object>(initial_object);
+
+    art::ScopedThreadSuspension sts(self, art::kWaitingForVisitObjects);
+    art::ScopedSuspendAll ssa("FollowReferences");
+
+    FollowReferencesHelper frh(this, o_initial, callbacks, user_data);
+    frh.Init();
+    frh.Work();
+  }
+  art::Runtime::Current()->GetHeap()->DecrementDisableMovingGC(self);
+
+  return ERR(NONE);
+}
+
 jvmtiError HeapUtil::GetLoadedClasses(jvmtiEnv* env,
                                       jint* class_count_ptr,
                                       jclass** classes_ptr) {
@@ -215,5 +681,4 @@
 
   return ERR(NONE);
 }
-
 }  // namespace openjdkjvmti
diff --git a/runtime/openjdkjvmti/ti_heap.h b/runtime/openjdkjvmti/ti_heap.h
index 570dd0c..72ee097 100644
--- a/runtime/openjdkjvmti/ti_heap.h
+++ b/runtime/openjdkjvmti/ti_heap.h
@@ -36,6 +36,13 @@
                                 const jvmtiHeapCallbacks* callbacks,
                                 const void* user_data);
 
+  jvmtiError FollowReferences(jvmtiEnv* env,
+                              jint heap_filter,
+                              jclass klass,
+                              jobject initial_object,
+                              const jvmtiHeapCallbacks* callbacks,
+                              const void* user_data);
+
   static jvmtiError ForceGarbageCollection(jvmtiEnv* env);
 
   ObjectTagTable* GetTags() {
diff --git a/test/913-heaps/expected.txt b/test/913-heaps/expected.txt
index 77791a4..dc6e67d 100644
--- a/test/913-heaps/expected.txt
+++ b/test/913-heaps/expected.txt
@@ -1,2 +1,98 @@
 ---
 true true
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 3000@0 [size=132, length=-1]
+root@root --(thread)--> 3000@0 [size=132, length=-1]
+1@1000 --(class)--> 1000@0 [size=123, length=-1]
+1@1000 --(field@8)--> 2@1000 [size=16, length=-1]
+1@1000 --(field@12)--> 3@1001 [size=24, length=-1]
+0@0 --(array-element@0)--> 1@1000 [size=16, length=-1]
+2@1000 --(class)--> 1000@0 [size=123, length=-1]
+3@1001 --(class)--> 1001@0 [size=123, length=-1]
+3@1001 --(field@16)--> 4@1000 [size=16, length=-1]
+3@1001 --(field@20)--> 5@1002 [size=32, length=-1]
+1001@0 --(superclass)--> 1000@0 [size=123, length=-1]
+4@1000 --(class)--> 1000@0 [size=123, length=-1]
+5@1002 --(class)--> 1002@0 [size=123, length=-1]
+5@1002 --(field@24)--> 6@1000 [size=16, length=-1]
+5@1002 --(field@28)--> 1@1000 [size=16, length=-1]
+1002@0 --(superclass)--> 1001@0 [size=123, length=-1]
+1002@0 --(interface)--> 2001@0 [size=132, length=-1]
+6@1000 --(class)--> 1000@0 [size=123, length=-1]
+2001@0 --(interface)--> 2000@0 [size=132, length=-1]
+---
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 2@1000 [size=16, length=-1]
+root@root --(stack-local)--> 3000@0 [size=132, length=-1]
+root@root --(thread)--> 2@1000 [size=16, length=-1]
+root@root --(thread)--> 3000@0 [size=132, length=-1]
+2@1000 --(class)--> 1000@0 [size=123, length=-1]
+1@1000 --(class)--> 1000@0 [size=123, length=-1]
+1@1000 --(field@8)--> 2@1000 [size=16, length=-1]
+1@1000 --(field@12)--> 3@1001 [size=24, length=-1]
+3@1001 --(class)--> 1001@0 [size=123, length=-1]
+3@1001 --(field@16)--> 4@1000 [size=16, length=-1]
+3@1001 --(field@20)--> 5@1002 [size=32, length=-1]
+0@0 --(array-element@0)--> 1@1000 [size=16, length=-1]
+1001@0 --(superclass)--> 1000@0 [size=123, length=-1]
+4@1000 --(class)--> 1000@0 [size=123, length=-1]
+5@1002 --(class)--> 1002@0 [size=123, length=-1]
+5@1002 --(field@24)--> 6@1000 [size=16, length=-1]
+5@1002 --(field@28)--> 1@1000 [size=16, length=-1]
+1002@0 --(superclass)--> 1001@0 [size=123, length=-1]
+1002@0 --(interface)--> 2001@0 [size=132, length=-1]
+6@1000 --(class)--> 1000@0 [size=123, length=-1]
+2001@0 --(interface)--> 2000@0 [size=132, length=-1]
+---
+root@root --(jni-global)--> 1@1000 [size=16, length=-1]
+root@root --(jni-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(thread)--> 1@1000 [size=16, length=-1]
+root@root --(thread)--> 3000@0 [size=132, length=-1]
+1@1000 --(class)--> 1000@0 [size=123, length=-1]
+1@1000 --(field@8)--> 2@1000 [size=16, length=-1]
+1@1000 --(field@12)--> 3@1001 [size=24, length=-1]
+2@1000 --(class)--> 1000@0 [size=123, length=-1]
+3@1001 --(class)--> 1001@0 [size=123, length=-1]
+3@1001 --(field@16)--> 4@1000 [size=16, length=-1]
+3@1001 --(field@20)--> 5@1002 [size=32, length=-1]
+1001@0 --(superclass)--> 1000@0 [size=123, length=-1]
+4@1000 --(class)--> 1000@0 [size=123, length=-1]
+5@1002 --(class)--> 1002@0 [size=123, length=-1]
+5@1002 --(field@24)--> 6@1000 [size=16, length=-1]
+5@1002 --(field@28)--> 1@1000 [size=16, length=-1]
+1002@0 --(superclass)--> 1001@0 [size=123, length=-1]
+1002@0 --(interface)--> 2001@0 [size=132, length=-1]
+6@1000 --(class)--> 1000@0 [size=123, length=-1]
+2001@0 --(interface)--> 2000@0 [size=132, length=-1]
+---
+root@root --(jni-global)--> 1@1000 [size=16, length=-1]
+root@root --(jni-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 1@1000 [size=16, length=-1]
+root@root --(stack-local)--> 2@1000 [size=16, length=-1]
+root@root --(thread)--> 1@1000 [size=16, length=-1]
+root@root --(thread)--> 2@1000 [size=16, length=-1]
+root@root --(thread)--> 3000@0 [size=132, length=-1]
+1@1000 --(class)--> 1000@0 [size=123, length=-1]
+1@1000 --(field@8)--> 2@1000 [size=16, length=-1]
+1@1000 --(field@12)--> 3@1001 [size=24, length=-1]
+2@1000 --(class)--> 1000@0 [size=123, length=-1]
+3@1001 --(class)--> 1001@0 [size=123, length=-1]
+3@1001 --(field@16)--> 4@1000 [size=16, length=-1]
+3@1001 --(field@20)--> 5@1002 [size=32, length=-1]
+1001@0 --(superclass)--> 1000@0 [size=123, length=-1]
+4@1000 --(class)--> 1000@0 [size=123, length=-1]
+5@1002 --(class)--> 1002@0 [size=123, length=-1]
+5@1002 --(field@24)--> 6@1000 [size=16, length=-1]
+5@1002 --(field@28)--> 1@1000 [size=16, length=-1]
+1002@0 --(superclass)--> 1001@0 [size=123, length=-1]
+1002@0 --(interface)--> 2001@0 [size=132, length=-1]
+6@1000 --(class)--> 1000@0 [size=123, length=-1]
+2001@0 --(interface)--> 2000@0 [size=132, length=-1]
+---
diff --git a/test/913-heaps/heaps.cc b/test/913-heaps/heaps.cc
index 437779a..d74026c 100644
--- a/test/913-heaps/heaps.cc
+++ b/test/913-heaps/heaps.cc
@@ -16,13 +16,18 @@
 
 #include "heaps.h"
 
+#include <inttypes.h>
 #include <stdio.h>
 #include <string.h>
 
+#include <vector>
+
+#include "base/logging.h"
 #include "base/macros.h"
+#include "base/stringprintf.h"
 #include "jni.h"
 #include "openjdkjvmti/jvmti.h"
-
+#include "ti-agent/common_helper.h"
 #include "ti-agent/common_load.h"
 
 namespace art {
@@ -38,6 +43,230 @@
   }
 }
 
+class IterationConfig {
+ public:
+  IterationConfig() {}
+  virtual ~IterationConfig() {}
+
+  virtual jint Handle(jvmtiHeapReferenceKind reference_kind,
+                      const jvmtiHeapReferenceInfo* reference_info,
+                      jlong class_tag,
+                      jlong referrer_class_tag,
+                      jlong size,
+                      jlong* tag_ptr,
+                      jlong* referrer_tag_ptr,
+                      jint length,
+                      void* user_data) = 0;
+};
+
+static jint JNICALL HeapReferenceCallback(jvmtiHeapReferenceKind reference_kind,
+                                          const jvmtiHeapReferenceInfo* reference_info,
+                                          jlong class_tag,
+                                          jlong referrer_class_tag,
+                                          jlong size,
+                                          jlong* tag_ptr,
+                                          jlong* referrer_tag_ptr,
+                                          jint length,
+                                          void* user_data) {
+  IterationConfig* config = reinterpret_cast<IterationConfig*>(user_data);
+  return config->Handle(reference_kind,
+                        reference_info,
+                        class_tag,
+                        referrer_class_tag,
+                        size,
+                        tag_ptr,
+                        referrer_tag_ptr,
+                        length,
+                        user_data);
+}
+
+static bool Run(jint heap_filter,
+                jclass klass_filter,
+                jobject initial_object,
+                IterationConfig* config) {
+  jvmtiHeapCallbacks callbacks;
+  memset(&callbacks, 0, sizeof(jvmtiHeapCallbacks));
+  callbacks.heap_reference_callback = HeapReferenceCallback;
+
+  jvmtiError ret = jvmti_env->FollowReferences(heap_filter,
+                                               klass_filter,
+                                               initial_object,
+                                               &callbacks,
+                                               config);
+  if (ret != JVMTI_ERROR_NONE) {
+    char* err;
+    jvmti_env->GetErrorName(ret, &err);
+    printf("Failure running FollowReferences: %s\n", err);
+    return false;
+  }
+  return true;
+}
+
+extern "C" JNIEXPORT jobjectArray JNICALL Java_Main_followReferences(JNIEnv* env,
+                                                                     jclass klass ATTRIBUTE_UNUSED,
+                                                                     jint heap_filter,
+                                                                     jclass klass_filter,
+                                                                     jobject initial_object,
+                                                                     jint stop_after,
+                                                                     jint follow_set,
+                                                                     jobject jniRef) {
+  class PrintIterationConfig FINAL : public IterationConfig {
+   public:
+    PrintIterationConfig(jint _stop_after, jint _follow_set)
+        : counter_(0),
+          stop_after_(_stop_after),
+          follow_set_(_follow_set) {
+    }
+
+    jint Handle(jvmtiHeapReferenceKind reference_kind,
+                const jvmtiHeapReferenceInfo* reference_info,
+                jlong class_tag,
+                jlong referrer_class_tag,
+                jlong size,
+                jlong* tag_ptr,
+                jlong* referrer_tag_ptr,
+                jint length,
+                void* user_data ATTRIBUTE_UNUSED) OVERRIDE {
+      jlong tag = *tag_ptr;
+      // Only check tagged objects.
+      if (tag == 0) {
+        return JVMTI_VISIT_OBJECTS;
+      }
+
+      Print(reference_kind,
+            reference_info,
+            class_tag,
+            referrer_class_tag,
+            size,
+            tag_ptr,
+            referrer_tag_ptr,
+            length);
+
+      counter_++;
+      if (counter_ == stop_after_) {
+        return JVMTI_VISIT_ABORT;
+      }
+
+      if (tag > 0 && tag < 32) {
+        bool should_visit_references = (follow_set_ & (1 << static_cast<int32_t>(tag))) != 0;
+        return should_visit_references ? JVMTI_VISIT_OBJECTS : 0;
+      }
+
+      return JVMTI_VISIT_OBJECTS;
+    }
+
+    void Print(jvmtiHeapReferenceKind reference_kind,
+               const jvmtiHeapReferenceInfo* reference_info,
+               jlong class_tag,
+               jlong referrer_class_tag,
+               jlong size,
+               jlong* tag_ptr,
+               jlong* referrer_tag_ptr,
+               jint length) {
+      std::string referrer_str;
+      if (referrer_tag_ptr == nullptr) {
+        referrer_str = "root@root";
+      } else {
+        referrer_str = StringPrintf("%" PRId64 "@%" PRId64, *referrer_tag_ptr, referrer_class_tag);
+      }
+
+      jlong adapted_size = size;
+      if (*tag_ptr >= 1000) {
+        // This is a class or interface, the size of which will be dependent on the architecture.
+        // Do not print the size, but detect known values and "normalize" for the golden file.
+        if ((sizeof(void*) == 4 && size == 180) || (sizeof(void*) == 8 && size == 232)) {
+          adapted_size = 123;
+        }
+      }
+
+      lines_.push_back(
+          StringPrintf("%s --(%s)--> %" PRId64 "@%" PRId64 " [size=%" PRId64 ", length=%d]",
+                       referrer_str.c_str(),
+                       GetReferenceTypeStr(reference_kind, reference_info).c_str(),
+                       *tag_ptr,
+                       class_tag,
+                       adapted_size,
+                       length));
+    }
+
+    static std::string GetReferenceTypeStr(jvmtiHeapReferenceKind reference_kind,
+                                           const jvmtiHeapReferenceInfo* reference_info) {
+      switch (reference_kind) {
+        case JVMTI_HEAP_REFERENCE_CLASS:
+          return "class";
+        case JVMTI_HEAP_REFERENCE_FIELD:
+          return StringPrintf("field@%d", reference_info->field.index);
+        case JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT:
+          return StringPrintf("array-element@%d", reference_info->array.index);
+        case JVMTI_HEAP_REFERENCE_CLASS_LOADER:
+          return "classloader";
+        case JVMTI_HEAP_REFERENCE_SIGNERS:
+          return "signers";
+        case JVMTI_HEAP_REFERENCE_PROTECTION_DOMAIN:
+          return "protection-domain";
+        case JVMTI_HEAP_REFERENCE_INTERFACE:
+          return "interface";
+        case JVMTI_HEAP_REFERENCE_STATIC_FIELD:
+          return StringPrintf("static-field@%d", reference_info->field.index);
+        case JVMTI_HEAP_REFERENCE_CONSTANT_POOL:
+          return "constant-pool";
+        case JVMTI_HEAP_REFERENCE_SUPERCLASS:
+          return "superclass";
+        case JVMTI_HEAP_REFERENCE_JNI_GLOBAL:
+          return "jni-global";
+        case JVMTI_HEAP_REFERENCE_SYSTEM_CLASS:
+          return "system-class";
+        case JVMTI_HEAP_REFERENCE_MONITOR:
+          return "monitor";
+        case JVMTI_HEAP_REFERENCE_STACK_LOCAL:
+          return "stack-local";
+        case JVMTI_HEAP_REFERENCE_JNI_LOCAL:
+          return "jni-local";
+        case JVMTI_HEAP_REFERENCE_THREAD:
+          return "thread";
+        case JVMTI_HEAP_REFERENCE_OTHER:
+          return "other";
+      }
+      return "unknown";
+    }
+
+    const std::vector<std::string>& GetLines() const {
+      return lines_;
+    }
+
+   private:
+    jint counter_;
+    const jint stop_after_;
+    const jint follow_set_;
+    std::vector<std::string> lines_;
+  };
+
+  // If jniRef isn't null, add a local and a global ref.
+  ScopedLocalRef<jobject> jni_local_ref(env, nullptr);
+  jobject jni_global_ref = nullptr;
+  if (jniRef != nullptr) {
+    jni_local_ref.reset(env->NewLocalRef(jniRef));
+    jni_global_ref = env->NewGlobalRef(jniRef);
+  }
+
+  PrintIterationConfig config(stop_after, follow_set);
+  Run(heap_filter, klass_filter, initial_object, &config);
+
+  const std::vector<std::string>& lines = config.GetLines();
+  jobjectArray ret = CreateObjectArray(env,
+                                       static_cast<jint>(lines.size()),
+                                       "java/lang/String",
+                                       [&](jint i) {
+                                         return env->NewStringUTF(lines[i].c_str());
+                                       });
+
+  if (jni_global_ref != nullptr) {
+    env->DeleteGlobalRef(jni_global_ref);
+  }
+
+  return ret;
+}
+
 // Don't do anything
 jint OnLoad(JavaVM* vm,
             char* options ATTRIBUTE_UNUSED,
diff --git a/test/913-heaps/src/Main.java b/test/913-heaps/src/Main.java
index 4d77a48..f463429 100644
--- a/test/913-heaps/src/Main.java
+++ b/test/913-heaps/src/Main.java
@@ -15,12 +15,14 @@
  */
 
 import java.util.ArrayList;
+import java.util.Collections;
 
 public class Main {
   public static void main(String[] args) throws Exception {
     System.loadLibrary(args[1]);
 
     doTest();
+    doFollowReferencesTest();
   }
 
   public static void doTest() throws Exception {
@@ -43,10 +45,161 @@
   }
 
   private static void printStats() {
-      System.out.println("---");
-      int s = getGcStarts();
-      int f = getGcFinishes();
-      System.out.println((s > 0) + " " + (f > 0));
+    System.out.println("---");
+    int s = getGcStarts();
+    int f = getGcFinishes();
+    System.out.println((s > 0) + " " + (f > 0));
+  }
+
+  public static void doFollowReferencesTest() throws Exception {
+    // Force GCs to clean up dirt.
+    Runtime.getRuntime().gc();
+    Runtime.getRuntime().gc();
+
+    tagClasses();
+    setTag(Thread.currentThread(), 3000);
+
+    {
+      ArrayList<Object> tmpStorage = new ArrayList<>();
+      doFollowReferencesTestNonRoot(tmpStorage);
+      tmpStorage = null;
+    }
+
+    // Force GCs to clean up dirt.
+    Runtime.getRuntime().gc();
+    Runtime.getRuntime().gc();
+
+    doFollowReferencesTestRoot();
+
+    // Force GCs to clean up dirt.
+    Runtime.getRuntime().gc();
+    Runtime.getRuntime().gc();
+  }
+
+  private static void doFollowReferencesTestNonRoot(ArrayList<Object> tmpStorage) {
+    A a = createTree();
+    tmpStorage.add(a);
+    doFollowReferencesTestImpl(null, Integer.MAX_VALUE, -1, null);
+    doFollowReferencesTestImpl(a, Integer.MAX_VALUE, -1, null);
+    tmpStorage.clear();
+  }
+
+  private static void doFollowReferencesTestRoot() {
+    A a = createTree();
+    doFollowReferencesTestImpl(null, Integer.MAX_VALUE, -1, a);
+    doFollowReferencesTestImpl(a, Integer.MAX_VALUE, -1, a);
+  }
+
+  private static void doFollowReferencesTestImpl(A root, int stopAfter, int followSet,
+      Object asRoot) {
+    String[] lines =
+        followReferences(0, null, root == null ? null : root.foo, stopAfter, followSet, asRoot);
+    // Note: sort the roots, as stack locals visit order isn't defined, so may depend on compiled
+    //       code. Do not sort non-roots, as the order here needs to be verified (elements are
+    //       finished before a reference is followed). The test setup (and root visit order)
+    //       luckily ensures that this is deterministic.
+
+    int i = 0;
+    ArrayList<String> rootLines = new ArrayList<>();
+    while (i < lines.length) {
+      if (lines[i].startsWith("root")) {
+        rootLines.add(lines[i]);
+      } else {
+        break;
+      }
+      i++;
+    }
+    Collections.sort(rootLines);
+    for (String l : rootLines) {
+      System.out.println(l);
+    }
+
+    // Print the non-root lines in order.
+    while (i < lines.length) {
+      System.out.println(lines[i]);
+      i++;
+    }
+
+    System.out.println("---");
+
+    // TODO: Test filters.
+  }
+
+  private static void tagClasses() {
+    setTag(A.class, 1000);
+    setTag(B.class, 1001);
+    setTag(C.class, 1002);
+    setTag(I1.class, 2000);
+    setTag(I2.class, 2001);
+  }
+
+  private static A createTree() {
+    A root = new A();
+    setTag(root, 1);
+
+    A foo = new A();
+    setTag(foo, 2);
+    root.foo = foo;
+
+    B foo2 = new B();
+    setTag(foo2, 3);
+    root.foo2 = foo2;
+
+    A bar = new A();
+    setTag(bar, 4);
+    foo2.bar = bar;
+
+    C bar2 = new C();
+    setTag(bar2, 5);
+    foo2.bar2 = bar2;
+
+    A baz = new A();
+    setTag(baz, 6);
+    bar2.baz = baz;
+    bar2.baz2 = root;
+
+    return root;
+  }
+
+  public static class A {
+    public A foo;
+    public A foo2;
+
+    public A() {}
+    public A(A a, A b) {
+      foo = a;
+      foo2 = b;
+    }
+  }
+
+  public static class B extends A {
+    public A bar;
+    public A bar2;
+
+    public B() {}
+    public B(A a, A b) {
+      bar = a;
+      bar2 = b;
+    }
+  }
+
+  public static interface I1 {
+    public final static int i1Field = 1;
+  }
+
+  public static interface I2 extends I1 {
+    public final static int i2Field = 2;
+  }
+
+  public static class C extends B implements I2 {
+    public A baz;
+    public A baz2;
+
+    public C() {}
+    public C(A a, A b) {
+      baz = a;
+      baz2 = b;
+    }
   }
 
   private static native void setupGcCallback();
@@ -54,4 +207,10 @@
   private static native int getGcStarts();
   private static native int getGcFinishes();
   private static native void forceGarbageCollection();
+
+  private static native void setTag(Object o, long tag);
+  private static native long getTag(Object o);
+
+  private static native String[] followReferences(int heapFilter, Class<?> klassFilter,
+      Object initialObject, int stopAfter, int followSet, Object jniRef);
 }