blob: 3ab50e4ece258ce0a9cd37ae08a57e8bd3dcbdb7 [file] [log] [blame]
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc;
import static com.google.common.collect.Iterables.getOnlyElement;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall.Listener;
import io.grpc.testing.NoopServerCall;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link ServerInterceptors}. */
@RunWith(JUnit4.class)
public class ServerInterceptorsTest {
@Mock
private Marshaller<String> requestMarshaller;
@Mock
private Marshaller<Integer> responseMarshaller;
@Mock
private ServerCallHandler<String, Integer> handler;
@Mock
private ServerCall.Listener<String> listener;
private MethodDescriptor<String, Integer> flowMethod;
private ServerCall<String, Integer> call = new NoopServerCall<String, Integer>();
private ServerServiceDefinition serviceDefinition;
private final Metadata headers = new Metadata();
/** Set up for test. */
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
flowMethod = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodType.UNKNOWN)
.setFullMethodName("basic/flow")
.setRequestMarshaller(requestMarshaller)
.setResponseMarshaller(responseMarshaller)
.build();
Mockito.when(handler.startCall(
Mockito.<ServerCall<String, Integer>>any(), Mockito.<Metadata>any()))
.thenReturn(listener);
serviceDefinition = ServerServiceDefinition.builder(new ServiceDescriptor("basic", flowMethod))
.addMethod(flowMethod, handler).build();
}
/** Final checks for all tests. */
@After
public void makeSureExpectedMocksUnused() {
verifyZeroInteractions(requestMarshaller);
verifyZeroInteractions(responseMarshaller);
verifyZeroInteractions(listener);
}
@Test(expected = NullPointerException.class)
public void npeForNullServiceDefinition() {
ServerServiceDefinition serviceDef = null;
ServerInterceptors.intercept(serviceDef, Arrays.<ServerInterceptor>asList());
}
@Test(expected = NullPointerException.class)
public void npeForNullInterceptorList() {
ServerInterceptors.intercept(serviceDefinition, (List<ServerInterceptor>) null);
}
@Test(expected = NullPointerException.class)
public void npeForNullInterceptor() {
ServerInterceptors.intercept(serviceDefinition, Arrays.asList((ServerInterceptor) null));
}
@Test
public void noop() {
assertSame(serviceDefinition,
ServerInterceptors.intercept(serviceDefinition, Arrays.<ServerInterceptor>asList()));
}
@Test
public void multipleInvocationsOfHandler() {
ServerInterceptor interceptor = Mockito.spy(new NoopInterceptor());
ServerServiceDefinition intercepted
= ServerInterceptors.intercept(serviceDefinition, Arrays.asList(interceptor));
assertSame(listener,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
verify(interceptor).interceptCall(same(call), same(headers), anyCallHandler());
verify(handler).startCall(call, headers);
verifyNoMoreInteractions(interceptor, handler);
assertSame(listener,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
verify(interceptor, times(2))
.interceptCall(same(call), same(headers), anyCallHandler());
verify(handler, times(2)).startCall(call, headers);
verifyNoMoreInteractions(interceptor, handler);
}
@Test
public void correctHandlerCalled() {
@SuppressWarnings("unchecked")
ServerCallHandler<String, Integer> handler2 = mock(ServerCallHandler.class);
MethodDescriptor<String, Integer> flowMethod2 =
flowMethod.toBuilder().setFullMethodName("basic/flow2").build();
serviceDefinition = ServerServiceDefinition.builder(
new ServiceDescriptor("basic", flowMethod, flowMethod2))
.addMethod(flowMethod, handler)
.addMethod(flowMethod2, handler2).build();
ServerServiceDefinition intercepted = ServerInterceptors.intercept(
serviceDefinition, Arrays.<ServerInterceptor>asList(new NoopInterceptor()));
getMethod(intercepted, "basic/flow").getServerCallHandler().startCall(call, headers);
verify(handler).startCall(call, headers);
verifyNoMoreInteractions(handler);
verifyNoMoreInteractions(handler2);
getMethod(intercepted, "basic/flow2").getServerCallHandler().startCall(call, headers);
verify(handler2).startCall(call, headers);
verifyNoMoreInteractions(handler);
verifyNoMoreInteractions(handler2);
}
@Test
public void callNextTwice() {
ServerInterceptor interceptor = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
// Calling next twice is permitted, although should only rarely be useful.
assertSame(listener, next.startCall(call, headers));
return next.startCall(call, headers);
}
};
ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition,
interceptor);
assertSame(listener,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
verify(handler, times(2)).startCall(same(call), same(headers));
verifyNoMoreInteractions(handler);
}
@Test
public void ordered() {
final List<String> order = new ArrayList<String>();
handler = new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call,
Metadata headers) {
order.add("handler");
return listener;
}
};
ServerInterceptor interceptor1 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
order.add("i1");
return next.startCall(call, headers);
}
};
ServerInterceptor interceptor2 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
order.add("i2");
return next.startCall(call, headers);
}
};
ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
new ServiceDescriptor("basic", flowMethod))
.addMethod(flowMethod, handler).build();
ServerServiceDefinition intercepted = ServerInterceptors.intercept(
serviceDefinition, Arrays.asList(interceptor1, interceptor2));
assertSame(listener,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
assertEquals(Arrays.asList("i2", "i1", "handler"), order);
}
@Test
public void orderedForward() {
final List<String> order = new ArrayList<String>();
handler = new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call,
Metadata headers) {
order.add("handler");
return listener;
}
};
ServerInterceptor interceptor1 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
order.add("i1");
return next.startCall(call, headers);
}
};
ServerInterceptor interceptor2 = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
order.add("i2");
return next.startCall(call, headers);
}
};
ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
new ServiceDescriptor("basic", flowMethod))
.addMethod(flowMethod, handler).build();
ServerServiceDefinition intercepted = ServerInterceptors.interceptForward(
serviceDefinition, interceptor1, interceptor2);
assertSame(listener,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
assertEquals(Arrays.asList("i1", "i2", "handler"), order);
}
@Test
public void argumentsPassed() {
@SuppressWarnings("unchecked")
final ServerCall<String, Integer> call2 = new NoopServerCall<String, Integer>();
@SuppressWarnings("unchecked")
final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class);
ServerInterceptor interceptor = new ServerInterceptor() {
@SuppressWarnings("unchecked") // Lot's of casting for no benefit. Not intended use.
@Override
public <R1, R2> ServerCall.Listener<R1> interceptCall(
ServerCall<R1, R2> call,
Metadata headers,
ServerCallHandler<R1, R2> next) {
assertSame(call, ServerInterceptorsTest.this.call);
assertSame(listener,
next.startCall((ServerCall<R1, R2>)call2, headers));
return (ServerCall.Listener<R1>) listener2;
}
};
ServerServiceDefinition intercepted = ServerInterceptors.intercept(
serviceDefinition, Arrays.asList(interceptor));
assertSame(listener2,
getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
verify(handler).startCall(call2, headers);
}
@Test
@SuppressWarnings("unchecked")
public void typedMarshalledMessages() {
final List<String> order = new ArrayList<String>();
Marshaller<Holder> marshaller = new Marshaller<Holder>() {
@Override
public InputStream stream(Holder value) {
return value.get();
}
@Override
public Holder parse(InputStream stream) {
return new Holder(stream);
}
};
ServerCallHandler<Holder, Holder> handler2 = new ServerCallHandler<Holder, Holder>() {
@Override
public Listener<Holder> startCall(final ServerCall<Holder, Holder> call,
final Metadata headers) {
return new Listener<Holder>() {
@Override
public void onMessage(Holder message) {
order.add("handler");
call.sendMessage(message);
}
};
}
};
MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder()
.setType(MethodType.UNKNOWN)
.setFullMethodName("basic/wrapped")
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.build();
ServerServiceDefinition serviceDef = ServerServiceDefinition.builder(
new ServiceDescriptor("basic", wrappedMethod))
.addMethod(wrappedMethod, handler2).build();
ServerInterceptor interceptor1 = new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
.SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
public void sendMessage(RespT message) {
order.add("i1sendMessage");
assertTrue(message instanceof Holder);
super.sendMessage(message);
}
};
ServerCall.Listener<ReqT> originalListener = next
.startCall(interceptedCall, headers);
return new ForwardingServerCallListener
.SimpleForwardingServerCallListener<ReqT>(originalListener) {
@Override
public void onMessage(ReqT message) {
order.add("i1onMessage");
assertTrue(message instanceof Holder);
super.onMessage(message);
}
};
}
};
ServerInterceptor interceptor2 = new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
.SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
public void sendMessage(RespT message) {
order.add("i2sendMessage");
assertTrue(message instanceof InputStream);
super.sendMessage(message);
}
};
ServerCall.Listener<ReqT> originalListener = next
.startCall(interceptedCall, headers);
return new ForwardingServerCallListener
.SimpleForwardingServerCallListener<ReqT>(originalListener) {
@Override
public void onMessage(ReqT message) {
order.add("i2onMessage");
assertTrue(message instanceof InputStream);
super.onMessage(message);
}
};
}
};
ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDef, interceptor1);
ServerServiceDefinition inputStreamMessageService = ServerInterceptors
.useInputStreamMessages(intercepted);
ServerServiceDefinition intercepted2 = ServerInterceptors
.intercept(inputStreamMessageService, interceptor2);
ServerMethodDefinition<InputStream, InputStream> serverMethod =
(ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped");
ServerCall<InputStream, InputStream> call2 = new NoopServerCall<InputStream, InputStream>();
byte[] bytes = {};
serverMethod
.getServerCallHandler()
.startCall(call2, headers)
.onMessage(new ByteArrayInputStream(bytes));
assertEquals(
Arrays.asList("i2onMessage", "i1onMessage", "handler", "i1sendMessage", "i2sendMessage"),
order);
}
@SuppressWarnings("unchecked")
private static ServerMethodDefinition<String, Integer> getSoleMethod(
ServerServiceDefinition serviceDef) {
if (serviceDef.getMethods().size() != 1) {
throw new AssertionError("Not exactly one method present");
}
return (ServerMethodDefinition<String, Integer>) getOnlyElement(serviceDef.getMethods());
}
@SuppressWarnings("unchecked")
private static ServerMethodDefinition<String, Integer> getMethod(
ServerServiceDefinition serviceDef, String name) {
return (ServerMethodDefinition<String, Integer>) serviceDef.getMethod(name);
}
private ServerCallHandler<String, Integer> anyCallHandler() {
return Mockito.<ServerCallHandler<String, Integer>>any();
}
private static class NoopInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
return next.startCall(call, headers);
}
}
private static class Holder {
private final InputStream inputStream;
Holder(InputStream inputStream) {
this.inputStream = inputStream;
}
public InputStream get() {
return inputStream;
}
}
}