Add a timeout for the DNS probe.

Bug: 129312219
Test: atest NetworkStackTests
Test: atest --generate-new-metrics 50 NetworkStackTests:com.android.server.connectivity.NetworkMonitorTest
Change-Id: Ib3ab9105d7ae39f551b51e8d5a04b9ec5e549655
diff --git a/packages/NetworkStack/res/values/config.xml b/packages/NetworkStack/res/values/config.xml
index 704788d..478ed6b 100644
--- a/packages/NetworkStack/res/values/config.xml
+++ b/packages/NetworkStack/res/values/config.xml
@@ -7,6 +7,9 @@
     values are meant to be the default when no other configuration is specified.
     -->
 
+    <!-- DNS probe timeout for network validation. Enough for 3 DNS queries 5 seconds apart. -->
+    <integer name="default_captive_portal_dns_probe_timeout">12500</integer>
+
     <!-- HTTP URL for network validation, to use for detecting captive portals. -->
     <string name="default_captive_portal_http_url" translatable="false">http://connectivitycheck.gstatic.com/generate_204</string>
 
@@ -27,6 +30,7 @@
 
     <!-- Configuration hooks for the above settings.
          Empty by default but may be overridden by RROs. -->
+    <integer name="config_captive_portal_dns_probe_timeout"></integer>
     <!--suppress CheckTagEmptyBody: overlayable resource to use as configuration hook -->
     <string name="config_captive_portal_http_url" translatable="false"></string>
     <!--suppress CheckTagEmptyBody: overlayable resource to use as configuration hook -->
diff --git a/packages/NetworkStack/src/com/android/server/connectivity/NetworkMonitor.java b/packages/NetworkStack/src/com/android/server/connectivity/NetworkMonitor.java
index 093235e..50eb5d4 100644
--- a/packages/NetworkStack/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/packages/NetworkStack/src/com/android/server/connectivity/NetworkMonitor.java
@@ -59,6 +59,7 @@
 import android.content.IntentFilter;
 import android.content.res.Resources;
 import android.net.ConnectivityManager;
+import android.net.DnsResolver;
 import android.net.INetworkMonitor;
 import android.net.INetworkMonitorCallbacks;
 import android.net.LinkProperties;
