Add ClientInterceptors "inside" ChannelImpl
Fixes #538
diff --git a/core/src/main/java/io/grpc/AbstractChannelBuilder.java b/core/src/main/java/io/grpc/AbstractChannelBuilder.java
index 4b142cc..17ff0d0 100644
--- a/core/src/main/java/io/grpc/AbstractChannelBuilder.java
+++ b/core/src/main/java/io/grpc/AbstractChannelBuilder.java
@@ -37,6 +37,9 @@
import io.grpc.SharedResourceHolder.Resource;
import io.grpc.transport.ClientTransportFactory;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -70,6 +73,7 @@
@Nullable
private ExecutorService userExecutor;
+ private final List<ClientInterceptor> interceptors = new ArrayList<ClientInterceptor>();
@Nullable
private String userAgent;
@@ -83,10 +87,34 @@
* <p>The channel won't take ownership of the given executor. It's caller's responsibility to
* shut down the executor when it's desired.
*/
- @SuppressWarnings("unchecked")
public final BuilderT executor(ExecutorService executor) {
userExecutor = executor;
- return (BuilderT) this;
+ return thisT();
+ }
+
+ /**
+ * Adds interceptors that will be called before the channel performs its real work. This is
+ * functionally equivalent to using {@link ClientInterceptors#intercept(Channel, List)}, but while
+ * still having access to the original {@code ChannelImpl}.
+ */
+ public final BuilderT intercept(List<ClientInterceptor> interceptors) {
+ this.interceptors.addAll(interceptors);
+ return thisT();
+ }
+
+ /**
+ * Adds interceptors that will be called before the channel performs its real work. This is
+ * functionally equivalent to using {@link ClientInterceptors#intercept(Channel,
+ * ClientInterceptor...)}, but while still having access to the original {@code ChannelImpl}.
+ */
+ public final BuilderT intercept(ClientInterceptor... interceptors) {
+ return intercept(Arrays.asList(interceptors));
+ }
+
+ private BuilderT thisT() {
+ @SuppressWarnings("unchecked")
+ BuilderT thisT = (BuilderT) this;
+ return thisT;
}
/**
@@ -116,7 +144,8 @@
}
final ChannelEssentials essentials = buildEssentials();
- ChannelImpl channel = new ChannelImpl(essentials.transportFactory, executor, userAgent);
+ ChannelImpl channel = new ChannelImpl(essentials.transportFactory, executor, userAgent,
+ interceptors);
channel.setTerminationRunnable(new Runnable() {
@Override
public void run() {
diff --git a/core/src/main/java/io/grpc/ChannelImpl.java b/core/src/main/java/io/grpc/ChannelImpl.java
index f2c40df..704fd78 100644
--- a/core/src/main/java/io/grpc/ChannelImpl.java
+++ b/core/src/main/java/io/grpc/ChannelImpl.java
@@ -46,6 +46,7 @@
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -94,6 +95,11 @@
*/
private ScheduledExecutorService deadlineCancellationExecutor;
/**
+ * We delegate to this channel, so that we can have interceptors as necessary. If there aren't
+ * any interceptors this will just be {@link RealChannel}.
+ */
+ private final Channel interceptorChannel;
+ /**
* All transports that are not stopped. At the very least {@link #activeTransport} will be
* present, but previously used transports that still have streams or are stopping may also be
* present.
@@ -112,10 +118,11 @@
private Runnable terminationRunnable;
ChannelImpl(ClientTransportFactory transportFactory, ExecutorService executor,
- @Nullable String userAgent) {
+ @Nullable String userAgent, List<ClientInterceptor> interceptors) {
this.transportFactory = transportFactory;
this.executor = executor;
this.userAgent = userAgent;
+ this.interceptorChannel = ClientInterceptors.intercept(new RealChannel(), interceptors);
deadlineCancellationExecutor = SharedResourceHolder.get(TIMER_SERVICE);
}
@@ -227,9 +234,9 @@
* Creates a new outgoing call on the channel.
*/
@Override
- public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
- MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
- return new CallImpl<ReqT, RespT>(method, new SerializingExecutor(executor), callOptions);
+ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions) {
+ return interceptorChannel.newCall(method, callOptions);
}
private ClientTransport obtainActiveTransport() {
@@ -267,6 +274,14 @@
}
}
+ private class RealChannel extends Channel {
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions) {
+ return new CallImpl<ReqT, RespT>(method, new SerializingExecutor(executor), callOptions);
+ }
+ }
+
private class TransportListener implements ClientTransport.Listener {
private final ClientTransport transport;
diff --git a/core/src/test/java/io/grpc/ChannelImplTest.java b/core/src/test/java/io/grpc/ChannelImplTest.java
index 4bd2cb3..f6ee1c5 100644
--- a/core/src/test/java/io/grpc/ChannelImplTest.java
+++ b/core/src/test/java/io/grpc/ChannelImplTest.java
@@ -33,6 +33,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
@@ -62,8 +63,11 @@
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicLong;
/** Unit tests for {@link ChannelImpl}. */
@RunWith(JUnit4.class)
@@ -94,7 +98,8 @@
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
- channel = new ChannelImpl(mockTransportFactory, executor, null);
+ channel = new ChannelImpl(mockTransportFactory, executor, null,
+ Collections.<ClientInterceptor>emptyList());
when(mockTransportFactory.newClientTransport()).thenReturn(mockTransport);
}
@@ -283,4 +288,21 @@
verifyNoMoreInteractions(mockStream2);
verifyNoMoreInteractions(mockStream3);
}
+
+ @Test
+ public void interceptor() {
+ final AtomicLong atomic = new AtomicLong();
+ ClientInterceptor interceptor = new ClientInterceptor() {
+ @Override
+ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> interceptCall(
+ MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions,
+ Channel next) {
+ atomic.set(1);
+ return next.newCall(method, callOptions);
+ }
+ };
+ channel = new ChannelImpl(mockTransportFactory, executor, null, Arrays.asList(interceptor));
+ assertNotNull(channel.newCall(method, CallOptions.DEFAULT));
+ assertEquals(1, atomic.get());
+ }
}