Add support for session tickets

This change adds a session ticket cache to DnsTlsTransport,
which should reduce the startup time for new sockets by 1 RTT.
This change is complex because queries may be issued on any
thread, all threads must share the cache, and we need some way
of expiring old entries from the cache.

Test: Integration tests pass.  Wireshark shows tickets in use.
Bug: 63447621
Change-Id: I756119a9d82bcba0b309a2dd84414e2fb3d4956f
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index b369022..542b4a9 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -22,7 +22,6 @@
 #include <arpa/nameser.h>
 #include <errno.h>
 #include <openssl/err.h>
-#include <openssl/ssl.h>
 #include <stdlib.h>
 
 #define LOG_TAG "DnsTlsTransport"
@@ -143,6 +142,9 @@
 }  // namespace
 
 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
+    if (DBG) {
+        ALOGD("%u connecting TCP socket", mMark);
+    }
     android::base::unique_fd fd;
     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
     switch (mServer.protocol) {
@@ -168,6 +170,11 @@
         fd.reset();
     }
 
+    if (!setNonBlocking(fd, false)) {
+        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+        fd.reset();
+    }
+
     return fd;
 }
 
@@ -231,28 +238,45 @@
     return make_tie(*this) == make_tie(other);
 }
 
-SSL* DnsTlsTransport::sslConnect(int fd) {
-    if (fd < 0) {
-        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+bool DnsTlsTransport::initialize() {
+    mSslCtx.reset(SSL_CTX_new(TLS_method()));
+    if (!mSslCtx) {
+        return false;
+    }
+    SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
+    SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
+    return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
+    // Check TLS context.
+    if (!mSslCtx) {
+        ALOGE("Internal error: context is null in ssl connect");
+        return nullptr;
+    }
+    if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
+        !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+        ALOGE("failed to min/max TLS versions");
         return nullptr;
     }
 
-    // Set up TLS context.
-    bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
-    if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
-        !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
-        ALOGD("failed to min/max TLS versions");
-        return nullptr;
-    }
-
-    bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
+    bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
     // This file descriptor is owned by a unique_fd, so don't let libssl close it.
     bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
     SSL_set_bio(ssl.get(), bio.get(), bio.get());
     bio.release();
 
-    if (!setNonBlocking(fd, false)) {
-        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+    // Add this transport as the 0-index extra data for the socket.
+    // This is used by newSessionCallback.
+    if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
+        ALOGE("failed to associate SSL socket to transport");
+        return nullptr;
+    }
+
+    // Add this transport as the 0-index extra data for the context.
+    // This is used by removeSessionCallback.
+    if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
+        ALOGE("failed to associate SSL context to transport");
         return nullptr;
     }
 
@@ -267,6 +291,20 @@
         SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
     }
 
+    bssl::UniquePtr<SSL_SESSION> session;
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        if (!mSessions.empty()) {
+            session = std::move(mSessions.front());
+            mSessions.pop_front();
+        } else if (DBG) {
+            ALOGD("Starting without session ticket.");
+        }
+    }
+    if (session) {
+        SSL_set_session(ssl.get(), session.get());
+    }
+
     for (;;) {
         if (DBG) {
             ALOGD("%u Calling SSL_connect", mMark);
@@ -353,7 +391,66 @@
     if (DBG) {
         ALOGD("%u handshake complete", mMark);
     }
-    return ssl.release();
+
+    return ssl;
+}
+
+// static
+int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+    if (!session) {
+        return 0;
+    }
+    if (DBG) {
+        ALOGD("Recording session ticket");
+    }
+    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+            SSL_get_ex_data(ssl, 0));
+    if (!xport) {
+        ALOGE("null transport in new session callback");
+        return 0;
+    }
+    xport->recordSession(session);
+    return 1;
+}
+
+void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
+    if (DBG) {
+        ALOGD("Removing session ticket");
+    }
+    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+            SSL_CTX_get_ex_data(ssl_ctx, 0));
+    if (!xport) {
+        ALOGE("null transport in remove session callback");
+        return;
+    }
+    xport->removeSession(session);
+}
+
+void DnsTlsTransport::recordSession(SSL_SESSION* session) {
+    std::lock_guard<std::mutex> guard(sLock);
+    mSessions.emplace_front(session);
+    if (mSessions.size() > 5) {
+        if (DBG) {
+            ALOGD("Too many sessions; trimming");
+        }
+        mSessions.pop_back();
+    }
+}
+
+void DnsTlsTransport::removeSession(SSL_SESSION* session) {
+    std::lock_guard<std::mutex> guard(sLock);
+    if (session) {
+        // TODO: Consider implementing targeted removal.
+        mSessions.clear();
+    }
+}
+
+void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
+    if (ssl) {
+        SSL_shutdown(ssl.get());
+        ssl.reset();
+    }
+    fd.reset();
 }
 
 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
