Merge changes I92c05118,Ica83c11b
* changes:
Fix cgroup bpf program direction
Add dump function for trafficController
diff --git a/client/NetdClient.cpp b/client/NetdClient.cpp
index 821f488..fbbc9e7 100644
--- a/client/NetdClient.cpp
+++ b/client/NetdClient.cpp
@@ -122,6 +122,10 @@
if (netId != NETID_UNSET) {
return netId;
}
+ // Special case for DNS-over-TLS bypass; b/72345192 .
+ if ((netIdForResolv & ~NETID_USE_LOCAL_NAMESERVERS) != NETID_UNSET) {
+ return netIdForResolv;
+ }
netId = netIdForProcess;
if (netId != NETID_UNSET) {
return netId;
@@ -130,6 +134,9 @@
}
int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
+ const unsigned requestedNetId = netId;
+ netId &= ~NETID_USE_LOCAL_NAMESERVERS;
+
if (netId == NETID_UNSET) {
*target = netId;
return 0;
@@ -148,7 +155,7 @@
}
int error = setNetworkForSocket(netId, socketFd);
if (!error) {
- *target = netId;
+ *target = (target == &netIdForResolv) ? requestedNetId : netId;
}
close(socketFd);
return error;
diff --git a/include/NetdClient.h b/include/NetdClient.h
index d439daa..2c6edd0 100644
--- a/include/NetdClient.h
+++ b/include/NetdClient.h
@@ -21,6 +21,8 @@
#include <sys/cdefs.h>
#include <sys/types.h>
+#define NETID_USE_LOCAL_NAMESERVERS 0x80000000
+
__BEGIN_DECLS
// All functions below that return an int return 0 on success or a negative errno value on failure.
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index be8cc19..18ffe1c 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -33,7 +33,9 @@
#define DBG 0
#define VDBG 0
+#include <algorithm>
#include <chrono>
+#include <list>
#include <vector>
#include <cutils/log.h>
@@ -47,6 +49,7 @@
#include "dns/DnsTlsDispatcher.h"
#include "dns/DnsTlsTransport.h"
#include "dns/DnsTlsServer.h"
+#include "NetdClient.h"
#include "NetdConstants.h"
#include "NetworkController.h"
#include "ResponseCode.h"
@@ -87,76 +90,131 @@
cli->decRef();
}
+bool checkAndClearUseLocalNameserversFlag(unsigned* netid) {
+ if (netid == nullptr || ((*netid) & NETID_USE_LOCAL_NAMESERVERS) == 0) {
+ return false;
+ }
+ *netid = (*netid) & ~NETID_USE_LOCAL_NAMESERVERS;
+ return true;
+}
+
thread_local android_net_context thread_netcontext = {};
DnsTlsDispatcher dnsTlsDispatcher;
-res_sendhookact qhook(sockaddr* const * nsap, const u_char** buf, int* buflen,
+void catnap() {
+ using namespace std::chrono_literals;
+ std::this_thread::sleep_for(100ms);
+}
+
+res_sendhookact qhook(sockaddr* const * /*ns*/, const u_char** buf, int* buflen,
u_char* ans, int anssiz, int* resplen) {
if (!thread_netcontext.qhook) {
ALOGE("qhook abort: thread qhook is null");
return res_goahead;
}
+ if (thread_netcontext.flags & NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS) {
+ return res_goahead;
+ }
if (!net::gCtls) {
ALOGE("qhook abort: gCtls is null");
return res_goahead;
}
- // Safely read the data from nsap without violating strict-aliasing.
- sockaddr_storage insecureResolver;
- if ((*nsap)->sa_family == AF_INET) {
- std::memcpy(&insecureResolver, *nsap, sizeof(sockaddr_in));
- } else if ((*nsap)->sa_family == AF_INET6) {
- std::memcpy(&insecureResolver, *nsap, sizeof(sockaddr_in6));
- } else {
- ALOGE("qhook abort: unknown address family");
- return res_goahead;
+
+ const auto privateDnsStatus =
+ net::gCtls->resolverCtrl.getPrivateDnsStatus(thread_netcontext.dns_netid);
+
+ if (privateDnsStatus.mode == PrivateDnsMode::OFF) return res_goahead;
+
+ if (privateDnsStatus.validatedServers.empty()) {
+ if (privateDnsStatus.mode == PrivateDnsMode::OPPORTUNISTIC) {
+ return res_goahead;
+ } else {
+ // Sleep and iterate some small number of times checking for the
+ // arrival of resolved and validated server IP addresses, instead
+ // of returning an immediate error.
+ catnap();
+ return res_modified;
+ }
}
- DnsTlsServer tlsServer;
- auto tlsStatus = net::gCtls->resolverCtrl.getTlsStatus(thread_netcontext.dns_netid,
- insecureResolver, &tlsServer);
- if (tlsStatus == ResolverController::Validation::unknown_netid) {
- if (DBG) {
- ALOGD("No TLS for netid %u", thread_netcontext.dns_netid);
- }
- return res_goahead;
- } else if (tlsStatus == ResolverController::Validation::unknown_server) {
- if (DBG) {
- ALOGW("Skipping unexpected server in TLS mode");
- }
- return res_nextns;
- } else {
- if (tlsStatus != ResolverController::Validation::success) {
- if (DBG) {
- ALOGD("Server is not ready");
- }
- // TLS validation has not completed. In opportunistic mode, fall back to UDP.
- // In strict mode, try a different server.
- bool opportunistic = tlsServer.name.empty() && tlsServer.fingerprints.empty();
- return opportunistic ? res_goahead : res_nextns;
- }
- if (DBG) {
- ALOGD("Performing query over TLS");
- }
- Slice query = netdutils::Slice(const_cast<u_char*>(*buf), *buflen);
- Slice answer = netdutils::Slice(const_cast<u_char*>(ans), anssiz);
- auto response = dnsTlsDispatcher.query(tlsServer, thread_netcontext.dns_mark,
- query, answer, resplen);
- if (response == DnsTlsTransport::Response::success) {
- if (DBG) {
- ALOGD("qhook success");
- }
- return res_done;
- }
- if (DBG) {
- ALOGW("qhook abort: doQuery failed: %d", (int)response);
- }
- // If there was a network error on a validated server, try a different name server.
- if (response == DnsTlsTransport::Response::network_error) {
- return res_nextns;
- }
- // There was an internal error. Fail hard.
- return res_error;
+
+ if (DBG) ALOGD("Performing query over TLS");
+
+ Slice query = netdutils::Slice(const_cast<u_char*>(*buf), *buflen);
+ Slice answer = netdutils::Slice(const_cast<u_char*>(ans), anssiz);
+ const auto response = dnsTlsDispatcher.query(
+ privateDnsStatus.validatedServers, thread_netcontext.dns_mark,
+ query, answer, resplen);
+ if (response == DnsTlsTransport::Response::success) {
+ if (DBG) ALOGD("qhook success");
+ return res_done;
}
+
+ if (DBG) {
+ ALOGW("qhook abort: TLS query failed: %d", (int)response);
+ }
+
+ if (privateDnsStatus.mode == PrivateDnsMode::OPPORTUNISTIC) {
+ // In opportunistic mode, handle falling back to cleartext in some
+ // cases (DNS shouldn't fail if a validated opportunistic mode server
+ // becomes unreachable for some reason).
+ switch (response) {
+ case DnsTlsTransport::Response::network_error:
+ case DnsTlsTransport::Response::internal_error:
+ // Note: this will cause cleartext queries to be emitted, with
+ // all of the EDNS0 goodness enabled. Fingers crossed. :-/
+ return res_goahead;
+ default:
+ break;
+ }
+ }
+
+ // There was an internal error. Fail hard.
+ return res_error;
+}
+
+constexpr bool requestingUseLocalNameservers(unsigned flags) {
+ return (flags & NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS) != 0;
+}
+
+inline bool queryingViaTls(unsigned dns_netid) {
+ const auto privateDnsStatus = net::gCtls->resolverCtrl.getPrivateDnsStatus(dns_netid);
+ switch (privateDnsStatus.mode) {
+ case PrivateDnsMode::OPPORTUNISTIC:
+ return !privateDnsStatus.validatedServers.empty();
+ case PrivateDnsMode::STRICT:
+ return true;
+ default:
+ return false;
+ }
+}
+
+void maybeFixupNetContext(android_net_context* ctx) {
+ if (requestingUseLocalNameservers(ctx->flags)) {
+ if (net::gCtls->netCtrl.getPermissionForUser(ctx->uid) != Permission::PERMISSION_SYSTEM) {
+ // Not permitted; clear the flag.
+ ctx->flags &= ~NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
+ }
+ }
+
+ if (!requestingUseLocalNameservers(ctx->flags)) {
+ // If we're not explicitly bypassing DNS-over-TLS servers, check whether
+ // DNS-over-TLS is in use as an indicator for when to use more modern
+ // DNS resolution mechanics.
+ if (queryingViaTls(ctx->dns_netid)) {
+ ctx->flags |= NET_CONTEXT_FLAG_USE_EDNS;
+ }
+ }
+
+ // Always set the qhook. An opportunistic mode server might have finished
+ // validating by the time the qhook runs. Note that this races with the
+ // queryingViaTls() check above, resulting in possibly sending queries over
+ // TLS without taking advantage of features like EDNS; c'est la guerre.
+ ctx->qhook = &qhook;
+
+ // Store the android_net_context instance in a thread_local variable
+ // so that the static qhook can access other fields of the struct.
+ thread_netcontext = *ctx;
}
} // namespace
@@ -264,15 +322,15 @@
void DnsProxyListener::GetAddrInfoHandler::run() {
if (DBG) {
- ALOGD("GetAddrInfoHandler, now for %s / %s / {%u,%u,%u,%u,%u}", mHost, mService,
+ ALOGD("GetAddrInfoHandler, now for %s / %s / {%u,%u,%u,%u,%u,%u}", mHost, mService,
mNetContext.app_netid, mNetContext.app_mark,
mNetContext.dns_netid, mNetContext.dns_mark,
- mNetContext.uid);
+ mNetContext.uid, mNetContext.flags);
}
struct addrinfo* result = NULL;
Stopwatch s;
- thread_netcontext = mNetContext;
+ maybeFixupNetContext(&mNetContext);
uint32_t rv = android_getaddrinfofornetcontext(mHost, mService, mHints, &mNetContext, &result);
const int latencyMs = lround(s.timeTaken());
@@ -382,11 +440,14 @@
int ai_socktype = atoi(argv[5]);
int ai_protocol = atoi(argv[6]);
unsigned netId = strtoul(argv[7], NULL, 10);
+ const bool useLocalNameservers = checkAndClearUseLocalNameserversFlag(&netId);
uid_t uid = cli->getUid();
android_net_context netcontext;
mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
- netcontext.qhook = &qhook;
+ if (useLocalNameservers) {
+ netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
+ }
if (ai_flags != -1 || ai_family != -1 ||
ai_socktype != -1 || ai_protocol != -1) {
@@ -438,6 +499,7 @@
uid_t uid = cli->getUid();
unsigned netId = strtoul(argv[1], NULL, 10);
+ const bool useLocalNameservers = checkAndClearUseLocalNameserversFlag(&netId);
char* name = argv[2];
int af = atoi(argv[3]);
@@ -449,7 +511,9 @@
android_net_context netcontext;
mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
- netcontext.qhook = &qhook;
+ if (useLocalNameservers) {
+ netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
+ }
const int metricsLevel = mDnsProxyListener->mEventReporter->getMetricsReportingLevel();
@@ -481,7 +545,7 @@
}
Stopwatch s;
- thread_netcontext = mNetContext;
+ maybeFixupNetContext(&mNetContext);
struct hostent* hp = android_gethostbynamefornetcontext(mName, mAf, &mNetContext);
const int latencyMs = lround(s.timeTaken());
@@ -574,6 +638,7 @@
int addrFamily = atoi(argv[3]);
uid_t uid = cli->getUid();
unsigned netId = strtoul(argv[4], NULL, 10);
+ const bool useLocalNameservers = checkAndClearUseLocalNameserversFlag(&netId);
void* addr = malloc(sizeof(struct in6_addr));
errno = 0;
@@ -590,7 +655,9 @@
android_net_context netcontext;
mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
- netcontext.qhook = &qhook;
+ if (useLocalNameservers) {
+ netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
+ }
DnsProxyListener::GetHostByAddrHandler* handler =
new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netcontext);
@@ -619,12 +686,11 @@
if (DBG) {
ALOGD("DnsProxyListener::GetHostByAddrHandler::run\n");
}
- struct hostent* hp;
// NOTE gethostbyaddr should take a void* but bionic thinks it should be char*
- thread_netcontext = mNetContext;
- hp = android_gethostbyaddrfornetcontext(
- (char*)mAddress, mAddressLen, mAddressFamily, &mNetContext);
+ maybeFixupNetContext(&mNetContext);
+ struct hostent* hp = android_gethostbyaddrfornetcontext(
+ (char*) mAddress, mAddressLen, mAddressFamily, &mNetContext);
if (DBG) {
ALOGD("GetHostByAddrHandler::run gethostbyaddr errno: %s hp->h_name = %s, name_len = %zu\n",
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 081eb2a..86bfcb2 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -82,11 +82,20 @@
return true;
}
+const char* getPrivateDnsModeString(PrivateDnsMode mode) {
+ switch (mode) {
+ case PrivateDnsMode::OFF: return "OFF";
+ case PrivateDnsMode::OPPORTUNISTIC: return "OPPORTUNISTIC";
+ case PrivateDnsMode::STRICT: return "STRICT";
+ }
+}
+
+std::mutex privateDnsLock;
+std::map<unsigned, PrivateDnsMode> privateDnsModes GUARDED_BY(privateDnsLock);
// Structure for tracking the validation status of servers on a specific netId.
// Using the AddressComparator ensures at most one entry per IP address.
typedef std::map<DnsTlsServer, ResolverController::Validation,
AddressComparator> PrivateDnsTracker;
-std::mutex privateDnsLock;
std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
EventReporter eventReporter;
android::sp<android::net::metrics::INetdEventListener> netdEventListener;
@@ -165,15 +174,18 @@
validate_thread.detach();
}
-int setPrivateDnsProviders(int32_t netId,
+int setPrivateDnsConfiguration(int32_t netId,
const std::vector<std::string>& servers, const std::string& name,
const std::set<std::vector<uint8_t>>& fingerprints) {
if (DBG) {
- ALOGD("setPrivateDnsProviders(%u, %zu, %s, %zu)",
+ ALOGD("setPrivateDnsConfiguration(%u, %zu, %s, %zu)",
netId, servers.size(), name.c_str(), fingerprints.size());
}
+
+ const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
+
// Parse the list of servers that has been passed in
- std::set<DnsTlsServer> set;
+ std::set<DnsTlsServer> tlsServers;
for (size_t i = 0; i < servers.size(); ++i) {
sockaddr_storage parsed;
if (!parseServer(servers[i].c_str(), &parsed)) {
@@ -182,10 +194,20 @@
DnsTlsServer server(parsed);
server.name = name;
server.fingerprints = fingerprints;
- set.insert(server);
+ tlsServers.insert(server);
}
std::lock_guard<std::mutex> guard(privateDnsLock);
+ if (explicitlyConfigured) {
+ privateDnsModes[netId] = PrivateDnsMode::STRICT;
+ } else if (!tlsServers.empty()) {
+ privateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
+ } else {
+ privateDnsModes[netId] = PrivateDnsMode::OFF;
+ privateDnsTransports.erase(netId);
+ return 0;
+ }
+
// Create the tracker if it was not present
auto netPair = privateDnsTransports.find(netId);
if (netPair == privateDnsTransports.end()) {
@@ -201,7 +223,7 @@
// Remove any servers from the tracker that are not in |servers| exactly.
for (auto it = tracker.begin(); it != tracker.end();) {
- if (set.count(it->first) == 0) {
+ if (tlsServers.count(it->first) == 0) {
it = tracker.erase(it);
} else {
++it;
@@ -209,7 +231,7 @@
}
// Add any new or changed servers to the tracker, and initiate async checks for them.
- for (const auto& server : set) {
+ for (const auto& server : tlsServers) {
// Don't probe a server more than once. This means that the only way to
// re-check a failed server is to remove it and re-add it from the netId.
if (tracker.count(server) == 0) {
@@ -224,6 +246,7 @@
ALOGD("clearPrivateDnsProviders(%u)", netId);
}
std::lock_guard<std::mutex> guard(privateDnsLock);
+ privateDnsModes.erase(netId);
privateDnsTransports.erase(netId);
}
@@ -248,37 +271,30 @@
return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
}
-ResolverController::Validation ResolverController::getTlsStatus(unsigned netId,
- const sockaddr_storage& insecureServer,
- DnsTlsServer* secureServer) {
- // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
- // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
- // an atomic_int32_t counter of validated connections, and returning early if it's zero.
- if (DBG) {
- ALOGD("getTlsStatus(%u, %s)?", netId, addrToString(&insecureServer).c_str());
- }
+ResolverController::PrivateDnsStatus
+ResolverController::getPrivateDnsStatus(unsigned netid) const {
+ PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
+
+ // This mutex is on the critical path of every DNS lookup.
+ //
+ // If the overhead of mutex acquisition proves too high, we could reduce it
+ // by maintaining an atomic_int32_t counter of TLS-enabled netids, or by
+ // using an RWLock.
std::lock_guard<std::mutex> guard(privateDnsLock);
- const auto netPair = privateDnsTransports.find(netId);
- if (netPair == privateDnsTransports.end()) {
- if (DBG) {
- ALOGD("Not using TLS: no tracked servers for netId %u", netId);
+
+ const auto mode = privateDnsModes.find(netid);
+ if (mode == privateDnsModes.end()) return status;
+ status.mode = mode->second;
+
+ const auto netPair = privateDnsTransports.find(netid);
+ if (netPair != privateDnsTransports.end()) {
+ for (const auto& serverPair : netPair->second) {
+ if (serverPair.second == Validation::success) {
+ status.validatedServers.push_back(serverPair.first);
+ }
}
- return Validation::unknown_netid;
}
- const auto& tracker = netPair->second;
- const auto serverPair = tracker.find(insecureServer);
- if (serverPair == tracker.end()) {
- if (DBG) {
- ALOGD("Server is not in the tracker (size %zu) for netid %u", tracker.size(), netId);
- }
- return Validation::unknown_server;
- }
- const auto& validatedServer = serverPair->first;
- Validation status = serverPair->second;
- if (DBG) {
- ALOGD("Server %s has status %d", addrToString(&(validatedServer.ss)).c_str(), (int)status);
- }
- *secureServer = validatedServer;
+
return status;
}
@@ -383,30 +399,16 @@
return -EINVAL;
}
- if (!tlsServers.empty()) {
- const int err = setPrivateDnsProviders(netId, tlsServers, tlsName, tlsFingerprints);
- if (err != 0) {
- return err;
- }
- } else {
- clearPrivateDnsProviders(netId);
+ const int err = setPrivateDnsConfiguration(netId, tlsServers, tlsName, tlsFingerprints);
+ if (err != 0) {
+ return err;
}
- // TODO: separate out configuring TLS servers and locally-assigned servers.
- // We should always program bionic with locally-assigned servers, so we can
- // make TLS-bypass simple by not setting .qhook in the right circumstances.
- // Relatedly, shunting queries to DNS-over-TLS should not be based on
- // matching resolver IPs in the qhook but rather purely a function of the
- // current state of DNS-over-TLS as known only within the dispatcher.
- const std::vector<std::string>& nameservers = (!tlsServers.empty())
- ? tlsServers // Strict mode or Opportunistic
- : servers; // off
-
- // Convert server list to bionic's format.
- auto server_count = std::min<size_t>(MAXNS, nameservers.size());
+ // Convert network-assigned server list to bionic's format.
+ auto server_count = std::min<size_t>(MAXNS, servers.size());
std::vector<const char*> server_ptrs;
for (size_t i = 0 ; i < server_count ; ++i) {
- server_ptrs.push_back(nameservers[i].c_str());
+ server_ptrs.push_back(servers[i].c_str());
}
std::string domains_str;
@@ -502,9 +504,12 @@
}
{
std::lock_guard<std::mutex> guard(privateDnsLock);
+ const auto& mode = privateDnsModes.find(netId);
+ dw.println("Private DNS mode: %s", getPrivateDnsModeString(
+ mode != privateDnsModes.end() ? mode->second : PrivateDnsMode::OFF));
const auto& netPair = privateDnsTransports.find(netId);
if (netPair == privateDnsTransports.end()) {
- dw.println("No Private DNS configured");
+ dw.println("No Private DNS servers configured");
} else {
const auto& tracker = netPair->second;
dw.println("Private DNS configuration (%zu entries)", tracker.size());
diff --git a/server/ResolverController.h b/server/ResolverController.h
index b67481f..287e199 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -17,6 +17,7 @@
#ifndef _RESOLVER_CONTROLLER_H_
#define _RESOLVER_CONTROLLER_H_
+#include <list>
#include <vector>
struct __res_params;
@@ -29,6 +30,13 @@
class DumpWriter;
struct ResolverStats;
+enum class PrivateDnsMode {
+ OFF,
+ OPPORTUNISTIC,
+ STRICT,
+};
+
+
class ResolverController {
public:
ResolverController() {};
@@ -42,12 +50,16 @@
// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
- // Given a netId and the address of an insecure (i.e. normal) DNS server, this method checks
- // if there is a known secure DNS server with the same IP address that has been validated as
- // accessible on this netId. It returns the validation status, and provides the secure server
- // (including port, name, and fingerprints) in the output parameter.
- Validation getTlsStatus(unsigned netId, const sockaddr_storage& insecureServer,
- DnsTlsServer* secureServer);
+ struct PrivateDnsStatus {
+ PrivateDnsMode mode;
+ std::list<DnsTlsServer> validatedServers;
+ };
+
+ // Retrieve the Private DNS status for the given |netid|.
+ //
+ // If the requested |netid| is not known, the PrivateDnsStatus's mode has a
+ // default value of PrivateDnsMode::OFF, and validatedServers is empty.
+ PrivateDnsStatus getPrivateDnsStatus(unsigned netid) const;
int clearDnsServers(unsigned netid);
diff --git a/server/TetherController.cpp b/server/TetherController.cpp
index c1d7308..6e805f5 100644
--- a/server/TetherController.cpp
+++ b/server/TetherController.cpp
@@ -138,7 +138,7 @@
mInterfaces.clear();
mDnsForwarders.clear();
mForwardingRequests.clear();
- ifacePairList.clear();
+ mFwdIfaces.clear();
}
bool TetherController::setIpFwdEnabled() {
@@ -439,7 +439,7 @@
return res;
}
- ifacePairList.clear();
+ mFwdIfaces.clear();
return 0;
}
@@ -472,8 +472,6 @@
return res;
}
- natCount = 0;
-
return 0;
}
@@ -492,48 +490,127 @@
return -1;
}
- // add this if we are the first added nat
- if (natCount == 0) {
+ if (isForwardingPairEnabled(intIface, extIface)) {
+ return 0;
+ }
+
+ // add this if we are the first enabled nat for this upstream
+ if (!isAnyForwardingEnabledOnUpstream(extIface)) {
std::vector<std::string> v4Cmds = {
"*nat",
StringPrintf("-A %s -o %s -j MASQUERADE", LOCAL_NAT_POSTROUTING, extIface),
"COMMIT\n"
};
- /*
- * IPv6 tethering doesn't need the state-based conntrack rules, so
- * it unconditionally jumps to the tether counters chain all the time.
- */
- std::vector<std::string> v6Cmds = {
- "*filter",
- StringPrintf("-A %s -g %s", LOCAL_FORWARD, LOCAL_TETHER_COUNTERS_CHAIN),
- "COMMIT\n"
- };
-
if (iptablesRestoreFunction(V4, Join(v4Cmds, '\n'), nullptr) ||
- iptablesRestoreFunction(V6, Join(v6Cmds, '\n'), nullptr)) {
+ setupIPv6CountersChain()) {
ALOGE("Error setting postroute rule: iface=%s", extIface);
- // unwind what's been done, but don't care about success - what more could we do?
- setDefaults();
+ if (!isAnyForwardingPairEnabled()) {
+ // unwind what's been done, but don't care about success - what more could we do?
+ setDefaults();
+ }
return -1;
}
}
if (setForwardRules(true, intIface, extIface) != 0) {
ALOGE("Error setting forward rules");
- if (natCount == 0) {
+ if (!isAnyForwardingPairEnabled()) {
setDefaults();
}
errno = ENODEV;
return -1;
}
- natCount++;
return 0;
}
-bool TetherController::checkTetherCountingRuleExist(const std::string& pair_name) {
- return std::find(ifacePairList.begin(), ifacePairList.end(), pair_name) != ifacePairList.end();
+int TetherController::setupIPv6CountersChain() {
+ // Only add this if we are the first enabled nat
+ if (isAnyForwardingPairEnabled()) {
+ return 0;
+ }
+
+ /*
+ * IPv6 tethering doesn't need the state-based conntrack rules, so
+ * it unconditionally jumps to the tether counters chain all the time.
+ */
+ std::vector<std::string> v6Cmds = {
+ "*filter",
+ StringPrintf("-A %s -g %s", LOCAL_FORWARD, LOCAL_TETHER_COUNTERS_CHAIN),
+ "COMMIT\n"
+ };
+
+ return iptablesRestoreFunction(V6, Join(v6Cmds, '\n'), nullptr);
+}
+
+// Gets a pointer to the ForwardingDownstream for an interface pair in the map, or nullptr
+TetherController::ForwardingDownstream* TetherController::findForwardingDownstream(
+ const std::string& intIface, const std::string& extIface) {
+ auto extIfaceMatches = mFwdIfaces.equal_range(extIface);
+ for (auto it = extIfaceMatches.first; it != extIfaceMatches.second; ++it) {
+ if (it->second.iface == intIface) {
+ return &(it->second);
+ }
+ }
+ return nullptr;
+}
+
+void TetherController::addForwardingPair(const std::string& intIface, const std::string& extIface) {
+ ForwardingDownstream* existingEntry = findForwardingDownstream(intIface, extIface);
+ if (existingEntry != nullptr) {
+ existingEntry->active = true;
+ return;
+ }
+
+ mFwdIfaces.insert(std::pair<std::string, ForwardingDownstream>(extIface, {
+ .iface = intIface,
+ .active = true
+ }));
+}
+
+void TetherController::markForwardingPairDisabled(
+ const std::string& intIface, const std::string& extIface) {
+ ForwardingDownstream* existingEntry = findForwardingDownstream(intIface, extIface);
+ if (existingEntry == nullptr) {
+ return;
+ }
+
+ existingEntry->active = false;
+}
+
+bool TetherController::isForwardingPairEnabled(
+ const std::string& intIface, const std::string& extIface) {
+ ForwardingDownstream* existingEntry = findForwardingDownstream(intIface, extIface);
+ return existingEntry != nullptr && existingEntry->active;
+}
+
+bool TetherController::isAnyForwardingEnabledOnUpstream(const std::string& extIface) {
+ auto extIfaceMatches = mFwdIfaces.equal_range(extIface);
+ for (auto it = extIfaceMatches.first; it != extIfaceMatches.second; ++it) {
+ if (it->second.active) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool TetherController::isAnyForwardingPairEnabled() {
+ for (auto& it : mFwdIfaces) {
+ if (it.second.active) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool TetherController::tetherCountingRuleExists(
+ const std::string& iface1, const std::string& iface2) {
+ // A counting rule exists if NAT was ever enabled for this interface pair, so if the pair
+ // is in the map regardless of its active status. Rules are added both ways so we check with
+ // the 2 combinations.
+ return findForwardingDownstream(iface1, iface2) != nullptr
+ || findForwardingDownstream(iface2, iface1) != nullptr;
}
/* static */
@@ -566,15 +643,11 @@
"*filter",
};
- /* We only ever add tethering quota rules so that they stick. */
- std::string pair1 = StringPrintf("%s_%s", intIface, extIface);
- if (add && !checkTetherCountingRuleExist(pair1)) {
+ // We only ever add tethering quota rules so that they stick.
+ if (add && !tetherCountingRuleExists(intIface, extIface)) {
v4.push_back(makeTetherCountingRule(intIface, extIface));
- v6.push_back(makeTetherCountingRule(intIface, extIface));
- }
- std::string pair2 = StringPrintf("%s_%s", extIface, intIface);
- if (add && !checkTetherCountingRuleExist(pair2)) {
v4.push_back(makeTetherCountingRule(extIface, intIface));
+ v6.push_back(makeTetherCountingRule(intIface, extIface));
v6.push_back(makeTetherCountingRule(extIface, intIface));
}
@@ -599,11 +672,10 @@
return -1;
}
- if (add && !checkTetherCountingRuleExist(pair1)) {
- ifacePairList.push_front(pair1);
- }
- if (add && !checkTetherCountingRuleExist(pair2)) {
- ifacePairList.push_front(pair2);
+ if (add) {
+ addForwardingPair(intIface, extIface);
+ } else {
+ markForwardingPairDisabled(intIface, extIface);
}
return 0;
@@ -616,8 +688,7 @@
}
setForwardRules(false, intIface, extIface);
- if (--natCount <= 0) {
- // handle decrement to 0 case (do reset to defaults) and erroneous dec below 0
+ if (!isAnyForwardingPairEnabled()) {
setDefaults();
}
return 0;
diff --git a/server/TetherController.h b/server/TetherController.h
index df43a7b..a34b5b7 100644
--- a/server/TetherController.h
+++ b/server/TetherController.h
@@ -33,8 +33,17 @@
class TetherController {
private:
+ struct ForwardingDownstream {
+ std::string iface;
+ bool active;
+ };
+
std::list<std::string> mInterfaces;
+ // Map upstream iface -> downstream iface. A pair is in the map if forwarding was enabled at
+ // some point since the controller was initialized.
+ std::multimap<std::string, ForwardingDownstream> mFwdIfaces;
+
// NetId to use for forwarded DNS queries. This may not be the default
// network, e.g., in the case where we are tethering to a DUN APN.
unsigned mDnsNetId;
@@ -48,10 +57,6 @@
TetherController();
virtual ~TetherController();
- // List of strings of interface pairs. Public because it's used by CommandListener.
- // TODO: merge with mInterfaces, and make private.
- std::list<std::string> ifacePairList;
-
bool enableForwarding(const char* requester);
bool disableForwarding(const char* requester);
size_t forwardingRequestCount();
@@ -126,10 +131,17 @@
private:
bool setIpFwdEnabled();
- int natCount;
-
+ int setupIPv6CountersChain();
static std::string makeTetherCountingRule(const char *if1, const char *if2);
- bool checkTetherCountingRuleExist(const std::string& pair_name);
+ ForwardingDownstream* findForwardingDownstream(const std::string& intIface,
+ const std::string& extIface);
+ void addForwardingPair(const std::string& intIface, const std::string& extIface);
+ void markForwardingPairDisabled(const std::string& intIface, const std::string& extIface);
+
+ bool isForwardingPairEnabled(const std::string& intIface, const std::string& extIface);
+ bool isAnyForwardingEnabledOnUpstream(const std::string& extIface);
+ bool isAnyForwardingPairEnabled();
+ bool tetherCountingRuleExists(const std::string& iface1, const std::string& iface2);
int setDefaults();
int setForwardRules(bool set, const char *intIface, const char *extIface);
diff --git a/server/TetherControllerTest.cpp b/server/TetherControllerTest.cpp
index ba101fd..bcdb106 100644
--- a/server/TetherControllerTest.cpp
+++ b/server/TetherControllerTest.cpp
@@ -116,9 +116,7 @@
template<typename T>
void appendAll(std::vector<T>& cmds, const std::vector<T>& appendCmds) {
- for (auto& cmd : appendCmds) {
- cmds.push_back(cmd);
- }
+ cmds.insert(cmds.end(), appendCmds.begin(), appendCmds.end());
}
ExpectedIptablesCommands startNatCommands(const char *intIf, const char *extIf,
@@ -167,6 +165,29 @@
};
}
+ constexpr static const bool WITH_COUNTERS = true;
+ constexpr static const bool NO_COUNTERS = false;
+ constexpr static const bool WITH_IPV6 = true;
+ constexpr static const bool NO_IPV6 = false;
+ ExpectedIptablesCommands allNewNatCommands(
+ const char *intIf, const char *extIf, bool withCounterChainRules,
+ bool withIPv6Upstream) {
+
+ ExpectedIptablesCommands commands;
+ ExpectedIptablesCommands setupFirstIPv4Commands = firstIPv4UpstreamCommands(extIf);
+ ExpectedIptablesCommands startFirstNatCommands = startNatCommands(intIf, extIf,
+ withCounterChainRules);
+
+ appendAll(commands, setupFirstIPv4Commands);
+ if (withIPv6Upstream) {
+ ExpectedIptablesCommands setupFirstIPv6Commands = firstIPv6UpstreamCommands();
+ appendAll(commands, setupFirstIPv6Commands);
+ }
+ appendAll(commands, startFirstNatCommands);
+
+ return commands;
+ }
+
ExpectedIptablesCommands stopNatCommands(const char *intIf, const char *extIf) {
std::string rpfilterCmd = StringPrintf(
"*raw\n"
@@ -204,18 +225,14 @@
TEST_F(TetherControllerTest, TestAddAndRemoveNat) {
// Start first NAT on first upstream interface. Expect the upstream and NAT rules to be created.
- ExpectedIptablesCommands firstNat;
- ExpectedIptablesCommands setupFirstIPv4Commands = firstIPv4UpstreamCommands("rmnet0");
- ExpectedIptablesCommands setupFirstIPv6Commands = firstIPv6UpstreamCommands();
- ExpectedIptablesCommands startFirstNatCommands = startNatCommands("wlan0", "rmnet0", true);
- appendAll(firstNat, setupFirstIPv4Commands);
- appendAll(firstNat, setupFirstIPv6Commands);
- appendAll(firstNat, startFirstNatCommands);
+ ExpectedIptablesCommands firstNat = allNewNatCommands(
+ "wlan0", "rmnet0", WITH_COUNTERS, WITH_IPV6);
mTetherCtrl.enableNat("wlan0", "rmnet0");
expectIptablesRestoreCommands(firstNat);
// Start second NAT on same upstream. Expect only the counter rules to be created.
- ExpectedIptablesCommands startOtherNatOnSameUpstream = startNatCommands("usb0", "rmnet0", true);
+ ExpectedIptablesCommands startOtherNatOnSameUpstream = startNatCommands(
+ "usb0", "rmnet0", WITH_COUNTERS);
mTetherCtrl.enableNat("usb0", "rmnet0");
expectIptablesRestoreCommands(startOtherNatOnSameUpstream);
@@ -231,13 +248,8 @@
mTetherCtrl.disableNat("usb0", "rmnet0");
expectIptablesRestoreCommands(stopLastNat);
- // Re-add a NAT removed previously
- firstNat = {};
- // tetherctrl_counters chain rules are not re-added
- startFirstNatCommands = startNatCommands("wlan0", "rmnet0", false);
- appendAll(firstNat, setupFirstIPv4Commands);
- appendAll(firstNat, setupFirstIPv6Commands);
- appendAll(firstNat, startFirstNatCommands);
+ // Re-add a NAT removed previously: tetherctrl_counters chain rules are not re-added
+ firstNat = allNewNatCommands("wlan0", "rmnet0", NO_COUNTERS, WITH_IPV6);
mTetherCtrl.enableNat("wlan0", "rmnet0");
expectIptablesRestoreCommands(firstNat);
@@ -248,6 +260,38 @@
expectIptablesRestoreCommands(stopLastNat);
}
+TEST_F(TetherControllerTest, TestMultipleUpstreams) {
+ // Start first NAT on first upstream interface. Expect the upstream and NAT rules to be created.
+ ExpectedIptablesCommands firstNat = allNewNatCommands(
+ "wlan0", "rmnet0", WITH_COUNTERS, WITH_IPV6);
+ mTetherCtrl.enableNat("wlan0", "rmnet0");
+ expectIptablesRestoreCommands(firstNat);
+
+ // Start second NAT, on new upstream. Expect the upstream and NAT rules to be created for IPv4,
+ // but no counter rules for IPv6.
+ ExpectedIptablesCommands secondNat = allNewNatCommands(
+ "wlan0", "v4-rmnet0", WITH_COUNTERS, NO_IPV6);
+ mTetherCtrl.enableNat("wlan0", "v4-rmnet0");
+ expectIptablesRestoreCommands(secondNat);
+
+ // Pretend that the caller has forgotten that it set up the second NAT, and asks us to do so
+ // again. Expect that we take no action.
+ const ExpectedIptablesCommands NONE = {};
+ mTetherCtrl.enableNat("wlan0", "v4-rmnet0");
+ expectIptablesRestoreCommands(NONE);
+
+ // Remove the second NAT.
+ ExpectedIptablesCommands stopSecondNat = stopNatCommands("wlan0", "v4-rmnet0");
+ mTetherCtrl.disableNat("wlan0", "v4-rmnet0");
+ expectIptablesRestoreCommands(stopSecondNat);
+
+ // Remove the first NAT. Expect rules to be cleared.
+ ExpectedIptablesCommands stopFirstNat = stopNatCommands("wlan0", "rmnet0");
+ appendAll(stopFirstNat, FLUSH_COMMANDS);
+ mTetherCtrl.disableNat("wlan0", "rmnet0");
+ expectIptablesRestoreCommands(stopFirstNat);
+}
+
std::string kTetherCounterHeaders = Join(std::vector<std::string> {
"Chain tetherctrl_counters (4 references)",
" pkts bytes target prot opt in out source destination",
diff --git a/server/XfrmController.cpp b/server/XfrmController.cpp
index cc2c305..744728a 100644
--- a/server/XfrmController.cpp
+++ b/server/XfrmController.cpp
@@ -78,7 +78,7 @@
namespace {
-constexpr uint32_t RAND_SPI_MIN = 1;
+constexpr uint32_t RAND_SPI_MIN = 256;
constexpr uint32_t RAND_SPI_MAX = 0xFFFFFFFE;
constexpr uint32_t INVALID_SPI = 0;
@@ -146,7 +146,7 @@
offset += 8;
for (uint32_t j = 0; j < (uint32_t)len; j++) {
- sprintf(&printBuf[j * 2 + offset], "%0.2x", buf[j]);
+ sprintf(&printBuf[j * 2 + offset], "%0.2x", (unsigned char)buf[j]);
}
ALOGD("%s", printBuf);
delete[] printBuf;
diff --git a/server/dns/DnsTlsDispatcher.cpp b/server/dns/DnsTlsDispatcher.cpp
index d5c0fb2..95fbb9a 100644
--- a/server/dns/DnsTlsDispatcher.cpp
+++ b/server/dns/DnsTlsDispatcher.cpp
@@ -28,6 +28,85 @@
// static
std::mutex DnsTlsDispatcher::sLock;
+
+std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark) const {
+ // Our preferred DnsTlsServer order is:
+ // 1) reuse existing IPv6 connections
+ // 2) reuse existing IPv4 connections
+ // 3) establish new IPv6 connections
+ // 4) establish new IPv4 connections
+ std::list<DnsTlsServer> existing6;
+ std::list<DnsTlsServer> existing4;
+ std::list<DnsTlsServer> new6;
+ std::list<DnsTlsServer> new4;
+
+ // Pull out any servers for which we might have existing connections and
+ // place them at the from the list of servers to try.
+ {
+ std::lock_guard<std::mutex> guard(sLock);
+
+ for (const auto& tlsServer : tlsServers) {
+ const Key key = std::make_pair(mark, tlsServer);
+ if (mStore.find(key) != mStore.end()) {
+ switch (tlsServer.ss.ss_family) {
+ case AF_INET:
+ existing4.push_back(tlsServer);
+ break;
+ case AF_INET6:
+ existing6.push_back(tlsServer);
+ break;
+ }
+ } else {
+ switch (tlsServer.ss.ss_family) {
+ case AF_INET:
+ new4.push_back(tlsServer);
+ break;
+ case AF_INET6:
+ new6.push_back(tlsServer);
+ break;
+ }
+ }
+ }
+ }
+
+ auto& out = existing6;
+ out.splice(out.cend(), existing4);
+ out.splice(out.cend(), new6);
+ out.splice(out.cend(), new4);
+ return out;
+}
+
+DnsTlsTransport::Response DnsTlsDispatcher::query(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark,
+ const Slice query, const Slice ans, int *resplen) {
+ const std::list<DnsTlsServer> orderedServers(getOrderedServerList(tlsServers, mark));
+
+ if (orderedServers.empty()) ALOGW("Empty DnsTlsServer list");
+
+ DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
+ for (const auto& server : orderedServers) {
+ code = this->query(server, mark, query, ans, resplen);
+ switch (code) {
+ // These response codes are valid responses and not expected to
+ // change if another server is queried.
+ case DnsTlsTransport::Response::success:
+ case DnsTlsTransport::Response::limit_error:
+ return code;
+ break;
+ // These response codes might differ when trying other servers, so
+ // keep iterating to see if we can get a different (better) result.
+ case DnsTlsTransport::Response::network_error:
+ case DnsTlsTransport::Response::internal_error:
+ continue;
+ break;
+ // No "default" statement.
+ }
+ }
+
+ return code;
+}
+
DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
const Slice query,
const Slice ans, int *resplen) {
@@ -60,7 +139,7 @@
netdutils::copy(ans, netdutils::makeSlice(result.response));
}
} else {
- ALOGV("Query failed: %u", (unsigned int)code);
+ ALOGV("Query failed: %u", (unsigned int) code);
}
auto now = std::chrono::steady_clock::now();
diff --git a/server/dns/DnsTlsDispatcher.h b/server/dns/DnsTlsDispatcher.h
index 9487b51..5ba057a 100644
--- a/server/dns/DnsTlsDispatcher.h
+++ b/server/dns/DnsTlsDispatcher.h
@@ -17,6 +17,7 @@
#ifndef _DNS_DNSTLSDISPATCHER_H
#define _DNS_DNSTLSDISPATCHER_H
+#include <list>
#include <memory>
#include <map>
#include <mutex>
@@ -48,6 +49,14 @@
DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) :
mFactory(std::move(factory)) {}
+ // Enqueues |query| for resolution via the given |tlsServers| on the
+ // network indicated by |mark|; writes the response into |ans|, and stores
+ // the count of bytes written in |resplen|. Returns a success or error code.
+ // The order in which servers from |tlsServers| are queried may not be the
+ // order passed in by the caller.
+ DnsTlsTransport::Response query(const std::list<DnsTlsServer> &tlsServers, unsigned mark,
+ const Slice query, const Slice ans, int * _Nonnull resplen);
+
// Given a |query|, sends it to the server on the network indicated by |mark|,
// and writes the response into |ans|, and indicates
// the number of bytes written in |resplen|. Returns a success or error code.
@@ -91,6 +100,10 @@
// This function performs a linear scan of mStore.
void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+ // Return a sorted list of DnsTlsServers in preference order.
+ std::list<DnsTlsServer> getOrderedServerList(
+ const std::list<DnsTlsServer> &tlsServers, unsigned mark) const;
+
// Trivial factory for DnsTlsSockets. Dependency injection is only used for testing.
std::unique_ptr<IDnsTlsSocketFactory> mFactory;
};
diff --git a/server/dns/DnsTlsServer.cpp b/server/dns/DnsTlsServer.cpp
index 9ac9893..dbb38df 100644
--- a/server/dns/DnsTlsServer.cpp
+++ b/server/dns/DnsTlsServer.cpp
@@ -125,5 +125,9 @@
return make_tie(*this) == make_tie(other);
}
+bool DnsTlsServer::wasExplicitlyConfigured() const {
+ return !name.empty() || !fingerprints.empty();
+}
+
} // namespace net
} // namespace android
diff --git a/server/dns/DnsTlsServer.h b/server/dns/DnsTlsServer.h
index c9cbd46..1c1cffe 100644
--- a/server/dns/DnsTlsServer.h
+++ b/server/dns/DnsTlsServer.h
@@ -61,6 +61,8 @@
// Exact comparison of DnsTlsServer objects
bool operator <(const DnsTlsServer& other) const;
bool operator ==(const DnsTlsServer& other) const;
+
+ bool wasExplicitlyConfigured() const;
};
// This comparison only checks the IP address. It ignores ports, names, and fingerprints.
diff --git a/tests/dns_responder/dns_responder_client.cpp b/tests/dns_responder/dns_responder_client.cpp
index eedb849..90581fa 100644
--- a/tests/dns_responder/dns_responder_client.cpp
+++ b/tests/dns_responder/dns_responder_client.cpp
@@ -14,9 +14,11 @@
* limitations under the License.
*/
+#define LOG_TAG "dns_responder_client"
#include "dns_responder_client.h"
#include <android-base/stringprintf.h>
+#include <utils/Log.h>
// TODO: make this dynamic and stop depending on implementation details.
#define TEST_OEM_NETWORK "oem29"
@@ -52,12 +54,11 @@
bool DnsResponderClient::SetResolversWithTls(const std::vector<std::string>& servers,
const std::vector<std::string>& domains, const std::vector<int>& params,
- const std::string& name,
- const std::vector<std::string>& fingerprints) {
- // Pass servers as both network-assigned and TLS servers. Tests can
- // determine on which server and by which protocol queries arrived.
+ const std::vector<std::string>& tlsServers,
+ const std::string& name, const std::vector<std::string>& fingerprints) {
const auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params,
- name, servers, fingerprints);
+ name, tlsServers, fingerprints);
+ if (!rv.isOk()) ALOGI("SetResolversWithTls() -> %s", rv.toString8().c_str());
return rv.isOk();
}
diff --git a/tests/dns_responder/dns_responder_client.h b/tests/dns_responder/dns_responder_client.h
index 0eb07d7..1981569 100644
--- a/tests/dns_responder/dns_responder_client.h
+++ b/tests/dns_responder/dns_responder_client.h
@@ -23,7 +23,7 @@
virtual ~DnsResponderClient() = default;
- void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
+ static void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
std::vector<Mapping>* mappings);
bool SetResolversForNetwork(const std::vector<std::string>& servers,
@@ -37,6 +37,18 @@
const std::vector<std::string>& searchDomains,
const std::vector<int>& params,
const std::string& name,
+ const std::vector<std::string>& fingerprints) {
+ // Pass servers as both network-assigned and TLS servers. Tests can
+ // determine on which server and by which protocol queries arrived.
+ return SetResolversWithTls(servers, searchDomains, params,
+ servers, name, fingerprints);
+ }
+
+ bool SetResolversWithTls(const std::vector<std::string>& servers,
+ const std::vector<std::string>& searchDomains,
+ const std::vector<int>& params,
+ const std::vector<std::string>& tlsServers,
+ const std::string& name,
const std::vector<std::string>& fingerprints);
static void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index fea04c5..f6321ec 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -36,8 +36,6 @@
namespace {
-const int SHA256_SIZE = 32;
-
// Copied from DnsTlsTransport.
bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
@@ -47,7 +45,7 @@
ALOGE("SPKI length mismatch");
return false;
}
- out->resize(SHA256_SIZE);
+ out->resize(test::SHA256_SIZE);
unsigned int digest_len = 0;
int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
if (ret != 1) {
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index b4630cf..bf6bb80 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -32,6 +32,8 @@
namespace test {
+constexpr int SHA256_SIZE = 32;
+
/*
* Simple DNS over TLS reverse proxy that forwards to a UDP backend.
* Only handles a single request at a time.
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
index fdd7902..7820338 100644
--- a/tests/dns_tls_test.cpp
+++ b/tests/dns_tls_test.cpp
@@ -700,6 +700,9 @@
addr2->sin6_scope_id = 2;
checkUnequal(s1, s2);
EXPECT_FALSE(isAddressEqual(s1, s2));
+
+ EXPECT_FALSE(s1.wasExplicitlyConfigured());
+ EXPECT_FALSE(s2.wasExplicitlyConfigured());
}
TEST_F(ServerTest, IPv6FlowInfo) {
@@ -711,6 +714,9 @@
// All comparisons ignore flowinfo.
EXPECT_EQ(s1, s2);
EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ EXPECT_FALSE(s1.wasExplicitlyConfigured());
+ EXPECT_FALSE(s2.wasExplicitlyConfigured());
}
TEST_F(ServerTest, Port) {
@@ -725,6 +731,9 @@
parseServer("2001:db8::1", 852, &s4.ss);
checkUnequal(s3, s4);
EXPECT_TRUE(isAddressEqual(s3, s4));
+
+ EXPECT_FALSE(s1.wasExplicitlyConfigured());
+ EXPECT_FALSE(s2.wasExplicitlyConfigured());
}
TEST_F(ServerTest, Name) {
@@ -734,6 +743,9 @@
s2.name = SERVERNAME2;
checkUnequal(s1, s2);
EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ EXPECT_TRUE(s1.wasExplicitlyConfigured());
+ EXPECT_TRUE(s2.wasExplicitlyConfigured());
}
TEST_F(ServerTest, Fingerprint) {
@@ -754,6 +766,9 @@
s1.fingerprints.insert(FINGERPRINT2);
EXPECT_EQ(s1, s2);
EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ EXPECT_TRUE(s1.wasExplicitlyConfigured());
+ EXPECT_TRUE(s2.wasExplicitlyConfigured());
}
TEST(QueryMapTest, Basic) {
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index 3b77833..8d1f6a8 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -39,6 +39,7 @@
// TODO: make this dynamic and stop depending on implementation details.
#define TEST_NETID 30
+#include "resolv_netid.h"
#include "NetdClient.h"
#include <gtest/gtest.h>
@@ -858,16 +859,21 @@
// Wait for query to get counted.
EXPECT_TRUE(tls.waitForQueries(2, 5000));
- // Stop the TLS server. Since it's already been validated, queries will
- // continue to be routed to it.
+ // Stop the TLS server. Since we're in opportunistic mode, queries will
+ // fall back to the locally-assigned (clear text) nameservers.
tls.stopServer();
+ dns.clearQueries();
result = gethostbyname("tls2");
- EXPECT_TRUE(result == nullptr);
- EXPECT_EQ(HOST_NOT_FOUND, h_errno);
+ EXPECT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.2", ToString(result));
+ const auto queries = dns.queries();
+ EXPECT_EQ(1U, queries.size());
+ EXPECT_EQ("tls2.example.com.", queries[0].first);
+ EXPECT_EQ(ns_t_a, queries[0].second);
- // Reset the resolvers without enabling TLS. Queries should now be routed to the
- // UDP endpoint.
+ // Reset the resolvers without enabling TLS. Queries should still be routed
+ // to the UDP endpoint.
ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
result = gethostbyname("tls3");
@@ -1154,3 +1160,167 @@
tls.stopServer();
dns.stopServer();
}
+
+TEST_F(ResolverTest, TlsBypass) {
+ const char OFF[] = "off";
+ const char OPPORTUNISTIC[] = "opportunistic";
+ const char STRICT[] = "strict";
+
+ const char GETHOSTBYNAME[] = "gethostbyname";
+ const char GETADDRINFO[] = "getaddrinfo";
+ const char GETADDRINFOFORNET[] = "getaddrinfofornet";
+
+ const unsigned BYPASS_NETID = NETID_USE_LOCAL_NAMESERVERS | TEST_NETID;
+
+ const std::vector<uint8_t> NOOP_FINGERPRINT(test::SHA256_SIZE, 0U);
+
+ const char ADDR4[] = "192.0.2.1";
+ const char ADDR6[] = "2001:db8::1";
+
+ const char cleartext_addr[] = "127.0.0.53";
+ const char cleartext_port[] = "53";
+ const char tls_port[] = "853";
+ const std::vector<std::string> servers = { cleartext_addr };
+
+ test::DNSResponder dns(cleartext_addr, cleartext_port, 250, ns_rcode::ns_r_servfail, 1.0);
+ ASSERT_TRUE(dns.startServer());
+
+ test::DnsTlsFrontend tls(cleartext_addr, tls_port, cleartext_addr, cleartext_port);
+
+ struct TestConfig {
+ const std::string mode;
+ const bool withWorkingTLS;
+ const std::string method;
+
+ std::string asHostName() const {
+ return StringPrintf("%s.%s.%s.",
+ mode.c_str(),
+ withWorkingTLS ? "tlsOn" : "tlsOff",
+ method.c_str());
+ }
+ } testConfigs[]{
+ {OFF, false, GETHOSTBYNAME},
+ {OPPORTUNISTIC, false, GETHOSTBYNAME},
+ {STRICT, false, GETHOSTBYNAME},
+ {OFF, true, GETHOSTBYNAME},
+ {OPPORTUNISTIC, true, GETHOSTBYNAME},
+ {STRICT, true, GETHOSTBYNAME},
+ {OFF, false, GETADDRINFO},
+ {OPPORTUNISTIC, false, GETADDRINFO},
+ {STRICT, false, GETADDRINFO},
+ {OFF, true, GETADDRINFO},
+ {OPPORTUNISTIC, true, GETADDRINFO},
+ {STRICT, true, GETADDRINFO},
+ {OFF, false, GETADDRINFOFORNET},
+ {OPPORTUNISTIC, false, GETADDRINFOFORNET},
+ {STRICT, false, GETADDRINFOFORNET},
+ {OFF, true, GETADDRINFOFORNET},
+ {OPPORTUNISTIC, true, GETADDRINFOFORNET},
+ {STRICT, true, GETADDRINFOFORNET},
+ };
+
+ for (const auto& config : testConfigs) {
+ const std::string testHostName = config.asHostName();
+ SCOPED_TRACE(testHostName);
+
+ // Don't tempt test bugs due to caching.
+ const char* host_name = testHostName.c_str();
+ dns.addMapping(host_name, ns_type::ns_t_a, ADDR4);
+ dns.addMapping(host_name, ns_type::ns_t_aaaa, ADDR6);
+
+ if (config.withWorkingTLS) ASSERT_TRUE(tls.startServer());
+
+ if (config.mode == OFF) {
+ ASSERT_TRUE(SetResolversForNetwork(
+ servers, mDefaultSearchDomains, mDefaultParams_Binder));
+ } else if (config.mode == OPPORTUNISTIC) {
+ ASSERT_TRUE(SetResolversWithTls(
+ servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
+ // Wait for validation to complete.
+ if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ } else if (config.mode == STRICT) {
+ // We use the existence of fingerprints to trigger strict mode,
+ // rather than hostname validation.
+ const auto& fingerprint =
+ (config.withWorkingTLS) ? tls.fingerprint() : NOOP_FINGERPRINT;
+ ASSERT_TRUE(SetResolversWithTls(
+ servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+ { base64Encode(fingerprint) }));
+ // Wait for validation to complete.
+ if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ } else {
+ FAIL() << "Unsupported Private DNS mode: " << config.mode;
+ }
+
+ const int tlsQueriesBefore = tls.queries();
+
+ const hostent* h_result = nullptr;
+ addrinfo* ai_result = nullptr;
+
+ if (config.method == GETHOSTBYNAME) {
+ ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
+ h_result = gethostbyname(host_name);
+
+ EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
+ ASSERT_FALSE(h_result == nullptr);
+ ASSERT_EQ(4, h_result->h_length);
+ ASSERT_FALSE(h_result->h_addr_list[0] == nullptr);
+ EXPECT_EQ(ADDR4, ToString(h_result));
+ EXPECT_TRUE(h_result->h_addr_list[1] == nullptr);
+ } else if (config.method == GETADDRINFO) {
+ ASSERT_EQ(0, setNetworkForResolv(BYPASS_NETID));
+ EXPECT_EQ(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
+
+ EXPECT_LE(1U, GetNumQueries(dns, host_name));
+ // Could be A or AAAA
+ const std::string result_str = ToString(ai_result);
+ EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
+ << ", result_str='" << result_str << "'";
+ } else if (config.method == GETADDRINFOFORNET) {
+ EXPECT_EQ(0, android_getaddrinfofornet(
+ host_name, nullptr, nullptr, BYPASS_NETID, MARK_UNSET, &ai_result));
+
+ EXPECT_LE(1U, GetNumQueries(dns, host_name));
+ // Could be A or AAAA
+ const std::string result_str = ToString(ai_result);
+ EXPECT_TRUE(result_str == ADDR4 || result_str == ADDR6)
+ << ", result_str='" << result_str << "'";
+ } else {
+ FAIL() << "Unsupported query method: " << config.method;
+ }
+
+ const int tlsQueriesAfter = tls.queries();
+ EXPECT_EQ(0, tlsQueriesAfter - tlsQueriesBefore);
+
+ // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
+ if (ai_result != nullptr) freeaddrinfo(ai_result);
+
+ // Clear per-process resolv netid.
+ ASSERT_EQ(0, setNetworkForResolv(NETID_UNSET));
+ tls.stopServer();
+ dns.clearQueries();
+ }
+
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, StrictMode_NoTlsServers) {
+ const std::vector<uint8_t> NOOP_FINGERPRINT(test::SHA256_SIZE, 0U);
+ const char cleartext_addr[] = "127.0.0.53";
+ const char cleartext_port[] = "53";
+ const std::vector<std::string> servers = { cleartext_addr };
+
+ test::DNSResponder dns(cleartext_addr, cleartext_port, 250, ns_rcode::ns_r_servfail, 1.0);
+ const char* host_name = "strictmode.notlsips.example.com.";
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
+ dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
+ ASSERT_TRUE(dns.startServer());
+
+ ASSERT_TRUE(SetResolversWithTls(
+ servers, mDefaultSearchDomains, mDefaultParams_Binder,
+ {}, "", { base64Encode(NOOP_FINGERPRINT) }));
+
+ addrinfo* ai_result = nullptr;
+ EXPECT_NE(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
+ EXPECT_EQ(0U, GetNumQueries(dns, host_name));
+}