Refactor unit-testing of core/service_impl.cc

This just improves unit-testing of ServiceImpl by
introducing MockProducer/ MockConsumer classes.
Also this renames (without adding any new behavior)
the existing methods Producer::OnTracing{Start,StopTracing}
as follows:
- OnTracingStop -> removed as it's currently unsupported
  (see b/77532839)
- OnTracingStart -> renamed to OnTracingSetup, because
  this is what it does. Also this name conflicts with
  Consumer::OnTracingStart, which has a different semantic
  and happens at different times.

Also this CL renames Consumer::OnTracingStop to
OnTracingDisabled to match the {Enable,Disable}Tracing methods.

Change-Id: Ided455d3b37cfefdfbc3eda94e5feccaeeb15a5d
diff --git a/src/tracing/core/service_impl_unittest.cc b/src/tracing/core/service_impl_unittest.cc
index 77256ac..1ea2534 100644
--- a/src/tracing/core/service_impl_unittest.cc
+++ b/src/tracing/core/service_impl_unittest.cc
@@ -30,47 +30,26 @@
 #include "perfetto/tracing/core/trace_packet.h"
 #include "perfetto/tracing/core/trace_writer.h"
 #include "src/base/test/test_task_runner.h"
+#include "src/tracing/test/mock_consumer.h"
+#include "src/tracing/test/mock_producer.h"
 #include "src/tracing/test/test_shared_memory.h"
 
 #include "perfetto/trace/test_event.pbzero.h"
 #include "perfetto/trace/trace.pb.h"
+#include "perfetto/trace/trace_packet.pb.h"
 #include "perfetto/trace/trace_packet.pbzero.h"
 
-namespace perfetto {
 using ::testing::_;
+using ::testing::Contains;
+using ::testing::Eq;
 using ::testing::InSequence;
 using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
 using ::testing::Mock;
+using ::testing::Property;
+using ::testing::StrictMock;
 
-namespace {
-
-class MockProducer : public Producer {
- public:
-  ~MockProducer() override {}
-
-  // Producer implementation.
-  MOCK_METHOD0(OnConnect, void());
-  MOCK_METHOD0(OnDisconnect, void());
-  MOCK_METHOD2(CreateDataSourceInstance,
-               void(DataSourceInstanceID, const DataSourceConfig&));
-  MOCK_METHOD1(TearDownDataSourceInstance, void(DataSourceInstanceID));
-  MOCK_METHOD0(OnTracingStart, void());
-  MOCK_METHOD0(OnTracingStop, void());
-};
-
-class MockConsumer : public Consumer {
- public:
-  ~MockConsumer() override {}
-
-  // Consumer implementation.
-  MOCK_METHOD0(OnConnect, void());
-  MOCK_METHOD0(OnDisconnect, void());
-  MOCK_METHOD0(OnTracingStop, void());
-
-  void OnTraceData(std::vector<TracePacket> packets, bool has_more) override {}
-};
-
-}  // namespace
+namespace perfetto {
 
 class ServiceImplTest : public testing::Test {
  public:
@@ -82,350 +61,220 @@
             .release()));
   }
 
