Merge pull request #2921 from yang-g/unknown_service

Use a sync service to handle requests to unknown services
diff --git a/include/grpc++/completion_queue.h b/include/grpc++/completion_queue.h
index 0523ab6..2f30211 100644
--- a/include/grpc++/completion_queue.h
+++ b/include/grpc++/completion_queue.h
@@ -63,6 +63,7 @@
 class ServerStreamingHandler;
 template <class ServiceType, class RequestType, class ResponseType>
 class BidiStreamingHandler;
+class UnknownMethodHandler;
 
 class ChannelInterface;
 class ClientContext;
@@ -138,6 +139,7 @@
   friend class ServerStreamingHandler;
   template <class ServiceType, class RequestType, class ResponseType>
   friend class BidiStreamingHandler;
+  friend class UnknownMethodHandler;
   friend class ::grpc::Server;
   friend class ::grpc::ServerContext;
   template <class InputMessage, class OutputMessage>
diff --git a/include/grpc++/impl/rpc_service_method.h b/include/grpc++/impl/rpc_service_method.h
index 3cfbef7..925801e 100644
--- a/include/grpc++/impl/rpc_service_method.h
+++ b/include/grpc++/impl/rpc_service_method.h
@@ -208,6 +208,21 @@
   ServiceType* service_;
 };
 
+// Handle unknown method by returning UNIMPLEMENTED error.
+class UnknownMethodHandler : public MethodHandler {
+ public:
+  void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+    Status status(StatusCode::UNIMPLEMENTED, "");
+    CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
+    if (!param.server_context->sent_initial_metadata_) {
+      ops.SendInitialMetadata(param.server_context->initial_metadata_);
+    }
+    ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
+    param.call->PerformOps(&ops);
+    param.call->cq()->Pluck(&ops);
+  }
+};
+
 // Server side rpc method class
 class RpcServiceMethod : public RpcMethod {
  public:
diff --git a/include/grpc++/server.h b/include/grpc++/server.h
index 94ee0b6..8755b4b 100644
--- a/include/grpc++/server.h
+++ b/include/grpc++/server.h
@@ -228,6 +228,8 @@
   grpc::condition_variable callback_cv_;
 
   std::list<SyncRequest>* sync_methods_;
+  std::unique_ptr<RpcServiceMethod> unknown_method_;
+  bool has_generic_service_;
 
   // Pointer to the c grpc server.
   grpc_server* const server_;
diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h
index 4f7fc54..8262dee 100644
--- a/include/grpc++/server_context.h
+++ b/include/grpc++/server_context.h
@@ -73,6 +73,7 @@
 class ServerStreamingHandler;
 template <class ServiceType, class RequestType, class ResponseType>
 class BidiStreamingHandler;
+class UnknownMethodHandler;
 
 class Call;
 class CallOpBuffer;
@@ -159,6 +160,7 @@
   friend class ServerStreamingHandler;
   template <class ServiceType, class RequestType, class ResponseType>
   friend class BidiStreamingHandler;
+  friend class UnknownMethodHandler;
   friend class ::grpc::ClientContext;
 
   // Prevent copying.
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc
index da2bc94..90f3854 100644
--- a/src/cpp/server/server.cc
+++ b/src/cpp/server/server.cc
@@ -67,11 +67,17 @@
         has_request_payload_(method->method_type() == RpcMethod::NORMAL_RPC ||
                              method->method_type() ==
                                  RpcMethod::SERVER_STREAMING),
+        call_details_(nullptr),
         cq_(nullptr) {
     grpc_metadata_array_init(&request_metadata_);
   }
 
-  ~SyncRequest() { grpc_metadata_array_destroy(&request_metadata_); }
+  ~SyncRequest() {
+    if (call_details_) {
+      delete call_details_;
+    }
+    grpc_metadata_array_destroy(&request_metadata_);
+  }
 
   static SyncRequest* Wait(CompletionQueue* cq, bool* ok) {
     void* tag = nullptr;
@@ -94,17 +100,32 @@
   void Request(grpc_server* server, grpc_completion_queue* notify_cq) {
     GPR_ASSERT(cq_ && !in_flight_);
     in_flight_ = true;
-    GPR_ASSERT(GRPC_CALL_OK ==
-               grpc_server_request_registered_call(
-                   server, tag_, &call_, &deadline_, &request_metadata_,
-                   has_request_payload_ ? &request_payload_ : nullptr, cq_,
-                   notify_cq, this));
+    if (tag_) {
+      GPR_ASSERT(GRPC_CALL_OK ==
+                 grpc_server_request_registered_call(
+                     server, tag_, &call_, &deadline_, &request_metadata_,
+                     has_request_payload_ ? &request_payload_ : nullptr, cq_,
+                     notify_cq, this));
+    } else {
+      if (!call_details_) {
+        call_details_ = new grpc_call_details;
+        grpc_call_details_init(call_details_);
+      }
+      GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call(
+                                     server, &call_, call_details_,
+                                     &request_metadata_, cq_, notify_cq, this));
+    }
   }
 
   bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE {
     if (!*status) {
       grpc_completion_queue_destroy(cq_);
     }
+    if (call_details_) {
+      deadline_ = call_details_->deadline;
+      grpc_call_details_destroy(call_details_);
+      grpc_call_details_init(call_details_);
+    }
     return true;
   }
 
@@ -157,6 +178,7 @@
   bool in_flight_;
   const bool has_request_payload_;
   grpc_call* call_;
+  grpc_call_details* call_details_;
   gpr_timespec deadline_;
   grpc_metadata_array request_metadata_;
   grpc_byte_buffer* request_payload_;
@@ -183,6 +205,7 @@
       shutdown_(false),
       num_running_cb_(0),
       sync_methods_(new std::list<SyncRequest>),
+      has_generic_service_(false),
       server_(CreateServer(max_message_size)),
       thread_pool_(thread_pool),
       thread_pool_owned_(thread_pool_owned) {
@@ -223,7 +246,8 @@
   return true;
 }
 
-bool Server::RegisterAsyncService(const grpc::string *host, AsynchronousService* service) {
+bool Server::RegisterAsyncService(const grpc::string* host,
+                                  AsynchronousService* service) {
   GPR_ASSERT(service->server_ == nullptr &&
              "Can only register an asynchronous service against one server.");
   service->server_ = this;
@@ -245,6 +269,7 @@
   GPR_ASSERT(service->server_ == nullptr &&
              "Can only register an async generic service against one server.");
   service->server_ = this;
+  has_generic_service_ = true;
 }
 
 int Server::AddListeningPort(const grpc::string& addr,
@@ -258,6 +283,11 @@
   started_ = true;
   grpc_server_start(server_);
 
+  if (!has_generic_service_) {
+    unknown_method_.reset(new RpcServiceMethod(
+        "unknown", RpcMethod::BIDI_STREAMING, new UnknownMethodHandler));
+    sync_methods_->emplace_back(unknown_method_.get(), nullptr);
+  }
   // Start processing rpcs.
   if (!sync_methods_->empty()) {
     for (auto m = sync_methods_->begin(); m != sync_methods_->end(); m++) {
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc
index ce7682e..0911887 100644
--- a/src/cpp/server/server_builder.cc
+++ b/src/cpp/server/server_builder.cc
@@ -38,6 +38,7 @@
 #include <grpc++/impl/service_type.h>
 #include <grpc++/server.h>
 #include <grpc++/thread_pool_interface.h>
+#include <grpc++/fixed_size_thread_pool.h>
 
 namespace grpc {
 
@@ -100,6 +101,12 @@
     thread_pool_ = CreateDefaultThreadPool();
     thread_pool_owned = true;
   }
+  // Async services only, create a thread pool to handle requests to unknown
+  // services.
+  if (!thread_pool_ && !generic_service_ && !async_services_.empty()) {
+    thread_pool_ = new FixedSizeThreadPool(1);
+    thread_pool_owned = true;
+  }
   std::unique_ptr<Server> server(
       new Server(thread_pool_, thread_pool_owned, max_message_size_));
   for (auto cq = cqs_.begin(); cq != cqs_.end(); ++cq) {
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index 9b53bdc..f00d19e 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -666,6 +666,28 @@
   EXPECT_TRUE(recv_status.ok());
 }
 
+TEST_F(AsyncEnd2endTest, UnimplementedRpc) {
+  std::shared_ptr<ChannelInterface> channel = CreateChannel(
+      server_address_.str(), InsecureCredentials(), ChannelArguments());
+  std::unique_ptr<grpc::cpp::test::util::UnimplementedService::Stub> stub;
+  stub =
+      std::move(grpc::cpp::test::util::UnimplementedService::NewStub(channel));
+  EchoRequest send_request;
+  EchoResponse recv_response;
+  Status recv_status;
+
+  ClientContext cli_ctx;
+  send_request.set_message("Hello");
+  std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+      stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get()));
+
+  response_reader->Finish(&recv_response, &recv_status, tag(4));
+  Verifier().Expect(4, false).Verify(cq_.get());
+
+  EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+  EXPECT_EQ("", recv_status.error_message());
+}
+
 }  // namespace
 }  // namespace testing
 }  // namespace grpc
diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc
index 3766981..3827cdf 100644
--- a/test/cpp/end2end/end2end_test.cc
+++ b/test/cpp/end2end/end2end_test.cc
@@ -290,13 +290,17 @@
     if (proxy_server_) proxy_server_->Shutdown();
   }
 
-  void ResetStub(bool use_proxy) {
+  void ResetChannel() {
     SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
     ChannelArguments args;
     args.SetSslTargetNameOverride("foo.test.google.fr");
     args.SetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING, "end2end_test");
     channel_ = CreateChannel(server_address_.str(), SslCredentials(ssl_opts),
                              args);
+  }
+
+  void ResetStub(bool use_proxy) {
+    ResetChannel();
     if (use_proxy) {
       proxy_service_.reset(new Proxy(channel_));
       int port = grpc_pick_unused_port_or_die();
@@ -930,6 +934,23 @@
   EXPECT_EQ(GRPC_CHANNEL_CONNECTING, channel_->GetState(false));
 }
 
+// Talking to a non-existing service.
+TEST_F(End2endTest, NonExistingService) {
+  ResetChannel();
+  std::unique_ptr<grpc::cpp::test::util::UnimplementedService::Stub> stub;
+  stub =
+      std::move(grpc::cpp::test::util::UnimplementedService::NewStub(channel_));
+
+  EchoRequest request;
+  EchoResponse response;
+  request.set_message("Hello");
+
+  ClientContext context;
+  Status s = stub->Unimplemented(&context, request, &response);
+  EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code());
+  EXPECT_EQ("", s.error_message());
+}
+
 INSTANTIATE_TEST_CASE_P(End2end, End2endTest, ::testing::Values(false, true));
 
 }  // namespace testing
diff --git a/test/cpp/util/echo.proto b/test/cpp/util/echo.proto
index 6bb0931..8ea2f59 100644
--- a/test/cpp/util/echo.proto
+++ b/test/cpp/util/echo.proto
@@ -41,3 +41,7 @@
   rpc BidiStream(stream EchoRequest) returns (stream EchoResponse);
   rpc Unimplemented(EchoRequest) returns (EchoResponse);
 }
+
+service UnimplementedService {
+  rpc Unimplemented(EchoRequest) returns (EchoResponse);
+}