blob: 14aadd4a54a19abd732852f98ac68d496d5a7214 [file] [log] [blame]
Mike Yubab3daa2018-10-19 22:11:43 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yubab3daa2018-10-19 22:11:43 +080018
Bernie Innocentiec4219b2019-01-30 11:16:36 +090019#include "DnsTlsSocket.h"
Mike Yubab3daa2018-10-19 22:11:43 +080020
Mike Yubab3daa2018-10-19 22:11:43 +080021#include <arpa/inet.h>
22#include <arpa/nameser.h>
23#include <errno.h>
24#include <linux/tcp.h>
25#include <openssl/err.h>
26#include <openssl/sha.h>
Ben Schwartz2187abe2019-01-10 14:30:46 -050027#include <sys/eventfd.h>
Mike Yubab3daa2018-10-19 22:11:43 +080028#include <sys/poll.h>
Sehee Park2c118782019-05-07 13:02:45 +090029#include <unistd.h>
Ben Schwartz2187abe2019-01-10 14:30:46 -050030#include <algorithm>
Mike Yubab3daa2018-10-19 22:11:43 +080031
Bernie Innocentiec4219b2019-01-30 11:16:36 +090032#include "DnsTlsSessionCache.h"
33#include "IDnsTlsSocketObserver.h"
Mike Yubab3daa2018-10-19 22:11:43 +080034
Chiachang Wang32372172019-07-06 13:54:18 +080035#include <Fwmark.h>
chenbruceaff85842019-05-31 15:46:42 +080036#include <android-base/logging.h>
Chiachang Wang32372172019-07-06 13:54:18 +080037#include <android-base/stringprintf.h>
Mike Yu04f1d482019-08-08 11:09:32 +080038#include <netdutils/SocketOption.h>
39#include <netdutils/ThreadUtil.h>
chenbruceaff85842019-05-31 15:46:42 +080040
Sehee Park2c118782019-05-07 13:02:45 +090041#include "private/android_filesystem_config.h" // AID_DNS
Sehee Parkd975bf32019-08-07 13:21:16 +090042#include "resolv_private.h"
Mike Yubab3daa2018-10-19 22:11:43 +080043
waynema0e73c2e2019-07-31 15:04:08 +080044// NOTE: Inject CA certificate for internal testing -- do NOT enable in production builds
45#ifndef RESOLV_INJECT_CA_CERTIFICATE
46#define RESOLV_INJECT_CA_CERTIFICATE 0
47#endif
48
Mike Yubab3daa2018-10-19 22:11:43 +080049namespace android {
50
51using netdutils::enableSockopt;
52using netdutils::enableTcpKeepAlives;
53using netdutils::isOk;
Bernie Innocentiec4219b2019-01-30 11:16:36 +090054using netdutils::Slice;
Mike Yubab3daa2018-10-19 22:11:43 +080055using netdutils::Status;
56
57namespace net {
58namespace {
59
60constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
Mike Yubab3daa2018-10-19 22:11:43 +080061
62int waitForReading(int fd) {
63 struct pollfd fds = { .fd = fd, .events = POLLIN };
64 const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
65 return ret;
66}
67
68int waitForWriting(int fd) {
69 struct pollfd fds = { .fd = fd, .events = POLLOUT };
70 const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
71 return ret;
72}
73
Chiachang Wang32372172019-07-06 13:54:18 +080074std::string markToFwmarkString(unsigned mMark) {
75 Fwmark mark;
76 mark.intValue = mMark;
77 return android::base::StringPrintf("%d, %d, %d, %d, %d", mark.netId, mark.explicitlySelected,
78 mark.protectedFromVpn, mark.permission, mark.uidBillingDone);
79}
80
Mike Yubab3daa2018-10-19 22:11:43 +080081} // namespace
82
83Status DnsTlsSocket::tcpConnect() {
chenbruceaff85842019-05-31 15:46:42 +080084 LOG(DEBUG) << mMark << " connecting TCP socket";
Mike Yubab3daa2018-10-19 22:11:43 +080085 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
86 switch (mServer.protocol) {
87 case IPPROTO_TCP:
88 type |= SOCK_STREAM;
89 break;
90 default:
91 return Status(EPROTONOSUPPORT);
92 }
93
94 mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
95 if (mSslFd.get() == -1) {
chenbruceaff85842019-05-31 15:46:42 +080096 LOG(ERROR) << "Failed to create socket";
Mike Yubab3daa2018-10-19 22:11:43 +080097 return Status(errno);
98 }
99
Sehee Parkd975bf32019-08-07 13:21:16 +0900100 resolv_tag_socket(mSslFd.get(), AID_DNS);
Sehee Park2c118782019-05-07 13:02:45 +0900101
Mike Yubab3daa2018-10-19 22:11:43 +0800102 const socklen_t len = sizeof(mMark);
103 if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
chenbruceaff85842019-05-31 15:46:42 +0800104 LOG(ERROR) << "Failed to set socket mark";
Mike Yubab3daa2018-10-19 22:11:43 +0800105 mSslFd.reset();
106 return Status(errno);
107 }
108
109 const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
110 if (!isOk(tfo) && tfo.code() != ENOPROTOOPT) {
chenbruceaff85842019-05-31 15:46:42 +0800111 LOG(WARNING) << "Failed to enable TFO: " << tfo.msg();
Mike Yubab3daa2018-10-19 22:11:43 +0800112 }
113
114 // Send 5 keepalives, 3 seconds apart, after 15 seconds of inactivity.
115 enableTcpKeepAlives(mSslFd.get(), 15U, 5U, 3U).ignoreError();
116
117 if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
118 sizeof(mServer.ss)) != 0 &&
119 errno != EINPROGRESS) {
chenbruceaff85842019-05-31 15:46:42 +0800120 LOG(DEBUG) << "Socket failed to connect";
Mike Yubab3daa2018-10-19 22:11:43 +0800121 mSslFd.reset();
122 return Status(errno);
123 }
124
125 return netdutils::status::ok;
126}
127
waynema0e73c2e2019-07-31 15:04:08 +0800128bool DnsTlsSocket::setTestCaCertificate() {
129 bssl::UniquePtr<BIO> bio(
130 BIO_new_mem_buf(mServer.certificate.data(), mServer.certificate.size()));
131 bssl::UniquePtr<X509> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
132 if (!cert) {
133 LOG(ERROR) << "Failed to read cert";
Mike Yubab3daa2018-10-19 22:11:43 +0800134 return false;
135 }
waynema0e73c2e2019-07-31 15:04:08 +0800136
137 X509_STORE* cert_store = SSL_CTX_get_cert_store(mSslCtx.get());
138 if (!X509_STORE_add_cert(cert_store, cert.get())) {
139 LOG(ERROR) << "Failed to add cert";
Mike Yubab3daa2018-10-19 22:11:43 +0800140 return false;
141 }
142 return true;
143}
144
waynema0e73c2e2019-07-31 15:04:08 +0800145// TODO: Try to use static sSslCtx instead of mSslCtx
Mike Yubab3daa2018-10-19 22:11:43 +0800146bool DnsTlsSocket::initialize() {
waynema0e73c2e2019-07-31 15:04:08 +0800147 // This method is called every time when a new SSL connection is created.
148 // This lock only serves to help catch bugs in code that calls this method.
Mike Yubab3daa2018-10-19 22:11:43 +0800149 std::lock_guard guard(mLock);
150 if (mSslCtx) {
151 // This is a bug in the caller.
152 return false;
153 }
154 mSslCtx.reset(SSL_CTX_new(TLS_method()));
155 if (!mSslCtx) {
156 return false;
157 }
158
waynema0e73c2e2019-07-31 15:04:08 +0800159 // Load system CA certs from CAPath for hostname verification.
Mike Yubab3daa2018-10-19 22:11:43 +0800160 //
161 // For discussion of alternative, sustainable approaches see b/71909242.
waynema0e73c2e2019-07-31 15:04:08 +0800162 if (RESOLV_INJECT_CA_CERTIFICATE && !mServer.certificate.empty()) {
163 // Inject test CA certs from ResolverParamsParcel.caCertificate for internal testing.
164 LOG(WARNING) << "test CA certificate is valid";
165 if (!setTestCaCertificate()) {
166 LOG(ERROR) << "Failed to set test CA certificate";
167 return false;
168 }
169 } else {
170 if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
171 LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir;
172 return false;
173 }
Mike Yubab3daa2018-10-19 22:11:43 +0800174 }
175
176 // Enable TLS false start
177 SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
178 SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
179
180 // Enable session cache
181 mCache->prepareSslContext(mSslCtx.get());
182
183 // Connect
184 Status status = tcpConnect();
185 if (!status.ok()) {
186 return false;
187 }
188 mSsl = sslConnect(mSslFd.get());
189 if (!mSsl) {
190 return false;
191 }
Ben Schwartz2187abe2019-01-10 14:30:46 -0500192
193 mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
Mike Yubab3daa2018-10-19 22:11:43 +0800194
195 // Start the I/O loop.
196 mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
197
198 return true;
199}
200
201bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
202 if (!mSslCtx) {
chenbruceaff85842019-05-31 15:46:42 +0800203 LOG(ERROR) << "Internal error: context is null in sslConnect";
Mike Yubab3daa2018-10-19 22:11:43 +0800204 return nullptr;
205 }
206 if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
chenbruceaff85842019-05-31 15:46:42 +0800207 LOG(ERROR) << "Failed to set minimum TLS version";
Mike Yubab3daa2018-10-19 22:11:43 +0800208 return nullptr;
209 }
210
211 bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
212 // This file descriptor is owned by mSslFd, so don't let libssl close it.
213 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
214 SSL_set_bio(ssl.get(), bio.get(), bio.get());
215 bio.release();
216
217 if (!mCache->prepareSsl(ssl.get())) {
218 return nullptr;
219 }
220
221 if (!mServer.name.empty()) {
waynema0e73c2e2019-07-31 15:04:08 +0800222 LOG(VERBOSE) << "Checking DNS over TLS hostname = " << mServer.name.c_str();
Mike Yubab3daa2018-10-19 22:11:43 +0800223 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
waynema0e73c2e2019-07-31 15:04:08 +0800224 LOG(ERROR) << "Failed to set SNI to " << mServer.name;
Mike Yubab3daa2018-10-19 22:11:43 +0800225 return nullptr;
226 }
227 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
228 if (X509_VERIFY_PARAM_set1_host(param, mServer.name.data(), mServer.name.size()) != 1) {
chenbruceaff85842019-05-31 15:46:42 +0800229 LOG(ERROR) << "Failed to set verify host param to " << mServer.name;
Mike Yubab3daa2018-10-19 22:11:43 +0800230 return nullptr;
231 }
232 // This will cause the handshake to fail if certificate verification fails.
233 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
234 }
235
236 bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
237 if (session) {
chenbruceaff85842019-05-31 15:46:42 +0800238 LOG(DEBUG) << "Setting session";
Mike Yubab3daa2018-10-19 22:11:43 +0800239 SSL_set_session(ssl.get(), session.get());
240 } else {
chenbruceaff85842019-05-31 15:46:42 +0800241 LOG(DEBUG) << "No session available";
Mike Yubab3daa2018-10-19 22:11:43 +0800242 }
243
244 for (;;) {
Chiachang Wang32372172019-07-06 13:54:18 +0800245 LOG(DEBUG) << " Calling SSL_connect with " << markToFwmarkString(mMark);
Mike Yubab3daa2018-10-19 22:11:43 +0800246 int ret = SSL_connect(ssl.get());
Chiachang Wang32372172019-07-06 13:54:18 +0800247 LOG(DEBUG) << " SSL_connect returned " << ret << " with " << markToFwmarkString(mMark);
Mike Yubab3daa2018-10-19 22:11:43 +0800248 if (ret == 1) break; // SSL handshake complete;
249
250 const int ssl_err = SSL_get_error(ssl.get(), ret);
251 switch (ssl_err) {
252 case SSL_ERROR_WANT_READ:
253 if (waitForReading(fd) != 1) {
Chiachang Wang32372172019-07-06 13:54:18 +0800254 PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark);
Mike Yubab3daa2018-10-19 22:11:43 +0800255 return nullptr;
256 }
257 break;
258 case SSL_ERROR_WANT_WRITE:
259 if (waitForWriting(fd) != 1) {
Chiachang Wang32372172019-07-06 13:54:18 +0800260 PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark);
Mike Yubab3daa2018-10-19 22:11:43 +0800261 return nullptr;
262 }
263 break;
264 default:
Chiachang Wang32372172019-07-06 13:54:18 +0800265 PLOG(WARNING) << "SSL_connect ssl error =" << ssl_err << ", "
266 << markToFwmarkString(mMark);
Mike Yubab3daa2018-10-19 22:11:43 +0800267 return nullptr;
268 }
269 }
270
chenbruceaff85842019-05-31 15:46:42 +0800271 LOG(DEBUG) << mMark << " handshake complete";
Mike Yubab3daa2018-10-19 22:11:43 +0800272
273 return ssl;
274}
275
276void DnsTlsSocket::sslDisconnect() {
277 if (mSsl) {
278 SSL_shutdown(mSsl.get());
279 mSsl.reset();
280 }
281 mSslFd.reset();
282}
283
284bool DnsTlsSocket::sslWrite(const Slice buffer) {
chenbruceaff85842019-05-31 15:46:42 +0800285 LOG(DEBUG) << mMark << " Writing " << buffer.size() << " bytes";
Mike Yubab3daa2018-10-19 22:11:43 +0800286 for (;;) {
287 int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
288 if (ret == int(buffer.size())) break; // SSL write complete;
289
290 if (ret < 1) {
291 const int ssl_err = SSL_get_error(mSsl.get(), ret);
292 switch (ssl_err) {
293 case SSL_ERROR_WANT_WRITE:
294 if (waitForWriting(mSslFd.get()) != 1) {
chenbruceaff85842019-05-31 15:46:42 +0800295 LOG(DEBUG) << "SSL_write error";
Mike Yubab3daa2018-10-19 22:11:43 +0800296 return false;
297 }
298 continue;
299 case 0:
300 break; // SSL write complete;
301 default:
chenbruceaff85842019-05-31 15:46:42 +0800302 LOG(DEBUG) << "SSL_write error " << ssl_err;
Mike Yubab3daa2018-10-19 22:11:43 +0800303 return false;
304 }
305 }
306 }
chenbruceaff85842019-05-31 15:46:42 +0800307 LOG(DEBUG) << mMark << " Wrote " << buffer.size() << " bytes";
Mike Yubab3daa2018-10-19 22:11:43 +0800308 return true;
309}
310
311void DnsTlsSocket::loop() {
312 std::lock_guard guard(mLock);
Ben Schwartz2187abe2019-01-10 14:30:46 -0500313 std::deque<std::vector<uint8_t>> q;
Mike Yubab3daa2018-10-19 22:11:43 +0800314 const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
Mike Yu04f1d482019-08-08 11:09:32 +0800315
316 Fwmark mark;
317 mark.intValue = mMark;
318 netdutils::setThreadName(android::base::StringPrintf("TlsListen_%u", mark.netId).c_str());
Mike Yubab3daa2018-10-19 22:11:43 +0800319 while (true) {
320 // poll() ignores negative fds
321 struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
Ben Schwartz2187abe2019-01-10 14:30:46 -0500322 enum { SSLFD = 0, EVENTFD = 1 };
Mike Yubab3daa2018-10-19 22:11:43 +0800323
324 // Always listen for a response from server.
325 fds[SSLFD].fd = mSslFd.get();
326 fds[SSLFD].events = POLLIN;
327
Ben Schwartz2187abe2019-01-10 14:30:46 -0500328 // If we have pending queries, wait for space to write one.
329 // Otherwise, listen for new queries.
Ben Schwartz62176fd2019-01-22 17:32:17 -0500330 // Note: This blocks the destructor until q is empty, i.e. until all pending
331 // queries are sent or have failed to send.
Ben Schwartz2187abe2019-01-10 14:30:46 -0500332 if (!q.empty()) {
Mike Yubab3daa2018-10-19 22:11:43 +0800333 fds[SSLFD].events |= POLLOUT;
334 } else {
Ben Schwartz2187abe2019-01-10 14:30:46 -0500335 fds[EVENTFD].fd = mEventFd.get();
336 fds[EVENTFD].events = POLLIN;
Mike Yubab3daa2018-10-19 22:11:43 +0800337 }
338
339 const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs));
340 if (s == 0) {
chenbruceaff85842019-05-31 15:46:42 +0800341 LOG(DEBUG) << "Idle timeout";
Mike Yubab3daa2018-10-19 22:11:43 +0800342 break;
343 }
344 if (s < 0) {
chenbruceaff85842019-05-31 15:46:42 +0800345 LOG(DEBUG) << "Poll failed: " << errno;
Mike Yubab3daa2018-10-19 22:11:43 +0800346 break;
347 }
Ben Schwartz62176fd2019-01-22 17:32:17 -0500348 if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
Mike Yubab3daa2018-10-19 22:11:43 +0800349 if (!readResponse()) {
chenbruceaff85842019-05-31 15:46:42 +0800350 LOG(DEBUG) << "SSL remote close or read error.";
Mike Yubab3daa2018-10-19 22:11:43 +0800351 break;
352 }
353 }
Ben Schwartz2187abe2019-01-10 14:30:46 -0500354 if (fds[EVENTFD].revents & (POLLIN | POLLERR)) {
355 int64_t num_queries;
356 ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries));
Mike Yubab3daa2018-10-19 22:11:43 +0800357 if (res < 0) {
chenbruceaff85842019-05-31 15:46:42 +0800358 LOG(WARNING) << "Error during eventfd read";
Mike Yubab3daa2018-10-19 22:11:43 +0800359 break;
360 } else if (res == 0) {
chenbruceaff85842019-05-31 15:46:42 +0800361 LOG(WARNING) << "eventfd closed; disconnecting";
Mike Yubab3daa2018-10-19 22:11:43 +0800362 break;
Ben Schwartz2187abe2019-01-10 14:30:46 -0500363 } else if (res != sizeof(num_queries)) {
chenbruceaff85842019-05-31 15:46:42 +0800364 LOG(ERROR) << "Int size mismatch: " << res << " != " << sizeof(num_queries);
Ben Schwartz2187abe2019-01-10 14:30:46 -0500365 break;
Ben Schwartz62176fd2019-01-22 17:32:17 -0500366 } else if (num_queries < 0) {
chenbruceaff85842019-05-31 15:46:42 +0800367 LOG(DEBUG) << "Negative eventfd read indicates destructor-initiated shutdown";
Ben Schwartz2187abe2019-01-10 14:30:46 -0500368 break;
369 }
370 // Take ownership of all pending queries. (q is always empty here.)
371 mQueue.swap(q);
Mike Yubab3daa2018-10-19 22:11:43 +0800372 } else if (fds[SSLFD].revents & POLLOUT) {
Ben Schwartz2187abe2019-01-10 14:30:46 -0500373 // q cannot be empty here.
374 // Sending the entire queue here would risk a TCP flow control deadlock, so
375 // we only send a single query on each cycle of this loop.
376 // TODO: Coalesce multiple pending queries if there is enough space in the
377 // write buffer.
378 if (!sendQuery(q.front())) {
Mike Yubab3daa2018-10-19 22:11:43 +0800379 break;
380 }
Ben Schwartz2187abe2019-01-10 14:30:46 -0500381 q.pop_front();
Mike Yubab3daa2018-10-19 22:11:43 +0800382 }
383 }
chenbruceaff85842019-05-31 15:46:42 +0800384 LOG(DEBUG) << "Disconnecting";
Mike Yubab3daa2018-10-19 22:11:43 +0800385 sslDisconnect();
chenbruceaff85842019-05-31 15:46:42 +0800386 LOG(DEBUG) << "Calling onClosed";
Mike Yubab3daa2018-10-19 22:11:43 +0800387 mObserver->onClosed();
chenbruceaff85842019-05-31 15:46:42 +0800388 LOG(DEBUG) << "Ending loop";
Mike Yubab3daa2018-10-19 22:11:43 +0800389}
390
391DnsTlsSocket::~DnsTlsSocket() {
chenbruceaff85842019-05-31 15:46:42 +0800392 LOG(DEBUG) << "Destructor";
Mike Yubab3daa2018-10-19 22:11:43 +0800393 // This will trigger an orderly shutdown in loop().
Ben Schwartz62176fd2019-01-22 17:32:17 -0500394 requestLoopShutdown();
Mike Yubab3daa2018-10-19 22:11:43 +0800395 {
396 // Wait for the orderly shutdown to complete.
397 std::lock_guard guard(mLock);
398 if (mLoopThread && std::this_thread::get_id() == mLoopThread->get_id()) {
chenbruceaff85842019-05-31 15:46:42 +0800399 LOG(ERROR) << "Violation of re-entrance precondition";
Mike Yubab3daa2018-10-19 22:11:43 +0800400 return;
401 }
402 }
403 if (mLoopThread) {
chenbruceaff85842019-05-31 15:46:42 +0800404 LOG(DEBUG) << "Waiting for loop thread to terminate";
Mike Yubab3daa2018-10-19 22:11:43 +0800405 mLoopThread->join();
406 mLoopThread.reset();
407 }
chenbruceaff85842019-05-31 15:46:42 +0800408 LOG(DEBUG) << "Destructor completed";
Mike Yubab3daa2018-10-19 22:11:43 +0800409}
410
411bool DnsTlsSocket::query(uint16_t id, const Slice query) {
Ben Schwartz2187abe2019-01-10 14:30:46 -0500412 // Compose the entire message in a single buffer, so that it can be
413 // sent as a single TLS record.
414 std::vector<uint8_t> buf(query.size() + 4);
415 // Write 2-byte length
416 uint16_t len = query.size() + 2; // + 2 for the ID.
417 buf[0] = len >> 8;
418 buf[1] = len;
419 // Write 2-byte ID
420 buf[2] = id >> 8;
421 buf[3] = id;
422 // Copy body
423 std::memcpy(buf.data() + 4, query.base(), query.size());
424
425 mQueue.push(std::move(buf));
426 // Increment the mEventFd counter by 1.
Ben Schwartz62176fd2019-01-22 17:32:17 -0500427 return incrementEventFd(1);
428}
429
430void DnsTlsSocket::requestLoopShutdown() {
Bernie Innocenti97ee1092019-03-28 15:52:59 +0900431 if (mEventFd != -1) {
432 // Write a negative number to the eventfd. This triggers an immediate shutdown.
433 incrementEventFd(INT64_MIN);
434 }
Ben Schwartz62176fd2019-01-22 17:32:17 -0500435}
436
437bool DnsTlsSocket::incrementEventFd(const int64_t count) {
Bernie Innocenti97ee1092019-03-28 15:52:59 +0900438 if (mEventFd == -1) {
chenbruceaff85842019-05-31 15:46:42 +0800439 LOG(ERROR) << "eventfd is not initialized";
Ben Schwartz62176fd2019-01-22 17:32:17 -0500440 return false;
441 }
Bernie Innocenti97ee1092019-03-28 15:52:59 +0900442 ssize_t written = write(mEventFd.get(), &count, sizeof(count));
Ben Schwartz62176fd2019-01-22 17:32:17 -0500443 if (written != sizeof(count)) {
chenbruceaff85842019-05-31 15:46:42 +0800444 LOG(ERROR) << "Failed to increment eventfd by " << count;
Ben Schwartz62176fd2019-01-22 17:32:17 -0500445 return false;
446 }
447 return true;
Mike Yubab3daa2018-10-19 22:11:43 +0800448}
449
450// Read exactly len bytes into buffer or fail with an SSL error code
451int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
452 size_t remaining = buffer.size();
453 while (remaining > 0) {
454 int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
455 if (ret == 0) {
chenbruceaff85842019-05-31 15:46:42 +0800456 if (remaining < buffer.size())
457 LOG(WARNING) << "SSL closed with " << remaining << " of " << buffer.size()
458 << " bytes remaining";
Mike Yubab3daa2018-10-19 22:11:43 +0800459 return SSL_ERROR_ZERO_RETURN;
460 }
461
462 if (ret < 0) {
463 const int ssl_err = SSL_get_error(mSsl.get(), ret);
464 if (wait && ssl_err == SSL_ERROR_WANT_READ) {
465 if (waitForReading(mSslFd.get()) != 1) {
chenbruceaff85842019-05-31 15:46:42 +0800466 LOG(DEBUG) << "Poll failed in sslRead: " << errno;
Mike Yubab3daa2018-10-19 22:11:43 +0800467 return SSL_ERROR_SYSCALL;
468 }
469 continue;
470 } else {
chenbruceaff85842019-05-31 15:46:42 +0800471 LOG(DEBUG) << "SSL_read error " << ssl_err;
Mike Yubab3daa2018-10-19 22:11:43 +0800472 return ssl_err;
473 }
474 }
475
476 remaining -= ret;
477 wait = true; // Once a read is started, try to finish.
478 }
479 return SSL_ERROR_NONE;
480}
481
Ben Schwartz2187abe2019-01-10 14:30:46 -0500482bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) {
Mike Yubab3daa2018-10-19 22:11:43 +0800483 if (!sslWrite(netdutils::makeSlice(buf))) {
484 return false;
485 }
chenbruceaff85842019-05-31 15:46:42 +0800486 LOG(DEBUG) << mMark << " SSL_write complete";
Mike Yubab3daa2018-10-19 22:11:43 +0800487 return true;
488}
489
490bool DnsTlsSocket::readResponse() {
chenbruceaff85842019-05-31 15:46:42 +0800491 LOG(DEBUG) << "reading response";
Mike Yubab3daa2018-10-19 22:11:43 +0800492 uint8_t responseHeader[2];
493 int err = sslRead(Slice(responseHeader, 2), false);
494 if (err == SSL_ERROR_WANT_READ) {
chenbruceaff85842019-05-31 15:46:42 +0800495 LOG(DEBUG) << "Ignoring spurious wakeup from server";
Mike Yubab3daa2018-10-19 22:11:43 +0800496 return true;
497 }
498 if (err != SSL_ERROR_NONE) {
499 return false;
500 }
501 // Truncate responses larger than MAX_SIZE. This is safe because a DNS packet is
502 // always invalid when truncated, so the response will be treated as an error.
503 constexpr uint16_t MAX_SIZE = 8192;
504 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
chenbruceaff85842019-05-31 15:46:42 +0800505 LOG(DEBUG) << mMark << " Expecting response of size " << responseSize;
Mike Yubab3daa2018-10-19 22:11:43 +0800506 std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
507 if (sslRead(netdutils::makeSlice(response), true) != SSL_ERROR_NONE) {
chenbruceaff85842019-05-31 15:46:42 +0800508 LOG(DEBUG) << mMark << " Failed to read " << response.size() << " bytes";
Mike Yubab3daa2018-10-19 22:11:43 +0800509 return false;
510 }
511 uint16_t remainingBytes = responseSize - response.size();
512 while (remainingBytes > 0) {
513 constexpr uint16_t CHUNK_SIZE = 2048;
514 std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
515 if (sslRead(netdutils::makeSlice(discard), true) != SSL_ERROR_NONE) {
chenbruceaff85842019-05-31 15:46:42 +0800516 LOG(DEBUG) << mMark << " Failed to discard " << discard.size() << " bytes";
Mike Yubab3daa2018-10-19 22:11:43 +0800517 return false;
518 }
519 remainingBytes -= discard.size();
520 }
chenbruceaff85842019-05-31 15:46:42 +0800521 LOG(DEBUG) << mMark << " SSL_read complete";
Mike Yubab3daa2018-10-19 22:11:43 +0800522
523 mObserver->onResponse(std::move(response));
524 return true;
525}
526
527} // end of namespace net
528} // end of namespace android