blob: cbedc34d29afb394fd3062e8424a498901745561 [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "dns_tls_test"
18
19#include <gtest/gtest.h>
20
21#include "dns/DnsTlsDispatcher.h"
Ben Schwartz33860762017-10-25 14:41:02 -040022#include "dns/DnsTlsQueryMap.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040023#include "dns/DnsTlsServer.h"
24#include "dns/DnsTlsSessionCache.h"
25#include "dns/DnsTlsSocket.h"
26#include "dns/DnsTlsTransport.h"
27#include "dns/IDnsTlsSocket.h"
28#include "dns/IDnsTlsSocketFactory.h"
Ben Schwartz33860762017-10-25 14:41:02 -040029#include "dns/IDnsTlsSocketObserver.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040030
31#include <chrono>
32#include <arpa/inet.h>
33#include <android-base/macros.h>
34#include <netdutils/Slice.h>
35
36#include "log/log.h"
37
38namespace android {
39namespace net {
40
41using netdutils::Slice;
42using netdutils::makeSlice;
43
44typedef std::vector<uint8_t> bytevec;
45
46static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
47 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
48 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
49 // IPv4 parse succeeded, so it's IPv4
50 sin->sin_family = AF_INET;
51 sin->sin_port = htons(port);
52 return;
53 }
54 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
55 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
56 // IPv6 parse succeeded, so it's IPv6.
57 sin6->sin6_family = AF_INET6;
58 sin6->sin6_port = htons(port);
59 return;
60 }
61 ALOGE("Failed to parse server address: %s", server);
62}
63
64bytevec FINGERPRINT1 = { 1 };
Ben Schwartze5595152017-10-25 14:41:02 -040065bytevec FINGERPRINT2 = { 2 };
Ben Schwartzded1b702017-10-25 14:41:02 -040066
67std::string SERVERNAME1 = "dns.example.com";
Ben Schwartze5595152017-10-25 14:41:02 -040068std::string SERVERNAME2 = "dns.example.org";
Ben Schwartzded1b702017-10-25 14:41:02 -040069
70// BaseTest just provides constants that are useful for the tests.
71class BaseTest : public ::testing::Test {
Erik Klineab999f12018-07-04 11:29:31 +090072 protected:
Ben Schwartzded1b702017-10-25 14:41:02 -040073 BaseTest() {
74 parseServer("192.0.2.1", 853, &V4ADDR1);
75 parseServer("192.0.2.2", 853, &V4ADDR2);
Ben Schwartze5595152017-10-25 14:41:02 -040076 parseServer("2001:db8::1", 853, &V6ADDR1);
77 parseServer("2001:db8::2", 853, &V6ADDR2);
Ben Schwartzded1b702017-10-25 14:41:02 -040078
79 SERVER1 = DnsTlsServer(V4ADDR1);
80 SERVER1.fingerprints.insert(FINGERPRINT1);
81 SERVER1.name = SERVERNAME1;
82 }
83
84 sockaddr_storage V4ADDR1;
85 sockaddr_storage V4ADDR2;
Ben Schwartze5595152017-10-25 14:41:02 -040086 sockaddr_storage V6ADDR1;
87 sockaddr_storage V6ADDR2;
Ben Schwartzded1b702017-10-25 14:41:02 -040088
89 DnsTlsServer SERVER1;
90};
91
92bytevec make_query(uint16_t id, size_t size) {
93 bytevec vec(size);
94 vec[0] = id >> 8;
95 vec[1] = id;
96 // Arbitrarily fill the query body with unique data.
97 for (size_t i = 2; i < size; ++i) {
98 vec[i] = id + i;
99 }
100 return vec;
101}
102
103// Query constants
104const unsigned MARK = 123;
105const uint16_t ID = 52;
106const uint16_t SIZE = 22;
107const bytevec QUERY = make_query(ID, SIZE);
108
109template <class T>
110class FakeSocketFactory : public IDnsTlsSocketFactory {
Erik Klineab999f12018-07-04 11:29:31 +0900111 public:
Ben Schwartzded1b702017-10-25 14:41:02 -0400112 FakeSocketFactory() {}
113 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
114 const DnsTlsServer& server ATTRIBUTE_UNUSED,
115 unsigned mark ATTRIBUTE_UNUSED,
Ben Schwartz33860762017-10-25 14:41:02 -0400116 IDnsTlsSocketObserver* observer,
Ben Schwartzded1b702017-10-25 14:41:02 -0400117 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
Ben Schwartz33860762017-10-25 14:41:02 -0400118 return std::make_unique<T>(observer);
Ben Schwartzded1b702017-10-25 14:41:02 -0400119 }
120};
121
122bytevec make_echo(uint16_t id, const Slice query) {
123 bytevec response(query.size() + 2);
124 response[0] = id >> 8;
125 response[1] = id;
126 // Echo the query as the fake response.
127 memcpy(response.data() + 2, query.base(), query.size());
128 return response;
129}
130
131// Simplest possible fake server. This just echoes the query as the response.
132class FakeSocketEcho : public IDnsTlsSocket {
Erik Klineab999f12018-07-04 11:29:31 +0900133 public:
134 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
Ben Schwartz33860762017-10-25 14:41:02 -0400135 bool query(uint16_t id, const Slice query) override {
136 // Return the response immediately (asynchronously).
137 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
Bernie Innocenti0f167432018-05-17 22:25:54 +0900138 return true;
Ben Schwartzded1b702017-10-25 14:41:02 -0400139 }
Erik Klineab999f12018-07-04 11:29:31 +0900140
141 private:
Ben Schwartz33860762017-10-25 14:41:02 -0400142 IDnsTlsSocketObserver* const mObserver;
Ben Schwartzded1b702017-10-25 14:41:02 -0400143};
144
145class TransportTest : public BaseTest {};
146
147TEST_F(TransportTest, Query) {
148 FakeSocketFactory<FakeSocketEcho> factory;
149 DnsTlsTransport transport(SERVER1, MARK, &factory);
Ben Schwartz33860762017-10-25 14:41:02 -0400150 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400151
152 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
153 EXPECT_EQ(QUERY, r.response);
154}
155
Ben Schwartz33860762017-10-25 14:41:02 -0400156TEST_F(TransportTest, SerialQueries_100000) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400157 FakeSocketFactory<FakeSocketEcho> factory;
158 DnsTlsTransport transport(SERVER1, MARK, &factory);
159 // Send more than 65536 queries serially.
160 for (int i = 0; i < 100000; ++i) {
Ben Schwartz33860762017-10-25 14:41:02 -0400161 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400162
163 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
164 EXPECT_EQ(QUERY, r.response);
165 }
166}
167
Ben Schwartz33860762017-10-25 14:41:02 -0400168// These queries might be handled in serial or parallel as they race the
169// responses.
170TEST_F(TransportTest, RacingQueries_10000) {
171 FakeSocketFactory<FakeSocketEcho> factory;
172 DnsTlsTransport transport(SERVER1, MARK, &factory);
173 std::vector<std::future<DnsTlsTransport::Result>> results;
174 // Fewer than 65536 queries to avoid ID exhaustion.
175 for (int i = 0; i < 10000; ++i) {
176 results.push_back(transport.query(makeSlice(QUERY)));
177 }
178 for (auto& result : results) {
179 auto r = result.get();
180 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
181 EXPECT_EQ(QUERY, r.response);
182 }
183}
184
185// A server that waits until sDelay queries are queued before responding.
186class FakeSocketDelay : public IDnsTlsSocket {
Erik Klineab999f12018-07-04 11:29:31 +0900187 public:
188 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900189 ~FakeSocketDelay() { std::lock_guard guard(mLock); }
Ben Schwartz33860762017-10-25 14:41:02 -0400190 static size_t sDelay;
191 static bool sReverse;
192
193 bool query(uint16_t id, const Slice query) override {
194 ALOGD("FakeSocketDelay got query with ID %d", int(id));
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900195 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400196 // Check for duplicate IDs.
197 EXPECT_EQ(0U, mIds.count(id));
198 mIds.insert(id);
199
200 // Store response.
201 mResponses.push_back(make_echo(id, query));
202
203 ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay);
204 if (mResponses.size() == sDelay) {
205 std::thread(&FakeSocketDelay::sendResponses, this).detach();
206 }
207 return true;
208 }
Erik Klineab999f12018-07-04 11:29:31 +0900209
210 private:
Ben Schwartz33860762017-10-25 14:41:02 -0400211 void sendResponses() {
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900212 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400213 if (sReverse) {
214 std::reverse(std::begin(mResponses), std::end(mResponses));
215 }
216 for (auto& response : mResponses) {
217 mObserver->onResponse(response);
218 }
219 mIds.clear();
220 mResponses.clear();
221 }
222
223 std::mutex mLock;
224 IDnsTlsSocketObserver* const mObserver;
225 std::set<uint16_t> mIds GUARDED_BY(mLock);
226 std::vector<bytevec> mResponses GUARDED_BY(mLock);
227};
228
229size_t FakeSocketDelay::sDelay;
230bool FakeSocketDelay::sReverse;
231
232TEST_F(TransportTest, ParallelColliding) {
233 FakeSocketDelay::sDelay = 10;
234 FakeSocketDelay::sReverse = false;
235 FakeSocketFactory<FakeSocketDelay> factory;
236 DnsTlsTransport transport(SERVER1, MARK, &factory);
237 std::vector<std::future<DnsTlsTransport::Result>> results;
238 // Fewer than 65536 queries to avoid ID exhaustion.
239 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
240 results.push_back(transport.query(makeSlice(QUERY)));
241 }
242 for (auto& result : results) {
243 auto r = result.get();
244 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
245 EXPECT_EQ(QUERY, r.response);
246 }
247}
248
249TEST_F(TransportTest, ParallelColliding_Max) {
250 FakeSocketDelay::sDelay = 65536;
251 FakeSocketDelay::sReverse = false;
252 FakeSocketFactory<FakeSocketDelay> factory;
253 DnsTlsTransport transport(SERVER1, MARK, &factory);
254 std::vector<std::future<DnsTlsTransport::Result>> results;
255 // Exactly 65536 queries should still be possible in parallel,
256 // even if they all have the same original ID.
257 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
258 results.push_back(transport.query(makeSlice(QUERY)));
259 }
260 for (auto& result : results) {
261 auto r = result.get();
262 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
263 EXPECT_EQ(QUERY, r.response);
264 }
265}
266
267TEST_F(TransportTest, ParallelUnique) {
268 FakeSocketDelay::sDelay = 10;
269 FakeSocketDelay::sReverse = false;
270 FakeSocketFactory<FakeSocketDelay> factory;
271 DnsTlsTransport transport(SERVER1, MARK, &factory);
272 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
273 std::vector<std::future<DnsTlsTransport::Result>> results;
274 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
275 queries[i] = make_query(i, SIZE);
276 results.push_back(transport.query(makeSlice(queries[i])));
277 }
278 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
279 auto r = results[i].get();
280 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
281 EXPECT_EQ(queries[i], r.response);
282 }
283}
284
285TEST_F(TransportTest, ParallelUnique_Max) {
286 FakeSocketDelay::sDelay = 65536;
287 FakeSocketDelay::sReverse = false;
288 FakeSocketFactory<FakeSocketDelay> factory;
289 DnsTlsTransport transport(SERVER1, MARK, &factory);
290 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
291 std::vector<std::future<DnsTlsTransport::Result>> results;
292 // Exactly 65536 queries should still be possible in parallel,
293 // and they should all be mapped correctly back to the original ID.
294 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
295 queries[i] = make_query(i, SIZE);
296 results.push_back(transport.query(makeSlice(queries[i])));
297 }
298 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
299 auto r = results[i].get();
300 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
301 EXPECT_EQ(queries[i], r.response);
302 }
303}
304
305TEST_F(TransportTest, IdExhaustion) {
306 // A delay of 65537 is unreachable, because the maximum number
307 // of outstanding queries is 65536.
308 FakeSocketDelay::sDelay = 65537;
309 FakeSocketDelay::sReverse = false;
310 FakeSocketFactory<FakeSocketDelay> factory;
311 DnsTlsTransport transport(SERVER1, MARK, &factory);
312 std::vector<std::future<DnsTlsTransport::Result>> results;
313 // Issue the maximum number of queries.
314 for (int i = 0; i < 65536; ++i) {
315 results.push_back(transport.query(makeSlice(QUERY)));
316 }
317
318 // The ID space is now full, so subsequent queries should fail immediately.
319 auto r = transport.query(makeSlice(QUERY)).get();
320 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
321 EXPECT_TRUE(r.response.empty());
322
323 for (auto& result : results) {
324 // All other queries should remain outstanding.
325 EXPECT_EQ(std::future_status::timeout,
326 result.wait_for(std::chrono::duration<int>::zero()));
327 }
328}
329
330// Responses can come back from the server in any order. This should have no
331// effect on Transport's observed behavior.
332TEST_F(TransportTest, ReverseOrder) {
333 FakeSocketDelay::sDelay = 10;
334 FakeSocketDelay::sReverse = true;
335 FakeSocketFactory<FakeSocketDelay> factory;
336 DnsTlsTransport transport(SERVER1, MARK, &factory);
337 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
338 std::vector<std::future<DnsTlsTransport::Result>> results;
339 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
340 queries[i] = make_query(i, SIZE);
341 results.push_back(transport.query(makeSlice(queries[i])));
342 }
343 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
344 auto r = results[i].get();
345 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
346 EXPECT_EQ(queries[i], r.response);
347 }
348}
349
350TEST_F(TransportTest, ReverseOrder_Max) {
351 FakeSocketDelay::sDelay = 65536;
352 FakeSocketDelay::sReverse = true;
353 FakeSocketFactory<FakeSocketDelay> factory;
354 DnsTlsTransport transport(SERVER1, MARK, &factory);
355 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
356 std::vector<std::future<DnsTlsTransport::Result>> results;
357 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
358 queries[i] = make_query(i, SIZE);
359 results.push_back(transport.query(makeSlice(queries[i])));
360 }
361 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
362 auto r = results[i].get();
363 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
364 EXPECT_EQ(queries[i], r.response);
365 }
366}
367
Ben Schwartzded1b702017-10-25 14:41:02 -0400368// Returning null from the factory indicates a connection failure.
369class NullSocketFactory : public IDnsTlsSocketFactory {
Erik Klineab999f12018-07-04 11:29:31 +0900370 public:
Ben Schwartzded1b702017-10-25 14:41:02 -0400371 NullSocketFactory() {}
372 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
373 const DnsTlsServer& server ATTRIBUTE_UNUSED,
374 unsigned mark ATTRIBUTE_UNUSED,
Ben Schwartz33860762017-10-25 14:41:02 -0400375 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
Ben Schwartzded1b702017-10-25 14:41:02 -0400376 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
377 return nullptr;
378 }
379};
380
381TEST_F(TransportTest, ConnectFail) {
382 NullSocketFactory factory;
383 DnsTlsTransport transport(SERVER1, MARK, &factory);
Ben Schwartz33860762017-10-25 14:41:02 -0400384 auto r = transport.query(makeSlice(QUERY)).get();
Ben Schwartzded1b702017-10-25 14:41:02 -0400385
386 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
387 EXPECT_TRUE(r.response.empty());
388}
389
Ben Schwartz33860762017-10-25 14:41:02 -0400390// Simulate a socket that connects but then immediately receives a server
391// close notification.
392class FakeSocketClose : public IDnsTlsSocket {
Erik Klineab999f12018-07-04 11:29:31 +0900393 public:
394 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
395 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
Ben Schwartz33860762017-10-25 14:41:02 -0400396 ~FakeSocketClose() { mCloser.join(); }
397 bool query(uint16_t id ATTRIBUTE_UNUSED,
398 const Slice query ATTRIBUTE_UNUSED) override {
399 return true;
400 }
Erik Klineab999f12018-07-04 11:29:31 +0900401
402 private:
Ben Schwartz33860762017-10-25 14:41:02 -0400403 std::thread mCloser;
404};
405
406TEST_F(TransportTest, CloseRetryFail) {
407 FakeSocketFactory<FakeSocketClose> factory;
408 DnsTlsTransport transport(SERVER1, MARK, &factory);
409 auto r = transport.query(makeSlice(QUERY)).get();
410
411 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
412 EXPECT_TRUE(r.response.empty());
413}
414
415// Simulate a server that occasionally closes the connection and silently
416// drops some queries.
417class FakeSocketLimited : public IDnsTlsSocket {
Erik Klineab999f12018-07-04 11:29:31 +0900418 public:
Ben Schwartz33860762017-10-25 14:41:02 -0400419 static int sLimit; // Number of queries to answer per socket.
420 static size_t sMaxSize; // Silently discard queries greater than this size.
Erik Klineab999f12018-07-04 11:29:31 +0900421 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
422 : mObserver(observer), mQueries(0) {}
Ben Schwartz33860762017-10-25 14:41:02 -0400423 ~FakeSocketLimited() {
424 {
425 ALOGD("~FakeSocketLimited acquiring mLock");
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900426 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400427 ALOGD("~FakeSocketLimited acquired mLock");
428 for (auto& thread : mThreads) {
429 ALOGD("~FakeSocketLimited joining response thread");
430 thread.join();
431 ALOGD("~FakeSocketLimited joined response thread");
432 }
433 mThreads.clear();
434 }
435
436 if (mCloser) {
437 ALOGD("~FakeSocketLimited joining closer thread");
438 mCloser->join();
439 ALOGD("~FakeSocketLimited joined closer thread");
440 }
441 }
442 bool query(uint16_t id, const Slice query) override {
443 ALOGD("FakeSocketLimited::query acquiring mLock");
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900444 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400445 ALOGD("FakeSocketLimited::query acquired mLock");
446 ++mQueries;
447
448 if (mQueries <= sLimit) {
449 ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize);
450 if (query.size() <= sMaxSize) {
451 // Return the response immediately (asynchronously).
452 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
453 }
454 }
455 if (mQueries == sLimit) {
456 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
457 }
458 return mQueries <= sLimit;
459 }
Erik Klineab999f12018-07-04 11:29:31 +0900460
461 private:
Ben Schwartz33860762017-10-25 14:41:02 -0400462 void sendClose() {
463 {
464 ALOGD("FakeSocketLimited::sendClose acquiring mLock");
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900465 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400466 ALOGD("FakeSocketLimited::sendClose acquired mLock");
467 for (auto& thread : mThreads) {
468 ALOGD("FakeSocketLimited::sendClose joining response thread");
469 thread.join();
470 ALOGD("FakeSocketLimited::sendClose joined response thread");
471 }
472 mThreads.clear();
473 }
474 mObserver->onClosed();
475 }
476 std::mutex mLock;
477 IDnsTlsSocketObserver* const mObserver;
478 int mQueries GUARDED_BY(mLock);
479 std::vector<std::thread> mThreads GUARDED_BY(mLock);
480 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
481};
482
483int FakeSocketLimited::sLimit;
484size_t FakeSocketLimited::sMaxSize;
485
486TEST_F(TransportTest, SilentDrop) {
487 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
488 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
489 FakeSocketFactory<FakeSocketLimited> factory;
490 DnsTlsTransport transport(SERVER1, MARK, &factory);
491
492 // Queue up 10 queries. They will all be ignored, and after the 10th,
493 // the socket will close. Transport will retry them all, until they
494 // all hit the retry limit and expire.
495 std::vector<std::future<DnsTlsTransport::Result>> results;
496 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
497 results.push_back(transport.query(makeSlice(QUERY)));
498 }
499 for (auto& result : results) {
500 auto r = result.get();
501 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
502 EXPECT_TRUE(r.response.empty());
503 }
504}
505
506TEST_F(TransportTest, PartialDrop) {
507 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
508 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
509 FakeSocketFactory<FakeSocketLimited> factory;
510 DnsTlsTransport transport(SERVER1, MARK, &factory);
511
512 // Queue up 100 queries, alternating "short" which will be served and "long"
513 // which will be dropped.
514 int num_queries = 10 * FakeSocketLimited::sLimit;
515 std::vector<bytevec> queries(num_queries);
516 std::vector<std::future<DnsTlsTransport::Result>> results;
517 for (int i = 0; i < num_queries; ++i) {
518 queries[i] = make_query(i, SIZE + (i % 2));
519 results.push_back(transport.query(makeSlice(queries[i])));
520 }
521 // Just check the short queries, which are at the even indices.
522 for (int i = 0; i < num_queries; i += 2) {
523 auto r = results[i].get();
524 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
525 EXPECT_EQ(queries[i], r.response);
526 }
527}
528
529// Simulate a malfunctioning server that injects extra miscellaneous
530// responses to queries that were not asked. This will cause wrong answers but
531// must not crash the Transport.
532class FakeSocketGarbage : public IDnsTlsSocket {
Erik Klineab999f12018-07-04 11:29:31 +0900533 public:
534 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
Ben Schwartz33860762017-10-25 14:41:02 -0400535 // Inject a garbage event.
536 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
537 }
538 ~FakeSocketGarbage() {
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900539 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400540 for (auto& thread : mThreads) {
541 thread.join();
542 }
543 }
544 bool query(uint16_t id, const Slice query) override {
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900545 std::lock_guard guard(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400546 // Return the response twice.
547 auto echo = make_echo(id, query);
548 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
549 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
550 // Also return some other garbage
551 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
552 return true;
553 }
Erik Klineab999f12018-07-04 11:29:31 +0900554
555 private:
Ben Schwartz33860762017-10-25 14:41:02 -0400556 std::mutex mLock;
557 std::vector<std::thread> mThreads GUARDED_BY(mLock);
558 IDnsTlsSocketObserver* const mObserver;
559};
560
561TEST_F(TransportTest, IgnoringGarbage) {
562 FakeSocketFactory<FakeSocketGarbage> factory;
563 DnsTlsTransport transport(SERVER1, MARK, &factory);
564 for (int i = 0; i < 10; ++i) {
565 auto r = transport.query(makeSlice(QUERY)).get();
566
567 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
568 // Don't check the response because this server is malfunctioning.
569 }
570}
571
Ben Schwartzded1b702017-10-25 14:41:02 -0400572// Dispatcher tests
573class DispatcherTest : public BaseTest {};
574
575TEST_F(DispatcherTest, Query) {
576 bytevec ans(4096);
577 int resplen = 0;
578
579 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
580 DnsTlsDispatcher dispatcher(std::move(factory));
581 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
582 makeSlice(ans), &resplen);
583
584 EXPECT_EQ(DnsTlsTransport::Response::success, r);
585 EXPECT_EQ(int(QUERY.size()), resplen);
586 ans.resize(resplen);
587 EXPECT_EQ(QUERY, ans);
588}
589
590TEST_F(DispatcherTest, AnswerTooLarge) {
591 bytevec ans(SIZE - 1); // Too small to hold the answer
592 int resplen = 0;
593
594 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
595 DnsTlsDispatcher dispatcher(std::move(factory));
596 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
597 makeSlice(ans), &resplen);
598
599 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
600}
601
602template<class T>
603class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
Erik Klineab999f12018-07-04 11:29:31 +0900604 public:
Ben Schwartzded1b702017-10-25 14:41:02 -0400605 TrackingFakeSocketFactory() {}
606 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
607 const DnsTlsServer& server,
608 unsigned mark,
Ben Schwartz33860762017-10-25 14:41:02 -0400609 IDnsTlsSocketObserver* observer,
Ben Schwartzded1b702017-10-25 14:41:02 -0400610 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
Bernie Innocentiabf8a342018-08-10 15:17:16 +0900611 std::lock_guard guard(mLock);
Ben Schwartzded1b702017-10-25 14:41:02 -0400612 keys.emplace(mark, server);
Ben Schwartz33860762017-10-25 14:41:02 -0400613 return std::make_unique<T>(observer);
Ben Schwartzded1b702017-10-25 14:41:02 -0400614 }
615 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
Erik Klineab999f12018-07-04 11:29:31 +0900616
617 private:
Ben Schwartzded1b702017-10-25 14:41:02 -0400618 std::mutex mLock;
619};
620
621TEST_F(DispatcherTest, Dispatching) {
Ben Schwartz33860762017-10-25 14:41:02 -0400622 FakeSocketDelay::sDelay = 5;
623 FakeSocketDelay::sReverse = true;
624 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
Ben Schwartzded1b702017-10-25 14:41:02 -0400625 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
626 DnsTlsDispatcher dispatcher(std::move(factory));
627
628 // Populate a vector of two servers and two socket marks, four combinations
629 // in total.
630 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
631 keys.emplace_back(MARK, SERVER1);
632 keys.emplace_back(MARK + 1, SERVER1);
633 keys.emplace_back(MARK, V4ADDR2);
634 keys.emplace_back(MARK + 1, V4ADDR2);
635
Ben Schwartz33860762017-10-25 14:41:02 -0400636 // Do several queries on each server. They should all succeed.
Ben Schwartzded1b702017-10-25 14:41:02 -0400637 std::vector<std::thread> threads;
Ben Schwartz33860762017-10-25 14:41:02 -0400638 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
Ben Schwartzded1b702017-10-25 14:41:02 -0400639 auto key = keys[i % keys.size()];
640 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
641 auto q = make_query(i, SIZE);
642 bytevec ans(4096);
643 int resplen = 0;
644 unsigned mark = key.first;
645 const DnsTlsServer& server = key.second;
646 auto r = dispatcher->query(server, mark, makeSlice(q),
647 makeSlice(ans), &resplen);
648 EXPECT_EQ(DnsTlsTransport::Response::success, r);
649 EXPECT_EQ(int(q.size()), resplen);
650 ans.resize(resplen);
651 EXPECT_EQ(q, ans);
652 }, &dispatcher);
653 }
654 for (auto& thread : threads) {
655 thread.join();
656 }
657 // We expect that the factory created one socket for each key.
658 EXPECT_EQ(keys.size(), weak_factory->keys.size());
659 for (auto& key : keys) {
660 EXPECT_EQ(1U, weak_factory->keys.count(key));
661 }
662}
663
Ben Schwartze5595152017-10-25 14:41:02 -0400664// Check DnsTlsServer's comparison logic.
665AddressComparator ADDRESS_COMPARATOR;
666bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
667 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
668 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
669 EXPECT_FALSE(cmp1 && cmp2);
670 return !cmp1 && !cmp2;
671}
672
673void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
674 EXPECT_TRUE(s1 == s1);
675 EXPECT_TRUE(s2 == s2);
676 EXPECT_TRUE(isAddressEqual(s1, s1));
677 EXPECT_TRUE(isAddressEqual(s2, s2));
678
679 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
680 EXPECT_FALSE(s1 == s2);
681 EXPECT_FALSE(s2 == s1);
682}
683
684class ServerTest : public BaseTest {};
685
686TEST_F(ServerTest, IPv4) {
687 checkUnequal(V4ADDR1, V4ADDR2);
688 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
689}
690
691TEST_F(ServerTest, IPv6) {
692 checkUnequal(V6ADDR1, V6ADDR2);
693 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
694}
695
696TEST_F(ServerTest, MixedAddressFamily) {
697 checkUnequal(V6ADDR1, V4ADDR1);
698 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
699}
700
701TEST_F(ServerTest, IPv6ScopeId) {
702 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
703 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
704 addr1->sin6_scope_id = 1;
705 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
706 addr2->sin6_scope_id = 2;
707 checkUnequal(s1, s2);
708 EXPECT_FALSE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900709
710 EXPECT_FALSE(s1.wasExplicitlyConfigured());
711 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400712}
713
714TEST_F(ServerTest, IPv6FlowInfo) {
715 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
716 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
717 addr1->sin6_flowinfo = 1;
718 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
719 addr2->sin6_flowinfo = 2;
720 // All comparisons ignore flowinfo.
721 EXPECT_EQ(s1, s2);
722 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900723
724 EXPECT_FALSE(s1.wasExplicitlyConfigured());
725 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400726}
727
728TEST_F(ServerTest, Port) {
729 DnsTlsServer s1, s2;
730 parseServer("192.0.2.1", 853, &s1.ss);
731 parseServer("192.0.2.1", 854, &s2.ss);
732 checkUnequal(s1, s2);
733 EXPECT_TRUE(isAddressEqual(s1, s2));
734
735 DnsTlsServer s3, s4;
736 parseServer("2001:db8::1", 853, &s3.ss);
737 parseServer("2001:db8::1", 852, &s4.ss);
738 checkUnequal(s3, s4);
739 EXPECT_TRUE(isAddressEqual(s3, s4));
Erik Kline1564d482018-03-07 17:09:35 +0900740
741 EXPECT_FALSE(s1.wasExplicitlyConfigured());
742 EXPECT_FALSE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400743}
744
745TEST_F(ServerTest, Name) {
746 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
747 s1.name = SERVERNAME1;
748 checkUnequal(s1, s2);
749 s2.name = SERVERNAME2;
750 checkUnequal(s1, s2);
751 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900752
753 EXPECT_TRUE(s1.wasExplicitlyConfigured());
754 EXPECT_TRUE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400755}
756
757TEST_F(ServerTest, Fingerprint) {
758 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
759
760 s1.fingerprints.insert(FINGERPRINT1);
761 checkUnequal(s1, s2);
762 EXPECT_TRUE(isAddressEqual(s1, s2));
763
764 s2.fingerprints.insert(FINGERPRINT2);
765 checkUnequal(s1, s2);
766 EXPECT_TRUE(isAddressEqual(s1, s2));
767
768 s2.fingerprints.insert(FINGERPRINT1);
769 checkUnequal(s1, s2);
770 EXPECT_TRUE(isAddressEqual(s1, s2));
771
772 s1.fingerprints.insert(FINGERPRINT2);
773 EXPECT_EQ(s1, s2);
774 EXPECT_TRUE(isAddressEqual(s1, s2));
Erik Kline1564d482018-03-07 17:09:35 +0900775
776 EXPECT_TRUE(s1.wasExplicitlyConfigured());
777 EXPECT_TRUE(s2.wasExplicitlyConfigured());
Ben Schwartze5595152017-10-25 14:41:02 -0400778}
779
Ben Schwartz33860762017-10-25 14:41:02 -0400780TEST(QueryMapTest, Basic) {
781 DnsTlsQueryMap map;
782
783 EXPECT_TRUE(map.empty());
784
785 bytevec q0 = make_query(999, SIZE);
786 bytevec q1 = make_query(888, SIZE);
787 bytevec q2 = make_query(777, SIZE);
788
789 auto f0 = map.recordQuery(makeSlice(q0));
790 auto f1 = map.recordQuery(makeSlice(q1));
791 auto f2 = map.recordQuery(makeSlice(q2));
792
793 // Check return values of recordQuery
794 EXPECT_EQ(0, f0->query.newId);
795 EXPECT_EQ(1, f1->query.newId);
796 EXPECT_EQ(2, f2->query.newId);
797
798 // Check side effects of recordQuery
799 EXPECT_FALSE(map.empty());
800
801 auto all = map.getAll();
802 EXPECT_EQ(3U, all.size());
803
804 EXPECT_EQ(0, all[0].newId);
805 EXPECT_EQ(1, all[1].newId);
806 EXPECT_EQ(2, all[2].newId);
807
808 EXPECT_EQ(makeSlice(q0), all[0].query);
809 EXPECT_EQ(makeSlice(q1), all[1].query);
810 EXPECT_EQ(makeSlice(q2), all[2].query);
811
812 bytevec a0 = make_query(0, SIZE);
813 bytevec a1 = make_query(1, SIZE);
814 bytevec a2 = make_query(2, SIZE);
815
816 // Return responses out of order
817 map.onResponse(a2);
818 map.onResponse(a0);
819 map.onResponse(a1);
820
821 EXPECT_TRUE(map.empty());
822
823 auto r0 = f0->result.get();
824 auto r1 = f1->result.get();
825 auto r2 = f2->result.get();
826
827 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
828 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
829 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
830
831 const bytevec& d0 = r0.response;
832 const bytevec& d1 = r1.response;
833 const bytevec& d2 = r2.response;
834
835 // The ID should match the query
836 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
837 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
838 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
839 // The body should match the answer
840 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
841 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
842 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
843}
844
845TEST(QueryMapTest, FillHole) {
846 DnsTlsQueryMap map;
847 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
848 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
849 futures[i] = map.recordQuery(makeSlice(QUERY));
850 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
851 EXPECT_EQ(i, futures[i]->query.newId);
852 }
853
854 // The map should now be full.
855 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
856
857 // Trying to add another query should fail because the map is full.
858 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
859
860 // Send an answer to query 40000
861 auto answer = make_query(40000, SIZE);
862 map.onResponse(answer);
863 auto result = futures[40000]->result.get();
864 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
865 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
866 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
867 bytevec(result.response.begin() + 2, result.response.end()));
868
869 // There should now be room in the map.
870 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
871 auto f = map.recordQuery(makeSlice(QUERY));
872 ASSERT_TRUE(f);
873 EXPECT_EQ(40000, f->query.newId);
874
875 // The map should now be full again.
876 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
877 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
878}
879
Ben Schwartzded1b702017-10-25 14:41:02 -0400880} // end of namespace net
881} // end of namespace android