Refactoring Deframing code for Netty.
- New Decompressor interface with NettyDecompressor impl. This is responsible for unpackaging and uncompressing the GRPC compression frame.
- New class GrpcDeframer. This is a transport-agnostic class that uses a Decompressor to unpackage the compression frame, and then reads one complete GRPC frame and notifies the listener.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=73068835
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Decompressor.java b/core/src/main/java/com/google/net/stubby/newtransport/Decompressor.java
new file mode 100644
index 0000000..cd955b6
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/Decompressor.java
@@ -0,0 +1,47 @@
+package com.google.net.stubby.newtransport;
+
+import java.io.Closeable;
+
+import javax.annotation.Nullable;
+
+/**
+ * An object responsible for reading GRPC compression frames for a single stream.
+ */
+public interface Decompressor extends Closeable {
+
+ /**
+ * Adds the given chunk of a GRPC compression frame to the internal buffers. If the data is
+ * compressed, it is uncompressed whenever possible (which may only be after the entire
+ * compression frame has been received).
+ *
+ * <p>Some or all of the given {@code data} chunk may not be made immediately available via
+ * {@link #readBytes} due to internal buffering.
+ *
+ * @param data a received chunk of a GRPC compression frame. Control over the life cycle for this
+ * buffer is given to this {@link Decompressor}. Only this {@link Decompressor} should call
+ * {@link Buffer#close} after this point.
+ */
+ void decompress(Buffer data);
+
+ /**
+ * Reads up to the given number of bytes. Ownership of the returned {@link Buffer} is transferred
+ * to the caller who is responsible for calling {@link Buffer#close}.
+ *
+ * <p>The length of the returned {@link Buffer} may be less than {@code maxLength}, but will never
+ * be 0. If no data is available, {@code null} is returned. To ensure that all available data is
+ * read, the caller should repeatedly call {@link #readBytes} until it returns {@code null}.
+ *
+ * @param maxLength the maximum number of bytes to read. This value must be > 0, otherwise throws
+ * an {@link IllegalArgumentException}.
+ * @return a {@link Buffer} containing the number of bytes read or {@code null} if no data is
+ * currently available.
+ */
+ @Nullable
+ Buffer readBytes(int maxLength);
+
+ /**
+ * Closes this decompressor and frees any resources.
+ */
+ @Override
+ void close();
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/GrpcDeframer.java b/core/src/main/java/com/google/net/stubby/newtransport/GrpcDeframer.java
new file mode 100644
index 0000000..b5c24e2
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/GrpcDeframer.java
@@ -0,0 +1,203 @@
+package com.google.net.stubby.newtransport;
+
+import static com.google.net.stubby.GrpcFramingUtil.CONTEXT_VALUE_FRAME;
+import static com.google.net.stubby.GrpcFramingUtil.FRAME_LENGTH;
+import static com.google.net.stubby.GrpcFramingUtil.FRAME_TYPE_LENGTH;
+import static com.google.net.stubby.GrpcFramingUtil.FRAME_TYPE_MASK;
+import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME;
+import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME;
+
+import com.google.common.base.Preconditions;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.transport.Transport;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+/**
+ * Deframer for GRPC frames. Delegates deframing/decompression of the GRPC compression frame to a
+ * {@link Decompressor}.
+ */
+public class GrpcDeframer implements Closeable {
+ private enum State {
+ HEADER, BODY
+ }
+
+ private static final int HEADER_LENGTH = FRAME_TYPE_LENGTH + FRAME_LENGTH;
+ private final Decompressor decompressor;
+ private State state = State.HEADER;
+ private int requiredLength = HEADER_LENGTH;
+ private int frameType;
+ private boolean statusNotified;
+ private GrpcMessageListener listener;
+ private CompositeBuffer nextFrame;
+
+ public GrpcDeframer(Decompressor decompressor, GrpcMessageListener listener) {
+ this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
+ this.listener = Preconditions.checkNotNull(listener, "listener");
+ }
+
+ public void deframe(Buffer data, boolean endOfStream) {
+ Preconditions.checkNotNull(data, "data");
+
+ // Add the data to the decompression buffer.
+ decompressor.decompress(data);
+
+ // Process the uncompressed bytes.
+ while (readRequiredBytes()) {
+ if (statusNotified) {
+ throw new IllegalStateException("Inbound data after receiving status frame");
+ }
+
+ switch (state) {
+ case HEADER:
+ processHeader();
+ break;
+ case BODY:
+ processBody();
+ break;
+ default:
+ throw new AssertionError("Invalid state: " + state);
+ }
+ }
+
+ // If reached the end of stream without reading a status frame, fabricate one
+ // and deliver to the target.
+ if (!statusNotified && endOfStream) {
+ notifyStatus(Status.OK);
+ }
+ }
+
+
+ @Override
+ public void close() {
+ decompressor.close();
+ if (nextFrame != null) {
+ nextFrame.close();
+ }
+ }
+
+ /**
+ * Attempts to read the required bytes into nextFrame.
+ *
+ * @returns {@code true} if all of the required bytes have been read.
+ */
+ private boolean readRequiredBytes() {
+ if (nextFrame == null) {
+ nextFrame = new CompositeBuffer();
+ }
+
+ // Read until the buffer contains all the required bytes.
+ int missingBytes;
+ while ((missingBytes = requiredLength - nextFrame.readableBytes()) > 0) {
+ Buffer buffer = decompressor.readBytes(missingBytes);
+ if (buffer == null) {
+ // No more data is available.
+ break;
+ }
+ // Add it to the composite buffer for the next frame.
+ nextFrame.addBuffer(buffer);
+ }
+
+ // Return whether or not all of the required bytes are now in the frame.
+ return nextFrame.readableBytes() == requiredLength;
+ }
+
+ /**
+ * Processes the GRPC compression header which is composed of the compression flag and the outer
+ * frame length.
+ */
+ private void processHeader() {
+ // Peek, but do not read the header.
+ frameType = nextFrame.readUnsignedByte() & FRAME_TYPE_MASK;
+
+ // Update the required length to include the length of the frame.
+ requiredLength = nextFrame.readInt();
+
+ // Continue reading the frame body.
+ state = State.BODY;
+ }
+
+ /**
+ * Processes the body of the GRPC compression frame. A single compression frame may contain
+ * several GRPC messages within it.
+ */
+ private void processBody() {
+ switch (frameType) {
+ case CONTEXT_VALUE_FRAME:
+ processContext();
+ break;
+ case PAYLOAD_FRAME:
+ processMessage();
+ break;
+ case STATUS_FRAME:
+ processStatus();
+ break;
+ default:
+ throw new AssertionError("Invalid frameType: " + frameType);
+ }
+
+ // Done with this frame, begin processing the next header.
+ state = State.HEADER;
+ requiredLength = HEADER_LENGTH;
+ }
+
+ /**
+ * Processes the payload of a context frame.
+ */
+ private void processContext() {
+ Transport.ContextValue ctx;
+ try {
+ // Not clear if using proto encoding here is of any benefit.
+ // Using ContextValue.parseFrom requires copying out of the framed chunk
+ // Writing a custom parser would have to do varint handling and potentially
+ // deal with out-of-order tags etc.
+ ctx = Transport.ContextValue.parseFrom(Buffers.openStream(nextFrame, false));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ } finally {
+ nextFrame.close();
+ nextFrame = null;
+ }
+
+ // Call the handler.
+ Buffer ctxBuffer = Buffers.wrap(ctx.getValue());
+ listener.onContext(ctx.getKey(), Buffers.openStream(ctxBuffer, true),
+ ctxBuffer.readableBytes());
+ }
+
+ /**
+ * Processes the payload of a message frame.
+ */
+ private void processMessage() {
+ try {
+ listener.onPayload(Buffers.openStream(nextFrame, true), nextFrame.readableBytes());
+ } finally {
+ // Don't close the frame, since the listener is now responsible for the life-cycle.
+ nextFrame = null;
+ }
+ }
+
+ /**
+ * Processes the payload of a status frame.
+ */
+ private void processStatus() {
+ try {
+ int statusCode = nextFrame.readUnsignedShort();
+ Transport.Code code = Transport.Code.valueOf(statusCode);
+ notifyStatus(code != null ? new Status(code)
+ : new Status(Transport.Code.UNKNOWN, "Unknown status code " + statusCode));
+ } finally {
+ nextFrame.close();
+ nextFrame = null;
+ }
+ }
+
+ /**
+ * Delivers the status notification to the listener.
+ */
+ private void notifyStatus(Status status) {
+ statusNotified = true;
+ listener.onStatus(status);
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
index 9c402b8..8ccc12c 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
@@ -22,5 +22,55 @@
return ((b & COMPRESSION_FLAG_MASK) == FLATE_FLAG);
}
+ /**
+ * Length of the compression type field.
+ */
+ public static final int COMPRESSION_TYPE_LENGTH = 1;
+
+ /**
+ * Length of the compression frame length field.
+ */
+ public static final int COMPRESSION_FRAME_LENGTH = 3;
+
+ /**
+ * Full length of the compression header.
+ */
+ public static final int COMPRESSION_HEADER_LENGTH =
+ COMPRESSION_TYPE_LENGTH + COMPRESSION_FRAME_LENGTH;
+
+ /**
+ * Length of flags block in bytes
+ */
+ public static final int FRAME_TYPE_LENGTH = 1;
+
+ // Flags
+ public static final byte PAYLOAD_FRAME = 0x0;
+ public static final byte CONTEXT_VALUE_FRAME = 0x1;
+ public static final byte CALL_HEADER_FRAME = 0x2;
+ public static final byte STATUS_FRAME = 0x3;
+ public static final byte FRAME_TYPE_MASK = 0x3;
+
+ /**
+ * Number of bytes for the length field within a frame
+ */
+ public static final int FRAME_LENGTH = 4;
+
+ /**
+ * Full length of the GRPC frame header.
+ */
+ public static final int FRAME_HEADER_LENGTH = FRAME_TYPE_LENGTH + FRAME_LENGTH;
+
+ public static boolean isContextValueFrame(int flags) {
+ return (flags & FRAME_TYPE_MASK) == CONTEXT_VALUE_FRAME;
+ }
+
+ public static boolean isPayloadFrame(byte flags) {
+ return (flags & FRAME_TYPE_MASK) == PAYLOAD_FRAME;
+ }
+
+ public static boolean isStatusFrame(byte flags) {
+ return (flags & FRAME_TYPE_MASK) == STATUS_FRAME;
+ }
+
private TransportFrameUtil() {}
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
index c9c3c06..5372133 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
@@ -7,16 +7,16 @@
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.AbstractStream;
import com.google.net.stubby.newtransport.ClientStream;
-import com.google.net.stubby.newtransport.Deframer;
+import com.google.net.stubby.newtransport.GrpcDeframer;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.transport.Transport;
-import io.netty.handler.codec.http.HttpResponseStatus;
-import io.netty.handler.codec.http2.Http2Headers;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.handler.codec.http2.Http2Headers;
import java.nio.ByteBuffer;
@@ -28,7 +28,7 @@
private volatile int id = PENDING_STREAM_ID;
private final Channel channel;
- private final Deframer<ByteBuf> deframer;
+ private final GrpcDeframer deframer;
private Transport.Code responseCode = Transport.Code.UNKNOWN;
private boolean isGrpcResponse;
private StringBuilder nonGrpcErrorMessage = new StringBuilder();
@@ -36,7 +36,8 @@
NettyClientStream(StreamListener listener, Channel channel) {
super(listener);
this.channel = Preconditions.checkNotNull(channel, "channel");
- this.deframer = new ByteBufDeframer(channel.alloc(), inboundMessageHandler());
+ this.deframer =
+ new GrpcDeframer(new NettyDecompressor(channel.alloc()), inboundMessageHandler());
}
/**
@@ -86,7 +87,7 @@
if (isGrpcResponse) {
// Retain the ByteBuf until it is released by the deframer.
- deframer.deliverFrame(frame.retain(), endOfStream);
+ deframer.deframe(new NettyBuffer(frame.retain()), endOfStream);
// TODO(user): add flow control.
promise.setSuccess();
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyDecompressor.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyDecompressor.java
new file mode 100644
index 0000000..0331031
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyDecompressor.java
@@ -0,0 +1,282 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static com.google.net.stubby.newtransport.TransportFrameUtil.COMPRESSION_HEADER_LENGTH;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.isFlateCompressed;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.io.Closeables;
+import com.google.net.stubby.newtransport.Buffer;
+import com.google.net.stubby.newtransport.Decompressor;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.ByteBufOutputStream;
+import io.netty.buffer.CompositeByteBuf;
+
+import java.io.Closeable;
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.zip.InflaterInputStream;
+
+import javax.annotation.Nullable;
+
+/**
+ * A {@link Decompressor} implementation based on Netty {@link CompositeByteBuf}s.
+ */
+public class NettyDecompressor implements Decompressor {
+
+ private final CompositeByteBuf buffer;
+ private final ByteBufAllocator alloc;
+ private Frame frame;
+
+ public NettyDecompressor(ByteBufAllocator alloc) {
+ this.alloc = Preconditions.checkNotNull(alloc, "alloc");
+ buffer = alloc.compositeBuffer();
+ }
+
+ @Override
+ public void decompress(Buffer data) {
+ ByteBuf buf = toByteBuf(data);
+
+ // Add it to the compression frame buffer.
+ buffer.addComponent(buf);
+ buffer.writerIndex(buffer.writerIndex() + buf.readableBytes());
+ }
+
+ @Override
+ public Buffer readBytes(final int maxLength) {
+ Preconditions.checkArgument(maxLength > 0, "maxLength must be > 0");
+ try {
+ // Read the next frame if we don't already have one.
+ if (frame == null) {
+ frame = nextFrame();
+ }
+
+ ByteBuf byteBuf = null;
+ if (frame != null) {
+ // Read as many bytes as we can from the frame.
+ byteBuf = frame.readBytes(maxLength);
+
+ // If we reached the end of the frame, close it.
+ if (frame.complete()) {
+ frame.close();
+ frame = null;
+ }
+ }
+
+ if (byteBuf == null) {
+ // No data was available.
+ return null;
+ }
+
+ return new NettyBuffer(byteBuf);
+ } finally {
+ // Discard any component buffers that have been fully read.
+ buffer.discardReadComponents();
+ }
+ }
+
+ @Override
+ public void close() {
+ // Release the CompositeByteBuf. This will automatically release any components as well.
+ buffer.release();
+ if (frame != null) {
+ frame.close();
+ }
+ }
+
+ /**
+ * Returns the next compression frame object, or {@code null} if the next frame header is
+ * incomplete.
+ */
+ @SuppressWarnings("resource")
+ @Nullable
+ private Frame nextFrame() {
+ if (buffer.readableBytes() < COMPRESSION_HEADER_LENGTH) {
+ // Don't have all the required bytes for the frame header yet.
+ return null;
+ }
+
+ // Read the header and create the frame object.
+ boolean compressed = isFlateCompressed(buffer.readUnsignedByte());
+ int frameLength = buffer.readUnsignedMedium();
+ if (frameLength == 0) {
+ return nextFrame();
+ }
+
+ return compressed ? new CompressedFrame(frameLength) : new UncompressedFrame(frameLength);
+ }
+
+ /**
+ * Converts the given buffer into a {@link ByteBuf}.
+ */
+ private ByteBuf toByteBuf(Buffer data) {
+ if (data instanceof NettyBuffer) {
+ // Just return the contained ByteBuf.
+ return ((NettyBuffer) data).buffer();
+ }
+
+ // Create a new ByteBuf and copy the content to it.
+ try {
+ int length = data.readableBytes();
+ ByteBuf buf = alloc.buffer(length);
+ data.readBytes(new ByteBufOutputStream(buf), length);
+ return buf;
+ } catch (IOException e) {
+ throw Throwables.propagate(e);
+ } finally {
+ data.close();
+ }
+ }
+
+ /**
+ * A wrapper around the body of a compression frame. Provides a generic method for reading bytes
+ * from any frame.
+ */
+ private interface Frame extends Closeable {
+ @Nullable
+ ByteBuf readBytes(int maxLength);
+
+ boolean complete();
+
+ @Override
+ void close();
+ }
+
+ /**
+ * An uncompressed frame. Just writes bytes directly from the compression frame.
+ */
+ private class UncompressedFrame implements Frame {
+ int bytesRemainingInFrame;
+
+ public UncompressedFrame(int frameLength) {
+ this.bytesRemainingInFrame = frameLength;
+ }
+
+ @Override
+ @Nullable
+ public ByteBuf readBytes(int maxLength) {
+ Preconditions.checkState(!complete(), "Must not call readBytes on a completed frame");
+ int available = buffer.readableBytes();
+ if (available == 0) {
+ return null;
+ }
+
+ int bytesToRead = Math.min(available, Math.min(maxLength, bytesRemainingInFrame));
+ bytesRemainingInFrame -= bytesToRead;
+
+ return buffer.readBytes(bytesToRead);
+ }
+
+ @Override
+ public boolean complete() {
+ return bytesRemainingInFrame == 0;
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+ }
+
+ /**
+ * A compressed frame that inflates the data as it reads from the frame.
+ */
+ private class CompressedFrame implements Frame {
+ private final InputStream in;
+ private ByteBuf nextBuf;
+
+ public CompressedFrame(int frameLength) {
+ // Limit the stream by the frameLength.
+ in = new InflaterInputStream(new GrowableByteBufInputStream(frameLength));
+ }
+
+ @Override
+ @Nullable
+ public ByteBuf readBytes(int maxLength) {
+
+ // If the pre-existing nextBuf is too small, release it.
+ if (nextBuf != null && nextBuf.capacity() < maxLength) {
+ nextBuf.release();
+ nextBuf = null;
+ }
+
+ if (nextBuf == null) {
+ nextBuf = alloc.buffer();
+ }
+
+ try {
+ int bytesToWrite = Math.min(maxLength, nextBuf.writableBytes());
+ nextBuf.writeBytes(in, bytesToWrite);
+ } catch (EOFException e) {
+ // The next compressed block is unavailable at the moment. Nothing to return.
+ return null;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ if (!nextBuf.isReadable()) {
+ throw new AssertionError("Read zero bytes from the compression frame");
+ }
+
+ ByteBuf ret = nextBuf;
+ nextBuf = null;
+ return ret;
+ }
+
+ @Override
+ public boolean complete() {
+ try {
+ return in.available() <= 0;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void close() {
+ Closeables.closeQuietly(in);
+ }
+ }
+
+ /**
+ * A stream backed by the {@link #buffer}, which allows for additional reading as the buffer
+ * grows. Not using Netty's stream class since it doesn't handle growth of the underlying buffer.
+ */
+ private class GrowableByteBufInputStream extends InputStream {
+ final int startIndex;
+ final int endIndex;
+
+ GrowableByteBufInputStream(int length) {
+ startIndex = buffer.readerIndex();
+ endIndex = startIndex + length;
+ }
+
+ @Override
+ public int read() throws IOException {
+ if (available() == 0) {
+ return -1;
+ }
+ return buffer.readByte() & 0xff;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ int available = available();
+ if (available == 0) {
+ return -1;
+ }
+
+ len = Math.min(available, len);
+ buffer.readBytes(b, off, len);
+ return len;
+ }
+
+ @Override
+ public int available() throws IOException {
+ return Math.min(endIndex - buffer.readerIndex(), buffer.readableBytes());
+ }
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/GrpcDeframerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/GrpcDeframerTest.java
new file mode 100644
index 0000000..81562cf
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/GrpcDeframerTest.java
@@ -0,0 +1,280 @@
+package com.google.net.stubby.newtransport;
+
+import static com.google.net.stubby.newtransport.TransportFrameUtil.CONTEXT_VALUE_FRAME;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.PAYLOAD_FRAME;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.STATUS_FRAME;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+import com.google.common.io.ByteStreams;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.transport.Transport;
+import com.google.protobuf.ByteString;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Arrays;
+
+import javax.annotation.Nullable;
+
+/**
+ * Tests for {@link GrpcDeframer}.
+ */
+@RunWith(JUnit4.class)
+public class GrpcDeframerTest {
+ private static final String KEY = "key";
+ private static final String MESSAGE = "hello world";
+ private static final ByteString MESSAGE_BSTR = ByteString.copyFromUtf8(MESSAGE);
+ private static final Transport.Code STATUS_CODE = Transport.Code.CANCELLED;
+
+ private GrpcDeframer reader;
+
+ private Transport.ContextValue contextProto;
+
+ private StubDecompressor decompressor;
+
+ @Mock
+ private GrpcMessageListener listener;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ decompressor = new StubDecompressor();
+ reader = new GrpcDeframer(decompressor, listener);
+
+ contextProto = Transport.ContextValue.newBuilder().setKey(KEY).setValue(MESSAGE_BSTR).build();
+ }
+
+ @Test
+ public void contextShouldCallTarget() throws Exception {
+ decompressor.init(contextFrame());
+ reader.deframe(Buffers.empty(), false);
+ verifyContext();
+ verifyNoPayload();
+ verifyNoStatus();
+ }
+
+ @Test
+ public void contextWithEndOfStreamShouldNotifyStatus() throws Exception {
+ decompressor.init(contextFrame());
+ reader.deframe(Buffers.empty(), true);
+ verifyContext();
+ verifyNoPayload();
+ verifyStatus(Transport.Code.OK);
+ }
+
+ @Test
+ public void payloadShouldCallTarget() throws Exception {
+ decompressor.init(payloadFrame());
+ reader.deframe(Buffers.empty(), false);
+ verifyNoContext();
+ verifyPayload();
+ verifyNoStatus();
+ }
+
+ @Test
+ public void payloadWithEndOfStreamShouldNotifyStatus() throws Exception {
+ decompressor.init(payloadFrame());
+ reader.deframe(Buffers.empty(), true);
+ verifyNoContext();
+ verifyPayload();
+ verifyStatus(Transport.Code.OK);
+ }
+
+ @Test
+ public void statusShouldCallTarget() throws Exception {
+ decompressor.init(statusFrame());
+ reader.deframe(Buffers.empty(), false);
+ verifyNoContext();
+ verifyNoPayload();
+ verifyStatus();
+ }
+
+ @Test
+ public void statusWithEndOfStreamShouldNotifyStatusOnce() throws Exception {
+ decompressor.init(statusFrame());
+ reader.deframe(Buffers.empty(), true);
+ verifyNoContext();
+ verifyNoPayload();
+ verifyStatus();
+ }
+
+ @Test
+ public void multipleFramesShouldCallTarget() throws Exception {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+
+ // Write a context frame.
+ writeFrame(CONTEXT_VALUE_FRAME, contextProto.toByteArray(), dos);
+
+ // Write a payload frame.
+ writeFrame(PAYLOAD_FRAME, MESSAGE_BSTR.toByteArray(), dos);
+
+ // Write a status frame.
+ byte[] statusBytes = new byte[] {0, (byte) STATUS_CODE.getNumber()};
+ writeFrame(STATUS_FRAME, statusBytes, dos);
+
+ // Now write the complete frame: compression header followed by the 3 message frames.
+ dos.close();
+ byte[] bodyBytes = os.toByteArray();
+
+ decompressor.init(bodyBytes);
+ reader.deframe(Buffers.empty(), false);
+
+ // Verify that all callbacks were called.
+ verifyContext();
+ verifyPayload();
+ verifyStatus();
+ }
+
+ @Test
+ public void partialFrameShouldSucceed() throws Exception {
+ byte[] frame = payloadFrame();
+
+ // Create a buffer that contains 2 payload frames.
+ byte[] fullBuffer = Arrays.copyOf(frame, frame.length * 2);
+ System.arraycopy(frame, 0, fullBuffer, frame.length, frame.length);
+
+ // Use only a portion of the frame. Should not call the listener.
+ int startIx = 0;
+ int endIx = 10;
+ byte[] chunk = Arrays.copyOfRange(fullBuffer, startIx, endIx);
+ decompressor.init(chunk);
+ reader.deframe(Buffers.empty(), false);
+ verifyNoContext();
+ verifyNoPayload();
+ verifyNoStatus();
+
+ // Supply the rest of the frame and a portion of a second frame. Should call the listener.
+ startIx = endIx;
+ endIx = startIx + frame.length;
+ chunk = Arrays.copyOfRange(fullBuffer, startIx, endIx);
+ decompressor.init(chunk);
+ reader.deframe(Buffers.empty(), false);
+ verifyNoContext();
+ verifyPayload();
+ verifyNoStatus();
+ }
+
+ private void verifyContext() {
+ ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
+ verify(listener).onContext(eq(KEY), captor.capture(), eq(MESSAGE.length()));
+ assertEquals(MESSAGE, readString(captor.getValue(), MESSAGE.length()));
+ }
+
+ private void verifyPayload() {
+ ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
+ verify(listener).onPayload(captor.capture(), eq(MESSAGE.length()));
+ assertEquals(MESSAGE, readString(captor.getValue(), MESSAGE.length()));
+ }
+
+ private String readString(InputStream in, int length) {
+ try {
+ byte[] bytes = new byte[length];
+ ByteStreams.readFully(in, bytes);
+ return new String(bytes, UTF_8);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void verifyStatus() {
+ verifyStatus(Transport.Code.CANCELLED);
+ }
+
+ private void verifyStatus(Transport.Code code) {
+ ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
+ verify(listener).onStatus(captor.capture());
+ assertEquals(code, captor.getValue().getCode());
+ }
+
+ private void verifyNoContext() {
+ verify(listener, never()).onContext(any(String.class), any(InputStream.class), anyInt());
+ }
+
+ private void verifyNoPayload() {
+ verify(listener, never()).onPayload(any(InputStream.class), anyInt());
+ }
+
+ private void verifyNoStatus() {
+ verify(listener, never()).onStatus(any(Status.class));
+ }
+
+ private byte[] contextFrame() throws IOException {
+ return frame(CONTEXT_VALUE_FRAME, contextProto.toByteArray());
+ }
+
+ private static byte[] payloadFrame() throws IOException {
+ return frame(PAYLOAD_FRAME, MESSAGE_BSTR.toByteArray());
+ }
+
+ private static byte[] statusFrame() throws IOException {
+ byte[] bytes = new byte[] {0, (byte) STATUS_CODE.getNumber()};
+ return frame(STATUS_FRAME, bytes);
+ }
+
+ private static byte[] frame(int frameType, byte[] data) throws IOException {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ OutputStream os = bos;
+ DataOutputStream dos = new DataOutputStream(os);
+ writeFrame(frameType, data, dos);
+ dos.close();
+ return bos.toByteArray();
+ }
+
+ private static void writeFrame(int frameType, byte[] data, DataOutputStream out)
+ throws IOException {
+ out.write(frameType);
+ out.writeInt(data.length);
+ out.write(data);
+ }
+
+ private static final class StubDecompressor implements Decompressor {
+ byte[] bytes;
+ int offset;
+
+ void init(byte[] bytes) {
+ this.bytes = bytes;
+ this.offset = 0;
+ }
+
+ @Override
+ public void decompress(Buffer data) {
+ // Do nothing.
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+
+ @Override
+ @Nullable
+ public Buffer readBytes(int length) {
+ length = Math.min(length, bytes.length - offset);
+ if (length == 0) {
+ return null;
+ }
+
+ Buffer buffer = Buffers.wrap(ByteString.copyFrom(bytes, offset, length));
+ offset += length;
+ return buffer;
+ }
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyDecompressorTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyDecompressorTest.java
new file mode 100644
index 0000000..1edede0
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyDecompressorTest.java
@@ -0,0 +1,170 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static com.google.net.stubby.newtransport.Buffers.readAsStringUtf8;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.COMPRESSION_HEADER_LENGTH;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.FLATE_FLAG;
+import static com.google.net.stubby.newtransport.TransportFrameUtil.NO_COMPRESS_FLAG;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+import com.google.net.stubby.newtransport.Buffer;
+import com.google.net.stubby.newtransport.Buffers;
+import com.google.net.stubby.newtransport.CompositeBuffer;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.zip.DeflaterOutputStream;
+
+/**
+ * Tests for {@link NettyDecompressor}.
+ */
+@RunWith(JUnit4.class)
+public class NettyDecompressorTest {
+ private static final String MESSAGE = "hello world";
+
+ private NettyDecompressor decompressor;
+
+ @Before
+ public void setup() {
+ decompressor = new NettyDecompressor(UnpooledByteBufAllocator.DEFAULT);
+ }
+
+ @Test
+ public void uncompressedDataShouldSucceed() throws Exception {
+ fullMessageShouldSucceed(false);
+ }
+
+ @Test
+ public void compressedDataShouldSucceed() throws Exception {
+ fullMessageShouldSucceed(true);
+ }
+
+ @Test
+ public void uncompressedFrameShouldNotBeReadableUntilComplete() throws Exception {
+ byte[] frame = frame(false);
+ byte[] chunk1 = Arrays.copyOf(frame, 5);
+ byte[] chunk2 = Arrays.copyOfRange(frame, 5, frame.length);
+
+ // Decompress the first chunk and verify it's not readable yet.
+ decompressor.decompress(Buffers.wrap(chunk1));
+
+ CompositeBuffer composite = new CompositeBuffer();
+ Buffer buffer = decompressor.readBytes(2);
+ assertEquals(1, buffer.readableBytes());
+ composite.addBuffer(buffer);
+
+ // Decompress the rest of the frame and verify it's readable.
+ decompressor.decompress(Buffers.wrap(chunk2));
+ composite.addBuffer(decompressor.readBytes(MESSAGE.length() - 1));
+ assertEquals(MESSAGE, readAsStringUtf8(composite));
+ }
+
+ @Test
+ public void compressedFrameShouldNotBeReadableUntilComplete() throws Exception {
+ byte[] frame = frame(true);
+ byte[] chunk1 = Arrays.copyOf(frame, 5);
+ byte[] chunk2 = Arrays.copyOfRange(frame, 5, frame.length);
+
+ // Decompress the first chunk and verify it's not readable yet.
+ decompressor.decompress(Buffers.wrap(chunk1));
+ Buffer buffer = decompressor.readBytes(2);
+ assertNull(buffer);
+
+ // Decompress the rest of the frame and verify it's readable.
+ decompressor.decompress(Buffers.wrap(chunk2));
+ CompositeBuffer composite = new CompositeBuffer();
+ for(int remaining = MESSAGE.length(); remaining > 0; ) {
+ Buffer buf = decompressor.readBytes(remaining);
+ if (buf == null) {
+ break;
+ }
+ composite.addBuffer(buf);
+ remaining -= buf.readableBytes();
+ }
+ assertEquals(MESSAGE, readAsStringUtf8(composite));
+ }
+
+ @Test
+ public void nettyBufferShouldBeReleasedAfterRead() throws Exception {
+ byte[] frame = frame(false);
+ byte[] chunk1 = Arrays.copyOf(frame, 5);
+ byte[] chunk2 = Arrays.copyOfRange(frame, 5, frame.length);
+ NettyBuffer buffer1 = new NettyBuffer(Unpooled.wrappedBuffer(chunk1));
+ NettyBuffer buffer2 = new NettyBuffer(Unpooled.wrappedBuffer(chunk2));
+ // CompositeByteBuf always keeps at least one buffer internally, so we add a second so
+ // that it will release the first after it is read.
+ decompressor.decompress(buffer1);
+ decompressor.decompress(buffer2);
+ NettyBuffer readBuffer = (NettyBuffer) decompressor.readBytes(buffer1.readableBytes());
+ assertEquals(0, buffer1.buffer().refCnt());
+ assertEquals(1, readBuffer.buffer().refCnt());
+ }
+
+ @Test
+ public void closeShouldReleasedBuffers() throws Exception {
+ byte[] frame = frame(false);
+ byte[] chunk1 = Arrays.copyOf(frame, 5);
+ NettyBuffer buffer1 = new NettyBuffer(Unpooled.wrappedBuffer(chunk1));
+ decompressor.decompress(buffer1);
+ assertEquals(1, buffer1.buffer().refCnt());
+ decompressor.close();
+ assertEquals(0, buffer1.buffer().refCnt());
+ }
+
+ private void fullMessageShouldSucceed(boolean compress) throws Exception {
+ // Decompress the entire frame all at once.
+ byte[] frame = frame(compress);
+ decompressor.decompress(Buffers.wrap(frame));
+
+ // Read some bytes and verify.
+ int chunkSize = MESSAGE.length() / 2;
+ assertEquals(MESSAGE.substring(0, chunkSize),
+ readAsStringUtf8(decompressor.readBytes(chunkSize)));
+
+ // Read the rest and verify.
+ assertEquals(MESSAGE.substring(chunkSize),
+ readAsStringUtf8(decompressor.readBytes(MESSAGE.length() - chunkSize)));
+ }
+
+ /**
+ * Creates a compression frame from {@link #MESSAGE}, applying compression if requested.
+ */
+ private byte[] frame(boolean compress) throws Exception {
+ byte[] msgBytes = bytes(MESSAGE);
+ if (compress) {
+ msgBytes = compress(msgBytes);
+ }
+ ByteArrayOutputStream os =
+ new ByteArrayOutputStream(msgBytes.length + COMPRESSION_HEADER_LENGTH);
+ DataOutputStream dos = new DataOutputStream(os);
+ int frameFlag = compress ? FLATE_FLAG : NO_COMPRESS_FLAG;
+ // Header = 1b flag | 3b length of GRPC frame
+ int header = (frameFlag << 24) | msgBytes.length;
+ dos.writeInt(header);
+ dos.write(msgBytes);
+ dos.close();
+ return os.toByteArray();
+ }
+
+ private byte[] bytes(String str) {
+ return str.getBytes(StandardCharsets.UTF_8);
+ }
+
+ private byte[] compress(byte[] data) throws Exception {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ DeflaterOutputStream dos = new DeflaterOutputStream(out);
+ dos.write(data);
+ dos.close();
+ return out.toByteArray();
+ }
+}