Harden TLS fingerprint test against flakiness

This change might reduce the flakiness of this test, and
also might help debug any future issues with it.

Bug: 64779303
Test: Tests continue to pass for me
Change-Id: I3a71b12f4ee3749c70210f1c19ca924473756fd5
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 238eef7..823afbf 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -40,6 +40,7 @@
 #include <resolv_stats.h>
 
 #include <android-base/strings.h>
+#include <android-base/thread_annotations.h>
 #include <android/net/INetd.h>
 
 #include "DumpWriter.h"
@@ -97,13 +98,13 @@
 // Structure for tracking the entire set of known Private DNS servers.
 std::mutex privateDnsLock;
 typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
-PrivateDnsSet privateDnsServers;
+PrivateDnsSet privateDnsServers GUARDED_BY(privateDnsLock);
 
 // Structure for tracking the validation status of servers on a specific netid.
 // Servers that fail validation are removed from the tracker, and can be retried.
 enum class Validation : bool { in_process, success };
 typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
-std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
+std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
 
 PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
     PrivateDnsSet set;
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index fab2924..8704bdb 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -588,6 +588,7 @@
         if (s < 0) continue;
         const int one = 1;
         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
+        setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
             APLOGI("bind failed for socket %d", s);
             close(s);
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index 02b01f2..b360060 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -206,6 +206,7 @@
         if (s < 0) continue;
         const int one = 1;
         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
+        setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
             APLOGI("bind failed for socket %d", s);
             close(s);
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index 0217f5b..6759f3f 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -743,6 +743,9 @@
         .sin_port = htons(853),
     };
     ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
+    const int one = 1;
+    setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
+    setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
     ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
     ASSERT_FALSE(listen(s, 1));
 
@@ -832,19 +835,23 @@
 TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
     const char* listen_addr = "127.0.0.3";
     const char* listen_udp = "53";
-    const char* listen_tls = "853";
+    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+    ASSERT_TRUE(dns.startServer());
     for (int chain_length = 1; chain_length <= 3; ++chain_length) {
         const char* host_name = StringPrintf("tlsfingerprint%d.example.com.", chain_length).c_str();
-        test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
         dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
-        ASSERT_TRUE(dns.startServer());
         std::vector<std::string> servers = { listen_addr };
 
+        // Run each TLS server on a new port to avoid any possible races related to reopening
+        // sockets that were just closed.
+        int tls_port = 853 + chain_length;
+        const char* listen_tls = std::to_string(tls_port).c_str();
         test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
         tls.set_chain_length(chain_length);
         ASSERT_TRUE(tls.startServer());
-        auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+        auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, tls_port, "SHA-256",
                 { base64Encode(tls.fingerprint()) });
+        EXPECT_EQ(0, rv.exceptionCode());
         ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
 
         const hostent* result;
@@ -862,9 +869,10 @@
         }
 
         rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+        EXPECT_EQ(0, rv.exceptionCode());
         tls.stopServer();
-        dns.stopServer();
     }
+    dns.stopServer();
 }
 
 TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {