Merge pull request #15 from yang-g/c++api

Now the test passes :)
diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h
index 7aa22ee..af1c710 100644
--- a/include/grpc++/impl/call.h
+++ b/include/grpc++/impl/call.h
@@ -68,7 +68,7 @@
   void AddRecvInitialMetadata(
       std::multimap<grpc::string, grpc::string> *metadata);
   void AddSendMessage(const google::protobuf::Message &message);
-  void AddRecvMessage(google::protobuf::Message *message, bool* got_message);
+  void AddRecvMessage(google::protobuf::Message *message);
   void AddClientSendClose();
   void AddClientRecvStatus(std::multimap<grpc::string, grpc::string> *metadata,
                            Status *status);
@@ -84,6 +84,7 @@
   // Called by completion queue just prior to returning from Next() or Pluck()
   void FinalizeResult(void **tag, bool *status) override;
 
+  bool got_message = false;
  private:
   void *return_tag_ = nullptr;
   // Send initial metadata
@@ -98,7 +99,6 @@
   grpc_byte_buffer* send_message_buf_ = nullptr;
   // Recv message
   google::protobuf::Message* recv_message_ = nullptr;
-  bool* got_message_ = nullptr;
   grpc_byte_buffer* recv_message_buf_ = nullptr;
   // Client send close
   bool client_send_close_ = false;
diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h
index 64091a4..423ebf2 100644
--- a/include/grpc++/server_context.h
+++ b/include/grpc++/server_context.h
@@ -45,7 +45,7 @@
 
 namespace grpc {
 
-template <class R>
+template <class W, class R>
 class ServerAsyncReader;
 template <class W>
 class ServerAsyncWriter;
diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h
index 359a272..6ee550b 100644
--- a/include/grpc++/stream.h
+++ b/include/grpc++/stream.h
@@ -119,10 +119,9 @@
       buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
       context_->initial_metadata_received_ = true;
     }
-    bool got_message;
-    buf.AddRecvMessage(msg, &got_message);
+    buf.AddRecvMessage(msg);
     call_.PerformOps(&buf);
-    return cq_.Pluck(&buf) && got_message;
+    return cq_.Pluck(&buf) && buf.got_message;
   }
 
   virtual Status Finish() override {
@@ -174,11 +173,10 @@
   virtual Status Finish() override {
     CallOpBuffer buf;
     Status status;
-    bool got_message;
-    buf.AddRecvMessage(response_, &got_message);
+    buf.AddRecvMessage(response_);
     buf.AddClientRecvStatus(&context_->trailing_metadata_, &status);
     call_.PerformOps(&buf);
-    GPR_ASSERT(cq_.Pluck(&buf) && got_message);
+    GPR_ASSERT(cq_.Pluck(&buf) && buf.got_message);
     return status;
   }
 
@@ -225,10 +223,9 @@
       buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
       context_->initial_metadata_received_ = true;
     }
-    bool got_message;
-    buf.AddRecvMessage(msg, &got_message);
+    buf.AddRecvMessage(msg);
     call_.PerformOps(&buf);
-    return cq_.Pluck(&buf) && got_message;
+    return cq_.Pluck(&buf) && buf.got_message;
   }
 
   virtual bool Write(const W& msg) override {
@@ -277,10 +274,9 @@
 
   virtual bool Read(R* msg) override {
     CallOpBuffer buf;
-    bool got_message;
-    buf.AddRecvMessage(msg, &got_message);
+    buf.AddRecvMessage(msg);
     call_->PerformOps(&buf);
-    return call_->cq()->Pluck(&buf) && got_message;
+    return call_->cq()->Pluck(&buf) && buf.got_message;
   }
 
  private:
@@ -338,10 +334,9 @@
 
   virtual bool Read(R* msg) override {
     CallOpBuffer buf;
-    bool got_message;
-    buf.AddRecvMessage(msg, &got_message);
+    buf.AddRecvMessage(msg);
     call_->PerformOps(&buf);
-    return call_->cq()->Pluck(&buf) && got_message;
+    return call_->cq()->Pluck(&buf) && buf.got_message;
   }
 
   virtual bool Write(const W& msg) override {
@@ -420,8 +415,7 @@
       read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
       context_->initial_metadata_received_ = true;
     }
-    bool ignore;
-    read_buf_.AddRecvMessage(msg, &ignore);
+    read_buf_.AddRecvMessage(msg);
     call_.PerformOps(&read_buf_);
   }
 
@@ -485,8 +479,7 @@
       finish_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
       context_->initial_metadata_received_ = true;
     }
-    bool ignore;
-    finish_buf_.AddRecvMessage(response_, &ignore);
+    finish_buf_.AddRecvMessage(response_);
     finish_buf_.AddClientRecvStatus(&context_->trailing_metadata_, status);
     call_.PerformOps(&finish_buf_);
   }
@@ -494,7 +487,6 @@
  private:
   ClientContext* context_ = nullptr;
   google::protobuf::Message *const response_;
-  bool got_message_;
   Call call_;
   CallOpBuffer init_buf_;
   CallOpBuffer meta_buf_;
@@ -532,8 +524,7 @@
       read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
       context_->initial_metadata_received_ = true;
     }
-    bool ignore;
-    read_buf_.AddRecvMessage(msg, &ignore);
+    read_buf_.AddRecvMessage(msg);
     call_.PerformOps(&read_buf_);
   }
 
@@ -624,7 +615,7 @@
   CallOpBuffer finish_buf_;
 };
 
-template <class R>
+template <class W, class R>
 class ServerAsyncReader : public ServerAsyncStreamingInterface,
                           public AsyncReaderInterface<R> {
  public:
@@ -646,7 +637,24 @@
     call_.PerformOps(&read_buf_);
   }
 
-  void Finish(const Status& status, void* tag) {
+  void Finish(const W& msg, const Status& status, void* tag) {
+    finish_buf_.Reset(tag);
+    if (!ctx_->sent_initial_metadata_) {
+      finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
+      ctx_->sent_initial_metadata_ = true;
+    }
+    // The response is dropped if the status is not OK.
+    if (status.IsOk()) {
+      finish_buf_.AddSendMessage(msg);
+    }
+    bool cancelled = false;
+    finish_buf_.AddServerRecvClose(&cancelled);
+    finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
+    call_.PerformOps(&finish_buf_);
+  }
+
+  void FinishWithError(const Status& status, void* tag) {
+    GPR_ASSERT(!status.IsOk());
     finish_buf_.Reset(tag);
     if (!ctx_->sent_initial_metadata_) {
       finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
@@ -658,7 +666,6 @@
     call_.PerformOps(&finish_buf_);
   }
 
-
  private:
   void BindCall(Call *call) override { call_ = *call; }
 
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index a34aa4e..2a9895e 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -133,7 +133,7 @@
     temp.append("template <class OutMessage> class ClientWriter;\n");
     temp.append("template <class InMessage> class ServerReader;\n");
     temp.append("template <class OutMessage> class ClientAsyncWriter;\n");
-    temp.append("template <class InMessage> class ServerAsyncReader;\n");
+    temp.append("template <class OutMessage, class InMessage> class ServerAsyncReader;\n");
   }
   if (HasServerOnlyStreaming(file)) {
     temp.append("template <class InMessage> class ClientReader;\n");
@@ -267,7 +267,7 @@
     printer->Print(*vars,
                    "void Request$Method$("
                    "::grpc::ServerContext* context, "
-                   "::grpc::ServerAsyncReader< $Request$>* reader, "
+                   "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
                    "::grpc::CompletionQueue* cq, void *tag);\n");
   } else if (ServerOnlyStreaming(method)) {
     printer->Print(*vars,
@@ -538,7 +538,7 @@
     printer->Print(*vars,
                    "void $Service$::AsyncService::Request$Method$("
                    "::grpc::ServerContext* context, "
-                   "::grpc::ServerAsyncReader< $Request$>* reader, "
+                   "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
                    "::grpc::CompletionQueue* cq, void* tag) {\n");
     printer->Print(
         *vars,
diff --git a/src/cpp/client/client_unary_call.cc b/src/cpp/client/client_unary_call.cc
index b6bd81d..d68d7a9 100644
--- a/src/cpp/client/client_unary_call.cc
+++ b/src/cpp/client/client_unary_call.cc
@@ -53,21 +53,18 @@
   buf.AddSendInitialMetadata(context);
   buf.AddSendMessage(request);
   buf.AddRecvInitialMetadata(&context->recv_initial_metadata_);
-  bool got_message;
-  buf.AddRecvMessage(result, &got_message);
+  buf.AddRecvMessage(result);
   buf.AddClientSendClose();
   buf.AddClientRecvStatus(&context->trailing_metadata_, &status);
   call.PerformOps(&buf);
-  GPR_ASSERT(cq.Pluck(&buf) && (got_message || !status.IsOk()));
+  GPR_ASSERT(cq.Pluck(&buf) && (buf.got_message || !status.IsOk()));
   return status;
 }
 
 class ClientAsyncRequest final : public CallOpBuffer {
  public:
-  bool got_message = false;
   void FinalizeResult(void** tag, bool* status) override {
     CallOpBuffer::FinalizeResult(tag, status);
-    *status &= got_message;
     delete this;
   }
 };
@@ -83,7 +80,7 @@
   buf->AddSendInitialMetadata(context);
   buf->AddSendMessage(request);
   buf->AddRecvInitialMetadata(&context->recv_initial_metadata_);
-  buf->AddRecvMessage(result, &buf->got_message);
+  buf->AddRecvMessage(result);
   buf->AddClientSendClose();
   buf->AddClientRecvStatus(&context->trailing_metadata_, status);
   call.PerformOps(buf);
diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc
index d706ec4..fe8859d 100644
--- a/src/cpp/common/call.cc
+++ b/src/cpp/common/call.cc
@@ -57,7 +57,7 @@
   }
 
   recv_message_ = nullptr;
-  got_message_ = nullptr;
+  got_message = false;
   if (recv_message_buf_) {
     grpc_byte_buffer_destroy(recv_message_buf_);
     recv_message_buf_ = nullptr;
@@ -142,9 +142,8 @@
   send_message_ = &message;
 }
 
-void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message, bool* got_message) {
+void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message) {
   recv_message_ = message;
-  got_message_ = got_message;
 }
 
 void CallOpBuffer::AddClientSendClose() {
@@ -256,12 +255,14 @@
   // Parse received message if any.
   if (recv_message_) {
     if (recv_message_buf_) {
-      *got_message_ = true;
+      got_message = true;
       *status = DeserializeProto(recv_message_buf_, recv_message_);
       grpc_byte_buffer_destroy(recv_message_buf_);
       recv_message_buf_ = nullptr;
     } else {
-      *got_message_ = false;
+      // Read failed
+      got_message = false;
+      *status = false;
     }
   }
   // Parse received status.
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index 52fb80e..b85aabf 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -64,9 +64,21 @@
 
 namespace {
 
+void* tag(int i) {
+  return (void*)(gpr_intptr)i;
+}
+
+void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
+  bool ok;
+  void* got_tag;
+  EXPECT_TRUE(cq->Next(&got_tag, &ok));
+  EXPECT_EQ(expect_ok, ok);
+  EXPECT_EQ(tag(i), got_tag);
+}
+
 class End2endTest : public ::testing::Test {
  protected:
-  End2endTest() : service_(&cq_) {}
+  End2endTest() : service_(&srv_cq_) {}
 
   void SetUp() override {
     int port = grpc_pick_unused_port_or_die();
@@ -86,20 +98,30 @@
     stub_.reset(grpc::cpp::test::util::TestService::NewStub(channel));
   }
 
-  CompletionQueue cq_;
+  void server_ok(int i) {
+    verify_ok(&srv_cq_, i, true);
+  }
+  void client_ok(int i) {
+    verify_ok(&cli_cq_, i , true);
+  }
+  void server_fail(int i) {
+    verify_ok(&srv_cq_, i, false);
+  }
+  void client_fail(int i) {
+    verify_ok(&cli_cq_, i, false);
+  }
+
+  CompletionQueue cli_cq_;
+  CompletionQueue srv_cq_;
   std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
   std::unique_ptr<Server> server_;
   grpc::cpp::test::util::TestService::AsyncService service_;
   std::ostringstream server_address_;
 };
 
-void* tag(int i) {
-  return (void*)(gpr_intptr)i;
-}
-
 TEST_F(End2endTest, SimpleRpc) {
   ResetStub();
-  
+
   EchoRequest send_request;
   EchoRequest recv_request;
   EchoResponse send_response;
@@ -110,34 +132,128 @@
   grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
 
   send_request.set_message("Hello");
-  stub_->Echo(&cli_ctx, send_request, &recv_response, &recv_status, &cq_, tag(1));
+  stub_->Echo(
+      &cli_ctx, send_request, &recv_response, &recv_status, &cli_cq_, tag(1));
 
-  service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, &cq_, tag(2));
+  service_.RequestEcho(
+      &srv_ctx, &recv_request, &response_writer, &srv_cq_, tag(2));
 
-  void *got_tag;
-  bool ok;
-  EXPECT_TRUE(cq_.Next(&got_tag, &ok));
-  EXPECT_TRUE(ok);
-  EXPECT_EQ(got_tag, tag(2));
-  EXPECT_EQ(recv_request.message(), "Hello");
+  server_ok(2);
+  EXPECT_EQ(send_request.message(), recv_request.message());
 
   send_response.set_message(recv_request.message());
   response_writer.Finish(send_response, Status::OK, tag(3));
 
-  EXPECT_TRUE(cq_.Next(&got_tag, &ok));
-  EXPECT_TRUE(ok);
-  if (got_tag == tag(3)) {
-    EXPECT_TRUE(cq_.Next(&got_tag, &ok));
-    EXPECT_TRUE(ok);
-    EXPECT_EQ(got_tag, tag(1));
-  } else {
-    EXPECT_EQ(got_tag, tag(1));
-    EXPECT_TRUE(cq_.Next(&got_tag, &ok));
-    EXPECT_TRUE(ok);
-    EXPECT_EQ(got_tag, tag(3));
-  }
+  server_ok(3);
 
-  EXPECT_EQ(recv_response.message(), "Hello");
+  client_ok(1);
+
+  EXPECT_EQ(send_response.message(), recv_response.message());
+  EXPECT_TRUE(recv_status.IsOk());
+}
+
+TEST_F(End2endTest, SimpleClientStreaming) {
+  ResetStub();
+
+  EchoRequest send_request;
+  EchoRequest recv_request;
+  EchoResponse send_response;
+  EchoResponse recv_response;
+  Status recv_status;
+  ClientContext cli_ctx;
+  ServerContext srv_ctx;
+  ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+  send_request.set_message("Hello");
+  ClientAsyncWriter<EchoRequest>* cli_stream =
+      stub_->RequestStream(&cli_ctx, &recv_response, &cli_cq_, tag(1));
+
+  service_.RequestRequestStream(
+      &srv_ctx, &srv_stream, &srv_cq_, tag(2));
+
+  server_ok(2);
+  client_ok(1);
+
+  cli_stream->Write(send_request, tag(3));
+  client_ok(3);
+
+  srv_stream.Read(&recv_request, tag(4));
+  server_ok(4);
+  EXPECT_EQ(send_request.message(), recv_request.message());
+
+  cli_stream->Write(send_request, tag(5));
+  client_ok(5);
+
+  srv_stream.Read(&recv_request, tag(6));
+  server_ok(6);
+
+  EXPECT_EQ(send_request.message(), recv_request.message());
+  cli_stream->WritesDone(tag(7));
+  client_ok(7);
+
+  srv_stream.Read(&recv_request, tag(8));
+  server_fail(8);
+
+  send_response.set_message(recv_request.message());
+  srv_stream.Finish(send_response, Status::OK, tag(9));
+  server_ok(9);
+
+  cli_stream->Finish(&recv_status, tag(10));
+  client_ok(10);
+
+  EXPECT_EQ(send_response.message(), recv_response.message());
+  EXPECT_TRUE(recv_status.IsOk());
+}
+
+TEST_F(End2endTest, SimpleBidiStreaming) {
+  ResetStub();
+
+  EchoRequest send_request;
+  EchoRequest recv_request;
+  EchoResponse send_response;
+  EchoResponse recv_response;
+  Status recv_status;
+  ClientContext cli_ctx;
+  ServerContext srv_ctx;
+  ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+  send_request.set_message("Hello");
+  ClientAsyncReaderWriter<EchoRequest, EchoResponse>* cli_stream =
+      stub_->BidiStream(&cli_ctx, &cli_cq_, tag(1));
+
+  service_.RequestBidiStream(
+      &srv_ctx, &srv_stream, &srv_cq_, tag(2));
+
+  server_ok(2);
+  client_ok(1);
+
+  cli_stream->Write(send_request, tag(3));
+  client_ok(3);
+
+  srv_stream.Read(&recv_request, tag(4));
+  server_ok(4);
+  EXPECT_EQ(send_request.message(), recv_request.message());
+
+  send_response.set_message(recv_request.message());
+  srv_stream.Write(send_response, tag(5));
+  server_ok(5);
+
+  cli_stream->Read(&recv_response, tag(6));
+  client_ok(6);
+  EXPECT_EQ(send_response.message(), recv_response.message());
+
+  cli_stream->WritesDone(tag(7));
+  client_ok(7);
+
+  srv_stream.Read(&recv_request, tag(8));
+  server_fail(8);
+
+  srv_stream.Finish(Status::OK, tag(9));
+  server_ok(9);
+
+  cli_stream->Finish(&recv_status, tag(10));
+  client_ok(10);
+
   EXPECT_TRUE(recv_status.IsOk());
 }