Merge "Explicitly allocate ResState on the call stack"
diff --git a/DnsTlsServer.cpp b/DnsTlsServer.cpp
index 2b8e417..4e94488 100644
--- a/DnsTlsServer.cpp
+++ b/DnsTlsServer.cpp
@@ -109,11 +109,7 @@
// Returns a tuple of references to the elements of s.
auto make_tie(const DnsTlsServer& s) {
- return std::tie(
- s.ss,
- s.name,
- s.protocol
- );
+ return std::tie(s.ss, s.name, s.protocol, s.connectTimeout);
}
bool DnsTlsServer::operator <(const DnsTlsServer& other) const {
diff --git a/DnsTlsServer.h b/DnsTlsServer.h
index c3454cb..f036b68 100644
--- a/DnsTlsServer.h
+++ b/DnsTlsServer.h
@@ -17,6 +17,7 @@
#ifndef _DNS_DNSTLSSERVER_H
#define _DNS_DNSTLSSERVER_H
+#include <chrono>
#include <set>
#include <string>
#include <vector>
@@ -58,6 +59,11 @@
// Placeholder. More protocols might be defined in the future.
int protocol = IPPROTO_TCP;
+ // The time to wait for the attempt on connecting to the server.
+ // Set the default value 127 seconds to be consistent with TCP connect timeout.
+ // (presume net.ipv4.tcp_syn_retries = 6)
+ std::chrono::milliseconds connectTimeout = std::chrono::milliseconds(127 * 1000);
+
// Exact comparison of DnsTlsServer objects
bool operator<(const DnsTlsServer& other) const;
bool operator==(const DnsTlsServer& other) const;
diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp
index 14aadd4..bb43ed9 100644
--- a/DnsTlsSocket.cpp
+++ b/DnsTlsSocket.cpp
@@ -59,16 +59,14 @@
constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
-int waitForReading(int fd) {
- struct pollfd fds = { .fd = fd, .events = POLLIN };
- const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
- return ret;
+int waitForReading(int fd, int timeoutMs = -1) {
+ pollfd fds = {.fd = fd, .events = POLLIN};
+ return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}
-int waitForWriting(int fd) {
- struct pollfd fds = { .fd = fd, .events = POLLOUT };
- const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
- return ret;
+int waitForWriting(int fd, int timeoutMs = -1) {
+ pollfd fds = {.fd = fd, .events = POLLOUT};
+ return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}
std::string markToFwmarkString(unsigned mMark) {
@@ -250,14 +248,21 @@
const int ssl_err = SSL_get_error(ssl.get(), ret);
switch (ssl_err) {
case SSL_ERROR_WANT_READ:
- if (waitForReading(fd) != 1) {
- PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark);
+ // SSL_ERROR_WANT_READ is returned because the application data has been sent during
+ // the TCP connection handshake, the device is waiting for the SSL handshake reply
+ // from the server.
+ if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) {
+ PLOG(WARNING) << "SSL_connect read error " << err << ", "
+ << markToFwmarkString(mMark);
return nullptr;
}
break;
case SSL_ERROR_WANT_WRITE:
- if (waitForWriting(fd) != 1) {
- PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark);
+ // If no application data is sent during the TCP connection handshake, the
+ // device is waiting for the connection established to perform SSL handshake.
+ if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) {
+ PLOG(WARNING) << "SSL_connect write error " << err << ", "
+ << markToFwmarkString(mMark);
return nullptr;
}
break;
@@ -291,8 +296,8 @@
const int ssl_err = SSL_get_error(mSsl.get(), ret);
switch (ssl_err) {
case SSL_ERROR_WANT_WRITE:
- if (waitForWriting(mSslFd.get()) != 1) {
- LOG(DEBUG) << "SSL_write error";
+ if (int err = waitForWriting(mSslFd.get()); err <= 0) {
+ PLOG(WARNING) << "Poll failed in sslWrite, error " << err;
return false;
}
continue;
@@ -462,8 +467,8 @@
if (ret < 0) {
const int ssl_err = SSL_get_error(mSsl.get(), ret);
if (wait && ssl_err == SSL_ERROR_WANT_READ) {
- if (waitForReading(mSslFd.get()) != 1) {
- LOG(DEBUG) << "Poll failed in sslRead: " << errno;
+ if (int err = waitForReading(mSslFd.get()); err <= 0) {
+ PLOG(WARNING) << "Poll failed in sslRead, error " << err;
return SSL_ERROR_SYSCALL;
}
continue;
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index 14bac16..9242222 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -29,6 +29,8 @@
#include "netd_resolv/resolv.h"
#include "netdutils/BackoffSequence.h"
+using std::chrono::milliseconds;
+
namespace android {
namespace net {
@@ -40,7 +42,10 @@
}
bool parseServer(const char* server, sockaddr_storage* parsed) {
- addrinfo hints = {.ai_family = AF_UNSPEC, .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV};
+ addrinfo hints = {
+ .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV,
+ .ai_family = AF_UNSPEC,
+ };
addrinfo* res;
int err = getaddrinfo(server, "853", &hints, &res);
@@ -56,9 +61,9 @@
int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
const std::vector<std::string>& servers, const std::string& name,
- const std::string& caCert) {
+ const std::string& caCert, int32_t connectTimeoutMs) {
LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec
- << ", " << servers.size() << ", " << name << ")";
+ << ", " << servers.size() << ", " << name << ", " << connectTimeoutMs << "ms)";
// Parse the list of servers that has been passed in
std::set<DnsTlsServer> tlsServers;
@@ -70,6 +75,15 @@
DnsTlsServer server(parsed);
server.name = name;
server.certificate = caCert;
+
+ // connectTimeoutMs = 0: use the default timeout value.
+ // connectTimeoutMs < 0: invalid timeout value.
+ if (connectTimeoutMs > 0) {
+ // Set a specific timeout value but limit it to be at least 1 second.
+ server.connectTimeout =
+ (connectTimeoutMs < 1000) ? milliseconds(1000) : milliseconds(connectTimeoutMs);
+ }
+
tlsServers.insert(server);
}
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index 5800831..6c50604 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -53,7 +53,8 @@
class PrivateDnsConfiguration {
public:
int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
- const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
+ const std::string& name, const std::string& caCert, int32_t connectTimeoutMs)
+ EXCLUDES(mPrivateDnsLock);
PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock);
diff --git a/ResolverController.cpp b/ResolverController.cpp
index bdc6558..6927e5f 100644
--- a/ResolverController.cpp
+++ b/ResolverController.cpp
@@ -210,9 +210,9 @@
if (tlsServers.size() > MAXNS) {
tlsServers.resize(MAXNS);
}
- const int err =
- gPrivateDnsConfiguration.set(resolverParams.netId, fwmark.intValue, tlsServers,
- resolverParams.tlsName, resolverParams.caCertificate);
+ const int err = gPrivateDnsConfiguration.set(
+ resolverParams.netId, fwmark.intValue, tlsServers, resolverParams.tlsName,
+ resolverParams.caCertificate, resolverParams.tlsConnectTimeoutMs);
if (err != 0) {
return err;
diff --git a/binder/android/net/ResolverParamsParcel.aidl b/binder/android/net/ResolverParamsParcel.aidl
index 4f266d3..94576b2 100644
--- a/binder/android/net/ResolverParamsParcel.aidl
+++ b/binder/android/net/ResolverParamsParcel.aidl
@@ -89,4 +89,10 @@
*
*/
@utf8InCpp String caCertificate;
+
+ /**
+ * The timeout for the connection attempt to a Private DNS server.
+ * It's initialized to 0 to use the predefined default value.
+ */
+ int tlsConnectTimeoutMs = 0;
}
diff --git a/res_cache.cpp b/res_cache.cpp
index bc36bef..d26ad76 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -1488,7 +1488,10 @@
for (int i = 0; i < numservers; i++) {
// The addrinfo structures allocated here are freed in free_nameservers_locked().
const addrinfo hints = {
- .ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM, .ai_flags = AI_NUMERICHOST};
+ .ai_flags = AI_NUMERICHOST,
+ .ai_family = AF_UNSPEC,
+ .ai_socktype = SOCK_DGRAM,
+ };
const int rt = getaddrinfo_numeric(nameservers[i].c_str(), "53", hints, &nsaddrinfo[i]);
if (rt != 0) {
for (int j = 0; j < i; j++) {
diff --git a/resolv_integration_test.cpp b/resolv_integration_test.cpp
index 5014983..27e21d2 100644
--- a/resolv_integration_test.cpp
+++ b/resolv_integration_test.cpp
@@ -558,8 +558,8 @@
// TODO: Test other invalid socket types.
const addrinfo hints = {
.ai_family = AF_UNSPEC,
- .ai_protocol = ANY,
.ai_socktype = SOCK_PACKET,
+ .ai_protocol = ANY,
};
addrinfo* result = nullptr;
// This is a valid hint, but the query won't be sent because the socket type is
@@ -2640,10 +2640,10 @@
SCOPED_TRACE(config.asParameters());
addrinfo hints = {
+ .ai_flags = config.flag,
.ai_family = AF_UNSPEC, // any address family
.ai_socktype = 0, // any type
.ai_protocol = 0, // any protocol
- .ai_flags = config.flag,
};
// Assign hostname as null and service as port name.
@@ -3439,3 +3439,50 @@
.isOk());
EXPECT_TRUE(netdService->firewallEnableChildChain(INetd::FIREWALL_CHAIN_STANDBY, false).isOk());
}
+
+TEST_F(ResolverTest, ConnectTlsServerTimeout) {
+ constexpr char listen_addr[] = "127.0.0.3";
+ constexpr char listen_udp[] = "53";
+ constexpr char listen_tls[] = "853";
+ constexpr char host_name[] = "tls.example.com.";
+ const std::vector<std::string> servers = {listen_addr};
+ const std::vector<DnsRecord> records = {
+ {host_name, ns_type::ns_t_a, "1.2.3.4"},
+ };
+
+ test::DNSResponder dns;
+ StartDns(dns, records);
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+
+ // Opportunistic mode.
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, {}));
+
+ // Wait for the server being marked as validated so that PrivateDnsStatus::validatedServers()
+ // won't return empty list.
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+ dns.clearQueries();
+ tls.clearQueries();
+
+ // The server becomes unresponsive to the handshake request.
+ tls.setHangOnHandshakeForTesting(true);
+
+ // Expect the things happening in getaddrinfo():
+ // 1. Connect to the private DNS server.
+ // 2. SSL handshake times out.
+ // 3. Fallback to UDP transport, and then get the answer.
+ const auto start = std::chrono::steady_clock::now();
+ addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+ ScopedAddrinfo result = safe_getaddrinfo("tls", nullptr, &hints);
+ const auto end = std::chrono::steady_clock::now();
+
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ(0, tls.queries());
+ EXPECT_EQ(1U, GetNumQueries(dns, host_name));
+ EXPECT_EQ("1.2.3.4", ToString(result));
+
+ // 3000ms is a loose upper bound. Theoretically, it takes a bit more than 1000ms.
+ EXPECT_GE(3000, std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
+ EXPECT_LE(1000, std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
+}
diff --git a/resolv_tls_unit_test.cpp b/resolv_tls_unit_test.cpp
index 72e6897..5d3ca26 100644
--- a/resolv_tls_unit_test.cpp
+++ b/resolv_tls_unit_test.cpp
@@ -791,6 +791,18 @@
EXPECT_TRUE(s2.wasExplicitlyConfigured());
}
+TEST_F(ServerTest, Timeout) {
+ DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+ s1.connectTimeout = std::chrono::milliseconds(4000);
+ checkUnequal(s1, s2);
+ s2.connectTimeout = std::chrono::milliseconds(4000);
+ EXPECT_EQ(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ EXPECT_FALSE(s1.wasExplicitlyConfigured());
+ EXPECT_FALSE(s2.wasExplicitlyConfigured());
+}
+
TEST(QueryMapTest, Basic) {
DnsTlsQueryMap map;
diff --git a/resolv_unit_test.cpp b/resolv_unit_test.cpp
index 0c1de0b..3295cac 100644
--- a/resolv_unit_test.cpp
+++ b/resolv_unit_test.cpp
@@ -239,8 +239,8 @@
for (const auto& socktype : {SOCK_RDM, SOCK_SEQPACKET, SOCK_DCCP, SOCK_PACKET}) {
const addrinfo hints = {
.ai_family = family,
- .ai_protocol = protocol,
.ai_socktype = socktype,
+ .ai_protocol = protocol,
};
for (const char* service : {static_cast<const char*>(nullptr), // service is null
"80",
@@ -298,8 +298,8 @@
addrinfo* result = nullptr;
const addrinfo hints = {
.ai_family = family,
- .ai_protocol = protocol,
.ai_socktype = socktype,
+ .ai_protocol = protocol,
};
NetworkDnsEventReported event;
int rv = resolv_getaddrinfo("localhost", nullptr /*servname*/, &hints, &mNetcontext,
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index 4bac4b0..072c2dd 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -515,7 +515,11 @@
}
// Set up UDP socket.
- addrinfo ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM, .ai_flags = AI_PASSIVE};
+ addrinfo ai_hints{
+ .ai_flags = AI_PASSIVE,
+ .ai_family = AF_UNSPEC,
+ .ai_socktype = SOCK_DGRAM,
+ };
addrinfo* ai_res = nullptr;
int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
ScopedAddrinfo ai_res_cleanup(ai_res);
diff --git a/tests/dns_responder/dns_responder_client.cpp b/tests/dns_responder/dns_responder_client.cpp
index 75d0e1d..473bfb5 100644
--- a/tests/dns_responder/dns_responder_client.cpp
+++ b/tests/dns_responder/dns_responder_client.cpp
@@ -103,6 +103,7 @@
paramsParcel.tlsServers = tlsServers;
paramsParcel.tlsFingerprints = {};
paramsParcel.caCertificate = caCert;
+ paramsParcel.tlsConnectTimeoutMs = 1000;
return paramsParcel;
}
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index 3b25324..6399842 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -164,7 +164,10 @@
// Set up TCP server socket for clients.
addrinfo frontend_ai_hints{
- .ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM, .ai_flags = AI_PASSIVE};
+ .ai_flags = AI_PASSIVE,
+ .ai_family = AF_UNSPEC,
+ .ai_socktype = SOCK_STREAM,
+ };
addrinfo* frontend_ai_res = nullptr;
int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints,
&frontend_ai_res);
@@ -241,6 +244,7 @@
enum { EVENT_FD = 0, LISTEN_FD = 1 };
pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
{.fd = socket_.get(), .events = POLLIN}};
+ android::base::unique_fd clientFd;
while (true) {
int poll_code = poll(fds, std::size(fds), -1);
@@ -266,6 +270,14 @@
break;
}
+ if (hangOnHandshake_) {
+ LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake";
+
+ // The previous fd already stored in clientFd will be closed automatically.
+ clientFd = std::move(client);
+ continue;
+ }
+
bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
SSL_set_fd(ssl.get(), client.get());
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 5af34b4..b4455de 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -55,6 +55,7 @@
void clearQueries() { queries_ = 0; }
bool waitForQueries(int number, int timeoutMs) const;
void set_chain_length(int length) { chain_length_ = length; }
+ void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; }
private:
void requestHandler();
@@ -81,6 +82,7 @@
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
int chain_length_ = 1;
+ std::atomic<bool> hangOnHandshake_ = false;
};
} // namespace test