Fix MailTransport#open() behavior

Use designated network and create socket with hostname while using SSL
VVM.

VVM currently ignores designated network if SSL is activated. Although
all carriers supported now which requires designated network don't use
SSL, it might be a problem in the future.

SSL socket created without hostname caused NPE in b/26144676. The bug is
fixed on security side but it is still recommended to have the hostname
passed in.

+ Static Java Library mockito-target and class MockitoHelper for testing

Bug:26268339
Bug:26270166
Change-Id: If9c82381916393427e932cc28089e1efa1e3b040
diff --git a/src/com/android/phone/common/mail/MailTransport.java b/src/com/android/phone/common/mail/MailTransport.java
index 172d1a9..a303036 100644
--- a/src/com/android/phone/common/mail/MailTransport.java
+++ b/src/com/android/phone/common/mail/MailTransport.java
@@ -18,20 +18,10 @@
 import android.content.Context;
 import android.net.Network;
 
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.phone.common.mail.store.ImapStore;
 import com.android.phone.common.mail.utils.LogUtils;
 
-import java.net.SocketAddress;
-import java.util.ArrayList;
-import java.util.List;
-
-import javax.net.ssl.HostnameVerifier;
-import javax.net.ssl.HttpsURLConnection;
-import javax.net.ssl.SSLException;
-import javax.net.ssl.SSLPeerUnverifiedException;
-import javax.net.ssl.SSLSession;
-import javax.net.ssl.SSLSocket;
-
 import java.io.BufferedInputStream;
 import java.io.BufferedOutputStream;
 import java.io.IOException;
@@ -40,6 +30,15 @@
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.Socket;
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.net.ssl.HostnameVerifier;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
+import javax.net.ssl.SSLSocket;
 
 /**
  * Make connection and perform operations on mail server by reading and writing lines.
@@ -62,6 +61,7 @@
     private BufferedInputStream mIn;
     private BufferedOutputStream mOut;
     private int mFlags;
+    private SocketCreator mSocketCreator;
 
     public MailTransport(Context context, Network network, String address, int port, int flags) {
         mContext = context;
@@ -92,44 +92,57 @@
      * Attempts to open a connection using the Uri supplied for connection parameters.  Will attempt
      * an SSL connection if indicated.
      */
