Improve okhttp client transport, handles go away and add unit test.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72155172
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
index d5ac86c..6affe1f 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/AsyncFrameWriter.java
@@ -1,7 +1,8 @@
 package com.google.net.stubby.newtransport.okhttp;
 
 import com.google.common.util.concurrent.SerializingExecutor;
-import com.google.common.util.concurrent.Service;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.transport.Transport.Code;
 
 import com.squareup.okhttp.internal.spdy.ErrorCode;
 import com.squareup.okhttp.internal.spdy.FrameWriter;
@@ -17,9 +18,10 @@
 class AsyncFrameWriter implements FrameWriter {
   private final FrameWriter frameWriter;
   private final Executor executor;
-  private final Service transport;
+  private final OkHttpClientTransport transport;
 
-  public AsyncFrameWriter(FrameWriter frameWriter, Service transport, Executor executor) {
+  public AsyncFrameWriter(FrameWriter frameWriter, OkHttpClientTransport transport,
+      Executor executor) {
     this.frameWriter = frameWriter;
     this.transport = transport;
     // Although writes are thread-safe, we serialize them to prevent consuming many Threads that are
@@ -158,6 +160,8 @@
       @Override
       public void doRun() throws IOException {
         frameWriter.goAway(lastGoodStreamId, errorCode, debugData);
+        // Flush it since after goAway, we are likely to close this writer.
+        frameWriter.flush();
       }
     });
   }
@@ -188,7 +192,7 @@
       try {
         doRun();
       } catch (IOException ex) {
-        transport.stopAsync();
+        transport.abort(Status.fromThrowable(ex));
         throw new RuntimeException(ex);
       }
     }
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
index fc974e0..b4e9f68 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
@@ -1,5 +1,6 @@
 package com.google.net.stubby.newtransport.okhttp;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.io.ByteBuffers;
@@ -33,9 +34,10 @@
 import java.io.IOException;
 import java.net.Socket;
 import java.nio.ByteBuffer;
-import java.util.Collection;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
@@ -48,6 +50,7 @@
  */
 public class OkHttpClientTransport extends AbstractClientTransport {
   /** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
+  @VisibleForTesting
   static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
 
   private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
@@ -75,21 +78,40 @@
   private final int port;
   private FrameReader frameReader;
   private AsyncFrameWriter frameWriter;
-  @GuardedBy("this")
+  private Object lock = new Object();
+  @GuardedBy("lock")
   private int nextStreamId;
   private final Map<Integer, OkHttpClientStream> streams =
       Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
   private final ExecutorService executor = Executors.newCachedThreadPool();
   private int unacknowledgedBytesRead;
+  private ClientFrameHandler clientFrameHandler;
+  // The status used to finish all active streams when the transport is closed.
+  @GuardedBy("lock")
+  private boolean goAway;
+  @GuardedBy("lock")
+  private Status goAwayStatus;
 
   public OkHttpClientTransport(String host, int port) {
-    this.host = host;
+    this.host = Preconditions.checkNotNull(host);
     this.port = port;
     // Client initiated streams are odd, server initiated ones are even. Server should not need to
     // use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
     nextStreamId = 3;
   }
 
+  /**
+   * Create a transport connected to a fake peer for test.
+   */
+  @VisibleForTesting
+  OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
+    host = null;
+    port = -1;
+    this.nextStreamId = nextStreamId;
+    this.frameReader = frameReader;
+    this.frameWriter = frameWriter;
+  }
+
   @Override
   protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
     return new OkHttpClientStream(method, listener);
