service: Let producers choose whether to enable SMB scraping

Adds a setting to ConnectProducer / ProducerIPCClient which allows a
producer to instruct the service to enable or disable SMB scraping for
the producer.

Bug: 133268735
Change-Id: I941c963edc901c7946db8f85abcc25b70d6bf59d
diff --git a/include/perfetto/tracing/core/tracing_service.h b/include/perfetto/tracing/core/tracing_service.h
index 3d27698..fd10957 100644
--- a/include/perfetto/tracing/core/tracing_service.h
+++ b/include/perfetto/tracing/core/tracing_service.h
@@ -219,6 +219,20 @@
   using ProducerEndpoint = perfetto::ProducerEndpoint;
   using ConsumerEndpoint = perfetto::ConsumerEndpoint;
 
+  enum class ProducerSMBScrapingMode {
+    // Use service's default setting for SMB scraping. Currently, the default
+    // mode is to disable SMB scraping, but this may change in the future.
+    kDefault,
+
+    // Enable scraping of uncommitted chunks in producers' shared memory
+    // buffers.
+    kEnabled,
+
+    // Disable scraping of uncommitted chunks in producers' shared memory
+    // buffers.
+    kDisabled
+  };
+
   // Implemented in src/core/tracing_service_impl.cc .
   static std::unique_ptr<TracingService> CreateInstance(
       std::unique_ptr<SharedMemory::Factory>,
@@ -252,7 +266,9 @@
       uid_t uid,
       const std::string& name,
       size_t shared_memory_size_hint_bytes = 0,
-      bool in_process = false) = 0;
+      bool in_process = false,
+      ProducerSMBScrapingMode smb_scraping_mode =
+          ProducerSMBScrapingMode::kDefault) = 0;
 
   // Connects a Consumer instance and obtains a ConsumerEndpoint, which is
   // essentially a 1:1 channel between one Consumer and the Service.
diff --git a/include/perfetto/tracing/ipc/producer_ipc_client.h b/include/perfetto/tracing/ipc/producer_ipc_client.h
index d77143d..533c97f 100644
--- a/include/perfetto/tracing/ipc/producer_ipc_client.h
+++ b/include/perfetto/tracing/ipc/producer_ipc_client.h
@@ -45,7 +45,9 @@
       const char* service_sock_name,
       Producer*,
       const std::string& producer_name,
-      base::TaskRunner*);
+      base::TaskRunner*,
+      TracingService::ProducerSMBScrapingMode smb_scraping_mode =
+          TracingService::ProducerSMBScrapingMode::kDefault);
 
  protected:
   ProducerIPCClient() = delete;
diff --git a/protos/perfetto/ipc/producer_port.proto b/protos/perfetto/ipc/producer_port.proto
index 24efc27..e129afd 100644
--- a/protos/perfetto/ipc/producer_port.proto
+++ b/protos/perfetto/ipc/producer_port.proto
@@ -95,6 +95,22 @@
   // Required to match the producer config set by the service to the correct
   // producer.
   optional string producer_name = 3;
+
+  enum ProducerSMBScrapingMode {
+    // Use the service's default setting for SMB scraping.
+    SMB_SCRAPING_UNSPECIFIED = 0;
+
+    // Enable scraping of uncommitted chunks from the producer's shared memory
+    // buffer.
+    SMB_SCRAPING_ENABLED = 1;
+
+    // Disable scraping of uncommitted chunks from the producer's shared memory
+    // buffer.
+    SMB_SCRAPING_DISABLED = 2;
+  }
+
+  // If provided, overrides the service's SMB scraping setting for the producer.
+  optional ProducerSMBScrapingMode smb_scraping_mode = 4;
 }
 
 message InitializeConnectionResponse {
diff --git a/src/tracing/core/tracing_service_impl.cc b/src/tracing/core/tracing_service_impl.cc
index c6aac42..db27f46 100644
--- a/src/tracing/core/tracing_service_impl.cc
+++ b/src/tracing/core/tracing_service_impl.cc
@@ -144,7 +144,8 @@
                                     uid_t uid,
                                     const std::string& producer_name,
                                     size_t shared_memory_size_hint_bytes,
-                                    bool in_process) {
+                                    bool in_process,
+                                    ProducerSMBScrapingMode smb_scraping_mode) {
   PERFETTO_DCHECK_THREAD(thread_checker_);
 
   if (lockdown_mode_ && uid != geteuid()) {
@@ -160,8 +161,21 @@
   const ProducerID id = GetNextProducerID();
   PERFETTO_DLOG("Producer %" PRIu16 " connected", id);
 
+  bool smb_scraping_enabled = smb_scraping_enabled_;
+  switch (smb_scraping_mode) {
+    case ProducerSMBScrapingMode::kDefault:
+      break;
+    case ProducerSMBScrapingMode::kEnabled:
+      smb_scraping_enabled = true;
+      break;
+    case ProducerSMBScrapingMode::kDisabled:
+      smb_scraping_enabled = false;
+      break;
+  }
+
   std::unique_ptr<ProducerEndpointImpl> endpoint(new ProducerEndpointImpl(
-      id, uid, this, task_runner_, producer, producer_name, in_process));
+      id, uid, this, task_runner_, producer, producer_name, in_process,
+      smb_scraping_enabled));
   auto it_and_inserted = producers_.emplace(id, endpoint.get());
   PERFETTO_DCHECK(it_and_inserted.second);
   endpoint->shmem_size_hint_bytes_ = shared_memory_size_hint_bytes;
