blob: bb5bfe5680cab4224a913bb809c8a752c7fb296f [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "dns_tls_test"
18
19#include <gtest/gtest.h>
20
21#include "dns/DnsTlsDispatcher.h"
Ben Schwartz33860762017-10-25 14:41:02 -040022#include "dns/DnsTlsQueryMap.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040023#include "dns/DnsTlsServer.h"
24#include "dns/DnsTlsSessionCache.h"
25#include "dns/DnsTlsSocket.h"
26#include "dns/DnsTlsTransport.h"
27#include "dns/IDnsTlsSocket.h"
28#include "dns/IDnsTlsSocketFactory.h"
Ben Schwartz33860762017-10-25 14:41:02 -040029#include "dns/IDnsTlsSocketObserver.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040030
31#include <chrono>
32#include <arpa/inet.h>
33#include <android-base/macros.h>
34#include <netdutils/Slice.h>
35
36#include "log/log.h"
37
38namespace android {
39namespace net {
40
41using netdutils::Slice;
42using netdutils::makeSlice;
43
44typedef std::vector<uint8_t> bytevec;
45
46static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
47 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
48 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
49 // IPv4 parse succeeded, so it's IPv4
50 sin->sin_family = AF_INET;
51 sin->sin_port = htons(port);
52 return;
53 }
54 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
55 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
56 // IPv6 parse succeeded, so it's IPv6.
57 sin6->sin6_family = AF_INET6;
58 sin6->sin6_port = htons(port);
59 return;
60 }
61 ALOGE("Failed to parse server address: %s", server);
62}
63
64bytevec FINGERPRINT1 = { 1 };
Ben Schwartze5595152017-10-25 14:41:02 -040065bytevec FINGERPRINT2 = { 2 };
Ben Schwartzded1b702017-10-25 14:41:02 -040066
67std::string SERVERNAME1 = "dns.example.com";
Ben Schwartze5595152017-10-25 14:41:02 -040068std::string SERVERNAME2 = "dns.example.org";
Ben Schwartzded1b702017-10-25 14:41:02 -040069
70// BaseTest just provides constants that are useful for the tests.
71class BaseTest : public ::testing::Test {
72protected:
73 BaseTest() {
74 parseServer("192.0.2.1", 853, &V4ADDR1);
75 parseServer("192.0.2.2", 853, &V4ADDR2);
Ben Schwartze5595152017-10-25 14:41:02 -040076 parseServer("2001:db8::1", 853, &V6ADDR1);
77 parseServer("2001:db8::2", 853, &V6ADDR2);
Ben Schwartzded1b702017-10-25 14:41:02 -040078
79 SERVER1 = DnsTlsServer(V4ADDR1);
80 SERVER1.fingerprints.insert(FINGERPRINT1);
81 SERVER1.name = SERVERNAME1;
82 }
83
84 sockaddr_storage V4ADDR1;
85 sockaddr_storage V4ADDR2;
Ben Schwartze5595152017-10-25 14:41:02 -040086 sockaddr_storage V6ADDR1;
87 sockaddr_storage V6ADDR2;
Ben Schwartzded1b702017-10-25 14:41:02 -040088
89 DnsTlsServer SERVER1;
90};
91
92bytevec make_query(uint16_t id, size_t size) {
93 bytevec vec(size);
94 vec[0] = id >> 8;
95 vec[1] = id;
96 // Arbitrarily fill the query body with unique data.
97 for (size_t i = 2; i < size; ++i) {
98 vec[i] = id + i;
99 }
100 return vec;
101}
102
103// Query constants
104const unsigned MARK = 123;
105const uint16_t ID = 52;
106const uint16_t SIZE = 22;
107const bytevec QUERY = make_query(ID, SIZE);
108
109template <class T>
110class FakeSocketFactory : public IDnsTlsSocketFactory {
111public:
112 FakeSocketFactory() {}
113 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
114 const DnsTlsServer& server ATTRIBUTE_UNUSED,
115 unsigned mark ATTRIBUTE_UNUSED,
Ben Schwartz33860762017-10-25 14:41:02 -0400116 IDnsTlsSocketObserver* observer,
Ben Schwartzded1b702017-10-25 14:41:02 -0400117 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
Ben Schwartz33860762017-10-25 14:41:02 -0400118 return std::make_unique<T>(observer);
Ben Schwartzded1b702017-10-25 14:41:02 -0400119 }
120};
121
122bytevec make_echo(uint16_t id, const Slice query) {
123 bytevec response(query.size() + 2);
124 response[0] = id >> 8;
125 response[1] = id;
126 // Echo the query as the fake response.
127 memcpy(response.data() + 2, query.base(), query.size());
128 return response;
129}
130
131// Simplest possible fake server. This just echoes the query as the response.
132class FakeSocketEcho : public IDnsTlsSocket {
133public:
Ben Schwartz33860762017-10-25 14:41:02 -0400134 FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
135 bool query(uint16_t id, const Slice query) override {
136 // Return the response immediately (asynchronously).
137 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
Bernie Innocenti0f167432018-05-17 22:25:54 +0900138 return true;
Ben Schwartzded1b702017-10-25 14:41:02 -0400139 }
Ben Schwartz33860762017-10-25 14:41:02 -0400140private:
141 IDnsTlsSocketObserver* const mObserver;
Ben Schwartzded1b702017-10-25 14:41:02 -0400142};
143
144class TransportTest : public BaseTest {};
145
146TEST_F(TransportTest, Query) {
147 FakeSocketFactory<FakeSocketEcho> factory;
148 DnsTlsTransport transport(SERVER1, MARK, &factory);
Ben Schwartz33860762017-10-25 14:41:02 -0400149 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400150
151 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
152 EXPECT_EQ(QUERY, r.response);
153}
154
Ben Schwartz33860762017-10-25 14:41:02 -0400155TEST_F(TransportTest, SerialQueries_100000) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400156 FakeSocketFactory<FakeSocketEcho> factory;
157 DnsTlsTransport transport(SERVER1, MARK, &factory);
158 // Send more than 65536 queries serially.
159 for (int i = 0; i < 100000; ++i) {
Ben Schwartz33860762017-10-25 14:41:02 -0400160 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400161
162 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
163 EXPECT_EQ(QUERY, r.response);
164 }
165}
166
Ben Schwartz33860762017-10-25 14:41:02 -0400167// These queries might be handled in serial or parallel as they race the
168// responses.
169TEST_F(TransportTest, RacingQueries_10000) {
170 FakeSocketFactory<FakeSocketEcho> factory;
171 DnsTlsTransport transport(SERVER1, MARK, &factory);
172 std::vector<std::future<DnsTlsTransport::Result>> results;
173 // Fewer than 65536 queries to avoid ID exhaustion.
174 for (int i = 0; i < 10000; ++i) {
175 results.push_back(transport.query(makeSlice(QUERY)));
176 }
177 for (auto& result : results) {
178 auto r = result.get();
179 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
180 EXPECT_EQ(QUERY, r.response);
181 }
182}
183
184// A server that waits until sDelay queries are queued before responding.
185class FakeSocketDelay : public IDnsTlsSocket {
186public:
187 FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
188 ~FakeSocketDelay() { std::lock_guard<std::mutex> guard(mLock); }
189 static size_t sDelay;
190 static bool sReverse;
191
192 bool query(uint16_t id, const Slice query) override {
193 ALOGD("FakeSocketDelay got query with ID %d", int(id));
194 std::lock_guard<std::mutex> guard(mLock);
195 // Check for duplicate IDs.
196 EXPECT_EQ(0U, mIds.count(id));
197 mIds.insert(id);
198
199 // Store response.
200 mResponses.push_back(make_echo(id, query));
201
202 ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay);
203 if (mResponses.size() == sDelay) {
204 std::thread(&FakeSocketDelay::sendResponses, this).detach();
205 }
206 return true;
207 }
208private:
209 void sendResponses() {
210 std::lock_guard<std::mutex> guard(mLock);
211 if (sReverse) {
212 std::reverse(std::begin(mResponses), std::end(mResponses));
213 }
214 for (auto& response : mResponses) {
215 mObserver->onResponse(response);
216 }
217 mIds.clear();
218 mResponses.clear();
219 }
220
221 std::mutex mLock;
222 IDnsTlsSocketObserver* const mObserver;
223 std::set<uint16_t> mIds GUARDED_BY(mLock);
224 std::vector<bytevec> mResponses GUARDED_BY(mLock);
225};
226
227size_t FakeSocketDelay::sDelay;
228bool FakeSocketDelay::sReverse;
229
230TEST_F(TransportTest, ParallelColliding) {
231 FakeSocketDelay::sDelay = 10;
232 FakeSocketDelay::sReverse = false;
233 FakeSocketFactory<FakeSocketDelay> factory;
234 DnsTlsTransport transport(SERVER1, MARK, &factory);
235 std::vector<std::future<DnsTlsTransport::Result>> results;
236 // Fewer than 65536 queries to avoid ID exhaustion.
237 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
238 results.push_back(transport.query(makeSlice(QUERY)));
239 }
240 for (auto& result : results) {
241 auto r = result.get();
242 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
243 EXPECT_EQ(QUERY, r.response);
244 }
245}
246
247TEST_F(TransportTest, ParallelColliding_Max) {
248 FakeSocketDelay::sDelay = 65536;
249 FakeSocketDelay::sReverse = false;
250 FakeSocketFactory<FakeSocketDelay> factory;
251 DnsTlsTransport transport(SERVER1, MARK, &factory);
252 std::vector<std::future<DnsTlsTransport::Result>> results;
253 // Exactly 65536 queries should still be possible in parallel,
254 // even if they all have the same original ID.
255 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
256 results.push_back(transport.query(makeSlice(QUERY)));
257 }
258 for (auto& result : results) {
259 auto r = result.get();
260 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
261 EXPECT_EQ(QUERY, r.response);
262 }
263}
264
265TEST_F(TransportTest, ParallelUnique) {
266 FakeSocketDelay::sDelay = 10;
267 FakeSocketDelay::sReverse = false;
268 FakeSocketFactory<FakeSocketDelay> factory;
269 DnsTlsTransport transport(SERVER1, MARK, &factory);
270 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
271 std::vector<std::future<DnsTlsTransport::Result>> results;
272 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
273 queries[i] = make_query(i, SIZE);
274 results.push_back(transport.query(makeSlice(queries[i])));
275 }
276 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
277 auto r = results[i].get();
278 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
279 EXPECT_EQ(queries[i], r.response);
280 }
281}
282
283TEST_F(TransportTest, ParallelUnique_Max) {
284 FakeSocketDelay::sDelay = 65536;
285 FakeSocketDelay::sReverse = false;
286 FakeSocketFactory<FakeSocketDelay> factory;
287 DnsTlsTransport transport(SERVER1, MARK, &factory);
288 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
289 std::vector<std::future<DnsTlsTransport::Result>> results;
290 // Exactly 65536 queries should still be possible in parallel,
291 // and they should all be mapped correctly back to the original ID.
292 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
293 queries[i] = make_query(i, SIZE);
294 results.push_back(transport.query(makeSlice(queries[i])));
295 }
296 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
297 auto r = results[i].get();
298 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
299 EXPECT_EQ(queries[i], r.response);
300 }
301}
302
303TEST_F(TransportTest, IdExhaustion) {
304 // A delay of 65537 is unreachable, because the maximum number
305 // of outstanding queries is 65536.
306 FakeSocketDelay::sDelay = 65537;
307 FakeSocketDelay::sReverse = false;
308 FakeSocketFactory<FakeSocketDelay> factory;
309 DnsTlsTransport transport(SERVER1, MARK, &factory);
310 std::vector<std::future<DnsTlsTransport::Result>> results;
311 // Issue the maximum number of queries.
312 for (int i = 0; i < 65536; ++i) {
313 results.push_back(transport.query(makeSlice(QUERY)));
314 }
315
316 // The ID space is now full, so subsequent queries should fail immediately.
317 auto r = transport.query(makeSlice(QUERY)).get();
318 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
319 EXPECT_TRUE(r.response.empty());
320
321 for (auto& result : results) {
322 // All other queries should remain outstanding.
323 EXPECT_EQ(std::future_status::timeout,
324 result.wait_for(std::chrono::duration<int>::zero()));
325 }
326}
327
328// Responses can come back from the server in any order. This should have no
329// effect on Transport's observed behavior.
330TEST_F(TransportTest, ReverseOrder) {
331 FakeSocketDelay::sDelay = 10;
332 FakeSocketDelay::sReverse = true;
333 FakeSocketFactory<FakeSocketDelay> factory;
334 DnsTlsTransport transport(SERVER1, MARK, &factory);
335 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
336 std::vector<std::future<DnsTlsTransport::Result>> results;
337 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
338 queries[i] = make_query(i, SIZE);
339 results.push_back(transport.query(makeSlice(queries[i])));
340 }
341 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
342 auto r = results[i].get();
343 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
344 EXPECT_EQ(queries[i], r.response);
345 }
346}
347
348TEST_F(TransportTest, ReverseOrder_Max) {
349 FakeSocketDelay::sDelay = 65536;
350 FakeSocketDelay::sReverse = true;
351 FakeSocketFactory<FakeSocketDelay> factory;
352 DnsTlsTransport transport(SERVER1, MARK, &factory);
353 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
354 std::vector<std::future<DnsTlsTransport::Result>> results;
355 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
356 queries[i] = make_query(i, SIZE);
357 results.push_back(transport.query(makeSlice(queries[i])));
358 }
359 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
360 auto r = results[i].get();
361 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
362 EXPECT_EQ(queries[i], r.response);
363 }
364}
365
Ben Schwartzded1b702017-10-25 14:41:02 -0400366// Returning null from the factory indicates a connection failure.
367class NullSocketFactory : public IDnsTlsSocketFactory {
368public:
369 NullSocketFactory() {}
370 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
371 const DnsTlsServer& server ATTRIBUTE_UNUSED,
372 unsigned mark ATTRIBUTE_UNUSED,
Ben Schwartz33860762017-10-25 14:41:02 -0400373 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
Ben Schwartzded1b702017-10-25 14:41:02 -0400374 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
375 return nullptr;
376 }
377};
378
379TEST_F(TransportTest, ConnectFail) {
380 NullSocketFactory factory;
381 DnsTlsTransport transport(SERVER1, MARK, &factory);
Ben Schwartz33860762017-10-25 14:41:02 -0400382 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400383
384 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
385 EXPECT_TRUE(r.response.empty());
386}
387
Ben Schwartz33860762017-10-25 14:41:02 -0400388// Simulate a socket that connects but then immediately receives a server
389// close notification.
390class FakeSocketClose : public IDnsTlsSocket {
391public:
392 FakeSocketClose(IDnsTlsSocketObserver* observer) :
393 mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
394 ~FakeSocketClose() { mCloser.join(); }
395 bool query(uint16_t id ATTRIBUTE_UNUSED,
396 const Slice query ATTRIBUTE_UNUSED) override {
397 return true;
398 }
399private:
400 std::thread mCloser;
401};
402
403TEST_F(TransportTest, CloseRetryFail) {
404 FakeSocketFactory<FakeSocketClose> factory;
405 DnsTlsTransport transport(SERVER1, MARK, &factory);
406 auto r = transport.query(makeSlice(QUERY)).get();
407
408 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
409 EXPECT_TRUE(r.response.empty());
410}
411
412// Simulate a server that occasionally closes the connection and silently
413// drops some queries.
414class FakeSocketLimited : public IDnsTlsSocket {
415public:
416 static int sLimit; // Number of queries to answer per socket.
417 static size_t sMaxSize; // Silently discard queries greater than this size.
418 FakeSocketLimited(IDnsTlsSocketObserver* observer) :
419 mObserver(observer), mQueries(0) {}
420 ~FakeSocketLimited() {
421 {
422 ALOGD("~FakeSocketLimited acquiring mLock");
423 std::lock_guard<std::mutex> guard(mLock);
424 ALOGD("~FakeSocketLimited acquired mLock");
425 for (auto& thread : mThreads) {
426 ALOGD("~FakeSocketLimited joining response thread");
427 thread.join();
428 ALOGD("~FakeSocketLimited joined response thread");
429 }
430 mThreads.clear();
431 }
432
433 if (mCloser) {
434 ALOGD("~FakeSocketLimited joining closer thread");
435 mCloser->join();
436 ALOGD("~FakeSocketLimited joined closer thread");
437 }
438 }
439 bool query(uint16_t id, const Slice query) override {
440 ALOGD("FakeSocketLimited::query acquiring mLock");
441 std::lock_guard<std::mutex> guard(mLock);
442 ALOGD("FakeSocketLimited::query acquired mLock");
443 ++mQueries;
444
445 if (mQueries <= sLimit) {
446 ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize);
447 if (query.size() <= sMaxSize) {
448 // Return the response immediately (asynchronously).
449 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
450 }
451 }
452 if (mQueries == sLimit) {
453 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
454 }
455 return mQueries <= sLimit;
456 }
457private:
458 void sendClose() {
459 {
460 ALOGD("FakeSocketLimited::sendClose acquiring mLock");
461 std::lock_guard<std::mutex> guard(mLock);
462 ALOGD("FakeSocketLimited::sendClose acquired mLock");
463 for (auto& thread : mThreads) {
464 ALOGD("FakeSocketLimited::sendClose joining response thread");
465 thread.join();
466 ALOGD("FakeSocketLimited::sendClose joined response thread");
467 }
468 mThreads.clear();
469 }
470 mObserver->onClosed();
471 }
472 std::mutex mLock;
473 IDnsTlsSocketObserver* const mObserver;
474 int mQueries GUARDED_BY(mLock);
475 std::vector<std::thread> mThreads GUARDED_BY(mLock);
476 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
477};
478
479int FakeSocketLimited::sLimit;
480size_t FakeSocketLimited::sMaxSize;
481
482TEST_F(TransportTest, SilentDrop) {
483 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
484 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
485 FakeSocketFactory<FakeSocketLimited> factory;
486 DnsTlsTransport transport(SERVER1, MARK, &factory);
487
488 // Queue up 10 queries. They will all be ignored, and after the 10th,
489 // the socket will close. Transport will retry them all, until they
490 // all hit the retry limit and expire.
491 std::vector<std::future<DnsTlsTransport::Result>> results;
492 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
493 results.push_back(transport.query(makeSlice(QUERY)));
494 }
495 for (auto& result : results) {
496 auto r = result.get();
497 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
498 EXPECT_TRUE(r.response.empty());
499 }
500}
501
502TEST_F(TransportTest, PartialDrop) {
503 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
504 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
505 FakeSocketFactory<FakeSocketLimited> factory;
506 DnsTlsTransport transport(SERVER1, MARK, &factory);
507
508 // Queue up 100 queries, alternating "short" which will be served and "long"
509 // which will be dropped.
510 int num_queries = 10 * FakeSocketLimited::sLimit;
511 std::vector<bytevec> queries(num_queries);
512 std::vector<std::future<DnsTlsTransport::Result>> results;
513 for (int i = 0; i < num_queries; ++i) {
514 queries[i] = make_query(i, SIZE + (i % 2));
515 results.push_back(transport.query(makeSlice(queries[i])));
516 }
517 // Just check the short queries, which are at the even indices.
518 for (int i = 0; i < num_queries; i += 2) {
519 auto r = results[i].get();
520 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
521 EXPECT_EQ(queries[i], r.response);
522 }
523}
524
525// Simulate a malfunctioning server that injects extra miscellaneous
526// responses to queries that were not asked. This will cause wrong answers but
527// must not crash the Transport.
528class FakeSocketGarbage : public IDnsTlsSocket {
529public:
530 FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
531 // Inject a garbage event.
532 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
533 }
534 ~FakeSocketGarbage() {
535 std::lock_guard<std::mutex> guard(mLock);
536 for (auto& thread : mThreads) {
537 thread.join();
538 }
539 }
540 bool query(uint16_t id, const Slice query) override {
541 std::lock_guard<std::mutex> guard(mLock);
542 // Return the response twice.
543 auto echo = make_echo(id, query);
544 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
545 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
546 // Also return some other garbage
547 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
548 return true;
549 }
550private:
551 std::mutex mLock;
552 std::vector<std::thread> mThreads GUARDED_BY(mLock);
553 IDnsTlsSocketObserver* const mObserver;
554};
555
556TEST_F(TransportTest, IgnoringGarbage) {
557 FakeSocketFactory<FakeSocketGarbage> factory;
558 DnsTlsTransport transport(SERVER1, MARK, &factory);
559 for (int i = 0; i < 10; ++i) {
560 auto r = transport.query(makeSlice(QUERY)).get();
561
562 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
563 // Don't check the response because this server is malfunctioning.
564 }
565}
566
Ben Schwartzded1b702017-10-25 14:41:02 -0400567// Dispatcher tests
568class DispatcherTest : public BaseTest {};
569
570TEST_F(DispatcherTest, Query) {
571 bytevec ans(4096);
572 int resplen = 0;
573
574 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
575 DnsTlsDispatcher dispatcher(std::move(factory));
576 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
577 makeSlice(ans), &resplen);
578
579 EXPECT_EQ(DnsTlsTransport::Response::success, r);
580 EXPECT_EQ(int(QUERY.size()), resplen);
581 ans.resize(resplen);
582 EXPECT_EQ(QUERY, ans);
583}
584
585TEST_F(DispatcherTest, AnswerTooLarge) {
586 bytevec ans(SIZE - 1); // Too small to hold the answer
587 int resplen = 0;
588
589 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
590 DnsTlsDispatcher dispatcher(std::move(factory));
591 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
592 makeSlice(ans), &resplen);
593
594 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
595}
596
597template<class T>
598class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
599public:
600 TrackingFakeSocketFactory() {}
601 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
602 const DnsTlsServer& server,
603 unsigned mark,
Ben Schwartz33860762017-10-25 14:41:02 -0400604 IDnsTlsSocketObserver* observer,
Ben Schwartzded1b702017-10-25 14:41:02 -0400605 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
606 std::lock_guard<std::mutex> guard(mLock);
607 keys.emplace(mark, server);
Ben Schwartz33860762017-10-25 14:41:02 -0400608 return std::make_unique<T>(observer);
Ben Schwartzded1b702017-10-25 14:41:02 -0400609 }
610 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
611private:
612 std::mutex mLock;
613};
614
615TEST_F(DispatcherTest, Dispatching) {
Ben Schwartz33860762017-10-25 14:41:02 -0400616 FakeSocketDelay::sDelay = 5;
617 FakeSocketDelay::sReverse = true;
618 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
Ben Schwartzded1b702017-10-25 14:41:02 -0400619 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
620 DnsTlsDispatcher dispatcher(std::move(factory));
621
622 // Populate a vector of two servers and two socket marks, four combinations
623 // in total.
624 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
625 keys.emplace_back(MARK, SERVER1);
626 keys.emplace_back(MARK + 1, SERVER1);
627 keys.emplace_back(MARK, V4ADDR2);
628 keys.emplace_back(MARK + 1, V4ADDR2);
629
Ben Schwartz33860762017-10-25 14:41:02 -0400630 // Do several queries on each server. They should all succeed.
Ben Schwartzded1b702017-10-25 14:41:02 -0400631 std::vector<std::thread> threads;
Ben Schwartz33860762017-10-25 14:41:02 -0400632 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400633 auto key = keys[i % keys.size()];
634 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
635 auto q = make_query(i, SIZE);
636 bytevec ans(4096);
637 int resplen = 0;
638 unsigned mark = key.first;
639 const DnsTlsServer& server = key.second;
640 auto r = dispatcher->query(server, mark, makeSlice(q),
641 makeSlice(ans), &resplen);
642 EXPECT_EQ(DnsTlsTransport::Response::success, r);
643 EXPECT_EQ(int(q.size()), resplen);
644 ans.resize(resplen);
645 EXPECT_EQ(q, ans);
646 }, &dispatcher);
647 }
648 for (auto& thread : threads) {
649 thread.join();
650 }
651 // We expect that the factory created one socket for each key.
652 EXPECT_EQ(keys.size(), weak_factory->keys.size());
653 for (auto& key : keys) {
654 EXPECT_EQ(1U, weak_factory->keys.count(key));
655 }
656}
657
Ben Schwartze5595152017-10-25 14:41:02 -0400658// Check DnsTlsServer's comparison logic.
659AddressComparator ADDRESS_COMPARATOR;
660bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
661 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
662 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
663 EXPECT_FALSE(cmp1 && cmp2);
664 return !cmp1 && !cmp2;
665}
666
667void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
668 EXPECT_TRUE(s1 == s1);
669 EXPECT_TRUE(s2 == s2);
670 EXPECT_TRUE(isAddressEqual(s1, s1));
671 EXPECT_TRUE(isAddressEqual(s2, s2));
672
673 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
674 EXPECT_FALSE(s1 == s2);
675 EXPECT_FALSE(s2 == s1);
676}
677
678class ServerTest : public BaseTest {};
679
680TEST_F(ServerTest, IPv4) {
681 checkUnequal(V4ADDR1, V4ADDR2);
682 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
683}
684
685TEST_F(ServerTest, IPv6) {
686 checkUnequal(V6ADDR1, V6ADDR2);
687 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
688}
689
690TEST_F(ServerTest, MixedAddressFamily) {
691 checkUnequal(V6ADDR1, V4ADDR1);
692 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
693}
694
695TEST_F(ServerTest, IPv6ScopeId) {
696 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
697 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
698 addr1->sin6_scope_id = 1;
699 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
700 addr2->sin6_scope_id = 2;
701 checkUnequal(s1, s2);
702 EXPECT_FALSE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900703
704 EXPECT_FALSE(s1.wasExplicitlyConfigured());
705 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400706}
707
708TEST_F(ServerTest, IPv6FlowInfo) {
709 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
710 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
711 addr1->sin6_flowinfo = 1;
712 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
713 addr2->sin6_flowinfo = 2;
714 // All comparisons ignore flowinfo.
715 EXPECT_EQ(s1, s2);
716 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900717
718 EXPECT_FALSE(s1.wasExplicitlyConfigured());
719 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400720}
721
722TEST_F(ServerTest, Port) {
723 DnsTlsServer s1, s2;
724 parseServer("192.0.2.1", 853, &s1.ss);
725 parseServer("192.0.2.1", 854, &s2.ss);
726 checkUnequal(s1, s2);
727 EXPECT_TRUE(isAddressEqual(s1, s2));
728
729 DnsTlsServer s3, s4;
730 parseServer("2001:db8::1", 853, &s3.ss);
731 parseServer("2001:db8::1", 852, &s4.ss);
732 checkUnequal(s3, s4);
733 EXPECT_TRUE(isAddressEqual(s3, s4));
Erik Kline1564d482018-03-07 17:09:35 +0900734
735 EXPECT_FALSE(s1.wasExplicitlyConfigured());
736 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400737}
738
739TEST_F(ServerTest, Name) {
740 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
741 s1.name = SERVERNAME1;
742 checkUnequal(s1, s2);
743 s2.name = SERVERNAME2;
744 checkUnequal(s1, s2);
745 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900746
747 EXPECT_TRUE(s1.wasExplicitlyConfigured());
748 EXPECT_TRUE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400749}
750
751TEST_F(ServerTest, Fingerprint) {
752 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
753
754 s1.fingerprints.insert(FINGERPRINT1);
755 checkUnequal(s1, s2);
756 EXPECT_TRUE(isAddressEqual(s1, s2));
757
758 s2.fingerprints.insert(FINGERPRINT2);
759 checkUnequal(s1, s2);
760 EXPECT_TRUE(isAddressEqual(s1, s2));
761
762 s2.fingerprints.insert(FINGERPRINT1);
763 checkUnequal(s1, s2);
764 EXPECT_TRUE(isAddressEqual(s1, s2));
765
766 s1.fingerprints.insert(FINGERPRINT2);
767 EXPECT_EQ(s1, s2);
768 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900769
770 EXPECT_TRUE(s1.wasExplicitlyConfigured());
771 EXPECT_TRUE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400772}
773
Ben Schwartz33860762017-10-25 14:41:02 -0400774TEST(QueryMapTest, Basic) {
775 DnsTlsQueryMap map;
776
777 EXPECT_TRUE(map.empty());
778
779 bytevec q0 = make_query(999, SIZE);
780 bytevec q1 = make_query(888, SIZE);
781 bytevec q2 = make_query(777, SIZE);
782
783 auto f0 = map.recordQuery(makeSlice(q0));
784 auto f1 = map.recordQuery(makeSlice(q1));
785 auto f2 = map.recordQuery(makeSlice(q2));
786
787 // Check return values of recordQuery
788 EXPECT_EQ(0, f0->query.newId);
789 EXPECT_EQ(1, f1->query.newId);
790 EXPECT_EQ(2, f2->query.newId);
791
792 // Check side effects of recordQuery
793 EXPECT_FALSE(map.empty());
794
795 auto all = map.getAll();
796 EXPECT_EQ(3U, all.size());
797
798 EXPECT_EQ(0, all[0].newId);
799 EXPECT_EQ(1, all[1].newId);
800 EXPECT_EQ(2, all[2].newId);
801
802 EXPECT_EQ(makeSlice(q0), all[0].query);
803 EXPECT_EQ(makeSlice(q1), all[1].query);
804 EXPECT_EQ(makeSlice(q2), all[2].query);
805
806 bytevec a0 = make_query(0, SIZE);
807 bytevec a1 = make_query(1, SIZE);
808 bytevec a2 = make_query(2, SIZE);
809
810 // Return responses out of order
811 map.onResponse(a2);
812 map.onResponse(a0);
813 map.onResponse(a1);
814
815 EXPECT_TRUE(map.empty());
816
817 auto r0 = f0->result.get();
818 auto r1 = f1->result.get();
819 auto r2 = f2->result.get();
820
821 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
822 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
823 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
824
825 const bytevec& d0 = r0.response;
826 const bytevec& d1 = r1.response;
827 const bytevec& d2 = r2.response;
828
829 // The ID should match the query
830 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
831 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
832 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
833 // The body should match the answer
834 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
835 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
836 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
837}
838
839TEST(QueryMapTest, FillHole) {
840 DnsTlsQueryMap map;
841 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
842 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
843 futures[i] = map.recordQuery(makeSlice(QUERY));
844 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
845 EXPECT_EQ(i, futures[i]->query.newId);
846 }
847
848 // The map should now be full.
849 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
850
851 // Trying to add another query should fail because the map is full.
852 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
853
854 // Send an answer to query 40000
855 auto answer = make_query(40000, SIZE);
856 map.onResponse(answer);
857 auto result = futures[40000]->result.get();
858 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
859 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
860 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
861 bytevec(result.response.begin() + 2, result.response.end()));
862
863 // There should now be room in the map.
864 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
865 auto f = map.recordQuery(makeSlice(QUERY));
866 ASSERT_TRUE(f);
867 EXPECT_EQ(40000, f->query.newId);
868
869 // The map should now be full again.
870 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
871 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
872}
873
Ben Schwartzded1b702017-10-25 14:41:02 -0400874} // end of namespace net
875} // end of namespace android