shill: Adds OnNewFamilyMessage tests to netlink_manager_unittest.
This required extending some mocks.
BUG=chromium:222486
TEST=unittest
Change-Id: I737be51b288cb2fe26a71708e077e9c3c917910e
Reviewed-on: https://gerrit.chromium.org/gerrit/55571
Commit-Queue: Wade Guthrie <wdg@chromium.org>
Reviewed-by: Wade Guthrie <wdg@chromium.org>
Tested-by: Wade Guthrie <wdg@chromium.org>
diff --git a/mock_netlink_manager.h b/mock_netlink_manager.h
index e88d33e..27697af 100644
--- a/mock_netlink_manager.h
+++ b/mock_netlink_manager.h
@@ -19,6 +19,11 @@
public:
MockNetlinkManager();
virtual ~MockNetlinkManager();
+
+ MOCK_METHOD2(
+ GetFamily,
+ uint16_t(const std::string &family_name,
+ const NetlinkMessageFactory::FactoryMethod &message_factory));
MOCK_METHOD1(RemoveBroadcastHandler,
bool(const NetlinkMessageHandler &message_handler));
MOCK_METHOD1(AddBroadcastHandler,
diff --git a/mock_netlink_socket.h b/mock_netlink_socket.h
index ceeb888..d5740bb 100644
--- a/mock_netlink_socket.h
+++ b/mock_netlink_socket.h
@@ -5,13 +5,12 @@
#ifndef SHILL_MOCK_NETLINK_SOCKET_H_
#define SHILL_MOCK_NETLINK_SOCKET_H_
+#include "shill/netlink_socket.h"
#include <base/basictypes.h>
#include <gmock/gmock.h>
-#include "shill/netlink_socket.h"
-
namespace shill {
class ByteString;
@@ -21,9 +20,11 @@
MockNetlinkSocket() {}
MOCK_METHOD0(Init, bool());
- virtual bool SendMessage(const ByteString &out_string);
uint32 GetLastSequenceNumber() const { return sequence_number_; }
+ MOCK_CONST_METHOD0(file_descriptor, int());
+ MOCK_METHOD1(SendMessage, bool(const ByteString &out_string));
MOCK_METHOD1(SubscribeToEvents, bool(uint32_t group_id));
+ MOCK_METHOD1(RecvMessage, bool(ByteString *message));
private:
DISALLOW_COPY_AND_ASSIGN(MockNetlinkSocket);
diff --git a/mock_sockets.h b/mock_sockets.h
index 919e1b2..896859b 100644
--- a/mock_sockets.h
+++ b/mock_sockets.h
@@ -5,11 +5,11 @@
#ifndef SHILL_MOCK_SOCKETS_H_
#define SHILL_MOCK_SOCKETS_H_
+#include "shill/sockets.h"
+
#include <base/basictypes.h>
#include <gmock/gmock.h>
-#include "shill/sockets.h"
-
namespace shill {
class MockSockets : public Sockets {
@@ -38,6 +38,11 @@
int flags,
struct sockaddr *src_addr,
socklen_t *addrlen));
+ MOCK_CONST_METHOD5(Select, int(int nfds,
+ fd_set *readfds,
+ fd_set *writefds,
+ fd_set *exceptfds,
+ struct timeval *timeout));
MOCK_CONST_METHOD4(Send, ssize_t(int sockfd, const void *buf, size_t len,
int flags));
MOCK_CONST_METHOD6(SendTo, ssize_t(int sockfd,
diff --git a/netlink_manager.cc b/netlink_manager.cc
index 2c2d99d..0172611 100644
--- a/netlink_manager.cc
+++ b/netlink_manager.cc
@@ -22,6 +22,7 @@
#include "shill/netlink_socket.h"
#include "shill/scope_logger.h"
#include "shill/shill_time.h"
+#include "shill/sockets.h"
using base::Bind;
using base::LazyInstance;
@@ -50,7 +51,8 @@
weak_ptr_factory_(this),
dispatcher_callback_(Bind(&NetlinkManager::OnRawNlMessageReceived,
weak_ptr_factory_.GetWeakPtr())),
- sock_(NULL) {}
+ sock_(NULL),
+ time_(Time::GetInstance()) {}
NetlinkManager *NetlinkManager::GetInstance() {
return g_netlink_manager.Pointer();
@@ -58,6 +60,7 @@
void NetlinkManager::Reset(bool full) {
ClearBroadcastHandlers();
+ message_handlers_.clear();
message_types_.clear();
if (full) {
dispatcher_ = NULL;
@@ -171,7 +174,7 @@
return (sock_ ? sock_->file_descriptor() : -1);
}
-uint16_t NetlinkManager::GetFamily(string name,
+uint16_t NetlinkManager::GetFamily(const string &name,
const NetlinkMessageFactory::FactoryMethod &message_factory) {
MessageType &message_type = message_types_[name];
if (message_type.family_id != NetlinkMessage::kIllegalMessageType) {
@@ -207,8 +210,7 @@
struct timeval start_time, now, end_time;
struct timeval maximum_wait_duration = {kMaximumNewFamilyWaitSeconds,
kMaximumNewFamilyWaitMicroSeconds};
- Time *time = Time::GetInstance();
- time->GetTimeMonotonic(&start_time);
+ time_->GetTimeMonotonic(&start_time);
now = start_time;
timeradd(&start_time, &maximum_wait_duration, &end_time);
@@ -219,8 +221,11 @@
FD_SET(file_descriptor(), &read_fds);
struct timeval wait_duration;
timersub(&end_time, &now, &wait_duration);
- int result = select(file_descriptor() + 1, &read_fds, NULL, NULL,
- &wait_duration);
+ int result = sock_->sockets()->Select(file_descriptor() + 1,
+ &read_fds,
+ NULL,
+ NULL,
+ &wait_duration);
if (result < 0) {
PLOG(ERROR) << "Select failed";
return NetlinkMessage::kIllegalMessageType;
@@ -241,7 +246,7 @@
if (family_id != NetlinkMessage::kIllegalMessageType) {
message_factory_.AddFactoryMethod(family_id, message_factory);
}
- time->GetTimeMonotonic(&now);
+ time_->GetTimeMonotonic(&now);
timersub(&now, &start_time, &wait_duration);
SLOG(WiFi, 5) << "Found id " << message_type.family_id
<< " for name '" << name << "' in "
@@ -249,7 +254,7 @@
<< wait_duration.tv_usec << " usec.";
return message_type.family_id;
}
- time->GetTimeMonotonic(&now);
+ time_->GetTimeMonotonic(&now);
} while (timercmp(&now, &end_time, <));
LOG(ERROR) << "Timed out waiting for family_id for family '" << name << "'.";
diff --git a/netlink_manager.h b/netlink_manager.h
index a5312cc..fc53984 100644
--- a/netlink_manager.h
+++ b/netlink_manager.h
@@ -72,6 +72,7 @@
#include "shill/io_handler.h"
#include "shill/netlink_message.h"
+#include "shill/shill_time.h"
struct nlmsghdr;
@@ -138,7 +139,7 @@
// |NetlinkMessage::kIllegalMessageType| if the message type could not be
// determined. May block so |GetFamily| should be called before entering the
// event loop.
- uint16_t GetFamily(std::string family_name,
+ virtual uint16_t GetFamily(const std::string &family_name,
const NetlinkMessageFactory::FactoryMethod &message_factory);
// Install a NetlinkManager NetlinkMessageHandler. The handler is a
@@ -189,6 +190,8 @@
friend class ShillDaemonTest;
FRIEND_TEST(NetlinkManagerTest, AddLinkTest);
FRIEND_TEST(NetlinkManagerTest, BroadcastHandlerTest);
+ FRIEND_TEST(NetlinkManagerTest, GetFamilyOneInterstitialMessage);
+ FRIEND_TEST(NetlinkManagerTest, GetFamilyTimeout);
FRIEND_TEST(NetlinkManagerTest, MessageHandlerTest);
FRIEND_TEST(NetlinkMessageTest, Parse_NL80211_CMD_TRIGGER_SCAN);
FRIEND_TEST(NetlinkMessageTest, Parse_NL80211_CMD_NEW_SCAN_RESULTS);
@@ -247,6 +250,7 @@
NetlinkSocket *sock_;
std::map<const std::string, MessageType> message_types_;
NetlinkMessageFactory message_factory_;
+ Time *time_;
DISALLOW_COPY_AND_ASSIGN(NetlinkManager);
};
diff --git a/netlink_manager_unittest.cc b/netlink_manager_unittest.cc
index 6154a5a..2f5b0a6 100644
--- a/netlink_manager_unittest.cc
+++ b/netlink_manager_unittest.cc
@@ -10,14 +10,20 @@
// This file tests the public interface to NetlinkManager.
#include "shill/netlink_manager.h"
+#include <string>
+
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "shill/mock_netlink_socket.h"
+#include "shill/mock_sockets.h"
+#include "shill/mock_time.h"
+#include "shill/netlink_attribute.h"
#include "shill/nl80211_message.h"
using base::Bind;
using base::Unretained;
+using std::string;
using testing::_;
using testing::Invoke;
using testing::Return;
@@ -53,18 +59,23 @@
} // namespace
-bool MockNetlinkSocket::SendMessage(const ByteString &out_string) {
- return true;
-}
-
class NetlinkManagerTest : public Test {
public:
- NetlinkManagerTest() : netlink_manager_(NetlinkManager::GetInstance()) {
+ NetlinkManagerTest()
+ : netlink_manager_(NetlinkManager::GetInstance()),
+ sockets_(new MockSockets),
+ saved_sequence_number_(0) {
netlink_manager_->message_types_[Nl80211Message::kMessageTypeString]
.family_id = kNl80211FamilyId;
netlink_manager_->message_factory_.AddFactoryMethod(
kNl80211FamilyId, Bind(&Nl80211Message::CreateMessage));
Nl80211Message::SetMessageType(kNl80211FamilyId);
+ // Passes ownership.
+ netlink_socket_.sockets_.reset(sockets_);
+
+ EXPECT_NE(reinterpret_cast<NetlinkManager *>(NULL), netlink_manager_);
+ netlink_manager_->sock_ = &netlink_socket_;
+ EXPECT_TRUE(netlink_manager_->Init());
}
~NetlinkManagerTest() {
@@ -74,10 +85,57 @@
netlink_manager_->sock_ = NULL;
}
- void SetupNetlinkManagerObject() {
- EXPECT_NE(reinterpret_cast<NetlinkManager *>(NULL), netlink_manager_);
- netlink_manager_->sock_ = &socket_;
- EXPECT_TRUE(netlink_manager_->Init());
+ // |SaveReply|, |SendMessage|, and |ReplyToSentMessage| work together to
+ // enable a test to get a response to a sent message. They must be called
+ // in the order, above, so that a) a reply message is available to b) have
+ // its sequence number replaced, and then c) sent back to the code.
+ void SaveReply(const ByteString &message) {
+ saved_message_ = message;
+ }
+
+ // Replaces the |saved_message_|'s sequence number with the sent value.
+ bool SendMessage(const ByteString &outgoing_message) {
+ if (outgoing_message.GetLength() < sizeof(nlmsghdr)) {
+ LOG(ERROR) << "Outgoing message is too short";
+ return false;
+ }
+ const nlmsghdr *outgoing_header =
+ reinterpret_cast<const nlmsghdr *>(outgoing_message.GetConstData());
+
+ if (saved_message_.GetLength() < sizeof(nlmsghdr)) {
+ LOG(ERROR) << "Saved message is too short; have you called |SaveReply|?";
+ return false;
+ }
+ nlmsghdr *reply_header =
+ reinterpret_cast<nlmsghdr *>(saved_message_.GetData());
+
+ reply_header->nlmsg_seq = outgoing_header->nlmsg_seq;
+ saved_sequence_number_ = reply_header->nlmsg_seq;
+ return true;
+ }
+
+ bool ReplyToSentMessage(ByteString *message) {
+ if (!message) {
+ return false;
+ }
+ *message = saved_message_;
+ return true;
+ }
+
+ bool ReplyWithRandomMessage(ByteString *message) {
+ GetFamilyMessage get_family_message;
+ // Any number that's not 0 or 1 is acceptable, here. Zero is bad because
+ // we want to make sure that this message is different than the main
+ // send/receive pair. One is bad becasue the default for
+ // |saved_sequence_number_| is zero and the likely default value for the
+ // first sequence number generated from the code is 1.
+ const uint32_t kRandomOffset = 1003;
+ if (!message) {
+ return false;
+ }
+ *message = get_family_message.Encode(saved_sequence_number_ +
+ kRandomOffset);
+ return true;
}
protected:
@@ -95,20 +153,153 @@
DISALLOW_COPY_AND_ASSIGN(MockHandler80211);
};
+ void Reset() {
+ netlink_manager_->Reset(false);
+ }
+
NetlinkManager *netlink_manager_;
- MockNetlinkSocket socket_;
+ MockNetlinkSocket netlink_socket_;
+ MockSockets *sockets_; // Owned by |netlink_socket_|.
+ ByteString saved_message_;
+ uint32_t saved_sequence_number_;
};
+namespace {
+
+class TimeFunctor {
+ public:
+ TimeFunctor(time_t tv_sec, suseconds_t tv_usec) {
+ return_value_.tv_sec = tv_sec;
+ return_value_.tv_usec = tv_usec;
+ }
+
+ TimeFunctor() {
+ return_value_.tv_sec = 0;
+ return_value_.tv_usec = 0;
+ }
+
+ TimeFunctor(const TimeFunctor &other) {
+ return_value_.tv_sec = other.return_value_.tv_sec;
+ return_value_.tv_usec = other.return_value_.tv_usec;
+ }
+
+ TimeFunctor &operator=(const TimeFunctor &rhs) {
+ return_value_.tv_sec = rhs.return_value_.tv_sec;
+ return_value_.tv_usec = rhs.return_value_.tv_usec;
+ return *this;
+ }
+
+ // Replaces GetTimeMonotonic.
+ int operator()(struct timeval *answer) {
+ if (answer) {
+ *answer = return_value_;
+ }
+ return 0;
+ }
+
+ private:
+ struct timeval return_value_;
+
+ // No DISALLOW_COPY_AND_ASSIGN since testing::Invoke uses copy.
+};
+
+} // namespace
+
// TODO(wdg): Add a test for multi-part messages. crbug.com/224652
-// TODO(wdg): Add a test for GetFaimily. crbug.com/224649
-// TODO(wdg): Add a test for OnNewFamilyMessage. crbug.com/222486
// TODO(wdg): Add a test for SubscribeToEvents (verify that it handles bad input
// appropriately, and that it calls NetlinkSocket::SubscribeToEvents if input
// is good.)
-TEST_F(NetlinkManagerTest, BroadcastHandlerTest) {
- SetupNetlinkManagerObject();
+TEST_F(NetlinkManagerTest, GetFamily) {
+ const uint16_t kSampleMessageType = 42;
+ const string kSampleMessageName("SampleMessageName");
+ const uint32_t kRandomSequenceNumber = 3;
+ NewFamilyMessage new_family_message;
+ new_family_message.attributes()->CreateAttribute(
+ CTRL_ATTR_FAMILY_ID,
+ base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+ new_family_message.attributes()->SetU16AttributeValue(
+ CTRL_ATTR_FAMILY_ID, kSampleMessageType);
+ new_family_message.attributes()->CreateAttribute(
+ CTRL_ATTR_FAMILY_NAME,
+ base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+ new_family_message.attributes()->SetStringAttributeValue(
+ CTRL_ATTR_FAMILY_NAME, kSampleMessageName);
+
+ // The sequence number is immaterial since it'll be overwritten.
+ SaveReply(new_family_message.Encode(kRandomSequenceNumber));
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).
+ WillOnce(Invoke(this, &NetlinkManagerTest::SendMessage));
+ EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+ EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillOnce(Return(1));
+ EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+ WillOnce(Invoke(this, &NetlinkManagerTest::ReplyToSentMessage));
+ NetlinkMessageFactory::FactoryMethod null_factory;
+ EXPECT_EQ(kSampleMessageType, netlink_manager_->GetFamily(kSampleMessageName,
+ null_factory));
+}
+
+TEST_F(NetlinkManagerTest, GetFamilyOneInterstitialMessage) {
+ Reset();
+
+ const uint16_t kSampleMessageType = 42;
+ const string kSampleMessageName("SampleMessageName");
+ const uint32_t kRandomSequenceNumber = 3;
+
+ NewFamilyMessage new_family_message;
+ new_family_message.attributes()->CreateAttribute(
+ CTRL_ATTR_FAMILY_ID,
+ base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+ new_family_message.attributes()->SetU16AttributeValue(
+ CTRL_ATTR_FAMILY_ID, kSampleMessageType);
+ new_family_message.attributes()->CreateAttribute(
+ CTRL_ATTR_FAMILY_NAME,
+ base::Bind(&NetlinkAttribute::NewControlAttributeFromId));
+ new_family_message.attributes()->SetStringAttributeValue(
+ CTRL_ATTR_FAMILY_NAME, kSampleMessageName);
+
+ // The sequence number is immaterial since it'll be overwritten.
+ SaveReply(new_family_message.Encode(kRandomSequenceNumber));
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).
+ WillOnce(Invoke(this, &NetlinkManagerTest::SendMessage));
+ EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+ EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillRepeatedly(Return(1));
+ EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+ WillOnce(Invoke(this, &NetlinkManagerTest::ReplyWithRandomMessage)).
+ WillOnce(Invoke(this, &NetlinkManagerTest::ReplyToSentMessage));
+ NetlinkMessageFactory::FactoryMethod null_factory;
+ EXPECT_EQ(kSampleMessageType, netlink_manager_->GetFamily(kSampleMessageName,
+ null_factory));
+}
+
+TEST_F(NetlinkManagerTest, GetFamilyTimeout) {
+ Reset();
+ MockTime time;
+ netlink_manager_->time_ = &time;
+
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
+ time_t kStartSeconds = 1234; // Arbitrary.
+ suseconds_t kSmallUsec = 100;
+ EXPECT_CALL(time, GetTimeMonotonic(_)).
+ WillOnce(Invoke(TimeFunctor(kStartSeconds, 0))). // Initial time.
+ WillOnce(Invoke(TimeFunctor(kStartSeconds, kSmallUsec))).
+ WillOnce(Invoke(TimeFunctor(kStartSeconds, 2 * kSmallUsec))).
+ WillOnce(Invoke(TimeFunctor(
+ kStartSeconds + NetlinkManager::kMaximumNewFamilyWaitSeconds + 1,
+ NetlinkManager::kMaximumNewFamilyWaitMicroSeconds)));
+ EXPECT_CALL(netlink_socket_, file_descriptor()).WillRepeatedly(Return(0));
+ EXPECT_CALL(*sockets_, Select(_, _, _, _, _)).WillRepeatedly(Return(1));
+ EXPECT_CALL(netlink_socket_, RecvMessage(_)).
+ WillRepeatedly(Invoke(this, &NetlinkManagerTest::ReplyWithRandomMessage));
+ NetlinkMessageFactory::FactoryMethod null_factory;
+
+ const string kSampleMessageName("SampleMessageName");
+ EXPECT_EQ(NetlinkMessage::kIllegalMessageType,
+ netlink_manager_->GetFamily(kSampleMessageName, null_factory));
+}
+
+TEST_F(NetlinkManagerTest, BroadcastHandlerTest) {
nlmsghdr *message = const_cast<nlmsghdr *>(
reinterpret_cast<const nlmsghdr *>(kNL80211_CMD_DISCONNECT));
@@ -157,9 +348,7 @@
}
TEST_F(NetlinkManagerTest, MessageHandlerTest) {
- // Setup.
- SetupNetlinkManagerObject();
-
+ Reset();
MockHandler80211 handler_broadcast;
EXPECT_TRUE(netlink_manager_->AddBroadcastHandler(
handler_broadcast.on_netlink_message()));
@@ -186,10 +375,11 @@
netlink_manager_->OnNlMessageReceived(received_message);
// Send the message and give our handler. Verify that we get called back.
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
EXPECT_TRUE(netlink_manager_->SendMessage(
&sent_message_1, handler_sent_1.on_netlink_message()));
// Make it appear that this message is in response to our sent message.
- received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+ received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
EXPECT_CALL(handler_sent_1, OnNetlinkMessage(_)).Times(1);
netlink_manager_->OnNlMessageReceived(received_message);
@@ -200,15 +390,17 @@
// Install and then uninstall message-specific handler; verify broadcast
// handler is called on message receipt.
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
EXPECT_TRUE(netlink_manager_->SendMessage(
&sent_message_1, handler_sent_1.on_netlink_message()));
- received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+ received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
EXPECT_TRUE(netlink_manager_->RemoveMessageHandler(sent_message_1));
EXPECT_CALL(handler_broadcast, OnNetlinkMessage(_)).Times(1);
netlink_manager_->OnNlMessageReceived(received_message);
// Install handler for different message; verify that broadcast handler is
// called for _this_ message.
+ EXPECT_CALL(netlink_socket_, SendMessage(_)).WillOnce(Return(true));
EXPECT_TRUE(netlink_manager_->SendMessage(
&sent_message_2, handler_sent_2.on_netlink_message()));
EXPECT_CALL(handler_broadcast, OnNetlinkMessage(_)).Times(1);
@@ -216,7 +408,7 @@
// Change the ID for the message to that of the second handler; verify that
// the appropriate handler is called for _that_ message.
- received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+ received_message->nlmsg_seq = netlink_socket_.GetLastSequenceNumber();
EXPECT_CALL(handler_sent_2, OnNetlinkMessage(_)).Times(1);
netlink_manager_->OnNlMessageReceived(received_message);
}
diff --git a/netlink_socket.h b/netlink_socket.h
index e15245c..f9ddcae 100644
--- a/netlink_socket.h
+++ b/netlink_socket.h
@@ -52,7 +52,7 @@
bool Init();
// Returns the file descriptor used by the socket.
- int file_descriptor() const { return file_descriptor_; }
+ virtual int file_descriptor() const { return file_descriptor_; }
// Get the next message sequence number for this socket.
// |GetSequenceNumber| won't return zero because that is the 'broadcast'
@@ -62,7 +62,7 @@
// Reads data from the socket into |message| and returns true if successful.
// The |message| parameter will be resized to hold the entirety of the read
// message (and any data in |message| will be overwritten).
- bool RecvMessage(ByteString *message);
+ virtual bool RecvMessage(ByteString *message);
// Sends a message, returns true if successful.
virtual bool SendMessage(const ByteString &message);
@@ -70,10 +70,13 @@
// Subscribes to netlink broadcast events.
virtual bool SubscribeToEvents(uint32_t group_id);
+ virtual const Sockets *sockets() const { return sockets_.get(); }
+
protected:
uint32_t sequence_number_;
private:
+ friend class NetlinkManagerTest;
friend class NetlinkSocketTest;
FRIEND_TEST(NetlinkSocketTest, SequenceNumberTest);
diff --git a/shill_unittest.cc b/shill_unittest.cc
index 8c78409..19eb37d 100644
--- a/shill_unittest.cc
+++ b/shill_unittest.cc
@@ -18,9 +18,11 @@
#include "shill/mock_dhcp_provider.h"
#include "shill/mock_manager.h"
#include "shill/mock_metrics.h"
+#include "shill/mock_netlink_manager.h"
#include "shill/mock_proxy_factory.h"
-#include "shill/mock_rtnl_handler.h"
#include "shill/mock_routing_table.h"
+#include "shill/mock_rtnl_handler.h"
+#include "shill/nl80211_message.h"
#include "shill/shill_daemon.h"
#include "shill/shill_test_config.h"
@@ -202,7 +204,12 @@
daemon_.dhcp_provider_ = &dhcp_provider_;
daemon_.metrics_.reset(metrics_); // Passes ownership
daemon_.manager_.reset(manager_); // Passes ownership
+ daemon_.netlink_manager_ = &netlink_manager_;
dispatcher_test_.ScheduleFailSafe();
+
+ const uint16_t kNl80211MessageType = 42; // Arbitrary.
+ ON_CALL(netlink_manager_, GetFamily(Nl80211Message::kMessageTypeString, _)).
+ WillByDefault(Return(kNl80211MessageType));
}
void StartDaemon() {
daemon_.Start();
@@ -227,6 +234,7 @@
MockDHCPProvider dhcp_provider_;
MockMetrics *metrics_;
MockManager *manager_;
+ MockNetlinkManager netlink_manager_;
DeviceInfo device_info_;
EventDispatcher *dispatcher_;
StrictMock<MockEventDispatchTester> dispatcher_test_;
diff --git a/sockets.cc b/sockets.cc
index 4cb7730..bb7ddae 100644
--- a/sockets.cc
+++ b/sockets.cc
@@ -18,6 +18,8 @@
namespace shill {
+Sockets::Sockets() {}
+
Sockets::~Sockets() {}
// Some system calls can be interrupted and return EINTR, but will succeed on
@@ -103,6 +105,14 @@
return HANDLE_EINTR(recvfrom(sockfd, buf, len, flags, src_addr, addrlen));
}
+int Sockets::Select(int nfds,
+ fd_set *readfds,
+ fd_set *writefds,
+ fd_set *exceptfds,
+ struct timeval *timeout) const {
+ return HANDLE_EINTR(select(nfds, readfds, writefds, exceptfds, timeout));
+}
+
ssize_t Sockets::Send(int sockfd,
const void *buf,
size_t len,
diff --git a/sockets.h b/sockets.h
index 19bb18f..17e9ff8 100644
--- a/sockets.h
+++ b/sockets.h
@@ -5,12 +5,12 @@
#ifndef SHILL_SOCKETS_H_
#define SHILL_SOCKETS_H_
-#include <string>
-
#include <linux/filter.h>
#include <sys/socket.h>
#include <sys/types.h>
+#include <string>
+
#include <base/basictypes.h>
namespace shill {
@@ -18,6 +18,7 @@
// A "sys/socket.h" abstraction allowing mocking in tests.
class Sockets {
public:
+ Sockets();
virtual ~Sockets();
// accept
@@ -68,6 +69,13 @@
virtual ssize_t RecvFrom(int sockfd, void *buf, size_t len, int flags,
struct sockaddr *src_addr, socklen_t *addrlen) const;
+ // select
+ virtual int Select(int nfds,
+ fd_set *readfds,
+ fd_set *writefds,
+ fd_set *exceptfds,
+ struct timeval *timeout) const;
+
// send
virtual ssize_t Send(int sockfd,
const void *buf,
@@ -93,6 +101,9 @@
// socket
virtual int Socket(int domain, int type, int protocol) const;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(Sockets);
};
class ScopedSocketCloser {