core: ServerBuilder.intercept(). (#3118)
This adds server-wide interceptors that applies to all call handlers.
Because ServerCallHandler is acquired per request, and can be dynamicly
provided by the fallback registry, the interceptors have to be installed
on a per-request basis. This adds a few object allocations per request,
which is acceptable.
diff --git a/core/src/main/java/io/grpc/InternalServerInterceptors.java b/core/src/main/java/io/grpc/InternalServerInterceptors.java
new file mode 100644
index 0000000..e981aa6
--- /dev/null
+++ b/core/src/main/java/io/grpc/InternalServerInterceptors.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+/**
+ * Accessor to internal methods of {@link ServerInterceptors}.
+ */
+@Internal
+public final class InternalServerInterceptors {
+ public static <ReqT, RespT> ServerCallHandler<ReqT, RespT> interceptCallHandler(
+ ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
+ return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler);
+ }
+
+ private InternalServerInterceptors() {
+ }
+}
diff --git a/core/src/main/java/io/grpc/ServerBuilder.java b/core/src/main/java/io/grpc/ServerBuilder.java
index 53c851d..87008ac 100644
--- a/core/src/main/java/io/grpc/ServerBuilder.java
+++ b/core/src/main/java/io/grpc/ServerBuilder.java
@@ -89,6 +89,20 @@
public abstract T addService(BindableService bindableService);
/**
+ * Adds a {@link ServerInterceptor} that is run for all services on the server. Interceptors
+ * added through this method always run before per-service interceptors added through {@link
+ * ServerInterceptors}. Interceptors run in the reverse order in which they are added.
+ *
+ * @param interceptor the all-service interceptor
+ * @return this
+ * @since 1.5.0
+ */
+ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3117")
+ public T intercept(ServerInterceptor interceptor) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
* Adds a {@link ServerTransportFilter}. The order of filters being added is the order they will
* be executed.
*
diff --git a/core/src/main/java/io/grpc/ServerInterceptors.java b/core/src/main/java/io/grpc/ServerInterceptors.java
index 1c3cd52..7917a7d 100644
--- a/core/src/main/java/io/grpc/ServerInterceptors.java
+++ b/core/src/main/java/io/grpc/ServerInterceptors.java
@@ -207,7 +207,7 @@
serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler));
}
- private static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
+ static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
public static <ReqT, RespT> InterceptCallHandler<ReqT, RespT> create(
ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
return new InterceptCallHandler<ReqT, RespT>(interceptor, callHandler);
diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
index 707f6f4..b1bc50c 100644
--- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
@@ -33,6 +33,7 @@
import io.grpc.InternalNotifyOnServerBuild;
import io.grpc.Server;
import io.grpc.ServerBuilder;
+import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
@@ -70,6 +71,9 @@
private final ArrayList<ServerTransportFilter> transportFilters =
new ArrayList<ServerTransportFilter>();
+ private final ArrayList<ServerInterceptor> interceptors =
+ new ArrayList<ServerInterceptor>();
+
private final List<InternalNotifyOnServerBuild> notifyOnBuildList =
new ArrayList<InternalNotifyOnServerBuild>();
@@ -123,6 +127,12 @@
}
@Override
+ public final T intercept(ServerInterceptor interceptor) {
+ interceptors.add(interceptor);
+ return thisT();
+ }
+
+ @Override
public final T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
streamTracerFactories.add(checkNotNull(factory, "factory"));
return thisT();
@@ -179,7 +189,7 @@
firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
- transportFilters);
+ transportFilters, interceptors);
for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
notifyTarget.notifyOnBuild(server);
}
diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java
index 27d1432..a0de5d6 100644
--- a/core/src/main/java/io/grpc/internal/ServerImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerImpl.java
@@ -32,8 +32,11 @@
import io.grpc.Decompressor;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
+import io.grpc.InternalServerInterceptors;
import io.grpc.Metadata;
import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerTransportFilter;
@@ -74,6 +77,9 @@
private final InternalHandlerRegistry registry;
private final HandlerRegistry fallbackRegistry;
private final List<ServerTransportFilter> transportFilters;
+ // This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
+ // creations.
+ private final ServerInterceptor[] interceptors;
@GuardedBy("lock") private boolean started;
@GuardedBy("lock") private boolean shutdown;
/** non-{@code null} if immediate shutdown has been requested. */
@@ -109,7 +115,7 @@
InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
InternalServer transportServer, Context rootContext,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
- List<ServerTransportFilter> transportFilters) {
+ List<ServerTransportFilter> transportFilters, List<ServerInterceptor> interceptors) {
this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool");
this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool");
this.registry = Preconditions.checkNotNull(registry, "registry");
@@ -122,6 +128,7 @@
this.compressorRegistry = compressorRegistry;
this.transportFilters = Collections.unmodifiableList(
new ArrayList<ServerTransportFilter>(transportFilters));
+ this.interceptors = interceptors.toArray(new ServerInterceptor[interceptors.size()]);
}
/**
@@ -469,9 +476,12 @@
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
stream, methodDef.getMethodDescriptor(), headers, context,
decompressorRegistry, compressorRegistry);
+ ServerCallHandler<ReqT, RespT> callHandler = methodDef.getServerCallHandler();
statsTraceCtx.serverCallStarted(call);
- ServerCall.Listener<ReqT> listener =
- methodDef.getServerCallHandler().startCall(call, headers);
+ for (ServerInterceptor interceptor : interceptors) {
+ callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler);
+ }
+ ServerCall.Listener<ReqT> listener = callHandler.startCall(call, headers);
if (listener == null) {
throw new NullPointerException(
"startCall() returned a null listener for method " + fullMethodName);
diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java
index b95c7ed..ce28591 100644
--- a/core/src/test/java/io/grpc/internal/ServerImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java
@@ -57,6 +57,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
import io.grpc.ServerTransportFilter;
@@ -70,6 +71,8 @@
import java.io.InputStream;
import java.net.SocketAddress;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
@@ -96,6 +99,13 @@
public class ServerImplTest {
private static final IntegerMarshaller INTEGER_MARSHALLER = IntegerMarshaller.INSTANCE;
private static final StringMarshaller STRING_MARSHALLER = StringMarshaller.INSTANCE;
+ private static final MethodDescriptor<String, Integer> METHOD =
+ MethodDescriptor.<String, Integer>newBuilder()
+ .setType(MethodDescriptor.MethodType.UNKNOWN)
+ .setFullMethodName("Waiter/serve")
+ .setRequestMarshaller(STRING_MARSHALLER)
+ .setResponseMarshaller(INTEGER_MARSHALLER)
+ .build();
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
private static final Context.Key<String> SERVER_TRACER_ADDED_KEY = Context.key("tracer-added");
private static final Context.CancellableContext SERVER_CONTEXT =
@@ -402,16 +412,10 @@
final AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>();
final AtomicReference<Context> callContextReference = new AtomicReference<Context>();
- MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
- .setType(MethodDescriptor.MethodType.UNKNOWN)
- .setFullMethodName("Waiter/serve")
- .setRequestMarshaller(STRING_MARSHALLER)
- .setResponseMarshaller(INTEGER_MARSHALLER)
- .build();
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
- new ServiceDescriptor("Waiter", method))
+ new ServiceDescriptor("Waiter", METHOD))
.addMethod(
- method,
+ METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@@ -570,18 +574,95 @@
}
@Test
+ public void interceptors() throws Exception {
+ final LinkedList<Context> capturedContexts = new LinkedList<Context>();
+ final Context.Key<String> key1 = Context.key("key1");
+ final Context.Key<String> key2 = Context.key("key2");
+ final Context.Key<String> key3 = Context.key("key3");
+ ServerInterceptor intercepter1 = new ServerInterceptor() {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call,
+ Metadata headers,
+ ServerCallHandler<ReqT, RespT> next) {
+ Context ctx = Context.current().withValue(key1, "value1");
+ Context origCtx = ctx.attach();
+ try {
+ capturedContexts.add(ctx);
+ return next.startCall(call, headers);
+ } finally {
+ ctx.detach(origCtx);
+ }
+ }
+ };
+ ServerInterceptor intercepter2 = new ServerInterceptor() {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call,
+ Metadata headers,
+ ServerCallHandler<ReqT, RespT> next) {
+ Context ctx = Context.current().withValue(key2, "value2");
+ Context origCtx = ctx.attach();
+ try {
+ capturedContexts.add(ctx);
+ return next.startCall(call, headers);
+ } finally {
+ ctx.detach(origCtx);
+ }
+ }
+ };
+ ServerCallHandler<String, Integer> callHandler = new ServerCallHandler<String, Integer>() {
+ @Override
+ public ServerCall.Listener<String> startCall(
+ ServerCall<String, Integer> call,
+ Metadata headers) {
+ capturedContexts.add(Context.current().withValue(key3, "value3"));
+ return callListener;
+ }
+ };
+
+ mutableFallbackRegistry.addService(
+ ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD, callHandler).build());
+ createServer(NO_FILTERS, Arrays.asList(intercepter2, intercepter1));
+ server.start();
+
+ ServerTransportListener transportListener
+ = transportServer.registerNewServerTransport(new SimpleServerTransport());
+
+ Metadata requestHeaders = new Metadata();
+ StatsTraceContext statsTraceCtx =
+ StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
+ when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
+
+ transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
+ assertEquals(1, executor.runDueTasks());
+
+ Context ctx1 = capturedContexts.poll();
+ assertEquals("value1", key1.get(ctx1));
+ assertNull(key2.get(ctx1));
+ assertNull(key3.get(ctx1));
+
+ Context ctx2 = capturedContexts.poll();
+ assertEquals("value1", key1.get(ctx2));
+ assertEquals("value2", key2.get(ctx2));
+ assertNull(key3.get(ctx2));
+
+ Context ctx3 = capturedContexts.poll();
+ assertEquals("value1", key1.get(ctx3));
+ assertEquals("value2", key2.get(ctx3));
+ assertEquals("value3", key3.get(ctx3));
+
+ assertTrue(capturedContexts.isEmpty());
+ }
+
+ @Test
public void exceptionInStartCallPropagatesToStream() throws Exception {
createAndStartServer(NO_FILTERS);
final Status status = Status.ABORTED.withDescription("Oh, no!");
- MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
- .setType(MethodDescriptor.MethodType.UNKNOWN)
- .setFullMethodName("Waiter/serve")
- .setRequestMarshaller(STRING_MARSHALLER)
- .setResponseMarshaller(INTEGER_MARSHALLER)
- .build();
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
- new ServiceDescriptor("Waiter", method))
- .addMethod(method,
+ new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@@ -695,20 +776,14 @@
@Test
public void testCallContextIsBoundInListenerCallbacks() throws Exception {
createAndStartServer(NO_FILTERS);
- MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
- .setType(MethodDescriptor.MethodType.UNKNOWN)
- .setFullMethodName("Waiter/serve")
- .setRequestMarshaller(STRING_MARSHALLER)
- .setResponseMarshaller(INTEGER_MARSHALLER)
- .build();
final AtomicBoolean onReadyCalled = new AtomicBoolean(false);
final AtomicBoolean onMessageCalled = new AtomicBoolean(false);
final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false);
final AtomicBoolean onCancelCalled = new AtomicBoolean(false);
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
- new ServiceDescriptor("Waiter", method))
+ new ServiceDescriptor("Waiter", METHOD))
.addMethod(
- method,
+ METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@@ -809,16 +884,9 @@
}
};
- MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
- .setType(MethodDescriptor.MethodType.UNKNOWN)
- .setFullMethodName("Waiter/serve")
- .setRequestMarshaller(STRING_MARSHALLER)
- .setResponseMarshaller(INTEGER_MARSHALLER)
- .build();
-
mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
- new ServiceDescriptor("Waiter", method))
- .addMethod(method,
+ new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
@@ -928,15 +996,9 @@
@Test
public void handlerRegistryPriorities() throws Exception {
fallbackRegistry = mock(HandlerRegistry.class);
- MethodDescriptor<String, Integer> method1 = MethodDescriptor.<String, Integer>newBuilder()
- .setType(MethodDescriptor.MethodType.UNKNOWN)
- .setFullMethodName("Service1/Method1")
- .setRequestMarshaller(STRING_MARSHALLER)
- .setResponseMarshaller(INTEGER_MARSHALLER)
- .build();
registry = new InternalHandlerRegistry.Builder()
- .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Service1", method1))
- .addMethod(method1, callHandler).build())
+ .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
+ .addMethod(METHOD, callHandler).build())
.build();
transportServer = new SimpleServer();
createAndStartServer(NO_FILTERS);
@@ -945,11 +1007,11 @@
= transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
- StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
+ StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
// This call will be handled by callHandler from the internal registry
- transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
+ transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertEquals(1, executor.runDueTasks());
verify(callHandler).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
Matchers.<Metadata>anyObject());
@@ -1109,9 +1171,15 @@
}
private void createServer(List<ServerTransportFilter> filters) {
+ createServer(filters, Collections.<ServerInterceptor>emptyList());
+ }
+
+ private void createServer(
+ List<ServerTransportFilter> filters, List<ServerInterceptor> interceptors) {
assertNull(server);
server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry,
- transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters);
+ transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters,
+ interceptors);
}
private void verifyExecutorsAcquired() {