Expose max message size at the server side
diff --git a/include/grpc++/config.h b/include/grpc++/config.h
index 0f3d692..b6c1705 100644
--- a/include/grpc++/config.h
+++ b/include/grpc++/config.h
@@ -93,13 +93,17 @@
#endif
#ifndef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM
+#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream.h>
#define GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM \
::google::protobuf::io::ZeroCopyOutputStream
#define GRPC_CUSTOM_ZEROCOPYINPUTSTREAM \
::google::protobuf::io::ZeroCopyInputStream
+#define GRPC_CUSTOM_CODEDINPUTSTREAM \
+ ::google::protobuf::io::CodedInputStream
#endif
+
#ifdef GRPC_CXX0X_NO_NULLPTR
#include <memory>
const class {
@@ -126,6 +130,7 @@
namespace io {
typedef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM ZeroCopyOutputStream;
typedef GRPC_CUSTOM_ZEROCOPYINPUTSTREAM ZeroCopyInputStream;
+typedef GRPC_CUSTOM_CODEDINPUTSTREAM CodedInputStream;
} // namespace io
} // namespace protobuf
diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h
index b14c41d..d76ef61 100644
--- a/include/grpc++/impl/call.h
+++ b/include/grpc++/impl/call.h
@@ -80,6 +80,10 @@
// Called by completion queue just prior to returning from Next() or Pluck()
bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE;
+ void set_max_message_size(int max_message_size) {
+ max_message_size_ = max_message_size;
+ }
+
bool got_message;
private:
@@ -99,6 +103,7 @@
grpc::protobuf::Message* recv_message_;
ByteBuffer* recv_message_buffer_;
grpc_byte_buffer* recv_buf_;
+ int max_message_size_;
// Client send close
bool client_send_close_;
// Client recv status
@@ -130,16 +135,21 @@
public:
/* call is owned by the caller */
Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq);
+ Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq,
+ int max_message_size);
void PerformOps(CallOpBuffer* buffer);
grpc_call* call() { return call_; }
CompletionQueue* cq() { return cq_; }
+ int max_message_size() { return max_message_size_; }
+
private:
CallHook* call_hook_;
CompletionQueue* cq_;
grpc_call* call_;
+ int max_message_size_;
};
} // namespace grpc
diff --git a/include/grpc++/server.h b/include/grpc++/server.h
index c686474..b2b9044 100644
--- a/include/grpc++/server.h
+++ b/include/grpc++/server.h
@@ -79,7 +79,8 @@
class AsyncRequest;
// ServerBuilder use only
- Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned);
+ Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
+ int max_message_size);
// Register a service. This call does not take ownership of the service.
// The service must exist for the lifetime of the Server instance.
bool RegisterService(RpcService* service);
@@ -106,6 +107,8 @@
ServerAsyncStreamingInterface* stream,
CompletionQueue* cq, void* tag);
+ const int max_message_size_;
+
// Completion queue.
CompletionQueue cq_;
@@ -126,7 +129,7 @@
// Whether the thread pool is created and owned by the server.
bool thread_pool_owned_;
private:
- Server() : server_(NULL) { abort(); }
+ Server() : max_message_size_(-1), server_(NULL) { abort(); }
};
} // namespace grpc
diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h
index 9a9932e..7155c7f 100644
--- a/include/grpc++/server_builder.h
+++ b/include/grpc++/server_builder.h
@@ -68,6 +68,11 @@
// Register a generic service.
void RegisterAsyncGenericService(AsyncGenericService* service);
+ // Set max message size in bytes.
+ void SetMaxMessageSize(int max_message_size) {
+ max_message_size_ = max_message_size;
+ }
+
// Add a listening port. Can be called multiple times.
void AddListeningPort(const grpc::string& addr,
std::shared_ptr<ServerCredentials> creds,
@@ -87,6 +92,7 @@
int* selected_port;
};
+ int max_message_size_;
std::vector<RpcService*> services_;
std::vector<AsynchronousService*> async_services_;
std::vector<Port> ports_;
diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc
index 9878133..25609a7 100644
--- a/src/cpp/common/call.cc
+++ b/src/cpp/common/call.cc
@@ -55,6 +55,7 @@
recv_message_(nullptr),
recv_message_buffer_(nullptr),
recv_buf_(nullptr),
+ max_message_size_(-1),
client_send_close_(false),
recv_trailing_metadata_(nullptr),
recv_status_(nullptr),
@@ -311,7 +312,7 @@
got_message = *status;
if (recv_message_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, 0);
- *status = *status && DeserializeProto(recv_buf_, recv_message_);
+ *status = *status && DeserializeProto(recv_buf_, recv_message_, max_message_size_);
grpc_byte_buffer_destroy(recv_buf_);
GRPC_TIMER_MARK(DESER_PROTO_END, 0);
} else {
@@ -338,9 +339,17 @@
}
Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq)
- : call_hook_(call_hook), cq_(cq), call_(call) {}
+ : call_hook_(call_hook), cq_(cq), call_(call), max_message_size_(-1) {}
+
+Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq,
+ int max_message_size)
+ : call_hook_(call_hook), cq_(cq), call_(call),
+ max_message_size_(max_message_size) {}
void Call::PerformOps(CallOpBuffer* buffer) {
+ if (max_message_size_ > 0) {
+ buffer->set_max_message_size(max_message_size_);
+ }
call_hook_->PerformOpsOnCall(buffer, this);
}
diff --git a/src/cpp/proto/proto_utils.cc b/src/cpp/proto/proto_utils.cc
index b8de2ea..8ab536a 100644
--- a/src/cpp/proto/proto_utils.cc
+++ b/src/cpp/proto/proto_utils.cc
@@ -158,9 +158,14 @@
return msg.SerializeToZeroCopyStream(&writer);
}
-bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg) {
+bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
+ int max_message_size) {
GrpcBufferReader reader(buffer);
- return msg->ParseFromZeroCopyStream(&reader);
+ ::grpc::protobuf::io::CodedInputStream decoder(&reader);
+ if (max_message_size > 0) {
+ decoder.SetTotalBytesLimit(max_message_size, max_message_size);
+ }
+ return msg->ParseFromCodedStream(&decoder) && decoder.ConsumedEntireMessage();
}
} // namespace grpc
diff --git a/src/cpp/proto/proto_utils.h b/src/cpp/proto/proto_utils.h
index bc60dc9..67a775b 100644
--- a/src/cpp/proto/proto_utils.h
+++ b/src/cpp/proto/proto_utils.h
@@ -47,7 +47,8 @@
grpc_byte_buffer** buffer);
// The caller keeps ownership of buffer and msg.
-bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg);
+bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
+ int max_message_size);
} // namespace grpc
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc
index 4694a3a..d8f8ab4 100644
--- a/src/cpp/server/server.cc
+++ b/src/cpp/server/server.cc
@@ -100,7 +100,7 @@
public:
explicit CallData(Server* server, SyncRequest* mrd)
: cq_(mrd->cq_),
- call_(mrd->call_, server, &cq_),
+ call_(mrd->call_, server, &cq_, server->max_message_size_),
ctx_(mrd->deadline_, mrd->request_metadata_.metadata,
mrd->request_metadata_.count),
has_request_payload_(mrd->has_request_payload_),
@@ -126,7 +126,7 @@
if (has_request_payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_.call());
req.reset(method_->AllocateRequestProto());
- if (!DeserializeProto(request_payload_, req.get())) {
+ if (!DeserializeProto(request_payload_, req.get(), call_.max_message_size())) {
abort(); // for now
}
GRPC_TIMER_MARK(DESER_PROTO_END, call_.call());
@@ -176,12 +176,27 @@
grpc_completion_queue* cq_;
};
-Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned)
- : started_(false),
+grpc_server* CreateServer(grpc_completion_queue* cq, int max_message_size) {
+ if (max_message_size > 0) {
+ grpc_arg arg;
+ arg.type = GRPC_ARG_INTEGER;
+ arg.key = const_cast<char*>(GRPC_ARG_MAX_MESSAGE_LENGTH);
+ arg.value.integer = max_message_size;
+ grpc_channel_args args = {1, &arg};
+ return grpc_server_create(cq, &args);
+ } else {
+ return grpc_server_create(cq, nullptr);
+ }
+}
+
+Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
+ int max_message_size)
+ : max_message_size_(max_message_size),
+ started_(false),
shutdown_(false),
num_running_cb_(0),
sync_methods_(new std::list<SyncRequest>),
- server_(grpc_server_create(cq_.cq(), nullptr)),
+ server_(CreateServer(cq_.cq(), max_message_size)),
thread_pool_(thread_pool),
thread_pool_owned_(thread_pool_owned) {}
@@ -347,7 +362,8 @@
if (*status && request_) {
if (payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_);
- *status = DeserializeProto(payload_, request_);
+ *status = DeserializeProto(payload_, request_,
+ server_->max_message_size_);
GRPC_TIMER_MARK(DESER_PROTO_END, call_);
} else {
*status = false;
@@ -374,7 +390,7 @@
}
ctx->call_ = call_;
ctx->cq_ = cq_;
- Call call(call_, server_, cq_);
+ Call call(call_, server_, cq_, server_->max_message_size_);
if (orig_status && call_) {
ctx->BeginCompletionOp(&call);
}
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc
index 81cb0e6..e48d1ee 100644
--- a/src/cpp/server/server_builder.cc
+++ b/src/cpp/server/server_builder.cc
@@ -42,7 +42,7 @@
namespace grpc {
ServerBuilder::ServerBuilder()
- : generic_service_(nullptr), thread_pool_(nullptr) {}
+ : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {}
void ServerBuilder::RegisterService(SynchronousService* service) {
services_.push_back(service->service());
@@ -86,7 +86,8 @@
thread_pool_ = new ThreadPool(cores);
thread_pool_owned = true;
}
- std::unique_ptr<Server> server(new Server(thread_pool_, thread_pool_owned));
+ std::unique_ptr<Server> server(
+ new Server(thread_pool_, thread_pool_owned, max_message_size_));
for (auto service = services_.begin(); service != services_.end();
service++) {
if (!server->RegisterService(*service)) {
diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc
index 5e89490..77c45d0 100644
--- a/test/cpp/end2end/end2end_test.cc
+++ b/test/cpp/end2end/end2end_test.cc
@@ -172,7 +172,7 @@
class End2endTest : public ::testing::Test {
protected:
- End2endTest() : thread_pool_(2) {}
+ End2endTest() : kMaxMessageSize_(8192), thread_pool_(2) {}
void SetUp() GRPC_OVERRIDE {
int port = grpc_pick_unused_port_or_die();
@@ -182,6 +182,7 @@
builder.AddListeningPort(server_address_.str(),
InsecureServerCredentials());
builder.RegisterService(&service_);
+ builder.SetMaxMessageSize(kMaxMessageSize_); // For testing max message size.
builder.RegisterService(&dup_pkg_service_);
builder.SetThreadPool(&thread_pool_);
server_ = builder.BuildAndStart();
@@ -198,11 +199,13 @@
std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
std::unique_ptr<Server> server_;
std::ostringstream server_address_;
+ const int kMaxMessageSize_;
TestServiceImpl service_;
TestServiceImplDupPkg dup_pkg_service_;
ThreadPool thread_pool_;
};
+/*
static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub,
int num_rpcs) {
EchoRequest request;
@@ -575,7 +578,18 @@
Status s = stream->Finish();
EXPECT_EQ(grpc::StatusCode::CANCELLED, s.code());
}
+*/
+TEST_F(End2endTest, RpcMaxMessageSize) {
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message(string(kMaxMessageSize_*2, 'a'));
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.IsOk());
+}
} // namespace testing
} // namespace grpc