Add simple context exchange mechanism by wrapping Channel.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=71139615
diff --git a/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java b/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java
new file mode 100644
index 0000000..3e24a37
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/context/ContextExchangeChannel.java
@@ -0,0 +1,114 @@
+package com.google.net.stubby.context;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.Marshaller;
+import com.google.net.stubby.MethodDescriptor;
+
+import java.io.InputStream;
+import java.util.Map;
+
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.inject.Provider;
+
+/**
+ * A channel implementation that sends bound context values and records received context.
+ * Unlike {@Channel} this class is not thread-safe so it is recommended to create an instance
+ * per thread.
+ */
+@NotThreadSafe
+public class ContextExchangeChannel extends ForwardingChannel {
+
+ private Map<String, Object> captured;
+ private Map<String, Provider<InputStream>> provided;
+
+ public ContextExchangeChannel(Channel channel) {
+ super(channel);
+ // builder?
+ captured = Maps.newTreeMap();
+ provided = Maps.newTreeMap();
+ }
+
+ @SuppressWarnings("unchecked")
+ public <T> Provider<T> receive(final String name, final Marshaller<T> m) {
+ synchronized (captured) {
+ captured.put(name, null);
+ }
+ return new Provider<T>() {
+ @Override
+ public T get() {
+ synchronized (captured) {
+ Object o = captured.get(name);
+ if (o instanceof InputStream) {
+ o = m.parse((InputStream) o);
+ captured.put(name, o);
+ }
+ return (T) o;
+ }
+ }
+ };
+ }
+
+ public <T> void send(final String name, final T value, final Marshaller<T> m) {
+ synchronized (provided) {
+ provided.put(name, new Provider<InputStream>() {
+ @Override
+ public InputStream get() {
+ return m.stream(value);
+ }
+ });
+ }
+ }
+
+ /**
+ * Clear all received values and allow another call
+ */
+ public void clearLastReceived() {
+ synchronized (captured) {
+ for (Map.Entry<String, Object> entry : captured.entrySet()) {
+ entry.setValue(null);
+ }
+ }
+ }
+
+
+ @Override
+ public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) {
+ return new CallImpl<ReqT, RespT>(delegate.newCall(method));
+ }
+
+ private class CallImpl<ReqT, RespT> extends ForwardingCall<ReqT, RespT> {
+ private CallImpl(Call<ReqT, RespT> delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public void start(Listener<RespT> responseListener) {
+ super.start(new ListenerImpl<RespT>(responseListener));
+ synchronized (provided) {
+ for (Map.Entry<String, Provider<InputStream>> entry : provided.entrySet()) {
+ sendContext(entry.getKey(), entry.getValue().get());
+ }
+ }
+ }
+ }
+
+ private class ListenerImpl<T> extends ForwardingListener<T> {
+ private ListenerImpl(Call.Listener<T> delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public ListenableFuture<Void> onContext(String name, InputStream value) {
+ synchronized (captured) {
+ if (captured.containsKey(name)) {
+ captured.put(name, value);
+ return null;
+ }
+ }
+ return super.onContext(name, value);
+ }
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java b/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java
new file mode 100644
index 0000000..8f299e4
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/context/ForwardingChannel.java
@@ -0,0 +1,89 @@
+package com.google.net.stubby.context;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.Status;
+
+import java.io.InputStream;
+
+import javax.annotation.Nullable;
+
+/**
+ * A {@link Channel} which forwards all of it's methods to another {@link Channel}. Implementations
+ * should override methods and make use of {@link ForwardingListener} and {@link ForwardingCall}
+ * to augment the behavior of the underlying {@link Channel}.
+ */
+public abstract class ForwardingChannel implements Channel {
+
+ protected final Channel delegate;
+
+ public ForwardingChannel(Channel channel) {
+ this.delegate = channel;
+ }
+
+ /**
+ * A {@link Call} which forwards all of it's methods to another {@link Call}.
+ */
+ public static class ForwardingCall<RequestT,ResponseT> extends Call<RequestT,ResponseT> {
+
+ protected final Call<RequestT, ResponseT> delegate;
+
+ public ForwardingCall(Call<RequestT, ResponseT> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void start(Listener<ResponseT> responseListener) {
+ this.delegate.start(responseListener);
+ }
+
+ @Override
+ public void cancel() {
+ this.delegate.cancel();
+ }
+
+ @Override
+ public void halfClose() {
+ this.delegate.halfClose();
+ }
+
+ @Override
+ public void sendContext(String name, InputStream value, @Nullable SettableFuture<Void> accepted) {
+ this.delegate.sendContext(name, value, accepted);
+ }
+
+ @Override
+ public void sendPayload(RequestT payload, @Nullable SettableFuture<Void> accepted) {
+ this.delegate.sendPayload(payload, accepted);
+ }
+ }
+
+ /**
+ * A {@link Call.Listener} which forwards all of its methods to another {@link Call.Listener}.
+ */
+ public static class ForwardingListener<T> extends Call.Listener<T> {
+
+ Call.Listener<T> delegate;
+
+ public ForwardingListener(Call.Listener<T> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public ListenableFuture<Void> onContext(String name, InputStream value) {
+ return delegate.onContext(name, value);
+ }
+
+ @Override
+ public ListenableFuture<Void> onPayload(T payload) {
+ return delegate.onPayload(payload);
+ }
+
+ @Override
+ public void onClose(Status status) {
+ delegate.onClose(status);
+ }
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java b/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java
new file mode 100644
index 0000000..1db3d40
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/context/ContextExchangeChannelTest.java
@@ -0,0 +1,133 @@
+package com.google.net.stubby.context;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.*;
+
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.net.stubby.Call;
+import com.google.net.stubby.Channel;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.stub.Marshallers;
+import com.google.net.stubby.testing.integration.Test.Payload;
+import com.google.net.stubby.testing.integration.Test.PayloadType;
+import com.google.net.stubby.testing.integration.Test.SimpleRequest;
+import com.google.net.stubby.testing.integration.Test.SimpleResponse;
+import com.google.net.stubby.testing.integration.grpcapi.TestService;
+import com.google.protobuf.ByteString;
+
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+import javax.inject.Provider;
+
+/**
+ * Tests for {@link ContextExchangeChannel}
+ */
+@RunWith(JUnit4.class)
+public class ContextExchangeChannelTest {
+
+ private static final SimpleRequest REQ = SimpleRequest.newBuilder().setPayload(
+ Payload.newBuilder().setPayloadCompressable("mary").
+ setPayloadType(PayloadType.COMPRESSABLE).build())
+ .build();
+
+ private static final SimpleResponse RESP = SimpleResponse.newBuilder().setPayload(
+ Payload.newBuilder().setPayloadCompressable("bob").
+ setPayloadType(PayloadType.COMPRESSABLE).build())
+ .build();
+
+ @Mock
+ Channel channel;
+
+ @Mock
+ Call call;
+
+ @Before @SuppressWarnings("unchecked")
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ when(channel.newCall(Mockito.any(MethodDescriptor.class))).thenReturn(call);
+ }
+
+ @Test
+ public void testReceive() throws Exception {
+ ContextExchangeChannel exchange = new ContextExchangeChannel(channel);
+ Provider<SimpleResponse> auth =
+ exchange.receive("auth", Marshallers.forProto(SimpleResponse.PARSER));
+ // Should be null, nothing has happened
+ assertNull(auth.get());
+ TestService.TestServiceBlockingStub stub =
+ TestService.blockingClient(exchange);
+ callStub(stub);
+ assertEquals(RESP, auth.get());
+ exchange.clearLastReceived();
+ assertNull(auth.get());
+ }
+
+ @Test @SuppressWarnings("unchecked")
+ public void testSend() throws Exception {
+ ContextExchangeChannel exchange = new ContextExchangeChannel(channel);
+ exchange.send("auth", RESP, Marshallers.forProto(SimpleResponse.PARSER));
+ TestService.TestServiceBlockingStub stub =
+ TestService.blockingClient(exchange);
+ callStub(stub);
+ verify(call).sendContext(eq("auth"),
+ argThat(new BaseMatcher<InputStream>() {
+ @Override
+ public boolean matches(Object o) {
+ try {
+ // Just check the length, consuming the stream will fail the test and Mockito
+ // calls this more than once.
+ return ((InputStream) o).available() == RESP.getSerializedSize();
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ }
+ }), (SettableFuture<Void>) isNull());
+ }
+
+ @SuppressWarnings("unchecked")
+ private void callStub(final TestService.TestServiceBlockingStub stub) throws Exception {
+ when(channel.newCall(Mockito.<MethodDescriptor>any())).thenReturn(call);
+
+ // execute the call in another thread so we don't deadlock waiting for the
+ // listener.onClose
+ Future<?> pending = Executors.newSingleThreadExecutor().submit(new Runnable() {
+ @Override
+ public void run() {
+ stub.unaryCall(REQ);
+ }
+ });
+ ArgumentCaptor<Call.Listener> listenerCapture = ArgumentCaptor.forClass(Call.Listener.class);
+ // Wait for the call to start to capture the listener
+ verify(call, timeout(1000)).start(listenerCapture.capture());
+
+ ByteString response = RESP.toByteString();
+ Call.Listener listener = listenerCapture.getValue();
+ // Respond with a context-value
+ listener.onContext("auth", response.newInput());
+ listener.onContext("something-else", response.newInput());
+ // .. and single payload
+ listener.onPayload(RESP);
+ listener.onClose(Status.OK);
+ pending.get();
+ }
+}