Adding engine support to server socket. (#81)
* Adding engine support to server socket.
Also adding simple unit test for server socket. This required moving some more test classes to the testing module.
* fixing import.
diff --git a/build.gradle b/build.gradle
index 3ee0373..9fb3a1f 100644
--- a/build.gradle
+++ b/build.gradle
@@ -70,7 +70,7 @@
// Test dependencies.
guava : 'com.google.guava:guava:19.0',
- junit : 'junit:junit:4.11',
+ junit : 'junit:junit:4.12',
mockito: 'org.mockito:mockito-core:1.9.5',
truth : 'com.google.truth:truth:0.28',
bouncycastle_provider: 'org.bouncycastle:bcprov-jdk15on:1.55',
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
index 0ad0816..dcf27f1 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
@@ -178,17 +178,17 @@
@Override
public void setChannelIdEnabled(boolean enabled) {
- throw new UnsupportedOperationException("Not supported");
+ super.setChannelIdEnabled(enabled);
}
@Override
public byte[] getChannelId() throws SSLException {
- throw new UnsupportedOperationException("Not supported");
+ return super.getChannelId();
}
@Override
public void setChannelIdPrivateKey(PrivateKey privateKey) {
- throw new UnsupportedOperationException("FIXME");
+ super.setChannelIdPrivateKey(privateKey);
}
@Override
@@ -239,7 +239,6 @@
@Override
public int getSoWriteTimeout() throws SocketException {
return 0;
- //throw new UnsupportedOperationException("Not supported");
}
@Override
@@ -519,6 +518,10 @@
needMoreData = false;
break;
}
+ case CLOSED: {
+ // EOF
+ return -1;
+ }
default: {
// Anything else is an error.
throw new SSLException(
diff --git a/common/src/main/java/org/conscrypt/OpenSSLServerSocketFactoryImpl.java b/common/src/main/java/org/conscrypt/OpenSSLServerSocketFactoryImpl.java
index 25819ed..74444b8 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLServerSocketFactoryImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLServerSocketFactoryImpl.java
@@ -22,9 +22,11 @@
import java.security.KeyManagementException;
public class OpenSSLServerSocketFactoryImpl extends javax.net.ssl.SSLServerSocketFactory {
+ private static boolean useEngineSocketByDefault = SSLUtils.USE_ENGINE_SOCKET_BY_DEFAULT;
private SSLParametersImpl sslParameters;
private IOException instantiationException;
+ private boolean useEngineSocket = useEngineSocketByDefault;
public OpenSSLServerSocketFactoryImpl() {
try {
@@ -42,6 +44,21 @@
this.sslParameters.setUseClientMode(false);
}
+ /**
+ * Configures the default socket to be created for all instances.
+ */
+ public static void setUseEngineSocketByDefault(boolean useEngineSocket) {
+ useEngineSocketByDefault = useEngineSocket;
+ }
+
+ /**
+ * Configures the socket to be created for this instance. If not called,
+ * {@link #useEngineSocketByDefault} will be used.
+ */
+ public void setUseEngineSocket(boolean useEngineSocket) {
+ this.useEngineSocket = useEngineSocket;
+ }
+
@Override
public String[] getDefaultCipherSuites() {
return sslParameters.getEnabledCipherSuites();
@@ -54,29 +71,27 @@
@Override
public ServerSocket createServerSocket() throws IOException {
- return new OpenSSLServerSocketImpl((SSLParametersImpl) sslParameters.clone());
+ return new OpenSSLServerSocketImpl((SSLParametersImpl) sslParameters.clone())
+ .setUseEngineSocket(useEngineSocket);
}
@Override
public ServerSocket createServerSocket(int port) throws IOException {
- return new OpenSSLServerSocketImpl(port, (SSLParametersImpl) sslParameters.clone());
+ return new OpenSSLServerSocketImpl(port, (SSLParametersImpl) sslParameters.clone())
+ .setUseEngineSocket(useEngineSocket);
}
@Override
- public ServerSocket createServerSocket(int port, int backlog)
+ public ServerSocket createServerSocket(int port, int backlog) throws IOException {
+ return new OpenSSLServerSocketImpl(port, backlog, (SSLParametersImpl) sslParameters.clone())
+ .setUseEngineSocket(useEngineSocket);
+ }
+
+ @Override
+ public ServerSocket createServerSocket(int port, int backlog, InetAddress iAddress)
throws IOException {
- return new OpenSSLServerSocketImpl(port,
- backlog,
- (SSLParametersImpl) sslParameters.clone());
- }
-
- @Override
- public ServerSocket createServerSocket(int port,
- int backlog,
- InetAddress iAddress) throws IOException {
- return new OpenSSLServerSocketImpl(port,
- backlog,
- iAddress,
- (SSLParametersImpl) sslParameters.clone());
+ return new OpenSSLServerSocketImpl(
+ port, backlog, iAddress, (SSLParametersImpl) sslParameters.clone())
+ .setUseEngineSocket(useEngineSocket);
}
}
diff --git a/common/src/main/java/org/conscrypt/OpenSSLServerSocketImpl.java b/common/src/main/java/org/conscrypt/OpenSSLServerSocketImpl.java
index 215da55..f031c6c 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLServerSocketImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLServerSocketImpl.java
@@ -26,6 +26,7 @@
public class OpenSSLServerSocketImpl extends javax.net.ssl.SSLServerSocket {
private final SSLParametersImpl sslParameters;
private boolean channelIdEnabled;
+ private boolean useEngineSocket;
protected OpenSSLServerSocketImpl(SSLParametersImpl sslParameters) throws IOException {
this.sslParameters = sslParameters;
@@ -52,6 +53,14 @@
this.sslParameters = sslParameters;
}
+ /**
+ * Configures the socket to be created for this instance.
+ */
+ public OpenSSLServerSocketImpl setUseEngineSocket(boolean useEngineSocket) {
+ this.useEngineSocket = useEngineSocket;
+ return this;
+ }
+
@Override
public boolean getEnableSessionCreation() {
return sslParameters.getEnableSessionCreation();
@@ -165,9 +174,21 @@
@Override
public Socket accept() throws IOException {
- OpenSSLSocketImpl socket = new OpenSSLSocketImpl(sslParameters);
- socket.setChannelIdEnabled(channelIdEnabled);
- implAccept(socket);
- return socket;
+ if (useEngineSocket) {
+ Socket rawSocket = new Socket();
+ implAccept(rawSocket);
+
+ // Enable channel ID.
+ OpenSSLEngineSocketImpl socket =
+ new OpenSSLEngineSocketImpl(rawSocket, null, -1, true, sslParameters);
+ socket.setChannelIdEnabled(channelIdEnabled);
+ socket.startHandshake();
+ return socket;
+ } else {
+ OpenSSLSocketImpl socket = new OpenSSLSocketImpl(sslParameters);
+ socket.setChannelIdEnabled(channelIdEnabled);
+ implAccept(socket);
+ return socket;
+ }
}
}
diff --git a/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java b/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
index c2272d5..35c178e 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
@@ -23,8 +23,7 @@
import java.security.KeyManagementException;
public class OpenSSLSocketFactoryImpl extends javax.net.ssl.SSLSocketFactory {
- private static boolean useEngineSocketByDefault =
- Boolean.parseBoolean(System.getProperty("org.conscrypt.useEngineSocketByDefault"));
+ private static boolean useEngineSocketByDefault = SSLUtils.USE_ENGINE_SOCKET_BY_DEFAULT;
private final SSLParametersImpl sslParameters;
private final IOException instantiationException;
diff --git a/common/src/main/java/org/conscrypt/SSLUtils.java b/common/src/main/java/org/conscrypt/SSLUtils.java
index 88fb4f0..b39a4fe 100644
--- a/common/src/main/java/org/conscrypt/SSLUtils.java
+++ b/common/src/main/java/org/conscrypt/SSLUtils.java
@@ -44,18 +44,21 @@
* Utility methods for SSL packet processing. Copied from the Netty project.
*/
final class SSLUtils {
+ static final boolean USE_ENGINE_SOCKET_BY_DEFAULT =
+ Boolean.parseBoolean(System.getProperty("org.conscrypt.useEngineSocketByDefault"));
+
/**
* Return how much bytes can be read out of the encrypted data. Be aware that this method will
* not
* increase the readerIndex of the given {@link ByteBuffer}.
*
* @param buffers The {@link ByteBuffer}s to read from. Be aware that they must have at least
- * {@link #SSL3_RT_HEADER_LENGTH} bytes to read, otherwise it will throw an {@link
- * IllegalArgumentException}.
+ * {@link org.conscrypt.NativeConstants#SSL3_RT_HEADER_LENGTH} bytes to read, otherwise it will
+ * throw an {@link IllegalArgumentException}.
* @return length The length of the encrypted packet that is included in the buffer. This will
* return {@code -1} if the given {@link ByteBuffer} is not encrypted at all.
* @throws IllegalArgumentException Is thrown if the given {@link ByteBuffer} has not at least
- * {@link #SSL3_RT_HEADER_LENGTH} bytes to read.
+ * {@link org.conscrypt.NativeConstants#SSL3_RT_HEADER_LENGTH} bytes to read.
*/
static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
ByteBuffer buffer = buffers[offset];
diff --git a/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ClientSocketBenchmark.java b/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ClientSocketBenchmark.java
index 2beb562..edb0e57 100644
--- a/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ClientSocketBenchmark.java
+++ b/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ClientSocketBenchmark.java
@@ -28,6 +28,8 @@
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
+import org.conscrypt.testing.NettyEchoServer;
+import org.conscrypt.testing.TestClient;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
@@ -44,7 +46,7 @@
/**
* Various factories for SSL sockets.
*/
- public enum SslSocketType {
+ public enum SslProvider {
JDK {
private final SSLSocketFactory socketFactory = getJdkSocketFactory();
@Override
@@ -94,7 +96,7 @@
abstract SSLSocket newSslSocket(String host, int port);
}
- @Param public SslSocketType sslSocketType;
+ @Param public SslProvider sslProvider;
@Param({"64", "128", "512", "1024", "4096"}) public int messageSize;
@@ -114,7 +116,7 @@
server = new NettyEchoServer(port, messageSize, cipher);
server.start();
- client = new TestClient(sslSocketType.newSslSocket(LOCALHOST, port, cipher));
+ client = new TestClient(sslProvider.newSslSocket(LOCALHOST, port, cipher));
client.start();
}
@@ -130,22 +132,4 @@
int numBytes = client.readMessage(response);
assertEquals(messageSize, numBytes);
}
-
- public static void main(String[] args) throws Exception {
- ClientSocketBenchmark bm = new ClientSocketBenchmark();
- bm.sslSocketType = SslSocketType.CONSCRYPT_ENGINE;
- bm.messageSize = 1024;
- bm.cipher = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256";
- bm.setup();
- try {
- while (true) {
- if (Thread.interrupted()) {
- break;
- }
- bm.pingPong();
- }
- } finally {
- bm.teardown();
- }
- }
}
diff --git a/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ServerSocketBenchmark.java b/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ServerSocketBenchmark.java
index f96d685..f22489a 100644
--- a/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ServerSocketBenchmark.java
+++ b/openjdk-benchmarks/src/jmh/java/org/conscrypt/benchmarks/ServerSocketBenchmark.java
@@ -32,6 +32,8 @@
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
+import org.conscrypt.testing.EchoServer;
+import org.conscrypt.testing.TestClient;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
@@ -48,13 +50,14 @@
/**
* Various factories for SSL server sockets.
*/
- public enum SslSocketType {
+ public enum SslProvider {
JDK(getJdkServerSocketFactory()),
- CONSCRYPT(getConscryptServerSocketFactory());
+ CONSCRYPT(getConscryptServerSocketFactory(false)),
+ CONSCRYPT_ENGINE(getConscryptServerSocketFactory(true));
private final SSLServerSocketFactory serverSocketFactory;
- SslSocketType(SSLServerSocketFactory serverSocketFactory) {
+ SslProvider(SSLServerSocketFactory serverSocketFactory) {
this.serverSocketFactory = serverSocketFactory;
}
@@ -72,7 +75,7 @@
}
}
- @Param public SslSocketType sslSocketType;
+ @Param public SslProvider sslProvider;
@Param({"64", "128", "512", "1024", "4096"}) public int messageSize;
@@ -88,7 +91,7 @@
message = newTextMessage(messageSize);
response = new byte[messageSize];
- server = new EchoServer(sslSocketType.newServerSocket(cipher), messageSize);
+ server = new EchoServer(sslProvider.newServerSocket(cipher), messageSize);
Future connectedFuture = server.start();
diff --git a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/EchoServer.java b/openjdk-testing/src/main/java/org/conscrypt/testing/EchoServer.java
similarity index 75%
rename from openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/EchoServer.java
rename to openjdk-testing/src/main/java/org/conscrypt/testing/EchoServer.java
index 25edaf6..771c3b4 100644
--- a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/EchoServer.java
+++ b/openjdk-testing/src/main/java/org/conscrypt/testing/EchoServer.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.conscrypt.benchmarks;
+package org.conscrypt.testing;
import java.io.IOException;
import java.util.Arrays;
@@ -28,7 +28,7 @@
/**
* Simple echo server that responds with an identical message to the one received.
*/
-final class EchoServer {
+public final class EchoServer {
private final ExecutorService executor = Executors.newSingleThreadExecutor();
private final SSLServerSocket serverSocket;
private final int messageSize;
@@ -36,31 +36,33 @@
private SSLSocket socket;
private volatile boolean stopping;
- EchoServer(SSLServerSocket serverSocket, int messageSize) {
+ public EchoServer(SSLServerSocket serverSocket, int messageSize) {
this.serverSocket = serverSocket;
this.messageSize = messageSize;
buffer = new byte[messageSize];
}
- Future<?> start() {
+ public Future<?> start() {
return executor.submit(new AcceptTask());
}
- void stop() {
+ public void stop() {
try {
stopping = true;
+ executor.shutdown();
+
if (socket != null) {
socket.close();
}
serverSocket.close();
- executor.shutdown();
+
executor.awaitTermination(5, TimeUnit.SECONDS);
} catch (IOException | InterruptedException e) {
throw new RuntimeException(e);
}
}
- int port() {
+ public int port() {
return serverSocket.getLocalPort();
}
@@ -76,36 +78,29 @@
if (stopping) {
return;
}
- executor.execute(new ReadTask());
+ executor.execute(new EchoTask());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
- private final class ReadTask implements Runnable {
+ private final class EchoTask implements Runnable {
@Override
public void run() {
try {
- if (stopping) {
- return;
+ while (!stopping) {
+ byte[] output = readMessage();
+ if (!stopping) {
+ sendMessage(output);
+ }
}
- byte[] output = readMessage();
- sendMessage(output);
-
- if (stopping) {
- return;
- }
- // Keep running the task until it's being shut down.
- executor.execute(this);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
- }
- private byte[] readMessage() {
- try {
+ private byte[] readMessage() throws IOException {
int totalBytesRead = 0;
while (totalBytesRead < messageSize) {
int remaining = messageSize - totalBytesRead;
@@ -116,17 +111,15 @@
totalBytesRead += bytesRead;
}
return Arrays.copyOfRange(buffer, 0, totalBytesRead);
- } catch (Throwable e) {
- throw new RuntimeException(e);
}
- }
- private void sendMessage(byte[] data) {
- try {
- socket.getOutputStream().write(data);
- socket.getOutputStream().flush();
- } catch (IOException e) {
- throw new RuntimeException(e);
+ private void sendMessage(byte[] data) {
+ try {
+ socket.getOutputStream().write(data);
+ socket.getOutputStream().flush();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
}
}
}
diff --git a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/NettyEchoServer.java b/openjdk-testing/src/main/java/org/conscrypt/testing/NettyEchoServer.java
similarity index 94%
rename from openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/NettyEchoServer.java
rename to openjdk-testing/src/main/java/org/conscrypt/testing/NettyEchoServer.java
index 7963696..352bd1b 100644
--- a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/NettyEchoServer.java
+++ b/openjdk-testing/src/main/java/org/conscrypt/testing/NettyEchoServer.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.conscrypt.benchmarks;
+package org.conscrypt.testing;
import static io.netty.channel.ChannelOption.SO_BACKLOG;
import static io.netty.channel.ChannelOption.SO_KEEPALIVE;
@@ -34,26 +34,25 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLEngine;
-import org.conscrypt.testing.TestUtil;
/**
* A test server based on Netty and Netty-tcnative that auto-replies with every message
* it receives.
*/
-final class NettyEchoServer {
+public final class NettyEchoServer {
private final EventLoopGroup group = new NioEventLoopGroup();
private final int port;
private final int messageSize;
private Channel channel;
private String cipher;
- NettyEchoServer(int port, int messageSize, String cipher) {
+ public NettyEchoServer(int port, int messageSize, String cipher) {
this.port = port;
this.messageSize = messageSize;
this.cipher = cipher;
}
- void start() {
+ public void start() {
ServerBootstrap b = new ServerBootstrap();
b.group(group);
b.channel(NioServerSocketChannel.class);
@@ -82,7 +81,7 @@
channel = future.channel();
}
- void stop() {
+ public void stop() {
if (channel != null) {
channel.close().awaitUninterruptibly();
group.shutdownGracefully(1, 5, TimeUnit.SECONDS);
diff --git a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/TestClient.java b/openjdk-testing/src/main/java/org/conscrypt/testing/TestClient.java
similarity index 78%
rename from openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/TestClient.java
rename to openjdk-testing/src/main/java/org/conscrypt/testing/TestClient.java
index bbeeae4..5da3c51 100644
--- a/openjdk-benchmarks/src/main/java/org/conscrypt/benchmarks/TestClient.java
+++ b/openjdk-testing/src/main/java/org/conscrypt/testing/TestClient.java
@@ -14,31 +14,38 @@
* limitations under the License.
*/
-package org.conscrypt.benchmarks;
+package org.conscrypt.testing;
+import java.io.BufferedInputStream;
import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import javax.net.ssl.SSLSocket;
/**
* Client-side endpoint. Provides basic services for sending/receiving messages from the client
* socket.
*/
-final class TestClient {
+public final class TestClient {
private final SSLSocket socket;
+ private InputStream input;
+ private OutputStream output;
- TestClient(SSLSocket socket) {
+ public TestClient(SSLSocket socket) {
this.socket = socket;
}
- void start() {
+ public void start() {
try {
socket.startHandshake();
+ input = new BufferedInputStream(socket.getInputStream());
+ output = socket.getOutputStream();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
- void stop() {
+ public void stop() {
try {
socket.close();
} catch (IOException e) {
@@ -46,7 +53,7 @@
}
}
- int readMessage(byte[] buffer) {
+ public int readMessage(byte[] buffer) {
try {
int totalBytesRead = 0;
while (totalBytesRead < buffer.length) {
@@ -63,7 +70,7 @@
}
}
- void sendMessage(byte[] data) {
+ public void sendMessage(byte[] data) {
try {
socket.getOutputStream().write(data);
socket.getOutputStream().flush();
diff --git a/openjdk-testing/src/main/java/org/conscrypt/testing/TestUtil.java b/openjdk-testing/src/main/java/org/conscrypt/testing/TestUtil.java
index e325694..dfd3749 100644
--- a/openjdk-testing/src/main/java/org/conscrypt/testing/TestUtil.java
+++ b/openjdk-testing/src/main/java/org/conscrypt/testing/TestUtil.java
@@ -108,8 +108,17 @@
}
}
- public static SSLServerSocketFactory getConscryptServerSocketFactory() {
- return getServerSocketFactory(CONSCRYPT_PROVIDER);
+ public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) {
+ try {
+ Class<?> clazz = Class.forName("org.conscrypt.OpenSSLServerSocketFactoryImpl");
+ Method method = clazz.getMethod("setUseEngineSocket", boolean.class);
+
+ SSLServerSocketFactory socketFactory = getServerSocketFactory(CONSCRYPT_PROVIDER);
+ method.invoke(socketFactory, useEngineSocket);
+ return socketFactory;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
}
private static SSLSocketFactory getSocketFactory(Provider provider) {
diff --git a/openjdk/src/test/java/org/conscrypt/OpenSSLServerSocketImplTest.java b/openjdk/src/test/java/org/conscrypt/OpenSSLServerSocketImplTest.java
new file mode 100644
index 0000000..193ac0e
--- /dev/null
+++ b/openjdk/src/test/java/org/conscrypt/OpenSSLServerSocketImplTest.java
@@ -0,0 +1,120 @@
+/*
+ * Copyright 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.conscrypt;
+
+import static org.conscrypt.testing.TestUtil.LOCALHOST;
+import static org.conscrypt.testing.TestUtil.getConscryptServerSocketFactory;
+import static org.conscrypt.testing.TestUtil.getJdkSocketFactory;
+import static org.conscrypt.testing.TestUtil.getProtocols;
+import static org.conscrypt.testing.TestUtil.newTextMessage;
+import static org.conscrypt.testing.TestUtil.pickUnusedPort;
+import static org.junit.Assert.assertArrayEquals;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLServerSocketFactory;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+import org.conscrypt.testing.EchoServer;
+import org.conscrypt.testing.TestClient;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class OpenSSLServerSocketImplTest {
+ private static final String CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256";
+ private static final int MESSAGE_SIZE = 4096;
+
+ /**
+ * Various factories for SSL server sockets.
+ */
+ public enum SocketType {
+ DEFAULT(getConscryptServerSocketFactory(false)),
+ ENGINE(getConscryptServerSocketFactory(true));
+
+ private final SSLServerSocketFactory serverSocketFactory;
+
+ SocketType(SSLServerSocketFactory serverSocketFactory) {
+ this.serverSocketFactory = serverSocketFactory;
+ }
+
+ final SSLServerSocket newServerSocket(String cipher) {
+ try {
+ int port = pickUnusedPort();
+ SSLServerSocket sslSocket =
+ (SSLServerSocket) serverSocketFactory.createServerSocket(port);
+ sslSocket.setEnabledProtocols(getProtocols());
+ sslSocket.setEnabledCipherSuites(new String[] {cipher});
+ return sslSocket;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ @Parameters(name = "{0}")
+ public static Iterable<Object> data() {
+ return Arrays.asList(SocketType.DEFAULT, SocketType.ENGINE);
+ }
+
+ @Parameter public SocketType socketType;
+
+ private TestClient client;
+ private EchoServer server;
+
+ @Before
+ public void setup() throws Exception {
+ // Create and start the server.
+ server = new EchoServer(socketType.newServerSocket(CIPHER), MESSAGE_SIZE);
+ Future connectedFuture = server.start();
+
+ // Create and start the client.
+ SSLSocketFactory socketFactory = getJdkSocketFactory();
+ SSLSocket socket = (SSLSocket) socketFactory.createSocket(LOCALHOST, server.port());
+ socket.setEnabledProtocols(getProtocols());
+ socket.setEnabledCipherSuites(new String[] {CIPHER});
+ client = new TestClient(socket);
+ client.start();
+
+ // Wait for the initial connection to complete.
+ connectedFuture.get(5, TimeUnit.SECONDS);
+ }
+
+ @After
+ public void teardown() throws Exception {
+ client.stop();
+ server.stop();
+ }
+
+ @Test
+ public void pingPong() throws IOException {
+ byte[] request = newTextMessage(MESSAGE_SIZE);
+ byte[] responseBuffer = new byte[MESSAGE_SIZE];
+ client.sendMessage(request);
+ int numBytes = client.readMessage(responseBuffer);
+ byte[] response = Arrays.copyOfRange(responseBuffer, 0, numBytes);
+ assertArrayEquals(request, response);
+ }
+}