blob: 5ab95bded94c8a4fa9469a2f1c74c0916132a480 [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "dns_tls_test"
18
19#include <gtest/gtest.h>
20
21#include "dns/DnsTlsDispatcher.h"
22#include "dns/DnsTlsServer.h"
23#include "dns/DnsTlsSessionCache.h"
24#include "dns/DnsTlsSocket.h"
25#include "dns/DnsTlsTransport.h"
26#include "dns/IDnsTlsSocket.h"
27#include "dns/IDnsTlsSocketFactory.h"
28
29#include <chrono>
30#include <arpa/inet.h>
31#include <android-base/macros.h>
32#include <netdutils/Slice.h>
33
34#include "log/log.h"
35
36namespace android {
37namespace net {
38
39using netdutils::Slice;
40using netdutils::makeSlice;
41
42typedef std::vector<uint8_t> bytevec;
43
44static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
45 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
46 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
47 // IPv4 parse succeeded, so it's IPv4
48 sin->sin_family = AF_INET;
49 sin->sin_port = htons(port);
50 return;
51 }
52 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
53 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
54 // IPv6 parse succeeded, so it's IPv6.
55 sin6->sin6_family = AF_INET6;
56 sin6->sin6_port = htons(port);
57 return;
58 }
59 ALOGE("Failed to parse server address: %s", server);
60}
61
62bytevec FINGERPRINT1 = { 1 };
63
64std::string SERVERNAME1 = "dns.example.com";
65
66// BaseTest just provides constants that are useful for the tests.
67class BaseTest : public ::testing::Test {
68protected:
69 BaseTest() {
70 parseServer("192.0.2.1", 853, &V4ADDR1);
71 parseServer("192.0.2.2", 853, &V4ADDR2);
72
73 SERVER1 = DnsTlsServer(V4ADDR1);
74 SERVER1.fingerprints.insert(FINGERPRINT1);
75 SERVER1.name = SERVERNAME1;
76 }
77
78 sockaddr_storage V4ADDR1;
79 sockaddr_storage V4ADDR2;
80
81 DnsTlsServer SERVER1;
82};
83
84bytevec make_query(uint16_t id, size_t size) {
85 bytevec vec(size);
86 vec[0] = id >> 8;
87 vec[1] = id;
88 // Arbitrarily fill the query body with unique data.
89 for (size_t i = 2; i < size; ++i) {
90 vec[i] = id + i;
91 }
92 return vec;
93}
94
95// Query constants
96const unsigned MARK = 123;
97const uint16_t ID = 52;
98const uint16_t SIZE = 22;
99const bytevec QUERY = make_query(ID, SIZE);
100
101template <class T>
102class FakeSocketFactory : public IDnsTlsSocketFactory {
103public:
104 FakeSocketFactory() {}
105 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
106 const DnsTlsServer& server ATTRIBUTE_UNUSED,
107 unsigned mark ATTRIBUTE_UNUSED,
108 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
109 return std::make_unique<T>();
110 }
111};
112
113bytevec make_echo(uint16_t id, const Slice query) {
114 bytevec response(query.size() + 2);
115 response[0] = id >> 8;
116 response[1] = id;
117 // Echo the query as the fake response.
118 memcpy(response.data() + 2, query.base(), query.size());
119 return response;
120}
121
122// Simplest possible fake server. This just echoes the query as the response.
123class FakeSocketEcho : public IDnsTlsSocket {
124public:
125 FakeSocketEcho() {}
126 DnsTlsServer::Result query(uint16_t id, const Slice query) override {
127 // Return the response immediately.
128 return { .code = DnsTlsServer::Response::success, .response = make_echo(id, query) };
129 }
130};
131
132class TransportTest : public BaseTest {};
133
134TEST_F(TransportTest, Query) {
135 FakeSocketFactory<FakeSocketEcho> factory;
136 DnsTlsTransport transport(SERVER1, MARK, &factory);
137 auto r = transport.query(makeSlice(QUERY));
138
139 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
140 EXPECT_EQ(QUERY, r.response);
141}
142
143TEST_F(TransportTest, SerialQueries) {
144 FakeSocketFactory<FakeSocketEcho> factory;
145 DnsTlsTransport transport(SERVER1, MARK, &factory);
146 // Send more than 65536 queries serially.
147 for (int i = 0; i < 100000; ++i) {
148 auto r = transport.query(makeSlice(QUERY));
149
150 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
151 EXPECT_EQ(QUERY, r.response);
152 }
153}
154
155// Returning null from the factory indicates a connection failure.
156class NullSocketFactory : public IDnsTlsSocketFactory {
157public:
158 NullSocketFactory() {}
159 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
160 const DnsTlsServer& server ATTRIBUTE_UNUSED,
161 unsigned mark ATTRIBUTE_UNUSED,
162 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
163 return nullptr;
164 }
165};
166
167TEST_F(TransportTest, ConnectFail) {
168 NullSocketFactory factory;
169 DnsTlsTransport transport(SERVER1, MARK, &factory);
170 auto r = transport.query(makeSlice(QUERY));
171
172 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
173 EXPECT_TRUE(r.response.empty());
174}
175
176// Dispatcher tests
177class DispatcherTest : public BaseTest {};
178
179TEST_F(DispatcherTest, Query) {
180 bytevec ans(4096);
181 int resplen = 0;
182
183 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
184 DnsTlsDispatcher dispatcher(std::move(factory));
185 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
186 makeSlice(ans), &resplen);
187
188 EXPECT_EQ(DnsTlsTransport::Response::success, r);
189 EXPECT_EQ(int(QUERY.size()), resplen);
190 ans.resize(resplen);
191 EXPECT_EQ(QUERY, ans);
192}
193
194TEST_F(DispatcherTest, AnswerTooLarge) {
195 bytevec ans(SIZE - 1); // Too small to hold the answer
196 int resplen = 0;
197
198 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
199 DnsTlsDispatcher dispatcher(std::move(factory));
200 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
201 makeSlice(ans), &resplen);
202
203 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
204}
205
206template<class T>
207class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
208public:
209 TrackingFakeSocketFactory() {}
210 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
211 const DnsTlsServer& server,
212 unsigned mark,
213 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
214 std::lock_guard<std::mutex> guard(mLock);
215 keys.emplace(mark, server);
216 return std::make_unique<T>();
217 }
218 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
219private:
220 std::mutex mLock;
221};
222
223TEST_F(DispatcherTest, Dispatching) {
224 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketEcho>>();
225 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
226 DnsTlsDispatcher dispatcher(std::move(factory));
227
228 // Populate a vector of two servers and two socket marks, four combinations
229 // in total.
230 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
231 keys.emplace_back(MARK, SERVER1);
232 keys.emplace_back(MARK + 1, SERVER1);
233 keys.emplace_back(MARK, V4ADDR2);
234 keys.emplace_back(MARK + 1, V4ADDR2);
235
236 // Do one query on each server. They should all succeed.
237 std::vector<std::thread> threads;
238 for (size_t i = 0; i < keys.size(); ++i) {
239 auto key = keys[i % keys.size()];
240 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
241 auto q = make_query(i, SIZE);
242 bytevec ans(4096);
243 int resplen = 0;
244 unsigned mark = key.first;
245 const DnsTlsServer& server = key.second;
246 auto r = dispatcher->query(server, mark, makeSlice(q),
247 makeSlice(ans), &resplen);
248 EXPECT_EQ(DnsTlsTransport::Response::success, r);
249 EXPECT_EQ(int(q.size()), resplen);
250 ans.resize(resplen);
251 EXPECT_EQ(q, ans);
252 }, &dispatcher);
253 }
254 for (auto& thread : threads) {
255 thread.join();
256 }
257 // We expect that the factory created one socket for each key.
258 EXPECT_EQ(keys.size(), weak_factory->keys.size());
259 for (auto& key : keys) {
260 EXPECT_EQ(1U, weak_factory->keys.count(key));
261 }
262}
263
264} // end of namespace net
265} // end of namespace android