@@ -1147,7 +1161,7 @@
 void TracingServiceImpl::ScrapeSharedMemoryBuffers(
     TracingSession* tracing_session,
     ProducerEndpointImpl* producer) {
-  if (!smb_scraping_enabled_)
+  if (!producer->smb_scraping_enabled_)
     return;
 
   // Can't copy chunks if we don't know about any trace writers.
@@ -2480,7 +2494,8 @@
     base::TaskRunner* task_runner,
     Producer* producer,
     const std::string& producer_name,
-    bool in_process)
+    bool in_process,
+    bool smb_scraping_enabled)
     : id_(id),
       uid_(uid),
       service_(service),
@@ -2488,6 +2503,7 @@
       producer_(producer),
       name_(producer_name),
       in_process_(in_process),
+      smb_scraping_enabled_(smb_scraping_enabled),
       weak_ptr_factory_(this) {}
 
 TracingServiceImpl::ProducerEndpointImpl::~ProducerEndpointImpl() {
diff --git a/src/tracing/core/tracing_service_impl.h b/src/tracing/core/tracing_service_impl.h
index fbae3b6..10ca797 100644
--- a/src/tracing/core/tracing_service_impl.h
+++ b/src/tracing/core/tracing_service_impl.h
@@ -76,7 +76,8 @@
                          base::TaskRunner*,
                          Producer*,
                          const std::string& producer_name,
-                         bool in_process);
+                         bool in_process,
+                         bool smb_scraping_enabled);
     ~ProducerEndpointImpl() override;
 
     // TracingService::ProducerEndpoint implementation.
@@ -118,6 +119,7 @@
    private:
     friend class TracingServiceImpl;
     friend class TracingServiceImplTest;
+    friend class TracingIntegrationTest;
     ProducerEndpointImpl(const ProducerEndpointImpl&) = delete;
     ProducerEndpointImpl& operator=(const ProducerEndpointImpl&) = delete;
 
@@ -132,6 +134,7 @@
     size_t shmem_size_hint_bytes_ = 0;
     const std::string name_;
     bool in_process_;
+    bool smb_scraping_enabled_;
 
     // Set of the global target_buffer IDs that the producer is configured to
     // write into in any active tracing session.
@@ -260,12 +263,16 @@
       uid_t uid,
       const std::string& producer_name,
       size_t shared_memory_size_hint_bytes = 0,
-      bool in_process = false) override;
+      bool in_process = false,
+      ProducerSMBScrapingMode smb_scraping_mode =
+          ProducerSMBScrapingMode::kDefault) override;
 
   std::unique_ptr<TracingService::ConsumerEndpoint> ConnectConsumer(
       Consumer*,
       uid_t) override;
 
+  // Set whether SMB scraping should be enabled by default or not. Producers can
+  // override this setting for their own SMBs.
   void SetSMBScrapingEnabled(bool enabled) override {
     smb_scraping_enabled_ = enabled;
   }
@@ -278,6 +285,7 @@
 
  private:
   friend class TracingServiceImplTest;
