TurnCustomizer - an interface for modifying stun messages sent by TurnPort

This patch adds an interface that allows modification of stun messages
sent by TurnPort. A user can inject a TurnCustomizer on the RTCConfig
and the TurnCustomizer will be invoked by TurnPort before sending
message. This allows user to e.g add custom attributes as described
in rtf5389.

BUG=webrtc:8313

Change-Id: I6f4333e9f8ff7fd20f32677be19285f15e1180b6
Reviewed-on: https://webrtc-review.googlesource.com/7618
Reviewed-by: Guido Urdaneta <guidou@webrtc.org>
Reviewed-by: Taylor Brandstetter <deadbeef@webrtc.org>
Reviewed-by: Sami Kalliomäki <sakal@webrtc.org>
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20233}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index ec007b0..9f655a7 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -148,6 +148,7 @@
       "base/mockicetransport.h",
       "base/testrelayserver.h",
       "base/teststunserver.h",
+      "base/testturncustomizer.h",
       "base/testturnserver.h",
     ]
     deps = [
diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc
index 36af5de..af853ca 100644
--- a/p2p/base/port_unittest.cc
+++ b/p2p/base/port_unittest.cc
@@ -536,7 +536,8 @@
     return TurnPort::Create(
         &main_, socket_factory, MakeNetwork(addr), 0, 0, username_, password_,
         ProtocolAddress(server_addr, int_proto), kRelayCredentials, 0,
-        std::string(), std::vector<std::string>(), std::vector<std::string>());
+        std::string(), std::vector<std::string>(), std::vector<std::string>(),
+        nullptr);
   }
   RelayPort* CreateGturnPort(const SocketAddress& addr,
                              ProtocolType int_proto, ProtocolType ext_proto) {
diff --git a/p2p/base/portallocator.cc b/p2p/base/portallocator.cc
index 97dedc5..109472b 100644
--- a/p2p/base/portallocator.cc
+++ b/p2p/base/portallocator.cc
@@ -33,7 +33,8 @@
     const ServerAddresses& stun_servers,
     const std::vector<RelayServerConfig>& turn_servers,
     int candidate_pool_size,
-    bool prune_turn_ports) {
+    bool prune_turn_ports,
+    webrtc::TurnCustomizer* turn_customizer) {
   bool ice_servers_changed =
       (stun_servers != stun_servers_ || turn_servers != turn_servers_);
   stun_servers_ = stun_servers;
@@ -62,6 +63,8 @@
     pooled_sessions_.clear();
   }
 
+  turn_customizer_ = turn_customizer;
+
   // If |candidate_pool_size_| is less than the number of pooled sessions, get
   // rid of the extras.
   while (candidate_pool_size_ < static_cast<int>(pooled_sessions_.size())) {
diff --git a/p2p/base/portallocator.h b/p2p/base/portallocator.h
index de8d2d9..5c0d303 100644
--- a/p2p/base/portallocator.h
+++ b/p2p/base/portallocator.h
@@ -25,6 +25,7 @@
 
 namespace webrtc {
 class MetricsObserverInterface;
+class TurnCustomizer;
 }
 
 namespace cricket {
@@ -362,7 +363,8 @@
   bool SetConfiguration(const ServerAddresses& stun_servers,
                         const std::vector<RelayServerConfig>& turn_servers,
                         int candidate_pool_size,
-                        bool prune_turn_ports);
+                        bool prune_turn_ports,
+                        webrtc::TurnCustomizer* turn_customizer = nullptr);
 
   const ServerAddresses& stun_servers() const { return stun_servers_; }
 
@@ -477,6 +479,10 @@
     metrics_observer_ = observer;
   }
 
+  webrtc::TurnCustomizer* turn_customizer() {
+    return turn_customizer_;
+  }
+
  protected:
   virtual PortAllocatorSession* CreateSessionInternal(
       const std::string& content_name,
@@ -512,6 +518,11 @@
   bool prune_turn_ports_ = false;
 
   webrtc::MetricsObserverInterface* metrics_observer_ = nullptr;
+
+  // Customizer for TURN messages.
+  // The instance is owned by application and will be shared among
+  // all TurnPort(s) created.
+  webrtc::TurnCustomizer* turn_customizer_ = nullptr;
 };
 
 }  // namespace cricket
