ART: Add IterateThroughHeapExt

Add another heap extension. This is the same as IterateThroughHeap,
but delivers an additional parameter to the callback that is the
heap id.

Bug: 37283268
Test: m test-art-host
Change-Id: I93c26ff3afe4205f00f2e4ed871384f862886746
diff --git a/runtime/openjdkjvmti/OpenjdkJvmTi.cc b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
index 0921cea..9be486e 100644
--- a/runtime/openjdkjvmti/OpenjdkJvmTi.cc
+++ b/runtime/openjdkjvmti/OpenjdkJvmTi.cc
@@ -1205,6 +1205,30 @@
       return error;
     }
 
+    error = add_extension(
+        reinterpret_cast<jvmtiExtensionFunction>(HeapExtensions::IterateThroughHeapExt),
+        "com.android.art.heap.iterate_through_heap_ext",
+        "Iterate through a heap. This is equivalent to the standard IterateThroughHeap function,"
+        " except for additionally passing the heap id of the current object. The jvmtiHeapCallbacks"
+        " structure is reused, with the callbacks field overloaded to a signature of "
+        "jint (*)(jlong, jlong, jlong*, jint length, void*, jint).",
+        4,
+        {                                                          // NOLINT [whitespace/braces] [4]
+            { "heap_filter", JVMTI_KIND_IN, JVMTI_TYPE_JINT, false},
+            { "klass", JVMTI_KIND_IN, JVMTI_TYPE_JCLASS, true},
+            { "callbacks", JVMTI_KIND_IN_PTR, JVMTI_TYPE_CVOID, false},
+            { "user_data", JVMTI_KIND_IN_PTR, JVMTI_TYPE_CVOID, true}
+        },
+        3,
+        {                                                          // NOLINT [whitespace/braces] [4]
+            JVMTI_ERROR_MUST_POSSESS_CAPABILITY,
+            JVMTI_ERROR_INVALID_CLASS,
+            JVMTI_ERROR_NULL_POINTER
+        });
+    if (error != ERR(NONE)) {
+      return error;
+    }
+
     // Copy into output buffer.
 
     *extension_count_ptr = ext_vector.size();
diff --git a/runtime/openjdkjvmti/ti_heap.cc b/runtime/openjdkjvmti/ti_heap.cc
index fc63ae5..99774c6 100644
--- a/runtime/openjdkjvmti/ti_heap.cc
+++ b/runtime/openjdkjvmti/ti_heap.cc
@@ -1433,6 +1433,33 @@
 static constexpr jint kHeapIdZygote = 2;
 static constexpr jint kHeapIdApp = 3;
 
