Server side cancellation receive support
diff --git a/include/grpc++/completion_queue.h b/include/grpc++/completion_queue.h
index 9a4fa9f..cec0ef0 100644
--- a/include/grpc++/completion_queue.h
+++ b/include/grpc++/completion_queue.h
@@ -114,7 +114,7 @@
bool Pluck(CompletionQueueTag *tag);
// Does a single polling pluck on tag
- void TryPluck(CompletionQueueTag *tag);
+ void TryPluck(CompletionQueueTag *tag, bool forever);
grpc_completion_queue *cq_; // owned
};
diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h
index e2e14d9..81dcb21 100644
--- a/include/grpc++/server_context.h
+++ b/include/grpc++/server_context.h
@@ -34,8 +34,6 @@
#ifndef __GRPCPP_SERVER_CONTEXT_H_
#define __GRPCPP_SERVER_CONTEXT_H_
-#include <grpc++/completion_queue.h>
-
#include <chrono>
#include <map>
#include <mutex>
@@ -63,7 +61,9 @@
template <class R, class W>
class ServerReaderWriter;
+class Call;
class CallOpBuffer;
+class CompletionQueue;
class Server;
// Interface of server side rpc context.
@@ -79,7 +79,7 @@
void AddInitialMetadata(const grpc::string& key, const grpc::string& value);
void AddTrailingMetadata(const grpc::string& key, const grpc::string& value);
- bool IsCancelled() { return completion_op_.CheckCancelled(cq_); }
+ bool IsCancelled();
std::multimap<grpc::string, grpc::string> client_metadata() {
return client_metadata_;
@@ -102,22 +102,14 @@
template <class R, class W>
friend class ::grpc::ServerReaderWriter;
- class CompletionOp final : public CompletionQueueTag {
- public:
- bool FinalizeResult(void** tag, bool* status) override;
+ class CompletionOp;
- bool CheckCancelled(CompletionQueue* cq);
-
- private:
- std::mutex mu_;
- bool finalized_ = false;
- int cancelled_ = 0;
- };
+ void BeginCompletionOp(Call* call);
ServerContext(gpr_timespec deadline, grpc_metadata* metadata,
size_t metadata_count);
- CompletionOp completion_op_;
+ CompletionOp* completion_op_ = nullptr;
std::chrono::system_clock::time_point deadline_;
grpc_call* call_ = nullptr;
diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h
index 20ba3fb..a37062b 100644
--- a/include/grpc++/stream.h
+++ b/include/grpc++/stream.h
@@ -576,8 +576,6 @@
if (status.IsOk()) {
finish_buf_.AddSendMessage(msg);
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -589,8 +587,6 @@
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
ctx_->sent_initial_metadata_ = true;
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -636,8 +632,6 @@
if (status.IsOk()) {
finish_buf_.AddSendMessage(msg);
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -649,8 +643,6 @@
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
ctx_->sent_initial_metadata_ = true;
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -697,8 +689,6 @@
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
ctx_->sent_initial_metadata_ = true;
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -753,8 +743,6 @@
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
ctx_->sent_initial_metadata_ = true;
}
- bool cancelled = false;
- finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
diff --git a/src/cpp/common/completion_queue.cc b/src/cpp/common/completion_queue.cc
index c330d21..f9bb868 100644
--- a/src/cpp/common/completion_queue.cc
+++ b/src/cpp/common/completion_queue.cc
@@ -88,10 +88,11 @@
}
}
-void CompletionQueue::TryPluck(CompletionQueueTag* tag) {
+void CompletionQueue::TryPluck(CompletionQueueTag* tag, bool forever) {
std::unique_ptr<grpc_event, EventDeleter> ev;
- ev.reset(grpc_completion_queue_pluck(cq_, tag, gpr_inf_past));
+ ev.reset(grpc_completion_queue_pluck(
+ cq_, tag, forever ? gpr_inf_future : gpr_inf_past));
if (!ev) return;
bool ok = ev->data.op_complete == GRPC_OP_OK;
void* ignored = tag;
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc
index 8fffea6..cf6b022 100644
--- a/src/cpp/server/server.cc
+++ b/src/cpp/server/server.cc
@@ -205,6 +205,7 @@
if (has_response_payload_) {
res.reset(method_->AllocateResponseProto());
}
+ ctx_.BeginCompletionOp(&call_);
auto status = method_->handler()->RunHandler(
MethodHandler::HandlerParameter(&call_, &ctx_, req.get(), res.get()));
CallOpBuffer buf;
@@ -215,10 +216,12 @@
buf.AddSendMessage(*res);
}
buf.AddServerSendStatus(&ctx_.trailing_metadata_, status);
- bool cancelled;
- buf.AddServerRecvClose(&cancelled);
call_.PerformOps(&buf);
GPR_ASSERT(cq_.Pluck(&buf));
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_.Shutdown();
+ GPR_ASSERT(cq_.Next(&ignored_tag, &ignored_ok) == false);
}
private:
@@ -332,6 +335,7 @@
}
ctx_->call_ = call_;
Call call(call_, server_, cq_);
+ ctx_->BeginCompletionOp(&call);
// just the pointers inside call are copied here
stream_->BindCall(&call);
delete this;
diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc
index b9d85b9..9412f27 100644
--- a/src/cpp/server/server_context.cc
+++ b/src/cpp/server/server_context.cc
@@ -34,10 +34,59 @@
#include <grpc++/server_context.h>
#include <grpc++/impl/call.h>
#include <grpc/grpc.h>
+#include <grpc/support/log.h>
#include "src/cpp/util/time.h"
namespace grpc {
+// CompletionOp
+
+class ServerContext::CompletionOp final : public CallOpBuffer {
+ public:
+ CompletionOp();
+ bool FinalizeResult(void** tag, bool* status) override;
+
+ bool CheckCancelled(CompletionQueue* cq);
+
+ void Unref();
+
+ private:
+ std::mutex mu_;
+ int refs_ = 2; // initial refs: one in the server context, one in the cq
+ bool finalized_ = false;
+ bool cancelled_ = false;
+};
+
+ServerContext::CompletionOp::CompletionOp() { AddServerRecvClose(&cancelled_); }
+
+void ServerContext::CompletionOp::Unref() {
+ std::unique_lock<std::mutex> lock(mu_);
+ if (--refs_ == 0) {
+ lock.unlock();
+ delete this;
+ }
+}
+
+bool ServerContext::CompletionOp::CheckCancelled(CompletionQueue* cq) {
+ cq->TryPluck(this, false);
+ std::lock_guard<std::mutex> g(mu_);
+ return finalized_ ? cancelled_ : false;
+}
+
+bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) {
+ GPR_ASSERT(CallOpBuffer::FinalizeResult(tag, status));
+ std::unique_lock<std::mutex> lock(mu_);
+ finalized_ = true;
+ if (!*status) cancelled_ = true;
+ if (--refs_ == 0) {
+ lock.unlock();
+ delete this;
+ }
+ return false;
+}
+
+// ServerContext body
+
ServerContext::ServerContext() {}
ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata* metadata,
@@ -55,6 +104,15 @@
if (call_) {
grpc_call_destroy(call_);
}
+ if (completion_op_) {
+ completion_op_->Unref();
+ }
+}
+
+void ServerContext::BeginCompletionOp(Call* call) {
+ GPR_ASSERT(!completion_op_);
+ completion_op_ = new CompletionOp();
+ call->PerformOps(completion_op_);
}
void ServerContext::AddInitialMetadata(const grpc::string& key,
@@ -67,16 +125,8 @@
trailing_metadata_.insert(std::make_pair(key, value));
}
-bool ServerContext::CompletionOp::CheckCancelled(CompletionQueue* cq) {
- cq->TryPluck(this);
- std::lock_guard<std::mutex> g(mu_);
- return finalized_ ? cancelled_ != 0 : false;
-}
-
-bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) {
- std::lock_guard<std::mutex> g(mu_);
- finalized_ = true;
- return false;
+bool ServerContext::IsCancelled() {
+ return completion_op_ && completion_op_->CheckCancelled(cq_);
}
} // namespace grpc
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index 7e827cb..2e28a86 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -90,7 +90,17 @@
server_ = builder.BuildAndStart();
}
- void TearDown() override { server_->Shutdown(); }
+ void TearDown() override {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cli_cq_.Shutdown();
+ srv_cq_.Shutdown();
+ while (cli_cq_.Next(&ignored_tag, &ignored_ok))
+ ;
+ while (srv_cq_.Next(&ignored_tag, &ignored_ok))
+ ;
+ }
void ResetStub() {
std::shared_ptr<ChannelInterface> channel =