blob: ba1961dd0218658a3ab5d8cb9829263083490c10 [file] [log] [blame]
Paul Stewartc2350ee2011-10-19 12:28:40 -07001// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/dns_client.h"
6
7#include <arpa/inet.h>
8#include <netdb.h>
9#include <netinet/in.h>
10#include <sys/socket.h>
11
12#include <map>
13#include <set>
14#include <string>
15#include <tr1/memory>
16#include <vector>
17
18#include <base/stl_util-inl.h>
19
20#include <shill/shill_ares.h>
21#include <shill/shill_time.h>
22
23using std::map;
24using std::set;
25using std::string;
26using std::vector;
27
28namespace shill {
29
30const int DNSClient::kDefaultTimeoutMS = 2000;
31const char DNSClient::kErrorNoData[] = "The query response contains no answers";
32const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
33const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
34const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
35const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
36const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
37const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
38const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
39const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
40const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
41
42// Private to the implementation of resolver so callers don't include ares.h
43struct DNSClientState {
44 ares_channel channel;
45 map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > read_handlers;
46 map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > write_handlers;
47 struct timeval start_time_;
48};
49
50DNSClient::DNSClient(IPAddress::Family family,
51 const string &interface_name,
52 const vector<string> &dns_servers,
53 int timeout_ms,
54 EventDispatcher *dispatcher,
55 Callback1<bool>::Type *callback)
56 : address_(IPAddress(family)),
57 interface_name_(interface_name),
58 dns_servers_(dns_servers),
59 dispatcher_(dispatcher),
60 callback_(callback),
61 timeout_ms_(timeout_ms),
62 running_(false),
63 resolver_state_(NULL),
64 read_callback_(NewCallback(this, &DNSClient::HandleDNSRead)),
65 write_callback_(NewCallback(this, &DNSClient::HandleDNSWrite)),
66 task_factory_(this),
67 ares_(Ares::GetInstance()),
68 time_(Time::GetInstance()) {}
69
70DNSClient::~DNSClient() {
71 Stop();
72}
73
74bool DNSClient::Start(const string &hostname) {
75 if (running_) {
76 LOG(ERROR) << "Only one DNS request is allowed at a time";
77 return false;
78 }
79
80 if (!resolver_state_.get()) {
81 struct ares_options options;
82 memset(&options, 0, sizeof(options));
83
84 vector<struct in_addr> server_addresses;
85 for (vector<string>::iterator it = dns_servers_.begin();
86 it != dns_servers_.end();
87 ++it) {
88 struct in_addr addr;
89 if (inet_aton(it->c_str(), &addr) != 0) {
90 server_addresses.push_back(addr);
91 }
92 }
93
94 if (server_addresses.empty()) {
95 LOG(ERROR) << "No valid DNS server addresses";
96 return false;
97 }
98
99 options.servers = server_addresses.data();
100 options.nservers = server_addresses.size();
101 options.timeout = timeout_ms_;
102
103 resolver_state_.reset(new DNSClientState);
104 int status = ares_->InitOptions(&resolver_state_->channel,
105 &options,
106 ARES_OPT_SERVERS | ARES_OPT_TIMEOUTMS);
107 if (status != ARES_SUCCESS) {
108 LOG(ERROR) << "ARES initialization returns error code: " << status;
109 resolver_state_.reset();
110 return false;
111 }
112
113 ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
114 }
115
116 running_ = true;
117 time_->GetTimeOfDay(&resolver_state_->start_time_, NULL);
118 error_.clear();
119 ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
120 address_.family(), ReceiveDNSReplyCB, this);
121
122 if (!RefreshHandles()) {
123 LOG(ERROR) << "Impossibly short timeout.";
124 Stop();
125 return false;
126 }
127
128 return true;
129}
130
131void DNSClient::Stop() {
132 if (!resolver_state_.get()) {
133 return;
134 }
135
136 running_ = false;
137 task_factory_.RevokeAll();
138 ares_->Destroy(resolver_state_->channel);
139 resolver_state_.reset();
140}
141
142void DNSClient::HandleDNSRead(int fd) {
143 ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
144 RefreshHandles();
145}
146
147void DNSClient::HandleDNSWrite(int fd) {
148 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
149 RefreshHandles();
150}
151
152void DNSClient::HandleTimeout() {
153 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
154 if (!RefreshHandles()) {
155 // If we have timed out, ARES might still have sockets open.
156 // Force them closed by doing an explicit shutdown. This is
157 // different from HandleDNSRead and HandleDNSWrite where any
158 // change in our running_ state would be as a result of ARES
159 // itself and therefore properly synchronized with it: if a
160 // search completes during the course of ares_->ProcessFd(),
161 // the ARES fds and other state is guaranteed to have cleaned
162 // up and ready for a new request. Since this timeout is
163 // genererated outside of the library it is best to completely
164 // shutdown ARES and start with fresh state for a new request.
165 Stop();
166 }
167}
168
169void DNSClient::ReceiveDNSReply(int status, struct hostent *hostent) {
170 if (!running_) {
171 // We can be called during ARES shutdown -- ignore these events.
172 return;
173 }
174 running_ = false;
175
176 if (status == ARES_SUCCESS &&
177 hostent != NULL &&
178 hostent->h_addrtype == address_.family() &&
179 hostent->h_length == IPAddress::GetAddressLength(address_.family()) &&
180 hostent->h_addr_list != NULL &&
181 hostent->h_addr_list[0] != NULL) {
182 address_ = IPAddress(address_.family(),
183 ByteString(reinterpret_cast<unsigned char *>(
184 hostent->h_addr_list[0]), hostent->h_length));
185 callback_->Run(true);
186 } else {
187 switch (status) {
188 case ARES_ENODATA:
189 error_ = kErrorNoData;
190 break;
191 case ARES_EFORMERR:
192 error_ = kErrorFormErr;
193 break;
194 case ARES_ESERVFAIL:
195 error_ = kErrorServerFail;
196 break;
197 case ARES_ENOTFOUND:
198 error_ = kErrorNotFound;
199 break;
200 case ARES_ENOTIMP:
201 error_ = kErrorNotImp;
202 break;
203 case ARES_EREFUSED:
204 error_ = kErrorRefused;
205 break;
206 case ARES_EBADQUERY:
207 case ARES_EBADNAME:
208 case ARES_EBADFAMILY:
209 case ARES_EBADRESP:
210 error_ = kErrorBadQuery;
211 break;
212 case ARES_ECONNREFUSED:
213 error_ = kErrorNetRefused;
214 break;
215 case ARES_ETIMEOUT:
216 error_ = kErrorTimedOut;
217 break;
218 default:
219 error_ = kErrorUnknown;
220 if (status == ARES_SUCCESS) {
221 LOG(ERROR) << "ARES returned success but hostent was invalid!";
222 } else {
223 LOG(ERROR) << "ARES returned unhandled error status " << status;
224 }
225 break;
226 }
227 callback_->Run(false);
228 }
229}
230
231void DNSClient::ReceiveDNSReplyCB(void *arg, int status,
232 int /*timeouts*/,
233 struct hostent *hostent) {
234 DNSClient *res = static_cast<DNSClient *>(arg);
235 res->ReceiveDNSReply(status, hostent);
236}
237
238bool DNSClient::RefreshHandles() {
239 map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_read =
240 resolver_state_->read_handlers;
241 map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_write =
242 resolver_state_->write_handlers;
243
244 resolver_state_->read_handlers.clear();
245 resolver_state_->write_handlers.clear();
246
247 ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
248 int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
249 ARES_GETSOCK_MAXNUM);
250
251 for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
252 if (ARES_GETSOCK_READABLE(action_bits, i)) {
253 if (ContainsKey(old_read, sockets[i])) {
254 resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
255 } else {
256 resolver_state_->read_handlers[sockets[i]] =
257 std::tr1::shared_ptr<IOHandler> (
258 dispatcher_->CreateReadyHandler(sockets[i],
259 IOHandler::kModeInput,
260 read_callback_.get()));
261 }
262 }
263 if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
264 if (ContainsKey(old_write, sockets[i])) {
265 resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
266 } else {
267 resolver_state_->write_handlers[sockets[i]] =
268 std::tr1::shared_ptr<IOHandler> (
269 dispatcher_->CreateReadyHandler(sockets[i],
270 IOHandler::kModeOutput,
271 write_callback_.get()));
272 }
273 }
274 }
275
276 if (!running_) {
277 // We are here just to clean up socket and timer handles, and the
278 // ARES state was cleaned up during the last call to ares_process_fd().
279 task_factory_.RevokeAll();
280 return false;
281 }
282
283 // Schedule timer event for the earlier of our timeout or one requested by
284 // the resolver library.
285 struct timeval now, elapsed_time, timeout_tv;
286 time_->GetTimeOfDay(&now, NULL);
287 timersub(&now, &resolver_state_->start_time_, &elapsed_time);
288 timeout_tv.tv_sec = timeout_ms_ / 1000;
289 timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
290 if (timercmp(&elapsed_time, &timeout_tv, >=)) {
291 // There are 3 cases of interest:
292 // - If we got here from Start(), we will have the side-effect of
293 // both invoking the callback and returning False in Start().
294 // Start() will call Stop() which will shut down ARES.
295 // - If we got here from the tail of an IO event (racing with the
296 // timer, we can't call Stop() since that will blow away the
297 // IOHandler we are running in, however we will soon be called
298 // again by the timeout proc so we can clean up the ARES state
299 // then.
300 // - If we got here from a timeout handler, it will safely call
301 // Stop() when we return false.
302 error_ = kErrorTimedOut;
303 callback_->Run(false);
304 running_ = false;
305 return false;
306 } else {
307 struct timeval max, ret_tv;
308 timersub(&timeout_tv, &elapsed_time, &max);
309 struct timeval *tv = ares_->Timeout(resolver_state_->channel,
310 &max, &ret_tv);
311 task_factory_.RevokeAll();
312 dispatcher_->PostDelayedTask(
313 task_factory_.NewRunnableMethod(&DNSClient::HandleTimeout),
314 tv->tv_sec * 1000 + tv->tv_usec / 1000);
315 }
316
317 return true;
318}
319
320} // namespace shill