blob: 4cedac99c85f6b4d456954abd28ccd25d9f49552 [file] [log] [blame]
Ben Schwartze7601812017-04-28 16:38:29 -04001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dns/DnsTlsTransport.h"
18
19#include <arpa/inet.h>
20#include <arpa/nameser.h>
21#include <errno.h>
22#include <openssl/err.h>
23#include <openssl/ssl.h>
24#include <stdlib.h>
25
26#define LOG_TAG "DnsTlsTransport"
27#define DBG 0
28
29#include "log/log.h"
30#include "Fwmark.h"
31#undef ADD // already defined in nameser.h
32#include "NetdConstants.h"
33#include "Permission.h"
34
35
36namespace android {
37namespace net {
38
39namespace {
40
41bool setNonBlocking(int fd, bool enabled) {
42 int flags = fcntl(fd, F_GETFL);
43 if (flags < 0) return false;
44
45 if (enabled) {
46 flags |= O_NONBLOCK;
47 } else {
48 flags &= ~O_NONBLOCK;
49 }
50 return (fcntl(fd, F_SETFL, flags) == 0);
51}
52
53int waitForReading(int fd) {
54 fd_set fds;
55 FD_ZERO(&fds);
56 FD_SET(fd, &fds);
57 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
58 if (DBG && ret <= 0) {
59 ALOGD("select");
60 }
61 return ret;
62}
63
64int waitForWriting(int fd) {
65 fd_set fds;
66 FD_ZERO(&fds);
67 FD_SET(fd, &fds);
68 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
69 if (DBG && ret <= 0) {
70 ALOGD("select");
71 }
72 return ret;
73}
74
75} // namespace
76
77android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
78 android::base::unique_fd fd;
79 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
Ben Schwartz52504622017-07-11 12:21:13 -040080 switch (mServer.protocol) {
Ben Schwartze7601812017-04-28 16:38:29 -040081 case IPPROTO_TCP:
82 type |= SOCK_STREAM;
83 break;
84 default:
85 errno = EPROTONOSUPPORT;
86 return fd;
87 }
88
Ben Schwartz52504622017-07-11 12:21:13 -040089 fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
Ben Schwartze7601812017-04-28 16:38:29 -040090 if (fd.get() == -1) {
91 return fd;
92 }
93
94 const socklen_t len = sizeof(mMark);
95 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
96 fd.reset();
97 } else if (connect(fd.get(),
Ben Schwartz52504622017-07-11 12:21:13 -040098 reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
Ben Schwartze7601812017-04-28 16:38:29 -040099 && errno != EINPROGRESS) {
100 fd.reset();
101 }
102
103 return fd;
104}
105
106bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
107 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
108 unsigned char spki[spki_len];
109 unsigned char* temp = spki;
110 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
111 ALOGW("SPKI length mismatch");
112 return false;
113 }
114 out->resize(SHA256_SIZE);
115 unsigned int digest_len = 0;
116 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
117 if (ret != 1) {
118 ALOGW("Server cert digest extraction failed");
119 return false;
120 }
121 if (digest_len != out->size()) {
122 ALOGW("Wrong digest length: %d", digest_len);
123 return false;
124 }
125 return true;
126}
127
128SSL* DnsTlsTransport::sslConnect(int fd) {
129 if (fd < 0) {
130 ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
131 return nullptr;
132 }
133
134 // Set up TLS context.
135 bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
136 if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
137 !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
138 ALOGD("failed to min/max TLS versions");
139 return nullptr;
140 }
141
142 bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
143 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
144 SSL_set_bio(ssl.get(), bio.get(), bio.get());
145 bio.release();
146
147 if (!setNonBlocking(fd, false)) {
148 ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
149 return nullptr;
150 }
151
Ben Schwartz1691bc42017-08-16 12:53:09 -0400152 if (!mServer.name.empty()) {
153 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
154 ALOGE("Failed to set SNI to %s", mServer.name.c_str());
155 return nullptr;
156 }
157 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
158 X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
159 // This will cause the handshake to fail if certificate verification fails.
160 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
161 }
162
Ben Schwartze7601812017-04-28 16:38:29 -0400163 for (;;) {
164 if (DBG) {
165 ALOGD("%u Calling SSL_connect", mMark);
166 }
167 int ret = SSL_connect(ssl.get());
168 if (DBG) {
169 ALOGD("%u SSL_connect returned %d", mMark, ret);
170 }
171 if (ret == 1) break; // SSL handshake complete;
172
173 const int ssl_err = SSL_get_error(ssl.get(), ret);
174 switch (ssl_err) {
175 case SSL_ERROR_WANT_READ:
176 if (waitForReading(fd) != 1) {
177 ALOGW("SSL_connect read error");
178 return nullptr;
179 }
180 break;
181 case SSL_ERROR_WANT_WRITE:
182 if (waitForWriting(fd) != 1) {
183 ALOGW("SSL_connect write error");
184 return nullptr;
185 }
186 break;
187 default:
188 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
189 return nullptr;
190 }
191 }
192
Ben Schwartz52504622017-07-11 12:21:13 -0400193 if (!mServer.fingerprints.empty()) {
Ben Schwartze7601812017-04-28 16:38:29 -0400194 if (DBG) {
195 ALOGD("Checking DNS over TLS fingerprint");
196 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400197
198 // We only care that the chain is internally self-consistent, not that
199 // it chains to a trusted root, so we can ignore some kinds of errors.
200 // TODO: Add a CA root verification mode that respects these errors.
201 int verify_result = SSL_get_verify_result(ssl.get());
202 switch (verify_result) {
203 case X509_V_OK:
204 case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
205 case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
206 case X509_V_ERR_CERT_UNTRUSTED:
207 break;
208 default:
209 ALOGW("Invalid certificate chain, error %d", verify_result);
210 return nullptr;
211 }
212
213 STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
214 if (!chain) {
Ben Schwartze7601812017-04-28 16:38:29 -0400215 ALOGW("Server has null certificate");
216 return nullptr;
217 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400218 // Chain and its contents are owned by ssl, so we don't need to free explicitly.
219 bool matched = false;
220 for (size_t i = 0; i < sk_X509_num(chain); ++i) {
221 // This appears to be O(N^2), but there doesn't seem to be a straightforward
222 // way to walk a STACK_OF nondestructively in linear time.
223 X509* cert = sk_X509_value(chain, i);
224 std::vector<uint8_t> digest;
225 if (!getSPKIDigest(cert, &digest)) {
226 ALOGE("Digest computation failed");
227 return nullptr;
228 }
229
230 if (mServer.fingerprints.count(digest) > 0) {
231 matched = true;
232 break;
233 }
Ben Schwartze7601812017-04-28 16:38:29 -0400234 }
235
Ben Schwartzf028d392017-07-10 15:07:12 -0400236 if (!matched) {
Ben Schwartze7601812017-04-28 16:38:29 -0400237 ALOGW("No matching fingerprint");
238 return nullptr;
239 }
Ben Schwartzf028d392017-07-10 15:07:12 -0400240
Ben Schwartze7601812017-04-28 16:38:29 -0400241 if (DBG) {
242 ALOGD("DNS over TLS fingerprint is correct");
243 }
244 }
245
246 if (DBG) {
247 ALOGD("%u handshake complete", mMark);
248 }
249 return ssl.release();
250}
251
252bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
253 if (DBG) {
254 ALOGD("%u Writing %d bytes", mMark, len);
255 }
256 for (;;) {
257 int ret = SSL_write(ssl, buffer, len);
258 if (ret == len) break; // SSL write complete;
259
260 if (ret < 1) {
261 const int ssl_err = SSL_get_error(ssl, ret);
262 switch (ssl_err) {
263 case SSL_ERROR_WANT_WRITE:
264 if (waitForWriting(fd) != 1) {
265 if (DBG) {
266 ALOGW("SSL_write error");
267 }
268 return false;
269 }
270 continue;
271 case 0:
272 break; // SSL write complete;
273 default:
274 if (DBG) {
275 ALOGW("SSL_write error %d", ssl_err);
276 }
277 return false;
278 }
279 }
280 }
281 if (DBG) {
282 ALOGD("%u Wrote %d bytes", mMark, len);
283 }
284 return true;
285}
286
287// Read exactly len bytes into buffer or fail
288bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
289 int remaining = len;
290 while (remaining > 0) {
291 int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
292 if (ret == 0) {
293 ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
294 return false;
295 }
296
297 if (ret < 0) {
298 const int ssl_err = SSL_get_error(ssl, ret);
299 if (ssl_err == SSL_ERROR_WANT_READ) {
300 if (waitForReading(fd) != 1) {
301 if (DBG) {
302 ALOGW("SSL_read error");
303 }
304 return false;
305 }
306 continue;
307 } else {
308 if (DBG) {
309 ALOGW("SSL_read error %d", ssl_err);
310 }
311 return false;
312 }
313 }
314
315 remaining -= ret;
316 }
317 return true;
318}
319
Ben Schwartz52504622017-07-11 12:21:13 -0400320// static
321DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
322 const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
323 // TODO: Keep a static container of transports instead of constructing a new one
324 // for every query.
325 DnsTlsTransport xport(server, mark);
326 return xport.doQuery(query, qlen, response, limit, resplen);
327}
328
Ben Schwartze7601812017-04-28 16:38:29 -0400329DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
330 uint8_t *response, size_t limit, int *resplen) {
331 *resplen = 0; // Zero indicates an error.
332
333 if (DBG) {
334 ALOGD("%u connecting TCP socket", mMark);
335 }
336 android::base::unique_fd fd(makeConnectedSocket());
337 if (DBG) {
338 ALOGD("%u connecting SSL", mMark);
339 }
340 bssl::UniquePtr<SSL> ssl(sslConnect(fd));
341 if (ssl == nullptr) {
342 if (DBG) {
343 ALOGW("%u SSL connection failed", mMark);
344 }
345 return Response::network_error;
346 }
347
348 uint8_t queryHeader[2];
349 queryHeader[0] = qlen >> 8;
350 queryHeader[1] = qlen;
351 if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
352 return Response::network_error;
353 }
354 if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
355 return Response::network_error;
356 }
357 if (DBG) {
358 ALOGD("%u SSL_write complete", mMark);
359 }
360
361 uint8_t responseHeader[2];
362 if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
363 if (DBG) {
364 ALOGW("%u Failed to read 2-byte length header", mMark);
365 }
366 return Response::network_error;
367 }
368 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
369 if (DBG) {
370 ALOGD("%u Expecting response of size %i", mMark, responseSize);
371 }
372 if (responseSize > limit) {
373 ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
374 return Response::limit_error;
375 }
376 if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
377 if (DBG) {
378 ALOGW("%u Failed to read %i bytes", mMark, responseSize);
379 }
380 return Response::network_error;
381 }
382 if (DBG) {
383 ALOGD("%u SSL_read complete", mMark);
384 }
385
386 if (response[0] != query[0] || response[1] != query[1]) {
387 ALOGE("reply query ID != query ID");
388 return Response::internal_error;
389 }
390
391 SSL_shutdown(ssl.get());
392
393 *resplen = responseSize;
394 return Response::success;
395}
396
Ben Schwartz52504622017-07-11 12:21:13 -0400397// static
398bool DnsTlsTransport::validate(const Server& server, unsigned netid) {
Ben Schwartze7601812017-04-28 16:38:29 -0400399 if (DBG) {
400 ALOGD("Beginning validation on %u", netid);
401 }
402 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
403 // order to prove that it is actually a working DNS over TLS server.
404 static const char kDnsSafeChars[] =
405 "abcdefhijklmnopqrstuvwxyz"
406 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
407 "0123456789";
408 const auto c = [](uint8_t rnd) -> uint8_t {
409 return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
410 };
411 uint8_t rnd[8];
412 arc4random_buf(rnd, ARRAY_SIZE(rnd));
413 // We could try to use res_mkquery() here, but it's basically the same.
414 uint8_t query[] = {
415 rnd[6], rnd[7], // [0-1] query ID
416 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
417 0, 1, // [4-5] QDCOUNT (number of queries)
418 0, 0, // [6-7] ANCOUNT (number of answers)
419 0, 0, // [8-9] NSCOUNT (number of name server records)
420 0, 0, // [10-11] ARCOUNT (number of additional records)
421 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
422 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
423 6, 'm', 'e', 't', 'r', 'i', 'c',
424 7, 'g', 's', 't', 'a', 't', 'i', 'c',
425 3, 'c', 'o', 'm',
426 0, // null terminator of FQDN (root TLD)
427 0, ns_t_aaaa, // QTYPE
428 0, ns_c_in // QCLASS
429 };
430 const int qlen = ARRAY_SIZE(query);
431
432 const int kRecvBufSize = 4 * 1024;
433 uint8_t recvbuf[kRecvBufSize];
434
435 // At validation time, we only know the netId, so we have to guess/compute the
436 // corresponding socket mark.
437 Fwmark fwmark;
438 fwmark.permission = PERMISSION_SYSTEM;
439 fwmark.explicitlySelected = true;
440 fwmark.protectedFromVpn = true;
441 fwmark.netId = netid;
442 unsigned mark = fwmark.intValue;
Ben Schwartze7601812017-04-28 16:38:29 -0400443 int replylen = 0;
Ben Schwartz52504622017-07-11 12:21:13 -0400444 DnsTlsTransport::query(server, mark, query, qlen, recvbuf, kRecvBufSize, &replylen);
Ben Schwartze7601812017-04-28 16:38:29 -0400445 if (replylen == 0) {
446 if (DBG) {
Ben Schwartz52504622017-07-11 12:21:13 -0400447 ALOGD("query failed");
Ben Schwartze7601812017-04-28 16:38:29 -0400448 }
449 return false;
450 }
451
452 if (replylen < NS_HFIXEDSZ) {
453 if (DBG) {
454 ALOGW("short response: %d", replylen);
455 }
456 return false;
457 }
458
459 const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
460 if (qdcount != 1) {
461 ALOGW("reply query count != 1: %d", qdcount);
462 return false;
463 }
464
465 const int ancount = (recvbuf[6] << 8) | recvbuf[7];
466 if (DBG) {
467 ALOGD("%u answer count: %d", netid, ancount);
468 }
469
470 // TODO: Further validate the response contents (check for valid AAAA record, ...).
471 // Note that currently, integration tests rely on this function accepting a
472 // response with zero records.
473#if 0
474 for (int i = 0; i < resplen; i++) {
475 ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
476 }
477#endif
478 return true;
479}
480
481} // namespace net
482} // namespace android