+  std::unique_ptr<MockProducer> CreateMockProducer() {
+    return std::unique_ptr<MockProducer>(
+        new StrictMock<MockProducer>(&task_runner));
+  }
+
+  std::unique_ptr<MockConsumer> CreateMockConsumer() {
+    return std::unique_ptr<MockConsumer>(
+        new StrictMock<MockConsumer>(&task_runner));
+  }
+
   base::TestTaskRunner task_runner;
   std::unique_ptr<ServiceImpl> svc;
 };
 
 TEST_F(ServiceImplTest, RegisterAndUnregister) {
-  MockProducer mock_producer_1;
-  MockProducer mock_producer_2;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint_1 =
-      svc->ConnectProducer(&mock_producer_1, 123u /* uid */, "mock_producer_1");
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint_2 =
-      svc->ConnectProducer(&mock_producer_2, 456u /* uid */, "mock_producer_2");
+  std::unique_ptr<MockProducer> mock_producer_1 = CreateMockProducer();
+  std::unique_ptr<MockProducer> mock_producer_2 = CreateMockProducer();
 
-  ASSERT_TRUE(producer_endpoint_1);
-  ASSERT_TRUE(producer_endpoint_2);
-
-  InSequence seq;
-  EXPECT_CALL(mock_producer_1, OnConnect());
-  EXPECT_CALL(mock_producer_2, OnConnect());
-  task_runner.RunUntilIdle();
+  mock_producer_1->Connect(svc.get(), "mock_producer_1", 123u /* uid */);
+  mock_producer_2->Connect(svc.get(), "mock_producer_2", 456u /* uid */);
 
   ASSERT_EQ(2u, svc->num_producers());
-  ASSERT_EQ(producer_endpoint_1.get(), svc->GetProducer(1));
-  ASSERT_EQ(producer_endpoint_2.get(), svc->GetProducer(2));
+  ASSERT_EQ(mock_producer_1->endpoint(), svc->GetProducer(1));
+  ASSERT_EQ(mock_producer_2->endpoint(), svc->GetProducer(2));
   ASSERT_EQ(123u, svc->GetProducer(1)->uid_);
   ASSERT_EQ(456u, svc->GetProducer(2)->uid_);
 
-  DataSourceDescriptor ds_desc1;
-  ds_desc1.set_name("foo");
-  producer_endpoint_1->RegisterDataSource(ds_desc1);
+  mock_producer_1->RegisterDataSource("foo");
+  mock_producer_2->RegisterDataSource("bar");
 
-  DataSourceDescriptor ds_desc2;
-  ds_desc2.set_name("bar");
-  producer_endpoint_2->RegisterDataSource(ds_desc2);
+  mock_producer_1->UnregisterDataSource("foo");
+  mock_producer_2->UnregisterDataSource("bar");
 
-  task_runner.RunUntilIdle();
-
-  producer_endpoint_1->UnregisterDataSource("foo");
-  producer_endpoint_2->UnregisterDataSource("bar");
-
-  task_runner.RunUntilIdle();
-
-  EXPECT_CALL(mock_producer_1, OnDisconnect());
-  producer_endpoint_1.reset();
-  task_runner.RunUntilIdle();
-  Mock::VerifyAndClearExpectations(&mock_producer_1);
-
+  mock_producer_1.reset();
   ASSERT_EQ(1u, svc->num_producers());
   ASSERT_EQ(nullptr, svc->GetProducer(1));
 
-  EXPECT_CALL(mock_producer_2, OnDisconnect());
-  producer_endpoint_2.reset();
-  task_runner.RunUntilIdle();
-  Mock::VerifyAndClearExpectations(&mock_producer_2);
+  mock_producer_2.reset();
+  ASSERT_EQ(nullptr, svc->GetProducer(2));
 
   ASSERT_EQ(0u, svc->num_producers());
 }
 
 TEST_F(ServiceImplTest, EnableAndDisableTracing) {
-  MockProducer mock_producer;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-      svc->ConnectProducer(&mock_producer, 123u /* uid */, "mock_producer");
-  MockConsumer mock_consumer;
-  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
-      svc->ConnectConsumer(&mock_consumer);
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
 
-  InSequence seq;
-  EXPECT_CALL(mock_producer, OnConnect());
-  EXPECT_CALL(mock_consumer, OnConnect());
-  task_runner.RunUntilIdle();
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer");
+  producer->RegisterDataSource("data_source");
 
-  DataSourceDescriptor ds_desc;
-  ds_desc.set_name("foo");
-  producer_endpoint->RegisterDataSource(ds_desc);
-
-  task_runner.RunUntilIdle();
-
-  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
-  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
   TraceConfig trace_config;
-  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  trace_config.add_buffers()->set_size_kb(128);
   auto* ds_config = trace_config.add_data_sources()->mutable_config();
-  ds_config->set_name("foo");
-  ds_config->set_target_buffer(0);
-  consumer_endpoint->EnableTracing(trace_config);
-  task_runner.RunUntilIdle();
+  ds_config->set_name("data_source");
+  consumer->EnableTracing(trace_config);
 
-  EXPECT_CALL(mock_producer, OnDisconnect());
-  EXPECT_CALL(mock_consumer, OnDisconnect());
-  consumer_endpoint->DisableTracing();
-  producer_endpoint.reset();
-  consumer_endpoint.reset();
-  task_runner.RunUntilIdle();
-  Mock::VerifyAndClearExpectations(&mock_producer);
-  Mock::VerifyAndClearExpectations(&mock_consumer);
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
+
+  consumer->DisableTracing();
+  producer->WaitForDataSourceStop("data_source");
+  consumer->WaitForTracingDisabled();
 }
 
 TEST_F(ServiceImplTest, LockdownMode) {
-  MockConsumer mock_consumer;
-  EXPECT_CALL(mock_consumer, OnConnect());
-  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
-      svc->ConnectConsumer(&mock_consumer);
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
+
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer_sameuid", geteuid());
+  producer->RegisterDataSource("data_source");
 
   TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(128);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("data_source");
   trace_config.set_lockdown_mode(
       TraceConfig::LockdownModeOperation::LOCKDOWN_SET);
-  consumer_endpoint->EnableTracing(trace_config);
+  consumer->EnableTracing(trace_config);
+
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
+
+  std::unique_ptr<MockProducer> producer_otheruid = CreateMockProducer();
+  auto x = svc->ConnectProducer(producer_otheruid.get(), geteuid() + 1,
+                                "mock_producer_ouid");
+  EXPECT_CALL(*producer_otheruid, OnConnect()).Times(0);
   task_runner.RunUntilIdle();
+  Mock::VerifyAndClearExpectations(producer_otheruid.get());
 
-  InSequence seq;
-
-  MockProducer mock_producer;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-      svc->ConnectProducer(&mock_producer, geteuid() + 1 /* uid */,
-                           "mock_producer");
-
-  MockProducer mock_producer_sameuid;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint_sameuid =
-      svc->ConnectProducer(&mock_producer_sameuid, geteuid() /* uid */,
-                           "mock_producer_sameuid");
-
-  EXPECT_CALL(mock_producer, OnConnect()).Times(0);
-  EXPECT_CALL(mock_producer_sameuid, OnConnect());
-  task_runner.RunUntilIdle();
-
-  Mock::VerifyAndClearExpectations(&mock_producer);
-
-  consumer_endpoint->DisableTracing();
-  task_runner.RunUntilIdle();
+  consumer->DisableTracing();
+  consumer->FreeBuffers();
+  producer->WaitForDataSourceStop("data_source");
+  consumer->WaitForTracingDisabled();
 
   trace_config.set_lockdown_mode(
       TraceConfig::LockdownModeOperation::LOCKDOWN_CLEAR);
-  consumer_endpoint->EnableTracing(trace_config);
-  task_runner.RunUntilIdle();
+  consumer->EnableTracing(trace_config);
+  producer->WaitForDataSourceStart("data_source");
 
-  EXPECT_CALL(mock_producer_sameuid, OnDisconnect());
-  EXPECT_CALL(mock_producer, OnConnect());
-  producer_endpoint_sameuid =
-      svc->ConnectProducer(&mock_producer, geteuid() + 1, "mock_producer");
+  std::unique_ptr<MockProducer> producer_otheruid2 = CreateMockProducer();
+  producer_otheruid->Connect(svc.get(), "mock_producer_ouid2", geteuid() + 1);
 
-  EXPECT_CALL(mock_producer, OnDisconnect());
-  task_runner.RunUntilIdle();
+  consumer->DisableTracing();
+  producer->WaitForDataSourceStop("data_source");
+  consumer->WaitForTracingDisabled();
 }
 
 TEST_F(ServiceImplTest, DisconnectConsumerWhileTracing) {
-  MockProducer mock_producer;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-      svc->ConnectProducer(&mock_producer, 123u /* uid */, "mock_producer");
-  MockConsumer mock_consumer;
-  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
-      svc->ConnectConsumer(&mock_consumer);
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
 
-  InSequence seq;
-  EXPECT_CALL(mock_producer, OnConnect());
-  EXPECT_CALL(mock_consumer, OnConnect());
-  task_runner.RunUntilIdle();
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer");
+  producer->RegisterDataSource("data_source");
 
-  DataSourceDescriptor ds_desc;
-  ds_desc.set_name("foo");
-  producer_endpoint->RegisterDataSource(ds_desc);
-  task_runner.RunUntilIdle();
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(128);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("data_source");
+  consumer->EnableTracing(trace_config);
+
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
 
   // Disconnecting the consumer while tracing should trigger data source
   // teardown.
-  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
-  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
-  TraceConfig trace_config;
-  trace_config.add_buffers()->set_size_kb(4096 * 10);
-  auto* ds_config = trace_config.add_data_sources()->mutable_config();
-  ds_config->set_name("foo");
-  ds_config->set_target_buffer(0);
-  consumer_endpoint->EnableTracing(trace_config);
-  task_runner.RunUntilIdle();
-
-  EXPECT_CALL(mock_consumer, OnDisconnect());
-  consumer_endpoint.reset();
-  task_runner.RunUntilIdle();
-
-  EXPECT_CALL(mock_producer, OnDisconnect());
-  producer_endpoint.reset();
-  Mock::VerifyAndClearExpectations(&mock_producer);
-  Mock::VerifyAndClearExpectations(&mock_consumer);
+  consumer.reset();
+  producer->WaitForDataSourceStop("data_source");
 }
 
 TEST_F(ServiceImplTest, ReconnectProducerWhileTracing) {
-  MockProducer mock_producer;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-      svc->ConnectProducer(&mock_producer, 123u /* uid */, "mock_producer");
-  MockConsumer mock_consumer;
-  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
-      svc->ConnectConsumer(&mock_consumer);
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
 
-  InSequence seq;
-  EXPECT_CALL(mock_producer, OnConnect());
-  EXPECT_CALL(mock_consumer, OnConnect());
-  task_runner.RunUntilIdle();
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer");
+  producer->RegisterDataSource("data_source");
 
-  DataSourceDescriptor ds_desc;
-  ds_desc.set_name("foo");
-  producer_endpoint->RegisterDataSource(ds_desc);
-  task_runner.RunUntilIdle();
-
-  // Disconnecting the producer while tracing should trigger data source
-  // teardown.
-  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
-  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
-  EXPECT_CALL(mock_producer, OnDisconnect());
   TraceConfig trace_config;
-  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  trace_config.add_buffers()->set_size_kb(128);
   auto* ds_config = trace_config.add_data_sources()->mutable_config();
-  ds_config->set_name("foo");
-  ds_config->set_target_buffer(0);
-  consumer_endpoint->EnableTracing(trace_config);
-  producer_endpoint.reset();
-  task_runner.RunUntilIdle();
+  ds_config->set_name("data_source");
+  consumer->EnableTracing(trace_config);
 
-  // Reconnecting a producer with a matching data source should see that data
-  // source getting enabled.
-  EXPECT_CALL(mock_producer, OnConnect());
-  producer_endpoint =
-      svc->ConnectProducer(&mock_producer, 123u /* uid */, "mock_producer");
-  task_runner.RunUntilIdle();
-  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
-  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
-  producer_endpoint->RegisterDataSource(ds_desc);
-  task_runner.RunUntilIdle();
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
 
-  EXPECT_CALL(mock_consumer, OnDisconnect());
-  consumer_endpoint->DisableTracing();
-  consumer_endpoint.reset();
-  task_runner.RunUntilIdle();
-
-  EXPECT_CALL(mock_producer, OnDisconnect());
-  producer_endpoint.reset();
-  Mock::VerifyAndClearExpectations(&mock_producer);
-  Mock::VerifyAndClearExpectations(&mock_consumer);
+  // Disconnecting and reconnecting a producer with a matching data source.
+  // The Producer should see that data source getting enabled again.
+  producer.reset();
+  producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer_2");
+  producer->RegisterDataSource("data_source");
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
 }
 
 TEST_F(ServiceImplTest, ProducerIDWrapping) {
-  base::TestTaskRunner task_runner;
-  auto shm_factory =
-      std::unique_ptr<SharedMemory::Factory>(new TestSharedMemory::Factory());
-  std::unique_ptr<ServiceImpl> svc(static_cast<ServiceImpl*>(
-      Service::CreateInstance(std::move(shm_factory), &task_runner).release()));
+  std::vector<std::unique_ptr<MockProducer>> producers;
+  producers.push_back(nullptr);
 
-  std::map<ProducerID, std::pair<std::unique_ptr<MockProducer>,
-                                 std::unique_ptr<Service::ProducerEndpoint>>>
-      producers;
-
-  auto ConnectProducerAndWait = [&task_runner, &svc, &producers]() {
-    char checkpoint_name[32];
-    static int checkpoint_num = 0;
-    sprintf(checkpoint_name, "on_connect_%d", checkpoint_num++);
-    auto on_connect = task_runner.CreateCheckpoint(checkpoint_name);
-    std::unique_ptr<MockProducer> producer(new MockProducer());
-    std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-        svc->ConnectProducer(producer.get(), 123u /* uid */, "mock_producer");
-    EXPECT_CALL(*producer, OnConnect()).WillOnce(Invoke(on_connect));
-    task_runner.RunUntilCheckpoint(checkpoint_name);
-    EXPECT_EQ(&*producer_endpoint, svc->GetProducer(svc->last_producer_id_));
-    const ProducerID pr_id = svc->last_producer_id_;
-    producers.emplace(pr_id, std::make_pair(std::move(producer),
-                                            std::move(producer_endpoint)));
-    return pr_id;
-  };
-
-  auto DisconnectProducerAndWait = [&task_runner,
-                                    &producers](ProducerID pr_id) {
-    char checkpoint_name[32];
-    static int checkpoint_num = 0;
-    sprintf(checkpoint_name, "on_disconnect_%d", checkpoint_num++);
-    auto on_disconnect = task_runner.CreateCheckpoint(checkpoint_name);
-    auto it = producers.find(pr_id);
-    PERFETTO_CHECK(it != producers.end());
-    EXPECT_CALL(*it->second.first, OnDisconnect())
-        .WillOnce(Invoke(on_disconnect));
-    producers.erase(pr_id);
-    task_runner.RunUntilCheckpoint(checkpoint_name);
+  auto connect_producer_and_get_id = [&producers,
+                                      this](const std::string& name) {
+    producers.emplace_back(CreateMockProducer());
+    producers.back()->Connect(svc.get(), "mock_producer_" + name);
+    return svc->last_producer_id_;
   };
 
   // Connect producers 1-4.
   for (ProducerID i = 1; i <= 4; i++)
-    ASSERT_EQ(i, ConnectProducerAndWait());
+    ASSERT_EQ(i, connect_producer_and_get_id(std::to_string(i)));
 
   // Disconnect producers 1,3.
-  DisconnectProducerAndWait(1);
-  DisconnectProducerAndWait(3);
+  producers[1].reset();
+  producers[3].reset();
 
   svc->last_producer_id_ = kMaxProducerID - 1;
-  ASSERT_EQ(kMaxProducerID, ConnectProducerAndWait());
-  ASSERT_EQ(1u, ConnectProducerAndWait());
-  ASSERT_EQ(3u, ConnectProducerAndWait());
-  ASSERT_EQ(5u, ConnectProducerAndWait());
-  ASSERT_EQ(6u, ConnectProducerAndWait());
-
-  // Disconnect all producers to mute spurious callbacks.
-  DisconnectProducerAndWait(kMaxProducerID);
-  for (ProducerID i = 1; i <= 6; i++)
-    DisconnectProducerAndWait(i);
+  ASSERT_EQ(kMaxProducerID, connect_producer_and_get_id("maxid"));
+  ASSERT_EQ(1u, connect_producer_and_get_id("1_again"));
+  ASSERT_EQ(3u, connect_producer_and_get_id("3_again"));
+  ASSERT_EQ(5u, connect_producer_and_get_id("5"));
+  ASSERT_EQ(6u, connect_producer_and_get_id("6"));
 }
 
 TEST_F(ServiceImplTest, WriteIntoFileAndStopOnMaxSize) {
-  MockProducer mock_producer;
-  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
-      svc->ConnectProducer(&mock_producer, 123u /* uid */, "mock_producer");
-  MockConsumer mock_consumer;
-  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
-      svc->ConnectConsumer(&mock_consumer);
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
 
-  EXPECT_CALL(mock_producer, OnConnect());
-  EXPECT_CALL(mock_consumer, OnConnect());
-  task_runner.RunUntilIdle();
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer");
+  producer->RegisterDataSource("data_source");
 
-  DataSourceDescriptor ds_desc;
-  ds_desc.set_name("datasource");
-  producer_endpoint->RegisterDataSource(ds_desc);
-  task_runner.RunUntilIdle();
-
-  static const char kPayload[] = "1234567890abcdef-";
-  static const int kNumPackets = 10;
   TraceConfig trace_config;
   trace_config.add_buffers()->set_size_kb(4096);
   auto* ds_config = trace_config.add_data_sources()->mutable_config();
-  ds_config->set_name("datasource");
+  ds_config->set_name("data_source");
   ds_config->set_target_buffer(0);
   trace_config.set_write_into_file(true);
   trace_config.set_file_write_period_ms(1);
   const uint64_t kMaxFileSize = 512;
   trace_config.set_max_file_size_bytes(kMaxFileSize);
   base::TempFile tmp_file = base::TempFile::Create();
-  auto on_tracing_start = task_runner.CreateCheckpoint("on_tracing_start");
-  BufferID buf_id = 0;
-  EXPECT_CALL(mock_producer, OnTracingStart());
-  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _))
-      .WillOnce(Invoke([on_tracing_start, &buf_id](
-                           DataSourceInstanceID, const DataSourceConfig& cfg) {
-        buf_id = static_cast<BufferID>(cfg.target_buffer());
-        on_tracing_start();
-      }));
-  consumer_endpoint->EnableTracing(trace_config,
-                                   base::ScopedFile(dup(tmp_file.fd())));
-  task_runner.RunUntilCheckpoint("on_tracing_start");
+  consumer->EnableTracing(trace_config, base::ScopedFile(dup(tmp_file.fd())));
+
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceStart("data_source");
+
+  static const char kPayload[] = "1234567890abcdef-";
+  static const int kNumPackets = 10;
 
   std::unique_ptr<TraceWriter> writer =