-    public void open() throws MessagingException, CertificateValidationException {
+    public void open() throws MessagingException {
         LogUtils.d(TAG, "*** IMAP open " + mHost + ":" + String.valueOf(mPort));
 
-        List<SocketAddress> socketAddresses = new ArrayList<SocketAddress>();
-        try {
-            if (canTrySslSecurity()) {
-                mSocket = HttpsURLConnection.getDefaultSSLSocketFactory().createSocket();
-                socketAddresses.add(new InetSocketAddress(mHost, mPort));
-            } else {
-                if (mNetwork == null) {
-                    mSocket = new Socket();
-                    socketAddresses.add(new InetSocketAddress(mHost, mPort));
-                } else {
-                    InetAddress[] inetAddresses = mNetwork.getAllByName(mHost);
-                    for (int i = 0; i < inetAddresses.length; i++) {
-                        socketAddresses.add(new InetSocketAddress(inetAddresses[i], mPort));
-                    }
-                    mSocket = mNetwork.getSocketFactory().createSocket();
+        List<InetSocketAddress> socketAddresses = new ArrayList<InetSocketAddress>();
+
+        if (mNetwork == null) {
+            socketAddresses.add(new InetSocketAddress(mHost, mPort));
+        } else {
+            try {
+                InetAddress[] inetAddresses = mNetwork.getAllByName(mHost);
+                if (inetAddresses.length == 0) {
+                    throw new MessagingException(MessagingException.IOERROR,
+                            "Host name " + mHost + "cannot be resolved on designated network");
                 }
+                for (int i = 0; i < inetAddresses.length; i++) {
+                    socketAddresses.add(new InetSocketAddress(inetAddresses[i], mPort));
+                }
+            } catch (IOException ioe) {
+                LogUtils.d(TAG, ioe.toString());
+                throw new MessagingException(MessagingException.IOERROR, ioe.toString());
             }
-        } catch (IOException ioe) {
-            LogUtils.d(TAG, ioe.toString());
-            throw new MessagingException(MessagingException.IOERROR, ioe.toString());
         }
 
+        boolean success = false;
         while (socketAddresses.size() > 0) {
+            mSocket = createSocket();
             try {
-                mSocket.connect(socketAddresses.remove(0), SOCKET_CONNECT_TIMEOUT);
+                InetSocketAddress address = socketAddresses.remove(0);
+                mSocket.connect(address, SOCKET_CONNECT_TIMEOUT);
 
-                // After the socket connects to an SSL server, confirm that the hostname is as
-                // expected
-                if (canTrySslSecurity() && !canTrustAllCertificates()) {
-                    verifyHostname(mSocket, mHost);
+                if (canTrySslSecurity()) {
+                    /**
+                     * {@link SSLSocket} must connect in its constructor, or create through a
+                     * already connected socket. Since we need to use
+                     * {@link Socket#connect(SocketAddress, int) } to set timeout, we can only
+                     * create it here.
+                     */
+                    LogUtils.d(TAG, "open: converting to SSL socket");
+                    mSocket = HttpsURLConnection.getDefaultSSLSocketFactory()
+                            .createSocket(mSocket, address.getHostName(), address.getPort(), true);
+                    // After the socket connects to an SSL server, confirm that the hostname is as
+                    // expected
+                    if (!canTrustAllCertificates()) {
+                        verifyHostname(mSocket, mHost);
+                    }
                 }
 
                 mIn = new BufferedInputStream(mSocket.getInputStream(), 1024);
                 mOut = new BufferedOutputStream(mSocket.getOutputStream(), 512);
                 mSocket.setSoTimeout(SOCKET_READ_TIMEOUT);
+                success = true;
                 return;
             } catch (IOException ioe) {
                 LogUtils.d(TAG, ioe.toString());
@@ -137,10 +150,50 @@
                     // Only throw an error when there are no more sockets to try.
                     throw new MessagingException(MessagingException.IOERROR, ioe.toString());
                 }
+            } finally {
+                if (!success) {
+                    try {
+                        mSocket.close();
+                        mSocket = null;
+                    } catch (IOException ioe) {
+                        throw new MessagingException(MessagingException.IOERROR, ioe.toString());
+                    }
+
+                }
             }
         }
     }
 
+    // For testing. We need something that can replace the behavior of "new Socket()"
+    @VisibleForTesting
+    interface SocketCreator {
+        Socket createSocket() throws MessagingException;
+    }
+
+    @VisibleForTesting
+    void setSocketCreator(SocketCreator creator) {
+        mSocketCreator = creator;
+    }
+
+    protected Socket createSocket() throws MessagingException {
+        if (mSocketCreator != null) {
+            return mSocketCreator.createSocket();
+        }
+
+        if (mNetwork == null) {
+            LogUtils.v(TAG, "createSocket: network not specified");
+            return new Socket();
+        }
+
+        try {
+            LogUtils.v(TAG, "createSocket: network specified");
+            return mNetwork.getSocketFactory().createSocket();
+        } catch (IOException ioe) {
+            LogUtils.d(TAG, ioe.toString());
+            throw new MessagingException(MessagingException.IOERROR, ioe.toString());
+        }
+    }
+
     /**
      * Lightweight version of SSLCertificateSocketFactory.verifyHostname, which provides this
      * service but is not in the public API.
@@ -173,8 +226,8 @@
         // in the verifier code and is not available in the verifier API, and extracting the
         // CN & alts is beyond the scope of this patch.
         if (!HOSTNAME_VERIFIER.verify(hostname, session)) {
-            throw new SSLPeerUnverifiedException(
-                    "Certificate hostname not useable for server: " + hostname);
+            throw new SSLPeerUnverifiedException("Certificate hostname not useable for server: "
+                    + session.getPeerPrincipal());
         }
     }
 
diff --git a/tests/Android.mk b/tests/Android.mk
index a3a657b..6cc0355 100644
--- a/tests/Android.mk
+++ b/tests/Android.mk
@@ -27,4 +27,8 @@
 
 LOCAL_INSTRUMENTATION_FOR := TeleService
 
+LOCAL_STATIC_JAVA_LIBRARIES := \
+        android-support-test \
+        mockito-target
+
 include $(BUILD_PACKAGE)
diff --git a/tests/src/com/android/phone/MockitoHelper.java b/tests/src/com/android/phone/MockitoHelper.java
new file mode 100644
index 0000000..3da5d6e
--- /dev/null
+++ b/tests/src/com/android/phone/MockitoHelper.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.phone;
+
+import com.android.services.telephony.Log;
+
+/**
+ * Helper for Mockito-based test cases.
+ */
+public final class MockitoHelper {
+
+    private static final String TAG = "MockitoHelper";
+
+    private ClassLoader mOriginalClassLoader;
+    private Thread mContextThread;
+
+    /**
+     * Creates a new helper, which in turn will set the context classloader so it can load Mockito
+     * resources.
+     *
+     * @param packageClass test case class
+     */
+    public void setUp(Class<?> packageClass) throws Exception {
+        // makes a copy of the context classloader
+        mContextThread = Thread.currentThread();
+        mOriginalClassLoader = mContextThread.getContextClassLoader();
+        ClassLoader newClassLoader = packageClass.getClassLoader();
+        Log.v(TAG, "Changing context classloader from " + mOriginalClassLoader
+                + " to " + newClassLoader);
+        mContextThread.setContextClassLoader(newClassLoader);
+    }
+
+    /**
+     * Restores the context classloader to the previous value.
+     */
+    public void tearDown() throws Exception {
+        Log.v(TAG, "Restoring context classloader to " + mOriginalClassLoader);
+        mContextThread.setContextClassLoader(mOriginalClassLoader);
+    }
+}
\ No newline at end of file
diff --git a/tests/src/com/android/phone/common/mail/MailTransportTest.java b/tests/src/com/android/phone/common/mail/MailTransportTest.java
new file mode 100644
index 0000000..1fd6596
--- /dev/null
+++ b/tests/src/com/android/phone/common/mail/MailTransportTest.java
@@ -0,0 +1,389 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.phone.common.mail;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.net.Network;
+import android.test.AndroidTestCase;
+
+import com.android.phone.MockitoHelper;
+import com.android.phone.common.mail.MailTransport.SocketCreator;
+import com.android.phone.common.mail.store.ImapStore;
+
+import junit.framework.AssertionFailedError;
+
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.net.SocketAddress;
+import java.net.SocketException;
+import java.net.UnknownHostException;
+
+import javax.net.SocketFactory;
+
+public class MailTransportTest extends AndroidTestCase {
+
+    private static final String HOST_ADDRESS = "127.0.0.1";
+    private static final String INVALID_HOST_ADDRESS = "255.255.255.255";
+    private static final int HOST_PORT = 80;
+    private static final int HOST_FLAGS = 0;
+    // bypass verifyHostname() in open() by setting ImapStore.FLAG_TRUST_ALL
+    private static final int HOST_FLAGS_SSL = ImapStore.FLAG_SSL & ImapStore.FLAG_TRUST_ALL;
+    private static final InetAddress VALID_INET_ADDRESS = createInetAddress(HOST_ADDRESS);
+    private static final InetAddress INVALID_INET_ADDRESS = createInetAddress(INVALID_HOST_ADDRESS);
+
+    // ClassLoader need to be replaced for mockito to work.
+    private MockitoHelper mMokitoHelper = new MockitoHelper();
+
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        mMokitoHelper.setUp(getClass());
+        MockitoAnnotations.initMocks(this);
+    }
+
+    @Override
+    public void tearDown() throws Exception {
+        mMokitoHelper.tearDown();
+        super.tearDown();
+    }
+
+    public void testCreateSocket_anyNetwork() throws MessagingException {
+        // With no network, Socket#Socket() should be called.
+        MailTransport transport =
+                new MailTransport(getContext(), null, HOST_ADDRESS, HOST_PORT, HOST_FLAGS);
+        Socket socket = transport.createSocket();
+        assertTrue(socket != null);
+    }
+
+    public void testCreateSocket_networkSpecified() throws MessagingException, IOException {
+        // Network#getSocketFactory should be used to create socket.
+        Network mockNetwork = createMockNetwork();
+        MailTransport transport =
+                new MailTransport(getContext(), mockNetwork, HOST_ADDRESS, HOST_PORT, HOST_FLAGS);
+        Socket socket = transport.createSocket();
+        assertTrue(socket != null);
+        verify(mockNetwork).getSocketFactory();
+    }
+
+    public void testCreateSocket_socketCreator() throws MessagingException, IOException {
+        // For testing purposes, how sockets are created can be overridden.
+        SocketCreator socketCreator = new SocketCreator() {
+
+            private final Socket mSocket = new Socket();
+
+            @Override
+            public Socket createSocket() {
+                return mSocket;
+            }
+        };
+
+        MailTransport transport = new
+                MailTransport(getContext(), null, HOST_ADDRESS, HOST_PORT, HOST_FLAGS);
+
+        transport.setSocketCreator(socketCreator);
+
+        Socket socket = transport.createSocket();
+        assertTrue(socket == socketCreator.createSocket());
+    }
+
+    public void testOpen() throws MessagingException {
+        MailTransport transport = new MailTransport(getContext(), null, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        transport.setSocketCreator(new TestSocketCreator());
+        transport.open();
+        assertTrue(transport.isOpen());
+
+    }
+
+    public void testOpen_Ssl() throws MessagingException {
+        //opening with ssl support.
+        MailTransport transport = new MailTransport(getContext(), null, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS_SSL);
+        transport.setSocketCreator(new TestSocketCreator());
+        transport.open();
+        assertTrue(transport.isOpen());
+
+    }
+
+    public void testOpen_MultiIp() throws MessagingException {
+        //In case of round robin DNS, try all resolved address until one succeeded.
+        Network network = createMultiIpMockNetwork();
+        MailTransport transport = new MailTransport(getContext(), network, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        transport.setSocketCreator(new TestSocketCreator());
+        transport.open();
+        assertTrue(transport.isOpen());
+    }
+
+    public void testOpen_MultiIp_SSL() throws MessagingException {
+        Network network = createMultiIpMockNetwork();
+
+        MailTransport transport = new MailTransport(getContext(), network, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS_SSL);
+        transport.setSocketCreator(new TestSocketCreator());
+        transport.open();
+        assertTrue(transport.isOpen());
+    }
+
+    public void testOpen_network_hostResolutionFailed() {
+        // Couldn't resolve host on the network. Open() should fail.
+        Network network = createMockNetwork();
+        try {
+            when(network.getAllByName(HOST_ADDRESS))
+                    .thenThrow(new UnknownHostException("host resolution failed"));
+        } catch (IOException e) {
+            //ignored
+        }
+
+        MailTransport transport = new MailTransport(getContext(), network, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        try {
+            transport.open();
+            throw new AssertionFailedError("Should throw MessagingException");
+        } catch (MessagingException e) {
+            //expected
+        }
+        assertFalse(transport.isOpen());
+    }
+
+    public void testOpen_createSocketFailed() {
+        // Unable to create socket. Open() should fail.
+        MailTransport transport = new MailTransport(getContext(), null, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        transport.setSocketCreator(new SocketCreator() {
+            @Override
+            public Socket createSocket() throws MessagingException {
+                throw new MessagingException("createSocket failed");
+            }
+        });
+        try {
+            transport.open();
+            throw new AssertionFailedError("Should throw MessagingException");
+        } catch (MessagingException e) {
+            //expected
+        }
+        assertFalse(transport.isOpen());
+    }
+
+    public void testOpen_network_createSocketFailed() {
+        // Unable to create socket. Open() should fail.
+
+        Network network = createOneIpMockNetwork();
+        SocketFactory mockSocketFactory = mock(SocketFactory.class);
+        try {
+            when(mockSocketFactory.createSocket())
+                    .thenThrow(new IOException("unable to create socket"));
+        } catch (IOException e) {
+            //ignored
+        }
+        when(network.getSocketFactory()).thenReturn(mockSocketFactory);
+
+        MailTransport transport = new MailTransport(getContext(), network, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+
+        try {
+            transport.open();
+            throw new AssertionFailedError("Should throw MessagingException");
+        } catch (MessagingException e) {
+            //expected
+        }
+        assertFalse(transport.isOpen());
+    }
+
+    public void testOpen_connectFailed_one() {
+        // There is only one IP for this host, and we failed to connect to it. Open() should fail.
+
+        MailTransport transport = new MailTransport(getContext(), null, HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        transport.setSocketCreator(new SocketCreator() {
+            @Override
+            public Socket createSocket() throws MessagingException {
+                return new Socket() {
+                    @Override
+                    public void connect(SocketAddress address, int timeout) throws IOException {
+                        throw new IOException("connect failed");
+                    }
+                };
+            }
+        });
+        try {
+            transport.open();
+            throw new AssertionFailedError("Should throw MessagingException");
+        } catch (MessagingException e) {
+            //expected
+        }
+        assertFalse(transport.isOpen());
+    }
+
+    public void testOpen_connectFailed_multi() {
+        // There are multiple IP for this host, and we failed to connect to any of it.
+        // Open() should fail.
+        MailTransport transport = new MailTransport(getContext(), createMultiIpMockNetwork(),
+                HOST_ADDRESS,
+                HOST_PORT, HOST_FLAGS);
+        transport.setSocketCreator(new SocketCreator() {
+            @Override
+            public Socket createSocket() throws MessagingException {
+                return new Socket() {
+                    @Override
+                    public void connect(SocketAddress address, int timeout) throws IOException {
+                        throw new IOException("connect failed");
+                    }
+                };
+            }
+        });
+        try {
+            transport.open();
+            throw new AssertionFailedError("Should throw MessagingException");
+        } catch (MessagingException e) {
+            //expected
+        }
+        assertFalse(transport.isOpen());
+    }
+
+    private class TestSocket extends Socket {
+
+        boolean mConnected = false;
+
+
+        /**
+         * A make a mock connection to the address.
+         *
+         * @param address Only address equivalent to VALID_INET_ADDRESS or INVALID_INET_ADDRESS is
+         * accepted
+         * @param timeout Ignored but should >= 0.
+         */
+        @Override
+        public void connect(SocketAddress address, int timeout) throws IOException {
+            // copied from Socket#connect
+            if (isClosed()) {
+                throw new SocketException("Socket is closed");
+            }
+            if (timeout < 0) {
+                throw new IllegalArgumentException("timeout < 0");
+            }
+            if (isConnected()) {
+                throw new SocketException("Already connected");
+            }
+            if (address == null) {
+                throw new IllegalArgumentException("remoteAddr == null");
+            }
+
+            if (!(address instanceof InetSocketAddress)) {
+                throw new AssertionError("address should be InetSocketAddress");
+            }
+
+            InetSocketAddress inetSocketAddress = (InetSocketAddress) address;
+            if (inetSocketAddress.getAddress().equals(INVALID_INET_ADDRESS)) {
+                throw new IOException("invalid address");
+            } else if (inetSocketAddress.getAddress().equals(VALID_INET_ADDRESS)) {
+                mConnected = true;
+            } else {
+                throw new AssertionError("Only INVALID_ADDRESS or VALID_ADDRESS are allowed");
+            }
+        }
+
+        @Override
+        public boolean isConnected() {
+            return mConnected;
+        }
+
+    }
+
+
+    private class TestSocketCreator implements MailTransport.SocketCreator {
+
+        @Override
+        public Socket createSocket() throws MessagingException {
+            Socket socket = new TestSocket();
+            return socket;
+        }
+
+    }
+
+    /**
+     * @return a mock Network that can create a TestSocket with {@code getSocketFactory()
+     * .createSocket()}
+     */
+    private Network createMockNetwork() {
+        Network network = mock(Network.class);
+        SocketFactory mockSocketFactory = mock(SocketFactory.class);
+        try {
+            when(mockSocketFactory.createSocket()).thenReturn(new TestSocket());
+        } catch (IOException e) {
+            //ignored
+        }
+        when(network.getSocketFactory()).thenReturn(mockSocketFactory);
+        return network;
+    }
+
+    /**
+     * @return a mock Network like {@link MailTransportTest#createMockNetwork()}, but also supports
+     * {@link Network#getAllByName(String)} with one valid result.
+     */
+    private Network createOneIpMockNetwork() {
+        Network network = createMockNetwork();
+        try {
+            when(network.getAllByName(HOST_ADDRESS))
+                    .thenReturn(new InetAddress[] {VALID_INET_ADDRESS});
+        } catch (UnknownHostException e) {
+            //ignored
+        }
+
+        return network;
+    }
+
+    /**
+     * @return a mock Network like {@link MailTransportTest#createMockNetwork()}, but also supports
+     * {@link Network#getAllByName(String)}, which will return 2 address with the first one
+     * invalid.
+     */
+    private Network createMultiIpMockNetwork() {
+        Network network = createMockNetwork();
+        try {
+            when(network.getAllByName(HOST_ADDRESS))
+                    .thenReturn(new InetAddress[] {INVALID_INET_ADDRESS, VALID_INET_ADDRESS});
+        } catch (UnknownHostException e) {
+            //ignored
+        }
+
+        return network;
+    }
+
+    /**
+     * helper method to translate{@code host} into a InetAddress.
+     *
+     * @param host IP address of the host. Domain name should not be used as this method should not
+     * access the internet.
+     */
+    private static InetAddress createInetAddress(String host) {
+        try {
+            return InetAddress.getByName(host);
+        } catch (UnknownHostException e) {
+            return null;
+        }
+    }
+
+
+}