am 3af56629: Merge "Improve tests for TLS fallback."

* commit '3af56629930e754e8b80c6964da1d9de83510b9e':
  Improve tests for TLS fallback.
diff --git a/luni/src/test/java/libcore/java/net/URLConnectionTest.java b/luni/src/test/java/libcore/java/net/URLConnectionTest.java
index 8319882..714df2c 100644
--- a/luni/src/test/java/libcore/java/net/URLConnectionTest.java
+++ b/luni/src/test/java/libcore/java/net/URLConnectionTest.java
@@ -17,6 +17,7 @@
 package libcore.java.net;
 
 import com.android.okhttp.HttpResponseCache;
+
 import com.google.mockwebserver.MockResponse;
 import com.google.mockwebserver.MockWebServer;
 import com.google.mockwebserver.RecordedRequest;
@@ -37,11 +38,14 @@
 import java.net.Proxy;
 import java.net.ResponseCache;
 import java.net.Socket;
+import java.net.SocketAddress;
+import java.net.SocketException;
 import java.net.SocketTimeoutException;
 import java.net.URI;
 import java.net.URL;
 import java.net.URLConnection;
 import java.net.UnknownHostException;
+import java.nio.channels.SocketChannel;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
 import java.util.ArrayList;
@@ -58,12 +62,13 @@
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.zip.GZIPInputStream;
 import java.util.zip.GZIPOutputStream;
-import javax.net.SocketFactory;
+import javax.net.ssl.HandshakeCompletedListener;
 import javax.net.ssl.HostnameVerifier;
 import javax.net.ssl.HttpsURLConnection;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
+import javax.net.ssl.SSLParameters;
 import javax.net.ssl.SSLSession;
 import javax.net.ssl.SSLSocket;
 import javax.net.ssl.SSLSocketFactory;
@@ -2187,12 +2192,12 @@
     public void testSslFallback() throws Exception {
         TestSSLContext testSSLContext = TestSSLContext.create();
 
-        // This server socket factory only supports SSLv3. This is to avoid issues due to SCSV
-        // checks. See https://tools.ietf.org/html/draft-ietf-tls-downgrade-scsv-00
+        // Android now disables SSLv3 by default. To test fallback we re-enable it for the server.
+        // This can be removed once OkHttp is updated to support other fallback protocols.
         SSLSocketFactory serverSocketFactory =
                 new LimitedProtocolsSocketFactory(
                         testSSLContext.serverContext.getSocketFactory(),
-                        "SSLv3");
+                        "TLSv1", "SSLv3");
 
         server.useHttps(serverSocketFactory, false);
         server.enqueue(new MockResponse().setSocketPolicy(FAIL_HANDSHAKE));
@@ -2200,9 +2205,10 @@
         server.play();
 
         HttpsURLConnection connection = (HttpsURLConnection) server.getUrl("/").openConnection();
-        // Keep track of the client sockets created so that we can interrogate them.
-        RecordingSocketFactory clientSocketFactory =
-                new RecordingSocketFactory(testSSLContext.clientContext.getSocketFactory());
+        // Keeps track of the client sockets created so that we can interrogate them.
+        final boolean disableFallbackScsv = true;
+        FallbackTestClientSocketFactory clientSocketFactory = new FallbackTestClientSocketFactory(
+                testSSLContext.clientContext.getSocketFactory(), disableFallbackScsv);
         connection.setSSLSocketFactory(clientSocketFactory);
         assertEquals("This required a 2nd handshake",
                 readAscii(connection.getInputStream(), Integer.MAX_VALUE));
@@ -2213,23 +2219,20 @@
 
         // Confirm the client fallback looks ok.
         List<SSLSocket> createdSockets = clientSocketFactory.getCreatedSockets();
-        assertEquals(2, createdSockets.size());
-        SSLSocket clientSocket1 = createdSockets.get(0);
+        assertTrue(createdSockets.size() > 1);
+        TlsFallbackDisabledScsvSSLSocket clientSocket1 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(0);
         List<String> clientSocket1EnabledProtocols = Arrays.asList(
                 clientSocket1.getEnabledProtocols());
         assertContains(clientSocket1EnabledProtocols, "TLSv1.2");