+  friend class TracingIntegrationTest;
 
   struct RegisteredDataSource {
     ProducerID producer_id;
diff --git a/src/tracing/ipc/producer/producer_ipc_client_impl.cc b/src/tracing/ipc/producer/producer_ipc_client_impl.cc
index 54bcb80..bf6e1b0 100644
--- a/src/tracing/ipc/producer/producer_ipc_client_impl.cc
+++ b/src/tracing/ipc/producer/producer_ipc_client_impl.cc
@@ -41,21 +41,25 @@
     const char* service_sock_name,
     Producer* producer,
     const std::string& producer_name,
-    base::TaskRunner* task_runner) {
+    base::TaskRunner* task_runner,
+    TracingService::ProducerSMBScrapingMode smb_scraping_mode) {
   return std::unique_ptr<TracingService::ProducerEndpoint>(
       new ProducerIPCClientImpl(service_sock_name, producer, producer_name,
-                                task_runner));
+                                task_runner, smb_scraping_mode));
 }
 
-ProducerIPCClientImpl::ProducerIPCClientImpl(const char* service_sock_name,
-                                             Producer* producer,
-                                             const std::string& producer_name,
-                                             base::TaskRunner* task_runner)
+ProducerIPCClientImpl::ProducerIPCClientImpl(
+    const char* service_sock_name,
+    Producer* producer,
+    const std::string& producer_name,
+    base::TaskRunner* task_runner,
+    TracingService::ProducerSMBScrapingMode smb_scraping_mode)
     : producer_(producer),
       task_runner_(task_runner),
       ipc_channel_(ipc::Client::CreateInstance(service_sock_name, task_runner)),
       producer_port_(this /* event_listener */),
-      name_(producer_name) {
+      name_(producer_name),
+      smb_scraping_mode_(smb_scraping_mode) {
   ipc_channel_->BindService(producer_port_.GetWeakPtr());
   PERFETTO_DCHECK_THREAD(thread_checker_);
 }
@@ -77,6 +81,20 @@
       });
   protos::InitializeConnectionRequest req;
   req.set_producer_name(name_);
+  switch (smb_scraping_mode_) {
+    case TracingService::ProducerSMBScrapingMode::kDefault:
+      // No need to set the mode, it defaults to use the service default if
+      // unspecified.
+      break;
+    case TracingService::ProducerSMBScrapingMode::kEnabled:
+      req.set_smb_scraping_mode(
+          protos::InitializeConnectionRequest::SMB_SCRAPING_ENABLED);
+      break;
+    case TracingService::ProducerSMBScrapingMode::kDisabled:
+      req.set_smb_scraping_mode(
+          protos::InitializeConnectionRequest::SMB_SCRAPING_DISABLED);
+      break;
+  }
   producer_port_.InitializeConnection(req, std::move(on_init));
 
   // Create the back channel to receive commands from the Service.
diff --git a/src/tracing/ipc/producer/producer_ipc_client_impl.h b/src/tracing/ipc/producer/producer_ipc_client_impl.h
index f020747..ba43421 100644
--- a/src/tracing/ipc/producer/producer_ipc_client_impl.h
+++ b/src/tracing/ipc/producer/producer_ipc_client_impl.h
@@ -55,7 +55,8 @@
   ProducerIPCClientImpl(const char* service_sock_name,
                         Producer*,
                         const std::string& producer_name,
-                        base::TaskRunner*);
+                        base::TaskRunner*,
+                        TracingService::ProducerSMBScrapingMode);
   ~ProducerIPCClientImpl() override;
 
   // TracingService::ProducerEndpoint implementation.
@@ -108,6 +109,7 @@
   std::set<DataSourceInstanceID> data_sources_setup_;
   bool connected_ = false;
   std::string const name_;
+  TracingService::ProducerSMBScrapingMode const smb_scraping_mode_;
   PERFETTO_THREAD_CHECKER(thread_checker_)
 };
 
diff --git a/src/tracing/ipc/service/producer_ipc_service.cc b/src/tracing/ipc/service/producer_ipc_service.cc
index 9051058..15f7808 100644
--- a/src/tracing/ipc/service/producer_ipc_service.cc
+++ b/src/tracing/ipc/service/producer_ipc_service.cc
@@ -65,10 +65,24 @@
   // Create a new entry.
   std::unique_ptr<RemoteProducer> producer(new RemoteProducer());
 
