Merge branch 'qps-stream' of https://github.com/vjpai/grpc into vjpai-qps-stream
diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc
index 586b6e7..1ec1928 100644
--- a/test/cpp/qps/server_async.cc
+++ b/test/cpp/qps/server_async.cc
@@ -33,6 +33,7 @@
 
 #include <forward_list>
 #include <functional>
+#include <mutex>
 #include <sys/time.h>
 #include <sys/resource.h>
 #include <sys/signal.h>
@@ -48,6 +49,7 @@
 #include <grpc++/server_context.h>
 #include <grpc++/server_credentials.h>
 #include <grpc++/status.h>
+#include <grpc++/stream.h>
 #include <gtest/gtest.h>
 #include "src/cpp/server/thread_pool.h"
 #include "test/core/util/grpc_profiler.h"
@@ -63,7 +65,8 @@
 class AsyncQpsServerTest : public Server {
  public:
   AsyncQpsServerTest(const ServerConfig& config, int port)
-      : srv_cq_(), async_service_(&srv_cq_), server_(nullptr) {
+      : srv_cq_(), async_service_(&srv_cq_), server_(nullptr),
+        shutdown_(false) {
     char* server_address = NULL;
     gpr_join_host_port(&server_address, "::", port);
 
@@ -78,10 +81,16 @@
     using namespace std::placeholders;
     request_unary_ = std::bind(&TestService::AsyncService::RequestUnaryCall,
                                &async_service_, _1, _2, _3, &srv_cq_, _4);
+    request_streaming_ =
+      std::bind(&TestService::AsyncService::RequestStreamingCall,
+		&async_service_, _1, _2, &srv_cq_, _3);
     for (int i = 0; i < 100; i++) {
       contexts_.push_front(
           new ServerRpcContextUnaryImpl<SimpleRequest, SimpleResponse>(
-              request_unary_, UnaryCall));
+              request_unary_, ProcessRPC));
+      contexts_.push_front(
+          new ServerRpcContextStreamingImpl<SimpleRequest, SimpleResponse>(
+              request_streaming_, ProcessRPC));
     }
     for (int i = 0; i < config.threads(); i++) {
       threads_.push_back(std::thread([=]() {
@@ -89,14 +98,15 @@
         bool ok;
         void* got_tag;
         while (srv_cq_.Next(&got_tag, &ok)) {
-          if (ok) {
-            ServerRpcContext* ctx = detag(got_tag);
-            // The tag is a pointer to an RPC context to invoke
-            if (ctx->RunNextState() == false) {
-              // this RPC context is done, so refresh it
+	  ServerRpcContext* ctx = detag(got_tag);
+	  // The tag is a pointer to an RPC context to invoke
+	  if (ctx->RunNextState(ok) == false) {
+	    // this RPC context is done, so refresh it
+            std::lock_guard<std::mutex> g(shutdown_mutex_);
+            if (!shutdown_) {
               ctx->Reset();
             }
-          }
+	  }
         }
         return;
       }));
@@ -104,7 +114,11 @@
   }
   ~AsyncQpsServerTest() {
     server_->Shutdown();
-    srv_cq_.Shutdown();
+    {
+      std::lock_guard<std::mutex> g(shutdown_mutex_);
+      shutdown_ = true;
+      srv_cq_.Shutdown();
+    }
     for (auto& thr : threads_) {
       thr.join();
     }
@@ -119,7 +133,7 @@
    public:
     ServerRpcContext() {}
     virtual ~ServerRpcContext(){};
-    virtual bool RunNextState() = 0;  // do next state, return false if all done
+    virtual bool RunNextState(bool) = 0;  // next state, return false if done
     virtual void Reset() = 0;         // start this back at a clean state
   };
   static void* tag(ServerRpcContext* func) {
@@ -130,7 +144,7 @@
   }
 
   template <class RequestType, class ResponseType>
-  class ServerRpcContextUnaryImpl : public ServerRpcContext {
+  class ServerRpcContextUnaryImpl GRPC_FINAL : public ServerRpcContext {
    public:
     ServerRpcContextUnaryImpl(
         std::function<void(ServerContext*, RequestType*,
@@ -146,7 +160,7 @@
                       AsyncQpsServerTest::tag(this));
     }
     ~ServerRpcContextUnaryImpl() GRPC_OVERRIDE {}
-    bool RunNextState() GRPC_OVERRIDE { return (this->*next_state_)(); }
+    bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);}
     void Reset() GRPC_OVERRIDE {
       srv_ctx_ = ServerContext();
       req_ = RequestType();
@@ -160,8 +174,11 @@
     }
 
    private:
-    bool finisher() { return false; }
-    bool invoker() {
+    bool finisher(bool) { return false; }
+    bool invoker(bool ok) {
+      if (!ok)
+	return false;
+
       ResponseType response;
 
       // Call the RPC processing function
@@ -174,7 +191,7 @@
     }
     ServerContext srv_ctx_;
     RequestType req_;
-    bool (ServerRpcContextUnaryImpl::*next_state_)();
+    bool (ServerRpcContextUnaryImpl::*next_state_)(bool);
     std::function<void(ServerContext*, RequestType*,
                        grpc::ServerAsyncResponseWriter<ResponseType>*, void*)>
         request_method_;
@@ -183,9 +200,88 @@
     grpc::ServerAsyncResponseWriter<ResponseType> response_writer_;
   };
 
-  static Status UnaryCall(const SimpleRequest* request,
-                          SimpleResponse* response) {
-    if (request->has_response_size() && request->response_size() > 0) {
+  template <class RequestType, class ResponseType>
+  class ServerRpcContextStreamingImpl GRPC_FINAL : public ServerRpcContext {
+   public:
+    ServerRpcContextStreamingImpl(
+        std::function<void(ServerContext *,
+                           grpc::ServerAsyncReaderWriter<ResponseType,
+			   RequestType> *, void *)> request_method,
+        std::function<grpc::Status(const RequestType *, ResponseType *)>
+            invoke_method)
+        : next_state_(&ServerRpcContextStreamingImpl::request_done),
+          request_method_(request_method),
+          invoke_method_(invoke_method),
+          stream_(&srv_ctx_) {
+      request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this));
+    }
+    ~ServerRpcContextStreamingImpl() GRPC_OVERRIDE {
+    }
+    bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);}
+    void Reset() GRPC_OVERRIDE {
+      srv_ctx_ = ServerContext();
+      req_ = RequestType();
+      stream_ = grpc::ServerAsyncReaderWriter<ResponseType,
+					      RequestType>(&srv_ctx_);
+
+      // Then request the method
+      next_state_ = &ServerRpcContextStreamingImpl::request_done;
+      request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this));
+    }
+
+   private:
+    bool request_done(bool ok) {
+      if (!ok)
+	return false;
+      stream_.Read(&req_, AsyncQpsServerTest::tag(this));
+      next_state_ = &ServerRpcContextStreamingImpl::read_done;
+      return true;
+    }
+
+    bool read_done(bool ok) {
+      if (ok) {
+	// invoke the method
+	ResponseType response;
+	// Call the RPC processing function
+	grpc::Status status = invoke_method_(&req_, &response);
+	// initiate the write
+	stream_.Write(response, AsyncQpsServerTest::tag(this));
+	next_state_ = &ServerRpcContextStreamingImpl::write_done;
+      } else {	// client has sent writes done
+	// finish the stream
+	stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this));
+	next_state_ = &ServerRpcContextStreamingImpl::finish_done;
+      }
+      return true;
+    }
+    bool write_done(bool ok) {
+      // now go back and get another streaming read!
+      if (ok) {
+	stream_.Read(&req_, AsyncQpsServerTest::tag(this));
+	next_state_ = &ServerRpcContextStreamingImpl::read_done;
+      }
+      else {
+	stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this));
+	next_state_ = &ServerRpcContextStreamingImpl::finish_done;
+      }
+      return true;
+    }
+    bool finish_done(bool ok) {return false; /* reset the context */ }
+
+    ServerContext srv_ctx_;
+    RequestType req_;
+    bool (ServerRpcContextStreamingImpl::*next_state_)(bool);
+    std::function<void(ServerContext *,
+		       grpc::ServerAsyncReaderWriter<ResponseType,
+		       RequestType> *, void *)> request_method_;
+    std::function<grpc::Status(const RequestType *, ResponseType *)>
+        invoke_method_;
+    grpc::ServerAsyncReaderWriter<ResponseType,RequestType> stream_;
+  };
+
+  static Status ProcessRPC(const SimpleRequest* request,
+			   SimpleResponse* response) {
+    if (request->response_size() > 0) {
       if (!SetPayload(request->response_type(), request->response_size(),
                       response->mutable_payload())) {
         return Status(grpc::StatusCode::INTERNAL, "Error creating payload.");
@@ -200,7 +296,13 @@
   std::function<void(ServerContext*, SimpleRequest*,
                      grpc::ServerAsyncResponseWriter<SimpleResponse>*, void*)>
       request_unary_;
+  std::function<void(ServerContext*, grpc::ServerAsyncReaderWriter<
+		     SimpleResponse,SimpleRequest>*, void*)>
+      request_streaming_;
   std::forward_list<ServerRpcContext*> contexts_;
+
+  std::mutex shutdown_mutex_;
+  bool shutdown_;
 };
 
 std::unique_ptr<Server> CreateAsyncServer(const ServerConfig& config,