-        List<String> clientSocket1EnabledCiphers =
-                Arrays.asList(clientSocket1.getEnabledCipherSuites());
-        assertContainsNoneMatching(
-                clientSocket1EnabledCiphers, StandardNames.CIPHER_SUITE_FALLBACK);
+        assertFalse(clientSocket1.wasTlsFallbackScsvSet());
 
-        SSLSocket clientSocket2 = createdSockets.get(1);
+        TlsFallbackDisabledScsvSSLSocket clientSocket2 =
+                (TlsFallbackDisabledScsvSSLSocket) createdSockets.get(1);
         List<String> clientSocket2EnabledProtocols =
                 Arrays.asList(clientSocket2.getEnabledProtocols());
         assertContainsNoneMatching(clientSocket2EnabledProtocols, "TLSv1.2");
-        List<String> clientSocket2EnabledCiphers =
-                Arrays.asList(clientSocket2.getEnabledCipherSuites());
-        assertContains(clientSocket2EnabledCiphers, StandardNames.CIPHER_SUITE_FALLBACK);
+        assertTrue(clientSocket2.wasTlsFallbackScsvSet());
     }
 
     public void testInspectSslBeforeConnect() throws Exception {
@@ -2490,36 +2493,37 @@
         }
 
         @Override
-        public Socket createSocket(Socket s, String host, int port, boolean autoClose)
+        public SSLSocket createSocket(Socket s, String host, int port, boolean autoClose)
                 throws IOException {
-            return delegate.createSocket(s, host, port, autoClose);
+            return (SSLSocket) delegate.createSocket(s, host, port, autoClose);
         }
 
         @Override
