blob: de2a45e840fa57909599a271c4b6e818603fa6af [file] [log] [blame]
Mike Yuc52739e2018-10-19 21:06:32 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yuc52739e2018-10-19 21:06:32 +080018
chenbruceaff85842019-05-31 15:46:42 +080019#include <arpa/inet.h>
20
21#include <chrono>
22
23#include <android-base/logging.h>
24#include <android-base/macros.h>
Mike Yu1b9069c2020-08-25 15:17:29 +080025#include <gmock/gmock.h>
Mike Yuc52739e2018-10-19 21:06:32 +080026#include <gtest/gtest.h>
chenbruceaff85842019-05-31 15:46:42 +080027#include <netdutils/Slice.h>
Mike Yuc52739e2018-10-19 21:06:32 +080028
Bernie Innocentiec4219b2019-01-30 11:16:36 +090029#include "DnsTlsDispatcher.h"
30#include "DnsTlsQueryMap.h"
31#include "DnsTlsServer.h"
32#include "DnsTlsSessionCache.h"
33#include "DnsTlsSocket.h"
34#include "DnsTlsTransport.h"
35#include "IDnsTlsSocket.h"
36#include "IDnsTlsSocketFactory.h"
37#include "IDnsTlsSocketObserver.h"
chenbruceb43ec752019-07-24 20:19:41 +080038#include "tests/dns_responder/dns_tls_frontend.h"
Ben Schwartz62176fd2019-01-22 17:32:17 -050039
Mike Yuc52739e2018-10-19 21:06:32 +080040namespace android {
41namespace net {
42
Mike Yuc52739e2018-10-19 21:06:32 +080043using netdutils::makeSlice;
Mike Yu1b9069c2020-08-25 15:17:29 +080044using netdutils::Slice;
Mike Yuc52739e2018-10-19 21:06:32 +080045
46typedef std::vector<uint8_t> bytevec;
47
48static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
49 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
50 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
51 // IPv4 parse succeeded, so it's IPv4
52 sin->sin_family = AF_INET;
53 sin->sin_port = htons(port);
54 return;
55 }
56 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
57 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
58 // IPv6 parse succeeded, so it's IPv6.
59 sin6->sin6_family = AF_INET6;
60 sin6->sin6_port = htons(port);
61 return;
62 }
chenbruceaff85842019-05-31 15:46:42 +080063 LOG(ERROR) << "Failed to parse server address: " << server;
Mike Yuc52739e2018-10-19 21:06:32 +080064}
65
Mike Yuc52739e2018-10-19 21:06:32 +080066std::string SERVERNAME1 = "dns.example.com";
67std::string SERVERNAME2 = "dns.example.org";
68
69// BaseTest just provides constants that are useful for the tests.
70class BaseTest : public ::testing::Test {
71 protected:
72 BaseTest() {
73 parseServer("192.0.2.1", 853, &V4ADDR1);
74 parseServer("192.0.2.2", 853, &V4ADDR2);
75 parseServer("2001:db8::1", 853, &V6ADDR1);
76 parseServer("2001:db8::2", 853, &V6ADDR2);
77
78 SERVER1 = DnsTlsServer(V4ADDR1);
Mike Yuc52739e2018-10-19 21:06:32 +080079 SERVER1.name = SERVERNAME1;
80 }
81
82 sockaddr_storage V4ADDR1;
83 sockaddr_storage V4ADDR2;
84 sockaddr_storage V6ADDR1;
85 sockaddr_storage V6ADDR2;
86
87 DnsTlsServer SERVER1;
88};
89
90bytevec make_query(uint16_t id, size_t size) {
91 bytevec vec(size);
92 vec[0] = id >> 8;
93 vec[1] = id;
94 // Arbitrarily fill the query body with unique data.
95 for (size_t i = 2; i < size; ++i) {
96 vec[i] = id + i;
97 }
98 return vec;
99}
100
101// Query constants
102const unsigned MARK = 123;
103const uint16_t ID = 52;
104const uint16_t SIZE = 22;
105const bytevec QUERY = make_query(ID, SIZE);
106
107template <class T>
108class FakeSocketFactory : public IDnsTlsSocketFactory {
109 public:
110 FakeSocketFactory() {}
111 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
112 const DnsTlsServer& server ATTRIBUTE_UNUSED,
113 unsigned mark ATTRIBUTE_UNUSED,
114 IDnsTlsSocketObserver* observer,
115 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
116 return std::make_unique<T>(observer);
117 }
118};
119
120bytevec make_echo(uint16_t id, const Slice query) {
121 bytevec response(query.size() + 2);
122 response[0] = id >> 8;
123 response[1] = id;
124 // Echo the query as the fake response.
125 memcpy(response.data() + 2, query.base(), query.size());
126 return response;
127}
128
129// Simplest possible fake server. This just echoes the query as the response.
130class FakeSocketEcho : public IDnsTlsSocket {
131 public:
132 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
133 bool query(uint16_t id, const Slice query) override {
134 // Return the response immediately (asynchronously).
135 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
136 return true;
137 }
Mike Yu441d9372020-07-15 17:06:22 +0800138 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800139
140 private:
141 IDnsTlsSocketObserver* const mObserver;
142};
143
144class TransportTest : public BaseTest {};
145
146TEST_F(TransportTest, Query) {
147 FakeSocketFactory<FakeSocketEcho> factory;
148 DnsTlsTransport transport(SERVER1, MARK, &factory);
149 auto r = transport.query(makeSlice(QUERY)).get();
150
151 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
152 EXPECT_EQ(QUERY, r.response);
Mike Yu1fea18c2019-12-06 10:59:17 +0800153 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800154}
155
156// Fake Socket that echoes the observed query ID as the response body.
157class FakeSocketId : public IDnsTlsSocket {
158 public:
159 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
160 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
161 // Return the response immediately (asynchronously).
162 bytevec response(4);
163 // Echo the ID in the header to match the response to the query.
164 // This will be overwritten by DnsTlsQueryMap.
165 response[0] = id >> 8;
166 response[1] = id;
167 // Echo the ID in the body, so that the test can verify which ID was used by
168 // DnsTlsQueryMap.
169 response[2] = id >> 8;
170 response[3] = id;
171 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
172 return true;
173 }
Mike Yu441d9372020-07-15 17:06:22 +0800174 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800175
176 private:
177 IDnsTlsSocketObserver* const mObserver;
178};
179
180// Test that IDs are properly reused
181TEST_F(TransportTest, IdReuse) {
182 FakeSocketFactory<FakeSocketId> factory;
183 DnsTlsTransport transport(SERVER1, MARK, &factory);
184 for (int i = 0; i < 100; ++i) {
185 // Send a query.
Mike Yubd136992019-12-04 15:01:07 +0800186 std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
Mike Yuc52739e2018-10-19 21:06:32 +0800187 // Wait for the response.
Mike Yubd136992019-12-04 15:01:07 +0800188 DnsTlsTransport::Result r = f.get();
Mike Yuc52739e2018-10-19 21:06:32 +0800189 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
190
191 // All queries should have an observed ID of zero, because it is returned to the ID pool
192 // after each use.
193 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
194 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800195 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800196}
197
198// These queries might be handled in serial or parallel as they race the
199// responses.
200TEST_F(TransportTest, RacingQueries_10000) {
201 FakeSocketFactory<FakeSocketEcho> factory;
202 DnsTlsTransport transport(SERVER1, MARK, &factory);
203 std::vector<std::future<DnsTlsTransport::Result>> results;
204 // Fewer than 65536 queries to avoid ID exhaustion.
205 const int num_queries = 10000;
206 results.reserve(num_queries);
207 for (int i = 0; i < num_queries; ++i) {
208 results.push_back(transport.query(makeSlice(QUERY)));
209 }
210 for (auto& result : results) {
211 auto r = result.get();
212 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
213 EXPECT_EQ(QUERY, r.response);
214 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800215 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800216}
217
218// A server that waits until sDelay queries are queued before responding.
219class FakeSocketDelay : public IDnsTlsSocket {
220 public:
221 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
Mike Yu441d9372020-07-15 17:06:22 +0800222 ~FakeSocketDelay() {
223 std::lock_guard guard(mLock);
224 sDelay = 1;
225 sReverse = false;
226 sConnectable = true;
227 }
228 inline static size_t sDelay = 1;
229 inline static bool sReverse = false;
230 inline static bool sConnectable = true;
Mike Yuc52739e2018-10-19 21:06:32 +0800231
232 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800233 LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
Mike Yuc52739e2018-10-19 21:06:32 +0800234 std::lock_guard guard(mLock);
235 // Check for duplicate IDs.
236 EXPECT_EQ(0U, mIds.count(id));
237 mIds.insert(id);
238
239 // Store response.
240 mResponses.push_back(make_echo(id, query));
241
chenbruceaff85842019-05-31 15:46:42 +0800242 LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
Mike Yuc52739e2018-10-19 21:06:32 +0800243 if (mResponses.size() == sDelay) {
244 std::thread(&FakeSocketDelay::sendResponses, this).detach();
245 }
246 return true;
247 }
Mike Yu441d9372020-07-15 17:06:22 +0800248 bool startHandshake() override { return sConnectable; }
Mike Yuc52739e2018-10-19 21:06:32 +0800249
250 private:
251 void sendResponses() {
252 std::lock_guard guard(mLock);
253 if (sReverse) {
254 std::reverse(std::begin(mResponses), std::end(mResponses));
255 }
256 for (auto& response : mResponses) {
257 mObserver->onResponse(response);
258 }
259 mIds.clear();
260 mResponses.clear();
261 }
262
263 std::mutex mLock;
264 IDnsTlsSocketObserver* const mObserver;
265 std::set<uint16_t> mIds GUARDED_BY(mLock);
266 std::vector<bytevec> mResponses GUARDED_BY(mLock);
267};
268
Mike Yuc52739e2018-10-19 21:06:32 +0800269TEST_F(TransportTest, ParallelColliding) {
270 FakeSocketDelay::sDelay = 10;
271 FakeSocketDelay::sReverse = false;
272 FakeSocketFactory<FakeSocketDelay> factory;
273 DnsTlsTransport transport(SERVER1, MARK, &factory);
274 std::vector<std::future<DnsTlsTransport::Result>> results;
275 // Fewer than 65536 queries to avoid ID exhaustion.
276 results.reserve(FakeSocketDelay::sDelay);
277 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
278 results.push_back(transport.query(makeSlice(QUERY)));
279 }
280 for (auto& result : results) {
281 auto r = result.get();
282 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
283 EXPECT_EQ(QUERY, r.response);
284 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800285 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800286}
287
288TEST_F(TransportTest, ParallelColliding_Max) {
289 FakeSocketDelay::sDelay = 65536;
290 FakeSocketDelay::sReverse = false;
291 FakeSocketFactory<FakeSocketDelay> factory;
292 DnsTlsTransport transport(SERVER1, MARK, &factory);
293 std::vector<std::future<DnsTlsTransport::Result>> results;
294 // Exactly 65536 queries should still be possible in parallel,
295 // even if they all have the same original ID.
296 results.reserve(FakeSocketDelay::sDelay);
297 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
298 results.push_back(transport.query(makeSlice(QUERY)));
299 }
300 for (auto& result : results) {
301 auto r = result.get();
302 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
303 EXPECT_EQ(QUERY, r.response);
304 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800305 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800306}
307
308TEST_F(TransportTest, ParallelUnique) {
309 FakeSocketDelay::sDelay = 10;
310 FakeSocketDelay::sReverse = false;
311 FakeSocketFactory<FakeSocketDelay> factory;
312 DnsTlsTransport transport(SERVER1, MARK, &factory);
313 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
314 std::vector<std::future<DnsTlsTransport::Result>> results;
315 results.reserve(FakeSocketDelay::sDelay);
316 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
317 queries[i] = make_query(i, SIZE);
318 results.push_back(transport.query(makeSlice(queries[i])));
319 }
320 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
321 auto r = results[i].get();
322 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
323 EXPECT_EQ(queries[i], r.response);
324 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800325 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800326}
327
328TEST_F(TransportTest, ParallelUnique_Max) {
329 FakeSocketDelay::sDelay = 65536;
330 FakeSocketDelay::sReverse = false;
331 FakeSocketFactory<FakeSocketDelay> factory;
332 DnsTlsTransport transport(SERVER1, MARK, &factory);
333 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
334 std::vector<std::future<DnsTlsTransport::Result>> results;
335 // Exactly 65536 queries should still be possible in parallel,
336 // and they should all be mapped correctly back to the original ID.
337 results.reserve(FakeSocketDelay::sDelay);
338 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
339 queries[i] = make_query(i, SIZE);
340 results.push_back(transport.query(makeSlice(queries[i])));
341 }
342 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
343 auto r = results[i].get();
344 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
345 EXPECT_EQ(queries[i], r.response);
346 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800347 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800348}
349
350TEST_F(TransportTest, IdExhaustion) {
351 const int num_queries = 65536;
352 // A delay of 65537 is unreachable, because the maximum number
353 // of outstanding queries is 65536.
354 FakeSocketDelay::sDelay = num_queries + 1;
355 FakeSocketDelay::sReverse = false;
356 FakeSocketFactory<FakeSocketDelay> factory;
357 DnsTlsTransport transport(SERVER1, MARK, &factory);
358 std::vector<std::future<DnsTlsTransport::Result>> results;
359 // Issue the maximum number of queries.
360 results.reserve(num_queries);
361 for (int i = 0; i < num_queries; ++i) {
362 results.push_back(transport.query(makeSlice(QUERY)));
363 }
364
365 // The ID space is now full, so subsequent queries should fail immediately.
366 auto r = transport.query(makeSlice(QUERY)).get();
367 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
368 EXPECT_TRUE(r.response.empty());
369
370 for (auto& result : results) {
371 // All other queries should remain outstanding.
372 EXPECT_EQ(std::future_status::timeout,
373 result.wait_for(std::chrono::duration<int>::zero()));
374 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800375 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800376}
377
378// Responses can come back from the server in any order. This should have no
379// effect on Transport's observed behavior.
380TEST_F(TransportTest, ReverseOrder) {
381 FakeSocketDelay::sDelay = 10;
382 FakeSocketDelay::sReverse = true;
383 FakeSocketFactory<FakeSocketDelay> factory;
384 DnsTlsTransport transport(SERVER1, MARK, &factory);
385 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
386 std::vector<std::future<DnsTlsTransport::Result>> results;
387 results.reserve(FakeSocketDelay::sDelay);
388 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
389 queries[i] = make_query(i, SIZE);
390 results.push_back(transport.query(makeSlice(queries[i])));
391 }
392 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
393 auto r = results[i].get();
394 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
395 EXPECT_EQ(queries[i], r.response);
396 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800397 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800398}
399
400TEST_F(TransportTest, ReverseOrder_Max) {
401 FakeSocketDelay::sDelay = 65536;
402 FakeSocketDelay::sReverse = true;
403 FakeSocketFactory<FakeSocketDelay> factory;
404 DnsTlsTransport transport(SERVER1, MARK, &factory);
405 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
406 std::vector<std::future<DnsTlsTransport::Result>> results;
407 results.reserve(FakeSocketDelay::sDelay);
408 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
409 queries[i] = make_query(i, SIZE);
410 results.push_back(transport.query(makeSlice(queries[i])));
411 }
412 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
413 auto r = results[i].get();
414 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
415 EXPECT_EQ(queries[i], r.response);
416 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800417 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800418}
419
420// Returning null from the factory indicates a connection failure.
421class NullSocketFactory : public IDnsTlsSocketFactory {
422 public:
423 NullSocketFactory() {}
424 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
425 const DnsTlsServer& server ATTRIBUTE_UNUSED,
426 unsigned mark ATTRIBUTE_UNUSED,
427 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
428 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
429 return nullptr;
430 }
431};
432
433TEST_F(TransportTest, ConnectFail) {
Mike Yu441d9372020-07-15 17:06:22 +0800434 // Failure on creating socket.
435 NullSocketFactory factory1;
436 DnsTlsTransport transport1(SERVER1, MARK, &factory1);
437 auto r = transport1.query(makeSlice(QUERY)).get();
Mike Yuc52739e2018-10-19 21:06:32 +0800438
439 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
440 EXPECT_TRUE(r.response.empty());
Mike Yu441d9372020-07-15 17:06:22 +0800441 EXPECT_EQ(transport1.getConnectCounter(), 1);
442
443 // Failure on handshaking.
444 FakeSocketDelay::sConnectable = false;
445 FakeSocketFactory<FakeSocketDelay> factory2;
446 DnsTlsTransport transport2(SERVER1, MARK, &factory2);
447 r = transport2.query(makeSlice(QUERY)).get();
448
449 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
450 EXPECT_TRUE(r.response.empty());
451 EXPECT_EQ(transport2.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800452}
453
454// Simulate a socket that connects but then immediately receives a server
455// close notification.
456class FakeSocketClose : public IDnsTlsSocket {
457 public:
458 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
459 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
460 ~FakeSocketClose() { mCloser.join(); }
461 bool query(uint16_t id ATTRIBUTE_UNUSED,
462 const Slice query ATTRIBUTE_UNUSED) override {
463 return true;
464 }
Mike Yu441d9372020-07-15 17:06:22 +0800465 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800466
467 private:
468 std::thread mCloser;
469};
470
471TEST_F(TransportTest, CloseRetryFail) {
472 FakeSocketFactory<FakeSocketClose> factory;
473 DnsTlsTransport transport(SERVER1, MARK, &factory);
474 auto r = transport.query(makeSlice(QUERY)).get();
475
476 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
477 EXPECT_TRUE(r.response.empty());
Mike Yucb2bb7c2019-11-22 20:42:13 +0800478
479 // Reconnections are triggered since DnsTlsQueryMap is not empty.
Mike Yu1fea18c2019-12-06 10:59:17 +0800480 EXPECT_EQ(transport.getConnectCounter(), DnsTlsQueryMap::kMaxTries);
Mike Yuc52739e2018-10-19 21:06:32 +0800481}
482
483// Simulate a server that occasionally closes the connection and silently
484// drops some queries.
485class FakeSocketLimited : public IDnsTlsSocket {
486 public:
487 static int sLimit; // Number of queries to answer per socket.
488 static size_t sMaxSize; // Silently discard queries greater than this size.
489 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
490 : mObserver(observer), mQueries(0) {}
491 ~FakeSocketLimited() {
492 {
chenbruceaff85842019-05-31 15:46:42 +0800493 LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800494 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800495 LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800496 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800497 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800498 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800499 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800500 }
501 mThreads.clear();
502 }
503
504 if (mCloser) {
chenbruceaff85842019-05-31 15:46:42 +0800505 LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800506 mCloser->join();
chenbruceaff85842019-05-31 15:46:42 +0800507 LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800508 }
509 }
510 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800511 LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800512 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800513 LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800514 ++mQueries;
515
516 if (mQueries <= sLimit) {
chenbruceaff85842019-05-31 15:46:42 +0800517 LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
Mike Yuc52739e2018-10-19 21:06:32 +0800518 if (query.size() <= sMaxSize) {
519 // Return the response immediately (asynchronously).
520 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
521 }
522 }
523 if (mQueries == sLimit) {
524 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
525 }
526 return mQueries <= sLimit;
527 }
Mike Yu441d9372020-07-15 17:06:22 +0800528 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800529
530 private:
531 void sendClose() {
532 {
chenbruceaff85842019-05-31 15:46:42 +0800533 LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800534 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800535 LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800536 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800537 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800538 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800539 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800540 }
541 mThreads.clear();
542 }
543 mObserver->onClosed();
544 }
545 std::mutex mLock;
546 IDnsTlsSocketObserver* const mObserver;
547 int mQueries GUARDED_BY(mLock);
548 std::vector<std::thread> mThreads GUARDED_BY(mLock);
549 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
550};
551
552int FakeSocketLimited::sLimit;
553size_t FakeSocketLimited::sMaxSize;
554
555TEST_F(TransportTest, SilentDrop) {
556 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
557 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
558 FakeSocketFactory<FakeSocketLimited> factory;
559 DnsTlsTransport transport(SERVER1, MARK, &factory);
560
561 // Queue up 10 queries. They will all be ignored, and after the 10th,
562 // the socket will close. Transport will retry them all, until they
563 // all hit the retry limit and expire.
564 std::vector<std::future<DnsTlsTransport::Result>> results;
565 results.reserve(FakeSocketLimited::sLimit);
566 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
567 results.push_back(transport.query(makeSlice(QUERY)));
568 }
569 for (auto& result : results) {
570 auto r = result.get();
571 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
572 EXPECT_TRUE(r.response.empty());
573 }
Mike Yucb2bb7c2019-11-22 20:42:13 +0800574
575 // Reconnections are triggered since DnsTlsQueryMap is not empty.
Mike Yu1fea18c2019-12-06 10:59:17 +0800576 EXPECT_EQ(transport.getConnectCounter(), DnsTlsQueryMap::kMaxTries);
Mike Yuc52739e2018-10-19 21:06:32 +0800577}
578
579TEST_F(TransportTest, PartialDrop) {
580 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
581 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
582 FakeSocketFactory<FakeSocketLimited> factory;
583 DnsTlsTransport transport(SERVER1, MARK, &factory);
584
585 // Queue up 100 queries, alternating "short" which will be served and "long"
586 // which will be dropped.
587 const int num_queries = 10 * FakeSocketLimited::sLimit;
588 std::vector<bytevec> queries(num_queries);
589 std::vector<std::future<DnsTlsTransport::Result>> results;
590 results.reserve(num_queries);
591 for (int i = 0; i < num_queries; ++i) {
592 queries[i] = make_query(i, SIZE + (i % 2));
593 results.push_back(transport.query(makeSlice(queries[i])));
594 }
595 // Just check the short queries, which are at the even indices.
596 for (int i = 0; i < num_queries; i += 2) {
597 auto r = results[i].get();
598 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
599 EXPECT_EQ(queries[i], r.response);
600 }
Mike Yucb2bb7c2019-11-22 20:42:13 +0800601
602 // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
603 // connect attempts for this test.
604}
605
606TEST_F(TransportTest, ConnectCounter) {
607 FakeSocketLimited::sLimit = 2; // Close the socket after 2 queries.
608 FakeSocketLimited::sMaxSize = SIZE; // No query drops.
609 FakeSocketFactory<FakeSocketLimited> factory;
610 DnsTlsTransport transport(SERVER1, MARK, &factory);
611
612 // Connecting on demand.
Mike Yu1fea18c2019-12-06 10:59:17 +0800613 EXPECT_EQ(transport.getConnectCounter(), 0);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800614
615 const int num_queries = 10;
616 std::vector<std::future<DnsTlsTransport::Result>> results;
617 results.reserve(num_queries);
618 for (int i = 0; i < num_queries; i++) {
619 // Reconnections take place every two queries.
620 results.push_back(transport.query(makeSlice(QUERY)));
621 }
622 for (int i = 0; i < num_queries; i++) {
623 auto r = results[i].get();
624 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
625 }
626
Mike Yu1fea18c2019-12-06 10:59:17 +0800627 EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
Mike Yuc52739e2018-10-19 21:06:32 +0800628}
629
630// Simulate a malfunctioning server that injects extra miscellaneous
631// responses to queries that were not asked. This will cause wrong answers but
632// must not crash the Transport.
633class FakeSocketGarbage : public IDnsTlsSocket {
634 public:
635 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
636 // Inject a garbage event.
637 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
638 }
639 ~FakeSocketGarbage() {
640 std::lock_guard guard(mLock);
641 for (auto& thread : mThreads) {
642 thread.join();
643 }
644 }
645 bool query(uint16_t id, const Slice query) override {
646 std::lock_guard guard(mLock);
647 // Return the response twice.
648 auto echo = make_echo(id, query);
649 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
650 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
651 // Also return some other garbage
652 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
653 return true;
654 }
Mike Yu441d9372020-07-15 17:06:22 +0800655 bool startHandshake() override { return true; }
Mike Yuc52739e2018-10-19 21:06:32 +0800656
657 private:
658 std::mutex mLock;
659 std::vector<std::thread> mThreads GUARDED_BY(mLock);
660 IDnsTlsSocketObserver* const mObserver;
661};
662
663TEST_F(TransportTest, IgnoringGarbage) {
664 FakeSocketFactory<FakeSocketGarbage> factory;
665 DnsTlsTransport transport(SERVER1, MARK, &factory);
666 for (int i = 0; i < 10; ++i) {
667 auto r = transport.query(makeSlice(QUERY)).get();
668
669 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
670 // Don't check the response because this server is malfunctioning.
671 }
Mike Yu1fea18c2019-12-06 10:59:17 +0800672 EXPECT_EQ(transport.getConnectCounter(), 1);
Mike Yuc52739e2018-10-19 21:06:32 +0800673}
674
675// Dispatcher tests
676class DispatcherTest : public BaseTest {};
677
678TEST_F(DispatcherTest, Query) {
679 bytevec ans(4096);
680 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800681 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800682
683 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
684 DnsTlsDispatcher dispatcher(std::move(factory));
Mike Yucb2bb7c2019-11-22 20:42:13 +0800685 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
686 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800687
688 EXPECT_EQ(DnsTlsTransport::Response::success, r);
689 EXPECT_EQ(int(QUERY.size()), resplen);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800690 EXPECT_TRUE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800691 ans.resize(resplen);
692 EXPECT_EQ(QUERY, ans);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800693
694 // Expect to reuse the connection.
695 r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
696 &connectTriggered);
697 EXPECT_EQ(DnsTlsTransport::Response::success, r);
698 EXPECT_FALSE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800699}
700
701TEST_F(DispatcherTest, AnswerTooLarge) {
702 bytevec ans(SIZE - 1); // Too small to hold the answer
703 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800704 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800705
706 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
707 DnsTlsDispatcher dispatcher(std::move(factory));
Mike Yucb2bb7c2019-11-22 20:42:13 +0800708 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
709 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800710
711 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
Mike Yucb2bb7c2019-11-22 20:42:13 +0800712 EXPECT_TRUE(connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800713}
714
715template<class T>
716class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
717 public:
718 TrackingFakeSocketFactory() {}
719 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
720 const DnsTlsServer& server,
721 unsigned mark,
722 IDnsTlsSocketObserver* observer,
723 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
724 std::lock_guard guard(mLock);
725 keys.emplace(mark, server);
726 return std::make_unique<T>(observer);
727 }
728 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
729
730 private:
731 std::mutex mLock;
732};
733
734TEST_F(DispatcherTest, Dispatching) {
735 FakeSocketDelay::sDelay = 5;
736 FakeSocketDelay::sReverse = true;
737 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
738 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
739 DnsTlsDispatcher dispatcher(std::move(factory));
740
741 // Populate a vector of two servers and two socket marks, four combinations
742 // in total.
743 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
744 keys.emplace_back(MARK, SERVER1);
745 keys.emplace_back(MARK + 1, SERVER1);
746 keys.emplace_back(MARK, V4ADDR2);
747 keys.emplace_back(MARK + 1, V4ADDR2);
748
749 // Do several queries on each server. They should all succeed.
750 std::vector<std::thread> threads;
751 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
752 auto key = keys[i % keys.size()];
753 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
754 auto q = make_query(i, SIZE);
755 bytevec ans(4096);
756 int resplen = 0;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800757 bool connectTriggered = false;
Mike Yuc52739e2018-10-19 21:06:32 +0800758 unsigned mark = key.first;
759 const DnsTlsServer& server = key.second;
Mike Yucb2bb7c2019-11-22 20:42:13 +0800760 auto r = dispatcher->query(server, mark, makeSlice(q), makeSlice(ans), &resplen,
761 &connectTriggered);
Mike Yuc52739e2018-10-19 21:06:32 +0800762 EXPECT_EQ(DnsTlsTransport::Response::success, r);
763 EXPECT_EQ(int(q.size()), resplen);
764 ans.resize(resplen);
765 EXPECT_EQ(q, ans);
766 }, &dispatcher);
767 }
768 for (auto& thread : threads) {
769 thread.join();
770 }
771 // We expect that the factory created one socket for each key.
772 EXPECT_EQ(keys.size(), weak_factory->keys.size());
773 for (auto& key : keys) {
774 EXPECT_EQ(1U, weak_factory->keys.count(key));
775 }
776}
777
778// Check DnsTlsServer's comparison logic.
779AddressComparator ADDRESS_COMPARATOR;
780bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
781 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
782 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
783 EXPECT_FALSE(cmp1 && cmp2);
784 return !cmp1 && !cmp2;
785}
786
787void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
788 EXPECT_TRUE(s1 == s1);
789 EXPECT_TRUE(s2 == s2);
790 EXPECT_TRUE(isAddressEqual(s1, s1));
791 EXPECT_TRUE(isAddressEqual(s2, s2));
792
793 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
794 EXPECT_FALSE(s1 == s2);
795 EXPECT_FALSE(s2 == s1);
796}
797
798class ServerTest : public BaseTest {};
799
800TEST_F(ServerTest, IPv4) {
801 checkUnequal(V4ADDR1, V4ADDR2);
802 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
803}
804
805TEST_F(ServerTest, IPv6) {
806 checkUnequal(V6ADDR1, V6ADDR2);
807 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
808}
809
810TEST_F(ServerTest, MixedAddressFamily) {
811 checkUnequal(V6ADDR1, V4ADDR1);
812 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
813}
814
815TEST_F(ServerTest, IPv6ScopeId) {
816 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
817 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
818 addr1->sin6_scope_id = 1;
819 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
820 addr2->sin6_scope_id = 2;
821 checkUnequal(s1, s2);
822 EXPECT_FALSE(isAddressEqual(s1, s2));
823
824 EXPECT_FALSE(s1.wasExplicitlyConfigured());
825 EXPECT_FALSE(s2.wasExplicitlyConfigured());
826}
827
828TEST_F(ServerTest, IPv6FlowInfo) {
829 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
830 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
831 addr1->sin6_flowinfo = 1;
832 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
833 addr2->sin6_flowinfo = 2;
834 // All comparisons ignore flowinfo.
835 EXPECT_EQ(s1, s2);
836 EXPECT_TRUE(isAddressEqual(s1, s2));
837
838 EXPECT_FALSE(s1.wasExplicitlyConfigured());
839 EXPECT_FALSE(s2.wasExplicitlyConfigured());
840}
841
842TEST_F(ServerTest, Port) {
843 DnsTlsServer s1, s2;
844 parseServer("192.0.2.1", 853, &s1.ss);
845 parseServer("192.0.2.1", 854, &s2.ss);
846 checkUnequal(s1, s2);
847 EXPECT_TRUE(isAddressEqual(s1, s2));
848
849 DnsTlsServer s3, s4;
850 parseServer("2001:db8::1", 853, &s3.ss);
851 parseServer("2001:db8::1", 852, &s4.ss);
852 checkUnequal(s3, s4);
853 EXPECT_TRUE(isAddressEqual(s3, s4));
854
855 EXPECT_FALSE(s1.wasExplicitlyConfigured());
856 EXPECT_FALSE(s2.wasExplicitlyConfigured());
857}
858
859TEST_F(ServerTest, Name) {
860 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
861 s1.name = SERVERNAME1;
862 checkUnequal(s1, s2);
863 s2.name = SERVERNAME2;
864 checkUnequal(s1, s2);
865 EXPECT_TRUE(isAddressEqual(s1, s2));
866
867 EXPECT_TRUE(s1.wasExplicitlyConfigured());
868 EXPECT_TRUE(s2.wasExplicitlyConfigured());
869}
870
Mike Yua772c202019-09-23 17:47:21 +0800871TEST_F(ServerTest, Timeout) {
872 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
873 s1.connectTimeout = std::chrono::milliseconds(4000);
874 checkUnequal(s1, s2);
875 s2.connectTimeout = std::chrono::milliseconds(4000);
876 EXPECT_EQ(s1, s2);
877 EXPECT_TRUE(isAddressEqual(s1, s2));
878
879 EXPECT_FALSE(s1.wasExplicitlyConfigured());
880 EXPECT_FALSE(s2.wasExplicitlyConfigured());
881}
882
Mike Yuc52739e2018-10-19 21:06:32 +0800883TEST(QueryMapTest, Basic) {
884 DnsTlsQueryMap map;
885
886 EXPECT_TRUE(map.empty());
887
888 bytevec q0 = make_query(999, SIZE);
889 bytevec q1 = make_query(888, SIZE);
890 bytevec q2 = make_query(777, SIZE);
891
892 auto f0 = map.recordQuery(makeSlice(q0));
893 auto f1 = map.recordQuery(makeSlice(q1));
894 auto f2 = map.recordQuery(makeSlice(q2));
895
896 // Check return values of recordQuery
897 EXPECT_EQ(0, f0->query.newId);
898 EXPECT_EQ(1, f1->query.newId);
899 EXPECT_EQ(2, f2->query.newId);
900
901 // Check side effects of recordQuery
902 EXPECT_FALSE(map.empty());
903
904 auto all = map.getAll();
905 EXPECT_EQ(3U, all.size());
906
907 EXPECT_EQ(0, all[0].newId);
908 EXPECT_EQ(1, all[1].newId);
909 EXPECT_EQ(2, all[2].newId);
910
Mike Yu7e08b852019-10-18 18:27:43 +0800911 EXPECT_EQ(q0, all[0].query);
912 EXPECT_EQ(q1, all[1].query);
913 EXPECT_EQ(q2, all[2].query);
Mike Yuc52739e2018-10-19 21:06:32 +0800914
915 bytevec a0 = make_query(0, SIZE);
916 bytevec a1 = make_query(1, SIZE);
917 bytevec a2 = make_query(2, SIZE);
918
919 // Return responses out of order
920 map.onResponse(a2);
921 map.onResponse(a0);
922 map.onResponse(a1);
923
924 EXPECT_TRUE(map.empty());
925
926 auto r0 = f0->result.get();
927 auto r1 = f1->result.get();
928 auto r2 = f2->result.get();
929
930 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
931 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
932 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
933
934 const bytevec& d0 = r0.response;
935 const bytevec& d1 = r1.response;
936 const bytevec& d2 = r2.response;
937
938 // The ID should match the query
939 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
940 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
941 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
942 // The body should match the answer
943 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
944 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
945 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
946}
947
948TEST(QueryMapTest, FillHole) {
949 DnsTlsQueryMap map;
950 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
951 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
952 futures[i] = map.recordQuery(makeSlice(QUERY));
953 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
954 EXPECT_EQ(i, futures[i]->query.newId);
955 }
956
957 // The map should now be full.
958 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
959
960 // Trying to add another query should fail because the map is full.
961 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
962
963 // Send an answer to query 40000
964 auto answer = make_query(40000, SIZE);
965 map.onResponse(answer);
966 auto result = futures[40000]->result.get();
967 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
968 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
969 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
970 bytevec(result.response.begin() + 2, result.response.end()));
971
972 // There should now be room in the map.
973 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
974 auto f = map.recordQuery(makeSlice(QUERY));
975 ASSERT_TRUE(f);
976 EXPECT_EQ(40000, f->query.newId);
977
978 // The map should now be full again.
979 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
980 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
981}
982
Mike Yue93d9ae2020-08-25 19:09:51 +0800983class DnsTlsSocketTest : public ::testing::Test {
984 protected:
985 class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
986 public:
987 MOCK_METHOD(void, onClosed, (), (override));
988 MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
989 };
Ben Schwartz62176fd2019-01-22 17:32:17 -0500990
Mike Yue93d9ae2020-08-25 19:09:51 +0800991 DnsTlsSocketTest() { parseServer(kTlsAddr, std::stoi(kTlsPort), &server.ss); }
Ben Schwartz62176fd2019-01-22 17:32:17 -0500992
Mike Yue93d9ae2020-08-25 19:09:51 +0800993 std::unique_ptr<DnsTlsSocket> makeDnsTlsSocket(IDnsTlsSocketObserver* observer) {
994 return std::make_unique<DnsTlsSocket>(this->server, MARK, observer, &this->cache);
995 }
996
997 void enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket>& socket) {
998 ASSERT_TRUE(socket);
999 DnsTlsSocket* delegate = socket.get();
1000 std::lock_guard guard(delegate->mLock);
1001 delegate->mAsyncHandshake = true;
1002 }
1003
1004 static constexpr char kTlsAddr[] = "127.0.0.3";
1005 static constexpr char kTlsPort[] = "8530"; // High-numbered port so root isn't required.
1006 static constexpr char kBackendAddr[] = "192.0.2.1";
1007 static constexpr char kBackendPort[] = "8531"; // High-numbered port so root isn't required.
1008
1009 test::DnsTlsFrontend tls{kTlsAddr, kTlsPort, kBackendAddr, kBackendPort};
Ben Schwartz62176fd2019-01-22 17:32:17 -05001010
1011 DnsTlsServer server;
Mike Yue93d9ae2020-08-25 19:09:51 +08001012 DnsTlsSessionCache cache;
1013};
1014
1015TEST_F(DnsTlsSocketTest, SlowDestructor) {
1016 ASSERT_TRUE(tls.startServer());
Ben Schwartz62176fd2019-01-22 17:32:17 -05001017
Mike Yu1b9069c2020-08-25 15:17:29 +08001018 MockDnsTlsSocketObserver observer;
Mike Yue93d9ae2020-08-25 19:09:51 +08001019 auto socket = makeDnsTlsSocket(&observer);
1020
Ben Schwartz62176fd2019-01-22 17:32:17 -05001021 ASSERT_TRUE(socket->initialize());
Mike Yu441d9372020-07-15 17:06:22 +08001022 ASSERT_TRUE(socket->startHandshake());
Ben Schwartz62176fd2019-01-22 17:32:17 -05001023
1024 // Test: Time the socket destructor. This should be fast.
1025 auto before = std::chrono::steady_clock::now();
Mike Yu1b9069c2020-08-25 15:17:29 +08001026 EXPECT_CALL(observer, onClosed);
Ben Schwartz62176fd2019-01-22 17:32:17 -05001027 socket.reset();
1028 auto after = std::chrono::steady_clock::now();
1029 auto delay = after - before;
chenbruceaff85842019-05-31 15:46:42 +08001030 LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
Ben Schwartz62176fd2019-01-22 17:32:17 -05001031 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
1032 // it will wait for the timeout, which is expected to take 20seconds.
1033 EXPECT_LT(delay, std::chrono::seconds{5});
1034}
1035
Mike Yue93d9ae2020-08-25 19:09:51 +08001036TEST_F(DnsTlsSocketTest, StartHandshake) {
Mike Yu441d9372020-07-15 17:06:22 +08001037 ASSERT_TRUE(tls.startServer());
1038
Mike Yue93d9ae2020-08-25 19:09:51 +08001039 MockDnsTlsSocketObserver observer;
1040 auto socket = makeDnsTlsSocket(&observer);
Mike Yu441d9372020-07-15 17:06:22 +08001041
1042 // Call the function before the call to initialize().
1043 EXPECT_FALSE(socket->startHandshake());
1044
1045 // Call the function after the call to initialize().
1046 EXPECT_TRUE(socket->initialize());
1047 EXPECT_TRUE(socket->startHandshake());
1048
1049 // Call both of them again.
1050 EXPECT_FALSE(socket->initialize());
1051 EXPECT_FALSE(socket->startHandshake());
Mike Yue93d9ae2020-08-25 19:09:51 +08001052
1053 // Should happen when joining the loop thread in |socket| destruction.
1054 EXPECT_CALL(observer, onClosed);
1055}
1056
1057TEST_F(DnsTlsSocketTest, ShutdownSignal) {
1058 ASSERT_TRUE(tls.startServer());
1059
1060 MockDnsTlsSocketObserver observer;
1061 std::unique_ptr<DnsTlsSocket> socket;
1062
1063 const auto setupAndStartHandshake = [&]() {
1064 socket = makeDnsTlsSocket(&observer);
1065 EXPECT_TRUE(socket->initialize());
1066 enableAsyncHandshake(socket);
1067 EXPECT_TRUE(socket->startHandshake());
1068 };
1069 const auto triggerShutdown = [&](const std::string& traceLog) {
1070 SCOPED_TRACE(traceLog);
1071 auto before = std::chrono::steady_clock::now();
1072 EXPECT_CALL(observer, onClosed);
1073 socket.reset();
1074 auto after = std::chrono::steady_clock::now();
1075 auto delay = after - before;
1076 LOG(INFO) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1077 EXPECT_LT(delay, std::chrono::seconds{1});
1078 };
1079
1080 tls.setHangOnHandshakeForTesting(true);
1081
1082 // Test 1: Reset the DnsTlsSocket which is doing the handshake.
1083 setupAndStartHandshake();
1084 triggerShutdown("Shutdown handshake w/o query requests");
1085
1086 // Test 2: Reset the DnsTlsSocket which is doing the handshake with some query requests.
1087 setupAndStartHandshake();
1088
1089 // DnsTlsSocket doesn't report the status of pending queries. The decision whether to mark
1090 // a query request as failed or not is made in DnsTlsTransport.
1091 EXPECT_CALL(observer, onResponse).Times(0);
1092 EXPECT_TRUE(socket->query(1, makeSlice(QUERY)));
1093 EXPECT_TRUE(socket->query(2, makeSlice(QUERY)));
1094 triggerShutdown("Shutdown handshake w/ query requests");
Mike Yu441d9372020-07-15 17:06:22 +08001095}
1096
Mike Yuc52739e2018-10-19 21:06:32 +08001097} // end of namespace net
1098} // end of namespace android