Improve okhttp client transport, handles go away and add unit test.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72155172
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
index d5ac86c..6affe1f 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
@@ -1,7 +1,8 @@
package com.google.net.stubby.newtransport.okhttp;
import com.google.common.util.concurrent.SerializingExecutor;
-import com.google.common.util.concurrent.Service;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.transport.Transport.Code;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.FrameWriter;
@@ -17,9 +18,10 @@
class AsyncFrameWriter implements FrameWriter {
private final FrameWriter frameWriter;
private final Executor executor;
- private final Service transport;
+ private final OkHttpClientTransport transport;
- public AsyncFrameWriter(FrameWriter frameWriter, Service transport, Executor executor) {
+ public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport,
+ Executor executor) {
this.frameWriter = frameWriter;
this.transport = transport;
// Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
@@ -158,6 +160,8 @@
@Override
public void doRun() throws IOException {
frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
+ // Flush it since after goAway, we are likely to close this writer.
+ frameWriter.flush();
}
});
}
@@ -188,7 +192,7 @@
try {
doRun();
} catch (IOException ex) {
- transport.stopAsync();
+ transport.abort(Status.fromThrowable(ex));
throw new RuntimeException(ex);
}
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
index fc974e0..b4e9f68 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
@@ -1,5 +1,6 @@
package com.google.net.stubby.newtransport.okhttp;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteBuffers;
@@ -33,9 +34,10 @@
import java.io.IOException;
import java.net.Socket;
import java.nio.ByteBuffer;
-import java.util.Collection;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
@@ -48,6 +50,7 @@
*/
public class OkHttpClientTransport extends AbstractClientTransport {
/** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
+ @VisibleForTesting
static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
@@ -75,21 +78,40 @@
private final int port;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
- @GuardedBy("this")
+ private Object lock = new Object();
+ @GuardedBy("lock")
private int nextStreamId;
private final Map<Integer, OkHttpClientStream> streams =
Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
private final ExecutorService executor = Executors.newCachedThreadPool();
private int unacknowledgedBytesRead;
+ private ClientFrameHandler clientFrameHandler;
+ // The status used to finish all active streams when the transport is closed.
+ @GuardedBy("lock")
+ private boolean goAway;
+ @GuardedBy("lock")
+ private Status goAwayStatus;
public OkHttpClientTransport(String host, int port) {
- this.host = host;
+ this.host = Preconditions.checkNotNull(host);
this.port = port;
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
}
+ /**
+ * Create a transport connected to a fake peer for test.
+ */
+ @VisibleForTesting
+ OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
+ host = null;
+ port = -1;
+ this.nextStreamId = nextStreamId;
+ this.frameReader = frameReader;
+ this.frameWriter = frameWriter;
+ }
+
@Override
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
return new OkHttpClientStream(method, listener);
@@ -97,53 +119,85 @@
@Override
protected void doStart() {
- BufferedSource source;
- BufferedSink sink;
- try {
- Socket socket = new Socket(host, port);
- // TODO(user): use SpdyConnection.
- source = Okio.buffer(Okio.source(socket));
- sink = Okio.buffer(Okio.sink(socket));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ // We set host to null for test.
+ if (host != null) {
+ BufferedSource source;
+ BufferedSink sink;
+ try {
+ Socket socket = new Socket(host, port);
+ source = Okio.buffer(Okio.source(socket));
+ sink = Okio.buffer(Okio.sink(socket));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ Variant variant = new Http20Draft12();
+ frameReader = variant.newReader(source, true);
+ frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
}
- Variant variant = new Http20Draft12();
- frameReader = variant.newReader(source, true);
- frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
- executor.execute(new ClientFrameHandler());
notifyStarted();
+ clientFrameHandler = new ClientFrameHandler();
+ executor.execute(clientFrameHandler);
}
@Override
protected void doStop() {
- closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
- frameWriter.close();
- try {
- frameReader.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
+ boolean normalClose;
+ synchronized (lock) {
+ normalClose = !goAway;
}
- executor.shutdown();
- notifyStopped();
+ if (normalClose) {
+ abort(new Status(Code.INTERNAL, "Transport stopped"));
+ // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
+ // The GOAWAY is part of graceful shutdown.
+ frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
+ }
+ stopIfNecessary();
+ }
+
+ @VisibleForTesting
+ ClientFrameHandler getHandler() {
+ return clientFrameHandler;
+ }
+
+ @VisibleForTesting
+ Map<Integer, OkHttpClientStream> getStreams() {
+ return streams;
}
/**
- * Close and remove all streams.
+ * Finish all active streams with given status, then close the transport.
*/
- private void closeAllStreams(Status status) {
- Collection<OkHttpClientStream> streamsCopy;
- synchronized (streams) {
- streamsCopy = streams.values();
- streams.clear();
+ void abort(Status status) {
+ onGoAway(-1, status);
+ }
+
+ private void onGoAway(int lastKnownStreamId, Status status) {
+ ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
+ synchronized (lock) {
+ goAway = true;
+ goAwayStatus = status;
+ Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
+ while (it.hasNext()) {
+ Map.Entry<Integer, OkHttpClientStream> entry = it.next();
+ if (entry.getKey() > lastKnownStreamId) {
+ goAwayStreams.add(entry.getValue());
+ it.remove();
+ }
+ }
}
- for (OkHttpClientStream stream : streamsCopy) {
+
+ // Starting stop, go into STOPPING state so that Channel know this Transport should not be used
+ // further, will become STOPPED once all streams are complete.
+ stopAsync();
+
+ for (OkHttpClientStream stream : goAwayStreams) {
stream.setStatus(status);
}
}
/**
- * Called when a HTTP2 stream is closed.
+ * Called when a stream is closed.
*
* <p> Return false if the stream has already finished.
*/
@@ -159,10 +213,39 @@
}
/**
+ * When the transport is in goAway states, we should stop it once all active streams finish.
+ */
+ private void stopIfNecessary() {
+ boolean shouldStop;
+ synchronized (lock) {
+ shouldStop = (goAway && streams.size() == 0);
+ }
+ if (shouldStop) {
+ frameWriter.close();
+ try {
+ frameReader.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ executor.shutdown();
+ notifyStopped();
+ }
+ }
+
+ /**
+ * Returns a Grpc status corresponding to the given ErrorCode.
+ */
+ @VisibleForTesting
+ static Status toGrpcStatus(ErrorCode code) {
+ return ERROR_CODE_TO_STATUS.get(code);
+ }
+
+ /**
* Runnable which reads frames and dispatches them to in flight calls
*/
- private class ClientFrameHandler implements FrameReader.Handler, Runnable {
- private ClientFrameHandler() {}
+ @VisibleForTesting
+ class ClientFrameHandler implements FrameReader.Handler, Runnable {
+ ClientFrameHandler() {}
@Override
public void run() {
@@ -173,8 +256,7 @@
while (frameReader.nextFrame(this)) {
}
} catch (IOException ioe) {
- ioe.printStackTrace();
- closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
+ abort(Status.fromThrowable(ioe));
} finally {
// Restore the original thread name.
Thread.currentThread().setName(threadName);
@@ -210,7 +292,9 @@
stream.unacknowledgedBytesRead = 0;
}
if (inFinished) {
- finishStream(streamId, Status.OK);
+ if (finishStream(streamId, Status.OK)) {
+ stopIfNecessary();
+ }
}
}
@@ -229,7 +313,9 @@
@Override
public void rstStream(int streamId, ErrorCode errorCode) {
- finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
+ if (finishStream(streamId, toGrpcStatus(errorCode))) {
+ stopIfNecessary();
+ }
}
@Override
@@ -252,18 +338,14 @@
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
- // TODO(user): Log here and implement the real Go away behavior: streams have
- // id <= lastGoodStreamId should not be closed.
- closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
- stopAsync();
+ onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {
- // TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
- // PROTOCOL_ERROR.
- frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
+ // We don't accept server initiated stream.
+ frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
}
@Override
@@ -284,28 +366,42 @@
}
}
+ @GuardedBy("lock")
+ private void assignStreamId(OkHttpClientStream stream) {
+ Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
+ stream.streamId = nextStreamId;
+ streams.put(stream.streamId, stream);
+ if (nextStreamId >= Integer.MAX_VALUE - 2) {
+ onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
+ } else {
+ nextStreamId += 2;
+ }
+ }
+
/**
* Client stream for the okhttp transport.
*/
- private class OkHttpClientStream extends AbstractStream implements ClientStream {
+ @VisibleForTesting
+ class OkHttpClientStream extends AbstractStream implements ClientStream {
int streamId;
final InputStreamDeframer deframer;
int unacknowledgedBytesRead;
- public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
+ OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
super(listener);
- Preconditions.checkState(streamId == 0, "StreamId should be 0");
- synchronized (OkHttpClientTransport.this) {
- streamId = nextStreamId;
- nextStreamId += 2;
- streams.put(streamId, this);
- frameWriter.synStream(false, false, streamId, 0,
- Headers.createRequestHeaders(method.getName()));
- }
deframer = new InputStreamDeframer(inboundMessageHandler());
+ synchronized (lock) {
+ if (goAway) {
+ setStatus(goAwayStatus);
+ return;
+ }
+ assignStreamId(this);
+ }
+ frameWriter.synStream(false, false, streamId, 0,
+ Headers.createRequestHeaders(method.getName()));
}
- public InputStreamDeframer getDeframer() {
+ InputStreamDeframer getDeframer() {
return deframer;
}
@@ -330,8 +426,9 @@
public void cancel() {
Preconditions.checkState(streamId != 0, "streamId should be set");
outboundPhase = Phase.STATUS;
- if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
+ if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
frameWriter.rstStream(streamId, ErrorCode.CANCEL);
+ stopIfNecessary();
}
}
}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java
new file mode 100644
index 0000000..7c3d7fd
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java
@@ -0,0 +1,555 @@
+package com.google.net.stubby.newtransport.okhttp;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.Service;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.StreamListener;
+import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler;
+import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream;
+import com.google.net.stubby.transport.Transport;
+import com.google.net.stubby.transport.Transport.Code;
+import com.google.net.stubby.transport.Transport.ContextValue;
+import com.google.protobuf.ByteString;
+
+import com.squareup.okhttp.internal.spdy.ErrorCode;
+import com.squareup.okhttp.internal.spdy.FrameReader;
+
+import okio.Buffer;
+import okio.BufferedSource;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Tests for {@link OkHttpClientTransport}.
+ */
+@RunWith(JUnit4.class)
+public class OkHttpClientTransportTest {
+ private static final int TIME_OUT_MS = 5000000;
+ private static final String NETWORK_ISSUE_MESSAGE = "network issue";
+
+ // Flags
+ private static final byte PAYLOAD_FRAME = 0x0;
+ public static final byte CONTEXT_VALUE_FRAME = 0x1;
+ public static final byte STATUS_FRAME = 0x3;
+
+ @Mock
+ private AsyncFrameWriter frameWriter;
+ @Mock
+ MethodDescriptor<?, ?> method;
+ private OkHttpClientTransport clientTransport;
+ private MockFrameReader frameReader;
+ private Map<Integer, OkHttpClientStream> streams;
+ private ClientFrameHandler frameHandler;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ streams = new HashMap<Integer, OkHttpClientStream>();
+ frameReader = new MockFrameReader();
+ clientTransport = new OkHttpClientTransport(frameReader, frameWriter, 3);
+ clientTransport.startAsync();
+ frameHandler = clientTransport.getHandler();
+ streams = clientTransport.getStreams();
+ when(method.getName()).thenReturn("fakemethod");
+ }
+
+ @After
+ public void tearDown() {
+ clientTransport.stopAsync();
+ assertTrue(frameReader.closed);
+ verify(frameWriter).close();
+ }
+
+ /**
+ * When nextFrame throws IOException, the transport should be aborted.
+ */
+ @Test
+ public void nextFrameThrowIOException() throws Exception {
+ MockStreamListener listener1 = new MockStreamListener();
+ MockStreamListener listener2 = new MockStreamListener();
+ clientTransport.newStream(method, listener1);
+ clientTransport.newStream(method, listener2);
+ assertEquals(2, streams.size());
+ assertTrue(streams.containsKey(3));
+ assertTrue(streams.containsKey(5));
+ frameReader.throwIOExceptionForNextFrame();
+ listener1.waitUntilStreamClosed();
+ listener2.waitUntilStreamClosed();
+ assertEquals(0, streams.size());
+ assertEquals(Code.INTERNAL, listener1.status.getCode());
+ assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
+ assertEquals(Code.INTERNAL, listener1.status.getCode());
+ assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
+ assertTrue("Service state: " + clientTransport.state(),
+ Service.State.TERMINATED == clientTransport.state());
+ }
+
+ @Test
+ public void readMessages() throws Exception {
+ final int numMessages = 10;
+ final String message = "Hello Client";
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ assertTrue(streams.containsKey(3));
+ for (int i = 0; i < numMessages; i++) {
+ BufferedSource source = mock(BufferedSource.class);
+ InputStream inputStream = createMessageFrame(message + i);
+ when(source.inputStream()).thenReturn(inputStream);
+ frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available());
+ }
+ listener.waitUntilStreamClosed();
+ assertEquals(Status.OK, listener.status);
+ assertEquals(numMessages, listener.messages.size());
+ for (int i = 0; i < numMessages; i++) {
+ assertEquals(message + i, listener.messages.get(i));
+ }
+ }
+
+ @Test
+ public void readContexts() throws Exception {
+ final int numContexts = 10;
+ final String key = "KEY";
+ final String value = "value";
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ assertTrue(streams.containsKey(3));
+ for (int i = 0; i < numContexts; i++) {
+ BufferedSource source = mock(BufferedSource.class);
+ InputStream inputStream = createContextFrame(key + i, value + i);
+ when(source.inputStream()).thenReturn(inputStream);
+ frameHandler.data(i == numContexts - 1 ? true : false, 3, source, inputStream.available());
+ }
+ listener.waitUntilStreamClosed();
+ assertEquals(Status.OK, listener.status);
+ assertEquals(numContexts, listener.contexts.size());
+ for (int i = 0; i < numContexts; i++) {
+ String val = listener.contexts.get(key + i);
+ assertNotNull(val);
+ assertEquals(value + i, val);
+ }
+ }
+
+ @Test
+ public void readStatus() throws Exception {
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ assertTrue(streams.containsKey(3));
+ BufferedSource source = mock(BufferedSource.class);
+ InputStream inputStream = createStatusFrame((short) Transport.Code.UNAVAILABLE.getNumber());
+ when(source.inputStream()).thenReturn(inputStream);
+ frameHandler.data(true, 3, source, inputStream.available());
+ listener.waitUntilStreamClosed();
+ assertEquals(Transport.Code.UNAVAILABLE, listener.status.getCode());
+ }
+
+ @Test
+ public void receiveReset() throws Exception {
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ assertTrue(streams.containsKey(3));
+ frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR);
+ listener.waitUntilStreamClosed();
+ assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status);
+ }
+
+ @Test
+ public void cancelStream() throws Exception {
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ OkHttpClientStream stream = streams.get(3);
+ assertNotNull(stream);
+ stream.cancel();
+ verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+ listener.waitUntilStreamClosed();
+ assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
+ }
+
+ @Test
+ public void writeMessage() throws Exception {
+ final String message = "Hello Server";
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ OkHttpClientStream stream = streams.get(3);
+ InputStream input = new ByteArrayInputStream(message.getBytes(StandardCharsets.UTF_8));
+ stream.writeMessage(input, input.available(), null);
+ stream.flush();
+ ArgumentCaptor<Buffer> captor =
+ ArgumentCaptor.forClass(Buffer.class);
+ verify(frameWriter).data(eq(false), eq(3), captor.capture());
+ Buffer sentFrame = captor.getValue();
+ checkSameInputStream(createMessageFrame(message), sentFrame.inputStream());
+ }
+
+ @Test
+ public void writeContext() throws Exception {
+ final String key = "KEY";
+ final String value = "VALUE";
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, listener);
+ OkHttpClientStream stream = streams.get(3);
+ InputStream input = new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8));
+ stream.writeContext(key, input, input.available(), null);
+ stream.flush();
+ ArgumentCaptor<Buffer> captor =
+ ArgumentCaptor.forClass(Buffer.class);
+ verify(frameWriter).data(eq(false), eq(3), captor.capture());
+ stream.cancel();
+ verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+ listener.waitUntilStreamClosed();
+ assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
+ }
+
+ @Test
+ public void windowUpdate() throws Exception {
+ MockStreamListener listener1 = new MockStreamListener();
+ MockStreamListener listener2 = new MockStreamListener();
+ clientTransport.newStream(method, listener1);
+ clientTransport.newStream(method, listener2);
+ assertEquals(2, streams.size());
+ OkHttpClientStream stream1 = streams.get(3);
+ OkHttpClientStream stream2 = streams.get(5);
+
+ int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4;
+ byte[] fakeMessage = new byte[messageLength];
+ byte[] contextBody = ContextValue
+ .newBuilder()
+ .setKey("KEY")
+ .setValue(ByteString.copyFrom(fakeMessage))
+ .build()
+ .toByteArray();
+
+ // Stream 1 receives context
+ InputStream contextFrame = createContextFrame(contextBody);
+ int contextFrameLength = contextFrame.available();
+ BufferedSource source = mock(BufferedSource.class);
+ when(source.inputStream()).thenReturn(contextFrame);
+ frameHandler.data(false, 3, source, contextFrame.available());
+
+ // Stream 2 receives context
+ contextFrame = createContextFrame(contextBody);
+ when(source.inputStream()).thenReturn(contextFrame);
+ frameHandler.data(false, 5, source, contextFrame.available());
+
+ verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * contextFrameLength));
+
+ // Stream 1 receives a message
+ InputStream messageFrame = createMessageFrame(fakeMessage);
+ int messageFrameLength = messageFrame.available();
+ when(source.inputStream()).thenReturn(messageFrame);
+ frameHandler.data(false, 3, source, messageFrame.available());
+
+ verify(frameWriter).windowUpdate(eq(3), eq((long) contextFrameLength + messageFrameLength));
+
+ // Stream 2 receives a message
+ messageFrame = createMessageFrame(fakeMessage);
+ when(source.inputStream()).thenReturn(messageFrame);
+ frameHandler.data(false, 5, source, messageFrame.available());
+
+ verify(frameWriter).windowUpdate(eq(5), eq((long) contextFrameLength + messageFrameLength));
+ verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
+
+ stream1.cancel();
+ verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+ listener1.waitUntilStreamClosed();
+ assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status);
+
+ stream2.cancel();
+ verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+ listener2.waitUntilStreamClosed();
+ assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status);
+ }
+
+ @Test
+ public void stopNormally() throws Exception {
+ MockStreamListener listener1 = new MockStreamListener();
+ MockStreamListener listener2 = new MockStreamListener();
+ clientTransport.newStream(method, listener1);
+ clientTransport.newStream(method, listener2);
+ assertEquals(2, streams.size());
+ clientTransport.stopAsync();
+ listener1.waitUntilStreamClosed();
+ listener2.waitUntilStreamClosed();
+ verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any());
+ assertEquals(0, streams.size());
+ assertEquals(Code.INTERNAL, listener1.status.getCode());
+ assertEquals(Code.INTERNAL, listener2.status.getCode());
+ assertEquals(Service.State.TERMINATED, clientTransport.state());
+ }
+
+ @Test
+ public void receiveGoAway() throws Exception {
+ // start 2 streams.
+ MockStreamListener listener1 = new MockStreamListener();
+ MockStreamListener listener2 = new MockStreamListener();
+ clientTransport.newStream(method, listener1);
+ clientTransport.newStream(method, listener2);
+ assertEquals(2, streams.size());
+
+ // Receive goAway, max good id is 3.
+ frameHandler.goAway(3, ErrorCode.CANCEL, null);
+
+ // Transport should be in STOPPING state.
+ assertEquals(Service.State.STOPPING, clientTransport.state());
+
+ // Stream 2 should be closed.
+ listener2.waitUntilStreamClosed();
+ assertEquals(1, streams.size());
+ assertEquals(Code.UNAVAILABLE, listener2.status.getCode());
+
+ // New stream should be failed.
+ MockStreamListener listener3 = new MockStreamListener();
+ try {
+ clientTransport.newStream(method, listener3);
+ fail("new stream should no be accepted by a go-away transport.");
+ } catch (IllegalStateException ex) {
+ // expected.
+ }
+
+ // But stream 1 should be able to send.
+ final String sentMessage = "Should I also go away?";
+ OkHttpClientStream stream = streams.get(3);
+ InputStream input =
+ new ByteArrayInputStream(sentMessage.getBytes(StandardCharsets.UTF_8));
+ stream.writeMessage(input, input.available(), null);
+ stream.flush();
+ ArgumentCaptor<Buffer> captor =
+ ArgumentCaptor.forClass(Buffer.class);
+ verify(frameWriter).data(eq(false), eq(3), captor.capture());
+ Buffer sentFrame = captor.getValue();
+ checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream());
+
+ // And read.
+ final String receivedMessage = "No, you are fine.";
+ BufferedSource source = mock(BufferedSource.class);
+ InputStream inputStream = createMessageFrame(receivedMessage);
+ when(source.inputStream()).thenReturn(inputStream);
+ frameHandler.data(true, 3, source, inputStream.available());
+ listener1.waitUntilStreamClosed();
+ assertEquals(1, listener1.messages.size());
+ assertEquals(receivedMessage, listener1.messages.get(0));
+
+ // The transport should be stopped after all active streams finished.
+ assertTrue("Service state: " + clientTransport.state(),
+ Service.State.TERMINATED == clientTransport.state());
+ }
+
+ @Test
+ public void streamIdExhaust() throws Exception {
+ int startId = Integer.MAX_VALUE - 2;
+ AsyncFrameWriter writer = mock(AsyncFrameWriter.class);
+ OkHttpClientTransport transport =
+ new OkHttpClientTransport(frameReader, writer, startId);
+ transport.startAsync();
+ streams = transport.getStreams();
+
+ MockStreamListener listener1 = new MockStreamListener();
+ transport.newStream(method, listener1);
+
+ try {
+ transport.newStream(method, new MockStreamListener());
+ fail("new stream should not be accepted by a go-away transport.");
+ } catch (IllegalStateException ex) {
+ // expected.
+ }
+
+ streams.get(startId).cancel();
+ listener1.waitUntilStreamClosed();
+ verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL));
+ assertEquals(Service.State.TERMINATED, transport.state());
+ }
+
+ private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException {
+ assertEquals(in1.available(), in2.available());
+ byte[] b1 = new byte[in1.available()];
+ in1.read(b1);
+ byte[] b2 = new byte[in2.available()];
+ in2.read(b2);
+ for (int i = 0; i < b1.length; i++) {
+ if (b1[i] != b2[i]) {
+ fail("Different InputStream.");
+ }
+ }
+ }
+
+ private static InputStream createMessageFrame(String message) throws IOException {
+ return createMessageFrame(message.getBytes(StandardCharsets.UTF_8));
+ }
+
+ private static InputStream createMessageFrame(byte[] message) throws IOException {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.write(PAYLOAD_FRAME);
+ dos.writeInt(message.length);
+ dos.write(message);
+ dos.close();
+ byte[] messageFrame = os.toByteArray();
+
+ // Write the compression header followed by the message frame.
+ return addCompressionHeader(messageFrame);
+ }
+
+ private static InputStream createContextFrame(String key, String value) throws IOException {
+ byte[] body = ContextValue
+ .newBuilder()
+ .setKey(key)
+ .setValue(ByteString.copyFromUtf8(value))
+ .build()
+ .toByteArray();
+ return createContextFrame(body);
+ }
+
+ private static InputStream createContextFrame(byte[] body) throws IOException {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.write(CONTEXT_VALUE_FRAME);
+ dos.writeInt(body.length);
+ dos.write(body);
+ dos.close();
+ byte[] contextFrame = os.toByteArray();
+
+ // Write the compression header followed by the context frame.
+ return addCompressionHeader(contextFrame);
+ }
+
+ private static InputStream createStatusFrame(short code) throws IOException {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.write(STATUS_FRAME);
+ int length = 2;
+ dos.writeInt(length);
+ dos.writeShort(code);
+ dos.close();
+ byte[] statusFrame = os.toByteArray();
+
+ // Write the compression header followed by the status frame.
+ return addCompressionHeader(statusFrame);
+ }
+
+ private static InputStream addCompressionHeader(byte[] raw) throws IOException {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.writeInt(raw.length);
+ dos.write(raw);
+ dos.close();
+ return new ByteArrayInputStream(os.toByteArray());
+ }
+
+ private static class MockFrameReader implements FrameReader {
+ boolean closed;
+ boolean throwExceptionForNextFrame;
+
+ @Override
+ public void close() throws IOException {
+ closed = true;
+ }
+
+ @Override
+ public boolean nextFrame(Handler handler) throws IOException {
+ if (throwExceptionForNextFrame) {
+ throw new IOException(NETWORK_ISSUE_MESSAGE);
+ }
+ synchronized (this) {
+ try {
+ wait();
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+ if (throwExceptionForNextFrame) {
+ throw new IOException(NETWORK_ISSUE_MESSAGE);
+ }
+ return true;
+ }
+
+ synchronized void throwIOExceptionForNextFrame() {
+ throwExceptionForNextFrame = true;
+ notifyAll();
+ }
+
+ @Override
+ public void readConnectionPreface() throws IOException {
+ // not used.
+ }
+ }
+
+ private static class MockStreamListener implements StreamListener {
+ Status status;
+ CountDownLatch closed = new CountDownLatch(1);
+ ArrayList<String> messages = new ArrayList<String>();
+ Map<String, String> contexts = new HashMap<String, String>();
+
+ @Override
+ public ListenableFuture<Void> contextRead(String name, InputStream value, int length) {
+ String valueStr = getContent(value);
+ if (valueStr != null) {
+ // We assume only one context for each name.
+ contexts.put(name, valueStr);
+ }
+ return null;
+ }
+
+ @Override
+ public ListenableFuture<Void> messageRead(InputStream message, int length) {
+ String msg = getContent(message);
+ if (msg != null) {
+ messages.add(msg);
+ }
+ return null;
+ }
+
+ @Override
+ public void closed(Status status) {
+ this.status = status;
+ closed.countDown();
+ }
+
+ void waitUntilStreamClosed() throws InterruptedException {
+ if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) {
+ fail("Failed waiting stream to be closed.");
+ }
+ }
+
+ static String getContent(InputStream message) {
+ BufferedReader br =
+ new BufferedReader(new InputStreamReader(message, StandardCharsets.UTF_8));
+ try {
+ // Only one line message is used in this test.
+ return br.readLine();
+ } catch (IOException e) {
+ return null;
+ }
+ }
+ }
+}