Pivot source files into top-level src,include and unify test target
See discussion in go/perfetto-build-files .
This is to disambiguate things like
#include "base/logging.h"
when in the chrome tree.
Also this CL unifies the test targets into two monolithic targets:
perfetto_tests and perfetto_benchmarks. This is to avoid ending
up with confusing binary names in the chrome tree (e.g.,
ipc_unittests)
Bug: 68710794
Change-Id: I1768e15b661406052b2be060d7aab0f1e7443a98
diff --git a/src/ipc/BUILD.gn b/src/ipc/BUILD.gn
new file mode 100644
index 0000000..a42e2fe
--- /dev/null
+++ b/src/ipc/BUILD.gn
@@ -0,0 +1,77 @@
+# Copyright (C) 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import("../../gn/perfetto.gni")
+import("../../gn/proto_library.gni")
+import("ipc_library.gni")
+
+source_set("ipc") {
+ public_configs = [ "../../gn:default_config" ]
+ public_deps = [
+ "../../include/perfetto/ipc",
+ ]
+ deps = [
+ ":wire_protocol",
+ "../../gn:default_deps",
+ "../base",
+ ]
+ sources = [
+ "buffered_frame_deserializer.cc",
+ "client_impl.cc",
+ "deferred.cc",
+ "host_impl.cc",
+ "host_impl.h",
+ "service_proxy.cc",
+ "unix_socket.cc",
+ "unix_socket.h",
+ ]
+}
+
+executable("perfetto_ipc_unittests") {
+ testonly = true
+ deps = [
+ ":ipc",
+ ":test_messages",
+ ":wire_protocol",
+ "../../gn:default_deps",
+ "../../gn:gtest_deps",
+ "../base",
+ "../base:test_support",
+ ]
+ sources = [
+ "buffered_frame_deserializer_unittest.cc",
+ "client_impl_unittest.cc",
+ "deferred_unittest.cc",
+ "host_impl_unittest.cc",
+ "test/ipc_integrationtest.cc",
+ "unix_socket_unittest.cc",
+ ]
+}
+
+proto_library("wire_protocol") {
+ sources = [
+ "wire_protocol.proto",
+ ]
+ proto_in_dir = perfetto_root_path
+ proto_out_dir = "protos_lite"
+}
+
+ipc_library("test_messages") {
+ sources = [
+ "test/client_unittest_messages.proto",
+ "test/deferred_unittest_messages.proto",
+ "test/greeter_service.proto",
+ ]
+ proto_in_dir = perfetto_root_path
+}
diff --git a/src/ipc/buffered_frame_deserializer.cc b/src/ipc/buffered_frame_deserializer.cc
new file mode 100644
index 0000000..0f40186
--- /dev/null
+++ b/src/ipc/buffered_frame_deserializer.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/buffered_frame_deserializer.h"
+
+#include <inttypes.h>
+#include <sys/mman.h>
+
+#include <algorithm>
+#include <type_traits>
+#include <utility>
+
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/utils.h"
+
+#include "src/ipc/wire_protocol.pb.h"
+
+namespace perfetto {
+namespace ipc {
+
+namespace {
+constexpr size_t kPageSize = 4096;
+
+// Size of the PROT_NONE guard region, adjactent to the end of the buffer.
+// It's a safety net to spot any out-of-bounds writes early.
+constexpr size_t kGuardRegionSize = kPageSize;
+
+// The header is just the number of bytes of the Frame protobuf message.
+constexpr size_t kHeaderSize = sizeof(uint32_t);
+} // namespace
+
+BufferedFrameDeserializer::BufferedFrameDeserializer(size_t max_capacity)
+ : capacity_(max_capacity) {
+ PERFETTO_CHECK(max_capacity % kPageSize == 0);
+ PERFETTO_CHECK(max_capacity > kPageSize);
+}
+
+BufferedFrameDeserializer::~BufferedFrameDeserializer() {
+ if (!buf_)
+ return;
+ int res = munmap(buf_, capacity_ + kGuardRegionSize);
+ PERFETTO_DCHECK(res == 0);
+}
+
+BufferedFrameDeserializer::ReceiveBuffer
+BufferedFrameDeserializer::BeginReceive() {
+ // Upon the first recv initialize the buffer to the max message size but
+ // release the physical memory for all but the first page. The kernel will
+ // automatically give us physical pages back as soon as we page-fault on them.
+ if (!buf_) {
+ PERFETTO_DCHECK(size_ == 0);
+ buf_ = reinterpret_cast<char*>(mmap(nullptr, capacity_ + kGuardRegionSize,
+ PROT_READ | PROT_WRITE,
+ MAP_ANONYMOUS | MAP_PRIVATE, 0, 0));
+ PERFETTO_CHECK(buf_ != MAP_FAILED);
+
+ // Surely we are going to use at least the first page. There is very little
+ // point in madvising that as well and immedately after telling the kernel
+ // that we want it back (via recv()).
+ int res = madvise(buf_ + kPageSize,
+ capacity_ + kGuardRegionSize - kPageSize, MADV_DONTNEED);
+ PERFETTO_DCHECK(res == 0);
+
+ res = mprotect(buf_ + capacity_, kGuardRegionSize, PROT_NONE);
+ PERFETTO_DCHECK(res == 0);
+ }
+
+ PERFETTO_CHECK(capacity_ > size_);
+ return ReceiveBuffer{buf_ + size_, capacity_ - size_};
+}
+
+bool BufferedFrameDeserializer::EndReceive(size_t recv_size) {
+ PERFETTO_CHECK(recv_size + size_ <= capacity_);
+ size_ += recv_size;
+
+ // At this point the contents buf_ can contain:
+ // A) Only a fragment of the header (the size of the frame). E.g.,
+ // 03 00 00 (the header is 4 bytes, one is missing).
+ //
+ // B) A header and a part of the frame. E.g.,
+ // 05 00 00 00 11 22 33
+ // [ header, size=5 ] [ Partial frame ]
+ //
+ // C) One or more complete header+frame. E.g.,
+ // 05 00 00 00 11 22 33 44 55 03 00 00 00 AA BB CC
+ // [ header, size=5 ] [ Whole frame ] [ header, size=3 ] [ Whole frame ]
+ //
+ // D) Some complete header+frame(s) and a partial header or frame (C + A/B).
+ //
+ // C Is the more likely case and the one we are optimizing for. A, B, D can
+ // happen because of the streaming nature of the socket.
+ // The invariant of this function is that, when it returns, buf_ is either
+ // empty (we drained all the complete frames) or starts with the header of the
+ // next, still incomplete, frame.
+
+ size_t consumed_size = 0;
+ for (;;) {
+ if (size_ < consumed_size + kHeaderSize)
+ break; // Case A, not enough data to read even the header.
+
+ // Read the header into |payload_size|.
+ uint32_t payload_size = 0;
+ const char* rd_ptr = buf_ + consumed_size;
+ memcpy(base::AssumeLittleEndian(&payload_size), rd_ptr, kHeaderSize);
+
+ // Saturate the |payload_size| to prevent overflows. The > capacity_ check
+ // below will abort the parsing.
+ size_t next_frame_size =
+ std::min(static_cast<size_t>(payload_size), capacity_);
+ next_frame_size += kHeaderSize;
+ rd_ptr += kHeaderSize;
+
+ if (size_ < consumed_size + next_frame_size) {
+ // Case B. We got the header but not the whole frame.
+ if (next_frame_size > capacity_) {
+ // The caller is expected to shut down the socket and give up at this
+ // point. If it doesn't do that and insists going on at some point it
+ // will hit the capacity check in BeginReceive().
+ PERFETTO_DLOG("Frame too large (size %zu)", next_frame_size);
+ return false;
+ }
+ break;
+ }
+
+ // Case C. We got at least one header and whole frame.
+ DecodeFrame(rd_ptr, payload_size);
+ consumed_size += next_frame_size;
+ }
+
+ PERFETTO_DCHECK(consumed_size <= size_);
+ if (consumed_size > 0) {
+ // Shift out the consumed data from the buffer. In the typical case (C)
+ // there is nothing to shift really, just setting size_ = 0 is enough.
+ // Shifting is only for the (unlikely) case D.
+ size_ -= consumed_size;
+ if (size_ > 0) {
+ // Case D. We consumed some frames but there is a leftover at the end of
+ // the buffer. Shift out the consumed bytes, so that on the next round
+ // |buf_| starts with the header of the next unconsumed frame.
+ const char* move_begin = buf_ + consumed_size;
+ PERFETTO_CHECK(move_begin > buf_);
+ PERFETTO_CHECK(move_begin + size_ <= buf_ + capacity_);
+ memmove(buf_, move_begin, size_);
+ }
+ // If we just finished decoding a large frame that used more than one page,
+ // release the extra memory in the buffer. Large frames should be quite
+ // rare.
+ if (consumed_size > kPageSize) {
+ size_t size_rounded_up = (size_ / kPageSize + 1) * kPageSize;
+ if (size_rounded_up < capacity_) {
+ char* madvise_begin = buf_ + size_rounded_up;
+ const size_t madvise_size = capacity_ - size_rounded_up;
+ PERFETTO_CHECK(madvise_begin > buf_ + size_);
+ PERFETTO_CHECK(madvise_begin + madvise_size <= buf_ + capacity_);
+ int res = madvise(madvise_begin, madvise_size, MADV_DONTNEED);
+ PERFETTO_DCHECK(res == 0);
+ }
+ }
+ }
+ // At this point |size_| == 0 for case C, > 0 for cases A, B, D.
+ return true;
+}
+
+std::unique_ptr<Frame> BufferedFrameDeserializer::PopNextFrame() {
+ if (decoded_frames_.empty())
+ return nullptr;
+ std::unique_ptr<Frame> frame = std::move(decoded_frames_.front());
+ decoded_frames_.pop_front();
+ return frame;
+}
+
+void BufferedFrameDeserializer::DecodeFrame(const char* data, size_t size) {
+ if (size == 0)
+ return;
+ std::unique_ptr<Frame> frame(new Frame);
+ const int sz = static_cast<int>(size);
+ ::google::protobuf::io::ArrayInputStream stream(data, sz);
+ if (frame->ParseFromBoundedZeroCopyStream(&stream, sz))
+ decoded_frames_.push_back(std::move(frame));
+}
+
+// static
+std::string BufferedFrameDeserializer::Serialize(const Frame& frame) {
+ std::string buf;
+ buf.reserve(1024); // Just an educated guess to avoid trivial expansions.
+ buf.insert(0, kHeaderSize, 0); // Reserve the space for the header.
+ frame.AppendToString(&buf);
+ const uint32_t payload_size = static_cast<uint32_t>(buf.size() - kHeaderSize);
+ PERFETTO_DCHECK(payload_size == static_cast<uint32_t>(frame.GetCachedSize()));
+ char header[kHeaderSize];
+ memcpy(header, base::AssumeLittleEndian(&payload_size), kHeaderSize);
+ buf.replace(0, kHeaderSize, header, kHeaderSize);
+ return buf;
+}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/buffered_frame_deserializer.h b/src/ipc/buffered_frame_deserializer.h
new file mode 100644
index 0000000..874a7a2
--- /dev/null
+++ b/src/ipc/buffered_frame_deserializer.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_IPC_BUFFERED_FRAME_DESERIALIZER_H_
+#define SRC_IPC_BUFFERED_FRAME_DESERIALIZER_H_
+
+#include <stddef.h>
+
+#include <list>
+#include <memory>
+
+#include "perfetto/base/utils.h"
+
+namespace perfetto {
+namespace ipc {
+
+class Frame; // Defined in the protobuf autogenerated wire_protocol.pb.h.
+
+// Deserializes incoming frames, taking care of buffering and tokenization.
+// Used by both host and client to decode incoming frames.
+//
+// Which problem does it solve?
+// ----------------------------
+// The wire protocol is as follows:
+// [32-bit frame size][proto-encoded Frame], e.g:
+// [06 00 00 00][00 11 22 33 44 55 66]
+// [02 00 00 00][AA BB]
+// [04 00 00 00][CC DD EE FF]
+// However, given that the socket works in SOCK_STREAM mode, the recv() calls
+// might see the following:
+// 06 00 00
+// 00 00 11 22 33 44 55
+// 66 02 00 00 00 ...
+// This class takes care of buffering efficiently the data received, without
+// making any assumption on how the incoming data will be chunked by the socket.
+// For instance, it is possible that a recv() doesn't produce any frame (because
+// it received only a part of the frame) or produces more than one frame.
+//
+// Usage
+// -----
+// Both host and client use this as follows:
+//
+// auto buf = rpc_frame_decoder.BeginReceive();
+// size_t rsize = socket.recv(buf.first, buf.second);
+// rpc_frame_decoder.EndReceive(rsize);
+// while (Frame frame = rpc_frame_decoder.PopNextFrame()) {
+// ... process |frame|
+// }
+//
+// Design goals:
+// -------------
+// - Optimize for the realistic case of each recv() receiving one or more
+// whole frames. In this case no memmove is performed.
+// - Guarantee that frames lay in a virtually contiguous memory area.
+// This allows to use the protobuf-lite deserialization API (scattered
+// deserialization is supported only by libprotobuf-full).
+// - Put a hard boundary to the size of the incoming buffer. This is to prevent
+// that a malicious sends an abnormally large frame and OOMs us.
+// - Simplicity: just use a linear mmap region. No reallocations or scattering.
+// Takes care of madvise()-ing unused memory.
+
+class BufferedFrameDeserializer {
+ public:
+ struct ReceiveBuffer {
+ char* data;
+ size_t size;
+ };
+
+ explicit BufferedFrameDeserializer(size_t max_capacity = 128 * 1024);
+ ~BufferedFrameDeserializer();
+
+ // This function doesn't really belong here as it does Serialization, unlike
+ // the rest of this class. However it is so small and has so many dependencies
+ // in common that doesn't justify having its own class.
+ static std::string Serialize(const Frame&);
+
+ // Returns a buffer that can be passed to recv(). The buffer is deliberately
+ // not initialized.
+ ReceiveBuffer BeginReceive();
+
+ // Must be called soon after BeginReceive().
+ // |recv_size| is the number of valid bytes that have been written into the
+ // buffer previously returned by BeginReceive() (the return value of recv()).
+ // Returns false if a header > |max_capacity| is received, in which case the
+ // caller is expected to shutdown the socket and terminate the ipc.
+ bool EndReceive(size_t recv_size) __attribute__((warn_unused_result));
+
+ // Decodes and returns the next decoded frame in the buffer if any, nullptr
+ // if no further frames have been decoded.
+ std::unique_ptr<Frame> PopNextFrame();
+
+ size_t capacity() const { return capacity_; }
+ size_t size() const { return size_; }
+
+ private:
+ BufferedFrameDeserializer(const BufferedFrameDeserializer&) = delete;
+ BufferedFrameDeserializer& operator=(const BufferedFrameDeserializer&) =
+ delete;
+
+ // If a valid frame is decoded it is added to |decoded_frames_|.
+ void DecodeFrame(const char*, size_t);
+
+ char* buf_ = nullptr;
+ const size_t capacity_ = 0; // sizeof(|buf_|), excluding the guard region.
+
+ // THe number of bytes in |buf_| that contain valid data (as a result of
+ // EndReceive()). This is always <= |capacity_|.
+ size_t size_ = 0;
+
+ std::list<std::unique_ptr<Frame>> decoded_frames_;
+};
+
+} // namespace ipc
+} // namespace perfetto
+
+#endif // SRC_IPC_BUFFERED_FRAME_DESERIALIZER_H_
diff --git a/src/ipc/buffered_frame_deserializer_unittest.cc b/src/ipc/buffered_frame_deserializer_unittest.cc
new file mode 100644
index 0000000..3058da0
--- /dev/null
+++ b/src/ipc/buffered_frame_deserializer_unittest.cc
@@ -0,0 +1,381 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/buffered_frame_deserializer.h"
+
+#include <algorithm>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/utils.h"
+
+#include "src/ipc/wire_protocol.pb.h"
+
+namespace perfetto {
+namespace ipc {
+namespace {
+
+constexpr size_t kHeaderSize = sizeof(uint32_t);
+
+// Generates a parsable Frame of exactly |size| bytes (including header).
+std::vector<char> GetSimpleFrame(size_t size) {
+ // A bit of reverse math of the proto encoding: a Frame which has only the
+ // |data_for_testing| fields, will require for each data_for_testing that is
+ // up to 127 bytes:
+ // - 1 byte to write the field preamble (field type and id).
+ // - 1 byte to write the field size, if 0 < size <= 127.
+ // - N bytes for the actual content (|padding| below).
+ // So below we split the payload into chunks of <= 127 bytes, keeping into
+ // account the extra 2 bytes for each chunk.
+ Frame frame;
+ std::vector<char> padding;
+ char padding_char = '0';
+ const size_t payload_size = size - kHeaderSize;
+ for (size_t size_left = payload_size; size_left > 0;) {
+ PERFETTO_CHECK(size_left >= 2); // We cannot produce frames < 2 bytes.
+ size_t padding_size;
+ if (size_left <= 127) {
+ padding_size = size_left - 2;
+ size_left = 0;
+ } else {
+ padding_size = 124;
+ size_left -= padding_size + 2;
+ }
+ padding.resize(padding_size);
+ for (size_t i = 0; i < padding_size; i++) {
+ padding_char = padding_char == 'z' ? '0' : padding_char + 1;
+ padding[i] = padding_char;
+ }
+ frame.add_data_for_testing(padding.data(), padding_size);
+ }
+ PERFETTO_CHECK(frame.ByteSize() == static_cast<int>(payload_size));
+ std::vector<char> encoded_frame;
+ encoded_frame.resize(size);
+ char* enc_buf = encoded_frame.data();
+ PERFETTO_CHECK(frame.SerializeToArray(enc_buf + kHeaderSize, payload_size));
+ memcpy(enc_buf, base::AssumeLittleEndian(&payload_size), kHeaderSize);
+ PERFETTO_CHECK(encoded_frame.size() == size);
+ return encoded_frame;
+}
+
+void CheckedMemcpy(BufferedFrameDeserializer::ReceiveBuffer rbuf,
+ const std::vector<char>& encoded_frame,
+ size_t offset = 0) {
+ ASSERT_GE(rbuf.size, encoded_frame.size() + offset);
+ memcpy(rbuf.data + offset, encoded_frame.data(), encoded_frame.size());
+}
+
+bool FrameEq(std::vector<char> expected_frame_with_header, const Frame& frame) {
+ std::string reserialized_frame = frame.SerializeAsString();
+
+ size_t expected_size = expected_frame_with_header.size() - kHeaderSize;
+ EXPECT_EQ(expected_size, reserialized_frame.size());
+ if (expected_size != reserialized_frame.size())
+ return false;
+
+ return memcmp(reserialized_frame.data(),
+ expected_frame_with_header.data() + kHeaderSize,
+ reserialized_frame.size()) == 0;
+}
+
+// Tests the simple case where each recv() just returns one whole header+frame.
+TEST(BufferedFrameDeserializerTest, WholeMessages) {
+ BufferedFrameDeserializer bfd;
+ for (int i = 1; i <= 50; i++) {
+ const size_t size = i * 10;
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+
+ ASSERT_NE(nullptr, rbuf.data);
+ std::vector<char> frame = GetSimpleFrame(size);
+ CheckedMemcpy(rbuf, frame);
+ ASSERT_TRUE(bfd.EndReceive(frame.size()));
+
+ // Excactly one frame should be decoded, with no leftover buffer.
+ auto decoded_frame = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame);
+ ASSERT_EQ(static_cast<int32_t>(size - kHeaderSize),
+ decoded_frame->ByteSize());
+ ASSERT_FALSE(bfd.PopNextFrame());
+ ASSERT_EQ(0u, bfd.size());
+ }
+}
+
+// Sends first a simple test frame. Then creates a realistic Frame fragmenting
+// it in three chunks and tests that the decoded Frame matches the original one.
+// The recv() sequence is as follows:
+// 1. [ simple_frame ] [ frame_chunk1 ... ]
+// 2. [ ... frame_chunk2 ... ]
+// 3. [ ... frame_chunk3 ]
+TEST(BufferedFrameDeserializerTest, FragmentedFrameIsCorrectlyDeserialized) {
+ BufferedFrameDeserializer bfd;
+ Frame frame;
+ frame.set_request_id(42);
+ auto* bind_reply = frame.mutable_msg_bind_service_reply();
+ bind_reply->set_success(true);
+ bind_reply->set_service_id(0x4242);
+ auto* method = bind_reply->add_methods();
+ method->set_id(0x424242);
+ method->set_name("foo");
+ std::vector<char> serialized_frame;
+ uint32_t payload_size = frame.ByteSize();
+
+ serialized_frame.resize(kHeaderSize + payload_size);
+ ASSERT_TRUE(frame.SerializeToArray(serialized_frame.data() + kHeaderSize,
+ payload_size));
+ memcpy(serialized_frame.data(), base::AssumeLittleEndian(&payload_size),
+ kHeaderSize);
+
+ std::vector<char> simple_frame = GetSimpleFrame(32);
+ std::vector<char> frame_chunk1(serialized_frame.begin(),
+ serialized_frame.begin() + 5);
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, simple_frame);
+ CheckedMemcpy(rbuf, frame_chunk1, simple_frame.size());
+ ASSERT_TRUE(bfd.EndReceive(simple_frame.size() + frame_chunk1.size()));
+
+ std::vector<char> frame_chunk2(serialized_frame.begin() + 5,
+ serialized_frame.begin() + 10);
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame_chunk2);
+ ASSERT_TRUE(bfd.EndReceive(frame_chunk2.size()));
+
+ std::vector<char> frame_chunk3(serialized_frame.begin() + 10,
+ serialized_frame.end());
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame_chunk3);
+ ASSERT_TRUE(bfd.EndReceive(frame_chunk3.size()));
+
+ // Validate the received frame2.
+ std::unique_ptr<Frame> decoded_simple_frame = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_simple_frame);
+ ASSERT_EQ(static_cast<int32_t>(simple_frame.size() - kHeaderSize),
+ decoded_simple_frame->ByteSize());
+
+ std::unique_ptr<Frame> decoded_frame = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame);
+ ASSERT_TRUE(FrameEq(serialized_frame, *decoded_frame));
+}
+
+// Tests the case of a EndReceive(0) while receiving a valid frame in chunks.
+TEST(BufferedFrameDeserializerTest, ZeroSizedReceive) {
+ BufferedFrameDeserializer bfd;
+ std::vector<char> frame = GetSimpleFrame(100);
+ std::vector<char> frame_chunk1(frame.begin(), frame.begin() + 50);
+ std::vector<char> frame_chunk2(frame.begin() + 50, frame.end());
+
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame_chunk1);
+ ASSERT_TRUE(bfd.EndReceive(frame_chunk1.size()));
+
+ rbuf = bfd.BeginReceive();
+ ASSERT_TRUE(bfd.EndReceive(0));
+
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame_chunk2);
+ ASSERT_TRUE(bfd.EndReceive(frame_chunk2.size()));
+
+ // Excactly one frame should be decoded, with no leftover buffer.
+ std::unique_ptr<Frame> decoded_frame = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame);
+ ASSERT_TRUE(FrameEq(frame, *decoded_frame));
+ ASSERT_FALSE(bfd.PopNextFrame());
+ ASSERT_EQ(0u, bfd.size());
+}
+
+// Tests the case of a EndReceive(4) where the header has no payload. The frame
+// should be just skipped and not returned by PopNextFrame().
+TEST(BufferedFrameDeserializerTest, EmptyPayload) {
+ BufferedFrameDeserializer bfd;
+ std::vector<char> frame = GetSimpleFrame(100);
+
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ std::vector<char> empty_frame(kHeaderSize, 0);
+ CheckedMemcpy(rbuf, empty_frame);
+ ASSERT_TRUE(bfd.EndReceive(kHeaderSize));
+
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame);
+ ASSERT_TRUE(bfd.EndReceive(frame.size()));
+
+ // |fram| should be properly decoded.
+ std::unique_ptr<Frame> decoded_frame = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame);
+ ASSERT_TRUE(FrameEq(frame, *decoded_frame));
+ ASSERT_FALSE(bfd.PopNextFrame());
+}
+
+// Test the case where a single Receive() returns batches of > 1 whole frames.
+// See case C in the comments for BufferedFrameDeserializer::EndReceive().
+TEST(BufferedFrameDeserializerTest, MultipleFramesInOneReceive) {
+ BufferedFrameDeserializer bfd;
+ std::vector<std::vector<size_t>> frame_batch_sizes(
+ {{11}, {13, 17, 19}, {23}, {29, 31}});
+
+ for (std::vector<size_t>& batch : frame_batch_sizes) {
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ size_t frame_offset_in_batch = 0;
+ for (size_t frame_size : batch) {
+ auto frame = GetSimpleFrame(frame_size);
+ CheckedMemcpy(rbuf, frame, frame_offset_in_batch);
+ frame_offset_in_batch += frame.size();
+ }
+ ASSERT_TRUE(bfd.EndReceive(frame_offset_in_batch));
+ for (size_t expected_size : batch) {
+ auto frame = bfd.PopNextFrame();
+ ASSERT_TRUE(frame);
+ ASSERT_EQ(static_cast<int32_t>(expected_size - kHeaderSize),
+ frame->ByteSize());
+ }
+ ASSERT_FALSE(bfd.PopNextFrame());
+ ASSERT_EQ(0u, bfd.size());
+ }
+}
+
+TEST(BufferedFrameDeserializerTest, RejectVeryLargeFrames) {
+ BufferedFrameDeserializer bfd;
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ const uint32_t kBigSize = std::numeric_limits<uint32_t>::max() - 2;
+ memcpy(rbuf.data, base::AssumeLittleEndian(&kBigSize), kHeaderSize);
+ memcpy(rbuf.data + kHeaderSize, "some initial payload", 20);
+ ASSERT_FALSE(bfd.EndReceive(kHeaderSize + 20));
+}
+
+// Tests the extreme case of recv() fragmentation. Two valid frames are received
+// but each recv() puts one byte at a time. Covers cases A and B commented in
+// BufferedFrameDeserializer::EndReceive().
+TEST(BufferedFrameDeserializerTest, HighlyFragmentedFrames) {
+ BufferedFrameDeserializer bfd;
+ for (int i = 1; i <= 50; i++) {
+ std::vector<char> frame = GetSimpleFrame(i * 100);
+ for (size_t off = 0; off < frame.size(); off++) {
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, {frame[off]});
+
+ // The frame should be available only when receiving the last byte.
+ ASSERT_TRUE(bfd.EndReceive(1));
+ if (off < frame.size() - 1) {
+ ASSERT_FALSE(bfd.PopNextFrame()) << off << "/" << frame.size();
+ ASSERT_EQ(off + 1, bfd.size());
+ } else {
+ ASSERT_TRUE(bfd.PopNextFrame());
+ }
+ }
+ }
+}
+
+// A bunch of valid frames interleaved with frames that have a valid header
+// but unparsable payload. The expectation is that PopNextFrame() returns
+// nullptr for the unparsable frames but the other frames are decoded peroperly.
+TEST(BufferedFrameDeserializerTest, CanRecoverAfterUnparsableFrames) {
+ BufferedFrameDeserializer bfd;
+ for (int i = 1; i <= 50; i++) {
+ const size_t size = i * 10;
+ std::vector<char> frame = GetSimpleFrame(size);
+ const bool unparsable = (i % 3) == 1;
+ if (unparsable)
+ memset(frame.data() + kHeaderSize, 0xFF, size - kHeaderSize);
+
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame);
+ ASSERT_TRUE(bfd.EndReceive(frame.size()));
+
+ // Excactly one frame should be decoded if |parsable|. In any case no
+ // leftover bytes should be left in the buffer.
+ auto decoded_frame = bfd.PopNextFrame();
+ if (unparsable) {
+ ASSERT_FALSE(decoded_frame);
+ } else {
+ ASSERT_TRUE(decoded_frame);
+ ASSERT_EQ(static_cast<int32_t>(size - kHeaderSize),
+ decoded_frame->ByteSize());
+ }
+ ASSERT_EQ(0u, bfd.size());
+ }
+}
+
+// Test that we can sustain recvs() which constantly max out the capacity.
+// It sets up four frames:
+// |frame1|: small, 1024 + 4 bytes.
+// |frame2|: as big as the |kMaxCapacity|. Its recv() is split into two chunks.
+// |frame3|: together with the 2nd part of |frame2| it maxes out capacity again.
+// |frame4|: as big as the |kMaxCapacity|. Received in one recv(), no splits.
+//
+// Which are then recv()'d in a loop in the following way.
+// |------------ max recv capacity ------------|
+// 1. [ frame1 ] [ frame2_chunk1 ..... ]
+// 2. [ ... frame2_chunk2 ]
+// 3. [ frame3 ]
+// 4. [ frame 4 ]
+TEST(BufferedFrameDeserializerTest, FillCapacity) {
+ size_t kMaxCapacity = 1024 * 16;
+ BufferedFrameDeserializer bfd(kMaxCapacity);
+
+ for (int i = 0; i < 3; i++) {
+ std::vector<char> frame1 = GetSimpleFrame(1024);
+ std::vector<char> frame2 = GetSimpleFrame(kMaxCapacity);
+ std::vector<char> frame2_chunk1(
+ frame2.begin(), frame2.begin() + kMaxCapacity - frame1.size());
+ std::vector<char> frame2_chunk2(frame2.begin() + frame2_chunk1.size(),
+ frame2.end());
+ std::vector<char> frame3 =
+ GetSimpleFrame(kMaxCapacity - frame2_chunk2.size());
+ std::vector<char> frame4 = GetSimpleFrame(kMaxCapacity);
+ ASSERT_EQ(kMaxCapacity, frame1.size() + frame2_chunk1.size());
+ ASSERT_EQ(kMaxCapacity, frame2_chunk1.size() + frame2_chunk2.size());
+ ASSERT_EQ(kMaxCapacity, frame2_chunk2.size() + frame3.size());
+ ASSERT_EQ(kMaxCapacity, frame4.size());
+
+ BufferedFrameDeserializer::ReceiveBuffer rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame1);
+ CheckedMemcpy(rbuf, frame2_chunk1, frame1.size());
+ ASSERT_TRUE(bfd.EndReceive(frame1.size() + frame2_chunk1.size()));
+
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame2_chunk2);
+ ASSERT_TRUE(bfd.EndReceive(frame2_chunk2.size()));
+
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame3);
+ ASSERT_TRUE(bfd.EndReceive(frame3.size()));
+
+ rbuf = bfd.BeginReceive();
+ CheckedMemcpy(rbuf, frame4);
+ ASSERT_TRUE(bfd.EndReceive(frame4.size()));
+
+ std::unique_ptr<Frame> decoded_frame_1 = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame_1);
+ ASSERT_TRUE(FrameEq(frame1, *decoded_frame_1));
+
+ std::unique_ptr<Frame> decoded_frame_2 = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame_2);
+ ASSERT_TRUE(FrameEq(frame2, *decoded_frame_2));
+
+ std::unique_ptr<Frame> decoded_frame_3 = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame_3);
+ ASSERT_TRUE(FrameEq(frame3, *decoded_frame_3));
+
+ std::unique_ptr<Frame> decoded_frame_4 = bfd.PopNextFrame();
+ ASSERT_TRUE(decoded_frame_4);
+ ASSERT_TRUE(FrameEq(frame4, *decoded_frame_4));
+
+ ASSERT_FALSE(bfd.PopNextFrame());
+ }
+}
+
+} // namespace
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/client_impl.cc b/src/ipc/client_impl.cc
new file mode 100644
index 0000000..e1c4a7c
--- /dev/null
+++ b/src/ipc/client_impl.cc
@@ -0,0 +1,258 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/client_impl.h"
+
+#include <inttypes.h>
+
+#include "perfetto/base/task_runner.h"
+#include "perfetto/base/utils.h"
+#include "perfetto/ipc/service_descriptor.h"
+#include "perfetto/ipc/service_proxy.h"
+
+// TODO(primiano): Add ThreadChecker everywhere.
+
+// TODO(primiano): Add timeouts.
+
+namespace perfetto {
+namespace ipc {
+
+// static
+std::unique_ptr<Client> Client::CreateInstance(const char* socket_name,
+ base::TaskRunner* task_runner) {
+ std::unique_ptr<Client> client(new ClientImpl(socket_name, task_runner));
+ return client;
+}
+
+ClientImpl::ClientImpl(const char* socket_name, base::TaskRunner* task_runner)
+ : task_runner_(task_runner), weak_ptr_factory_(this) {
+ GOOGLE_PROTOBUF_VERIFY_VERSION;
+ sock_ = UnixSocket::Connect(socket_name, this, task_runner);
+}
+
+ClientImpl::~ClientImpl() {
+ OnDisconnect(nullptr); // The UnixSocket* ptr is not used in OnDisconnect().
+}
+
+void ClientImpl::BindService(base::WeakPtr<ServiceProxy> service_proxy) {
+ if (!service_proxy)
+ return;
+ if (!sock_->is_connected())
+ return queued_bindings_.emplace_back(service_proxy);
+ RequestID request_id = ++last_request_id_;
+ Frame frame;
+ frame.set_request_id(request_id);
+ Frame::BindService* req = frame.mutable_msg_bind_service();
+ const char* const service_name = service_proxy->GetDescriptor().service_name;
+ req->set_service_name(service_name);
+ if (!SendFrame(frame)) {
+ PERFETTO_DLOG("BindService(%s) failed", service_name);
+ return service_proxy->OnConnect(false /* success */);
+ }
+ QueuedRequest qr;
+ qr.type = Frame::kMsgBindService;
+ qr.request_id = request_id;
+ qr.service_proxy = service_proxy;
+ queued_requests_.emplace(request_id, std::move(qr));
+}
+
+void ClientImpl::UnbindService(ServiceID service_id) {
+ service_bindings_.erase(service_id);
+}
+
+RequestID ClientImpl::BeginInvoke(ServiceID service_id,
+ const std::string& method_name,
+ MethodID remote_method_id,
+ const ProtoMessage& method_args,
+ base::WeakPtr<ServiceProxy> service_proxy) {
+ std::string args_proto;
+ RequestID request_id = ++last_request_id_;
+ Frame frame;
+ frame.set_request_id(request_id);
+ Frame::InvokeMethod* req = frame.mutable_msg_invoke_method();
+ req->set_service_id(service_id);
+ req->set_method_id(remote_method_id);
+ bool did_serialize = method_args.SerializeToString(&args_proto);
+ req->set_args_proto(args_proto);
+ if (!did_serialize || !SendFrame(frame)) {
+ PERFETTO_DLOG("BeginInvoke() failed while sending the frame");
+ return 0;
+ }
+ QueuedRequest qr;
+ qr.type = Frame::kMsgInvokeMethod;
+ qr.request_id = request_id;
+ qr.method_name = method_name;
+ qr.service_proxy = service_proxy;
+ queued_requests_.emplace(request_id, std::move(qr));
+ return request_id;
+}
+
+bool ClientImpl::SendFrame(const Frame& frame) {
+ // Serialize the frame into protobuf, add the size header, and send it.
+ std::string buf = BufferedFrameDeserializer::Serialize(frame);
+
+ // TODO(primiano): remember that this is doing non-blocking I/O. What if the
+ // socket buffer is full? Maybe we just want to drop this on the floor? Or
+ // maybe throttle the send and PostTask the reply later?
+ bool res = sock_->Send(buf.data(), buf.size());
+ PERFETTO_CHECK(!sock_->is_connected() || res);
+ return res;
+}
+
+void ClientImpl::OnConnect(UnixSocket*, bool connected) {
+ // Drain the BindService() calls that were queued before establishig the
+ // connection with the host.
+ for (base::WeakPtr<ServiceProxy>& service_proxy : queued_bindings_) {
+ if (connected) {
+ BindService(service_proxy);
+ } else if (service_proxy) {
+ service_proxy->OnConnect(false /* success */);
+ }
+ }
+ queued_bindings_.clear();
+}
+
+void ClientImpl::OnDisconnect(UnixSocket*) {
+ for (auto it : service_bindings_) {
+ base::WeakPtr<ServiceProxy>& service_proxy = it.second;
+ task_runner_->PostTask([service_proxy] {
+ if (service_proxy)
+ service_proxy->OnDisconnect();
+ });
+ }
+ service_bindings_.clear();
+ queued_bindings_.clear();
+}
+
+void ClientImpl::OnDataAvailable(UnixSocket*) {
+ size_t rsize;
+ do {
+ auto buf = frame_deserializer_.BeginReceive();
+ base::ScopedFile fd;
+ rsize = sock_->Receive(buf.data, buf.size, &fd);
+ if (fd) {
+ PERFETTO_DCHECK(!received_fd_);
+ received_fd_ = std::move(fd);
+ }
+ if (!frame_deserializer_.EndReceive(rsize)) {
+ // The endpoint tried to send a frame that is way too large.
+ return sock_->Shutdown(); // In turn will trigger an OnDisconnect().
+ // TODO check this.
+ }
+ } while (rsize > 0);
+
+ while (std::unique_ptr<Frame> frame = frame_deserializer_.PopNextFrame())
+ OnFrameReceived(*frame);
+}
+
+void ClientImpl::OnFrameReceived(const Frame& frame) {
+ auto queued_requests_it = queued_requests_.find(frame.request_id());
+ if (queued_requests_it == queued_requests_.end()) {
+ PERFETTO_DLOG("OnFrameReceived(): got invalid request_id=%" PRIu64,
+ static_cast<uint64_t>(frame.request_id()));
+ return;
+ }
+ QueuedRequest req = std::move(queued_requests_it->second);
+ queued_requests_.erase(queued_requests_it);
+
+ if (req.type == Frame::kMsgBindService &&
+ frame.msg_case() == Frame::kMsgBindServiceReply) {
+ return OnBindServiceReply(std::move(req), frame.msg_bind_service_reply());
+ }
+ if (req.type == Frame::kMsgInvokeMethod &&
+ frame.msg_case() == Frame::kMsgInvokeMethodReply) {
+ return OnInvokeMethodReply(std::move(req), frame.msg_invoke_method_reply());
+ }
+ if (frame.msg_case() == Frame::kMsgRequestError) {
+ PERFETTO_DLOG("Host error: %s", frame.msg_request_error().error().c_str());
+ return;
+ }
+
+ PERFETTO_DLOG(
+ "OnFrameReceived() request msg_type=%d, received msg_type=%d in reply to "
+ "request_id=%" PRIu64,
+ req.type, frame.msg_case(), static_cast<uint64_t>(frame.request_id()));
+}
+
+void ClientImpl::OnBindServiceReply(QueuedRequest req,
+ const Frame::BindServiceReply& reply) {
+ base::WeakPtr<ServiceProxy>& service_proxy = req.service_proxy;
+ if (!service_proxy)
+ return;
+ const char* svc_name = service_proxy->GetDescriptor().service_name;
+ if (!reply.success()) {
+ PERFETTO_DLOG("BindService(): unknown service_name=\"%s\"", svc_name);
+ return service_proxy->OnConnect(false /* success */);
+ }
+
+ auto prev_service = service_bindings_.find(reply.service_id());
+ if (prev_service != service_bindings_.end() && prev_service->second.get()) {
+ PERFETTO_DLOG(
+ "BindService(): Trying to bind service \"%s\" but another service "
+ "named \"%s\" is already bound with the same ID.",
+ svc_name, prev_service->second->GetDescriptor().service_name);
+ return service_proxy->OnConnect(false /* success */);
+ }
+
+ // Build the method [name] -> [remote_id] map.
+ std::map<std::string, MethodID> methods;
+ for (const auto& method : reply.methods()) {
+ if (method.name().empty() || method.id() <= 0) {
+ PERFETTO_DLOG("OnBindServiceReply(): invalid method \"%s\" -> %" PRIu32,
+ method.name().c_str(), method.id());
+ continue;
+ }
+ methods[method.name()] = method.id();
+ }
+ service_proxy->InitializeBinding(weak_ptr_factory_.GetWeakPtr(),
+ reply.service_id(), std::move(methods));
+ service_bindings_[reply.service_id()] = service_proxy;
+ service_proxy->OnConnect(true /* success */);
+}
+
+void ClientImpl::OnInvokeMethodReply(QueuedRequest req,
+ const Frame::InvokeMethodReply& reply) {
+ base::WeakPtr<ServiceProxy> service_proxy = req.service_proxy;
+ if (!service_proxy)
+ return;
+ std::unique_ptr<ProtoMessage> decoded_reply;
+ if (reply.success()) {
+ // TODO this could be optimized, stop doing method name string lookups.
+ for (const auto& method : service_proxy->GetDescriptor().methods) {
+ if (req.method_name == method.name) {
+ decoded_reply = method.reply_proto_decoder(reply.reply_proto());
+ break;
+ }
+ }
+ }
+ const RequestID request_id = req.request_id;
+ service_proxy->EndInvoke(request_id, std::move(decoded_reply),
+ reply.has_more());
+
+ // If this is a streaming method and future replies will be resolved, put back
+ // the |req| with the callback into the set of active requests.
+ if (reply.has_more())
+ queued_requests_.emplace(request_id, std::move(req));
+}
+
+ClientImpl::QueuedRequest::QueuedRequest() = default;
+
+base::ScopedFile ClientImpl::TakeReceivedFD() {
+ return std::move(received_fd_);
+}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/client_impl.h b/src/ipc/client_impl.h
new file mode 100644
index 0000000..71934ed
--- /dev/null
+++ b/src/ipc/client_impl.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_IPC_CLIENT_IMPL_H_
+#define SRC_IPC_CLIENT_IMPL_H_
+
+#include "perfetto/base/scoped_file.h"
+#include "perfetto/base/task_runner.h"
+#include "perfetto/ipc/client.h"
+#include "src/ipc/buffered_frame_deserializer.h"
+#include "src/ipc/unix_socket.h"
+
+#include "src/ipc/wire_protocol.pb.h"
+
+#include <list>
+#include <map>
+#include <memory>
+
+namespace perfetto {
+
+namespace base {
+class TaskRunner;
+} // namespace base
+
+namespace ipc {
+
+class ServiceDescriptor;
+
+class ClientImpl : public Client, public UnixSocket::EventListener {
+ public:
+ ClientImpl(const char* socket_name, base::TaskRunner*);
+ ~ClientImpl() override;
+
+ // Client implementation.
+ void BindService(base::WeakPtr<ServiceProxy>) override;
+ void UnbindService(ServiceID) override;
+ base::ScopedFile TakeReceivedFD() override;
+
+ // UnixSocket::EventListener implementation.
+ void OnConnect(UnixSocket*, bool connected) override;
+ void OnDisconnect(UnixSocket*) override;
+ void OnDataAvailable(UnixSocket*) override;
+
+ RequestID BeginInvoke(ServiceID,
+ const std::string& method_name,
+ MethodID remote_method_id,
+ const ProtoMessage& method_args,
+ base::WeakPtr<ServiceProxy>);
+
+ private:
+ struct QueuedRequest {
+ QueuedRequest();
+ int type = 0; // From Frame::msg_case(), see wire_protocol.proto.
+ RequestID request_id = 0;
+ base::WeakPtr<ServiceProxy> service_proxy;
+
+ // Only for type == kMsgInvokeMethod.
+ std::string method_name;
+ };
+
+ ClientImpl(const ClientImpl&) = delete;
+ ClientImpl& operator=(const ClientImpl&) = delete;
+
+ bool SendFrame(const Frame&);
+ void OnFrameReceived(const Frame&);
+ void OnBindServiceReply(QueuedRequest, const Frame::BindServiceReply&);
+ void OnInvokeMethodReply(QueuedRequest, const Frame::InvokeMethodReply&);
+
+ std::unique_ptr<UnixSocket> sock_;
+ base::TaskRunner* const task_runner_;
+ RequestID last_request_id_ = 0;
+ BufferedFrameDeserializer frame_deserializer_;
+ base::ScopedFile received_fd_;
+ std::map<RequestID, QueuedRequest> queued_requests_;
+ std::map<ServiceID, base::WeakPtr<ServiceProxy>> service_bindings_;
+ base::WeakPtrFactory<Client> weak_ptr_factory_;
+
+ // Queue of calls to BindService() that happened before the socket connected.
+ std::list<base::WeakPtr<ServiceProxy>> queued_bindings_;
+};
+
+} // namespace ipc
+} // namespace perfetto
+
+#endif // SRC_IPC_CLIENT_IMPL_H_
diff --git a/src/ipc/client_impl_unittest.cc b/src/ipc/client_impl_unittest.cc
new file mode 100644
index 0000000..769ee47
--- /dev/null
+++ b/src/ipc/client_impl_unittest.cc
@@ -0,0 +1,499 @@
+/*
+ * Copyright (C) 2017 The Android Open foo Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/client_impl.h"
+
+#include <stdio.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "perfetto/base/utils.h"
+#include "perfetto/ipc/service_descriptor.h"
+#include "perfetto/ipc/service_proxy.h"
+#include "src/base/test/test_task_runner.h"
+#include "src/ipc/buffered_frame_deserializer.h"
+#include "src/ipc/unix_socket.h"
+
+#include "src/ipc/test/client_unittest_messages.pb.h"
+
+namespace perfetto {
+namespace ipc {
+namespace {
+
+using ::testing::_;
+using ::testing::InSequence;
+using ::testing::Invoke;
+using ::testing::Mock;
+
+constexpr char kSockName[] = "/tmp/perfetto_client_impl_unittest.sock";
+
+// A fake ServiceProxy. This fakes the client-side class that would be
+// auto-generated from .proto-files.
+class FakeProxy : public ServiceProxy {
+ public:
+ FakeProxy(const char* service_name, ServiceProxy::EventListener* el)
+ : ServiceProxy(el), service_name_(service_name) {}
+
+ const ServiceDescriptor& GetDescriptor() override {
+ auto reply_decoder = [](const std::string& proto) {
+ std::unique_ptr<ProtoMessage> reply(new ReplyProto());
+ EXPECT_TRUE(reply->ParseFromString(proto));
+ return reply;
+ };
+ if (!descriptor_.service_name) {
+ descriptor_.service_name = service_name_;
+ descriptor_.methods.push_back({"FakeMethod1", nullptr, reply_decoder});
+ }
+ return descriptor_;
+ }
+
+ const char* service_name_;
+ ServiceDescriptor descriptor_;
+};
+
+class MockEventListener : public ServiceProxy::EventListener {
+ public:
+ MOCK_METHOD0(OnConnect, void());
+ MOCK_METHOD0(OnDisconnect, void());
+};
+
+// A fake host implementation. Listens on |kSockName| and replies to IPC
+// metohds like a real one.
+class FakeHost : public UnixSocket::EventListener {
+ public:
+ struct FakeMethod {
+ MethodID id;
+ MOCK_METHOD2(OnInvoke,
+ void(const Frame::InvokeMethod&, Frame::InvokeMethodReply*));
+ }; // FakeMethod.
+
+ struct FakeService {
+ FakeMethod* AddFakeMethod(const std::string& name) {
+ auto it_and_inserted =
+ methods.emplace(name, std::unique_ptr<FakeMethod>(new FakeMethod()));
+ EXPECT_TRUE(it_and_inserted.second);
+ FakeMethod* method = it_and_inserted.first->second.get();
+ method->id = ++last_method_id;
+ return method;
+ }
+
+ ServiceID id;
+ std::map<std::string, std::unique_ptr<FakeMethod>> methods;
+ MethodID last_method_id = 0;
+ }; // FakeService.
+
+ explicit FakeHost(base::TaskRunner* task_runner) {
+ unlink(kSockName);
+ listening_sock = UnixSocket::Listen(kSockName, this, task_runner);
+ EXPECT_TRUE(listening_sock->is_listening());
+ }
+ ~FakeHost() override { unlink(kSockName); }
+
+ FakeService* AddFakeService(const std::string& name) {
+ auto it_and_inserted =
+ services.emplace(name, std::unique_ptr<FakeService>(new FakeService()));
+ EXPECT_TRUE(it_and_inserted.second);
+ FakeService* svc = it_and_inserted.first->second.get();
+ svc->id = ++last_service_id;
+ return svc;
+ }
+
+ // UnixSocket::EventListener implementation.
+ void OnNewIncomingConnection(
+ UnixSocket*,
+ std::unique_ptr<UnixSocket> new_connection) override {
+ ASSERT_FALSE(client_sock);
+ client_sock = std::move(new_connection);
+ }
+
+ void OnDataAvailable(UnixSocket* sock) override {
+ if (sock != client_sock.get())
+ return;
+ auto buf = frame_deserializer.BeginReceive();
+ size_t rsize = client_sock->Receive(buf.data, buf.size);
+ EXPECT_TRUE(frame_deserializer.EndReceive(rsize));
+ while (std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame())
+ OnFrameReceived(*frame);
+ }
+
+ void OnFrameReceived(const Frame& req) {
+ if (req.msg_case() == Frame::kMsgBindService) {
+ auto svc_it = services.find(req.msg_bind_service().service_name());
+ ASSERT_NE(services.end(), svc_it);
+ const FakeService& svc = *svc_it->second.get();
+ Frame reply;
+ reply.set_request_id(req.request_id());
+ reply.mutable_msg_bind_service_reply()->set_success(true);
+ reply.mutable_msg_bind_service_reply()->set_service_id(svc.id);
+ for (const auto& method_it : svc.methods) {
+ auto* method = reply.mutable_msg_bind_service_reply()->add_methods();
+ method->set_name(method_it.first);
+ method->set_id(method_it.second->id);
+ }
+ return Reply(reply);
+ } else if (req.msg_case() == Frame::kMsgInvokeMethod) {
+ // Lookup the service and method.
+ bool has_more = false;
+ do {
+ Frame reply;
+ reply.set_request_id(req.request_id());
+ for (const auto& svc : services) {
+ if (static_cast<int32_t>(svc.second->id) !=
+ req.msg_invoke_method().service_id())
+ continue;
+ for (const auto& method : svc.second->methods) {
+ if (static_cast<int32_t>(method.second->id) !=
+ req.msg_invoke_method().method_id())
+ continue;
+ method.second->OnInvoke(req.msg_invoke_method(),
+ reply.mutable_msg_invoke_method_reply());
+ has_more = reply.mutable_msg_invoke_method_reply()->has_more();
+ }
+ }
+ // If either the method or the service are not found, |success| will be
+ // false by default.
+ Reply(reply);
+ } while (has_more);
+ } else {
+ FAIL() << "Unknown request";
+ }
+ }
+
+ void Reply(const Frame& frame) {
+ auto buf = BufferedFrameDeserializer::Serialize(frame);
+ ASSERT_TRUE(client_sock->is_connected());
+ EXPECT_TRUE(client_sock->Send(buf.data(), buf.size(), next_reply_fd));
+ next_reply_fd = -1;
+ }
+
+ BufferedFrameDeserializer frame_deserializer;
+ std::unique_ptr<UnixSocket> listening_sock;
+ std::unique_ptr<UnixSocket> client_sock;
+ std::map<std::string, std::unique_ptr<FakeService>> services;
+ ServiceID last_service_id = 0;
+ int next_reply_fd = -1;
+}; // FakeHost.
+
+class ClientImplTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ task_runner_.reset(new base::TestTaskRunner());
+ host_.reset(new FakeHost(task_runner_.get()));
+ cli_ = Client::CreateInstance(kSockName, task_runner_.get());
+ }
+
+ void TearDown() override {
+ cli_.reset();
+ host_.reset();
+ task_runner_->RunUntilIdle();
+ task_runner_.reset();
+ }
+
+ ::testing::StrictMock<MockEventListener> proxy_events_;
+ std::unique_ptr<base::TestTaskRunner> task_runner_;
+ std::unique_ptr<FakeHost> host_;
+ std::unique_ptr<Client> cli_;
+};
+
+TEST_F(ClientImplTest, BindAndInvokeMethod) {
+ auto* host_svc = host_->AddFakeService("FakeSvc");
+ auto* host_method = host_svc->AddFakeMethod("FakeMethod1");
+
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+
+ // Bind |proxy| to the fake host.
+ cli_->BindService(proxy->GetWeakPtr());
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ // Invoke a valid method.
+ EXPECT_CALL(*host_method, OnInvoke(_, _))
+ .WillOnce(Invoke(
+ [](const Frame::InvokeMethod& req, Frame::InvokeMethodReply* reply) {
+ RequestProto req_args;
+ EXPECT_TRUE(req_args.ParseFromString(req.args_proto()));
+ EXPECT_EQ("req_data", req_args.data());
+ ReplyProto reply_args;
+ reply->set_reply_proto(reply_args.SerializeAsString());
+ reply->set_success(true);
+ }));
+
+ RequestProto req;
+ req.set_data("req_data");
+ auto on_invoke_reply = task_runner_->CreateCheckpoint("on_invoke_reply");
+ DeferredBase deferred_reply(
+ [on_invoke_reply](AsyncResult<ProtoMessage> reply) {
+ EXPECT_TRUE(reply.success());
+ on_invoke_reply();
+ });
+ proxy->BeginInvoke("FakeMethod1", req, std::move(deferred_reply));
+ task_runner_->RunUntilCheckpoint("on_invoke_reply");
+
+ // Invoke an invalid method.
+ auto on_invalid_invoke = task_runner_->CreateCheckpoint("on_invalid_invoke");
+ DeferredBase deferred_reply2(
+ [on_invalid_invoke](AsyncResult<ProtoMessage> reply) {
+ EXPECT_FALSE(reply.success());
+ on_invalid_invoke();
+ });
+ RequestProto empty_req;
+ proxy->BeginInvoke("InvalidMethod", empty_req, std::move(deferred_reply2));
+ task_runner_->RunUntilCheckpoint("on_invalid_invoke");
+}
+
+// Like BindAndInvokeMethod, but this time invoke a streaming method that
+// provides > 1 reply per invocation.
+TEST_F(ClientImplTest, BindAndInvokeStreamingMethod) {
+ auto* host_svc = host_->AddFakeService("FakeSvc");
+ auto* host_method = host_svc->AddFakeMethod("FakeMethod1");
+ const int kNumReplies = 3;
+
+ // Create and bind |proxy| to the fake host.
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+ cli_->BindService(proxy->GetWeakPtr());
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ // Invoke a valid method, reply kNumReplies times.
+ int replies_left = kNumReplies;
+ EXPECT_CALL(*host_method, OnInvoke(_, _))
+ .Times(kNumReplies)
+ .WillRepeatedly(Invoke([&replies_left](const Frame::InvokeMethod& req,
+ Frame::InvokeMethodReply* reply) {
+ RequestProto req_args;
+ EXPECT_TRUE(req_args.ParseFromString(req.args_proto()));
+ reply->set_reply_proto(ReplyProto().SerializeAsString());
+ reply->set_success(true);
+ reply->set_has_more(--replies_left > 0);
+ }));
+
+ RequestProto req;
+ req.set_data("req_data");
+ auto on_last_reply = task_runner_->CreateCheckpoint("on_last_reply");
+ int replies_seen = 0;
+ DeferredBase deferred_reply(
+ [on_last_reply, &replies_seen](AsyncResult<ProtoMessage> reply) {
+ EXPECT_TRUE(reply.success());
+ replies_seen++;
+ if (!reply.has_more())
+ on_last_reply();
+ });
+ proxy->BeginInvoke("FakeMethod1", req, std::move(deferred_reply));
+ task_runner_->RunUntilCheckpoint("on_last_reply");
+ ASSERT_EQ(kNumReplies, replies_seen);
+}
+
+// Like BindAndInvokeMethod, but this time invoke a streaming method that
+// provides > 1 reply per invocation.
+TEST_F(ClientImplTest, ReceiveFileDescriptor) {
+ auto* host_svc = host_->AddFakeService("FakeSvc");
+ auto* host_method = host_svc->AddFakeMethod("FakeMethod1");
+
+ // Create and bind |proxy| to the fake host.
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+ cli_->BindService(proxy->GetWeakPtr());
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ FILE* tx_file = tmpfile(); // Automatically unlinked from the filesystem.
+ static constexpr char kFileContent[] = "shared file";
+ fwrite(kFileContent, sizeof(kFileContent), 1, tx_file);
+ fflush(tx_file);
+ host_->next_reply_fd = fileno(tx_file);
+
+ // Invoke a valid method, reply kNumReplies times.
+ EXPECT_CALL(*host_method, OnInvoke(_, _))
+ .WillOnce(Invoke(
+ [](const Frame::InvokeMethod& req, Frame::InvokeMethodReply* reply) {
+ RequestProto req_args;
+ reply->set_reply_proto(ReplyProto().SerializeAsString());
+ reply->set_success(true);
+ }));
+
+ RequestProto req;
+ auto on_reply = task_runner_->CreateCheckpoint("on_reply");
+ DeferredBase deferred_reply([on_reply](AsyncResult<ProtoMessage> reply) {
+ EXPECT_TRUE(reply.success());
+ on_reply();
+ });
+ proxy->BeginInvoke("FakeMethod1", req, std::move(deferred_reply));
+ task_runner_->RunUntilCheckpoint("on_reply");
+
+ fclose(tx_file);
+ base::ScopedFile rx_fd = cli_->TakeReceivedFD();
+ ASSERT_TRUE(rx_fd);
+ char buf[sizeof(kFileContent)] = {};
+ ASSERT_EQ(0, lseek(*rx_fd, 0, SEEK_SET));
+ ASSERT_EQ(static_cast<long>(sizeof(buf)),
+ PERFETTO_EINTR(read(*rx_fd, buf, sizeof(buf))));
+ ASSERT_STREQ(kFileContent, buf);
+}
+
+TEST_F(ClientImplTest, BindSameServiceMultipleTimesShouldFail) {
+ host_->AddFakeService("FakeSvc");
+
+ std::unique_ptr<FakeProxy> proxy[3];
+ for (size_t i = 0; i < base::ArraySize(proxy); i++)
+ proxy[i].reset(new FakeProxy("FakeSvc", &proxy_events_));
+
+ // Bind to the host.
+ for (size_t i = 0; i < base::ArraySize(proxy); i++) {
+ auto checkpoint_name = "on_connect_or_disconnect" + std::to_string(i);
+ auto closure = task_runner_->CreateCheckpoint(checkpoint_name);
+ if (i == 0) {
+ // Only the first call should succeed.
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(closure));
+ } else {
+ EXPECT_CALL(proxy_events_, OnDisconnect()).WillOnce(Invoke(closure));
+ }
+ cli_->BindService(proxy[i]->GetWeakPtr());
+ task_runner_->RunUntilCheckpoint(checkpoint_name);
+ }
+}
+
+TEST_F(ClientImplTest, BindRequestsAreQueuedIfNotConnected) {
+ host_->AddFakeService("FakeSvc1");
+ host_->AddFakeService("FakeSvc2");
+
+ std::unique_ptr<FakeProxy> proxy1(new FakeProxy("FakeSvc1", &proxy_events_));
+ std::unique_ptr<FakeProxy> proxy2(new FakeProxy("FakeSvc2", &proxy_events_));
+
+ // Bind the services (in opposite order of creation) before running any task.
+ cli_->BindService(proxy2->GetWeakPtr());
+ cli_->BindService(proxy1->GetWeakPtr());
+
+ InSequence seq;
+ auto on_connect1 = task_runner_->CreateCheckpoint("on_connect1");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect1));
+
+ auto on_connect2 = task_runner_->CreateCheckpoint("on_connect2");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect2));
+
+ task_runner_->RunUntilCheckpoint("on_connect1");
+ task_runner_->RunUntilCheckpoint("on_connect2");
+}
+
+// The deferred callbacks for both binding a service and invoking a method
+// should be dropped if the ServiceProxy object is destroyed prematurely.
+TEST_F(ClientImplTest, DropCallbacksIfServiceProxyIsDestroyed) {
+ auto* host_svc = host_->AddFakeService("FakeSvc");
+ auto* host_method = host_svc->AddFakeMethod("FakeMethod1");
+
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+
+ // First bind the service but destroy it before ClientImpl manages to run any
+ // tasks. No OnConnect() should be called.
+ cli_->BindService(proxy->GetWeakPtr());
+ proxy.reset();
+ task_runner_->RunUntilIdle();
+ ASSERT_TRUE(Mock::VerifyAndClearExpectations(&proxy_events_));
+
+ // Now bind it successfully, invoke a method but destroy the proxy before
+ // the method reply is dispatched. The DeferredReply should be rejected,
+ // despite the fact that the host gave a successful reply.
+ proxy.reset(new FakeProxy("FakeSvc", &proxy_events_));
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ cli_->BindService(proxy->GetWeakPtr());
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ RequestProto req;
+ auto on_reply_sent = task_runner_->CreateCheckpoint("on_reply_sent");
+ EXPECT_CALL(*host_method, OnInvoke(_, _))
+ .WillOnce(Invoke([on_reply_sent](const Frame::InvokeMethod&,
+ Frame::InvokeMethodReply* reply) {
+ ReplyProto reply_args;
+ reply->set_success(true);
+ on_reply_sent();
+ }));
+
+ auto on_reject = task_runner_->CreateCheckpoint("on_reject");
+ DeferredBase deferred_reply([on_reject](AsyncResult<ProtoMessage> res) {
+ ASSERT_FALSE(res.success());
+ on_reject();
+ });
+ proxy->BeginInvoke("FakeMethod1", req, std::move(deferred_reply));
+ proxy.reset();
+ task_runner_->RunUntilCheckpoint("on_reject");
+ task_runner_->RunUntilCheckpoint("on_reply_sent");
+}
+
+// If the Client object is destroyed before the ServiceProxy, the ServiceProxy
+// should see a Disconnect() call and any pending callback should be rejected.
+TEST_F(ClientImplTest, ClientDestroyedBeforeProxy) {
+ auto* host_svc = host_->AddFakeService("FakeSvc");
+ host_svc->AddFakeMethod("FakeMethod1");
+
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ cli_->BindService(proxy->GetWeakPtr());
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ auto on_reject = task_runner_->CreateCheckpoint("on_reject");
+ DeferredBase deferred_reply([on_reject](AsyncResult<ProtoMessage> res) {
+ ASSERT_FALSE(res.success());
+ on_reject();
+ });
+ RequestProto req;
+ proxy->BeginInvoke("FakeMethod1", req, std::move(deferred_reply));
+ EXPECT_CALL(proxy_events_, OnDisconnect());
+ cli_.reset();
+ task_runner_->RunUntilCheckpoint("on_reject");
+}
+
+// Test that OnDisconnect() is invoked if the host is not reachable.
+TEST_F(ClientImplTest, HostNotReachable) {
+ host_.reset();
+
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+
+ auto on_disconnect = task_runner_->CreateCheckpoint("on_disconnect");
+ EXPECT_CALL(proxy_events_, OnDisconnect()).WillOnce(Invoke(on_disconnect));
+ cli_->BindService(proxy->GetWeakPtr());
+ task_runner_->RunUntilCheckpoint("on_disconnect");
+}
+
+// Test that OnDisconnect() is invoked if the host shuts down prematurely.
+TEST_F(ClientImplTest, HostDisconnection) {
+ host_->AddFakeService("FakeSvc");
+
+ std::unique_ptr<FakeProxy> proxy(new FakeProxy("FakeSvc", &proxy_events_));
+
+ // Bind |proxy| to the fake host.
+ cli_->BindService(proxy->GetWeakPtr());
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ task_runner_->RunUntilCheckpoint("on_connect");
+
+ auto on_disconnect = task_runner_->CreateCheckpoint("on_disconnect");
+ EXPECT_CALL(proxy_events_, OnDisconnect()).WillOnce(Invoke(on_disconnect));
+ host_.reset();
+ task_runner_->RunUntilCheckpoint("on_disconnect");
+}
+
+// TODO(primiano): add the tests below.
+// TEST(ClientImplTest, UnparsableReply) {}
+
+} // namespace
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/deferred.cc b/src/ipc/deferred.cc
new file mode 100644
index 0000000..bed3cc2
--- /dev/null
+++ b/src/ipc/deferred.cc
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/ipc/deferred.h"
+
+#include "google/protobuf/message_lite.h"
+#include "perfetto/base/logging.h"
+
+namespace perfetto {
+namespace ipc {
+
+DeferredBase::DeferredBase(
+ std::function<void(AsyncResult<ProtoMessage>)> callback)
+ : callback_(std::move(callback)) {}
+
+DeferredBase::~DeferredBase() {
+ if (callback_)
+ Reject();
+}
+
+// Can't just use "= default" here because the default move operator for
+// std::function doesn't necessarily swap and hence can leave a copy of the
+// bind state around, which is undesirable.
+DeferredBase::DeferredBase(DeferredBase&& other) noexcept {
+ Move(other);
+}
+
+DeferredBase& DeferredBase::operator=(DeferredBase&& other) {
+ if (callback_)
+ Reject();
+ Move(other);
+ return *this;
+}
+
+void DeferredBase::Move(DeferredBase& other) {
+ callback_ = std::move(other.callback_);
+ other.callback_ = nullptr;
+}
+
+void DeferredBase::Bind(
+ std::function<void(AsyncResult<ProtoMessage>)> callback) {
+ callback_ = std::move(callback);
+}
+
+bool DeferredBase::IsBound() const {
+ return !!callback_;
+}
+
+void DeferredBase::Resolve(AsyncResult<ProtoMessage> async_result) {
+ if (!callback_) {
+ PERFETTO_DCHECK(false);
+ return;
+ }
+ bool has_more = async_result.has_more();
+ callback_(std::move(async_result));
+ if (!has_more)
+ callback_ = nullptr;
+}
+
+// Resolves with a nullptr |msg_|, signalling failure to |callback_|.
+void DeferredBase::Reject() {
+ Resolve(AsyncResult<ProtoMessage>());
+}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/deferred_unittest.cc b/src/ipc/deferred_unittest.cc
new file mode 100644
index 0000000..076cce6
--- /dev/null
+++ b/src/ipc/deferred_unittest.cc
@@ -0,0 +1,281 @@
+/*
+ * Copyright (C) 2017 The Android Open foo Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/ipc/deferred.h"
+
+#include "gtest/gtest.h"
+#include "perfetto/base/logging.h"
+
+#include "src/ipc/test/deferred_unittest_messages.pb.h"
+
+namespace perfetto {
+namespace ipc {
+namespace {
+
+#if PERFETTO_DCHECK_IS_ON()
+#define EXPECT_DCHECK(x) EXPECT_DEATH_IF_SUPPORTED((x), ".*");
+#else
+#define EXPECT_DCHECK(x) x
+#endif
+
+TEST(DeferredTest, BindAndResolve) {
+ Deferred<TestMessage> deferred;
+ std::shared_ptr<int> num_callbacks(new int{0});
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_TRUE(msg.success());
+ ASSERT_TRUE(msg);
+ ASSERT_EQ(42, msg->num());
+ ASSERT_EQ(13, msg.fd());
+ ASSERT_EQ("foo", msg->str());
+ (*num_callbacks)++;
+ });
+
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res->set_num(42);
+ res.set_fd(13);
+ (*res).set_str("foo");
+ deferred.Resolve(std::move(res));
+
+ // A second call to Resolve() or Reject() shouldn't have any effect beause we
+ // didn't set has_more.
+ EXPECT_DCHECK(deferred.Resolve(std::move(res)));
+ EXPECT_DCHECK(deferred.Reject());
+
+ ASSERT_EQ(1, *num_callbacks);
+}
+
+// In case of a Reject() a callback with a nullptr should be received.
+TEST(DeferredTest, BindAndFail) {
+ Deferred<TestMessage> deferred;
+ std::shared_ptr<int> num_callbacks(new int{0});
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_EQ(-1, msg.fd());
+ ASSERT_FALSE(msg.success());
+ ASSERT_FALSE(msg);
+ ASSERT_EQ(nullptr, &*msg);
+ (*num_callbacks)++;
+ });
+
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res.set_fd(42);
+ deferred.Reject();
+ EXPECT_DCHECK(deferred.Resolve(std::move(res)));
+ EXPECT_DCHECK(deferred.Reject());
+ ASSERT_EQ(1, *num_callbacks);
+}
+
+// Test the RAII behavior.
+TEST(DeferredTest, AutoRejectIfOutOfScope) {
+ std::shared_ptr<int> num_callbacks(new int{0});
+ {
+ Deferred<TestMessage> deferred;
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_FALSE(msg.success());
+ (*num_callbacks)++;
+ });
+ }
+ ASSERT_EQ(1, *num_callbacks);
+}
+
+// Binds two callbacks one after the other and tests that the bind state of the
+// first callback is released.
+TEST(DeferredTest, BindTwiceDoesNotHoldBindState) {
+ // Use shared_ptr's use_count() to infer the bind state of the callback.
+ std::shared_ptr<int> num_callbacks(new int{0});
+ Deferred<TestMessage> deferred;
+ deferred.Bind(
+ [num_callbacks](AsyncResult<TestMessage>) { (*num_callbacks)++; });
+
+ // At this point both the shared_ptr above and the callback in |deferred| are
+ // refcounting the bind state.
+ ASSERT_GE(num_callbacks.use_count(), 2);
+
+ // Re-binding the callback should release the bind state, without invoking the
+ // old callback.
+ deferred.Bind([](AsyncResult<TestMessage> msg) {});
+ ASSERT_EQ(1, num_callbacks.use_count());
+ ASSERT_EQ(0, *num_callbacks);
+
+ // Test that the new callback is invoked when re-bindings.
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_TRUE(msg.success());
+ ASSERT_EQ(4242, msg->num());
+ (*num_callbacks)++;
+ });
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res->set_num(4242);
+ deferred.Resolve(std::move(res));
+ ASSERT_EQ(1, *num_callbacks);
+ ASSERT_EQ(1, num_callbacks.use_count());
+}
+
+TEST(DeferredTest, MoveOperators) {
+ Deferred<TestMessage> deferred;
+ std::shared_ptr<int> num_callbacks(new int{0});
+ std::function<void(AsyncResult<TestMessage>)> callback =
+ [num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_TRUE(msg.success());
+ ASSERT_GE(msg->num(), 42);
+ ASSERT_LE(msg->num(), 43);
+ ASSERT_EQ(msg->num() * 10, msg.fd());
+ ASSERT_EQ(std::to_string(msg->num()), msg->str());
+ (*num_callbacks)++;
+ };
+ deferred.Bind(callback);
+
+ // Do a bit of std::move() dance with both the Deferred and the AsyncResult.
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res.set_fd(420);
+ res->set_num(42);
+ AsyncResult<TestMessage> res_moved(std::move(res));
+ res = std::move(res_moved);
+ res->set_str("42");
+ res_moved = std::move(res);
+
+ Deferred<TestMessage> deferred_moved(std::move(deferred));
+ deferred = std::move(deferred_moved);
+ deferred_moved = std::move(deferred);
+
+ EXPECT_DCHECK(deferred.Reject()); // |deferred| has been cleared.
+ ASSERT_EQ(0, *num_callbacks);
+
+ deferred_moved.Resolve(std::move(res_moved)); // This, instead, should fire.
+ ASSERT_EQ(1, *num_callbacks);
+
+ // |deferred| and |res| have lost their state but should remain reusable.
+ deferred.Bind(callback);
+ res = AsyncResult<TestMessage>::Create();
+ res.set_fd(430);
+ res->set_num(43);
+ res->set_str("43");
+ deferred.Resolve(std::move(res));
+ ASSERT_EQ(2, *num_callbacks);
+
+ // Finally re-bind |deferred|, move it to a new scoped Deferred and verify
+ // that the moved-into object still auto-nacks the callback.
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_FALSE(msg.success());
+ (*num_callbacks)++;
+ });
+ { Deferred<TestMessage> scoped_deferred(std::move(deferred)); }
+ ASSERT_EQ(3, *num_callbacks);
+ callback = nullptr;
+ ASSERT_EQ(1, num_callbacks.use_count());
+}
+
+// Covers the case of a streaming reply, where the deferred keeps being resolved
+// until has_more == true.
+TEST(DeferredTest, StreamingReply) {
+ Deferred<TestMessage> deferred;
+ std::shared_ptr<int> num_callbacks(new int{0});
+ std::function<void(AsyncResult<TestMessage>)> callback =
+ [num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_TRUE(msg.success());
+ ASSERT_EQ(*num_callbacks == 0 ? 13 : -1, msg.fd());
+ ASSERT_EQ(*num_callbacks, msg->num());
+ ASSERT_EQ(std::to_string(*num_callbacks), msg->str());
+ ASSERT_EQ(msg->num() < 3, msg.has_more());
+ (*num_callbacks)++;
+ };
+ deferred.Bind(callback);
+
+ for (int i = 0; i < 3; i++) {
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res.set_fd(i == 0 ? 13 : -1);
+ res->set_num(i);
+ res->set_str(std::to_string(i));
+ res.set_has_more(true);
+ AsyncResult<TestMessage> res_moved(std::move(res));
+ deferred.Resolve(std::move(res_moved));
+ }
+
+ Deferred<TestMessage> deferred_moved(std::move(deferred));
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res->set_num(3);
+ res->set_str(std::to_string(3));
+ res.set_has_more(false);
+ deferred_moved.Resolve(std::move(res));
+ ASSERT_EQ(4, *num_callbacks);
+
+ EXPECT_DCHECK(deferred_moved.Reject());
+ ASSERT_EQ(4, *num_callbacks);
+ callback = nullptr;
+ ASSERT_EQ(1, num_callbacks.use_count());
+}
+
+// Similar to the above, but checks that destroying a Deferred without having
+// resolved with has_more == true automatically rejects once out of scope.
+TEST(DeferredTest, StreamingReplyIsRejectedOutOfScope) {
+ std::shared_ptr<int> num_callbacks(new int{0});
+
+ {
+ Deferred<TestMessage> deferred;
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_EQ((*num_callbacks) < 3, msg.success());
+ ASSERT_EQ(msg.success(), msg.has_more());
+ (*num_callbacks)++;
+ });
+
+ for (int i = 0; i < 3; i++) {
+ AsyncResult<TestMessage> res = AsyncResult<TestMessage>::Create();
+ res.set_has_more(true);
+ deferred.Resolve(std::move(res));
+ }
+
+ // |deferred_moved| going out of scope should cause a Reject().
+ { Deferred<TestMessage> deferred_moved = std::move(deferred); }
+ ASSERT_EQ(4, *num_callbacks);
+ }
+
+ // |deferred| going out of scope should do noting, it has been std::move()'d.
+ ASSERT_EQ(4, *num_callbacks);
+ ASSERT_EQ(1, num_callbacks.use_count());
+}
+
+// Tests that a Deferred<Specialized> still behaves sanely after it has been
+// moved into a DeferredBase.
+TEST(DeferredTest, MoveAsBase) {
+ Deferred<TestMessage> deferred;
+ std::shared_ptr<int> num_callbacks(new int{0});
+ deferred.Bind([num_callbacks](AsyncResult<TestMessage> msg) {
+ ASSERT_TRUE(msg.success());
+ ASSERT_EQ(13, msg.fd());
+ ASSERT_EQ(42, msg->num());
+ ASSERT_EQ("foo", msg->str());
+ (*num_callbacks)++;
+ });
+
+ DeferredBase deferred_base(std::move(deferred));
+ ASSERT_FALSE(deferred.IsBound());
+ ASSERT_TRUE(deferred_base.IsBound());
+
+ std::unique_ptr<TestMessage> msg(new TestMessage());
+ msg->set_num(42);
+ msg->set_str("foo");
+
+ AsyncResult<ProtoMessage> async_result_base(std::move(msg));
+ async_result_base.set_fd(13);
+ deferred_base.Resolve(std::move(async_result_base));
+
+ EXPECT_DCHECK(deferred_base.Resolve(std::move(async_result_base)));
+ EXPECT_DCHECK(deferred_base.Reject());
+
+ ASSERT_EQ(1, *num_callbacks);
+}
+
+} // namespace
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/host_impl.cc b/src/ipc/host_impl.cc
new file mode 100644
index 0000000..7707825
--- /dev/null
+++ b/src/ipc/host_impl.cc
@@ -0,0 +1,256 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/host_impl.h"
+
+#include <inttypes.h>
+
+#include <algorithm>
+#include <utility>
+
+#include "perfetto/base/task_runner.h"
+#include "perfetto/base/utils.h"
+#include "perfetto/ipc/service.h"
+#include "perfetto/ipc/service_descriptor.h"
+
+#include "src/ipc/wire_protocol.pb.h"
+
+// TODO(primiano): put limits on #connections/uid and req. queue (b/69093705).
+
+namespace perfetto {
+namespace ipc {
+
+// static
+std::unique_ptr<Host> Host::CreateInstance(const char* socket_name,
+ base::TaskRunner* task_runner) {
+ std::unique_ptr<HostImpl> host(new HostImpl(socket_name, task_runner));
+ if (!host->sock()->is_listening())
+ return nullptr;
+ return std::move(host);
+}
+
+HostImpl::HostImpl(const char* socket_name, base::TaskRunner* task_runner)
+ : task_runner_(task_runner), weak_ptr_factory_(this) {
+ GOOGLE_PROTOBUF_VERIFY_VERSION;
+ PERFETTO_DCHECK_THREAD(thread_checker_);
+ sock_ = UnixSocket::Listen(socket_name, this, task_runner_);
+}
+
+HostImpl::~HostImpl() = default;
+
+bool HostImpl::ExposeService(std::unique_ptr<Service> service) {
+ PERFETTO_DCHECK_THREAD(thread_checker_);
+ const std::string& service_name = service->GetDescriptor().service_name;
+ if (GetServiceByName(service_name)) {
+ PERFETTO_DLOG("Duplicate ExposeService(): %s", service_name.c_str());
+ return false;
+ }
+ ServiceID sid = ++last_service_id_;
+ ExposedService exposed_service(sid, service_name, std::move(service));
+ services_.emplace(sid, std::move(exposed_service));
+ return true;
+}
+
+void HostImpl::OnNewIncomingConnection(UnixSocket*,
+ std::unique_ptr<UnixSocket> new_conn) {
+ PERFETTO_DCHECK_THREAD(thread_checker_);
+ std::unique_ptr<ClientConnection> client(new ClientConnection());
+ ClientID client_id = ++last_client_id_;
+ clients_by_socket_[new_conn.get()] = client.get();
+ client->id = client_id;
+ client->sock = std::move(new_conn);
+ clients_[client_id] = std::move(client);
+}
+
+void HostImpl::OnDataAvailable(UnixSocket* sock) {
+ PERFETTO_DCHECK_THREAD(thread_checker_);
+ auto it = clients_by_socket_.find(sock);
+ if (it == clients_by_socket_.end())
+ return;
+ ClientConnection* client = it->second;
+ BufferedFrameDeserializer& frame_deserializer = client->frame_deserializer;
+
+ size_t rsize;
+ do {
+ auto buf = frame_deserializer.BeginReceive();
+ rsize = client->sock->Receive(buf.data, buf.size);
+ if (!frame_deserializer.EndReceive(rsize))
+ return OnDisconnect(client->sock.get());
+ } while (rsize > 0);
+
+ for (;;) {
+ std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame();
+ if (!frame)
+ break;
+ OnReceivedFrame(client, *frame);
+ }
+}
+
+void HostImpl::OnReceivedFrame(ClientConnection* client,
+ const Frame& req_frame) {
+ if (req_frame.msg_case() == Frame::kMsgBindService) {
+ return OnBindService(client, req_frame);
+ } else if (req_frame.msg_case() == Frame::kMsgInvokeMethod) {
+ return OnInvokeMethod(client, req_frame);
+ }
+ PERFETTO_DLOG("Received invalid RPC frame %u from client %" PRIu64,
+ req_frame.msg_case(), client->id);
+ Frame reply_frame;
+ reply_frame.set_request_id(req_frame.request_id());
+ reply_frame.mutable_msg_request_error()->set_error("unknown request");
+ SendFrame(client, reply_frame);
+}
+
+void HostImpl::OnBindService(ClientConnection* client, const Frame& req_frame) {
+ // Binding a service doesn't do anything major. It just returns back the
+ // service id and its method map.
+ const Frame::BindService& req = req_frame.msg_bind_service();
+ Frame reply_frame;
+ reply_frame.set_request_id(req_frame.request_id());
+ auto* reply = reply_frame.mutable_msg_bind_service_reply();
+ const ExposedService* service = GetServiceByName(req.service_name());
+ if (service) {
+ reply->set_success(true);
+ reply->set_service_id(service->id);
+ uint32_t method_id = 1; // method ids start at index 1.
+ for (const auto& desc_method : service->instance->GetDescriptor().methods) {
+ Frame::BindServiceReply::MethodInfo* method_info = reply->add_methods();
+ method_info->set_name(desc_method.name);
+ method_info->set_id(method_id++);
+ }
+ }
+ SendFrame(client, reply_frame);
+}
+
+void HostImpl::OnInvokeMethod(ClientConnection* client,
+ const Frame& req_frame) {
+ const Frame::InvokeMethod& req = req_frame.msg_invoke_method();
+ Frame reply_frame;
+ RequestID request_id = req_frame.request_id();
+ reply_frame.set_request_id(request_id);
+ reply_frame.mutable_msg_invoke_method_reply()->set_success(false);
+ auto svc_it = services_.find(req.service_id());
+ if (svc_it == services_.end())
+ return SendFrame(client, reply_frame); // |success| == false by default.
+
+ Service* service = svc_it->second.instance.get();
+ const ServiceDescriptor& svc = service->GetDescriptor();
+ const auto& methods = svc.methods;
+ if (req.method_id() <= 0 ||
+ static_cast<uint32_t>(req.method_id()) > methods.size())
+ return SendFrame(client, reply_frame);
+
+ const ServiceDescriptor::Method& method = methods[req.method_id() - 1];
+ std::unique_ptr<ProtoMessage> decoded_req_args(
+ method.request_proto_decoder(req.args_proto()));
+ if (!decoded_req_args)
+ return SendFrame(client, reply_frame);
+
+ Deferred<ProtoMessage> deferred_reply;
+ base::WeakPtr<HostImpl> host_weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ ClientID client_id = client->id;
+ deferred_reply.Bind(
+ [host_weak_ptr, client_id, request_id](AsyncResult<ProtoMessage> reply) {
+ if (!host_weak_ptr)
+ return; // The reply came too late, the HostImpl has gone.
+ host_weak_ptr->ReplyToMethodInvocation(client_id, request_id,
+ std::move(reply));
+ });
+
+ service->client_info_ = ClientInfo(client->id, client->sock->peer_uid());
+ method.invoker(service, *decoded_req_args, std::move(deferred_reply));
+ service->client_info_ = ClientInfo();
+}
+
+void HostImpl::ReplyToMethodInvocation(ClientID client_id,
+ RequestID request_id,
+ AsyncResult<ProtoMessage> reply) {
+ auto client_iter = clients_.find(client_id);
+ if (client_iter == clients_.end())
+ return; // client has disconnected by the time we got the async reply.
+
+ ClientConnection* client = client_iter->second.get();
+ Frame reply_frame;
+ reply_frame.set_request_id(request_id);
+
+ auto* reply_frame_data = reply_frame.mutable_msg_invoke_method_reply();
+ reply_frame_data->set_has_more(reply.has_more());
+ if (reply.success()) {
+ std::string reply_proto;
+ if (reply->SerializeToString(&reply_proto)) {
+ reply_frame_data->set_reply_proto(reply_proto);
+ reply_frame_data->set_success(true);
+ }
+ }
+ SendFrame(client, reply_frame, reply.fd());
+}
+
+// static
+void HostImpl::SendFrame(ClientConnection* client, const Frame& frame, int fd) {
+ std::string buf = BufferedFrameDeserializer::Serialize(frame);
+
+ // TODO(primiano): remember that this is doing non-blocking I/O. What if the
+ // socket buffer is full? Maybe we just want to drop this on the floor? Or
+ // maybe throttle the send and PostTask the reply later?
+ bool res = client->sock->Send(buf.data(), buf.size(), fd);
+ PERFETTO_CHECK(!client->sock->is_connected() || res);
+}
+
+void HostImpl::OnDisconnect(UnixSocket* sock) {
+ PERFETTO_DCHECK_THREAD(thread_checker_);
+ auto it = clients_by_socket_.find(sock);
+ if (it == clients_by_socket_.end())
+ return;
+ ClientID client_id = it->second->id;
+ ClientInfo client_info(client_id, sock->peer_uid());
+ clients_by_socket_.erase(it);
+ PERFETTO_DCHECK(clients_.count(client_id));
+ clients_.erase(client_id);
+
+ for (const auto& service_it : services_) {
+ Service& service = *service_it.second.instance;
+ service.client_info_ = client_info;
+ service.OnClientDisconnected();
+ service.client_info_ = ClientInfo();
+ }
+}
+
+const HostImpl::ExposedService* HostImpl::GetServiceByName(
+ const std::string& name) {
+ // This could be optimized by using another map<name,ServiceID>. However this
+ // is used only by Bind/ExposeService that are quite rare (once per client
+ // connection and once per service instance), not worth it.
+ for (const auto& it : services_) {
+ if (it.second.name == name)
+ return &it.second;
+ }
+ return nullptr;
+}
+
+HostImpl::ExposedService::ExposedService(ServiceID id_,
+ const std::string& name_,
+ std::unique_ptr<Service> instance_)
+ : id(id_), name(name_), instance(std::move(instance_)) {}
+
+HostImpl::ExposedService::ExposedService(ExposedService&&) noexcept = default;
+HostImpl::ExposedService& HostImpl::ExposedService::operator=(
+ HostImpl::ExposedService&&) = default;
+HostImpl::ExposedService::~ExposedService() = default;
+
+HostImpl::ClientConnection::~ClientConnection() = default;
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/host_impl.h b/src/ipc/host_impl.h
new file mode 100644
index 0000000..f09b32d
--- /dev/null
+++ b/src/ipc/host_impl.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_IPC_HOST_IMPL_H_
+#define SRC_IPC_HOST_IMPL_H_
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "perfetto/base/task_runner.h"
+#include "perfetto/base/thread_checker.h"
+#include "perfetto/ipc/deferred.h"
+#include "perfetto/ipc/host.h"
+#include "src/ipc/buffered_frame_deserializer.h"
+#include "src/ipc/unix_socket.h"
+
+namespace perfetto {
+namespace ipc {
+
+class Frame;
+
+class HostImpl : public Host, public UnixSocket::EventListener {
+ public:
+ HostImpl(const char* socket_name, base::TaskRunner*);
+ ~HostImpl() override;
+
+ // Host implementation.
+ bool ExposeService(std::unique_ptr<Service>) override;
+
+ // UnixSocket::EventListener implementation.
+ void OnNewIncomingConnection(UnixSocket*,
+ std::unique_ptr<UnixSocket>) override;
+ void OnDisconnect(UnixSocket*) override;
+ void OnDataAvailable(UnixSocket*) override;
+
+ const UnixSocket* sock() const { return sock_.get(); }
+
+ private:
+ // Owns the per-client receive buffer (BufferedFrameDeserializer).
+ struct ClientConnection {
+ ~ClientConnection();
+ ClientID id;
+ std::unique_ptr<UnixSocket> sock;
+ BufferedFrameDeserializer frame_deserializer;
+ };
+ struct ExposedService {
+ ExposedService(ServiceID, const std::string&, std::unique_ptr<Service>);
+ ~ExposedService();
+ ExposedService(ExposedService&&) noexcept;
+ ExposedService& operator=(ExposedService&&);
+
+ ServiceID id;
+ std::string name;
+ std::unique_ptr<Service> instance;
+ };
+
+ HostImpl(const HostImpl&) = delete;
+ HostImpl& operator=(const HostImpl&) = delete;
+
+ bool Initialize(const char* socket_name);
+ void OnReceivedFrame(ClientConnection*, const Frame&);
+ void OnBindService(ClientConnection*, const Frame&);
+ void OnInvokeMethod(ClientConnection*, const Frame&);
+ void ReplyToMethodInvocation(ClientID, RequestID, AsyncResult<ProtoMessage>);
+ const ExposedService* GetServiceByName(const std::string&);
+
+ static void SendFrame(ClientConnection*, const Frame&, int fd = -1);
+
+ base::TaskRunner* const task_runner_;
+ base::WeakPtrFactory<HostImpl> weak_ptr_factory_;
+ std::map<ServiceID, ExposedService> services_;
+ std::unique_ptr<UnixSocket> sock_; // The listening socket.
+ std::map<ClientID, std::unique_ptr<ClientConnection>> clients_;
+ std::map<UnixSocket*, ClientConnection*> clients_by_socket_;
+ ServiceID last_service_id_ = 0;
+ ClientID last_client_id_ = 0;
+ PERFETTO_THREAD_CHECKER(thread_checker_)
+};
+
+} // namespace ipc
+} // namespace perfetto
+
+#endif // SRC_IPC_HOST_IMPL_H_
diff --git a/src/ipc/host_impl_unittest.cc b/src/ipc/host_impl_unittest.cc
new file mode 100644
index 0000000..72ab226
--- /dev/null
+++ b/src/ipc/host_impl_unittest.cc
@@ -0,0 +1,391 @@
+/*
+ * Copyright (C) 2017 The Android Open foo Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/host_impl.h"
+
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "perfetto/base/scoped_file.h"
+#include "perfetto/ipc/service.h"
+#include "perfetto/ipc/service_descriptor.h"
+#include "src/base/test/test_task_runner.h"
+#include "src/ipc/buffered_frame_deserializer.h"
+#include "src/ipc/unix_socket.h"
+
+#include "src/ipc/test/client_unittest_messages.pb.h"
+#include "src/ipc/wire_protocol.pb.h"
+
+namespace perfetto {
+namespace ipc {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+
+constexpr char kSockName[] = "/tmp/perfetto_host_impl_unittest.sock";
+
+// RequestProto and ReplyProto are defined in client_unittest_messages.proto.
+
+class FakeService : public Service {
+ public:
+ MOCK_METHOD0(Destroyed, void());
+ MOCK_METHOD2(OnFakeMethod1, void(const RequestProto&, DeferredBase*));
+
+ static void Invoker(Service* service,
+ const ProtoMessage& req,
+ DeferredBase deferred_reply) {
+ static_cast<FakeService*>(service)->OnFakeMethod1(
+ static_cast<const RequestProto&>(req), &deferred_reply);
+ }
+
+ static std::unique_ptr<ProtoMessage> RequestDecoder(
+ const std::string& proto) {
+ std::unique_ptr<ProtoMessage> reply(new RequestProto());
+ EXPECT_TRUE(reply->ParseFromString(proto));
+ return reply;
+ }
+
+ FakeService(const char* service_name) {
+ descriptor_.service_name = service_name;
+ descriptor_.methods.push_back(
+ {"FakeMethod1", &RequestDecoder, nullptr, &Invoker});
+ }
+
+ const ServiceDescriptor& GetDescriptor() override { return descriptor_; }
+
+ ServiceDescriptor descriptor_;
+};
+
+class FakeClient : public UnixSocket::EventListener {
+ public:
+ MOCK_METHOD0(OnConnect, void());
+ MOCK_METHOD0(OnDisconnect, void());
+ MOCK_METHOD1(OnServiceBound, void(const Frame::BindServiceReply&));
+ MOCK_METHOD1(OnInvokeMethodReply, void(const Frame::InvokeMethodReply&));
+ MOCK_METHOD1(OnFileDescriptorReceived, void(int));
+ MOCK_METHOD0(OnRequestError, void());
+
+ explicit FakeClient(base::TaskRunner* task_runner) {
+ sock_ = UnixSocket::Connect(kSockName, this, task_runner);
+ }
+
+ ~FakeClient() override = default;
+
+ void BindService(const std::string& service_name) {
+ Frame frame;
+ uint64_t request_id = requests_.empty() ? 1 : requests_.rbegin()->first + 1;
+ requests_.emplace(request_id, 0);
+ frame.set_request_id(request_id);
+ frame.mutable_msg_bind_service()->set_service_name(service_name);
+ SendFrame(frame);
+ }
+
+ void InvokeMethod(ServiceID service_id,
+ MethodID method_id,
+ const ProtoMessage& args) {
+ Frame frame;
+ uint64_t request_id = requests_.empty() ? 1 : requests_.rbegin()->first + 1;
+ requests_.emplace(request_id, 0);
+ frame.set_request_id(request_id);
+ frame.mutable_msg_invoke_method()->set_service_id(service_id);
+ frame.mutable_msg_invoke_method()->set_method_id(method_id);
+ frame.mutable_msg_invoke_method()->set_args_proto(args.SerializeAsString());
+ SendFrame(frame);
+ }
+
+ // UnixSocket::EventListener implementation.
+ void OnConnect(UnixSocket*, bool success) override {
+ ASSERT_TRUE(success);
+ OnConnect();
+ }
+
+ void OnDisconnect(UnixSocket*) override { OnDisconnect(); }
+
+ void OnDataAvailable(UnixSocket* sock) override {
+ ASSERT_EQ(sock_.get(), sock);
+ auto buf = frame_deserializer_.BeginReceive();
+ base::ScopedFile fd;
+ size_t rsize = sock->Receive(buf.data, buf.size, &fd);
+ ASSERT_TRUE(frame_deserializer_.EndReceive(rsize));
+ if (fd)
+ OnFileDescriptorReceived(*fd);
+ while (std::unique_ptr<Frame> frame = frame_deserializer_.PopNextFrame()) {
+ ASSERT_EQ(1u, requests_.count(frame->request_id()));
+ EXPECT_EQ(0, requests_[frame->request_id()]++);
+ if (frame->msg_case() == Frame::kMsgBindServiceReply) {
+ if (frame->msg_bind_service_reply().success())
+ last_bound_service_id_ = frame->msg_bind_service_reply().service_id();
+ return OnServiceBound(frame->msg_bind_service_reply());
+ }
+ if (frame->msg_case() == Frame::kMsgInvokeMethodReply)
+ return OnInvokeMethodReply(frame->msg_invoke_method_reply());
+ if (frame->msg_case() == Frame::kMsgRequestError)
+ return OnRequestError();
+ FAIL() << "Unexpected frame received from host " << frame->msg_case();
+ }
+ }
+
+ void SendFrame(const Frame& frame) {
+ std::string buf = BufferedFrameDeserializer::Serialize(frame);
+ ASSERT_TRUE(sock_->Send(buf.data(), buf.size()));
+ }
+
+ BufferedFrameDeserializer frame_deserializer_;
+ std::unique_ptr<UnixSocket> sock_;
+ std::map<uint64_t /* request_id */, int /* num_replies_received */> requests_;
+ ServiceID last_bound_service_id_;
+};
+
+class HostImplTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ unlink(kSockName);
+ task_runner_.reset(new base::TestTaskRunner());
+ Host* host = Host::CreateInstance(kSockName, task_runner_.get()).release();
+ ASSERT_NE(nullptr, host);
+ host_.reset(static_cast<HostImpl*>(host));
+ cli_.reset(new FakeClient(task_runner_.get()));
+ auto on_connect = task_runner_->CreateCheckpoint("on_connect");
+ EXPECT_CALL(*cli_, OnConnect()).WillOnce(Invoke(on_connect));
+ task_runner_->RunUntilCheckpoint("on_connect");
+ }
+
+ void TearDown() override {
+ task_runner_->RunUntilIdle();
+ cli_.reset();
+ host_.reset();
+ task_runner_->RunUntilIdle();
+ task_runner_.reset();
+ unlink(kSockName);
+ }
+
+ // ::testing::StrictMock<MockEventListener> proxy_events_;
+ std::unique_ptr<base::TestTaskRunner> task_runner_;
+ std::unique_ptr<HostImpl> host_;
+ std::unique_ptr<FakeClient> cli_;
+};
+
+TEST_F(HostImplTest, BindService) {
+ // First bind the service when it doesn't exists yet and check that the
+ // BindService() request fails.
+ cli_->BindService("FakeService"); // FakeService does not exist yet.
+ auto on_bind_failure = task_runner_->CreateCheckpoint("on_bind_failure");
+ EXPECT_CALL(*cli_, OnServiceBound(_))
+ .WillOnce(Invoke([on_bind_failure](const Frame::BindServiceReply& reply) {
+ ASSERT_FALSE(reply.success());
+ on_bind_failure();
+ }));
+ task_runner_->RunUntilCheckpoint("on_bind_failure");
+
+ // Now expose the service and bind it.
+ ASSERT_TRUE(host_->ExposeService(
+ std::unique_ptr<Service>(new FakeService("FakeService"))));
+ auto on_bind_success = task_runner_->CreateCheckpoint("on_bind_success");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_))
+ .WillOnce(Invoke([on_bind_success](const Frame::BindServiceReply& reply) {
+ ASSERT_TRUE(reply.success());
+ on_bind_success();
+ }));
+ task_runner_->RunUntilCheckpoint("on_bind_success");
+}
+
+TEST_F(HostImplTest, InvokeNonExistingMethod) {
+ FakeService* fake_service = new FakeService("FakeService");
+ ASSERT_TRUE(host_->ExposeService(std::unique_ptr<Service>(fake_service)));
+ auto on_bind = task_runner_->CreateCheckpoint("on_bind");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_)).WillOnce(InvokeWithoutArgs(on_bind));
+ task_runner_->RunUntilCheckpoint("on_bind");
+
+ auto on_invoke_failure = task_runner_->CreateCheckpoint("on_invoke_failure");
+ cli_->InvokeMethod(cli_->last_bound_service_id_, 42, RequestProto());
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_))
+ .WillOnce(
+ Invoke([on_invoke_failure](const Frame::InvokeMethodReply& reply) {
+ ASSERT_FALSE(reply.success());
+ ASSERT_FALSE(reply.has_more());
+ on_invoke_failure();
+ }));
+ task_runner_->RunUntilCheckpoint("on_invoke_failure");
+}
+
+TEST_F(HostImplTest, InvokeMethod) {
+ FakeService* fake_service = new FakeService("FakeService");
+ ASSERT_TRUE(host_->ExposeService(std::unique_ptr<Service>(fake_service)));
+ auto on_bind = task_runner_->CreateCheckpoint("on_bind");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_)).WillOnce(InvokeWithoutArgs(on_bind));
+ task_runner_->RunUntilCheckpoint("on_bind");
+
+ RequestProto req_args;
+ req_args.set_data("foo");
+ cli_->InvokeMethod(cli_->last_bound_service_id_, 1, req_args);
+ auto on_reply_sent = task_runner_->CreateCheckpoint("on_reply_sent");
+ EXPECT_CALL(*fake_service, OnFakeMethod1(_, _))
+ .WillOnce(
+ Invoke([on_reply_sent](const RequestProto& req, DeferredBase* reply) {
+ ASSERT_EQ("foo", req.data());
+ std::unique_ptr<ReplyProto> reply_args(new ReplyProto());
+ reply_args->set_data("bar");
+ reply->Resolve(AsyncResult<ProtoMessage>(
+ std::unique_ptr<ProtoMessage>(reply_args.release())));
+ on_reply_sent();
+ }));
+ task_runner_->RunUntilCheckpoint("on_reply_sent");
+
+ auto on_reply_received = task_runner_->CreateCheckpoint("on_reply_received");
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_))
+ .WillOnce(
+ Invoke([on_reply_received](const Frame::InvokeMethodReply& reply) {
+ ASSERT_TRUE(reply.success());
+ ASSERT_FALSE(reply.has_more());
+ ReplyProto reply_args;
+ reply_args.ParseFromString(reply.reply_proto());
+ ASSERT_EQ("bar", reply_args.data());
+ on_reply_received();
+ }));
+ task_runner_->RunUntilCheckpoint("on_reply_received");
+}
+
+TEST_F(HostImplTest, SendFileDescriptor) {
+ FakeService* fake_service = new FakeService("FakeService");
+ ASSERT_TRUE(host_->ExposeService(std::unique_ptr<Service>(fake_service)));
+ auto on_bind = task_runner_->CreateCheckpoint("on_bind");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_)).WillOnce(InvokeWithoutArgs(on_bind));
+ task_runner_->RunUntilCheckpoint("on_bind");
+
+ static constexpr char kFileContent[] = "shared file";
+ RequestProto req_args;
+ cli_->InvokeMethod(cli_->last_bound_service_id_, 1, req_args);
+ auto on_reply_sent = task_runner_->CreateCheckpoint("on_reply_sent");
+ FILE* tx_file = tmpfile();
+ fwrite(kFileContent, sizeof(kFileContent), 1, tx_file);
+ fflush(tx_file);
+ EXPECT_CALL(*fake_service, OnFakeMethod1(_, _))
+ .WillOnce(Invoke([on_reply_sent, tx_file](const RequestProto& req,
+ DeferredBase* reply) {
+ std::unique_ptr<ReplyProto> reply_args(new ReplyProto());
+ auto async_res = AsyncResult<ProtoMessage>(
+ std::unique_ptr<ProtoMessage>(reply_args.release()));
+ async_res.set_fd(fileno(tx_file));
+ reply->Resolve(std::move(async_res));
+ on_reply_sent();
+ }));
+ task_runner_->RunUntilCheckpoint("on_reply_sent");
+ fclose(tx_file);
+
+ auto on_fd_received = task_runner_->CreateCheckpoint("on_fd_received");
+ EXPECT_CALL(*cli_, OnFileDescriptorReceived(_))
+ .WillOnce(Invoke([on_fd_received](int fd) {
+ char buf[sizeof(kFileContent)] = {};
+ ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
+ ASSERT_EQ(static_cast<int32_t>(sizeof(buf)),
+ PERFETTO_EINTR(read(fd, buf, sizeof(buf))));
+ ASSERT_STREQ(kFileContent, buf);
+ on_fd_received();
+ }));
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_));
+ task_runner_->RunUntilCheckpoint("on_fd_received");
+}
+
+// Invoke a method and immediately after disconnect the client.
+TEST_F(HostImplTest, OnClientDisconnect) {
+ FakeService* fake_service = new FakeService("FakeService");
+ ASSERT_TRUE(host_->ExposeService(std::unique_ptr<Service>(fake_service)));
+ auto on_bind = task_runner_->CreateCheckpoint("on_bind");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_)).WillOnce(InvokeWithoutArgs(on_bind));
+ task_runner_->RunUntilCheckpoint("on_bind");
+
+ RequestProto req_args;
+ req_args.set_data("foo");
+ cli_->InvokeMethod(cli_->last_bound_service_id_, 1, req_args);
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_)).Times(0);
+ cli_.reset(); // Disconnect the client.
+ auto on_host_method = task_runner_->CreateCheckpoint("on_host_method");
+ EXPECT_CALL(*fake_service, OnFakeMethod1(_, _))
+ .WillOnce(Invoke(
+ [on_host_method](const RequestProto& req, DeferredBase* reply) {
+ ASSERT_EQ("foo", req.data());
+ on_host_method();
+ }));
+ task_runner_->RunUntilCheckpoint("on_host_method");
+}
+
+// Like InvokeMethod, but instead of resolving the Deferred reply within the
+// call stack, std::move()-s it outside an replies
+TEST_F(HostImplTest, MoveReplyObjectAndReplyAsynchronously) {
+ FakeService* fake_service = new FakeService("FakeService");
+ ASSERT_TRUE(host_->ExposeService(std::unique_ptr<Service>(fake_service)));
+ auto on_bind = task_runner_->CreateCheckpoint("on_bind");
+ cli_->BindService("FakeService");
+ EXPECT_CALL(*cli_, OnServiceBound(_)).WillOnce(InvokeWithoutArgs(on_bind));
+ task_runner_->RunUntilCheckpoint("on_bind");
+
+ // Invokes the remote method and waits that the FakeService sees it. The reply
+ // is not resolved but just moved into |moved_reply|.
+ RequestProto req_args;
+ cli_->InvokeMethod(cli_->last_bound_service_id_, 1, req_args);
+ auto on_invoke = task_runner_->CreateCheckpoint("on_invoke");
+ DeferredBase moved_reply;
+ EXPECT_CALL(*fake_service, OnFakeMethod1(_, _))
+ .WillOnce(Invoke([on_invoke, &moved_reply](const RequestProto& req,
+ DeferredBase* reply) {
+ moved_reply = std::move(*reply);
+ on_invoke();
+ }));
+ task_runner_->RunUntilCheckpoint("on_invoke");
+
+ // Check that the FakeClient doesn't see any reply yet.
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_)).Times(0);
+ task_runner_->RunUntilIdle();
+ ASSERT_TRUE(::testing::Mock::VerifyAndClearExpectations(cli_.get()));
+
+ // Resolve the reply asynchronously in a deferred task.
+ task_runner_->PostTask([&moved_reply] {
+ std::unique_ptr<ReplyProto> reply_args(new ReplyProto());
+ reply_args->set_data("bar");
+ moved_reply.Resolve(AsyncResult<ProtoMessage>(
+ std::unique_ptr<ProtoMessage>(reply_args.release())));
+ });
+
+ auto on_reply_received = task_runner_->CreateCheckpoint("on_reply_received");
+ EXPECT_CALL(*cli_, OnInvokeMethodReply(_))
+ .WillOnce(
+ Invoke([on_reply_received](const Frame::InvokeMethodReply& reply) {
+ ASSERT_TRUE(reply.success());
+ ASSERT_FALSE(reply.has_more());
+ ReplyProto reply_args;
+ reply_args.ParseFromString(reply.reply_proto());
+ ASSERT_EQ("bar", reply_args.data());
+ on_reply_received();
+ }));
+ task_runner_->RunUntilCheckpoint("on_reply_received");
+}
+
+// TODO(primiano): add the tests below in next CLs.
+// TEST(HostImplTest, ManyClients) {}
+// TEST(HostImplTest, OverlappingRequstsOutOfOrder) {}
+// TEST(HostImplTest, StreamingRequest) {}
+
+} // namespace
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/ipc_library.gni b/src/ipc/ipc_library.gni
new file mode 100644
index 0000000..3a31c73
--- /dev/null
+++ b/src/ipc/ipc_library.gni
@@ -0,0 +1,49 @@
+# Copyright (C) 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import("../../gn/perfetto.gni")
+import("//build_overrides/build.gni")
+
+if (!build_with_chromium) {
+ import("//build/proto_library.gni")
+} else {
+ import("//third_party/protobuf/proto_library.gni")
+}
+
+# Generates .ipc.{h,cc} stubs for IPC services defined in .proto files.
+template("ipc_library") {
+ proto_library(target_name) {
+ perfetto_root_path = invoker.perfetto_root_path
+
+ generator_plugin_label =
+ perfetto_root_path + "src/ipc/protoc_plugin:ipc_plugin"
+ generator_plugin_suffix = ".ipc"
+
+ deps = [
+ "${perfetto_root_path}src/ipc",
+ ]
+ proto_out_dir = "protos_lite"
+ forward_variables_from(invoker,
+ [
+ "defines",
+ "extra_configs",
+ "include_dirs",
+ "proto_in_dir",
+ "proto_out_dir",
+ "sources",
+ "testonly",
+ "visibility",
+ ])
+ }
+}
diff --git a/src/ipc/protoc_plugin/BUILD.gn b/src/ipc/protoc_plugin/BUILD.gn
new file mode 100644
index 0000000..0530f40
--- /dev/null
+++ b/src/ipc/protoc_plugin/BUILD.gn
@@ -0,0 +1,31 @@
+# Copyright (C) 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if (current_toolchain == host_toolchain) {
+ executable("ipc_plugin") {
+ sources = [
+ "ipc_generator.cc",
+ "ipc_generator.h",
+ "ipc_plugin.cc",
+ ]
+ deps = [
+ "../../../gn:default_deps",
+ "../../../gn:protoc_lib_deps",
+ ]
+ if (is_clang) {
+ # Internal protobuf headers hit this.
+ cflags = [ "-Wno-unreachable-code" ]
+ }
+ }
+} # host_toolchain
diff --git a/src/ipc/protoc_plugin/ipc_generator.cc b/src/ipc/protoc_plugin/ipc_generator.cc
new file mode 100644
index 0000000..b34bcfe
--- /dev/null
+++ b/src/ipc/protoc_plugin/ipc_generator.cc
@@ -0,0 +1,274 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/protoc_plugin/ipc_generator.h"
+
+#include <functional>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "google/protobuf/compiler/cpp/cpp_options.h"
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/descriptor.pb.h"
+#include "google/protobuf/io/printer.h"
+#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/stubs/strutil.h"
+
+namespace perfetto {
+namespace ipc {
+
+using google::protobuf::ServiceDescriptor;
+using google::protobuf::FileDescriptor;
+using google::protobuf::MethodDescriptor;
+using google::protobuf::compiler::GeneratorContext;
+using google::protobuf::io::Printer;
+using google::protobuf::io::ZeroCopyOutputStream;
+
+using google::protobuf::Split;
+using google::protobuf::StripString;
+using google::protobuf::StripSuffixString;
+using google::protobuf::UpperString;
+
+namespace {
+
+static const char kBanner[] = "// DO NOT EDIT. Autogenerated by Perfetto IPC\n";
+
+static const char kHeaderSvcClass[] = R"(
+class $c$ : public ::perfetto::ipc::Service {
+ private:
+ static ::perfetto::ipc::ServiceDescriptor* NewDescriptor();
+
+ public:
+ ~$c$() override;
+
+ static const ::perfetto::ipc::ServiceDescriptor& GetDescriptorStatic();
+
+ // Service implementation.
+ const ::perfetto::ipc::ServiceDescriptor& GetDescriptor() override;
+
+ // Methods from the .proto file
+)";
+
+static const char kHeaderProxyClass[] = R"(
+class $c$Proxy : public ::perfetto::ipc::ServiceProxy {
+ public:
+ explicit $c$Proxy(::perfetto::ipc::ServiceProxy::EventListener*);
+ ~$c$Proxy() override;
+
+ // ServiceProxy implementation.
+ const ::perfetto::ipc::ServiceDescriptor& GetDescriptor() override;
+
+ // Methods from the .proto file
+)";
+
+static const char kCppClassDefinitions[] = R"(
+const ::perfetto::ipc::ServiceDescriptor& $c$::GetDescriptorStatic() {
+ static auto* instance = NewDescriptor();
+ return *instance;
+}
+
+// Host-side definitions.
+$c$::~$c$() = default;
+
+const ::perfetto::ipc::ServiceDescriptor& $c$::GetDescriptor() {
+ return GetDescriptorStatic();
+}
+
+// Client-side definitions.
+$c$Proxy::$c$Proxy(::perfetto::ipc::ServiceProxy::EventListener* event_listener)
+ : ::perfetto::ipc::ServiceProxy(event_listener) {}
+
+$c$Proxy::~$c$Proxy() = default;
+
+const ::perfetto::ipc::ServiceDescriptor& $c$Proxy::GetDescriptor() {
+ return $c$::GetDescriptorStatic();
+}
+)";
+
+static const char kCppMethodDescriptor[] = R"(
+ desc->methods.emplace_back(::perfetto::ipc::ServiceDescriptor::Method{
+ "$m$",
+ &_IPC_Decoder<$i$>,
+ &_IPC_Decoder<$o$>,
+ &_IPC_Invoker<$c$, $i$, $o$, &$c$::$m$>});
+)";
+
+static const char kCppMethod[] = R"(
+void $c$Proxy::$m$(const $i$& request, Deferred$o$ reply) {
+ BeginInvoke("$m$", request, ::perfetto::ipc::DeferredBase(std::move(reply)));
+}
+)";
+
+std::string StripName(const FileDescriptor& file) {
+ return StripSuffixString(file.name(), ".proto");
+}
+
+std::string GetStubName(const FileDescriptor& file) {
+ return StripName(file) + ".ipc";
+}
+
+void ForEachMethod(const ServiceDescriptor& svc,
+ std::function<void(const MethodDescriptor&,
+ const std::string&,
+ const std::string&)> function) {
+ for (int i = 0; i < svc.method_count(); i++) {
+ const MethodDescriptor& method = *svc.method(i);
+ // TODO if the input or output type are in a different namespace we need to
+ // emit the ::fully::qualified::name.
+ std::string input_type = method.input_type()->name();
+ std::string output_type = method.output_type()->name();
+ function(method, input_type, output_type);
+ }
+}
+
+void GenerateServiceHeader(const FileDescriptor& file,
+ const ServiceDescriptor& svc,
+ Printer* printer) {
+ printer->Print("\n");
+ std::vector<std::string> namespaces = Split(file.package(), ".");
+ for (const std::string& ns : namespaces)
+ printer->Print("namespace $ns$ {\n", "ns", ns);
+
+ // Generate the host-side declarations.
+ printer->Print(kHeaderSvcClass, "c", svc.name());
+ std::set<std::string> types_seen;
+ ForEachMethod(svc, [&types_seen, printer](const MethodDescriptor& method,
+ const std::string& input_type,
+ const std::string& output_type) {
+ if (types_seen.count(output_type) == 0) {
+ printer->Print(" using Deferred$o$ = ::perfetto::ipc::Deferred<$o$>;\n",
+ "o", output_type);
+ types_seen.insert(output_type);
+ }
+ printer->Print(" virtual void $m$(const $i$&, Deferred$o$) = 0;\n\n", "m",
+ method.name(), "i", input_type, "o", output_type);
+ });
+ printer->Print("};\n\n");
+
+ // Generate the client-side declarations.
+ printer->Print(kHeaderProxyClass, "c", svc.name());
+ types_seen.clear();
+ ForEachMethod(svc, [&types_seen, printer](const MethodDescriptor& method,
+ const std::string& input_type,
+ const std::string& output_type) {
+ if (types_seen.count(output_type) == 0) {
+ printer->Print(" using Deferred$o$ = ::perfetto::ipc::Deferred<$o$>;\n",
+ "o", output_type);
+ types_seen.insert(output_type);
+ }
+ printer->Print(" void $m$(const $i$&, Deferred$o$);\n\n", "m",
+ method.name(), "i", input_type, "o", output_type);
+ });
+ printer->Print("};\n\n");
+
+ for (auto it = namespaces.rbegin(); it != namespaces.rend(); it++)
+ printer->Print("} // namespace $ns$\n", "ns", *it);
+
+ printer->Print("\n");
+}
+
+void GenerateServiceCpp(const FileDescriptor& file,
+ const ServiceDescriptor& svc,
+ Printer* printer) {
+ printer->Print("\n");
+
+ std::vector<std::string> namespaces = Split(file.package(), ".");
+ for (const std::string& ns : namespaces)
+ printer->Print("namespace $ns$ {\n", "ns", ns);
+
+ printer->Print("::perfetto::ipc::ServiceDescriptor* $c$::NewDescriptor() {\n",
+ "c", svc.name());
+ printer->Print(" auto* desc = new ::perfetto::ipc::ServiceDescriptor();\n");
+ printer->Print(" desc->service_name = \"$c$\";\n", "c", svc.name());
+
+ ForEachMethod(svc, [&svc, printer](const MethodDescriptor& method,
+ const std::string& input_type,
+ const std::string& output_type) {
+ printer->Print(kCppMethodDescriptor, "c", svc.name(), "i", input_type, "o",
+ output_type, "m", method.name());
+ });
+
+ printer->Print(" desc->methods.shrink_to_fit();\n");
+ printer->Print(" return desc;\n");
+ printer->Print("}\n\n");
+
+ printer->Print(kCppClassDefinitions, "c", svc.name());
+
+ ForEachMethod(svc, [&svc, printer](const MethodDescriptor& method,
+ const std::string& input_type,
+ const std::string& output_type) {
+ printer->Print(kCppMethod, "c", svc.name(), "m", method.name(), "i",
+ input_type, "o", output_type);
+ });
+
+ for (auto it = namespaces.rbegin(); it != namespaces.rend(); it++)
+ printer->Print("} // namespace $ns$\n", "ns", *it);
+}
+
+} // namespace
+
+IPCGenerator::IPCGenerator() = default;
+IPCGenerator::~IPCGenerator() = default;
+
+bool IPCGenerator::Generate(const FileDescriptor* file,
+ const std::string& options,
+ GeneratorContext* context,
+ std::string* error) const {
+ if (file->options().cc_generic_services()) {
+ *error = "Please set \"cc_generic_service = false\".";
+ return false;
+ }
+
+ const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
+ context->Open(GetStubName(*file) + ".h"));
+ const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
+ context->Open(GetStubName(*file) + ".cc"));
+
+ // Variables are delimited by $.
+ Printer h_printer(h_fstream.get(), '$');
+ Printer cc_printer(cc_fstream.get(), '$');
+
+ std::string guard = file->package() + "_" + file->name() + "_H_";
+ UpperString(&guard);
+ StripString(&guard, ".-/\\", '_');
+
+ h_printer.Print(kBanner);
+ h_printer.Print("#ifndef $guard$\n#define $guard$\n\n", "guard", guard);
+ h_printer.Print("#include \"$h$\"\n", "h", StripName(*file) + ".pb.h");
+ h_printer.Print("#include \"perfetto/ipc/deferred.h\"\n");
+ h_printer.Print("#include \"perfetto/ipc/service.h\"\n");
+ h_printer.Print("#include \"perfetto/ipc/service_descriptor.h\"\n");
+ h_printer.Print("#include \"perfetto/ipc/service_proxy.h\"\n\n");
+
+ cc_printer.Print(kBanner);
+ cc_printer.Print("#include \"$h$\"\n", "h", GetStubName(*file) + ".h");
+ cc_printer.Print("#include \"perfetto/ipc/codegen_helpers.h\"\n\n");
+ cc_printer.Print("#include <memory>\n");
+
+ for (int i = 0; i < file->service_count(); i++) {
+ const ServiceDescriptor* svc = file->service(i);
+ GenerateServiceHeader(*file, *svc, &h_printer);
+ GenerateServiceCpp(*file, *svc, &cc_printer);
+ }
+
+ h_printer.Print("#endif // $guard$\n", "guard", guard);
+
+ return true;
+}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/protoc_plugin/ipc_generator.h b/src/ipc/protoc_plugin/ipc_generator.h
new file mode 100644
index 0000000..bdd4d12
--- /dev/null
+++ b/src/ipc/protoc_plugin/ipc_generator.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_IPC_PROTOC_PLUGIN_IPC_GENERATOR_H_
+#define SRC_IPC_PROTOC_PLUGIN_IPC_GENERATOR_H_
+
+#include <string>
+
+#include "google/protobuf/compiler/code_generator.h"
+
+namespace perfetto {
+namespace ipc {
+
+class IPCGenerator : public ::google::protobuf::compiler::CodeGenerator {
+ public:
+ explicit IPCGenerator();
+ ~IPCGenerator() override;
+
+ // CodeGenerator implementation
+ bool Generate(const google::protobuf::FileDescriptor* file,
+ const std::string& options,
+ google::protobuf::compiler::GeneratorContext* context,
+ std::string* error) const override;
+};
+
+} // namespace ipc
+} // namespace perfetto
+
+#endif // SRC_IPC_PROTOC_PLUGIN_IPC_GENERATOR_H_
diff --git a/src/ipc/protoc_plugin/ipc_plugin.cc b/src/ipc/protoc_plugin/ipc_plugin.cc
new file mode 100644
index 0000000..82d45a4
--- /dev/null
+++ b/src/ipc/protoc_plugin/ipc_plugin.cc
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "google/protobuf/compiler/plugin.h"
+#include "src/ipc/protoc_plugin/ipc_generator.h"
+
+int main(int argc, char* argv[]) {
+ ::perfetto::ipc::IPCGenerator generator;
+ return google::protobuf::compiler::PluginMain(argc, argv, &generator);
+}
diff --git a/src/ipc/service_proxy.cc b/src/ipc/service_proxy.cc
new file mode 100644
index 0000000..8cdb545
--- /dev/null
+++ b/src/ipc/service_proxy.cc
@@ -0,0 +1,107 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/ipc/service_proxy.h"
+
+#include <utility>
+
+#include "google/protobuf/message_lite.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/weak_ptr.h"
+#include "perfetto/ipc/service_descriptor.h"
+#include "src/ipc/client_impl.h"
+
+namespace perfetto {
+namespace ipc {
+
+ServiceProxy::ServiceProxy(EventListener* event_listener)
+ : weak_ptr_factory_(this), event_listener_(event_listener) {}
+
+ServiceProxy::~ServiceProxy() {
+ if (client_ && connected())
+ client_->UnbindService(service_id_);
+};
+
+void ServiceProxy::InitializeBinding(
+ base::WeakPtr<Client> client,
+ ServiceID service_id,
+ std::map<std::string, MethodID> remote_method_ids) {
+ client_ = client;
+ service_id_ = service_id;
+ remote_method_ids_ = std::move(remote_method_ids);
+}
+
+void ServiceProxy::BeginInvoke(const std::string& method_name,
+ const ProtoMessage& request,
+ DeferredBase reply) {
+ // |reply| will auto-resolve if it gets out of scope early.
+ if (!connected()) {
+ PERFETTO_DCHECK(false);
+ return;
+ }
+ if (!client_)
+ return; // The Client object has been destroyed in the meantime.
+
+ auto remote_method_it = remote_method_ids_.find(method_name);
+ RequestID request_id = 0;
+ if (remote_method_it != remote_method_ids_.end()) {
+ request_id =
+ static_cast<ClientImpl*>(client_.get())
+ ->BeginInvoke(service_id_, method_name, remote_method_it->second,
+ request, weak_ptr_factory_.GetWeakPtr());
+ } else {
+ PERFETTO_DLOG("Cannot find method \"%s\" on the host", method_name.c_str());
+ }
+ if (!request_id)
+ return;
+ PERFETTO_DCHECK(pending_callbacks_.count(request_id) == 0);
+ pending_callbacks_.emplace(request_id, std::move(reply));
+}
+
+void ServiceProxy::EndInvoke(RequestID request_id,
+ std::unique_ptr<ProtoMessage> result,
+ bool has_more) {
+ auto callback_it = pending_callbacks_.find(request_id);
+ if (callback_it == pending_callbacks_.end()) {
+ PERFETTO_DCHECK(false);
+ return;
+ }
+ DeferredBase& reply_callback = callback_it->second;
+ AsyncResult<ProtoMessage> reply(std::move(result), has_more);
+ reply_callback.Resolve(std::move(reply));
+ if (!has_more)
+ pending_callbacks_.erase(callback_it);
+}
+
+void ServiceProxy::OnConnect(bool success) {
+ if (success) {
+ PERFETTO_DCHECK(service_id_);
+ return event_listener_->OnConnect();
+ }
+ return event_listener_->OnDisconnect();
+}
+
+void ServiceProxy::OnDisconnect() {
+ pending_callbacks_.clear(); // Will Reject() all the pending callbacks.
+ event_listener_->OnDisconnect();
+}
+
+base::WeakPtr<ServiceProxy> ServiceProxy::GetWeakPtr() const {
+ return weak_ptr_factory_.GetWeakPtr();
+}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/test/client_unittest_messages.proto b/src/ipc/test/client_unittest_messages.proto
new file mode 100644
index 0000000..6ff32e8
--- /dev/null
+++ b/src/ipc/test/client_unittest_messages.proto
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+option optimize_for = LITE_RUNTIME;
+
+package perfetto.ipc;
+
+message RequestProto {
+ string data = 1;
+}
+
+message ReplyProto {
+ string data = 2; // 2 here is deliberately != RequestProto.data ID (1).
+}
diff --git a/src/ipc/test/deferred_unittest_messages.proto b/src/ipc/test/deferred_unittest_messages.proto
new file mode 100644
index 0000000..dc52a23
--- /dev/null
+++ b/src/ipc/test/deferred_unittest_messages.proto
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+option optimize_for = LITE_RUNTIME;
+
+package perfetto.ipc;
+
+message TestMessage {
+ int32 num = 1;
+ string str = 2;
+}
diff --git a/src/ipc/test/greeter_service.proto b/src/ipc/test/greeter_service.proto
new file mode 100644
index 0000000..1ac1167
--- /dev/null
+++ b/src/ipc/test/greeter_service.proto
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+option optimize_for = LITE_RUNTIME;
+
+// Deliberately a namespace != of perfetto.* to spot namespace dependencies
+// bugs in the autogenerated headers.
+package ipc_test;
+
+service Greeter {
+ rpc SayHello(GreeterRequestMsg) returns (GreeterReplyMsg) {}
+ rpc WaveGoodbye(GreeterRequestMsg) returns (GreeterReplyMsg) {}
+}
+
+message GreeterRequestMsg {
+ string name = 1;
+}
+
+message GreeterReplyMsg {
+ string message = 1;
+}
diff --git a/src/ipc/test/ipc_integrationtest.cc b/src/ipc/test/ipc_integrationtest.cc
new file mode 100644
index 0000000..6addb28
--- /dev/null
+++ b/src/ipc/test/ipc_integrationtest.cc
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2017 The Android Open foo Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "perfetto/ipc/client.h"
+#include "perfetto/ipc/host.h"
+#include "src/base/test/test_task_runner.h"
+
+#include "src/ipc/test/greeter_service.ipc.h"
+#include "src/ipc/test/greeter_service.pb.h"
+
+namespace ipc_test {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::perfetto::ipc::AsyncResult;
+using ::perfetto::ipc::Client;
+using ::perfetto::ipc::Deferred;
+using ::perfetto::ipc::Host;
+using ::perfetto::ipc::Service;
+using ::perfetto::ipc::ServiceProxy;
+
+constexpr char kSockName[] = "/tmp/perfetto_ipc_test.sock";
+
+class MockEventListener : public ServiceProxy::EventListener {
+ public:
+ MOCK_METHOD0(OnConnect, void());
+ MOCK_METHOD0(OnDisconnect, void());
+};
+
+class MockGreeterService : public ipc_test::Greeter {
+ public:
+ MOCK_METHOD2(OnSayHello,
+ void(const GreeterRequestMsg&, DeferredGreeterReplyMsg*));
+ void SayHello(const GreeterRequestMsg& request,
+ DeferredGreeterReplyMsg reply) override {
+ OnSayHello(request, &reply);
+ }
+
+ MOCK_METHOD2(OnWaveGoodbye,
+ void(const GreeterRequestMsg&, DeferredGreeterReplyMsg*));
+ void WaveGoodbye(const GreeterRequestMsg& request,
+ DeferredGreeterReplyMsg reply) override {
+ OnWaveGoodbye(request, &reply);
+ }
+};
+
+class IPCIntegrationTest : public ::testing::Test {
+ protected:
+ void SetUp() override { unlink(kSockName); }
+ void TearDown() override { unlink(kSockName); }
+
+ perfetto::base::TestTaskRunner task_runner_;
+ MockEventListener svc_proxy_events_;
+};
+
+TEST_F(IPCIntegrationTest, SayHelloWaveGoodbye) {
+ std::unique_ptr<Host> host = Host::CreateInstance(kSockName, &task_runner_);
+ ASSERT_TRUE(host);
+
+ MockGreeterService* svc = new MockGreeterService();
+ ASSERT_TRUE(host->ExposeService(std::unique_ptr<Service>(svc)));
+
+ auto on_connect = task_runner_.CreateCheckpoint("on_connect");
+ EXPECT_CALL(svc_proxy_events_, OnConnect()).WillOnce(Invoke(on_connect));
+ std::unique_ptr<Client> cli =
+ Client::CreateInstance(kSockName, &task_runner_);
+ std::unique_ptr<GreeterProxy> svc_proxy(new GreeterProxy(&svc_proxy_events_));
+ cli->BindService(svc_proxy->GetWeakPtr());
+ task_runner_.RunUntilCheckpoint("on_connect");
+
+ {
+ GreeterRequestMsg req;
+ req.set_name("Mr Bojangles");
+ auto on_reply = task_runner_.CreateCheckpoint("on_hello_reply");
+ Deferred<GreeterReplyMsg> deferred_reply(
+ [on_reply](AsyncResult<GreeterReplyMsg> reply) {
+ ASSERT_TRUE(reply.success());
+ ASSERT_FALSE(reply.has_more());
+ ASSERT_EQ("Hello Mr Bojangles", reply->message());
+ on_reply();
+ });
+
+ EXPECT_CALL(*svc, OnSayHello(_, _))
+ .WillOnce(Invoke([](const GreeterRequestMsg& host_req,
+ Deferred<GreeterReplyMsg>* host_reply) {
+ auto reply = AsyncResult<GreeterReplyMsg>::Create();
+ reply->set_message("Hello " + host_req.name());
+ host_reply->Resolve(std::move(reply));
+ }));
+ svc_proxy->SayHello(req, std::move(deferred_reply));
+ task_runner_.RunUntilCheckpoint("on_hello_reply");
+ }
+
+ {
+ GreeterRequestMsg req;
+ req.set_name("Mrs Bojangles");
+ auto on_reply = task_runner_.CreateCheckpoint("on_goodbye_reply");
+ Deferred<GreeterReplyMsg> deferred_reply(
+ [on_reply](AsyncResult<GreeterReplyMsg> reply) {
+ ASSERT_TRUE(reply.success());
+ ASSERT_FALSE(reply.has_more());
+ ASSERT_EQ("Goodbye Mrs Bojangles", reply->message());
+ on_reply();
+ });
+
+ EXPECT_CALL(*svc, OnWaveGoodbye(_, _))
+ .WillOnce(Invoke([](const GreeterRequestMsg& host_req,
+ Deferred<GreeterReplyMsg>* host_reply) {
+ auto reply = AsyncResult<GreeterReplyMsg>::Create();
+ reply->set_message("Goodbye " + host_req.name());
+ host_reply->Resolve(std::move(reply));
+ }));
+ svc_proxy->WaveGoodbye(req, std::move(deferred_reply));
+ task_runner_.RunUntilCheckpoint("on_goodbye_reply");
+ }
+}
+
+} // namespace
+} // namespace ipc_test
diff --git a/src/ipc/unix_socket.cc b/src/ipc/unix_socket.cc
new file mode 100644
index 0000000..e28f3dc
--- /dev/null
+++ b/src/ipc/unix_socket.cc
@@ -0,0 +1,450 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/unix_socket.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <memory>
+
+#include "perfetto/base/build_config.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/task_runner.h"
+#include "perfetto/base/utils.h"
+
+#if BUILDFLAG(OS_MACOSX)
+#include <sys/ucred.h>
+#endif
+
+namespace perfetto {
+namespace ipc {
+
+// TODO(primiano): Add ThreadChecker to methods of this class.
+
+namespace {
+
+// MSG_NOSIGNAL is not supported on Mac OS X, but in that case the socket is
+// created with SO_NOSIGPIPE (See InitializeSocket()).
+#if BUILDFLAG(OS_MACOSX)
+constexpr int kNoSigPipe = 0;
+#else
+constexpr int kNoSigPipe = MSG_NOSIGNAL;
+#endif
+
+// Android takes an int instead of socklen_t for the control buffer size.
+#if BUILDFLAG(OS_ANDROID)
+using CBufLenType = size_t;
+#else
+using CBufLenType = socklen_t;
+#endif
+
+bool MakeSockAddr(const std::string& socket_name,
+ sockaddr_un* addr,
+ socklen_t* addr_size) {
+ memset(addr, 0, sizeof(*addr));
+ const size_t name_len = socket_name.size();
+ if (name_len >= sizeof(addr->sun_path)) {
+ errno = ENAMETOOLONG;
+ return false;
+ }
+ memcpy(addr->sun_path, socket_name.data(), name_len);
+ if (addr->sun_path[0] == '@')
+ addr->sun_path[0] = '\0';
+ addr->sun_family = AF_UNIX;
+ *addr_size = static_cast<socklen_t>(
+ __builtin_offsetof(sockaddr_un, sun_path) + name_len + 1);
+ return true;
+}
+
+} // namespace
+
+// static
+std::unique_ptr<UnixSocket> UnixSocket::Listen(const std::string& socket_name,
+ EventListener* event_listener,
+ base::TaskRunner* task_runner) {
+ std::unique_ptr<UnixSocket> sock(new UnixSocket(event_listener, task_runner));
+ sock->DoListen(socket_name);
+ return sock;
+}
+
+// static
+std::unique_ptr<UnixSocket> UnixSocket::Connect(const std::string& socket_name,
+ EventListener* event_listener,
+ base::TaskRunner* task_runner) {
+ std::unique_ptr<UnixSocket> sock(new UnixSocket(event_listener, task_runner));
+ sock->DoConnect(socket_name);
+ return sock;
+}
+
+UnixSocket::UnixSocket(EventListener* event_listener,
+ base::TaskRunner* task_runner)
+ : UnixSocket(event_listener, task_runner, base::ScopedFile()) {}
+
+UnixSocket::UnixSocket(EventListener* event_listener,
+ base::TaskRunner* task_runner,
+ base::ScopedFile adopt_fd)
+ : event_listener_(event_listener),
+ task_runner_(task_runner),
+ weak_ptr_factory_(this) {
+ if (adopt_fd) {
+ // Only in the case of OnNewIncomingConnection().
+ fd_ = std::move(adopt_fd);
+ state_ = State::kConnected;
+ ReadPeerCredentials();
+ } else {
+ fd_.reset(socket(AF_UNIX, SOCK_STREAM, 0));
+ }
+ if (!fd_) {
+ last_error_ = errno;
+ return;
+ }
+
+#if BUILDFLAG(OS_MACOSX)
+ const int no_sigpipe = 1;
+ setsockopt(*fd_, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe));
+#endif
+ // There is no reason why a socket should outlive the process in case of
+ // exec() by default, this is just working around a broken unix design.
+ int fcntl_res = fcntl(*fd_, F_SETFD, FD_CLOEXEC);
+ PERFETTO_CHECK(fcntl_res == 0);
+
+ // Set non-blocking mode.
+ int flags = fcntl(*fd_, F_GETFL, 0);
+ flags |= O_NONBLOCK;
+ fcntl_res = fcntl(fd(), F_SETFL, flags);
+ PERFETTO_CHECK(fcntl_res == 0);
+
+ base::WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ task_runner_->AddFileDescriptorWatch(*fd_, [weak_ptr]() {
+ if (weak_ptr)
+ weak_ptr->OnEvent();
+ });
+}
+
+UnixSocket::~UnixSocket() {
+ // The implicit dtor of |weak_ptr_factory_| will no-op pending callbacks.
+ Shutdown();
+}
+
+// Called only by the Listen() static constructor.
+void UnixSocket::DoListen(const std::string& socket_name) {
+ PERFETTO_DCHECK(state_ == State::kDisconnected);
+ if (!fd_)
+ return; // This is the only thing that can gracefully fail in the ctor.
+
+ sockaddr_un addr;
+ socklen_t addr_size;
+ if (!MakeSockAddr(socket_name, &addr, &addr_size)) {
+ last_error_ = errno;
+ return;
+ }
+
+// Android takes an int as 3rd argument of bind() instead of socklen_t.
+#if BUILDFLAG(OS_ANDROID)
+ const int bind_size = static_cast<int>(addr_size);
+#else
+ const socklen_t bind_size = addr_size;
+#endif
+
+ if (bind(*fd_, reinterpret_cast<sockaddr*>(&addr), bind_size)) {
+ last_error_ = errno;
+ PERFETTO_DPLOG("bind()");
+ return;
+ }
+ if (listen(*fd_, SOMAXCONN)) {
+ last_error_ = errno;
+ PERFETTO_DPLOG("listen()");
+ return;
+ }
+
+ last_error_ = 0;
+ state_ = State::kListening;
+}
+
+// Called only by the Connect() static constructor.
+void UnixSocket::DoConnect(const std::string& socket_name) {
+ PERFETTO_DCHECK(state_ == State::kDisconnected);
+
+ // This is the only thing that can gracefully fail in the ctor.
+ if (!fd_)
+ return NotifyConnectionState(false);
+
+ sockaddr_un addr;
+ socklen_t addr_size;
+ if (!MakeSockAddr(socket_name, &addr, &addr_size)) {
+ last_error_ = errno;
+ return NotifyConnectionState(false);
+ }
+
+ int res = PERFETTO_EINTR(
+ connect(*fd_, reinterpret_cast<sockaddr*>(&addr), addr_size));
+ if (res && errno != EINPROGRESS) {
+ last_error_ = errno;
+ return NotifyConnectionState(false);
+ }
+
+ // At this point either |res| == 0 (the connect() succeeded) or started
+ // asynchronously (EINPROGRESS).
+ last_error_ = 0;
+ state_ = State::kConnecting;
+
+ // Even if the socket is non-blocking, connecting to a UNIX socket can be
+ // acknowledged straight away rather than returning EINPROGRESS. In this case
+ // just trigger an OnEvent without waiting for the FD watch. That will poll
+ // the SO_ERROR and evolve the state into either kConnected or kDisconnected.
+ if (res == 0) {
+ base::WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ task_runner_->PostTask([weak_ptr]() {
+ if (weak_ptr)
+ weak_ptr->OnEvent();
+ });
+ }
+}
+
+void UnixSocket::ReadPeerCredentials() {
+#if BUILDFLAG(OS_LINUX) || BUILDFLAG(OS_ANDROID)
+ struct ucred user_cred;
+ socklen_t len = sizeof(user_cred);
+ int res = getsockopt(*fd_, SOL_SOCKET, SO_PEERCRED, &user_cred, &len);
+ PERFETTO_CHECK(res == 0);
+ peer_uid_ = user_cred.uid;
+#else
+ struct xucred user_cred;
+ socklen_t len = sizeof(user_cred);
+ int res = getsockopt(*fd_, 0, LOCAL_PEERCRED, &user_cred, &len);
+ PERFETTO_CHECK(res == 0 && user_cred.cr_version == XUCRED_VERSION);
+ peer_uid_ = user_cred.cr_uid;
+#endif
+}
+
+void UnixSocket::OnEvent() {
+ if (state_ == State::kDisconnected)
+ return; // Some spurious event, typically queued just before Shutdown().
+
+ if (state_ == State::kConnected)
+ return event_listener_->OnDataAvailable(this);
+
+ if (state_ == State::kConnecting) {
+ PERFETTO_DCHECK(fd_);
+ int sock_err = EINVAL;
+ socklen_t err_len = sizeof(sock_err);
+ int res = getsockopt(*fd_, SOL_SOCKET, SO_ERROR, &sock_err, &err_len);
+ if (res == 0 && sock_err == EINPROGRESS)
+ return; // Not connected yet, just a spurious FD watch wakeup.
+ if (res == 0 && sock_err == 0) {
+ ReadPeerCredentials();
+ state_ = State::kConnected;
+ return event_listener_->OnConnect(this, true /* connected */);
+ }
+ last_error_ = sock_err;
+ return event_listener_->OnConnect(this, false /* connected */);
+ }
+
+ // New incoming connection.
+ if (state_ == State::kListening) {
+ // There could be more than one incoming connection behind each FD watch
+ // notification. Drain'em all.
+ for (;;) {
+ sockaddr_un cli_addr = {};
+ socklen_t size = sizeof(cli_addr);
+ base::ScopedFile new_fd(PERFETTO_EINTR(
+ accept(*fd_, reinterpret_cast<sockaddr*>(&cli_addr), &size)));
+ if (!new_fd)
+ return;
+ std::unique_ptr<UnixSocket> new_sock(
+ new UnixSocket(event_listener_, task_runner_, std::move(new_fd)));
+ event_listener_->OnNewIncomingConnection(this, std::move(new_sock));
+ }
+ }
+}
+
+bool UnixSocket::Send(const std::string& msg) {
+ return Send(msg.c_str(), msg.size() + 1);
+}
+
+bool UnixSocket::Send(const void* msg, size_t len, int send_fd) {
+ if (state_ != State::kConnected) {
+ errno = last_error_ = ENOTCONN;
+ return false;
+ }
+
+ msghdr msg_hdr = {};
+ iovec iov = {const_cast<void*>(msg), len};
+ msg_hdr.msg_iov = &iov;
+ msg_hdr.msg_iovlen = 1;
+ alignas(cmsghdr) char control_buf[256];
+
+ if (send_fd > -1) {
+ const CBufLenType control_buf_len =
+ static_cast<CBufLenType>(CMSG_SPACE(sizeof(int)));
+ PERFETTO_CHECK(control_buf_len <= sizeof(control_buf));
+ memset(control_buf, 0, sizeof(control_buf));
+ msg_hdr.msg_control = control_buf;
+ msg_hdr.msg_controllen = control_buf_len;
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg_hdr);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+ memcpy(CMSG_DATA(cmsg), &send_fd, sizeof(int));
+ msg_hdr.msg_controllen = cmsg->cmsg_len;
+ }
+
+ const ssize_t sz = PERFETTO_EINTR(sendmsg(*fd_, &msg_hdr, kNoSigPipe));
+ if (sz >= 0) {
+ // There should be no way a non-blocking socket returns < |len|.
+ // If the queueing fails, sendmsg() must return -1 + errno = EWOULDBLOCK.
+ PERFETTO_CHECK(static_cast<size_t>(sz) == len);
+ last_error_ = 0;
+ return true;
+ }
+
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ // A genuine out-of-buffer. The client should retry or give up.
+ // Man pages specify that EAGAIN and EWOULDBLOCK have the same semantic here
+ // and clients should check for both.
+ last_error_ = EAGAIN;
+ return false;
+ }
+
+ // Either the the other endpoint disconnect (ECONNRESET) or some other error
+ // happened.
+ last_error_ = errno;
+ PERFETTO_DPLOG("sendmsg() failed");
+ Shutdown();
+ return false;
+}
+
+void UnixSocket::Shutdown() {
+ base::WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ if (state_ == State::kConnected) {
+ task_runner_->PostTask([weak_ptr]() {
+ if (weak_ptr)
+ weak_ptr->event_listener_->OnDisconnect(weak_ptr.get());
+ });
+ } else if (state_ == State::kConnecting) {
+ task_runner_->PostTask([weak_ptr]() {
+ if (weak_ptr)
+ weak_ptr->event_listener_->OnConnect(weak_ptr.get(), false);
+ });
+ }
+ if (fd_) {
+ shutdown(*fd_, SHUT_RDWR);
+ task_runner_->RemoveFileDescriptorWatch(*fd_);
+ fd_.reset();
+ }
+ state_ = State::kDisconnected;
+}
+
+size_t UnixSocket::Receive(void* msg, size_t len, base::ScopedFile* recv_fd) {
+ if (state_ != State::kConnected) {
+ last_error_ = ENOTCONN;
+ return 0;
+ }
+
+ msghdr msg_hdr = {};
+ iovec iov = {msg, len};
+ msg_hdr.msg_iov = &iov;
+ msg_hdr.msg_iovlen = 1;
+ alignas(cmsghdr) char control_buf[256];
+
+ if (recv_fd) {
+ msg_hdr.msg_control = control_buf;
+ msg_hdr.msg_controllen = static_cast<CBufLenType>(CMSG_SPACE(sizeof(int)));
+ PERFETTO_CHECK(msg_hdr.msg_controllen <= sizeof(control_buf));
+ }
+ const ssize_t sz = PERFETTO_EINTR(recvmsg(*fd_, &msg_hdr, kNoSigPipe));
+ if (sz < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+ last_error_ = EAGAIN;
+ return 0;
+ }
+ if (sz <= 0) {
+ last_error_ = errno;
+ Shutdown();
+ return 0;
+ }
+ PERFETTO_CHECK(static_cast<size_t>(sz) <= len);
+
+ int* fds = nullptr;
+ uint32_t fds_len = 0;
+
+ if (msg_hdr.msg_controllen > 0) {
+ for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg_hdr); cmsg;
+ cmsg = CMSG_NXTHDR(&msg_hdr, cmsg)) {
+ const size_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
+ if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+ PERFETTO_DCHECK(payload_len % sizeof(int) == 0u);
+ PERFETTO_DCHECK(fds == nullptr);
+ fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
+ fds_len = static_cast<uint32_t>(payload_len / sizeof(int));
+ }
+ }
+ }
+
+ if (msg_hdr.msg_flags & MSG_TRUNC || msg_hdr.msg_flags & MSG_CTRUNC) {
+ for (size_t i = 0; fds && i < fds_len; ++i)
+ close(fds[i]);
+ last_error_ = EMSGSIZE;
+ Shutdown();
+ return 0;
+ }
+
+ for (size_t i = 0; fds && i < fds_len; ++i) {
+ if (recv_fd && i == 0) {
+ recv_fd->reset(fds[i]);
+ } else {
+ close(fds[i]);
+ }
+ }
+
+ last_error_ = 0;
+ return static_cast<size_t>(sz);
+}
+
+std::string UnixSocket::ReceiveString(size_t max_length) {
+ std::unique_ptr<char[]> buf(new char[max_length + 1]);
+ size_t rsize = Receive(buf.get(), max_length);
+ PERFETTO_CHECK(static_cast<size_t>(rsize) <= max_length);
+ buf[static_cast<size_t>(rsize)] = '\0';
+ return std::string(buf.get());
+}
+
+void UnixSocket::NotifyConnectionState(bool success) {
+ base::WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ task_runner_->PostTask([weak_ptr, success]() {
+ if (weak_ptr)
+ weak_ptr->event_listener_->OnConnect(weak_ptr.get(), success);
+ });
+}
+
+UnixSocket::EventListener::~EventListener() {}
+void UnixSocket::EventListener::OnNewIncomingConnection(
+ UnixSocket*,
+ std::unique_ptr<UnixSocket>) {}
+void UnixSocket::EventListener::OnConnect(UnixSocket*, bool) {}
+void UnixSocket::EventListener::OnDisconnect(UnixSocket*) {}
+void UnixSocket::EventListener::OnDataAvailable(UnixSocket*) {}
+
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/unix_socket.h b/src/ipc/unix_socket.h
new file mode 100644
index 0000000..807101f
--- /dev/null
+++ b/src/ipc/unix_socket.h
@@ -0,0 +1,196 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_IPC_UNIX_SOCKET_H_
+#define SRC_IPC_UNIX_SOCKET_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include <memory>
+#include <string>
+
+#include "perfetto/base/logging.h"
+#include "perfetto/base/scoped_file.h"
+#include "perfetto/base/weak_ptr.h"
+
+namespace perfetto {
+
+namespace base {
+class TaskRunner;
+} // namespace base.
+
+namespace ipc {
+
+// A non-blocking UNIX domain socket in SOCK_STREAM mode. Allows also to
+// transfer file descriptors. None of the methods in this class are blocking.
+// The main design goal is API simplicity and strong guarantees on the
+// EventListener callbacks, in order to avoid ending in some undefined state.
+// In case of any error it will aggressively just shut down the socket and
+// notify the failure with OnConnect(false) or OnDisconnect() depending on the
+// state of the socket (see below).
+// EventListener callbacks stop happening as soon as the instance is destroyed.
+//
+// Lifecycle of a client socket:
+//
+// Connect()
+// |
+// +------------------+------------------+
+// | (success) | (failure or Shutdown())
+// V V
+// OnConnect(true) OnConnect(false)
+// |
+// V
+// OnDataAvailable()
+// |
+// V
+// OnDisconnect() (failure or shutdown)
+//
+//
+// Lifecycle of a server socket:
+//
+// Listen() --> returns false in case of errors.
+// |
+// V
+// OnNewIncomingConnection(new_socket)
+//
+// (|new_socket| inherits the same EventListener)
+// |
+// V
+// OnDataAvailable()
+// | (failure or Shutdown())
+// V
+// OnDisconnect()
+class UnixSocket {
+ public:
+ class EventListener {
+ public:
+ virtual ~EventListener();
+
+ // After Listen().
+ virtual void OnNewIncomingConnection(
+ UnixSocket* self,
+ std::unique_ptr<UnixSocket> new_connection);
+
+ // After Connect(), whether successful or not.
+ virtual void OnConnect(UnixSocket* self, bool connected);
+
+ // After a successful Connect() or OnNewIncomingConnection(). Either the
+ // other endpoint did disconnect or some other error happened.
+ virtual void OnDisconnect(UnixSocket* self);
+
+ // Whenever there is data available to Receive(). Note that spurious FD
+ // watch events are possible, so it is possible that Receive() soon after
+ // OnDataAvailable() returns 0 (just ignore those).
+ virtual void OnDataAvailable(UnixSocket* self);
+ };
+
+ enum class State {
+ kDisconnected = 0, // Failed connection, peer disconnection or Shutdown().
+ kConnecting, // Soon after Connect(), before it either succeeds or fails.
+ kConnected, // After a successful Connect().
+ kListening // After Listen(), until Shutdown().
+ };
+
+ // Creates a Unix domain socket and starts listening. If |socket_name|
+ // starts with a '@', an abstract socket will be created (Linux/Android only).
+ // Returns always an instance. In case of failure (e.g., another socket
+ // with the same name is already listening) the returned socket will have
+ // is_listening() == false and last_error() will contain the failure reason.
+ static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
+ EventListener*,
+ base::TaskRunner*);
+
+ // Creates a Unix domain socket and connects to the listening endpoint.
+ // Returns always an instance. EventListener::OnConnect(bool success) will
+ // be called always, whether the connection succeeded or not.
+ static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
+ EventListener*,
+ base::TaskRunner*);
+
+ // This class gives the hard guarantee that no callback is called on the
+ // passed EventListener immediately after the object has been destroyed.
+ // Any queued callback will be silently dropped.
+ ~UnixSocket();
+
+ // Shuts down the current connection, if any. If the socket was Listen()-ing,
+ // stops listening. The socket goes back to kNotInitialized state, so it can
+ // be reused with Listen() or Connect().
+ void Shutdown();
+
+ // Returns true is the message was queued, false if there was no space in the
+ // output buffer, in which case the client should retry or give up.
+ // If any other error happens the socket will be shutdown and
+ // EventListener::OnDisconnect() will be called.
+ // If the socket is not connected, Send() will just return false.
+ // Does not append a null string terminator to msg in any case.
+ bool Send(const void* msg, size_t len, int send_fd = -1);
+ bool Send(const std::string& msg);
+
+ // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
+ // is no data in the buffer to read or an error occurs (in which case a
+ // EventListener::OnDisconnect() will follow).
+ // If the ScopedFile pointer is not null and a FD is received, it moves the
+ // received FD into that. If a FD is received but the ScopedFile pointer is
+ // null, the FD will be automatically closed.
+ size_t Receive(void* msg, size_t len, base::ScopedFile* = nullptr);
+
+ // Only for tests. This is slower than Receive() as it requires a heap
+ // allocation and a copy for the std::string. Guarantees that the returned
+ // string is null terminated even if the underlying message sent by the peer
+ // is not.
+ std::string ReceiveString(size_t max_length = 1024);
+
+ bool is_connected() const { return state_ == State::kConnected; }
+ bool is_listening() const { return state_ == State::kListening; }
+ int fd() const { return fd_.get(); }
+ int last_error() const { return last_error_; }
+
+ // User ID of the peer, as returned by the kernel. If the client disconnects
+ // and the socket goes into the kDisconnected state, it retains the uid of
+ // the last peer.
+ int peer_uid() const {
+ PERFETTO_DCHECK(!is_listening() && peer_uid_ >= 0);
+ return peer_uid_;
+ }
+
+ private:
+ UnixSocket(EventListener*, base::TaskRunner*);
+ UnixSocket(EventListener*, base::TaskRunner*, base::ScopedFile);
+ UnixSocket(const UnixSocket&) = delete;
+ UnixSocket& operator=(const UnixSocket&) = delete;
+
+ // Called once by the corresponding public static factory methods.
+ void DoConnect(const std::string& socket_name);
+ void DoListen(const std::string& socket_name);
+ void ReadPeerCredentials();
+
+ void OnEvent();
+ void NotifyConnectionState(bool success);
+
+ base::ScopedFile fd_;
+ State state_ = State::kDisconnected;
+ int last_error_ = 0;
+ int peer_uid_ = -1;
+ EventListener* event_listener_;
+ base::TaskRunner* task_runner_;
+ base::WeakPtrFactory<UnixSocket> weak_ptr_factory_;
+};
+
+} // namespace ipc
+} // namespace perfetto
+
+#endif // SRC_IPC_UNIX_SOCKET_H_
diff --git a/src/ipc/unix_socket_unittest.cc b/src/ipc/unix_socket_unittest.cc
new file mode 100644
index 0000000..e5586a8
--- /dev/null
+++ b/src/ipc/unix_socket_unittest.cc
@@ -0,0 +1,431 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/ipc/unix_socket.h"
+
+#include <sys/mman.h>
+
+#include <list>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "perfetto/base/build_config.h"
+#include "perfetto/base/logging.h"
+#include "perfetto/base/utils.h"
+#include "src/base/test/test_task_runner.h"
+
+namespace perfetto {
+namespace ipc {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::Mock;
+
+// Mac OS X doesn't support abstract (i.e. unnamed) sockets.
+#if BUILDFLAG(OS_MACOSX)
+static const char kSocketName[] = "/tmp/test_socket";
+void UnlinkSocket() {
+ unlink(kSocketName);
+}
+#else
+static const char kSocketName[] = "@test_socket";
+void UnlinkSocket() {}
+#endif
+
+class MockEventListener : public UnixSocket::EventListener {
+ public:
+ MOCK_METHOD2(OnNewIncomingConnection, void(UnixSocket*, UnixSocket*));
+ MOCK_METHOD2(OnConnect, void(UnixSocket*, bool));
+ MOCK_METHOD1(OnDisconnect, void(UnixSocket*));
+ MOCK_METHOD1(OnDataAvailable, void(UnixSocket*));
+
+ // GMock doesn't support mocking methods with non-copiable args.
+ void OnNewIncomingConnection(
+ UnixSocket* self,
+ std::unique_ptr<UnixSocket> new_connection) override {
+ incoming_connections_.emplace_back(std::move(new_connection));
+ OnNewIncomingConnection(self, incoming_connections_.back().get());
+ }
+
+ std::unique_ptr<UnixSocket> GetIncomingConnection() {
+ if (incoming_connections_.empty())
+ return nullptr;
+ std::unique_ptr<UnixSocket> sock = std::move(incoming_connections_.front());
+ incoming_connections_.pop_front();
+ return sock;
+ }
+
+ private:
+ std::list<std::unique_ptr<UnixSocket>> incoming_connections_;
+};
+
+class UnixSocketTest : public ::testing::Test {
+ protected:
+ void SetUp() override { UnlinkSocket(); }
+ void TearDown() override { UnlinkSocket(); }
+
+ base::TestTaskRunner task_runner_;
+ MockEventListener event_listener_;
+};
+
+TEST_F(UnixSocketTest, ConnectionFailureIfUnreachable) {
+ auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_FALSE(cli->is_connected());
+ auto checkpoint = task_runner_.CreateCheckpoint("failure");
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), false))
+ .WillOnce(Invoke([checkpoint](UnixSocket*, bool) { checkpoint(); }));
+ task_runner_.RunUntilCheckpoint("failure");
+}
+
+// Both server and client should see an OnDisconnect() if the server drops
+// incoming connections immediately as they are created.
+TEST_F(UnixSocketTest, ConnectionImmediatelyDroppedByServer) {
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+
+ // The server will immediately shutdown the connection upon
+ // OnNewIncomingConnection().
+ auto srv_did_shutdown = task_runner_.CreateCheckpoint("srv_did_shutdown");
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .WillOnce(
+ Invoke([this, srv_did_shutdown](UnixSocket*, UnixSocket* new_conn) {
+ EXPECT_CALL(event_listener_, OnDisconnect(new_conn));
+ new_conn->Shutdown();
+ srv_did_shutdown();
+ }));
+
+ auto checkpoint = task_runner_.CreateCheckpoint("cli_connected");
+ auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
+ .WillOnce(Invoke([checkpoint](UnixSocket*, bool) { checkpoint(); }));
+ task_runner_.RunUntilCheckpoint("cli_connected");
+ task_runner_.RunUntilCheckpoint("srv_did_shutdown");
+
+ // Trying to send something will trigger the disconnection notification.
+ auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
+ EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
+ .WillOnce(
+ Invoke([cli_disconnected](UnixSocket*) { cli_disconnected(); }));
+ EXPECT_FALSE(cli->Send("whatever"));
+ task_runner_.RunUntilCheckpoint("cli_disconnected");
+}
+
+TEST_F(UnixSocketTest, ClientAndServerExchangeData) {
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+
+ auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
+ auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
+ auto srv_disconnected = task_runner_.CreateCheckpoint("srv_disconnected");
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .WillOnce(Invoke([this, cli_connected, srv_disconnected](
+ UnixSocket*, UnixSocket* srv_conn) {
+ EXPECT_CALL(event_listener_, OnDisconnect(srv_conn))
+ .WillOnce(Invoke(
+ [srv_disconnected](UnixSocket*) { srv_disconnected(); }));
+ cli_connected();
+ }));
+ task_runner_.RunUntilCheckpoint("cli_connected");
+
+ auto srv_conn = event_listener_.GetIncomingConnection();
+ ASSERT_TRUE(srv_conn);
+ ASSERT_TRUE(cli->is_connected());
+
+ auto cli_did_recv = task_runner_.CreateCheckpoint("cli_did_recv");
+ EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
+ .WillOnce(Invoke([cli_did_recv](UnixSocket* s) {
+ ASSERT_EQ("srv>cli", s->ReceiveString());
+ cli_did_recv();
+ }));
+
+ auto srv_did_recv = task_runner_.CreateCheckpoint("srv_did_recv");
+ EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn.get()))
+ .WillOnce(Invoke([srv_did_recv](UnixSocket* s) {
+ ASSERT_EQ("cli>srv", s->ReceiveString());
+ srv_did_recv();
+ }));
+ ASSERT_TRUE(cli->Send("cli>srv"));
+ ASSERT_TRUE(srv_conn->Send("srv>cli"));
+ task_runner_.RunUntilCheckpoint("cli_did_recv");
+ task_runner_.RunUntilCheckpoint("srv_did_recv");
+
+ // Check that Send/Receive() fails gracefully once the socket is closed.
+ auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
+ EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
+ .WillOnce(
+ Invoke([cli_disconnected](UnixSocket*) { cli_disconnected(); }));
+ cli->Shutdown();
+ char msg[4];
+ ASSERT_EQ(0u, cli->Receive(&msg, sizeof(msg)));
+ ASSERT_EQ("", cli->ReceiveString());
+ ASSERT_EQ(0u, srv_conn->Receive(&msg, sizeof(msg)));
+ ASSERT_EQ("", srv_conn->ReceiveString());
+ ASSERT_FALSE(cli->Send("foo"));
+ ASSERT_FALSE(srv_conn->Send("bar"));
+ srv->Shutdown();
+ task_runner_.RunUntilCheckpoint("cli_disconnected");
+ task_runner_.RunUntilCheckpoint("srv_disconnected");
+}
+
+// Mostly a stress tests. Connects kNumClients clients to the same server and
+// tests that all can exchange data and can see the expected sequence of events.
+TEST_F(UnixSocketTest, SeveralClients) {
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+ constexpr size_t kNumClients = 32;
+ std::unique_ptr<UnixSocket> cli[kNumClients];
+
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .Times(kNumClients)
+ .WillRepeatedly(Invoke([this](UnixSocket*, UnixSocket* s) {
+ EXPECT_CALL(event_listener_, OnDataAvailable(s))
+ .WillOnce(Invoke([](UnixSocket* t) {
+ ASSERT_EQ("PING", t->ReceiveString());
+ ASSERT_TRUE(t->Send("PONG"));
+ }));
+ }));
+
+ for (size_t i = 0; i < kNumClients; i++) {
+ cli[i] = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true))
+ .WillOnce(Invoke([](UnixSocket* s, bool success) {
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(s->Send("PING"));
+ }));
+
+ auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
+ EXPECT_CALL(event_listener_, OnDataAvailable(cli[i].get()))
+ .WillOnce(Invoke([checkpoint](UnixSocket* s) {
+ ASSERT_EQ("PONG", s->ReceiveString());
+ checkpoint();
+ }));
+ }
+
+ for (size_t i = 0; i < kNumClients; i++) {
+ task_runner_.RunUntilCheckpoint(std::to_string(i));
+ ASSERT_TRUE(Mock::VerifyAndClearExpectations(cli[i].get()));
+ }
+}
+
+// Creates two processes. The server process creates a file and passes it over
+// the socket to the client. Both processes mmap the file in shared mode and
+// check that they see the same contents.
+TEST_F(UnixSocketTest, SharedMemory) {
+ int pipes[2];
+ ASSERT_EQ(0, pipe(pipes));
+
+ pid_t pid = fork();
+ ASSERT_GE(pid, 0);
+ constexpr size_t kTmpSize = 4096;
+
+ if (pid == 0) {
+ // Child process.
+ FILE* tmp = tmpfile();
+ ASSERT_NE(nullptr, tmp);
+ int tmp_fd = fileno(tmp);
+ ASSERT_FALSE(ftruncate(tmp_fd, kTmpSize));
+ char* mem = reinterpret_cast<char*>(
+ mmap(nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, tmp_fd, 0));
+ ASSERT_NE(nullptr, mem);
+ memcpy(mem, "shm rocks", 10);
+
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+ // Signal the other process that it can connect.
+ ASSERT_EQ(1, PERFETTO_EINTR(write(pipes[1], ".", 1)));
+ auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_server");
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .WillOnce(Invoke(
+ [this, tmp_fd, checkpoint, mem](UnixSocket*, UnixSocket* new_conn) {
+ ASSERT_EQ(geteuid(), static_cast<uint32_t>(new_conn->peer_uid()));
+ ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd));
+ // Wait for the client to change this again.
+ EXPECT_CALL(event_listener_, OnDataAvailable(new_conn))
+ .WillOnce(Invoke([checkpoint, mem](UnixSocket* s) {
+ ASSERT_EQ("change notify", s->ReceiveString());
+ ASSERT_STREQ("rock more", mem);
+ checkpoint();
+ }));
+ }));
+ task_runner_.RunUntilCheckpoint("change_seen_by_server");
+ ASSERT_TRUE(Mock::VerifyAndClearExpectations(&event_listener_));
+ _exit(0);
+ } else {
+ char sync_cmd = '\0';
+ ASSERT_EQ(1, PERFETTO_EINTR(read(pipes[0], &sync_cmd, 1)));
+ ASSERT_EQ('.', sync_cmd);
+ auto cli =
+ UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), true));
+ auto checkpoint = task_runner_.CreateCheckpoint("change_seen_by_client");
+ EXPECT_CALL(event_listener_, OnDataAvailable(cli.get()))
+ .WillOnce(Invoke([checkpoint](UnixSocket* s) {
+ char msg[32];
+ base::ScopedFile fd;
+ ASSERT_EQ(5u, s->Receive(msg, sizeof(msg), &fd));
+ ASSERT_STREQ("txfd", msg);
+ ASSERT_TRUE(fd);
+ char* mem = reinterpret_cast<char*>(mmap(
+ nullptr, kTmpSize, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0));
+ ASSERT_NE(nullptr, mem);
+ mem[9] = '\0'; // Just to get a clean error in case of test failure.
+ ASSERT_STREQ("shm rocks", mem);
+
+ // Now change the shared memory and ping the other process.
+ memcpy(mem, "rock more", 10);
+ ASSERT_TRUE(s->Send("change notify"));
+ checkpoint();
+ }));
+ task_runner_.RunUntilCheckpoint("change_seen_by_client");
+ int st = 0;
+ PERFETTO_EINTR(waitpid(pid, &st, 0));
+ ASSERT_FALSE(WIFSIGNALED(st)) << "Server died with signal " << WTERMSIG(st);
+ EXPECT_TRUE(WIFEXITED(st));
+ ASSERT_EQ(0, WEXITSTATUS(st));
+ }
+}
+
+constexpr size_t kAtomicWrites_FrameSize = 1123;
+bool AtomicWrites_SendAttempt(UnixSocket* s,
+ base::TaskRunner* task_runner,
+ int num_frame) {
+ char buf[kAtomicWrites_FrameSize];
+ memset(buf, static_cast<char>(num_frame), sizeof(buf));
+ if (s->Send(buf, sizeof(buf)))
+ return true;
+ task_runner->PostTask(
+ std::bind(&AtomicWrites_SendAttempt, s, task_runner, num_frame));
+ return false;
+}
+
+// Creates a client-server pair. The client sends continuously data to the
+// server. Upon each Send() attempt, the client sends a buffer which is memset()
+// with a unique number (0 to kNumFrames). We are deliberately trying to fill
+// the socket output buffer, so we expect some of these Send()s to fail.
+// The client is extremely aggressive and, when a Send() fails, just keeps
+// re-posting it with the same unique number. The server verifies that we
+// receive one and exactly one of each buffers, without any gaps or truncation.
+TEST_F(UnixSocketTest, SendIsAtomic) {
+ static constexpr int kNumFrames = 127;
+
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+
+ auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+
+ auto all_frames_done = task_runner_.CreateCheckpoint("all_frames_done");
+ std::set<int> received_iterations;
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .WillOnce(Invoke([this, &received_iterations, all_frames_done](
+ UnixSocket*, UnixSocket* srv_conn) {
+ EXPECT_CALL(event_listener_, OnDataAvailable(srv_conn))
+ .WillRepeatedly(
+ Invoke([&received_iterations, all_frames_done](UnixSocket* s) {
+ char buf[kAtomicWrites_FrameSize];
+ size_t res = s->Receive(buf, sizeof(buf));
+ if (res == 0)
+ return; // Spurious select(), could happen.
+ ASSERT_EQ(kAtomicWrites_FrameSize, res);
+ // Check that we didn't get two truncated frames.
+ for (size_t i = 0; i < sizeof(buf); i++)
+ ASSERT_EQ(buf[0], buf[i]);
+ ASSERT_EQ(0u, received_iterations.count(buf[0]));
+ received_iterations.insert(buf[0]);
+ if (received_iterations.size() == kNumFrames)
+ all_frames_done();
+ }));
+ }));
+
+ auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
+ .WillOnce(
+ Invoke([cli_connected](UnixSocket*, bool) { cli_connected(); }));
+ task_runner_.RunUntilCheckpoint("cli_connected");
+ ASSERT_TRUE(cli->is_connected());
+ ASSERT_EQ(geteuid(), static_cast<uint32_t>(cli->peer_uid()));
+
+ bool did_requeue = false;
+ for (int i = 0; i < kNumFrames; i++)
+ did_requeue |= !AtomicWrites_SendAttempt(cli.get(), &task_runner_, i);
+
+ // We expect that at least one of the kNumFrames didn't fit in the socket
+ // buffer and was re-posted, otherwise this entire test would be pointless.
+ ASSERT_TRUE(did_requeue);
+
+ task_runner_.RunUntilCheckpoint("all_frames_done");
+}
+
+// Checks that the peer_uid() is retained after the client disconnects. The IPC
+// layer needs to rely on this to validate messages received immediately before
+// a client disconnects.
+TEST_F(UnixSocketTest, PeerUidRetainedAfterDisconnect) {
+ auto srv = UnixSocket::Listen(kSocketName, &event_listener_, &task_runner_);
+ ASSERT_TRUE(srv->is_listening());
+ UnixSocket* srv_client_conn = nullptr;
+ auto srv_connected = task_runner_.CreateCheckpoint("srv_connected");
+ EXPECT_CALL(event_listener_, OnNewIncomingConnection(srv.get(), _))
+ .WillOnce(Invoke(
+ [&srv_client_conn, srv_connected](UnixSocket*, UnixSocket* srv_conn) {
+ srv_client_conn = srv_conn;
+ EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_conn->peer_uid()));
+ srv_connected();
+ }));
+ auto cli_connected = task_runner_.CreateCheckpoint("cli_connected");
+ auto cli = UnixSocket::Connect(kSocketName, &event_listener_, &task_runner_);
+ EXPECT_CALL(event_listener_, OnConnect(cli.get(), true))
+ .WillOnce(
+ Invoke([cli_connected](UnixSocket*, bool) { cli_connected(); }));
+
+ task_runner_.RunUntilCheckpoint("cli_connected");
+ task_runner_.RunUntilCheckpoint("srv_connected");
+ ASSERT_NE(nullptr, srv_client_conn);
+ ASSERT_TRUE(srv_client_conn->is_connected());
+
+ auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
+ EXPECT_CALL(event_listener_, OnDisconnect(srv_client_conn))
+ .WillOnce(
+ Invoke([cli_disconnected](UnixSocket*) { cli_disconnected(); }));
+
+ // TODO(primiano): when the a peer disconnects, the other end receives a
+ // spurious OnDataAvailable() that needs to be acked with a Receive() to read
+ // the EOF. See b/69536434.
+ EXPECT_CALL(event_listener_, OnDataAvailable(srv_client_conn))
+ .WillOnce(Invoke([](UnixSocket* sock) { sock->ReceiveString(); }));
+
+ cli.reset();
+ task_runner_.RunUntilCheckpoint("cli_disconnected");
+ ASSERT_FALSE(srv_client_conn->is_connected());
+ EXPECT_EQ(geteuid(), static_cast<uint32_t>(srv_client_conn->peer_uid()));
+}
+
+// TODO(primiano): add a test to check that in the case of a peer sending a fd
+// and the other end just doing a recv (without taking it), the fd is closed and
+// not left around.
+
+// TODO(primiano); add a test to check that a socket can be reused after
+// Shutdown(),
+
+// TODO(primiano): add a test to check that OnDisconnect() is called in all
+// possible cases.
+
+// TODO(primiano): add tests that destroy the socket in all possible stages and
+// verify that no spurious EventListener callback is received.
+
+} // namespace
+} // namespace ipc
+} // namespace perfetto
diff --git a/src/ipc/wire_protocol.proto b/src/ipc/wire_protocol.proto
new file mode 100644
index 0000000..9a26042
--- /dev/null
+++ b/src/ipc/wire_protocol.proto
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+option optimize_for = LITE_RUNTIME;
+
+package perfetto.ipc;
+
+message Frame {
+ // Client -> Host.
+ message BindService { string service_name = 1; }
+
+ // Host -> Client.
+ message BindServiceReply {
+ message MethodInfo {
+ int32 id = 1;
+ string name = 2;
+ }
+ bool success = 1;
+ int32 service_id = 2;
+ repeated MethodInfo methods = 3;
+ }
+
+ // Client -> Host.
+ message InvokeMethod {
+ int32 service_id = 1; // As returned by BindServiceReply.id.
+ int32 method_id = 2; // As returned by BindServiceReply.method.id.
+ bytes args_proto = 3; // Proto-encoded request argument.
+ }
+
+ // Host -> Client.
+ message InvokeMethodReply {
+ bool success = 1;
+ bool has_more = 2; // only for streaming RPCs.
+ bytes reply_proto = 3; // proto-encoded response value.
+ }
+
+ // Host -> Client.
+ message RequestError { string error = 1; }
+
+ // The client is expected to send requests with monotonically increasing
+ // request_id. The host will match the request_id sent from the client.
+ // In the case of a Streaming response (has_more = true) the host will send
+ // several InvokeMethodReply with the same request_id.
+ uint64 request_id = 2;
+
+ oneof msg {
+ BindService msg_bind_service = 3;
+ BindServiceReply msg_bind_service_reply = 4;
+ InvokeMethod msg_invoke_method = 5;
+ InvokeMethodReply msg_invoke_method_reply = 6;
+ RequestError msg_request_error = 7;
+ }
+
+ // Used only in unittests to generate a parsable message of arbitrary size.
+ repeated bytes data_for_testing = 1;
+};