blob: 4cb4d501d146d99b099a250a43f51a516adc29b5 [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "dns_tls_test"
18
19#include <gtest/gtest.h>
20
21#include "dns/DnsTlsDispatcher.h"
22#include "dns/DnsTlsServer.h"
23#include "dns/DnsTlsSessionCache.h"
24#include "dns/DnsTlsSocket.h"
25#include "dns/DnsTlsTransport.h"
26#include "dns/IDnsTlsSocket.h"
27#include "dns/IDnsTlsSocketFactory.h"
28
29#include <chrono>
30#include <arpa/inet.h>
31#include <android-base/macros.h>
32#include <netdutils/Slice.h>
33
34#include "log/log.h"
35
36namespace android {
37namespace net {
38
39using netdutils::Slice;
40using netdutils::makeSlice;
41
42typedef std::vector<uint8_t> bytevec;
43
44static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
45 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
46 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
47 // IPv4 parse succeeded, so it's IPv4
48 sin->sin_family = AF_INET;
49 sin->sin_port = htons(port);
50 return;
51 }
52 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
53 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
54 // IPv6 parse succeeded, so it's IPv6.
55 sin6->sin6_family = AF_INET6;
56 sin6->sin6_port = htons(port);
57 return;
58 }
59 ALOGE("Failed to parse server address: %s", server);
60}
61
62bytevec FINGERPRINT1 = { 1 };
Ben Schwartze5595152017-10-25 14:41:02 -040063bytevec FINGERPRINT2 = { 2 };
Ben Schwartzded1b702017-10-25 14:41:02 -040064
65std::string SERVERNAME1 = "dns.example.com";
Ben Schwartze5595152017-10-25 14:41:02 -040066std::string SERVERNAME2 = "dns.example.org";
Ben Schwartzded1b702017-10-25 14:41:02 -040067
68// BaseTest just provides constants that are useful for the tests.
69class BaseTest : public ::testing::Test {
70protected:
71 BaseTest() {
72 parseServer("192.0.2.1", 853, &V4ADDR1);
73 parseServer("192.0.2.2", 853, &V4ADDR2);
Ben Schwartze5595152017-10-25 14:41:02 -040074 parseServer("2001:db8::1", 853, &V6ADDR1);
75 parseServer("2001:db8::2", 853, &V6ADDR2);
Ben Schwartzded1b702017-10-25 14:41:02 -040076
77 SERVER1 = DnsTlsServer(V4ADDR1);
78 SERVER1.fingerprints.insert(FINGERPRINT1);
79 SERVER1.name = SERVERNAME1;
80 }
81
82 sockaddr_storage V4ADDR1;
83 sockaddr_storage V4ADDR2;
Ben Schwartze5595152017-10-25 14:41:02 -040084 sockaddr_storage V6ADDR1;
85 sockaddr_storage V6ADDR2;
Ben Schwartzded1b702017-10-25 14:41:02 -040086
87 DnsTlsServer SERVER1;
88};
89
90bytevec make_query(uint16_t id, size_t size) {
91 bytevec vec(size);
92 vec[0] = id >> 8;
93 vec[1] = id;
94 // Arbitrarily fill the query body with unique data.
95 for (size_t i = 2; i < size; ++i) {
96 vec[i] = id + i;
97 }
98 return vec;
99}
100
101// Query constants
102const unsigned MARK = 123;
103const uint16_t ID = 52;
104const uint16_t SIZE = 22;
105const bytevec QUERY = make_query(ID, SIZE);
106
107template <class T>
108class FakeSocketFactory : public IDnsTlsSocketFactory {
109public:
110 FakeSocketFactory() {}
111 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
112 const DnsTlsServer& server ATTRIBUTE_UNUSED,
113 unsigned mark ATTRIBUTE_UNUSED,
114 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
115 return std::make_unique<T>();
116 }
117};
118
119bytevec make_echo(uint16_t id, const Slice query) {
120 bytevec response(query.size() + 2);
121 response[0] = id >> 8;
122 response[1] = id;
123 // Echo the query as the fake response.
124 memcpy(response.data() + 2, query.base(), query.size());
125 return response;
126}
127
128// Simplest possible fake server. This just echoes the query as the response.
129class FakeSocketEcho : public IDnsTlsSocket {
130public:
131 FakeSocketEcho() {}
132 DnsTlsServer::Result query(uint16_t id, const Slice query) override {
133 // Return the response immediately.
134 return { .code = DnsTlsServer::Response::success, .response = make_echo(id, query) };
135 }
136};
137
138class TransportTest : public BaseTest {};
139
140TEST_F(TransportTest, Query) {
141 FakeSocketFactory<FakeSocketEcho> factory;
142 DnsTlsTransport transport(SERVER1, MARK, &factory);
143 auto r = transport.query(makeSlice(QUERY));
144
145 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
146 EXPECT_EQ(QUERY, r.response);
147}
148
149TEST_F(TransportTest, SerialQueries) {
150 FakeSocketFactory<FakeSocketEcho> factory;
151 DnsTlsTransport transport(SERVER1, MARK, &factory);
152 // Send more than 65536 queries serially.
153 for (int i = 0; i < 100000; ++i) {
154 auto r = transport.query(makeSlice(QUERY));
155
156 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
157 EXPECT_EQ(QUERY, r.response);
158 }
159}
160
161// Returning null from the factory indicates a connection failure.
162class NullSocketFactory : public IDnsTlsSocketFactory {
163public:
164 NullSocketFactory() {}
165 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
166 const DnsTlsServer& server ATTRIBUTE_UNUSED,
167 unsigned mark ATTRIBUTE_UNUSED,
168 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
169 return nullptr;
170 }
171};
172
173TEST_F(TransportTest, ConnectFail) {
174 NullSocketFactory factory;
175 DnsTlsTransport transport(SERVER1, MARK, &factory);
176 auto r = transport.query(makeSlice(QUERY));
177
178 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
179 EXPECT_TRUE(r.response.empty());
180}
181
182// Dispatcher tests
183class DispatcherTest : public BaseTest {};
184
185TEST_F(DispatcherTest, Query) {
186 bytevec ans(4096);
187 int resplen = 0;
188
189 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
190 DnsTlsDispatcher dispatcher(std::move(factory));
191 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
192 makeSlice(ans), &resplen);
193
194 EXPECT_EQ(DnsTlsTransport::Response::success, r);
195 EXPECT_EQ(int(QUERY.size()), resplen);
196 ans.resize(resplen);
197 EXPECT_EQ(QUERY, ans);
198}
199
200TEST_F(DispatcherTest, AnswerTooLarge) {
201 bytevec ans(SIZE - 1); // Too small to hold the answer
202 int resplen = 0;
203
204 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
205 DnsTlsDispatcher dispatcher(std::move(factory));
206 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
207 makeSlice(ans), &resplen);
208
209 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
210}
211
212template<class T>
213class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
214public:
215 TrackingFakeSocketFactory() {}
216 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
217 const DnsTlsServer& server,
218 unsigned mark,
219 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
220 std::lock_guard<std::mutex> guard(mLock);
221 keys.emplace(mark, server);
222 return std::make_unique<T>();
223 }
224 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
225private:
226 std::mutex mLock;
227};
228
229TEST_F(DispatcherTest, Dispatching) {
230 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketEcho>>();
231 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
232 DnsTlsDispatcher dispatcher(std::move(factory));
233
234 // Populate a vector of two servers and two socket marks, four combinations
235 // in total.
236 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
237 keys.emplace_back(MARK, SERVER1);
238 keys.emplace_back(MARK + 1, SERVER1);
239 keys.emplace_back(MARK, V4ADDR2);
240 keys.emplace_back(MARK + 1, V4ADDR2);
241
242 // Do one query on each server. They should all succeed.
243 std::vector<std::thread> threads;
244 for (size_t i = 0; i < keys.size(); ++i) {
245 auto key = keys[i % keys.size()];
246 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
247 auto q = make_query(i, SIZE);
248 bytevec ans(4096);
249 int resplen = 0;
250 unsigned mark = key.first;
251 const DnsTlsServer& server = key.second;
252 auto r = dispatcher->query(server, mark, makeSlice(q),
253 makeSlice(ans), &resplen);
254 EXPECT_EQ(DnsTlsTransport::Response::success, r);
255 EXPECT_EQ(int(q.size()), resplen);
256 ans.resize(resplen);
257 EXPECT_EQ(q, ans);
258 }, &dispatcher);
259 }
260 for (auto& thread : threads) {
261 thread.join();
262 }
263 // We expect that the factory created one socket for each key.
264 EXPECT_EQ(keys.size(), weak_factory->keys.size());
265 for (auto& key : keys) {
266 EXPECT_EQ(1U, weak_factory->keys.count(key));
267 }
268}
269
Ben Schwartze5595152017-10-25 14:41:02 -0400270// Check DnsTlsServer's comparison logic.
271AddressComparator ADDRESS_COMPARATOR;
272bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
273 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
274 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
275 EXPECT_FALSE(cmp1 && cmp2);
276 return !cmp1 && !cmp2;
277}
278
279void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
280 EXPECT_TRUE(s1 == s1);
281 EXPECT_TRUE(s2 == s2);
282 EXPECT_TRUE(isAddressEqual(s1, s1));
283 EXPECT_TRUE(isAddressEqual(s2, s2));
284
285 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
286 EXPECT_FALSE(s1 == s2);
287 EXPECT_FALSE(s2 == s1);
288}
289
290class ServerTest : public BaseTest {};
291
292TEST_F(ServerTest, IPv4) {
293 checkUnequal(V4ADDR1, V4ADDR2);
294 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
295}
296
297TEST_F(ServerTest, IPv6) {
298 checkUnequal(V6ADDR1, V6ADDR2);
299 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
300}
301
302TEST_F(ServerTest, MixedAddressFamily) {
303 checkUnequal(V6ADDR1, V4ADDR1);
304 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
305}
306
307TEST_F(ServerTest, IPv6ScopeId) {
308 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
309 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
310 addr1->sin6_scope_id = 1;
311 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
312 addr2->sin6_scope_id = 2;
313 checkUnequal(s1, s2);
314 EXPECT_FALSE(isAddressEqual(s1, s2));
315}
316
317TEST_F(ServerTest, IPv6FlowInfo) {
318 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
319 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
320 addr1->sin6_flowinfo = 1;
321 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
322 addr2->sin6_flowinfo = 2;
323 // All comparisons ignore flowinfo.
324 EXPECT_EQ(s1, s2);
325 EXPECT_TRUE(isAddressEqual(s1, s2));
326}
327
328TEST_F(ServerTest, Port) {
329 DnsTlsServer s1, s2;
330 parseServer("192.0.2.1", 853, &s1.ss);
331 parseServer("192.0.2.1", 854, &s2.ss);
332 checkUnequal(s1, s2);
333 EXPECT_TRUE(isAddressEqual(s1, s2));
334
335 DnsTlsServer s3, s4;
336 parseServer("2001:db8::1", 853, &s3.ss);
337 parseServer("2001:db8::1", 852, &s4.ss);
338 checkUnequal(s3, s4);
339 EXPECT_TRUE(isAddressEqual(s3, s4));
340}
341
342TEST_F(ServerTest, Name) {
343 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
344 s1.name = SERVERNAME1;
345 checkUnequal(s1, s2);
346 s2.name = SERVERNAME2;
347 checkUnequal(s1, s2);
348 EXPECT_TRUE(isAddressEqual(s1, s2));
349}
350
351TEST_F(ServerTest, Fingerprint) {
352 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
353
354 s1.fingerprints.insert(FINGERPRINT1);
355 checkUnequal(s1, s2);
356 EXPECT_TRUE(isAddressEqual(s1, s2));
357
358 s2.fingerprints.insert(FINGERPRINT2);
359 checkUnequal(s1, s2);
360 EXPECT_TRUE(isAddressEqual(s1, s2));
361
362 s2.fingerprints.insert(FINGERPRINT1);
363 checkUnequal(s1, s2);
364 EXPECT_TRUE(isAddressEqual(s1, s2));
365
366 s1.fingerprints.insert(FINGERPRINT2);
367 EXPECT_EQ(s1, s2);
368 EXPECT_TRUE(isAddressEqual(s1, s2));
369}
370
Ben Schwartzded1b702017-10-25 14:41:02 -0400371} // end of namespace net
372} // end of namespace android