-      producer_endpoint->CreateTraceWriter(buf_id);
+      producer->CreateTraceWriter("data_source");
   // All these packets should fit within kMaxFileSize.
   for (int i = 0; i < kNumPackets; i++) {
     auto tp = writer->NewTracePacket();
@@ -444,17 +293,9 @@
   writer->Flush();
   writer.reset();
 
-  auto on_tracing_stop = task_runner.CreateCheckpoint("on_tracing_stop");
-  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
-  EXPECT_CALL(mock_consumer, OnTracingStop()).WillOnce(Invoke(on_tracing_stop));
-  task_runner.RunUntilCheckpoint("on_tracing_stop");
-
-  EXPECT_CALL(mock_consumer, OnDisconnect());
-  EXPECT_CALL(mock_producer, OnDisconnect());
-  consumer_endpoint->DisableTracing();
-  consumer_endpoint.reset();
-  producer_endpoint.reset();
-  task_runner.RunUntilIdle();
+  consumer->DisableTracing();
+  producer->WaitForDataSourceStop("data_source");
+  consumer->WaitForTracingDisabled();
 
   // Verify the contents of the file.
   std::string trace_raw;
@@ -470,6 +311,6 @@
     ASSERT_EQ(kPayload + std::to_string(num_testing_packet++),
               tp.for_testing().str());
   }
-}  // namespace perfetto
+}
 
 }  // namespace perfetto