Allow reserved metadata to be propagated between calls
diff --git a/src/ruby/ext/grpc/rb_channel.c b/src/ruby/ext/grpc/rb_channel.c
index ac591f1..8bf2bf2 100644
--- a/src/ruby/ext/grpc/rb_channel.c
+++ b/src/ruby/ext/grpc/rb_channel.c
@@ -195,18 +195,28 @@
/* Create a call given a grpc_channel, in order to call method. The request
is not sent until grpc_call_invoke is called. */
-static VALUE grpc_rb_channel_create_call(VALUE self, VALUE cqueue, VALUE method,
- VALUE host, VALUE deadline) {
+static VALUE grpc_rb_channel_create_call(VALUE self, VALUE cqueue,
+ VALUE parent, VALUE mask,
+ VALUE method, VALUE host,
+ VALUE deadline) {
VALUE res = Qnil;
grpc_rb_channel *wrapper = NULL;
grpc_call *call = NULL;
+ grpc_call *parent_call = NULL;
grpc_channel *ch = NULL;
grpc_completion_queue *cq = NULL;
+ int flags = GRPC_PROPAGATE_DEFAULTS;
char *method_chars = StringValueCStr(method);
char *host_chars = NULL;
if (host != Qnil) {
host_chars = StringValueCStr(host);
}
+ if (mask != Qnil) {
+ flags = NUM2UINT(mask);
+ }
+ if (parent != Qnil) {
+ parent_call = grpc_rb_get_wrapped_call(parent);
+ }
cq = grpc_rb_get_wrapped_completion_queue(cqueue);
TypedData_Get_Struct(self, grpc_rb_channel, &grpc_channel_data_type, wrapper);
@@ -216,10 +226,10 @@
return Qnil;
}
- call = grpc_channel_create_call(ch, NULL, GRPC_PROPAGATE_DEFAULTS, cq,
- method_chars, host_chars,
- grpc_rb_time_timeval(deadline,
- /* absolute time */ 0));
+ call = grpc_channel_create_call(ch, parent_call, flags, cq, method_chars,
+ host_chars, grpc_rb_time_timeval(
+ deadline,
+ /* absolute time */ 0));
if (call == NULL) {
rb_raise(rb_eRuntimeError, "cannot create call with method %s",
method_chars);
@@ -237,6 +247,7 @@
return res;
}
+
/* Closes the channel, calling it's destroy method */
static VALUE grpc_rb_channel_destroy(VALUE self) {
grpc_rb_channel *wrapper = NULL;
@@ -283,7 +294,7 @@
/* Add ruby analogues of the Channel methods. */
rb_define_method(grpc_rb_cChannel, "create_call",
- grpc_rb_channel_create_call, 4);
+ grpc_rb_channel_create_call, 6);
rb_define_method(grpc_rb_cChannel, "target", grpc_rb_channel_get_target, 0);
rb_define_method(grpc_rb_cChannel, "destroy", grpc_rb_channel_destroy, 0);
rb_define_alias(grpc_rb_cChannel, "close", "destroy");
diff --git a/src/ruby/lib/grpc/generic/client_stub.rb b/src/ruby/lib/grpc/generic/client_stub.rb
index a2f1ec6..cce7185 100644
--- a/src/ruby/lib/grpc/generic/client_stub.rb
+++ b/src/ruby/lib/grpc/generic/client_stub.rb
@@ -32,6 +32,8 @@
# GRPC contains the General RPC module.
module GRPC
+ # rubocop:disable Metrics/ParameterLists
+
# ClientStub represents an endpoint used to send requests to GRPC servers.
class ClientStub
include Core::StatusCodes
@@ -68,6 +70,12 @@
update_metadata
end
+ # Allows users of the stub to modify the propagate mask.
+ #
+ # This is an advanced feature for use when making calls to another gRPC
+ # server whilst running in the handler of an existing one.
+ attr_writer :propagate_mask
+
# Creates a new ClientStub.
#
# Minimally, a stub is created with the just the host of the gRPC service
@@ -91,8 +99,8 @@
#
# - :update_metadata
# when present, this a func that takes a hash and returns a hash
- # it can be used to update metadata, i.e, remove, change or update
- # amend metadata values.
+ # it can be used to update metadata, i.e, remove, or amend
+ # metadata values.
#
# @param host [String] the host the stub connects to
# @param q [Core::CompletionQueue] used to wait for events
@@ -105,6 +113,7 @@
channel_override: nil,
timeout: nil,
creds: nil,
+ propagate_mask: nil,
update_metadata: nil,
**kw)
fail(TypeError, '!CompletionQueue') unless q.is_a?(Core::CompletionQueue)
@@ -113,6 +122,7 @@
@update_metadata = ClientStub.check_update_metadata(update_metadata)
alt_host = kw[Core::Channel::SSL_TARGET]
@host = alt_host.nil? ? host : alt_host
+ @propagate_mask = propagate_mask
@timeout = timeout.nil? ? DEFAULT_TIMEOUT : timeout
end
@@ -151,11 +161,15 @@
# @param marshal [Function] f(obj)->string that marshals requests
# @param unmarshal [Function] f(string)->obj that unmarshals responses
# @param timeout [Numeric] (optional) the max completion time in seconds
+ # @param parent [Core::Call] a prior call whose reserved metadata
+ # will be propagated by this one.
# @param return_op [true|false] return an Operation if true
# @return [Object] the response received from the server
def request_response(method, req, marshal, unmarshal, timeout = nil,
- return_op: false, **kw)
- c = new_active_call(method, marshal, unmarshal, timeout)
+ return_op: false,
+ parent: parent,
+ **kw)
+ c = new_active_call(method, marshal, unmarshal, timeout, parent: parent)
kw_with_jwt_uri = self.class.update_with_jwt_aud_uri(kw, @host, method)
md = @update_metadata.nil? ? kw : @update_metadata.call(kw_with_jwt_uri)
return c.request_response(req, **md) unless return_op
@@ -210,10 +224,14 @@
# @param unmarshal [Function] f(string)->obj that unmarshals responses
# @param timeout [Numeric] the max completion time in seconds
# @param return_op [true|false] return an Operation if true
+ # @param parent [Core::Call] a prior call whose reserved metadata
+ # will be propagated by this one.
# @return [Object|Operation] the response received from the server
def client_streamer(method, requests, marshal, unmarshal, timeout = nil,
- return_op: false, **kw)
- c = new_active_call(method, marshal, unmarshal, timeout)
+ return_op: false,
+ parent: nil,
+ **kw)
+ c = new_active_call(method, marshal, unmarshal, timeout, parent: parent)
kw_with_jwt_uri = self.class.update_with_jwt_aud_uri(kw, @host, method)
md = @update_metadata.nil? ? kw : @update_metadata.call(kw_with_jwt_uri)
return c.client_streamer(requests, **md) unless return_op
@@ -276,11 +294,16 @@
# @param unmarshal [Function] f(string)->obj that unmarshals responses
# @param timeout [Numeric] the max completion time in seconds
# @param return_op [true|false]return an Operation if true
+ # @param parent [Core::Call] a prior call whose reserved metadata
+ # will be propagated by this one.
# @param blk [Block] when provided, is executed for each response
# @return [Enumerator|Operation|nil] as discussed above
def server_streamer(method, req, marshal, unmarshal, timeout = nil,
- return_op: false, **kw, &blk)
- c = new_active_call(method, marshal, unmarshal, timeout)
+ return_op: false,
+ parent: nil,
+ **kw,
+ &blk)
+ c = new_active_call(method, marshal, unmarshal, timeout, parent: parent)
kw_with_jwt_uri = self.class.update_with_jwt_aud_uri(kw, @host, method)
md = @update_metadata.nil? ? kw : @update_metadata.call(kw_with_jwt_uri)
return c.server_streamer(req, **md, &blk) unless return_op
@@ -381,12 +404,17 @@
# @param marshal [Function] f(obj)->string that marshals requests
# @param unmarshal [Function] f(string)->obj that unmarshals responses
# @param timeout [Numeric] (optional) the max completion time in seconds
- # @param blk [Block] when provided, is executed for each response
+ # @param parent [Core::Call] a prior call whose reserved metadata
+ # will be propagated by this one.
# @param return_op [true|false] return an Operation if true
+ # @param blk [Block] when provided, is executed for each response
# @return [Enumerator|nil|Operation] as discussed above
def bidi_streamer(method, requests, marshal, unmarshal, timeout = nil,
- return_op: false, **kw, &blk)
- c = new_active_call(method, marshal, unmarshal, timeout)
+ return_op: false,
+ parent: nil,
+ **kw,
+ &blk)
+ c = new_active_call(method, marshal, unmarshal, timeout, parent: parent)
kw_with_jwt_uri = self.class.update_with_jwt_aud_uri(kw, @host, method)
md = @update_metadata.nil? ? kw : @update_metadata.call(kw_with_jwt_uri)
return c.bidi_streamer(requests, **md, &blk) unless return_op
@@ -407,10 +435,17 @@
# @param method [string] the method being called.
# @param marshal [Function] f(obj)->string that marshals requests
# @param unmarshal [Function] f(string)->obj that unmarshals responses
+ # @param parent [Grpc::Call] a parent call, available when calls are
+ # made from server
# @param timeout [TimeConst]
- def new_active_call(method, marshal, unmarshal, timeout = nil)
+ def new_active_call(method, marshal, unmarshal, timeout = nil, parent: nil)
deadline = from_relative_time(timeout.nil? ? @timeout : timeout)
- call = @ch.create_call(@queue, method, nil, deadline)
+ call = @ch.create_call(@queue,
+ parent, # parent call
+ @propagate_mask, # propagation options
+ method,
+ nil, # host use nil,
+ deadline)
ActiveCall.new(call, @queue, marshal, unmarshal, deadline, started: false)
end
end
diff --git a/src/ruby/spec/call_spec.rb b/src/ruby/spec/call_spec.rb
index 36a442f..3c5d33f 100644
--- a/src/ruby/spec/call_spec.rb
+++ b/src/ruby/spec/call_spec.rb
@@ -137,7 +137,7 @@
end
def make_test_call
- @ch.create_call(client_queue, 'dummy_method', nil, deadline)
+ @ch.create_call(client_queue, nil, nil, 'dummy_method', nil, deadline)
end
def deadline
diff --git a/src/ruby/spec/channel_spec.rb b/src/ruby/spec/channel_spec.rb
index 9081f0e..25cefcd 100644
--- a/src/ruby/spec/channel_spec.rb
+++ b/src/ruby/spec/channel_spec.rb
@@ -117,7 +117,7 @@
deadline = Time.now + 5
blk = proc do
- ch.create_call(cq, 'dummy_method', nil, deadline)
+ ch.create_call(cq, nil, nil, 'dummy_method', nil, deadline)
end
expect(&blk).to_not raise_error
end
@@ -128,7 +128,7 @@
deadline = Time.now + 5
blk = proc do
- ch.create_call(cq, 'dummy_method', nil, deadline)
+ ch.create_call(cq, nil, nil, 'dummy_method', nil, deadline)
end
expect(&blk).to raise_error(RuntimeError)
end
diff --git a/src/ruby/spec/client_server_spec.rb b/src/ruby/spec/client_server_spec.rb
index 57c9a8d..2e673ff 100644
--- a/src/ruby/spec/client_server_spec.rb
+++ b/src/ruby/spec/client_server_spec.rb
@@ -61,7 +61,7 @@
end
def new_client_call
- @ch.create_call(@client_queue, '/method', nil, deadline)
+ @ch.create_call(@client_queue, nil, nil, '/method', nil, deadline)
end
end
diff --git a/src/ruby/spec/generic/active_call_spec.rb b/src/ruby/spec/generic/active_call_spec.rb
index 424b2db..0bf65ba 100644
--- a/src/ruby/spec/generic/active_call_spec.rb
+++ b/src/ruby/spec/generic/active_call_spec.rb
@@ -338,7 +338,7 @@
end
def make_test_call
- @ch.create_call(@client_queue, '/method', nil, deadline)
+ @ch.create_call(@client_queue, nil, nil, '/method', nil, deadline)
end
def deadline