change ServerAsyncReader API and add a simple clientstreaming test, it passes
diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h
index 64091a4..423ebf2 100644
--- a/include/grpc++/server_context.h
+++ b/include/grpc++/server_context.h
@@ -45,7 +45,7 @@
namespace grpc {
-template <class R>
+template <class W, class R>
class ServerAsyncReader;
template <class W>
class ServerAsyncWriter;
diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h
index ecc28f6..6ee550b 100644
--- a/include/grpc++/stream.h
+++ b/include/grpc++/stream.h
@@ -615,7 +615,7 @@
CallOpBuffer finish_buf_;
};
-template <class R>
+template <class W, class R>
class ServerAsyncReader : public ServerAsyncStreamingInterface,
public AsyncReaderInterface<R> {
public:
@@ -637,7 +637,24 @@
call_.PerformOps(&read_buf_);
}
- void Finish(const Status& status, void* tag) {
+ void Finish(const W& msg, const Status& status, void* tag) {
+ finish_buf_.Reset(tag);
+ if (!ctx_->sent_initial_metadata_) {
+ finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
+ ctx_->sent_initial_metadata_ = true;
+ }
+ // The response is dropped if the status is not OK.
+ if (status.IsOk()) {
+ finish_buf_.AddSendMessage(msg);
+ }
+ bool cancelled = false;
+ finish_buf_.AddServerRecvClose(&cancelled);
+ finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
+ call_.PerformOps(&finish_buf_);
+ }
+
+ void FinishWithError(const Status& status, void* tag) {
+ GPR_ASSERT(!status.IsOk());
finish_buf_.Reset(tag);
if (!ctx_->sent_initial_metadata_) {
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
@@ -649,7 +666,6 @@
call_.PerformOps(&finish_buf_);
}
-
private:
void BindCall(Call *call) override { call_ = *call; }
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index a34aa4e..2a9895e 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -133,7 +133,7 @@
temp.append("template <class OutMessage> class ClientWriter;\n");
temp.append("template <class InMessage> class ServerReader;\n");
temp.append("template <class OutMessage> class ClientAsyncWriter;\n");
- temp.append("template <class InMessage> class ServerAsyncReader;\n");
+ temp.append("template <class OutMessage, class InMessage> class ServerAsyncReader;\n");
}
if (HasServerOnlyStreaming(file)) {
temp.append("template <class InMessage> class ClientReader;\n");
@@ -267,7 +267,7 @@
printer->Print(*vars,
"void Request$Method$("
"::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReader< $Request$>* reader, "
+ "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
"::grpc::CompletionQueue* cq, void *tag);\n");
} else if (ServerOnlyStreaming(method)) {
printer->Print(*vars,
@@ -538,7 +538,7 @@
printer->Print(*vars,
"void $Service$::AsyncService::Request$Method$("
"::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReader< $Request$>* reader, "
+ "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
"::grpc::CompletionQueue* cq, void* tag) {\n");
printer->Print(
*vars,
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index 62c7e40..b85aabf 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -110,6 +110,7 @@
void client_fail(int i) {
verify_ok(&cli_cq_, i, false);
}
+
CompletionQueue cli_cq_;
CompletionQueue srv_cq_;
std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
@@ -151,6 +152,59 @@
EXPECT_TRUE(recv_status.IsOk());
}
+TEST_F(End2endTest, SimpleClientStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ ClientAsyncWriter<EchoRequest>* cli_stream =
+ stub_->RequestStream(&cli_ctx, &recv_response, &cli_cq_, tag(1));
+
+ service_.RequestRequestStream(
+ &srv_ctx, &srv_stream, &srv_cq_, tag(2));
+
+ server_ok(2);
+ client_ok(1);
+
+ cli_stream->Write(send_request, tag(3));
+ client_ok(3);
+
+ srv_stream.Read(&recv_request, tag(4));
+ server_ok(4);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_stream->Write(send_request, tag(5));
+ client_ok(5);
+
+ srv_stream.Read(&recv_request, tag(6));
+ server_ok(6);
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ cli_stream->WritesDone(tag(7));
+ client_ok(7);
+
+ srv_stream.Read(&recv_request, tag(8));
+ server_fail(8);
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Finish(send_response, Status::OK, tag(9));
+ server_ok(9);
+
+ cli_stream->Finish(&recv_status, tag(10));
+ client_ok(10);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.IsOk());
+}
+
TEST_F(End2endTest, SimpleBidiStreaming) {
ResetStub();