blob: f6736f8dfd7d253cd6dd5f0402354ae7ed41397e [file] [log] [blame]
Mike Yubab3daa2018-10-19 22:11:43 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef _DNS_DNSTLSSOCKET_H
18#define _DNS_DNSTLSSOCKET_H
19
Bernie Innocentiec4219b2019-01-30 11:16:36 +090020#include <openssl/ssl.h>
Mike Yubab3daa2018-10-19 22:11:43 +080021#include <future>
22#include <mutex>
Mike Yubab3daa2018-10-19 22:11:43 +080023
24#include <android-base/thread_annotations.h>
25#include <android-base/unique_fd.h>
26#include <netdutils/Slice.h>
27#include <netdutils/Status.h>
28
29#include "DnsTlsServer.h"
30#include "IDnsTlsSocket.h"
Ben Schwartz2187abe2019-01-10 14:30:46 -050031#include "LockedQueue.h"
Mike Yubab3daa2018-10-19 22:11:43 +080032
33namespace android {
34namespace net {
35
36class IDnsTlsSocketObserver;
37class DnsTlsSessionCache;
38
Mike Yubab3daa2018-10-19 22:11:43 +080039// A class for managing a TLS socket that sends and receives messages in
40// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
41// This class is not aware of query-response pairing or anything else about DNS.
42// For the observer:
43// This class is not re-entrant: the observer is not permitted to wait for a call to query()
44// or the destructor in a callback. Doing so will result in deadlocks.
45// This class may call the observer at any time after initialize(), until the destructor
46// returns (but not after).
Mike Yu441d9372020-07-15 17:06:22 +080047//
48// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
49//
Mike Yu0c8e4522020-08-24 14:41:32 +080050// UNINITIALIZED
Mike Yu441d9372020-07-15 17:06:22 +080051// |
52// v
Mike Yu0c8e4522020-08-24 14:41:32 +080053// INITIALIZED
54// |
55// v
56// +----CONNECTING------+
Mike Yu441d9372020-07-15 17:06:22 +080057// Handshake fails | | Handshake succeeds
Mike Yue93d9ae2020-08-25 19:09:51 +080058// (onClose() when | |
59// mAsyncHandshake is set) | v
Mike Yu0c8e4522020-08-24 14:41:32 +080060// | +---> CONNECTED --+
Mike Yu441d9372020-07-15 17:06:22 +080061// | | | |
62// | +-----------+ | Idle timeout
63// | Send/Recv queries | onClose()
64// | onResponse() |
65// | |
66// | |
Mike Yu0c8e4522020-08-24 14:41:32 +080067// +--> WAIT_FOR_DELETE <-----+
Mike Yu441d9372020-07-15 17:06:22 +080068//
69//
70// TODO: Add onHandshakeFinished() for handshake results.
Mike Yub601ff72018-11-01 20:07:00 +080071class DnsTlsSocket : public IDnsTlsSocket {
72 public:
Mike Yu0c8e4522020-08-24 14:41:32 +080073 enum class State {
74 UNINITIALIZED,
75 INITIALIZED,
76 CONNECTING,
77 CONNECTED,
78 WAIT_FOR_DELETE,
79 };
80
Mike Yubab3daa2018-10-19 22:11:43 +080081 DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
Bernie Innocentiec4219b2019-01-30 11:16:36 +090082 IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
83 : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
Mike Yubab3daa2018-10-19 22:11:43 +080084 ~DnsTlsSocket();
85
Mike Yu441d9372020-07-15 17:06:22 +080086 // Creates the SSL context for this session. Returns false on failure.
Mike Yubab3daa2018-10-19 22:11:43 +080087 // This method should be called after construction and before use of a DnsTlsSocket.
88 // Only call this method once per DnsTlsSocket.
89 bool initialize() EXCLUDES(mLock);
90
Mike Yue93d9ae2020-08-25 19:09:51 +080091 // If async handshake is enabled, this function simply signals a handshake request, and the
92 // handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
93 // this function performs the handshake and returns after the handshake finishes.
Mike Yu441d9372020-07-15 17:06:22 +080094 bool startHandshake() EXCLUDES(mLock);
95
Mike Yubab3daa2018-10-19 22:11:43 +080096 // Send a query on the provided SSL socket. |query| contains
97 // the body of a query, not including the ID header. This function will typically return before
98 // the query is actually sent. If this function fails, DnsTlsSocketObserver will be
99 // notified that the socket is closed.
100 // Note that success here indicates successful sending, not receipt of a response.
101 // Thread-safe.
Ben Schwartz62176fd2019-01-22 17:32:17 -0500102 bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
Mike Yubab3daa2018-10-19 22:11:43 +0800103
Bernie Innocentiec4219b2019-01-30 11:16:36 +0900104 private:
Mike Yubab3daa2018-10-19 22:11:43 +0800105 // Lock to be held by the SSL event loop thread. This is not normally in contention.
106 std::mutex mLock;
107
108 // Forwards queries and receives responses. Blocks until the idle timeout.
109 void loop() EXCLUDES(mLock);
110 std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
111
112 // On success, sets mSslFd to a socket connected to mAddr (the
113 // connection will likely be in progress if mProtocol is IPPROTO_TCP).
114 // On error, returns the errno.
115 netdutils::Status tcpConnect() REQUIRES(mLock);
116
Mike Yue93d9ae2020-08-25 19:09:51 +0800117 bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);
118
Mike Yubab3daa2018-10-19 22:11:43 +0800119 // Connect an SSL session on the provided socket. If connection fails, closing the
120 // socket remains the caller's responsibility.
121 bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
122
Mike Yue93d9ae2020-08-25 19:09:51 +0800123 // Connect an SSL session on the provided socket. This is an interruptible version
124 // which allows to terminate connection handshake any time.
125 bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);
126
Mike Yubab3daa2018-10-19 22:11:43 +0800127 // Disconnect the SSL session and close the socket.
128 void sslDisconnect() REQUIRES(mLock);
129
130 // Writes a buffer to the socket.
Bernie Innocentiec4219b2019-01-30 11:16:36 +0900131 bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
Mike Yubab3daa2018-10-19 22:11:43 +0800132
133 // Reads exactly the specified number of bytes from the socket, or fails.
134 // Returns SSL_ERROR_NONE on success.
135 // If |wait| is true, then this function always blocks. Otherwise, it
136 // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
Bernie Innocentiec4219b2019-01-30 11:16:36 +0900137 int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
Mike Yubab3daa2018-10-19 22:11:43 +0800138
Ben Schwartz2187abe2019-01-10 14:30:46 -0500139 bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
Mike Yu5e1b9912020-11-10 16:50:13 +0800140
141 // Read one DNS response. It can potentially block until reading the exact bytes of
142 // the response.
Mike Yubab3daa2018-10-19 22:11:43 +0800143 bool readResponse() REQUIRES(mLock);
144
waynema0e73c2e2019-07-31 15:04:08 +0800145 // It is only used for DNS-OVER-TLS internal test.
146 bool setTestCaCertificate() REQUIRES(mLock);
147
Ben Schwartz62176fd2019-01-22 17:32:17 -0500148 // Similar to query(), this function uses incrementEventFd to send a message to the
149 // loop thread. However, instead of incrementing the counter by one (indicating a
150 // new query), it wraps the counter to negative, which we use to indicate a shutdown
151 // request.
152 void requestLoopShutdown() EXCLUDES(mLock);
153
154 // This function sends a message to the loop thread by incrementing mEventFd.
155 bool incrementEventFd(int64_t count) EXCLUDES(mLock);
156
Mike Yue93d9ae2020-08-25 19:09:51 +0800157 // Transition the state from expected state |from| to new state |to|.
158 void transitionState(State from, State to) REQUIRES(mLock);
159
Ben Schwartz2187abe2019-01-10 14:30:46 -0500160 // Queue of pending queries. query() pushes items onto the queue and notifies
161 // the loop thread by incrementing mEventFd. loop() reads items off the queue.
162 LockedQueue<std::vector<uint8_t>> mQueue;
163
164 // eventfd socket used for notifying the SSL thread when queries are ready to send.
165 // This socket acts similarly to an atomic counter, incremented by query() and cleared
166 // by loop(). We have to use a socket because the SSL thread needs to wait in poll()
Ben Schwartz62176fd2019-01-22 17:32:17 -0500167 // for input from either a remote server or a query thread. Since eventfd does not have
168 // EOF, we indicate a close request by setting the counter to a negative number.
169 // This file descriptor is opened by initialize(), and closed implicitly after
170 // destruction.
Mike Yue93d9ae2020-08-25 19:09:51 +0800171 // Note that: data starts being read from the eventfd when the state is CONNECTED.
Ben Schwartz2187abe2019-01-10 14:30:46 -0500172 base::unique_fd mEventFd;
Mike Yubab3daa2018-10-19 22:11:43 +0800173
Mike Yue93d9ae2020-08-25 19:09:51 +0800174 // An eventfd used to listen to shutdown requests when the state is CONNECTING.
175 // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
176 // handle shutdown requests.
177 base::unique_fd mShutdownEvent;
178
Mike Yubab3daa2018-10-19 22:11:43 +0800179 // SSL Socket fields.
180 bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
181 base::unique_fd mSslFd GUARDED_BY(mLock);
182 bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
183 static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
184
185 const unsigned mMark; // Socket mark
186 const DnsTlsServer mServer;
187 IDnsTlsSocketObserver* _Nonnull const mObserver;
188 DnsTlsSessionCache* _Nonnull const mCache;
Mike Yu0c8e4522020-08-24 14:41:32 +0800189 State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
Mike Yue93d9ae2020-08-25 19:09:51 +0800190
191 // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
192 // thread (the call to startHandshake()).
193 bool mAsyncHandshake GUARDED_BY(mLock) = false;
194
Mike Yu19192712020-08-28 11:56:31 +0800195 // The time to wait for the attempt on connecting to the server.
196 // Set the default value 127 seconds to be consistent with TCP connect timeout.
197 // (presume net.ipv4.tcp_syn_retries = 6)
198 static constexpr int kDotConnectTimeoutMs = 127 * 1000;
199 int mConnectTimeoutMs;
200
Mike Yue93d9ae2020-08-25 19:09:51 +0800201 // For testing.
202 friend class DnsTlsSocketTest;
Mike Yubab3daa2018-10-19 22:11:43 +0800203};
204
205} // end of namespace net
206} // end of namespace android
207
208#endif // _DNS_DNSTLSSOCKET_H