-        public Socket createSocket() throws IOException {
-            return delegate.createSocket();
+        public SSLSocket createSocket() throws IOException {
+            return (SSLSocket) delegate.createSocket();
         }
 
         @Override
-        public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
-            return delegate.createSocket(host, port);
+        public SSLSocket createSocket(String host, int port)
+                throws IOException, UnknownHostException {
+            return (SSLSocket) delegate.createSocket(host, port);
         }
 
         @Override
-        public Socket createSocket(String host, int port, InetAddress localHost,
+        public SSLSocket createSocket(String host, int port, InetAddress localHost,
                 int localPort) throws IOException, UnknownHostException {
-            return delegate.createSocket(host, port, localHost, localPort);
+            return (SSLSocket) delegate.createSocket(host, port, localHost, localPort);
         }
 
         @Override
-        public Socket createSocket(InetAddress host, int port) throws IOException {
-            return delegate.createSocket(host, port);
+        public SSLSocket createSocket(InetAddress host, int port) throws IOException {
+            return (SSLSocket) delegate.createSocket(host, port);
         }
 
         @Override
-        public Socket createSocket(InetAddress address, int port,
+        public SSLSocket createSocket(InetAddress address, int port,
                 InetAddress localAddress, int localPort) throws IOException {
-            return delegate.createSocket(address, port, localAddress, localPort);
+            return (SSLSocket) delegate.createSocket(address, port, localAddress, localPort);
         }
 
     }
@@ -2538,7 +2542,7 @@
         }
 
         @Override
-        public Socket createSocket(Socket s, String host, int port, boolean autoClose)
+        public SSLSocket createSocket(Socket s, String host, int port, boolean autoClose)
                 throws IOException {
             SSLSocket socket = (SSLSocket) delegate.createSocket(s, host, port, autoClose);
             socket.setEnabledProtocols(protocols);
@@ -2546,21 +2550,22 @@
         }
 
         @Override
-        public Socket createSocket() throws IOException {
+        public SSLSocket createSocket() throws IOException {
             SSLSocket socket = (SSLSocket) delegate.createSocket();
             socket.setEnabledProtocols(protocols);
             return socket;
         }
 
         @Override
-        public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
+        public SSLSocket createSocket(String host, int port)
+                throws IOException, UnknownHostException {
             SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
             socket.setEnabledProtocols(protocols);
             return socket;
         }
 
         @Override
-        public Socket createSocket(String host, int port, InetAddress localHost,
+        public SSLSocket createSocket(String host, int port, InetAddress localHost,
                 int localPort) throws IOException, UnknownHostException {
             SSLSocket socket = (SSLSocket) delegate.createSocket(host, port, localHost, localPort);
             socket.setEnabledProtocols(protocols);
@@ -2568,14 +2573,14 @@
         }
 
         @Override
-        public Socket createSocket(InetAddress host, int port) throws IOException {
+        public SSLSocket createSocket(InetAddress host, int port) throws IOException {
             SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
             socket.setEnabledProtocols(protocols);
             return socket;
         }
 
         @Override
-        public Socket createSocket(InetAddress address, int port,
+        public SSLSocket createSocket(InetAddress address, int port,
                 InetAddress localAddress, int localPort) throws IOException {
             SSLSocket socket =
                     (SSLSocket) delegate.createSocket(address, port, localAddress, localPort);
@@ -2585,58 +2590,337 @@
     }
 
     /**
-     * An SSLSocketFactory that delegates calls and keeps a record of any sockets created.
+     * An {@link javax.net.ssl.SSLSocket} that delegates all calls.
      */
-    private static class RecordingSocketFactory extends DelegatingSSLSocketFactory {
+    private static abstract class DelegatingSSLSocket extends SSLSocket {
+        protected final SSLSocket delegate;
 
+        public DelegatingSSLSocket(SSLSocket delegate) {
+            this.delegate = delegate;
+        }
+
+        @Override public void shutdownInput() throws IOException {
+            delegate.shutdownInput();
+        }
+
+        @Override public void shutdownOutput() throws IOException {
+            delegate.shutdownOutput();
+        }
+
+        @Override public String[] getSupportedCipherSuites() {
+            return delegate.getSupportedCipherSuites();
+        }
+
+        @Override public String[] getEnabledCipherSuites() {
+            return delegate.getEnabledCipherSuites();
+        }
+
+        @Override public void setEnabledCipherSuites(String[] suites) {
+            delegate.setEnabledCipherSuites(suites);
+        }
+
+        @Override public String[] getSupportedProtocols() {
+            return delegate.getSupportedProtocols();
+        }
+
+        @Override public String[] getEnabledProtocols() {
+            return delegate.getEnabledProtocols();
+        }
+
+        @Override public void setEnabledProtocols(String[] protocols) {
+            delegate.setEnabledProtocols(protocols);
+        }
+
+        @Override public SSLSession getSession() {
+            return delegate.getSession();
+        }
+
+        @Override public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {
+            delegate.addHandshakeCompletedListener(listener);
+        }
+
+        @Override public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {
+            delegate.removeHandshakeCompletedListener(listener);
+        }
+
+        @Override public void startHandshake() throws IOException {
+            delegate.startHandshake();
+        }
+
+        @Override public void setUseClientMode(boolean mode) {
+            delegate.setUseClientMode(mode);
+        }
+
+        @Override public boolean getUseClientMode() {
+            return delegate.getUseClientMode();
+        }
+
+        @Override public void setNeedClientAuth(boolean need) {
+            delegate.setNeedClientAuth(need);
+        }
+
+        @Override public void setWantClientAuth(boolean want) {
+            delegate.setWantClientAuth(want);
+        }
+
+        @Override public boolean getNeedClientAuth() {
+            return delegate.getNeedClientAuth();
+        }
+
+        @Override public boolean getWantClientAuth() {
+            return delegate.getWantClientAuth();
+        }
+
+        @Override public void setEnableSessionCreation(boolean flag) {
+            delegate.setEnableSessionCreation(flag);
+        }
+
+        @Override public boolean getEnableSessionCreation() {
+            return delegate.getEnableSessionCreation();
+        }
+
+        @Override public SSLParameters getSSLParameters() {
+            return delegate.getSSLParameters();
+        }
+
+        @Override public void setSSLParameters(SSLParameters p) {
+            delegate.setSSLParameters(p);
+        }
+
+        @Override public void close() throws IOException {
+            delegate.close();
+        }
+
+        @Override public InetAddress getInetAddress() {
+            return delegate.getInetAddress();
+        }
+
+        @Override public InputStream getInputStream() throws IOException {
+            return delegate.getInputStream();
+        }
+
+        @Override public boolean getKeepAlive() throws SocketException {
+            return delegate.getKeepAlive();
+        }
+
+        @Override public InetAddress getLocalAddress() {
+            return delegate.getLocalAddress();
+        }
+
+        @Override public int getLocalPort() {
+            return delegate.getLocalPort();
+        }
+
+        @Override public OutputStream getOutputStream() throws IOException {
+            return delegate.getOutputStream();
+        }
+
+        @Override public int getPort() {
+            return delegate.getPort();
+        }
+
+        @Override public int getSoLinger() throws SocketException {
+            return delegate.getSoLinger();
+        }
+
+        @Override public int getReceiveBufferSize() throws SocketException {
+            return delegate.getReceiveBufferSize();
+        }
+
+        @Override public int getSendBufferSize() throws SocketException {
+            return delegate.getSendBufferSize();
+        }
+
+        @Override public int getSoTimeout() throws SocketException {
+            return delegate.getSoTimeout();
+        }
+
+        @Override public boolean getTcpNoDelay() throws SocketException {
+            return delegate.getTcpNoDelay();
+        }
+
+        @Override public void setKeepAlive(boolean keepAlive) throws SocketException {
+            delegate.setKeepAlive(keepAlive);
+        }
+
+        @Override public void setSendBufferSize(int size) throws SocketException {
+            delegate.setSendBufferSize(size);
+        }
+
+        @Override public void setReceiveBufferSize(int size) throws SocketException {
+            delegate.setReceiveBufferSize(size);
+        }
+
+        @Override public void setSoLinger(boolean on, int timeout) throws SocketException {
+            delegate.setSoLinger(on, timeout);
+        }
+
+        @Override public void setSoTimeout(int timeout) throws SocketException {
+            delegate.setSoTimeout(timeout);
+        }
+
+        @Override public void setTcpNoDelay(boolean on) throws SocketException {
+            delegate.setTcpNoDelay(on);
+        }
+
+        @Override public String toString() {
+            return delegate.toString();
+        }
+
+        @Override public SocketAddress getLocalSocketAddress() {
+            return delegate.getLocalSocketAddress();
+        }
+
+        @Override public SocketAddress getRemoteSocketAddress() {
+            return delegate.getRemoteSocketAddress();
+        }
+
+        @Override public boolean isBound() {
+            return delegate.isBound();
+        }
+
+        @Override public boolean isConnected() {
+            return delegate.isConnected();
+        }
+
+        @Override public boolean isClosed() {
+            return delegate.isClosed();
+        }
+
+        @Override public void bind(SocketAddress localAddr) throws IOException {
+            delegate.bind(localAddr);
+        }
+
+        @Override public void connect(SocketAddress remoteAddr) throws IOException {
+            delegate.connect(remoteAddr);
+        }
+
+        @Override public void connect(SocketAddress remoteAddr, int timeout) throws IOException {
+            delegate.connect(remoteAddr, timeout);
+        }
+
+        @Override public boolean isInputShutdown() {
+            return delegate.isInputShutdown();
+        }
+
+        @Override public boolean isOutputShutdown() {
+            return delegate.isOutputShutdown();
+        }
+
+        @Override public void setReuseAddress(boolean reuse) throws SocketException {
+            delegate.setReuseAddress(reuse);
+        }
+
+        @Override public boolean getReuseAddress() throws SocketException {
+            return delegate.getReuseAddress();
+        }
+
+        @Override public void setOOBInline(boolean oobinline) throws SocketException {
+            delegate.setOOBInline(oobinline);
+        }
+
+        @Override public boolean getOOBInline() throws SocketException {
+            return delegate.getOOBInline();
+        }
+
+        @Override public void setTrafficClass(int value) throws SocketException {
+            delegate.setTrafficClass(value);
+        }
+
+        @Override public int getTrafficClass() throws SocketException {
+            return delegate.getTrafficClass();
+        }
+
+        @Override public void sendUrgentData(int value) throws IOException {
+            delegate.sendUrgentData(value);
+        }
+
+        @Override public SocketChannel getChannel() {
+            return delegate.getChannel();
+        }
+
+        @Override public void setPerformancePreferences(int connectionTime, int latency,
+                int bandwidth) {
+            delegate.setPerformancePreferences(connectionTime, latency, bandwidth);
+        }
+    }
+
+    /**
+     * An SSLSocketFactory that delegates calls. It keeps a record of any sockets created.
+     * If {@link #disableTlsFallbackScsv} is set to {@code true} then sockets created by the
+     * delegate are wrapped with ones that will not accept the {@link #TLS_FALLBACK_SCSV} cipher,
+     * thus bypassing server-side fallback checks on platforms that support it. Unfortunately this
+     * wrapping will disable any reflection-based calls to SSLSocket from Platform.
+     */
+    private static class FallbackTestClientSocketFactory extends DelegatingSSLSocketFactory {
+        /**
+         * The cipher suite used during TLS connection fallback to indicate a fallback.
+         * See https://tools.ietf.org/html/draft-ietf-tls-downgrade-scsv-00
+         */
+        public static final String TLS_FALLBACK_SCSV = "TLS_FALLBACK_SCSV";
+
+        private final boolean disableTlsFallbackScsv;
         private final List<SSLSocket> createdSockets = new ArrayList<SSLSocket>();
 
-        private RecordingSocketFactory(SSLSocketFactory delegate) {
+        public FallbackTestClientSocketFactory(SSLSocketFactory delegate,
+                boolean disableTlsFallbackScsv) {
             super(delegate);
+            this.disableTlsFallbackScsv = disableTlsFallbackScsv;
         }
 
-        @Override
-        public Socket createSocket(Socket s, String host, int port, boolean autoClose)
+        @Override public SSLSocket createSocket(Socket s, String host, int port, boolean autoClose)
                 throws IOException {
-            SSLSocket socket = (SSLSocket) delegate.createSocket(s, host, port, autoClose);
+            SSLSocket socket = super.createSocket(s, host, port, autoClose);
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
 
-        @Override
-        public Socket createSocket() throws IOException {
-            SSLSocket socket = (SSLSocket) delegate.createSocket();
+        @Override public SSLSocket createSocket() throws IOException {
+            SSLSocket socket = super.createSocket();
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
 
-        @Override
-        public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
-            SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
+        @Override public SSLSocket createSocket(String host,int port) throws IOException {
+            SSLSocket socket = super.createSocket(host, port);
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
 
-        @Override
-        public Socket createSocket(String host, int port, InetAddress localHost,
-                int localPort) throws IOException, UnknownHostException {
-            SSLSocket socket = (SSLSocket) delegate.createSocket(host, port, localHost, localPort);
+        @Override public SSLSocket createSocket(String host,int port, InetAddress localHost,
+                int localPort) throws IOException {
+            SSLSocket socket = super.createSocket(host, port, localHost, localPort);
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
 
-        @Override
-        public Socket createSocket(InetAddress host, int port) throws IOException {
-            SSLSocket socket = (SSLSocket) delegate.createSocket(host, port);
+        @Override public SSLSocket createSocket(InetAddress host,int port) throws IOException {
+            SSLSocket socket = super.createSocket(host, port);
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
 
-        @Override
-        public Socket createSocket(InetAddress address, int port,
+        @Override public SSLSocket createSocket(InetAddress address,int port,
                 InetAddress localAddress, int localPort) throws IOException {
-            SSLSocket socket =
-                    (SSLSocket) delegate.createSocket(address, port, localAddress, localPort);
+            SSLSocket socket = super.createSocket(address, port, localAddress, localPort);
+            if (disableTlsFallbackScsv) {
+                socket = new TlsFallbackDisabledScsvSSLSocket(socket);
+            }
             createdSockets.add(socket);
             return socket;
         }
@@ -2646,4 +2930,31 @@
         }
     }
 
+    private static class TlsFallbackDisabledScsvSSLSocket extends DelegatingSSLSocket {
+
+        private boolean tlsFallbackScsvSet;
+
+        public TlsFallbackDisabledScsvSSLSocket(SSLSocket socket) {
+            super(socket);
+        }
+
+        @Override public void setEnabledCipherSuites(String[] suites) {
+            List<String> enabledCipherSuites = new ArrayList<String>(suites.length);
+            for (String suite : suites) {
+                if (suite.equals(FallbackTestClientSocketFactory.TLS_FALLBACK_SCSV)) {
+                    // Record that an attempt was made to set TLS_FALLBACK_SCSV, but don't actually
+                    // set it.
+                    tlsFallbackScsvSet = true;
+                } else {
+                    enabledCipherSuites.add(suite);
+                }
+            }
+            delegate.setEnabledCipherSuites(
+                    enabledCipherSuites.toArray(new String[enabledCipherSuites.size()]));
+        }
+
+        public boolean wasTlsFallbackScsvSet() {
+            return tlsFallbackScsvSet;
+        }
+    }
 }