shill: Adds OnNewFamilyMessage tests to netlink_manager_unittest.

This required extending some mocks.

BUG=chromium:222486
TEST=unittest

Change-Id: I737be51b288cb2fe26a71708e077e9c3c917910e
Reviewed-on: https://gerrit.chromium.org/gerrit/55571
Commit-Queue: Wade Guthrie <wdg@chromium.org>
Reviewed-by: Wade Guthrie <wdg@chromium.org>
Tested-by: Wade Guthrie <wdg@chromium.org>
diff --git a/mock_netlink_manager.h b/mock_netlink_manager.h
index e88d33e..27697af 100644
--- a/mock_netlink_manager.h
+++ b/mock_netlink_manager.h
@@ -19,6 +19,11 @@
  public:
   MockNetlinkManager();
   virtual ~MockNetlinkManager();
+
+  MOCK_METHOD2(
+      GetFamily,
+      uint16_t(const std::string &family_name,
+               const NetlinkMessageFactory::FactoryMethod &message_factory));
   MOCK_METHOD1(RemoveBroadcastHandler,
                bool(const NetlinkMessageHandler &message_handler));
   MOCK_METHOD1(AddBroadcastHandler,
diff --git a/mock_netlink_socket.h b/mock_netlink_socket.h
index ceeb888..d5740bb 100644
--- a/mock_netlink_socket.h
+++ b/mock_netlink_socket.h
@@ -5,13 +5,12 @@
 #ifndef SHILL_MOCK_NETLINK_SOCKET_H_
 #define SHILL_MOCK_NETLINK_SOCKET_H_
 
+#include "shill/netlink_socket.h"
 
 #include <base/basictypes.h>
 
 #include <gmock/gmock.h>
 
-#include "shill/netlink_socket.h"
-
 namespace shill {
 
 class ByteString;
@@ -21,9 +20,11 @@
   MockNetlinkSocket() {}
   MOCK_METHOD0(Init, bool());
 
-  virtual bool SendMessage(const ByteString &out_string);
   uint32 GetLastSequenceNumber() const { return sequence_number_; }
+  MOCK_CONST_METHOD0(file_descriptor, int());
+  MOCK_METHOD1(SendMessage, bool(const ByteString &out_string));
   MOCK_METHOD1(SubscribeToEvents, bool(uint32_t group_id));
+  MOCK_METHOD1(RecvMessage, bool(ByteString *message));
 
  private:
   DISALLOW_COPY_AND_ASSIGN(MockNetlinkSocket);
diff --git a/mock_sockets.h b/mock_sockets.h
index 919e1b2..896859b 100644
--- a/mock_sockets.h
+++ b/mock_sockets.h
@@ -5,11 +5,11 @@
 #ifndef SHILL_MOCK_SOCKETS_H_
 #define SHILL_MOCK_SOCKETS_H_
 
+#include "shill/sockets.h"
+
 #include <base/basictypes.h>
 #include <gmock/gmock.h>
 
-#include "shill/sockets.h"
-
 namespace shill {
 
 class MockSockets : public Sockets {
@@ -38,6 +38,11 @@
                                        int flags,
                                        struct sockaddr *src_addr,
                                        socklen_t *addrlen));
+  MOCK_CONST_METHOD5(Select, int(int nfds,
+                                 fd_set *readfds,
+                                 fd_set *writefds,
+                                 fd_set *exceptfds,
+                                 struct timeval *timeout));
   MOCK_CONST_METHOD4(Send, ssize_t(int sockfd, const void *buf, size_t len,
                                    int flags));
   MOCK_CONST_METHOD6(SendTo, ssize_t(int sockfd,
diff --git a/netlink_manager.cc b/netlink_manager.cc
index 2c2d99d..0172611 100644
--- a/netlink_manager.cc
+++ b/netlink_manager.cc
@@ -22,6 +22,7 @@
 #include "shill/netlink_socket.h"
 #include "shill/scope_logger.h"
 #include "shill/shill_time.h"
+#include "shill/sockets.h"
 
 using base::Bind;
 using base::LazyInstance;
@@ -50,7 +51,8 @@
       weak_ptr_factory_(this),
       dispatcher_callback_(Bind(&NetlinkManager::OnRawNlMessageReceived,
                                 weak_ptr_factory_.GetWeakPtr())),
-      sock_(NULL) {}
+      sock_(NULL),
+      time_(Time::GetInstance()) {}
 
 NetlinkManager *NetlinkManager::GetInstance() {
   return g_netlink_manager.Pointer();
@@ -58,6 +60,7 @@
 
 void NetlinkManager::Reset(bool full) {
   ClearBroadcastHandlers();
+  message_handlers_.clear();
   message_types_.clear();
   if (full) {
     dispatcher_ = NULL;
@@ -171,7 +174,7 @@
   return (sock_ ? sock_->file_descriptor() : -1);
 }
 
-uint16_t NetlinkManager::GetFamily(string name,
+uint16_t NetlinkManager::GetFamily(const string &name,
     const NetlinkMessageFactory::FactoryMethod &message_factory) {
   MessageType &message_type = message_types_[name];
   if (message_type.family_id != NetlinkMessage::kIllegalMessageType) {
@@ -207,8 +210,7 @@
   struct timeval start_time, now, end_time;
   struct timeval maximum_wait_duration = {kMaximumNewFamilyWaitSeconds,
                                           kMaximumNewFamilyWaitMicroSeconds};
-  Time *time = Time::GetInstance();
-  time->GetTimeMonotonic(&start_time);
+  time_->GetTimeMonotonic(&start_time);
   now = start_time;
   timeradd(&start_time, &maximum_wait_duration, &end_time);
 
@@ -219,8 +221,11 @@
     FD_SET(file_descriptor(), &read_fds);
     struct timeval wait_duration;
     timersub(&end_time, &now, &wait_duration);
-    int result = select(file_descriptor() + 1, &read_fds, NULL, NULL,
-                        &wait_duration);
+    int result = sock_->sockets()->Select(file_descriptor() + 1,
+                                          &read_fds,
+                                          NULL,
+                                          NULL,
+                                          &wait_duration);
     if (result < 0) {
       PLOG(ERROR) << "Select failed";
       return NetlinkMessage::kIllegalMessageType;
@@ -241,7 +246,7 @@
       if (family_id != NetlinkMessage::kIllegalMessageType) {
         message_factory_.AddFactoryMethod(family_id, message_factory);
       }
-      time->GetTimeMonotonic(&now);
+      time_->GetTimeMonotonic(&now);
       timersub(&now, &start_time, &wait_duration);
       SLOG(WiFi, 5) << "Found id " << message_type.family_id
                     << " for name '" << name << "' in "
@@ -249,7 +254,7 @@
                     << wait_duration.tv_usec << " usec.";
       return message_type.family_id;
     }
-    time->GetTimeMonotonic(&now);
+    time_->GetTimeMonotonic(&now);
   } while (timercmp(&now, &end_time, <));
 
   LOG(ERROR) << "Timed out waiting for family_id for family '" << name << "'.";
diff --git a/netlink_manager.h b/netlink_manager.h
index a5312cc..fc53984 100644
--- a/netlink_manager.h
+++ b/netlink_manager.h
@@ -72,6 +72,7 @@
 
 #include "shill/io_handler.h"
 #include "shill/netlink_message.h"
+#include "shill/shill_time.h"
 
 struct nlmsghdr;
 
@@ -138,7 +139,7 @@
   // |NetlinkMessage::kIllegalMessageType| if the message type could not be
   // determined.  May block so |GetFamily| should be called before entering the
   // event loop.
-  uint16_t GetFamily(std::string family_name,
+  virtual uint16_t GetFamily(const std::string &family_name,
       const NetlinkMessageFactory::FactoryMethod &message_factory);
 
   // Install a NetlinkManager NetlinkMessageHandler.  The handler is a
@@ -189,6 +190,8 @@
   friend class ShillDaemonTest;
   FRIEND_TEST(NetlinkManagerTest, AddLinkTest);
   FRIEND_TEST(NetlinkManagerTest, BroadcastHandlerTest);
+  FRIEND_TEST(NetlinkManagerTest, GetFamilyOneInterstitialMessage);
+  FRIEND_TEST(NetlinkManagerTest, GetFamilyTimeout);
   FRIEND_TEST(NetlinkManagerTest, MessageHandlerTest);
   FRIEND_TEST(NetlinkMessageTest, Parse_NL80211_CMD_TRIGGER_SCAN);
   FRIEND_TEST(NetlinkMessageTest, Parse_NL80211_CMD_NEW_SCAN_RESULTS);
@@ -247,6 +250,7 @@
   NetlinkSocket *sock_;
   std::map<const std::string, MessageType> message_types_;
   NetlinkMessageFactory message_factory_;
+  Time *time_;
 
   DISALLOW_COPY_AND_ASSIGN(NetlinkManager);
 };
diff --git a/netlink_manager_unittest.cc b/netlink_manager_unittest.cc
index 6154a5a..2f5b0a6 100644
--- a/netlink_manager_unittest.cc
+++ b/netlink_manager_unittest.cc
@@ -10,14 +10,20 @@
 // This file tests the public interface to NetlinkManager.
 #include "shill/netlink_manager.h"
 
+#include <string>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 #include "shill/mock_netlink_socket.h"
+#include "shill/mock_sockets.h"
+#include "shill/mock_time.h"
+#include "shill/netlink_attribute.h"
 #include "shill/nl80211_message.h"
 
 using base::Bind;
 using base::Unretained;
+using std::string;
 using testing::_;
 using testing::Invoke;
 using testing::Return;
@@ -53,18 +59,23 @@
 
 }  // namespace
 
-bool MockNetlinkSocket::SendMessage(const ByteString &out_string) {
-  return true;
-}
-
 class NetlinkManagerTest : public Test {
  public:
-  NetlinkManagerTest() : netlink_manager_(NetlinkManager::GetInstance()) {
+  NetlinkManagerTest()
+      : netlink_manager_(NetlinkManager::GetInstance()),
+        sockets_(new MockSockets),
+        saved_sequence_number_(0) {
     netlink_manager_->message_types_[Nl80211Message::kMessageTypeString]
         .family_id = kNl80211FamilyId;
     netlink_manager_->message_factory_.AddFactoryMethod(
         kNl80211FamilyId, Bind(&Nl80211Message::CreateMessage));
     Nl80211Message::SetMessageType(kNl80211FamilyId);
+    // Passes ownership.
+    netlink_socket_.sockets_.reset(sockets_);
+
+    EXPECT_NE(reinterpret_cast<NetlinkManager *>(NULL), netlink_manager_);
+    netlink_manager_->sock_ = &netlink_socket_;
+    EXPECT_TRUE(netlink_manager_->Init());
   }
 
   ~NetlinkManagerTest() {
@@ -74,10 +85,57 @@
     netlink_manager_->sock_ = NULL;
   }
 
-  void SetupNetlinkManagerObject() {
-    EXPECT_NE(reinterpret_cast<NetlinkManager *>(NULL), netlink_manager_);
-    netlink_manager_->sock_ = &socket_;
-    EXPECT_TRUE(netlink_manager_->Init());
+  // |SaveReply|, |SendMessage|, and |ReplyToSentMessage| work together to
+  // enable a test to get a response to a sent message.  They must be called
+  // in the order, above, so that a) a reply message is available to b) have
+  // its sequence number replaced, and then c) sent back to the code.
+  void SaveReply(const ByteString &message) {
+    saved_message_ = message;
+  }
+
+  // Replaces the |saved_message_|'s sequence number with the sent value.
+  bool SendMessage(const ByteString &outgoing_message) {
+    if (outgoing_message.GetLength() < sizeof(nlmsghdr)) {
+      LOG(ERROR) << "Outgoing message is too short";
+      return false;
+    }
+    const nlmsghdr *outgoing_header =
+        reinterpret_cast<const nlmsghdr *>(outgoing_message.GetConstData());
+
+    if (saved_message_.GetLength() < sizeof(nlmsghdr)) {
+      LOG(ERROR) << "Saved message is too short; have you called |SaveReply|?";
+      return false;
+    }
+    nlmsghdr *reply_header =
+        reinterpret_cast<nlmsghdr *>(saved_message_.GetData());
+
+    reply_header->nlmsg_seq = outgoing_header->nlmsg_seq;
+    saved_sequence_number_ = reply_header->nlmsg_seq;
+    return true;
+  }
+
+  bool ReplyToSentMessage(ByteString *message) {
+    if (!message) {
+      return false;
+    }
+    *message = saved_message_;
+    return true;
+  }
+
+  bool ReplyWithRandomMessage(ByteString *message) {
+    GetFamilyMessage get_family_message;
+    // Any number that's not 0 or 1 is acceptable, here.  Zero is bad because
+    // we want to make sure that this message is different than the main
+    // send/receive pair.  One is bad becasue the default for
+    // |saved_sequence_number_| is zero and the likely default value for the
+    // first sequence number generated from the code is 1.
+    const uint32_t kRandomOffset = 1003;
+    if (!message) {
+      return false;
+    }
+    *message = get_family_message.Encode(saved_sequence_number_ +
+                                         kRandomOffset);
+    return true;
   }
 
  protected:
@@ -95,20 +153,153 @@
     DISALLOW_COPY_AND_ASSIGN(MockHandler80211);
   };
 
+  void Reset() {
+    netlink_manager_->Reset(false);
+  }
+
   NetlinkManager *netlink_manager_;
-  MockNetlinkSocket socket_;
+  MockNetlinkSocket netlink_socket_;
+  MockSockets *sockets_;  // Owned by |netlink_socket_|.
+  ByteString saved_message_;
+  uint32_t saved_sequence_number_;
 };
 
+namespace {
+
+class TimeFunctor {
+ public:
+  TimeFunctor(time_t tv_sec, suseconds_t tv_usec) {
+    return_value_.tv_sec = tv_sec;
+    return_value_.tv_usec = tv_usec;
+  }
+
+  TimeFunctor() {
+    return_value_.tv_sec = 0;
+    return_value_.tv_usec = 0;
+  }
+
+  TimeFunctor(const TimeFunctor &other) {
+    return_value_.tv_sec = other.return_value_.tv_sec;
+    return_value_.tv_usec = other.return_value_.tv_usec;
+  }
+
+  TimeFunctor &operator=(const TimeFunctor &rhs) {
+    return_value_.tv_sec = rhs.return_value_.tv_sec;
+    return_value_.tv_usec = rhs.return_value_.tv_usec;
+    return *this;
+  }
+
+  // Replaces GetTimeMonotonic.
+  int operator()(struct timeval *answer) {
+    if (answer) {
+      *answer = return_value_;
+    }
+    return 0;
+  }
+
+ private:
+  struct timeval return_value_;
+
+  // No DISALLOW_COPY_AND_ASSIGN since testing::Invoke uses copy.
+};
+
+}  // namespace
+
 // TODO(wdg): Add a test for multi-part messages.  crbug.com/224652
-// TODO(wdg): Add a test for GetFaimily.  crbug.com/224649
-// TODO(wdg): Add a test for OnNewFamilyMessage.  crbug.com/222486
 // TODO(wdg): Add a test for SubscribeToEvents (verify that it handles bad input
 // appropriately, and that it calls NetlinkSocket::SubscribeToEvents if input
 // is good.)
 
-TEST_F(NetlinkManagerTest, BroadcastHandlerTest) {
-  SetupNetlinkManagerObject();
+TEST_F(NetlinkManagerTest, GetFamily) {
+  const uint16_t kSampleMessageType = 42;
+  const string kSampleMessageName("SampleMessageName");
+  const uint32_t kRandomSequenceNumber = 3;
 
+  NewFamilyMessage new_family_message;
+  new_family_message.attributes()->CreateAttribute(
+      CTRL_ATTR_FAMILY_ID,
+      base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+  new_family_message.attributes()->SetU16AttributeValue(
+      CTRL_ATTR_FAMILY_ID, kSampleMessageType);
+  new_family_message.attributes()->CreateAttribute(
+      CTRL_ATTR_FAMILY_NAME,
+      base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+  new_family_message.attributes()->SetStringAttributeValue(
+      CTRL_ATTR_FAMILY_NAME, kSampleMessageName);
+
+  // The sequence number is immaterial since it'll be overwritten.
+  SaveReply(new_family_message.Encode(kRandomSequenceNumber));
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).
+      WillOnce(Invoke(this, &NetlinkManagerTest::SendMessage));
+  EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+  EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillOnce(Return(1));
+  EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+      WillOnce(Invoke(this, &NetlinkManagerTest::ReplyToSentMessage));
+  NetlinkMessageFactory::FactoryMethod null_factory;
+  EXPECT_EQ(kSampleMessageType, netlink_manager_->GetFamily(kSampleMessageName,
+                                                            null_factory));
+}
+
+TEST_F(NetlinkManagerTest, GetFamilyOneInterstitialMessage) {
+  Reset();
+
+  const uint16_t kSampleMessageType = 42;
+  const string kSampleMessageName("SampleMessageName");
+  const uint32_t kRandomSequenceNumber = 3;
+
+  NewFamilyMessage new_family_message;
+  new_family_message.attributes()->CreateAttribute(
+      CTRL_ATTR_FAMILY_ID,
+      base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+  new_family_message.attributes()->SetU16AttributeValue(
+      CTRL_ATTR_FAMILY_ID, kSampleMessageType);
+  new_family_message.attributes()->CreateAttribute(
+      CTRL_ATTR_FAMILY_NAME,
+      base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+  new_family_message.attributes()->SetStringAttributeValue(
+      CTRL_ATTR_FAMILY_NAME, kSampleMessageName);
+
+  // The sequence number is immaterial since it'll be overwritten.
+  SaveReply(new_family_message.Encode(kRandomSequenceNumber));
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).
+      WillOnce(Invoke(this, &NetlinkManagerTest::SendMessage));
+  EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+  EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillRepeatedly(Return(1));
+  EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+      WillOnce(Invoke(this, &NetlinkManagerTest::ReplyWithRandomMessage)).
+      WillOnce(Invoke(this, &NetlinkManagerTest::ReplyToSentMessage));
+  NetlinkMessageFactory::FactoryMethod null_factory;
+  EXPECT_EQ(kSampleMessageType, netlink_manager_->GetFamily(kSampleMessageName,
+                                                            null_factory));
+}
+
+TEST_F(NetlinkManagerTest, GetFamilyTimeout) {
+  Reset();
+  MockTime time;
+  netlink_manager_->time_ = &time;
+
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
+  time_t kStartSeconds = 1234;  // Arbitrary.
+  suseconds_t kSmallUsec = 100;
+  EXPECT_CALL(time, GetTimeMonotonic(_)).
+      WillOnce(Invoke(TimeFunctor(kStartSeconds, 0))).  // Initial time.
+      WillOnce(Invoke(TimeFunctor(kStartSeconds, kSmallUsec))).
+      WillOnce(Invoke(TimeFunctor(kStartSeconds, 2 * kSmallUsec))).
+      WillOnce(Invoke(TimeFunctor(
+          kStartSeconds + NetlinkManager::kMaximumNewFamilyWaitSeconds + 1,
+          NetlinkManager::kMaximumNewFamilyWaitMicroSeconds)));
+  EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+  EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillRepeatedly(Return(1));
+  EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+      WillRepeatedly(Invoke(this, &NetlinkManagerTest::ReplyWithRandomMessage));
+  NetlinkMessageFactory::FactoryMethod null_factory;
+
+  const string kSampleMessageName("SampleMessageName");
+  EXPECT_EQ(NetlinkMessage::kIllegalMessageType,
+            netlink_manager_->GetFamily(kSampleMessageName, null_factory));
+}
+
+TEST_F(NetlinkManagerTest, BroadcastHandlerTest) {
   nlmsghdr *message = const_cast<nlmsghdr *>(
         reinterpret_cast<const nlmsghdr *>(kNL80211_CMD_DISCONNECT));
 
@@ -157,9 +348,7 @@
 }
 
 TEST_F(NetlinkManagerTest, MessageHandlerTest) {
-  // Setup.
-  SetupNetlinkManagerObject();
-
+  Reset();
   MockHandler80211 handler_broadcast;
   EXPECT_TRUE(netlink_manager_->AddBroadcastHandler(
       handler_broadcast.on_netlink_message()));
@@ -186,10 +375,11 @@
   netlink_manager_->OnNlMessageReceived(received_message);
 
   // Send the message and give our handler.  Verify that we get called back.
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
   EXPECT_TRUE(netlink_manager_->SendMessage(
       &sent_message_1, handler_sent_1.on_netlink_message()));
   // Make it appear that this message is in response to our sent message.
-  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+  received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
   EXPECT_CALL(handler_sent_1, OnNetlinkMessage(_)).Times(1);
   netlink_manager_->OnNlMessageReceived(received_message);
 
@@ -200,15 +390,17 @@
 
   // Install and then uninstall message-specific handler; verify broadcast
   // handler is called on message receipt.
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
   EXPECT_TRUE(netlink_manager_->SendMessage(
       &sent_message_1, handler_sent_1.on_netlink_message()));
-  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+  received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
   EXPECT_TRUE(netlink_manager_->RemoveMessageHandler(sent_message_1));
   EXPECT_CALL(handler_broadcast, OnNetlinkMessage(_)).Times(1);
   netlink_manager_->OnNlMessageReceived(received_message);
 
   // Install handler for different message; verify that broadcast handler is
   // called for _this_ message.
+  EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
   EXPECT_TRUE(netlink_manager_->SendMessage(
       &sent_message_2, handler_sent_2.on_netlink_message()));
   EXPECT_CALL(handler_broadcast, OnNetlinkMessage(_)).Times(1);
@@ -216,7 +408,7 @@
 
   // Change the ID for the message to that of the second handler; verify that
   // the appropriate handler is called for _that_ message.
-  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+  received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
   EXPECT_CALL(handler_sent_2, OnNetlinkMessage(_)).Times(1);
   netlink_manager_->OnNlMessageReceived(received_message);
 }
diff --git a/netlink_socket.h b/netlink_socket.h
index e15245c..f9ddcae 100644
--- a/netlink_socket.h
+++ b/netlink_socket.h
@@ -52,7 +52,7 @@
   bool Init();
 
   // Returns the file descriptor used by the socket.
-  int file_descriptor() const { return file_descriptor_; }
+  virtual int file_descriptor() const { return file_descriptor_; }
 
   // Get the next message sequence number for this socket.
   // |GetSequenceNumber| won't return zero because that is the 'broadcast'
@@ -62,7 +62,7 @@
   // Reads data from the socket into |message| and returns true if successful.
   // The |message| parameter will be resized to hold the entirety of the read
   // message (and any data in |message| will be overwritten).
-  bool RecvMessage(ByteString *message);
+  virtual bool RecvMessage(ByteString *message);
 
   // Sends a message, returns true if successful.
   virtual bool SendMessage(const ByteString &message);
@@ -70,10 +70,13 @@
   // Subscribes to netlink broadcast events.
   virtual bool SubscribeToEvents(uint32_t group_id);
 
+  virtual const Sockets *sockets() const { return sockets_.get(); }
+
  protected:
   uint32_t sequence_number_;
 
  private:
+  friend class NetlinkManagerTest;
   friend class NetlinkSocketTest;
   FRIEND_TEST(NetlinkSocketTest, SequenceNumberTest);
 
diff --git a/shill_unittest.cc b/shill_unittest.cc
index 8c78409..19eb37d 100644
--- a/shill_unittest.cc
+++ b/shill_unittest.cc
@@ -18,9 +18,11 @@
 #include "shill/mock_dhcp_provider.h"
 #include "shill/mock_manager.h"
 #include "shill/mock_metrics.h"
+#include "shill/mock_netlink_manager.h"
 #include "shill/mock_proxy_factory.h"
-#include "shill/mock_rtnl_handler.h"
 #include "shill/mock_routing_table.h"
+#include "shill/mock_rtnl_handler.h"
+#include "shill/nl80211_message.h"
 #include "shill/shill_daemon.h"
 #include "shill/shill_test_config.h"
 
@@ -202,7 +204,12 @@
     daemon_.dhcp_provider_ = &dhcp_provider_;
     daemon_.metrics_.reset(metrics_);  // Passes ownership
     daemon_.manager_.reset(manager_);  // Passes ownership
+    daemon_.netlink_manager_ = &netlink_manager_;
     dispatcher_test_.ScheduleFailSafe();
+
+    const uint16_t kNl80211MessageType = 42;  // Arbitrary.
+    ON_CALL(netlink_manager_, GetFamily(Nl80211Message::kMessageTypeString, _)).
+            WillByDefault(Return(kNl80211MessageType));
   }
   void StartDaemon() {
     daemon_.Start();
@@ -227,6 +234,7 @@
   MockDHCPProvider dhcp_provider_;
   MockMetrics *metrics_;
   MockManager *manager_;
+  MockNetlinkManager netlink_manager_;
   DeviceInfo device_info_;
   EventDispatcher *dispatcher_;
   StrictMock<MockEventDispatchTester> dispatcher_test_;
diff --git a/sockets.cc b/sockets.cc
index 4cb7730..bb7ddae 100644
--- a/sockets.cc
+++ b/sockets.cc
@@ -18,6 +18,8 @@
 
 namespace shill {
 
+Sockets::Sockets() {}
+
 Sockets::~Sockets() {}
 
 // Some system calls can be interrupted and return EINTR, but will succeed on
@@ -103,6 +105,14 @@
   return HANDLE_EINTR(recvfrom(sockfd, buf, len, flags, src_addr, addrlen));
 }
 
+int Sockets::Select(int nfds,
+                    fd_set *readfds,
+                    fd_set *writefds,
+                    fd_set *exceptfds,
+                    struct timeval *timeout) const {
+  return HANDLE_EINTR(select(nfds, readfds, writefds, exceptfds, timeout));
+}
+
 ssize_t Sockets::Send(int sockfd,
                       const void *buf,
                       size_t len,
diff --git a/sockets.h b/sockets.h
index 19bb18f..17e9ff8 100644
--- a/sockets.h
+++ b/sockets.h
@@ -5,12 +5,12 @@
 #ifndef SHILL_SOCKETS_H_
 #define SHILL_SOCKETS_H_
 
-#include <string>
-
 #include <linux/filter.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 
+#include <string>
+
 #include <base/basictypes.h>
 
 namespace shill {
@@ -18,6 +18,7 @@
 // A "sys/socket.h" abstraction allowing mocking in tests.
 class Sockets {
  public:
+  Sockets();
   virtual ~Sockets();
 
   // accept
@@ -68,6 +69,13 @@
   virtual ssize_t RecvFrom(int sockfd, void *buf, size_t len, int flags,
                            struct sockaddr *src_addr, socklen_t *addrlen) const;
 
+  // select
+  virtual int Select(int nfds,
+                     fd_set *readfds,
+                     fd_set *writefds,
+                     fd_set *exceptfds,
+                     struct timeval *timeout) const;
+
   // send
   virtual ssize_t Send(int sockfd,
                        const void *buf,
@@ -93,6 +101,9 @@
 
   // socket
   virtual int Socket(int domain, int type, int protocol) const;
+
+ private:
+  DISALLOW_COPY_AND_ASSIGN(Sockets);
 };
 
 class ScopedSocketCloser {