sdk: Support embedder-provided policy approving consumer connections

ChromeOS needs a way to deny (system) consumer connections when a
specific enterprise policy is set.

This patch adds a TracingPolicy object to InitTracingArgs which is
notified for each new consumer session, allowing the embedder to deny
them. Since Chrome can only check for this policy on a specific
thread, the approval can be asynchronous.

To support this, the SDK's ConsumerImpl object is created before the
service consumer connection is initialized. This works fine, because
subsequent interactions with the object already have to consider a
delayed service connection.

Bug: 183391449
Change-Id: Ibe70e01a69bb8cfe0447e986647a499de90c5012
diff --git a/Android.bp b/Android.bp
index c3e2ce7..10d3039 100644
--- a/Android.bp
+++ b/Android.bp
@@ -8437,6 +8437,7 @@
     "src/tracing/platform.cc",
     "src/tracing/traced_value.cc",
     "src/tracing/tracing.cc",
+    "src/tracing/tracing_policy.cc",
     "src/tracing/track.cc",
     "src/tracing/track_event_category_registry.cc",
     "src/tracing/track_event_legacy.cc",
diff --git a/BUILD b/BUILD
index 1699bdb..80fb3be 100644
--- a/BUILD
+++ b/BUILD
@@ -508,6 +508,7 @@
 filegroup(
     name = "include_perfetto_tracing_tracing",
     srcs = [
+        "include/perfetto/tracing/backend_type.h",
         "include/perfetto/tracing/buffer_exhausted_policy.h",
         "include/perfetto/tracing/console_interceptor.h",
         "include/perfetto/tracing/data_source.h",
@@ -535,6 +536,7 @@
         "include/perfetto/tracing/traced_value_forward.h",
         "include/perfetto/tracing/tracing.h",
         "include/perfetto/tracing/tracing_backend.h",
+        "include/perfetto/tracing/tracing_policy.h",
         "include/perfetto/tracing/track.h",
         "include/perfetto/tracing/track_event.h",
         "include/perfetto/tracing/track_event_category_registry.h",
@@ -1625,6 +1627,7 @@
         "src/tracing/platform.cc",
         "src/tracing/traced_value.cc",
         "src/tracing/tracing.cc",
+        "src/tracing/tracing_policy.cc",
         "src/tracing/track.cc",
         "src/tracing/track_event_category_registry.cc",
         "src/tracing/track_event_legacy.cc",
diff --git a/include/perfetto/tracing/BUILD.gn b/include/perfetto/tracing/BUILD.gn
index 607aa5f..c840df1 100644
--- a/include/perfetto/tracing/BUILD.gn
+++ b/include/perfetto/tracing/BUILD.gn
@@ -27,6 +27,7 @@
   ]
 
   sources = [
+    "backend_type.h",
     "buffer_exhausted_policy.h",
     "console_interceptor.h",
     "data_source.h",
@@ -54,6 +55,7 @@
     "traced_value_forward.h",
     "tracing.h",
     "tracing_backend.h",
+    "tracing_policy.h",
     "track.h",
     "track_event.h",
     "track_event_category_registry.h",
diff --git a/include/perfetto/tracing/backend_type.h b/include/perfetto/tracing/backend_type.h
new file mode 100644
index 0000000..e242650
--- /dev/null
+++ b/include/perfetto/tracing/backend_type.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INCLUDE_PERFETTO_TRACING_BACKEND_TYPE_H_
+#define INCLUDE_PERFETTO_TRACING_BACKEND_TYPE_H_
+
+#include <stdint.h>
+
+namespace perfetto {
+
+enum BackendType : uint32_t {
+  kUnspecifiedBackend = 0,
+
+  // Connects to a previously-initialized perfetto tracing backend for
+  // in-process. If the in-process backend has not been previously initialized
+  // it will do so and create the tracing service on a dedicated thread.
+  kInProcessBackend = 1 << 0,
+
+  // Connects to the system tracing service (e.g. on Linux/Android/Mac uses a
+  // named UNIX socket).
+  kSystemBackend = 1 << 1,
+
+  // Used to provide a custom IPC transport to connect to the service.
+  // TracingInitArgs::custom_backend must be non-null and point to an
+  // indefinitely lived instance.
+  kCustomBackend = 1 << 2,
+};
+
+}  // namespace perfetto
+
+#endif  // INCLUDE_PERFETTO_TRACING_BACKEND_TYPE_H_
diff --git a/include/perfetto/tracing/tracing.h b/include/perfetto/tracing/tracing.h
index b859eb1..0b7fe6e 100644
--- a/include/perfetto/tracing/tracing.h
+++ b/include/perfetto/tracing/tracing.h
@@ -28,9 +28,11 @@
 #include "perfetto/base/compiler.h"
 #include "perfetto/base/export.h"
 #include "perfetto/base/logging.h"
+#include "perfetto/tracing/backend_type.h"
 #include "perfetto/tracing/core/forward_decls.h"
 #include "perfetto/tracing/internal/in_process_tracing_backend.h"
 #include "perfetto/tracing/internal/system_tracing_backend.h"
+#include "perfetto/tracing/tracing_policy.h"
 
 namespace perfetto {
 
@@ -42,24 +44,6 @@
 class Platform;
 class TracingSession;  // Declared below.
 
-enum BackendType : uint32_t {
-  kUnspecifiedBackend = 0,
-
-  // Connects to a previously-initialized perfetto tracing backend for
-  // in-process. If the in-process backend has not been previously initialized
-  // it will do so and create the tracing service on a dedicated thread.
-  kInProcessBackend = 1 << 0,
-
-  // Connects to the system tracing service (e.g. on Linux/Android/Mac uses a
-  // named UNIX socket).
-  kSystemBackend = 1 << 1,
-
-  // Used to provide a custom IPC transport to connect to the service.
-  // TracingInitArgs::custom_backend must be non-null and point to an
-  // indefinitely lived instance.
-  kCustomBackend = 1 << 2,
-};
-
 struct TracingError {
   enum ErrorCode : uint32_t {
     // Peer disconnection.
@@ -80,7 +64,7 @@
 };
 
 struct TracingInitArgs {
-  uint32_t backends = 0;                     // One or more BackendFlags.
+  uint32_t backends = 0;                     // One or more BackendTypes.
   TracingBackend* custom_backend = nullptr;  // [Optional].
 
   // [Optional] Platform implementation. It allows the embedder to take control
@@ -117,6 +101,12 @@
   // delay, i.e. commits will be sent to the service at the next opportunity.
   uint32_t shmem_batch_commits_duration_ms = 0;
 
+  // [Optional] If set, the policy object is notified when certain SDK events
+  // occur and may apply policy decisions, such as denying connections. The
+  // embedder is responsible for ensuring the object remains alive for the
+  // lifetime of the process.
+  TracingPolicy* tracing_policy = nullptr;
+
  protected:
   friend class Tracing;
   friend class internal::TracingMuxerImpl;
diff --git a/include/perfetto/tracing/tracing_policy.h b/include/perfetto/tracing/tracing_policy.h
new file mode 100644
index 0000000..39ca6f7
--- /dev/null
+++ b/include/perfetto/tracing/tracing_policy.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INCLUDE_PERFETTO_TRACING_TRACING_POLICY_H_
+#define INCLUDE_PERFETTO_TRACING_TRACING_POLICY_H_
+
+#include <functional>
+
+#include "perfetto/tracing/backend_type.h"
+
+namespace perfetto {
+
+// Applies policy decisions, such as allowing or denying connections, when
+// certain tracing SDK events occur. All methods are called on an internal
+// perfetto thread.
+class TracingPolicy {
+ public:
+  virtual ~TracingPolicy();
+
+  // Called when the current process attempts to connect a new consumer to the
+  // backend of |backend_type| to check if the connection should be allowed. Its
+  // implementation should execute |result_callback| with the result of the
+  // check (synchronuosly or asynchronously on any thread). If the result is
+  // false, the consumer connection is aborted. Chrome uses this to restrict
+  // creating (system) tracing sessions based on an enterprise policy.
+  struct ShouldAllowConsumerSessionArgs {
+    BackendType backend_type;
+    std::function<void(bool /*allow*/)> result_callback;
+  };
+  virtual void ShouldAllowConsumerSession(
+      const ShouldAllowConsumerSessionArgs&) = 0;
+};
+
+}  // namespace perfetto
+
+#endif  // INCLUDE_PERFETTO_TRACING_TRACING_POLICY_H_
diff --git a/src/tracing/BUILD.gn b/src/tracing/BUILD.gn
index 43ce774..2615660 100644
--- a/src/tracing/BUILD.gn
+++ b/src/tracing/BUILD.gn
@@ -115,6 +115,7 @@
     "platform.cc",
     "traced_value.cc",
     "tracing.cc",
+    "tracing_policy.cc",
     "track.cc",
     "track_event_category_registry.cc",
     "track_event_legacy.cc",
diff --git a/src/tracing/internal/tracing_muxer_impl.cc b/src/tracing/internal/tracing_muxer_impl.cc
index 6c4edb3..280a990 100644
--- a/src/tracing/internal/tracing_muxer_impl.cc
+++ b/src/tracing/internal/tracing_muxer_impl.cc
@@ -653,6 +653,8 @@
 void TracingMuxerImpl::Initialize(const TracingInitArgs& args) {
   PERFETTO_DCHECK_THREAD(thread_checker_);  // Rebind the thread checker.
 
+  policy_ = args.tracing_policy;
+
   auto add_backend = [this, &args](TracingBackend* backend, BackendType type) {
     if (!backend) {
       // We skip the log in release builds because the *_backend_fake.cc code
@@ -1361,7 +1363,6 @@
   for (RegisteredBackend& backend : backends_) {
     for (auto& consumer : backend.consumers) {
       if (consumer->session_id_ == session_id) {
-        PERFETTO_DCHECK(consumer->service_);
         return consumer.get();
       }
     }
@@ -1369,6 +1370,24 @@
   return nullptr;
 }
 
+void TracingMuxerImpl::InitializeConsumer(TracingSessionGlobalID session_id) {
+  PERFETTO_DCHECK_THREAD(thread_checker_);
+
+  auto* consumer = FindConsumer(session_id);
+  if (!consumer)
+    return;
+
+  TracingBackendId backend_id = consumer->backend_id_;
+  // |backends_| is append-only, Backend instances are always valid.
+  PERFETTO_CHECK(backend_id < backends_.size());
+  RegisteredBackend& backend = backends_[backend_id];
+
+  TracingBackend::ConnectConsumerArgs conn_args;
+  conn_args.consumer = consumer;
+  conn_args.task_runner = task_runner_.get();
+  consumer->Initialize(backend.backend->ConnectConsumer(conn_args));
+}
+
 void TracingMuxerImpl::OnConsumerDisconnected(ConsumerImpl* consumer) {
   PERFETTO_DCHECK_THREAD(thread_checker_);
   for (RegisteredBackend& backend : backends_) {
@@ -1483,21 +1502,53 @@
         continue;
       }
 
+      TracingBackendId backend_id = backend.id;
+
+      // Create the consumer now, even if we have to ask the embedder below, so
+      // that any other tasks executing after this one can find the consumer and
+      // change its pending attributes.
+      backend.consumers.emplace_back(
+          new ConsumerImpl(this, backend.type, backend.id, session_id));
+
       // The last registered backend in |backends_| is the unsupported backend
       // without a valid type.
       if (!backend.type) {
         PERFETTO_ELOG(
             "No tracing backend ready for type=%d, consumer will disconnect",
             requested_backend_type);
+        InitializeConsumer(session_id);
+        return;
       }
 
-      backend.consumers.emplace_back(
-          new ConsumerImpl(this, backend.type, backend.id, session_id));
-      auto& consumer = backend.consumers.back();
-      TracingBackend::ConnectConsumerArgs conn_args;
-      conn_args.consumer = consumer.get();
-      conn_args.task_runner = task_runner_.get();
-      consumer->Initialize(backend.backend->ConnectConsumer(conn_args));
+      // Check if the embedder wants to be asked for permission before
+      // connecting the consumer.
+      if (!policy_) {
+        InitializeConsumer(session_id);
+        return;
+      }
+
+      TracingPolicy::ShouldAllowConsumerSessionArgs args;
+      args.backend_type = backend.type;
+      args.result_callback = [this, backend_id, session_id](bool allow) {
+        task_runner_->PostTask([this, backend_id, session_id, allow] {
+          if (allow) {
+            InitializeConsumer(session_id);
+            return;
+          }
+
+          PERFETTO_ELOG(
+              "Consumer session for backend type type=%d forbidden, "
+              "consumer will disconnect",
+              backends_[backend_id].type);
+
+          auto* consumer = FindConsumer(session_id);
+          if (!consumer)
+            return;
+
+          consumer->OnDisconnect();
+        });
+      };
+      policy_->ShouldAllowConsumerSession(args);
       return;
     }
     PERFETTO_DFATAL("Not reached");
diff --git a/src/tracing/internal/tracing_muxer_impl.h b/src/tracing/internal/tracing_muxer_impl.h
index 0c0f551..d0898a7 100644
--- a/src/tracing/internal/tracing_muxer_impl.h
+++ b/src/tracing/internal/tracing_muxer_impl.h
@@ -380,6 +380,7 @@
   explicit TracingMuxerImpl(const TracingInitArgs&);
   void Initialize(const TracingInitArgs& args);
   ConsumerImpl* FindConsumer(TracingSessionGlobalID session_id);
+  void InitializeConsumer(TracingSessionGlobalID session_id);
   void OnConsumerDisconnected(ConsumerImpl* consumer);
   void OnProducerDisconnected(ProducerImpl* producer);
 
@@ -399,6 +400,7 @@
   std::vector<RegisteredDataSource> data_sources_;
   std::vector<RegisteredBackend> backends_;
   std::vector<RegisteredInterceptor> interceptors_;
+  TracingPolicy* policy_ = nullptr;
 
   std::atomic<TracingSessionGlobalID> next_tracing_session_id_{};
 
diff --git a/src/tracing/test/api_integrationtest.cc b/src/tracing/test/api_integrationtest.cc
index 62115e8..cd65a55 100644
--- a/src/tracing/test/api_integrationtest.cc
+++ b/src/tracing/test/api_integrationtest.cc
@@ -174,6 +174,7 @@
 
 namespace {
 
+using perfetto::TracingInitArgs;
 using ::testing::_;
 using ::testing::ContainerEq;
 using ::testing::ElementsAre;
@@ -343,6 +344,19 @@
   }
 };
 
+class TestTracingPolicy : public perfetto::TracingPolicy {
+ public:
+  void ShouldAllowConsumerSession(
+      const ShouldAllowConsumerSessionArgs& args) override {
+    EXPECT_NE(args.backend_type, perfetto::BackendType::kUnspecifiedBackend);
+    args.result_callback(should_allow_consumer_connection);
+  }
+
+  bool should_allow_consumer_connection = true;
+};
+
+TestTracingPolicy* g_test_tracing_policy = new TestTracingPolicy();  // Leaked.
+
 // -------------------------
 // Declaration of test class
 // -------------------------
@@ -352,6 +366,7 @@
 
   void SetUp() override {
     instance = this;
+    g_test_tracing_policy->should_allow_consumer_connection = true;
 
     // Start a fresh system service for this test, tearing down any previous
     // service that was running.
@@ -376,8 +391,9 @@
     // Since the client API can only be initialized once per process, initialize
     // both the in-process and system backends for every test here. The actual
     // service to be used is chosen by the test parameter.
-    perfetto::TracingInitArgs args;
+    TracingInitArgs args;
     args.backends = supported_backends;
+    args.tracing_policy = g_test_tracing_policy;
     perfetto::Tracing::Initialize(args);
     RegisterDataSource<MockDataSource>("my_data_source");
     perfetto::TrackEvent::Register();
@@ -724,7 +740,10 @@
   tracing_session->get()->StopBlocking();
 }
 
-TEST_P(PerfettoApiTest, TrackEventStartStopAndDestroy) {
+// Disabled by default because it leaks tracing sessions into subsequent tests,
+// which can result in the per-uid tracing session limit (5) to be hit in later
+// tests.
+TEST_P(PerfettoApiTest, DISABLED_TrackEventStartStopAndDestroy) {
   // This test used to cause a use after free as the tracing session got
   // destroyed. It needed to be run approximately 2000 times to catch it so test
   // with --gtest_repeat=3000 (less if running under GDB).
@@ -2865,6 +2884,31 @@
   tracing_session->get()->StopBlocking();
 }
 
+TEST_P(PerfettoApiTest, ForbiddenConsumer) {
+  g_test_tracing_policy->should_allow_consumer_connection = false;
+
+  // Create a new trace session while consumer connections are forbidden.
+  perfetto::TraceConfig cfg;
+  cfg.add_buffers()->set_size_kb(1024);
+  auto* tracing_session = NewTrace(cfg);
+
+  // Creating the consumer should cause an asynchronous disconnect error.
+  WaitableTestEvent got_error;
+  tracing_session->get()->SetOnErrorCallback([&](perfetto::TracingError error) {
+    EXPECT_EQ(perfetto::TracingError::kDisconnected, error.code);
+    EXPECT_FALSE(error.message.empty());
+    got_error.Notify();
+  });
+  got_error.Wait();
+
+  // Clear the callback for test tear down.
+  tracing_session->get()->SetOnErrorCallback(nullptr);
+  // Synchronize the consumer channel to ensure the callback has propagated.
+  tracing_session->get()->StopBlocking();
+
+  g_test_tracing_policy->should_allow_consumer_connection = true;
+}
+
 TEST_P(PerfettoApiTest, GetTraceStats) {
   perfetto::TraceConfig cfg;
   cfg.set_duration_ms(500);
diff --git a/src/tracing/tracing_policy.cc b/src/tracing/tracing_policy.cc
new file mode 100644
index 0000000..15df1a0
--- /dev/null
+++ b/src/tracing/tracing_policy.cc
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/tracing/tracing_policy.h"
+
+namespace perfetto {
+
+TracingPolicy::~TracingPolicy() = default;
+
+}  // namespace perfetto