Add final set of protobuf helpers to api_core (#4259)
diff --git a/google/api_core/protobuf_helpers.py b/google/api_core/protobuf_helpers.py
index 6031ff0..fbe1a82 100644
--- a/google/api_core/protobuf_helpers.py
+++ b/google/api_core/protobuf_helpers.py
@@ -19,6 +19,8 @@
from google.protobuf.message import Message
+_SENTINEL = object()
+
def from_any_pb(pb_type, any_pb):
"""Converts an ``Any`` protobuf to the specified message type.
@@ -44,11 +46,13 @@
def check_oneof(**kwargs):
- """Raise ValueError if more than one keyword argument is not none.
+ """Raise ValueError if more than one keyword argument is not ``None``.
+
Args:
kwargs (dict): The keyword arguments sent to the function.
+
Raises:
- ValueError: If more than one entry in kwargs is not none.
+ ValueError: If more than one entry in ``kwargs`` is not ``None``.
"""
# Sanity check: If no keyword arguments were sent, this is fine.
if not kwargs:
@@ -62,10 +66,12 @@
def get_messages(module):
- """Return a dictionary of message names and objects.
+ """Discovers all protobuf Message classes in a given import module.
+
Args:
- module (module): A Python module; dir() will be run against this
+ module (module): A Python module; :func:`dir` will be run against this
module to find Message subclasses.
+
Returns:
dict[str, Message]: A dictionary with the Message class names as
keys, and the Message subclasses themselves as values.
@@ -76,3 +82,168 @@
if inspect.isclass(candidate) and issubclass(candidate, Message):
answer[name] = candidate
return answer
+
+
+def _resolve_subkeys(key, separator='.'):
+ """Resolve a potentially nested key.
+
+ If the key contains the ``separator`` (e.g. ``.``) then the key will be
+ split on the first instance of the subkey::
+
+ >>> _resolve_subkeys('a.b.c')
+ ('a', 'b.c')
+ >>> _resolve_subkeys('d|e|f', separator='|')
+ ('d', 'e|f')
+
+ If not, the subkey will be :data:`None`::
+
+ >>> _resolve_subkeys('foo')
+ ('foo', None)
+
+ Args:
+ key (str): A string that may or may not contain the separator.
+ separator (str): The namespace separator. Defaults to `.`.
+
+ Returns:
+ Tuple[str, str]: The key and subkey(s).
+ """
+ parts = key.split(separator, 1)
+
+ if len(parts) > 1:
+ return parts
+ else:
+ return parts[0], None
+
+
+def get(msg_or_dict, key, default=_SENTINEL):
+ """Retrieve a key's value from a protobuf Message or dictionary.
+
+ Args:
+ mdg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
+ object.
+ key (str): The key to retrieve from the object.
+ default (Any): If the key is not present on the object, and a default
+ is set, returns that default instead. A type-appropriate falsy
+ default is generally recommended, as protobuf messages almost
+ always have default values for unset values and it is not always
+ possible to tell the difference between a falsy value and an
+ unset one. If no default is set then :class:`KeyError` will be
+ raised if the key is not present in the object.
+
+ Returns:
+ Any: The return value from the underlying Message or dict.
+
+ Raises:
+ KeyError: If the key is not found. Note that, for unset values,
+ messages and dictionaries may not have consistent behavior.
+ TypeError: If ``msg_or_dict`` is not a Message or Mapping.
+ """
+ # We may need to get a nested key. Resolve this.
+ key, subkey = _resolve_subkeys(key)
+
+ # Attempt to get the value from the two types of objects we know about.
+ # If we get something else, complain.
+ if isinstance(msg_or_dict, Message):
+ answer = getattr(msg_or_dict, key, default)
+ elif isinstance(msg_or_dict, collections.Mapping):
+ answer = msg_or_dict.get(key, default)
+ else:
+ raise TypeError(
+ 'get() expected a dict or protobuf message, got {!r}.'.format(
+ type(msg_or_dict)))
+
+ # If the object we got back is our sentinel, raise KeyError; this is
+ # a "not found" case.
+ if answer is _SENTINEL:
+ raise KeyError(key)
+
+ # If a subkey exists, call this method recursively against the answer.
+ if subkey is not None and answer is not default:
+ return get(answer, subkey, default=default)
+
+ return answer
+
+
+def _set_field_on_message(msg, key, value):
+ """Set helper for protobuf Messages."""
+ # Attempt to set the value on the types of objects we know how to deal
+ # with.
+ if isinstance(value, (collections.MutableSequence, tuple)):
+ # Clear the existing repeated protobuf message of any elements
+ # currently inside it.
+ while getattr(msg, key):
+ getattr(msg, key).pop()
+
+ # Write our new elements to the repeated field.
+ for item in value:
+ if isinstance(item, collections.Mapping):
+ getattr(msg, key).add(**item)
+ else:
+ # protobuf's RepeatedCompositeContainer doesn't support
+ # append.
+ getattr(msg, key).extend([item])
+ elif isinstance(value, collections.Mapping):
+ # Assign the dictionary values to the protobuf message.
+ for item_key, item_value in value.items():
+ set(getattr(msg, key), item_key, item_value)
+ elif isinstance(value, Message):
+ getattr(msg, key).CopyFrom(value)
+ else:
+ setattr(msg, key, value)
+
+
+def set(msg_or_dict, key, value):
+ """Set a key's value on a protobuf Message or dictionary.
+
+ Args:
+ msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
+ object.
+ key (str): The key to set.
+ value (Any): The value to set.
+
+ Raises:
+ TypeError: If ``msg_or_dict`` is not a Message or dictionary.
+ """
+ # Sanity check: Is our target object valid?
+ if not isinstance(msg_or_dict, (collections.MutableMapping, Message)):
+ raise TypeError(
+ 'set() expected a dict or protobuf message, got {!r}.'.format(
+ type(msg_or_dict)))
+
+ # We may be setting a nested key. Resolve this.
+ basekey, subkey = _resolve_subkeys(key)
+
+ # If a subkey exists, then get that object and call this method
+ # recursively against it using the subkey.
+ if subkey is not None:
+ if isinstance(msg_or_dict, collections.MutableMapping):
+ msg_or_dict.setdefault(basekey, {})
+ set(get(msg_or_dict, basekey), subkey, value)
+ return
+
+ if isinstance(msg_or_dict, collections.MutableMapping):
+ msg_or_dict[key] = value
+ else:
+ _set_field_on_message(msg_or_dict, key, value)
+
+
+def setdefault(msg_or_dict, key, value):
+ """Set the key on a protobuf Message or dictionary to a given value if the
+ current value is falsy.
+
+ Because protobuf Messages do not distinguish between unset values and
+ falsy ones particularly well (by design), this method treats any falsy
+ value (e.g. 0, empty list) as a target to be overwritten, on both Messages
+ and dictionaries.
+
+ Args:
+ msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
+ object.
+ key (str): The key on the object in question.
+ value (Any): The value to set.
+
+ Raises:
+ TypeError: If ``msg_or_dict`` is not a Message or dictionary.
+ """
+ if not get(msg_or_dict, key, default=None):
+ set(msg_or_dict, key, value)
diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py
index b9aca76..8f86aa4 100644
--- a/tests/unit/test_protobuf_helpers.py
+++ b/tests/unit/test_protobuf_helpers.py
@@ -14,8 +14,11 @@
import pytest
+from google.api import http_pb2
from google.api_core import protobuf_helpers
+from google.longrunning import operations_pb2
from google.protobuf import any_pb2
+from google.protobuf import timestamp_pb2
from google.protobuf.message import Message
from google.type import date_pb2
from google.type import timeofday_pb2
@@ -65,3 +68,165 @@
# Ensure that no non-Message objects were exported.
for value in answer.values():
assert issubclass(value, Message)
+
+
+def test_get_dict_absent():
+ with pytest.raises(KeyError):
+ assert protobuf_helpers.get({}, 'foo')
+
+
+def test_get_dict_present():
+ assert protobuf_helpers.get({'foo': 'bar'}, 'foo') == 'bar'
+
+
+def test_get_dict_default():
+ assert protobuf_helpers.get({}, 'foo', default='bar') == 'bar'
+
+
+def test_get_dict_nested():
+ assert protobuf_helpers.get({'foo': {'bar': 'baz'}}, 'foo.bar') == 'baz'
+
+
+def test_get_dict_nested_default():
+ assert protobuf_helpers.get({}, 'foo.baz', default='bacon') == 'bacon'
+ assert (
+ protobuf_helpers.get({'foo': {}}, 'foo.baz', default='bacon') ==
+ 'bacon')
+
+
+def test_get_msg_sentinel():
+ msg = timestamp_pb2.Timestamp()
+ with pytest.raises(KeyError):
+ assert protobuf_helpers.get(msg, 'foo')
+
+
+def test_get_msg_present():
+ msg = timestamp_pb2.Timestamp(seconds=42)
+ assert protobuf_helpers.get(msg, 'seconds') == 42
+
+
+def test_get_msg_default():
+ msg = timestamp_pb2.Timestamp()
+ assert protobuf_helpers.get(msg, 'foo', default='bar') == 'bar'
+
+
+def test_invalid_object():
+ with pytest.raises(TypeError):
+ protobuf_helpers.get(object(), 'foo', 'bar')
+
+
+def test_set_dict():
+ mapping = {}
+ protobuf_helpers.set(mapping, 'foo', 'bar')
+ assert mapping == {'foo': 'bar'}
+
+
+def test_set_msg():
+ msg = timestamp_pb2.Timestamp()
+ protobuf_helpers.set(msg, 'seconds', 42)
+ assert msg.seconds == 42
+
+
+def test_set_dict_nested():
+ mapping = {}
+ protobuf_helpers.set(mapping, 'foo.bar', 'baz')
+ assert mapping == {'foo': {'bar': 'baz'}}
+
+
+def test_set_invalid_object():
+ with pytest.raises(TypeError):
+ protobuf_helpers.set(object(), 'foo', 'bar')
+
+
+def test_set_list():
+ list_ops_response = operations_pb2.ListOperationsResponse()
+
+ protobuf_helpers.set(list_ops_response, 'operations', [
+ {'name': 'foo'},
+ operations_pb2.Operation(name='bar'),
+ ])
+
+ assert len(list_ops_response.operations) == 2
+
+ for operation in list_ops_response.operations:
+ assert isinstance(operation, operations_pb2.Operation)
+
+ assert list_ops_response.operations[0].name == 'foo'
+ assert list_ops_response.operations[1].name == 'bar'
+
+
+def test_set_list_clear_existing():
+ list_ops_response = operations_pb2.ListOperationsResponse(
+ operations=[{'name': 'baz'}],
+ )
+
+ protobuf_helpers.set(list_ops_response, 'operations', [
+ {'name': 'foo'},
+ operations_pb2.Operation(name='bar'),
+ ])
+
+ assert len(list_ops_response.operations) == 2
+ for operation in list_ops_response.operations:
+ assert isinstance(operation, operations_pb2.Operation)
+ assert list_ops_response.operations[0].name == 'foo'
+ assert list_ops_response.operations[1].name == 'bar'
+
+
+def test_set_msg_with_msg_field():
+ rule = http_pb2.HttpRule()
+ pattern = http_pb2.CustomHttpPattern(kind='foo', path='bar')
+
+ protobuf_helpers.set(rule, 'custom', pattern)
+
+ assert rule.custom.kind == 'foo'
+ assert rule.custom.path == 'bar'
+
+
+def test_set_msg_with_dict_field():
+ rule = http_pb2.HttpRule()
+ pattern = {'kind': 'foo', 'path': 'bar'}
+
+ protobuf_helpers.set(rule, 'custom', pattern)
+
+ assert rule.custom.kind == 'foo'
+ assert rule.custom.path == 'bar'
+
+
+def test_set_msg_nested_key():
+ rule = http_pb2.HttpRule(
+ custom=http_pb2.CustomHttpPattern(kind='foo', path='bar'))
+
+ protobuf_helpers.set(rule, 'custom.kind', 'baz')
+
+ assert rule.custom.kind == 'baz'
+ assert rule.custom.path == 'bar'
+
+
+def test_setdefault_dict_unset():
+ mapping = {}
+ protobuf_helpers.setdefault(mapping, 'foo', 'bar')
+ assert mapping == {'foo': 'bar'}
+
+
+def test_setdefault_dict_falsy():
+ mapping = {'foo': None}
+ protobuf_helpers.setdefault(mapping, 'foo', 'bar')
+ assert mapping == {'foo': 'bar'}
+
+
+def test_setdefault_dict_truthy():
+ mapping = {'foo': 'bar'}
+ protobuf_helpers.setdefault(mapping, 'foo', 'baz')
+ assert mapping == {'foo': 'bar'}
+
+
+def test_setdefault_pb2_falsy():
+ operation = operations_pb2.Operation()
+ protobuf_helpers.setdefault(operation, 'name', 'foo')
+ assert operation.name == 'foo'
+
+
+def test_setdefault_pb2_truthy():
+ operation = operations_pb2.Operation(name='bar')
+ protobuf_helpers.setdefault(operation, 'name', 'foo')
+ assert operation.name == 'bar'