shill: Create an asynchronous resolver object
Resolve DNS requests using the c-ares library but
using the shill event loop to handle events.
BUG=chromium-os:21664
TEST=New unit test
Change-Id: I99776b6cc74977d31198c67357c42a75f4047942
Reviewed-on: https://gerrit.chromium.org/gerrit/10328
Reviewed-by: Chris Masone <cmasone@chromium.org>
Tested-by: Paul Stewart <pstew@chromium.org>
diff --git a/Makefile b/Makefile
index b20ad66..d83ea97 100644
--- a/Makefile
+++ b/Makefile
@@ -12,7 +12,7 @@
# libevent, gdk and gtk-2.0 are needed to leverage chrome's MessageLoop
# TODO(cmasone): explore if newer versions of libbase let us avoid this.
-BASE_LIBS = -lbase -lchromeos -levent -lpthread -lrt
+BASE_LIBS = -lbase -lchromeos -levent -lpthread -lrt -lcares
BASE_INCLUDE_DIRS = -I..
BASE_LIB_DIRS =
@@ -88,6 +88,7 @@
dhcp_config.o \
dhcp_provider.o \
dhcpcd_proxy.o \
+ dns_client.o \
endpoint.o \
ephemeral_profile.o \
error.o \
@@ -124,9 +125,11 @@
rtnl_message.o \
service.o \
service_dbus_adaptor.o \
+ shill_ares.o \
shill_config.o \
shill_daemon.o \
shill_test_config.o \
+ shill_time.o \
sockets.o \
supplicant_interface_proxy.o \
supplicant_process_proxy.o \
@@ -154,12 +157,14 @@
device_unittest.o \
dhcp_config_unittest.o \
dhcp_provider_unittest.o \
+ dns_client_unittest.o \
error_unittest.o \
ip_address_unittest.o \
ipconfig_unittest.o \
key_file_store_unittest.o \
manager_unittest.o \
mock_adaptors.o \
+ mock_ares.o \
mock_control.o \
mock_dbus_properties_proxy.o \
mock_device.o \
@@ -167,6 +172,7 @@
mock_dhcp_config.o \
mock_dhcp_provider.o \
mock_dhcp_proxy.o \
+ mock_event_dispatcher.o \
mock_glib.o \
mock_ipconfig.o \
mock_manager.o \
@@ -186,6 +192,7 @@
mock_store.o \
mock_supplicant_interface_proxy.o \
mock_supplicant_process_proxy.o \
+ mock_time.o \
mock_wifi.o \
modem_info_unittest.o \
modem_manager_unittest.o \
diff --git a/dns_client.cc b/dns_client.cc
new file mode 100644
index 0000000..ba1961d
--- /dev/null
+++ b/dns_client.cc
@@ -0,0 +1,320 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/dns_client.h"
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+
+#include <map>
+#include <set>
+#include <string>
+#include <tr1/memory>
+#include <vector>
+
+#include <base/stl_util-inl.h>
+
+#include <shill/shill_ares.h>
+#include <shill/shill_time.h>
+
+using std::map;
+using std::set;
+using std::string;
+using std::vector;
+
+namespace shill {
+
+const int DNSClient::kDefaultTimeoutMS = 2000;
+const char DNSClient::kErrorNoData[] = "The query response contains no answers";
+const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
+const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
+const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
+const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
+const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
+const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
+const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
+const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
+const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
+
+// Private to the implementation of resolver so callers don't include ares.h
+struct DNSClientState {
+ ares_channel channel;
+ map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > read_handlers;
+ map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > write_handlers;
+ struct timeval start_time_;
+};
+
+DNSClient::DNSClient(IPAddress::Family family,
+ const string &interface_name,
+ const vector<string> &dns_servers,
+ int timeout_ms,
+ EventDispatcher *dispatcher,
+ Callback1<bool>::Type *callback)
+ : address_(IPAddress(family)),
+ interface_name_(interface_name),
+ dns_servers_(dns_servers),
+ dispatcher_(dispatcher),
+ callback_(callback),
+ timeout_ms_(timeout_ms),
+ running_(false),
+ resolver_state_(NULL),
+ read_callback_(NewCallback(this, &DNSClient::HandleDNSRead)),
+ write_callback_(NewCallback(this, &DNSClient::HandleDNSWrite)),
+ task_factory_(this),
+ ares_(Ares::GetInstance()),
+ time_(Time::GetInstance()) {}
+
+DNSClient::~DNSClient() {
+ Stop();
+}
+
+bool DNSClient::Start(const string &hostname) {
+ if (running_) {
+ LOG(ERROR) << "Only one DNS request is allowed at a time";
+ return false;
+ }
+
+ if (!resolver_state_.get()) {
+ struct ares_options options;
+ memset(&options, 0, sizeof(options));
+
+ vector<struct in_addr> server_addresses;
+ for (vector<string>::iterator it = dns_servers_.begin();
+ it != dns_servers_.end();
+ ++it) {
+ struct in_addr addr;
+ if (inet_aton(it->c_str(), &addr) != 0) {
+ server_addresses.push_back(addr);
+ }
+ }
+
+ if (server_addresses.empty()) {
+ LOG(ERROR) << "No valid DNS server addresses";
+ return false;
+ }
+
+ options.servers = server_addresses.data();
+ options.nservers = server_addresses.size();
+ options.timeout = timeout_ms_;
+
+ resolver_state_.reset(new DNSClientState);
+ int status = ares_->InitOptions(&resolver_state_->channel,
+ &options,
+ ARES_OPT_SERVERS | ARES_OPT_TIMEOUTMS);
+ if (status != ARES_SUCCESS) {
+ LOG(ERROR) << "ARES initialization returns error code: " << status;
+ resolver_state_.reset();
+ return false;
+ }
+
+ ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
+ }
+
+ running_ = true;
+ time_->GetTimeOfDay(&resolver_state_->start_time_, NULL);
+ error_.clear();
+ ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
+ address_.family(), ReceiveDNSReplyCB, this);
+
+ if (!RefreshHandles()) {
+ LOG(ERROR) << "Impossibly short timeout.";
+ Stop();
+ return false;
+ }
+
+ return true;
+}
+
+void DNSClient::Stop() {
+ if (!resolver_state_.get()) {
+ return;
+ }
+
+ running_ = false;
+ task_factory_.RevokeAll();
+ ares_->Destroy(resolver_state_->channel);
+ resolver_state_.reset();
+}
+
+void DNSClient::HandleDNSRead(int fd) {
+ ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
+ RefreshHandles();
+}
+
+void DNSClient::HandleDNSWrite(int fd) {
+ ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
+ RefreshHandles();
+}
+
+void DNSClient::HandleTimeout() {
+ ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
+ if (!RefreshHandles()) {
+ // If we have timed out, ARES might still have sockets open.
+ // Force them closed by doing an explicit shutdown. This is
+ // different from HandleDNSRead and HandleDNSWrite where any
+ // change in our running_ state would be as a result of ARES
+ // itself and therefore properly synchronized with it: if a
+ // search completes during the course of ares_->ProcessFd(),
+ // the ARES fds and other state is guaranteed to have cleaned
+ // up and ready for a new request. Since this timeout is
+ // genererated outside of the library it is best to completely
+ // shutdown ARES and start with fresh state for a new request.
+ Stop();
+ }
+}
+
+void DNSClient::ReceiveDNSReply(int status, struct hostent *hostent) {
+ if (!running_) {
+ // We can be called during ARES shutdown -- ignore these events.
+ return;
+ }
+ running_ = false;
+
+ if (status == ARES_SUCCESS &&
+ hostent != NULL &&
+ hostent->h_addrtype == address_.family() &&
+ hostent->h_length == IPAddress::GetAddressLength(address_.family()) &&
+ hostent->h_addr_list != NULL &&
+ hostent->h_addr_list[0] != NULL) {
+ address_ = IPAddress(address_.family(),
+ ByteString(reinterpret_cast<unsigned char *>(
+ hostent->h_addr_list[0]), hostent->h_length));
+ callback_->Run(true);
+ } else {
+ switch (status) {
+ case ARES_ENODATA:
+ error_ = kErrorNoData;
+ break;
+ case ARES_EFORMERR:
+ error_ = kErrorFormErr;
+ break;
+ case ARES_ESERVFAIL:
+ error_ = kErrorServerFail;
+ break;
+ case ARES_ENOTFOUND:
+ error_ = kErrorNotFound;
+ break;
+ case ARES_ENOTIMP:
+ error_ = kErrorNotImp;
+ break;
+ case ARES_EREFUSED:
+ error_ = kErrorRefused;
+ break;
+ case ARES_EBADQUERY:
+ case ARES_EBADNAME:
+ case ARES_EBADFAMILY:
+ case ARES_EBADRESP:
+ error_ = kErrorBadQuery;
+ break;
+ case ARES_ECONNREFUSED:
+ error_ = kErrorNetRefused;
+ break;
+ case ARES_ETIMEOUT:
+ error_ = kErrorTimedOut;
+ break;
+ default:
+ error_ = kErrorUnknown;
+ if (status == ARES_SUCCESS) {
+ LOG(ERROR) << "ARES returned success but hostent was invalid!";
+ } else {
+ LOG(ERROR) << "ARES returned unhandled error status " << status;
+ }
+ break;
+ }
+ callback_->Run(false);
+ }
+}
+
+void DNSClient::ReceiveDNSReplyCB(void *arg, int status,
+ int /*timeouts*/,
+ struct hostent *hostent) {
+ DNSClient *res = static_cast<DNSClient *>(arg);
+ res->ReceiveDNSReply(status, hostent);
+}
+
+bool DNSClient::RefreshHandles() {
+ map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_read =
+ resolver_state_->read_handlers;
+ map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_write =
+ resolver_state_->write_handlers;
+
+ resolver_state_->read_handlers.clear();
+ resolver_state_->write_handlers.clear();
+
+ ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
+ int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
+ ARES_GETSOCK_MAXNUM);
+
+ for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
+ if (ARES_GETSOCK_READABLE(action_bits, i)) {
+ if (ContainsKey(old_read, sockets[i])) {
+ resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
+ } else {
+ resolver_state_->read_handlers[sockets[i]] =
+ std::tr1::shared_ptr<IOHandler> (
+ dispatcher_->CreateReadyHandler(sockets[i],
+ IOHandler::kModeInput,
+ read_callback_.get()));
+ }
+ }
+ if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
+ if (ContainsKey(old_write, sockets[i])) {
+ resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
+ } else {
+ resolver_state_->write_handlers[sockets[i]] =
+ std::tr1::shared_ptr<IOHandler> (
+ dispatcher_->CreateReadyHandler(sockets[i],
+ IOHandler::kModeOutput,
+ write_callback_.get()));
+ }
+ }
+ }
+
+ if (!running_) {
+ // We are here just to clean up socket and timer handles, and the
+ // ARES state was cleaned up during the last call to ares_process_fd().
+ task_factory_.RevokeAll();
+ return false;
+ }
+
+ // Schedule timer event for the earlier of our timeout or one requested by
+ // the resolver library.
+ struct timeval now, elapsed_time, timeout_tv;
+ time_->GetTimeOfDay(&now, NULL);
+ timersub(&now, &resolver_state_->start_time_, &elapsed_time);
+ timeout_tv.tv_sec = timeout_ms_ / 1000;
+ timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
+ if (timercmp(&elapsed_time, &timeout_tv, >=)) {
+ // There are 3 cases of interest:
+ // - If we got here from Start(), we will have the side-effect of
+ // both invoking the callback and returning False in Start().
+ // Start() will call Stop() which will shut down ARES.
+ // - If we got here from the tail of an IO event (racing with the
+ // timer, we can't call Stop() since that will blow away the
+ // IOHandler we are running in, however we will soon be called
+ // again by the timeout proc so we can clean up the ARES state
+ // then.
+ // - If we got here from a timeout handler, it will safely call
+ // Stop() when we return false.
+ error_ = kErrorTimedOut;
+ callback_->Run(false);
+ running_ = false;
+ return false;
+ } else {
+ struct timeval max, ret_tv;
+ timersub(&timeout_tv, &elapsed_time, &max);
+ struct timeval *tv = ares_->Timeout(resolver_state_->channel,
+ &max, &ret_tv);
+ task_factory_.RevokeAll();
+ dispatcher_->PostDelayedTask(
+ task_factory_.NewRunnableMethod(&DNSClient::HandleTimeout),
+ tv->tv_sec * 1000 + tv->tv_usec / 1000);
+ }
+
+ return true;
+}
+
+} // namespace shill
diff --git a/dns_client.h b/dns_client.h
new file mode 100644
index 0000000..dd47784
--- /dev/null
+++ b/dns_client.h
@@ -0,0 +1,87 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_DNS_CLIENT_
+#define SHILL_DNS_CLIENT_
+
+#include <string>
+#include <vector>
+
+#include <base/callback_old.h>
+#include <base/memory/scoped_ptr.h>
+#include <base/task.h>
+#include <gtest/gtest_prod.h> // for FRIEND_TEST
+
+#include "shill/event_dispatcher.h"
+#include "shill/ip_address.h"
+#include "shill/refptr_types.h"
+
+struct hostent;
+
+namespace shill {
+
+class Ares;
+class Time;
+struct DNSClientState;
+
+// Implements a DNS resolution client that can run asynchronously
+class DNSClient {
+ public:
+ static const int kDefaultTimeoutMS;
+ static const char kErrorNoData[];
+ static const char kErrorFormErr[];
+ static const char kErrorServerFail[];
+ static const char kErrorNotFound[];
+ static const char kErrorNotImp[];
+ static const char kErrorRefused[];
+ static const char kErrorBadQuery[];
+ static const char kErrorNetRefused[];
+ static const char kErrorTimedOut[];
+ static const char kErrorUnknown[];
+
+ DNSClient(IPAddress::Family family,
+ const std::string &interface_name,
+ const std::vector<std::string> &dns_servers,
+ int timeout_ms,
+ EventDispatcher *dispatcher,
+ Callback1<bool>::Type *callback);
+ ~DNSClient();
+
+ bool Start(const std::string &hostname);
+ void Stop();
+ const IPAddress &address() const { return address_; }
+ const std::string &error() const { return error_; }
+
+ private:
+ friend class DNSClientTest;
+
+ void HandleDNSRead(int fd);
+ void HandleDNSWrite(int fd);
+ void HandleTimeout();
+ void ReceiveDNSReply(int status, struct hostent *hostent);
+ static void ReceiveDNSReplyCB(void *arg, int status, int timeouts,
+ struct hostent *hostent);
+ bool RefreshHandles();
+
+ IPAddress address_;
+ std::string interface_name_;
+ std::vector<std::string> dns_servers_;
+ EventDispatcher *dispatcher_;
+ Callback1<bool>::Type *callback_;
+ int timeout_ms_;
+ bool running_;
+ std::string error_;
+ scoped_ptr<DNSClientState> resolver_state_;
+ scoped_ptr<Callback1<int>::Type> read_callback_;
+ scoped_ptr<Callback1<int>::Type> write_callback_;
+ ScopedRunnableMethodFactory<DNSClient> task_factory_;
+ Ares *ares_;
+ Time *time_;
+
+ DISALLOW_COPY_AND_ASSIGN(DNSClient);
+};
+
+} // namespace shill
+
+#endif // SHILL_DNS_CLIENT_
diff --git a/dns_client_unittest.cc b/dns_client_unittest.cc
new file mode 100644
index 0000000..2df51cc
--- /dev/null
+++ b/dns_client_unittest.cc
@@ -0,0 +1,356 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/dns_client.h"
+
+#include <netdb.h>
+
+#include <string>
+#include <vector>
+
+#include <base/memory/scoped_ptr.h>
+#include <gtest/gtest.h>
+#include <gmock/gmock.h>
+
+#include "shill/event_dispatcher.h"
+#include "shill/io_handler.h"
+#include "shill/mock_ares.h"
+#include "shill/mock_control.h"
+#include "shill/mock_event_dispatcher.h"
+#include "shill/mock_time.h"
+
+using std::string;
+using std::vector;
+using testing::_;
+using testing::DoAll;
+using testing::Return;
+using testing::ReturnArg;
+using testing::ReturnNew;
+using testing::Test;
+using testing::SetArgumentPointee;
+using testing::StrEq;
+using testing::StrictMock;
+
+namespace shill {
+
+namespace {
+const char kGoodName[] = "all-systems.mcast.net";
+const char kResult[] = "224.0.0.1";
+const char kGoodServer[] = "8.8.8.8";
+const char kBadServer[] = "10.9xx8.7";
+const char kNetworkInterface[] = "eth0";
+char kReturnAddressList0[] = { 224, 0, 0, 1 };
+char *kReturnAddressList[] = { kReturnAddressList0, NULL };
+char kFakeAresChannelData = 0;
+const ares_channel kAresChannel =
+ reinterpret_cast<ares_channel>(&kFakeAresChannelData);
+const int kAresFd = 10203;
+const int kAresTimeoutMS = 2000; // ARES transaction timeout
+const int kAresWaitMS = 1000; // Time period ARES asks caller to wait
+} // namespace {}
+
+class DNSClientTest : public Test {
+ public:
+ DNSClientTest() : ares_result_(ARES_SUCCESS) {
+ time_val_.tv_sec = 0;
+ time_val_.tv_usec = 0;
+ ares_timeout_.tv_sec = kAresWaitMS / 1000;
+ ares_timeout_.tv_usec = (kAresWaitMS % 1000) * 1000;
+ hostent_.h_addrtype = IPAddress::kFamilyIPv4;
+ hostent_.h_length = sizeof(kReturnAddressList0);
+ hostent_.h_addr_list = kReturnAddressList;
+ }
+
+ virtual void SetUp() {
+ EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
+ SetInActive();
+ }
+
+ virtual void TearDown() {
+ // We need to make sure the dns_client instance releases ares_
+ // before the destructor for DNSClientTest deletes ares_.
+ if (dns_client_.get()) {
+ dns_client_->Stop();
+ }
+ }
+
+ void AdvanceTime(int time_ms) {
+ struct timeval adv_time = { time_ms/1000, (time_ms % 1000) * 1000 };
+ timeradd(&time_val_, &adv_time, &time_val_);
+ EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
+ }
+
+ void CallReplyCB() {
+ dns_client_->ReceiveDNSReplyCB(dns_client_.get(), ares_result_, 0,
+ &hostent_);
+ }
+
+ void CallDNSRead() {
+ dns_client_->HandleDNSRead(kAresFd);
+ }
+
+ void CallDNSWrite() {
+ dns_client_->HandleDNSWrite(kAresFd);
+ }
+
+ void CallTimeout() {
+ dns_client_->HandleTimeout();
+ }
+
+ void CreateClient(const vector<string> &dns_servers, int timeout_ms) {
+ dns_client_.reset(new DNSClient(IPAddress::kFamilyIPv4,
+ kNetworkInterface,
+ dns_servers,
+ timeout_ms,
+ &dispatcher_,
+ callback_target_.callback()));
+ dns_client_->ares_ = &ares_;
+ dns_client_->time_ = &time_;
+ }
+
+ void SetActive() {
+ // Returns that socket kAresFd is readable.
+ EXPECT_CALL(ares_, GetSock(_, _, _))
+ .WillRepeatedly(DoAll(SetArgumentPointee<1>(kAresFd), Return(1)));
+ EXPECT_CALL(ares_, Timeout(_, _, _))
+ .WillRepeatedly(
+ DoAll(SetArgumentPointee<2>(ares_timeout_), ReturnArg<2>()));
+ }
+
+ void SetInActive() {
+ EXPECT_CALL(ares_, GetSock(_, _, _))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(ares_, Timeout(_, _, _))
+ .WillRepeatedly(ReturnArg<1>());
+ }
+
+ void SetupRequest(const string &name, const string &server) {
+ vector<string> dns_servers;
+ dns_servers.push_back(server);
+ CreateClient(dns_servers, kAresTimeoutMS);
+ // These expectations are fulfilled when dns_client_->Start() is called.
+ EXPECT_CALL(ares_, InitOptions(_, _, _))
+ .WillOnce(DoAll(SetArgumentPointee<0>(kAresChannel),
+ Return(ARES_SUCCESS)));
+ EXPECT_CALL(ares_, SetLocalDev(kAresChannel, StrEq(kNetworkInterface)))
+ .Times(1);
+ EXPECT_CALL(ares_, GetHostByName(kAresChannel, StrEq(name), _, _, _));
+ }
+
+ void StartValidRequest() {
+ SetupRequest(kGoodName, kGoodServer);
+ EXPECT_CALL(dispatcher_,
+ CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
+ .WillOnce(ReturnNew<IOHandler>());
+ SetActive();
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ ASSERT_TRUE(dns_client_->Start(kGoodName));
+ EXPECT_CALL(ares_, Destroy(kAresChannel));
+ }
+
+ void TestValidCompletion() {
+ EXPECT_CALL(callback_target_, CallTarget(true));
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD))
+ .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
+ CallDNSRead();
+ ASSERT_TRUE(dns_client_->address().IsValid());
+ IPAddress ipaddr(dns_client_->address().family());
+ ASSERT_TRUE(ipaddr.SetAddressFromString(kResult));
+ EXPECT_TRUE(ipaddr.Equals(dns_client_->address()));
+ }
+
+ protected:
+ class DNSCallbackTarget {
+ public:
+ DNSCallbackTarget()
+ : callback_(NewCallback(this, &DNSCallbackTarget::CallTarget)) {}
+
+ MOCK_METHOD1(CallTarget, void(bool success));
+ Callback1<bool>::Type *callback() { return callback_.get(); }
+
+ private:
+ scoped_ptr<Callback1<bool>::Type> callback_;
+ };
+
+ scoped_ptr<DNSClient> dns_client_;
+ MockEventDispatcher dispatcher_;
+ string queued_request_;
+ StrictMock<DNSCallbackTarget> callback_target_;
+ StrictMock<MockAres> ares_;
+ StrictMock<MockTime> time_;
+ struct timeval time_val_;
+ struct timeval ares_timeout_;
+ struct hostent hostent_;
+ int ares_result_;
+};
+
+class SentinelIOHandler : public IOHandler {
+ public:
+ MOCK_METHOD0(Die, void());
+ virtual ~SentinelIOHandler() { Die(); }
+};
+
+TEST_F(DNSClientTest, Constructor) {
+ vector<string> dns_servers;
+ dns_servers.push_back(kGoodServer);
+ CreateClient(dns_servers, kAresTimeoutMS);
+ EXPECT_TRUE(dns_client_->address().family() == IPAddress::kFamilyIPv4);
+ EXPECT_TRUE(dns_client_->address().IsDefault());
+}
+
+// Receive error because no DNS servers were specified.
+TEST_F(DNSClientTest, NoServers) {
+ CreateClient(vector<string>(), kAresTimeoutMS);
+ EXPECT_FALSE(dns_client_->Start(kGoodName));
+}
+
+// Receive error because the DNS server IP address is invalid.
+TEST_F(DNSClientTest, TimeoutInvalidServer) {
+ vector<string> dns_servers;
+ dns_servers.push_back(kBadServer);
+ CreateClient(dns_servers, kAresTimeoutMS);
+ ASSERT_FALSE(dns_client_->Start(kGoodName));
+}
+
+// Setup error because InitOptions failed.
+TEST_F(DNSClientTest, InitOptionsFailure) {
+ vector<string> dns_servers;
+ dns_servers.push_back(kGoodServer);
+ CreateClient(dns_servers, kAresTimeoutMS);
+ EXPECT_CALL(ares_, InitOptions(_, _, _))
+ .WillOnce(Return(ARES_EBADFLAGS));
+ EXPECT_FALSE(dns_client_->Start(kGoodName));
+}
+
+// Fail a second request because one is already in progress.
+TEST_F(DNSClientTest, MultipleRequest) {
+ StartValidRequest();
+ ASSERT_FALSE(dns_client_->Start(kGoodName));
+}
+
+TEST_F(DNSClientTest, GoodRequest) {
+ StartValidRequest();
+ TestValidCompletion();
+}
+
+TEST_F(DNSClientTest, GoodRequestWithTimeout) {
+ StartValidRequest();
+ // Insert an intermediate HandleTimeout callback.
+ AdvanceTime(kAresWaitMS);
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ CallTimeout();
+ AdvanceTime(kAresWaitMS);
+ TestValidCompletion();
+}
+
+TEST_F(DNSClientTest, GoodRequestWithDNSRead) {
+ StartValidRequest();
+ // Insert an intermediate HandleDNSRead callback.
+ AdvanceTime(kAresWaitMS);
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ CallDNSRead();
+ AdvanceTime(kAresWaitMS);
+ TestValidCompletion();
+}
+
+TEST_F(DNSClientTest, GoodRequestWithDNSWrite) {
+ StartValidRequest();
+ // Insert an intermediate HandleDNSWrite callback.
+ AdvanceTime(kAresWaitMS);
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, kAresFd));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ CallDNSWrite();
+ AdvanceTime(kAresWaitMS);
+ TestValidCompletion();
+}
+
+// Failure due to the timeout occurring during first call to RefreshHandles.
+TEST_F(DNSClientTest, TimeoutFirstRefresh) {
+ SetupRequest(kGoodName, kGoodServer);
+ struct timeval init_time_val = time_val_;
+ AdvanceTime(kAresTimeoutMS);
+ EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ .WillOnce(DoAll(SetArgumentPointee<0>(init_time_val), Return(0)))
+ .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
+ EXPECT_CALL(callback_target_, CallTarget(false));
+ EXPECT_CALL(ares_, Destroy(kAresChannel));
+ ASSERT_FALSE(dns_client_->Start(kGoodName));
+ EXPECT_EQ(string(DNSClient::kErrorTimedOut), dns_client_->error());
+}
+
+// Failed request due to timeout within the dns_client.
+TEST_F(DNSClientTest, TimeoutDispatcherEvent) {
+ StartValidRequest();
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel,
+ ARES_SOCKET_BAD, ARES_SOCKET_BAD));
+ AdvanceTime(kAresTimeoutMS);
+ EXPECT_CALL(callback_target_, CallTarget(false));
+ CallTimeout();
+}
+
+// Failed request due to timeout reported by ARES.
+TEST_F(DNSClientTest, TimeoutFromARES) {
+ StartValidRequest();
+ AdvanceTime(kAresWaitMS);
+ ares_result_ = ARES_ETIMEOUT;
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD))
+ .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
+ EXPECT_CALL(callback_target_, CallTarget(false));
+ CallTimeout();
+ EXPECT_EQ(string(DNSClient::kErrorTimedOut), dns_client_->error());
+}
+
+// Failed request due to "host not found" reported by ARES.
+TEST_F(DNSClientTest, HostNotFound) {
+ StartValidRequest();
+ AdvanceTime(kAresWaitMS);
+ ares_result_ = ARES_ENOTFOUND;
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD))
+ .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
+ EXPECT_CALL(callback_target_, CallTarget(false));
+ CallDNSRead();
+ EXPECT_EQ(string(DNSClient::kErrorNotFound), dns_client_->error());
+}
+
+// Make sure IOHandles are deallocated when GetSock() reports them gone.
+TEST_F(DNSClientTest, IOHandleDeallocGetSock) {
+ SetupRequest(kGoodName, kGoodServer);
+ // This isn't any kind of scoped/ref pointer because we are tracking dealloc.
+ SentinelIOHandler *io_handler = new SentinelIOHandler();
+ EXPECT_CALL(dispatcher_,
+ CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
+ .WillOnce(Return(io_handler));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ SetActive();
+ ASSERT_TRUE(dns_client_->Start(kGoodName));
+ AdvanceTime(kAresWaitMS);
+ SetInActive();
+ EXPECT_CALL(*io_handler, Die());
+ EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ CallDNSRead();
+ EXPECT_CALL(ares_, Destroy(kAresChannel));
+}
+
+// Make sure IOHandles are deallocated when Stop() is called.
+TEST_F(DNSClientTest, IOHandleDeallocStop) {
+ SetupRequest(kGoodName, kGoodServer);
+ // This isn't any kind of scoped/ref pointer because we are tracking dealloc.
+ SentinelIOHandler *io_handler = new SentinelIOHandler();
+ EXPECT_CALL(dispatcher_,
+ CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
+ .WillOnce(Return(io_handler));
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
+ SetActive();
+ ASSERT_TRUE(dns_client_->Start(kGoodName));
+ EXPECT_CALL(*io_handler, Die());
+ EXPECT_CALL(ares_, Destroy(kAresChannel));
+ dns_client_->Stop();
+}
+
+} // namespace shill
diff --git a/event_dispatcher.h b/event_dispatcher.h
index 5ca886e..14724bc 100644
--- a/event_dispatcher.h
+++ b/event_dispatcher.h
@@ -29,15 +29,15 @@
EventDispatcher();
virtual ~EventDispatcher();
- void DispatchForever();
+ virtual void DispatchForever();
// Processes all pending events that can run and returns.
- void DispatchPendingEvents();
+ virtual void DispatchPendingEvents();
// These are thin wrappers around calls of the same name in
// <base/message_loop_proxy.h>
- bool PostTask(Task *task);
- bool PostDelayedTask(Task *task, int64 delay_ms);
+ virtual bool PostTask(Task *task);
+ virtual bool PostDelayedTask(Task *task, int64 delay_ms);
virtual IOHandler *CreateInputHandler(
int fd,
diff --git a/glib_io_ready_handler.cc b/glib_io_ready_handler.cc
index 3109fc2..bb4fa7c 100644
--- a/glib_io_ready_handler.cc
+++ b/glib_io_ready_handler.cc
@@ -46,7 +46,7 @@
GlibIOReadyHandler::~GlibIOReadyHandler() {
g_source_remove(source_id_);
- g_io_channel_shutdown(channel_, TRUE, NULL);
+ // NB: We don't shut down the channel since we don't own it
g_io_channel_unref(channel_);
}
diff --git a/glib_io_ready_handler.h b/glib_io_ready_handler.h
index 919c8c3..3a3d310 100644
--- a/glib_io_ready_handler.h
+++ b/glib_io_ready_handler.h
@@ -20,9 +20,9 @@
// sockets and effort to working with peripheral libraries.
class GlibIOReadyHandler : public IOHandler {
public:
- GlibIOReadyHandler(int fd,
- IOHandler::ReadyMode mode,
- Callback1<int>::Type *callback);
+ GlibIOReadyHandler(int fd,
+ IOHandler::ReadyMode mode,
+ Callback1<int>::Type *callback);
~GlibIOReadyHandler();
virtual void Start();
diff --git a/mock_ares.cc b/mock_ares.cc
new file mode 100644
index 0000000..e36e071
--- /dev/null
+++ b/mock_ares.cc
@@ -0,0 +1,13 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_ares.h"
+
+namespace shill {
+
+MockAres::MockAres() {}
+
+MockAres::~MockAres() {}
+
+} // namespace shill
diff --git a/mock_ares.h b/mock_ares.h
new file mode 100644
index 0000000..0a89de2
--- /dev/null
+++ b/mock_ares.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_ARES_H_
+#define SHILL_MOCK_ARES_H_
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+#include "shill/shill_ares.h"
+
+namespace shill {
+
+class MockAres : public Ares {
+ public:
+ MockAres();
+ virtual ~MockAres();
+
+ MOCK_METHOD1(Destroy, void(ares_channel channel));
+ MOCK_METHOD5(GetHostByName, void(ares_channel channel,
+ const char *hostname,
+ int family,
+ ares_host_callback callback,
+ void *arg));
+ MOCK_METHOD3(GetSock, int(ares_channel channel,
+ ares_socket_t *socks,
+ int numsocks));
+ MOCK_METHOD3(InitOptions, int(ares_channel *channelptr,
+ struct ares_options *options,
+ int optmask));
+ MOCK_METHOD3(ProcessFd, void(ares_channel channel,
+ ares_socket_t read_fd,
+ ares_socket_t write_fd));
+ MOCK_METHOD2(SetLocalDev, void(ares_channel channel,
+ const char *local_dev_name));
+ MOCK_METHOD3(Timeout, struct timeval *(ares_channel channel,
+ struct timeval *maxtv,
+ struct timeval *tv));
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockAres);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_ARES_H_
diff --git a/mock_event_dispatcher.cc b/mock_event_dispatcher.cc
new file mode 100644
index 0000000..b9c8fcf
--- /dev/null
+++ b/mock_event_dispatcher.cc
@@ -0,0 +1,13 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_event_dispatcher.h"
+
+namespace shill {
+
+MockEventDispatcher::MockEventDispatcher() {}
+
+MockEventDispatcher::~MockEventDispatcher() {}
+
+} // namespace shill
diff --git a/mock_event_dispatcher.h b/mock_event_dispatcher.h
new file mode 100644
index 0000000..f30c37d
--- /dev/null
+++ b/mock_event_dispatcher.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_EVENT_DISPATCHER_H_
+#define SHILL_MOCK_EVENT_DISPATCHER_H_
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+#include "shill/event_dispatcher.h"
+
+namespace shill {
+
+class MockEventDispatcher : public EventDispatcher {
+ public:
+ MockEventDispatcher();
+ virtual ~MockEventDispatcher();
+
+ MOCK_METHOD0(DispatchForever, void());
+ MOCK_METHOD0(DispatchPendingEvents, void());
+ MOCK_METHOD1(PostTask, bool(Task *task));
+ MOCK_METHOD2(PostDelayedTask, bool(Task *task, int64 delay_ms));
+ MOCK_METHOD2(CreateInputHandler,
+ IOHandler *(int fd, Callback1<InputData *>::Type *callback));
+
+ MOCK_METHOD3(CreateReadyHandler,
+ IOHandler *(int fd,
+ IOHandler::ReadyMode mode,
+ Callback1<int>::Type *callback));
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockEventDispatcher);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_EVENT_DISPATCHER_H_
diff --git a/mock_time.cc b/mock_time.cc
new file mode 100644
index 0000000..754a1c7
--- /dev/null
+++ b/mock_time.cc
@@ -0,0 +1,13 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_time.h"
+
+namespace shill {
+
+MockTime::MockTime() {}
+
+MockTime::~MockTime() {}
+
+} // namespace shill
diff --git a/mock_time.h b/mock_time.h
new file mode 100644
index 0000000..7199822
--- /dev/null
+++ b/mock_time.h
@@ -0,0 +1,28 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_TIME_H_
+#define SHILL_MOCK_TIME_H_
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+#include "shill/shill_time.h"
+
+namespace shill {
+
+class MockTime : public Time {
+ public:
+ MockTime();
+ virtual ~MockTime();
+
+ MOCK_METHOD2(GetTimeOfDay, int(struct timeval *tv, struct timezone *tz));
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockTime);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_TIME_H_
diff --git a/shill_ares.cc b/shill_ares.cc
new file mode 100644
index 0000000..cdede2c
--- /dev/null
+++ b/shill_ares.cc
@@ -0,0 +1,60 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/shill_ares.h"
+
+namespace shill {
+
+static base::LazyInstance<Ares> g_ares(base::LINKER_INITIALIZED);
+
+Ares::Ares() { }
+
+Ares::~Ares() { }
+
+Ares* Ares::GetInstance() {
+ return g_ares.Pointer();
+}
+
+void Ares::Destroy(ares_channel channel) {
+ ares_destroy(channel);
+}
+
+void Ares::GetHostByName(ares_channel channel,
+ const char *hostname,
+ int family,
+ ares_host_callback callback,
+ void *arg) {
+ ares_gethostbyname(channel, hostname, family, callback, arg);
+}
+
+int Ares::GetSock(ares_channel channel,
+ ares_socket_t *socks,
+ int numsocks) {
+ return ares_getsock(channel, socks, numsocks);
+}
+
+int Ares::InitOptions(ares_channel *channelptr,
+ struct ares_options *options,
+ int optmask) {
+ return ares_init_options(channelptr, options, optmask);
+}
+
+
+void Ares::ProcessFd(ares_channel channel,
+ ares_socket_t read_fd,
+ ares_socket_t write_fd) {
+ return ares_process_fd(channel, read_fd, write_fd);
+}
+
+void Ares::SetLocalDev(ares_channel channel, const char *local_dev_name) {
+ ares_set_local_dev(channel, local_dev_name);
+}
+
+struct timeval *Ares::Timeout(ares_channel channel,
+ struct timeval *maxtv,
+ struct timeval *tv) {
+ return ares_timeout(channel, maxtv, tv);
+}
+
+} // namespace shill
diff --git a/shill_ares.h b/shill_ares.h
new file mode 100644
index 0000000..33c10fb
--- /dev/null
+++ b/shill_ares.h
@@ -0,0 +1,65 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_ARES_H_
+#define SHILL_ARES_H_
+
+#include <ares.h>
+
+#include <base/lazy_instance.h>
+
+namespace shill {
+
+// A "ares.h" abstraction allowing mocking in tests.
+class Ares {
+ public:
+ virtual ~Ares();
+
+ static Ares *GetInstance();
+
+ // ares_destroy
+ virtual void Destroy(ares_channel channel);
+
+ // ares_gethostbyname
+ virtual void GetHostByName(ares_channel channel,
+ const char *hostname,
+ int family,
+ ares_host_callback callback,
+ void *arg);
+
+ // ares_getsock
+ virtual int GetSock(ares_channel channel,
+ ares_socket_t *socks,
+ int numsocks);
+
+ // ares_init_options
+ virtual int InitOptions(ares_channel *channelptr,
+ struct ares_options *options,
+ int optmask);
+
+ // ares_process_fd
+ virtual void ProcessFd(ares_channel channel,
+ ares_socket_t read_fd,
+ ares_socket_t write_fd);
+
+ // ares_set_local_dev
+ virtual void SetLocalDev(ares_channel channel, const char *local_dev_name);
+
+ // ares_timeout
+ virtual struct timeval *Timeout(ares_channel channel,
+ struct timeval *maxtv,
+ struct timeval *tv);
+
+ protected:
+ Ares();
+
+ private:
+ friend struct base::DefaultLazyInstanceTraits<Ares>;
+
+ DISALLOW_COPY_AND_ASSIGN(Ares);
+};
+
+} // namespace shill
+
+#endif // SHILL_ARES_H_
diff --git a/shill_time.cc b/shill_time.cc
new file mode 100644
index 0000000..7bae785
--- /dev/null
+++ b/shill_time.cc
@@ -0,0 +1,23 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/shill_time.h"
+
+namespace shill {
+
+static base::LazyInstance<Time> g_time(base::LINKER_INITIALIZED);
+
+Time::Time() { }
+
+Time::~Time() { }
+
+Time* Time::GetInstance() {
+ return g_time.Pointer();
+}
+
+int Time::GetTimeOfDay(struct timeval *tv, struct timezone *tz) {
+ return gettimeofday(tv, tz);
+}
+
+} // namespace shill
diff --git a/shill_time.h b/shill_time.h
new file mode 100644
index 0000000..846acd4
--- /dev/null
+++ b/shill_time.h
@@ -0,0 +1,35 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_TIME_H_
+#define SHILL_TIME_H_
+
+#include <sys/time.h>
+
+#include <base/lazy_instance.h>
+
+namespace shill {
+
+// A "sys/time.h" abstraction allowing mocking in tests.
+class Time {
+ public:
+ virtual ~Time();
+
+ static Time *GetInstance();
+
+ // gettimeofday
+ virtual int GetTimeOfDay(struct timeval *tv, struct timezone *tz);
+
+ protected:
+ Time();
+
+ private:
+ friend struct base::DefaultLazyInstanceTraits<Time>;
+
+ DISALLOW_COPY_AND_ASSIGN(Time);
+};
+
+} // namespace shill
+
+#endif // SHILL_TIME_H_