Adding crude outbound flow control to OkHttp transport.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=80390743
diff --git a/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java b/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java
index 6239ca2..effdaf9 100644
--- a/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java
+++ b/integration-testing/src/main/java/com/google/net/stubby/testing/integration/AbstractTransportTest.java
@@ -3,7 +3,6 @@
import static com.google.net.stubby.testing.integration.Messages.PayloadType.COMPRESSABLE;
import static com.google.net.stubby.testing.integration.Util.assertEquals;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import com.google.common.base.Throwables;
@@ -41,8 +40,8 @@
import org.junit.Before;
import org.junit.Test;
-import java.util.Arrays;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
diff --git a/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java b/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java
index de8978d..c0230d9 100644
--- a/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java
+++ b/integration-testing/src/test/java/com/google/net/stubby/testing/integration/Http2OkHttpTest.java
@@ -1,7 +1,5 @@
package com.google.net.stubby.testing.integration;
-import static org.junit.Assume.assumeTrue;
-
import com.google.net.stubby.ChannelImpl;
import com.google.net.stubby.transport.AbstractStream;
import com.google.net.stubby.transport.netty.NettyServerBuilder;
@@ -35,10 +33,4 @@
protected ChannelImpl createChannel() {
return OkHttpChannelBuilder.forAddress("127.0.0.1", serverPort).build();
}
-
- @Override
- public void clientStreaming() {
- // TODO(user): Broken. We assume due to flow control bugs.
- assumeTrue(false);
- }
}
diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java
index 29b373c..3ca82ff 100644
--- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java
+++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientStream.java
@@ -28,7 +28,8 @@
*/
static OkHttpClientStream newStream(final Executor executor, ClientStreamListener listener,
AsyncFrameWriter frameWriter,
- OkHttpClientTransport transport) {
+ OkHttpClientTransport transport,
+ OutboundFlowController outboundFlow) {
// Create a lock object that can be used by both the executor and methods in the stream
// to ensure consistent locking behavior.
final Object executorLock = new Object();
@@ -46,7 +47,7 @@
}
};
return new OkHttpClientStream(synchronizingExecutor, listener, frameWriter, transport,
- executorLock);
+ executorLock, outboundFlow);
}
@GuardedBy("executorLock")
@@ -54,15 +55,18 @@
@GuardedBy("executorLock")
private int processedWindow = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE;
private final AsyncFrameWriter frameWriter;
+ private final OutboundFlowController outboundFlow;
private final OkHttpClientTransport transport;
// Lock used to synchronize with work done on the executor.
private final Object executorLock;
+ private Object outboundFlowState;
private OkHttpClientStream(final Executor executor,
final ClientStreamListener listener,
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
- Object executorLock) {
+ Object executorLock,
+ OutboundFlowController outboundFlow) {
super(listener, null, executor);
if (!GRPC_V2_PROTOCOL) {
throw new RuntimeException("okhttp transport can only work with V2 protocol!");
@@ -70,6 +74,7 @@
this.frameWriter = frameWriter;
this.transport = transport;
this.executorLock = executorLock;
+ this.outboundFlow = outboundFlow;
}
public void transportHeadersReceived(List<Header> headers, boolean endOfStream) {
@@ -105,8 +110,7 @@
// Per http2 SPEC, the max data length should be larger than 64K, while our frame size is
// only 4K.
Preconditions.checkState(buffer.size() < frameWriter.maxDataLength());
- frameWriter.data(endOfStream, id(), buffer, (int) buffer.size());
- frameWriter.flush();
+ outboundFlow.data(endOfStream, id(), buffer);
}
@Override
@@ -144,4 +148,12 @@
transport.stopIfNecessary();
}
}
+
+ void setOutboundFlowState(Object outboundFlowState) {
+ this.outboundFlowState = outboundFlowState;
+ }
+
+ Object getOutboundFlowState() {
+ return outboundFlowState;
+ }
}
diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java
index be7bb27..d0a9050 100644
--- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java
+++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OkHttpClientTransport.java
@@ -82,6 +82,7 @@
private final String defaultAuthority;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
+ private OutboundFlowController outboundFlow;
private final Object lock = new Object();
@GuardedBy("lock")
private int nextStreamId;
@@ -118,6 +119,7 @@
this.executor = Preconditions.checkNotNull(executor);
this.frameReader = Preconditions.checkNotNull(frameReader);
this.frameWriter = Preconditions.checkNotNull(frameWriter);
+ this.outboundFlow = new OutboundFlowController(this, frameWriter);
this.nextStreamId = nextStreamId;
}
@@ -126,7 +128,7 @@
Metadata.Headers headers,
ClientStreamListener listener) {
OkHttpClientStream clientStream = OkHttpClientStream.newStream(executor, listener,
- frameWriter, this);
+ frameWriter, this, outboundFlow);
if (goAway) {
clientStream.setStatus(goAwayStatus, new Metadata.Trailers());
} else {
@@ -154,6 +156,7 @@
Variant variant = new Http20Draft14();
frameReader = variant.newReader(source, true);
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
+ outboundFlow = new OutboundFlowController(this, frameWriter);
frameWriter.connectionPreface();
Settings settings = new Settings();
frameWriter.settings(settings);
@@ -185,7 +188,6 @@
return clientFrameHandler;
}
- @VisibleForTesting
Map<Integer, OkHttpClientStream> getStreams() {
return streams;
}
@@ -395,8 +397,8 @@
}
@Override
- public void windowUpdate(int arg0, long arg1) {
- // TODO(user): outbound flow control.
+ public void windowUpdate(int streamId, long delta) {
+ outboundFlow.windowUpdate(streamId, (int) delta);
}
@Override
diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java
new file mode 100644
index 0000000..a1881d0
--- /dev/null
+++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/OutboundFlowController.java
@@ -0,0 +1,399 @@
+package com.google.net.stubby.transport.okhttp;
+
+import static com.google.net.stubby.transport.okhttp.Utils.CONNECTION_STREAM_ID;
+import static com.google.net.stubby.transport.okhttp.Utils.DEFAULT_WINDOW_SIZE;
+import static com.google.net.stubby.transport.okhttp.Utils.MAX_FRAME_SIZE;
+import static java.lang.Math.ceil;
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+
+import com.google.common.base.Preconditions;
+
+import com.squareup.okhttp.internal.spdy.FrameWriter;
+
+import okio.Buffer;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Queue;
+
+/**
+ * Simple outbound flow controller that evenly splits the connection window across all existing
+ * streams.
+ */
+class OutboundFlowController {
+ private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
+ private final OkHttpClientTransport transport;
+ private final FrameWriter frameWriter;
+ private int initialWindowSize = DEFAULT_WINDOW_SIZE;
+ private final OutboundFlowState connectionState = new OutboundFlowState(CONNECTION_STREAM_ID);
+
+ OutboundFlowController(OkHttpClientTransport transport, FrameWriter frameWriter) {
+ this.transport = Preconditions.checkNotNull(transport, "transport");
+ this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
+ }
+
+ synchronized void initialOutboundWindowSize(int newWindowSize) {
+ if (newWindowSize < 0) {
+ throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize);
+ }
+
+ int delta = newWindowSize - initialWindowSize;
+ initialWindowSize = newWindowSize;
+ for (OkHttpClientStream stream : getActiveStreams()) {
+ // Verify that the maximum value is not exceeded by this change.
+ OutboundFlowState state = state(stream);
+ state.incrementStreamWindow(delta);
+ }
+
+ if (delta > 0) {
+ // The window size increased, send any pending frames for all streams.
+ writeStreams();
+ }
+ }
+
+ synchronized void windowUpdate(int streamId, int delta) {
+ if (streamId == CONNECTION_STREAM_ID) {
+ // Update the connection window and write any pending frames for all streams.
+ connectionState.incrementStreamWindow(delta);
+ writeStreams();
+ } else {
+ // Update the stream window and write any pending frames for the stream.
+ OutboundFlowState state = stateOrFail(streamId);
+ state.incrementStreamWindow(delta);
+
+ WriteStatus writeStatus = new WriteStatus();
+ state.writeBytes(state.writableWindow(), writeStatus);
+ if (writeStatus.hasWritten()) {
+ flush();
+ }
+ }
+ }
+
+ synchronized void data(boolean outFinished, int streamId, Buffer source) {
+ Preconditions.checkNotNull(source, "source");
+ if (streamId <= 0) {
+ throw new IllegalArgumentException("streamId must be > 0");
+ }
+
+ OutboundFlowState state = stateOrFail(streamId);
+ int window = state.writableWindow();
+ boolean framesAlreadyQueued = state.hasFrame();
+
+ OutboundFlowState.Frame frame = state.newFrame(source, outFinished);
+ if (!framesAlreadyQueued && window >= frame.size()) {
+ // Window size is large enough to send entire data frame
+ frame.write();
+ flush();
+ return;
+ }
+
+ // Enqueue the frame to be written when the window size permits.
+ frame.enqueue();
+
+ if (framesAlreadyQueued || window <= 0) {
+ // Stream already has frames pending or is stalled, don't send anything now.
+ return;
+ }
+
+ // Create and send a partial frame up to the window size.
+ frame.split(window).write();
+ flush();
+ }
+
+ private void flush() {
+ try {
+ frameWriter.flush();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private OutboundFlowState state(OkHttpClientStream stream) {
+ OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState();
+ if (state == null) {
+ state = new OutboundFlowState(stream.id());
+ stream.setOutboundFlowState(state);
+ }
+ return state;
+ }
+
+ private OutboundFlowState state(int streamId) {
+ OkHttpClientStream stream = transport.getStreams().get(streamId);
+ return stream != null ? state(stream) : null;
+ }
+
+ private OutboundFlowState stateOrFail(int streamId) {
+ OutboundFlowState state = state(streamId);
+ if (state == null) {
+ throw new RuntimeException("Missing flow control window for stream: " + streamId);
+ }
+ return state;
+ }
+
+ /**
+ * Gets all active streams as an array.
+ */
+ private OkHttpClientStream[] getActiveStreams() {
+ return transport.getStreams().values().toArray(EMPTY_STREAM_ARRAY);
+ }
+
+ /**
+ * Writes as much data for all the streams as possible given the current flow control windows.
+ */
+ private void writeStreams() {
+ OkHttpClientStream[] streams = getActiveStreams();
+ int connectionWindow = connectionState.window();
+ for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) {
+ int nextNumStreams = 0;
+ int windowSlice = (int) ceil(connectionWindow / (float) numStreams);
+ for (int index = 0; index < numStreams && connectionWindow > 0; ++index) {
+ OkHttpClientStream stream = streams[index];
+ OutboundFlowState state = state(stream);
+
+ int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice));
+ if (bytesForStream > 0) {
+ state.allocateBytes(bytesForStream);
+ connectionWindow -= bytesForStream;
+ }
+
+ if (state.unallocatedBytes() > 0) {
+ // There is more data to process for this stream. Add it to the next
+ // pass.
+ streams[nextNumStreams++] = stream;
+ }
+ }
+ numStreams = nextNumStreams;
+ }
+
+ // Now take one last pass through all of the streams and write any allocated bytes.
+ WriteStatus writeStatus = new WriteStatus();
+ for (OkHttpClientStream stream : getActiveStreams()) {
+ OutboundFlowState state = state(stream);
+ state.writeBytes(state.allocatedBytes(), writeStatus);
+ state.clearAllocatedBytes();
+ }
+
+ if (writeStatus.hasWritten()) {
+ flush();
+ }
+ }
+
+ /**
+ * Simple status that keeps track of the number of writes performed.
+ */
+ private final class WriteStatus {
+ int numWrites;
+
+ void incrementNumWrites() {
+ numWrites++;
+ }
+
+ boolean hasWritten() {
+ return numWrites > 0;
+ }
+ }
+
+ /**
+ * The outbound flow control state for a single stream.
+ */
+ private final class OutboundFlowState {
+ final Queue<Frame> pendingWriteQueue;
+ final int streamId;
+ int queuedBytes;
+ int window = initialWindowSize;
+ int allocatedBytes;
+
+ OutboundFlowState(int streamId) {
+ this.streamId = streamId;
+ pendingWriteQueue = new ArrayDeque<Frame>(2);
+ }
+
+ int window() {
+ return window;
+ }
+
+ void allocateBytes(int bytes) {
+ allocatedBytes += bytes;
+ }
+
+ int allocatedBytes() {
+ return allocatedBytes;
+ }
+
+ int unallocatedBytes() {
+ return streamableBytes() - allocatedBytes;
+ }
+
+ void clearAllocatedBytes() {
+ allocatedBytes = 0;
+ }
+
+ /**
+ * Increments the flow control window for this stream by the given delta and returns the new
+ * value.
+ */
+ int incrementStreamWindow(int delta) {
+ if (delta > 0 && Integer.MAX_VALUE - delta < window) {
+ throw new IllegalArgumentException("Window size overflow for stream: " + streamId);
+ }
+ window += delta;
+
+ return window;
+ }
+
+ /**
+ * Returns the maximum writable window (minimum of the stream and connection windows).
+ */
+ int writableWindow() {
+ return min(window, connectionState.window());
+ }
+
+ int streamableBytes() {
+ return max(0, min(window, queuedBytes));
+ }
+
+ /**
+ * Creates a new frame with the given values but does not add it to the pending queue.
+ */
+ Frame newFrame(Buffer data, boolean endStream) {
+ return new Frame(data, endStream);
+ }
+
+ /**
+ * Indicates whether or not there are frames in the pending queue.
+ */
+ boolean hasFrame() {
+ return !pendingWriteQueue.isEmpty();
+ }
+
+ /**
+ * Returns the the head of the pending queue, or {@code null} if empty.
+ */
+ private Frame peek() {
+ return pendingWriteQueue.peek();
+ }
+
+ /**
+ * Writes up to the number of bytes from the pending queue.
+ */
+ int writeBytes(int bytes, WriteStatus writeStatus) {
+ int bytesAttempted = 0;
+ int maxBytes = min(bytes, writableWindow());
+ while (hasFrame()) {
+ Frame pendingWrite = peek();
+ if (maxBytes >= pendingWrite.size()) {
+ // Window size is large enough to send entire data frame
+ writeStatus.incrementNumWrites();
+ bytesAttempted += pendingWrite.size();
+ pendingWrite.write();
+ } else if (maxBytes <= 0) {
+ // No data from the current frame can be written - we're done.
+ // We purposely check this after first testing the size of the
+ // pending frame to properly handle zero-length frame.
+ break;
+ } else {
+ // We can send a partial frame
+ Frame partialFrame = pendingWrite.split(maxBytes);
+ writeStatus.incrementNumWrites();
+ bytesAttempted += partialFrame.size();
+ partialFrame.write();
+ }
+
+ // Update the threshold.
+ maxBytes = min(bytes - bytesAttempted, writableWindow());
+ }
+ return bytesAttempted;
+ }
+
+ /**
+ * A wrapper class around the content of a data frame.
+ */
+ private final class Frame {
+ final Buffer data;
+ final boolean endStream;
+ boolean enqueued;
+
+ Frame(Buffer data, boolean endStream) {
+ this.data = data;
+ this.endStream = endStream;
+ }
+
+ /**
+ * Gets the total size (in bytes) of this frame including the data and padding.
+ */
+ int size() {
+ return (int) data.size();
+ }
+
+ void enqueue() {
+ if (!enqueued) {
+ enqueued = true;
+ pendingWriteQueue.offer(this);
+
+ // Increment the number of pending bytes for this stream.
+ queuedBytes += size();
+ }
+ }
+
+ /**
+ * Writes the frame and decrements the stream and connection window sizes. If the frame is in
+ * the pending queue, the written bytes are removed from this branch of the priority tree.
+ */
+ void write() {
+ // Using a do/while loop because if the buffer is empty we still need to call
+ // the writer once to send the empty frame.
+ do {
+ int bytesToWrite = size();
+ int frameBytes = min(bytesToWrite, MAX_FRAME_SIZE);
+ if (frameBytes == bytesToWrite) {
+ // All the bytes fit into a single HTTP/2 frame, just send it all.
+ connectionState.incrementStreamWindow(-bytesToWrite);
+ incrementStreamWindow(-bytesToWrite);
+ try {
+ frameWriter.data(endStream, streamId, data, bytesToWrite);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ if (enqueued) {
+ // It's enqueued - remove it from the head of the pending write queue.
+ queuedBytes -= bytesToWrite;
+ pendingWriteQueue.remove(this);
+ }
+ return;
+ }
+
+ // Split a chunk that will fit into a single HTTP/2 frame and write it.
+ Frame frame = split(frameBytes);
+ frame.write();
+ } while (size() > 0);
+ }
+
+ /**
+ * Creates a new frame that is a view of this frame's data. The {@code maxBytes} are first
+ * split from the data buffer. If not all the requested bytes are available, the remaining
+ * bytes are then split from the padding (if available).
+ *
+ * @param maxBytes the maximum number of bytes that is allowed in the created frame.
+ * @return the partial frame.
+ */
+ Frame split(int maxBytes) {
+ // The requested maxBytes should always be less than the size of this frame.
+ assert maxBytes < size() : "Attempting to split a frame for the full size.";
+
+ // Get the portion of the data buffer to be split. Limit to the readable bytes.
+ int dataSplit = min(maxBytes, (int) data.size());
+
+ Buffer splitSlice = new Buffer();
+ splitSlice.write(data, dataSplit);
+
+ Frame frame = new Frame(splitSlice, false);
+
+ if (enqueued) {
+ queuedBytes -= dataSplit;
+ }
+ return frame;
+ }
+ }
+ }
+}
diff --git a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java
index 7293b97..6a42970 100644
--- a/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java
+++ b/okhttp/src/main/java/com/google/net/stubby/transport/okhttp/Utils.java
@@ -10,6 +10,10 @@
* Common utility methods for OkHttp transport.
*/
class Utils {
+ static final int DEFAULT_WINDOW_SIZE = 65535;
+ static final int CONNECTION_STREAM_ID = 0;
+ static final int MAX_FRAME_SIZE = 16384;
+
public static Metadata.Headers convertHeaders(List<Header> http2Headers) {
return new Metadata.Headers(convertHeadersToArray(http2Headers));
}