blob: 542b4a909ecf62dcd1690a7790e65aa189f53e62 [file] [log] [blame]
Ben Schwartze7601812017-04-28 16:38:29 -04001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dns/DnsTlsTransport.h"
18
Ben Schwartz4204ecf2017-10-02 12:35:48 -040019#include <algorithm>
20#include <iterator>
Ben Schwartze7601812017-04-28 16:38:29 -040021#include <arpa/inet.h>
22#include <arpa/nameser.h>
23#include <errno.h>
24#include <openssl/err.h>
Ben Schwartze7601812017-04-28 16:38:29 -040025#include <stdlib.h>
26
27#define LOG_TAG "DnsTlsTransport"
28#define DBG 0
29
30#include "log/log.h"
31#include "Fwmark.h"
32#undef ADD // already defined in nameser.h
33#include "NetdConstants.h"
34#include "Permission.h"
35
Ben Schwartz4204ecf2017-10-02 12:35:48 -040036namespace {
37
38// Returns a tuple of references to the elements of a.
39auto make_tie(const sockaddr_in& a) {
40 return std::tie(a.sin_port, a.sin_addr.s_addr);
41}
42
43// Returns a tuple of references to the elements of a.
44auto make_tie(const sockaddr_in6& a) {
45 // Skip flowinfo, which is not relevant.
46 return std::tie(
47 a.sin6_port,
48 a.sin6_addr,
49 a.sin6_scope_id
50 );
51}
52
53} // namespace
54
55// These binary operators make sockaddr_storage comparable. They need to be
56// in the global namespace so that the std::tuple < and == operators can see them.
57static bool operator <(const in6_addr& x, const in6_addr& y) {
58 return std::lexicographical_compare(
59 std::begin(x.s6_addr), std::end(x.s6_addr),
60 std::begin(y.s6_addr), std::end(y.s6_addr));
61}
62
63static bool operator ==(const in6_addr& x, const in6_addr& y) {
64 return std::equal(
65 std::begin(x.s6_addr), std::end(x.s6_addr),
66 std::begin(y.s6_addr), std::end(y.s6_addr));
67}
68
69static bool operator <(const sockaddr_storage& x, const sockaddr_storage& y) {
70 if (x.ss_family != y.ss_family) {
71 return x.ss_family < y.ss_family;
72 }
73 // Same address family.
74 if (x.ss_family == AF_INET) {
75 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
76 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
77 return make_tie(x_sin) < make_tie(y_sin);
78 } else if (x.ss_family == AF_INET6) {
79 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
80 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
81 return make_tie(x_sin6) < make_tie(y_sin6);
82 }
83 return false; // Unknown address type. This is an error.
84}
85
86static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) {
87 if (x.ss_family != y.ss_family) {
88 return false;
89 }
90 // Same address family.
91 if (x.ss_family == AF_INET) {
92 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
93 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
94 return make_tie(x_sin) == make_tie(y_sin);
95 } else if (x.ss_family == AF_INET6) {
96 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
97 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
98 return make_tie(x_sin6) == make_tie(y_sin6);
99 }
100 return false; // Unknown address type. This is an error.
101}
Ben Schwartze7601812017-04-28 16:38:29 -0400102
103namespace android {
104namespace net {
105
106namespace {
107
108bool setNonBlocking(int fd, bool enabled) {
109 int flags = fcntl(fd, F_GETFL);
110 if (flags < 0) return false;
111
112 if (enabled) {
113 flags |= O_NONBLOCK;
114 } else {
115 flags &= ~O_NONBLOCK;
116 }
117 return (fcntl(fd, F_SETFL, flags) == 0);
118}
119
120int waitForReading(int fd) {
121 fd_set fds;
122 FD_ZERO(&fds);
123 FD_SET(fd, &fds);
124 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
125 if (DBG && ret <= 0) {
126 ALOGD("select");
127 }
128 return ret;
129}
130
131int waitForWriting(int fd) {
132 fd_set fds;
133 FD_ZERO(&fds);
134 FD_SET(fd, &fds);
135 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
136 if (DBG && ret <= 0) {
137 ALOGD("select");
138 }
139 return ret;
140}
141
142} // namespace
143
144android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
Ben Schwartza13c23a2017-10-02 12:06:21 -0400145 if (DBG) {
146 ALOGD("%u connecting TCP socket", mMark);
147 }
Ben Schwartze7601812017-04-28 16:38:29 -0400148 android::base::unique_fd fd;
149 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
Ben Schwartz52504622017-07-11 12:21:13 -0400150 switch (mServer.protocol) {
Ben Schwartze7601812017-04-28 16:38:29 -0400151 case IPPROTO_TCP:
152 type |= SOCK_STREAM;
153 break;
154 default:
155 errno = EPROTONOSUPPORT;
156 return fd;
157 }
158
Ben Schwartz52504622017-07-11 12:21:13 -0400159 fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
Ben Schwartze7601812017-04-28 16:38:29 -0400160 if (fd.get() == -1) {
161 return fd;
162 }
163
164 const socklen_t len = sizeof(mMark);
165 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
166 fd.reset();
167 } else if (connect(fd.get(),
Ben Schwartz52504622017-07-11 12:21:13 -0400168 reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
Ben Schwartze7601812017-04-28 16:38:29 -0400169 && errno != EINPROGRESS) {
170 fd.reset();
171 }
172
Ben Schwartza13c23a2017-10-02 12:06:21 -0400173 if (!setNonBlocking(fd, false)) {
174 ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
175 fd.reset();
176 }
177
Ben Schwartze7601812017-04-28 16:38:29 -0400178 return fd;
179}
180
181bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
182 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
183 unsigned char spki[spki_len];
184 unsigned char* temp = spki;
185 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
186 ALOGW("SPKI length mismatch");
187 return false;
188 }
189 out->resize(SHA256_SIZE);
190 unsigned int digest_len = 0;
191 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
192 if (ret != 1) {
193 ALOGW("Server cert digest extraction failed");
194 return false;
195 }
196 if (digest_len != out->size()) {
197 ALOGW("Wrong digest length: %d", digest_len);
198 return false;
199 }
200 return true;
201}
202
Ben Schwartz4204ecf2017-10-02 12:35:48 -0400203// This comparison ignores ports and fingerprints.
204// TODO: respect IPv6 scope id (e.g. link-local addresses).
205bool AddressComparator::operator() (const DnsTlsTransport::Server& x,
206 const DnsTlsTransport::Server& y) const {
207 if (x.ss.ss_family != y.ss.ss_family) {
208 return x.ss.ss_family < y.ss.ss_family;
209 }
210 // Same address family.
211 if (x.ss.ss_family == AF_INET) {
212 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
213 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
214 return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
215 } else if (x.ss.ss_family == AF_INET6) {
216 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
217 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
218 return x_sin6.sin6_addr < y_sin6.sin6_addr;
219 }
220 return false; // Unknown address type. This is an error.
221}
222
223// Returns a tuple of references to the elements of s.
224auto make_tie(const DnsTlsTransport::Server& s) {
225 return std::tie(
226 s.ss,
227 s.name,
228 s.fingerprints,
229 s.protocol
230 );
231}
232
233bool DnsTlsTransport::Server::operator <(const DnsTlsTransport::Server& other) const {
234 return make_tie(*this) < make_tie(other);
235}
236
237bool DnsTlsTransport::Server::operator ==(const DnsTlsTransport::Server& other) const {
238 return make_tie(*this) == make_tie(other);
239}
240
Ben Schwartza13c23a2017-10-02 12:06:21 -0400241bool DnsTlsTransport::initialize() {
242 mSslCtx.reset(SSL_CTX_new(TLS_method()));
243 if (!mSslCtx) {
244 return false;
245 }
246 SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
247 SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
248 return true;
249}
250
251bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
252 // Check TLS context.
253 if (!mSslCtx) {
254 ALOGE("Internal error: context is null in ssl connect");
255 return nullptr;
256 }
257 if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
258 !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
259 ALOGE("failed to min/max TLS versions");
Ben Schwartze7601812017-04-28 16:38:29 -0400260 return nullptr;
261 }
262
Ben Schwartza13c23a2017-10-02 12:06:21 -0400263 bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
Ben Schwartz4204ecf2017-10-02 12:35:48 -0400264 // This file descriptor is owned by a unique_fd, so don't let libssl close it.
265 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
Ben Schwartze7601812017-04-28 16:38:29 -0400266 SSL_set_bio(ssl.get(), bio.get(), bio.get());
267 bio.release();
268
Ben Schwartza13c23a2017-10-02 12:06:21 -0400269 // Add this transport as the 0-index extra data for the socket.
270 // This is used by newSessionCallback.
271 if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
272 ALOGE("failed to associate SSL socket to transport");
273 return nullptr;
274 }
275
276 // Add this transport as the 0-index extra data for the context.
277 // This is used by removeSessionCallback.
278 if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
279 ALOGE("failed to associate SSL context to transport");
Ben Schwartze7601812017-04-28 16:38:29 -0400280 return nullptr;
281 }
282
Ben Schwartz1691bc42017-08-16 12:53:09 -0400283 if (!mServer.name.empty()) {
284 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
285 ALOGE("Failed to set SNI to %s", mServer.name.c_str());
286 return nullptr;
287 }
288 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
289 X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
290 // This will cause the handshake to fail if certificate verification fails.
291 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
292 }
293
Ben Schwartza13c23a2017-10-02 12:06:21 -0400294 bssl::UniquePtr<SSL_SESSION> session;
295 {
296 std::lock_guard<std::mutex> guard(sLock);
297 if (!mSessions.empty()) {
298 session = std::move(mSessions.front());
299 mSessions.pop_front();
300 } else if (DBG) {
301 ALOGD("Starting without session ticket.");
302 }
303 }
304 if (session) {
305 SSL_set_session(ssl.get(), session.get());
306 }
307
Ben Schwartze7601812017-04-28 16:38:29 -0400308 for (;;) {
309 if (DBG) {
310 ALOGD("%u Calling SSL_connect", mMark);
311 }
312 int ret = SSL_connect(ssl.get());
313 if (DBG) {
314 ALOGD("%u SSL_connect returned %d", mMark, ret);
315 }
316 if (ret == 1) break; // SSL handshake complete;
317
318 const int ssl_err = SSL_get_error(ssl.get(), ret);
319 switch (ssl_err) {
320 case SSL_ERROR_WANT_READ:
321 if (waitForReading(fd) != 1) {
322 ALOGW("SSL_connect read error");
323 return nullptr;
324 }
325 break;
326 case SSL_ERROR_WANT_WRITE:
327 if (waitForWriting(fd) != 1) {
328 ALOGW("SSL_connect write error");
329 return nullptr;
330 }
331 break;
332 default:
333 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
334 return nullptr;
335 }
336 }
337
Ben Schwartz52504622017-07-11 12:21:13 -0400338 if (!mServer.fingerprints.empty()) {
Ben Schwartze7601812017-04-28 16:38:29 -0400339 if (DBG) {
340 ALOGD("Checking DNS over TLS fingerprint");
341 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400342
343 // We only care that the chain is internally self-consistent, not that
344 // it chains to a trusted root, so we can ignore some kinds of errors.
345 // TODO: Add a CA root verification mode that respects these errors.
346 int verify_result = SSL_get_verify_result(ssl.get());
347 switch (verify_result) {
348 case X509_V_OK:
349 case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
350 case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
351 case X509_V_ERR_CERT_UNTRUSTED:
352 break;
353 default:
354 ALOGW("Invalid certificate chain, error %d", verify_result);
355 return nullptr;
356 }
357
358 STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
359 if (!chain) {
Ben Schwartze7601812017-04-28 16:38:29 -0400360 ALOGW("Server has null certificate");
361 return nullptr;
362 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400363 // Chain and its contents are owned by ssl, so we don't need to free explicitly.
364 bool matched = false;
365 for (size_t i = 0; i < sk_X509_num(chain); ++i) {
366 // This appears to be O(N^2), but there doesn't seem to be a straightforward
367 // way to walk a STACK_OF nondestructively in linear time.
368 X509* cert = sk_X509_value(chain, i);
369 std::vector<uint8_t> digest;
370 if (!getSPKIDigest(cert, &digest)) {
371 ALOGE("Digest computation failed");
372 return nullptr;
373 }
374
375 if (mServer.fingerprints.count(digest) > 0) {
376 matched = true;
377 break;
378 }
Ben Schwartze7601812017-04-28 16:38:29 -0400379 }
380
Ben Schwartzf028d392017-07-10 15:07:12 -0400381 if (!matched) {
Ben Schwartze7601812017-04-28 16:38:29 -0400382 ALOGW("No matching fingerprint");
383 return nullptr;
384 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400385
Ben Schwartze7601812017-04-28 16:38:29 -0400386 if (DBG) {
387 ALOGD("DNS over TLS fingerprint is correct");
388 }
389 }
390
391 if (DBG) {
392 ALOGD("%u handshake complete", mMark);
393 }
Ben Schwartza13c23a2017-10-02 12:06:21 -0400394
395 return ssl;
396}
397
398// static
399int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
400 if (!session) {
401 return 0;
402 }
403 if (DBG) {
404 ALOGD("Recording session ticket");
405 }
406 DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
407 SSL_get_ex_data(ssl, 0));
408 if (!xport) {
409 ALOGE("null transport in new session callback");
410 return 0;
411 }
412 xport->recordSession(session);
413 return 1;
414}
415
416void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
417 if (DBG) {
418 ALOGD("Removing session ticket");
419 }
420 DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
421 SSL_CTX_get_ex_data(ssl_ctx, 0));
422 if (!xport) {
423 ALOGE("null transport in remove session callback");
424 return;
425 }
426 xport->removeSession(session);
427}
428
429void DnsTlsTransport::recordSession(SSL_SESSION* session) {
430 std::lock_guard<std::mutex> guard(sLock);
431 mSessions.emplace_front(session);
432 if (mSessions.size() > 5) {
433 if (DBG) {
434 ALOGD("Too many sessions; trimming");
435 }
436 mSessions.pop_back();
437 }
438}
439
440void DnsTlsTransport::removeSession(SSL_SESSION* session) {
441 std::lock_guard<std::mutex> guard(sLock);
442 if (session) {
443 // TODO: Consider implementing targeted removal.
444 mSessions.clear();
445 }
446}
447
448void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
449 if (ssl) {
450 SSL_shutdown(ssl.get());
451 ssl.reset();
452 }
453 fd.reset();
Ben Schwartze7601812017-04-28 16:38:29 -0400454}
455
456bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
457 if (DBG) {
458 ALOGD("%u Writing %d bytes", mMark, len);
459 }
460 for (;;) {
461 int ret = SSL_write(ssl, buffer, len);
462 if (ret == len) break; // SSL write complete;
463
464 if (ret < 1) {
465 const int ssl_err = SSL_get_error(ssl, ret);
466 switch (ssl_err) {
467 case SSL_ERROR_WANT_WRITE:
468 if (waitForWriting(fd) != 1) {
469 if (DBG) {
470 ALOGW("SSL_write error");
471 }
472 return false;
473 }
474 continue;
475 case 0:
476 break; // SSL write complete;
477 default:
478 if (DBG) {
479 ALOGW("SSL_write error %d", ssl_err);
480 }
481 return false;
482 }
483 }
484 }
485 if (DBG) {
486 ALOGD("%u Wrote %d bytes", mMark, len);
487 }
488 return true;
489}
490
491// Read exactly len bytes into buffer or fail
492bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
493 int remaining = len;
494 while (remaining > 0) {
495 int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
496 if (ret == 0) {
497 ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
498 return false;
499 }
500
501 if (ret < 0) {
502 const int ssl_err = SSL_get_error(ssl, ret);
503 if (ssl_err == SSL_ERROR_WANT_READ) {
504 if (waitForReading(fd) != 1) {
505 if (DBG) {
506 ALOGW("SSL_read error");
507 }
508 return false;
509 }
510 continue;
511 } else {
512 if (DBG) {
513 ALOGW("SSL_read error %d", ssl_err);
514 }
515 return false;
516 }
517 }
518
519 remaining -= ret;
520 }
521 return true;
522}
523
Ben Schwartz52504622017-07-11 12:21:13 -0400524// static
Ben Schwartza13c23a2017-10-02 12:06:21 -0400525std::mutex DnsTlsTransport::sLock;
526std::map<DnsTlsTransport::Key, std::unique_ptr<DnsTlsTransport>> DnsTlsTransport::sStore;
Ben Schwartz52504622017-07-11 12:21:13 -0400527DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
528 const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
Ben Schwartza13c23a2017-10-02 12:06:21 -0400529 const Key key = std::make_pair(mark, server);
530 DnsTlsTransport* xport;
531 {
532 std::lock_guard<std::mutex> guard(sLock);
533 auto it = sStore.find(key);
534 if (it == sStore.end()) {
535 xport = new DnsTlsTransport(server, mark);
536 if (!xport->initialize()) {
537 return DnsTlsTransport::Response::internal_error;
538 }
539 sStore[key].reset(xport);
540 } else {
541 xport = it->second.get();
542 }
543 ++xport->mUseCount;
544 }
545
546 Response res = xport->doQuery(query, qlen, response, limit, resplen);
547 auto now = std::chrono::steady_clock::now();
548 {
549 std::lock_guard<std::mutex> guard(sLock);
550 --xport->mUseCount;
551 xport->mLastUsed = now;
552 cleanup(now);
553 }
554 return res;
555}
556
557static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
558std::chrono::time_point<std::chrono::steady_clock> DnsTlsTransport::sLastCleanup;
559void DnsTlsTransport::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
560 if (now - sLastCleanup < IDLE_TIMEOUT) {
561 return;
562 }
563 for (auto it = sStore.begin(); it != sStore.end(); ) {
564 auto& xport = it->second;
565 if (xport->mUseCount == 0 && now - xport->mLastUsed > IDLE_TIMEOUT) {
566 it = sStore.erase(it);
567 } else {
568 ++it;
569 }
570 }
571 sLastCleanup = now;
Ben Schwartz52504622017-07-11 12:21:13 -0400572}
573
Ben Schwartze7601812017-04-28 16:38:29 -0400574DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
575 uint8_t *response, size_t limit, int *resplen) {
Ben Schwartza13c23a2017-10-02 12:06:21 -0400576 android::base::unique_fd fd = makeConnectedSocket();
577 if (fd.get() < 0) {
578 ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
579 return Response::network_error;
Ben Schwartze7601812017-04-28 16:38:29 -0400580 }
Ben Schwartza13c23a2017-10-02 12:06:21 -0400581 bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
582 if (!ssl) {
Ben Schwartze7601812017-04-28 16:38:29 -0400583 return Response::network_error;
584 }
585
Ben Schwartza13c23a2017-10-02 12:06:21 -0400586 Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
587 if (res == Response::success) {
588 res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
589 }
590
591 sslDisconnect(std::move(ssl), std::move(fd));
592 return res;
593}
594
595DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen) {
596 if (DBG) {
597 ALOGD("sending query");
598 }
Ben Schwartze7601812017-04-28 16:38:29 -0400599 uint8_t queryHeader[2];
600 queryHeader[0] = qlen >> 8;
601 queryHeader[1] = qlen;
Ben Schwartza13c23a2017-10-02 12:06:21 -0400602 if (!sslWrite(fd, ssl, queryHeader, 2)) {
Ben Schwartze7601812017-04-28 16:38:29 -0400603 return Response::network_error;
604 }
Ben Schwartza13c23a2017-10-02 12:06:21 -0400605 if (!sslWrite(fd, ssl, query, qlen)) {
Ben Schwartze7601812017-04-28 16:38:29 -0400606 return Response::network_error;
607 }
608 if (DBG) {
609 ALOGD("%u SSL_write complete", mMark);
610 }
Ben Schwartza13c23a2017-10-02 12:06:21 -0400611 return Response::success;
612}
Ben Schwartze7601812017-04-28 16:38:29 -0400613
Ben Schwartza13c23a2017-10-02 12:06:21 -0400614DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
615 if (DBG) {
616 ALOGD("reading response");
617 }
Ben Schwartze7601812017-04-28 16:38:29 -0400618 uint8_t responseHeader[2];
Ben Schwartza13c23a2017-10-02 12:06:21 -0400619 if (!sslRead(fd, ssl, responseHeader, 2)) {
Ben Schwartze7601812017-04-28 16:38:29 -0400620 if (DBG) {
621 ALOGW("%u Failed to read 2-byte length header", mMark);
622 }
623 return Response::network_error;
624 }
625 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
626 if (DBG) {
627 ALOGD("%u Expecting response of size %i", mMark, responseSize);
628 }
629 if (responseSize > limit) {
630 ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
631 return Response::limit_error;
632 }
Ben Schwartza13c23a2017-10-02 12:06:21 -0400633 if (!sslRead(fd, ssl, response, responseSize)) {
Ben Schwartze7601812017-04-28 16:38:29 -0400634 if (DBG) {
635 ALOGW("%u Failed to read %i bytes", mMark, responseSize);
636 }
637 return Response::network_error;
638 }
639 if (DBG) {
640 ALOGD("%u SSL_read complete", mMark);
641 }
642
643 if (response[0] != query[0] || response[1] != query[1]) {
644 ALOGE("reply query ID != query ID");
645 return Response::internal_error;
646 }
647
Ben Schwartze7601812017-04-28 16:38:29 -0400648 *resplen = responseSize;
649 return Response::success;
650}
651
Ben Schwartz52504622017-07-11 12:21:13 -0400652// static
653bool DnsTlsTransport::validate(const Server& server, unsigned netid) {
Ben Schwartze7601812017-04-28 16:38:29 -0400654 if (DBG) {
655 ALOGD("Beginning validation on %u", netid);
656 }
657 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
658 // order to prove that it is actually a working DNS over TLS server.
659 static const char kDnsSafeChars[] =
660 "abcdefhijklmnopqrstuvwxyz"
661 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
662 "0123456789";
663 const auto c = [](uint8_t rnd) -> uint8_t {
664 return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
665 };
666 uint8_t rnd[8];
667 arc4random_buf(rnd, ARRAY_SIZE(rnd));
668 // We could try to use res_mkquery() here, but it's basically the same.
669 uint8_t query[] = {
670 rnd[6], rnd[7], // [0-1] query ID
671 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
672 0, 1, // [4-5] QDCOUNT (number of queries)
673 0, 0, // [6-7] ANCOUNT (number of answers)
674 0, 0, // [8-9] NSCOUNT (number of name server records)
675 0, 0, // [10-11] ARCOUNT (number of additional records)
676 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
677 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
678 6, 'm', 'e', 't', 'r', 'i', 'c',
679 7, 'g', 's', 't', 'a', 't', 'i', 'c',
680 3, 'c', 'o', 'm',
681 0, // null terminator of FQDN (root TLD)
682 0, ns_t_aaaa, // QTYPE
683 0, ns_c_in // QCLASS
684 };
685 const int qlen = ARRAY_SIZE(query);
686
687 const int kRecvBufSize = 4 * 1024;
688 uint8_t recvbuf[kRecvBufSize];
689
690 // At validation time, we only know the netId, so we have to guess/compute the
691 // corresponding socket mark.
692 Fwmark fwmark;
693 fwmark.permission = PERMISSION_SYSTEM;
694 fwmark.explicitlySelected = true;
695 fwmark.protectedFromVpn = true;
696 fwmark.netId = netid;
697 unsigned mark = fwmark.intValue;
Ben Schwartze7601812017-04-28 16:38:29 -0400698 int replylen = 0;
Ben Schwartz52504622017-07-11 12:21:13 -0400699 DnsTlsTransport::query(server, mark, query, qlen, recvbuf, kRecvBufSize, &replylen);
Ben Schwartze7601812017-04-28 16:38:29 -0400700 if (replylen == 0) {
701 if (DBG) {
Ben Schwartz52504622017-07-11 12:21:13 -0400702 ALOGD("query failed");
Ben Schwartze7601812017-04-28 16:38:29 -0400703 }
704 return false;
705 }
706
707 if (replylen < NS_HFIXEDSZ) {
708 if (DBG) {
709 ALOGW("short response: %d", replylen);
710 }
711 return false;
712 }
713
714 const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
715 if (qdcount != 1) {
716 ALOGW("reply query count != 1: %d", qdcount);
717 return false;
718 }
719
720 const int ancount = (recvbuf[6] << 8) | recvbuf[7];
721 if (DBG) {
722 ALOGD("%u answer count: %d", netid, ancount);
723 }
724
725 // TODO: Further validate the response contents (check for valid AAAA record, ...).
726 // Note that currently, integration tests rely on this function accepting a
727 // response with zero records.
728#if 0
729 for (int i = 0; i < resplen; i++) {
730 ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
731 }
732#endif
733 return true;
734}
735
736} // namespace net
737} // namespace android