blob: 5e1cd64eee35f84c0682b1cc482500810c8fa3a7 [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "DnsTlsSocket"
Ben Schwartz33860762017-10-25 14:41:02 -040018//#define LOG_NDEBUG 0
Ben Schwartzded1b702017-10-25 14:41:02 -040019
20#include "dns/DnsTlsSocket.h"
21
22#include <algorithm>
23#include <arpa/inet.h>
24#include <arpa/nameser.h>
25#include <errno.h>
Erik Klined1503072018-02-22 23:55:40 -080026#include <linux/tcp.h>
Ben Schwartzded1b702017-10-25 14:41:02 -040027#include <openssl/err.h>
28#include <sys/select.h>
29
30#include "dns/DnsTlsSessionCache.h"
Ben Schwartz33860762017-10-25 14:41:02 -040031#include "dns/IDnsTlsSocketObserver.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040032
33#include "log/log.h"
Erik Klined1503072018-02-22 23:55:40 -080034#include "netdutils/SocketOption.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040035#include "Fwmark.h"
36#undef ADD // already defined in nameser.h
37#include "NetdConstants.h"
38#include "Permission.h"
39
40
41namespace android {
Ben Schwartzded1b702017-10-25 14:41:02 -040042
Erik Klined1503072018-02-22 23:55:40 -080043using netdutils::enableSockopt;
44using netdutils::enableTcpKeepAlives;
45using netdutils::isOk;
Ben Schwartzded1b702017-10-25 14:41:02 -040046using netdutils::Status;
47
Erik Klined1503072018-02-22 23:55:40 -080048namespace net {
Ben Schwartzded1b702017-10-25 14:41:02 -040049namespace {
50
51constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
52
53int waitForReading(int fd) {
54 fd_set fds;
55 FD_ZERO(&fds);
56 FD_SET(fd, &fds);
57 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
58 ALOGV_IF(ret <= 0, "select failed during read");
59 return ret;
60}
61
62int waitForWriting(int fd) {
63 fd_set fds;
64 FD_ZERO(&fds);
65 FD_SET(fd, &fds);
66 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
67 ALOGV_IF(ret <= 0, "select failed during write");
68 return ret;
69}
70
71} // namespace
72
73Status DnsTlsSocket::tcpConnect() {
74 ALOGV("%u connecting TCP socket", mMark);
75 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
76 switch (mServer.protocol) {
77 case IPPROTO_TCP:
78 type |= SOCK_STREAM;
79 break;
80 default:
81 return Status(EPROTONOSUPPORT);
82 }
83
84 mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
85 if (mSslFd.get() == -1) {
86 ALOGE("Failed to create socket");
87 return Status(errno);
88 }
89
90 const socklen_t len = sizeof(mMark);
91 if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
92 ALOGE("Failed to set socket mark");
93 mSslFd.reset();
94 return Status(errno);
95 }
Erik Klined1503072018-02-22 23:55:40 -080096
97 const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
98 if (!isOk(tfo) && tfo.code() != ENOPROTOOPT) {
99 ALOGI("Failed to enable TFO: %s", tfo.msg().c_str());
100 }
101
102 // Send 5 keepalives, 3 seconds apart, after 15 seconds of inactivity.
103 enableTcpKeepAlives(mSslFd.get(), 15U, 5U, 3U);
104
Ben Schwartzded1b702017-10-25 14:41:02 -0400105 if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
106 sizeof(mServer.ss)) != 0 &&
107 errno != EINPROGRESS) {
108 ALOGV("Socket failed to connect");
109 mSslFd.reset();
110 return Status(errno);
111 }
112
113 return netdutils::status::ok;
114}
115
116bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
117 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
118 unsigned char spki[spki_len];
119 unsigned char* temp = spki;
120 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
121 ALOGW("SPKI length mismatch");
122 return false;
123 }
124 out->resize(SHA256_SIZE);
125 unsigned int digest_len = 0;
126 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
127 if (ret != 1) {
128 ALOGW("Server cert digest extraction failed");
129 return false;
130 }
131 if (digest_len != out->size()) {
132 ALOGW("Wrong digest length: %d", digest_len);
133 return false;
134 }
135 return true;
136}
137
138bool DnsTlsSocket::initialize() {
139 // This method should only be called once, at the beginning, so locking should be
140 // unnecessary. This lock only serves to help catch bugs in code that calls this method.
141 std::lock_guard<std::mutex> guard(mLock);
142 if (mSslCtx) {
143 // This is a bug in the caller.
144 return false;
145 }
146 mSslCtx.reset(SSL_CTX_new(TLS_method()));
147 if (!mSslCtx) {
148 return false;
149 }
150
151 // Load system CA certs for hostname verification.
152 //
153 // For discussion of alternative, sustainable approaches see b/71909242.
154 if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
155 ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
156 return false;
157 }
158
159 // Enable TLS false start
160 SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
161 SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
162
163 // Enable session cache
164 mCache->prepareSslContext(mSslCtx.get());
165
166 // Connect
167 Status status = tcpConnect();
168 if (!status.ok()) {
169 return false;
170 }
171 mSsl = sslConnect(mSslFd.get());
172 if (!mSsl) {
173 return false;
174 }
Ben Schwartz33860762017-10-25 14:41:02 -0400175 int sv[2];
176 if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
177 return false;
178 }
179 // The two sockets are perfectly symmetrical, so the choice of which one is
180 // "in" and which one is "out" is arbitrary.
181 mIpcInFd.reset(sv[0]);
182 mIpcOutFd.reset(sv[1]);
183
184 // Start the I/O loop.
185 mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
Ben Schwartzded1b702017-10-25 14:41:02 -0400186
187 return true;
188}
189
190bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
191 if (!mSslCtx) {
192 ALOGE("Internal error: context is null in sslConnect");
193 return nullptr;
194 }
195 if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
196 ALOGE("Failed to set minimum TLS version");
197 return nullptr;
198 }
199
200 bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
201 // This file descriptor is owned by mSslFd, so don't let libssl close it.
202 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
203 SSL_set_bio(ssl.get(), bio.get(), bio.get());
204 bio.release();
205
206 if (!mCache->prepareSsl(ssl.get())) {
207 return nullptr;
208 }
209
210 if (!mServer.name.empty()) {
211 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
212 ALOGE("Failed to set SNI to %s", mServer.name.c_str());
213 return nullptr;
214 }
215 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
216 X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
217 // This will cause the handshake to fail if certificate verification fails.
218 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
219 }
220
221 bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
222 if (session) {
223 ALOGV("Setting session");
224 SSL_set_session(ssl.get(), session.get());
225 } else {
226 ALOGV("No session available");
227 }
228
229 for (;;) {
230 ALOGV("%u Calling SSL_connect", mMark);
231 int ret = SSL_connect(ssl.get());
232 ALOGV("%u SSL_connect returned %d", mMark, ret);
233 if (ret == 1) break; // SSL handshake complete;
234
235 const int ssl_err = SSL_get_error(ssl.get(), ret);
236 switch (ssl_err) {
237 case SSL_ERROR_WANT_READ:
238 if (waitForReading(fd) != 1) {
239 ALOGW("SSL_connect read error");
240 return nullptr;
241 }
242 break;
243 case SSL_ERROR_WANT_WRITE:
244 if (waitForWriting(fd) != 1) {
245 ALOGW("SSL_connect write error");
246 return nullptr;
247 }
248 break;
249 default:
250 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
251 return nullptr;
252 }
253 }
254
255 // TODO: Call SSL_shutdown before discarding the session if validation fails.
256 if (!mServer.fingerprints.empty()) {
257 ALOGV("Checking DNS over TLS fingerprint");
258
259 // We only care that the chain is internally self-consistent, not that
260 // it chains to a trusted root, so we can ignore some kinds of errors.
261 // TODO: Add a CA root verification mode that respects these errors.
262 int verify_result = SSL_get_verify_result(ssl.get());
263 switch (verify_result) {
264 case X509_V_OK:
265 case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
266 case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
267 case X509_V_ERR_CERT_UNTRUSTED:
268 break;
269 default:
270 ALOGW("Invalid certificate chain, error %d", verify_result);
271 return nullptr;
272 }
273
274 STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
275 if (!chain) {
276 ALOGW("Server has null certificate");
277 return nullptr;
278 }
279 // Chain and its contents are owned by ssl, so we don't need to free explicitly.
280 bool matched = false;
281 for (size_t i = 0; i < sk_X509_num(chain); ++i) {
282 // This appears to be O(N^2), but there doesn't seem to be a straightforward
283 // way to walk a STACK_OF nondestructively in linear time.
284 X509* cert = sk_X509_value(chain, i);
285 std::vector<uint8_t> digest;
286 if (!getSPKIDigest(cert, &digest)) {
287 ALOGE("Digest computation failed");
288 return nullptr;
289 }
290
291 if (mServer.fingerprints.count(digest) > 0) {
292 matched = true;
293 break;
294 }
295 }
296
297 if (!matched) {
298 ALOGW("No matching fingerprint");
299 return nullptr;
300 }
301
302 ALOGV("DNS over TLS fingerprint is correct");
303 }
304
305 ALOGV("%u handshake complete", mMark);
306
307 return ssl;
308}
309
310void DnsTlsSocket::sslDisconnect() {
311 if (mSsl) {
312 SSL_shutdown(mSsl.get());
313 mSsl.reset();
314 }
315 mSslFd.reset();
316}
317
318bool DnsTlsSocket::sslWrite(const Slice buffer) {
319 ALOGV("%u Writing %zu bytes", mMark, buffer.size());
320 for (;;) {
321 int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
322 if (ret == int(buffer.size())) break; // SSL write complete;
323
324 if (ret < 1) {
325 const int ssl_err = SSL_get_error(mSsl.get(), ret);
326 switch (ssl_err) {
327 case SSL_ERROR_WANT_WRITE:
328 if (waitForWriting(mSslFd.get()) != 1) {
329 ALOGV("SSL_write error");
330 return false;
331 }
332 continue;
333 case 0:
334 break; // SSL write complete;
335 default:
336 ALOGV("SSL_write error %d", ssl_err);
337 return false;
338 }
339 }
340 }
341 ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
342 return true;
343}
344
Ben Schwartz33860762017-10-25 14:41:02 -0400345void DnsTlsSocket::loop() {
Ben Schwartzded1b702017-10-25 14:41:02 -0400346 std::lock_guard<std::mutex> guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400347 // Buffer at most one query.
348 Query q;
349
350 fd_set readFds, writeFds;
351 FD_ZERO(&readFds);
352 FD_ZERO(&writeFds);
353 const int maxFd = std::max(mSslFd.get(), mIpcOutFd.get());
354 while (true) {
355 timeval timeout = { .tv_sec = DnsTlsSocket::kIdleTimeout.count() };
356 // Always listen for a response from server.
357 FD_SET(mSslFd.get(), &readFds);
358 // If we have a pending query, also wait for space
359 // to write it, otherwise listen for a new query.
360 if (!q.query.empty()) {
361 FD_SET(mSslFd.get(), &writeFds);
362 FD_CLR(mIpcOutFd.get(), &readFds);
363 } else {
364 FD_CLR(mSslFd.get(), &writeFds);
365 FD_SET(mIpcOutFd.get(), &readFds);
366 }
367 // Deviating from POSIX, Linux will decrement the timeout on each retry.
368 // Either behavior is OK here.
369 const int s = TEMP_FAILURE_RETRY(select(maxFd + 1, &readFds, &writeFds, nullptr, &timeout));
370 if (s == 0) {
371 ALOGV("Idle timeout");
372 break;
373 }
374 if (s < 0) {
375 ALOGV("Select failed: %d", errno);
376 break;
377 }
378 if (FD_ISSET(mSslFd.get(), &readFds)) {
379 if (!readResponse()) {
380 ALOGV("SSL remote close or read error.");
381 break;
382 }
383 }
384 if (FD_ISSET(mIpcOutFd.get(), &readFds)) {
385 int res = read(mIpcOutFd.get(), &q, sizeof(q));
386 if (res < 0) {
387 ALOGW("Error during IPC read");
388 break;
389 } else if (res == 0) {
390 ALOGV("IPC channel closed; disconnecting");
391 break;
392 } else if (res != sizeof(q)) {
393 ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
394 break;
395 }
396 } else if (FD_ISSET(mSslFd.get(), &writeFds)) {
397 // query cannot be null here.
398 if (!sendQuery(q)) {
399 break;
400 }
401 q = Query(); // Reset q to empty
402 }
Ben Schwartzded1b702017-10-25 14:41:02 -0400403 }
Ben Schwartz33860762017-10-25 14:41:02 -0400404 ALOGV("Closing IPC read FD");
405 mIpcOutFd.reset();
406 ALOGV("Disconnecting");
407 sslDisconnect();
408 ALOGV("Calling onClosed");
409 mObserver->onClosed();
410 ALOGV("Ending loop");
Ben Schwartzded1b702017-10-25 14:41:02 -0400411}
412
Ben Schwartz33860762017-10-25 14:41:02 -0400413DnsTlsSocket::~DnsTlsSocket() {
414 ALOGV("Destructor");
415 // This will trigger an orderly shutdown in loop().
416 mIpcInFd.reset();
417 {
418 // Wait for the orderly shutdown to complete.
419 std::lock_guard<std::mutex> guard(mLock);
420 if (mLoopThread && std::this_thread::get_id() == mLoopThread->get_id()) {
421 ALOGE("Violation of re-entrance precondition");
422 return;
423 }
424 }
425 if (mLoopThread) {
426 ALOGV("Waiting for loop thread to terminate");
427 mLoopThread->join();
428 mLoopThread.reset();
429 }
430 ALOGV("Destructor completed");
431}
432
433bool DnsTlsSocket::query(uint16_t id, const Slice query) {
434 const Query q = { .id = id, .query = query };
435 if (!mIpcInFd) {
436 return false;
437 }
438 int written = write(mIpcInFd.get(), &q, sizeof(q));
439 return written == sizeof(q);
440}
441
442// Read exactly len bytes into buffer or fail with an SSL error code
443int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400444 size_t remaining = buffer.size();
445 while (remaining > 0) {
446 int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
447 if (ret == 0) {
448 ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
449 remaining, buffer.size());
Ben Schwartz33860762017-10-25 14:41:02 -0400450 return SSL_ERROR_ZERO_RETURN;
Ben Schwartzded1b702017-10-25 14:41:02 -0400451 }
452
453 if (ret < 0) {
454 const int ssl_err = SSL_get_error(mSsl.get(), ret);
Ben Schwartz33860762017-10-25 14:41:02 -0400455 if (wait && ssl_err == SSL_ERROR_WANT_READ) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400456 if (waitForReading(mSslFd.get()) != 1) {
Ben Schwartz33860762017-10-25 14:41:02 -0400457 ALOGV("Select failed in sslRead");
458 return SSL_ERROR_SYSCALL;
Ben Schwartzded1b702017-10-25 14:41:02 -0400459 }
460 continue;
461 } else {
462 ALOGV("SSL_read error %d", ssl_err);
Ben Schwartz33860762017-10-25 14:41:02 -0400463 return ssl_err;
Ben Schwartzded1b702017-10-25 14:41:02 -0400464 }
465 }
466
467 remaining -= ret;
Ben Schwartz33860762017-10-25 14:41:02 -0400468 wait = true; // Once a read is started, try to finish.
Ben Schwartzded1b702017-10-25 14:41:02 -0400469 }
Ben Schwartz33860762017-10-25 14:41:02 -0400470 return SSL_ERROR_NONE;
Ben Schwartzded1b702017-10-25 14:41:02 -0400471}
472
473bool DnsTlsSocket::sendQuery(const Query& q) {
474 ALOGV("sending query");
475 // Compose the entire message in a single buffer, so that it can be
476 // sent as a single TLS record.
477 std::vector<uint8_t> buf(q.query.size() + 4);
478 // Write 2-byte length
479 uint16_t len = q.query.size() + 2; // + 2 for the ID.
480 buf[0] = len >> 8;
481 buf[1] = len;
482 // Write 2-byte ID
483 buf[2] = q.id >> 8;
484 buf[3] = q.id;
485 // Copy body
486 std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
487 if (!sslWrite(netdutils::makeSlice(buf))) {
488 return false;
489 }
490 ALOGV("%u SSL_write complete", mMark);
491 return true;
492}
493
Ben Schwartz33860762017-10-25 14:41:02 -0400494bool DnsTlsSocket::readResponse() {
Ben Schwartzded1b702017-10-25 14:41:02 -0400495 ALOGV("reading response");
496 uint8_t responseHeader[2];
Ben Schwartz33860762017-10-25 14:41:02 -0400497 int err = sslRead(Slice(responseHeader, 2), false);
498 if (err == SSL_ERROR_WANT_READ) {
499 ALOGV("Ignoring spurious wakeup from server");
500 return true;
501 }
502 if (err != SSL_ERROR_NONE) {
503 return false;
Ben Schwartzded1b702017-10-25 14:41:02 -0400504 }
505 // Truncate responses larger than MAX_SIZE. This is safe because a DNS packet is
506 // always invalid when truncated, so the response will be treated as an error.
507 constexpr uint16_t MAX_SIZE = 8192;
508 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
509 ALOGV("%u Expecting response of size %i", mMark, responseSize);
510 std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
Ben Schwartz33860762017-10-25 14:41:02 -0400511 if (sslRead(netdutils::makeSlice(response), true) != SSL_ERROR_NONE) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400512 ALOGV("%u Failed to read %zu bytes", mMark, response.size());
Ben Schwartz33860762017-10-25 14:41:02 -0400513 return false;
Ben Schwartzded1b702017-10-25 14:41:02 -0400514 }
515 uint16_t remainingBytes = responseSize - response.size();
516 while (remainingBytes > 0) {
517 constexpr uint16_t CHUNK_SIZE = 2048;
518 std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
Ben Schwartz33860762017-10-25 14:41:02 -0400519 if (sslRead(netdutils::makeSlice(discard), true) != SSL_ERROR_NONE) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400520 ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
Ben Schwartz33860762017-10-25 14:41:02 -0400521 return false;
Ben Schwartzded1b702017-10-25 14:41:02 -0400522 }
523 remainingBytes -= discard.size();
524 }
525 ALOGV("%u SSL_read complete", mMark);
526
Ben Schwartz33860762017-10-25 14:41:02 -0400527 mObserver->onResponse(std::move(response));
528 return true;
Ben Schwartzded1b702017-10-25 14:41:02 -0400529}
530
531} // end of namespace net
532} // end of namespace android