blob: 72e68972de73052c6e3125223f8add1dae798b46 [file] [log] [blame]
Mike Yuc52739e2018-10-19 21:06:32 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yuc52739e2018-10-19 21:06:32 +080018
chenbruceaff85842019-05-31 15:46:42 +080019#include <arpa/inet.h>
20
21#include <chrono>
22
23#include <android-base/logging.h>
24#include <android-base/macros.h>
Mike Yuc52739e2018-10-19 21:06:32 +080025#include <gtest/gtest.h>
chenbruceaff85842019-05-31 15:46:42 +080026#include <netdutils/Slice.h>
Mike Yuc52739e2018-10-19 21:06:32 +080027
Bernie Innocentiec4219b2019-01-30 11:16:36 +090028#include "DnsTlsDispatcher.h"
29#include "DnsTlsQueryMap.h"
30#include "DnsTlsServer.h"
31#include "DnsTlsSessionCache.h"
32#include "DnsTlsSocket.h"
33#include "DnsTlsTransport.h"
34#include "IDnsTlsSocket.h"
35#include "IDnsTlsSocketFactory.h"
36#include "IDnsTlsSocketObserver.h"
chenbruceb43ec752019-07-24 20:19:41 +080037#include "tests/dns_responder/dns_tls_frontend.h"
Ben Schwartz62176fd2019-01-22 17:32:17 -050038
Mike Yuc52739e2018-10-19 21:06:32 +080039namespace android {
40namespace net {
41
42using netdutils::Slice;
43using netdutils::makeSlice;
44
45typedef std::vector<uint8_t> bytevec;
46
47static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
48 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
49 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
50 // IPv4 parse succeeded, so it's IPv4
51 sin->sin_family = AF_INET;
52 sin->sin_port = htons(port);
53 return;
54 }
55 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
56 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
57 // IPv6 parse succeeded, so it's IPv6.
58 sin6->sin6_family = AF_INET6;
59 sin6->sin6_port = htons(port);
60 return;
61 }
chenbruceaff85842019-05-31 15:46:42 +080062 LOG(ERROR) << "Failed to parse server address: " << server;
Mike Yuc52739e2018-10-19 21:06:32 +080063}
64
Mike Yuc52739e2018-10-19 21:06:32 +080065std::string SERVERNAME1 = "dns.example.com";
66std::string SERVERNAME2 = "dns.example.org";
67
68// BaseTest just provides constants that are useful for the tests.
69class BaseTest : public ::testing::Test {
70 protected:
71 BaseTest() {
72 parseServer("192.0.2.1", 853, &V4ADDR1);
73 parseServer("192.0.2.2", 853, &V4ADDR2);
74 parseServer("2001:db8::1", 853, &V6ADDR1);
75 parseServer("2001:db8::2", 853, &V6ADDR2);
76
77 SERVER1 = DnsTlsServer(V4ADDR1);
Mike Yuc52739e2018-10-19 21:06:32 +080078 SERVER1.name = SERVERNAME1;
79 }
80
81 sockaddr_storage V4ADDR1;
82 sockaddr_storage V4ADDR2;
83 sockaddr_storage V6ADDR1;
84 sockaddr_storage V6ADDR2;
85
86 DnsTlsServer SERVER1;
87};
88
89bytevec make_query(uint16_t id, size_t size) {
90 bytevec vec(size);
91 vec[0] = id >> 8;
92 vec[1] = id;
93 // Arbitrarily fill the query body with unique data.
94 for (size_t i = 2; i < size; ++i) {
95 vec[i] = id + i;
96 }
97 return vec;
98}
99
100// Query constants
101const unsigned MARK = 123;
102const uint16_t ID = 52;
103const uint16_t SIZE = 22;
104const bytevec QUERY = make_query(ID, SIZE);
105
106template <class T>
107class FakeSocketFactory : public IDnsTlsSocketFactory {
108 public:
109 FakeSocketFactory() {}
110 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
111 const DnsTlsServer& server ATTRIBUTE_UNUSED,
112 unsigned mark ATTRIBUTE_UNUSED,
113 IDnsTlsSocketObserver* observer,
114 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
115 return std::make_unique<T>(observer);
116 }
117};
118
119bytevec make_echo(uint16_t id, const Slice query) {
120 bytevec response(query.size() + 2);
121 response[0] = id >> 8;
122 response[1] = id;
123 // Echo the query as the fake response.
124 memcpy(response.data() + 2, query.base(), query.size());
125 return response;
126}
127
128// Simplest possible fake server. This just echoes the query as the response.
129class FakeSocketEcho : public IDnsTlsSocket {
130 public:
131 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
132 bool query(uint16_t id, const Slice query) override {
133 // Return the response immediately (asynchronously).
134 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
135 return true;
136 }
137
138 private:
139 IDnsTlsSocketObserver* const mObserver;
140};
141
142class TransportTest : public BaseTest {};
143
144TEST_F(TransportTest, Query) {
145 FakeSocketFactory<FakeSocketEcho> factory;
146 DnsTlsTransport transport(SERVER1, MARK, &factory);
147 auto r = transport.query(makeSlice(QUERY)).get();
148
149 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
150 EXPECT_EQ(QUERY, r.response);
151}
152
153// Fake Socket that echoes the observed query ID as the response body.
154class FakeSocketId : public IDnsTlsSocket {
155 public:
156 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
157 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
158 // Return the response immediately (asynchronously).
159 bytevec response(4);
160 // Echo the ID in the header to match the response to the query.
161 // This will be overwritten by DnsTlsQueryMap.
162 response[0] = id >> 8;
163 response[1] = id;
164 // Echo the ID in the body, so that the test can verify which ID was used by
165 // DnsTlsQueryMap.
166 response[2] = id >> 8;
167 response[3] = id;
168 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
169 return true;
170 }
171
172 private:
173 IDnsTlsSocketObserver* const mObserver;
174};
175
176// Test that IDs are properly reused
177TEST_F(TransportTest, IdReuse) {
178 FakeSocketFactory<FakeSocketId> factory;
179 DnsTlsTransport transport(SERVER1, MARK, &factory);
180 for (int i = 0; i < 100; ++i) {
181 // Send a query.
182 std::future<DnsTlsServer::Result> f = transport.query(makeSlice(QUERY));
183 // Wait for the response.
184 DnsTlsServer::Result r = f.get();
185 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
186
187 // All queries should have an observed ID of zero, because it is returned to the ID pool
188 // after each use.
189 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
190 }
191}
192
193// These queries might be handled in serial or parallel as they race the
194// responses.
195TEST_F(TransportTest, RacingQueries_10000) {
196 FakeSocketFactory<FakeSocketEcho> factory;
197 DnsTlsTransport transport(SERVER1, MARK, &factory);
198 std::vector<std::future<DnsTlsTransport::Result>> results;
199 // Fewer than 65536 queries to avoid ID exhaustion.
200 const int num_queries = 10000;
201 results.reserve(num_queries);
202 for (int i = 0; i < num_queries; ++i) {
203 results.push_back(transport.query(makeSlice(QUERY)));
204 }
205 for (auto& result : results) {
206 auto r = result.get();
207 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
208 EXPECT_EQ(QUERY, r.response);
209 }
210}
211
212// A server that waits until sDelay queries are queued before responding.
213class FakeSocketDelay : public IDnsTlsSocket {
214 public:
215 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
216 ~FakeSocketDelay() { std::lock_guard guard(mLock); }
217 static size_t sDelay;
218 static bool sReverse;
219
220 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800221 LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
Mike Yuc52739e2018-10-19 21:06:32 +0800222 std::lock_guard guard(mLock);
223 // Check for duplicate IDs.
224 EXPECT_EQ(0U, mIds.count(id));
225 mIds.insert(id);
226
227 // Store response.
228 mResponses.push_back(make_echo(id, query));
229
chenbruceaff85842019-05-31 15:46:42 +0800230 LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
Mike Yuc52739e2018-10-19 21:06:32 +0800231 if (mResponses.size() == sDelay) {
232 std::thread(&FakeSocketDelay::sendResponses, this).detach();
233 }
234 return true;
235 }
236
237 private:
238 void sendResponses() {
239 std::lock_guard guard(mLock);
240 if (sReverse) {
241 std::reverse(std::begin(mResponses), std::end(mResponses));
242 }
243 for (auto& response : mResponses) {
244 mObserver->onResponse(response);
245 }
246 mIds.clear();
247 mResponses.clear();
248 }
249
250 std::mutex mLock;
251 IDnsTlsSocketObserver* const mObserver;
252 std::set<uint16_t> mIds GUARDED_BY(mLock);
253 std::vector<bytevec> mResponses GUARDED_BY(mLock);
254};
255
256size_t FakeSocketDelay::sDelay;
257bool FakeSocketDelay::sReverse;
258
259TEST_F(TransportTest, ParallelColliding) {
260 FakeSocketDelay::sDelay = 10;
261 FakeSocketDelay::sReverse = false;
262 FakeSocketFactory<FakeSocketDelay> factory;
263 DnsTlsTransport transport(SERVER1, MARK, &factory);
264 std::vector<std::future<DnsTlsTransport::Result>> results;
265 // Fewer than 65536 queries to avoid ID exhaustion.
266 results.reserve(FakeSocketDelay::sDelay);
267 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
268 results.push_back(transport.query(makeSlice(QUERY)));
269 }
270 for (auto& result : results) {
271 auto r = result.get();
272 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
273 EXPECT_EQ(QUERY, r.response);
274 }
275}
276
277TEST_F(TransportTest, ParallelColliding_Max) {
278 FakeSocketDelay::sDelay = 65536;
279 FakeSocketDelay::sReverse = false;
280 FakeSocketFactory<FakeSocketDelay> factory;
281 DnsTlsTransport transport(SERVER1, MARK, &factory);
282 std::vector<std::future<DnsTlsTransport::Result>> results;
283 // Exactly 65536 queries should still be possible in parallel,
284 // even if they all have the same original ID.
285 results.reserve(FakeSocketDelay::sDelay);
286 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
287 results.push_back(transport.query(makeSlice(QUERY)));
288 }
289 for (auto& result : results) {
290 auto r = result.get();
291 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
292 EXPECT_EQ(QUERY, r.response);
293 }
294}
295
296TEST_F(TransportTest, ParallelUnique) {
297 FakeSocketDelay::sDelay = 10;
298 FakeSocketDelay::sReverse = false;
299 FakeSocketFactory<FakeSocketDelay> factory;
300 DnsTlsTransport transport(SERVER1, MARK, &factory);
301 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
302 std::vector<std::future<DnsTlsTransport::Result>> results;
303 results.reserve(FakeSocketDelay::sDelay);
304 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
305 queries[i] = make_query(i, SIZE);
306 results.push_back(transport.query(makeSlice(queries[i])));
307 }
308 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
309 auto r = results[i].get();
310 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
311 EXPECT_EQ(queries[i], r.response);
312 }
313}
314
315TEST_F(TransportTest, ParallelUnique_Max) {
316 FakeSocketDelay::sDelay = 65536;
317 FakeSocketDelay::sReverse = false;
318 FakeSocketFactory<FakeSocketDelay> factory;
319 DnsTlsTransport transport(SERVER1, MARK, &factory);
320 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
321 std::vector<std::future<DnsTlsTransport::Result>> results;
322 // Exactly 65536 queries should still be possible in parallel,
323 // and they should all be mapped correctly back to the original ID.
324 results.reserve(FakeSocketDelay::sDelay);
325 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
326 queries[i] = make_query(i, SIZE);
327 results.push_back(transport.query(makeSlice(queries[i])));
328 }
329 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
330 auto r = results[i].get();
331 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
332 EXPECT_EQ(queries[i], r.response);
333 }
334}
335
336TEST_F(TransportTest, IdExhaustion) {
337 const int num_queries = 65536;
338 // A delay of 65537 is unreachable, because the maximum number
339 // of outstanding queries is 65536.
340 FakeSocketDelay::sDelay = num_queries + 1;
341 FakeSocketDelay::sReverse = false;
342 FakeSocketFactory<FakeSocketDelay> factory;
343 DnsTlsTransport transport(SERVER1, MARK, &factory);
344 std::vector<std::future<DnsTlsTransport::Result>> results;
345 // Issue the maximum number of queries.
346 results.reserve(num_queries);
347 for (int i = 0; i < num_queries; ++i) {
348 results.push_back(transport.query(makeSlice(QUERY)));
349 }
350
351 // The ID space is now full, so subsequent queries should fail immediately.
352 auto r = transport.query(makeSlice(QUERY)).get();
353 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
354 EXPECT_TRUE(r.response.empty());
355
356 for (auto& result : results) {
357 // All other queries should remain outstanding.
358 EXPECT_EQ(std::future_status::timeout,
359 result.wait_for(std::chrono::duration<int>::zero()));
360 }
361}
362
363// Responses can come back from the server in any order. This should have no
364// effect on Transport's observed behavior.
365TEST_F(TransportTest, ReverseOrder) {
366 FakeSocketDelay::sDelay = 10;
367 FakeSocketDelay::sReverse = true;
368 FakeSocketFactory<FakeSocketDelay> factory;
369 DnsTlsTransport transport(SERVER1, MARK, &factory);
370 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
371 std::vector<std::future<DnsTlsTransport::Result>> results;
372 results.reserve(FakeSocketDelay::sDelay);
373 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
374 queries[i] = make_query(i, SIZE);
375 results.push_back(transport.query(makeSlice(queries[i])));
376 }
377 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
378 auto r = results[i].get();
379 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
380 EXPECT_EQ(queries[i], r.response);
381 }
382}
383
384TEST_F(TransportTest, ReverseOrder_Max) {
385 FakeSocketDelay::sDelay = 65536;
386 FakeSocketDelay::sReverse = true;
387 FakeSocketFactory<FakeSocketDelay> factory;
388 DnsTlsTransport transport(SERVER1, MARK, &factory);
389 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
390 std::vector<std::future<DnsTlsTransport::Result>> results;
391 results.reserve(FakeSocketDelay::sDelay);
392 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
393 queries[i] = make_query(i, SIZE);
394 results.push_back(transport.query(makeSlice(queries[i])));
395 }
396 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
397 auto r = results[i].get();
398 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
399 EXPECT_EQ(queries[i], r.response);
400 }
401}
402
403// Returning null from the factory indicates a connection failure.
404class NullSocketFactory : public IDnsTlsSocketFactory {
405 public:
406 NullSocketFactory() {}
407 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
408 const DnsTlsServer& server ATTRIBUTE_UNUSED,
409 unsigned mark ATTRIBUTE_UNUSED,
410 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
411 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
412 return nullptr;
413 }
414};
415
416TEST_F(TransportTest, ConnectFail) {
417 NullSocketFactory factory;
418 DnsTlsTransport transport(SERVER1, MARK, &factory);
419 auto r = transport.query(makeSlice(QUERY)).get();
420
421 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
422 EXPECT_TRUE(r.response.empty());
423}
424
425// Simulate a socket that connects but then immediately receives a server
426// close notification.
427class FakeSocketClose : public IDnsTlsSocket {
428 public:
429 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
430 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
431 ~FakeSocketClose() { mCloser.join(); }
432 bool query(uint16_t id ATTRIBUTE_UNUSED,
433 const Slice query ATTRIBUTE_UNUSED) override {
434 return true;
435 }
436
437 private:
438 std::thread mCloser;
439};
440
441TEST_F(TransportTest, CloseRetryFail) {
442 FakeSocketFactory<FakeSocketClose> factory;
443 DnsTlsTransport transport(SERVER1, MARK, &factory);
444 auto r = transport.query(makeSlice(QUERY)).get();
445
446 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
447 EXPECT_TRUE(r.response.empty());
448}
449
450// Simulate a server that occasionally closes the connection and silently
451// drops some queries.
452class FakeSocketLimited : public IDnsTlsSocket {
453 public:
454 static int sLimit; // Number of queries to answer per socket.
455 static size_t sMaxSize; // Silently discard queries greater than this size.
456 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
457 : mObserver(observer), mQueries(0) {}
458 ~FakeSocketLimited() {
459 {
chenbruceaff85842019-05-31 15:46:42 +0800460 LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800461 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800462 LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800463 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800464 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800465 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800466 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800467 }
468 mThreads.clear();
469 }
470
471 if (mCloser) {
chenbruceaff85842019-05-31 15:46:42 +0800472 LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800473 mCloser->join();
chenbruceaff85842019-05-31 15:46:42 +0800474 LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800475 }
476 }
477 bool query(uint16_t id, const Slice query) override {
chenbruceaff85842019-05-31 15:46:42 +0800478 LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800479 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800480 LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800481 ++mQueries;
482
483 if (mQueries <= sLimit) {
chenbruceaff85842019-05-31 15:46:42 +0800484 LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
Mike Yuc52739e2018-10-19 21:06:32 +0800485 if (query.size() <= sMaxSize) {
486 // Return the response immediately (asynchronously).
487 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
488 }
489 }
490 if (mQueries == sLimit) {
491 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
492 }
493 return mQueries <= sLimit;
494 }
495
496 private:
497 void sendClose() {
498 {
chenbruceaff85842019-05-31 15:46:42 +0800499 LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800500 std::lock_guard guard(mLock);
chenbruceaff85842019-05-31 15:46:42 +0800501 LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
Mike Yuc52739e2018-10-19 21:06:32 +0800502 for (auto& thread : mThreads) {
chenbruceaff85842019-05-31 15:46:42 +0800503 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800504 thread.join();
chenbruceaff85842019-05-31 15:46:42 +0800505 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
Mike Yuc52739e2018-10-19 21:06:32 +0800506 }
507 mThreads.clear();
508 }
509 mObserver->onClosed();
510 }
511 std::mutex mLock;
512 IDnsTlsSocketObserver* const mObserver;
513 int mQueries GUARDED_BY(mLock);
514 std::vector<std::thread> mThreads GUARDED_BY(mLock);
515 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
516};
517
518int FakeSocketLimited::sLimit;
519size_t FakeSocketLimited::sMaxSize;
520
521TEST_F(TransportTest, SilentDrop) {
522 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
523 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
524 FakeSocketFactory<FakeSocketLimited> factory;
525 DnsTlsTransport transport(SERVER1, MARK, &factory);
526
527 // Queue up 10 queries. They will all be ignored, and after the 10th,
528 // the socket will close. Transport will retry them all, until they
529 // all hit the retry limit and expire.
530 std::vector<std::future<DnsTlsTransport::Result>> results;
531 results.reserve(FakeSocketLimited::sLimit);
532 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
533 results.push_back(transport.query(makeSlice(QUERY)));
534 }
535 for (auto& result : results) {
536 auto r = result.get();
537 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
538 EXPECT_TRUE(r.response.empty());
539 }
540}
541
542TEST_F(TransportTest, PartialDrop) {
543 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
544 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
545 FakeSocketFactory<FakeSocketLimited> factory;
546 DnsTlsTransport transport(SERVER1, MARK, &factory);
547
548 // Queue up 100 queries, alternating "short" which will be served and "long"
549 // which will be dropped.
550 const int num_queries = 10 * FakeSocketLimited::sLimit;
551 std::vector<bytevec> queries(num_queries);
552 std::vector<std::future<DnsTlsTransport::Result>> results;
553 results.reserve(num_queries);
554 for (int i = 0; i < num_queries; ++i) {
555 queries[i] = make_query(i, SIZE + (i % 2));
556 results.push_back(transport.query(makeSlice(queries[i])));
557 }
558 // Just check the short queries, which are at the even indices.
559 for (int i = 0; i < num_queries; i += 2) {
560 auto r = results[i].get();
561 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
562 EXPECT_EQ(queries[i], r.response);
563 }
564}
565
566// Simulate a malfunctioning server that injects extra miscellaneous
567// responses to queries that were not asked. This will cause wrong answers but
568// must not crash the Transport.
569class FakeSocketGarbage : public IDnsTlsSocket {
570 public:
571 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
572 // Inject a garbage event.
573 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
574 }
575 ~FakeSocketGarbage() {
576 std::lock_guard guard(mLock);
577 for (auto& thread : mThreads) {
578 thread.join();
579 }
580 }
581 bool query(uint16_t id, const Slice query) override {
582 std::lock_guard guard(mLock);
583 // Return the response twice.
584 auto echo = make_echo(id, query);
585 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
586 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
587 // Also return some other garbage
588 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
589 return true;
590 }
591
592 private:
593 std::mutex mLock;
594 std::vector<std::thread> mThreads GUARDED_BY(mLock);
595 IDnsTlsSocketObserver* const mObserver;
596};
597
598TEST_F(TransportTest, IgnoringGarbage) {
599 FakeSocketFactory<FakeSocketGarbage> factory;
600 DnsTlsTransport transport(SERVER1, MARK, &factory);
601 for (int i = 0; i < 10; ++i) {
602 auto r = transport.query(makeSlice(QUERY)).get();
603
604 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
605 // Don't check the response because this server is malfunctioning.
606 }
607}
608
609// Dispatcher tests
610class DispatcherTest : public BaseTest {};
611
612TEST_F(DispatcherTest, Query) {
613 bytevec ans(4096);
614 int resplen = 0;
615
616 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
617 DnsTlsDispatcher dispatcher(std::move(factory));
618 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
619 makeSlice(ans), &resplen);
620
621 EXPECT_EQ(DnsTlsTransport::Response::success, r);
622 EXPECT_EQ(int(QUERY.size()), resplen);
623 ans.resize(resplen);
624 EXPECT_EQ(QUERY, ans);
625}
626
627TEST_F(DispatcherTest, AnswerTooLarge) {
628 bytevec ans(SIZE - 1); // Too small to hold the answer
629 int resplen = 0;
630
631 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
632 DnsTlsDispatcher dispatcher(std::move(factory));
633 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
634 makeSlice(ans), &resplen);
635
636 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
637}
638
639template<class T>
640class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
641 public:
642 TrackingFakeSocketFactory() {}
643 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
644 const DnsTlsServer& server,
645 unsigned mark,
646 IDnsTlsSocketObserver* observer,
647 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
648 std::lock_guard guard(mLock);
649 keys.emplace(mark, server);
650 return std::make_unique<T>(observer);
651 }
652 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
653
654 private:
655 std::mutex mLock;
656};
657
658TEST_F(DispatcherTest, Dispatching) {
659 FakeSocketDelay::sDelay = 5;
660 FakeSocketDelay::sReverse = true;
661 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
662 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
663 DnsTlsDispatcher dispatcher(std::move(factory));
664
665 // Populate a vector of two servers and two socket marks, four combinations
666 // in total.
667 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
668 keys.emplace_back(MARK, SERVER1);
669 keys.emplace_back(MARK + 1, SERVER1);
670 keys.emplace_back(MARK, V4ADDR2);
671 keys.emplace_back(MARK + 1, V4ADDR2);
672
673 // Do several queries on each server. They should all succeed.
674 std::vector<std::thread> threads;
675 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
676 auto key = keys[i % keys.size()];
677 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
678 auto q = make_query(i, SIZE);
679 bytevec ans(4096);
680 int resplen = 0;
681 unsigned mark = key.first;
682 const DnsTlsServer& server = key.second;
683 auto r = dispatcher->query(server, mark, makeSlice(q),
684 makeSlice(ans), &resplen);
685 EXPECT_EQ(DnsTlsTransport::Response::success, r);
686 EXPECT_EQ(int(q.size()), resplen);
687 ans.resize(resplen);
688 EXPECT_EQ(q, ans);
689 }, &dispatcher);
690 }
691 for (auto& thread : threads) {
692 thread.join();
693 }
694 // We expect that the factory created one socket for each key.
695 EXPECT_EQ(keys.size(), weak_factory->keys.size());
696 for (auto& key : keys) {
697 EXPECT_EQ(1U, weak_factory->keys.count(key));
698 }
699}
700
701// Check DnsTlsServer's comparison logic.
702AddressComparator ADDRESS_COMPARATOR;
703bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
704 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
705 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
706 EXPECT_FALSE(cmp1 && cmp2);
707 return !cmp1 && !cmp2;
708}
709
710void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
711 EXPECT_TRUE(s1 == s1);
712 EXPECT_TRUE(s2 == s2);
713 EXPECT_TRUE(isAddressEqual(s1, s1));
714 EXPECT_TRUE(isAddressEqual(s2, s2));
715
716 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
717 EXPECT_FALSE(s1 == s2);
718 EXPECT_FALSE(s2 == s1);
719}
720
721class ServerTest : public BaseTest {};
722
723TEST_F(ServerTest, IPv4) {
724 checkUnequal(V4ADDR1, V4ADDR2);
725 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
726}
727
728TEST_F(ServerTest, IPv6) {
729 checkUnequal(V6ADDR1, V6ADDR2);
730 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
731}
732
733TEST_F(ServerTest, MixedAddressFamily) {
734 checkUnequal(V6ADDR1, V4ADDR1);
735 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
736}
737
738TEST_F(ServerTest, IPv6ScopeId) {
739 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
740 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
741 addr1->sin6_scope_id = 1;
742 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
743 addr2->sin6_scope_id = 2;
744 checkUnequal(s1, s2);
745 EXPECT_FALSE(isAddressEqual(s1, s2));
746
747 EXPECT_FALSE(s1.wasExplicitlyConfigured());
748 EXPECT_FALSE(s2.wasExplicitlyConfigured());
749}
750
751TEST_F(ServerTest, IPv6FlowInfo) {
752 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
753 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
754 addr1->sin6_flowinfo = 1;
755 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
756 addr2->sin6_flowinfo = 2;
757 // All comparisons ignore flowinfo.
758 EXPECT_EQ(s1, s2);
759 EXPECT_TRUE(isAddressEqual(s1, s2));
760
761 EXPECT_FALSE(s1.wasExplicitlyConfigured());
762 EXPECT_FALSE(s2.wasExplicitlyConfigured());
763}
764
765TEST_F(ServerTest, Port) {
766 DnsTlsServer s1, s2;
767 parseServer("192.0.2.1", 853, &s1.ss);
768 parseServer("192.0.2.1", 854, &s2.ss);
769 checkUnequal(s1, s2);
770 EXPECT_TRUE(isAddressEqual(s1, s2));
771
772 DnsTlsServer s3, s4;
773 parseServer("2001:db8::1", 853, &s3.ss);
774 parseServer("2001:db8::1", 852, &s4.ss);
775 checkUnequal(s3, s4);
776 EXPECT_TRUE(isAddressEqual(s3, s4));
777
778 EXPECT_FALSE(s1.wasExplicitlyConfigured());
779 EXPECT_FALSE(s2.wasExplicitlyConfigured());
780}
781
782TEST_F(ServerTest, Name) {
783 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
784 s1.name = SERVERNAME1;
785 checkUnequal(s1, s2);
786 s2.name = SERVERNAME2;
787 checkUnequal(s1, s2);
788 EXPECT_TRUE(isAddressEqual(s1, s2));
789
790 EXPECT_TRUE(s1.wasExplicitlyConfigured());
791 EXPECT_TRUE(s2.wasExplicitlyConfigured());
792}
793
Mike Yuc52739e2018-10-19 21:06:32 +0800794TEST(QueryMapTest, Basic) {
795 DnsTlsQueryMap map;
796
797 EXPECT_TRUE(map.empty());
798
799 bytevec q0 = make_query(999, SIZE);
800 bytevec q1 = make_query(888, SIZE);
801 bytevec q2 = make_query(777, SIZE);
802
803 auto f0 = map.recordQuery(makeSlice(q0));
804 auto f1 = map.recordQuery(makeSlice(q1));
805 auto f2 = map.recordQuery(makeSlice(q2));
806
807 // Check return values of recordQuery
808 EXPECT_EQ(0, f0->query.newId);
809 EXPECT_EQ(1, f1->query.newId);
810 EXPECT_EQ(2, f2->query.newId);
811
812 // Check side effects of recordQuery
813 EXPECT_FALSE(map.empty());
814
815 auto all = map.getAll();
816 EXPECT_EQ(3U, all.size());
817
818 EXPECT_EQ(0, all[0].newId);
819 EXPECT_EQ(1, all[1].newId);
820 EXPECT_EQ(2, all[2].newId);
821
822 EXPECT_EQ(makeSlice(q0), all[0].query);
823 EXPECT_EQ(makeSlice(q1), all[1].query);
824 EXPECT_EQ(makeSlice(q2), all[2].query);
825
826 bytevec a0 = make_query(0, SIZE);
827 bytevec a1 = make_query(1, SIZE);
828 bytevec a2 = make_query(2, SIZE);
829
830 // Return responses out of order
831 map.onResponse(a2);
832 map.onResponse(a0);
833 map.onResponse(a1);
834
835 EXPECT_TRUE(map.empty());
836
837 auto r0 = f0->result.get();
838 auto r1 = f1->result.get();
839 auto r2 = f2->result.get();
840
841 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
842 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
843 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
844
845 const bytevec& d0 = r0.response;
846 const bytevec& d1 = r1.response;
847 const bytevec& d2 = r2.response;
848
849 // The ID should match the query
850 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
851 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
852 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
853 // The body should match the answer
854 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
855 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
856 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
857}
858
859TEST(QueryMapTest, FillHole) {
860 DnsTlsQueryMap map;
861 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
862 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
863 futures[i] = map.recordQuery(makeSlice(QUERY));
864 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
865 EXPECT_EQ(i, futures[i]->query.newId);
866 }
867
868 // The map should now be full.
869 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
870
871 // Trying to add another query should fail because the map is full.
872 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
873
874 // Send an answer to query 40000
875 auto answer = make_query(40000, SIZE);
876 map.onResponse(answer);
877 auto result = futures[40000]->result.get();
878 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
879 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
880 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
881 bytevec(result.response.begin() + 2, result.response.end()));
882
883 // There should now be room in the map.
884 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
885 auto f = map.recordQuery(makeSlice(QUERY));
886 ASSERT_TRUE(f);
887 EXPECT_EQ(40000, f->query.newId);
888
889 // The map should now be full again.
890 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
891 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
892}
893
Ben Schwartz62176fd2019-01-22 17:32:17 -0500894class StubObserver : public IDnsTlsSocketObserver {
895 public:
896 bool closed = false;
897 void onResponse(std::vector<uint8_t>) override {}
898
899 void onClosed() override { closed = true; }
900};
901
902TEST(DnsTlsSocketTest, SlowDestructor) {
903 constexpr char tls_addr[] = "127.0.0.3";
904 constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
905 // This test doesn't perform any queries, so the backend address can be invalid.
906 constexpr char backend_addr[] = "192.0.2.1";
907 constexpr char backend_port[] = "1";
908
909 test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
910 ASSERT_TRUE(tls.startServer());
911
912 DnsTlsServer server;
913 parseServer(tls_addr, 8530, &server.ss);
914
915 StubObserver observer;
916 ASSERT_FALSE(observer.closed);
917 DnsTlsSessionCache cache;
918 auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
919 ASSERT_TRUE(socket->initialize());
920
921 // Test: Time the socket destructor. This should be fast.
922 auto before = std::chrono::steady_clock::now();
923 socket.reset();
924 auto after = std::chrono::steady_clock::now();
925 auto delay = after - before;
chenbruceaff85842019-05-31 15:46:42 +0800926 LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
Ben Schwartz62176fd2019-01-22 17:32:17 -0500927 EXPECT_TRUE(observer.closed);
928 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
929 // it will wait for the timeout, which is expected to take 20seconds.
930 EXPECT_LT(delay, std::chrono::seconds{5});
931}
932
Mike Yuc52739e2018-10-19 21:06:32 +0800933} // end of namespace net
934} // end of namespace android