blob: cadb10a67b163e4a205568c9d4eebb2b12a6b071 [file] [log] [blame]
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.google.net.stubby.testing.integration;
import static com.google.net.stubby.testing.integration.Messages.PayloadType.COMPRESSABLE;
import static com.google.net.stubby.testing.integration.Util.assertEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.net.stubby.AbstractServerBuilder;
import com.google.net.stubby.Call;
import com.google.net.stubby.ChannelImpl;
import com.google.net.stubby.Metadata;
import com.google.net.stubby.ServerImpl;
import com.google.net.stubby.ServerInterceptors;
import com.google.net.stubby.Status;
import com.google.net.stubby.proto.ProtoUtils;
import com.google.net.stubby.stub.MetadataUtils;
import com.google.net.stubby.stub.StreamObserver;
import com.google.net.stubby.stub.StreamRecorder;
import com.google.net.stubby.testing.TestUtils;
import com.google.net.stubby.testing.integration.Messages.Payload;
import com.google.net.stubby.testing.integration.Messages.PayloadType;
import com.google.net.stubby.testing.integration.Messages.ResponseParameters;
import com.google.net.stubby.testing.integration.Messages.SimpleRequest;
import com.google.net.stubby.testing.integration.Messages.SimpleResponse;
import com.google.net.stubby.testing.integration.Messages.StreamingInputCallRequest;
import com.google.net.stubby.testing.integration.Messages.StreamingInputCallResponse;
import com.google.net.stubby.testing.integration.Messages.StreamingOutputCallRequest;
import com.google.net.stubby.testing.integration.Messages.StreamingOutputCallResponse;
import com.google.protobuf.ByteString;
import com.google.protobuf.EmptyProtos.Empty;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* Abstract base class for all GRPC transport tests.
*/
public abstract class AbstractTransportTest {
public static final Metadata.Key<Messages.SimpleContext> METADATA_KEY =
ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance());
private static ScheduledExecutorService testServiceExecutor;
private static ServerImpl server;
protected static void startStaticServer(AbstractServerBuilder<?> builder) {
testServiceExecutor = Executors.newScheduledThreadPool(2);
server = builder
.addService(ServerInterceptors.intercept(
TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)),
TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)))
.buildAndWaitForRunning();
}
protected static void stopStaticServer() {
server.stopAsync();
testServiceExecutor.shutdown();
}
protected ChannelImpl channel;
protected TestServiceGrpc.TestServiceBlockingStub blockingStub;
protected TestServiceGrpc.TestService asyncStub;
/**
* Must be called by the subclass setup method.
*/
@Before
public void setup() {
channel = createChannel();
blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel);
}
@After
public void teardown() throws Exception {
if (channel != null) {
channel.shutdown();
}
}
protected abstract ChannelImpl createChannel();
@Test
public void emptyUnary() throws Exception {
assertEquals(Empty.getDefaultInstance(), blockingStub.emptyCall(Empty.getDefaultInstance()));
}
@Test
public void largeUnary() throws Exception {
final SimpleRequest request = SimpleRequest.newBuilder()
.setResponseSize(314159)
.setResponseType(PayloadType.COMPRESSABLE)
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[271828])))
.build();
final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[314159])))
.build();
assertEquals(goldenResponse, blockingStub.unaryCall(request));
}
@Test
public void serverStreaming() throws Exception {
final StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
.setResponseType(PayloadType.COMPRESSABLE)
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(31415))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(9))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(2653))
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(58979))
.build();
final List<StreamingOutputCallResponse> goldenResponses = Arrays.asList(
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[31415])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[9])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[2653])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[58979])))
.build());
StreamRecorder<StreamingOutputCallResponse> recorder = StreamRecorder.create();
asyncStub.streamingOutputCall(request, recorder);
recorder.awaitCompletion();
assertSuccess(recorder);
assertEquals(goldenResponses, recorder.getValues());
}
@Test
public void clientStreaming() throws Exception {
final List<StreamingInputCallRequest> requests = Arrays.asList(
StreamingInputCallRequest.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[27182])))
.build(),
StreamingInputCallRequest.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[8])))
.build(),
StreamingInputCallRequest.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[1828])))
.build(),
StreamingInputCallRequest.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[45904])))
.build());
final StreamingInputCallResponse goldenResponse = StreamingInputCallResponse.newBuilder()
.setAggregatedPayloadSize(74922)
.build();
StreamRecorder<StreamingInputCallResponse> responseObserver = StreamRecorder.create();
StreamObserver<StreamingInputCallRequest> requestObserver =
asyncStub.streamingInputCall(responseObserver);
for (StreamingInputCallRequest request : requests) {
requestObserver.onValue(request);
}
requestObserver.onCompleted();
assertEquals(goldenResponse, responseObserver.firstValue().get());
}
@Test(timeout=5000)
public void pingPong() throws Exception {
final List<StreamingOutputCallRequest> requests = Arrays.asList(
StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(31415))
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[27182])))
.build(),
StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(9))
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[8])))
.build(),
StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(2653))
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[1828])))
.build(),
StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(58979))
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[45904])))
.build());
final List<StreamingOutputCallResponse> goldenResponses = Arrays.asList(
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[31415])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[9])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[2653])))
.build(),
StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[58979])))
.build());
@SuppressWarnings("unchecked")
StreamObserver<StreamingOutputCallResponse> responseObserver = mock(StreamObserver.class);
StreamObserver<StreamingOutputCallRequest> requestObserver
= asyncStub.fullDuplexCall(responseObserver);
for (int i = 0; i < requests.size(); i++) {
requestObserver.onValue(requests.get(i));
verify(responseObserver, timeout(1000)).onValue(goldenResponses.get(i));
verifyNoMoreInteractions(responseObserver);
}
requestObserver.onCompleted();
verify(responseObserver, timeout(1000)).onCompleted();
verifyNoMoreInteractions(responseObserver);
}
@Test
public void emptyStream() throws Exception {
@SuppressWarnings("unchecked")
StreamObserver<StreamingOutputCallResponse> responseObserver = mock(StreamObserver.class);
StreamObserver<StreamingOutputCallRequest> requestObserver
= asyncStub.fullDuplexCall(responseObserver);
requestObserver.onCompleted();
verify(responseObserver, timeout(1000)).onCompleted();
verifyNoMoreInteractions(responseObserver);
}
@Test
public void cancelAfterBegin() throws Exception {
StreamRecorder<StreamingInputCallResponse> responseObserver = StreamRecorder.create();
StreamObserver<StreamingInputCallRequest> requestObserver =
asyncStub.streamingInputCall(responseObserver);
requestObserver.onError(new RuntimeException());
responseObserver.awaitCompletion();
assertEquals(Arrays.<StreamingInputCallResponse>asList(), responseObserver.getValues());
assertEquals(Status.CANCELLED, Status.fromThrowable(responseObserver.getError()));
}
@Test
public void cancelAfterFirstResponse() throws Exception {
final StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setSize(31415))
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[27182])))
.build();
final StreamingOutputCallResponse goldenResponse = StreamingOutputCallResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[31415])))
.build();
@SuppressWarnings("unchecked")
StreamObserver<StreamingOutputCallResponse> responseObserver = mock(StreamObserver.class);
StreamObserver<StreamingOutputCallRequest> requestObserver
= asyncStub.fullDuplexCall(responseObserver);
requestObserver.onValue(request);
verify(responseObserver, timeout(1000)).onValue(goldenResponse);
verifyNoMoreInteractions(responseObserver);
requestObserver.onError(new RuntimeException());
ArgumentCaptor<Throwable> captor = ArgumentCaptor.forClass(Throwable.class);
verify(responseObserver, timeout(1000)).onError(captor.capture());
assertEquals(Status.CANCELLED, Status.fromThrowable(captor.getValue()));
verifyNoMoreInteractions(responseObserver);
}
@Test
public void fullDuplexCallShouldSucceed() throws Exception {
// Build the request.
List<Integer> responseSizes = Arrays.asList(50, 100, 150, 200);
StreamingOutputCallRequest.Builder streamingOutputBuilder =
StreamingOutputCallRequest.newBuilder();
streamingOutputBuilder.setResponseType(COMPRESSABLE);
for (Integer size : responseSizes) {
streamingOutputBuilder.addResponseParametersBuilder().setSize(size).setIntervalUs(0);
}
final StreamingOutputCallRequest request = streamingOutputBuilder.build();
StreamRecorder<StreamingOutputCallResponse> recorder = StreamRecorder.create();
StreamObserver<StreamingOutputCallRequest> requestStream =
asyncStub.fullDuplexCall(recorder);
final int numRequests = 10;
for (int ix = numRequests; ix > 0; --ix) {
requestStream.onValue(request);
}
requestStream.onCompleted();
recorder.awaitCompletion();
assertSuccess(recorder);
assertEquals(responseSizes.size() * numRequests, recorder.getValues().size());
for (int ix = 0; ix < recorder.getValues().size(); ++ix) {
StreamingOutputCallResponse response = recorder.getValues().get(ix);
assertEquals(COMPRESSABLE, response.getPayload().getType());
int length = response.getPayload().getBody().size();
int expectedSize = responseSizes.get(ix % responseSizes.size());
assertEquals("comparison failed at index " + ix, expectedSize, length);
}
}
@Test
public void halfDuplexCallShouldSucceed() throws Exception {
// Build the request.
List<Integer> responseSizes = Arrays.asList(50, 100, 150, 200);
StreamingOutputCallRequest.Builder streamingOutputBuilder =
StreamingOutputCallRequest.newBuilder();
streamingOutputBuilder.setResponseType(COMPRESSABLE);
for (Integer size : responseSizes) {
streamingOutputBuilder.addResponseParametersBuilder().setSize(size).setIntervalUs(0);
}
final StreamingOutputCallRequest request = streamingOutputBuilder.build();
StreamRecorder<StreamingOutputCallResponse> recorder = StreamRecorder.create();
StreamObserver<StreamingOutputCallRequest> requestStream = asyncStub.halfDuplexCall(recorder);
final int numRequests = 10;
for (int ix = numRequests; ix > 0; --ix) {
requestStream.onValue(request);
}
requestStream.onCompleted();
recorder.awaitCompletion();
assertSuccess(recorder);
assertEquals(responseSizes.size() * numRequests, recorder.getValues().size());
for (int ix = 0; ix < recorder.getValues().size(); ++ix) {
StreamingOutputCallResponse response = recorder.getValues().get(ix);
assertEquals(COMPRESSABLE, response.getPayload().getType());
int length = response.getPayload().getBody().size();
int expectedSize = responseSizes.get(ix % responseSizes.size());
assertEquals("comparison failed at index " + ix, expectedSize, length);
}
}
@Test
public void streamingOutputShouldBeFlowControlled() throws Exception {
// Create the call object.
Call<StreamingOutputCallRequest, StreamingOutputCallResponse> call =
channel.newCall(TestServiceGrpc.CONFIG.streamingOutputCall);
// Build the request.
List<Integer> responseSizes = Arrays.asList(50, 100, 150, 200);
StreamingOutputCallRequest.Builder streamingOutputBuilder =
StreamingOutputCallRequest.newBuilder();
streamingOutputBuilder.setResponseType(COMPRESSABLE);
for (Integer size : responseSizes) {
streamingOutputBuilder.addResponseParametersBuilder().setSize(size).setIntervalUs(0);
}
StreamingOutputCallRequest request = streamingOutputBuilder.build();
// Start the call and prepare capture of results.
final List<StreamingOutputCallResponse> results =
Collections.synchronizedList(new ArrayList<StreamingOutputCallResponse>());
final SettableFuture<Void> completionFuture = SettableFuture.create();
final AtomicInteger count = new AtomicInteger();
call.start(new Call.Listener<StreamingOutputCallResponse>() {
@Override
public void onHeaders(Metadata.Headers headers) {
}
@Override
public void onPayload(final StreamingOutputCallResponse payload) {
results.add(payload);
count.incrementAndGet();
}
@Override
public void onClose(Status status, Metadata.Trailers trailers) {
if (status.isOk()) {
completionFuture.set(null);
} else {
completionFuture.setException(status.asException());
}
}
}, new Metadata.Headers());
// Send the request.
call.sendPayload(request);
call.halfClose();
// Slowly set completion on all of the futures.
int expectedResults = responseSizes.size();
while (count.get() < expectedResults) {
// Allow one more inbound message to be delivered to the application.
call.request(1);
// Sleep a bit.
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
}
// Wait for successful completion of the response.
completionFuture.get();
assertEquals(responseSizes.size(), results.size());
for (int ix = 0; ix < results.size(); ++ix) {
StreamingOutputCallResponse response = results.get(ix);
assertEquals(COMPRESSABLE, response.getPayload().getType());
int length = response.getPayload().getBody().size();
assertEquals("comparison failed at index " + ix, responseSizes.get(ix).intValue(), length);
}
}
@org.junit.Test
public void exchangeContextUnaryCall() throws Exception {
TestServiceGrpc.TestServiceBlockingStub stub =
TestServiceGrpc.newBlockingStub(channel);
// Capture the context exchange
Metadata.Headers fixedHeaders = new Metadata.Headers();
// Send a context proto (as it's in the default extension registry)
Messages.SimpleContext contextValue =
Messages.SimpleContext.newBuilder().setValue("dog").build();
fixedHeaders.put(METADATA_KEY, contextValue);
stub = MetadataUtils.attachHeaders(stub, fixedHeaders);
// .. and expect it to be echoed back in trailers
AtomicReference<Metadata.Trailers> trailersCapture = new AtomicReference<Metadata.Trailers>();
AtomicReference<Metadata.Headers> headersCapture = new AtomicReference<Metadata.Headers>();
stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture);
Assert.assertNotNull(stub.unaryCall(unaryRequest()));
// Assert that our side channel object is echoed back in both headers and trailers
Assert.assertEquals(contextValue, headersCapture.get().get(METADATA_KEY));
Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY));
}
@org.junit.Test
public void exchangeContextStreamingCall() throws Exception {
TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
// Capture the context exchange
Metadata.Headers fixedHeaders = new Metadata.Headers();
// Send a context proto (as it's in the default extension registry)
Messages.SimpleContext contextValue =
Messages.SimpleContext.newBuilder().setValue("dog").build();
fixedHeaders.put(METADATA_KEY, contextValue);
stub = MetadataUtils.attachHeaders(stub, fixedHeaders);
// .. and expect it to be echoed back in trailers
AtomicReference<Metadata.Trailers> trailersCapture = new AtomicReference<Metadata.Trailers>();
AtomicReference<Metadata.Headers> headersCapture = new AtomicReference<Metadata.Headers>();
stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture);
List<Integer> responseSizes = Arrays.asList(50, 100, 150, 200);
Messages.StreamingOutputCallRequest.Builder streamingOutputBuilder =
Messages.StreamingOutputCallRequest.newBuilder();
streamingOutputBuilder.setResponseType(COMPRESSABLE);
for (Integer size : responseSizes) {
streamingOutputBuilder.addResponseParametersBuilder().setSize(size).setIntervalUs(0);
}
final Messages.StreamingOutputCallRequest request = streamingOutputBuilder.build();
StreamRecorder<Messages.StreamingOutputCallResponse> recorder = StreamRecorder.create();
StreamObserver<Messages.StreamingOutputCallRequest> requestStream =
stub.fullDuplexCall(recorder);
final int numRequests = 10;
for (int ix = numRequests; ix > 0; --ix) {
requestStream.onValue(request);
}
requestStream.onCompleted();
recorder.awaitCompletion();
assertSuccess(recorder);
org.junit.Assert.assertEquals(responseSizes.size() * numRequests, recorder.getValues().size());
// Assert that our side channel object is echoed back in both headers and trailers
Assert.assertEquals(contextValue, headersCapture.get().get(METADATA_KEY));
Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY));
}
protected int unaryPayloadLength() {
// 10MiB.
return 10485760;
}
protected SimpleRequest unaryRequest() {
SimpleRequest.Builder unaryBuilder = SimpleRequest.newBuilder();
unaryBuilder.getPayloadBuilder().setType(PayloadType.COMPRESSABLE);
byte[] data = new byte[unaryPayloadLength()];
new Random().nextBytes(data);
unaryBuilder.getPayloadBuilder().setBody(ByteString.copyFrom(data));
unaryBuilder.setResponseSize(10).setResponseType(PayloadType.COMPRESSABLE);
return unaryBuilder.build();
}
protected static void assertSuccess(StreamRecorder<?> recorder) {
if (recorder.getError() != null) {
throw new AssertionError(recorder.getError());
}
}
}