profiling: Add unittest for wire_protocol

Fix bug.

Change-Id: Ide59ba95da4cc22970a25701a245516427e9c18e
diff --git a/Android.bp b/Android.bp
index b6dadc0..541e66f 100644
--- a/Android.bp
+++ b/Android.bp
@@ -3859,6 +3859,7 @@
     "src/profiling/memory/unwinding.cc",
     "src/profiling/memory/unwinding_unittest.cc",
     "src/profiling/memory/wire_protocol.cc",
+    "src/profiling/memory/wire_protocol_unittest.cc",
     "src/protozero/message.cc",
     "src/protozero/message_handle.cc",
     "src/protozero/message_handle_unittest.cc",
diff --git a/src/profiling/memory/BUILD.gn b/src/profiling/memory/BUILD.gn
index 663a059..369cb71 100644
--- a/src/profiling/memory/BUILD.gn
+++ b/src/profiling/memory/BUILD.gn
@@ -89,6 +89,7 @@
     "socket_listener_unittest.cc",
     "string_interner_unittest.cc",
     "unwinding_unittest.cc",
+    "wire_protocol_unittest.cc",
   ]
 }
 
diff --git a/src/profiling/memory/unwinding.cc b/src/profiling/memory/unwinding.cc
index 1dcc8c0..8ed2780 100644
--- a/src/profiling/memory/unwinding.cc
+++ b/src/profiling/memory/unwinding.cc
@@ -238,6 +238,8 @@
     FreeRecord& free_rec = rec->free_record;
     FreePageEntry* entries = free_rec.metadata->entries;
     uint64_t num_entries = free_rec.metadata->num_entries;
+    if (num_entries > kFreePageSize)
+      return;
     for (size_t i = 0; i < num_entries; ++i) {
       const FreePageEntry& entry = entries[i];
       metadata->heap_dump.RecordFree(entry.addr, entry.sequence_number);
diff --git a/src/profiling/memory/wire_protocol.cc b/src/profiling/memory/wire_protocol.cc
index 57cea48..255bae2 100644
--- a/src/profiling/memory/wire_protocol.cc
+++ b/src/profiling/memory/wire_protocol.cc
@@ -30,8 +30,8 @@
 bool ViewAndAdvance(char** ptr, T** out, const char* end) {
   if (end - sizeof(T) < *ptr)
     return false;
-  *out = reinterpret_cast<T*>(ptr);
-  ptr += sizeof(T);
+  *out = reinterpret_cast<T*>(*ptr);
+  *ptr += sizeof(T);
   return true;
 }
 }  // namespace
@@ -45,9 +45,11 @@
   iovecs[1].iov_base = const_cast<RecordType*>(&msg.record_type);
   iovecs[1].iov_len = sizeof(msg.record_type);
   if (msg.alloc_header) {
+    PERFETTO_DCHECK(msg.record_type == RecordType::Malloc);
     iovecs[2].iov_base = msg.alloc_header;
     iovecs[2].iov_len = sizeof(*msg.alloc_header);
   } else if (msg.free_header) {
+    PERFETTO_DCHECK(msg.record_type == RecordType::Free);
     iovecs[2].iov_base = msg.free_header;
     iovecs[2].iov_len = sizeof(*msg.free_header);
   } else {
@@ -78,23 +80,27 @@
   char* end = buf + size;
   if (!ViewAndAdvance<RecordType>(&buf, &record_type, end))
     return false;
-  switch (*record_type) {
-    case RecordType::Malloc:
-      if (!ViewAndAdvance<AllocMetadata>(&buf, &out->alloc_header, end))
-        return false;
-      out->payload = buf;
-      if (buf > end) {
-        PERFETTO_DCHECK(false);
-        return false;
-      }
-      out->payload_size = static_cast<size_t>(end - buf);
-      break;
-    case RecordType::Free:
-      if (!ViewAndAdvance<FreeMetadata>(&buf, &out->free_header, end))
-        return false;
-      break;
-  }
+
+  out->payload = nullptr;
+  out->payload_size = 0;
   out->record_type = *record_type;
+
+  if (*record_type == RecordType::Malloc) {
+    if (!ViewAndAdvance<AllocMetadata>(&buf, &out->alloc_header, end))
+      return false;
+    out->payload = buf;
+    if (buf > end) {
+      PERFETTO_DCHECK(false);
+      return false;
+    }
+    out->payload_size = static_cast<size_t>(end - buf);
+  } else if (*record_type == RecordType::Free) {
+    if (!ViewAndAdvance<FreeMetadata>(&buf, &out->free_header, end))
+      return false;
+  } else {
+    PERFETTO_DCHECK(false);
+    return false;
+  }
   return true;
 }
 
