Add tests for strict mode private DNS validation.

Test successful and failed validation, and updating the config.
In order to do this, add a FakeDns class so we can change DNS
responses dynamically while the test is running.

Also a couple of minor fixes:
1. Make sure the DNS timeout is set. Before this CL, it was
   always 0. Not sure why. It does seem to be set to the default
   value (12500) when actually running on device. We didn't
   catch this because the only tests that use the timeout set it
   explicitly.
2. Make runNetworkTest a bit more realistic: always send
   NetworkCapabilities *before* calling notifyNetworkConnected.
   This is what ConnectivityService does.

Bug: 122652057
Test: atest FrameworksNetTests NetworkStackTests
Test: atest --generate-new-metrics 50 NetworkStackTests:com.android.server.connectivity.NetworkMonitorTest
Change-Id: Ifd6694262501874f3261c864a049cb35c6afb9c8
Merged-In: Ifd6694262501874f3261c864a049cb35c6afb9c8
(cherry picked from commit 89909befd2f4e990d21e27c99857c635b3c6709e)
diff --git a/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java b/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
index 0dc1cbf..a24ed5a 100644
--- a/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -42,10 +42,10 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -66,6 +66,7 @@
 import android.net.NetworkInfo;
 import android.net.captiveportal.CaptivePortalProbeResult;
 import android.net.metrics.IpConnectivityLog;
+import android.net.shared.PrivateDnsConfig;
 import android.net.util.SharedLog;
 import android.net.wifi.WifiInfo;
 import android.net.wifi.WifiManager;
@@ -73,6 +74,7 @@
 import android.os.ConditionVariable;
 import android.os.Handler;
 import android.os.Looper;
+import android.os.Process;
 import android.os.RemoteException;
 import android.os.SystemClock;
 import android.provider.Settings;
@@ -132,6 +134,7 @@
     private @Mock NetworkMonitor.Dependencies mDependencies;
     private @Mock INetworkMonitorCallbacks mCallbacks;
     private @Spy Network mNetwork = new Network(TEST_NETID);
+    private @Mock Network mNonPrivateDnsBypassNetwork;
     private @Mock DataStallStatsUtils mDataStallStatsUtils;
     private @Mock WifiInfo mWifiInfo;
     private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor;
@@ -166,31 +169,93 @@
     private static final NetworkCapabilities NO_INTERNET_CAPABILITIES = new NetworkCapabilities()
             .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);
 
