blob: 9cb6cf4b889cb51631724ba0010fbbb190ac2253 [file] [log] [blame]
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.google.net.stubby;
import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.notNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.AbstractService;
import com.google.common.util.concurrent.Service;
import com.google.net.stubby.transport.ServerStream;
import com.google.net.stubby.transport.ServerStreamListener;
import com.google.net.stubby.transport.ServerTransportListener;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
/** Unit tests for {@link ServerImpl}. */
@RunWith(JUnit4.class)
public class ServerImplTest {
private static final IntegerMarshaller INTEGER_MARSHALLER = new IntegerMarshaller();
private static final StringMarshaller STRING_MARSHALLER = new StringMarshaller();
private ExecutorService executor = Executors.newSingleThreadExecutor();
private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl();
private Service transportServer = new NoopService();
private ServerImpl server = new ServerImpl(executor, registry)
.setTransportServer(transportServer);
@Mock
private ServerStream stream;
@Mock
private ServerCall.Listener<String> callListener;
@Before
public void startup() {
MockitoAnnotations.initMocks(this);
server.startAsync();
server.awaitRunning();
}
@After
public void teardown() {
executor.shutdownNow();
}
@Test
public void startStopImmediate() {
Service transportServer = new NoopService();
Server server = new ServerImpl(executor, registry).setTransportServer(transportServer);
assertEquals(Service.State.NEW, server.state());
assertEquals(Service.State.NEW, transportServer.state());
server.startAsync();
server.awaitRunning();
assertEquals(Service.State.RUNNING, server.state());
assertEquals(Service.State.RUNNING, transportServer.state());
server.stopAsync();
server.awaitTerminated();
assertEquals(Service.State.TERMINATED, server.state());
assertEquals(Service.State.TERMINATED, transportServer.state());
}
@Test
public void transportServerFailureFailsServer() {
class FailableService extends NoopService {
public void doNotifyFailed(Throwable cause) {
notifyFailed(cause);
}
}
FailableService transportServer = new FailableService();
Server server = new ServerImpl(executor, registry).setTransportServer(transportServer);
server.startAsync();
server.awaitRunning();
RuntimeException ex = new RuntimeException("force failure");
transportServer.doNotifyFailed(ex);
assertEquals(Service.State.FAILED, server.state());
assertEquals(ex, server.failureCause());
}
@Test
public void transportServerFailsStartup() {
class FailingStartupService extends NoopService {
@Override
public void doStart() {
notifyFailed(new RuntimeException());
}
}
FailingStartupService transportServer = new FailingStartupService();
Server server = new ServerImpl(executor, registry).setTransportServer(transportServer);
server.startAsync();
assertEquals(Service.State.FAILED, server.state());
}
@Test
public void transportServerFirstToShutdown() {
class ManualStoppedService extends NoopService {
public void doNotifyStopped() {
notifyStopped();
}
@Override
public void doStop() {} // Don't notify.
}
NoopService transportServer = new NoopService();
ServerImpl server = new ServerImpl(executor, registry).setTransportServer(transportServer);
server.startAsync();
server.awaitRunning();
ManualStoppedService transport = new ManualStoppedService();
transport.startAsync();
server.serverListener().transportCreated(transport);
server.stopAsync();
assertEquals(Service.State.STOPPING, transport.state());
assertEquals(Service.State.TERMINATED, transportServer.state());
assertEquals(Service.State.STOPPING, server.state());
transport.doNotifyStopped();
assertEquals(Service.State.TERMINATED, transport.state());
assertEquals(Service.State.TERMINATED, server.state());
}
@Test
public void transportServerLastToShutdown() {
class ManualStoppedService extends NoopService {
public void doNotifyStopped() {
notifyStopped();
}
@Override
public void doStop() {} // Don't notify.
}
ManualStoppedService transportServer = new ManualStoppedService();
ServerImpl server = new ServerImpl(executor, registry).setTransportServer(transportServer);
server.startAsync();
server.awaitRunning();
Service transport = new NoopService();
transport.startAsync();
server.serverListener().transportCreated(transport);
server.stopAsync();
assertEquals(Service.State.TERMINATED, transport.state());
assertEquals(Service.State.STOPPING, transportServer.state());
assertEquals(Service.State.STOPPING, server.state());
transportServer.doNotifyStopped();
assertEquals(Service.State.TERMINATED, transportServer.state());
assertEquals(Service.State.TERMINATED, server.state());
}
@Test
public void basicExchangeSuccessful() throws Exception {
final Metadata.Key<Integer> metadataKey
= Metadata.Key.of("inception", Metadata.INTEGER_MARSHALLER);
final AtomicReference<ServerCall<Integer>> callReference
= new AtomicReference<ServerCall<Integer>>();
registry.addService(ServerServiceDefinition.builder("Waiter")
.addMethod("serve", STRING_MARSHALLER, INTEGER_MARSHALLER,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(String fullMethodName,
ServerCall<Integer> call, Metadata.Headers headers) {
assertEquals("/Waiter/serve", fullMethodName);
assertNotNull(call);
assertNotNull(headers);
assertEquals(0, headers.get(metadataKey).intValue());
callReference.set(call);
return callListener;
}
}).build());
ServerTransportListener transportListener = newTransport(server);
Metadata.Headers headers = new Metadata.Headers();
headers.put(metadataKey, 0);
ServerStreamListener streamListener
= transportListener.streamCreated(stream, "/Waiter/serve", headers);
assertNotNull(streamListener);
executeBarrier(executor).await();
ServerCall<Integer> call = callReference.get();
assertNotNull(call);
String order = "Lots of pizza, please";
streamListener.messageRead(STRING_MARSHALLER.stream(order), 1);
verify(callListener, timeout(2000)).onPayload(order);
call.sendPayload(314);
ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class);
verify(stream).writeMessage(inputCaptor.capture(), eq(3), isNull(Runnable.class));
verify(stream).flush();
assertEquals(314, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue());
streamListener.halfClosed(); // All full; no dessert.
executeBarrier(executor).await();
verify(callListener).onHalfClose();
call.sendPayload(50);
verify(stream).writeMessage(inputCaptor.capture(), eq(2), isNull(Runnable.class));
verify(stream, times(2)).flush();
assertEquals(50, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue());
Metadata.Trailers trailers = new Metadata.Trailers();
trailers.put(metadataKey, 3);
Status status = Status.OK.withDescription("A okay");
call.close(status, trailers);
verify(stream).close(status, trailers);
streamListener.closed(Status.OK);
executeBarrier(executor).await();
verify(callListener).onComplete();
verifyNoMoreInteractions(stream);
verifyNoMoreInteractions(callListener);
}
@Test
public void exceptionInStartCallPropagatesToStream() throws Exception {
CyclicBarrier barrier = executeBarrier(executor);
final Status status = Status.ABORTED.withDescription("Oh, no!");
registry.addService(ServerServiceDefinition.builder("Waiter")
.addMethod("serve", STRING_MARSHALLER, INTEGER_MARSHALLER,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(String fullMethodName,
ServerCall<Integer> call, Metadata.Headers headers) {
throw status.asRuntimeException();
}
}).build());
ServerTransportListener transportListener = newTransport(server);
ServerStreamListener streamListener
= transportListener.streamCreated(stream, "/Waiter/serve", new Metadata.Headers());
assertNotNull(streamListener);
verifyNoMoreInteractions(stream);
barrier.await();
executeBarrier(executor).await();
verify(stream).close(same(status), notNull(Metadata.Trailers.class));
verifyNoMoreInteractions(stream);
}
private static ServerTransportListener newTransport(ServerImpl server) {
Service transport = new NoopService();
transport.startAsync();
return server.serverListener().transportCreated(transport);
}
/**
* Useful for plugging a single-threaded executor from processing tasks, or for waiting until a
* single-threaded executor has processed queued tasks.
*/
private static CyclicBarrier executeBarrier(Executor executor) {
final CyclicBarrier barrier = new CyclicBarrier(2);
executor.execute(new Runnable() {
@Override
public void run() {
try {
barrier.await();
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
} catch (BrokenBarrierException ex) {
throw new RuntimeException(ex);
}
}
});
return barrier;
}
private static class NoopService extends AbstractService {
@Override
protected void doStart() {
notifyStarted();
}
@Override
protected void doStop() {
notifyStopped();
}
}
private static class StringMarshaller implements Marshaller<String> {
@Override
public InputStream stream(String value) {
return new ByteArrayInputStream(value.getBytes(UTF_8));
}
@Override
public String parse(InputStream stream) {
try {
return new String(ByteStreams.toByteArray(stream), UTF_8);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
private static class IntegerMarshaller implements Marshaller<Integer> {
@Override
public InputStream stream(Integer value) {
return STRING_MARSHALLER.stream(value.toString());
}
@Override
public Integer parse(InputStream stream) {
return Integer.valueOf(STRING_MARSHALLER.parse(stream));
}
}
}