Merge "Test changes for the HttpURLConnection fallback strategy"
diff --git a/luni/src/test/java/libcore/java/net/URLConnectionTest.java b/luni/src/test/java/libcore/java/net/URLConnectionTest.java
index a94f0be..3db7c8f 100644
--- a/luni/src/test/java/libcore/java/net/URLConnectionTest.java
+++ b/luni/src/test/java/libcore/java/net/URLConnectionTest.java
@@ -74,7 +74,6 @@
 import javax.net.ssl.SSLSocketFactory;
 import javax.net.ssl.TrustManager;
 import javax.net.ssl.X509TrustManager;
-import libcore.java.security.StandardNames;
 import libcore.java.security.TestKeyStore;
 import libcore.java.util.AbstractResourceLeakageDetectorTestCase;
 import libcore.javax.net.ssl.TestSSLContext;
@@ -2189,16 +2188,19 @@
         urlConnection.getInputStream();
     }
 
-    public void testSslFallback() throws Exception {
+    public void testSslFallback_allSupportedProtocols() throws Exception {
         TestSSLContext testSSLContext = TestSSLContext.create();
 
+        String[] allSupportedProtocols = { "TLSv1.2", "TLSv1.1", "TLSv1", "SSLv3" };
         SSLSocketFactory serverSocketFactory =
                 new LimitedProtocolsSocketFactory(
                         testSSLContext.serverContext.getSocketFactory(),
-                        "TLSv1", "SSLv3");
+                        allSupportedProtocols);
         server.useHttps(serverSocketFactory, false);
         server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
-        server.enqueue(new MockResponse().setBody("This required a 2nd handshake"));
+        server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
+        server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
+        server.enqueue(new MockResponse().setBody("This required fallbacks"));
         server.play();
 
         HttpsURLConnection connection = (HttpsURLConnection) server.getUrl("/").openConnection();
@@ -2206,32 +2208,87 @@
         final boolean disableFallbackScsv = true;
         FallbackTestClientSocketFactory clientSocketFactory = new FallbackTestClientSocketFactory(
                 new LimitedProtocolsSocketFactory(
-                        testSSLContext.clientContext.getSocketFactory(), "TLSv1", "SSLv3"),
+                        testSSLContext.clientContext.getSocketFactory(), allSupportedProtocols),
                 disableFallbackScsv);
         connection.setSSLSocketFactory(clientSocketFactory);
-        assertEquals("This required a 2nd handshake",
+        assertEquals("This required fallbacks",
                 readAscii(connection.getInputStream(), Integer.MAX_VALUE));
 
+        // Confirm the server accepted a single connection.
         RecordedRequest retry = server.takeRequest();
         assertEquals(0, retry.getSequenceNumber());
         assertEquals("SSLv3", retry.getSslProtocol());
 
         // Confirm the client fallback looks ok.
         List<SSLSocket> createdSockets = clientSocketFactory.getCreatedSockets();
-        assertTrue(createdSockets.size() > 1);
+        assertEquals(4, createdSockets.size());
         TlsFallbackDisabledScsvSSLSocket clientSocket1 =
                 (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(0);
-        List<String> clientSocket1EnabledProtocols = Arrays.asList(
-                clientSocket1.getEnabledProtocols());
-        assertContains(clientSocket1EnabledProtocols, "TLSv1");
-        assertFalse(clientSocket1.wasTlsFallbackScsvSet());
+        assertSslSocket(clientSocket1,
+                false /* expectedWasFallbackScsvSet */, "TLSv1.2", "TLSv1.1", "TLSv1", "SSLv3");
 
         TlsFallbackDisabledScsvSSLSocket clientSocket2 =
                 (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(1);
-        List<String> clientSocket2EnabledProtocols =
-                Arrays.asList(clientSocket2.getEnabledProtocols());
-        assertContainsNoneMatching(clientSocket2EnabledProtocols, "TLSv1");
-        assertTrue(clientSocket2.wasTlsFallbackScsvSet());
+        assertSslSocket(clientSocket2,
+                true /* expectedWasFallbackScsvSet */, "TLSv1.1", "TLSv1", "SSLv3");
+
+        TlsFallbackDisabledScsvSSLSocket clientSocket3 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(2);
+        assertSslSocket(clientSocket3, true /* expectedWasFallbackScsvSet */, "TLSv1", "SSLv3");
+
+        TlsFallbackDisabledScsvSSLSocket clientSocket4 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(3);
+        assertSslSocket(clientSocket4, true /* expectedWasFallbackScsvSet */, "SSLv3");
+    }
+
+    public void testSslFallback_defaultProtocols() throws Exception {
+        TestSSLContext testSSLContext = TestSSLContext.create();
+
+        server.useHttps(testSSLContext.serverContext.getSocketFactory(), false);
+        server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
+        server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
+        server.enqueue(new MockResponse().setBody("This required fallbacks"));
+        server.play();
+
+        HttpsURLConnection connection = (HttpsURLConnection) server.getUrl("/").openConnection();
+        // Keeps track of the client sockets created so that we can interrogate them.
+        final boolean disableFallbackScsv = true;
+        FallbackTestClientSocketFactory clientSocketFactory = new FallbackTestClientSocketFactory(
+                testSSLContext.clientContext.getSocketFactory(),
+                disableFallbackScsv);
+        connection.setSSLSocketFactory(clientSocketFactory);
+        assertEquals("This required fallbacks",
+                readAscii(connection.getInputStream(), Integer.MAX_VALUE));
+
+        // Confirm the server accepted a single connection.
+        RecordedRequest retry = server.takeRequest();
+        assertEquals(0, retry.getSequenceNumber());
+        assertEquals("TLSv1", retry.getSslProtocol());
+
+        // Confirm the client fallback looks ok.
+        List<SSLSocket> createdSockets = clientSocketFactory.getCreatedSockets();
+        assertEquals(3, createdSockets.size());
+        TlsFallbackDisabledScsvSSLSocket clientSocket1 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(0);
+        assertSslSocket(clientSocket1,
+                false /* expectedWasFallbackScsvSet */, "TLSv1.2", "TLSv1.1", "TLSv1");
+
+        TlsFallbackDisabledScsvSSLSocket clientSocket2 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(1);
+        assertSslSocket(clientSocket2, true /* expectedWasFallbackScsvSet */, "TLSv1.1", "TLSv1");
+
+        TlsFallbackDisabledScsvSSLSocket clientSocket3 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(2);
+        assertSslSocket(clientSocket3, true /* expectedWasFallbackScsvSet */, "TLSv1");
+    }
+
+    private static void assertSslSocket(TlsFallbackDisabledScsvSSLSocket socket,
+            boolean expectedWasFallbackScsvSet, String... expectedEnabledProtocols) {
+        Set<String> enabledProtocols =
+                new HashSet<String>(Arrays.asList(socket.getEnabledProtocols()));
+        Set<String> expectedProtocolsSet = new HashSet<String>(Arrays.asList(expectedEnabledProtocols));
+        assertEquals(enabledProtocols, expectedProtocolsSet);
+        assertEquals(expectedWasFallbackScsvSet, socket.wasTlsFallbackScsvSet());
     }
 
     public void testInspectSslBeforeConnect() throws Exception {