-    private void setDnsAnswers(String[] answers) throws UnknownHostException {
-        if (answers == null) {
-            doThrow(new UnknownHostException()).when(mNetwork).getAllByName(any());
-            doNothing().when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
-            return;
+    /**
+     * Fakes DNS responses.
+     *
+     * Allows test methods to configure the IP addresses that will be resolved by
+     * Network#getAllByName and by DnsResolver#query.
+     */
+    class FakeDns {
+        private final ArrayMap<String, List<InetAddress>> mAnswers = new ArrayMap<>();
+        private boolean mNonBypassPrivateDnsWorking = true;
+
+        /** Whether DNS queries on mNonBypassPrivateDnsWorking should succeed. */
+        private void setNonBypassPrivateDnsWorking(boolean working) {
+            mNonBypassPrivateDnsWorking = working;
         }
 
-        List<InetAddress> answerList = new ArrayList<>();
-        for (String answer : answers) {
-            answerList.add(InetAddresses.parseNumericAddress(answer));
+        /** Clears all DNS entries. */
+        private synchronized void clearAll() {
+            mAnswers.clear();
         }
-        InetAddress[] answerArray = answerList.toArray(new InetAddress[0]);
 
-        doReturn(answerArray).when(mNetwork).getAllByName(any());
+        /** Returns the answer for a given name on the given mock network. */
+        private synchronized List<InetAddress> getAnswer(Object mock, String hostname) {
+            if (mock == mNonPrivateDnsBypassNetwork && !mNonBypassPrivateDnsWorking) {
+                return null;
+            }
+            if (mAnswers.containsKey(hostname)) {
+                return mAnswers.get(hostname);
+            }
+            return mAnswers.get("*");
+        }
 
-        doAnswer((invocation) -> {
-            Executor executor = (Executor) invocation.getArgument(3);
-            DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);
-            new Handler(Looper.getMainLooper()).post(() -> {
-                executor.execute(() -> callback.onAnswer(answerList, 0));
-            });
-            return null;
-        }).when(mDnsResolver).query(eq(mNetwork), any(), anyInt(), any(), any(), any());
+        /** Sets the answer for a given name. */
+        private synchronized void setAnswer(String hostname, String[] answer)
+                throws UnknownHostException {
+            if (answer == null) {
+                mAnswers.remove(hostname);
+            } else {
+                List<InetAddress> answerList = new ArrayList<>();
+                for (String addr : answer) {
+                    answerList.add(InetAddresses.parseNumericAddress(addr));
+                }
+                mAnswers.put(hostname, answerList);
+            }
+        }
+
+        /** Simulates a getAllByName call for the specified name on the specified mock network. */
+        private InetAddress[] getAllByName(Object mock, String hostname)
+                throws UnknownHostException {
+            List<InetAddress> answer = getAnswer(mock, hostname);
+            if (answer == null || answer.size() == 0) {
+                throw new UnknownHostException(hostname);
+            }
+            return answer.toArray(new InetAddress[0]);
+        }
+
+        /** Starts mocking DNS queries. */
+        private void startMocking() throws UnknownHostException {
+            // Queries on mNetwork (i.e., bypassing private DNS) using getAllByName.
+            doAnswer(invocation -> {
+                return getAllByName(invocation.getMock(), invocation.getArgument(0));
+            }).when(mNetwork).getAllByName(any());
+
+            // Queries on mNonBypassPrivateDnsNetwork using getAllByName.
+            doAnswer(invocation -> {
+                return getAllByName(invocation.getMock(), invocation.getArgument(0));
+            }).when(mNonPrivateDnsBypassNetwork).getAllByName(any());
+
+            // Queries on mNetwork (i.e., bypassing private DNS) using DnsResolver#query.
+            doAnswer(invocation -> {
+                String hostname = (String) invocation.getArgument(1);
+                Executor executor = (Executor) invocation.getArgument(3);
+                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);
+
+                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
+                if (answer != null && answer.size() > 0) {
+                    new Handler(Looper.getMainLooper()).post(() -> {
+                        executor.execute(() -> callback.onAnswer(answer, 0));
+                    });
+                }
+                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
+                return null;
+            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
+        }
     }
 
+    private FakeDns mFakeDns;
+
     @Before
     public void setUp() throws IOException {
         MockitoAnnotations.initMocks(this);
@@ -206,7 +271,7 @@
         when(mDependencies.getSetting(any(), eq(Settings.Global.CAPTIVE_PORTAL_HTTPS_URL), any()))
                 .thenReturn(TEST_HTTPS_URL);
 
-        doReturn(mNetwork).when(mNetwork).getPrivateDnsBypassingCopy();
+        doReturn(mNetwork).when(mNonPrivateDnsBypassNetwork).getPrivateDnsBypassingCopy();
 
         when(mContext.getSystemService(Context.CONNECTIVITY_SERVICE)).thenReturn(mCm);
         when(mContext.getSystemService(Context.TELEPHONY_SERVICE)).thenReturn(mTelephony);
@@ -222,6 +287,9 @@
         setFallbackSpecs(null); // Test with no fallback spec by default
         when(mRandom.nextInt()).thenReturn(0);
 
+        when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout)))
+                .thenReturn(500);
+
         doAnswer((invocation) -> {
             URL url = invocation.getArgument(0);
             switch(url.toString()) {
@@ -241,7 +309,9 @@
         when(mHttpConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
         when(mHttpsConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
 
-        setDnsAnswers(new String[]{"2001:db8::1", "192.0.2.2"});
+        mFakeDns = new FakeDns();
+        mFakeDns.startMocking();
+        mFakeDns.setAnswer("*", new String[]{"2001:db8::1", "192.0.2.2"});
 
         when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> {
             mRegisteredReceivers.add(invocation.getArgument(0));
@@ -264,6 +334,7 @@
 
     @After
     public void tearDown() {
+        mFakeDns.clearAll();
         assertTrue(mCreatedNetworkMonitors.size() > 0);
         // Make a local copy of mCreatedNetworkMonitors because during the iteration below,
         // WrappedNetworkMonitor#onQuitting will delete elements from it on the handler threads.
@@ -284,8 +355,8 @@
         private final ConditionVariable mQuitCv = new ConditionVariable(false);
 
         WrappedNetworkMonitor() {
-                super(mContext, mCallbacks, mNetwork, mLogger, mValidationLogger, mDependencies,
-                        mDataStallStatsUtils);
+            super(mContext, mCallbacks, mNonPrivateDnsBypassNetwork, mLogger, mValidationLogger,
+                    mDependencies, mDataStallStatsUtils);
         }
 
         @Override
@@ -314,23 +385,22 @@
         }
     }
 
-    private WrappedNetworkMonitor makeMonitor() {
+    private WrappedNetworkMonitor makeMonitor(NetworkCapabilities nc) {
         final WrappedNetworkMonitor nm = new WrappedNetworkMonitor();
         nm.start();
+        setNetworkCapabilities(nm, nc);
         waitForIdle(nm.getHandler());
         mCreatedNetworkMonitors.add(nm);
         return nm;
     }
 
     private WrappedNetworkMonitor makeMeteredNetworkMonitor() {
-        final WrappedNetworkMonitor nm = makeMonitor();
-        setNetworkCapabilities(nm, METERED_CAPABILITIES);
+        final WrappedNetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
         return nm;
     }
 
     private WrappedNetworkMonitor makeNotMeteredNetworkMonitor() {
-        final WrappedNetworkMonitor nm = makeMonitor();
-        setNetworkCapabilities(nm, NOT_METERED_CAPABILITIES);
+        final WrappedNetworkMonitor nm = makeMonitor(NOT_METERED_CAPABILITIES);
         return nm;
     }
 
@@ -603,7 +673,7 @@
         setSslException(mHttpsConnection);
         setPortal302(mHttpConnection);
 
-        final NetworkMonitor nm = makeMonitor();
+        final NetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
         nm.notifyNetworkConnected(TEST_LINK_PROPERTIES, METERED_CAPABILITIES);
 
         verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
@@ -638,6 +708,63 @@
     }
 
     @Test
+    public void testPrivateDnsSuccess() throws Exception {
+        setStatus(mHttpsConnection, 204);
+        setStatus(mHttpConnection, 204);
+        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::53"});
+
+        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
+    }
+
+    @Test
+    public void testPrivateDnsResolutionRetryUpdate() throws Exception {
+        // Set a private DNS hostname that doesn't resolve and expect validation to fail.
+        mFakeDns.setAnswer("dns.google", new String[0]);
+        setStatus(mHttpsConnection, 204);
+        setStatus(mHttpConnection, 204);
+
+        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));
+
+        // Fix DNS and retry, expect validation to succeed.
+        reset(mCallbacks);
+        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"});
+
+        wnm.forceReevaluation(Process.myUid());
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
+
+        // Change configuration to an invalid DNS name, expect validation to fail.
+        reset(mCallbacks);
+        mFakeDns.setAnswer("dns.bad", new String[0]);
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.bad", new InetAddress[0]));
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));
+
+        // Change configuration back to working again, but make private DNS not work.
+        // Expect validation to fail.
+        reset(mCallbacks);
+        mFakeDns.setNonBypassPrivateDnsWorking(false);
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));
+
+        // Make private DNS work again. Expect validation to succeed.
+        reset(mCallbacks);
+        mFakeDns.setNonBypassPrivateDnsWorking(true);
+        wnm.forceReevaluation(Process.myUid());
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
+    }
+
+    @Test
     public void testDataStall_StallSuspectedAndSendMetrics() throws IOException {
         WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
         wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000);
