blob: ba1961dd0218658a3ab5d8cb9829263083490c10 [file] [log] [blame]
// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/dns_client.h"
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <map>
#include <set>
#include <string>
#include <tr1/memory>
#include <vector>
#include <base/stl_util-inl.h>
#include <shill/shill_ares.h>
#include <shill/shill_time.h>
using std::map;
using std::set;
using std::string;
using std::vector;
namespace shill {
const int DNSClient::kDefaultTimeoutMS = 2000;
const char DNSClient::kErrorNoData[] = "The query response contains no answers";
const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
// Private to the implementation of resolver so callers don't include ares.h
struct DNSClientState {
ares_channel channel;
map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > read_handlers;
map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > write_handlers;
struct timeval start_time_;
};
DNSClient::DNSClient(IPAddress::Family family,
const string &interface_name,
const vector<string> &dns_servers,
int timeout_ms,
EventDispatcher *dispatcher,
Callback1<bool>::Type *callback)
: address_(IPAddress(family)),
interface_name_(interface_name),
dns_servers_(dns_servers),
dispatcher_(dispatcher),
callback_(callback),
timeout_ms_(timeout_ms),
running_(false),
resolver_state_(NULL),
read_callback_(NewCallback(this, &DNSClient::HandleDNSRead)),
write_callback_(NewCallback(this, &DNSClient::HandleDNSWrite)),
task_factory_(this),
ares_(Ares::GetInstance()),
time_(Time::GetInstance()) {}
DNSClient::~DNSClient() {
Stop();
}
bool DNSClient::Start(const string &hostname) {
if (running_) {
LOG(ERROR) << "Only one DNS request is allowed at a time";
return false;
}
if (!resolver_state_.get()) {
struct ares_options options;
memset(&options, 0, sizeof(options));
vector<struct in_addr> server_addresses;
for (vector<string>::iterator it = dns_servers_.begin();
it != dns_servers_.end();
++it) {
struct in_addr addr;
if (inet_aton(it->c_str(), &addr) != 0) {
server_addresses.push_back(addr);
}
}
if (server_addresses.empty()) {
LOG(ERROR) << "No valid DNS server addresses";
return false;
}
options.servers = server_addresses.data();
options.nservers = server_addresses.size();
options.timeout = timeout_ms_;
resolver_state_.reset(new DNSClientState);
int status = ares_->InitOptions(&resolver_state_->channel,
&options,
ARES_OPT_SERVERS | ARES_OPT_TIMEOUTMS);
if (status != ARES_SUCCESS) {
LOG(ERROR) << "ARES initialization returns error code: " << status;
resolver_state_.reset();
return false;
}
ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
}
running_ = true;
time_->GetTimeOfDay(&resolver_state_->start_time_, NULL);
error_.clear();
ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
address_.family(), ReceiveDNSReplyCB, this);
if (!RefreshHandles()) {
LOG(ERROR) << "Impossibly short timeout.";
Stop();
return false;
}
return true;
}
void DNSClient::Stop() {
if (!resolver_state_.get()) {
return;
}
running_ = false;
task_factory_.RevokeAll();
ares_->Destroy(resolver_state_->channel);
resolver_state_.reset();
}
void DNSClient::HandleDNSRead(int fd) {
ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
RefreshHandles();
}
void DNSClient::HandleDNSWrite(int fd) {
ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
RefreshHandles();
}
void DNSClient::HandleTimeout() {
ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
if (!RefreshHandles()) {
// If we have timed out, ARES might still have sockets open.
// Force them closed by doing an explicit shutdown. This is
// different from HandleDNSRead and HandleDNSWrite where any
// change in our running_ state would be as a result of ARES
// itself and therefore properly synchronized with it: if a
// search completes during the course of ares_->ProcessFd(),
// the ARES fds and other state is guaranteed to have cleaned
// up and ready for a new request. Since this timeout is
// genererated outside of the library it is best to completely
// shutdown ARES and start with fresh state for a new request.
Stop();
}
}
void DNSClient::ReceiveDNSReply(int status, struct hostent *hostent) {
if (!running_) {
// We can be called during ARES shutdown -- ignore these events.
return;
}
running_ = false;
if (status == ARES_SUCCESS &&
hostent != NULL &&
hostent->h_addrtype == address_.family() &&
hostent->h_length == IPAddress::GetAddressLength(address_.family()) &&
hostent->h_addr_list != NULL &&
hostent->h_addr_list[0] != NULL) {
address_ = IPAddress(address_.family(),
ByteString(reinterpret_cast<unsigned char *>(
hostent->h_addr_list[0]), hostent->h_length));
callback_->Run(true);
} else {
switch (status) {
case ARES_ENODATA:
error_ = kErrorNoData;
break;
case ARES_EFORMERR:
error_ = kErrorFormErr;
break;
case ARES_ESERVFAIL:
error_ = kErrorServerFail;
break;
case ARES_ENOTFOUND:
error_ = kErrorNotFound;
break;
case ARES_ENOTIMP:
error_ = kErrorNotImp;
break;
case ARES_EREFUSED:
error_ = kErrorRefused;
break;
case ARES_EBADQUERY:
case ARES_EBADNAME:
case ARES_EBADFAMILY:
case ARES_EBADRESP:
error_ = kErrorBadQuery;
break;
case ARES_ECONNREFUSED:
error_ = kErrorNetRefused;
break;
case ARES_ETIMEOUT:
error_ = kErrorTimedOut;
break;
default:
error_ = kErrorUnknown;
if (status == ARES_SUCCESS) {
LOG(ERROR) << "ARES returned success but hostent was invalid!";
} else {
LOG(ERROR) << "ARES returned unhandled error status " << status;
}
break;
}
callback_->Run(false);
}
}
void DNSClient::ReceiveDNSReplyCB(void *arg, int status,
int /*timeouts*/,
struct hostent *hostent) {
DNSClient *res = static_cast<DNSClient *>(arg);
res->ReceiveDNSReply(status, hostent);
}
bool DNSClient::RefreshHandles() {
map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_read =
resolver_state_->read_handlers;
map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_write =
resolver_state_->write_handlers;
resolver_state_->read_handlers.clear();
resolver_state_->write_handlers.clear();
ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
ARES_GETSOCK_MAXNUM);
for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
if (ARES_GETSOCK_READABLE(action_bits, i)) {
if (ContainsKey(old_read, sockets[i])) {
resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
} else {
resolver_state_->read_handlers[sockets[i]] =
std::tr1::shared_ptr<IOHandler> (
dispatcher_->CreateReadyHandler(sockets[i],
IOHandler::kModeInput,
read_callback_.get()));
}
}
if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
if (ContainsKey(old_write, sockets[i])) {
resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
} else {
resolver_state_->write_handlers[sockets[i]] =
std::tr1::shared_ptr<IOHandler> (
dispatcher_->CreateReadyHandler(sockets[i],
IOHandler::kModeOutput,
write_callback_.get()));
}
}
}
if (!running_) {
// We are here just to clean up socket and timer handles, and the
// ARES state was cleaned up during the last call to ares_process_fd().
task_factory_.RevokeAll();
return false;
}
// Schedule timer event for the earlier of our timeout or one requested by
// the resolver library.
struct timeval now, elapsed_time, timeout_tv;
time_->GetTimeOfDay(&now, NULL);
timersub(&now, &resolver_state_->start_time_, &elapsed_time);
timeout_tv.tv_sec = timeout_ms_ / 1000;
timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
if (timercmp(&elapsed_time, &timeout_tv, >=)) {
// There are 3 cases of interest:
// - If we got here from Start(), we will have the side-effect of
// both invoking the callback and returning False in Start().
// Start() will call Stop() which will shut down ARES.
// - If we got here from the tail of an IO event (racing with the
// timer, we can't call Stop() since that will blow away the
// IOHandler we are running in, however we will soon be called
// again by the timeout proc so we can clean up the ARES state
// then.
// - If we got here from a timeout handler, it will safely call
// Stop() when we return false.
error_ = kErrorTimedOut;
callback_->Run(false);
running_ = false;
return false;
} else {
struct timeval max, ret_tv;
timersub(&timeout_tv, &elapsed_time, &max);
struct timeval *tv = ares_->Timeout(resolver_state_->channel,
&max, &ret_tv);
task_factory_.RevokeAll();
dispatcher_->PostDelayedTask(
task_factory_.NewRunnableMethod(&DNSClient::HandleTimeout),
tv->tv_sec * 1000 + tv->tv_usec / 1000);
}
return true;
}
} // namespace shill