@@ -97,53 +119,85 @@
 
   @Override
   protected void doStart() {
-    BufferedSource source;
-    BufferedSink sink;
-    try {
-      Socket socket = new Socket(host, port);
-      // TODO(user): use SpdyConnection.
-      source = Okio.buffer(Okio.source(socket));
-      sink = Okio.buffer(Okio.sink(socket));
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    // We set host to null for test.
+    if (host != null) {
+      BufferedSource source;
+      BufferedSink sink;
+      try {
+        Socket socket = new Socket(host, port);
+        source = Okio.buffer(Okio.source(socket));
+        sink = Okio.buffer(Okio.sink(socket));
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      Variant variant = new Http20Draft12();
+      frameReader = variant.newReader(source, true);
+      frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
     }
-    Variant variant = new Http20Draft12();
-    frameReader = variant.newReader(source, true);
-    frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
 
-    executor.execute(new ClientFrameHandler());
     notifyStarted();
+    clientFrameHandler = new ClientFrameHandler();
+    executor.execute(clientFrameHandler);
   }
 
   @Override
   protected void doStop() {
-    closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
-    frameWriter.close();
-    try {
-      frameReader.close();
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    boolean normalClose;
+    synchronized (lock) {
+      normalClose = !goAway;
     }
-    executor.shutdown();
-    notifyStopped();
+    if (normalClose) {
+      abort(new Status(Code.INTERNAL, "Transport stopped"));
+      // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
+      // The GOAWAY is part of graceful shutdown.
+      frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
+    }
+    stopIfNecessary();
+  }
+
+  @VisibleForTesting
+  ClientFrameHandler getHandler() {
+    return clientFrameHandler;
+  }
+
+  @VisibleForTesting
+  Map<Integer, OkHttpClientStream> getStreams() {
+    return streams;
   }
 
   /**
-   * Close and remove all streams.
+   * Finish all active streams with given status, then close the transport.
    */
-  private void closeAllStreams(Status status) {
-    Collection<OkHttpClientStream> streamsCopy;
-    synchronized (streams) {
-      streamsCopy = streams.values();
-      streams.clear();
+  void abort(Status status) {
+    onGoAway(-1, status);
+  }
+
+  private void onGoAway(int lastKnownStreamId, Status status) {
+    ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
+    synchronized (lock) {
+      goAway = true;
+      goAwayStatus = status;
+      Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
+      while (it.hasNext()) {
+        Map.Entry<Integer, OkHttpClientStream> entry = it.next();
+        if (entry.getKey() > lastKnownStreamId) {
+          goAwayStreams.add(entry.getValue());
+          it.remove();
+        }
+      }
     }
-    for (OkHttpClientStream stream : streamsCopy) {
+
+    // Starting stop, go into STOPPING state so that Channel know this Transport should not be used
+    // further, will become STOPPED once all streams are complete.
+    stopAsync();
+
+    for (OkHttpClientStream stream : goAwayStreams) {
       stream.setStatus(status);
     }
   }
 
   /**
-   * Called when a HTTP2 stream is closed.
+   * Called when a stream is closed.
    *
    * <p> Return false if the stream has already finished.
    */
@@ -159,10 +213,39 @@
   }
 
   /**
+   * When the transport is in goAway states, we should stop it once all active streams finish.
+   */
+  private void stopIfNecessary() {
+    boolean shouldStop;
+    synchronized (lock) {
+      shouldStop = (goAway && streams.size() == 0);
+    }
+    if (shouldStop) {
+      frameWriter.close();
+      try {
+        frameReader.close();
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      executor.shutdown();
+      notifyStopped();
+    }
+  }
+
+  /**
+   * Returns a Grpc status corresponding to the given ErrorCode.
+   */
+  @VisibleForTesting
+  static Status toGrpcStatus(ErrorCode code) {
+    return ERROR_CODE_TO_STATUS.get(code);
+  }
+
+  /**
    * Runnable which reads frames and dispatches them to in flight calls
    */
-  private class ClientFrameHandler implements FrameReader.Handler, Runnable {
-    private ClientFrameHandler() {}
+  @VisibleForTesting
+  class ClientFrameHandler implements FrameReader.Handler, Runnable {
+    ClientFrameHandler() {}
 
     @Override
     public void run() {
@@ -173,8 +256,7 @@
         while (frameReader.nextFrame(this)) {
         }
       } catch (IOException ioe) {
-        ioe.printStackTrace();
-        closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
+        abort(Status.fromThrowable(ioe));
       } finally {
         // Restore the original thread name.
         Thread.currentThread().setName(threadName);
@@ -210,7 +292,9 @@
         stream.unacknowledgedBytesRead = 0;
       }
       if (inFinished) {
-        finishStream(streamId, Status.OK);
+        if (finishStream(streamId, Status.OK)) {
+          stopIfNecessary();
+        }
       }
     }
 
@@ -229,7 +313,9 @@
 
     @Override
     public void rstStream(int streamId, ErrorCode errorCode) {
-      finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
+      if (finishStream(streamId, toGrpcStatus(errorCode))) {
+        stopIfNecessary();
+      }
     }
 
     @Override
@@ -252,18 +338,14 @@
 
     @Override
     public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
-      // TODO(user): Log here and implement the real Go away behavior: streams have
-      // id <= lastGoodStreamId should not be closed.
-      closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
-      stopAsync();
+      onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
     }
 
     @Override
     public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
         throws IOException {
-      // TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
-      // PROTOCOL_ERROR.
-      frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
+      // We don't accept server initiated stream.
+      frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
     }
 
     @Override
@@ -284,28 +366,42 @@
     }
   }
 
+  @GuardedBy("lock")
+  private void assignStreamId(OkHttpClientStream stream) {
+    Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
+    stream.streamId = nextStreamId;
+    streams.put(stream.streamId, stream);
+    if (nextStreamId >= Integer.MAX_VALUE - 2) {
+      onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
+    } else {
+      nextStreamId += 2;
+    }
+  }
+
   /**
    * Client stream for the okhttp transport.
    */
-  private class OkHttpClientStream extends AbstractStream implements ClientStream {
+  @VisibleForTesting
+  class OkHttpClientStream extends AbstractStream implements ClientStream {
     int streamId;
     final InputStreamDeframer deframer;
     int unacknowledgedBytesRead;
 
-    public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
+    OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
       super(listener);
-      Preconditions.checkState(streamId == 0, "StreamId should be 0");
-      synchronized (OkHttpClientTransport.this) {
-        streamId = nextStreamId;
-        nextStreamId += 2;
-        streams.put(streamId, this);
-        frameWriter.synStream(false, false, streamId, 0,
-            Headers.createRequestHeaders(method.getName()));
-      }
       deframer = new InputStreamDeframer(inboundMessageHandler());
+      synchronized (lock) {
+        if (goAway) {
+          setStatus(goAwayStatus);
+          return;
+        }
+        assignStreamId(this);
+      }
+      frameWriter.synStream(false, false, streamId, 0,
+          Headers.createRequestHeaders(method.getName()));
     }
 
-    public InputStreamDeframer getDeframer() {
+    InputStreamDeframer getDeframer() {
       return deframer;
     }
 
@@ -330,8 +426,9 @@
     public void cancel() {
       Preconditions.checkState(streamId != 0, "streamId should be set");
       outboundPhase = Phase.STATUS;
-      if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
+      if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
         frameWriter.rstStream(streamId, ErrorCode.CANCEL);
+        stopIfNecessary();
       }
     }
   }
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java
new file mode 100644
index 0000000..7c3d7fd
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportTest.java
@@ -0,0 +1,555 @@
+package com.google.net.stubby.newtransport.okhttp;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.Service;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.StreamListener;
+import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.ClientFrameHandler;
+import com.google.net.stubby.newtransport.okhttp.OkHttpClientTransport.OkHttpClientStream;
+import com.google.net.stubby.transport.Transport;
+import com.google.net.stubby.transport.Transport.Code;
+import com.google.net.stubby.transport.Transport.ContextValue;
+import com.google.protobuf.ByteString;
+
+import com.squareup.okhttp.internal.spdy.ErrorCode;
+import com.squareup.okhttp.internal.spdy.FrameReader;
+
+import okio.Buffer;
+import okio.BufferedSource;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Tests for {@link OkHttpClientTransport}.
+ */
+@RunWith(JUnit4.class)
+public class OkHttpClientTransportTest {
+  private static final int TIME_OUT_MS = 5000000;
+  private static final String NETWORK_ISSUE_MESSAGE = "network issue";
+
+  // Flags
+  private static final byte PAYLOAD_FRAME = 0x0;
+  public static final byte CONTEXT_VALUE_FRAME = 0x1;
+  public static final byte STATUS_FRAME = 0x3;
+
+  @Mock
+  private AsyncFrameWriter frameWriter;
+  @Mock
+  MethodDescriptor<?, ?> method;
+  private OkHttpClientTransport clientTransport;
+  private MockFrameReader frameReader;
+  private Map<Integer, OkHttpClientStream> streams;
+  private ClientFrameHandler frameHandler;
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    streams = new HashMap<Integer, OkHttpClientStream>();
+    frameReader = new MockFrameReader();
+    clientTransport = new OkHttpClientTransport(frameReader, frameWriter, 3);
+    clientTransport.startAsync();
+    frameHandler = clientTransport.getHandler();
+    streams = clientTransport.getStreams();
+    when(method.getName()).thenReturn("fakemethod");
+  }
+
+  @After
+  public void tearDown() {
+    clientTransport.stopAsync();
+    assertTrue(frameReader.closed);
+    verify(frameWriter).close();
+  }
+
+  /**
+   * When nextFrame throws IOException, the transport should be aborted.
+   */
+  @Test
+  public void nextFrameThrowIOException() throws Exception {
+    MockStreamListener listener1 = new MockStreamListener();
+    MockStreamListener listener2 = new MockStreamListener();
+    clientTransport.newStream(method, listener1);
+    clientTransport.newStream(method, listener2);
+    assertEquals(2, streams.size());
+    assertTrue(streams.containsKey(3));
+    assertTrue(streams.containsKey(5));
+    frameReader.throwIOExceptionForNextFrame();
+    listener1.waitUntilStreamClosed();
+    listener2.waitUntilStreamClosed();
+    assertEquals(0, streams.size());
+    assertEquals(Code.INTERNAL, listener1.status.getCode());
+    assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
+    assertEquals(Code.INTERNAL, listener1.status.getCode());
+    assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage());
+    assertTrue("Service state: " + clientTransport.state(),
+        Service.State.TERMINATED == clientTransport.state());
+  }
+
+  @Test
+  public void readMessages() throws Exception {
+    final int numMessages = 10;
+    final String message = "Hello Client";
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    assertTrue(streams.containsKey(3));
+    for (int i = 0; i < numMessages; i++) {
+      BufferedSource source = mock(BufferedSource.class);
+      InputStream inputStream = createMessageFrame(message + i);
+      when(source.inputStream()).thenReturn(inputStream);
+      frameHandler.data(i == numMessages - 1 ? true : false, 3, source, inputStream.available());
+    }
+    listener.waitUntilStreamClosed();
+    assertEquals(Status.OK, listener.status);
+    assertEquals(numMessages, listener.messages.size());
+    for (int i = 0; i < numMessages; i++) {
+      assertEquals(message + i, listener.messages.get(i));
+    }
+  }
+
+  @Test
+  public void readContexts() throws Exception {
+    final int numContexts = 10;
+    final String key = "KEY";
+    final String value = "value";
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    assertTrue(streams.containsKey(3));
+    for (int i = 0; i < numContexts; i++) {
+      BufferedSource source = mock(BufferedSource.class);
+      InputStream inputStream = createContextFrame(key + i, value + i);
+      when(source.inputStream()).thenReturn(inputStream);
+      frameHandler.data(i == numContexts - 1 ? true : false, 3, source, inputStream.available());
+    }
+    listener.waitUntilStreamClosed();
+    assertEquals(Status.OK, listener.status);
+    assertEquals(numContexts, listener.contexts.size());
+    for (int i = 0; i < numContexts; i++) {
+      String val = listener.contexts.get(key + i);
+      assertNotNull(val);
+      assertEquals(value + i, val);
+    }
+  }
+
+  @Test
+  public void readStatus() throws Exception {
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    assertTrue(streams.containsKey(3));
+    BufferedSource source = mock(BufferedSource.class);
+    InputStream inputStream = createStatusFrame((short) Transport.Code.UNAVAILABLE.getNumber());
+    when(source.inputStream()).thenReturn(inputStream);
+    frameHandler.data(true, 3, source, inputStream.available());
+    listener.waitUntilStreamClosed();
+    assertEquals(Transport.Code.UNAVAILABLE, listener.status.getCode());
+  }
+
+  @Test
+  public void receiveReset() throws Exception {
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    assertTrue(streams.containsKey(3));
+    frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR);
+    listener.waitUntilStreamClosed();
+    assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.PROTOCOL_ERROR), listener.status);
+  }
+
+  @Test
+  public void cancelStream() throws Exception {
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    OkHttpClientStream stream = streams.get(3);
+    assertNotNull(stream);
+    stream.cancel();
+    verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+    listener.waitUntilStreamClosed();
+    assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
+  }
+
+  @Test
+  public void writeMessage() throws Exception {
+    final String message = "Hello Server";
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    OkHttpClientStream stream = streams.get(3);
+    InputStream input = new ByteArrayInputStream(message.getBytes(StandardCharsets.UTF_8));
+    stream.writeMessage(input, input.available(), null);
+    stream.flush();
+    ArgumentCaptor<Buffer> captor =
+        ArgumentCaptor.forClass(Buffer.class);
+    verify(frameWriter).data(eq(false), eq(3), captor.capture());
+    Buffer sentFrame = captor.getValue();
+    checkSameInputStream(createMessageFrame(message), sentFrame.inputStream());
+  }
+
+  @Test
+  public void writeContext() throws Exception {
+    final String key = "KEY";
+    final String value = "VALUE";
+    MockStreamListener listener = new MockStreamListener();
+    clientTransport.newStream(method, listener);
+    OkHttpClientStream stream = streams.get(3);
+    InputStream input = new ByteArrayInputStream(value.getBytes(StandardCharsets.UTF_8));
+    stream.writeContext(key, input, input.available(), null);
+    stream.flush();
+    ArgumentCaptor<Buffer> captor =
+        ArgumentCaptor.forClass(Buffer.class);
+    verify(frameWriter).data(eq(false), eq(3), captor.capture());
+    stream.cancel();
+    verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+    listener.waitUntilStreamClosed();
+    assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener.status);
+  }
+
+  @Test
+  public void windowUpdate() throws Exception {
+    MockStreamListener listener1 = new MockStreamListener();
+    MockStreamListener listener2 = new MockStreamListener();
+    clientTransport.newStream(method, listener1);
+    clientTransport.newStream(method, listener2);
+    assertEquals(2, streams.size());
+    OkHttpClientStream stream1 = streams.get(3);
+    OkHttpClientStream stream2 = streams.get(5);
+
+    int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 4;
+    byte[] fakeMessage = new byte[messageLength];
+    byte[] contextBody = ContextValue
+        .newBuilder()
+        .setKey("KEY")
+        .setValue(ByteString.copyFrom(fakeMessage))
+        .build()
+        .toByteArray();
+
+    // Stream 1 receives context
+    InputStream contextFrame = createContextFrame(contextBody);
+    int contextFrameLength = contextFrame.available();
+    BufferedSource source = mock(BufferedSource.class);
+    when(source.inputStream()).thenReturn(contextFrame);
+    frameHandler.data(false, 3, source, contextFrame.available());
+
+    // Stream 2 receives context
+    contextFrame = createContextFrame(contextBody);
+    when(source.inputStream()).thenReturn(contextFrame);
+    frameHandler.data(false, 5, source, contextFrame.available());
+
+    verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * contextFrameLength));
+
+    // Stream 1 receives a message
+    InputStream messageFrame = createMessageFrame(fakeMessage);
+    int messageFrameLength = messageFrame.available();
+    when(source.inputStream()).thenReturn(messageFrame);
+    frameHandler.data(false, 3, source, messageFrame.available());
+
+    verify(frameWriter).windowUpdate(eq(3), eq((long) contextFrameLength + messageFrameLength));
+
+    // Stream 2 receives a message
+    messageFrame = createMessageFrame(fakeMessage);
+    when(source.inputStream()).thenReturn(messageFrame);
+    frameHandler.data(false, 5, source, messageFrame.available());
+
+    verify(frameWriter).windowUpdate(eq(5), eq((long) contextFrameLength + messageFrameLength));
+    verify(frameWriter).windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
+
+    stream1.cancel();
+    verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+    listener1.waitUntilStreamClosed();
+    assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener1.status);
+
+    stream2.cancel();
+    verify(frameWriter).rstStream(eq(3), eq(ErrorCode.CANCEL));
+    listener2.waitUntilStreamClosed();
+    assertEquals(OkHttpClientTransport.toGrpcStatus(ErrorCode.CANCEL), listener2.status);
+  }
+
+  @Test
+  public void stopNormally() throws Exception {
+    MockStreamListener listener1 = new MockStreamListener();
+    MockStreamListener listener2 = new MockStreamListener();
+    clientTransport.newStream(method, listener1);
+    clientTransport.newStream(method, listener2);
+    assertEquals(2, streams.size());
+    clientTransport.stopAsync();
+    listener1.waitUntilStreamClosed();
+    listener2.waitUntilStreamClosed();
+    verify(frameWriter).goAway(eq(0), eq(ErrorCode.NO_ERROR), (byte[]) any());
+    assertEquals(0, streams.size());
+    assertEquals(Code.INTERNAL, listener1.status.getCode());
+    assertEquals(Code.INTERNAL, listener2.status.getCode());
+    assertEquals(Service.State.TERMINATED, clientTransport.state());
+  }
+
+  @Test
+  public void receiveGoAway() throws Exception {
+    // start 2 streams.
+    MockStreamListener listener1 = new MockStreamListener();
+    MockStreamListener listener2 = new MockStreamListener();
+    clientTransport.newStream(method, listener1);
+    clientTransport.newStream(method, listener2);
+    assertEquals(2, streams.size());
+
+    // Receive goAway, max good id is 3.
+    frameHandler.goAway(3, ErrorCode.CANCEL, null);
+
+    // Transport should be in STOPPING state.
+    assertEquals(Service.State.STOPPING, clientTransport.state());
+
+    // Stream 2 should be closed.
+    listener2.waitUntilStreamClosed();
+    assertEquals(1, streams.size());
+    assertEquals(Code.UNAVAILABLE, listener2.status.getCode());
+
+    // New stream should be failed.
+    MockStreamListener listener3 = new MockStreamListener();
+    try {
+      clientTransport.newStream(method, listener3);
+      fail("new stream should no be accepted by a go-away transport.");
+    } catch (IllegalStateException ex) {
+      // expected.
+    }
+
+    // But stream 1 should be able to send.
+    final String sentMessage = "Should I also go away?";
+    OkHttpClientStream stream = streams.get(3);
+    InputStream input =
+        new ByteArrayInputStream(sentMessage.getBytes(StandardCharsets.UTF_8));
+    stream.writeMessage(input, input.available(), null);
+    stream.flush();
+    ArgumentCaptor<Buffer> captor =
+        ArgumentCaptor.forClass(Buffer.class);
+    verify(frameWriter).data(eq(false), eq(3), captor.capture());
+    Buffer sentFrame = captor.getValue();
+    checkSameInputStream(createMessageFrame(sentMessage), sentFrame.inputStream());
+
+    // And read.
+    final String receivedMessage = "No, you are fine.";
+    BufferedSource source = mock(BufferedSource.class);
+    InputStream inputStream = createMessageFrame(receivedMessage);
+    when(source.inputStream()).thenReturn(inputStream);
+    frameHandler.data(true, 3, source, inputStream.available());
+    listener1.waitUntilStreamClosed();
+    assertEquals(1, listener1.messages.size());
+    assertEquals(receivedMessage, listener1.messages.get(0));
+
+    // The transport should be stopped after all active streams finished.
+    assertTrue("Service state: " + clientTransport.state(),
+        Service.State.TERMINATED == clientTransport.state());
+  }
+
+  @Test
+  public void streamIdExhaust() throws Exception {
+    int startId = Integer.MAX_VALUE - 2;
+    AsyncFrameWriter writer =  mock(AsyncFrameWriter.class);
+    OkHttpClientTransport transport =
+        new OkHttpClientTransport(frameReader, writer, startId);
+    transport.startAsync();
+    streams = transport.getStreams();
+
+    MockStreamListener listener1 = new MockStreamListener();
+    transport.newStream(method, listener1);
+
+    try {
+      transport.newStream(method, new MockStreamListener());
+      fail("new stream should not be accepted by a go-away transport.");
+    } catch (IllegalStateException ex) {
+      // expected.
+    }
+
+    streams.get(startId).cancel();
+    listener1.waitUntilStreamClosed();
+    verify(writer).rstStream(eq(startId), eq(ErrorCode.CANCEL));
+    assertEquals(Service.State.TERMINATED, transport.state());
+  }
+
+  private static void checkSameInputStream(InputStream in1, InputStream in2) throws IOException {
+    assertEquals(in1.available(), in2.available());
+    byte[] b1 = new byte[in1.available()];
+    in1.read(b1);
+    byte[] b2 = new byte[in2.available()];
+    in2.read(b2);
+    for (int i = 0; i < b1.length; i++) {
+      if (b1[i] != b2[i]) {
+        fail("Different InputStream.");
+      }
+    }
+  }
+
+  private static InputStream createMessageFrame(String message) throws IOException {
+    return createMessageFrame(message.getBytes(StandardCharsets.UTF_8));
+  }
+
+  private static InputStream createMessageFrame(byte[] message) throws IOException {
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(os);
+    dos.write(PAYLOAD_FRAME);
+    dos.writeInt(message.length);
+    dos.write(message);
+    dos.close();
+    byte[] messageFrame = os.toByteArray();
+
+    // Write the compression header followed by the message frame.
+    return addCompressionHeader(messageFrame);
+  }
+
+  private static InputStream createContextFrame(String key, String value) throws IOException {
+    byte[] body = ContextValue
+        .newBuilder()
+        .setKey(key)
+        .setValue(ByteString.copyFromUtf8(value))
+        .build()
+        .toByteArray();
+    return createContextFrame(body);
+  }
+
+  private static InputStream createContextFrame(byte[] body) throws IOException {
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(os);
+    dos.write(CONTEXT_VALUE_FRAME);
+    dos.writeInt(body.length);
+    dos.write(body);
+    dos.close();
+    byte[] contextFrame = os.toByteArray();
+
+    // Write the compression header followed by the context frame.
+    return addCompressionHeader(contextFrame);
+  }
+
+  private static InputStream createStatusFrame(short code) throws IOException {
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(os);
+    dos.write(STATUS_FRAME);
+    int length = 2;
+    dos.writeInt(length);
+    dos.writeShort(code);
+    dos.close();
+    byte[] statusFrame = os.toByteArray();
+
+    // Write the compression header followed by the status frame.
+    return addCompressionHeader(statusFrame);
+  }
+
+  private static InputStream addCompressionHeader(byte[] raw) throws IOException {
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(os);
+    dos.writeInt(raw.length);
+    dos.write(raw);
+    dos.close();
+    return new ByteArrayInputStream(os.toByteArray());
+  }
+
+  private static class MockFrameReader implements FrameReader {
+    boolean closed;
+    boolean throwExceptionForNextFrame;
+
+    @Override
+    public void close() throws IOException {
+      closed = true;
+    }
+
+    @Override
+    public boolean nextFrame(Handler handler) throws IOException {
+      if (throwExceptionForNextFrame) {
+        throw new IOException(NETWORK_ISSUE_MESSAGE);
+      }
+      synchronized (this) {
+        try {
+          wait();
+        } catch (InterruptedException e) {
+          throw new IOException(e);
+        }
+      }
+      if (throwExceptionForNextFrame) {
+        throw new IOException(NETWORK_ISSUE_MESSAGE);
+      }
+      return true;
+    }
+
+    synchronized void throwIOExceptionForNextFrame() {
+      throwExceptionForNextFrame = true;
+      notifyAll();
+    }
+
+    @Override
+    public void readConnectionPreface() throws IOException {
+      // not used.
+    }
+  }
+
+  private static class MockStreamListener implements StreamListener {
+    Status status;
+    CountDownLatch closed = new CountDownLatch(1);
+    ArrayList<String> messages = new ArrayList<String>();
+    Map<String, String> contexts = new HashMap<String, String>();
+
+    @Override
+    public ListenableFuture<Void> contextRead(String name, InputStream value, int length) {
+      String valueStr = getContent(value);
+      if (valueStr != null) {
+        // We assume only one context for each name.
+        contexts.put(name, valueStr);
+      }
+      return null;
+    }
+
+    @Override
+    public ListenableFuture<Void> messageRead(InputStream message, int length) {
+      String msg = getContent(message);
+      if (msg != null) {
+        messages.add(msg);
+      }
+      return null;
+    }
+
+    @Override
+    public void closed(Status status) {
+      this.status = status;
+      closed.countDown();
+    }
+
+    void waitUntilStreamClosed() throws InterruptedException {
+      if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) {
+        fail("Failed waiting stream to be closed.");
+      }
+    }
+
+    static String getContent(InputStream message) {
+      BufferedReader br =
+          new BufferedReader(new InputStreamReader(message, StandardCharsets.UTF_8));
+      try {
+        // Only one line message is used in this test.
+        return br.readLine();
+      } catch (IOException e) {
+        return null;
+      }
+    }
+  }
+}