Add simple context exchange mechanism by wrapping Channel.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=71139615
diff --git a/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java b/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java
new file mode 100644
index 0000000..3e24a37
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java
@@ -0,0 +1,114 @@
+package com.google.net.stubby.context;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.Marshaller;
+import com.google.net.stubby.MethodDescriptor;
+
+import java.io.InputStream;
+import java.util.Map;
+
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.inject.Provider;
+
+/**
+ * A channel implementation that sends bound context values and records received context.
+ * Unlike {@Channel} this class is not thread-safe so it is recommended to create an instance
+ * per thread.
+ */
+@NotThreadSafe
+public class ContextExchangeChannel extends ForwardingChannel {
+
+  private Map<String, Object> captured;
+  private Map<String, Provider<InputStream>> provided;
+
+  public ContextExchangeChannel(Channel channel) {
+    super(channel);
+    // builder?
+    captured = Maps.newTreeMap();
+    provided = Maps.newTreeMap();
+  }
+
+  @SuppressWarnings("unchecked")
+  public <T> Provider<T> receive(final String name, final Marshaller<T> m) {
+    synchronized (captured) {
+      captured.put(name, null);
+    }
+    return new Provider<T>() {
+      @Override
+      public T get() {
+        synchronized (captured) {
+          Object o = captured.get(name);
+          if (o instanceof InputStream) {
+            o = m.parse((InputStream) o);
+            captured.put(name, o);
+          }
+          return (T) o;
+        }
+      }
+    };
+  }
+
+  public <T> void send(final String name, final T value, final Marshaller<T> m) {
+    synchronized (provided) {
+      provided.put(name, new Provider<InputStream>() {
+        @Override
+        public InputStream get() {
+          return m.stream(value);
+        }
+      });
+    }
+  }
+
+  /**
+   * Clear all received values and allow another call
+   */
+  public void clearLastReceived() {
+    synchronized (captured) {
+      for (Map.Entry<String, Object> entry : captured.entrySet()) {
+        entry.setValue(null);
+      }
+    }
+  }
+
+
+  @Override
+  public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) {
+    return new CallImpl<ReqT, RespT>(delegate.newCall(method));
+  }
+
+  private class CallImpl<ReqT, RespT> extends ForwardingCall<ReqT, RespT> {
+    private CallImpl(Call<ReqT, RespT> delegate) {
+      super(delegate);
+    }
+
+    @Override
+    public void start(Listener<RespT> responseListener) {
+      super.start(new ListenerImpl<RespT>(responseListener));
+      synchronized (provided) {
+        for (Map.Entry<String, Provider<InputStream>> entry : provided.entrySet()) {
+          sendContext(entry.getKey(), entry.getValue().get());
+        }
+      }
+    }
+  }
+
+  private class ListenerImpl<T> extends ForwardingListener<T> {
+    private ListenerImpl(Call.Listener<T> delegate) {
+      super(delegate);
+    }
+
+    @Override
+    public ListenableFuture<Void> onContext(String name, InputStream value) {
+      synchronized (captured) {
+        if (captured.containsKey(name)) {
+          captured.put(name, value);
+          return null;
+        }
+      }
+      return super.onContext(name, value);
+    }
+  }
+}
diff --git a/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java b/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java
new file mode 100644
index 0000000..8f299e4
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java
@@ -0,0 +1,89 @@
+package com.google.net.stubby.context;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.Status;
+
+import java.io.InputStream;
+
+import javax.annotation.Nullable;
+
+/**
+ * A {@link Channel} which forwards all of it's methods to another {@link Channel}. Implementations
+ * should override methods and make use of {@link ForwardingListener} and {@link ForwardingCall}
+ * to augment the behavior of the underlying {@link Channel}.
+ */
+public abstract class ForwardingChannel implements Channel {
+
+  protected final Channel delegate;
+
+  public ForwardingChannel(Channel channel) {
+    this.delegate = channel;
+  }
+
+  /**
+   * A {@link Call} which forwards all of it's methods to another {@link Call}.
+   */
+  public static class ForwardingCall<RequestT,ResponseT> extends Call<RequestT,ResponseT> {
+
+    protected final Call<RequestT, ResponseT> delegate;
+
+    public ForwardingCall(Call<RequestT, ResponseT> delegate) {
+      this.delegate = delegate;
+    }
+
+    @Override
+    public void start(Listener<ResponseT> responseListener) {
+      this.delegate.start(responseListener);
+    }
+
+    @Override
+    public void cancel() {
+      this.delegate.cancel();
+    }
+
+    @Override
+    public void halfClose() {
+      this.delegate.halfClose();
+    }
+
+    @Override
+    public void sendContext(String name, InputStream value, @Nullable SettableFuture<Void> accepted) {
+      this.delegate.sendContext(name, value, accepted);
+    }
+
+    @Override
+    public void sendPayload(RequestT payload, @Nullable SettableFuture<Void> accepted) {
+      this.delegate.sendPayload(payload, accepted);
+    }
+  }
+
+  /**
+   * A {@link Call.Listener} which forwards all of its methods to another {@link Call.Listener}.
+   */
+  public static class ForwardingListener<T> extends Call.Listener<T> {
+
+    Call.Listener<T> delegate;
+
+    public ForwardingListener(Call.Listener<T> delegate) {
+      this.delegate = delegate;
+    }
+
+    @Override
+    public ListenableFuture<Void> onContext(String name, InputStream value) {
+      return delegate.onContext(name, value);
+    }
+
+    @Override
+    public ListenableFuture<Void> onPayload(T payload) {
+      return delegate.onPayload(payload);
+    }
+
+    @Override
+    public void onClose(Status status) {
+      delegate.onClose(status);
+    }
+  }
+}
diff --git a/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java b/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java
new file mode 100644
index 0000000..1db3d40
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java
@@ -0,0 +1,133 @@
+package com.google.net.stubby.context;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.*;
+
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.stub.Marshallers;
+import com.google.net.stubby.testing.integration.Test.Payload;
+import com.google.net.stubby.testing.integration.Test.PayloadType;
+import com.google.net.stubby.testing.integration.Test.SimpleRequest;
+import com.google.net.stubby.testing.integration.Test.SimpleResponse;
+import com.google.net.stubby.testing.integration.grpcapi.TestService;
+import com.google.protobuf.ByteString;
+
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+import javax.inject.Provider;
+
+/**
+ * Tests for {@link ContextExchangeChannel}
+ */
+@RunWith(JUnit4.class)
+public class ContextExchangeChannelTest {
+
+  private static final SimpleRequest REQ = SimpleRequest.newBuilder().setPayload(
+      Payload.newBuilder().setPayloadCompressable("mary").
+          setPayloadType(PayloadType.COMPRESSABLE).build())
+      .build();
+
+  private static final SimpleResponse RESP = SimpleResponse.newBuilder().setPayload(
+      Payload.newBuilder().setPayloadCompressable("bob").
+          setPayloadType(PayloadType.COMPRESSABLE).build())
+      .build();
+
+  @Mock
+  Channel channel;
+
+  @Mock
+  Call call;
+
+  @Before @SuppressWarnings("unchecked")
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when(channel.newCall(Mockito.any(MethodDescriptor.class))).thenReturn(call);
+  }
+
+  @Test
+  public void testReceive() throws Exception {
+    ContextExchangeChannel exchange = new ContextExchangeChannel(channel);
+    Provider<SimpleResponse> auth =
+        exchange.receive("auth", Marshallers.forProto(SimpleResponse.PARSER));
+    // Should be null, nothing has happened
+    assertNull(auth.get());
+    TestService.TestServiceBlockingStub stub =
+        TestService.blockingClient(exchange);
+    callStub(stub);
+    assertEquals(RESP, auth.get());
+    exchange.clearLastReceived();
+    assertNull(auth.get());
+  }
+
+  @Test @SuppressWarnings("unchecked")
+  public void testSend() throws Exception {
+    ContextExchangeChannel exchange = new ContextExchangeChannel(channel);
+    exchange.send("auth", RESP, Marshallers.forProto(SimpleResponse.PARSER));
+    TestService.TestServiceBlockingStub stub =
+        TestService.blockingClient(exchange);
+    callStub(stub);
+    verify(call).sendContext(eq("auth"),
+        argThat(new BaseMatcher<InputStream>() {
+          @Override
+          public boolean matches(Object o) {
+            try {
+              // Just check the length, consuming the stream will fail the test and Mockito
+              // calls this more than once.
+              return ((InputStream) o).available() == RESP.getSerializedSize();
+            } catch (IOException ioe) {
+              throw new RuntimeException(ioe);
+            }
+          }
+
+          @Override
+          public void describeTo(Description description) {
+          }
+        }), (SettableFuture<Void>) isNull());
+  }
+
+  @SuppressWarnings("unchecked")
+  private void callStub(final TestService.TestServiceBlockingStub stub) throws Exception {
+    when(channel.newCall(Mockito.<MethodDescriptor>any())).thenReturn(call);
+
+    // execute the call in another thread so we don't deadlock waiting for the
+    // listener.onClose
+    Future<?> pending = Executors.newSingleThreadExecutor().submit(new Runnable() {
+      @Override
+      public void run() {
+        stub.unaryCall(REQ);
+      }
+    });
+    ArgumentCaptor<Call.Listener> listenerCapture = ArgumentCaptor.forClass(Call.Listener.class);
+    // Wait for the call to start to capture the listener
+    verify(call, timeout(1000)).start(listenerCapture.capture());
+
+    ByteString response = RESP.toByteString();
+    Call.Listener listener = listenerCapture.getValue();
+    // Respond with a context-value
+    listener.onContext("auth", response.newInput());
+    listener.onContext("something-else", response.newInput());
+    // .. and single payload
+    listener.onPayload(RESP);
+    listener.onClose(Status.OK);
+    pending.get();
+  }
+}