blob: 41a2c371964e3c2f32233d157d09bd117e15f8a7 [file] [log] [blame]
Mike Yuc52739e2018-10-19 21:06:32 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yuc52739e2018-10-19 21:06:32 +080018#define LOG_NDEBUG 1 // Set to 0 to enable verbose debug logging
19
20#include <gtest/gtest.h>
21
Bernie Innocentiec4219b2019-01-30 11:16:36 +090022#include "DnsTlsDispatcher.h"
23#include "DnsTlsQueryMap.h"
24#include "DnsTlsServer.h"
25#include "DnsTlsSessionCache.h"
26#include "DnsTlsSocket.h"
27#include "DnsTlsTransport.h"
28#include "IDnsTlsSocket.h"
29#include "IDnsTlsSocketFactory.h"
30#include "IDnsTlsSocketObserver.h"
Mike Yuc52739e2018-10-19 21:06:32 +080031
Ben Schwartz62176fd2019-01-22 17:32:17 -050032#include "dns_responder/dns_tls_frontend.h"
33
Mike Yuc52739e2018-10-19 21:06:32 +080034#include <chrono>
35#include <arpa/inet.h>
36#include <android-base/macros.h>
37#include <netdutils/Slice.h>
38
39#include "log/log.h"
40
41namespace android {
42namespace net {
43
44using netdutils::Slice;
45using netdutils::makeSlice;
46
47typedef std::vector<uint8_t> bytevec;
48
49static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
50 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
51 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
52 // IPv4 parse succeeded, so it's IPv4
53 sin->sin_family = AF_INET;
54 sin->sin_port = htons(port);
55 return;
56 }
57 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
58 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
59 // IPv6 parse succeeded, so it's IPv6.
60 sin6->sin6_family = AF_INET6;
61 sin6->sin6_port = htons(port);
62 return;
63 }
64 ALOGE("Failed to parse server address: %s", server);
65}
66
67bytevec FINGERPRINT1 = { 1 };
68bytevec FINGERPRINT2 = { 2 };
69
70std::string SERVERNAME1 = "dns.example.com";
71std::string SERVERNAME2 = "dns.example.org";
72
73// BaseTest just provides constants that are useful for the tests.
74class BaseTest : public ::testing::Test {
75 protected:
76 BaseTest() {
77 parseServer("192.0.2.1", 853, &V4ADDR1);
78 parseServer("192.0.2.2", 853, &V4ADDR2);
79 parseServer("2001:db8::1", 853, &V6ADDR1);
80 parseServer("2001:db8::2", 853, &V6ADDR2);
81
82 SERVER1 = DnsTlsServer(V4ADDR1);
83 SERVER1.fingerprints.insert(FINGERPRINT1);
84 SERVER1.name = SERVERNAME1;
85 }
86
87 sockaddr_storage V4ADDR1;
88 sockaddr_storage V4ADDR2;
89 sockaddr_storage V6ADDR1;
90 sockaddr_storage V6ADDR2;
91
92 DnsTlsServer SERVER1;
93};
94
95bytevec make_query(uint16_t id, size_t size) {
96 bytevec vec(size);
97 vec[0] = id >> 8;
98 vec[1] = id;
99 // Arbitrarily fill the query body with unique data.
100 for (size_t i = 2; i < size; ++i) {
101 vec[i] = id + i;
102 }
103 return vec;
104}
105
106// Query constants
107const unsigned MARK = 123;
108const uint16_t ID = 52;
109const uint16_t SIZE = 22;
110const bytevec QUERY = make_query(ID, SIZE);
111
112template <class T>
113class FakeSocketFactory : public IDnsTlsSocketFactory {
114 public:
115 FakeSocketFactory() {}
116 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
117 const DnsTlsServer& server ATTRIBUTE_UNUSED,
118 unsigned mark ATTRIBUTE_UNUSED,
119 IDnsTlsSocketObserver* observer,
120 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
121 return std::make_unique<T>(observer);
122 }
123};
124
125bytevec make_echo(uint16_t id, const Slice query) {
126 bytevec response(query.size() + 2);
127 response[0] = id >> 8;
128 response[1] = id;
129 // Echo the query as the fake response.
130 memcpy(response.data() + 2, query.base(), query.size());
131 return response;
132}
133
134// Simplest possible fake server. This just echoes the query as the response.
135class FakeSocketEcho : public IDnsTlsSocket {
136 public:
137 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
138 bool query(uint16_t id, const Slice query) override {
139 // Return the response immediately (asynchronously).
140 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
141 return true;
142 }
143
144 private:
145 IDnsTlsSocketObserver* const mObserver;
146};
147
148class TransportTest : public BaseTest {};
149
150TEST_F(TransportTest, Query) {
151 FakeSocketFactory<FakeSocketEcho> factory;
152 DnsTlsTransport transport(SERVER1, MARK, &factory);
153 auto r = transport.query(makeSlice(QUERY)).get();
154
155 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
156 EXPECT_EQ(QUERY, r.response);
157}
158
159// Fake Socket that echoes the observed query ID as the response body.
160class FakeSocketId : public IDnsTlsSocket {
161 public:
162 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
163 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
164 // Return the response immediately (asynchronously).
165 bytevec response(4);
166 // Echo the ID in the header to match the response to the query.
167 // This will be overwritten by DnsTlsQueryMap.
168 response[0] = id >> 8;
169 response[1] = id;
170 // Echo the ID in the body, so that the test can verify which ID was used by
171 // DnsTlsQueryMap.
172 response[2] = id >> 8;
173 response[3] = id;
174 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
175 return true;
176 }
177
178 private:
179 IDnsTlsSocketObserver* const mObserver;
180};
181
182// Test that IDs are properly reused
183TEST_F(TransportTest, IdReuse) {
184 FakeSocketFactory<FakeSocketId> factory;
185 DnsTlsTransport transport(SERVER1, MARK, &factory);
186 for (int i = 0; i < 100; ++i) {
187 // Send a query.
188 std::future<DnsTlsServer::Result> f = transport.query(makeSlice(QUERY));
189 // Wait for the response.
190 DnsTlsServer::Result r = f.get();
191 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
192
193 // All queries should have an observed ID of zero, because it is returned to the ID pool
194 // after each use.
195 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
196 }
197}
198
199// These queries might be handled in serial or parallel as they race the
200// responses.
201TEST_F(TransportTest, RacingQueries_10000) {
202 FakeSocketFactory<FakeSocketEcho> factory;
203 DnsTlsTransport transport(SERVER1, MARK, &factory);
204 std::vector<std::future<DnsTlsTransport::Result>> results;
205 // Fewer than 65536 queries to avoid ID exhaustion.
206 const int num_queries = 10000;
207 results.reserve(num_queries);
208 for (int i = 0; i < num_queries; ++i) {
209 results.push_back(transport.query(makeSlice(QUERY)));
210 }
211 for (auto& result : results) {
212 auto r = result.get();
213 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
214 EXPECT_EQ(QUERY, r.response);
215 }
216}
217
218// A server that waits until sDelay queries are queued before responding.
219class FakeSocketDelay : public IDnsTlsSocket {
220 public:
221 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
222 ~FakeSocketDelay() { std::lock_guard guard(mLock); }
223 static size_t sDelay;
224 static bool sReverse;
225
226 bool query(uint16_t id, const Slice query) override {
227 ALOGV("FakeSocketDelay got query with ID %d", int(id));
228 std::lock_guard guard(mLock);
229 // Check for duplicate IDs.
230 EXPECT_EQ(0U, mIds.count(id));
231 mIds.insert(id);
232
233 // Store response.
234 mResponses.push_back(make_echo(id, query));
235
236 ALOGV("Up to %zu out of %zu queries", mResponses.size(), sDelay);
237 if (mResponses.size() == sDelay) {
238 std::thread(&FakeSocketDelay::sendResponses, this).detach();
239 }
240 return true;
241 }
242
243 private:
244 void sendResponses() {
245 std::lock_guard guard(mLock);
246 if (sReverse) {
247 std::reverse(std::begin(mResponses), std::end(mResponses));
248 }
249 for (auto& response : mResponses) {
250 mObserver->onResponse(response);
251 }
252 mIds.clear();
253 mResponses.clear();
254 }
255
256 std::mutex mLock;
257 IDnsTlsSocketObserver* const mObserver;
258 std::set<uint16_t> mIds GUARDED_BY(mLock);
259 std::vector<bytevec> mResponses GUARDED_BY(mLock);
260};
261
262size_t FakeSocketDelay::sDelay;
263bool FakeSocketDelay::sReverse;
264
265TEST_F(TransportTest, ParallelColliding) {
266 FakeSocketDelay::sDelay = 10;
267 FakeSocketDelay::sReverse = false;
268 FakeSocketFactory<FakeSocketDelay> factory;
269 DnsTlsTransport transport(SERVER1, MARK, &factory);
270 std::vector<std::future<DnsTlsTransport::Result>> results;
271 // Fewer than 65536 queries to avoid ID exhaustion.
272 results.reserve(FakeSocketDelay::sDelay);
273 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
274 results.push_back(transport.query(makeSlice(QUERY)));
275 }
276 for (auto& result : results) {
277 auto r = result.get();
278 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
279 EXPECT_EQ(QUERY, r.response);
280 }
281}
282
283TEST_F(TransportTest, ParallelColliding_Max) {
284 FakeSocketDelay::sDelay = 65536;
285 FakeSocketDelay::sReverse = false;
286 FakeSocketFactory<FakeSocketDelay> factory;
287 DnsTlsTransport transport(SERVER1, MARK, &factory);
288 std::vector<std::future<DnsTlsTransport::Result>> results;
289 // Exactly 65536 queries should still be possible in parallel,
290 // even if they all have the same original ID.
291 results.reserve(FakeSocketDelay::sDelay);
292 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
293 results.push_back(transport.query(makeSlice(QUERY)));
294 }
295 for (auto& result : results) {
296 auto r = result.get();
297 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
298 EXPECT_EQ(QUERY, r.response);
299 }
300}
301
302TEST_F(TransportTest, ParallelUnique) {
303 FakeSocketDelay::sDelay = 10;
304 FakeSocketDelay::sReverse = false;
305 FakeSocketFactory<FakeSocketDelay> factory;
306 DnsTlsTransport transport(SERVER1, MARK, &factory);
307 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
308 std::vector<std::future<DnsTlsTransport::Result>> results;
309 results.reserve(FakeSocketDelay::sDelay);
310 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
311 queries[i] = make_query(i, SIZE);
312 results.push_back(transport.query(makeSlice(queries[i])));
313 }
314 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
315 auto r = results[i].get();
316 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
317 EXPECT_EQ(queries[i], r.response);
318 }
319}
320
321TEST_F(TransportTest, ParallelUnique_Max) {
322 FakeSocketDelay::sDelay = 65536;
323 FakeSocketDelay::sReverse = false;
324 FakeSocketFactory<FakeSocketDelay> factory;
325 DnsTlsTransport transport(SERVER1, MARK, &factory);
326 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
327 std::vector<std::future<DnsTlsTransport::Result>> results;
328 // Exactly 65536 queries should still be possible in parallel,
329 // and they should all be mapped correctly back to the original ID.
330 results.reserve(FakeSocketDelay::sDelay);
331 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
332 queries[i] = make_query(i, SIZE);
333 results.push_back(transport.query(makeSlice(queries[i])));
334 }
335 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
336 auto r = results[i].get();
337 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
338 EXPECT_EQ(queries[i], r.response);
339 }
340}
341
342TEST_F(TransportTest, IdExhaustion) {
343 const int num_queries = 65536;
344 // A delay of 65537 is unreachable, because the maximum number
345 // of outstanding queries is 65536.
346 FakeSocketDelay::sDelay = num_queries + 1;
347 FakeSocketDelay::sReverse = false;
348 FakeSocketFactory<FakeSocketDelay> factory;
349 DnsTlsTransport transport(SERVER1, MARK, &factory);
350 std::vector<std::future<DnsTlsTransport::Result>> results;
351 // Issue the maximum number of queries.
352 results.reserve(num_queries);
353 for (int i = 0; i < num_queries; ++i) {
354 results.push_back(transport.query(makeSlice(QUERY)));
355 }
356
357 // The ID space is now full, so subsequent queries should fail immediately.
358 auto r = transport.query(makeSlice(QUERY)).get();
359 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
360 EXPECT_TRUE(r.response.empty());
361
362 for (auto& result : results) {
363 // All other queries should remain outstanding.
364 EXPECT_EQ(std::future_status::timeout,
365 result.wait_for(std::chrono::duration<int>::zero()));
366 }
367}
368
369// Responses can come back from the server in any order. This should have no
370// effect on Transport's observed behavior.
371TEST_F(TransportTest, ReverseOrder) {
372 FakeSocketDelay::sDelay = 10;
373 FakeSocketDelay::sReverse = true;
374 FakeSocketFactory<FakeSocketDelay> factory;
375 DnsTlsTransport transport(SERVER1, MARK, &factory);
376 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
377 std::vector<std::future<DnsTlsTransport::Result>> results;
378 results.reserve(FakeSocketDelay::sDelay);
379 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
380 queries[i] = make_query(i, SIZE);
381 results.push_back(transport.query(makeSlice(queries[i])));
382 }
383 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
384 auto r = results[i].get();
385 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
386 EXPECT_EQ(queries[i], r.response);
387 }
388}
389
390TEST_F(TransportTest, ReverseOrder_Max) {
391 FakeSocketDelay::sDelay = 65536;
392 FakeSocketDelay::sReverse = true;
393 FakeSocketFactory<FakeSocketDelay> factory;
394 DnsTlsTransport transport(SERVER1, MARK, &factory);
395 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
396 std::vector<std::future<DnsTlsTransport::Result>> results;
397 results.reserve(FakeSocketDelay::sDelay);
398 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
399 queries[i] = make_query(i, SIZE);
400 results.push_back(transport.query(makeSlice(queries[i])));
401 }
402 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
403 auto r = results[i].get();
404 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
405 EXPECT_EQ(queries[i], r.response);
406 }
407}
408
409// Returning null from the factory indicates a connection failure.
410class NullSocketFactory : public IDnsTlsSocketFactory {
411 public:
412 NullSocketFactory() {}
413 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
414 const DnsTlsServer& server ATTRIBUTE_UNUSED,
415 unsigned mark ATTRIBUTE_UNUSED,
416 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
417 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
418 return nullptr;
419 }
420};
421
422TEST_F(TransportTest, ConnectFail) {
423 NullSocketFactory factory;
424 DnsTlsTransport transport(SERVER1, MARK, &factory);
425 auto r = transport.query(makeSlice(QUERY)).get();
426
427 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
428 EXPECT_TRUE(r.response.empty());
429}
430
431// Simulate a socket that connects but then immediately receives a server
432// close notification.
433class FakeSocketClose : public IDnsTlsSocket {
434 public:
435 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
436 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
437 ~FakeSocketClose() { mCloser.join(); }
438 bool query(uint16_t id ATTRIBUTE_UNUSED,
439 const Slice query ATTRIBUTE_UNUSED) override {
440 return true;
441 }
442
443 private:
444 std::thread mCloser;
445};
446
447TEST_F(TransportTest, CloseRetryFail) {
448 FakeSocketFactory<FakeSocketClose> factory;
449 DnsTlsTransport transport(SERVER1, MARK, &factory);
450 auto r = transport.query(makeSlice(QUERY)).get();
451
452 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
453 EXPECT_TRUE(r.response.empty());
454}
455
456// Simulate a server that occasionally closes the connection and silently
457// drops some queries.
458class FakeSocketLimited : public IDnsTlsSocket {
459 public:
460 static int sLimit; // Number of queries to answer per socket.
461 static size_t sMaxSize; // Silently discard queries greater than this size.
462 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
463 : mObserver(observer), mQueries(0) {}
464 ~FakeSocketLimited() {
465 {
466 ALOGV("~FakeSocketLimited acquiring mLock");
467 std::lock_guard guard(mLock);
468 ALOGV("~FakeSocketLimited acquired mLock");
469 for (auto& thread : mThreads) {
470 ALOGV("~FakeSocketLimited joining response thread");
471 thread.join();
472 ALOGV("~FakeSocketLimited joined response thread");
473 }
474 mThreads.clear();
475 }
476
477 if (mCloser) {
478 ALOGV("~FakeSocketLimited joining closer thread");
479 mCloser->join();
480 ALOGV("~FakeSocketLimited joined closer thread");
481 }
482 }
483 bool query(uint16_t id, const Slice query) override {
484 ALOGV("FakeSocketLimited::query acquiring mLock");
485 std::lock_guard guard(mLock);
486 ALOGV("FakeSocketLimited::query acquired mLock");
487 ++mQueries;
488
489 if (mQueries <= sLimit) {
490 ALOGV("size %zu vs. limit of %zu", query.size(), sMaxSize);
491 if (query.size() <= sMaxSize) {
492 // Return the response immediately (asynchronously).
493 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
494 }
495 }
496 if (mQueries == sLimit) {
497 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
498 }
499 return mQueries <= sLimit;
500 }
501
502 private:
503 void sendClose() {
504 {
505 ALOGV("FakeSocketLimited::sendClose acquiring mLock");
506 std::lock_guard guard(mLock);
507 ALOGV("FakeSocketLimited::sendClose acquired mLock");
508 for (auto& thread : mThreads) {
509 ALOGV("FakeSocketLimited::sendClose joining response thread");
510 thread.join();
511 ALOGV("FakeSocketLimited::sendClose joined response thread");
512 }
513 mThreads.clear();
514 }
515 mObserver->onClosed();
516 }
517 std::mutex mLock;
518 IDnsTlsSocketObserver* const mObserver;
519 int mQueries GUARDED_BY(mLock);
520 std::vector<std::thread> mThreads GUARDED_BY(mLock);
521 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
522};
523
524int FakeSocketLimited::sLimit;
525size_t FakeSocketLimited::sMaxSize;
526
527TEST_F(TransportTest, SilentDrop) {
528 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
529 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
530 FakeSocketFactory<FakeSocketLimited> factory;
531 DnsTlsTransport transport(SERVER1, MARK, &factory);
532
533 // Queue up 10 queries. They will all be ignored, and after the 10th,
534 // the socket will close. Transport will retry them all, until they
535 // all hit the retry limit and expire.
536 std::vector<std::future<DnsTlsTransport::Result>> results;
537 results.reserve(FakeSocketLimited::sLimit);
538 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
539 results.push_back(transport.query(makeSlice(QUERY)));
540 }
541 for (auto& result : results) {
542 auto r = result.get();
543 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
544 EXPECT_TRUE(r.response.empty());
545 }
546}
547
548TEST_F(TransportTest, PartialDrop) {
549 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
550 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
551 FakeSocketFactory<FakeSocketLimited> factory;
552 DnsTlsTransport transport(SERVER1, MARK, &factory);
553
554 // Queue up 100 queries, alternating "short" which will be served and "long"
555 // which will be dropped.
556 const int num_queries = 10 * FakeSocketLimited::sLimit;
557 std::vector<bytevec> queries(num_queries);
558 std::vector<std::future<DnsTlsTransport::Result>> results;
559 results.reserve(num_queries);
560 for (int i = 0; i < num_queries; ++i) {
561 queries[i] = make_query(i, SIZE + (i % 2));
562 results.push_back(transport.query(makeSlice(queries[i])));
563 }
564 // Just check the short queries, which are at the even indices.
565 for (int i = 0; i < num_queries; i += 2) {
566 auto r = results[i].get();
567 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
568 EXPECT_EQ(queries[i], r.response);
569 }
570}
571
572// Simulate a malfunctioning server that injects extra miscellaneous
573// responses to queries that were not asked. This will cause wrong answers but
574// must not crash the Transport.
575class FakeSocketGarbage : public IDnsTlsSocket {
576 public:
577 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
578 // Inject a garbage event.
579 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
580 }
581 ~FakeSocketGarbage() {
582 std::lock_guard guard(mLock);
583 for (auto& thread : mThreads) {
584 thread.join();
585 }
586 }
587 bool query(uint16_t id, const Slice query) override {
588 std::lock_guard guard(mLock);
589 // Return the response twice.
590 auto echo = make_echo(id, query);
591 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
592 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
593 // Also return some other garbage
594 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
595 return true;
596 }
597
598 private:
599 std::mutex mLock;
600 std::vector<std::thread> mThreads GUARDED_BY(mLock);
601 IDnsTlsSocketObserver* const mObserver;
602};
603
604TEST_F(TransportTest, IgnoringGarbage) {
605 FakeSocketFactory<FakeSocketGarbage> factory;
606 DnsTlsTransport transport(SERVER1, MARK, &factory);
607 for (int i = 0; i < 10; ++i) {
608 auto r = transport.query(makeSlice(QUERY)).get();
609
610 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
611 // Don't check the response because this server is malfunctioning.
612 }
613}
614
615// Dispatcher tests
616class DispatcherTest : public BaseTest {};
617
618TEST_F(DispatcherTest, Query) {
619 bytevec ans(4096);
620 int resplen = 0;
621
622 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
623 DnsTlsDispatcher dispatcher(std::move(factory));
624 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
625 makeSlice(ans), &resplen);
626
627 EXPECT_EQ(DnsTlsTransport::Response::success, r);
628 EXPECT_EQ(int(QUERY.size()), resplen);
629 ans.resize(resplen);
630 EXPECT_EQ(QUERY, ans);
631}
632
633TEST_F(DispatcherTest, AnswerTooLarge) {
634 bytevec ans(SIZE - 1); // Too small to hold the answer
635 int resplen = 0;
636
637 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
638 DnsTlsDispatcher dispatcher(std::move(factory));
639 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
640 makeSlice(ans), &resplen);
641
642 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
643}
644
645template<class T>
646class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
647 public:
648 TrackingFakeSocketFactory() {}
649 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
650 const DnsTlsServer& server,
651 unsigned mark,
652 IDnsTlsSocketObserver* observer,
653 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
654 std::lock_guard guard(mLock);
655 keys.emplace(mark, server);
656 return std::make_unique<T>(observer);
657 }
658 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
659
660 private:
661 std::mutex mLock;
662};
663
664TEST_F(DispatcherTest, Dispatching) {
665 FakeSocketDelay::sDelay = 5;
666 FakeSocketDelay::sReverse = true;
667 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
668 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
669 DnsTlsDispatcher dispatcher(std::move(factory));
670
671 // Populate a vector of two servers and two socket marks, four combinations
672 // in total.
673 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
674 keys.emplace_back(MARK, SERVER1);
675 keys.emplace_back(MARK + 1, SERVER1);
676 keys.emplace_back(MARK, V4ADDR2);
677 keys.emplace_back(MARK + 1, V4ADDR2);
678
679 // Do several queries on each server. They should all succeed.
680 std::vector<std::thread> threads;
681 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
682 auto key = keys[i % keys.size()];
683 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
684 auto q = make_query(i, SIZE);
685 bytevec ans(4096);
686 int resplen = 0;
687 unsigned mark = key.first;
688 const DnsTlsServer& server = key.second;
689 auto r = dispatcher->query(server, mark, makeSlice(q),
690 makeSlice(ans), &resplen);
691 EXPECT_EQ(DnsTlsTransport::Response::success, r);
692 EXPECT_EQ(int(q.size()), resplen);
693 ans.resize(resplen);
694 EXPECT_EQ(q, ans);
695 }, &dispatcher);
696 }
697 for (auto& thread : threads) {
698 thread.join();
699 }
700 // We expect that the factory created one socket for each key.
701 EXPECT_EQ(keys.size(), weak_factory->keys.size());
702 for (auto& key : keys) {
703 EXPECT_EQ(1U, weak_factory->keys.count(key));
704 }
705}
706
707// Check DnsTlsServer's comparison logic.
708AddressComparator ADDRESS_COMPARATOR;
709bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
710 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
711 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
712 EXPECT_FALSE(cmp1 && cmp2);
713 return !cmp1 && !cmp2;
714}
715
716void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
717 EXPECT_TRUE(s1 == s1);
718 EXPECT_TRUE(s2 == s2);
719 EXPECT_TRUE(isAddressEqual(s1, s1));
720 EXPECT_TRUE(isAddressEqual(s2, s2));
721
722 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
723 EXPECT_FALSE(s1 == s2);
724 EXPECT_FALSE(s2 == s1);
725}
726
727class ServerTest : public BaseTest {};
728
729TEST_F(ServerTest, IPv4) {
730 checkUnequal(V4ADDR1, V4ADDR2);
731 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
732}
733
734TEST_F(ServerTest, IPv6) {
735 checkUnequal(V6ADDR1, V6ADDR2);
736 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
737}
738
739TEST_F(ServerTest, MixedAddressFamily) {
740 checkUnequal(V6ADDR1, V4ADDR1);
741 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
742}
743
744TEST_F(ServerTest, IPv6ScopeId) {
745 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
746 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
747 addr1->sin6_scope_id = 1;
748 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
749 addr2->sin6_scope_id = 2;
750 checkUnequal(s1, s2);
751 EXPECT_FALSE(isAddressEqual(s1, s2));
752
753 EXPECT_FALSE(s1.wasExplicitlyConfigured());
754 EXPECT_FALSE(s2.wasExplicitlyConfigured());
755}
756
757TEST_F(ServerTest, IPv6FlowInfo) {
758 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
759 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
760 addr1->sin6_flowinfo = 1;
761 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
762 addr2->sin6_flowinfo = 2;
763 // All comparisons ignore flowinfo.
764 EXPECT_EQ(s1, s2);
765 EXPECT_TRUE(isAddressEqual(s1, s2));
766
767 EXPECT_FALSE(s1.wasExplicitlyConfigured());
768 EXPECT_FALSE(s2.wasExplicitlyConfigured());
769}
770
771TEST_F(ServerTest, Port) {
772 DnsTlsServer s1, s2;
773 parseServer("192.0.2.1", 853, &s1.ss);
774 parseServer("192.0.2.1", 854, &s2.ss);
775 checkUnequal(s1, s2);
776 EXPECT_TRUE(isAddressEqual(s1, s2));
777
778 DnsTlsServer s3, s4;
779 parseServer("2001:db8::1", 853, &s3.ss);
780 parseServer("2001:db8::1", 852, &s4.ss);
781 checkUnequal(s3, s4);
782 EXPECT_TRUE(isAddressEqual(s3, s4));
783
784 EXPECT_FALSE(s1.wasExplicitlyConfigured());
785 EXPECT_FALSE(s2.wasExplicitlyConfigured());
786}
787
788TEST_F(ServerTest, Name) {
789 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
790 s1.name = SERVERNAME1;
791 checkUnequal(s1, s2);
792 s2.name = SERVERNAME2;
793 checkUnequal(s1, s2);
794 EXPECT_TRUE(isAddressEqual(s1, s2));
795
796 EXPECT_TRUE(s1.wasExplicitlyConfigured());
797 EXPECT_TRUE(s2.wasExplicitlyConfigured());
798}
799
800TEST_F(ServerTest, Fingerprint) {
801 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
802
803 s1.fingerprints.insert(FINGERPRINT1);
804 checkUnequal(s1, s2);
805 EXPECT_TRUE(isAddressEqual(s1, s2));
806
807 s2.fingerprints.insert(FINGERPRINT2);
808 checkUnequal(s1, s2);
809 EXPECT_TRUE(isAddressEqual(s1, s2));
810
811 s2.fingerprints.insert(FINGERPRINT1);
812 checkUnequal(s1, s2);
813 EXPECT_TRUE(isAddressEqual(s1, s2));
814
815 s1.fingerprints.insert(FINGERPRINT2);
816 EXPECT_EQ(s1, s2);
817 EXPECT_TRUE(isAddressEqual(s1, s2));
818
819 EXPECT_TRUE(s1.wasExplicitlyConfigured());
820 EXPECT_TRUE(s2.wasExplicitlyConfigured());
821}
822
823TEST(QueryMapTest, Basic) {
824 DnsTlsQueryMap map;
825
826 EXPECT_TRUE(map.empty());
827
828 bytevec q0 = make_query(999, SIZE);
829 bytevec q1 = make_query(888, SIZE);
830 bytevec q2 = make_query(777, SIZE);
831
832 auto f0 = map.recordQuery(makeSlice(q0));
833 auto f1 = map.recordQuery(makeSlice(q1));
834 auto f2 = map.recordQuery(makeSlice(q2));
835
836 // Check return values of recordQuery
837 EXPECT_EQ(0, f0->query.newId);
838 EXPECT_EQ(1, f1->query.newId);
839 EXPECT_EQ(2, f2->query.newId);
840
841 // Check side effects of recordQuery
842 EXPECT_FALSE(map.empty());
843
844 auto all = map.getAll();
845 EXPECT_EQ(3U, all.size());
846
847 EXPECT_EQ(0, all[0].newId);
848 EXPECT_EQ(1, all[1].newId);
849 EXPECT_EQ(2, all[2].newId);
850
851 EXPECT_EQ(makeSlice(q0), all[0].query);
852 EXPECT_EQ(makeSlice(q1), all[1].query);
853 EXPECT_EQ(makeSlice(q2), all[2].query);
854
855 bytevec a0 = make_query(0, SIZE);
856 bytevec a1 = make_query(1, SIZE);
857 bytevec a2 = make_query(2, SIZE);
858
859 // Return responses out of order
860 map.onResponse(a2);
861 map.onResponse(a0);
862 map.onResponse(a1);
863
864 EXPECT_TRUE(map.empty());
865
866 auto r0 = f0->result.get();
867 auto r1 = f1->result.get();
868 auto r2 = f2->result.get();
869
870 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
871 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
872 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
873
874 const bytevec& d0 = r0.response;
875 const bytevec& d1 = r1.response;
876 const bytevec& d2 = r2.response;
877
878 // The ID should match the query
879 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
880 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
881 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
882 // The body should match the answer
883 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
884 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
885 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
886}
887
888TEST(QueryMapTest, FillHole) {
889 DnsTlsQueryMap map;
890 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
891 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
892 futures[i] = map.recordQuery(makeSlice(QUERY));
893 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
894 EXPECT_EQ(i, futures[i]->query.newId);
895 }
896
897 // The map should now be full.
898 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
899
900 // Trying to add another query should fail because the map is full.
901 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
902
903 // Send an answer to query 40000
904 auto answer = make_query(40000, SIZE);
905 map.onResponse(answer);
906 auto result = futures[40000]->result.get();
907 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
908 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
909 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
910 bytevec(result.response.begin() + 2, result.response.end()));
911
912 // There should now be room in the map.
913 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
914 auto f = map.recordQuery(makeSlice(QUERY));
915 ASSERT_TRUE(f);
916 EXPECT_EQ(40000, f->query.newId);
917
918 // The map should now be full again.
919 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
920 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
921}
922
Ben Schwartz62176fd2019-01-22 17:32:17 -0500923class StubObserver : public IDnsTlsSocketObserver {
924 public:
925 bool closed = false;
926 void onResponse(std::vector<uint8_t>) override {}
927
928 void onClosed() override { closed = true; }
929};
930
931TEST(DnsTlsSocketTest, SlowDestructor) {
932 constexpr char tls_addr[] = "127.0.0.3";
933 constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
934 // This test doesn't perform any queries, so the backend address can be invalid.
935 constexpr char backend_addr[] = "192.0.2.1";
936 constexpr char backend_port[] = "1";
937
938 test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
939 ASSERT_TRUE(tls.startServer());
940
941 DnsTlsServer server;
942 parseServer(tls_addr, 8530, &server.ss);
943
944 StubObserver observer;
945 ASSERT_FALSE(observer.closed);
946 DnsTlsSessionCache cache;
947 auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
948 ASSERT_TRUE(socket->initialize());
949
950 // Test: Time the socket destructor. This should be fast.
951 auto before = std::chrono::steady_clock::now();
952 socket.reset();
953 auto after = std::chrono::steady_clock::now();
954 auto delay = after - before;
955 ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
956 EXPECT_TRUE(observer.closed);
957 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
958 // it will wait for the timeout, which is expected to take 20seconds.
959 EXPECT_LT(delay, std::chrono::seconds{5});
960}
961
Mike Yuc52739e2018-10-19 21:06:32 +0800962} // end of namespace net
963} // end of namespace android