Enforce sending headers before messages on server
ServerCall already had "headers must be sent before any messages, which
must be sent before closing," but the implementation did not enforce it
and our async server handler didn't obey.
The benefit of forcing sending headers first is that it removes the only
implicit call in our API and interceptors dealing just with metadata
don't need to override sendMessage. The implicit behavior was bug-prone
since it wasn't obvious you were forgetting that headers may not be
sent.
diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java
index cfafa94..7503647 100644
--- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java
+++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java
@@ -102,9 +102,8 @@
@Override
public final void writeMessage(InputStream message) {
- if (!headersSent) {
- writeHeaders(new Metadata());
- headersSent = true;
+ if (outboundPhase() != Phase.MESSAGE) {
+ throw new IllegalStateException("Messages are only permitted after headers and before close");
}
super.writeMessage(message);
}
diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java
index 1286276..2443082 100644
--- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java
+++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java
@@ -170,6 +170,7 @@
capturedHeaders.set(captured);
}
};
+ stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
@@ -204,6 +205,7 @@
sendCalled.set(true);
}
};
+ stream.writeHeaders(new Metadata());
stream.closeFramer();
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
@@ -220,6 +222,7 @@
sendCalled.set(true);
}
};
+ stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
// Force the message to be flushed
diff --git a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java
index 6c8d07c..97a2b5a 100644
--- a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java
+++ b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java
@@ -37,7 +37,6 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
-import io.grpc.Status;
import java.util.logging.Logger;
@@ -60,26 +59,10 @@
ServerCallHandler<ReqT, RespT> next) {
logger.info("header received from client:" + requestHeaders.toString());
return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) {
- boolean sentHeaders = false;
-
@Override
public void sendHeaders(Metadata responseHeaders) {
responseHeaders.put(customHeadKey, "customRespondValue");
super.sendHeaders(responseHeaders);
- sentHeaders = true;
- }
-
- @Override
- public void sendMessage(RespT message) {
- if (!sentHeaders) {
- sendHeaders(new Metadata());
- }
- super.sendMessage(message);
- }
-
- @Override
- public void close(Status status, Metadata trailers) {
- super.close(status, trailers);
}
}, requestHeaders);
}
diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
index c866c18..ad0964f 100644
--- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
@@ -80,7 +80,7 @@
* Tests for {@link NettyClientStream}.
*/
@RunWith(JUnit4.class)
-public class NettyClientStreamTest extends NettyStreamTestBase {
+public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream> {
@Mock
protected ClientStreamListener listener;
@@ -372,6 +372,9 @@
}
@Override
+ protected void sendHeadersIfServer() {}
+
+ @Override
protected void closeStream() {
stream().cancel(Status.CANCELLED);
}
diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
index 83f634b..40de685 100644
--- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java
@@ -339,6 +339,7 @@
this.stream = stream;
this.method = method;
this.headers = headers;
+ stream.writeHeaders(new Metadata());
stream.request(1);
}
diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
index 55ac31e..2128ada 100644
--- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
@@ -71,7 +71,7 @@
/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
-public class NettyServerStreamTest extends NettyStreamTestBase {
+public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
@Mock
protected ServerStreamListener serverListener;
@@ -92,13 +92,14 @@
@Test
public void writeMessageShouldSendResponse() throws Exception {
- byte[] msg = smallMessage();
- stream.writeMessage(new ByteArrayInputStream(msg));
- stream.flush();
+ stream.writeHeaders(new Metadata());
Http2Headers headers = new DefaultHttp2Headers()
.status(Utils.STATUS_OK)
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
+ byte[] msg = smallMessage();
+ stream.writeMessage(new ByteArrayInputStream(msg));
+ stream.flush();
verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
any(ChannelPromise.class),
eq(true));
@@ -267,6 +268,11 @@
}
@Override
+ protected void sendHeadersIfServer() {
+ stream.writeHeaders(new Metadata());
+ }
+
+ @Override
protected void closeStream() {
stream().close(Status.ABORTED, new Metadata());
}
diff --git a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java
index 90099af..23d47ac 100644
--- a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java
+++ b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java
@@ -71,7 +71,7 @@
/**
* Base class for Netty stream unit tests.
*/
-public abstract class NettyStreamTestBase {
+public abstract class NettyStreamTestBase<T extends AbstractStream<Integer>> {
protected static final String MESSAGE = "hello world";
protected static final int STREAM_ID = 1;
@@ -99,7 +99,7 @@
@Mock
protected WriteQueue writeQueue;
- protected AbstractStream<Integer> stream;
+ protected T stream;
/** Set up for test. */
@Before
@@ -160,6 +160,7 @@
@Test
public void notifiedOnReadyAfterWriteCompletes() throws IOException {
+ sendHeadersIfServer();
assertTrue(stream.isReady());
byte[] msg = largeMessage();
// The future is set up to automatically complete, indicating that the write is done.
@@ -171,6 +172,7 @@
@Test
public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
+ sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
@@ -184,6 +186,7 @@
@Test
public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
+ sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
@@ -209,7 +212,9 @@
return largeMessage;
}
- protected abstract AbstractStream<Integer> createStream();
+ protected abstract T createStream();
+
+ protected abstract void sendHeadersIfServer();
protected abstract StreamListener listener();
diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java
index 9870e8d..9cf4450 100644
--- a/stub/src/main/java/io/grpc/stub/ServerCalls.java
+++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java
@@ -223,6 +223,7 @@
private static class ResponseObserver<RespT> implements StreamObserver<RespT> {
final ServerCall<RespT> call;
volatile boolean cancelled;
+ private boolean sentHeaders;
ResponseObserver(ServerCall<RespT> call) {
this.call = call;
@@ -233,6 +234,10 @@
if (cancelled) {
throw Status.CANCELLED.asRuntimeException();
}
+ if (!sentHeaders) {
+ call.sendHeaders(new Metadata());
+ sentHeaders = true;
+ }
call.sendMessage(response);
}
diff --git a/testing/src/main/java/io/grpc/testing/TestUtils.java b/testing/src/main/java/io/grpc/testing/TestUtils.java
index 53e3c39..1127cfb 100644
--- a/testing/src/main/java/io/grpc/testing/TestUtils.java
+++ b/testing/src/main/java/io/grpc/testing/TestUtils.java
@@ -90,21 +90,10 @@
ServerCallHandler<ReqT, RespT> next) {
return next.startCall(method,
new SimpleForwardingServerCall<RespT>(call) {
- boolean sentHeaders;
-
@Override
public void sendHeaders(Metadata responseHeaders) {
responseHeaders.merge(requestHeaders, keySet);
super.sendHeaders(responseHeaders);
- sentHeaders = true;
- }
-
- @Override
- public void sendMessage(RespT message) {
- if (!sentHeaders) {
- sendHeaders(new Metadata());
- }
- super.sendMessage(message);
}
@Override