Merge "profiling: Unload malloc hooks on disconnect."
diff --git a/src/profiling/memory/client.cc b/src/profiling/memory/client.cc
index 039e22e..38f50be 100644
--- a/src/profiling/memory/client.cc
+++ b/src/profiling/memory/client.cc
@@ -282,11 +282,12 @@
 //               +------------+    |
 //               |  main      |    v
 // stackbase +-> +------------+ 0xffff
-void Client::RecordMalloc(uint64_t alloc_size,
+bool Client::RecordMalloc(uint64_t alloc_size,
                           uint64_t total_size,
                           uint64_t alloc_address) {
-  if (!inited_.load(std::memory_order_acquire))
-    return;
+  if (!inited_.load(std::memory_order_acquire)) {
+    return false;
+  }
   AllocMetadata metadata;
   const char* stackbase = GetStackBase();
   const char* stacktop = reinterpret_cast<char*>(__builtin_frame_address(0));
@@ -294,7 +295,8 @@
 
   if (stackbase < stacktop) {
     PERFETTO_DFATAL("Stackbase >= stacktop.");
-    return;
+    Shutdown();
+    return false;
   }
 
   uint64_t stack_size = static_cast<uint64_t>(stackbase - stacktop);
@@ -318,36 +320,43 @@
     PERFETTO_PLOG("Failed to send wire message.");
     sock.Shutdown();
     Shutdown();
+    return false;
   }
+  return true;
 }
 
