Client API: Allow data sources to handle Stop asynchronously

Exposes the ability to defer tracing session stop to the public
producer-side API. This feature was already part of the tracing
protocol but was not exposed to the public API surface.

Also makes some minor cleanups:
- Rename instance_state->started to trace_lambda_enabled
  to better reflect what the field does.
- Rename {android_,}client_api_example as the code is not really
  Android-specific.

Bug: 132678367
Bug: 137210068
Change-Id: I446d2ce9df1c9d4517a8b078813b9cc4138aef20
diff --git a/Android.bp b/Android.bp
index 52d9186..bbbdc3e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -4256,7 +4256,7 @@
 cc_binary {
   name: "libperfetto_client_example",
   srcs: [
-    "test/android_client_api_example.cc",
+    "test/client_api_example.cc",
   ],
   static_libs: [
     "libperfetto_client_experimental",
diff --git a/Android.bp.extras b/Android.bp.extras
index a2db5ed..3d89bce 100644
--- a/Android.bp.extras
+++ b/Android.bp.extras
@@ -72,7 +72,7 @@
 cc_binary {
   name: "libperfetto_client_example",
   srcs: [
-    "test/android_client_api_example.cc",
+    "test/client_api_example.cc",
   ],
   static_libs: [
     "libperfetto_client_experimental",
diff --git a/BUILD.gn b/BUILD.gn
index f50d0c9..bf85e54 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -63,6 +63,7 @@
       "protos/perfetto/trace:merged_trace",  # For syntax-checking the proto.
       "src/ipc/protoc_plugin:ipc_plugin($host_toolchain)",
       "src/tracing:client_api",
+      "test:client_api_example",
       "tools:protoc_helper",
     ]
     if (perfetto_build_standalone) {
diff --git a/include/perfetto/tracing/data_source.h b/include/perfetto/tracing/data_source.h
index 9d2415d..b4338ab 100644
--- a/include/perfetto/tracing/data_source.h
+++ b/include/perfetto/tracing/data_source.h
@@ -52,20 +52,48 @@
  public:
   virtual ~DataSourceBase();
 
+  // TODO(primiano): change the const& args below to be pointers instead. It
+  // makes it more awkward to handle output arguments and require mutable(s).
+  // This requires synchronizing a breaking API change for existing embedders.
+
   // OnSetup() is invoked when tracing is configured. In most cases this happens
   // just before starting the trace. In the case of deferred start (see
   // deferred_start in trace_config.proto) start might happen later.
-  struct SetupArgs {
+  class SetupArgs {
+   public:
     // This is valid only within the scope of the OnSetup() call and must not
     // be retained.
     const DataSourceConfig* config = nullptr;
   };
   virtual void OnSetup(const SetupArgs&);
 
-  struct StartArgs {};
+  class StartArgs {};
   virtual void OnStart(const StartArgs&);
 
-  struct StopArgs {};
+  class StopArgs {
+   public:
+    virtual ~StopArgs();
+
+    // HandleAsynchronously() can optionally be called to defer the tracing
+    // session stop and write tracing data just before stopping.
+    // This function returns a closure that must be invoked after the last
+    // trace events have been emitted. The returned closure can be called from
+    // any thread. The caller also needs to explicitly call TraceContext.Flush()
+    // from the last Trace() lambda invocation because no other implicit flushes
+    // will happen after the stop signal.
+    // When this function is called, the tracing service will defer the stop of
+    // the tracing session until the returned closure is invoked.
+    // However, the caller cannot hang onto this closure for too long. The
+    // tracing service will forcefully stop the tracing session without waiting
+    // for pending producers after TraceConfig.data_source_stop_timeout_ms
+    // (default: 5s, can be overridden by Consumers when starting a trace).
+    // If the closure is called after this timeout an error will be logged and
+    // the trace data emitted will not be present in the trace. No other
+    // functional side effects (e.g. crashes or corruptions) will happen. In
+    // other words, it is fine to accidentally hold onto this closure for too
+    // long but, if that happens, some tracing data will be lost.
+    virtual std::function<void()> HandleStopAsynchronously() const = 0;
+  };
   virtual void OnStop(const StopArgs&);
 };
 
@@ -232,7 +260,7 @@
         instances =
             static_state_.valid_instances.load(std::memory_order_acquire);
         instance_state = static_state_.TryGetCached(instances, i);
-        if (!instance_state || !instance_state->started)
+        if (!instance_state || !instance_state->trace_lambda_enabled)
           return;
         tls_inst.backend_id = instance_state->backend_id;
         tls_inst.buffer_id = instance_state->buffer_id;
diff --git a/include/perfetto/tracing/internal/data_source_internal.h b/include/perfetto/tracing/internal/data_source_internal.h
index 9ce027a..dde8786 100644
--- a/include/perfetto/tracing/internal/data_source_internal.h
+++ b/include/perfetto/tracing/internal/data_source_internal.h
@@ -44,10 +44,20 @@
 // There is one of these object per DataSource instance (up to
 // kMaxDataSourceInstances).
 struct DataSourceState {
-  // If false the data source is initialized but not started yet (or stopped).
-  // This is set right before calling OnStart() and cleared right before calling
-  // OnStop()
-  bool started = false;
+  // This boolean flag determines whether the DataSource::Trace() method should
+  // do something or be a no-op. This flag doesn't give the full guarantee
+  // that tracing data will be visible in the trace, it just makes it so that
+  // the client attemps writing trace data and interacting with the service.
+  // For instance, when a tracing session ends the service will reject data
+  // commits that arrive too late even if the producer hasn't received the stop
+  // IPC message.
+  // This flag is set right before calling OnStart() and cleared right before
+  // calling OnStop(), unless using HandleStopAsynchronously() (see comments
+  // in data_source.h).
+  // Keep this flag as the first field. This allows the compiler to directly
+  // dereference the DataSourceState* pointer in the trace fast-path without
+  // doing extra pointr arithmetic.
+  bool trace_lambda_enabled = false;
 
   // The central buffer id that all TraceWriter(s) created by this data source
   // must target.
diff --git a/include/perfetto/tracing/internal/tracing_muxer.h b/include/perfetto/tracing/internal/tracing_muxer.h
index 773f1e6..1e9a6d7 100644
--- a/include/perfetto/tracing/internal/tracing_muxer.h
+++ b/include/perfetto/tracing/internal/tracing_muxer.h
@@ -98,7 +98,7 @@
   static TracingMuxer* instance_;
   Platform* const platform_ = nullptr;
 
-  // Incremented upon each data source stop. See comment in tracing_tls.h.
+  // Incremented every time a data source is destroyed. See tracing_tls.h.
   std::atomic<uint32_t> generation_{};
 };
 
diff --git a/src/tracing/api_integrationtest.cc b/src/tracing/api_integrationtest.cc
index 046bace..f57ff10 100644
--- a/src/tracing/api_integrationtest.cc
+++ b/src/tracing/api_integrationtest.cc
@@ -86,6 +86,8 @@
   WaitableTestEvent on_stop;
   MockDataSource* instance;
   perfetto::DataSourceConfig config;
+  bool handle_stop_asynchronously = false;
+  std::function<void()> async_stop_closure;
 };
 
 class MockDataSource : public perfetto::DataSource<MockDataSource> {
@@ -179,8 +181,10 @@
   handle_->on_start.Notify();
 }
 
-void MockDataSource::OnStop(const StopArgs&) {
+void MockDataSource::OnStop(const StopArgs& args) {
   EXPECT_NE(handle_, nullptr);
+  if (handle_->handle_stop_asynchronously)
+    handle_->async_stop_closure = args.HandleStopAsynchronously();
   handle_->on_stop.Notify();
 }
 
@@ -295,6 +299,70 @@
   EXPECT_TRUE(tracing_session->on_stop.notified());
 }
 
+TEST_F(PerfettoApiTest, WriteEventsAfterDeferredStop) {
+  auto* data_source = &data_sources_["my_data_source"];
+  data_source->handle_stop_asynchronously = true;
+
+  // Setup the trace config and start the tracing session.
+  perfetto::TraceConfig cfg;
+  cfg.set_duration_ms(500);
+  cfg.add_buffers()->set_size_kb(1024);
+  auto* ds_cfg = cfg.add_data_sources()->mutable_config();
+  ds_cfg->set_name("my_data_source");
+  auto* tracing_session = NewTrace(cfg);
+  tracing_session->get()->StartBlocking();
+
+  // Stop and wait for the producer to have seen the stop event.
+  WaitableTestEvent consumer_stop_signal;
+  tracing_session->get()->SetOnStopCallback(
+      [&consumer_stop_signal] { consumer_stop_signal.Notify(); });
+  tracing_session->get()->Stop();
+  data_source->on_stop.Wait();
+
+  // At this point tracing should be still allowed because of the
+  // HandleStopAsynchronously() call.
+  bool lambda_called = false;
+
+  // This usleep is here just to prevent that we accidentally pass the test
+  // just by virtue of hitting some race. We should be able to trace up until
+  // 5 seconds after seeing the stop when using the deferred stop mechanism.
+  usleep(250 * 1000);
+
+  MockDataSource::Trace([&lambda_called](MockDataSource::TraceContext ctx) {
+    auto packet = ctx.NewTracePacket();
+    packet->set_for_testing()->set_str("event written after OnStop");
+    packet->Finalize();
+    ctx.Flush();
+    lambda_called = true;
+  });
+  ASSERT_TRUE(lambda_called);
+
+  // Now call the async stop closure. This acks the stop to the service and
+  // disallows further Trace() calls.
+  EXPECT_TRUE(data_source->async_stop_closure);
+  data_source->async_stop_closure();
+
+  // Wait that the stop is propagated to the consumer.
+  consumer_stop_signal.Wait();
+
+  MockDataSource::Trace([](MockDataSource::TraceContext) {
+    FAIL() << "Should not be called after the stop is acked";
+  });
+
+  // Check the contents of the trace.
+  std::vector<char> raw_trace = tracing_session->get()->ReadTraceBlocking();
+  ASSERT_GE(raw_trace.size(), 0u);
+  perfetto::protos::Trace trace;
+  ASSERT_TRUE(trace.ParseFromArray(raw_trace.data(), int(raw_trace.size())));
+  int test_packet_found = 0;
+  for (const auto& packet : trace.packet()) {
+    if (!packet.has_for_testing())
+      continue;
+    EXPECT_EQ(packet.for_testing().str(), "event written after OnStop");
+    test_packet_found++;
+  }
+  EXPECT_EQ(test_packet_found, 1);
+}
 }  // namespace
 
 PERFETTO_DEFINE_DATA_SOURCE_STATIC_MEMBERS(MockDataSource);
diff --git a/src/tracing/data_source.cc b/src/tracing/data_source.cc
index edb00fd..418ba43 100644
--- a/src/tracing/data_source.cc
+++ b/src/tracing/data_source.cc
@@ -18,6 +18,7 @@
 
 namespace perfetto {
 
+DataSourceBase::StopArgs::~StopArgs() = default;
 DataSourceBase::~DataSourceBase() = default;
 void DataSourceBase::OnSetup(const SetupArgs&) {}
 void DataSourceBase::OnStart(const StartArgs&) {}
diff --git a/src/tracing/internal/tracing_muxer_impl.cc b/src/tracing/internal/tracing_muxer_impl.cc
index be96d0b..bb91811 100644
--- a/src/tracing/internal/tracing_muxer_impl.cc
+++ b/src/tracing/internal/tracing_muxer_impl.cc
@@ -41,6 +41,21 @@
 namespace perfetto {
 namespace internal {
 
+namespace {
+
+class StopArgsImpl : public DataSourceBase::StopArgs {
+ public:
+  std::function<void()> HandleStopAsynchronously() const override {
+    auto closure = std::move(async_stop_closure);
+    async_stop_closure = std::function<void()>();
+    return closure;
+  }
+
+  mutable std::function<void()> async_stop_closure;
+};
+
+}  // namespace
+
 // ----- Begin of TracingMuxerImpl::ProducerImpl
 TracingMuxerImpl::ProducerImpl::ProducerImpl(TracingMuxerImpl* muxer,
                                              TracingBackendId backend_id)
@@ -86,8 +101,7 @@
 
 void TracingMuxerImpl::ProducerImpl::StopDataSource(DataSourceInstanceID id) {
   PERFETTO_DCHECK_THREAD(thread_checker_);
-  muxer_->StopDataSource(backend_id_, id);
-  service_->NotifyDataSourceStopped(id);
+  muxer_->StopDataSource_AsyncBegin(backend_id_, id);
 }
 
 void TracingMuxerImpl::ProducerImpl::Flush(FlushRequestID flush_id,
@@ -484,65 +498,93 @@
                                        DataSourceInstanceID instance_id) {
   PERFETTO_DLOG("Starting data source %" PRIu64, instance_id);
   PERFETTO_DCHECK_THREAD(thread_checker_);
-  for (const auto& rds : data_sources_) {
-    DataSourceStaticState& static_state = *rds.static_state;
-    for (uint32_t i = 0; i < kMaxDataSourceInstances; i++) {
-      auto* internal_state = static_state.TryGet(i);
-      if (!internal_state)
-        continue;
 
-      if (internal_state->backend_id != backend_id ||
-          internal_state->data_source_instance_id != instance_id) {
-        continue;
-      }
-
-      std::lock_guard<std::mutex> guard(internal_state->lock);
-      internal_state->started = true;
-      internal_state->data_source->OnStart(DataSourceBase::StartArgs{});
-      return;
-    }
+  auto ds = FindDataSource(backend_id, instance_id);
+  if (!ds) {
+    PERFETTO_ELOG("Could not find data source to start");
+    return;
   }
-  PERFETTO_ELOG("Could not find data source to start");
+
+  std::lock_guard<std::mutex> guard(ds.internal_state->lock);
+  ds.internal_state->trace_lambda_enabled = true;
+  ds.internal_state->data_source->OnStart(DataSourceBase::StartArgs{});
 }
 
 // Called by the service of one of the backends.
-void TracingMuxerImpl::StopDataSource(TracingBackendId backend_id,
-                                      DataSourceInstanceID instance_id) {
+void TracingMuxerImpl::StopDataSource_AsyncBegin(
+    TracingBackendId backend_id,
+    DataSourceInstanceID instance_id) {
   PERFETTO_DLOG("Stopping data source %" PRIu64, instance_id);
   PERFETTO_DCHECK_THREAD(thread_checker_);
-  for (const auto& rds : data_sources_) {
-    DataSourceStaticState& static_state = *rds.static_state;
-    for (uint32_t i = 0; i < kMaxDataSourceInstances; i++) {
-      auto* internal_state = static_state.TryGet(i);
-      if (!internal_state)
-        continue;
 
-      if (internal_state->backend_id != backend_id ||
-          internal_state->data_source_instance_id != instance_id) {
-        continue;
-      }
-
-      static_state.valid_instances.fetch_and(~(1 << i),
-                                             std::memory_order_acq_rel);
-
-      // Take the mutex to prevent that the data source is in the middle of
-      // a Trace() execution where it called GetDataSourceLocked() while we
-      // destroy it.
-      {
-        std::lock_guard<std::mutex> guard(internal_state->lock);
-        internal_state->started = false;
-        internal_state->data_source->OnStop(DataSourceBase::StopArgs{});
-        internal_state->data_source.reset();
-      }
-
-      // The other fields of internal_state are deliberately *not* cleared.
-      // See races-related comments of DataSource::Trace().
-
-      TracingMuxer::generation_++;
-      return;
-    }
+  auto ds = FindDataSource(backend_id, instance_id);
+  if (!ds) {
+    PERFETTO_ELOG("Could not find data source to stop");
+    return;
   }
-  PERFETTO_ELOG("Could not find data source to stop");
+
+  StopArgsImpl stop_args{};
+  stop_args.async_stop_closure = [this, backend_id, instance_id] {
+    // TracingMuxerImpl is long lived, capturing |this| is okay.
+    // The notification closure can be moved out of the StopArgs by the
+    // embedder to handle stop asynchronously. The embedder might then
+    // call the closure on a different thread than the current one, hence
+    // this nested PostTask().
+    task_runner_->PostTask([this, backend_id, instance_id] {
+      StopDataSource_AsyncEnd(backend_id, instance_id);
+    });
+  };
+
+  {
+    std::lock_guard<std::mutex> guard(ds.internal_state->lock);
+    ds.internal_state->data_source->OnStop(stop_args);
+  }
+
+  // If the embedder hasn't called StopArgs.HandleStopAsynchronously() run the
+  // async closure here. In theory we could avoid the PostTask and call
+  // straight into CompleteDataSourceAsyncStop(). We keep that to reduce
+  // divergencies between the deferred-stop vs non-deferred-stop code paths.
+  if (stop_args.async_stop_closure)
+    std::move(stop_args.async_stop_closure)();
+}
+
+void TracingMuxerImpl::StopDataSource_AsyncEnd(
+    TracingBackendId backend_id,
+    DataSourceInstanceID instance_id) {
+  PERFETTO_DLOG("Ending async stop of data source %" PRIu64, instance_id);
+  PERFETTO_DCHECK_THREAD(thread_checker_);
+
+  auto ds = FindDataSource(backend_id, instance_id);
+  if (!ds) {
+    PERFETTO_ELOG(
+        "Async stop of data source %" PRIu64
+        " failed. This might be due to calling the async_stop_closure twice.",
+        instance_id);
+    return;
+  }
+
+  const uint32_t mask = ~(1 << ds.instance_idx);
+  ds.static_state->valid_instances.fetch_and(mask, std::memory_order_acq_rel);
+
+  // Take the mutex to prevent that the data source is in the middle of
+  // a Trace() execution where it called GetDataSourceLocked() while we
+  // destroy it.
+  {
+    std::lock_guard<std::mutex> guard(ds.internal_state->lock);
+    ds.internal_state->trace_lambda_enabled = false;
+    ds.internal_state->data_source.reset();
+  }
+
+  // The other fields of internal_state are deliberately *not* cleared.
+  // See races-related comments of DataSource::Trace().
+
+  TracingMuxer::generation_++;
+
+  // |backends_| is append-only, Backend instances are always valid.
+  PERFETTO_CHECK(backend_id < backends_.size());
+  ProducerImpl* producer = backends_[backend_id].producer.get();
+  if (producer && producer->connected_)
+    producer->service_->NotifyDataSourceStopped(instance_id);
 }
 
 void TracingMuxerImpl::DestroyStoppedTraceWritersForCurrentThread() {
@@ -709,6 +751,23 @@
   return nullptr;
 }
 
+TracingMuxerImpl::FindDataSourceRes TracingMuxerImpl::FindDataSource(
+    TracingBackendId backend_id,
+    DataSourceInstanceID instance_id) {
+  PERFETTO_DCHECK_THREAD(thread_checker_);
+  for (const auto& rds : data_sources_) {
+    DataSourceStaticState* static_state = rds.static_state;
+    for (uint32_t i = 0; i < kMaxDataSourceInstances; i++) {
+      auto* internal_state = static_state->TryGet(i);
+      if (internal_state && internal_state->backend_id == backend_id &&
+          internal_state->data_source_instance_id == instance_id) {
+        return FindDataSourceRes(static_state, internal_state, i);
+      }
+    }
+  }
+  return FindDataSourceRes();
+}
+
 // Can be called from any thread.
 std::unique_ptr<TraceWriterBase> TracingMuxerImpl::CreateTraceWriter(
     DataSourceState* data_source) {
diff --git a/src/tracing/internal/tracing_muxer_impl.h b/src/tracing/internal/tracing_muxer_impl.h
index 8ff4b9d..5402577 100644
--- a/src/tracing/internal/tracing_muxer_impl.h
+++ b/src/tracing/internal/tracing_muxer_impl.h
@@ -109,7 +109,8 @@
                        DataSourceInstanceID,
                        const DataSourceConfig&);
   void StartDataSource(TracingBackendId, DataSourceInstanceID);
-  void StopDataSource(TracingBackendId, DataSourceInstanceID);
+  void StopDataSource_AsyncBegin(TracingBackendId, DataSourceInstanceID);
+  void StopDataSource_AsyncEnd(TracingBackendId, DataSourceInstanceID);
 
   // Consumer-side bookkeeping methods.
   void SetupTracingSession(TracingSessionGlobalID,
@@ -271,6 +272,18 @@
   void Initialize(const TracingInitArgs& args);
   ConsumerImpl* FindConsumer(TracingSessionGlobalID session_id);
 
+  struct FindDataSourceRes {
+    FindDataSourceRes() = default;
+    FindDataSourceRes(DataSourceStaticState* a, DataSourceState* b, uint32_t c)
+        : static_state(a), internal_state(b), instance_idx(c) {}
+    explicit operator bool() const { return !!internal_state; }
+
+    DataSourceStaticState* static_state = nullptr;
+    DataSourceState* internal_state = nullptr;
+    uint32_t instance_idx = 0;
+  };
+  FindDataSourceRes FindDataSource(TracingBackendId, DataSourceInstanceID);
+
   std::unique_ptr<base::TaskRunner> task_runner_;
   std::vector<RegisteredDataSource> data_sources_;
   std::vector<RegisteredBackend> backends_;
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 69a7513..81d4e32 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -53,6 +53,17 @@
   }
 }
 
+executable("client_api_example") {
+  sources = [
+    "client_api_example.cc",
+  ]
+  deps = [
+    "..:libperfetto_client_experimental",
+    "../gn:default_deps",
+    "../protos/perfetto/trace:zero",
+  ]
+}
+
 perfetto_fuzzer_test("end_to_end_shared_memory_fuzzer") {
   sources = [
     "end_to_end_shared_memory_fuzzer.cc",
diff --git a/test/android_client_api_example.cc b/test/client_api_example.cc
similarity index 67%
rename from test/android_client_api_example.cc
rename to test/client_api_example.cc
index 459dd5f..f43bd98 100644
--- a/test/android_client_api_example.cc
+++ b/test/client_api_example.cc
@@ -16,13 +16,16 @@
 
 #include "perfetto/tracing.h"
 
+#include <thread>
+
 #include "perfetto/trace/test_event.pbzero.h"
-#include "perfetto/trace/trace.pb.h"
 #include "perfetto/trace/trace_packet.pbzero.h"
 
 // Deliberately not pulling any non-public perfetto header to spot accidental
 // header public -> non-public dependency while building this file.
 
+namespace {
+
 class MyDataSource : public perfetto::DataSource<MyDataSource> {
  public:
   void OnSetup(const SetupArgs& args) override {
@@ -33,9 +36,32 @@
 
   void OnStart(const StartArgs&) override { PERFETTO_ILOG("OnStart called"); }
 
-  void OnStop(const StopArgs&) override { PERFETTO_ILOG("OnStop called"); }
+  void OnStop(const StopArgs& args) override {
+    PERFETTO_ILOG("OnStop called");
+
+    // Demonstrates the ability to defer stop and handle it asynchronously,
+    // writing data at the very end of the trace.
+    auto stop_closure = args.HandleStopAsynchronously();
+    std::thread another_thread([stop_closure] {
+      sleep(2);
+      MyDataSource::Trace([](MyDataSource::TraceContext ctx) {
+        PERFETTO_LOG("Tracing lambda called while stopping");
+        auto packet = ctx.NewTracePacket();
+        packet->set_for_testing()->set_str("event recorded while stopping");
+        packet->Finalize();  //  Required because of the Flush below.
+
+        // This explicit Flush() is required because the service doesn't issue
+        // any other flush requests after the Stop() signal.
+        ctx.Flush();
+      });
+      stop_closure();
+    });
+    another_thread.detach();
+  }
 };
 
+}  // namespace
+
 PERFETTO_DEFINE_DATA_SOURCE_STATIC_MEMBERS(MyDataSource);
 
 int main() {
diff --git a/test/configs/BUILD.gn b/test/configs/BUILD.gn
index fcbc2b6..80b5fc8 100644
--- a/test/configs/BUILD.gn
+++ b/test/configs/BUILD.gn
@@ -28,6 +28,7 @@
       "atrace.cfg",
       "background.cfg",
       "camera.cfg",
+      "client_api.cfg",
       "ftrace.cfg",
       "ftrace_largebuffer.cfg",
       "heapprofd.cfg",
diff --git a/test/configs/client_api.cfg b/test/configs/client_api.cfg
new file mode 100644
index 0000000..c610ab8
--- /dev/null
+++ b/test/configs/client_api.cfg
@@ -0,0 +1,12 @@
+# This config is for trying out test/client_api_example.cc .
+
+buffers {
+  size_kb: 1024
+  fill_policy: RING_BUFFER
+}
+
+data_sources {
+  config {
+    name: "com.example.mytrace"
+  }
+}