@@ -425,48 +522,101 @@
 }
 
 // static
+std::mutex DnsTlsTransport::sLock;
+std::map<DnsTlsTransport::Key, std::unique_ptr<DnsTlsTransport>> DnsTlsTransport::sStore;
 DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
         const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
-    // TODO: Keep a static container of transports instead of constructing a new one
-    // for every query.
-    DnsTlsTransport xport(server, mark);
-    return xport.doQuery(query, qlen, response, limit, resplen);
+    const Key key = std::make_pair(mark, server);
+    DnsTlsTransport* xport;
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        auto it = sStore.find(key);
+        if (it == sStore.end()) {
+            xport = new DnsTlsTransport(server, mark);
+            if (!xport->initialize()) {
+                return DnsTlsTransport::Response::internal_error;
+            }
+            sStore[key].reset(xport);
+        } else {
+            xport = it->second.get();
+        }
+        ++xport->mUseCount;
+    }
+
+    Response res = xport->doQuery(query, qlen, response, limit, resplen);
+    auto now = std::chrono::steady_clock::now();
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        --xport->mUseCount;
+        xport->mLastUsed = now;
+        cleanup(now);
+    }
+    return res;
+}
+
+static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
+std::chrono::time_point<std::chrono::steady_clock> DnsTlsTransport::sLastCleanup;
+void DnsTlsTransport::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
+    if (now - sLastCleanup < IDLE_TIMEOUT) {
+        return;
+    }
+    for (auto it = sStore.begin(); it != sStore.end(); ) {
+        auto& xport = it->second;
+        if (xport->mUseCount == 0 && now - xport->mLastUsed > IDLE_TIMEOUT) {
+            it = sStore.erase(it);
+        } else {
+            ++it;
+        }
+    }
+    sLastCleanup = now;
 }
 
 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
         uint8_t *response, size_t limit, int *resplen) {
-    *resplen = 0;  // Zero indicates an error.
-
-    if (DBG) {
-        ALOGD("%u connecting TCP socket", mMark);
+    android::base::unique_fd fd = makeConnectedSocket();
+    if (fd.get() < 0) {
+        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+        return Response::network_error;
     }
-    android::base::unique_fd fd(makeConnectedSocket());
-    if (DBG) {
-        ALOGD("%u connecting SSL", mMark);
-    }
-    bssl::UniquePtr<SSL> ssl(sslConnect(fd));
-    if (ssl == nullptr) {
-        if (DBG) {
-            ALOGW("%u SSL connection failed", mMark);
-        }
+    bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
+    if (!ssl) {
         return Response::network_error;
     }
 
+    Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
+    if (res == Response::success) {
+        res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
+    }
+
+    sslDisconnect(std::move(ssl), std::move(fd));
+    return res;
+}
+
+DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen) {
+    if (DBG) {
+        ALOGD("sending query");
+    }
     uint8_t queryHeader[2];
     queryHeader[0] = qlen >> 8;
     queryHeader[1] = qlen;
-    if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
+    if (!sslWrite(fd, ssl, queryHeader, 2)) {
         return Response::network_error;
     }
-    if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
+    if (!sslWrite(fd, ssl, query, qlen)) {
         return Response::network_error;
     }
     if (DBG) {
         ALOGD("%u SSL_write complete", mMark);
     }
+    return Response::success;
+}
 
+DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
+    if (DBG) {
+        ALOGD("reading response");
+    }
     uint8_t responseHeader[2];
-    if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
+    if (!sslRead(fd, ssl, responseHeader, 2)) {
         if (DBG) {
             ALOGW("%u Failed to read 2-byte length header", mMark);
         }
@@ -480,7 +630,7 @@
         ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
         return Response::limit_error;
     }
-    if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
+    if (!sslRead(fd, ssl, response, responseSize)) {
         if (DBG) {
             ALOGW("%u Failed to read %i bytes", mMark, responseSize);
         }
@@ -495,8 +645,6 @@
         return Response::internal_error;
     }
 
-    SSL_shutdown(ssl.get());
-
     *resplen = responseSize;
     return Response::success;
 }