-void Client::RecordFree(uint64_t alloc_address) {
-  if (!inited_.load(std::memory_order_acquire))
-    return;
-  if (!free_page_.Add(
-          alloc_address,
-          1 + sequence_number_.fetch_add(1, std::memory_order_acq_rel),
-          &socket_pool_))
-    Shutdown();
-}
-
-size_t Client::ShouldSampleAlloc(uint64_t alloc_size,
-                                 void* (*unhooked_malloc)(size_t),
-                                 void (*unhooked_free)(void*)) {
+bool Client::RecordFree(uint64_t alloc_address) {
   if (!inited_.load(std::memory_order_acquire))
     return false;
-  return SampleSize(pthread_key_.get(), alloc_size, client_config_.interval,
-                    unhooked_malloc, unhooked_free);
+  bool success = free_page_.Add(
+      alloc_address,
+      1 + sequence_number_.fetch_add(1, std::memory_order_acq_rel),
+      &socket_pool_);
+  if (!success)
+    Shutdown();
+  return success;
 }
 
-void Client::MaybeSampleAlloc(uint64_t alloc_size,
+ssize_t Client::ShouldSampleAlloc(uint64_t alloc_size,
+                                  void* (*unhooked_malloc)(size_t),
+                                  void (*unhooked_free)(void*)) {
+  if (!inited_.load(std::memory_order_acquire))
+    return -1;
+  return static_cast<ssize_t>(SampleSize(pthread_key_.get(), alloc_size,
+                                         client_config_.interval,
+                                         unhooked_malloc, unhooked_free));
+}
+
+bool Client::MaybeSampleAlloc(uint64_t alloc_size,
                               uint64_t alloc_address,
                               void* (*unhooked_malloc)(size_t),
                               void (*unhooked_free)(void*)) {
-  size_t total_size =
+  ssize_t total_size =
       ShouldSampleAlloc(alloc_size, unhooked_malloc, unhooked_free);
   if (total_size > 0)
-    RecordMalloc(alloc_size, total_size, alloc_address);
+    return RecordMalloc(alloc_size, static_cast<size_t>(total_size),
+                        alloc_address);
+  return total_size != -1;
 }
 
 void Client::Shutdown() {
diff --git a/src/profiling/memory/client.h b/src/profiling/memory/client.h
index f71109b..b567f27 100644
--- a/src/profiling/memory/client.h
+++ b/src/profiling/memory/client.h
@@ -130,11 +130,11 @@
  public:
   Client(std::vector<base::UnixSocketRaw> sockets);
   Client(const std::string& sock_name, size_t conns);
-  void RecordMalloc(uint64_t alloc_size,
+  bool RecordMalloc(uint64_t alloc_size,
                     uint64_t total_size,
                     uint64_t alloc_address);
-  void RecordFree(uint64_t alloc_address);
-  void MaybeSampleAlloc(uint64_t alloc_size,
+  bool RecordFree(uint64_t alloc_address);
+  bool MaybeSampleAlloc(uint64_t alloc_size,
                         uint64_t alloc_address,
                         void* (*unhooked_malloc)(size_t),
                         void (*unhooked_free)(void*));
@@ -144,9 +144,9 @@
   bool inited() { return inited_; }
 
  private:
-  size_t ShouldSampleAlloc(uint64_t alloc_size,
-                           void* (*unhooked_malloc)(size_t),
-                           void (*unhooked_free)(void*));
+  ssize_t ShouldSampleAlloc(uint64_t alloc_size,
+                            void* (*unhooked_malloc)(size_t),
+                            void (*unhooked_free)(void*));
   const char* GetStackBase();
 
   std::atomic<bool> inited_{false};
diff --git a/src/profiling/memory/malloc_hooks.cc b/src/profiling/memory/malloc_hooks.cc
index 28142d8..8096e7a 100644
--- a/src/profiling/memory/malloc_hooks.cc
+++ b/src/profiling/memory/malloc_hooks.cc
@@ -27,6 +27,7 @@
 
 #include <sys/system_properties.h>
 
+#include <private/bionic_malloc.h>
 #include <private/bionic_malloc_dispatch.h>
 
 #include "perfetto/base/build_config.h"
@@ -58,6 +59,10 @@
   return g_dispatch.load(std::memory_order_relaxed);
 }
 
+static void MallocDispatchReset(const MallocDispatch* dispatch) {
+  android_mallopt(M_RESET_HOOKS, nullptr, 0);
+}
+
 // This is so we can make an so that we can swap out with the existing
 // libc_malloc_hooks.so
 #ifndef HEAPPROFD_PREFIX
@@ -262,6 +267,7 @@
   perfetto::profiling::Client* client = GetClient();
   if (client)
     client->Shutdown();
+  MallocDispatchReset(GetDispatch());
 }
 
 void HEAPPROFD_ADD_PREFIX(_dump_heap)(const char*) {}
@@ -289,8 +295,9 @@
   perfetto::profiling::Client* client = GetClient();
   void* addr = dispatch->malloc(size);
   if (client) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return addr;
 }
@@ -299,7 +306,8 @@
   const MallocDispatch* dispatch = GetDispatch();
   perfetto::profiling::Client* client = GetClient();
   if (client)
-    client->RecordFree(reinterpret_cast<uint64_t>(pointer));
+    if (!client->RecordFree(reinterpret_cast<uint64_t>(pointer)))
+      MallocDispatchReset(GetDispatch());
   return dispatch->free(pointer);
 }
 
@@ -308,8 +316,9 @@
   perfetto::profiling::Client* client = GetClient();
   void* addr = dispatch->aligned_alloc(alignment, size);
   if (client) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return addr;
 }
@@ -319,8 +328,9 @@
   perfetto::profiling::Client* client = GetClient();
   void* addr = dispatch->memalign(alignment, size);
   if (client) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return addr;
 }
@@ -329,11 +339,13 @@
   const MallocDispatch* dispatch = GetDispatch();
   perfetto::profiling::Client* client = GetClient();
   if (client && pointer)
-    client->RecordFree(reinterpret_cast<uint64_t>(pointer));
+    if (!client->RecordFree(reinterpret_cast<uint64_t>(pointer)))
+      MallocDispatchReset(GetDispatch());
   void* addr = dispatch->realloc(pointer, size);
   if (client && size > 0) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return addr;
 }
@@ -343,8 +355,9 @@
   perfetto::profiling::Client* client = GetClient();
   void* addr = dispatch->calloc(nmemb, size);
   if (client) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(addr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return addr;
 }
@@ -366,8 +379,9 @@
   perfetto::profiling::Client* client = GetClient();
   int res = dispatch->posix_memalign(memptr, alignment, size);
   if (res == 0 && client) {
-    client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(*memptr),
-                             dispatch->malloc, dispatch->free);
+    if (!client->MaybeSampleAlloc(size, reinterpret_cast<uint64_t>(*memptr),
+                                  dispatch->malloc, dispatch->free))
+      MallocDispatchReset(GetDispatch());
   }
   return res;
 }