Implements the netty-based server transport.
Interface changes
- Moves halfClose() from Stream to ClientStream, because it's only valid on the
client and ServerStream already has close().
Implementation details
- Splits the cient-specific logic from AbstractStream and forms
AbstractClientStream.
- Creates AbstractServerStream for server-specific logic
- Creates NettyServerHandler which is the server counterpart of NettyClientHandler
- Refactors NettyClientHandlerTest and NettyClientStreamTest to share code with
NettyServerHandlerTest and NettyServerStreamTest
- Updated NettyServer to work with the transport.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=74347890
diff --git a/core/src/main/java/com/google/net/stubby/MethodDescriptor.java b/core/src/main/java/com/google/net/stubby/MethodDescriptor.java
index 482b42f..377ece0 100644
--- a/core/src/main/java/com/google/net/stubby/MethodDescriptor.java
+++ b/core/src/main/java/com/google/net/stubby/MethodDescriptor.java
@@ -127,4 +127,15 @@
ImmutableMap.<String, Provider<String>>builder().
putAll(headers).put(headerName, headerValueProvider).build());
}
+
+ /**
+ * Creates a new descriptor with additional bound headers.
+ */
+ public MethodDescriptor<RequestT, ResponseT> withHeaders(
+ ImmutableMap<String, Provider<String>> additionalHeaders) {
+ return new MethodDescriptor<RequestT, ResponseT>(type, name, timeoutMicros,
+ requestMarshaller, responseMarshaller,
+ ImmutableMap.<String, Provider<String>>builder().
+ putAll(headers).putAll(additionalHeaders).build());
+ }
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java
new file mode 100644
index 0000000..4ed5eee
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java
@@ -0,0 +1,97 @@
+package com.google.net.stubby.newtransport;
+
+import static com.google.net.stubby.newtransport.StreamState.CLOSED;
+import static com.google.net.stubby.newtransport.StreamState.OPEN;
+import static com.google.net.stubby.newtransport.StreamState.READ_ONLY;
+
+import com.google.common.base.Preconditions;
+import com.google.net.stubby.Status;
+
+import java.io.InputStream;
+
+/**
+ * The abstract base class for {@link ClientStream} implementations.
+ */
+public abstract class AbstractClientStream extends AbstractStream implements ClientStream {
+
+ private final StreamListener listener;
+
+ private Status status;
+
+ private final Object stateLock = new Object();
+ private volatile StreamState state = StreamState.OPEN;
+
+ protected AbstractClientStream(StreamListener listener) {
+ this.listener = Preconditions.checkNotNull(listener);
+ }
+
+ @Override
+ protected final StreamListener listener() {
+ return listener;
+ }
+
+ @Override
+ protected final GrpcMessageListener inboundMessageHandler() {
+ // Wraps the base handler to get status update.
+ final GrpcMessageListener delegate = super.inboundMessageHandler();
+ return new GrpcMessageListener() {
+ @Override
+ public void onContext(String name, InputStream value, int length) {
+ delegate.onContext(name, value, length);
+ }
+
+ @Override
+ public void onPayload(InputStream input, int length) {
+ delegate.onPayload(input, length);
+ }
+
+ @Override
+ public void onStatus(Status status) {
+ delegate.onStatus(status);
+ setStatus(status);
+ }
+ };
+ }
+
+ /**
+ * Sets the status if not already set and notifies the stream listener that the stream was closed.
+ * This method must be called from the transport thread.
+ *
+ * @param newStatus the new status to set
+ * @return {@code} true if the status was not already set.
+ */
+ public boolean setStatus(final Status newStatus) {
+ Preconditions.checkNotNull(newStatus, "newStatus");
+ synchronized (stateLock) {
+ if (status != null) {
+ // Disallow override of current status.
+ return false;
+ }
+
+ status = newStatus;
+ state = CLOSED;
+ }
+
+ // Invoke the observer callback.
+ listener.closed(newStatus);
+
+ // Free any resources.
+ dispose();
+
+ return true;
+ }
+
+ @Override
+ public final void halfClose() {
+ outboundPhase(Phase.STATUS);
+ synchronized (stateLock) {
+ state = state == OPEN ? READ_ONLY : CLOSED;
+ }
+ closeFramer(null);
+ }
+
+ @Override
+ public StreamState state() {
+ return state;
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/AbstractServerStream.java b/core/src/main/java/com/google/net/stubby/newtransport/AbstractServerStream.java
new file mode 100644
index 0000000..580f28f
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/AbstractServerStream.java
@@ -0,0 +1,126 @@
+package com.google.net.stubby.newtransport;
+
+import static com.google.net.stubby.newtransport.StreamState.CLOSED;
+import static com.google.net.stubby.newtransport.StreamState.OPEN;
+import static com.google.net.stubby.newtransport.StreamState.WRITE_ONLY;
+
+import com.google.common.base.Preconditions;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.transport.Transport;
+
+import java.io.InputStream;
+
+/**
+ * Abstract base class for {@link ServerStream} implementations.
+ */
+public abstract class AbstractServerStream extends AbstractStream implements ServerStream {
+
+ private StreamListener listener;
+
+ private final Object stateLock = new Object();
+ private volatile StreamState state = StreamState.OPEN;
+
+ @Override
+ protected final StreamListener listener() {
+ return listener;
+ }
+
+ public final void setListener(StreamListener listener) {
+ this.listener = Preconditions.checkNotNull(listener, "listener");
+ }
+
+ @Override
+ protected final GrpcMessageListener inboundMessageHandler() {
+ // Wraps the base handler to get status update.
+ final GrpcMessageListener delegate = super.inboundMessageHandler();
+ return new GrpcMessageListener() {
+ @Override
+ public void onContext(String name, InputStream value, int length) {
+ delegate.onContext(name, value, length);
+ }
+
+ @Override
+ public void onPayload(InputStream input, int length) {
+ delegate.onPayload(input, length);
+ }
+
+ @Override
+ public void onStatus(Status status) {
+ delegate.onStatus(status);
+ listener.closed(status);
+ }
+ };
+ }
+
+ @Override
+ public final void close(Status status) {
+ synchronized (stateLock) {
+ Preconditions.checkState(!status.isOk() || state == WRITE_ONLY,
+ "Cannot close with OK before client half-closes");
+ state = CLOSED;
+ }
+ outboundPhase(Phase.STATUS);
+ closeFramer(status);
+ dispose();
+ }
+
+ @Override
+ public StreamState state() {
+ return state;
+ }
+
+ /**
+ * Called when the remote end half-closes the stream.
+ */
+ public final void remoteEndClosed() {
+ StreamState previousState;
+ synchronized (stateLock) {
+ previousState = state;
+ if (previousState == OPEN) {
+ state = WRITE_ONLY;
+ }
+ }
+ if (previousState == OPEN) {
+ inboundPhase(Phase.STATUS);
+ listener.closed(Status.OK);
+ } else {
+ abortStream(
+ new Status(Transport.Code.FAILED_PRECONDITION, "Client-end of the stream already closed"),
+ true);
+ }
+ }
+
+ /**
+ * Aborts the stream with an error status, cleans up resources and notifies the listener if
+ * necessary.
+ *
+ * <p>Unlike {@link #close(Status)}, this method is only called from the gRPC framework, so that
+ * we need to call closed() on the listener if it has not been called.
+ *
+ * @param status the error status. Must not be Status.OK.
+ * @param notifyClient true if the stream is still writable and you want to notify the client
+ * about stream closure and send the status
+ */
+ public final void abortStream(Status status, boolean notifyClient) {
+ Preconditions.checkArgument(!status.isOk(), "status must not be OK");
+ StreamState previousState;
+ synchronized (stateLock) {
+ previousState = state;
+ if (state == CLOSED) {
+ return;
+ }
+ state = CLOSED;
+ }
+
+ if (previousState == OPEN) {
+ listener.closed(status);
+ } // Otherwise, previousState is WRITE_ONLY thus closed() has already been called.
+
+ outboundPhase(Phase.STATUS);
+ if (notifyClient) {
+ closeFramer(status);
+ }
+
+ dispose();
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/AbstractStream.java b/core/src/main/java/com/google/net/stubby/newtransport/AbstractStream.java
index 7321b8e..40d853a 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/AbstractStream.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/AbstractStream.java
@@ -1,9 +1,5 @@
package com.google.net.stubby.newtransport;
-import static com.google.net.stubby.newtransport.StreamState.CLOSED;
-import static com.google.net.stubby.newtransport.StreamState.OPEN;
-import static com.google.net.stubby.newtransport.StreamState.READ_ONLY;
-
import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import com.google.common.util.concurrent.ListenableFuture;
@@ -27,12 +23,8 @@
CONTEXT, MESSAGE, STATUS
}
- private volatile StreamState state = StreamState.OPEN;
- private Status status;
- private final Object stateLock = new Object();
private final Object writeLock = new Object();
private final MessageFramer framer;
- private final StreamListener listener;
protected Phase inboundPhase = Phase.CONTEXT;
protected Phase outboundPhase = Phase.CONTEXT;
@@ -55,7 +47,7 @@
ListenableFuture<Void> future = null;
try {
inboundPhase(Phase.CONTEXT);
- future = listener.contextRead(name, value, length);
+ future = listener().contextRead(name, value, length);
} finally {
closeWhenDone(future, value);
}
@@ -66,7 +58,7 @@
ListenableFuture<Void> future = null;
try {
inboundPhase(Phase.MESSAGE);
- future = listener.messageRead(input, length);
+ future = listener().messageRead(input, length);
} finally {
closeWhenDone(future, input);
}
@@ -75,34 +67,15 @@
@Override
public void onStatus(Status status) {
inboundPhase(Phase.STATUS);
- setStatus(status);
}
};
- protected AbstractStream(StreamListener listener) {
- this.listener = Preconditions.checkNotNull(listener, "listener");
-
+ protected AbstractStream() {
framer = new MessageFramer(outboundFrameHandler, 4096);
// No compression at the moment.
framer.setAllowCompression(false);
}
- @Override
- public StreamState state() {
- return state;
- }
-
- @Override
- public final void halfClose() {
- outboundPhase(Phase.STATUS);
- synchronized (stateLock) {
- state = state == OPEN ? READ_ONLY : CLOSED;
- }
- synchronized (writeLock) {
- framer.close();
- }
- }
-
/**
* Free any resources associated with this stream. Subclass implementations must call this
* version.
@@ -159,35 +132,7 @@
}
/**
- * Sets the status if not already set and notifies the stream listener that the stream was closed.
- * This method must be called from the transport thread.
- *
- * @param newStatus the new status to set
- * @return {@code} true if the status was not already set.
- */
- public boolean setStatus(final Status newStatus) {
- Preconditions.checkNotNull(newStatus, "newStatus");
- synchronized (stateLock) {
- if (status != null) {
- // Disallow override of current status.
- return false;
- }
-
- status = newStatus;
- state = CLOSED;
- }
-
- // Invoke the observer callback.
- listener.closed(newStatus);
-
- // Free any resources.
- dispose();
-
- return true;
- }
-
- /**
- * Sends an outbound frame to the server.
+ * Sends an outbound frame to the remote end point.
*
* @param frame a buffer containing the chunk of data to be sent.
* @param endOfStream if {@code true} indicates that no more data will be sent on the stream by
@@ -196,10 +141,15 @@
protected abstract void sendFrame(ByteBuffer frame, boolean endOfStream);
/**
+ * Returns the listener associated to this stream.
+ */
+ protected abstract StreamListener listener();
+
+ /**
* Gets the handler for inbound messages. Subclasses must use this as the target for a
* {@link com.google.net.stubby.newtransport.Deframer}.
*/
- protected final GrpcMessageListener inboundMessageHandler() {
+ protected GrpcMessageListener inboundMessageHandler() {
return inboundMessageHandler;
}
@@ -219,6 +169,24 @@
outboundPhase = verifyNextPhase(outboundPhase, nextPhase);
}
+ /**
+ * Closes the underlying framer.
+ *
+ * <p>No-op if the framer has already been closed.
+ *
+ * @param status if not null, will write the status to the framer before closing it
+ */
+ protected final void closeFramer(@Nullable Status status) {
+ synchronized (writeLock) {
+ if (!framer.isClosed()) {
+ if (status != null) {
+ framer.writeStatus(status);
+ }
+ framer.close();
+ }
+ }
+ }
+
private Phase verifyNextPhase(Phase currentPhase, Phase nextPhase) {
if (nextPhase.ordinal() < currentPhase.ordinal() || currentPhase == Phase.STATUS) {
throw new IllegalStateException(
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java
index 5f9df75..c5981f4 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java
@@ -11,4 +11,12 @@
* period until {@link StreamListener#closed} is called.
*/
void cancel();
+
+ /**
+ * Closes the local side of this stream and flushes any remaining messages. After this is called,
+ * no further messages may be sent on this stream, but additional messages may be received until
+ * the remote end-point is closed.
+ */
+ void halfClose();
+
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Stream.java b/core/src/main/java/com/google/net/stubby/newtransport/Stream.java
index 5ed9770..77dc1cc 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/Stream.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/Stream.java
@@ -17,13 +17,6 @@
StreamState state();
/**
- * Closes the local side of this stream and flushes any remaining messages. After this is called,
- * no further messages may be sent on this stream, but additional messages may be received until
- * the remote end-point is closed.
- */
- void halfClose();
-
- /**
* Writes the context name/value pair to the remote end-point. The bytes from the stream are
* immediate read by the Transport. This method will always return immediately and will not wait
* for the write to complete.
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
index 8ccc12c..b80e34c 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java
@@ -1,5 +1,6 @@
package com.google.net.stubby.newtransport;
+import javax.annotation.Nullable;
/**
* Utility functions for transport layer framing.
@@ -72,5 +73,19 @@
return (flags & FRAME_TYPE_MASK) == STATUS_FRAME;
}
+ // TODO(user): This needs proper namespacing support, this is currently just a hack
+ /**
+ * Converts the path from the HTTP request to the full qualified method name.
+ *
+ * @return null if the path is malformatted.
+ */
+ @Nullable
+ public static String getFullMethodNameFromPath(String path) {
+ if (!path.startsWith("/")) {
+ return null;
+ }
+ return path.substring(1);
+ }
+
private TransportFrameUtil() {}
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/http/HttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/http/HttpClientTransport.java
index 86fad43..d1b4bc7 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/http/HttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/http/HttpClientTransport.java
@@ -9,8 +9,8 @@
import com.google.common.io.ByteBuffers;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.AbstractClientStream;
import com.google.net.stubby.newtransport.AbstractClientTransport;
-import com.google.net.stubby.newtransport.AbstractStream;
import com.google.net.stubby.newtransport.ClientStream;
import com.google.net.stubby.newtransport.InputStreamDeframer;
import com.google.net.stubby.newtransport.StreamListener;
@@ -75,7 +75,7 @@
/**
* Client stream implementation for an HTTP transport.
*/
- private class HttpClientStream extends AbstractStream implements ClientStream {
+ private class HttpClientStream extends AbstractClientStream {
final HttpURLConnection connection;
final DataOutputStream outputStream;
boolean connected;
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java
index b3d2977..75a1cbf 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java
@@ -234,8 +234,7 @@
*/
private void sendGrpcFrame(ChannelHandlerContext ctx, SendGrpcFrameCommand cmd,
ChannelPromise promise) throws Http2Exception {
- NettyClientStream stream = cmd.stream();
- Http2Stream http2Stream = connection().requireStream(stream.id());
+ Http2Stream http2Stream = connection().requireStream(cmd.streamId());
switch (http2Stream.state()) {
case CLOSED:
case HALF_CLOSED_LOCAL:
@@ -250,7 +249,7 @@
}
// Call the base class to write the HTTP/2 DATA frame.
- writeData(ctx, stream.id(), cmd.content(), 0, cmd.endStream(), promise);
+ writeData(ctx, cmd.streamId(), cmd.content(), 0, cmd.endStream(), promise);
}
/**
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
index fc15c8d..888637d 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java
@@ -5,8 +5,7 @@
import com.google.common.base.Preconditions;
import com.google.net.stubby.Status;
-import com.google.net.stubby.newtransport.AbstractStream;
-import com.google.net.stubby.newtransport.ClientStream;
+import com.google.net.stubby.newtransport.AbstractClientStream;
import com.google.net.stubby.newtransport.GrpcDeframer;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamListener;
@@ -23,7 +22,7 @@
/**
* Client stream for a Netty transport.
*/
-class NettyClientStream extends AbstractStream implements ClientStream {
+class NettyClientStream extends AbstractClientStream implements NettyStream {
public static final int PENDING_STREAM_ID = -1;
private volatile int id = PENDING_STREAM_ID;
@@ -43,6 +42,7 @@
/**
* Returns the HTTP/2 ID for this stream.
*/
+ @Override
public int id() {
return id;
}
@@ -70,13 +70,7 @@
}
}
- /**
- * Called in the channel thread to process the content of an inbound DATA frame.
- *
- * @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must be
- * retained.
- * @param promise the promise to be set after the application has finished processing the frame.
- */
+ @Override
public void inboundDataReceived(ByteBuf frame, boolean endOfStream, ChannelPromise promise) {
Preconditions.checkNotNull(frame, "frame");
Preconditions.checkNotNull(promise, "promise");
@@ -107,20 +101,12 @@
@Override
protected void sendFrame(ByteBuffer frame, boolean endOfStream) {
- SendGrpcFrameCommand cmd = new SendGrpcFrameCommand(this, toByteBuf(frame), endOfStream);
+ SendGrpcFrameCommand cmd = new SendGrpcFrameCommand(id(),
+ Utils.toByteBuf(channel.alloc(), frame), endOfStream);
channel.writeAndFlush(cmd);
}
/**
- * Copies the content of the given {@link ByteBuffer} to a new {@link ByteBuf} instance.
- */
- private ByteBuf toByteBuf(ByteBuffer source) {
- ByteBuf buf = channel.alloc().buffer(source.remaining());
- buf.writeBytes(source);
- return buf;
- }
-
- /**
* Determines whether or not the response from the server is a GRPC response.
*/
private static boolean isGrpcResponse(Http2Headers headers, Transport.Code code) {
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java
index 630cdd7..a5769d1 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java
@@ -5,6 +5,8 @@
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.AbstractService;
+import com.google.net.stubby.newtransport.ServerListener;
+import com.google.net.stubby.newtransport.ServerTransportListener;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
@@ -15,6 +17,8 @@
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.http2.DefaultHttp2Connection;
+import io.netty.handler.codec.http2.DefaultHttp2StreamRemovalPolicy;
/**
* Implementation of the {@link com.google.common.util.concurrent.Service} interface for a
@@ -27,18 +31,25 @@
private final EventLoopGroup workerGroup;
private Channel channel;
- public NettyServer(int port, ChannelInitializer<SocketChannel> channelInitializer) {
- this(port, channelInitializer, new NioEventLoopGroup(), new NioEventLoopGroup());
+ public NettyServer(ServerListener serverListener, int port) {
+ this(serverListener, port, new NioEventLoopGroup(), new NioEventLoopGroup());
}
- public NettyServer(int port, ChannelInitializer<SocketChannel> channelInitializer,
- EventLoopGroup bossGroup, EventLoopGroup workerGroup) {
- Preconditions.checkNotNull(channelInitializer, "channelInitializer");
+ public NettyServer(final ServerListener serverListener, int port, EventLoopGroup bossGroup,
+ EventLoopGroup workerGroup) {
Preconditions.checkNotNull(bossGroup, "bossGroup");
Preconditions.checkNotNull(workerGroup, "workerGroup");
Preconditions.checkArgument(port >= 0, "port must be positive");
this.port = port;
- this.channelInitializer = channelInitializer;
+ this.channelInitializer = new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) throws Exception {
+ // TODO(user): pass a real transport object
+ ServerTransportListener transportListener = serverListener.transportCreated(null);
+ ch.pipeline().addLast(new NettyServerHandler(transportListener,
+ new DefaultHttp2Connection(true, new DefaultHttp2StreamRemovalPolicy())));
+ }
+ };
this.bossGroup = bossGroup;
this.workerGroup = workerGroup;
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java
new file mode 100644
index 0000000..6fbe4fa
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java
@@ -0,0 +1,223 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER;
+import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC;
+import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.ServerTransportListener;
+import com.google.net.stubby.newtransport.StreamListener;
+import com.google.net.stubby.newtransport.TransportFrameUtil;
+import com.google.net.stubby.transport.Transport;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.http2.AbstractHttp2ConnectionHandler;
+import io.netty.handler.codec.http2.DefaultHttp2Headers;
+import io.netty.handler.codec.http2.Http2Connection;
+import io.netty.handler.codec.http2.Http2ConnectionAdapter;
+import io.netty.handler.codec.http2.Http2Error;
+import io.netty.handler.codec.http2.Http2Exception;
+import io.netty.handler.codec.http2.Http2Headers;
+import io.netty.handler.codec.http2.Http2Stream;
+import io.netty.handler.codec.http2.Http2StreamException;
+import io.netty.util.ReferenceCountUtil;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import javax.inject.Provider;
+
+/**
+ * Server-side Netty handler for GRPC processing. All event handlers are executed entirely within
+ * the context of the Netty Channel thread.
+ */
+class NettyServerHandler extends AbstractHttp2ConnectionHandler {
+
+ private static Logger logger = Logger.getLogger(NettyServerHandler.class.getName());
+
+ private static final Status GOAWAY_STATUS = new Status(Transport.Code.UNAVAILABLE);
+
+ private final ServerTransportListener transportListener;
+
+ NettyServerHandler(ServerTransportListener transportListener, Http2Connection connection) {
+ super(connection);
+ this.transportListener = transportListener;
+
+ // Observe the HTTP/2 connection for events.
+ connection.addListener(new Http2ConnectionAdapter() {
+ @Override
+ public void streamHalfClosed(Http2Stream stream) {
+ if (stream.state() == Http2Stream.State.HALF_CLOSED_REMOTE) {
+ serverStream(stream).remoteEndClosed();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void onHeadersRead(ChannelHandlerContext ctx,
+ int streamId,
+ Http2Headers headers,
+ int streamDependency,
+ short weight,
+ boolean exclusive,
+ int padding,
+ boolean endStream) throws Http2Exception {
+ try {
+ NettyServerStream stream = new NettyServerStream(ctx.channel(), streamId);
+ // The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this method.
+ Http2Stream http2Stream = connection().requireStream(streamId);
+ http2Stream.data(stream);
+ MethodDescriptor<?, ?> method = createMethod(streamId, headers);
+ StreamListener listener = transportListener.streamCreated(stream, method);
+ stream.setListener(listener);
+ } catch (Http2Exception e) {
+ throw e;
+ } catch (Throwable e) {
+ logger.log(Level.WARNING, "Exception in onHeadersRead()", e);
+ throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString());
+ }
+ }
+
+ @Override
+ public void onDataRead(ChannelHandlerContext ctx,
+ int streamId,
+ ByteBuf data,
+ int padding,
+ boolean endOfStream) throws Http2Exception {
+ try {
+ NettyServerStream stream = serverStream(connection().requireStream(streamId));
+ // TODO(user): update flow controller to use a promise
+ stream.inboundDataReceived(data, endOfStream, ctx.newPromise());
+ } catch (Http2Exception e) {
+ throw e;
+ } catch (Throwable e) {
+ logger.log(Level.WARNING, "Exception in onDataRead()", e);
+ throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString());
+ }
+ }
+
+ @Override
+ public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode)
+ throws Http2Exception {
+ try {
+ NettyServerStream stream = serverStream(connection().requireStream(streamId));
+ stream.abortStream(Status.CANCELLED, false);
+ } catch (Http2Exception e) {
+ throw e;
+ } catch (Throwable e) {
+ logger.log(Level.WARNING, "Exception in onRstStreamRead()", e);
+ throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString());
+ }
+ }
+
+ /**
+ * Handler for stream errors that have occurred during HTTP/2 frame processing.
+ *
+ * <p>When a callback method of this class throws an Http2StreamException,
+ * it will be handled by this method. Other types of exceptions will be handled by
+ * {@link #onConnectionError(ChannelHandlerContext, Http2Exception)} from the base class. The
+ * catch-all logic is in {@link #decode(ChannelHandlerContext, ByteBuf, List)} from the base class.
+ */
+ @Override
+ protected void onStreamError(ChannelHandlerContext ctx, Http2StreamException cause) {
+ // Aborts the stream with a status that contains the cause.
+ Http2Stream stream = connection().stream(cause.streamId());
+ if (stream != null) {
+ // Send the error message to the client to help debugging.
+ serverStream(stream).abortStream(Status.fromThrowable(cause), true);
+ } else {
+ // Only call the base class if we cannot anything about it.
+ super.onStreamError(ctx, cause);
+ }
+ }
+
+ /**
+ * Handler for the Channel shutting down
+ */
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ super.channelInactive(ctx);
+ // Any streams that are still active must be closed
+ for (Http2Stream stream : connection().activeStreams()) {
+ serverStream(stream).abortStream(GOAWAY_STATUS, false);
+ }
+ }
+
+ /**
+ * Handler for commands sent from the stream.
+ */
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
+ if (msg instanceof SendGrpcFrameCommand) {
+ SendGrpcFrameCommand cmd = (SendGrpcFrameCommand) msg;
+ // Call the base class to write the HTTP/2 DATA frame.
+ writeData(ctx, cmd.streamId(), cmd.content(), 0, cmd.endStream(), promise);
+ } else if (msg instanceof SendResponseHeadersCommand) {
+ SendResponseHeadersCommand cmd = (SendResponseHeadersCommand) msg;
+ writeHeaders(
+ ctx, cmd.streamId(),
+ DefaultHttp2Headers.newBuilder()
+ .status("200")
+ .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC)
+ .build(),
+ 0, false, promise);
+ } else {
+ AssertionError e = new AssertionError("Write called for unexpected type: "
+ + msg.getClass().getName());
+ ReferenceCountUtil.release(msg);
+ promise.setFailure(e);
+ throw e;
+ }
+ }
+
+ private MethodDescriptor<?, ?> createMethod(int streamId, Http2Headers headers)
+ throws Http2StreamException {
+ if (!HTTP_METHOD.equals(headers.method())) {
+ throw new Http2StreamException(streamId, Http2Error.REFUSED_STREAM,
+ String.format("Method '%s' is not supported", headers.method()));
+ }
+ if (!CONTENT_TYPE_PROTORPC.equals(headers.get(CONTENT_TYPE_HEADER))) {
+ throw new Http2StreamException(streamId, Http2Error.REFUSED_STREAM,
+ String.format("Header '%s'='%s', while '%s' is expected", CONTENT_TYPE_HEADER,
+ headers.get(CONTENT_TYPE_HEADER), CONTENT_TYPE_PROTORPC));
+ }
+ String methodName = TransportFrameUtil.getFullMethodNameFromPath(headers.path());
+ if (methodName == null) {
+ throw new Http2StreamException(streamId, Http2Error.REFUSED_STREAM,
+ String.format("Malformatted path: %s", headers.path()));
+ }
+ // TODO(user): pass the real timeout
+ MethodDescriptor<?, ?> method = MethodDescriptor.create(
+ MethodDescriptor.Type.UNKNOWN, methodName, 1, TimeUnit.SECONDS, null, null);
+ ImmutableMap.Builder<String, Provider<String>> grpcHeaders =
+ new ImmutableMap.Builder<String, Provider<String>>();
+ for (Map.Entry<String, String> header : headers) {
+ if (!header.getKey().startsWith(":")) {
+ final String value = header.getValue();
+ // headers starting with ":" are reserved for HTTP/2 built-in headers
+ grpcHeaders.put(header.getKey(), new Provider<String>() {
+ @Override
+ public String get() {
+ return value;
+ }
+ });
+ }
+ }
+ return method.withHeaders(grpcHeaders.build());
+ }
+
+ /**
+ * Returns the server stream associated to the given HTTP/2 stream object
+ */
+ private NettyServerStream serverStream(Http2Stream stream) {
+ return stream.<NettyServerStream>data();
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java
new file mode 100644
index 0000000..99a0ac8
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java
@@ -0,0 +1,60 @@
+package com.google.net.stubby.newtransport.netty;
+
+import com.google.common.base.Preconditions;
+import com.google.net.stubby.newtransport.AbstractServerStream;
+import com.google.net.stubby.newtransport.GrpcDeframer;
+import com.google.net.stubby.newtransport.StreamState;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelPromise;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Server stream for a Netty transport
+ */
+class NettyServerStream extends AbstractServerStream implements NettyStream {
+
+ private final GrpcDeframer deframer;
+ private final Channel channel;
+ private final int id;
+
+ private boolean headersSent;
+
+ NettyServerStream(Channel channel, int id) {
+ this.channel = Preconditions.checkNotNull(channel, "channel is null");
+ this.id = id;
+ this.deframer = new GrpcDeframer(new NettyDecompressor(channel.alloc()),
+ inboundMessageHandler());
+ }
+
+ @Override
+ public void inboundDataReceived(ByteBuf frame, boolean endOfStream, ChannelPromise promise) {
+ if (state() == StreamState.CLOSED) {
+ promise.setSuccess();
+ return;
+ }
+ // Retain the ByteBuf until it is released by the deframer.
+ // TODO(user): It sounds sub-optimal to deframe in the network thread. That means
+ // decompression is serialized.
+ deframer.deframe(new NettyBuffer(frame.retain()), endOfStream);
+ promise.setSuccess();
+ }
+
+ @Override
+ protected void sendFrame(ByteBuffer frame, boolean endOfStream) {
+ if (!headersSent) {
+ channel.write(new SendResponseHeadersCommand(id));
+ headersSent = true;
+ }
+ SendGrpcFrameCommand cmd =
+ new SendGrpcFrameCommand(id, Utils.toByteBuf(channel.alloc(), frame), endOfStream);
+ channel.writeAndFlush(cmd);
+ }
+
+ @Override
+ public int id() {
+ return id;
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyStream.java
new file mode 100644
index 0000000..68e36e2
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyStream.java
@@ -0,0 +1,27 @@
+package com.google.net.stubby.newtransport.netty;
+
+import com.google.net.stubby.newtransport.Stream;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelPromise;
+
+/**
+ * A common interface shared between NettyClientStream and NettyServerStream.
+ */
+interface NettyStream extends Stream {
+
+ /**
+ * Called in the network thread to process the content of an inbound DATA frame.
+ *
+ * @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must
+ * be retained.
+ * @param promise the promise to be set after the application has finished
+ * processing the frame.
+ */
+ void inboundDataReceived(ByteBuf frame, boolean endOfStream, ChannelPromise promise);
+
+ /**
+ * Returns the HTTP/2 stream ID.
+ */
+ int id();
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java
index 7c348b8..1d31eef 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java
@@ -1,7 +1,5 @@
package com.google.net.stubby.newtransport.netty;
-import com.google.common.base.Preconditions;
-
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.DefaultByteBufHolder;
@@ -10,17 +8,17 @@
* Command sent from the transport to the Netty channel to send a GRPC frame to the remote endpoint.
*/
class SendGrpcFrameCommand extends DefaultByteBufHolder {
- private final NettyClientStream stream;
+ private final int streamId;
private final boolean endStream;
- SendGrpcFrameCommand(NettyClientStream stream, ByteBuf content, boolean endStream) {
+ SendGrpcFrameCommand(int streamId, ByteBuf content, boolean endStream) {
super(content);
- this.stream = Preconditions.checkNotNull(stream, "stream");
+ this.streamId = streamId;
this.endStream = endStream;
}
- NettyClientStream stream() {
- return stream;
+ int streamId() {
+ return streamId;
}
boolean endStream() {
@@ -29,12 +27,12 @@
@Override
public ByteBufHolder copy() {
- return new SendGrpcFrameCommand(stream, content().copy(), endStream);
+ return new SendGrpcFrameCommand(streamId, content().copy(), endStream);
}
@Override
public ByteBufHolder duplicate() {
- return new SendGrpcFrameCommand(stream, content().duplicate(), endStream);
+ return new SendGrpcFrameCommand(streamId, content().duplicate(), endStream);
}
@Override
@@ -60,4 +58,31 @@
super.touch(hint);
return this;
}
+
+ @Override
+ public boolean equals(Object that) {
+ if (that == null || !that.getClass().equals(SendGrpcFrameCommand.class)) {
+ return false;
+ }
+ SendGrpcFrameCommand thatCmd = (SendGrpcFrameCommand) that;
+ return thatCmd.streamId == streamId && thatCmd.endStream == endStream
+ && thatCmd.content().equals(content());
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + "(streamId=" + streamId
+ + ", endStream=" + endStream + ", content=" + content()
+ + ")";
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = content().hashCode();
+ hash = hash * 31 + streamId;
+ if (endStream) {
+ hash = -hash;
+ }
+ return hash;
+ }
}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java
new file mode 100644
index 0000000..2c761b4
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java
@@ -0,0 +1,35 @@
+package com.google.net.stubby.newtransport.netty;
+
+/**
+ * Command sent from the transport to the Netty channel to send response headers to the client.
+ */
+class SendResponseHeadersCommand {
+ private final int streamId;
+
+ SendResponseHeadersCommand(int streamId) {
+ this.streamId = streamId;
+ }
+
+ int streamId() {
+ return streamId;
+ }
+
+ @Override
+ public boolean equals(Object that) {
+ if (that == null || !that.getClass().equals(SendResponseHeadersCommand.class)) {
+ return false;
+ }
+ SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that;
+ return thatCmd.streamId == streamId;
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + "(streamId=" + streamId + ")";
+ }
+
+ @Override
+ public int hashCode() {
+ return streamId;
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java
new file mode 100644
index 0000000..a06e056
--- /dev/null
+++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java
@@ -0,0 +1,25 @@
+package com.google.net.stubby.newtransport.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Common utility methods.
+ */
+class Utils {
+
+ /**
+ * Copies the content of the given {@link ByteBuffer} to a new {@link ByteBuf} instance.
+ */
+ static ByteBuf toByteBuf(ByteBufAllocator alloc, ByteBuffer source) {
+ ByteBuf buf = alloc.buffer(source.remaining());
+ buf.writeBytes(source);
+ return buf;
+ }
+
+ private Utils() {
+ // Prevents instantiation
+ }
+}
diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
index d9f3be7..1690a6e 100644
--- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
+++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java
@@ -7,8 +7,8 @@
import com.google.common.io.ByteStreams;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.AbstractClientStream;
import com.google.net.stubby.newtransport.AbstractClientTransport;
-import com.google.net.stubby.newtransport.AbstractStream;
import com.google.net.stubby.newtransport.ClientStream;
import com.google.net.stubby.newtransport.ClientTransport;
import com.google.net.stubby.newtransport.InputStreamDeframer;
@@ -25,11 +25,11 @@
import com.squareup.okhttp.internal.spdy.Settings;
import com.squareup.okhttp.internal.spdy.Variant;
-import okio.ByteString;
+import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
+import okio.ByteString;
import okio.Okio;
-import okio.Buffer;
import java.io.IOException;
import java.net.Socket;
@@ -383,7 +383,7 @@
* Client stream for the okhttp transport.
*/
@VisibleForTesting
- class OkHttpClientStream extends AbstractStream implements ClientStream {
+ class OkHttpClientStream extends AbstractClientStream {
int streamId;
final InputStreamDeframer deframer;
int unacknowledgedBytesRead;
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java
index f5e275c..bb45ded 100644
--- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java
@@ -9,7 +9,6 @@
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.calls;
-import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
@@ -22,6 +21,17 @@
import com.google.net.stubby.newtransport.StreamState;
import com.google.net.stubby.transport.Transport;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
+import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
+import io.netty.handler.codec.http2.DefaultHttp2Headers;
+import io.netty.handler.codec.http2.Http2CodecUtil;
+import io.netty.handler.codec.http2.Http2Error;
+import io.netty.handler.codec.http2.Http2Headers;
+import io.netty.handler.codec.http2.Http2Settings;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -29,61 +39,23 @@
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock;
-import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import io.netty.handler.codec.http2.DefaultHttp2Headers;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.buffer.UnpooledByteBufAllocator;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
-import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
-import io.netty.handler.codec.http2.Http2CodecUtil;
-import io.netty.handler.codec.http2.Http2Error;
-import io.netty.handler.codec.http2.Http2FrameListener;
-import io.netty.handler.codec.http2.Http2FrameReader;
-import io.netty.handler.codec.http2.Http2FrameWriter;
-import io.netty.handler.codec.http2.Http2Headers;
-import io.netty.handler.codec.http2.Http2Settings;
/**
* Tests for {@link NettyClientHandler}.
*/
@RunWith(JUnit4.class)
-public class NettyClientHandlerTest {
+public class NettyClientHandlerTest extends NettyHandlerTestBase {
private NettyClientHandler handler;
- @Mock
- private Channel channel;
-
- @Mock
- private ChannelHandlerContext ctx;
-
- @Mock
- private ChannelFuture future;
-
- @Mock
- private ChannelPromise promise;
-
+ // TODO(user): mocking concrete classes is not safe. Consider making NettyClientStream an
+ // interface.
@Mock
private NettyClientStream stream;
@Mock
private MethodDescriptor<?, ?> method;
-
- @Mock
- private Http2FrameListener frameListener;
-
- private Http2FrameWriter frameWriter;
- private Http2FrameReader frameReader;
private ByteBuf content;
@Before
@@ -166,7 +138,7 @@
createStream();
// Send a frame and verify that it was written.
- handler.write(ctx, new SendGrpcFrameCommand(stream, content, true), promise);
+ handler.write(ctx, new SendGrpcFrameCommand(stream.id(), content, true), promise);
verify(promise, never()).setFailure(any(Throwable.class));
verify(ctx).write(any(ByteBuf.class), eq(promise));
verify(ctx).flush();
@@ -175,7 +147,7 @@
@Test
public void sendForUnknownStreamShouldFail() throws Exception {
when(stream.id()).thenReturn(3);
- handler.write(ctx, new SendGrpcFrameCommand(stream, content, true), promise);
+ handler.write(ctx, new SendGrpcFrameCommand(stream.id(), content, true), promise);
verify(promise).setFailure(any(Throwable.class));
}
@@ -249,12 +221,6 @@
mockContext();
}
- private ByteBuf headersFrame(int streamId, Http2Headers headers) {
- ChannelHandlerContext ctx = newContext();
- frameWriter.writeHeaders(ctx, streamId, headers, 0, false, promise);
- return captureWrite(ctx);
- }
-
private ByteBuf dataFrame(int streamId, boolean endStream) {
// Need to retain the content since the frameWriter releases it.
content.retain();
@@ -263,40 +229,6 @@
return captureWrite(ctx);
}
- private ByteBuf goAwayFrame(int lastStreamId) {
- ChannelHandlerContext ctx = newContext();
- frameWriter.writeGoAway(ctx, lastStreamId, 0, Unpooled.EMPTY_BUFFER, newPromise());
- return captureWrite(ctx);
- }
-
- private ByteBuf rstStreamFrame(int streamId, int errorCode) {
- ChannelHandlerContext ctx = newContext();
- frameWriter.writeRstStream(ctx, streamId, errorCode, newPromise());
- return captureWrite(ctx);
- }
-
- private ByteBuf serializeSettings(Http2Settings settings) {
- ChannelHandlerContext ctx = newContext();
- frameWriter.writeSettings(ctx, settings, newPromise());
- return captureWrite(ctx);
- }
-
- private ChannelHandlerContext newContext() {
- ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class);
- when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
- return ctx;
- }
-
- private ChannelPromise newPromise() {
- return Mockito.mock(ChannelPromise.class);
- }
-
- private ByteBuf captureWrite(ChannelHandlerContext ctx) {
- ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
- verify(ctx).write(captor.capture(), any(ChannelPromise.class));
- return captor.getValue();
- }
-
private void createStream() throws Exception {
// Create the stream.
handler.write(ctx, new CreateStreamCommand(method, stream), promise);
@@ -304,34 +236,4 @@
// Reset the context mock to clear recording of sent headers frame.
mockContext();
}
-
- private void mockContext() {
- Mockito.reset(ctx);
- Mockito.reset(promise);
- when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
- when(ctx.channel()).thenReturn(channel);
- when(ctx.write(any())).thenReturn(future);
- when(ctx.write(any(), eq(promise))).thenReturn(future);
- when(ctx.writeAndFlush(any())).thenReturn(future);
- when(ctx.writeAndFlush(any(), eq(promise))).thenReturn(future);
- when(ctx.newPromise()).thenReturn(promise);
- }
-
- private void mockFuture(boolean succeeded) {
- when(future.isDone()).thenReturn(true);
- when(future.isCancelled()).thenReturn(false);
- when(future.isSuccess()).thenReturn(succeeded);
- if (!succeeded) {
- when(future.cause()).thenReturn(new Exception("fake"));
- }
-
- doAnswer(new Answer<ChannelFuture>() {
- @Override
- public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
- ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0];
- listener.operationComplete(future);
- return future;
- }
- }).when(future).addListener(any(ChannelFutureListener.class));
- }
}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java
index 54b68a3..d47e104 100644
--- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java
@@ -1,97 +1,44 @@
package com.google.net.stubby.newtransport.netty;
-import static com.google.net.stubby.GrpcFramingUtil.CONTEXT_VALUE_FRAME;
-import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME;
-import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-import com.google.common.io.ByteStreams;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.HttpUtil;
-import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.newtransport.StreamState;
import com.google.net.stubby.transport.Transport;
-import com.google.net.stubby.transport.Transport.ContextValue;
-import com.google.protobuf.ByteString;
+
+import io.netty.buffer.Unpooled;
+import io.netty.handler.codec.http2.DefaultHttp2Headers;
+import io.netty.handler.codec.http2.Http2Headers;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import io.netty.handler.codec.http2.DefaultHttp2Headers;
-import io.netty.handler.codec.http2.Http2Headers;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.buffer.UnpooledByteBufAllocator;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelPromise;
-import io.netty.channel.EventLoop;
-
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.DataOutputStream;
-import java.io.InputStream;
-import java.util.concurrent.TimeUnit;
/**
* Tests for {@link NettyClientStream}.
*/
@RunWith(JUnit4.class)
-public class NettyClientStreamTest {
- private static final String CONTEXT_KEY = "key";
- private static final String MESSAGE = "hello world";
-
+public class NettyClientStreamTest extends NettyStreamTestBase {
private NettyClientStream stream;
- @Mock
- private StreamListener listener;
-
- @Mock
- private Channel channel;
-
- @Mock
- private ChannelFuture future;
-
- @Mock
- private ChannelPromise promise;
-
- @Mock
- private EventLoop eventLoop;
-
- private InputStream input;
-
- @Mock
- private Runnable accepted;
-
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
-
- mockChannelFuture(true);
- when(channel.write(any())).thenReturn(future);
- when(channel.writeAndFlush(any())).thenReturn(future);
- when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
- when(channel.eventLoop()).thenReturn(eventLoop);
- when(eventLoop.inEventLoop()).thenReturn(true);
+ init();
stream = new NettyClientStream(listener, channel);
assertEquals(StreamState.OPEN, stream.state());
- input = new ByteArrayInputStream(MESSAGE.getBytes(UTF_8));
+ }
+
+ @Override
+ protected NettyStream stream() {
+ return stream;
}
@Test
@@ -114,10 +61,7 @@
stream.id(1);
stream.writeContext(CONTEXT_KEY, input, input.available(), accepted);
stream.flush();
- ArgumentCaptor<SendGrpcFrameCommand> captor =
- ArgumentCaptor.forClass(SendGrpcFrameCommand.class);
- verify(channel).writeAndFlush(captor.capture());
- assertEquals(contextFrame(), captor.getValue().content());
+ verify(channel).writeAndFlush(new SendGrpcFrameCommand(1, contextFrame(), false));
verify(accepted).run();
}
@@ -127,10 +71,7 @@
stream.id(1);
stream.writeMessage(input, input.available(), accepted);
stream.flush();
- ArgumentCaptor<SendGrpcFrameCommand> captor =
- ArgumentCaptor.forClass(SendGrpcFrameCommand.class);
- verify(channel).writeAndFlush(captor.capture());
- assertEquals(messageFrame(), captor.getValue().content());
+ verify(channel).writeAndFlush(new SendGrpcFrameCommand(1, messageFrame(), false));
verify(accepted).run();
}
@@ -168,28 +109,20 @@
assertEquals(StreamState.CLOSED, stream.state());
}
+ @Override
@Test
public void inboundContextShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders(), false);
-
- stream.inboundDataReceived(contextFrame(), false, promise);
- ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
- verify(listener).contextRead(eq(CONTEXT_KEY), captor.capture(), eq(MESSAGE.length()));
- verify(promise).setSuccess();
- assertEquals(MESSAGE, toString(captor.getValue()));
+ super.inboundContextShouldCallListener();
}
+ @Override
@Test
public void inboundMessageShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders(), false);
-
- stream.inboundDataReceived(messageFrame(), false, promise);
- ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
- verify(listener).messageRead(captor.capture(), eq(MESSAGE.length()));
- verify(promise).setSuccess();
- assertEquals(MESSAGE, toString(captor.getValue()));
+ super.inboundMessageShouldCallListener();
}
@Test
@@ -199,7 +132,7 @@
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders(), false);
- stream.inboundDataReceived(statusFrame(), false, promise);
+ stream.inboundDataReceived(statusFrame(new Status(Transport.Code.INTERNAL)), false, promise);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture());
assertEquals(Transport.Code.INTERNAL, captor.getValue().getCode());
@@ -215,83 +148,8 @@
assertEquals(MESSAGE, captor.getValue().getDescription());
}
- private String toString(InputStream in) throws Exception {
- byte[] bytes = new byte[in.available()];
- ByteStreams.readFully(in, bytes);
- return new String(bytes, UTF_8);
- }
-
- private ByteBuf contextFrame() throws Exception {
- byte[] body = ContextValue
- .newBuilder()
- .setKey(CONTEXT_KEY)
- .setValue(ByteString.copyFromUtf8(MESSAGE))
- .build()
- .toByteArray();
- ByteArrayOutputStream os = new ByteArrayOutputStream();
- DataOutputStream dos = new DataOutputStream(os);
- dos.write(CONTEXT_VALUE_FRAME);
- dos.writeInt(body.length);
- dos.write(body);
- dos.close();
-
- // Write the compression header followed by the context frame.
- return compressionFrame(os.toByteArray());
- }
-
- private ByteBuf messageFrame() throws Exception {
- ByteArrayOutputStream os = new ByteArrayOutputStream();
- DataOutputStream dos = new DataOutputStream(os);
- dos.write(PAYLOAD_FRAME);
- dos.writeInt(MESSAGE.length());
- dos.write(MESSAGE.getBytes(UTF_8));
- dos.close();
-
- // Write the compression header followed by the context frame.
- return compressionFrame(os.toByteArray());
- }
-
- private ByteBuf statusFrame() throws Exception {
- ByteArrayOutputStream os = new ByteArrayOutputStream();
- DataOutputStream dos = new DataOutputStream(os);
- short code = (short) Transport.Code.INTERNAL.getNumber();
- dos.write(STATUS_FRAME);
- int length = 2;
- dos.writeInt(length);
- dos.writeShort(code);
-
- // Write the compression header followed by the context frame.
- return compressionFrame(os.toByteArray());
- }
-
- private ByteBuf compressionFrame(byte[] data) {
- ByteBuf buf = Unpooled.buffer();
- buf.writeInt(data.length);
- buf.writeBytes(data);
- return buf;
- }
-
private Http2Headers grpcResponseHeaders() {
return DefaultHttp2Headers.newBuilder().status("200")
.set(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC).build();
}
-
- private void mockChannelFuture(boolean succeeded) {
- when(future.isDone()).thenReturn(true);
- when(future.isCancelled()).thenReturn(false);
- when(future.isSuccess()).thenReturn(succeeded);
- when(future.awaitUninterruptibly(anyLong(), any(TimeUnit.class))).thenReturn(true);
- if (!succeeded) {
- when(future.cause()).thenReturn(new Exception("fake"));
- }
-
- doAnswer(new Answer<ChannelFuture>() {
- @Override
- public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
- ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0];
- listener.operationComplete(future);
- return future;
- }
- }).when(future).addListener(any(ChannelFutureListener.class));
- }
}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyHandlerTestBase.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyHandlerTestBase.java
new file mode 100644
index 0000000..1848178
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyHandlerTestBase.java
@@ -0,0 +1,124 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.http2.Http2FrameListener;
+import io.netty.handler.codec.http2.Http2FrameReader;
+import io.netty.handler.codec.http2.Http2FrameWriter;
+import io.netty.handler.codec.http2.Http2Headers;
+import io.netty.handler.codec.http2.Http2Settings;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/**
+ * Base class for Netty handler unit tests.
+ */
+@RunWith(JUnit4.class)
+public abstract class NettyHandlerTestBase {
+
+ @Mock
+ protected Channel channel;
+
+ @Mock
+ protected ChannelHandlerContext ctx;
+
+ @Mock
+ protected ChannelFuture future;
+
+ @Mock
+ protected ChannelPromise promise;
+
+ @Mock
+ protected Http2FrameListener frameListener;
+
+ protected Http2FrameWriter frameWriter;
+ protected Http2FrameReader frameReader;
+
+ protected final ByteBuf headersFrame(int streamId, Http2Headers headers) {
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeHeaders(ctx, streamId, headers, 0, false, promise);
+ return captureWrite(ctx);
+ }
+
+ protected final ByteBuf goAwayFrame(int lastStreamId) {
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeGoAway(ctx, lastStreamId, 0, Unpooled.EMPTY_BUFFER, newPromise());
+ return captureWrite(ctx);
+ }
+
+ protected final ByteBuf rstStreamFrame(int streamId, int errorCode) {
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeRstStream(ctx, streamId, errorCode, newPromise());
+ return captureWrite(ctx);
+ }
+
+ protected final ByteBuf serializeSettings(Http2Settings settings) {
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeSettings(ctx, settings, newPromise());
+ return captureWrite(ctx);
+ }
+
+ protected final ChannelHandlerContext newContext() {
+ ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class);
+ when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+ return ctx;
+ }
+
+ protected final ChannelPromise newPromise() {
+ return Mockito.mock(ChannelPromise.class);
+ }
+
+ protected final ByteBuf captureWrite(ChannelHandlerContext ctx) {
+ ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
+ verify(ctx).write(captor.capture(), any(ChannelPromise.class));
+ return captor.getValue();
+ }
+
+ protected final void mockContext() {
+ Mockito.reset(ctx);
+ Mockito.reset(promise);
+ when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+ when(ctx.channel()).thenReturn(channel);
+ when(ctx.write(any())).thenReturn(future);
+ when(ctx.write(any(), eq(promise))).thenReturn(future);
+ when(ctx.writeAndFlush(any())).thenReturn(future);
+ when(ctx.writeAndFlush(any(), eq(promise))).thenReturn(future);
+ when(ctx.newPromise()).thenReturn(promise);
+ }
+
+ protected final void mockFuture(boolean succeeded) {
+ when(future.isDone()).thenReturn(true);
+ when(future.isCancelled()).thenReturn(false);
+ when(future.isSuccess()).thenReturn(succeeded);
+ if (!succeeded) {
+ when(future.cause()).thenReturn(new Exception("fake"));
+ }
+
+ doAnswer(new Answer<ChannelFuture>() {
+ @Override
+ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
+ ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0];
+ listener.operationComplete(future);
+ return future;
+ }
+ }).when(future).addListener(any(ChannelFutureListener.class));
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java
new file mode 100644
index 0000000..f8f4367
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java
@@ -0,0 +1,191 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import com.google.common.io.ByteStreams;
+import com.google.net.stubby.MethodDescriptor;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.Framer;
+import com.google.net.stubby.newtransport.HttpUtil;
+import com.google.net.stubby.newtransport.MessageFramer;
+import com.google.net.stubby.newtransport.ServerStream;
+import com.google.net.stubby.newtransport.ServerTransportListener;
+import com.google.net.stubby.newtransport.StreamListener;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http2.DefaultHttp2Connection;
+import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
+import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
+import io.netty.handler.codec.http2.DefaultHttp2Headers;
+import io.netty.handler.codec.http2.Http2CodecUtil;
+import io.netty.handler.codec.http2.Http2Error;
+import io.netty.handler.codec.http2.Http2Headers;
+import io.netty.handler.codec.http2.Http2Settings;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/** Unit tests for {@link NettyServerHandler}. */
+@RunWith(JUnit4.class)
+public class NettyServerHandlerTest extends NettyHandlerTestBase {
+
+ private static final int STREAM_ID = 3;
+ private static final byte[] CONTENT = "hello world".getBytes(UTF_8);
+
+ @Mock private ServerTransportListener transportListener;
+
+ @Mock private StreamListener streamListener;
+
+ private NettyServerStream stream;
+
+ private NettyServerHandler handler;
+
+ @Before
+ public void setup() throws Exception {
+ MockitoAnnotations.initMocks(this);
+
+ when(transportListener.streamCreated(any(ServerStream.class), any(MethodDescriptor.class)))
+ .thenReturn(streamListener);
+ handler = new NettyServerHandler(transportListener, new DefaultHttp2Connection(true));
+ frameWriter = new DefaultHttp2FrameWriter();
+ frameReader = new DefaultHttp2FrameReader();
+
+ when(channel.isActive()).thenReturn(true);
+ mockContext();
+ mockFuture(true);
+
+ when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+
+ // Simulate activation of the handler to force writing of the initial settings
+ handler.handlerAdded(ctx);
+
+ // Simulate receipt of the connection preface
+ handler.channelRead(ctx, Http2CodecUtil.connectionPrefaceBuf());
+ // Simulate receipt of initial remote settings.
+ ByteBuf serializedSettings = serializeSettings(new Http2Settings());
+ handler.channelRead(ctx, serializedSettings);
+
+ // Reset the context to clear any interactions resulting from the HTTP/2
+ // connection preface handshake.
+ mockContext();
+ }
+
+ @Test
+ public void sendFrameShouldSucceed() throws Exception {
+ createStream();
+ ByteBuf content = Unpooled.copiedBuffer(CONTENT);
+
+ // Send a frame and verify that it was written.
+ handler.write(ctx, new SendGrpcFrameCommand(stream.id(), content, false), promise);
+ verify(promise, never()).setFailure(any(Throwable.class));
+ verify(ctx).write(any(ByteBuf.class), eq(promise));
+ assertEquals(0, content.refCnt());
+ }
+
+ @Test
+ public void inboundDataShouldForwardToStreamListener() throws Exception {
+ inboundDataShouldForwardToStreamListener(false);
+ }
+
+ @Test
+ public void inboundDataWithEndStreamShouldForwardToStreamListener() throws Exception {
+ inboundDataShouldForwardToStreamListener(true);
+ }
+
+ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws Exception {
+ createStream();
+
+ // Create a data frame and then trigger the handler to read it.
+ ByteBuf frame = dataFrame(STREAM_ID, endStream);
+ handler.channelRead(ctx, frame);
+ ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
+ verify(streamListener).messageRead(captor.capture(), eq(CONTENT.length));
+ assertArrayEquals(CONTENT, ByteStreams.toByteArray(captor.getValue()));
+
+ if (endStream) {
+ verify(streamListener).closed(eq(Status.OK));
+ }
+ verifyNoMoreInteractions(streamListener);
+ }
+
+ @Test
+ public void clientHalfCloseShouldForwardToStreamListener() throws Exception {
+ createStream();
+
+ handler.channelRead(ctx, emptyDataFrame(STREAM_ID, true));
+ verify(streamListener, never()).messageRead(any(InputStream.class), anyInt());
+ verify(streamListener).closed(eq(Status.OK));
+ verifyNoMoreInteractions(streamListener);
+ }
+
+ @Test
+ public void clientCancelShouldForwardToStreamListener() throws Exception {
+ createStream();
+
+ handler.channelRead(ctx, rstStreamFrame(STREAM_ID, Http2Error.CANCEL.code()));
+ verify(streamListener, never()).messageRead(any(InputStream.class), anyInt());
+ verify(streamListener).closed(eq(Status.CANCELLED));
+ verifyNoMoreInteractions(streamListener);
+ }
+
+ private void createStream() throws Exception {
+ Http2Headers headers = DefaultHttp2Headers.newBuilder()
+ .method(HttpUtil.HTTP_METHOD)
+ .set(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC)
+ .path("/foo.bar")
+ .build();
+ ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
+ handler.channelRead(ctx, headersFrame);
+ ArgumentCaptor<NettyServerStream> streamCaptor =
+ ArgumentCaptor.forClass(NettyServerStream.class);
+ @SuppressWarnings("rawtypes")
+ ArgumentCaptor<MethodDescriptor> methodCaptor = ArgumentCaptor.forClass(MethodDescriptor.class);
+ verify(transportListener).streamCreated(streamCaptor.capture(), methodCaptor.capture());
+ stream = streamCaptor.getValue();
+ }
+
+ private ByteBuf dataFrame(int streamId, boolean endStream) {
+ final ByteBuf compressionFrame = Unpooled.buffer(CONTENT.length);
+ MessageFramer framer = new MessageFramer(new Framer.Sink<ByteBuffer>() {
+ @Override
+ public void deliverFrame(ByteBuffer frame, boolean endOfStream) {
+ compressionFrame.writeBytes(frame);
+ }
+ }, 1000);
+ framer.writePayload(new ByteArrayInputStream(CONTENT), CONTENT.length);
+ framer.flush();
+ if (endStream) {
+ framer.close();
+ }
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeData(ctx, streamId, compressionFrame, 0, endStream, newPromise());
+ return captureWrite(ctx);
+ }
+
+ private ByteBuf emptyDataFrame(int streamId, boolean endStream) {
+ ChannelHandlerContext ctx = newContext();
+ frameWriter.writeData(ctx, streamId, Unpooled.EMPTY_BUFFER, 0, endStream, newPromise());
+ return captureWrite(ctx);
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java
new file mode 100644
index 0000000..425b4c1
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java
@@ -0,0 +1,143 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
+
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.StreamState;
+import com.google.net.stubby.transport.Transport;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link NettyServerStream}. */
+@RunWith(JUnit4.class)
+public class NettyServerStreamTest extends NettyStreamTestBase {
+
+ private static final int STREAM_ID = 1;
+ private NettyServerStream stream;
+
+ @Before public void setUp() {
+ init();
+
+ stream = new NettyServerStream(channel, STREAM_ID);
+ stream.setListener(listener);
+ assertEquals(StreamState.OPEN, stream.state());
+ verifyZeroInteractions(listener);
+ }
+
+ @Override
+ protected NettyStream stream() {
+ return stream;
+ }
+
+ @Test
+ public void writeContextShouldSendResponse() throws Exception {
+ stream.writeContext(CONTEXT_KEY, input, input.available(), accepted);
+ stream.flush();
+ verify(channel).write(new SendResponseHeadersCommand(STREAM_ID));
+ verify(channel).writeAndFlush(new SendGrpcFrameCommand(STREAM_ID, contextFrame(), false));
+ verify(accepted).run();
+ }
+
+ @Test
+ public void writeMessageShouldSendResponse() throws Exception {
+ stream.writeMessage(input, input.available(), accepted);
+ stream.flush();
+ verify(channel).write(new SendResponseHeadersCommand(STREAM_ID));
+ verify(channel).writeAndFlush(new SendGrpcFrameCommand(STREAM_ID, messageFrame(), false));
+ verify(accepted).run();
+ }
+
+ @Test
+ public void closeBeforeClientHalfCloseShouldFail() {
+ try {
+ stream.close(Status.OK);
+ fail("Should throw exception");
+ } catch (IllegalStateException expected) { }
+ assertEquals(StreamState.OPEN, stream.state());
+ verifyZeroInteractions(listener);
+ }
+
+ @Test
+ public void closeWithErrorBeforeClientHalfCloseShouldSucceed() throws Exception {
+ stream.close(Status.CANCELLED);
+ assertEquals(StreamState.CLOSED, stream.state());
+ verify(channel).writeAndFlush(
+ new SendGrpcFrameCommand(STREAM_ID, statusFrame(Status.CANCELLED), true));
+ verifyZeroInteractions(listener);
+ }
+
+ @Test
+ public void closeAfterClientHalfCloseShouldSucceed() throws Exception {
+ // Client half-closes. Listener gets closed()
+ stream.remoteEndClosed();
+ assertEquals(StreamState.WRITE_ONLY, stream.state());
+ verify(listener).closed(Status.OK);
+ // Server closes. Status sent.
+ stream.close(Status.OK);
+ assertEquals(StreamState.CLOSED, stream.state());
+ verify(channel).writeAndFlush(
+ new SendGrpcFrameCommand(STREAM_ID, statusFrame(Status.OK), true));
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void clientHalfCloseForTheSecondTimeShouldFail() throws Exception {
+ // Client half-closes. Listener gets closed()
+ stream.remoteEndClosed();
+ assertEquals(StreamState.WRITE_ONLY, stream.state());
+ verify(listener).closed(Status.OK);
+ // Client half-closes again. Stream will be aborted with an error.
+ stream.remoteEndClosed();
+ assertEquals(StreamState.CLOSED, stream.state());
+ verify(channel).writeAndFlush(
+ new SendGrpcFrameCommand(
+ STREAM_ID,
+ statusFrame(new Status(Transport.Code.FAILED_PRECONDITION,
+ "Client-end of the stream already closed")),
+ true));
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void abortStreamAndSendStatus() throws Exception {
+ Status status = new Status(Transport.Code.INTERNAL, new Throwable());
+ stream.abortStream(status, true);
+ assertEquals(StreamState.CLOSED, stream.state());
+ verify(listener).closed(status);
+ verify(channel).writeAndFlush(
+ new SendGrpcFrameCommand(STREAM_ID, statusFrame(status), true));
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void abortStreamAndNotSendStatus() throws Exception {
+ Status status = new Status(Transport.Code.INTERNAL, new Throwable());
+ stream.abortStream(status, false);
+ assertEquals(StreamState.CLOSED, stream.state());
+ verify(listener).closed(status);
+ verify(channel, never()).writeAndFlush(
+ new SendGrpcFrameCommand(STREAM_ID, statusFrame(status), true));
+ verifyNoMoreInteractions(listener);
+ }
+
+ @Test
+ public void abortStreamAfterClientHalfCloseShouldNotCallListenerTwice() {
+ Status status = new Status(Transport.Code.INTERNAL, new Throwable());
+ // Client half-closes. Listener gets closed()
+ stream.remoteEndClosed();
+ assertEquals(StreamState.WRITE_ONLY, stream.state());
+ verify(listener).closed(Status.OK);
+ // Abort
+ stream.abortStream(status, true);
+ assertEquals(StreamState.CLOSED, stream.state());
+ verifyNoMoreInteractions(listener);
+ }
+}
diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java
new file mode 100644
index 0000000..62b81d6
--- /dev/null
+++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java
@@ -0,0 +1,174 @@
+package com.google.net.stubby.newtransport.netty;
+
+import static com.google.net.stubby.GrpcFramingUtil.CONTEXT_VALUE_FRAME;
+import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME;
+import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME;
+import static io.netty.util.CharsetUtil.UTF_8;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.io.ByteStreams;
+import com.google.net.stubby.Status;
+import com.google.net.stubby.newtransport.StreamListener;
+import com.google.net.stubby.transport.Transport.ContextValue;
+import com.google.protobuf.ByteString;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.EventLoop;
+
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.InputStream;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Base class for Netty stream unit tests.
+ */
+public abstract class NettyStreamTestBase {
+ protected static final String CONTEXT_KEY = "key";
+ protected static final String MESSAGE = "hello world";
+
+ @Mock protected Channel channel;
+
+ @Mock protected ChannelFuture future;
+
+ @Mock protected StreamListener listener;
+
+ @Mock protected Runnable accepted;
+
+ @Mock protected EventLoop eventLoop;
+
+ @Mock protected ChannelPromise promise;
+
+ protected InputStream input;
+
+ /**
+ * Returns the NettyStream object to be tested.
+ */
+ protected abstract NettyStream stream();
+
+ protected final void init() {
+ MockitoAnnotations.initMocks(this);
+
+ mockChannelFuture(true);
+ when(channel.write(any())).thenReturn(future);
+ when(channel.writeAndFlush(any())).thenReturn(future);
+ when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+ when(channel.eventLoop()).thenReturn(eventLoop);
+ when(eventLoop.inEventLoop()).thenReturn(true);
+
+ input = new ByteArrayInputStream(MESSAGE.getBytes(UTF_8));
+ }
+
+ @Test
+ public void inboundContextShouldCallListener() throws Exception {
+ stream().inboundDataReceived(contextFrame(), false, promise);
+ ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
+ verify(listener).contextRead(eq(CONTEXT_KEY), captor.capture(), eq(MESSAGE.length()));
+ verify(promise).setSuccess();
+ assertEquals(MESSAGE, toString(captor.getValue()));
+ }
+
+ @Test
+ public void inboundMessageShouldCallListener() throws Exception {
+ stream().inboundDataReceived(messageFrame(), false, promise);
+ ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
+ verify(listener).messageRead(captor.capture(), eq(MESSAGE.length()));
+ verify(promise).setSuccess();
+ assertEquals(MESSAGE, toString(captor.getValue()));
+ }
+
+ private String toString(InputStream in) throws Exception {
+ byte[] bytes = new byte[in.available()];
+ ByteStreams.readFully(in, bytes);
+ return new String(bytes, UTF_8);
+ }
+
+ protected final ByteBuf contextFrame() throws Exception {
+ byte[] body = ContextValue
+ .newBuilder()
+ .setKey(CONTEXT_KEY)
+ .setValue(ByteString.copyFromUtf8(MESSAGE))
+ .build()
+ .toByteArray();
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.write(CONTEXT_VALUE_FRAME);
+ dos.writeInt(body.length);
+ dos.write(body);
+ dos.close();
+
+ // Write the compression header followed by the context frame.
+ return compressionFrame(os.toByteArray());
+ }
+
+ protected final ByteBuf messageFrame() throws Exception {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ dos.write(PAYLOAD_FRAME);
+ dos.writeInt(MESSAGE.length());
+ dos.write(MESSAGE.getBytes(UTF_8));
+ dos.close();
+
+ // Write the compression header followed by the context frame.
+ return compressionFrame(os.toByteArray());
+ }
+
+ protected final ByteBuf statusFrame(Status status) throws Exception {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(os);
+ short code = (short) status.getCode().getNumber();
+ dos.write(STATUS_FRAME);
+ int length = 2;
+ dos.writeInt(length);
+ dos.writeShort(code);
+
+ // Write the compression header followed by the context frame.
+ return compressionFrame(os.toByteArray());
+ }
+
+ protected final ByteBuf compressionFrame(byte[] data) {
+ ByteBuf buf = Unpooled.buffer();
+ buf.writeInt(data.length);
+ buf.writeBytes(data);
+ return buf;
+ }
+
+ private void mockChannelFuture(boolean succeeded) {
+ when(future.isDone()).thenReturn(true);
+ when(future.isCancelled()).thenReturn(false);
+ when(future.isSuccess()).thenReturn(succeeded);
+ when(future.awaitUninterruptibly(anyLong(), any(TimeUnit.class))).thenReturn(true);
+ if (!succeeded) {
+ when(future.cause()).thenReturn(new Exception("fake"));
+ }
+
+ doAnswer(new Answer<ChannelFuture>() {
+ @Override
+ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
+ ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0];
+ listener.operationComplete(future);
+ return future;
+ }
+ }).when(future).addListener(any(ChannelFutureListener.class));
+ }
+}