+  TracingService::ProducerSMBScrapingMode smb_scraping_mode =
+      TracingService::ProducerSMBScrapingMode::kDefault;
+  switch (req.smb_scraping_mode()) {
+    case protos::InitializeConnectionRequest::SMB_SCRAPING_UNSPECIFIED:
+      break;
+    case protos::InitializeConnectionRequest::SMB_SCRAPING_DISABLED:
+      smb_scraping_mode = TracingService::ProducerSMBScrapingMode::kDisabled;
+      break;
+    case protos::InitializeConnectionRequest::SMB_SCRAPING_ENABLED:
+      smb_scraping_mode = TracingService::ProducerSMBScrapingMode::kEnabled;
+      break;
+  }
+
   // ConnectProducer will call OnConnect() on the next task.
   producer->service_endpoint = core_service_->ConnectProducer(
       producer.get(), client_info.uid(), req.producer_name(),
-      req.shared_memory_size_hint_bytes());
+      req.shared_memory_size_hint_bytes(), /*in_process=*/false,
+      smb_scraping_mode);
 
   // Could happen if the service has too many producers connected.
   if (!producer->service_endpoint)
diff --git a/src/tracing/test/tracing_integration_test.cc b/src/tracing/test/tracing_integration_test.cc
index 41fd1d4..16a2181 100644
--- a/src/tracing/test/tracing_integration_test.cc
+++ b/src/tracing/test/tracing_integration_test.cc
@@ -33,6 +33,7 @@
 #include "perfetto/tracing/ipc/service_ipc_host.h"
 #include "src/base/test/test_task_runner.h"
 #include "src/ipc/test/test_socket.h"
+#include "src/tracing/core/tracing_service_impl.h"
 
 #include "perfetto/config/trace_config.pb.h"
 #include "perfetto/trace/test_event.pbzero.h"
@@ -112,6 +113,8 @@
   EXPECT_EQ(0u, buf_stats.abi_violations());
 }
 
+}  // namespace
+
 class TracingIntegrationTest : public ::testing::Test {
  public:
   void SetUp() override {
@@ -126,7 +129,7 @@
     // Create and connect a Producer.
     producer_endpoint_ = ProducerIPCClient::Connect(
         kProducerSockName, &producer_, "perfetto.mock_producer",
-        task_runner_.get());
+        task_runner_.get(), GetProducerSMBScrapingMode());
     auto on_producer_connect =
         task_runner_->CreateCheckpoint("on_producer_connect");
     EXPECT_CALL(producer_, OnConnect()).WillOnce(Invoke(on_producer_connect));
@@ -175,6 +178,39 @@
     DESTROY_TEST_SOCK(kConsumerSockName);
   }
 
+  virtual TracingService::ProducerSMBScrapingMode GetProducerSMBScrapingMode() {
+    return TracingService::ProducerSMBScrapingMode::kDefault;
+  }
+
+  void WaitForTraceWritersChanged(ProducerID producer_id) {
+    static int i = 0;
+    auto checkpoint_name = "writers_changed_" + std::to_string(producer_id) +
+                           "_" + std::to_string(i++);
+    auto writers_changed = task_runner_->CreateCheckpoint(checkpoint_name);
+    auto writers = GetWriters(producer_id);
+    std::function<void()> task;
+    task = [&task, writers, writers_changed, producer_id, this]() {
+      if (writers != GetWriters(producer_id)) {
+        writers_changed();
+        return;
+      }
+      task_runner_->PostDelayedTask(task, 1);
+    };
+    task_runner_->PostDelayedTask(task, 1);
+    task_runner_->RunUntilCheckpoint(checkpoint_name);
+  }
+
+  const std::map<WriterID, BufferID>& GetWriters(ProducerID producer_id) {
+    return reinterpret_cast<TracingServiceImpl*>(svc_->service())
+        ->GetProducer(producer_id)
+        ->writers_;
+  }
+
+  ProducerID* last_producer_id() {
+    return &reinterpret_cast<TracingServiceImpl*>(svc_->service())
+                ->last_producer_id_;
+  }
+
   std::unique_ptr<base::TestTaskRunner> task_runner_;
   std::unique_ptr<ServiceIPCHost> svc_;
   std::unique_ptr<TracingService::ProducerEndpoint> producer_endpoint_;
@@ -406,6 +442,103 @@
   ASSERT_GT(num_system_info_packet, 0u);
 }
 