@@ -122,6 +123,7 @@
 import java.util.UUID;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 
 /**
@@ -136,8 +138,16 @@
                                                       + "AppleWebKit/537.36 (KHTML, like Gecko) "
                                                       + "Chrome/60.0.3112.32 Safari/537.36";
 
+    @VisibleForTesting
+    static final String CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT =
+            "captive_portal_dns_probe_timeout";
+
     private static final int SOCKET_TIMEOUT_MS = 10000;
     private static final int PROBE_TIMEOUT_MS  = 3000;
+    // Enough for 3 DNS queries 5 seconds apart.
+    // TODO: get this from resources and DeviceConfig instead.
+    private static final int DNS_TIMEOUT_MS = 12500;
+
     enum EvaluationResult {
         VALIDATED(true),
         CAPTIVE_PORTAL(false);
@@ -1185,6 +1195,33 @@
                 Settings.Global.CAPTIVE_PORTAL_HTTPS_URL);
     }
 
+    private int getDnsProbeTimeout() {
+        return getIntSetting(mContext, R.integer.config_captive_portal_dns_probe_timeout,
+                CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT,
+                R.integer.default_captive_portal_dns_probe_timeout);
+    }
+
+    /**
+     * Gets an integer setting from resources or device config
+     *
+     * configResource is used if set, followed by device config if set, followed by defaultResource.
+     * If none of these are set then an exception is thrown.
+     *
+     * TODO: move to a common location such as a ConfigUtils class.
+     * TODO(b/130324939): test that the resources can be overlayed by an RRO package.
+     */
+    @VisibleForTesting
+    int getIntSetting(@NonNull final Context context, @StringRes int configResource,
+            @NonNull String symbol, @StringRes int defaultResource) {
+        final Resources res = context.getResources();
+        try {
+            return res.getInteger(configResource);
+        } catch (Resources.NotFoundException e) {
+            return mDependencies.getDeviceConfigPropertyInt(NAMESPACE_CONNECTIVITY,
+                    symbol, res.getInteger(defaultResource));
+        }
+    }
+
     /**
      * Get the captive portal server HTTP URL that is configured on the device.
      *
@@ -1446,6 +1483,45 @@
         return sendHttpProbe(url, probeType, null);
     }
 
+    /** Do a DNS lookup for the given server, or throw UnknownHostException after timeoutMs */
+    @VisibleForTesting
+    protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
+                throws UnknownHostException {
+        final CountDownLatch latch = new CountDownLatch(1);
+        final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
+        final DnsResolver.Callback<List<InetAddress>> callback =
+                    new DnsResolver.Callback<List<InetAddress>>() {
+            public void onAnswer(List<InetAddress> answer, int rcode) {
+                if (rcode == 0) {
+                    resultRef.set(answer);
+                }
+                latch.countDown();
+            }
+            public void onError(@NonNull DnsResolver.DnsException e) {
+                validationLog("DNS error resolving " + host + ": " + e.getMessage());
+                latch.countDown();
+            }
+        };
+
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
+                TrafficStatsConstants.TAG_SYSTEM_PROBE);
+        mDependencies.getDnsResolver().query(mNetwork, host, DnsResolver.FLAG_EMPTY,
+                r -> r.run() /* executor */, null /* cancellationSignal */, callback);
+        TrafficStats.setThreadStatsTag(oldTag);
+
+        try {
+            latch.await(timeoutMs, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+        }
+
+        List<InetAddress> result = resultRef.get();
+        if (result == null || result.size() == 0) {
+            throw new UnknownHostException(host);
+        }
+
+        return result.toArray(new InetAddress[0]);
+    }
+
     /** Do a DNS resolution of the given server. */
     private void sendDnsProbe(String host) {
         if (TextUtils.isEmpty(host)) {
@@ -1457,7 +1533,7 @@
         int result;
         String connectInfo;
         try {
-            InetAddress[] addresses = mNetwork.getAllByName(host);
+            InetAddress[] addresses = sendDnsProbeWithTimeout(host, getDnsProbeTimeout());
             StringBuffer buffer = new StringBuffer();
             for (InetAddress address : addresses) {
                 buffer.append(',').append(address.getHostAddress());
@@ -1782,6 +1858,10 @@
             return new OneAddressPerFamilyNetwork(network);
         }
 
+        public DnsResolver getDnsResolver() {
+            return DnsResolver.getInstance();
+        }
+
         public Random getRandom() {
             return new Random();
         }
diff --git a/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java b/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
index 594f2ca..0dc1cbf 100644
--- a/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/packages/NetworkStack/tests/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -33,6 +33,7 @@
 import static junit.framework.Assert.assertEquals;
 import static junit.framework.Assert.assertFalse;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
@@ -41,6 +42,7 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.never;
@@ -55,6 +57,7 @@
 import android.content.Intent;
 import android.content.res.Resources;
 import android.net.ConnectivityManager;
+import android.net.DnsResolver;
 import android.net.INetworkMonitorCallbacks;
 import android.net.InetAddresses;
 import android.net.LinkProperties;
@@ -69,6 +72,7 @@
 import android.os.Bundle;
 import android.os.ConditionVariable;
 import android.os.Handler;
+import android.os.Looper;
 import android.os.RemoteException;
 import android.os.SystemClock;
 import android.provider.Settings;
@@ -79,6 +83,7 @@
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.networkstack.R;
 import com.android.networkstack.metrics.DataStallDetectionStats;
 import com.android.networkstack.metrics.DataStallStatsUtils;
 
@@ -96,8 +101,12 @@
 import java.net.HttpURLConnection;
 import java.net.InetAddress;
 import java.net.URL;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Random;
+import java.util.concurrent.Executor;
 
 import javax.net.ssl.SSLHandshakeException;
 
@@ -111,6 +120,7 @@
     private @Mock IpConnectivityLog mLogger;
     private @Mock SharedLog mValidationLogger;
     private @Mock NetworkInfo mNetworkInfo;
+    private @Mock DnsResolver mDnsResolver;
     private @Mock ConnectivityManager mCm;
     private @Mock TelephonyManager mTelephony;
     private @Mock WifiManager mWifi;
@@ -156,10 +166,36 @@
     private static final NetworkCapabilities NO_INTERNET_CAPABILITIES = new NetworkCapabilities()
             .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);
 
+    private void setDnsAnswers(String[] answers) throws UnknownHostException {
+        if (answers == null) {
+            doThrow(new UnknownHostException()).when(mNetwork).getAllByName(any());
+            doNothing().when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
+            return;
+        }
+
+        List<InetAddress> answerList = new ArrayList<>();
+        for (String answer : answers) {
+            answerList.add(InetAddresses.parseNumericAddress(answer));
+        }
+        InetAddress[] answerArray = answerList.toArray(new InetAddress[0]);
+
+        doReturn(answerArray).when(mNetwork).getAllByName(any());
+
+        doAnswer((invocation) -> {
+            Executor executor = (Executor) invocation.getArgument(3);
+            DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);
+            new Handler(Looper.getMainLooper()).post(() -> {
+                executor.execute(() -> callback.onAnswer(answerList, 0));
+            });
+            return null;
+        }).when(mDnsResolver).query(eq(mNetwork), any(), anyInt(), any(), any(), any());
+    }
+
     @Before
     public void setUp() throws IOException {
         MockitoAnnotations.initMocks(this);
         when(mDependencies.getPrivateDnsBypassNetwork(any())).thenReturn(mNetwork);
+        when(mDependencies.getDnsResolver()).thenReturn(mDnsResolver);
         when(mDependencies.getRandom()).thenReturn(mRandom);
         when(mDependencies.getSetting(any(), eq(Settings.Global.CAPTIVE_PORTAL_MODE), anyInt()))
                 .thenReturn(Settings.Global.CAPTIVE_PORTAL_MODE_PROMPT);
@@ -204,9 +240,8 @@
         }).when(mNetwork).openConnection(any());
         when(mHttpConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
         when(mHttpsConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
-        doReturn(new InetAddress[] {
-                InetAddresses.parseNumericAddress("192.168.0.0")
-        }).when(mNetwork).getAllByName(any());
+
+        setDnsAnswers(new String[]{"2001:db8::1", "192.0.2.2"});
 
         when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> {
             mRegisteredReceivers.add(invocation.getArgument(0));
@@ -313,6 +348,44 @@
     }
 
     @Test
+    public void testGetIntSetting() throws Exception {
+        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
+
+        // No config resource, no device config. Expect to get default resource.
+        doThrow(new Resources.NotFoundException())
+                .when(mResources).getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout));
+        doAnswer(invocation -> {
+            int defaultValue = invocation.getArgument(2);
+            return defaultValue;
+        }).when(mDependencies).getDeviceConfigPropertyInt(any(),
+                eq(NetworkMonitor.CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT),
+                anyInt());
+        when(mResources.getInteger(eq(R.integer.default_captive_portal_dns_probe_timeout)))
+                .thenReturn(42);
+        assertEquals(42, wnm.getIntSetting(mContext,
+                R.integer.config_captive_portal_dns_probe_timeout,
+                NetworkMonitor.CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT,
+                R.integer.default_captive_portal_dns_probe_timeout));
+
+        // Set device config. Expect to get device config.
+        when(mDependencies.getDeviceConfigPropertyInt(any(),
+                eq(NetworkMonitor.CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT), anyInt()))
+                        .thenReturn(1234);
+        assertEquals(1234, wnm.getIntSetting(mContext,
+                R.integer.config_captive_portal_dns_probe_timeout,
+                NetworkMonitor.CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT,
+                R.integer.default_captive_portal_dns_probe_timeout));
+
+        // Set config resource. Expect to get config resource.
+        when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout)))
+                .thenReturn(5678);
+        assertEquals(5678, wnm.getIntSetting(mContext,
+                R.integer.config_captive_portal_dns_probe_timeout,
+                NetworkMonitor.CONFIG_CAPTIVE_PORTAL_DNS_PROBE_TIMEOUT,
+                R.integer.default_captive_portal_dns_probe_timeout));
+    }
+
+    @Test
     public void testIsCaptivePortal_HttpProbeIsPortal() throws IOException {
         setSslException(mHttpsConnection);
         setPortal302(mHttpConnection);
@@ -642,6 +715,45 @@
         runPartialConnectivityNetworkTest();
     }
 
