blob: 919ff6f79a0370ce87f91a068cd95b244acfb61d [file] [log] [blame]
Mike Yuc52739e2018-10-19 21:06:32 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yuc52739e2018-10-19 21:06:32 +080018
chenbruceaff85842019-05-31 15:46:42 +080019#include <arpa/inet.h>
20
21#include <chrono>
22
23#include <android-base/logging.h>
24#include <android-base/macros.h>
Mike Yu1b9069c2020-08-25 15:17:29 +080025#include <gmock/gmock.h>
Mike Yuc52739e2018-10-19 21:06:32 +080026#include <gtest/gtest.h>
chenbruceaff85842019-05-31 15:46:42 +080027#include <netdutils/Slice.h>
Mike Yuc52739e2018-10-19 21:06:32 +080028
Bernie Innocentiec4219b2019-01-30 11:16:36 +090029#include "DnsTlsDispatcher.h"
30#include "DnsTlsQueryMap.h"
31#include "DnsTlsServer.h"
32#include "DnsTlsSessionCache.h"
33#include "DnsTlsSocket.h"
34#include "DnsTlsTransport.h"
Mike Yubb499092020-08-28 19:18:42 +080035#include "Experiments.h"
Bernie Innocentiec4219b2019-01-30 11:16:36 +090036#include "IDnsTlsSocket.h"
37#include "IDnsTlsSocketFactory.h"
38#include "IDnsTlsSocketObserver.h"
chenbruceb43ec752019-07-24 20:19:41 +080039#include "tests/dns_responder/dns_tls_frontend.h"
Ben Schwartz62176fd2019-01-22 17:32:17 -050040
Mike Yuc52739e2018-10-19 21:06:32 +080041namespace android {
42namespace net {
43
Mike Yuc52739e2018-10-19 21:06:32 +080044using netdutils::makeSlice;
Mike Yu1b9069c2020-08-25 15:17:29 +080045using netdutils::Slice;
Mike Yuc52739e2018-10-19 21:06:32 +080046
Mike Yubb499092020-08-28 19:18:42 +080047static const std::string DOT_MAXTRIES_FLAG = "dot_maxtries";
48
Mike Yuc52739e2018-10-19 21:06:32 +080049typedef std::vector<uint8_t> bytevec;
50
51static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
52 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
53 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
54 // IPv4 parse succeeded, so it's IPv4
55 sin->sin_family = AF_INET;
56 sin->sin_port = htons(port);
57 return;
58 }
59 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
60 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
61 // IPv6 parse succeeded, so it's IPv6.
62 sin6->sin6_family = AF_INET6;
63 sin6->sin6_port = htons(port);
64 return;
65 }
chenbruceaff85842019-05-31 15:46:42 +080066 LOG(ERROR) << "Failed to parse server address: " << server;
Mike Yuc52739e2018-10-19 21:06:32 +080067}
68
Mike Yuc52739e2018-10-19 21:06:32 +080069std::string SERVERNAME1 = "dns.example.com";
70std::string SERVERNAME2 = "dns.example.org";
71
72// BaseTest just provides constants that are useful for the tests.
73class BaseTest : public ::testing::Test {
74 protected:
75 BaseTest() {
76 parseServer("192.0.2.1", 853, &V4ADDR1);
77 parseServer("192.0.2.2", 853, &V4ADDR2);
78 parseServer("2001:db8::1", 853, &V6ADDR1);
79 parseServer("2001:db8::2", 853, &V6ADDR2);
80
81 SERVER1 = DnsTlsServer(V4ADDR1);
Mike Yuc52739e2018-10-19 21:06:32 +080082 SERVER1.name = SERVERNAME1;
83 }
84
85 sockaddr_storage V4ADDR1;
86 sockaddr_storage V4ADDR2;
87 sockaddr_storage V6ADDR1;
88 sockaddr_storage V6ADDR2;
89
90 DnsTlsServer SERVER1;
91};
92
93bytevec make_query(uint16_t id, size_t size) {
94 bytevec vec(size);
95 vec[0] = id >> 8;
96 vec[1] = id;
97 // Arbitrarily fill the query body with unique data.
98 for (size_t i = 2; i < size; ++i) {
99 vec[i] = id + i;
100 }
101 return vec;
102}
103
104// Query constants
105const unsigned MARK = 123;
106const uint16_t ID = 52;
107const uint16_t SIZE = 22;
108const bytevec QUERY = make_query(ID, SIZE);
109
110template <class T>
111class FakeSocketFactory : public IDnsTlsSocketFactory {
112 public:
113 FakeSocketFactory() {}
114 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
115 const DnsTlsServer& server ATTRIBUTE_UNUSED,
116 unsigned mark ATTRIBUTE_UNUSED,
117 IDnsTlsSocketObserver* observer,
118 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
119 return std::make_unique<T>(observer);
120 }
121};
122
123bytevec make_echo(uint16_t id, const Slice query) {
124 bytevec response(query.size() + 2);
125 response[0] = id >> 8;
126 response[1] = id;
127 // Echo the query as the fake response.
128 memcpy(response.data() + 2, query.base(), query.size());
129 return response;
130}
131
132// Simplest possible fake server. This just echoes the query as the response.
133class FakeSocketEcho : public IDnsTlsSocket {
134 public:
135 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
136 bool query(uint16_t id, const Slice query) override {
137 // Return the response immediately (asynchronously).
138 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
139 return true;
140 }
Mike Yu441d9372020-07-15 17:06:22 +0800141 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800142
143 private:
144 IDnsTlsSocketObserver* const mObserver;
145};
146
147class TransportTest : public BaseTest {};
148
149TEST_F(TransportTest, Query) {
150 FakeSocketFactory<FakeSocketEcho> factory;
151 DnsTlsTransport transport(SERVER1, MARK, &factory);
152 auto r = transport.query(makeSlice(QUERY)).get();
153
154 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
155 EXPECT_EQ(QUERY, r.response);
Mike Yu1fea18c2019-12-06 10:59:17 +0800156 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800157}
158
159// Fake Socket that echoes the observed query ID as the response body.
160class FakeSocketId : public IDnsTlsSocket {
161 public:
162 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
163 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
164 // Return the response immediately (asynchronously).
165 bytevec response(4);
166 // Echo the ID in the header to match the response to the query.
167 // This will be overwritten by DnsTlsQueryMap.
168 response[0] = id >> 8;
169 response[1] = id;
170 // Echo the ID in the body, so that the test can verify which ID was used by
171 // DnsTlsQueryMap.
172 response[2] = id >> 8;
173 response[3] = id;
174 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
175 return true;
176 }
Mike Yu441d9372020-07-15 17:06:22 +0800177 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800178
179 private:
180 IDnsTlsSocketObserver* const mObserver;
181};
182
183// Test that IDs are properly reused
184TEST_F(TransportTest, IdReuse) {
185 FakeSocketFactory<FakeSocketId> factory;
186 DnsTlsTransport transport(SERVER1, MARK, &factory);
187 for (int i = 0; i < 100; ++i) {
188 // Send a query.
Mike Yubd136992019-12-04 15:01:07 +0800189 std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
Mike Yuc52739e2018-10-19 21:06:32 +0800190 // Wait for the response.
Mike Yubd136992019-12-04 15:01:07 +0800191 DnsTlsTransport::Result r = f.get();
Mike Yuc52739e2018-10-19 21:06:32 +0800192 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
193
194 // All queries should have an observed ID of zero, because it is returned to the ID pool
195 // after each use.
196 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
197 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800198 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800199}
200
201// These queries might be handled in serial or parallel as they race the
202// responses.
203TEST_F(TransportTest, RacingQueries_10000) {
204 FakeSocketFactory<FakeSocketEcho> factory;
205 DnsTlsTransport transport(SERVER1, MARK, &factory);
206 std::vector<std::future<DnsTlsTransport::Result>> results;
207 // Fewer than 65536 queries to avoid ID exhaustion.
208 const int num_queries = 10000;
209 results.reserve(num_queries);
210 for (int i = 0; i < num_queries; ++i) {
211 results.push_back(transport.query(makeSlice(QUERY)));
212 }
213 for (auto& result : results) {
214 auto r = result.get();
215 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
216 EXPECT_EQ(QUERY, r.response);
217 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800218 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800219}
220
221// A server that waits until sDelay queries are queued before responding.
222class FakeSocketDelay : public IDnsTlsSocket {
223 public:
224 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
Mike Yu441d9372020-07-15 17:06:22 +0800225 ~FakeSocketDelay() {
226 std::lock_guard guard(mLock);
227 sDelay = 1;
228 sReverse = false;
229 sConnectable = true;
230 }
231 inline static size_t sDelay = 1;
232 inline static bool sReverse = false;
233 inline static bool sConnectable = true;
Mike Yuc52739e2018-10-19 21:06:32 +0800234
235 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800236 LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
Mike Yuc52739e2018-10-19 21:06:32 +0800237 std::lock_guard guard(mLock);
238 // Check for duplicate IDs.
239 EXPECT_EQ(0U, mIds.count(id));
240 mIds.insert(id);
241
242 // Store response.
243 mResponses.push_back(make_echo(id, query));
244
chenbruceaff85842019-05-31 15:46:42 +0800245 LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
Mike Yuc52739e2018-10-19 21:06:32 +0800246 if (mResponses.size() == sDelay) {
247 std::thread(&FakeSocketDelay::sendResponses, this).detach();
248 }
249 return true;
250 }
Mike Yu441d9372020-07-15 17:06:22 +0800251 bool startHandshake() override { return sConnectable; }
Mike Yuc52739e2018-10-19 21:06:32 +0800252
253 private:
254 void sendResponses() {
255 std::lock_guard guard(mLock);
256 if (sReverse) {
257 std::reverse(std::begin(mResponses), std::end(mResponses));
258 }
259 for (auto& response : mResponses) {
260 mObserver->onResponse(response);
261 }
262 mIds.clear();
263 mResponses.clear();
264 }
265
266 std::mutex mLock;
267 IDnsTlsSocketObserver* const mObserver;
268 std::set<uint16_t> mIds GUARDED_BY(mLock);
269 std::vector<bytevec> mResponses GUARDED_BY(mLock);
270};
271
Mike Yuc52739e2018-10-19 21:06:32 +0800272TEST_F(TransportTest, ParallelColliding) {
273 FakeSocketDelay::sDelay = 10;
274 FakeSocketDelay::sReverse = false;
275 FakeSocketFactory<FakeSocketDelay> factory;
276 DnsTlsTransport transport(SERVER1, MARK, &factory);
277 std::vector<std::future<DnsTlsTransport::Result>> results;
278 // Fewer than 65536 queries to avoid ID exhaustion.
279 results.reserve(FakeSocketDelay::sDelay);
280 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
281 results.push_back(transport.query(makeSlice(QUERY)));
282 }
283 for (auto& result : results) {
284 auto r = result.get();
285 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
286 EXPECT_EQ(QUERY, r.response);
287 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800288 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800289}
290
291TEST_F(TransportTest, ParallelColliding_Max) {
292 FakeSocketDelay::sDelay = 65536;
293 FakeSocketDelay::sReverse = false;
294 FakeSocketFactory<FakeSocketDelay> factory;
295 DnsTlsTransport transport(SERVER1, MARK, &factory);
296 std::vector<std::future<DnsTlsTransport::Result>> results;
297 // Exactly 65536 queries should still be possible in parallel,
298 // even if they all have the same original ID.
299 results.reserve(FakeSocketDelay::sDelay);
300 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
301 results.push_back(transport.query(makeSlice(QUERY)));
302 }
303 for (auto& result : results) {
304 auto r = result.get();
305 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
306 EXPECT_EQ(QUERY, r.response);
307 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800308 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800309}
310
311TEST_F(TransportTest, ParallelUnique) {
312 FakeSocketDelay::sDelay = 10;
313 FakeSocketDelay::sReverse = false;
314 FakeSocketFactory<FakeSocketDelay> factory;
315 DnsTlsTransport transport(SERVER1, MARK, &factory);
316 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
317 std::vector<std::future<DnsTlsTransport::Result>> results;
318 results.reserve(FakeSocketDelay::sDelay);
319 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
320 queries[i] = make_query(i, SIZE);
321 results.push_back(transport.query(makeSlice(queries[i])));
322 }
323 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
324 auto r = results[i].get();
325 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
326 EXPECT_EQ(queries[i], r.response);
327 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800328 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800329}
330
331TEST_F(TransportTest, ParallelUnique_Max) {
332 FakeSocketDelay::sDelay = 65536;
333 FakeSocketDelay::sReverse = false;
334 FakeSocketFactory<FakeSocketDelay> factory;
335 DnsTlsTransport transport(SERVER1, MARK, &factory);
336 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
337 std::vector<std::future<DnsTlsTransport::Result>> results;
338 // Exactly 65536 queries should still be possible in parallel,
339 // and they should all be mapped correctly back to the original ID.
340 results.reserve(FakeSocketDelay::sDelay);
341 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
342 queries[i] = make_query(i, SIZE);
343 results.push_back(transport.query(makeSlice(queries[i])));
344 }
345 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
346 auto r = results[i].get();
347 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
348 EXPECT_EQ(queries[i], r.response);
349 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800350 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800351}
352
353TEST_F(TransportTest, IdExhaustion) {
354 const int num_queries = 65536;
355 // A delay of 65537 is unreachable, because the maximum number
356 // of outstanding queries is 65536.
357 FakeSocketDelay::sDelay = num_queries + 1;
358 FakeSocketDelay::sReverse = false;
359 FakeSocketFactory<FakeSocketDelay> factory;
360 DnsTlsTransport transport(SERVER1, MARK, &factory);
361 std::vector<std::future<DnsTlsTransport::Result>> results;
362 // Issue the maximum number of queries.
363 results.reserve(num_queries);
364 for (int i = 0; i < num_queries; ++i) {
365 results.push_back(transport.query(makeSlice(QUERY)));
366 }
367
368 // The ID space is now full, so subsequent queries should fail immediately.
369 auto r = transport.query(makeSlice(QUERY)).get();
370 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
371 EXPECT_TRUE(r.response.empty());
372
373 for (auto& result : results) {
374 // All other queries should remain outstanding.
375 EXPECT_EQ(std::future_status::timeout,
376 result.wait_for(std::chrono::duration<int>::zero()));
377 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800378 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800379}
380
381// Responses can come back from the server in any order. This should have no
382// effect on Transport's observed behavior.
383TEST_F(TransportTest, ReverseOrder) {
384 FakeSocketDelay::sDelay = 10;
385 FakeSocketDelay::sReverse = true;
386 FakeSocketFactory<FakeSocketDelay> factory;
387 DnsTlsTransport transport(SERVER1, MARK, &factory);
388 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
389 std::vector<std::future<DnsTlsTransport::Result>> results;
390 results.reserve(FakeSocketDelay::sDelay);
391 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
392 queries[i] = make_query(i, SIZE);
393 results.push_back(transport.query(makeSlice(queries[i])));
394 }
395 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
396 auto r = results[i].get();
397 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
398 EXPECT_EQ(queries[i], r.response);
399 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800400 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800401}
402
403TEST_F(TransportTest, ReverseOrder_Max) {
404 FakeSocketDelay::sDelay = 65536;
405 FakeSocketDelay::sReverse = true;
406 FakeSocketFactory<FakeSocketDelay> factory;
407 DnsTlsTransport transport(SERVER1, MARK, &factory);
408 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
409 std::vector<std::future<DnsTlsTransport::Result>> results;
410 results.reserve(FakeSocketDelay::sDelay);
411 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
412 queries[i] = make_query(i, SIZE);
413 results.push_back(transport.query(makeSlice(queries[i])));
414 }
415 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
416 auto r = results[i].get();
417 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
418 EXPECT_EQ(queries[i], r.response);
419 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800420 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800421}
422
423// Returning null from the factory indicates a connection failure.
424class NullSocketFactory : public IDnsTlsSocketFactory {
425 public:
426 NullSocketFactory() {}
427 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
428 const DnsTlsServer& server ATTRIBUTE_UNUSED,
429 unsigned mark ATTRIBUTE_UNUSED,
430 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
431 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
432 return nullptr;
433 }
434};
435
436TEST_F(TransportTest, ConnectFail) {
Mike Yu441d9372020-07-15 17:06:22 +0800437 // Failure on creating socket.
438 NullSocketFactory factory1;
439 DnsTlsTransport transport1(SERVER1, MARK, &factory1);
440 auto r = transport1.query(makeSlice(QUERY)).get();
Mike Yuc52739e2018-10-19 21:06:32 +0800441
442 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
443 EXPECT_TRUE(r.response.empty());
Mike Yu441d9372020-07-15 17:06:22 +0800444 EXPECT_EQ(transport1.getConnectCounter(), 1);
445
446 // Failure on handshaking.
447 FakeSocketDelay::sConnectable = false;
448 FakeSocketFactory<FakeSocketDelay> factory2;
449 DnsTlsTransport transport2(SERVER1, MARK, &factory2);
450 r = transport2.query(makeSlice(QUERY)).get();
451
452 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
453 EXPECT_TRUE(r.response.empty());
454 EXPECT_EQ(transport2.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800455}
456
457// Simulate a socket that connects but then immediately receives a server
458// close notification.
459class FakeSocketClose : public IDnsTlsSocket {
460 public:
461 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
462 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
463 ~FakeSocketClose() { mCloser.join(); }
464 bool query(uint16_t id ATTRIBUTE_UNUSED,
465 const Slice query ATTRIBUTE_UNUSED) override {
466 return true;
467 }
Mike Yu441d9372020-07-15 17:06:22 +0800468 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800469
470 private:
471 std::thread mCloser;
472};
473
474TEST_F(TransportTest, CloseRetryFail) {
475 FakeSocketFactory<FakeSocketClose> factory;
476 DnsTlsTransport transport(SERVER1, MARK, &factory);
477 auto r = transport.query(makeSlice(QUERY)).get();
478
479 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
480 EXPECT_TRUE(r.response.empty());
Mike Yucb2bb7c2019-11-22 20:42:13 +0800481
Mike Yubb499092020-08-28 19:18:42 +0800482 // Reconnections might be triggered depending on the flag.
483 EXPECT_EQ(transport.getConnectCounter(),
484 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
Mike Yuc52739e2018-10-19 21:06:32 +0800485}
486
487// Simulate a server that occasionally closes the connection and silently
488// drops some queries.
489class FakeSocketLimited : public IDnsTlsSocket {
490 public:
491 static int sLimit; // Number of queries to answer per socket.
492 static size_t sMaxSize; // Silently discard queries greater than this size.
493 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
494 : mObserver(observer), mQueries(0) {}
495 ~FakeSocketLimited() {
496 {
chenbruceaff85842019-05-31 15:46:42 +0800497 LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800498 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800499 LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800500 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800501 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800502 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800503 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800504 }
505 mThreads.clear();
506 }
507
508 if (mCloser) {
chenbruceaff85842019-05-31 15:46:42 +0800509 LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800510 mCloser->join();
chenbruceaff85842019-05-31 15:46:42 +0800511 LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800512 }
513 }
514 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800515 LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800516 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800517 LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800518 ++mQueries;
519
520 if (mQueries <= sLimit) {
chenbruceaff85842019-05-31 15:46:42 +0800521 LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
Mike Yuc52739e2018-10-19 21:06:32 +0800522 if (query.size() <= sMaxSize) {
523 // Return the response immediately (asynchronously).
524 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
525 }
526 }
527 if (mQueries == sLimit) {
528 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
529 }
530 return mQueries <= sLimit;
531 }
Mike Yu441d9372020-07-15 17:06:22 +0800532 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800533
534 private:
535 void sendClose() {
536 {
chenbruceaff85842019-05-31 15:46:42 +0800537 LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800538 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800539 LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800540 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800541 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800542 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800543 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800544 }
545 mThreads.clear();
546 }
547 mObserver->onClosed();
548 }
549 std::mutex mLock;
550 IDnsTlsSocketObserver* const mObserver;
551 int mQueries GUARDED_BY(mLock);
552 std::vector<std::thread> mThreads GUARDED_BY(mLock);
553 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
554};
555
556int FakeSocketLimited::sLimit;
557size_t FakeSocketLimited::sMaxSize;
558
559TEST_F(TransportTest, SilentDrop) {
560 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
561 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
562 FakeSocketFactory<FakeSocketLimited> factory;
563 DnsTlsTransport transport(SERVER1, MARK, &factory);
564
565 // Queue up 10 queries. They will all be ignored, and after the 10th,
566 // the socket will close. Transport will retry them all, until they
567 // all hit the retry limit and expire.
568 std::vector<std::future<DnsTlsTransport::Result>> results;
569 results.reserve(FakeSocketLimited::sLimit);
570 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
571 results.push_back(transport.query(makeSlice(QUERY)));
572 }
573 for (auto& result : results) {
574 auto r = result.get();
575 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
576 EXPECT_TRUE(r.response.empty());
577 }
Mike Yucb2bb7c2019-11-22 20:42:13 +0800578
Mike Yubb499092020-08-28 19:18:42 +0800579 // Reconnections might be triggered depending on the flag.
580 EXPECT_EQ(transport.getConnectCounter(),
581 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
Mike Yuc52739e2018-10-19 21:06:32 +0800582}
583
584TEST_F(TransportTest, PartialDrop) {
585 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
586 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
587 FakeSocketFactory<FakeSocketLimited> factory;
588 DnsTlsTransport transport(SERVER1, MARK, &factory);
589
590 // Queue up 100 queries, alternating "short" which will be served and "long"
591 // which will be dropped.
592 const int num_queries = 10 * FakeSocketLimited::sLimit;
593 std::vector<bytevec> queries(num_queries);
594 std::vector<std::future<DnsTlsTransport::Result>> results;
595 results.reserve(num_queries);
596 for (int i = 0; i < num_queries; ++i) {
597 queries[i] = make_query(i, SIZE + (i % 2));
598 results.push_back(transport.query(makeSlice(queries[i])));
599 }
600 // Just check the short queries, which are at the even indices.
601 for (int i = 0; i < num_queries; i += 2) {
602 auto r = results[i].get();
603 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
604 EXPECT_EQ(queries[i], r.response);
605 }
Mike Yucb2bb7c2019-11-22 20:42:13 +0800606
607 // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
608 // connect attempts for this test.
609}
610
611TEST_F(TransportTest, ConnectCounter) {
612 FakeSocketLimited::sLimit = 2; // Close the socket after 2 queries.
613 FakeSocketLimited::sMaxSize = SIZE; // No query drops.
614 FakeSocketFactory<FakeSocketLimited> factory;
615 DnsTlsTransport transport(SERVER1, MARK, &factory);
616
617 // Connecting on demand.
Mike Yu1fea18c2019-12-06 10:59:17 +0800618 EXPECT_EQ(transport.getConnectCounter(), 0);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800619
620 const int num_queries = 10;
621 std::vector<std::future<DnsTlsTransport::Result>> results;
622 results.reserve(num_queries);
623 for (int i = 0; i < num_queries; i++) {
624 // Reconnections take place every two queries.
625 results.push_back(transport.query(makeSlice(QUERY)));
626 }
627 for (int i = 0; i < num_queries; i++) {
628 auto r = results[i].get();
629 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
630 }
631
Mike Yu1fea18c2019-12-06 10:59:17 +0800632 EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
Mike Yuc52739e2018-10-19 21:06:32 +0800633}
634
635// Simulate a malfunctioning server that injects extra miscellaneous
636// responses to queries that were not asked. This will cause wrong answers but
637// must not crash the Transport.
638class FakeSocketGarbage : public IDnsTlsSocket {
639 public:
640 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
641 // Inject a garbage event.
642 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
643 }
644 ~FakeSocketGarbage() {
645 std::lock_guard guard(mLock);
646 for (auto& thread : mThreads) {
647 thread.join();
648 }
649 }
650 bool query(uint16_t id, const Slice query) override {
651 std::lock_guard guard(mLock);
652 // Return the response twice.
653 auto echo = make_echo(id, query);
654 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
655 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
656 // Also return some other garbage
657 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
658 return true;
659 }
Mike Yu441d9372020-07-15 17:06:22 +0800660 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800661
662 private:
663 std::mutex mLock;
664 std::vector<std::thread> mThreads GUARDED_BY(mLock);
665 IDnsTlsSocketObserver* const mObserver;
666};
667
668TEST_F(TransportTest, IgnoringGarbage) {
669 FakeSocketFactory<FakeSocketGarbage> factory;
670 DnsTlsTransport transport(SERVER1, MARK, &factory);
671 for (int i = 0; i < 10; ++i) {
672 auto r = transport.query(makeSlice(QUERY)).get();
673
674 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
675 // Don't check the response because this server is malfunctioning.
676 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800677 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800678}
679
680// Dispatcher tests
681class DispatcherTest : public BaseTest {};
682
683TEST_F(DispatcherTest, Query) {
684 bytevec ans(4096);
685 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800686 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800687
688 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
689 DnsTlsDispatcher dispatcher(std::move(factory));
Mike Yucb2bb7c2019-11-22 20:42:13 +0800690 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
691 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800692
693 EXPECT_EQ(DnsTlsTransport::Response::success, r);
694 EXPECT_EQ(int(QUERY.size()), resplen);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800695 EXPECT_TRUE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800696 ans.resize(resplen);
697 EXPECT_EQ(QUERY, ans);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800698
699 // Expect to reuse the connection.
700 r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
701 &connectTriggered);
702 EXPECT_EQ(DnsTlsTransport::Response::success, r);
703 EXPECT_FALSE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800704}
705
706TEST_F(DispatcherTest, AnswerTooLarge) {
707 bytevec ans(SIZE - 1); // Too small to hold the answer
708 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800709 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800710
711 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
712 DnsTlsDispatcher dispatcher(std::move(factory));
Mike Yucb2bb7c2019-11-22 20:42:13 +0800713 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
714 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800715
716 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800717 EXPECT_TRUE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800718}
719
720template<class T>
721class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
722 public:
723 TrackingFakeSocketFactory() {}
724 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
725 const DnsTlsServer& server,
726 unsigned mark,
727 IDnsTlsSocketObserver* observer,
728 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
729 std::lock_guard guard(mLock);
730 keys.emplace(mark, server);
731 return std::make_unique<T>(observer);
732 }
733 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
734
735 private:
736 std::mutex mLock;
737};
738
739TEST_F(DispatcherTest, Dispatching) {
740 FakeSocketDelay::sDelay = 5;
741 FakeSocketDelay::sReverse = true;
742 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
743 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
744 DnsTlsDispatcher dispatcher(std::move(factory));
745
746 // Populate a vector of two servers and two socket marks, four combinations
747 // in total.
748 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
749 keys.emplace_back(MARK, SERVER1);
750 keys.emplace_back(MARK + 1, SERVER1);
751 keys.emplace_back(MARK, V4ADDR2);
752 keys.emplace_back(MARK + 1, V4ADDR2);
753
754 // Do several queries on each server. They should all succeed.
755 std::vector<std::thread> threads;
756 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
757 auto key = keys[i % keys.size()];
758 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
759 auto q = make_query(i, SIZE);
760 bytevec ans(4096);
761 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800762 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800763 unsigned mark = key.first;
764 const DnsTlsServer& server = key.second;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800765 auto r = dispatcher->query(server, mark, makeSlice(q), makeSlice(ans), &resplen,
766 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800767 EXPECT_EQ(DnsTlsTransport::Response::success, r);
768 EXPECT_EQ(int(q.size()), resplen);
769 ans.resize(resplen);
770 EXPECT_EQ(q, ans);
771 }, &dispatcher);
772 }
773 for (auto& thread : threads) {
774 thread.join();
775 }
776 // We expect that the factory created one socket for each key.
777 EXPECT_EQ(keys.size(), weak_factory->keys.size());
778 for (auto& key : keys) {
779 EXPECT_EQ(1U, weak_factory->keys.count(key));
780 }
781}
782
783// Check DnsTlsServer's comparison logic.
784AddressComparator ADDRESS_COMPARATOR;
785bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
786 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
787 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
788 EXPECT_FALSE(cmp1 && cmp2);
789 return !cmp1 && !cmp2;
790}
791
792void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
793 EXPECT_TRUE(s1 == s1);
794 EXPECT_TRUE(s2 == s2);
795 EXPECT_TRUE(isAddressEqual(s1, s1));
796 EXPECT_TRUE(isAddressEqual(s2, s2));
797
798 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
799 EXPECT_FALSE(s1 == s2);
800 EXPECT_FALSE(s2 == s1);
801}
802
Mike Yufa985f72020-11-23 20:24:21 +0800803void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
804 EXPECT_TRUE(s1 == s1);
805 EXPECT_TRUE(s2 == s2);
806 EXPECT_TRUE(isAddressEqual(s1, s1));
807 EXPECT_TRUE(isAddressEqual(s2, s2));
808
809 EXPECT_FALSE(s1 < s2);
810 EXPECT_FALSE(s2 < s1);
811 EXPECT_TRUE(s1 == s2);
812 EXPECT_TRUE(s2 == s1);
813}
814
Mike Yuc52739e2018-10-19 21:06:32 +0800815class ServerTest : public BaseTest {};
816
817TEST_F(ServerTest, IPv4) {
818 checkUnequal(V4ADDR1, V4ADDR2);
819 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
820}
821
822TEST_F(ServerTest, IPv6) {
823 checkUnequal(V6ADDR1, V6ADDR2);
824 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
825}
826
827TEST_F(ServerTest, MixedAddressFamily) {
828 checkUnequal(V6ADDR1, V4ADDR1);
829 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
830}
831
832TEST_F(ServerTest, IPv6ScopeId) {
833 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
834 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
835 addr1->sin6_scope_id = 1;
836 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
837 addr2->sin6_scope_id = 2;
838 checkUnequal(s1, s2);
839 EXPECT_FALSE(isAddressEqual(s1, s2));
840
841 EXPECT_FALSE(s1.wasExplicitlyConfigured());
842 EXPECT_FALSE(s2.wasExplicitlyConfigured());
843}
844
845TEST_F(ServerTest, IPv6FlowInfo) {
846 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
847 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
848 addr1->sin6_flowinfo = 1;
849 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
850 addr2->sin6_flowinfo = 2;
851 // All comparisons ignore flowinfo.
852 EXPECT_EQ(s1, s2);
853 EXPECT_TRUE(isAddressEqual(s1, s2));
854
855 EXPECT_FALSE(s1.wasExplicitlyConfigured());
856 EXPECT_FALSE(s2.wasExplicitlyConfigured());
857}
858
859TEST_F(ServerTest, Port) {
860 DnsTlsServer s1, s2;
861 parseServer("192.0.2.1", 853, &s1.ss);
862 parseServer("192.0.2.1", 854, &s2.ss);
863 checkUnequal(s1, s2);
864 EXPECT_TRUE(isAddressEqual(s1, s2));
865
866 DnsTlsServer s3, s4;
867 parseServer("2001:db8::1", 853, &s3.ss);
868 parseServer("2001:db8::1", 852, &s4.ss);
869 checkUnequal(s3, s4);
870 EXPECT_TRUE(isAddressEqual(s3, s4));
871
872 EXPECT_FALSE(s1.wasExplicitlyConfigured());
873 EXPECT_FALSE(s2.wasExplicitlyConfigured());
874}
875
876TEST_F(ServerTest, Name) {
877 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
878 s1.name = SERVERNAME1;
879 checkUnequal(s1, s2);
880 s2.name = SERVERNAME2;
881 checkUnequal(s1, s2);
882 EXPECT_TRUE(isAddressEqual(s1, s2));
883
884 EXPECT_TRUE(s1.wasExplicitlyConfigured());
885 EXPECT_TRUE(s2.wasExplicitlyConfigured());
886}
887
Mike Yufa985f72020-11-23 20:24:21 +0800888TEST_F(ServerTest, State) {
889 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
890 checkEqual(s1, s2);
891 s1.setValidationState(Validation::success);
892 checkEqual(s1, s2);
893 s2.setValidationState(Validation::fail);
894 checkEqual(s1, s2);
895
896 EXPECT_EQ(s1.validationState(), Validation::success);
897 EXPECT_EQ(s2.validationState(), Validation::fail);
898}
899
Mike Yuc52739e2018-10-19 21:06:32 +0800900TEST(QueryMapTest, Basic) {
901 DnsTlsQueryMap map;
902
903 EXPECT_TRUE(map.empty());
904
905 bytevec q0 = make_query(999, SIZE);
906 bytevec q1 = make_query(888, SIZE);
907 bytevec q2 = make_query(777, SIZE);
908
909 auto f0 = map.recordQuery(makeSlice(q0));
910 auto f1 = map.recordQuery(makeSlice(q1));
911 auto f2 = map.recordQuery(makeSlice(q2));
912
913 // Check return values of recordQuery
914 EXPECT_EQ(0, f0->query.newId);
915 EXPECT_EQ(1, f1->query.newId);
916 EXPECT_EQ(2, f2->query.newId);
917
918 // Check side effects of recordQuery
919 EXPECT_FALSE(map.empty());
920
921 auto all = map.getAll();
922 EXPECT_EQ(3U, all.size());
923
924 EXPECT_EQ(0, all[0].newId);
925 EXPECT_EQ(1, all[1].newId);
926 EXPECT_EQ(2, all[2].newId);
927
Mike Yu7e08b852019-10-18 18:27:43 +0800928 EXPECT_EQ(q0, all[0].query);
929 EXPECT_EQ(q1, all[1].query);
930 EXPECT_EQ(q2, all[2].query);
Mike Yuc52739e2018-10-19 21:06:32 +0800931
932 bytevec a0 = make_query(0, SIZE);
933 bytevec a1 = make_query(1, SIZE);
934 bytevec a2 = make_query(2, SIZE);
935
936 // Return responses out of order
937 map.onResponse(a2);
938 map.onResponse(a0);
939 map.onResponse(a1);
940
941 EXPECT_TRUE(map.empty());
942
943 auto r0 = f0->result.get();
944 auto r1 = f1->result.get();
945 auto r2 = f2->result.get();
946
947 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
948 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
949 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
950
951 const bytevec& d0 = r0.response;
952 const bytevec& d1 = r1.response;
953 const bytevec& d2 = r2.response;
954
955 // The ID should match the query
956 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
957 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
958 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
959 // The body should match the answer
960 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
961 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
962 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
963}
964
965TEST(QueryMapTest, FillHole) {
966 DnsTlsQueryMap map;
967 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
968 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
969 futures[i] = map.recordQuery(makeSlice(QUERY));
970 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
971 EXPECT_EQ(i, futures[i]->query.newId);
972 }
973
974 // The map should now be full.
975 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
976
977 // Trying to add another query should fail because the map is full.
978 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
979
980 // Send an answer to query 40000
981 auto answer = make_query(40000, SIZE);
982 map.onResponse(answer);
983 auto result = futures[40000]->result.get();
984 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
985 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
986 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
987 bytevec(result.response.begin() + 2, result.response.end()));
988
989 // There should now be room in the map.
990 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
991 auto f = map.recordQuery(makeSlice(QUERY));
992 ASSERT_TRUE(f);
993 EXPECT_EQ(40000, f->query.newId);
994
995 // The map should now be full again.
996 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
997 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
998}
999
Mike Yue93d9ae2020-08-25 19:09:51 +08001000class DnsTlsSocketTest : public ::testing::Test {
1001 protected:
1002 class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
1003 public:
1004 MOCK_METHOD(void, onClosed, (), (override));
1005 MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
1006 };
Ben Schwartz62176fd2019-01-22 17:32:17 -05001007
Mike Yue93d9ae2020-08-25 19:09:51 +08001008 DnsTlsSocketTest() { parseServer(kTlsAddr, std::stoi(kTlsPort), &server.ss); }
Ben Schwartz62176fd2019-01-22 17:32:17 -05001009
Mike Yue93d9ae2020-08-25 19:09:51 +08001010 std::unique_ptr<DnsTlsSocket> makeDnsTlsSocket(IDnsTlsSocketObserver* observer) {
1011 return std::make_unique<DnsTlsSocket>(this->server, MARK, observer, &this->cache);
1012 }
1013
1014 void enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket>& socket) {
1015 ASSERT_TRUE(socket);
1016 DnsTlsSocket* delegate = socket.get();
1017 std::lock_guard guard(delegate->mLock);
1018 delegate->mAsyncHandshake = true;
1019 }
1020
1021 static constexpr char kTlsAddr[] = "127.0.0.3";
1022 static constexpr char kTlsPort[] = "8530"; // High-numbered port so root isn't required.
1023 static constexpr char kBackendAddr[] = "192.0.2.1";
1024 static constexpr char kBackendPort[] = "8531"; // High-numbered port so root isn't required.
1025
1026 test::DnsTlsFrontend tls{kTlsAddr, kTlsPort, kBackendAddr, kBackendPort};
Ben Schwartz62176fd2019-01-22 17:32:17 -05001027
1028 DnsTlsServer server;
Mike Yue93d9ae2020-08-25 19:09:51 +08001029 DnsTlsSessionCache cache;
1030};
1031
1032TEST_F(DnsTlsSocketTest, SlowDestructor) {
1033 ASSERT_TRUE(tls.startServer());
Ben Schwartz62176fd2019-01-22 17:32:17 -05001034
Mike Yu1b9069c2020-08-25 15:17:29 +08001035 MockDnsTlsSocketObserver observer;
Mike Yue93d9ae2020-08-25 19:09:51 +08001036 auto socket = makeDnsTlsSocket(&observer);
1037
Ben Schwartz62176fd2019-01-22 17:32:17 -05001038 ASSERT_TRUE(socket->initialize());
Mike Yu441d9372020-07-15 17:06:22 +08001039 ASSERT_TRUE(socket->startHandshake());
Ben Schwartz62176fd2019-01-22 17:32:17 -05001040
1041 // Test: Time the socket destructor. This should be fast.
1042 auto before = std::chrono::steady_clock::now();
Mike Yu1b9069c2020-08-25 15:17:29 +08001043 EXPECT_CALL(observer, onClosed);
Ben Schwartz62176fd2019-01-22 17:32:17 -05001044 socket.reset();
1045 auto after = std::chrono::steady_clock::now();
1046 auto delay = after - before;
chenbruceaff85842019-05-31 15:46:42 +08001047 LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
Ben Schwartz62176fd2019-01-22 17:32:17 -05001048 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
1049 // it will wait for the timeout, which is expected to take 20seconds.
1050 EXPECT_LT(delay, std::chrono::seconds{5});
1051}
1052
Mike Yue93d9ae2020-08-25 19:09:51 +08001053TEST_F(DnsTlsSocketTest, StartHandshake) {
Mike Yu441d9372020-07-15 17:06:22 +08001054 ASSERT_TRUE(tls.startServer());
1055
Mike Yue93d9ae2020-08-25 19:09:51 +08001056 MockDnsTlsSocketObserver observer;
1057 auto socket = makeDnsTlsSocket(&observer);
Mike Yu441d9372020-07-15 17:06:22 +08001058
1059 // Call the function before the call to initialize().
1060 EXPECT_FALSE(socket->startHandshake());
1061
1062 // Call the function after the call to initialize().
1063 EXPECT_TRUE(socket->initialize());
1064 EXPECT_TRUE(socket->startHandshake());
1065
1066 // Call both of them again.
1067 EXPECT_FALSE(socket->initialize());
1068 EXPECT_FALSE(socket->startHandshake());
Mike Yue93d9ae2020-08-25 19:09:51 +08001069
1070 // Should happen when joining the loop thread in |socket| destruction.
1071 EXPECT_CALL(observer, onClosed);
1072}
1073
1074TEST_F(DnsTlsSocketTest, ShutdownSignal) {
1075 ASSERT_TRUE(tls.startServer());
1076
1077 MockDnsTlsSocketObserver observer;
1078 std::unique_ptr<DnsTlsSocket> socket;
1079
1080 const auto setupAndStartHandshake = [&]() {
1081 socket = makeDnsTlsSocket(&observer);
1082 EXPECT_TRUE(socket->initialize());
1083 enableAsyncHandshake(socket);
1084 EXPECT_TRUE(socket->startHandshake());
1085 };
1086 const auto triggerShutdown = [&](const std::string& traceLog) {
1087 SCOPED_TRACE(traceLog);
1088 auto before = std::chrono::steady_clock::now();
1089 EXPECT_CALL(observer, onClosed);
1090 socket.reset();
1091 auto after = std::chrono::steady_clock::now();
1092 auto delay = after - before;
1093 LOG(INFO) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1094 EXPECT_LT(delay, std::chrono::seconds{1});
1095 };
1096
1097 tls.setHangOnHandshakeForTesting(true);
1098
1099 // Test 1: Reset the DnsTlsSocket which is doing the handshake.
1100 setupAndStartHandshake();
1101 triggerShutdown("Shutdown handshake w/o query requests");
1102
1103 // Test 2: Reset the DnsTlsSocket which is doing the handshake with some query requests.
1104 setupAndStartHandshake();
1105
1106 // DnsTlsSocket doesn't report the status of pending queries. The decision whether to mark
1107 // a query request as failed or not is made in DnsTlsTransport.
1108 EXPECT_CALL(observer, onResponse).Times(0);
1109 EXPECT_TRUE(socket->query(1, makeSlice(QUERY)));
1110 EXPECT_TRUE(socket->query(2, makeSlice(QUERY)));
1111 triggerShutdown("Shutdown handshake w/ query requests");
Mike Yu441d9372020-07-15 17:06:22 +08001112}
1113
Mike Yuc52739e2018-10-19 21:06:32 +08001114} // end of namespace net
1115} // end of namespace android