+class TracingIntegrationTestWithSMBScrapingProducer
+    : public TracingIntegrationTest {
+ public:
+  TracingService::ProducerSMBScrapingMode GetProducerSMBScrapingMode()
+      override {
+    return TracingService::ProducerSMBScrapingMode::kEnabled;
+  }
+};
+
+TEST_F(TracingIntegrationTestWithSMBScrapingProducer, ScrapeOnFlush) {
+  // Start tracing.
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("perfetto.test");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint_->EnableTracing(trace_config);
+
+  // At this point, the Producer should be asked to turn its data source on.
+
+  BufferID global_buf_id = 0;
+  auto on_create_ds_instance =
+      task_runner_->CreateCheckpoint("on_create_ds_instance");
+  EXPECT_CALL(producer_, OnTracingSetup());
+
+  EXPECT_CALL(producer_, SetupDataSource(_, _));
+  EXPECT_CALL(producer_, StartDataSource(_, _))
+      .WillOnce(Invoke([on_create_ds_instance, &global_buf_id](
+                           DataSourceInstanceID, const DataSourceConfig& cfg) {
+        global_buf_id = static_cast<BufferID>(cfg.target_buffer());
+        on_create_ds_instance();
+      }));
+  task_runner_->RunUntilCheckpoint("on_create_ds_instance");
+
+  // Create writer, which will post a task to register the writer with the
+  // service.
+  std::unique_ptr<TraceWriter> writer =
+      producer_endpoint_->CreateTraceWriter(global_buf_id);
+  ASSERT_TRUE(writer);
+
+  // Wait for the writer to be registered.
+  WaitForTraceWritersChanged(*last_producer_id());
+
+  // Write a few trace packets.
+  writer->NewTracePacket()->set_for_testing()->set_str("payload1");
+  writer->NewTracePacket()->set_for_testing()->set_str("payload2");
+  writer->NewTracePacket()->set_for_testing()->set_str("payload3");
+
+  // Ask the service to flush, but don't flush our trace writer. This should
+  // cause our uncommitted SMB chunk to be scraped.
+  auto on_flush_complete = task_runner_->CreateCheckpoint("on_flush_complete");
+  consumer_endpoint_->Flush(5000, [on_flush_complete](bool success) {
+    EXPECT_TRUE(success);
+    on_flush_complete();
+  });
+  EXPECT_CALL(producer_, Flush(_, _, _))
+      .WillOnce(Invoke([this](FlushRequestID flush_req_id,
+                              const DataSourceInstanceID*, size_t) {
+        producer_endpoint_->NotifyFlushComplete(flush_req_id);
+      }));
+  task_runner_->RunUntilCheckpoint("on_flush_complete");
+
+  // Read the log buffer. We should only see the first two written trace
+  // packets, because the service can't be sure the last one was written
+  // completely by the trace writer.
+  consumer_endpoint_->ReadBuffers();
+
+  size_t num_test_pack_rx = 0;
+  auto all_packets_rx = task_runner_->CreateCheckpoint("all_packets_rx");
+  EXPECT_CALL(consumer_, OnTracePackets(_, _))
+      .WillRepeatedly(
+          Invoke([&num_test_pack_rx, all_packets_rx](
+                     std::vector<TracePacket>* packets, bool has_more) {
+            for (auto& encoded_packet : *packets) {
+              protos::TracePacket packet;
+              ASSERT_TRUE(encoded_packet.Decode(&packet));
+              if (packet.has_for_testing()) {
+                num_test_pack_rx++;
+              }
+            }
+            if (!has_more)
+              all_packets_rx();
+          }));
+  task_runner_->RunUntilCheckpoint("all_packets_rx");
+  ASSERT_EQ(2, num_test_pack_rx);
+
+  // Disable tracing.
+  consumer_endpoint_->DisableTracing();
+
+  auto on_tracing_disabled =
+      task_runner_->CreateCheckpoint("on_tracing_disabled");
+  EXPECT_CALL(producer_, StopDataSource(_));
+  EXPECT_CALL(consumer_, OnTracingDisabled())
+      .WillOnce(Invoke(on_tracing_disabled));
+  task_runner_->RunUntilCheckpoint("on_tracing_disabled");
+}
+
 // TODO(primiano): add tests to cover:
 // - unknown fields preserved end-to-end.
 // - >1 data source.
@@ -416,5 +549,4 @@
 // - Out of order Enable/Disable/FreeBuffers calls.
 // - DisableTracing does actually freeze the buffers.
 
-}  // namespace
 }  // namespace perfetto