+    private void assertIpAddressArrayEquals(String[] expected, InetAddress[] actual) {
+        String[] actualStrings = new String[actual.length];
+        for (int i = 0; i < actual.length; i++) {
+            actualStrings[i] = actual[i].getHostAddress();
+        }
+        assertArrayEquals("Array of IP addresses differs", expected, actualStrings);
+    }
+
+    @Test
+    public void testSendDnsProbeWithTimeout() throws Exception {
+        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
+        final int shortTimeoutMs = 200;
+
+        String[] expected = new String[]{"2001:db8::"};
+        setDnsAnswers(expected);
+        InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+        assertIpAddressArrayEquals(expected, actual);
+
+        expected = new String[]{"2001:db8::", "192.0.2.1"};
+        setDnsAnswers(expected);
+        actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+        assertIpAddressArrayEquals(expected, actual);
+
+        expected = new String[0];
+        setDnsAnswers(expected);
+        try {
+            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+            fail("No DNS results, expected UnknownHostException");
+        } catch (UnknownHostException e) {
+        }
+
+        setDnsAnswers(null);
+        try {
+            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+            fail("DNS query timed out, expected UnknownHostException");
+        } catch (UnknownHostException e) {
+        }
+    }
+
     private void makeDnsTimeoutEvent(WrappedNetworkMonitor wrappedMonitor, int count) {
         for (int i = 0; i < count; i++) {
             wrappedMonitor.getDnsStallDetector().accumulateConsecutiveDnsTimeoutCount(