| /* |
| * Copyright 2014, Google Inc. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are |
| * met: |
| * |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above |
| * copyright notice, this list of conditions and the following disclaimer |
| * in the documentation and/or other materials provided with the |
| * distribution. |
| * |
| * * Neither the name of Google Inc. nor the names of its |
| * contributors may be used to endorse or promote products derived from |
| * this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| */ |
| |
| package com.google.net.stubby; |
| |
| import static com.google.common.base.Charsets.UTF_8; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertNotNull; |
| import static org.mockito.Matchers.eq; |
| import static org.mockito.Matchers.isNull; |
| import static org.mockito.Matchers.notNull; |
| import static org.mockito.Matchers.same; |
| import static org.mockito.Mockito.timeout; |
| import static org.mockito.Mockito.times; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.verifyNoMoreInteractions; |
| |
| import com.google.common.io.ByteStreams; |
| import com.google.common.util.concurrent.AbstractService; |
| import com.google.common.util.concurrent.Service; |
| import com.google.net.stubby.transport.ServerStream; |
| import com.google.net.stubby.transport.ServerStreamListener; |
| import com.google.net.stubby.transport.ServerTransportListener; |
| |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.mockito.ArgumentCaptor; |
| import org.mockito.Mock; |
| import org.mockito.Mockito; |
| import org.mockito.MockitoAnnotations; |
| |
| import java.io.ByteArrayInputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.util.concurrent.BrokenBarrierException; |
| import java.util.concurrent.CyclicBarrier; |
| import java.util.concurrent.Executor; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.atomic.AtomicReference; |
| |
| /** Unit tests for {@link ServerImpl}. */ |
| @RunWith(JUnit4.class) |
| public class ServerImplTest { |
| private static final IntegerMarshaller INTEGER_MARSHALLER = new IntegerMarshaller(); |
| private static final StringMarshaller STRING_MARSHALLER = new StringMarshaller(); |
| |
| private ExecutorService executor = Executors.newSingleThreadExecutor(); |
| private MutableHandlerRegistry registry = new MutableHandlerRegistryImpl(); |
| private Service transportServer = new NoopService(); |
| private ServerImpl server = new ServerImpl(executor, registry) |
| .setTransportServer(transportServer); |
| |
| @Mock |
| private ServerStream stream; |
| |
| @Mock |
| private ServerCall.Listener<String> callListener; |
| |
| @Before |
| public void startup() { |
| MockitoAnnotations.initMocks(this); |
| |
| server.startAsync(); |
| server.awaitRunning(); |
| } |
| |
| @After |
| public void teardown() { |
| executor.shutdownNow(); |
| } |
| |
| @Test |
| public void startStopImmediate() { |
| Service transportServer = new NoopService(); |
| Server server = new ServerImpl(executor, registry).setTransportServer(transportServer); |
| assertEquals(Service.State.NEW, server.state()); |
| assertEquals(Service.State.NEW, transportServer.state()); |
| server.startAsync(); |
| server.awaitRunning(); |
| assertEquals(Service.State.RUNNING, server.state()); |
| assertEquals(Service.State.RUNNING, transportServer.state()); |
| server.stopAsync(); |
| server.awaitTerminated(); |
| assertEquals(Service.State.TERMINATED, server.state()); |
| assertEquals(Service.State.TERMINATED, transportServer.state()); |
| } |
| |
| @Test |
| public void transportServerFailureFailsServer() { |
| class FailableService extends NoopService { |
| public void doNotifyFailed(Throwable cause) { |
| notifyFailed(cause); |
| } |
| } |
| FailableService transportServer = new FailableService(); |
| Server server = new ServerImpl(executor, registry).setTransportServer(transportServer); |
| server.startAsync(); |
| server.awaitRunning(); |
| RuntimeException ex = new RuntimeException("force failure"); |
| transportServer.doNotifyFailed(ex); |
| assertEquals(Service.State.FAILED, server.state()); |
| assertEquals(ex, server.failureCause()); |
| } |
| |
| @Test |
| public void transportServerFailsStartup() { |
| class FailingStartupService extends NoopService { |
| @Override |
| public void doStart() { |
| notifyFailed(new RuntimeException()); |
| } |
| } |
| FailingStartupService transportServer = new FailingStartupService(); |
| Server server = new ServerImpl(executor, registry).setTransportServer(transportServer); |
| server.startAsync(); |
| assertEquals(Service.State.FAILED, server.state()); |
| } |
| |
| @Test |
| public void transportServerFirstToShutdown() { |
| class ManualStoppedService extends NoopService { |
| public void doNotifyStopped() { |
| notifyStopped(); |
| } |
| |
| @Override |
| public void doStop() {} // Don't notify. |
| } |
| NoopService transportServer = new NoopService(); |
| ServerImpl server = new ServerImpl(executor, registry).setTransportServer(transportServer); |
| server.startAsync(); |
| server.awaitRunning(); |
| ManualStoppedService transport = new ManualStoppedService(); |
| transport.startAsync(); |
| server.serverListener().transportCreated(transport); |
| server.stopAsync(); |
| assertEquals(Service.State.STOPPING, transport.state()); |
| assertEquals(Service.State.TERMINATED, transportServer.state()); |
| assertEquals(Service.State.STOPPING, server.state()); |
| |
| transport.doNotifyStopped(); |
| assertEquals(Service.State.TERMINATED, transport.state()); |
| assertEquals(Service.State.TERMINATED, server.state()); |
| } |
| |
| @Test |
| public void transportServerLastToShutdown() { |
| class ManualStoppedService extends NoopService { |
| public void doNotifyStopped() { |
| notifyStopped(); |
| } |
| |
| @Override |
| public void doStop() {} // Don't notify. |
| } |
| ManualStoppedService transportServer = new ManualStoppedService(); |
| ServerImpl server = new ServerImpl(executor, registry).setTransportServer(transportServer); |
| server.startAsync(); |
| server.awaitRunning(); |
| Service transport = new NoopService(); |
| transport.startAsync(); |
| server.serverListener().transportCreated(transport); |
| server.stopAsync(); |
| assertEquals(Service.State.TERMINATED, transport.state()); |
| assertEquals(Service.State.STOPPING, transportServer.state()); |
| assertEquals(Service.State.STOPPING, server.state()); |
| |
| transportServer.doNotifyStopped(); |
| assertEquals(Service.State.TERMINATED, transportServer.state()); |
| assertEquals(Service.State.TERMINATED, server.state()); |
| } |
| |
| @Test |
| public void basicExchangeSuccessful() throws Exception { |
| final Metadata.Key<Integer> metadataKey |
| = Metadata.Key.of("inception", Metadata.INTEGER_MARSHALLER); |
| final AtomicReference<ServerCall<Integer>> callReference |
| = new AtomicReference<ServerCall<Integer>>(); |
| registry.addService(ServerServiceDefinition.builder("Waiter") |
| .addMethod("serve", STRING_MARSHALLER, INTEGER_MARSHALLER, |
| new ServerCallHandler<String, Integer>() { |
| @Override |
| public ServerCall.Listener<String> startCall(String fullMethodName, |
| ServerCall<Integer> call, Metadata.Headers headers) { |
| assertEquals("/Waiter/serve", fullMethodName); |
| assertNotNull(call); |
| assertNotNull(headers); |
| assertEquals(0, headers.get(metadataKey).intValue()); |
| callReference.set(call); |
| return callListener; |
| } |
| }).build()); |
| ServerTransportListener transportListener = newTransport(server); |
| |
| Metadata.Headers headers = new Metadata.Headers(); |
| headers.put(metadataKey, 0); |
| ServerStreamListener streamListener |
| = transportListener.streamCreated(stream, "/Waiter/serve", headers); |
| assertNotNull(streamListener); |
| |
| executeBarrier(executor).await(); |
| ServerCall<Integer> call = callReference.get(); |
| assertNotNull(call); |
| |
| String order = "Lots of pizza, please"; |
| streamListener.messageRead(STRING_MARSHALLER.stream(order), 1); |
| verify(callListener, timeout(2000)).onPayload(order); |
| |
| call.sendPayload(314); |
| ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class); |
| verify(stream).writeMessage(inputCaptor.capture(), eq(3), isNull(Runnable.class)); |
| verify(stream).flush(); |
| assertEquals(314, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue()); |
| |
| streamListener.halfClosed(); // All full; no dessert. |
| executeBarrier(executor).await(); |
| verify(callListener).onHalfClose(); |
| |
| call.sendPayload(50); |
| verify(stream).writeMessage(inputCaptor.capture(), eq(2), isNull(Runnable.class)); |
| verify(stream, times(2)).flush(); |
| assertEquals(50, INTEGER_MARSHALLER.parse(inputCaptor.getValue()).intValue()); |
| |
| Metadata.Trailers trailers = new Metadata.Trailers(); |
| trailers.put(metadataKey, 3); |
| Status status = Status.OK.withDescription("A okay"); |
| call.close(status, trailers); |
| verify(stream).close(status, trailers); |
| |
| streamListener.closed(Status.OK); |
| executeBarrier(executor).await(); |
| verify(callListener).onComplete(); |
| |
| verifyNoMoreInteractions(stream); |
| verifyNoMoreInteractions(callListener); |
| } |
| |
| @Test |
| public void exceptionInStartCallPropagatesToStream() throws Exception { |
| CyclicBarrier barrier = executeBarrier(executor); |
| final Status status = Status.ABORTED.withDescription("Oh, no!"); |
| registry.addService(ServerServiceDefinition.builder("Waiter") |
| .addMethod("serve", STRING_MARSHALLER, INTEGER_MARSHALLER, |
| new ServerCallHandler<String, Integer>() { |
| @Override |
| public ServerCall.Listener<String> startCall(String fullMethodName, |
| ServerCall<Integer> call, Metadata.Headers headers) { |
| throw status.asRuntimeException(); |
| } |
| }).build()); |
| ServerTransportListener transportListener = newTransport(server); |
| |
| ServerStreamListener streamListener |
| = transportListener.streamCreated(stream, "/Waiter/serve", new Metadata.Headers()); |
| assertNotNull(streamListener); |
| verifyNoMoreInteractions(stream); |
| |
| barrier.await(); |
| executeBarrier(executor).await(); |
| verify(stream).close(same(status), notNull(Metadata.Trailers.class)); |
| verifyNoMoreInteractions(stream); |
| } |
| |
| private static ServerTransportListener newTransport(ServerImpl server) { |
| Service transport = new NoopService(); |
| transport.startAsync(); |
| return server.serverListener().transportCreated(transport); |
| } |
| |
| /** |
| * Useful for plugging a single-threaded executor from processing tasks, or for waiting until a |
| * single-threaded executor has processed queued tasks. |
| */ |
| private static CyclicBarrier executeBarrier(Executor executor) { |
| final CyclicBarrier barrier = new CyclicBarrier(2); |
| executor.execute(new Runnable() { |
| @Override |
| public void run() { |
| try { |
| barrier.await(); |
| } catch (InterruptedException ex) { |
| throw new RuntimeException(ex); |
| } catch (BrokenBarrierException ex) { |
| throw new RuntimeException(ex); |
| } |
| } |
| }); |
| return barrier; |
| } |
| |
| private static class NoopService extends AbstractService { |
| @Override |
| protected void doStart() { |
| notifyStarted(); |
| } |
| |
| @Override |
| protected void doStop() { |
| notifyStopped(); |
| } |
| } |
| |
| private static class StringMarshaller implements Marshaller<String> { |
| @Override |
| public InputStream stream(String value) { |
| return new ByteArrayInputStream(value.getBytes(UTF_8)); |
| } |
| |
| @Override |
| public String parse(InputStream stream) { |
| try { |
| return new String(ByteStreams.toByteArray(stream), UTF_8); |
| } catch (IOException ex) { |
| throw new RuntimeException(ex); |
| } |
| } |
| } |
| |
| private static class IntegerMarshaller implements Marshaller<Integer> { |
| @Override |
| public InputStream stream(Integer value) { |
| return STRING_MARSHALLER.stream(value.toString()); |
| } |
| |
| @Override |
| public Integer parse(InputStream stream) { |
| return Integer.valueOf(STRING_MARSHALLER.parse(stream)); |
| } |
| } |
| } |