| package com.google.net.stubby.newtransport.netty; |
| |
| import static java.nio.charset.StandardCharsets.UTF_8; |
| import static org.junit.Assert.assertArrayEquals; |
| import static org.junit.Assert.assertEquals; |
| import static org.mockito.Matchers.any; |
| import static org.mockito.Matchers.anyInt; |
| import static org.mockito.Matchers.eq; |
| import static org.mockito.Mockito.never; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.verifyNoMoreInteractions; |
| import static org.mockito.Mockito.when; |
| |
| import com.google.common.io.ByteStreams; |
| import com.google.net.stubby.MethodDescriptor; |
| import com.google.net.stubby.Status; |
| import com.google.net.stubby.newtransport.Framer; |
| import com.google.net.stubby.newtransport.HttpUtil; |
| import com.google.net.stubby.newtransport.MessageFramer; |
| import com.google.net.stubby.newtransport.ServerStream; |
| import com.google.net.stubby.newtransport.ServerTransportListener; |
| import com.google.net.stubby.newtransport.StreamListener; |
| |
| import io.netty.buffer.ByteBuf; |
| import io.netty.buffer.Unpooled; |
| import io.netty.buffer.UnpooledByteBufAllocator; |
| import io.netty.channel.ChannelHandlerContext; |
| import io.netty.handler.codec.http2.DefaultHttp2Connection; |
| import io.netty.handler.codec.http2.DefaultHttp2FrameReader; |
| import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; |
| import io.netty.handler.codec.http2.DefaultHttp2Headers; |
| import io.netty.handler.codec.http2.Http2CodecUtil; |
| import io.netty.handler.codec.http2.Http2Error; |
| import io.netty.handler.codec.http2.Http2Headers; |
| import io.netty.handler.codec.http2.Http2Settings; |
| |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.mockito.ArgumentCaptor; |
| import org.mockito.Mock; |
| import org.mockito.MockitoAnnotations; |
| |
| import java.io.ByteArrayInputStream; |
| import java.io.InputStream; |
| import java.nio.ByteBuffer; |
| |
| /** Unit tests for {@link NettyServerHandler}. */ |
| @RunWith(JUnit4.class) |
| public class NettyServerHandlerTest extends NettyHandlerTestBase { |
| |
| private static final int STREAM_ID = 3; |
| private static final byte[] CONTENT = "hello world".getBytes(UTF_8); |
| |
| @Mock private ServerTransportListener transportListener; |
| |
| @Mock private StreamListener streamListener; |
| |
| private NettyServerStream stream; |
| |
| private NettyServerHandler handler; |
| |
| @Before |
| public void setup() throws Exception { |
| MockitoAnnotations.initMocks(this); |
| |
| when(transportListener.streamCreated(any(ServerStream.class), any(MethodDescriptor.class))) |
| .thenReturn(streamListener); |
| handler = new NettyServerHandler(transportListener, new DefaultHttp2Connection(true)); |
| frameWriter = new DefaultHttp2FrameWriter(); |
| frameReader = new DefaultHttp2FrameReader(); |
| |
| when(channel.isActive()).thenReturn(true); |
| mockContext(); |
| mockFuture(true); |
| |
| when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); |
| |
| // Simulate activation of the handler to force writing of the initial settings |
| handler.handlerAdded(ctx); |
| |
| // Simulate receipt of the connection preface |
| handler.channelRead(ctx, Http2CodecUtil.connectionPrefaceBuf()); |
| // Simulate receipt of initial remote settings. |
| ByteBuf serializedSettings = serializeSettings(new Http2Settings()); |
| handler.channelRead(ctx, serializedSettings); |
| |
| // Reset the context to clear any interactions resulting from the HTTP/2 |
| // connection preface handshake. |
| mockContext(); |
| } |
| |
| @Test |
| public void sendFrameShouldSucceed() throws Exception { |
| createStream(); |
| ByteBuf content = Unpooled.copiedBuffer(CONTENT); |
| |
| // Send a frame and verify that it was written. |
| handler.write(ctx, new SendGrpcFrameCommand(stream.id(), content, false), promise); |
| verify(promise, never()).setFailure(any(Throwable.class)); |
| verify(ctx).write(any(ByteBuf.class), eq(promise)); |
| assertEquals(0, content.refCnt()); |
| } |
| |
| @Test |
| public void inboundDataShouldForwardToStreamListener() throws Exception { |
| inboundDataShouldForwardToStreamListener(false); |
| } |
| |
| @Test |
| public void inboundDataWithEndStreamShouldForwardToStreamListener() throws Exception { |
| inboundDataShouldForwardToStreamListener(true); |
| } |
| |
| private void inboundDataShouldForwardToStreamListener(boolean endStream) throws Exception { |
| createStream(); |
| |
| // Create a data frame and then trigger the handler to read it. |
| ByteBuf frame = dataFrame(STREAM_ID, endStream); |
| handler.channelRead(ctx, frame); |
| ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class); |
| verify(streamListener).messageRead(captor.capture(), eq(CONTENT.length)); |
| assertArrayEquals(CONTENT, ByteStreams.toByteArray(captor.getValue())); |
| |
| if (endStream) { |
| verify(streamListener).closed(eq(Status.OK)); |
| } |
| verifyNoMoreInteractions(streamListener); |
| } |
| |
| @Test |
| public void clientHalfCloseShouldForwardToStreamListener() throws Exception { |
| createStream(); |
| |
| handler.channelRead(ctx, emptyDataFrame(STREAM_ID, true)); |
| verify(streamListener, never()).messageRead(any(InputStream.class), anyInt()); |
| verify(streamListener).closed(eq(Status.OK)); |
| verifyNoMoreInteractions(streamListener); |
| } |
| |
| @Test |
| public void clientCancelShouldForwardToStreamListener() throws Exception { |
| createStream(); |
| |
| handler.channelRead(ctx, rstStreamFrame(STREAM_ID, Http2Error.CANCEL.code())); |
| verify(streamListener, never()).messageRead(any(InputStream.class), anyInt()); |
| verify(streamListener).closed(eq(Status.CANCELLED)); |
| verifyNoMoreInteractions(streamListener); |
| } |
| |
| private void createStream() throws Exception { |
| Http2Headers headers = DefaultHttp2Headers.newBuilder() |
| .method(HttpUtil.HTTP_METHOD) |
| .set(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC) |
| .path("/foo.bar") |
| .build(); |
| ByteBuf headersFrame = headersFrame(STREAM_ID, headers); |
| handler.channelRead(ctx, headersFrame); |
| ArgumentCaptor<NettyServerStream> streamCaptor = |
| ArgumentCaptor.forClass(NettyServerStream.class); |
| @SuppressWarnings("rawtypes") |
| ArgumentCaptor<MethodDescriptor> methodCaptor = ArgumentCaptor.forClass(MethodDescriptor.class); |
| verify(transportListener).streamCreated(streamCaptor.capture(), methodCaptor.capture()); |
| stream = streamCaptor.getValue(); |
| } |
| |
| private ByteBuf dataFrame(int streamId, boolean endStream) { |
| final ByteBuf compressionFrame = Unpooled.buffer(CONTENT.length); |
| MessageFramer framer = new MessageFramer(new Framer.Sink<ByteBuffer>() { |
| @Override |
| public void deliverFrame(ByteBuffer frame, boolean endOfStream) { |
| compressionFrame.writeBytes(frame); |
| } |
| }, 1000); |
| framer.writePayload(new ByteArrayInputStream(CONTENT), CONTENT.length); |
| framer.flush(); |
| if (endStream) { |
| framer.close(); |
| } |
| ChannelHandlerContext ctx = newContext(); |
| frameWriter.writeData(ctx, streamId, compressionFrame, 0, endStream, newPromise()); |
| return captureWrite(ctx); |
| } |
| |
| private ByteBuf emptyDataFrame(int streamId, boolean endStream) { |
| ChannelHandlerContext ctx = newContext(); |
| frameWriter.writeData(ctx, streamId, Unpooled.EMPTY_BUFFER, 0, endStream, newPromise()); |
| return captureWrite(ctx); |
| } |
| } |