Support client streaming
diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc
index a02a8b2..d9232ec 100644
--- a/test/cpp/util/cli_call.cc
+++ b/test/cpp/util/cli_call.cc
@@ -37,8 +37,6 @@
#include <grpc++/channel.h>
#include <grpc++/client_context.h>
-#include <grpc++/completion_queue.h>
-#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc/grpc.h>
#include <grpc/slice.h>
@@ -50,49 +48,61 @@
void* tag(int i) { return (void*)(intptr_t)i; }
} // namespace
+enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH };
+
Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request,
grpc::string* response,
const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata) {
- std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel));
- grpc::ClientContext ctx;
+ CliCall call(channel, method, metadata);
+ call.Write(request);
+ call.WritesDone();
+ call.Read(response, server_initial_metadata);
+ return call.Finish(server_trailing_metadata);
+}
+
+CliCall::CliCall(std::shared_ptr<grpc::Channel> channel,
+ const grpc::string& method,
+ const OutgoingMetadataContainer& metadata)
+ : stub_(new grpc::GenericStub(channel)) {
if (!metadata.empty()) {
for (OutgoingMetadataContainer::const_iterator iter = metadata.begin();
iter != metadata.end(); ++iter) {
- ctx.AddMetadata(iter->first, iter->second);
+ ctx_.AddMetadata(iter->first, iter->second);
}
}
- grpc::CompletionQueue cq;
- std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call(
- stub->Call(&ctx, method, &cq, tag(1)));
+ call_ = stub_->Call(&ctx_, method, &cq_, tag(1));
void* got_tag;
bool ok;
- cq.Next(&got_tag, &ok);
+ cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
+}
+
+void CliCall::Write(const grpc::string& request) {
+ void* got_tag;
+ bool ok;
grpc_slice s = grpc_slice_from_copied_string(request.c_str());
grpc::Slice req_slice(s, grpc::Slice::STEAL_REF);
grpc::ByteBuffer send_buffer(&req_slice, 1);
- call->Write(send_buffer, tag(2));
- cq.Next(&got_tag, &ok);
+ call_->Write(send_buffer, tag(2));
+ cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
- call->WritesDone(tag(3));
- cq.Next(&got_tag, &ok);
- GPR_ASSERT(ok);
- grpc::ByteBuffer recv_buffer;
- call->Read(&recv_buffer, tag(4));
- cq.Next(&got_tag, &ok);
- if (!ok) {
- std::cout << "Failed to read response." << std::endl;
- }
- grpc::Status status;
- call->Finish(&status, tag(5));
- cq.Next(&got_tag, &ok);
- GPR_ASSERT(ok);
+}
- if (status.ok()) {
+void CliCall::Read(grpc::string* response,
+ IncomingMetadataContainer* server_initial_metadata) {
+ void* got_tag;
+ bool ok;
+
+ grpc::ByteBuffer recv_buffer;
+ call_->Read(&recv_buffer, tag(4));
+ cq_.Next(&got_tag, &ok);
+ if (!ok) {
+ fprintf(stderr, "Failed to read response.");
+ } else {
std::vector<grpc::Slice> slices;
(void)recv_buffer.Dump(&slices);
@@ -101,10 +111,33 @@
response->append(reinterpret_cast<const char*>(slices[i].begin()),
slices[i].size());
}
+ if (server_initial_metadata) {
+ *server_initial_metadata = ctx_.GetServerInitialMetadata();
+ }
+ }
+}
+
+void CliCall::WritesDone() {
+ void* got_tag;
+ bool ok;
+
+ call_->WritesDone(tag(3));
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+}
+
+Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) {
+ void* got_tag;
+ bool ok;
+ grpc::Status status;
+
+ call_->Finish(&status, tag(5));
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+ if (server_trailing_metadata) {
+ *server_trailing_metadata = ctx_.GetServerTrailingMetadata();
}
- *server_initial_metadata = ctx.GetServerInitialMetadata();
- *server_trailing_metadata = ctx.GetServerTrailingMetadata();
return status;
}
diff --git a/test/cpp/util/cli_call.h b/test/cpp/util/cli_call.h
index 65da86b..3f32830 100644
--- a/test/cpp/util/cli_call.h
+++ b/test/cpp/util/cli_call.h
@@ -37,10 +37,15 @@
#include <map>
#include <grpc++/channel.h>
+#include <grpc++/completion_queue.h>
+#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/status.h>
#include <grpc++/support/string_ref.h>
namespace grpc {
+
+class ClientContext;
+
namespace testing {
class CliCall final {
@@ -48,12 +53,32 @@
typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer;
typedef std::multimap<grpc::string_ref, grpc::string_ref>
IncomingMetadataContainer;
+
+ CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method,
+ const OutgoingMetadataContainer& metadata);
+
static Status Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request,
grpc::string* response,
const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata);
+
+ void Write(const grpc::string& request);
+
+ void WritesDone();
+
+ void Read(grpc::string* response,
+ IncomingMetadataContainer* server_initial_metadata);
+
+ Status Finish(IncomingMetadataContainer* server_trailing_metadata);
+
+ private:
+ enum CallStatus : intptr_t;
+ std::unique_ptr<grpc::GenericStub> stub_;
+ grpc::ClientContext ctx_;
+ std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
+ grpc::CompletionQueue cq_;
};
} // namespace testing
diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc
index b9900ca..8082d60 100644
--- a/test/cpp/util/grpc_tool.cc
+++ b/test/cpp/util/grpc_tool.cc
@@ -419,79 +419,180 @@
std::unique_ptr<grpc::testing::ProtoFileParser> parser;
grpc::string serialized_request_proto;
- if (argc == 3) {
- request_text = argv[2];
- if (!FLAGS_infile.empty()) {
- fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+ std::shared_ptr<grpc::Channel> channel =
+ FLAGS_remotedb
+ ? grpc::CreateChannel(server_address, cred.GetCredentials())
+ : nullptr;
+
+ parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path,
+ FLAGS_protofiles));
+
+ grpc::string formated_method_name =
+ parser->GetFormatedMethodName(method_name);
+
+ if (parser->HasError()) {
+ return false;
+ }
+
+ if (parser->IsStreaming(method_name, true /* is_request */)) {
+ fprintf(stderr, "streaming request\n");
+ std::istream* input_stream;
+ std::ifstream input_file;
+
+ if (argc == 3) {
+ request_text = argv[2];
+ if (!FLAGS_infile.empty()) {
+ fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+ }
}
- } else {
- std::stringstream input_stream;
+ // std::stringstream input_stream;
+
+ std::multimap<grpc::string, grpc::string> client_metadata;
+ ParseMetadataFlag(&client_metadata);
+ PrintMetadata(client_metadata, "Sending client initial metadata:");
+
+ CliCall call(channel, formated_method_name, client_metadata);
+
if (FLAGS_infile.empty()) {
if (isatty(STDIN_FILENO)) {
fprintf(stderr, "reading request message from stdin...\n");
}
- input_stream << std::cin.rdbuf();
+ input_stream = &std::cin;
+ // rdbuf = std::cin.rdbuf();
+ // input_stream.rdbuf(std::cin.rdbuf());
+ // input_stream << std::cin.rdbuf();
+
} else {
- std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
- input_stream << input_file.rdbuf();
+ input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
+ // rdbuf = input_file.rdbuf();
+ // input_stream.rdbuf(input_file.rdbuf());
+ input_stream = &input_file;
+ // input_file.close();
+ }
+ // request_text = input_stream.str();
+
+ std::stringstream request_ss;
+ grpc::string line;
+ while (!input_stream->eof() && getline(*input_stream, line)) {
+ if (line.length() == 0) {
+ // request_text = request_ss.str();
+ if (FLAGS_binary_input) {
+ serialized_request_proto = request_ss.str();
+ } else {
+ serialized_request_proto = parser->GetSerializedProtoFromMethod(
+ method_name, request_ss.str(), true /* is_request */);
+ if (parser->HasError()) {
+ return false;
+ }
+ }
+
+ request_ss.str(grpc::string());
+ request_ss.clear();
+
+ grpc::string response_text = parser->GetTextFormatFromMethod(
+ method_name, serialized_request_proto, true /* is_request */);
+ call.Write(serialized_request_proto);
+
+ fprintf(stderr, "%s", response_text.c_str());
+ } else {
+ request_ss << line << ' ';
+ }
+ }
+ if (input_file.is_open()) {
input_file.close();
}
- request_text = input_stream.str();
- }
- std::shared_ptr<grpc::Channel> channel =
- grpc::CreateChannel(server_address, cred.GetCredentials());
- if (!FLAGS_binary_input || !FLAGS_binary_output) {
- parser.reset(
- new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
- FLAGS_proto_path, FLAGS_protofiles));
- if (parser->HasError()) {
- return false;
- }
- }
+ call.WritesDone();
- if (FLAGS_binary_input) {
- serialized_request_proto = request_text;
- formatted_method_name = method_name;
- } else {
- formatted_method_name = parser->GetFormattedMethodName(method_name);
- serialized_request_proto = parser->GetSerializedProtoFromMethod(
- method_name, request_text, true /* is_request */);
- if (parser->HasError()) {
- return false;
- }
- }
- fprintf(stderr, "connecting to %s\n", server_address.c_str());
+ grpc::string serialized_response_proto;
+ std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+ server_trailing_metadata;
+ call.Read(&serialized_response_proto, &server_initial_metadata);
+ Status status = call.Finish(&server_trailing_metadata);
- grpc::string serialized_response_proto;
- std::multimap<grpc::string, grpc::string> client_metadata;
- std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
- server_trailing_metadata;
- ParseMetadataFlag(&client_metadata);
- PrintMetadata(client_metadata, "Sending client initial metadata:");
- grpc::Status status = grpc::testing::CliCall::Call(
- channel, formatted_method_name, serialized_request_proto,
- &serialized_response_proto, client_metadata, &server_initial_metadata,
- &server_trailing_metadata);
- PrintMetadata(server_initial_metadata,
- "Received initial metadata from server:");
- PrintMetadata(server_trailing_metadata,
- "Received trailing metadata from server:");
- if (status.ok()) {
- fprintf(stderr, "Rpc succeeded with OK status\n");
- if (FLAGS_binary_output) {
- output_ss << serialized_response_proto;
+ PrintMetadata(server_initial_metadata,
+ "Received initial metadata from server:");
+ PrintMetadata(server_trailing_metadata,
+ "Received trailing metadata from server:");
+ if (status.ok()) {
+ fprintf(stderr, "Stream RPC succeeded with OK status\n");
+ if (FLAGS_binary_output) {
+ output_ss << serialized_response_proto;
+ } else {
+ grpc::string response_text = parser->GetTextFormatFromMethod(
+ method_name, serialized_response_proto, false /* is_request */);
+ if (parser->HasError()) {
+ return false;
+ }
+ output_ss << "Response: \n " << response_text << std::endl;
+ }
} else {
- grpc::string response_text = parser->GetTextFormatFromMethod(
- method_name, serialized_response_proto, false /* is_request */);
+ fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+ status.error_code(), status.error_message().c_str());
+ }
+
+ } else { // parser->IsStreaming(method_name, true /* is_request */)
+ if (argc == 3) {
+ request_text = argv[2];
+ if (!FLAGS_infile.empty()) {
+ fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+ }
+ } else {
+ std::stringstream input_stream;
+ if (FLAGS_infile.empty()) {
+ if (isatty(STDIN_FILENO)) {
+ fprintf(stderr, "reading request message from stdin...\n");
+ }
+ input_stream << std::cin.rdbuf();
+ } else {
+ std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
+ input_stream << input_file.rdbuf();
+ input_file.close();
+ }
+ request_text = input_stream.str();
+ }
+
+ if (FLAGS_binary_input) {
+ serialized_request_proto = request_text;
+ } else {
+ serialized_request_proto = parser->GetSerializedProtoFromMethod(
+ method_name, request_text, true /* is_request */);
if (parser->HasError()) {
return false;
}
- output_ss << "Response: \n " << response_text << std::endl;
}
- } else {
- fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
- status.error_code(), status.error_message().c_str());
+ fprintf(stderr, "connecting to %s\n", server_address.c_str());
+
+ grpc::string serialized_response_proto;
+ std::multimap<grpc::string, grpc::string> client_metadata;
+ std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+ server_trailing_metadata;
+ ParseMetadataFlag(&client_metadata);
+ PrintMetadata(client_metadata, "Sending client initial metadata:");
+ grpc::Status status = grpc::testing::CliCall::Call(
+ channel, formated_method_name, serialized_request_proto,
+ &serialized_response_proto, client_metadata, &server_initial_metadata,
+ &server_trailing_metadata);
+ PrintMetadata(server_initial_metadata,
+ "Received initial metadata from server:");
+ PrintMetadata(server_trailing_metadata,
+ "Received trailing metadata from server:");
+ if (status.ok()) {
+ fprintf(stderr, "Rpc succeeded with OK status\n");
+ if (FLAGS_binary_output) {
+ output_ss << serialized_response_proto;
+ } else {
+ grpc::string response_text = parser->GetTextFormatFromMethod(
+ method_name, serialized_response_proto, false /* is_request */);
+ if (parser->HasError()) {
+ return false;
+ }
+ output_ss << "Response: \n " << response_text << std::endl;
+ }
+ } else {
+ fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+ status.error_code(), status.error_message().c_str());
+ }
}
return callback(output_ss.str());
diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc
index bc8a608..41bf88c 100644
--- a/test/cpp/util/proto_file_parser.cc
+++ b/test/cpp/util/proto_file_parser.cc
@@ -144,12 +144,18 @@
grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
has_error_ = false;
+
+ if (known_methods_.find(method) != known_methods_.end()) {
+ return known_methods_[method];
+ }
+
const protobuf::MethodDescriptor* method_descriptor = nullptr;
for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
it++) {
const auto* service_desc = *it;
for (int j = 0; j < service_desc->method_count(); j++) {
const auto* method_desc = service_desc->method(j);
+ fprintf(stderr, "%s\n", method_desc->full_name().c_str());
if (MethodNameMatch(method_desc->full_name(), method)) {
if (method_descriptor) {
std::ostringstream error_stream;
@@ -169,6 +175,8 @@
return "";
}
+ known_methods_[method] = method_descriptor->full_name();
+
return method_descriptor->full_name();
}
@@ -205,6 +213,25 @@
: method_desc->output_type()->full_name();
}
+bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) {
+ has_error_ = false;
+
+ grpc::string full_method_name = GetFullMethodName(method);
+ if (has_error_) {
+ return false;
+ }
+
+ const protobuf::MethodDescriptor* method_desc =
+ desc_pool_->FindMethodByName(full_method_name);
+ if (!method_desc) {
+ LogError("Method not found");
+ return false;
+ }
+
+ return is_request ? method_desc->client_streaming()
+ : method_desc->server_streaming();
+}
+
grpc::string ProtoFileParser::GetSerializedProtoFromMethod(
const grpc::string& method, const grpc::string& text_format_proto,
bool is_request) {
diff --git a/test/cpp/util/proto_file_parser.h b/test/cpp/util/proto_file_parser.h
index c1070a3..23d311e 100644
--- a/test/cpp/util/proto_file_parser.h
+++ b/test/cpp/util/proto_file_parser.h
@@ -84,6 +84,8 @@
const grpc::string& message_type_name,
const grpc::string& serialized_proto);
+ bool IsStreaming(const grpc::string& method, bool is_request);
+
bool HasError() const { return has_error_; }
void LogError(const grpc::string& error_msg);
@@ -104,6 +106,7 @@
std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_;
std::unique_ptr<grpc::protobuf::Message> request_prototype_;
std::unique_ptr<grpc::protobuf::Message> response_prototype_;
+ std::unordered_map<grpc::string, grpc::string> known_methods_;
std::vector<const protobuf::ServiceDescriptor*> service_desc_list_;
};