+static jint GetHeapId(art::ObjPtr<art::mirror::Object> obj)
+    REQUIRES_SHARED(art::Locks::mutator_lock_) {
+  if (obj == nullptr) {
+    return -1;
+  }
+
+  art::gc::Heap* const heap = art::Runtime::Current()->GetHeap();
+  const art::gc::space::ContinuousSpace* const space =
+      heap->FindContinuousSpaceFromObject(obj, true);
+  jint heap_type = kHeapIdApp;
+  if (space != nullptr) {
+    if (space->IsZygoteSpace()) {
+      heap_type = kHeapIdZygote;
+    } else if (space->IsImageSpace() && heap->ObjectIsInBootImageSpace(obj)) {
+      // Only count objects in the boot image as HPROF_HEAP_IMAGE, this leaves app image objects
+      // as HPROF_HEAP_APP. b/35762934
+      heap_type = kHeapIdImage;
+    }
+  } else {
+    const auto* los = heap->GetLargeObjectsSpace();
+    if (los->Contains(obj.Ptr()) && los->IsZygoteLargeObject(art::Thread::Current(), obj.Ptr())) {
+      heap_type = kHeapIdZygote;
+    }
+  }
+  return heap_type;
+};
+
 jvmtiError HeapExtensions::GetObjectHeapId(jvmtiEnv* env, jlong tag, jint* heap_id, ...) {
   if (heap_id == nullptr) {
     return ERR(NULL_POINTER);
@@ -1443,28 +1470,10 @@
   auto work = [&]() REQUIRES_SHARED(art::Locks::mutator_lock_) {
     ObjectTagTable* tag_table = ArtJvmTiEnv::AsArtJvmTiEnv(env)->object_tag_table.get();
     art::ObjPtr<art::mirror::Object> obj = tag_table->Find(tag);
-    if (obj == nullptr) {
+    jint heap_type = GetHeapId(obj);
+    if (heap_type == -1) {
       return ERR(NOT_FOUND);
     }
-
-    art::gc::Heap* const heap = art::Runtime::Current()->GetHeap();
-    const art::gc::space::ContinuousSpace* const space =
-        heap->FindContinuousSpaceFromObject(obj, true);
-    jint heap_type = kHeapIdApp;
-    if (space != nullptr) {
-      if (space->IsZygoteSpace()) {
-        heap_type = kHeapIdZygote;
-      } else if (space->IsImageSpace() && heap->ObjectIsInBootImageSpace(obj)) {
-        // Only count objects in the boot image as HPROF_HEAP_IMAGE, this leaves app image objects
-        // as HPROF_HEAP_APP. b/35762934
-        heap_type = kHeapIdImage;
-      }
-    } else {
-      const auto* los = heap->GetLargeObjectsSpace();
-      if (los->Contains(obj.Ptr()) && los->IsZygoteLargeObject(self, obj.Ptr())) {
-        heap_type = kHeapIdZygote;
-      }
-    }
     *heap_id = heap_type;
     return ERR(NONE);
   };
@@ -1518,4 +1527,36 @@
   }
 }
 
+jvmtiError HeapExtensions::IterateThroughHeapExt(jvmtiEnv* env,
+                                                 jint heap_filter,
+                                                 jclass klass,
+                                                 const jvmtiHeapCallbacks* callbacks,
+                                                 const void* user_data) {
+  if (ArtJvmTiEnv::AsArtJvmTiEnv(env)->capabilities.can_tag_objects != 1) { \
+    return ERR(MUST_POSSESS_CAPABILITY); \
+  }
+
+  // ART extension API: Also pass the heap id.
+  auto ArtIterateHeap = [](art::mirror::Object* obj,
+                           const jvmtiHeapCallbacks* cb_callbacks,
+                           jlong class_tag,
+                           jlong size,
+                           jlong* tag,
+                           jint length,
+                           void* cb_user_data)
+      REQUIRES_SHARED(art::Locks::mutator_lock_) {
+    jint heap_id = GetHeapId(obj);
+    using ArtExtensionAPI = jint (*)(jlong, jlong, jlong*, jint length, void*, jint);
+    return reinterpret_cast<ArtExtensionAPI>(cb_callbacks->heap_iteration_callback)(
+        class_tag, size, tag, length, cb_user_data, heap_id);
+  };
+  return DoIterateThroughHeap(ArtIterateHeap,
+                              env,
+                              ArtJvmTiEnv::AsArtJvmTiEnv(env)->object_tag_table.get(),
+                              heap_filter,
+                              klass,
+                              callbacks,
+                              user_data);
+}
+
 }  // namespace openjdkjvmti
diff --git a/runtime/openjdkjvmti/ti_heap.h b/runtime/openjdkjvmti/ti_heap.h
index b4b71ba..0c973db 100644
--- a/runtime/openjdkjvmti/ti_heap.h
+++ b/runtime/openjdkjvmti/ti_heap.h
@@ -60,6 +60,12 @@
  public:
   static jvmtiError JNICALL GetObjectHeapId(jvmtiEnv* env, jlong tag, jint* heap_id, ...);
   static jvmtiError JNICALL GetHeapName(jvmtiEnv* env, jint heap_id, char** heap_name, ...);
+
+  static jvmtiError JNICALL IterateThroughHeapExt(jvmtiEnv* env,
+                                                  jint heap_filter,
+                                                  jclass klass,
+                                                  const jvmtiHeapCallbacks* callbacks,
+                                                  const void* user_data);
 };
 
 }  // namespace openjdkjvmti