diff --git a/p2p/base/stun.cc b/p2p/base/stun.cc
index 9ac2554..e7d1244 100644
--- a/p2p/base/stun.cc
+++ b/p2p/base/stun.cc
@@ -67,9 +67,18 @@
   return true;
 }
 
+static bool ImplementationDefinedRange(int attr_type)
+{
+  return attr_type >= 0xC000 && attr_type <= 0xFFFF;
+}
+
 void StunMessage::AddAttribute(std::unique_ptr<StunAttribute> attr) {
-  // Fail any attributes that aren't valid for this type of message.
-  RTC_DCHECK_EQ(attr->value_type(), GetAttributeValueType(attr->type()));
+  // Fail any attributes that aren't valid for this type of message,
+  // but allow any type for the range that is "implementation defined"
+  // in the RFC.
+  if (!ImplementationDefinedRange(attr->type())) {
+    RTC_DCHECK_EQ(attr->value_type(), GetAttributeValueType(attr->type()));
+  }
 
   attr->SetOwner(this);
   size_t attr_length = attr->length();
@@ -398,8 +407,16 @@
 
 StunAttribute* StunMessage::CreateAttribute(int type, size_t length) /*const*/ {
   StunAttributeValueType value_type = GetAttributeValueType(type);
-  return StunAttribute::Create(value_type, type, static_cast<uint16_t>(length),
-                               this);
+  if (value_type != STUN_VALUE_UNKNOWN) {
+    return StunAttribute::Create(value_type, type,
+                                 static_cast<uint16_t>(length), this);
+  } else if (ImplementationDefinedRange(type)) {
+    // Read unknown attributes as STUN_VALUE_BYTE_STRING
+    return StunAttribute::Create(STUN_VALUE_BYTE_STRING, type,
+                                 static_cast<uint16_t>(length), this);
+  } else {
+    return NULL;
+  }
 }
 
 const StunAttribute* StunMessage::GetAttribute(int type) const {
diff --git a/p2p/base/stunrequest.cc b/p2p/base/stunrequest.cc
index 8afd1a0..cec9ce3 100644
--- a/p2p/base/stunrequest.cc
+++ b/p2p/base/stunrequest.cc
@@ -207,6 +207,10 @@
   return msg_;
 }
 
+StunMessage* StunRequest::mutable_msg() {
+  return msg_;
+}
+
 int StunRequest::Elapsed() const {
   return static_cast<int>(rtc::TimeMillis() - tstamp_);
 }
diff --git a/p2p/base/stunrequest.h b/p2p/base/stunrequest.h
index 00f317f..e77242e 100644
--- a/p2p/base/stunrequest.h
+++ b/p2p/base/stunrequest.h
@@ -105,6 +105,9 @@
   // Returns a const pointer to |msg_|.
   const StunMessage* msg() const;
 
+  // Returns a mutable pointer to |msg_|.
+  StunMessage* mutable_msg();
+
   // Time elapsed since last send (in ms)
   int Elapsed() const;
 
diff --git a/p2p/base/testturncustomizer.h b/p2p/base/testturncustomizer.h
new file mode 100644
index 0000000..33f48fb
--- /dev/null
+++ b/p2p/base/testturncustomizer.h
@@ -0,0 +1,57 @@
+/*
+ *  Copyright 2017 The WebRTC Project Authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef P2P_BASE_TESTTURNCUSTOMIZER_H_
+#define P2P_BASE_TESTTURNCUSTOMIZER_H_
+
+#include "api/turncustomizer.h"
+#include "rtc_base/ptr_util.h"
+
+namespace cricket {
+
+class TestTurnCustomizer : public webrtc::TurnCustomizer {
+ public:
+  TestTurnCustomizer() {}
+  virtual ~TestTurnCustomizer() {}
+
+  enum TestTurnAttributeExtensions {
+    // Test only attribute
+    STUN_ATTR_COUNTER                     = 0xFF02   // Number
+  };
+
+  void MaybeModifyOutgoingStunMessage(
+      cricket::PortInterface* port,
+      cricket::StunMessage* message) override {
+    modify_cnt_ ++;
+
+    if (add_counter_) {
+      message->AddAttribute(rtc::MakeUnique<cricket::StunUInt32Attribute>(
+          STUN_ATTR_COUNTER, modify_cnt_));
+    }
+    return;
+  }
+
+  bool AllowChannelData(cricket::PortInterface* port,
+                        const void* data,
+                        size_t size,
+                        bool payload) override {
+    allow_channel_data_cnt_++;
+    return allow_channel_data_;
+  }
+
+  bool add_counter_ = false;
+  bool allow_channel_data_ = true;
+  unsigned int modify_cnt_ = 0;
+  unsigned int allow_channel_data_cnt_ = 0;
+};
+
+}  // namespace cricket
+
+#endif  // P2P_BASE_TESTTURNCUSTOMIZER_H_
diff --git a/p2p/base/turnport.cc b/p2p/base/turnport.cc
index bf1351b..a0b98bb 100644
--- a/p2p/base/turnport.cc
+++ b/p2p/base/turnport.cc
@@ -196,7 +196,8 @@
                    const ProtocolAddress& server_address,
                    const RelayCredentials& credentials,
                    int server_priority,
-                   const std::string& origin)
+                   const std::string& origin,
+                   webrtc::TurnCustomizer* customizer)
     : Port(thread,
            RELAY_PORT_TYPE,
            factory,
@@ -212,7 +213,8 @@
       next_channel_number_(TURN_CHANNEL_NUMBER_START),
       state_(STATE_CONNECTING),
       server_priority_(server_priority),
-      allocate_mismatch_retries_(0) {
+      allocate_mismatch_retries_(0),
+      turn_customizer_(customizer) {
   request_manager_.SignalSendPacket.connect(this, &TurnPort::OnSendStunPacket);
   request_manager_.set_origin(origin);
 }
@@ -229,7 +231,8 @@
                    int server_priority,
                    const std::string& origin,
                    const std::vector<std::string>& tls_alpn_protocols,
-                   const std::vector<std::string>& tls_elliptic_curves)
+                   const std::vector<std::string>& tls_elliptic_curves,
+                   webrtc::TurnCustomizer* customizer)
     : Port(thread,
            RELAY_PORT_TYPE,
            factory,
@@ -249,7 +252,8 @@
       next_channel_number_(TURN_CHANNEL_NUMBER_START),
       state_(STATE_CONNECTING),
       server_priority_(server_priority),
-      allocate_mismatch_retries_(0) {
+      allocate_mismatch_retries_(0),
+      turn_customizer_(customizer) {
   request_manager_.SignalSendPacket.connect(this, &TurnPort::OnSendStunPacket);
   request_manager_.set_origin(origin);
 }
@@ -952,6 +956,7 @@
 }
 
 void TurnPort::SendRequest(StunRequest* req, int delay) {
+  TurnCustomizerMaybeModifyOutgoingStunMessage(req->mutable_msg());
   request_manager_.SendDelayed(req, delay);
 }
 
@@ -1142,6 +1147,24 @@
   return url.str();
 }
 
+void TurnPort::TurnCustomizerMaybeModifyOutgoingStunMessage(
+    StunMessage* message) {
+  if (turn_customizer_ == nullptr) {
+    return;
+  }
+
+  turn_customizer_->MaybeModifyOutgoingStunMessage(this, message);
+}
+
+bool TurnPort::TurnCustomizerAllowChannelData(
+    const void* data, size_t size, bool payload) {
+  if (turn_customizer_ == nullptr) {
+    return true;
+  }
+
+  return turn_customizer_->AllowChannelData(this, data, size, payload);
+}
+
 TurnAllocateRequest::TurnAllocateRequest(TurnPort* port)
     : StunRequest(new TurnMessage()),
       port_(port) {
@@ -1533,8 +1556,10 @@
 int TurnEntry::Send(const void* data, size_t size, bool payload,
                     const rtc::PacketOptions& options) {
   rtc::ByteBufferWriter buf;
-  if (state_ != STATE_BOUND) {
+  if (state_ != STATE_BOUND ||
+      !port_->TurnCustomizerAllowChannelData(data, size, payload)) {
     // If we haven't bound the channel yet, we have to use a Send Indication.
+    // The turn_customizer_ can also make us use Send Indication.
     TurnMessage msg;
     msg.SetType(TURN_SEND_INDICATION);
     msg.SetTransactionID(
@@ -1543,6 +1568,9 @@
         STUN_ATTR_XOR_PEER_ADDRESS, ext_addr_));
     msg.AddAttribute(
         rtc::MakeUnique<StunByteStringAttribute>(STUN_ATTR_DATA, data, size));
+
+    port_->TurnCustomizerMaybeModifyOutgoingStunMessage(&msg);
+
     const bool success = msg.Write(&buf);
     RTC_DCHECK(success);
 
diff --git a/p2p/base/turnport.h b/p2p/base/turnport.h
index f891983..1d1ffe2 100644
--- a/p2p/base/turnport.h
+++ b/p2p/base/turnport.h
@@ -26,6 +26,10 @@
 class SignalThread;
 }
 
+namespace webrtc {
+class TurnCustomizer;
+}
+
 namespace cricket {
 
 extern const char TURN_PORT_TYPE[];
@@ -52,9 +56,11 @@
                           const ProtocolAddress& server_address,
                           const RelayCredentials& credentials,
                           int server_priority,
-                          const std::string& origin) {
+                          const std::string& origin,
+                          webrtc::TurnCustomizer* customizer) {
     return new TurnPort(thread, factory, network, socket, username, password,
-                        server_address, credentials, server_priority, origin);
+                        server_address, credentials, server_priority, origin,
+                        customizer);
   }
 
   // Create a TURN port that will use a new socket, bound to |network| and
@@ -71,10 +77,12 @@
                           int server_priority,
                           const std::string& origin,
                           const std::vector<std::string>& tls_alpn_protocols,
-                          const std::vector<std::string>& tls_elliptic_curves) {
+                          const std::vector<std::string>& tls_elliptic_curves,
+                          webrtc::TurnCustomizer* customizer) {
     return new TurnPort(thread, factory, network, min_port, max_port, username,
                         password, server_address, credentials, server_priority,
-                        origin, tls_alpn_protocols, tls_elliptic_curves);
+                        origin, tls_alpn_protocols, tls_elliptic_curves,
+                        customizer);
   }
 
   virtual ~TurnPort();
@@ -184,7 +192,8 @@
            const ProtocolAddress& server_address,
            const RelayCredentials& credentials,
            int server_priority,
-           const std::string& origin);
+           const std::string& origin,
+           webrtc::TurnCustomizer* customizer);
 
   TurnPort(rtc::Thread* thread,
            rtc::PacketSocketFactory* factory,
@@ -198,7 +207,8 @@
            int server_priority,
            const std::string& origin,
            const std::vector<std::string>& tls_alpn_protocols,
-           const std::vector<std::string>& tls_elliptic_curves);
+           const std::vector<std::string>& tls_elliptic_curves,
+           webrtc::TurnCustomizer* customizer);
 
  private:
   enum {
@@ -275,6 +285,10 @@
   // Reconstruct the URL of the server which the candidate is gathered from.
   std::string ReconstructedServerUrl();
 
+  void TurnCustomizerMaybeModifyOutgoingStunMessage(StunMessage* message);
+  bool TurnCustomizerAllowChannelData(const void* data,
+                                      size_t size, bool payload);
+
   ProtocolAddress server_address_;
   TlsCertPolicy tls_cert_policy_ = TlsCertPolicy::TLS_CERT_POLICY_SECURE;
   std::vector<std::string> tls_alpn_protocols_;
@@ -305,6 +319,9 @@
 
   rtc::AsyncInvoker invoker_;
 
+  // Optional TurnCustomizer that can modify outgoing messages.
+  webrtc::TurnCustomizer *turn_customizer_ = nullptr;
+
   friend class TurnEntry;
   friend class TurnAllocateRequest;
   friend class TurnRefreshRequest;
diff --git a/p2p/base/turnport_unittest.cc b/p2p/base/turnport_unittest.cc
index 0bc4b31..4abc025 100644
--- a/p2p/base/turnport_unittest.cc
+++ b/p2p/base/turnport_unittest.cc
@@ -18,6 +18,7 @@
 #include "p2p/base/p2pconstants.h"
 #include "p2p/base/portallocator.h"
 #include "p2p/base/tcpport.h"
+#include "p2p/base/testturncustomizer.h"
 #include "p2p/base/testturnserver.h"
 #include "p2p/base/turnport.h"
 #include "p2p/base/udpport.h"
@@ -30,6 +31,7 @@
 #include "rtc_base/gunit.h"
 #include "rtc_base/helpers.h"
 #include "rtc_base/logging.h"
+#include "rtc_base/ptr_util.h"
 #include "rtc_base/socketadapters.h"
 #include "rtc_base/socketaddress.h"
 #include "rtc_base/ssladapter.h"
@@ -264,7 +266,7 @@
     turn_port_.reset(TurnPort::Create(
         &main_, &socket_factory_, network, 0, 0, kIceUfrag1, kIcePwd1,
         server_address, credentials, 0, origin, std::vector<std::string>(),
-        std::vector<std::string>()));
+        std::vector<std::string>(), turn_customizer_.get()));
     // This TURN port will be the controlling.
     turn_port_->SetIceRole(ICEROLE_CONTROLLING);
     ConnectSignals();
@@ -294,7 +296,8 @@
     RelayCredentials credentials(username, password);
     turn_port_.reset(TurnPort::Create(
         &main_, &socket_factory_, MakeNetwork(kLocalAddr1), socket_.get(),
-        kIceUfrag1, kIcePwd1, server_address, credentials, 0, std::string()));
+        kIceUfrag1, kIcePwd1, server_address, credentials, 0, std::string(),
+        nullptr));
     // This TURN port will be the controlling.
     turn_port_->SetIceRole(ICEROLE_CONTROLLING);
     ConnectSignals();
@@ -695,6 +698,7 @@
   std::vector<rtc::Buffer> turn_packets_;
   std::vector<rtc::Buffer> udp_packets_;
   rtc::PacketOptions options;
+  std::unique_ptr<webrtc::TurnCustomizer> turn_customizer_;
 };
 
 TEST_F(TurnPortTest, TestTurnPortType) {
@@ -1458,4 +1462,146 @@
 }
 #endif
 
+class MessageObserver : public StunMessageObserver{
+ public:
+  MessageObserver(unsigned int *message_counter,
+                  unsigned int* channel_data_counter,
+                  unsigned int *attr_counter)
+      : message_counter_(message_counter),
+        channel_data_counter_(channel_data_counter),
+        attr_counter_(attr_counter) {}
+  virtual ~MessageObserver() {}
+  virtual void ReceivedMessage(const TurnMessage* msg) override {
+    if (message_counter_ != nullptr) {
+      (*message_counter_)++;
+    }
+    // Implementation defined attributes are returned as ByteString
+    const StunByteStringAttribute* attr = msg->GetByteString(
+        TestTurnCustomizer::STUN_ATTR_COUNTER);
+    if (attr != nullptr && attr_counter_ != nullptr) {
+      rtc::ByteBufferReader buf(attr->bytes(), attr->length());
+      unsigned int val = ~0u;
+      buf.ReadUInt32(&val);
+      (*attr_counter_)++;
+    }
+  }
+
+  virtual void ReceivedChannelData(const char* data, size_t size) override {
+    if (channel_data_counter_ != nullptr) {
+      (*channel_data_counter_)++;
+    }
+  }
+
+  // Number of TurnMessages observed.
+  unsigned int* message_counter_ = nullptr;
+
+  // Number of channel data observed.
+  unsigned int* channel_data_counter_ = nullptr;
+
+  // Number of TurnMessages that had STUN_ATTR_COUNTER.
+  unsigned int* attr_counter_ = nullptr;
+};
+
+// Do a TURN allocation, establish a TLS connection, and send some data.
+// Add customizer and check that it get called.
+TEST_F(TurnPortTest, TestTurnCustomizerCount) {
+  unsigned int observer_message_counter = 0;
+  unsigned int observer_channel_data_counter = 0;
+  unsigned int observer_attr_counter = 0;
+  TestTurnCustomizer* customizer = new TestTurnCustomizer();
+  std::unique_ptr<MessageObserver> validator(new MessageObserver(
+      &observer_message_counter,
+      &observer_channel_data_counter,
+      &observer_attr_counter));
+
+  turn_server_.AddInternalSocket(kTurnTcpIntAddr, PROTO_TLS);
+  turn_customizer_.reset(customizer);
+  turn_server_.server()->SetStunMessageObserver(std::move(validator));
+
+  CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTlsProtoAddr);
+  TestTurnSendData(PROTO_TLS);
+  EXPECT_EQ(TLS_PROTOCOL_NAME, turn_port_->Candidates()[0].relay_protocol());
+
+  // There should have been at least turn_packets_.size() calls to |customizer|.
+  EXPECT_GE(customizer->modify_cnt_ + customizer->allow_channel_data_cnt_,
+            turn_packets_.size());
+
+  // Some channel data should be received.
+  EXPECT_GE(observer_channel_data_counter, 0u);
+
+  // Need to release TURN port before the customizer.
+  turn_port_.reset(nullptr);
+}
+
+// Do a TURN allocation, establish a TLS connection, and send some data.
+// Add customizer and check that it can can prevent usage of channel data.
+TEST_F(TurnPortTest, TestTurnCustomizerDisallowChannelData) {
+  unsigned int observer_message_counter = 0;
+  unsigned int observer_channel_data_counter = 0;
+  unsigned int observer_attr_counter = 0;
+  TestTurnCustomizer* customizer = new TestTurnCustomizer();
+  std::unique_ptr<MessageObserver> validator(new MessageObserver(
+      &observer_message_counter,
+      &observer_channel_data_counter,
+      &observer_attr_counter));
+  customizer->allow_channel_data_ = false;
+  turn_server_.AddInternalSocket(kTurnTcpIntAddr, PROTO_TLS);
+  turn_customizer_.reset(customizer);
+  turn_server_.server()->SetStunMessageObserver(std::move(validator));
+
+  CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTlsProtoAddr);
+  TestTurnSendData(PROTO_TLS);
+  EXPECT_EQ(TLS_PROTOCOL_NAME, turn_port_->Candidates()[0].relay_protocol());
+
+  // There should have been at least turn_packets_.size() calls to |customizer|.
+  EXPECT_GE(customizer->modify_cnt_, turn_packets_.size());
+
+  // No channel data should be received.
+  EXPECT_EQ(observer_channel_data_counter, 0u);
+
+  // Need to release TURN port before the customizer.
+  turn_port_.reset(nullptr);
+}
+
+// Do a TURN allocation, establish a TLS connection, and send some data.
+// Add customizer and check that it can add attribute to messages.
+TEST_F(TurnPortTest, TestTurnCustomizerAddAttribute) {
+  unsigned int observer_message_counter = 0;
+  unsigned int observer_channel_data_counter = 0;
+  unsigned int observer_attr_counter = 0;
+  TestTurnCustomizer* customizer = new TestTurnCustomizer();
+  std::unique_ptr<MessageObserver> validator(new MessageObserver(
+      &observer_message_counter,
+      &observer_channel_data_counter,
+      &observer_attr_counter));
+  customizer->allow_channel_data_ = false;
+  customizer->add_counter_ = true;
+  turn_server_.AddInternalSocket(kTurnTcpIntAddr, PROTO_TLS);
+  turn_customizer_.reset(customizer);
+  turn_server_.server()->SetStunMessageObserver(std::move(validator));
+
+  CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTlsProtoAddr);
+  TestTurnSendData(PROTO_TLS);
+  EXPECT_EQ(TLS_PROTOCOL_NAME, turn_port_->Candidates()[0].relay_protocol());
+
+  // There should have been at least turn_packets_.size() calls to |customizer|.
+  EXPECT_GE(customizer->modify_cnt_, turn_packets_.size());
+
+  // Everything will be sent as messages since channel data is disallowed.
+  EXPECT_GE(customizer->modify_cnt_, observer_message_counter);
+
+  // All messages should have attribute.
+  EXPECT_EQ(observer_message_counter, observer_attr_counter);
+
+  // At least allow_channel_data_cnt_ messages should have been sent.
+  EXPECT_GE(customizer->modify_cnt_, customizer->allow_channel_data_cnt_);
+  EXPECT_GE(customizer->allow_channel_data_cnt_, 0u);
+
+  // No channel data should be received.
+  EXPECT_EQ(observer_channel_data_counter, 0u);
+
+  // Need to release TURN port before the customizer.
+  turn_port_.reset(nullptr);
+}
+
 }  // namespace cricket
diff --git a/p2p/base/turnserver.cc b/p2p/base/turnserver.cc
index 7e4c436..0f7d815 100644
--- a/p2p/base/turnserver.cc
+++ b/p2p/base/turnserver.cc
@@ -211,6 +211,9 @@
     if (allocation) {
       allocation->HandleChannelData(data, size);
     }
+    if (stun_message_observer_ != nullptr) {
+      stun_message_observer_->ReceivedChannelData(data, size);
+    }
   }
 }
 
@@ -223,6 +226,10 @@
     return;
   }
 
+  if (stun_message_observer_ != nullptr) {
+    stun_message_observer_->ReceivedMessage(&msg);
+  }
+
   // If it's a STUN binding request, handle that specially.
   if (msg.type() == STUN_BINDING_REQUEST) {
     HandleBindingRequest(conn, &msg);
diff --git a/p2p/base/turnserver.h b/p2p/base/turnserver.h
index 1e06a93..f694912 100644
--- a/p2p/base/turnserver.h
+++ b/p2p/base/turnserver.h
@@ -157,6 +157,13 @@
   virtual ~TurnRedirectInterface() {}
 };
 
+class StunMessageObserver {
+ public:
+  virtual void ReceivedMessage(const TurnMessage* msg) = 0;
+  virtual void ReceivedChannelData(const char* data, size_t size) = 0;
+  virtual ~StunMessageObserver() {}
+};
+
 // The core TURN server class. Give it a socket to listen on via
 // AddInternalServerSocket, and a factory to create external sockets via
 // SetExternalSocketFactory, and it's ready to go.
@@ -214,6 +221,11 @@
     return GenerateNonce(timestamp);
   }
 
+  void SetStunMessageObserver(
+      std::unique_ptr<StunMessageObserver> observer) {
+    stun_message_observer_ = std::move(observer);
+  }
+
  private:
   std::string GenerateNonce(int64_t now) const;
   void OnInternalPacket(rtc::AsyncPacketSocket* socket, const char* data,
@@ -296,6 +308,9 @@
   // from this value, and it will be reset to 0 after generating the NONCE.
   int64_t ts_for_next_nonce_ = 0;
 
+  // For testing only. Used to observe STUN messages received.
+  std::unique_ptr<StunMessageObserver> stun_message_observer_;
+
   friend class TurnServerAllocation;
 };
 
diff --git a/p2p/client/basicportallocator.cc b/p2p/client/basicportallocator.cc
index f4c1fbd..4717484 100644
--- a/p2p/client/basicportallocator.cc
+++ b/p2p/client/basicportallocator.cc
@@ -96,11 +96,15 @@
     PORTALLOCATOR_DISABLE_STUN | PORTALLOCATOR_DISABLE_RELAY;
 
 // BasicPortAllocator
-BasicPortAllocator::BasicPortAllocator(rtc::NetworkManager* network_manager,
-                                       rtc::PacketSocketFactory* socket_factory)
+BasicPortAllocator::BasicPortAllocator(
+    rtc::NetworkManager* network_manager,
+    rtc::PacketSocketFactory* socket_factory,
+    webrtc::TurnCustomizer* customizer)
     : network_manager_(network_manager), socket_factory_(socket_factory) {
   RTC_DCHECK(network_manager_ != nullptr);
   RTC_DCHECK(socket_factory_ != nullptr);
+  SetConfiguration(ServerAddresses(), std::vector<RelayServerConfig>(),
+                   0, false, customizer);
   Construct();
 }
 
@@ -115,7 +119,8 @@
                                        const ServerAddresses& stun_servers)
     : network_manager_(network_manager), socket_factory_(socket_factory) {
   RTC_DCHECK(socket_factory_ != NULL);
-  SetConfiguration(stun_servers, std::vector<RelayServerConfig>(), 0, false);
+  SetConfiguration(stun_servers, std::vector<RelayServerConfig>(), 0, false,
+                   nullptr);
   Construct();
 }
 
@@ -142,7 +147,7 @@
     turn_servers.push_back(config);
   }
 
-  SetConfiguration(stun_servers, turn_servers, 0, false);
+  SetConfiguration(stun_servers, turn_servers, 0, false, nullptr);
   Construct();
 }
 
@@ -188,7 +193,7 @@
   std::vector<RelayServerConfig> new_turn_servers = turn_servers();
   new_turn_servers.push_back(turn_server);
   SetConfiguration(stun_servers(), new_turn_servers, candidate_pool_size(),
-                   prune_turn_ports());
+                   prune_turn_ports(), turn_customizer());
 }
 
 // BasicPortAllocatorSession
@@ -1374,7 +1379,6 @@
       continue;
     }
 
-
     // Shared socket mode must be enabled only for UDP based ports. Hence
     // don't pass shared socket for ports which will create TCP sockets.
     // TODO(mallinath) - Enable shared socket mode for TURN ports. Disabled
@@ -1386,7 +1390,8 @@
                               network_, udp_socket_.get(),
                               session_->username(), session_->password(),
                               *relay_port, config.credentials, config.priority,
-                              session_->allocator()->origin());
+                              session_->allocator()->origin(),
+                              session_->allocator()->turn_customizer());
       turn_ports_.push_back(port);
       // Listen to the port destroyed signal, to allow AllocationSequence to
       // remove entrt from it's map.
@@ -1397,7 +1402,8 @@
           session_->allocator()->min_port(), session_->allocator()->max_port(),
           session_->username(), session_->password(), *relay_port,
           config.credentials, config.priority, session_->allocator()->origin(),
-          config.tls_alpn_protocols, config.tls_elliptic_curves);
+          config.tls_alpn_protocols, config.tls_elliptic_curves,
+          session_->allocator()->turn_customizer());
     }
     RTC_DCHECK(port != NULL);
     port->SetTlsCertPolicy(config.tls_cert_policy);
diff --git a/p2p/client/basicportallocator.h b/p2p/client/basicportallocator.h
index d0c1de1..5ec721b 100644
--- a/p2p/client/basicportallocator.h
+++ b/p2p/client/basicportallocator.h
@@ -15,6 +15,7 @@
 #include <string>
 #include <vector>
 
+#include "api/turncustomizer.h"
 #include "p2p/base/portallocator.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/messagequeue.h"
@@ -26,7 +27,8 @@
 class BasicPortAllocator : public PortAllocator {
  public:
   BasicPortAllocator(rtc::NetworkManager* network_manager,
-                     rtc::PacketSocketFactory* socket_factory);
+                     rtc::PacketSocketFactory* socket_factory,
+                     webrtc::TurnCustomizer* customizer = nullptr);
   explicit BasicPortAllocator(rtc::NetworkManager* network_manager);
   BasicPortAllocator(rtc::NetworkManager* network_manager,
                      rtc::PacketSocketFactory* socket_factory,