diff --git a/src/profiling/memory/wire_protocol_unittest.cc b/src/profiling/memory/wire_protocol_unittest.cc
new file mode 100644
index 0000000..c9e8b22
--- /dev/null
+++ b/src/profiling/memory/wire_protocol_unittest.cc
@@ -0,0 +1,140 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/profiling/memory/wire_protocol.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/scoped_file.h"
+#include "src/profiling/memory/record_reader.h"
+
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace perfetto {
+
+bool operator==(const AllocMetadata& one, const AllocMetadata& other);
+bool operator==(const AllocMetadata& one, const AllocMetadata& other) {
+  return std::tie(one.sequence_number, one.alloc_size, one.alloc_address,
+                  one.stack_pointer, one.stack_pointer_offset, one.arch) ==
+             std::tie(other.sequence_number, other.alloc_size,
+                      other.alloc_address, other.stack_pointer,
+                      other.stack_pointer_offset, other.arch) &&
+         memcmp(one.register_data, other.register_data, kMaxRegisterDataSize) ==
+             0;
+}
+
+bool operator==(const FreeMetadata& one, const FreeMetadata& other);
+bool operator==(const FreeMetadata& one, const FreeMetadata& other) {
+  if (one.num_entries != other.num_entries)
+    return false;
+  for (size_t i = 0; i < one.num_entries; ++i) {
+    if (std::tie(one.entries[i].sequence_number, one.entries[i].addr) !=
+        std::tie(other.entries[i].sequence_number, other.entries[i].addr))
+      return false;
+  }
+  return true;
+}
+
+namespace {
+
+RecordReader::Record ReceiveAll(int sock) {
+  RecordReader record_reader;
+  RecordReader::Record record;
+  bool received = false;
+  while (!received) {
+    RecordReader::ReceiveBuffer buf = record_reader.BeginReceive();
+    ssize_t rd = PERFETTO_EINTR(read(sock, buf.data, buf.size));
+    PERFETTO_CHECK(rd > 0);
+    auto status = record_reader.EndReceive(static_cast<size_t>(rd), &record);
+    switch (status) {
+      case (RecordReader::Result::Noop):
+        break;
+      case (RecordReader::Result::RecordReceived):
+        received = true;
+        break;
+      case (RecordReader::Result::KillConnection):
+        PERFETTO_CHECK(false);
+        break;
+    }
+  }
+  return record;
+}
+
+TEST(WireProtocolTest, AllocMessage) {
+  char payload[] = {0x77, 0x77, 0x77, 0x00};
+  WireMessage msg = {};
+  msg.record_type = RecordType::Malloc;
+  AllocMetadata metadata = {};
+  metadata.sequence_number = 0xA1A2A3A4A5A6A7A8;
+  metadata.alloc_size = 0xB1B2B3B4B5B6B7B8;
+  metadata.alloc_address = 0xC1C2C3C4C5C6C7C8;
+  metadata.stack_pointer = 0xD1D2D3D4D5D6D7D8;
+  metadata.stack_pointer_offset = 0xE1E2E3E4E5E6E7E8;
+  metadata.arch = unwindstack::ARCH_X86;
+  for (size_t i = 0; i < kMaxRegisterDataSize; ++i)
+    metadata.register_data[i] = 0x66;
+  msg.alloc_header = &metadata;
+  msg.payload = payload;
+  msg.payload_size = sizeof(payload);
+
+  int sv[2];
+  ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
+  base::ScopedFile send_sock(sv[0]);
+  base::ScopedFile recv_sock(sv[1]);
+  ASSERT_TRUE(SendWireMessage(*send_sock, msg));
+
+  RecordReader::Record record = ReceiveAll(*recv_sock);
+
+  WireMessage recv_msg;
+  ASSERT_TRUE(ReceiveWireMessage(reinterpret_cast<char*>(record.data.get()),
+                                 record.size, &recv_msg));
+  ASSERT_EQ(recv_msg.record_type, msg.record_type);
+  ASSERT_EQ(*recv_msg.alloc_header, *msg.alloc_header);
+  ASSERT_EQ(recv_msg.payload_size, msg.payload_size);
+  ASSERT_STREQ(recv_msg.payload, msg.payload);
+}
+
+TEST(WireProtocolTest, FreeMessage) {
+  WireMessage msg = {};
+  msg.record_type = RecordType::Free;
+  FreeMetadata metadata = {};
+  metadata.num_entries = kFreePageSize;
+  for (size_t i = 0; i < kFreePageSize; ++i) {
+    metadata.entries[i].sequence_number = 0x111111111111111;
+    metadata.entries[i].addr = 0x222222222222222;
+  }
+  msg.free_header = &metadata;
+
+  int sv[2];
+  ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
+  base::ScopedFile send_sock(sv[0]);
+  base::ScopedFile recv_sock(sv[1]);
+  ASSERT_TRUE(SendWireMessage(*send_sock, msg));
+
+  RecordReader::Record record = ReceiveAll(*recv_sock);
+
+  WireMessage recv_msg;
+  ASSERT_TRUE(ReceiveWireMessage(reinterpret_cast<char*>(record.data.get()),
+                                 record.size, &recv_msg));
+  ASSERT_EQ(recv_msg.record_type, msg.record_type);
+  ASSERT_EQ(*recv_msg.free_header, *msg.free_header);
+  ASSERT_EQ(recv_msg.payload_size, msg.payload_size);
+}
+
+}  // namespace
+}  // namespace perfetto