@@ -728,25 +855,27 @@
         WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
         final int shortTimeoutMs = 200;
 
+        // Clear the wildcard DNS response created in setUp.
+        mFakeDns.setAnswer("*", null);
+
         String[] expected = new String[]{"2001:db8::"};
-        setDnsAnswers(expected);
+        mFakeDns.setAnswer("www.google.com", expected);
         InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
         assertIpAddressArrayEquals(expected, actual);
 
         expected = new String[]{"2001:db8::", "192.0.2.1"};
-        setDnsAnswers(expected);
-        actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+        mFakeDns.setAnswer("www.googleapis.com", expected);
+        actual = wnm.sendDnsProbeWithTimeout("www.googleapis.com", shortTimeoutMs);
         assertIpAddressArrayEquals(expected, actual);
 
-        expected = new String[0];
-        setDnsAnswers(expected);
+        mFakeDns.setAnswer("www.google.com", new String[0]);
         try {
             wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
             fail("No DNS results, expected UnknownHostException");
         } catch (UnknownHostException e) {
         }
 
-        setDnsAnswers(null);
+        mFakeDns.setAnswer("www.google.com", null);
         try {
             wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
             fail("DNS query timed out, expected UnknownHostException");
@@ -841,7 +970,7 @@
     }
 
     private NetworkMonitor runNetworkTest(NetworkCapabilities nc, int testResult) {
-        final NetworkMonitor monitor = makeMonitor();
+        final NetworkMonitor monitor = makeMonitor(nc);
         monitor.notifyNetworkConnected(TEST_LINK_PROPERTIES, nc);
         try {
             verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))