core: ServerBuilder.intercept(). (#3118)

This adds server-wide interceptors that applies to all call handlers.
Because ServerCallHandler is acquired per request, and can be dynamicly
provided by the fallback registry, the interceptors have to be installed
on a per-request basis. This adds a few object allocations per request,
which is acceptable.
diff --git a/core/src/main/java/io/grpc/InternalServerInterceptors.java b/core/src/main/java/io/grpc/InternalServerInterceptors.java
new file mode 100644
index 0000000..e981aa6
--- /dev/null
+++ b/core/src/main/java/io/grpc/InternalServerInterceptors.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+/**
+ * Accessor to internal methods of {@link ServerInterceptors}.
+ */
+@Internal
+public final class InternalServerInterceptors {
+  public static <ReqT, RespT> ServerCallHandler<ReqT, RespT> interceptCallHandler(
+      ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
+    return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler);
+  }
+
+  private InternalServerInterceptors() {
+  }
+}
diff --git a/core/src/main/java/io/grpc/ServerBuilder.java b/core/src/main/java/io/grpc/ServerBuilder.java
index 53c851d..87008ac 100644
--- a/core/src/main/java/io/grpc/ServerBuilder.java
+++ b/core/src/main/java/io/grpc/ServerBuilder.java
@@ -89,6 +89,20 @@
   public abstract T addService(BindableService bindableService);
 
   /**
+   * Adds a {@link ServerInterceptor} that is run for all services on the server.  Interceptors
+   * added through this method always run before per-service interceptors added through {@link
+   * ServerInterceptors}.  Interceptors run in the reverse order in which they are added.
+   *
+   * @param interceptor the all-service interceptor
+   * @return this
+   * @since 1.5.0
+   */
+  @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3117")
+  public T intercept(ServerInterceptor interceptor) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
    * Adds a {@link ServerTransportFilter}. The order of filters being added is the order they will
    * be executed.
    *
diff --git a/core/src/main/java/io/grpc/ServerInterceptors.java b/core/src/main/java/io/grpc/ServerInterceptors.java
index 1c3cd52..7917a7d 100644
--- a/core/src/main/java/io/grpc/ServerInterceptors.java
+++ b/core/src/main/java/io/grpc/ServerInterceptors.java
@@ -207,7 +207,7 @@
     serviceDefBuilder.addMethod(method.withServerCallHandler(callHandler));
   }
 
-  private static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
+  static class InterceptCallHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
     public static <ReqT, RespT> InterceptCallHandler<ReqT, RespT> create(
         ServerInterceptor interceptor, ServerCallHandler<ReqT, RespT> callHandler) {
       return new InterceptCallHandler<ReqT, RespT>(interceptor, callHandler);
diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
index 707f6f4..b1bc50c 100644
--- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
@@ -33,6 +33,7 @@
 import io.grpc.InternalNotifyOnServerBuild;
 import io.grpc.Server;
 import io.grpc.ServerBuilder;
+import io.grpc.ServerInterceptor;
 import io.grpc.ServerMethodDefinition;
 import io.grpc.ServerServiceDefinition;
 import io.grpc.ServerStreamTracer;
@@ -70,6 +71,9 @@
   private final ArrayList<ServerTransportFilter> transportFilters =
       new ArrayList<ServerTransportFilter>();
 
+  private final ArrayList<ServerInterceptor> interceptors =
+      new ArrayList<ServerInterceptor>();
+
   private final List<InternalNotifyOnServerBuild> notifyOnBuildList =
       new ArrayList<InternalNotifyOnServerBuild>();
 
@@ -123,6 +127,12 @@
   }
 
   @Override
+  public final T intercept(ServerInterceptor interceptor) {
+    interceptors.add(interceptor);
+    return thisT();
+  }
+
+  @Override
   public final T addStreamTracerFactory(ServerStreamTracer.Factory factory) {
     streamTracerFactories.add(checkNotNull(factory, "factory"));
     return thisT();
@@ -179,7 +189,7 @@
         firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
         Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
         firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
-        transportFilters);
+        transportFilters, interceptors);
     for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
       notifyTarget.notifyOnBuild(server);
     }
diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java
index 27d1432..a0de5d6 100644
--- a/core/src/main/java/io/grpc/internal/ServerImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerImpl.java
@@ -32,8 +32,11 @@
 import io.grpc.Decompressor;
 import io.grpc.DecompressorRegistry;
 import io.grpc.HandlerRegistry;
+import io.grpc.InternalServerInterceptors;
 import io.grpc.Metadata;
 import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
 import io.grpc.ServerMethodDefinition;
 import io.grpc.ServerServiceDefinition;
 import io.grpc.ServerTransportFilter;
@@ -74,6 +77,9 @@
   private final InternalHandlerRegistry registry;
   private final HandlerRegistry fallbackRegistry;
   private final List<ServerTransportFilter> transportFilters;
+  // This is iterated on a per-call basis.  Use an array instead of a Collection to avoid iterator
+  // creations.
+  private final ServerInterceptor[] interceptors;
   @GuardedBy("lock") private boolean started;
   @GuardedBy("lock") private boolean shutdown;
   /** non-{@code null} if immediate shutdown has been requested. */
@@ -109,7 +115,7 @@
       InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
       InternalServer transportServer, Context rootContext,
       DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
-      List<ServerTransportFilter> transportFilters) {
+      List<ServerTransportFilter> transportFilters, List<ServerInterceptor> interceptors) {
     this.executorPool = Preconditions.checkNotNull(executorPool, "executorPool");
     this.timeoutServicePool = Preconditions.checkNotNull(timeoutServicePool, "timeoutServicePool");
     this.registry = Preconditions.checkNotNull(registry, "registry");
@@ -122,6 +128,7 @@
     this.compressorRegistry = compressorRegistry;
     this.transportFilters = Collections.unmodifiableList(
         new ArrayList<ServerTransportFilter>(transportFilters));
+    this.interceptors = interceptors.toArray(new ServerInterceptor[interceptors.size()]);
   }
 
   /**
@@ -469,9 +476,12 @@
       ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
           stream, methodDef.getMethodDescriptor(), headers, context,
           decompressorRegistry, compressorRegistry);
+      ServerCallHandler<ReqT, RespT> callHandler = methodDef.getServerCallHandler();
       statsTraceCtx.serverCallStarted(call);
-      ServerCall.Listener<ReqT> listener =
-          methodDef.getServerCallHandler().startCall(call, headers);
+      for (ServerInterceptor interceptor : interceptors) {
+        callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler);
+      }
+      ServerCall.Listener<ReqT> listener = callHandler.startCall(call, headers);
       if (listener == null) {
         throw new NullPointerException(
             "startCall() returned a null listener for method " + fullMethodName);
diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java
index b95c7ed..ce28591 100644
--- a/core/src/test/java/io/grpc/internal/ServerImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java
@@ -57,6 +57,7 @@
 import io.grpc.MethodDescriptor;
 import io.grpc.ServerCall;
 import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
 import io.grpc.ServerServiceDefinition;
 import io.grpc.ServerStreamTracer;
 import io.grpc.ServerTransportFilter;
@@ -70,6 +71,8 @@
 import java.io.InputStream;
 import java.net.SocketAddress;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.Executor;
@@ -96,6 +99,13 @@
 public class ServerImplTest {
   private static final IntegerMarshaller INTEGER_MARSHALLER = IntegerMarshaller.INSTANCE;
   private static final StringMarshaller STRING_MARSHALLER = StringMarshaller.INSTANCE;
+  private static final MethodDescriptor<String, Integer> METHOD =
+      MethodDescriptor.<String, Integer>newBuilder()
+          .setType(MethodDescriptor.MethodType.UNKNOWN)
+          .setFullMethodName("Waiter/serve")
+          .setRequestMarshaller(STRING_MARSHALLER)
+          .setResponseMarshaller(INTEGER_MARSHALLER)
+          .build();
   private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
   private static final Context.Key<String> SERVER_TRACER_ADDED_KEY = Context.key("tracer-added");
   private static final Context.CancellableContext SERVER_CONTEXT =
@@ -402,16 +412,10 @@
     final AtomicReference<ServerCall<String, Integer>> callReference
         = new AtomicReference<ServerCall<String, Integer>>();
     final AtomicReference<Context> callContextReference = new AtomicReference<Context>();
-    MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
-        .setType(MethodDescriptor.MethodType.UNKNOWN)
-        .setFullMethodName("Waiter/serve")
-        .setRequestMarshaller(STRING_MARSHALLER)
-        .setResponseMarshaller(INTEGER_MARSHALLER)
-        .build();
     mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
-        new ServiceDescriptor("Waiter", method))
+        new ServiceDescriptor("Waiter", METHOD))
         .addMethod(
-            method,
+            METHOD,
             new ServerCallHandler<String, Integer>() {
               @Override
               public ServerCall.Listener<String> startCall(
@@ -570,18 +574,95 @@
   }
 
   @Test
+  public void interceptors() throws Exception {
+    final LinkedList<Context> capturedContexts = new LinkedList<Context>();
+    final Context.Key<String> key1 = Context.key("key1");
+    final Context.Key<String> key2 = Context.key("key2");
+    final Context.Key<String> key3 = Context.key("key3");
+    ServerInterceptor intercepter1 = new ServerInterceptor() {
+        @Override
+        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+            ServerCall<ReqT, RespT> call,
+            Metadata headers,
+            ServerCallHandler<ReqT, RespT> next) {
+          Context ctx = Context.current().withValue(key1, "value1");
+          Context origCtx = ctx.attach();
+          try {
+            capturedContexts.add(ctx);
+            return next.startCall(call, headers);
+          } finally {
+            ctx.detach(origCtx);
+          }
+        }
+      };
+    ServerInterceptor intercepter2 = new ServerInterceptor() {
+        @Override
+        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+            ServerCall<ReqT, RespT> call,
+            Metadata headers,
+            ServerCallHandler<ReqT, RespT> next) {
+          Context ctx = Context.current().withValue(key2, "value2");
+          Context origCtx = ctx.attach();
+          try {
+            capturedContexts.add(ctx);
+            return next.startCall(call, headers);
+          } finally {
+            ctx.detach(origCtx);
+          }
+        }
+      };
+    ServerCallHandler<String, Integer> callHandler = new ServerCallHandler<String, Integer>() {
+        @Override
+        public ServerCall.Listener<String> startCall(
+            ServerCall<String, Integer> call,
+            Metadata headers) {
+          capturedContexts.add(Context.current().withValue(key3, "value3"));
+          return callListener;
+        }
+      };
+
+    mutableFallbackRegistry.addService(
+        ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
+            .addMethod(METHOD, callHandler).build());
+    createServer(NO_FILTERS, Arrays.asList(intercepter2, intercepter1));
+    server.start();
+
+    ServerTransportListener transportListener
+        = transportServer.registerNewServerTransport(new SimpleServerTransport());
+
+    Metadata requestHeaders = new Metadata();
+    StatsTraceContext statsTraceCtx =
+        StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
+    when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
+
+    transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
+    assertEquals(1, executor.runDueTasks());
+
+    Context ctx1 = capturedContexts.poll();
+    assertEquals("value1", key1.get(ctx1));
+    assertNull(key2.get(ctx1));
+    assertNull(key3.get(ctx1));
+
+    Context ctx2 = capturedContexts.poll();
+    assertEquals("value1", key1.get(ctx2));
+    assertEquals("value2", key2.get(ctx2));
+    assertNull(key3.get(ctx2));
+
+    Context ctx3 = capturedContexts.poll();
+    assertEquals("value1", key1.get(ctx3));
+    assertEquals("value2", key2.get(ctx3));
+    assertEquals("value3", key3.get(ctx3));
+
+    assertTrue(capturedContexts.isEmpty());
+  }
+
+  @Test
   public void exceptionInStartCallPropagatesToStream() throws Exception {
     createAndStartServer(NO_FILTERS);
     final Status status = Status.ABORTED.withDescription("Oh, no!");
-    MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
-        .setType(MethodDescriptor.MethodType.UNKNOWN)
-        .setFullMethodName("Waiter/serve")
-        .setRequestMarshaller(STRING_MARSHALLER)
-        .setResponseMarshaller(INTEGER_MARSHALLER)
-        .build();
     mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
-        new ServiceDescriptor("Waiter", method))
-        .addMethod(method,
+        new ServiceDescriptor("Waiter", METHOD))
+        .addMethod(METHOD,
             new ServerCallHandler<String, Integer>() {
               @Override
               public ServerCall.Listener<String> startCall(
@@ -695,20 +776,14 @@
   @Test
   public void testCallContextIsBoundInListenerCallbacks() throws Exception {
     createAndStartServer(NO_FILTERS);
-    MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
-        .setType(MethodDescriptor.MethodType.UNKNOWN)
-        .setFullMethodName("Waiter/serve")
-        .setRequestMarshaller(STRING_MARSHALLER)
-        .setResponseMarshaller(INTEGER_MARSHALLER)
-        .build();
     final AtomicBoolean  onReadyCalled = new AtomicBoolean(false);
     final AtomicBoolean onMessageCalled = new AtomicBoolean(false);
     final AtomicBoolean onHalfCloseCalled = new AtomicBoolean(false);
     final AtomicBoolean onCancelCalled = new AtomicBoolean(false);
     mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
-        new ServiceDescriptor("Waiter", method))
+        new ServiceDescriptor("Waiter", METHOD))
         .addMethod(
-            method,
+            METHOD,
             new ServerCallHandler<String, Integer>() {
               @Override
               public ServerCall.Listener<String> startCall(
@@ -809,16 +884,9 @@
       }
     };
 
-    MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
-        .setType(MethodDescriptor.MethodType.UNKNOWN)
-        .setFullMethodName("Waiter/serve")
-        .setRequestMarshaller(STRING_MARSHALLER)
-        .setResponseMarshaller(INTEGER_MARSHALLER)
-        .build();
-
     mutableFallbackRegistry.addService(ServerServiceDefinition.builder(
-        new ServiceDescriptor("Waiter", method))
-        .addMethod(method,
+        new ServiceDescriptor("Waiter", METHOD))
+        .addMethod(METHOD,
             new ServerCallHandler<String, Integer>() {
               @Override
               public ServerCall.Listener<String> startCall(
@@ -928,15 +996,9 @@
   @Test
   public void handlerRegistryPriorities() throws Exception {
     fallbackRegistry = mock(HandlerRegistry.class);
-    MethodDescriptor<String, Integer> method1 = MethodDescriptor.<String, Integer>newBuilder()
-        .setType(MethodDescriptor.MethodType.UNKNOWN)
-        .setFullMethodName("Service1/Method1")
-        .setRequestMarshaller(STRING_MARSHALLER)
-        .setResponseMarshaller(INTEGER_MARSHALLER)
-        .build();
     registry = new InternalHandlerRegistry.Builder()
-        .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Service1", method1))
-            .addMethod(method1, callHandler).build())
+        .addService(ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
+            .addMethod(METHOD, callHandler).build())
         .build();
     transportServer = new SimpleServer();
     createAndStartServer(NO_FILTERS);
@@ -945,11 +1007,11 @@
         = transportServer.registerNewServerTransport(new SimpleServerTransport());
     Metadata requestHeaders = new Metadata();
     StatsTraceContext statsTraceCtx =
-        StatsTraceContext.newServerContext(streamTracerFactories, "Waitier/serve", requestHeaders);
+        StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders);
     when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
 
     // This call will be handled by callHandler from the internal registry
-    transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
+    transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
     assertEquals(1, executor.runDueTasks());
     verify(callHandler).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
         Matchers.<Metadata>anyObject());
@@ -1109,9 +1171,15 @@
   }
 
   private void createServer(List<ServerTransportFilter> filters) {
+    createServer(filters, Collections.<ServerInterceptor>emptyList());
+  }
+
+  private void createServer(
+      List<ServerTransportFilter> filters, List<ServerInterceptor> interceptors) {
     assertNull(server);
     server = new ServerImpl(executorPool, timerPool, registry, fallbackRegistry,
-        transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters);
+        transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, filters,
+        interceptors);
   }
 
   private void verifyExecutorsAcquired() {