blob: b369022c5480526cbcd7aa2f642beaf54be1d689 [file] [log] [blame]
Ben Schwartze7601812017-04-28 16:38:29 -04001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dns/DnsTlsTransport.h"
18
Ben Schwartz4204ecf2017-10-02 12:35:48 -040019#include <algorithm>
20#include <iterator>
Ben Schwartze7601812017-04-28 16:38:29 -040021#include <arpa/inet.h>
22#include <arpa/nameser.h>
23#include <errno.h>
24#include <openssl/err.h>
25#include <openssl/ssl.h>
26#include <stdlib.h>
27
28#define LOG_TAG "DnsTlsTransport"
29#define DBG 0
30
31#include "log/log.h"
32#include "Fwmark.h"
33#undef ADD // already defined in nameser.h
34#include "NetdConstants.h"
35#include "Permission.h"
36
Ben Schwartz4204ecf2017-10-02 12:35:48 -040037namespace {
38
39// Returns a tuple of references to the elements of a.
40auto make_tie(const sockaddr_in& a) {
41 return std::tie(a.sin_port, a.sin_addr.s_addr);
42}
43
44// Returns a tuple of references to the elements of a.
45auto make_tie(const sockaddr_in6& a) {
46 // Skip flowinfo, which is not relevant.
47 return std::tie(
48 a.sin6_port,
49 a.sin6_addr,
50 a.sin6_scope_id
51 );
52}
53
54} // namespace
55
56// These binary operators make sockaddr_storage comparable. They need to be
57// in the global namespace so that the std::tuple < and == operators can see them.
58static bool operator <(const in6_addr& x, const in6_addr& y) {
59 return std::lexicographical_compare(
60 std::begin(x.s6_addr), std::end(x.s6_addr),
61 std::begin(y.s6_addr), std::end(y.s6_addr));
62}
63
64static bool operator ==(const in6_addr& x, const in6_addr& y) {
65 return std::equal(
66 std::begin(x.s6_addr), std::end(x.s6_addr),
67 std::begin(y.s6_addr), std::end(y.s6_addr));
68}
69
70static bool operator <(const sockaddr_storage& x, const sockaddr_storage& y) {
71 if (x.ss_family != y.ss_family) {
72 return x.ss_family < y.ss_family;
73 }
74 // Same address family.
75 if (x.ss_family == AF_INET) {
76 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
77 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
78 return make_tie(x_sin) < make_tie(y_sin);
79 } else if (x.ss_family == AF_INET6) {
80 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
81 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
82 return make_tie(x_sin6) < make_tie(y_sin6);
83 }
84 return false; // Unknown address type. This is an error.
85}
86
87static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) {
88 if (x.ss_family != y.ss_family) {
89 return false;
90 }
91 // Same address family.
92 if (x.ss_family == AF_INET) {
93 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
94 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
95 return make_tie(x_sin) == make_tie(y_sin);
96 } else if (x.ss_family == AF_INET6) {
97 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
98 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
99 return make_tie(x_sin6) == make_tie(y_sin6);
100 }
101 return false; // Unknown address type. This is an error.
102}
Ben Schwartze7601812017-04-28 16:38:29 -0400103
104namespace android {
105namespace net {
106
107namespace {
108
109bool setNonBlocking(int fd, bool enabled) {
110 int flags = fcntl(fd, F_GETFL);
111 if (flags < 0) return false;
112
113 if (enabled) {
114 flags |= O_NONBLOCK;
115 } else {
116 flags &= ~O_NONBLOCK;
117 }
118 return (fcntl(fd, F_SETFL, flags) == 0);
119}
120
121int waitForReading(int fd) {
122 fd_set fds;
123 FD_ZERO(&fds);
124 FD_SET(fd, &fds);
125 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
126 if (DBG && ret <= 0) {
127 ALOGD("select");
128 }
129 return ret;
130}
131
132int waitForWriting(int fd) {
133 fd_set fds;
134 FD_ZERO(&fds);
135 FD_SET(fd, &fds);
136 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
137 if (DBG && ret <= 0) {
138 ALOGD("select");
139 }
140 return ret;
141}
142
143} // namespace
144
145android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
146 android::base::unique_fd fd;
147 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
Ben Schwartz52504622017-07-11 12:21:13 -0400148 switch (mServer.protocol) {
Ben Schwartze7601812017-04-28 16:38:29 -0400149 case IPPROTO_TCP:
150 type |= SOCK_STREAM;
151 break;
152 default:
153 errno = EPROTONOSUPPORT;
154 return fd;
155 }
156
Ben Schwartz52504622017-07-11 12:21:13 -0400157 fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
Ben Schwartze7601812017-04-28 16:38:29 -0400158 if (fd.get() == -1) {
159 return fd;
160 }
161
162 const socklen_t len = sizeof(mMark);
163 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
164 fd.reset();
165 } else if (connect(fd.get(),
Ben Schwartz52504622017-07-11 12:21:13 -0400166 reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
Ben Schwartze7601812017-04-28 16:38:29 -0400167 && errno != EINPROGRESS) {
168 fd.reset();
169 }
170
171 return fd;
172}
173
174bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
175 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
176 unsigned char spki[spki_len];
177 unsigned char* temp = spki;
178 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
179 ALOGW("SPKI length mismatch");
180 return false;
181 }
182 out->resize(SHA256_SIZE);
183 unsigned int digest_len = 0;
184 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
185 if (ret != 1) {
186 ALOGW("Server cert digest extraction failed");
187 return false;
188 }
189 if (digest_len != out->size()) {
190 ALOGW("Wrong digest length: %d", digest_len);
191 return false;
192 }
193 return true;
194}
195
Ben Schwartz4204ecf2017-10-02 12:35:48 -0400196// This comparison ignores ports and fingerprints.
197// TODO: respect IPv6 scope id (e.g. link-local addresses).
198bool AddressComparator::operator() (const DnsTlsTransport::Server& x,
199 const DnsTlsTransport::Server& y) const {
200 if (x.ss.ss_family != y.ss.ss_family) {
201 return x.ss.ss_family < y.ss.ss_family;
202 }
203 // Same address family.
204 if (x.ss.ss_family == AF_INET) {
205 const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
206 const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
207 return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
208 } else if (x.ss.ss_family == AF_INET6) {
209 const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
210 const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
211 return x_sin6.sin6_addr < y_sin6.sin6_addr;
212 }
213 return false; // Unknown address type. This is an error.
214}
215
216// Returns a tuple of references to the elements of s.
217auto make_tie(const DnsTlsTransport::Server& s) {
218 return std::tie(
219 s.ss,
220 s.name,
221 s.fingerprints,
222 s.protocol
223 );
224}
225
226bool DnsTlsTransport::Server::operator <(const DnsTlsTransport::Server& other) const {
227 return make_tie(*this) < make_tie(other);
228}
229
230bool DnsTlsTransport::Server::operator ==(const DnsTlsTransport::Server& other) const {
231 return make_tie(*this) == make_tie(other);
232}
233
Ben Schwartze7601812017-04-28 16:38:29 -0400234SSL* DnsTlsTransport::sslConnect(int fd) {
235 if (fd < 0) {
236 ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
237 return nullptr;
238 }
239
240 // Set up TLS context.
241 bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
242 if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
243 !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
244 ALOGD("failed to min/max TLS versions");
245 return nullptr;
246 }
247
248 bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
Ben Schwartz4204ecf2017-10-02 12:35:48 -0400249 // This file descriptor is owned by a unique_fd, so don't let libssl close it.
250 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
Ben Schwartze7601812017-04-28 16:38:29 -0400251 SSL_set_bio(ssl.get(), bio.get(), bio.get());
252 bio.release();
253
254 if (!setNonBlocking(fd, false)) {
255 ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
256 return nullptr;
257 }
258
Ben Schwartz1691bc42017-08-16 12:53:09 -0400259 if (!mServer.name.empty()) {
260 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
261 ALOGE("Failed to set SNI to %s", mServer.name.c_str());
262 return nullptr;
263 }
264 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
265 X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
266 // This will cause the handshake to fail if certificate verification fails.
267 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
268 }
269
Ben Schwartze7601812017-04-28 16:38:29 -0400270 for (;;) {
271 if (DBG) {
272 ALOGD("%u Calling SSL_connect", mMark);
273 }
274 int ret = SSL_connect(ssl.get());
275 if (DBG) {
276 ALOGD("%u SSL_connect returned %d", mMark, ret);
277 }
278 if (ret == 1) break; // SSL handshake complete;
279
280 const int ssl_err = SSL_get_error(ssl.get(), ret);
281 switch (ssl_err) {
282 case SSL_ERROR_WANT_READ:
283 if (waitForReading(fd) != 1) {
284 ALOGW("SSL_connect read error");
285 return nullptr;
286 }
287 break;
288 case SSL_ERROR_WANT_WRITE:
289 if (waitForWriting(fd) != 1) {
290 ALOGW("SSL_connect write error");
291 return nullptr;
292 }
293 break;
294 default:
295 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
296 return nullptr;
297 }
298 }
299
Ben Schwartz52504622017-07-11 12:21:13 -0400300 if (!mServer.fingerprints.empty()) {
Ben Schwartze7601812017-04-28 16:38:29 -0400301 if (DBG) {
302 ALOGD("Checking DNS over TLS fingerprint");
303 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400304
305 // We only care that the chain is internally self-consistent, not that
306 // it chains to a trusted root, so we can ignore some kinds of errors.
307 // TODO: Add a CA root verification mode that respects these errors.
308 int verify_result = SSL_get_verify_result(ssl.get());
309 switch (verify_result) {
310 case X509_V_OK:
311 case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
312 case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
313 case X509_V_ERR_CERT_UNTRUSTED:
314 break;
315 default:
316 ALOGW("Invalid certificate chain, error %d", verify_result);
317 return nullptr;
318 }
319
320 STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
321 if (!chain) {
Ben Schwartze7601812017-04-28 16:38:29 -0400322 ALOGW("Server has null certificate");
323 return nullptr;
324 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400325 // Chain and its contents are owned by ssl, so we don't need to free explicitly.
326 bool matched = false;
327 for (size_t i = 0; i < sk_X509_num(chain); ++i) {
328 // This appears to be O(N^2), but there doesn't seem to be a straightforward
329 // way to walk a STACK_OF nondestructively in linear time.
330 X509* cert = sk_X509_value(chain, i);
331 std::vector<uint8_t> digest;
332 if (!getSPKIDigest(cert, &digest)) {
333 ALOGE("Digest computation failed");
334 return nullptr;
335 }
336
337 if (mServer.fingerprints.count(digest) > 0) {
338 matched = true;
339 break;
340 }
Ben Schwartze7601812017-04-28 16:38:29 -0400341 }
342
Ben Schwartzf028d392017-07-10 15:07:12 -0400343 if (!matched) {
Ben Schwartze7601812017-04-28 16:38:29 -0400344 ALOGW("No matching fingerprint");
345 return nullptr;
346 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400347
Ben Schwartze7601812017-04-28 16:38:29 -0400348 if (DBG) {
349 ALOGD("DNS over TLS fingerprint is correct");
350 }
351 }
352
353 if (DBG) {
354 ALOGD("%u handshake complete", mMark);
355 }
356 return ssl.release();
357}
358
359bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
360 if (DBG) {
361 ALOGD("%u Writing %d bytes", mMark, len);
362 }
363 for (;;) {
364 int ret = SSL_write(ssl, buffer, len);
365 if (ret == len) break; // SSL write complete;
366
367 if (ret < 1) {
368 const int ssl_err = SSL_get_error(ssl, ret);
369 switch (ssl_err) {
370 case SSL_ERROR_WANT_WRITE:
371 if (waitForWriting(fd) != 1) {
372 if (DBG) {
373 ALOGW("SSL_write error");
374 }
375 return false;
376 }
377 continue;
378 case 0:
379 break; // SSL write complete;
380 default:
381 if (DBG) {
382 ALOGW("SSL_write error %d", ssl_err);
383 }
384 return false;
385 }
386 }
387 }
388 if (DBG) {
389 ALOGD("%u Wrote %d bytes", mMark, len);
390 }
391 return true;
392}
393
394// Read exactly len bytes into buffer or fail
395bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
396 int remaining = len;
397 while (remaining > 0) {
398 int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
399 if (ret == 0) {
400 ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
401 return false;
402 }
403
404 if (ret < 0) {
405 const int ssl_err = SSL_get_error(ssl, ret);
406 if (ssl_err == SSL_ERROR_WANT_READ) {
407 if (waitForReading(fd) != 1) {
408 if (DBG) {
409 ALOGW("SSL_read error");
410 }
411 return false;
412 }
413 continue;
414 } else {
415 if (DBG) {
416 ALOGW("SSL_read error %d", ssl_err);
417 }
418 return false;
419 }
420 }
421
422 remaining -= ret;
423 }
424 return true;
425}
426
Ben Schwartz52504622017-07-11 12:21:13 -0400427// static
428DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
429 const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
430 // TODO: Keep a static container of transports instead of constructing a new one
431 // for every query.
432 DnsTlsTransport xport(server, mark);
433 return xport.doQuery(query, qlen, response, limit, resplen);
434}
435
Ben Schwartze7601812017-04-28 16:38:29 -0400436DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
437 uint8_t *response, size_t limit, int *resplen) {
438 *resplen = 0; // Zero indicates an error.
439
440 if (DBG) {
441 ALOGD("%u connecting TCP socket", mMark);
442 }
443 android::base::unique_fd fd(makeConnectedSocket());
444 if (DBG) {
445 ALOGD("%u connecting SSL", mMark);
446 }
447 bssl::UniquePtr<SSL> ssl(sslConnect(fd));
448 if (ssl == nullptr) {
449 if (DBG) {
450 ALOGW("%u SSL connection failed", mMark);
451 }
452 return Response::network_error;
453 }
454
455 uint8_t queryHeader[2];
456 queryHeader[0] = qlen >> 8;
457 queryHeader[1] = qlen;
458 if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
459 return Response::network_error;
460 }
461 if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
462 return Response::network_error;
463 }
464 if (DBG) {
465 ALOGD("%u SSL_write complete", mMark);
466 }
467
468 uint8_t responseHeader[2];
469 if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
470 if (DBG) {
471 ALOGW("%u Failed to read 2-byte length header", mMark);
472 }
473 return Response::network_error;
474 }
475 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
476 if (DBG) {
477 ALOGD("%u Expecting response of size %i", mMark, responseSize);
478 }
479 if (responseSize > limit) {
480 ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
481 return Response::limit_error;
482 }
483 if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
484 if (DBG) {
485 ALOGW("%u Failed to read %i bytes", mMark, responseSize);
486 }
487 return Response::network_error;
488 }
489 if (DBG) {
490 ALOGD("%u SSL_read complete", mMark);
491 }
492
493 if (response[0] != query[0] || response[1] != query[1]) {
494 ALOGE("reply query ID != query ID");
495 return Response::internal_error;
496 }
497
498 SSL_shutdown(ssl.get());
499
500 *resplen = responseSize;
501 return Response::success;
502}
503
Ben Schwartz52504622017-07-11 12:21:13 -0400504// static
505bool DnsTlsTransport::validate(const Server& server, unsigned netid) {
Ben Schwartze7601812017-04-28 16:38:29 -0400506 if (DBG) {
507 ALOGD("Beginning validation on %u", netid);
508 }
509 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
510 // order to prove that it is actually a working DNS over TLS server.
511 static const char kDnsSafeChars[] =
512 "abcdefhijklmnopqrstuvwxyz"
513 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
514 "0123456789";
515 const auto c = [](uint8_t rnd) -> uint8_t {
516 return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
517 };
518 uint8_t rnd[8];
519 arc4random_buf(rnd, ARRAY_SIZE(rnd));
520 // We could try to use res_mkquery() here, but it's basically the same.
521 uint8_t query[] = {
522 rnd[6], rnd[7], // [0-1] query ID
523 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
524 0, 1, // [4-5] QDCOUNT (number of queries)
525 0, 0, // [6-7] ANCOUNT (number of answers)
526 0, 0, // [8-9] NSCOUNT (number of name server records)
527 0, 0, // [10-11] ARCOUNT (number of additional records)
528 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
529 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
530 6, 'm', 'e', 't', 'r', 'i', 'c',
531 7, 'g', 's', 't', 'a', 't', 'i', 'c',
532 3, 'c', 'o', 'm',
533 0, // null terminator of FQDN (root TLD)
534 0, ns_t_aaaa, // QTYPE
535 0, ns_c_in // QCLASS
536 };
537 const int qlen = ARRAY_SIZE(query);
538
539 const int kRecvBufSize = 4 * 1024;
540 uint8_t recvbuf[kRecvBufSize];
541
542 // At validation time, we only know the netId, so we have to guess/compute the
543 // corresponding socket mark.
544 Fwmark fwmark;
545 fwmark.permission = PERMISSION_SYSTEM;
546 fwmark.explicitlySelected = true;
547 fwmark.protectedFromVpn = true;
548 fwmark.netId = netid;
549 unsigned mark = fwmark.intValue;
Ben Schwartze7601812017-04-28 16:38:29 -0400550 int replylen = 0;
Ben Schwartz52504622017-07-11 12:21:13 -0400551 DnsTlsTransport::query(server, mark, query, qlen, recvbuf, kRecvBufSize, &replylen);
Ben Schwartze7601812017-04-28 16:38:29 -0400552 if (replylen == 0) {
553 if (DBG) {
Ben Schwartz52504622017-07-11 12:21:13 -0400554 ALOGD("query failed");
Ben Schwartze7601812017-04-28 16:38:29 -0400555 }
556 return false;
557 }
558
559 if (replylen < NS_HFIXEDSZ) {
560 if (DBG) {
561 ALOGW("short response: %d", replylen);
562 }
563 return false;
564 }
565
566 const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
567 if (qdcount != 1) {
568 ALOGW("reply query count != 1: %d", qdcount);
569 return false;
570 }
571
572 const int ancount = (recvbuf[6] << 8) | recvbuf[7];
573 if (DBG) {
574 ALOGD("%u answer count: %d", netid, ancount);
575 }
576
577 // TODO: Further validate the response contents (check for valid AAAA record, ...).
578 // Note that currently, integration tests rely on this function accepting a
579 // response with zero records.
580#if 0
581 for (int i = 0; i < resplen; i++) {
582 ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
583 }
584#endif
585 return true;
586}
587
588} // namespace net
589} // namespace android