blob: b854a38dc59d49b1aaad8dc01400f044bca3c16f [file] [log] [blame]
Mike Yua46fae72018-11-01 20:07:00 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "PrivateDnsConfiguration"
18#define DBG 0
19
20#include <log/log.h>
21#include <netdb.h>
22#include <sys/socket.h>
23
24#include "netd_resolv/DnsTlsTransport.h"
25#include "netd_resolv/PrivateDnsConfiguration.h"
26#include "netdutils/BackoffSequence.h"
27
28int resolv_set_private_dns_for_net(unsigned netid, uint32_t mark, const char** servers,
29 const unsigned numServers, const char* tlsName,
30 const uint8_t** fingerprints, const unsigned numFingerprint) {
31 std::vector<std::string> tlsServers;
32 for (unsigned i = 0; i < numServers; i++) {
33 tlsServers.push_back(std::string(servers[i]));
34 }
35
36 std::set<std::vector<uint8_t>> tlsFingerprints;
37 for (unsigned i = 0; i < numFingerprint; i++) {
38 // Each fingerprint stored are 32(SHA256_SIZE) bytes long.
39 tlsFingerprints.emplace(std::vector<uint8_t>(fingerprints[i], fingerprints[i] + 32));
40 }
41
42 return android::net::gPrivateDnsConfiguration.set(netid, mark, tlsServers, std::string(tlsName),
43 tlsFingerprints);
44}
45
46void resolv_delete_private_dns_for_net(unsigned netid) {
47 android::net::gPrivateDnsConfiguration.clear(netid);
48}
49
50void resolv_get_private_dns_status_for_net(unsigned netid, ExternalPrivateDnsStatus* status) {
51 android::net::gPrivateDnsConfiguration.getStatus(netid, status);
52}
53
54void resolv_register_private_dns_callback(private_dns_validated_callback callback) {
55 android::net::gPrivateDnsConfiguration.setCallback(callback);
56}
57
58namespace android {
59
60using android::netdutils::BackoffSequence;
61
62namespace net {
63
64std::string addrToString(const sockaddr_storage* addr) {
65 char out[INET6_ADDRSTRLEN] = {0};
66 getnameinfo((const sockaddr*) addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0,
67 NI_NUMERICHOST);
68 return std::string(out);
69}
70
71bool parseServer(const char* server, sockaddr_storage* parsed) {
72 addrinfo hints = {.ai_family = AF_UNSPEC, .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV};
73 addrinfo* res;
74
75 int err = getaddrinfo(server, "853", &hints, &res);
76 if (err != 0) {
77 ALOGW("Failed to parse server address (%s): %s", server, gai_strerror(err));
78 return false;
79 }
80
81 memcpy(parsed, res->ai_addr, res->ai_addrlen);
82 freeaddrinfo(res);
83 return true;
84}
85
86void PrivateDnsConfiguration::setCallback(private_dns_validated_callback callback) {
87 if (mCallback == nullptr) {
88 mCallback = callback;
89 }
90}
91
92int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
93 const std::vector<std::string>& servers, const std::string& name,
94 const std::set<std::vector<uint8_t>>& fingerprints) {
95 if (DBG) {
96 ALOGD("PrivateDnsConfiguration::set(%u, %zu, %s, %zu)", netId, servers.size(), name.c_str(),
97 fingerprints.size());
98 }
99
100 const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
101
102 // Parse the list of servers that has been passed in
103 std::set<DnsTlsServer> tlsServers;
104 for (size_t i = 0; i < servers.size(); ++i) {
105 sockaddr_storage parsed;
106 if (!parseServer(servers[i].c_str(), &parsed)) {
107 return -EINVAL;
108 }
109 DnsTlsServer server(parsed);
110 server.name = name;
111 server.fingerprints = fingerprints;
112 tlsServers.insert(server);
113 }
114
115 std::lock_guard guard(mPrivateDnsLock);
116 if (explicitlyConfigured) {
117 mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
118 } else if (!tlsServers.empty()) {
119 mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
120 } else {
121 mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
122 mPrivateDnsTransports.erase(netId);
123 return 0;
124 }
125
126 // Create the tracker if it was not present
127 auto netPair = mPrivateDnsTransports.find(netId);
128 if (netPair == mPrivateDnsTransports.end()) {
129 // No TLS tracker yet for this netId.
130 bool added;
131 std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
132 if (!added) {
133 ALOGE("Memory error while recording private DNS for netId %d", netId);
134 return -ENOMEM;
135 }
136 }
137 auto& tracker = netPair->second;
138
139 // Remove any servers from the tracker that are not in |servers| exactly.
140 for (auto it = tracker.begin(); it != tracker.end();) {
141 if (tlsServers.count(it->first) == 0) {
142 it = tracker.erase(it);
143 } else {
144 ++it;
145 }
146 }
147
148 // Add any new or changed servers to the tracker, and initiate async checks for them.
149 for (const auto& server : tlsServers) {
150 if (needsValidation(tracker, server)) {
151 validatePrivateDnsProvider(server, tracker, netId, mark);
152 }
153 }
154 return 0;
155}
156
157PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
158 PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
159
160 // This mutex is on the critical path of every DNS lookup.
161 //
162 // If the overhead of mutex acquisition proves too high, we could reduce
163 // it by maintaining an atomic_int32_t counter of TLS-enabled netids, or
164 // by using an RWLock.
165 std::lock_guard guard(mPrivateDnsLock);
166
167 const auto mode = mPrivateDnsModes.find(netId);
168 if (mode == mPrivateDnsModes.end()) return status;
169 status.mode = mode->second;
170
171 const auto netPair = mPrivateDnsTransports.find(netId);
172 if (netPair != mPrivateDnsTransports.end()) {
173 for (const auto& serverPair : netPair->second) {
174 if (serverPair.second == Validation::success) {
175 status.validatedServers.push_back(serverPair.first);
176 }
177 }
178 }
179
180 return status;
181}
182
183void PrivateDnsConfiguration::getStatus(unsigned netId, ExternalPrivateDnsStatus* status) {
184 // This mutex is on the critical path of every DNS lookup.
185 //
186 // If the overhead of mutex acquisition proves too high, we could reduce
187 // it by maintaining an atomic_int32_t counter of TLS-enabled netids, or
188 // by using an RWLock.
189 std::lock_guard guard(mPrivateDnsLock);
190
191 const auto mode = mPrivateDnsModes.find(netId);
192 if (mode == mPrivateDnsModes.end()) return;
193 status->mode = mode->second;
194
195 const auto netPair = mPrivateDnsTransports.find(netId);
196 if (netPair != mPrivateDnsTransports.end()) {
197 status->numServers = netPair->second.size();
198 int count = 0;
199 for (const auto& serverPair : netPair->second) {
200 status->serverStatus[count].ss = serverPair.first.ss;
201 status->serverStatus[count].hostname =
202 serverPair.first.name.empty() ? "" : serverPair.first.name.c_str();
203 status->serverStatus[count].validation = serverPair.second;
204 /*
205 unsigned numFingerprint = 0;
206 for (const auto& fp : serverPair.first.fingerprints) {
207 std::copy(
208 fp.begin(), fp.end(),
209 status->serverStatus[count].fingerprints.fingerprint[numFingerprint].data);
210 numFingerprint++;
211 }
212 status->serverStatus[count].fingerprints.num = numFingerprint;
213 */
214 count++;
215 }
216 }
217}
218
219void PrivateDnsConfiguration::clear(unsigned netId) {
220 if (DBG) {
221 ALOGD("PrivateDnsConfiguration::clear(%u)", netId);
222 }
223 std::lock_guard guard(mPrivateDnsLock);
224 mPrivateDnsModes.erase(netId);
225 mPrivateDnsTransports.erase(netId);
226}
227
228void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
229 PrivateDnsTracker& tracker, unsigned netId,
230 uint32_t mark) REQUIRES(mPrivateDnsLock) {
231 if (DBG) {
232 ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
233 }
234
235 tracker[server] = Validation::in_process;
236 if (DBG) {
237 ALOGD("Server %s marked as in_process. Tracker now has size %zu",
238 addrToString(&(server.ss)).c_str(), tracker.size());
239 }
240 // Note that capturing |server| and |netId| in this lambda create copies.
241 std::thread validate_thread([this, server, netId, mark] {
242 // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
243 //
244 // Start with a 1 minute delay and backoff to once per hour.
245 //
246 // Assumptions:
247 // [1] Each TLS validation is ~10KB of certs+handshake+payload.
248 // [2] Network typically provision clients with <=4 nameservers.
249 // [3] Average month has 30 days.
250 //
251 // Each validation pass in a given hour is ~1.2MB of data. And 24
252 // such validation passes per day is about ~30MB per month, in the
253 // worst case. Otherwise, this will cost ~600 SYNs per month
254 // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
255 auto backoff = BackoffSequence<>::Builder()
256 .withInitialRetransmissionTime(std::chrono::seconds(60))
257 .withMaximumRetransmissionTime(std::chrono::seconds(3600))
258 .build();
259
260 while (true) {
261 // ::validate() is a blocking call that performs network operations.
262 // It can take milliseconds to minutes, up to the SYN retry limit.
263 const bool success = DnsTlsTransport::validate(server, netId, mark);
264 if (DBG) {
265 ALOGD("validateDnsTlsServer returned %d for %s", success,
266 addrToString(&(server.ss)).c_str());
267 }
268
269 const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
270 if (!needs_reeval) {
271 break;
272 }
273
274 if (backoff.hasNextTimeout()) {
275 std::this_thread::sleep_for(backoff.getNextTimeout());
276 } else {
277 break;
278 }
279 }
280 });
281 validate_thread.detach();
282}
283
284bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
285 bool success) {
286 constexpr bool NEEDS_REEVALUATION = true;
287 constexpr bool DONT_REEVALUATE = false;
288
289 std::lock_guard guard(mPrivateDnsLock);
290
291 auto netPair = mPrivateDnsTransports.find(netId);
292 if (netPair == mPrivateDnsTransports.end()) {
293 ALOGW("netId %u was erased during private DNS validation", netId);
294 return DONT_REEVALUATE;
295 }
296
297 const auto mode = mPrivateDnsModes.find(netId);
298 if (mode == mPrivateDnsModes.end()) {
299 ALOGW("netId %u has no private DNS validation mode", netId);
300 return DONT_REEVALUATE;
301 }
302 const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
303
304 bool reevaluationStatus =
305 (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
306
307 auto& tracker = netPair->second;
308 auto serverPair = tracker.find(server);
309 if (serverPair == tracker.end()) {
310 ALOGW("Server %s was removed during private DNS validation",
311 addrToString(&(server.ss)).c_str());
312 success = false;
313 reevaluationStatus = DONT_REEVALUATE;
314 } else if (!(serverPair->first == server)) {
315 // TODO: It doesn't seem correct to overwrite the tracker entry for
316 // |server| down below in this circumstance... Fix this.
317 ALOGW("Server %s was changed during private DNS validation",
318 addrToString(&(server.ss)).c_str());
319 success = false;
320 reevaluationStatus = DONT_REEVALUATE;
321 }
322
323 // Invoke the callback to send a validation event to NetdEventListenerService.
324 if (mCallback != nullptr) {
325 const char* ipLiteral = addrToString(&(server.ss)).c_str();
326 const char* hostname = server.name.empty() ? "" : server.name.c_str();
327 mCallback(netId, ipLiteral, hostname, success);
328 }
329
330 if (success) {
331 tracker[server] = Validation::success;
332 if (DBG) {
333 ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
334 addrToString(&(server.ss)).c_str(), tracker.size());
335 }
336 } else {
337 // Validation failure is expected if a user is on a captive portal.
338 // TODO: Trigger a second validation attempt after captive portal login
339 // succeeds.
340 tracker[server] = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
341 : Validation::fail;
342 if (DBG) {
343 ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
344 }
345 }
346
347 return reevaluationStatus;
348}
349
350// Start validation for newly added servers as well as any servers that have
351// landed in Validation::fail state. Note that servers that have failed
352// multiple validation attempts but for which there is still a validating
353// thread running are marked as being Validation::in_process.
354bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
355 const DnsTlsServer& server) {
356 const auto& iter = tracker.find(server);
357 return (iter == tracker.end()) || (iter->second == Validation::fail);
358}
359
360PrivateDnsConfiguration gPrivateDnsConfiguration;
361
362} // namespace net
363} // namespace android