service: Implement producer and data source unregistration

Bug: 73806120
Change-Id: If3ae078340b4f0d212d6624ca8a20b9665b6de2d
diff --git a/src/tracing/core/service_impl_unittest.cc b/src/tracing/core/service_impl_unittest.cc
index 8f13777..3ab0b11 100644
--- a/src/tracing/core/service_impl_unittest.cc
+++ b/src/tracing/core/service_impl_unittest.cc
@@ -20,10 +20,12 @@
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "perfetto/tracing/core/consumer.h"
 #include "perfetto/tracing/core/data_source_config.h"
 #include "perfetto/tracing/core/data_source_descriptor.h"
 #include "perfetto/tracing/core/producer.h"
 #include "perfetto/tracing/core/shared_memory.h"
+#include "perfetto/tracing/core/trace_packet.h"
 #include "src/base/test/test_task_runner.h"
 #include "src/tracing/test/test_shared_memory.h"
 
@@ -46,14 +48,34 @@
   MOCK_METHOD1(TearDownDataSourceInstance, void(DataSourceInstanceID));
 };
 
+class MockConsumer : public Consumer {
+ public:
+  ~MockConsumer() override {}
+
+  // Consumer implementation.
+  MOCK_METHOD0(OnConnect, void());
+  MOCK_METHOD0(OnDisconnect, void());
+
+  void OnTraceData(std::vector<TracePacket> packets, bool has_more) override {}
+};
+
 }  // namespace
 
-TEST(ServiceImplTest, RegisterAndUnregister) {
+class ServiceImplTest : public testing::Test {
+ public:
+  ServiceImplTest() {
+    auto shm_factory =
+        std::unique_ptr<SharedMemory::Factory>(new TestSharedMemory::Factory());
+    svc.reset(static_cast<ServiceImpl*>(
+        Service::CreateInstance(std::move(shm_factory), &task_runner)
+            .release()));
+  }
+
   base::TestTaskRunner task_runner;
-  auto shm_factory =
-      std::unique_ptr<SharedMemory::Factory>(new TestSharedMemory::Factory());
-  std::unique_ptr<ServiceImpl> svc(static_cast<ServiceImpl*>(
-      Service::CreateInstance(std::move(shm_factory), &task_runner).release()));
+  std::unique_ptr<ServiceImpl> svc;
+};
+
+TEST_F(ServiceImplTest, RegisterAndUnregister) {
   MockProducer mock_producer_1;
   MockProducer mock_producer_2;
   std::unique_ptr<Service::ProducerEndpoint> producer_endpoint_1 =
@@ -78,7 +100,7 @@
   DataSourceDescriptor ds_desc1;
   ds_desc1.set_name("foo");
   producer_endpoint_1->RegisterDataSource(
-      ds_desc1, [&task_runner, &producer_endpoint_1](DataSourceID id) {
+      ds_desc1, [this, &producer_endpoint_1](DataSourceID id) {
         EXPECT_EQ(1u, id);
         task_runner.PostTask(
             std::bind(&Service::ProducerEndpoint::UnregisterDataSource,
@@ -88,7 +110,7 @@
   DataSourceDescriptor ds_desc2;
   ds_desc2.set_name("bar");
   producer_endpoint_2->RegisterDataSource(
-      ds_desc2, [&task_runner, &producer_endpoint_2](DataSourceID id) {
+      ds_desc2, [this, &producer_endpoint_2](DataSourceID id) {
         EXPECT_EQ(1u, id);
         task_runner.PostTask(
             std::bind(&Service::ProducerEndpoint::UnregisterDataSource,
@@ -113,4 +135,136 @@
   ASSERT_EQ(0u, svc->num_producers());
 }
 
+TEST_F(ServiceImplTest, EnableAndDisableTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint->DisableTracing();
+  producer_endpoint.reset();
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
+TEST_F(ServiceImplTest, DisconnectConsumerWhileTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  // Disconnecting the consumer while tracing should trigger data source
+  // teardown.
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  producer_endpoint.reset();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
+TEST_F(ServiceImplTest, ReconnectProducerWhileTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  // Disconnecting the producer while tracing should trigger data source
+  // teardown.
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  producer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  // Reconnecting a producer with a matching data source should see that data
+  // source getting enabled.
+  EXPECT_CALL(mock_producer, OnConnect());
+  producer_endpoint = svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  task_runner.RunUntilIdle();
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint->DisableTracing();
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  producer_endpoint.reset();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
 }  // namespace perfetto