Improve okhttp client transport, handles go away and add unit test.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72155172
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
index fc974e0..b4e9f68 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
@@ -1,5 +1,6 @@
 package com.google.net.stubby.newtransport.okhttp;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.io.ByteBuffers;
@@ -33,9 +34,10 @@
 import java.io.IOException;
 import java.net.Socket;
 import java.nio.ByteBuffer;
-import java.util.Collection;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
@@ -48,6 +50,7 @@
  */
 public class OkHttpClientTransport extends AbstractClientTransport {
   /** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
+  @VisibleForTesting
   static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
 
   private static final ImmutableMap<ErrorCode, Status> ERROR_CODE_TO_STATUS = ImmutableMap
@@ -75,21 +78,40 @@
   private final int port;
   private FrameReader frameReader;
   private AsyncFrameWriter frameWriter;
-  @GuardedBy("this")
+  private Object lock = new Object();
+  @GuardedBy("lock")
   private int nextStreamId;
   private final Map<Integer, OkHttpClientStream> streams =
       Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
   private final ExecutorService executor = Executors.newCachedThreadPool();
   private int unacknowledgedBytesRead;
+  private ClientFrameHandler clientFrameHandler;
+  // The status used to finish all active streams when the transport is closed.
+  @GuardedBy("lock")
+  private boolean goAway;
+  @GuardedBy("lock")
+  private Status goAwayStatus;
 
   public OkHttpClientTransport(String host, int port) {
-    this.host = host;
+    this.host = Preconditions.checkNotNull(host);
     this.port = port;
     // Client initiated streams are odd, server initiated ones are even. Server should not need to
     // use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
     nextStreamId = 3;
   }
 
+  /**
+   * Create a transport connected to a fake peer for test.
+   */
+  @VisibleForTesting
+  OkHttpClientTransport(FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) {
+    host = null;
+    port = -1;
+    this.nextStreamId = nextStreamId;
+    this.frameReader = frameReader;
+    this.frameWriter = frameWriter;
+  }
+
   @Override
   protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, StreamListener listener) {
     return new OkHttpClientStream(method, listener);
@@ -97,53 +119,85 @@
 
   @Override
   protected void doStart() {
-    BufferedSource source;
-    BufferedSink sink;
-    try {
-      Socket socket = new Socket(host, port);
-      // TODO(user): use SpdyConnection.
-      source = Okio.buffer(Okio.source(socket));
-      sink = Okio.buffer(Okio.sink(socket));
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    // We set host to null for test.
+    if (host != null) {
+      BufferedSource source;
+      BufferedSink sink;
+      try {
+        Socket socket = new Socket(host, port);
+        source = Okio.buffer(Okio.source(socket));
+        sink = Okio.buffer(Okio.sink(socket));
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      Variant variant = new Http20Draft12();
+      frameReader = variant.newReader(source, true);
+      frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
     }
-    Variant variant = new Http20Draft12();
-    frameReader = variant.newReader(source, true);
-    frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
 
-    executor.execute(new ClientFrameHandler());
     notifyStarted();
+    clientFrameHandler = new ClientFrameHandler();
+    executor.execute(clientFrameHandler);
   }
 
   @Override
   protected void doStop() {
-    closeAllStreams(new Status(Code.INTERNAL, "Transport stopped"));
-    frameWriter.close();
-    try {
-      frameReader.close();
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    boolean normalClose;
+    synchronized (lock) {
+      normalClose = !goAway;
     }
-    executor.shutdown();
-    notifyStopped();
+    if (normalClose) {
+      abort(new Status(Code.INTERNAL, "Transport stopped"));
+      // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
+      // The GOAWAY is part of graceful shutdown.
+      frameWriter.goAway(0, ErrorCode.NO_ERROR, null);
+    }
+    stopIfNecessary();
+  }
+
+  @VisibleForTesting
+  ClientFrameHandler getHandler() {
+    return clientFrameHandler;
+  }
+
+  @VisibleForTesting
+  Map<Integer, OkHttpClientStream> getStreams() {
+    return streams;
   }
 
   /**
-   * Close and remove all streams.
+   * Finish all active streams with given status, then close the transport.
    */
-  private void closeAllStreams(Status status) {
-    Collection<OkHttpClientStream> streamsCopy;
-    synchronized (streams) {
-      streamsCopy = streams.values();
-      streams.clear();
+  void abort(Status status) {
+    onGoAway(-1, status);
+  }
+
+  private void onGoAway(int lastKnownStreamId, Status status) {
+    ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
+    synchronized (lock) {
+      goAway = true;
+      goAwayStatus = status;
+      Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
+      while (it.hasNext()) {
+        Map.Entry<Integer, OkHttpClientStream> entry = it.next();
+        if (entry.getKey() > lastKnownStreamId) {
+          goAwayStreams.add(entry.getValue());
+          it.remove();
+        }
+      }
     }
-    for (OkHttpClientStream stream : streamsCopy) {
+
+    // Starting stop, go into STOPPING state so that Channel know this Transport should not be used
+    // further, will become STOPPED once all streams are complete.
+    stopAsync();
+
+    for (OkHttpClientStream stream : goAwayStreams) {
       stream.setStatus(status);
     }
   }
 
   /**
-   * Called when a HTTP2 stream is closed.
+   * Called when a stream is closed.
    *
    * <p> Return false if the stream has already finished.
    */
@@ -159,10 +213,39 @@
   }
 
   /**
+   * When the transport is in goAway states, we should stop it once all active streams finish.
+   */
+  private void stopIfNecessary() {
+    boolean shouldStop;
+    synchronized (lock) {
+      shouldStop = (goAway && streams.size() == 0);
+    }
+    if (shouldStop) {
+      frameWriter.close();
+      try {
+        frameReader.close();
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      executor.shutdown();
+      notifyStopped();
+    }
+  }
+
+  /**
+   * Returns a Grpc status corresponding to the given ErrorCode.
+   */
+  @VisibleForTesting
+  static Status toGrpcStatus(ErrorCode code) {
+    return ERROR_CODE_TO_STATUS.get(code);
+  }
+
+  /**
    * Runnable which reads frames and dispatches them to in flight calls
    */
-  private class ClientFrameHandler implements FrameReader.Handler, Runnable {
-    private ClientFrameHandler() {}
+  @VisibleForTesting
+  class ClientFrameHandler implements FrameReader.Handler, Runnable {
+    ClientFrameHandler() {}
 
     @Override
     public void run() {
@@ -173,8 +256,7 @@
         while (frameReader.nextFrame(this)) {
         }
       } catch (IOException ioe) {
-        ioe.printStackTrace();
-        closeAllStreams(new Status(Code.INTERNAL, ioe.getMessage()));
+        abort(Status.fromThrowable(ioe));
       } finally {
         // Restore the original thread name.
         Thread.currentThread().setName(threadName);
@@ -210,7 +292,9 @@
         stream.unacknowledgedBytesRead = 0;
       }
       if (inFinished) {
-        finishStream(streamId, Status.OK);
+        if (finishStream(streamId, Status.OK)) {
+          stopIfNecessary();
+        }
       }
     }
 
@@ -229,7 +313,9 @@
 
     @Override
     public void rstStream(int streamId, ErrorCode errorCode) {
-      finishStream(streamId, ERROR_CODE_TO_STATUS.get(errorCode));
+      if (finishStream(streamId, toGrpcStatus(errorCode))) {
+        stopIfNecessary();
+      }
     }
 
     @Override
@@ -252,18 +338,14 @@
 
     @Override
     public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
-      // TODO(user): Log here and implement the real Go away behavior: streams have
-      // id <= lastGoodStreamId should not be closed.
-      closeAllStreams(new Status(Code.UNAVAILABLE, "Go away"));
-      stopAsync();
+      onGoAway(lastGoodStreamId, new Status(Code.UNAVAILABLE, "Go away"));
     }
 
     @Override
     public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
         throws IOException {
-      // TODO(user): should send SETTINGS_ENABLE_PUSH=0, then here we should reset it with
-      // PROTOCOL_ERROR.
-      frameWriter.rstStream(streamId, ErrorCode.REFUSED_STREAM);
+      // We don't accept server initiated stream.
+      frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
     }
 
     @Override
@@ -284,28 +366,42 @@
     }
   }
 
+  @GuardedBy("lock")
+  private void assignStreamId(OkHttpClientStream stream) {
+    Preconditions.checkState(stream.streamId == 0, "StreamId already assigned");
+    stream.streamId = nextStreamId;
+    streams.put(stream.streamId, stream);
+    if (nextStreamId >= Integer.MAX_VALUE - 2) {
+      onGoAway(Integer.MAX_VALUE, new Status(Code.INTERNAL, "Stream id exhaust"));
+    } else {
+      nextStreamId += 2;
+    }
+  }
+
   /**
    * Client stream for the okhttp transport.
    */
-  private class OkHttpClientStream extends AbstractStream implements ClientStream {
+  @VisibleForTesting
+  class OkHttpClientStream extends AbstractStream implements ClientStream {
     int streamId;
     final InputStreamDeframer deframer;
     int unacknowledgedBytesRead;
 
-    public OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
+    OkHttpClientStream(MethodDescriptor<?, ?> method, StreamListener listener) {
       super(listener);
-      Preconditions.checkState(streamId == 0, "StreamId should be 0");
-      synchronized (OkHttpClientTransport.this) {
-        streamId = nextStreamId;
-        nextStreamId += 2;
-        streams.put(streamId, this);
-        frameWriter.synStream(false, false, streamId, 0,
-            Headers.createRequestHeaders(method.getName()));
-      }
       deframer = new InputStreamDeframer(inboundMessageHandler());
+      synchronized (lock) {
+        if (goAway) {
+          setStatus(goAwayStatus);
+          return;
+        }
+        assignStreamId(this);
+      }
+      frameWriter.synStream(false, false, streamId, 0,
+          Headers.createRequestHeaders(method.getName()));
     }
 
-    public InputStreamDeframer getDeframer() {
+    InputStreamDeframer getDeframer() {
       return deframer;
     }
 
@@ -330,8 +426,9 @@
     public void cancel() {
       Preconditions.checkState(streamId != 0, "streamId should be set");
       outboundPhase = Phase.STATUS;
-      if (finishStream(streamId, ERROR_CODE_TO_STATUS.get(ErrorCode.CANCEL))) {
+      if (finishStream(streamId, toGrpcStatus(ErrorCode.CANCEL))) {
         frameWriter.rstStream(streamId, ErrorCode.CANCEL);
+        stopIfNecessary();
       }
     }
   }