Merge libnetddns into libnetd_resolv
libnetddns is the library for DNS-over-TLS and is statically
linked to netd. Deprecate it and move them to libnetd_resolv
as a more general DNS library for netd.
This change comprises:
[1] Clean up netd/server/dns/*. Move all DnsTls* files to
netd/resolv/ to parts of libnetd_resolv library.
[2] Export DnsTls* classes being visible for netd. It will only
be temporary for a while.
[3] Remove the libssl dependency in netd. The relevant stuff is
moved to libnetd_resolv.
Note that DnsTls* classes are still required for DnsProxyListener
and ResolverController to manipulate private DNS servers even after
this change.
Bug: 113628807
Test: as follows
- built, flashed, booted
- system/netd/tests/runtests.sh
- DNS-over-TLS in live network passed
Change-Id: Ieac5889b4ebe737f876b3dcbe1a8da2b2b1b629d
diff --git a/resolv/Android.bp b/resolv/Android.bp
index 61560bd..f26d820 100644
--- a/resolv/Android.bp
+++ b/resolv/Android.bp
@@ -16,6 +16,12 @@
"res_send.cpp",
"res_state.cpp",
"res_stats.cpp",
+ "DnsTlsDispatcher.cpp",
+ "DnsTlsQueryMap.cpp",
+ "DnsTlsTransport.cpp",
+ "DnsTlsServer.cpp",
+ "DnsTlsSessionCache.cpp",
+ "DnsTlsSocket.cpp",
],
// Link everything statically (except for libc) to minimize our dependence
// on system ABIs
@@ -23,6 +29,14 @@
static_libs: [
"libbase",
"liblog",
+ "libnetdutils",
+ "libssl",
+ "libcrypto",
+ ],
+ // TODO: Remove this after we find an alternative way to substitute Fwmark
+ // used in DnsTls* classes.
+ include_dirs: [
+ "system/netd/include",
],
export_include_dirs: ["include"],
// TODO: pie in the sky: make this code clang-tidy clean
@@ -48,10 +62,8 @@
shared_libs: [
"libbase",
"liblog",
- "libssl",
- ],
- static_libs: [
- "libnetddns",
"libnetdutils",
+ "libnetd_resolv",
+ "libssl",
],
}
diff --git a/resolv/DnsTlsDispatcher.cpp b/resolv/DnsTlsDispatcher.cpp
new file mode 100644
index 0000000..9d5d3d5
--- /dev/null
+++ b/resolv/DnsTlsDispatcher.cpp
@@ -0,0 +1,180 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsDispatcher"
+//#define LOG_NDEBUG 0
+
+#include "netd_resolv/DnsTlsDispatcher.h"
+#include "netd_resolv/DnsTlsSocketFactory.h"
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+
+// static
+std::mutex DnsTlsDispatcher::sLock;
+
+DnsTlsDispatcher::DnsTlsDispatcher() {
+ mFactory.reset(new DnsTlsSocketFactory());
+}
+
+std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark) const {
+ // Our preferred DnsTlsServer order is:
+ // 1) reuse existing IPv6 connections
+ // 2) reuse existing IPv4 connections
+ // 3) establish new IPv6 connections
+ // 4) establish new IPv4 connections
+ std::list<DnsTlsServer> existing6;
+ std::list<DnsTlsServer> existing4;
+ std::list<DnsTlsServer> new6;
+ std::list<DnsTlsServer> new4;
+
+ // Pull out any servers for which we might have existing connections and
+ // place them at the from the list of servers to try.
+ {
+ std::lock_guard guard(sLock);
+
+ for (const auto& tlsServer : tlsServers) {
+ const Key key = std::make_pair(mark, tlsServer);
+ if (mStore.find(key) != mStore.end()) {
+ switch (tlsServer.ss.ss_family) {
+ case AF_INET:
+ existing4.push_back(tlsServer);
+ break;
+ case AF_INET6:
+ existing6.push_back(tlsServer);
+ break;
+ }
+ } else {
+ switch (tlsServer.ss.ss_family) {
+ case AF_INET:
+ new4.push_back(tlsServer);
+ break;
+ case AF_INET6:
+ new6.push_back(tlsServer);
+ break;
+ }
+ }
+ }
+ }
+
+ auto& out = existing6;
+ out.splice(out.cend(), existing4);
+ out.splice(out.cend(), new6);
+ out.splice(out.cend(), new4);
+ return out;
+}
+
+DnsTlsTransport::Response DnsTlsDispatcher::query(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark,
+ const Slice query, const Slice ans, int *resplen) {
+ const std::list<DnsTlsServer> orderedServers(getOrderedServerList(tlsServers, mark));
+
+ if (orderedServers.empty()) ALOGW("Empty DnsTlsServer list");
+
+ DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
+ for (const auto& server : orderedServers) {
+ code = this->query(server, mark, query, ans, resplen);
+ switch (code) {
+ // These response codes are valid responses and not expected to
+ // change if another server is queried.
+ case DnsTlsTransport::Response::success:
+ case DnsTlsTransport::Response::limit_error:
+ return code;
+ break;
+ // These response codes might differ when trying other servers, so
+ // keep iterating to see if we can get a different (better) result.
+ case DnsTlsTransport::Response::network_error:
+ case DnsTlsTransport::Response::internal_error:
+ continue;
+ break;
+ // No "default" statement.
+ }
+ }
+
+ return code;
+}
+
+DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
+ const Slice query,
+ const Slice ans, int *resplen) {
+ const Key key = std::make_pair(mark, server);
+ Transport* xport;
+ {
+ std::lock_guard guard(sLock);
+ auto it = mStore.find(key);
+ if (it == mStore.end()) {
+ xport = new Transport(server, mark, mFactory.get());
+ mStore[key].reset(xport);
+ } else {
+ xport = it->second.get();
+ }
+ ++xport->useCount;
+ }
+
+ ALOGV("Sending query of length %zu", query.size());
+ auto res = xport->transport.query(query);
+ ALOGV("Awaiting response");
+ const auto& result = res.get();
+ DnsTlsTransport::Response code = result.code;
+ if (code == DnsTlsTransport::Response::success) {
+ if (result.response.size() > ans.size()) {
+ ALOGV("Response too large: %zu > %zu", result.response.size(), ans.size());
+ code = DnsTlsTransport::Response::limit_error;
+ } else {
+ ALOGV("Got response successfully");
+ *resplen = result.response.size();
+ netdutils::copy(ans, netdutils::makeSlice(result.response));
+ }
+ } else {
+ ALOGV("Query failed: %u", (unsigned int) code);
+ }
+
+ auto now = std::chrono::steady_clock::now();
+ {
+ std::lock_guard guard(sLock);
+ --xport->useCount;
+ xport->lastUsed = now;
+ cleanup(now);
+ }
+ return code;
+}
+
+// This timeout effectively controls how long to keep SSL session tickets.
+static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
+void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
+ // To avoid scanning mStore after every query, return early if a cleanup has been
+ // performed recently.
+ if (now - mLastCleanup < IDLE_TIMEOUT) {
+ return;
+ }
+ for (auto it = mStore.begin(); it != mStore.end();) {
+ auto& s = it->second;
+ if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) {
+ it = mStore.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ mLastCleanup = now;
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/resolv/DnsTlsQueryMap.cpp b/resolv/DnsTlsQueryMap.cpp
new file mode 100644
index 0000000..97f4eb6
--- /dev/null
+++ b/resolv/DnsTlsQueryMap.cpp
@@ -0,0 +1,149 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsQueryMap"
+//#define LOG_NDEBUG 0
+
+#include "netd_resolv/DnsTlsQueryMap.h"
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(const Slice query) {
+ std::lock_guard guard(mLock);
+
+ // Store the query so it can be matched to the response or reissued.
+ if (query.size() < 2) {
+ ALOGW("Query is too short");
+ return nullptr;
+ }
+ int32_t newId = getFreeId();
+ if (newId < 0) {
+ ALOGW("All query IDs are in use");
+ return nullptr;
+ }
+ Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
+ std::map<uint16_t, QueryPromise>::iterator it;
+ bool inserted;
+ std::tie(it, inserted) = mQueries.emplace(newId, q);
+ if (!inserted) {
+ ALOGE("Failed to store pending query");
+ return nullptr;
+ }
+ return std::make_unique<QueryFuture>(q, it->second.result.get_future());
+}
+
+void DnsTlsQueryMap::expire(QueryPromise* p) {
+ Result r = { .code = Response::network_error };
+ p->result.set_value(r);
+}
+
+void DnsTlsQueryMap::markTried(uint16_t newId) {
+ std::lock_guard guard(mLock);
+ auto it = mQueries.find(newId);
+ if (it != mQueries.end()) {
+ it->second.tries++;
+ }
+}
+
+void DnsTlsQueryMap::cleanup() {
+ std::lock_guard guard(mLock);
+ for (auto it = mQueries.begin(); it != mQueries.end();) {
+ auto& p = it->second;
+ if (p.tries >= kMaxTries) {
+ expire(&p);
+ it = mQueries.erase(it);
+ } else {
+ ++it;
+ }
+ }
+}
+
+int32_t DnsTlsQueryMap::getFreeId() {
+ if (mQueries.empty()) {
+ return 0;
+ }
+ uint16_t maxId = mQueries.rbegin()->first;
+ if (maxId < UINT16_MAX) {
+ return maxId + 1;
+ }
+ if (mQueries.size() == UINT16_MAX + 1) {
+ // Map is full.
+ return -1;
+ }
+ // Linear scan.
+ uint16_t nextId = 0;
+ for (auto& pair : mQueries) {
+ uint16_t id = pair.first;
+ if (id != nextId) {
+ // Found a gap.
+ return nextId;
+ }
+ nextId = id + 1;
+ }
+ // Unreachable (but the compiler isn't smart enough to prove it).
+ return -1;
+}
+
+std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
+ std::lock_guard guard(mLock);
+ std::vector<DnsTlsQueryMap::Query> queries;
+ for (auto& q : mQueries) {
+ queries.push_back(q.second.query);
+ }
+ return queries;
+}
+
+bool DnsTlsQueryMap::empty() {
+ std::lock_guard guard(mLock);
+ return mQueries.empty();
+}
+
+void DnsTlsQueryMap::clear() {
+ std::lock_guard guard(mLock);
+ for (auto& q : mQueries) {
+ expire(&q.second);
+ }
+ mQueries.clear();
+}
+
+void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
+ ALOGV("Got response of size %zu", response.size());
+ if (response.size() < 2) {
+ ALOGW("Response is too short");
+ return;
+ }
+ uint16_t id = response[0] << 8 | response[1];
+ std::lock_guard guard(mLock);
+ auto it = mQueries.find(id);
+ if (it == mQueries.end()) {
+ ALOGW("Discarding response: unknown ID %d", id);
+ return;
+ }
+ Result r = { .code = Response::success, .response = std::move(response) };
+ // Rewrite ID to match the query
+ const uint8_t* data = it->second.query.query.base();
+ r.response[0] = data[0];
+ r.response[1] = data[1];
+ ALOGV("Sending result to dispatcher");
+ it->second.result.set_value(std::move(r));
+ mQueries.erase(it);
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/resolv/DnsTlsServer.cpp b/resolv/DnsTlsServer.cpp
new file mode 100644
index 0000000..b1ee481
--- /dev/null
+++ b/resolv/DnsTlsServer.cpp
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "netd_resolv/DnsTlsServer.h"
+
+#include <algorithm>
+
+namespace {
+
+// Returns a tuple of references to the elements of a.
+auto make_tie(const sockaddr_in& a) {
+ return std::tie(a.sin_port, a.sin_addr.s_addr);
+}
+
+// Returns a tuple of references to the elements of a.
+auto make_tie(const sockaddr_in6& a) {
+ // Skip flowinfo, which is not relevant.
+ return std::tie(
+ a.sin6_port,
+ a.sin6_addr,
+ a.sin6_scope_id
+ );
+}
+
+} // namespace
+
+// These binary operators make sockaddr_storage comparable. They need to be
+// in the global namespace so that the std::tuple < and == operators can see them.
+static bool operator <(const in6_addr& x, const in6_addr& y) {
+ return std::lexicographical_compare(
+ std::begin(x.s6_addr), std::end(x.s6_addr),
+ std::begin(y.s6_addr), std::end(y.s6_addr));
+}
+
+static bool operator ==(const in6_addr& x, const in6_addr& y) {
+ return std::equal(
+ std::begin(x.s6_addr), std::end(x.s6_addr),
+ std::begin(y.s6_addr), std::end(y.s6_addr));
+}
+
+static bool operator <(const sockaddr_storage& x, const sockaddr_storage& y) {
+ if (x.ss_family != y.ss_family) {
+ return x.ss_family < y.ss_family;
+ }
+ // Same address family.
+ if (x.ss_family == AF_INET) {
+ const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
+ const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
+ return make_tie(x_sin) < make_tie(y_sin);
+ } else if (x.ss_family == AF_INET6) {
+ const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
+ const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
+ return make_tie(x_sin6) < make_tie(y_sin6);
+ }
+ return false; // Unknown address type. This is an error.
+}
+
+static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) {
+ if (x.ss_family != y.ss_family) {
+ return false;
+ }
+ // Same address family.
+ if (x.ss_family == AF_INET) {
+ const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
+ const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
+ return make_tie(x_sin) == make_tie(y_sin);
+ } else if (x.ss_family == AF_INET6) {
+ const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
+ const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
+ return make_tie(x_sin6) == make_tie(y_sin6);
+ }
+ return false; // Unknown address type. This is an error.
+}
+
+namespace android {
+namespace net {
+
+// This comparison ignores ports and fingerprints.
+bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const {
+ if (x.ss.ss_family != y.ss.ss_family) {
+ return x.ss.ss_family < y.ss.ss_family;
+ }
+ // Same address family.
+ if (x.ss.ss_family == AF_INET) {
+ const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+ const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+ return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+ } else if (x.ss.ss_family == AF_INET6) {
+ const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+ const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+ return std::tie(x_sin6.sin6_addr, x_sin6.sin6_scope_id) <
+ std::tie(y_sin6.sin6_addr, y_sin6.sin6_scope_id);
+ }
+ return false; // Unknown address type. This is an error.
+}
+
+// Returns a tuple of references to the elements of s.
+auto make_tie(const DnsTlsServer& s) {
+ return std::tie(
+ s.ss,
+ s.name,
+ s.fingerprints,
+ s.protocol
+ );
+}
+
+bool DnsTlsServer::operator <(const DnsTlsServer& other) const {
+ return make_tie(*this) < make_tie(other);
+}
+
+bool DnsTlsServer::operator ==(const DnsTlsServer& other) const {
+ return make_tie(*this) == make_tie(other);
+}
+
+bool DnsTlsServer::wasExplicitlyConfigured() const {
+ return !name.empty() || !fingerprints.empty();
+}
+
+} // namespace net
+} // namespace android
diff --git a/resolv/DnsTlsSessionCache.cpp b/resolv/DnsTlsSessionCache.cpp
new file mode 100644
index 0000000..54c1296
--- /dev/null
+++ b/resolv/DnsTlsSessionCache.cpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "netd_resolv/DnsTlsSessionCache.h"
+
+#define LOG_TAG "DnsTlsSessionCache"
+//#define LOG_NDEBUG 0
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+bool DnsTlsSessionCache::prepareSsl(SSL* ssl) {
+ // Add this cache as the 0-index extra data for the socket.
+ // This is used by newSessionCallback.
+ int ret = SSL_set_ex_data(ssl, 0, this);
+ return ret == 1;
+}
+
+void DnsTlsSessionCache::prepareSslContext(SSL_CTX* ssl_ctx) {
+ SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_CLIENT);
+ SSL_CTX_sess_set_new_cb(ssl_ctx, &DnsTlsSessionCache::newSessionCallback);
+}
+
+// static
+int DnsTlsSessionCache::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ if (!ssl || !session) {
+ ALOGE("Null SSL object in new session callback");
+ return 0;
+ }
+ DnsTlsSessionCache* cache = reinterpret_cast<DnsTlsSessionCache*>(
+ SSL_get_ex_data(ssl, 0));
+ if (!cache) {
+ ALOGE("null transport in new session callback");
+ return 0;
+ }
+ ALOGV("Recording session");
+ cache->recordSession(session);
+ return 1; // Increment the refcount of session.
+}
+
+void DnsTlsSessionCache::recordSession(SSL_SESSION* session) {
+ std::lock_guard guard(mLock);
+ mSessions.emplace_front(session);
+ if (mSessions.size() > kMaxSize) {
+ ALOGV("Too many sessions; trimming");
+ mSessions.pop_back();
+ }
+}
+
+bssl::UniquePtr<SSL_SESSION> DnsTlsSessionCache::getSession() {
+ std::lock_guard guard(mLock);
+ if (mSessions.size() == 0) {
+ ALOGV("No known sessions");
+ return nullptr;
+ }
+ bssl::UniquePtr<SSL_SESSION> ret = std::move(mSessions.front());
+ mSessions.pop_front();
+ return ret;
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/resolv/DnsTlsSocket.cpp b/resolv/DnsTlsSocket.cpp
new file mode 100644
index 0000000..425aa17
--- /dev/null
+++ b/resolv/DnsTlsSocket.cpp
@@ -0,0 +1,527 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsSocket"
+//#define LOG_NDEBUG 0
+
+#include "netd_resolv/DnsTlsSocket.h"
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+#include <errno.h>
+#include <linux/tcp.h>
+#include <openssl/err.h>
+#include <openssl/sha.h>
+#include <sys/poll.h>
+
+#include "netd_resolv/DnsTlsSessionCache.h"
+#include "netd_resolv/IDnsTlsSocketObserver.h"
+
+#include "log/log.h"
+#include "netdutils/SocketOption.h"
+#include "Permission.h"
+
+namespace android {
+
+using netdutils::enableSockopt;
+using netdutils::enableTcpKeepAlives;
+using netdutils::isOk;
+using netdutils::Status;
+
+namespace net {
+namespace {
+
+constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
+constexpr size_t SHA256_SIZE = SHA256_DIGEST_LENGTH;
+
+int waitForReading(int fd) {
+ struct pollfd fds = { .fd = fd, .events = POLLIN };
+ const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
+ return ret;
+}
+
+int waitForWriting(int fd) {
+ struct pollfd fds = { .fd = fd, .events = POLLOUT };
+ const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
+ return ret;
+}
+
+} // namespace
+
+Status DnsTlsSocket::tcpConnect() {
+ ALOGV("%u connecting TCP socket", mMark);
+ int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
+ switch (mServer.protocol) {
+ case IPPROTO_TCP:
+ type |= SOCK_STREAM;
+ break;
+ default:
+ return Status(EPROTONOSUPPORT);
+ }
+
+ mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
+ if (mSslFd.get() == -1) {
+ ALOGE("Failed to create socket");
+ return Status(errno);
+ }
+
+ const socklen_t len = sizeof(mMark);
+ if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
+ ALOGE("Failed to set socket mark");
+ mSslFd.reset();
+ return Status(errno);
+ }
+
+ const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
+ if (!isOk(tfo) && tfo.code() != ENOPROTOOPT) {
+ ALOGI("Failed to enable TFO: %s", tfo.msg().c_str());
+ }
+
+ // Send 5 keepalives, 3 seconds apart, after 15 seconds of inactivity.
+ enableTcpKeepAlives(mSslFd.get(), 15U, 5U, 3U).ignoreError();
+
+ if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
+ sizeof(mServer.ss)) != 0 &&
+ errno != EINPROGRESS) {
+ ALOGV("Socket failed to connect");
+ mSslFd.reset();
+ return Status(errno);
+ }
+
+ return netdutils::status::ok;
+}
+
+bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
+ int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr);
+ unsigned char spki[spki_len];
+ unsigned char* temp = spki;
+ if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
+ ALOGW("SPKI length mismatch");
+ return false;
+ }
+ out->resize(SHA256_SIZE);
+ unsigned int digest_len = 0;
+ int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr);
+ if (ret != 1) {
+ ALOGW("Server cert digest extraction failed");
+ return false;
+ }
+ if (digest_len != out->size()) {
+ ALOGW("Wrong digest length: %d", digest_len);
+ return false;
+ }
+ return true;
+}
+
+bool DnsTlsSocket::initialize() {
+ // This method should only be called once, at the beginning, so locking should be
+ // unnecessary. This lock only serves to help catch bugs in code that calls this method.
+ std::lock_guard guard(mLock);
+ if (mSslCtx) {
+ // This is a bug in the caller.
+ return false;
+ }
+ mSslCtx.reset(SSL_CTX_new(TLS_method()));
+ if (!mSslCtx) {
+ return false;
+ }
+
+ // Load system CA certs for hostname verification.
+ //
+ // For discussion of alternative, sustainable approaches see b/71909242.
+ if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
+ ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
+ return false;
+ }
+
+ // Enable TLS false start
+ SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
+ SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
+
+ // Enable session cache
+ mCache->prepareSslContext(mSslCtx.get());
+
+ // Connect
+ Status status = tcpConnect();
+ if (!status.ok()) {
+ return false;
+ }
+ mSsl = sslConnect(mSslFd.get());
+ if (!mSsl) {
+ return false;
+ }
+ int sv[2];
+ if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
+ return false;
+ }
+ // The two sockets are perfectly symmetrical, so the choice of which one is
+ // "in" and which one is "out" is arbitrary.
+ mIpcInFd.reset(sv[0]);
+ mIpcOutFd.reset(sv[1]);
+
+ // Start the I/O loop.
+ mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
+
+ return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
+ if (!mSslCtx) {
+ ALOGE("Internal error: context is null in sslConnect");
+ return nullptr;
+ }
+ if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+ ALOGE("Failed to set minimum TLS version");
+ return nullptr;
+ }
+
+ bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
+ // This file descriptor is owned by mSslFd, so don't let libssl close it.
+ bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
+ SSL_set_bio(ssl.get(), bio.get(), bio.get());
+ bio.release();
+
+ if (!mCache->prepareSsl(ssl.get())) {
+ return nullptr;
+ }
+
+ if (!mServer.name.empty()) {
+ if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
+ ALOGE("Failed to set SNI to %s", mServer.name.c_str());
+ return nullptr;
+ }
+ X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
+ if (X509_VERIFY_PARAM_set1_host(param, mServer.name.data(), mServer.name.size()) != 1) {
+ ALOGE("Failed to set verify host param to %s", mServer.name.c_str());
+ return nullptr;
+ }
+ // This will cause the handshake to fail if certificate verification fails.
+ SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
+ }
+
+ bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
+ if (session) {
+ ALOGV("Setting session");
+ SSL_set_session(ssl.get(), session.get());
+ } else {
+ ALOGV("No session available");
+ }
+
+ for (;;) {
+ ALOGV("%u Calling SSL_connect", mMark);
+ int ret = SSL_connect(ssl.get());
+ ALOGV("%u SSL_connect returned %d", mMark, ret);
+ if (ret == 1) break; // SSL handshake complete;
+
+ const int ssl_err = SSL_get_error(ssl.get(), ret);
+ switch (ssl_err) {
+ case SSL_ERROR_WANT_READ:
+ if (waitForReading(fd) != 1) {
+ ALOGW("SSL_connect read error: %d", errno);
+ return nullptr;
+ }
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForWriting(fd) != 1) {
+ ALOGW("SSL_connect write error");
+ return nullptr;
+ }
+ break;
+ default:
+ ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
+ return nullptr;
+ }
+ }
+
+ // TODO: Call SSL_shutdown before discarding the session if validation fails.
+ if (!mServer.fingerprints.empty()) {
+ ALOGV("Checking DNS over TLS fingerprint");
+
+ // We only care that the chain is internally self-consistent, not that
+ // it chains to a trusted root, so we can ignore some kinds of errors.
+ // TODO: Add a CA root verification mode that respects these errors.
+ int verify_result = SSL_get_verify_result(ssl.get());
+ switch (verify_result) {
+ case X509_V_OK:
+ case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
+ case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
+ case X509_V_ERR_CERT_UNTRUSTED:
+ break;
+ default:
+ ALOGW("Invalid certificate chain, error %d", verify_result);
+ return nullptr;
+ }
+
+ STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
+ if (!chain) {
+ ALOGW("Server has null certificate");
+ return nullptr;
+ }
+ // Chain and its contents are owned by ssl, so we don't need to free explicitly.
+ bool matched = false;
+ for (size_t i = 0; i < sk_X509_num(chain); ++i) {
+ // This appears to be O(N^2), but there doesn't seem to be a straightforward
+ // way to walk a STACK_OF nondestructively in linear time.
+ X509* cert = sk_X509_value(chain, i);
+ std::vector<uint8_t> digest;
+ if (!getSPKIDigest(cert, &digest)) {
+ ALOGE("Digest computation failed");
+ return nullptr;
+ }
+
+ if (mServer.fingerprints.count(digest) > 0) {
+ matched = true;
+ break;
+ }
+ }
+
+ if (!matched) {
+ ALOGW("No matching fingerprint");
+ return nullptr;
+ }
+
+ ALOGV("DNS over TLS fingerprint is correct");
+ }
+
+ ALOGV("%u handshake complete", mMark);
+
+ return ssl;
+}
+
+void DnsTlsSocket::sslDisconnect() {
+ if (mSsl) {
+ SSL_shutdown(mSsl.get());
+ mSsl.reset();
+ }
+ mSslFd.reset();
+}
+
+bool DnsTlsSocket::sslWrite(const Slice buffer) {
+ ALOGV("%u Writing %zu bytes", mMark, buffer.size());
+ for (;;) {
+ int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
+ if (ret == int(buffer.size())) break; // SSL write complete;
+
+ if (ret < 1) {
+ const int ssl_err = SSL_get_error(mSsl.get(), ret);
+ switch (ssl_err) {
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForWriting(mSslFd.get()) != 1) {
+ ALOGV("SSL_write error");
+ return false;
+ }
+ continue;
+ case 0:
+ break; // SSL write complete;
+ default:
+ ALOGV("SSL_write error %d", ssl_err);
+ return false;
+ }
+ }
+ }
+ ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
+ return true;
+}
+
+void DnsTlsSocket::loop() {
+ std::lock_guard guard(mLock);
+ // Buffer at most one query.
+ Query q;
+
+ const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
+ while (true) {
+ // poll() ignores negative fds
+ struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
+ enum { SSLFD = 0, IPCFD = 1 };
+
+ // Always listen for a response from server.
+ fds[SSLFD].fd = mSslFd.get();
+ fds[SSLFD].events = POLLIN;
+
+ // If we have a pending query, also wait for space
+ // to write it, otherwise listen for a new query.
+ if (!q.query.empty()) {
+ fds[SSLFD].events |= POLLOUT;
+ } else {
+ fds[IPCFD].fd = mIpcOutFd.get();
+ fds[IPCFD].events = POLLIN;
+ }
+
+ const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs));
+ if (s == 0) {
+ ALOGV("Idle timeout");
+ break;
+ }
+ if (s < 0) {
+ ALOGV("Poll failed: %d", errno);
+ break;
+ }
+ if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
+ if (!readResponse()) {
+ ALOGV("SSL remote close or read error.");
+ break;
+ }
+ }
+ if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
+ int res = read(mIpcOutFd.get(), &q, sizeof(q));
+ if (res < 0) {
+ ALOGW("Error during IPC read");
+ break;
+ } else if (res == 0) {
+ ALOGV("IPC channel closed; disconnecting");
+ break;
+ } else if (res != sizeof(q)) {
+ ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
+ break;
+ }
+ } else if (fds[SSLFD].revents & POLLOUT) {
+ // query cannot be null here.
+ if (!sendQuery(q)) {
+ break;
+ }
+ q = Query(); // Reset q to empty
+ }
+ }
+ ALOGV("Closing IPC read FD");
+ mIpcOutFd.reset();
+ ALOGV("Disconnecting");
+ sslDisconnect();
+ ALOGV("Calling onClosed");
+ mObserver->onClosed();
+ ALOGV("Ending loop");
+}
+
+DnsTlsSocket::~DnsTlsSocket() {
+ ALOGV("Destructor");
+ // This will trigger an orderly shutdown in loop().
+ mIpcInFd.reset();
+ {
+ // Wait for the orderly shutdown to complete.
+ std::lock_guard guard(mLock);
+ if (mLoopThread && std::this_thread::get_id() == mLoopThread->get_id()) {
+ ALOGE("Violation of re-entrance precondition");
+ return;
+ }
+ }
+ if (mLoopThread) {
+ ALOGV("Waiting for loop thread to terminate");
+ mLoopThread->join();
+ mLoopThread.reset();
+ }
+ ALOGV("Destructor completed");
+}
+
+bool DnsTlsSocket::query(uint16_t id, const Slice query) {
+ const Query q = { .id = id, .query = query };
+ if (!mIpcInFd) {
+ return false;
+ }
+ int written = write(mIpcInFd.get(), &q, sizeof(q));
+ return written == sizeof(q);
+}
+
+// Read exactly len bytes into buffer or fail with an SSL error code
+int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
+ size_t remaining = buffer.size();
+ while (remaining > 0) {
+ int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
+ if (ret == 0) {
+ ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
+ remaining, buffer.size());
+ return SSL_ERROR_ZERO_RETURN;
+ }
+
+ if (ret < 0) {
+ const int ssl_err = SSL_get_error(mSsl.get(), ret);
+ if (wait && ssl_err == SSL_ERROR_WANT_READ) {
+ if (waitForReading(mSslFd.get()) != 1) {
+ ALOGV("Poll failed in sslRead: %d", errno);
+ return SSL_ERROR_SYSCALL;
+ }
+ continue;
+ } else {
+ ALOGV("SSL_read error %d", ssl_err);
+ return ssl_err;
+ }
+ }
+
+ remaining -= ret;
+ wait = true; // Once a read is started, try to finish.
+ }
+ return SSL_ERROR_NONE;
+}
+
+bool DnsTlsSocket::sendQuery(const Query& q) {
+ ALOGV("sending query");
+ // Compose the entire message in a single buffer, so that it can be
+ // sent as a single TLS record.
+ std::vector<uint8_t> buf(q.query.size() + 4);
+ // Write 2-byte length
+ uint16_t len = q.query.size() + 2; // + 2 for the ID.
+ buf[0] = len >> 8;
+ buf[1] = len;
+ // Write 2-byte ID
+ buf[2] = q.id >> 8;
+ buf[3] = q.id;
+ // Copy body
+ std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
+ if (!sslWrite(netdutils::makeSlice(buf))) {
+ return false;
+ }
+ ALOGV("%u SSL_write complete", mMark);
+ return true;
+}
+
+bool DnsTlsSocket::readResponse() {
+ ALOGV("reading response");
+ uint8_t responseHeader[2];
+ int err = sslRead(Slice(responseHeader, 2), false);
+ if (err == SSL_ERROR_WANT_READ) {
+ ALOGV("Ignoring spurious wakeup from server");
+ return true;
+ }
+ if (err != SSL_ERROR_NONE) {
+ return false;
+ }
+ // Truncate responses larger than MAX_SIZE. This is safe because a DNS packet is
+ // always invalid when truncated, so the response will be treated as an error.
+ constexpr uint16_t MAX_SIZE = 8192;
+ const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
+ ALOGV("%u Expecting response of size %i", mMark, responseSize);
+ std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
+ if (sslRead(netdutils::makeSlice(response), true) != SSL_ERROR_NONE) {
+ ALOGV("%u Failed to read %zu bytes", mMark, response.size());
+ return false;
+ }
+ uint16_t remainingBytes = responseSize - response.size();
+ while (remainingBytes > 0) {
+ constexpr uint16_t CHUNK_SIZE = 2048;
+ std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
+ if (sslRead(netdutils::makeSlice(discard), true) != SSL_ERROR_NONE) {
+ ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
+ return false;
+ }
+ remainingBytes -= discard.size();
+ }
+ ALOGV("%u SSL_read complete", mMark);
+
+ mObserver->onResponse(std::move(response));
+ return true;
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/resolv/DnsTlsTransport.cpp b/resolv/DnsTlsTransport.cpp
new file mode 100644
index 0000000..b4294e2
--- /dev/null
+++ b/resolv/DnsTlsTransport.cpp
@@ -0,0 +1,224 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsTransport"
+//#define LOG_NDEBUG 0
+
+#include "netd_resolv/DnsTlsTransport.h"
+
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+
+#include "netd_resolv/DnsTlsSocketFactory.h"
+#include "netd_resolv/IDnsTlsSocketFactory.h"
+
+#include "log/log.h"
+#include "Fwmark.h"
+#include "Permission.h"
+
+namespace android {
+namespace net {
+
+std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
+ std::lock_guard guard(mLock);
+
+ auto record = mQueries.recordQuery(query);
+ if (!record) {
+ return std::async(std::launch::deferred, []{
+ return (Result) { .code = Response::internal_error };
+ });
+ }
+
+ if (!mSocket) {
+ ALOGV("No socket for query. Opening socket and sending.");
+ doConnect();
+ } else {
+ sendQuery(record->query);
+ }
+
+ return std::move(record->result);
+}
+
+bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
+ // Strip off the ID number and send the new ID instead.
+ bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
+ if (sent) {
+ mQueries.markTried(q.newId);
+ }
+ return sent;
+}
+
+void DnsTlsTransport::doConnect() {
+ ALOGV("Constructing new socket");
+ mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
+
+ if (mSocket) {
+ auto queries = mQueries.getAll();
+ ALOGV("Initialization succeeded. Reissuing %zu queries.", queries.size());
+ for(auto& q : queries) {
+ if (!sendQuery(q)) {
+ break;
+ }
+ }
+ } else {
+ ALOGV("Initialization failed.");
+ mSocket.reset();
+ ALOGV("Failing all pending queries.");
+ mQueries.clear();
+ }
+}
+
+void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
+ mQueries.onResponse(std::move(response));
+}
+
+void DnsTlsTransport::onClosed() {
+ std::lock_guard guard(mLock);
+ if (mClosing) {
+ return;
+ }
+ // Move remaining operations to a new thread.
+ // This is necessary because
+ // 1. onClosed is currently running on a thread that blocks mSocket's destructor
+ // 2. doReconnect will call that destructor
+ if (mReconnectThread) {
+ // Complete cleanup of a previous reconnect thread, if present.
+ mReconnectThread->join();
+ // Joining a thread that is trying to acquire mLock, while holding mLock,
+ // looks like it risks a deadlock. However, a deadlock will not occur because
+ // once onClosed is called, it cannot be called again until after doReconnect
+ // acquires mLock.
+ }
+ mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
+}
+
+void DnsTlsTransport::doReconnect() {
+ std::lock_guard guard(mLock);
+ if (mClosing) {
+ return;
+ }
+ mQueries.cleanup();
+ if (!mQueries.empty()) {
+ ALOGV("Fast reconnect to retry remaining queries");
+ doConnect();
+ } else {
+ ALOGV("No pending queries. Going idle.");
+ mSocket.reset();
+ }
+}
+
+DnsTlsTransport::~DnsTlsTransport() {
+ ALOGV("Destructor");
+ {
+ std::lock_guard guard(mLock);
+ ALOGV("Locked destruction procedure");
+ mQueries.clear();
+ mClosing = true;
+ }
+ // It's possible that a reconnect thread was spawned and waiting for mLock.
+ // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
+ // but we need to wait for it to finish before allowing destruction to proceed.
+ if (mReconnectThread) {
+ ALOGV("Waiting for reconnect thread to terminate");
+ mReconnectThread->join();
+ mReconnectThread.reset();
+ }
+ // Ensure that the socket is destroyed, and can clean up its callback threads,
+ // before any of this object's fields become invalid.
+ mSocket.reset();
+ ALOGV("Destructor completed");
+}
+
+// static
+// TODO: Use this function to preheat the session cache.
+// That may require moving it to DnsTlsDispatcher.
+bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid) {
+ ALOGV("Beginning validation on %u", netid);
+ // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
+ // order to prove that it is actually a working DNS over TLS server.
+ static const char kDnsSafeChars[] =
+ "abcdefhijklmnopqrstuvwxyz"
+ "ABCDEFHIJKLMNOPQRSTUVWXYZ"
+ "0123456789";
+ const auto c = [](uint8_t rnd) -> uint8_t {
+ return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
+ };
+ uint8_t rnd[8];
+ arc4random_buf(rnd, std::size(rnd));
+ // We could try to use res_mkquery() here, but it's basically the same.
+ uint8_t query[] = {
+ rnd[6], rnd[7], // [0-1] query ID
+ 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
+ 0, 1, // [4-5] QDCOUNT (number of queries)
+ 0, 0, // [6-7] ANCOUNT (number of answers)
+ 0, 0, // [8-9] NSCOUNT (number of name server records)
+ 0, 0, // [10-11] ARCOUNT (number of additional records)
+ 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
+ '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
+ 6, 'm', 'e', 't', 'r', 'i', 'c',
+ 7, 'g', 's', 't', 'a', 't', 'i', 'c',
+ 3, 'c', 'o', 'm',
+ 0, // null terminator of FQDN (root TLD)
+ 0, ns_t_aaaa, // QTYPE
+ 0, ns_c_in // QCLASS
+ };
+ const int qlen = std::size(query);
+
+ // At validation time, we only know the netId, so we have to guess/compute the
+ // corresponding socket mark.
+ Fwmark fwmark;
+ fwmark.permission = PERMISSION_SYSTEM;
+ fwmark.explicitlySelected = true;
+ fwmark.protectedFromVpn = true;
+ fwmark.netId = netid;
+ unsigned mark = fwmark.intValue;
+ int replylen = 0;
+ DnsTlsSocketFactory factory;
+ DnsTlsTransport transport(server, mark, &factory);
+ auto r = transport.query(Slice(query, qlen)).get();
+ if (r.code != Response::success) {
+ ALOGV("query failed");
+ return false;
+ }
+
+ const std::vector<uint8_t>& recvbuf = r.response;
+ if (recvbuf.size() < NS_HFIXEDSZ) {
+ ALOGW("short response: %d", replylen);
+ return false;
+ }
+
+ const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
+ if (qdcount != 1) {
+ ALOGW("reply query count != 1: %d", qdcount);
+ return false;
+ }
+
+ const int ancount = (recvbuf[6] << 8) | recvbuf[7];
+ ALOGV("%u answer count: %d", netid, ancount);
+
+ // TODO: Further validate the response contents (check for valid AAAA record, ...).
+ // Note that currently, integration tests rely on this function accepting a
+ // response with zero records.
+#if 0
+ for (int i = 0; i < resplen; i++) {
+ ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
+ }
+#endif
+ return true;
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/resolv/README.md b/resolv/README.md
new file mode 100644
index 0000000..8fe4c89
--- /dev/null
+++ b/resolv/README.md
@@ -0,0 +1,134 @@
+# DNS-over-TLS query forwarder design
+
+## Overview
+
+The DNS-over-TLS query forwarder consists of five classes:
+ * `DnsTlsDispatcher`
+ * `DnsTlsTransport`
+ * `DnsTlsQueryMap`
+ * `DnsTlsSessionCache`
+ * `DnsTlsSocket`
+
+`DnsTlsDispatcher` is a singleton class whose `query` method is the DnsTls's
+only public interface. `DnsTlsDispatcher` is just a table holding the
+`DnsTlsTransport` for each server (represented by a `DnsTlsServer` struct) and
+network. `DnsTlsDispatcher` also blocks each query thread, waiting on a
+`std::future` returned by `DnsTlsTransport` that represents the response.
+
+`DnsTlsTransport` sends each query over a `DnsTlsSocket`, opening a
+new one if necessary. It also has to listen for responses from the
+`DnsTlsSocket`, which happen on a different thread.
+`IDnsTlsSocketObserver` is an interface defining how `DnsTlsSocket` returns
+responses to `DnsTlsTransport`.
+
+`DnsTlsQueryMap` and `DnsTlsSessionCache` are helper classes owned by `DnsTlsTransport`.
+`DnsTlsQueryMap` handles ID renumbering and query-response pairing.
+`DnsTlsSessionCache` allows TLS session resumption.
+
+`DnsTlsSocket` interleaves all queries onto a single socket, and reports all
+responses to `DnsTlsTransport` (through the `IDnsTlsObserver` interface). It doesn't
+know anything about which queries correspond to which responses, and does not retain
+state to indicate whether there is an outstanding query.
+
+## Threading
+
+### Overall patterns
+
+For clarity, each of the five classes in this design is thread-safe and holds one lock.
+Classes that spawn a helper thread call `thread::join()` in their destructor to ensure
+that it is cleaned up appropriately.
+
+All the classes here make full use of Clang thread annotations (and also null-pointer
+annotations) to minimize the likelihood of a latent threading bug. The unit tests are
+also heavily threaded to exercise this functionality.
+
+This code creates O(1) threads per socket, and does not create a new thread for each
+query or response. However, DnsProxyListener does create a thread for each query.
+
+### Threading in `DnsTlsSocket`
+
+`DnsTlsSocket` can receive queries on any thread, and send them over a
+"reliable datagram pipe" (`socketpair()` in `SOCK_SEQPACKET` mode).
+The query method writes a struct (containing a pointer to the query) to the pipe
+from its thread, and the loop thread (which owns the SSL socket)
+reads off the other end of the pipe. The pipe doesn't actually have a queue "inside";
+instead, any queueing happens by blocking the query thread until the
+socket thread can read the datagram off the other end.
+
+We need to pass messages between threads using a pipe, and not a condition variable
+or a thread-safe queue, because the socket thread has to be blocked
+in `poll()` waiting for data from the server, but also has to be woken
+up on inputs from the query threads. Therefore, inputs from the query
+threads have to arrive on a socket, so that `poll()` can listen for them.
+(There can only be a single thread because [you can't use different threads
+to read and write in OpenSSL](https://www.openssl.org/blog/blog/2017/02/21/threads/)).
+
+## ID renumbering
+
+`DnsTlsDispatcher` accepts queries that have colliding ID numbers and still sends them on
+a single socket. To avoid confusion at the server, `DnsTlsQueryMap` assigns each
+query a new ID for transmission, records the mapping from input IDs to sent IDs, and
+applies the inverse mapping to responses before returning them to the caller.
+
+`DnsTlsQueryMap` assigns each new query the ID number one greater than the largest
+ID number of an outstanding query. This means that ID numbers are initially sequential
+and usually small. If the largest possible ID number is already in use,
+`DnsTlsQueryMap` will scan the ID space to find an available ID, or fail the query
+if there are no available IDs. Queries will not block waiting for an ID number to
+become available.
+
+## Time constants
+
+`DnsTlsSocket` imposes a 20-second inactivity timeout. A socket that has been idle for
+20 seconds will be closed. This sets the limit of tolerance for slow replies,
+which could happen as a result of malfunctioning authoritative DNS servers.
+If there are any pending queries, `DnsTlsTransport` will retry them.
+
+`DnsTlsQueryMap` imposes a retry limit of 3. `DnsTlsTransport` will retry the query up
+to 3 times before reporting failure to `DnsTlsDispatcher`.
+This limit helps to ensure proper functioning in the case of a recursive resolver that
+is malfunctioning or is flooded with requests that are stalled due to malfunctioning
+authoritative servers.
+
+`DnsTlsDispatcher` maintains a 5-minute timeout. Any `DnsTlsTransport` that has had no
+outstanding queries for 5 minutes will be destroyed at the next query on a different
+transport.
+This sets the limit on how long session tickets will be preserved during idle periods,
+because each `DnsTlsTransport` owns a `DnsTlsSessionCache`. Imposing this timeout
+increases latency on the first query after an idle period, but also helps to avoid
+unbounded memory usage.
+
+`DnsTlsSessionCache` sets a limit of 5 sessions in each cache, expiring the oldest one
+when the limit is reached. However, because the client code does not currently
+reuse sessions more than once, it should not be possible to hit this limit.
+
+## Testing
+
+Unit tests are in `dns_tls_test.cpp`. They cover all the classes except
+`DnsTlsSocket` (which requires `CAP_NET_ADMIN` because it uses `setsockopt(SO_MARK)`) and
+`DnsTlsSessionCache` (which requires integration with libssl). These classes are
+exercised by the integration tests in `../tests/resolv_test.cpp`.
+
+### Dependency Injection
+
+For unit testing, we would like to be able to mock out `DnsTlsSocket`. This is
+particularly required for unit testing of `DnsTlsDispatcher` and `DnsTlsTransport`.
+To make these unit tests possible, this code uses a dependency injection pattern:
+`DnsTlsSocket` is produced by a `DnsTlsSocketFactory`, and both of these have a
+defined interface.
+
+`DnsTlsDispatcher`'s constructor takes an `IDnsTlsSocketFactory`,
+which in production is a `DnsTlsSocketFactory`. However, in unit tests, we can
+substitute a test factory that returns a fake socket, so that the unit tests can
+run without actually connecting over TLS to a test server. (The integration tests
+do actual TLS.)
+
+## Logging
+
+This code uses `ALOGV` throughout for low-priority logging, and does not use
+`ALOGD`. `ALOGV` is disabled by default, unless activated by `#define LOG_NDEBUG 0`.
+(`ALOGD` is not disabled by default, requiring extra measures to avoid spamming the
+system log in production builds.)
+
+## Reference
+ * [BoringSSL API docs](https://commondatastorage.googleapis.com/chromium-boringssl-docs/headers.html)
diff --git a/resolv/dns_tls_test.cpp b/resolv/dns_tls_test.cpp
index cd0a2a8..4597389 100644
--- a/resolv/dns_tls_test.cpp
+++ b/resolv/dns_tls_test.cpp
@@ -19,15 +19,15 @@
#include <gtest/gtest.h>
-#include "dns/DnsTlsDispatcher.h"
-#include "dns/DnsTlsQueryMap.h"
-#include "dns/DnsTlsServer.h"
-#include "dns/DnsTlsSessionCache.h"
-#include "dns/DnsTlsSocket.h"
-#include "dns/DnsTlsTransport.h"
-#include "dns/IDnsTlsSocket.h"
-#include "dns/IDnsTlsSocketFactory.h"
-#include "dns/IDnsTlsSocketObserver.h"
+#include "netd_resolv/DnsTlsDispatcher.h"
+#include "netd_resolv/DnsTlsQueryMap.h"
+#include "netd_resolv/DnsTlsServer.h"
+#include "netd_resolv/DnsTlsSessionCache.h"
+#include "netd_resolv/DnsTlsSocket.h"
+#include "netd_resolv/DnsTlsTransport.h"
+#include "netd_resolv/IDnsTlsSocket.h"
+#include "netd_resolv/IDnsTlsSocketFactory.h"
+#include "netd_resolv/IDnsTlsSocketObserver.h"
#include <chrono>
#include <arpa/inet.h>
diff --git a/resolv/include/netd_resolv/DnsTlsDispatcher.h b/resolv/include/netd_resolv/DnsTlsDispatcher.h
new file mode 100644
index 0000000..0bb19f2
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsDispatcher.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSDISPATCHER_H
+#define _DNS_DNSTLSDISPATCHER_H
+
+#include <list>
+#include <map>
+#include <memory>
+#include <mutex>
+
+#include <android-base/thread_annotations.h>
+
+#include <netdutils/Slice.h>
+
+#include "DnsTlsServer.h"
+#include "DnsTlsTransport.h"
+#include "IDnsTlsSocketFactory.h"
+#include "params.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+
+// This is a singleton class that manages the collection of active DnsTlsTransports.
+// Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
+class LIBNETD_RESOLV_TLS_EXPORT DnsTlsDispatcher {
+public:
+ // Default constructor.
+ DnsTlsDispatcher();
+
+ // Constructor with dependency injection for testing.
+ explicit DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) :
+ mFactory(std::move(factory)) {}
+
+ // Enqueues |query| for resolution via the given |tlsServers| on the
+ // network indicated by |mark|; writes the response into |ans|, and stores
+ // the count of bytes written in |resplen|. Returns a success or error code.
+ // The order in which servers from |tlsServers| are queried may not be the
+ // order passed in by the caller.
+ DnsTlsTransport::Response query(const std::list<DnsTlsServer> &tlsServers, unsigned mark,
+ const Slice query, const Slice ans, int * _Nonnull resplen);
+
+ // Given a |query|, sends it to the server on the network indicated by |mark|,
+ // and writes the response into |ans|, and indicates
+ // the number of bytes written in |resplen|. Returns a success or error code.
+ DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
+ const Slice query, const Slice ans, int * _Nonnull resplen);
+
+private:
+ // This lock is static so that it can be used to annotate the Transport struct.
+ // DnsTlsDispatcher is a singleton in practice, so making this static does not change
+ // the locking behavior.
+ static std::mutex sLock;
+
+ // Key = <mark, server>
+ typedef std::pair<unsigned, const DnsTlsServer> Key;
+
+ // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and
+ // usage monitoring so we can expire idle sessions from the cache.
+ struct Transport {
+ Transport(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketFactory* _Nonnull factory) :
+ transport(server, mark, factory) {}
+ // DnsTlsTransport is thread-safe, so it doesn't need to be guarded.
+ DnsTlsTransport transport;
+ // This use counter and timestamp are used to ensure that only idle sessions are
+ // destroyed.
+ int useCount GUARDED_BY(sLock) = 0;
+ // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero.
+ std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock);
+ };
+
+ // Cache of reusable DnsTlsTransports. Transports stay in cache as long as
+ // they are in use and for a few minutes after.
+ // The key is a (netid, server) pair. The netid is first for lexicographic comparison speed.
+ std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock);
+
+ // The last time we did a cleanup. For efficiency, we only perform a cleanup once every
+ // few minutes.
+ std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock);
+
+ // Drop any cache entries whose useCount is zero and which have not been used recently.
+ // This function performs a linear scan of mStore.
+ void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+
+ // Return a sorted list of DnsTlsServers in preference order.
+ std::list<DnsTlsServer> getOrderedServerList(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark) const;
+
+ // Trivial factory for DnsTlsSockets. Dependency injection is only used for testing.
+ std::unique_ptr<IDnsTlsSocketFactory> mFactory;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSDISPATCHER_H
diff --git a/resolv/include/netd_resolv/DnsTlsQueryMap.h b/resolv/include/netd_resolv/DnsTlsQueryMap.h
new file mode 100644
index 0000000..4c8010c
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsQueryMap.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSQUERYMAP_H
+#define _DNS_DNSTLSQUERYMAP_H
+
+#include <future>
+#include <map>
+#include <mutex>
+#include <vector>
+
+#include <android-base/thread_annotations.h>
+#include <netdutils/Slice.h>
+
+#include "DnsTlsServer.h"
+#include "params.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+
+// Keeps track of queries and responses. This class matches responses with queries.
+// All methods are thread-safe and non-blocking.
+class LIBNETD_RESOLV_TLS_EXPORT DnsTlsQueryMap {
+public:
+ struct Query {
+ // The new ID number assigned to this query.
+ uint16_t newId;
+ // A query that has been passed to recordQuery(), with its original ID number.
+ const Slice query;
+ };
+
+ typedef DnsTlsServer::Response Response;
+ typedef DnsTlsServer::Result Result;
+
+ struct QueryFuture {
+ QueryFuture(Query query, std::future<Result> result) :
+ query(query), result(std::move(result)) {}
+ Query query;
+ // A future which will resolve to the result of this query.
+ std::future<Result> result;
+ };
+
+ // Returns an object containing everything needed to complete processing of
+ // this query, or null if the query could not be recorded.
+ std::unique_ptr<QueryFuture> recordQuery(const Slice query);
+
+ // Process a response, including a new ID. If the response
+ // is not recognized as matching any query, it will be ignored.
+ void onResponse(std::vector<uint8_t> response);
+
+ // Clear all map contents. This causes all pending queries to resolve with failure.
+ void clear();
+
+ // Get all pending queries. This returns a shallow copy, mostly for thread-safety.
+ std::vector<Query> getAll();
+
+ // Mark a query has having been retried. If the query hits the retry limit, it will
+ // be expired at the next call to cleanup.
+ void markTried(uint16_t newId);
+ void cleanup();
+
+ // Returns true if there are no pending queries.
+ bool empty();
+
+private:
+ std::mutex mLock;
+
+ struct QueryPromise {
+ QueryPromise(Query query) : query(query) {}
+ Query query;
+ // Number of times the query has been tried. Limited to kMaxTries.
+ int tries = 0;
+ // A promise whose future is returned by recordQuery()
+ // It is fulfilled by onResponse().
+ std::promise<Result> result;
+ };
+
+ // The maximum number of times we will send a query before abandoning it.
+ static constexpr int kMaxTries = 3;
+
+ // Outstanding queries by newId.
+ std::map<uint16_t, QueryPromise> mQueries GUARDED_BY(mLock);
+
+ // Get a "newId" number that is not currently in use. Returns -1 if there are none.
+ int32_t getFreeId() REQUIRES(mLock);
+
+ // Fulfill the result with an error code.
+ static void expire(QueryPromise* _Nonnull p);
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSQUERYMAP_H
diff --git a/resolv/include/netd_resolv/DnsTlsServer.h b/resolv/include/netd_resolv/DnsTlsServer.h
new file mode 100644
index 0000000..752dc5f
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsServer.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSERVER_H
+#define _DNS_DNSTLSSERVER_H
+
+#include <set>
+#include <string>
+#include <vector>
+
+#include <netinet/in.h>
+
+#include "params.h"
+
+namespace android {
+namespace net {
+
+// DnsTlsServer represents a recursive resolver that supports, or may support, a
+// secure protocol.
+struct LIBNETD_RESOLV_TLS_EXPORT DnsTlsServer {
+ // Default constructor.
+ DnsTlsServer() {}
+
+ // Allow sockaddr_storage to be promoted to DnsTlsServer automatically.
+ DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}
+
+ enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+
+ struct Result {
+ Response code;
+ std::vector<uint8_t> response;
+ };
+
+ // The server location, including IP and port.
+ sockaddr_storage ss = {};
+
+ // A set of SHA256 public key fingerprints. If this set is nonempty, the server
+ // must present a self-consistent certificate chain that contains a certificate
+ // whose public key matches one of these fingerprints. Otherwise, the client will
+ // terminate the connection.
+ std::set<std::vector<uint8_t>> fingerprints;
+
+ // The server's hostname. If this string is nonempty, the server must present a
+ // certificate that indicates this name and has a valid chain to a trusted root CA.
+ std::string name;
+
+ // Placeholder. More protocols might be defined in the future.
+ int protocol = IPPROTO_TCP;
+
+ // Exact comparison of DnsTlsServer objects
+ bool operator <(const DnsTlsServer& other) const;
+ bool operator ==(const DnsTlsServer& other) const;
+
+ bool wasExplicitlyConfigured() const;
+};
+
+// This comparison only checks the IP address. It ignores ports, names, and fingerprints.
+struct LIBNETD_RESOLV_TLS_EXPORT AddressComparator {
+ bool operator() (const DnsTlsServer& x, const DnsTlsServer& y) const;
+};
+
+} // namespace net
+} // namespace android
+
+#endif // _DNS_DNSTLSSERVER_H
diff --git a/resolv/include/netd_resolv/DnsTlsSessionCache.h b/resolv/include/netd_resolv/DnsTlsSessionCache.h
new file mode 100644
index 0000000..8d0fc1d
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsSessionCache.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSESSIONCACHE_H
+#define _DNS_DNSTLSSESSIONCACHE_H
+
+#include <deque>
+#include <mutex>
+
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+
+#include "params.h"
+
+namespace android {
+namespace net {
+
+// Cache of recently seen SSL_SESSIONs. This is used to support session tickets.
+// This class is thread-safe.
+class DnsTlsSessionCache {
+public:
+ // Prepare SSL objects to use this session cache. These methods must be called
+ // before making use of either object.
+ void prepareSslContext(SSL_CTX* _Nonnull ssl_ctx);
+ bool prepareSsl(SSL* _Nonnull ssl);
+
+ // Get the most recently discovered session. For TLS 1.3 compatibility and
+ // maximum privacy, each session will only be returned once, so the caller
+ // gains ownership of the session. (Here and throughout,
+ // bssl::UniquePtr<SSL_SESSION> is actually serving as a reference counted
+ // pointer.)
+ bssl::UniquePtr<SSL_SESSION> getSession() EXCLUDES(mLock);
+
+private:
+ static constexpr size_t kMaxSize = 5;
+ static int newSessionCallback(SSL* _Nullable ssl, SSL_SESSION* _Nullable session);
+
+ std::mutex mLock;
+ void recordSession(SSL_SESSION* _Nullable session) EXCLUDES(mLock);
+
+ // Queue of sessions, from least recently added to most recently.
+ std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(mLock);
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSESSIONCACHE_H
diff --git a/resolv/include/netd_resolv/DnsTlsSocket.h b/resolv/include/netd_resolv/DnsTlsSocket.h
new file mode 100644
index 0000000..7190dd2
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsSocket.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKET_H
+#define _DNS_DNSTLSSOCKET_H
+
+#include <future>
+#include <mutex>
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+#include <netdutils/Slice.h>
+#include <netdutils/Status.h>
+
+#include "DnsTlsServer.h"
+#include "IDnsTlsSocket.h"
+#include "params.h"
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+
+using netdutils::Slice;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+// This class is not aware of query-response pairing or anything else about DNS.
+// For the observer:
+// This class is not re-entrant: the observer is not permitted to wait for a call to query()
+// or the destructor in a callback. Doing so will result in deadlocks.
+// This class may call the observer at any time after initialize(), until the destructor
+// returns (but not after).
+class LIBNETD_RESOLV_TLS_EXPORT DnsTlsSocket : public IDnsTlsSocket {
+public:
+ DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketObserver* _Nonnull observer,
+ DnsTlsSessionCache* _Nonnull cache) :
+ mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
+ ~DnsTlsSocket();
+
+ // Creates the SSL context for this session and connect. Returns false on failure.
+ // This method should be called after construction and before use of a DnsTlsSocket.
+ // Only call this method once per DnsTlsSocket.
+ bool initialize() EXCLUDES(mLock);
+
+ // Send a query on the provided SSL socket. |query| contains
+ // the body of a query, not including the ID header. This function will typically return before
+ // the query is actually sent. If this function fails, DnsTlsSocketObserver will be
+ // notified that the socket is closed.
+ // Note that success here indicates successful sending, not receipt of a response.
+ // Thread-safe.
+ bool query(uint16_t id, const Slice query) override;
+
+private:
+ // Lock to be held by the SSL event loop thread. This is not normally in contention.
+ std::mutex mLock;
+
+ // Forwards queries and receives responses. Blocks until the idle timeout.
+ void loop() EXCLUDES(mLock);
+ std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
+
+ // On success, sets mSslFd to a socket connected to mAddr (the
+ // connection will likely be in progress if mProtocol is IPPROTO_TCP).
+ // On error, returns the errno.
+ netdutils::Status tcpConnect() REQUIRES(mLock);
+
+ // Connect an SSL session on the provided socket. If connection fails, closing the
+ // socket remains the caller's responsibility.
+ bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
+
+ // Disconnect the SSL session and close the socket.
+ void sslDisconnect() REQUIRES(mLock);
+
+ // Writes a buffer to the socket.
+ bool sslWrite(const Slice buffer) REQUIRES(mLock);
+
+ // Reads exactly the specified number of bytes from the socket, or fails.
+ // Returns SSL_ERROR_NONE on success.
+ // If |wait| is true, then this function always blocks. Otherwise, it
+ // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
+ int sslRead(const Slice buffer, bool wait) REQUIRES(mLock);
+
+ struct Query {
+ uint16_t id;
+ Slice query;
+ };
+
+ bool sendQuery(const Query& q) REQUIRES(mLock);
+ bool readResponse() REQUIRES(mLock);
+
+ // SOCK_SEQPACKET socket pair used for sending queries from myriad query
+ // threads to the SSL thread. EOF indicates a close request.
+ // We have to use a socket pair (i.e. a pipe) because the SSL thread needs
+ // to wait in poll() for input from either a remote server or a query thread.
+ base::unique_fd mIpcInFd;
+ base::unique_fd mIpcOutFd GUARDED_BY(mLock);
+
+ // SSL Socket fields.
+ bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
+ base::unique_fd mSslFd GUARDED_BY(mLock);
+ bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
+ static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
+
+ const unsigned mMark; // Socket mark
+ const DnsTlsServer mServer;
+ IDnsTlsSocketObserver* _Nonnull const mObserver;
+ DnsTlsSessionCache* _Nonnull const mCache;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSOCKET_H
diff --git a/resolv/include/netd_resolv/DnsTlsSocketFactory.h b/resolv/include/netd_resolv/DnsTlsSocketFactory.h
new file mode 100644
index 0000000..1c58535
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsSocketFactory.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKETFACTORY_H
+#define _DNS_DNSTLSSOCKETFACTORY_H
+
+#include <memory>
+
+#include "DnsTlsSocket.h"
+#include "IDnsTlsSocketFactory.h"
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Trivial RAII factory for DnsTlsSocket. This is owned by DnsTlsDispatcher.
+class DnsTlsSocketFactory : public IDnsTlsSocketFactory {
+public:
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketObserver* _Nonnull observer,
+ DnsTlsSessionCache* _Nonnull cache) override {
+ auto socket = std::make_unique<DnsTlsSocket>(server, mark, observer, cache);
+ if (!socket->initialize()) {
+ return nullptr;
+ }
+ return std::move(socket);
+ }
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSOCKETFACTORY_H
diff --git a/resolv/include/netd_resolv/DnsTlsTransport.h b/resolv/include/netd_resolv/DnsTlsTransport.h
new file mode 100644
index 0000000..0d314ad
--- /dev/null
+++ b/resolv/include/netd_resolv/DnsTlsTransport.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSTRANSPORT_H
+#define _DNS_DNSTLSTRANSPORT_H
+
+#include <future>
+#include <map>
+#include <mutex>
+#include <vector>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+
+#include "DnsTlsQueryMap.h"
+#include "DnsTlsServer.h"
+#include "DnsTlsSessionCache.h"
+#include "IDnsTlsSocket.h"
+#include "IDnsTlsSocketObserver.h"
+#include "params.h"
+
+#include <netdutils/Slice.h>
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketFactory;
+
+// Manages at most one DnsTlsSocket at a time. This class handles socket lifetime issues,
+// such as reopening the socket and reissuing pending queries.
+class LIBNETD_RESOLV_TLS_EXPORT DnsTlsTransport : public IDnsTlsSocketObserver {
+public:
+ DnsTlsTransport(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketFactory* _Nonnull factory) :
+ mMark(mark), mServer(server), mFactory(factory) {}
+ ~DnsTlsTransport();
+
+ typedef DnsTlsServer::Response Response;
+ typedef DnsTlsServer::Result Result;
+
+ // Given a |query|, this method sends it to the server and returns the result asynchronously.
+ std::future<Result> query(const netdutils::Slice query) EXCLUDES(mLock);
+
+ // Check that a given TLS server is fully working on the specified netid, and has the
+ // provided SHA-256 fingerprint (if nonempty). This function is used in ResolverController
+ // to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
+ static bool validate(const DnsTlsServer& server, unsigned netid);
+
+ // Implement IDnsTlsSocketObserver
+ void onResponse(std::vector<uint8_t> response) override;
+ void onClosed() override EXCLUDES(mLock);
+
+private:
+ std::mutex mLock;
+
+ DnsTlsSessionCache mCache;
+ DnsTlsQueryMap mQueries;
+
+ const unsigned mMark; // Socket mark
+ const DnsTlsServer mServer;
+ IDnsTlsSocketFactory* _Nonnull const mFactory;
+
+ void doConnect() REQUIRES(mLock);
+
+ // doReconnect is used by onClosed. It runs on the reconnect thread.
+ void doReconnect() EXCLUDES(mLock);
+ std::unique_ptr<std::thread> mReconnectThread GUARDED_BY(mLock);
+
+ // Used to prevent onClosed from starting a reconnect during the destructor.
+ bool mClosing GUARDED_BY(mLock) = false;
+
+ // Sending queries on the socket is thread-safe, but construction/destruction is not.
+ std::unique_ptr<IDnsTlsSocket> mSocket GUARDED_BY(mLock);
+
+ // Send a query to the socket.
+ bool sendQuery(const DnsTlsQueryMap::Query q) REQUIRES(mLock);
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSTRANSPORT_H
diff --git a/resolv/include/netd_resolv/IDnsTlsSocket.h b/resolv/include/netd_resolv/IDnsTlsSocket.h
new file mode 100644
index 0000000..4f21bbd
--- /dev/null
+++ b/resolv/include/netd_resolv/IDnsTlsSocket.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKET_H
+#define _DNS_IDNSTLSSOCKET_H
+
+#include <cstddef>
+#include <cstdint>
+
+#include <netdutils/Slice.h>
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+// This interface is not aware of query-response pairing or anything else about DNS.
+class IDnsTlsSocket {
+public:
+ virtual ~IDnsTlsSocket() {};
+ // Send a query on the provided SSL socket. |query| contains
+ // the body of a query, not including the ID bytes. This function will typically return before
+ // the query is actually sent. If this function fails, the observer will be
+ // notified that the socket is closed.
+ // Note that a true return value indicates successful sending, not receipt of a response.
+ virtual bool query(uint16_t id, const netdutils::Slice query) = 0;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_IDNSTLSSOCKET_H
diff --git a/resolv/include/netd_resolv/IDnsTlsSocketFactory.h b/resolv/include/netd_resolv/IDnsTlsSocketFactory.h
new file mode 100644
index 0000000..52164b6
--- /dev/null
+++ b/resolv/include/netd_resolv/IDnsTlsSocketFactory.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKETFACTORY_H
+#define _DNS_IDNSTLSSOCKETFACTORY_H
+
+#include "IDnsTlsSocket.h"
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Dependency injection interface for DnsTlsSocketFactory.
+// This pattern allows mocking of DnsTlsSocket for tests.
+class IDnsTlsSocketFactory {
+public:
+ virtual ~IDnsTlsSocketFactory() {};
+ virtual std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server,
+ unsigned mark,
+ IDnsTlsSocketObserver* _Nonnull observer,
+ DnsTlsSessionCache* _Nonnull cache) = 0;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_IDNSTLSSOCKETFACTORY_H
diff --git a/resolv/include/netd_resolv/IDnsTlsSocketObserver.h b/resolv/include/netd_resolv/IDnsTlsSocketObserver.h
new file mode 100644
index 0000000..7ae364c
--- /dev/null
+++ b/resolv/include/netd_resolv/IDnsTlsSocketObserver.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKETOBSERVER_H
+#define _DNS_IDNSTLSSOCKETOBSERVER_H
+
+namespace android {
+namespace net {
+
+// Interface to listen for DNS query responses on a socket, and to be notified
+// when the socket is closed by the remote peer. This is only implemented by
+// DnsTlsTransport, but it is a separate interface for clarity and to avoid a
+// circular dependency with DnsTlsSocket.
+class IDnsTlsSocketObserver {
+public:
+ virtual ~IDnsTlsSocketObserver() {};
+ virtual void onResponse(std::vector<uint8_t> response) = 0;
+
+ virtual void onClosed() = 0;
+};
+
+} // namespace net
+} // namespace android
+
+#endif // _DNS_IDNSTLSSOCKETOBSERVER_H
diff --git a/resolv/include/netd_resolv/params.h b/resolv/include/netd_resolv/params.h
index 60b9b54..c3f3363 100644
--- a/resolv/include/netd_resolv/params.h
+++ b/resolv/include/netd_resolv/params.h
@@ -46,4 +46,7 @@
#define LIBNETD_RESOLV_PUBLIC extern "C" [[gnu::visibility("default")]]
+// TODO: Remove it after we move PrivateDnsConfiguration and qhook() into libnetd_resolv.
+#define LIBNETD_RESOLV_TLS_EXPORT [[gnu::visibility("default")]]
+
#endif // NETD_RESOLV_PARAMS_H
diff --git a/resolv/libnetd_resolv.map.txt b/resolv/libnetd_resolv.map.txt
index 63e0cb8..db81593 100644
--- a/resolv/libnetd_resolv.map.txt
+++ b/resolv/libnetd_resolv.map.txt
@@ -28,6 +28,53 @@
android_net_res_stats_get_usable_servers;
resolv_delete_cache_for_net;
resolv_set_nameservers_for_net;
+
+ # These symbol names are generated by 'nm -D libnetd_resolv.so' with reverting the commit
+ # (aosp/801454). They are temporarily necessary for netd to access DnsTls* classes before
+ # we move qhook() and PrivateDnsConfiguration to libnetd_resolv library.
+ # TODO: Remove them after we hide DnsTls* stuff thus netd no longer needs to access them.
+ _ZN7android3net12DnsTlsSocket10initializeEv;
+ _ZN7android3net12DnsTlsSocket10sslConnectEi;
+ _ZN7android3net12DnsTlsSocket10tcpConnectEv;
+ _ZN7android3net12DnsTlsSocket12readResponseEv;
+ _ZN7android3net12DnsTlsSocket13sslDisconnectEv;
+ _ZN7android3net12DnsTlsSocket4loopEv;
+ _ZN7android3net12DnsTlsSocket5queryEtNS_9netdutils5SliceE;
+ _ZN7android3net12DnsTlsSocket7sslReadENS_9netdutils5SliceEb;
+ _ZN7android3net12DnsTlsSocket8sslWriteENS_9netdutils5SliceE;
+ _ZN7android3net12DnsTlsSocket9sendQueryERKNS1_5QueryE;
+ _ZN7android3net12DnsTlsSocketD0Ev;
+ _ZN7android3net12DnsTlsSocketD1Ev;
+ _ZN7android3net12DnsTlsSocketD2Ev;
+ _ZN7android3net14DnsTlsQueryMap10onResponseENSt3__16vectorIhNS2_9allocatorIhEEEE;
+ _ZN7android3net14DnsTlsQueryMap11recordQueryENS_9netdutils5SliceE;
+ _ZN7android3net14DnsTlsQueryMap5clearEv;
+ _ZN7android3net14DnsTlsQueryMap5emptyEv;
+ _ZN7android3net14DnsTlsQueryMap6expireEPNS1_12QueryPromiseE;
+ _ZN7android3net14DnsTlsQueryMap6getAllEv;
+ _ZN7android3net14DnsTlsQueryMap7cleanupEv;
+ _ZN7android3net14DnsTlsQueryMap9getFreeIdEv;
+ _ZN7android3net14DnsTlsQueryMap9markTriedEt;
+ _ZN7android3net15DnsTlsTransport10onResponseENSt3__16vectorIhNS2_9allocatorIhEEEE;
+ _ZN7android3net15DnsTlsTransport11doReconnectEv;
+ _ZN7android3net15DnsTlsTransport5queryENS_9netdutils5SliceE;
+ _ZN7android3net15DnsTlsTransport8onClosedEv;
+ _ZN7android3net15DnsTlsTransport8validateERKNS0_12DnsTlsServerEj;
+ _ZN7android3net15DnsTlsTransport9doConnectEv;
+ _ZN7android3net15DnsTlsTransport9sendQueryENS0_14DnsTlsQueryMap5QueryE;
+ _ZN7android3net15DnsTlsTransportD0Ev;
+ _ZN7android3net15DnsTlsTransportD1Ev;
+ _ZN7android3net15DnsTlsTransportD2Ev;
+ _ZN7android3net16DnsTlsDispatcher5queryERKNS0_12DnsTlsServerEjNS_9netdutils5SliceES6_Pi;
+ _ZN7android3net16DnsTlsDispatcher5queryERKNSt3__14listINS0_12DnsTlsServerENS2_9allocatorIS4_EEEEjNS_9netdutils5SliceESB_Pi;
+ _ZN7android3net16DnsTlsDispatcherC1Ev;
+ _ZN7android3net16DnsTlsDispatcherC2Ev;
+ _ZNK7android3net12DnsTlsServer23wasExplicitlyConfiguredEv;
+ _ZNK7android3net12DnsTlsServereqERKS1_;
+ _ZNK7android3net12DnsTlsServerltERKS1_;
+ _ZNK7android3net16DnsTlsDispatcher20getOrderedServerListERKNSt3__14listINS0_12DnsTlsServerENS2_9allocatorIS4_EEEEj;
+ _ZNK7android3net17AddressComparatorclERKNS0_12DnsTlsServerES4_;
+ _ZTVN7android3net15DnsTlsTransportE;
local:
*;
};