merge with head
diff --git a/include/grpc++/impl/method_handler_impl.h b/include/grpc++/impl/method_handler_impl.h
new file mode 100644
index 0000000..8f7121b
--- /dev/null
+++ b/include/grpc++/impl/method_handler_impl.h
@@ -0,0 +1,203 @@
+/*
+ *
+ * Copyright 2015, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+#ifndef GRPCXX_IMPL_METHOD_HANDLER_IMPL_H
+#define GRPCXX_IMPL_METHOD_HANDLER_IMPL_H
+
+#include <grpc++/impl/rpc_service_method.h>
+#include <grpc++/support/sync_stream.h>
+
+namespace grpc {
+
+// A wrapper class of an application provided rpc method handler.
+template <class ServiceType, class RequestType, class ResponseType>
+class RpcMethodHandler : public MethodHandler {
+ public:
+ RpcMethodHandler(
+ std::function<Status(ServiceType*, ServerContext*, const RequestType*,
+ ResponseType*)> func,
+ ServiceType* service)
+ : func_(func), service_(service) {}
+
+ void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+ RequestType req;
+ Status status = SerializationTraits<RequestType>::Deserialize(
+ param.request, &req, param.max_message_size);
+ ResponseType rsp;
+ if (status.ok()) {
+ status = func_(service_, param.server_context, &req, &rsp);
+ }
+
+ GPR_ASSERT(!param.server_context->sent_initial_metadata_);
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
+ CallOpServerSendStatus> ops;
+ ops.SendInitialMetadata(param.server_context->initial_metadata_);
+ if (status.ok()) {
+ status = ops.SendMessage(rsp);
+ }
+ ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
+ param.call->PerformOps(&ops);
+ param.call->cq()->Pluck(&ops);
+ }
+
+ private:
+ // Application provided rpc handler function.
+ std::function<Status(ServiceType*, ServerContext*, const RequestType*,
+ ResponseType*)> func_;
+ // The class the above handler function lives in.
+ ServiceType* service_;
+};
+
+// A wrapper class of an application provided client streaming handler.
+template <class ServiceType, class RequestType, class ResponseType>
+class ClientStreamingHandler : public MethodHandler {
+ public:
+ ClientStreamingHandler(
+ std::function<Status(ServiceType*, ServerContext*,
+ ServerReader<RequestType>*, ResponseType*)> func,
+ ServiceType* service)
+ : func_(func), service_(service) {}
+
+ void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+ ServerReader<RequestType> reader(param.call, param.server_context);
+ ResponseType rsp;
+ Status status = func_(service_, param.server_context, &reader, &rsp);
+
+ GPR_ASSERT(!param.server_context->sent_initial_metadata_);
+ CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
+ CallOpServerSendStatus> ops;
+ ops.SendInitialMetadata(param.server_context->initial_metadata_);
+ if (status.ok()) {
+ status = ops.SendMessage(rsp);
+ }
+ ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
+ param.call->PerformOps(&ops);
+ param.call->cq()->Pluck(&ops);
+ }
+
+ private:
+ std::function<Status(ServiceType*, ServerContext*, ServerReader<RequestType>*,
+ ResponseType*)> func_;
+ ServiceType* service_;
+};
+
+// A wrapper class of an application provided server streaming handler.
+template <class ServiceType, class RequestType, class ResponseType>
+class ServerStreamingHandler : public MethodHandler {
+ public:
+ ServerStreamingHandler(
+ std::function<Status(ServiceType*, ServerContext*, const RequestType*,
+ ServerWriter<ResponseType>*)> func,
+ ServiceType* service)
+ : func_(func), service_(service) {}
+
+ void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+ RequestType req;
+ Status status = SerializationTraits<RequestType>::Deserialize(
+ param.request, &req, param.max_message_size);
+
+ if (status.ok()) {
+ ServerWriter<ResponseType> writer(param.call, param.server_context);
+ status = func_(service_, param.server_context, &req, &writer);
+ }
+
+ CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
+ if (!param.server_context->sent_initial_metadata_) {
+ ops.SendInitialMetadata(param.server_context->initial_metadata_);
+ }
+ ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
+ param.call->PerformOps(&ops);
+ param.call->cq()->Pluck(&ops);
+ }
+
+ private:
+ std::function<Status(ServiceType*, ServerContext*, const RequestType*,
+ ServerWriter<ResponseType>*)> func_;
+ ServiceType* service_;
+};
+
+// A wrapper class of an application provided bidi-streaming handler.
+template <class ServiceType, class RequestType, class ResponseType>
+class BidiStreamingHandler : public MethodHandler {
+ public:
+ BidiStreamingHandler(
+ std::function<Status(ServiceType*, ServerContext*,
+ ServerReaderWriter<ResponseType, RequestType>*)>
+ func,
+ ServiceType* service)
+ : func_(func), service_(service) {}
+
+ void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+ ServerReaderWriter<ResponseType, RequestType> stream(param.call,
+ param.server_context);
+ Status status = func_(service_, param.server_context, &stream);
+
+ CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
+ if (!param.server_context->sent_initial_metadata_) {
+ ops.SendInitialMetadata(param.server_context->initial_metadata_);
+ }
+ ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
+ param.call->PerformOps(&ops);
+ param.call->cq()->Pluck(&ops);
+ }
+
+ private:
+ std::function<Status(ServiceType*, ServerContext*,
+ ServerReaderWriter<ResponseType, RequestType>*)> func_;
+ ServiceType* service_;
+};
+
+// Handle unknown method by returning UNIMPLEMENTED error.
+class UnknownMethodHandler : public MethodHandler {
+ public:
+ template <class T>
+ static void FillOps(ServerContext* context, T* ops) {
+ Status status(StatusCode::UNIMPLEMENTED, "");
+ if (!context->sent_initial_metadata_) {
+ ops->SendInitialMetadata(context->initial_metadata_);
+ context->sent_initial_metadata_ = true;
+ }
+ ops->ServerSendStatus(context->trailing_metadata_, status);
+ }
+
+ void RunHandler(const HandlerParameter& param) GRPC_FINAL {
+ CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
+ FillOps(param.server_context, &ops);
+ param.call->PerformOps(&ops);
+ param.call->cq()->Pluck(&ops);
+ }
+};
+
+} // namespace grpc
+
+#endif // GRPCXX_IMPL_METHOD_HANDLER_IMPL_H
\ No newline at end of file
diff --git a/include/grpc++/impl/rpc_service_method.h b/include/grpc++/impl/rpc_service_method.h
index b203c8f..3b47a4d 100644
--- a/include/grpc++/impl/rpc_service_method.h
+++ b/include/grpc++/impl/rpc_service_method.h
@@ -43,14 +43,11 @@
#include <grpc++/impl/rpc_method.h>
#include <grpc++/support/config.h>
#include <grpc++/support/status.h>
-#include <grpc++/support/sync_stream.h>
namespace grpc {
class ServerContext;
class StreamContextInterface;
-// TODO(rocking): we might need to split this file into multiple ones.
-
// Base class for running an RPC handler.
class MethodHandler {
public:
@@ -71,197 +68,25 @@
virtual void RunHandler(const HandlerParameter& param) = 0;
};
-// A wrapper class of an application provided rpc method handler.
-template <class ServiceType, class RequestType, class ResponseType>
-class RpcMethodHandler : public MethodHandler {
- public:
- RpcMethodHandler(
- std::function<Status(ServiceType*, ServerContext*, const RequestType*,
- ResponseType*)> func,
- ServiceType* service)
- : func_(func), service_(service) {}
-
- void RunHandler(const HandlerParameter& param) GRPC_FINAL {
- RequestType req;
- Status status = SerializationTraits<RequestType>::Deserialize(
- param.request, &req, param.max_message_size);
- ResponseType rsp;
- if (status.ok()) {
- status = func_(service_, param.server_context, &req, &rsp);
- }
-
- GPR_ASSERT(!param.server_context->sent_initial_metadata_);
- CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
- CallOpServerSendStatus> ops;
- ops.SendInitialMetadata(param.server_context->initial_metadata_);
- if (status.ok()) {
- status = ops.SendMessage(rsp);
- }
- ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
- param.call->PerformOps(&ops);
- param.call->cq()->Pluck(&ops);
- }
-
- private:
- // Application provided rpc handler function.
- std::function<Status(ServiceType*, ServerContext*, const RequestType*,
- ResponseType*)> func_;
- // The class the above handler function lives in.
- ServiceType* service_;
-};
-
-// A wrapper class of an application provided client streaming handler.
-template <class ServiceType, class RequestType, class ResponseType>
-class ClientStreamingHandler : public MethodHandler {
- public:
- ClientStreamingHandler(
- std::function<Status(ServiceType*, ServerContext*,
- ServerReader<RequestType>*, ResponseType*)> func,
- ServiceType* service)
- : func_(func), service_(service) {}
-
- void RunHandler(const HandlerParameter& param) GRPC_FINAL {
- ServerReader<RequestType> reader(param.call, param.server_context);
- ResponseType rsp;
- Status status = func_(service_, param.server_context, &reader, &rsp);
-
- GPR_ASSERT(!param.server_context->sent_initial_metadata_);
- CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
- CallOpServerSendStatus> ops;
- ops.SendInitialMetadata(param.server_context->initial_metadata_);
- if (status.ok()) {
- status = ops.SendMessage(rsp);
- }
- ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
- param.call->PerformOps(&ops);
- param.call->cq()->Pluck(&ops);
- }
-
- private:
- std::function<Status(ServiceType*, ServerContext*, ServerReader<RequestType>*,
- ResponseType*)> func_;
- ServiceType* service_;
-};
-
-// A wrapper class of an application provided server streaming handler.
-template <class ServiceType, class RequestType, class ResponseType>
-class ServerStreamingHandler : public MethodHandler {
- public:
- ServerStreamingHandler(
- std::function<Status(ServiceType*, ServerContext*, const RequestType*,
- ServerWriter<ResponseType>*)> func,
- ServiceType* service)
- : func_(func), service_(service) {}
-
- void RunHandler(const HandlerParameter& param) GRPC_FINAL {
- RequestType req;
- Status status = SerializationTraits<RequestType>::Deserialize(
- param.request, &req, param.max_message_size);
-
- if (status.ok()) {
- ServerWriter<ResponseType> writer(param.call, param.server_context);
- status = func_(service_, param.server_context, &req, &writer);
- }
-
- CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
- if (!param.server_context->sent_initial_metadata_) {
- ops.SendInitialMetadata(param.server_context->initial_metadata_);
- }
- ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
- param.call->PerformOps(&ops);
- param.call->cq()->Pluck(&ops);
- }
-
- private:
- std::function<Status(ServiceType*, ServerContext*, const RequestType*,
- ServerWriter<ResponseType>*)> func_;
- ServiceType* service_;
-};
-
-// A wrapper class of an application provided bidi-streaming handler.
-template <class ServiceType, class RequestType, class ResponseType>
-class BidiStreamingHandler : public MethodHandler {
- public:
- BidiStreamingHandler(
- std::function<Status(ServiceType*, ServerContext*,
- ServerReaderWriter<ResponseType, RequestType>*)>
- func,
- ServiceType* service)
- : func_(func), service_(service) {}
-
- void RunHandler(const HandlerParameter& param) GRPC_FINAL {
- ServerReaderWriter<ResponseType, RequestType> stream(param.call,
- param.server_context);
- Status status = func_(service_, param.server_context, &stream);
-
- CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
- if (!param.server_context->sent_initial_metadata_) {
- ops.SendInitialMetadata(param.server_context->initial_metadata_);
- }
- ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
- param.call->PerformOps(&ops);
- param.call->cq()->Pluck(&ops);
- }
-
- private:
- std::function<Status(ServiceType*, ServerContext*,
- ServerReaderWriter<ResponseType, RequestType>*)> func_;
- ServiceType* service_;
-};
-
-// Handle unknown method by returning UNIMPLEMENTED error.
-class UnknownMethodHandler : public MethodHandler {
- public:
- template <class T>
- static void FillOps(ServerContext* context, T* ops) {
- Status status(StatusCode::UNIMPLEMENTED, "");
- if (!context->sent_initial_metadata_) {
- ops->SendInitialMetadata(context->initial_metadata_);
- context->sent_initial_metadata_ = true;
- }
- ops->ServerSendStatus(context->trailing_metadata_, status);
- }
-
- void RunHandler(const HandlerParameter& param) GRPC_FINAL {
- CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
- FillOps(param.server_context, &ops);
- param.call->PerformOps(&ops);
- param.call->cq()->Pluck(&ops);
- }
-};
-
// Server side rpc method class
class RpcServiceMethod : public RpcMethod {
public:
// Takes ownership of the handler
RpcServiceMethod(const char* name, RpcMethod::RpcType type,
MethodHandler* handler)
- : RpcMethod(name, type), handler_(handler) {}
+ : RpcMethod(name, type), server_tag_(nullptr), handler_(handler) {}
- MethodHandler* handler() { return handler_.get(); }
+ void set_server_tag(void* tag) { server_tag_ = tag; }
+ void* server_tag() const { return server_tag_; }
+ // if MethodHandler is nullptr, then this is an async method
+ MethodHandler* handler() const { return handler_.get(); }
+ void ResetHandler() { handler_.reset(); }
private:
+ void* server_tag_;
std::unique_ptr<MethodHandler> handler_;
};
-// This class contains all the method information for an rpc service. It is
-// used for registering a service on a grpc server.
-class RpcService {
- public:
- // Takes ownership.
- void AddMethod(RpcServiceMethod* method) { methods_.emplace_back(method); }
-
- RpcServiceMethod* GetMethod(int i) { return methods_[i].get(); }
- int GetMethodCount() const {
- // On win x64, int is only 32bit
- GPR_ASSERT(methods_.size() <= INT_MAX);
- return (int)methods_.size();
- }
-
- private:
- std::vector<std::unique_ptr<RpcServiceMethod>> methods_;
-};
-
} // namespace grpc
#endif // GRPCXX_IMPL_RPC_SERVICE_METHOD_H
diff --git a/include/grpc++/impl/service_type.h b/include/grpc++/impl/service_type.h
index 3b6ac1d..655aa91 100644
--- a/include/grpc++/impl/service_type.h
+++ b/include/grpc++/impl/service_type.h
@@ -34,6 +34,7 @@
#ifndef GRPCXX_IMPL_SERVICE_TYPE_H
#define GRPCXX_IMPL_SERVICE_TYPE_H
+#include <grpc++/impl/rpc_service_method.h>
#include <grpc++/impl/serialization_traits.h>
#include <grpc++/server.h>
#include <grpc++/support/config.h>
@@ -43,17 +44,10 @@
class Call;
class CompletionQueue;
-class RpcService;
class Server;
class ServerCompletionQueue;
class ServerContext;
-class SynchronousService {
- public:
- virtual ~SynchronousService() {}
- virtual RpcService* service() = 0;
-};
-
class ServerAsyncStreamingInterface {
public:
virtual ~ServerAsyncStreamingInterface() {}
@@ -65,15 +59,28 @@
virtual void BindCall(Call* call) = 0;
};
-class AsynchronousService {
+class Service {
public:
- AsynchronousService(const char** method_names, size_t method_count)
- : server_(nullptr),
- method_names_(method_names),
- method_count_(method_count),
- request_args_(nullptr) {}
+ Service() : server_(nullptr) {}
+ virtual ~Service() {}
- ~AsynchronousService() { delete[] request_args_; }
+ bool has_async_methods() const {
+ for (auto it = methods_.begin(); it != methods_.end(); ++it) {
+ if ((*it)->handler() == nullptr) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool has_synchronous_methods() const {
+ for (auto it = methods_.begin(); it != methods_.end(); ++it) {
+ if ((*it)->handler() != nullptr) {
+ return true;
+ }
+ }
+ return false;
+ }
protected:
template <class Message>
@@ -81,41 +88,53 @@
ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag) {
- server_->RequestAsyncCall(request_args_[index], context, stream, call_cq,
+ server_->RequestAsyncCall(methods_[index].get(), context, stream, call_cq,
notification_cq, tag, request);
}
- void RequestClientStreaming(int index, ServerContext* context,
- ServerAsyncStreamingInterface* stream,
- CompletionQueue* call_cq,
- ServerCompletionQueue* notification_cq,
- void* tag) {
- server_->RequestAsyncCall(request_args_[index], context, stream, call_cq,
+ void RequestAsyncClientStreaming(int index, ServerContext* context,
+ ServerAsyncStreamingInterface* stream,
+ CompletionQueue* call_cq,
+ ServerCompletionQueue* notification_cq,
+ void* tag) {
+ server_->RequestAsyncCall(methods_[index].get(), context, stream, call_cq,
notification_cq, tag);
}
template <class Message>
- void RequestServerStreaming(int index, ServerContext* context,
- Message* request,
- ServerAsyncStreamingInterface* stream,
- CompletionQueue* call_cq,
- ServerCompletionQueue* notification_cq,
- void* tag) {
- server_->RequestAsyncCall(request_args_[index], context, stream, call_cq,
+ void RequestAsyncServerStreaming(int index, ServerContext* context,
+ Message* request,
+ ServerAsyncStreamingInterface* stream,
+ CompletionQueue* call_cq,
+ ServerCompletionQueue* notification_cq,
+ void* tag) {
+ server_->RequestAsyncCall(methods_[index].get(), context, stream, call_cq,
notification_cq, tag, request);
}
- void RequestBidiStreaming(int index, ServerContext* context,
- ServerAsyncStreamingInterface* stream,
- CompletionQueue* call_cq,
- ServerCompletionQueue* notification_cq, void* tag) {
- server_->RequestAsyncCall(request_args_[index], context, stream, call_cq,
+ void RequestAsyncBidiStreaming(int index, ServerContext* context,
+ ServerAsyncStreamingInterface* stream,
+ CompletionQueue* call_cq,
+ ServerCompletionQueue* notification_cq,
+ void* tag) {
+ server_->RequestAsyncCall(methods_[index].get(), context, stream, call_cq,
notification_cq, tag);
}
+ void AddMethod(RpcServiceMethod* method) { methods_.emplace_back(method); }
+
+ void MarkMethodAsync(const grpc::string& method_name) {
+ for (auto it = methods_.begin(); it != methods_.end(); ++it) {
+ if ((*it)->name() == method_name) {
+ (*it)->ResetHandler();
+ return;
+ }
+ }
+ abort();
+ }
+
private:
friend class Server;
+
Server* server_;
- const char** const method_names_;
- size_t method_count_;
- void** request_args_;
+ std::vector<std::unique_ptr<RpcServiceMethod>> methods_;
};
} // namespace grpc
diff --git a/include/grpc++/server.h b/include/grpc++/server.h
index 644e66e..92d7a4b 100644
--- a/include/grpc++/server.h
+++ b/include/grpc++/server.h
@@ -40,6 +40,7 @@
#include <grpc++/completion_queue.h>
#include <grpc++/impl/call.h>
#include <grpc++/impl/grpc_library.h>
+#include <grpc++/impl/rpc_service_method.h>
#include <grpc++/impl/sync.h>
#include <grpc++/security/server_credentials.h>
#include <grpc++/support/channel_arguments.h>
@@ -51,13 +52,11 @@
namespace grpc {
-class AsynchronousService;
class GenericServerContext;
class AsyncGenericService;
-class RpcService;
-class RpcServiceMethod;
class ServerAsyncStreamingInterface;
class ServerContext;
+class Service;
class ThreadPoolInterface;
/// Models a gRPC server.
@@ -105,7 +104,7 @@
private:
friend class AsyncGenericService;
- friend class AsynchronousService;
+ friend class Service;
friend class ServerBuilder;
class SyncRequest;
@@ -123,12 +122,7 @@
/// Register a service. This call does not take ownership of the service.
/// The service must exist for the lifetime of the Server instance.
- bool RegisterService(const grpc::string* host, RpcService* service);
-
- /// Register an asynchronous service. This call does not take ownership of the
- /// service. The service must exist for the lifetime of the Server instance.
- bool RegisterAsyncService(const grpc::string* host,
- AsynchronousService* service);
+ bool RegisterService(const grpc::string* host, Service* service);
/// Register a generic service. This call does not take ownership of the
/// service. The service must exist for the lifetime of the Server instance.
@@ -265,21 +259,22 @@
class UnimplementedAsyncResponse;
template <class Message>
- void RequestAsyncCall(void* registered_method, ServerContext* context,
+ void RequestAsyncCall(RpcServiceMethod* method, ServerContext* context,
ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
Message* message) {
- new PayloadAsyncRequest<Message>(registered_method, this, context, stream,
- call_cq, notification_cq, tag, message);
+ new PayloadAsyncRequest<Message>(method->server_tag(), this, context,
+ stream, call_cq, notification_cq, tag,
+ message);
}
- void RequestAsyncCall(void* registered_method, ServerContext* context,
+ void RequestAsyncCall(RpcServiceMethod* method, ServerContext* context,
ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag) {
- new NoPayloadAsyncRequest(registered_method, this, context, stream, call_cq,
- notification_cq, tag);
+ new NoPayloadAsyncRequest(method->server_tag(), this, context, stream,
+ call_cq, notification_cq, tag);
}
void RequestAsyncGenericCall(GenericServerContext* context,
diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h
index b324deb..86c7fec 100644
--- a/include/grpc++/server_builder.h
+++ b/include/grpc++/server_builder.h
@@ -44,14 +44,12 @@
namespace grpc {
class AsyncGenericService;
-class AsynchronousService;
class CompletionQueue;
class RpcService;
class Server;
class ServerCompletionQueue;
class ServerCredentials;
-class SynchronousService;
-class ThreadPoolInterface;
+class Service;
/// A builder class for the creation and startup of \a grpc::Server instances.
class ServerBuilder {
@@ -62,14 +60,7 @@
/// The service must exist for the lifetime of the \a Server instance returned
/// by \a BuildAndStart().
/// Matches requests with any :authority
- void RegisterService(SynchronousService* service);
-
- /// Register an asynchronous service.
- /// This call does not take ownership of the service or completion queue.
- /// The service and completion queuemust exist for the lifetime of the \a
- /// Server instance returned by \a BuildAndStart().
- /// Matches requests with any :authority
- void RegisterAsyncService(AsynchronousService* service);
+ void RegisterService(Service* service);
/// Register a generic service.
/// Matches requests with any :authority
@@ -79,15 +70,7 @@
/// The service must exist for the lifetime of the \a Server instance returned
/// by BuildAndStart().
/// Only matches requests with :authority \a host
- void RegisterService(const grpc::string& host, SynchronousService* service);
-
- /// Register an asynchronous service.
- /// This call does not take ownership of the service or completion queue.
- /// The service and completion queuemust exist for the lifetime of the \a
- /// Server instance returned by \a BuildAndStart().
- /// Only matches requests with :authority equal to \a host
- void RegisterAsyncService(const grpc::string& host,
- AsynchronousService* service);
+ void RegisterService(const grpc::string& host, Service* service);
/// Set max message size in bytes.
void SetMaxMessageSize(int max_message_size) {
@@ -132,26 +115,22 @@
};
typedef std::unique_ptr<grpc::string> HostString;
- template <class T>
struct NamedService {
- explicit NamedService(T* s) : service(s) {}
- NamedService(const grpc::string& h, T* s)
+ explicit NamedService(Service* s) : service(s) {}
+ NamedService(const grpc::string& h, Service* s)
: host(new grpc::string(h)), service(s) {}
HostString host;
- T* service;
+ Service* service;
};
int max_message_size_;
grpc_compression_options compression_options_;
std::vector<std::unique_ptr<ServerBuilderOption>> options_;
- std::vector<std::unique_ptr<NamedService<RpcService>>> services_;
- std::vector<std::unique_ptr<NamedService<AsynchronousService>>>
- async_services_;
+ std::vector<std::unique_ptr<NamedService>> services_;
std::vector<Port> ports_;
std::vector<ServerCompletionQueue*> cqs_;
std::shared_ptr<ServerCredentials> creds_;
AsyncGenericService* generic_service_;
- ThreadPoolInterface* thread_pool_;
};
} // namespace grpc
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index 3c8ca8a..9d0d7eb 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -491,39 +491,114 @@
grpc_cpp_generator::ClassName(method->input_type(), true);
(*vars)["Response"] =
grpc_cpp_generator::ClassName(method->output_type(), true);
+ printer->Print(*vars, "template <class BaseClass>\n");
+ printer->Print(*vars,
+ "class WithAsyncMethod_$Method$ : public BaseClass {\n");
+ printer->Print(
+ " private:\n"
+ " void BaseClassMustBeDerivedFromService(Service *service) {}\n");
+ printer->Print(" public:\n");
+ printer->Indent();
+ printer->Print(*vars,
+ "WithAsyncMethod_$Method$() {\n"
+ " ::grpc::Service::MarkMethodAsync("
+ "\"/$Package$$Service$/$Method$\");\n"
+ "}\n");
+ printer->Print(*vars,
+ "~WithAsyncMethod_$Method$() GRPC_OVERRIDE {\n"
+ " BaseClassMustBeDerivedFromService(this);\n"
+ "}\n");
if (NoStreaming(method)) {
printer->Print(
*vars,
+ "// disable synchronous version of this method\n"
+ "::grpc::Status $Method$("
+ "::grpc::ServerContext* context, const $Request$* request, "
+ "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n"
+ " abort();\n"
+ " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
+ "}\n");
+ printer->Print(
+ *vars,
"void Request$Method$("
"::grpc::ServerContext* context, $Request$* request, "
"::grpc::ServerAsyncResponseWriter< $Response$>* response, "
"::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag);\n");
+ "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
+ printer->Print(*vars,
+ " ::grpc::Service::RequestAsyncUnary($Idx$, context, "
+ "request, response, new_call_cq, notification_cq, tag);\n");
+ printer->Print("}\n");
} else if (ClientOnlyStreaming(method)) {
printer->Print(
*vars,
+ "// disable synchronous version of this method\n"
+ "::grpc::Status $Method$("
+ "::grpc::ServerContext* context, "
+ "::grpc::ServerReader< $Request$>* reader, "
+ "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n"
+ " abort();\n"
+ " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
+ "}\n");
+ printer->Print(
+ *vars,
"void Request$Method$("
"::grpc::ServerContext* context, "
"::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
"::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag);\n");
+ "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
+ printer->Print(*vars,
+ " ::grpc::Service::RequestAsyncClientStreaming($Idx$, "
+ "context, reader, new_call_cq, notification_cq, tag);\n");
+ printer->Print("}\n");
} else if (ServerOnlyStreaming(method)) {
printer->Print(
*vars,
+ "// disable synchronous version of this method\n"
+ "::grpc::Status $Method$("
+ "::grpc::ServerContext* context, const $Request$* request, "
+ "::grpc::ServerWriter< $Response$>* writer) GRPC_FINAL GRPC_OVERRIDE "
+ "{\n"
+ " abort();\n"
+ " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
+ "}\n");
+ printer->Print(
+ *vars,
"void Request$Method$("
"::grpc::ServerContext* context, $Request$* request, "
"::grpc::ServerAsyncWriter< $Response$>* writer, "
"::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag);\n");
+ "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
+ printer->Print(
+ *vars,
+ " ::grpc::Service::RequestAsyncServerStreaming($Idx$, "
+ "context, request, writer, new_call_cq, notification_cq, tag);\n");
+ printer->Print("}\n");
} else if (BidiStreaming(method)) {
printer->Print(
*vars,
+ "// disable synchronous version of this method\n"
+ "::grpc::Status $Method$("
+ "::grpc::ServerContext* context, "
+ "::grpc::ServerReaderWriter< $Response$, $Request$>* stream) "
+ "GRPC_FINAL GRPC_OVERRIDE {\n"
+ " abort();\n"
+ " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
+ "}\n");
+ printer->Print(
+ *vars,
"void Request$Method$("
"::grpc::ServerContext* context, "
"::grpc::ServerAsyncReaderWriter< $Response$, $Request$>* stream, "
"::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag);\n");
+ "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
+ printer->Print(*vars,
+ " ::grpc::Service::RequestAsyncBidiStreaming($Idx$, "
+ "context, stream, new_call_cq, notification_cq, tag);\n");
+ printer->Print("}\n");
}
+ printer->Outdent();
+ printer->Print(*vars, "};\n");
}
void PrintHeaderService(grpc::protobuf::io::Printer *printer,
@@ -580,9 +655,9 @@
printer->Print("\n");
- // Server side - Synchronous
+ // Server side - base
printer->Print(
- "class Service : public ::grpc::SynchronousService {\n"
+ "class Service : public ::grpc::Service {\n"
" public:\n");
printer->Indent();
printer->Print("Service();\n");
@@ -590,26 +665,26 @@
for (int i = 0; i < service->method_count(); ++i) {
PrintHeaderServerMethodSync(printer, service->method(i), vars);
}
- printer->Print("::grpc::RpcService* service() GRPC_OVERRIDE GRPC_FINAL;\n");
printer->Outdent();
- printer->Print(
- " private:\n"
- " std::unique_ptr< ::grpc::RpcService> service_;\n");
printer->Print("};\n");
// Server side - Asynchronous
- printer->Print(
- "class AsyncService GRPC_FINAL : public ::grpc::AsynchronousService {\n"
- " public:\n");
- printer->Indent();
- (*vars)["MethodCount"] = as_string(service->method_count());
- printer->Print("explicit AsyncService();\n");
- printer->Print("~AsyncService() {};\n");
for (int i = 0; i < service->method_count(); ++i) {
+ (*vars)["Idx"] = as_string(i);
PrintHeaderServerMethodAsync(printer, service->method(i), vars);
}
- printer->Outdent();
- printer->Print("};\n");
+
+ printer->Print("typedef ");
+
+ for (int i = 0; i < service->method_count(); ++i) {
+ (*vars)["method_name"] = service->method(i)->name();
+ printer->Print(*vars, "WithAsyncMethod_$method_name$<");
+ }
+ printer->Print("Service");
+ for (int i = 0; i < service->method_count(); ++i) {
+ printer->Print(" >");
+ }
+ printer->Print(" AsyncService;\n");
printer->Outdent();
printer->Print("};\n");
@@ -623,6 +698,12 @@
grpc::protobuf::io::StringOutputStream output_stream(&output);
grpc::protobuf::io::Printer printer(&output_stream, '$');
std::map<grpc::string, grpc::string> vars;
+ // Package string is empty or ends with a dot. It is used to fully qualify
+ // method names.
+ vars["Package"] = file->package();
+ if (!file->package().empty()) {
+ vars["Package"].append(".");
+ }
if (!params.services_namespace.empty()) {
vars["services_namespace"] = params.services_namespace;
@@ -704,6 +785,7 @@
printer.Print(vars, "#include <grpc++/channel.h>\n");
printer.Print(vars, "#include <grpc++/impl/client_unary_call.h>\n");
+ printer.Print(vars, "#include <grpc++/impl/method_handler_impl.h>\n");
printer.Print(vars, "#include <grpc++/impl/rpc_service_method.h>\n");
printer.Print(vars, "#include <grpc++/impl/service_type.h>\n");
printer.Print(vars, "#include <grpc++/support/async_unary_call.h>\n");
@@ -889,69 +971,6 @@
}
}
-void PrintSourceServerAsyncMethod(
- grpc::protobuf::io::Printer *printer,
- const grpc::protobuf::MethodDescriptor *method,
- std::map<grpc::string, grpc::string> *vars) {
- (*vars)["Method"] = method->name();
- (*vars)["Request"] =
- grpc_cpp_generator::ClassName(method->input_type(), true);
- (*vars)["Response"] =
- grpc_cpp_generator::ClassName(method->output_type(), true);
- if (NoStreaming(method)) {
- printer->Print(
- *vars,
- "void $ns$$Service$::AsyncService::Request$Method$("
- "::grpc::ServerContext* context, "
- "$Request$* request, "
- "::grpc::ServerAsyncResponseWriter< $Response$>* response, "
- "::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
- printer->Print(*vars,
- " AsynchronousService::RequestAsyncUnary($Idx$, context, "
- "request, response, new_call_cq, notification_cq, tag);\n");
- printer->Print("}\n\n");
- } else if (ClientOnlyStreaming(method)) {
- printer->Print(
- *vars,
- "void $ns$$Service$::AsyncService::Request$Method$("
- "::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
- "::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
- printer->Print(*vars,
- " AsynchronousService::RequestClientStreaming($Idx$, "
- "context, reader, new_call_cq, notification_cq, tag);\n");
- printer->Print("}\n\n");
- } else if (ServerOnlyStreaming(method)) {
- printer->Print(
- *vars,
- "void $ns$$Service$::AsyncService::Request$Method$("
- "::grpc::ServerContext* context, "
- "$Request$* request, "
- "::grpc::ServerAsyncWriter< $Response$>* writer, "
- "::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
- printer->Print(
- *vars,
- " AsynchronousService::RequestServerStreaming($Idx$, "
- "context, request, writer, new_call_cq, notification_cq, tag);\n");
- printer->Print("}\n\n");
- } else if (BidiStreaming(method)) {
- printer->Print(
- *vars,
- "void $ns$$Service$::AsyncService::Request$Method$("
- "::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReaderWriter< $Response$, $Request$>* stream, "
- "::grpc::CompletionQueue* new_call_cq, "
- "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n");
- printer->Print(*vars,
- " AsynchronousService::RequestBidiStreaming($Idx$, "
- "context, stream, new_call_cq, notification_cq, tag);\n");
- printer->Print("}\n\n");
- }
-}
-
void PrintSourceService(grpc::protobuf::io::Printer *printer,
const grpc::protobuf::ServiceDescriptor *service,
std::map<grpc::string, grpc::string> *vars) {
@@ -1006,32 +1025,8 @@
PrintSourceClientMethod(printer, service->method(i), vars);
}
- (*vars)["MethodCount"] = as_string(service->method_count());
- printer->Print(*vars,
- "$ns$$Service$::AsyncService::AsyncService() : "
- "::grpc::AsynchronousService("
- "$prefix$$Service$_method_names, $MethodCount$) "
- "{}\n\n");
-
- printer->Print(*vars,
- "$ns$$Service$::Service::Service() {\n"
- "}\n\n");
- printer->Print(*vars,
- "$ns$$Service$::Service::~Service() {\n"
- "}\n\n");
- for (int i = 0; i < service->method_count(); ++i) {
- (*vars)["Idx"] = as_string(i);
- PrintSourceServerMethod(printer, service->method(i), vars);
- PrintSourceServerAsyncMethod(printer, service->method(i), vars);
- }
- printer->Print(*vars,
- "::grpc::RpcService* $ns$$Service$::Service::service() {\n");
+ printer->Print(*vars, "$ns$$Service$::Service::Service() {\n");
printer->Indent();
- printer->Print(
- "if (service_) {\n"
- " return service_.get();\n"
- "}\n");
- printer->Print("service_ = std::unique_ptr< ::grpc::RpcService>(new ::grpc::RpcService());\n");
for (int i = 0; i < service->method_count(); ++i) {
const grpc::protobuf::MethodDescriptor *method = service->method(i);
(*vars)["Idx"] = as_string(i);
@@ -1043,7 +1038,7 @@
if (NoStreaming(method)) {
printer->Print(
*vars,
- "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
+ "AddMethod(new ::grpc::RpcServiceMethod(\n"
" $prefix$$Service$_method_names[$Idx$],\n"
" ::grpc::RpcMethod::NORMAL_RPC,\n"
" new ::grpc::RpcMethodHandler< $ns$$Service$::Service, "
@@ -1053,7 +1048,7 @@
} else if (ClientOnlyStreaming(method)) {
printer->Print(
*vars,
- "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
+ "AddMethod(new ::grpc::RpcServiceMethod(\n"
" $prefix$$Service$_method_names[$Idx$],\n"
" ::grpc::RpcMethod::CLIENT_STREAMING,\n"
" new ::grpc::ClientStreamingHandler< "
@@ -1062,7 +1057,7 @@
} else if (ServerOnlyStreaming(method)) {
printer->Print(
*vars,
- "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
+ "AddMethod(new ::grpc::RpcServiceMethod(\n"
" $prefix$$Service$_method_names[$Idx$],\n"
" ::grpc::RpcMethod::SERVER_STREAMING,\n"
" new ::grpc::ServerStreamingHandler< "
@@ -1071,7 +1066,7 @@
} else if (BidiStreaming(method)) {
printer->Print(
*vars,
- "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
+ "AddMethod(new ::grpc::RpcServiceMethod(\n"
" $prefix$$Service$_method_names[$Idx$],\n"
" ::grpc::RpcMethod::BIDI_STREAMING,\n"
" new ::grpc::BidiStreamingHandler< "
@@ -1079,9 +1074,15 @@
" std::mem_fn(&$ns$$Service$::Service::$Method$), this)));\n");
}
}
- printer->Print("return service_.get();\n");
printer->Outdent();
- printer->Print("}\n\n");
+ printer->Print(*vars, "}\n\n");
+ printer->Print(*vars,
+ "$ns$$Service$::Service::~Service() {\n"
+ "}\n\n");
+ for (int i = 0; i < service->method_count(); ++i) {
+ (*vars)["Idx"] = as_string(i);
+ PrintSourceServerMethod(printer, service->method(i), vars);
+ }
}
grpc::string GetSourceServices(const grpc::protobuf::FileDescriptor *file,
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc
index 878775b..898f68f 100644
--- a/src/cpp/server/server.cc
+++ b/src/cpp/server/server.cc
@@ -40,6 +40,7 @@
#include <grpc/support/log.h>
#include <grpc++/completion_queue.h>
#include <grpc++/generic/async_generic_service.h>
+#include <grpc++/impl/method_handler_impl.h>
#include <grpc++/impl/rpc_service_method.h>
#include <grpc++/impl/service_type.h>
#include <grpc++/server_context.h>
@@ -314,36 +315,28 @@
g_callbacks = callbacks;
}
-bool Server::RegisterService(const grpc::string* host, RpcService* service) {
- for (int i = 0; i < service->GetMethodCount(); ++i) {
- RpcServiceMethod* method = service->GetMethod(i);
+bool Server::RegisterService(const grpc::string* host, Service* service) {
+ bool has_async_methods = service->has_async_methods();
+ if (has_async_methods) {
+ GPR_ASSERT(service->server_ == nullptr &&
+ "Can only register an asynchronous service against one server.");
+ service->server_ = this;
+ }
+ for (auto it = service->methods_.begin(); it != service->methods_.end();
+ ++it) {
+ RpcServiceMethod* method = it->get();
void* tag = grpc_server_register_method(server_, method->name(),
host ? host->c_str() : nullptr);
- if (!tag) {
+ if (tag == nullptr) {
gpr_log(GPR_DEBUG, "Attempt to register %s multiple times",
method->name());
return false;
}
- sync_methods_->emplace_back(method, tag);
- }
- return true;
-}
-
-bool Server::RegisterAsyncService(const grpc::string* host,
- AsynchronousService* service) {
- GPR_ASSERT(service->server_ == nullptr &&
- "Can only register an asynchronous service against one server.");
- service->server_ = this;
- service->request_args_ = new void* [service->method_count_];
- for (size_t i = 0; i < service->method_count_; ++i) {
- void* tag = grpc_server_register_method(server_, service->method_names_[i],
- host ? host->c_str() : nullptr);
- if (!tag) {
- gpr_log(GPR_DEBUG, "Attempt to register %s multiple times",
- service->method_names_[i]);
- return false;
+ if (method->handler() == nullptr) {
+ method->set_server_tag(tag);
+ } else {
+ sync_methods_->emplace_back(method, tag);
}
- service->request_args_[i] = tag;
}
return true;
}
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc
index 26c0724..bd7dd76 100644
--- a/src/cpp/server/server_builder.cc
+++ b/src/cpp/server/server_builder.cc
@@ -43,7 +43,7 @@
namespace grpc {
ServerBuilder::ServerBuilder()
- : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {
+ : max_message_size_(-1), generic_service_(nullptr) {
grpc_compression_options_init(&compression_options_);
}
@@ -53,24 +53,13 @@
return std::unique_ptr<ServerCompletionQueue>(cq);
}
-void ServerBuilder::RegisterService(SynchronousService* service) {
- services_.emplace_back(new NamedService<RpcService>(service->service()));
-}
-
-void ServerBuilder::RegisterAsyncService(AsynchronousService* service) {
- async_services_.emplace_back(new NamedService<AsynchronousService>(service));
+void ServerBuilder::RegisterService(Service* service) {
+ services_.emplace_back(new NamedService(service));
}
void ServerBuilder::RegisterService(const grpc::string& addr,
- SynchronousService* service) {
- services_.emplace_back(
- new NamedService<RpcService>(addr, service->service()));
-}
-
-void ServerBuilder::RegisterAsyncService(const grpc::string& addr,
- AsynchronousService* service) {
- async_services_.emplace_back(
- new NamedService<AsynchronousService>(addr, service));
+ Service* service) {
+ services_.emplace_back(new NamedService(addr, service));
}
void ServerBuilder::RegisterAsyncGenericService(AsyncGenericService* service) {
@@ -96,14 +85,14 @@
}
std::unique_ptr<Server> ServerBuilder::BuildAndStart() {
- bool thread_pool_owned = false;
- if (!async_services_.empty() && !services_.empty()) {
- gpr_log(GPR_ERROR, "Mixing async and sync services is unsupported for now");
- return nullptr;
- }
- if (!thread_pool_ && !services_.empty()) {
- thread_pool_ = CreateDefaultThreadPool();
- thread_pool_owned = true;
+ std::unique_ptr<ThreadPoolInterface> thread_pool;
+ for (auto it = services_.begin(); it != services_.end(); ++it) {
+ if ((*it)->service->has_synchronous_methods()) {
+ if (thread_pool == nullptr) {
+ thread_pool.reset(CreateDefaultThreadPool());
+ break;
+ }
+ }
}
ChannelArguments args;
for (auto option = options_.begin(); option != options_.end(); ++option) {
@@ -115,7 +104,7 @@
args.SetInt(GRPC_COMPRESSION_ALGORITHM_STATE_ARG,
compression_options_.enabled_algorithms_bitset);
std::unique_ptr<Server> server(
- new Server(thread_pool_, thread_pool_owned, max_message_size_, args));
+ new Server(thread_pool.release(), true, max_message_size_, args));
for (auto cq = cqs_.begin(); cq != cqs_.end(); ++cq) {
grpc_server_register_completion_queue(server->server_, (*cq)->cq(),
nullptr);
@@ -126,13 +115,6 @@
return nullptr;
}
}
- for (auto service = async_services_.begin(); service != async_services_.end();
- service++) {
- if (!server->RegisterAsyncService((*service)->host.get(),
- (*service)->service)) {
- return nullptr;
- }
- }
if (generic_service_) {
server->RegisterAsyncGenericService(generic_service_);
}
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index cfda571..0616cc0 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -180,21 +180,11 @@
int port = grpc_pick_unused_port_or_die();
server_address_ << "localhost:" << port;
- // It is currently unsupported to mix sync and async services
- // in the same server, so first test that (for coverage)
- ServerBuilder build_bad;
- build_bad.AddListeningPort(server_address_.str(),
- grpc::InsecureServerCredentials());
- build_bad.RegisterAsyncService(&service_);
- grpc::testing::EchoTestService::Service sync_service;
- build_bad.RegisterService(&sync_service);
- GPR_ASSERT(build_bad.BuildAndStart() == nullptr);
-
// Setup server
ServerBuilder builder;
builder.AddListeningPort(server_address_.str(),
grpc::InsecureServerCredentials());
- builder.RegisterAsyncService(&service_);
+ builder.RegisterService(&service_);
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
}
diff --git a/test/cpp/end2end/hybrid_end2end_test.cc b/test/cpp/end2end/hybrid_end2end_test.cc
new file mode 100644
index 0000000..24de363
--- /dev/null
+++ b/test/cpp/end2end/hybrid_end2end_test.cc
@@ -0,0 +1,178 @@
+/*
+ *
+ * Copyright 2015-2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+#include <memory>
+
+#include <grpc++/channel.h>
+#include <grpc++/client_context.h>
+#include <grpc++/create_channel.h>
+#include <grpc++/server.h>
+#include <grpc++/server_builder.h>
+#include <grpc++/server_context.h>
+#include <grpc/grpc.h>
+#include <grpc/support/thd.h>
+#include <grpc/support/time.h>
+#include <gtest/gtest.h>
+
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+void* tag(int i) { return (void*)(intptr_t)i; }
+
+// Handlers to handle async request at a server. To be run in a separate thread.
+void HandleEcho(::grpc::Service* service, ServerCompletionQueue* cq) {
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EcoResponse> response_writer(&srv_ctx);
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ service->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq, cq, tag(1));
+ Verify(cq, 1, true);
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(2));
+ Verify(cq, 2, true);
+}
+
+void HandleClientStreaming(::grpc::Service* service, ServerCompletionQueue* cq) {
+ ServerContext srv_ctx;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+ service_.RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1));
+ Verify(cq, 1, true);
+ do {
+ srv_stream.Read(&recv_request, tag(2));
+ } while (VerifyReturnSuccess(2));
+ srv_stream.Finish(send_response, Status::OK, tag(3));
+ Verify(cq, 3, true);
+}
+
+class HybridEnd2endTest : public ::testing::Test {
+ protected:
+ HybridEnd2endTest() {}
+
+ void SetUpServer(::grpc::Service* service) {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() GRPC_OVERRIDE {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel =
+ CreateChannel(server_address_.str(), InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ void TestAllMethods() {
+ SendEcho();
+ SendSimpleClientStreaming();
+ }
+
+ void SendEcho() {
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ Status recv_status = stub_->Echo(&cli_ctx, send_request, &recv_response);
+ EXPECT_EQ(send_request.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+
+ void SendSimpleClientStreaming() {
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ auto stream = stub_->RequestStream(&cli_ctx, &recv_response);
+ for (int i = 0; i < 5; i++) {
+ EXPECT_TRUE(stream->Write(&send_request));
+ }
+ Status recv_status = stream->Finish();
+ EXPECT_EQ(send_request.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+};
+
+TEST_F(HybridEnd2endTest, AsyncEchorequestStream) {
+ WithAsyncMethod_Echo<WithAsyncMethod_RequestStream<EchoTestService> > service;
+ SetUpServer(&service);
+ ResetStub();
+ std::thread echo_handler_thread(HandleEcho, &service, cq_.get());
+ std::thread request_stream_thread(HandleClientStreaming, &service, cq_.get());
+ TestAllMethods();
+ echo_handler_thread.join();
+ request_stream_thread.join();
+}
+
+
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc_test_init(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/test/cpp/end2end/mixed_handlers_end2end_test.cc b/test/cpp/end2end/mixed_handlers_end2end_test.cc
new file mode 100644
index 0000000..a896ad2
--- /dev/null
+++ b/test/cpp/end2end/mixed_handlers_end2end_test.cc
@@ -0,0 +1,747 @@
+/*
+ *
+ * Copyright 2015-2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+#include <memory>
+
+#include <grpc/grpc.h>
+#include <grpc/support/thd.h>
+#include <grpc/support/time.h>
+#include <grpc++/channel.h>
+#include <grpc++/client_context.h>
+#include <grpc++/create_channel.h>
+#include <grpc++/server.h>
+#include <grpc++/server_builder.h>
+#include <grpc++/server_context.h>
+#include <gtest/gtest.h>
+
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+void* tag(int i) { return (void*)(intptr_t)i; }
+
+class Verifier {
+ public:
+ explicit Verifier(bool spin) : spin_(spin) {}
+ Verifier& Expect(int i, bool expect_ok) {
+ expectations_[tag(i)] = expect_ok;
+ return *this;
+ }
+ void Verify(CompletionQueue* cq) {
+ GPR_ASSERT(!expectations_.empty());
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ if (spin_) {
+ for (;;) {
+ auto r = cq->AsyncNext(&got_tag, &ok, gpr_time_0(GPR_CLOCK_REALTIME));
+ if (r == CompletionQueue::TIMEOUT) continue;
+ if (r == CompletionQueue::GOT_EVENT) break;
+ gpr_log(GPR_ERROR, "unexpected result from AsyncNext");
+ abort();
+ }
+ } else {
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ }
+ auto it = expectations_.find(got_tag);
+ EXPECT_TRUE(it != expectations_.end());
+ EXPECT_EQ(it->second, ok);
+ expectations_.erase(it);
+ }
+ }
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ if (spin_) {
+ while (std::chrono::system_clock::now() < deadline) {
+ EXPECT_EQ(
+ cq->AsyncNext(&got_tag, &ok, gpr_time_0(GPR_CLOCK_REALTIME)),
+ CompletionQueue::TIMEOUT);
+ }
+ } else {
+ EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline),
+ CompletionQueue::TIMEOUT);
+ }
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ if (spin_) {
+ for (;;) {
+ GPR_ASSERT(std::chrono::system_clock::now() < deadline);
+ auto r =
+ cq->AsyncNext(&got_tag, &ok, gpr_time_0(GPR_CLOCK_REALTIME));
+ if (r == CompletionQueue::TIMEOUT) continue;
+ if (r == CompletionQueue::GOT_EVENT) break;
+ gpr_log(GPR_ERROR, "unexpected result from AsyncNext");
+ abort();
+ }
+ } else {
+ EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline),
+ CompletionQueue::GOT_EVENT);
+ }
+ auto it = expectations_.find(got_tag);
+ EXPECT_TRUE(it != expectations_.end());
+ EXPECT_EQ(it->second, ok);
+ expectations_.erase(it);
+ }
+ }
+ }
+
+ private:
+ std::map<void*, bool> expectations_;
+ bool spin_;
+};
+
+class AsyncEnd2endTest : public ::testing::TestWithParam<bool> {
+ protected:
+ AsyncEnd2endTest() {}
+
+ void SetUp() GRPC_OVERRIDE {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() GRPC_OVERRIDE {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel =
+ CreateChannel(server_address_.str(), InsecureChannelCredentials());
+ stub_ = grpc::testing::TestService::NewStub(channel);
+ }
+
+ void SendRpc(int num_rpcs) {
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+ }
+
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<grpc::testing::TestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ grpc::testing::TestService::AsyncService service_;
+ std::ostringstream server_address_;
+};
+
+TEST_P(AsyncEnd2endTest, SimpleRpc) {
+ ResetStub();
+ SendRpc(1);
+}
+
+TEST_P(AsyncEnd2endTest, SequentialRpcs) {
+ ResetStub();
+ SendRpc(10);
+}
+
+// Test a simple RPC using the async version of Next
+TEST_P(AsyncEnd2endTest, AsyncNextRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ std::chrono::system_clock::time_point time_now(
+ std::chrono::system_clock::now());
+ std::chrono::system_clock::time_point time_limit(
+ std::chrono::system_clock::now() + std::chrono::seconds(10));
+ Verifier(GetParam()).Verify(cq_.get(), time_now);
+ Verifier(GetParam()).Verify(cq_.get(), time_now);
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get(), time_limit);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier(GetParam())
+ .Expect(3, true)
+ .Verify(cq_.get(), std::chrono::system_clock::time_point::max());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam())
+ .Expect(4, true)
+ .Verify(cq_.get(), std::chrono::system_clock::time_point::max());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Two pings and a final pong.
+TEST_P(AsyncEnd2endTest, SimpleClientStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncWriter<EchoRequest> > cli_stream(
+ stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1)));
+
+ service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier(GetParam()).Expect(2, true).Expect(1, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request, tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_stream->Write(send_request, tag(5));
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(6));
+ Verifier(GetParam()).Expect(6, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ cli_stream->WritesDone(tag(7));
+ Verifier(GetParam()).Expect(7, true).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier(GetParam()).Expect(8, false).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Finish(send_response, Status::OK, tag(9));
+ Verifier(GetParam()).Expect(9, true).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier(GetParam()).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, two pongs.
+TEST_P(AsyncEnd2endTest, SimpleServerStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncReader<EchoResponse> > cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+
+ service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier(GetParam()).Expect(1, true).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.Write(send_response, tag(5));
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier(GetParam()).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.Finish(Status::OK, tag(7));
+ Verifier(GetParam()).Expect(7, true).Verify(cq_.get());
+
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier(GetParam()).Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier(GetParam()).Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, one pong.
+TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse> >
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1)));
+
+ service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier(GetParam()).Expect(1, true).Expect(2, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request, tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier(GetParam()).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ Verifier(GetParam()).Expect(7, true).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier(GetParam()).Expect(8, false).Verify(cq_.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ Verifier(GetParam()).Expect(9, true).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier(GetParam()).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Metadata tests
+TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::pair<grpc::string, grpc::string> meta1("key1", "val1");
+ std::pair<grpc::string, grpc::string> meta2("key2", "val2");
+ cli_ctx.AddMetadata(meta1.first, meta1.second);
+ cli_ctx.AddMetadata(meta2.first, meta2.second);
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ auto client_initial_metadata = srv_ctx.client_metadata();
+ EXPECT_EQ(meta1.second,
+ ToString(client_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(client_initial_metadata.find(meta2.first)->second));
+ EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2));
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::pair<grpc::string, grpc::string> meta1("key1", "val1");
+ std::pair<grpc::string, grpc::string> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ srv_ctx.AddInitialMetadata(meta1.first, meta1.second);
+ srv_ctx.AddInitialMetadata(meta2.first, meta2.second);
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ response_reader->ReadInitialMetadata(tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+ auto server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_initial_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(5));
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(6));
+ Verifier(GetParam()).Expect(6, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::pair<grpc::string, grpc::string> meta1("key1", "val1");
+ std::pair<grpc::string, grpc::string> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_ctx.AddTrailingMetadata(meta1.first, meta1.second);
+ srv_ctx.AddTrailingMetadata(meta2.first, meta2.second);
+ response_writer.Finish(send_response, Status::OK, tag(4));
+
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(5));
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ auto server_trailing_metadata = cli_ctx.GetServerTrailingMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_trailing_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_trailing_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_trailing_metadata.size());
+}
+
+TEST_P(AsyncEnd2endTest, MetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::pair<grpc::string, grpc::string> meta1("key1", "val1");
+ std::pair<grpc::string, grpc::string> meta2(
+ "key2-bin",
+ grpc::string("\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc", 13));
+ std::pair<grpc::string, grpc::string> meta3("key3", "val3");
+ std::pair<grpc::string, grpc::string> meta6(
+ "key4-bin",
+ grpc::string("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d",
+ 14));
+ std::pair<grpc::string, grpc::string> meta5("key5", "val5");
+ std::pair<grpc::string, grpc::string> meta4(
+ "key6-bin",
+ grpc::string(
+ "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee", 15));
+
+ cli_ctx.AddMetadata(meta1.first, meta1.second);
+ cli_ctx.AddMetadata(meta2.first, meta2.second);
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ auto client_initial_metadata = srv_ctx.client_metadata();
+ EXPECT_EQ(meta1.second,
+ ToString(client_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(client_initial_metadata.find(meta2.first)->second));
+ EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2));
+
+ srv_ctx.AddInitialMetadata(meta3.first, meta3.second);
+ srv_ctx.AddInitialMetadata(meta4.first, meta4.second);
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+ response_reader->ReadInitialMetadata(tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+ auto server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta3.second,
+ ToString(server_initial_metadata.find(meta3.first)->second));
+ EXPECT_EQ(meta4.second,
+ ToString(server_initial_metadata.find(meta4.first)->second));
+ EXPECT_GE(server_initial_metadata.size(), static_cast<size_t>(2));
+
+ send_response.set_message(recv_request.message());
+ srv_ctx.AddTrailingMetadata(meta5.first, meta5.second);
+ srv_ctx.AddTrailingMetadata(meta6.first, meta6.second);
+ response_writer.Finish(send_response, Status::OK, tag(5));
+
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(6));
+ Verifier(GetParam()).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ auto server_trailing_metadata = cli_ctx.GetServerTrailingMetadata();
+ EXPECT_EQ(meta5.second,
+ ToString(server_trailing_metadata.find(meta5.first)->second));
+ EXPECT_EQ(meta6.second,
+ ToString(server_trailing_metadata.find(meta6.first)->second));
+ EXPECT_GE(server_trailing_metadata.size(), static_cast<size_t>(2));
+}
+
+// Server uses AsyncNotifyWhenDone API to check for cancellation
+TEST_P(AsyncEnd2endTest, ServerCheckCancellation) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ srv_ctx.AsyncNotifyWhenDone(tag(5));
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_ctx.TryCancel();
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam()).Expect(4, false).Verify(cq_.get());
+
+ EXPECT_EQ(StatusCode::CANCELLED, recv_status.error_code());
+}
+
+// Server uses AsyncNotifyWhenDone API to check for normal finish
+TEST_P(AsyncEnd2endTest, ServerCheckDone) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ srv_ctx.AsyncNotifyWhenDone(tag(5));
+ service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier(GetParam()).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier(GetParam()).Expect(3, true).Verify(cq_.get());
+ Verifier(GetParam()).Expect(5, true).Verify(cq_.get());
+ EXPECT_FALSE(srv_ctx.IsCancelled());
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam()).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, UnimplementedRpc) {
+ std::shared_ptr<Channel> channel =
+ CreateChannel(server_address_.str(), InsecureChannelCredentials());
+ std::unique_ptr<grpc::testing::UnimplementedService::Stub> stub;
+ stub = grpc::testing::UnimplementedService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader(
+ stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get()));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier(GetParam()).Expect(4, false).Verify(cq_.get());
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+}
+
+INSTANTIATE_TEST_CASE_P(AsyncEnd2end, AsyncEnd2endTest,
+ ::testing::Values(false, true));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc_test_init(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc
index 3e2317c..47e902c 100644
--- a/test/cpp/qps/client_async.cc
+++ b/test/cpp/qps/client_async.cc
@@ -49,10 +49,10 @@
#include <grpc/support/histogram.h>
#include <grpc/support/log.h>
+#include "src/proto/grpc/testing/services.grpc.pb.h"
#include "test/cpp/qps/client.h"
#include "test/cpp/qps/timer.h"
#include "test/cpp/util/create_test_channel.h"
-#include "src/proto/grpc/testing/services.grpc.pb.h"
namespace grpc {
namespace testing {
diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc
index 1ae88d7..50be679 100644
--- a/test/cpp/qps/server_async.cc
+++ b/test/cpp/qps/server_async.cc
@@ -350,7 +350,7 @@
static void RegisterBenchmarkService(ServerBuilder *builder,
BenchmarkService::AsyncService *service) {
- builder->RegisterAsyncService(service);
+ builder->RegisterService(service);
}
static void RegisterGenericService(ServerBuilder *builder,
grpc::AsyncGenericService *service) {