Improve okhttp client transport, handles go away and add unit test.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72155172
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
index fc974e0..b4e9f68 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
@@ -1,5 +1,6 @@
package com.google.net.stubby.newtransport.okhttp;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteBuffers;
@@ -33,9 +34,10 @@
import java.io.IOException;
import java.net.Socket;
import java.nio.ByteBuffer;
-import java.util.Collection;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
@@ -48,6 +50,7 @@
*/
public class OkHttpClientTransport extends AbstractClientTransport {
/** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
+ @VisibleForTesting
static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
@@ -75,21 +78,40 @@
private final int port;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
- @GuardedBy("this")
+ private Object lock = new Object();
+ @GuardedBy("lock")
private int nextStreamId;
private final Map<Integer, OkHttpClientStream> streams =
Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
private final ExecutorService executor = Executors.newCachedThreadPool();
private int unacknowledgedBytesRead;
+ private ClientFrameHandler clientFrameHandler;
+ // The status used to finish all active streams when the transport is closed.
+ @GuardedBy("lock")
+ private boolean goAway;
+ @GuardedBy("lock")
+ private Status goAwayStatus;
public OkHttpClientTransport(String host, int port) {
- this.host = host;
+ this.host = Preconditions.checkNotNull(host);
this.port = port;
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
}
+ /**
+ * Create a transport connected to a fake peer for test.
+ */
+ @VisibleForTesting
+ OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
+ host = null;
+ port = -1;
+ this.nextStreamId = nextStreamId;
+ this.frameReader = frameReader;
+ this.frameWriter = frameWriter;
+ }
+
@Override
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
return new OkHttpClientStream(method, listener);
@@ -97,53 +119,85 @@
@Override
protected void doStart() {
- BufferedSource source;
- BufferedSink sink;
- try {
- Socket socket = new Socket(host, port);
- // TODO(user): use SpdyConnection.
- source = Okio.buffer(Okio.source(socket));
- sink = Okio.buffer(Okio.sink(socket));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ // We set host to null for test.
+ if (host != null) {
+ BufferedSource source;
+ BufferedSink sink;
+ try {
+ Socket socket = new Socket(host, port);
+ source = Okio.buffer(Okio.source(socket));
+ sink = Okio.buffer(Okio.sink(socket));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ Variant variant = new Http20Draft12();
+ frameReader = variant.newReader(source, true);
+ frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
}
- Variant variant = new Http20Draft12();
- frameReader = variant.newReader(source, true);
- frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
- executor.execute(new ClientFrameHandler());
notifyStarted();
+ clientFrameHandler = new ClientFrameHandler();
+ executor.execute(clientFrameHandler);
}
@Override
protected void doStop() {
- closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
- frameWriter.close();
- try {
- frameReader.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
+ boolean normalClose;
+ synchronized (lock) {
+ normalClose = !goAway;
}
- executor.shutdown();
- notifyStopped();
+ if (normalClose) {
+ abort(new Status(Code.INTERNAL, "Transport stopped"));
+ // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
+ // The GOAWAY is part of graceful shutdown.
+ frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
+ }
+ stopIfNecessary();
+ }
+
+ @VisibleForTesting
+ ClientFrameHandler getHandler() {
+ return clientFrameHandler;
+ }
+
+ @VisibleForTesting
+ Map<Integer, OkHttpClientStream> getStreams() {
+ return streams;
}
/**
- * Close and remove all streams.
+ * Finish all active streams with given status, then close the transport.
*/
- private void closeAllStreams(Status status) {
- Collection<OkHttpClientStream> streamsCopy;
- synchronized (streams) {
- streamsCopy = streams.values();
- streams.clear();
+ void abort(Status status) {
+ onGoAway(-1, status);
+ }
+
+ private void onGoAway(int lastKnownStreamId, Status status) {
+ ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
+ synchronized (lock) {
+ goAway = true;
+ goAwayStatus = status;
+ Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
+ while (it.hasNext()) {
+ Map.Entry<Integer, OkHttpClientStream> entry = it.next();
+ if (entry.getKey() > lastKnownStreamId) {
+ goAwayStreams.add(entry.getValue());
+ it.remove();
+ }
+ }
}
- for (OkHttpClientStream stream : streamsCopy) {
+
+ // Starting stop, go into STOPPING state so that Channel know this Transport should not be used
+ // further, will become STOPPED once all streams are complete.
+ stopAsync();
+
+ for (OkHttpClientStream stream : goAwayStreams) {
stream.setStatus(status);
}
}
/**
- * Called when a HTTP2 stream is closed.
+ * Called when a stream is closed.
*
* <p> Return false if the stream has already finished.
*/
@@ -159,10 +213,39 @@
}
/**
+ * When the transport is in goAway states, we should stop it once all active streams finish.
+ */
+ private void stopIfNecessary() {
+ boolean shouldStop;
+ synchronized (lock) {
+ shouldStop = (goAway && streams.size() == 0);
+ }
+ if (shouldStop) {
+ frameWriter.close();
+ try {
+ frameReader.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ executor.shutdown();
+ notifyStopped();
+ }
+ }
+
+ /**
+ * Returns a Grpc status corresponding to the given ErrorCode.
+ */
+ @VisibleForTesting
+ static Status toGrpcStatus(ErrorCode code) {
+ return ERROR_CODE_TO_STATUS.get(code);
+ }
+
+ /**
* Runnable which reads frames and dispatches them to in flight calls
*/
- private class ClientFrameHandler implements FrameReader.Handler, Runnable {
- private ClientFrameHandler() {}
+ @VisibleForTesting
+ class ClientFrameHandler implements FrameReader.Handler, Runnable {
+ ClientFrameHandler() {}
@Override
public void run() {
@@ -173,8 +256,7 @@
while (frameReader.nextFrame(this)) {
}
} catch (IOException ioe) {
- ioe.printStackTrace();
- closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
+ abort(Status.fromThrowable(ioe));
} finally {
// Restore the original thread name.
Thread.currentThread().setName(threadName);
@@ -210,7 +292,9 @@
stream.unacknowledgedBytesRead = 0;
}
if (inFinished) {
- finishStream(streamId, Status.OK);
+ if (finishStream(streamId, Status.OK)) {
+ stopIfNecessary();
+ }
}
}
@@ -229,7 +313,9 @@
@Override
public void rstStream(int streamId, ErrorCode errorCode) {
- finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
+ if (finishStream(streamId, toGrpcStatus(errorCode))) {
+ stopIfNecessary();
+ }
}
@Override
@@ -252,18 +338,14 @@
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
- // TODO(user): Log here and implement the real Go away behavior: streams have
- // id <= lastGoodStreamId should not be closed.
- closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
- stopAsync();
+ onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {
- // TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
- // PROTOCOL_ERROR.
- frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
+ // We don't accept server initiated stream.
+ frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
}
@Override
@@ -284,28 +366,42 @@
}
}
+ @GuardedBy("lock")
+ private void assignStreamId(OkHttpClientStream stream) {
+ Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
+ stream.streamId = nextStreamId;
+ streams.put(stream.streamId, stream);
+ if (nextStreamId >= Integer.MAX_VALUE - 2) {
+ onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
+ } else {
+ nextStreamId += 2;
+ }
+ }
+
/**
* Client stream for the okhttp transport.
*/
- private class OkHttpClientStream extends AbstractStream implements ClientStream {
+ @VisibleForTesting
+ class OkHttpClientStream extends AbstractStream implements ClientStream {
int streamId;
final InputStreamDeframer deframer;
int unacknowledgedBytesRead;
- public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
+ OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
super(listener);
- Preconditions.checkState(streamId == 0, "StreamId should be 0");
- synchronized (OkHttpClientTransport.this) {
- streamId = nextStreamId;
- nextStreamId += 2;
- streams.put(streamId, this);
- frameWriter.synStream(false, false, streamId, 0,
- Headers.createRequestHeaders(method.getName()));
- }
deframer = new InputStreamDeframer(inboundMessageHandler());
+ synchronized (lock) {
+ if (goAway) {
+ setStatus(goAwayStatus);
+ return;
+ }
+ assignStreamId(this);
+ }
+ frameWriter.synStream(false, false, streamId, 0,
+ Headers.createRequestHeaders(method.getName()));
}
- public InputStreamDeframer getDeframer() {
+ InputStreamDeframer getDeframer() {
return deframer;
}
@@ -330,8 +426,9 @@
public void cancel() {
Preconditions.checkState(streamId != 0, "streamId should be set");
outboundPhase = Phase.STATUS;
- if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
+ if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
frameWriter.rstStream(streamId, ErrorCode.CANCEL);
+ stopIfNecessary();
}
}
}