service: Track writer registrations and target buffer associations.

Also verify that a commit request from the producer specifies the same
target buffer as any previous writer registration for the writer causing
the commit.

Bug: 73828976
Change-Id: I0bcd357f5d36cedb124f8e8379fd52da64995735
diff --git a/src/tracing/core/service_impl_unittest.cc b/src/tracing/core/service_impl_unittest.cc
index accb28c..cbf2da6 100644
--- a/src/tracing/core/service_impl_unittest.cc
+++ b/src/tracing/core/service_impl_unittest.cc
@@ -31,6 +31,7 @@
 #include "perfetto/tracing/core/trace_packet.h"
 #include "perfetto/tracing/core/trace_writer.h"
 #include "src/base/test/test_task_runner.h"
+#include "src/tracing/core/trace_writer_impl.h"
 #include "src/tracing/test/mock_consumer.h"
 #include "src/tracing/test/mock_producer.h"
 #include "src/tracing/test/test_shared_memory.h"
@@ -96,6 +97,10 @@
     return svc->GetProducer(producer_id)->allowed_target_buffers_;
   }
 
+  const std::map<WriterID, BufferID>& GetWriters(ProducerID producer_id) {
+    return svc->GetProducer(producer_id)->writers_;
+  }
+
   size_t GetNumPendingFlushes() {
     return tracing_session()->pending_flushes.size();
   }
@@ -111,6 +116,24 @@
     }
   }
 
+  void WaitForTraceWritersChanged(ProducerID producer_id) {
+    static int i = 0;
+    auto checkpoint_name = "writers_changed_" + std::to_string(producer_id) +
+                           "_" + std::to_string(i++);
+    auto writers_changed = task_runner.CreateCheckpoint(checkpoint_name);
+    auto writers = GetWriters(producer_id);
+    std::function<void()> task;
+    task = [&task, writers, writers_changed, producer_id, this]() {
+      if (writers != GetWriters(producer_id)) {
+        writers_changed();
+        return;
+      }
+      task_runner.PostDelayedTask(task, 1);
+    };
+    task_runner.PostDelayedTask(task, 1);
+    task_runner.RunUntilCheckpoint(checkpoint_name);
+  }
+
   base::TestTaskRunner task_runner;
   std::unique_ptr<TracingServiceImpl> svc;
 };
@@ -1058,4 +1081,65 @@
 }
 #endif  // !PERFETTO_DCHECK_IS_ON()
 
+TEST_F(TracingServiceImplTest, RegisterAndUnregisterTraceWriter) {
+  std::unique_ptr<MockConsumer> consumer = CreateMockConsumer();
+  consumer->Connect(svc.get());
+
+  std::unique_ptr<MockProducer> producer = CreateMockProducer();
+  producer->Connect(svc.get(), "mock_producer");
+  ProducerID producer_id = *last_producer_id();
+  producer->RegisterDataSource("data_source");
+
+  EXPECT_TRUE(GetWriters(producer_id).empty());
+
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(128);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("data_source");
+  ds_config->set_target_buffer(0);
+  consumer->EnableTracing(trace_config);
+
+  producer->WaitForTracingSetup();
+  producer->WaitForDataSourceSetup("data_source");
+  producer->WaitForDataSourceStart("data_source");
+
+  // Calling StartTracing() should be a noop (% a DLOG statement) because the
+  // trace config didn't have the |deferred_start| flag set.
+  consumer->StartTracing();
+
+  // Creating the trace writer should register it with the service.
+  std::unique_ptr<TraceWriter> writer = producer->endpoint()->CreateTraceWriter(
+      tracing_session()->buffers_index[0]);
+
+  WaitForTraceWritersChanged(producer_id);
+
+  std::map<WriterID, BufferID> expected_writers;
+  expected_writers[writer->writer_id()] = tracing_session()->buffers_index[0];
+  EXPECT_EQ(expected_writers, GetWriters(producer_id));
+
+  // Verify writing works.
+  {
+    auto tp = writer->NewTracePacket();
+    tp->set_for_testing()->set_str("payload");
+  }
+
+  auto flush_request = consumer->Flush();
+  producer->WaitForFlush(writer.get());
+  ASSERT_TRUE(flush_request.WaitForReply());
+
+  // Destroying the writer should unregister it.
+  writer.reset();
+  WaitForTraceWritersChanged(producer_id);
+  EXPECT_TRUE(GetWriters(producer_id).empty());
+
+  consumer->DisableTracing();
+  producer->WaitForDataSourceStop("data_source");
+  consumer->WaitForTracingDisabled();
+
+  auto packets = consumer->ReadBuffers();
+  EXPECT_THAT(packets, Contains(Property(
+                           &protos::TracePacket::for_testing,
+                           Property(&protos::TestEvent::str, Eq("payload")))));
+}
+
 }  // namespace perfetto