blob: 624396823b2fa0cac7697699db3c02a398cb6e23 [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "DnsTlsSocket"
18
19#include "dns/DnsTlsSocket.h"
20
21#include <algorithm>
22#include <arpa/inet.h>
23#include <arpa/nameser.h>
24#include <errno.h>
25#include <openssl/err.h>
26#include <sys/select.h>
27
28#include "dns/DnsTlsSessionCache.h"
29
30//#define LOG_NDEBUG 0
31
32#include "log/log.h"
33#include "Fwmark.h"
34#undef ADD // already defined in nameser.h
35#include "NetdConstants.h"
36#include "Permission.h"
37
38
39namespace android {
40namespace net {
41
42using netdutils::Status;
43
44namespace {
45
46constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
47
48int waitForReading(int fd) {
49 fd_set fds;
50 FD_ZERO(&fds);
51 FD_SET(fd, &fds);
52 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
53 ALOGV_IF(ret <= 0, "select failed during read");
54 return ret;
55}
56
57int waitForWriting(int fd) {
58 fd_set fds;
59 FD_ZERO(&fds);
60 FD_SET(fd, &fds);
61 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
62 ALOGV_IF(ret <= 0, "select failed during write");
63 return ret;
64}
65
66} // namespace
67
68Status DnsTlsSocket::tcpConnect() {
69 ALOGV("%u connecting TCP socket", mMark);
70 int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
71 switch (mServer.protocol) {
72 case IPPROTO_TCP:
73 type |= SOCK_STREAM;
74 break;
75 default:
76 return Status(EPROTONOSUPPORT);
77 }
78
79 mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
80 if (mSslFd.get() == -1) {
81 ALOGE("Failed to create socket");
82 return Status(errno);
83 }
84
85 const socklen_t len = sizeof(mMark);
86 if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
87 ALOGE("Failed to set socket mark");
88 mSslFd.reset();
89 return Status(errno);
90 }
91 if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
92 sizeof(mServer.ss)) != 0 &&
93 errno != EINPROGRESS) {
94 ALOGV("Socket failed to connect");
95 mSslFd.reset();
96 return Status(errno);
97 }
98
99 return netdutils::status::ok;
100}
101
102bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
103 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
104 unsigned char spki[spki_len];
105 unsigned char* temp = spki;
106 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
107 ALOGW("SPKI length mismatch");
108 return false;
109 }
110 out->resize(SHA256_SIZE);
111 unsigned int digest_len = 0;
112 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
113 if (ret != 1) {
114 ALOGW("Server cert digest extraction failed");
115 return false;
116 }
117 if (digest_len != out->size()) {
118 ALOGW("Wrong digest length: %d", digest_len);
119 return false;
120 }
121 return true;
122}
123
124bool DnsTlsSocket::initialize() {
125 // This method should only be called once, at the beginning, so locking should be
126 // unnecessary. This lock only serves to help catch bugs in code that calls this method.
127 std::lock_guard<std::mutex> guard(mLock);
128 if (mSslCtx) {
129 // This is a bug in the caller.
130 return false;
131 }
132 mSslCtx.reset(SSL_CTX_new(TLS_method()));
133 if (!mSslCtx) {
134 return false;
135 }
136
137 // Load system CA certs for hostname verification.
138 //
139 // For discussion of alternative, sustainable approaches see b/71909242.
140 if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
141 ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
142 return false;
143 }
144
145 // Enable TLS false start
146 SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
147 SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
148
149 // Enable session cache
150 mCache->prepareSslContext(mSslCtx.get());
151
152 // Connect
153 Status status = tcpConnect();
154 if (!status.ok()) {
155 return false;
156 }
157 mSsl = sslConnect(mSslFd.get());
158 if (!mSsl) {
159 return false;
160 }
161
162 return true;
163}
164
165bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
166 if (!mSslCtx) {
167 ALOGE("Internal error: context is null in sslConnect");
168 return nullptr;
169 }
170 if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
171 ALOGE("Failed to set minimum TLS version");
172 return nullptr;
173 }
174
175 bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
176 // This file descriptor is owned by mSslFd, so don't let libssl close it.
177 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
178 SSL_set_bio(ssl.get(), bio.get(), bio.get());
179 bio.release();
180
181 if (!mCache->prepareSsl(ssl.get())) {
182 return nullptr;
183 }
184
185 if (!mServer.name.empty()) {
186 if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
187 ALOGE("Failed to set SNI to %s", mServer.name.c_str());
188 return nullptr;
189 }
190 X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
191 X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
192 // This will cause the handshake to fail if certificate verification fails.
193 SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
194 }
195
196 bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
197 if (session) {
198 ALOGV("Setting session");
199 SSL_set_session(ssl.get(), session.get());
200 } else {
201 ALOGV("No session available");
202 }
203
204 for (;;) {
205 ALOGV("%u Calling SSL_connect", mMark);
206 int ret = SSL_connect(ssl.get());
207 ALOGV("%u SSL_connect returned %d", mMark, ret);
208 if (ret == 1) break; // SSL handshake complete;
209
210 const int ssl_err = SSL_get_error(ssl.get(), ret);
211 switch (ssl_err) {
212 case SSL_ERROR_WANT_READ:
213 if (waitForReading(fd) != 1) {
214 ALOGW("SSL_connect read error");
215 return nullptr;
216 }
217 break;
218 case SSL_ERROR_WANT_WRITE:
219 if (waitForWriting(fd) != 1) {
220 ALOGW("SSL_connect write error");
221 return nullptr;
222 }
223 break;
224 default:
225 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
226 return nullptr;
227 }
228 }
229
230 // TODO: Call SSL_shutdown before discarding the session if validation fails.
231 if (!mServer.fingerprints.empty()) {
232 ALOGV("Checking DNS over TLS fingerprint");
233
234 // We only care that the chain is internally self-consistent, not that
235 // it chains to a trusted root, so we can ignore some kinds of errors.
236 // TODO: Add a CA root verification mode that respects these errors.
237 int verify_result = SSL_get_verify_result(ssl.get());
238 switch (verify_result) {
239 case X509_V_OK:
240 case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
241 case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
242 case X509_V_ERR_CERT_UNTRUSTED:
243 break;
244 default:
245 ALOGW("Invalid certificate chain, error %d", verify_result);
246 return nullptr;
247 }
248
249 STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
250 if (!chain) {
251 ALOGW("Server has null certificate");
252 return nullptr;
253 }
254 // Chain and its contents are owned by ssl, so we don't need to free explicitly.
255 bool matched = false;
256 for (size_t i = 0; i < sk_X509_num(chain); ++i) {
257 // This appears to be O(N^2), but there doesn't seem to be a straightforward
258 // way to walk a STACK_OF nondestructively in linear time.
259 X509* cert = sk_X509_value(chain, i);
260 std::vector<uint8_t> digest;
261 if (!getSPKIDigest(cert, &digest)) {
262 ALOGE("Digest computation failed");
263 return nullptr;
264 }
265
266 if (mServer.fingerprints.count(digest) > 0) {
267 matched = true;
268 break;
269 }
270 }
271
272 if (!matched) {
273 ALOGW("No matching fingerprint");
274 return nullptr;
275 }
276
277 ALOGV("DNS over TLS fingerprint is correct");
278 }
279
280 ALOGV("%u handshake complete", mMark);
281
282 return ssl;
283}
284
285void DnsTlsSocket::sslDisconnect() {
286 if (mSsl) {
287 SSL_shutdown(mSsl.get());
288 mSsl.reset();
289 }
290 mSslFd.reset();
291}
292
293bool DnsTlsSocket::sslWrite(const Slice buffer) {
294 ALOGV("%u Writing %zu bytes", mMark, buffer.size());
295 for (;;) {
296 int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
297 if (ret == int(buffer.size())) break; // SSL write complete;
298
299 if (ret < 1) {
300 const int ssl_err = SSL_get_error(mSsl.get(), ret);
301 switch (ssl_err) {
302 case SSL_ERROR_WANT_WRITE:
303 if (waitForWriting(mSslFd.get()) != 1) {
304 ALOGV("SSL_write error");
305 return false;
306 }
307 continue;
308 case 0:
309 break; // SSL write complete;
310 default:
311 ALOGV("SSL_write error %d", ssl_err);
312 return false;
313 }
314 }
315 }
316 ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
317 return true;
318}
319
320DnsTlsSocket::~DnsTlsSocket() {
321 sslDisconnect();
322}
323
324DnsTlsServer::Result DnsTlsSocket::query(uint16_t id, const Slice query) {
325 std::lock_guard<std::mutex> guard(mLock);
326 const Query q = { .id = id, .query = query };
327 if (!sendQuery(q)) {
328 return { .code = DnsTlsServer::Response::network_error };
329 }
330 return readResponse();
331}
332
333// Read exactly len bytes into buffer or fail
334bool DnsTlsSocket::sslRead(const Slice buffer) {
335 size_t remaining = buffer.size();
336 while (remaining > 0) {
337 int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
338 if (ret == 0) {
339 ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
340 remaining, buffer.size());
341 return false;
342 }
343
344 if (ret < 0) {
345 const int ssl_err = SSL_get_error(mSsl.get(), ret);
346 if (ssl_err == SSL_ERROR_WANT_READ) {
347 if (waitForReading(mSslFd.get()) != 1) {
348 ALOGV("SSL_read error");
349 return false;
350 }
351 continue;
352 } else {
353 ALOGV("SSL_read error %d", ssl_err);
354 return false;
355 }
356 }
357
358 remaining -= ret;
359 }
360 return true;
361}
362
363bool DnsTlsSocket::sendQuery(const Query& q) {
364 ALOGV("sending query");
365 // Compose the entire message in a single buffer, so that it can be
366 // sent as a single TLS record.
367 std::vector<uint8_t> buf(q.query.size() + 4);
368 // Write 2-byte length
369 uint16_t len = q.query.size() + 2; // + 2 for the ID.
370 buf[0] = len >> 8;
371 buf[1] = len;
372 // Write 2-byte ID
373 buf[2] = q.id >> 8;
374 buf[3] = q.id;
375 // Copy body
376 std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
377 if (!sslWrite(netdutils::makeSlice(buf))) {
378 return false;
379 }
380 ALOGV("%u SSL_write complete", mMark);
381 return true;
382}
383
384DnsTlsServer::Result DnsTlsSocket::readResponse() {
385 ALOGV("reading response");
386 uint8_t responseHeader[2];
387 const DnsTlsServer::Result failed = { .code = DnsTlsServer::Response::network_error };
388 if (!sslRead(Slice(responseHeader, 2))) {
389 return failed;
390 }
391 // Truncate responses larger than MAX_SIZE. This is safe because a DNS packet is
392 // always invalid when truncated, so the response will be treated as an error.
393 constexpr uint16_t MAX_SIZE = 8192;
394 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
395 ALOGV("%u Expecting response of size %i", mMark, responseSize);
396 std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
397 if (!sslRead(netdutils::makeSlice(response))) {
398 ALOGV("%u Failed to read %zu bytes", mMark, response.size());
399 return failed;
400 }
401 uint16_t remainingBytes = responseSize - response.size();
402 while (remainingBytes > 0) {
403 constexpr uint16_t CHUNK_SIZE = 2048;
404 std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
405 if (!sslRead(netdutils::makeSlice(discard))) {
406 ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
407 return failed;
408 }
409 remainingBytes -= discard.size();
410 }
411 ALOGV("%u SSL_read complete", mMark);
412
413 return { .code = DnsTlsServer::Response::success, .response = response };
414}
415
416} // end of namespace net
417} // end of namespace android