Add protobuf_helpers.field_mask to calculate a field mask from two messages (#5320)
diff --git a/google/api_core/protobuf_helpers.py b/google/api_core/protobuf_helpers.py
index a3f3227..ab6d5a2 100644
--- a/google/api_core/protobuf_helpers.py
+++ b/google/api_core/protobuf_helpers.py
@@ -15,11 +15,25 @@
"""Helpers for :mod:`protobuf`."""
import collections
+import copy
import inspect
-from google.protobuf.message import Message
+from google.protobuf import field_mask_pb2
+from google.protobuf import message
+from google.protobuf import wrappers_pb2
_SENTINEL = object()
+_WRAPPER_TYPES = (
+ wrappers_pb2.BoolValue,
+ wrappers_pb2.BytesValue,
+ wrappers_pb2.DoubleValue,
+ wrappers_pb2.FloatValue,
+ wrappers_pb2.Int32Value,
+ wrappers_pb2.Int64Value,
+ wrappers_pb2.StringValue,
+ wrappers_pb2.UInt32Value,
+ wrappers_pb2.UInt64Value,
+)
def from_any_pb(pb_type, any_pb):
@@ -73,13 +87,15 @@
module to find Message subclasses.
Returns:
- dict[str, Message]: A dictionary with the Message class names as
- keys, and the Message subclasses themselves as values.
+ dict[str, google.protobuf.message.Message]: A dictionary with the
+ Message class names as keys, and the Message subclasses themselves
+ as values.
"""
answer = collections.OrderedDict()
for name in dir(module):
candidate = getattr(module, name)
- if inspect.isclass(candidate) and issubclass(candidate, Message):
+ if (inspect.isclass(candidate) and
+ issubclass(candidate, message.Message)):
answer[name] = candidate
return answer
@@ -143,7 +159,7 @@
# Attempt to get the value from the two types of objects we know about.
# If we get something else, complain.
- if isinstance(msg_or_dict, Message):
+ if isinstance(msg_or_dict, message.Message):
answer = getattr(msg_or_dict, key, default)
elif isinstance(msg_or_dict, collections.Mapping):
answer = msg_or_dict.get(key, default)
@@ -186,7 +202,7 @@
# Assign the dictionary values to the protobuf message.
for item_key, item_value in value.items():
set(getattr(msg, key), item_key, item_value)
- elif isinstance(value, Message):
+ elif isinstance(value, message.Message):
getattr(msg, key).CopyFrom(value)
else:
setattr(msg, key, value)
@@ -205,7 +221,8 @@
TypeError: If ``msg_or_dict`` is not a Message or dictionary.
"""
# Sanity check: Is our target object valid?
- if not isinstance(msg_or_dict, (collections.MutableMapping, Message)):
+ if (not isinstance(msg_or_dict,
+ (collections.MutableMapping, message.Message))):
raise TypeError(
'set() expected a dict or protobuf message, got {!r}.'.format(
type(msg_or_dict)))
@@ -247,3 +264,84 @@
"""
if not get(msg_or_dict, key, default=None):
set(msg_or_dict, key, value)
+
+
+def field_mask(original, modified):
+ """Create a field mask by comparing two messages.
+
+ Args:
+ original (~google.protobuf.message.Message): the original message.
+ If set to None, this field will be interpretted as an empty
+ message.
+ modified (~google.protobuf.message.Message): the modified message.
+ If set to None, this field will be interpretted as an empty
+ message.
+
+ Returns:
+ google.protobuf.field_mask_pb2.FieldMask: field mask that contains
+ the list of field names that have different values between the two
+ messages. If the messages are equivalent, then the field mask is empty.
+
+ Raises:
+ ValueError: If the ``original`` or ``modified`` are not the same type.
+ """
+ if original is None and modified is None:
+ return field_mask_pb2.FieldMask()
+
+ if original is None and modified is not None:
+ original = copy.deepcopy(modified)
+ original.Clear()
+
+ if modified is None and original is not None:
+ modified = copy.deepcopy(original)
+ modified.Clear()
+
+ if type(original) != type(modified):
+ raise ValueError(
+ 'expected that both original and modified should be of the '
+ 'same type, received "{!r}" and "{!r}".'.
+ format(type(original), type(modified)))
+
+ return field_mask_pb2.FieldMask(
+ paths=_field_mask_helper(original, modified))
+
+
+def _field_mask_helper(original, modified, current=''):
+ answer = []
+
+ for name in original.DESCRIPTOR.fields_by_name:
+ field_path = _get_path(current, name)
+
+ original_val = getattr(original, name)
+ modified_val = getattr(modified, name)
+
+ if _is_message(original_val) or _is_message(modified_val):
+ if original_val != modified_val:
+ # Wrapper types do not need to include the .value part of the
+ # path.
+ if _is_wrapper(original_val) or _is_wrapper(modified_val):
+ answer.append(field_path)
+ elif not modified_val.ListFields():
+ answer.append(field_path)
+ else:
+ answer.extend(_field_mask_helper(original_val,
+ modified_val, field_path))
+ else:
+ if original_val != modified_val:
+ answer.append(field_path)
+
+ return answer
+
+
+def _get_path(current, name):
+ if not current:
+ return name
+ return '%s.%s' % (current, name)
+
+
+def _is_message(value):
+ return isinstance(value, message.Message)
+
+
+def _is_wrapper(value):
+ return type(value) in _WRAPPER_TYPES
diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py
index 5ccc4b8..83e078b 100644
--- a/tests/unit/test_protobuf_helpers.py
+++ b/tests/unit/test_protobuf_helpers.py
@@ -18,8 +18,13 @@
from google.api_core import protobuf_helpers
from google.longrunning import operations_pb2
from google.protobuf import any_pb2
+from google.protobuf import message
+from google.protobuf import source_context_pb2
+from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
-from google.protobuf.message import Message
+from google.protobuf import type_pb2
+from google.protobuf import wrappers_pb2
+from google.type import color_pb2
from google.type import date_pb2
from google.type import timeofday_pb2
@@ -67,7 +72,7 @@
# Ensure that no non-Message objects were exported.
for value in answer.values():
- assert issubclass(value, Message)
+ assert issubclass(value, message.Message)
def test_get_dict_absent():
@@ -230,3 +235,225 @@
operation = operations_pb2.Operation(name='bar')
protobuf_helpers.setdefault(operation, 'name', 'foo')
assert operation.name == 'bar'
+
+
+def test_field_mask_invalid_args():
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask('foo', any_pb2.Any())
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask(any_pb2.Any(), 'bar')
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation())
+
+
+def test_field_mask_equal_values():
+ assert protobuf_helpers.field_mask(None, None).paths == []
+
+ original = struct_pb2.Value(number_value=1.0)
+ modified = struct_pb2.Value(number_value=1.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0)])
+ modified = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0)])
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = struct_pb2.Struct(
+ fields={'bar': struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(
+ fields={'bar': struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_zero_values():
+ # Singular Values
+ original = color_pb2.Color(red=0.0)
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = color_pb2.Color(red=0.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Repeated Values
+ original = struct_pb2.ListValue(values=[])
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.ListValue(values=[])
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Maps
+ original = struct_pb2.Struct(fields={})
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.Struct(fields={})
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Oneofs
+ original = struct_pb2.Value(number_value=0.0)
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.Value(number_value=0.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_singular_field_diffs():
+ original = type_pb2.Type(name='name')
+ modified = type_pb2.Type()
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['name'])
+
+ original = type_pb2.Type(name='name')
+ modified = type_pb2.Type()
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['name'])
+
+ original = None
+ modified = type_pb2.Type(name='name')
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['name'])
+
+ original = type_pb2.Type(name='name')
+ modified = None
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['name'])
+
+
+def test_field_mask_message_diffs():
+ original = type_pb2.Type()
+ modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='name'))
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['source_context.file_name'])
+
+ original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='name'))
+ modified = type_pb2.Type()
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['source_context'])
+
+ original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='name'))
+ modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='other_name'))
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['source_context.file_name'])
+
+ original = None
+ modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='name'))
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['source_context.file_name'])
+
+ original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+ file_name='name'))
+ modified = None
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['source_context'])
+
+
+def test_field_mask_wrapper_type_diffs():
+ original = color_pb2.Color()
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == ['alpha']
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color()
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['alpha'])
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['alpha'])
+
+ original = None
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['alpha'])
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = None
+ assert (protobuf_helpers.field_mask(original, modified).paths ==
+ ['alpha'])
+
+
+def test_field_mask_repeated_diffs():
+ original = struct_pb2.ListValue()
+ modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+ struct_pb2.Value(number_value=2.0)])
+ assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+ original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+ struct_pb2.Value(number_value=2.0)])
+ modified = struct_pb2.ListValue()
+ assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+ original = None
+ modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+ struct_pb2.Value(number_value=2.0)])
+ assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+ original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+ struct_pb2.Value(number_value=2.0)])
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+ original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+ struct_pb2.Value(number_value=2.0)])
+ modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=2.0),
+ struct_pb2.Value(number_value=1.0)])
+ assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+
+def test_field_mask_map_diffs():
+ original = struct_pb2.Struct()
+ modified = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+ original = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct()
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+ original = None
+ modified = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+ original = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+ original = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=2.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+ original = struct_pb2.Struct(
+ fields={'foo': struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(
+ fields={'bar': struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+
+def test_field_mask_different_level_diffs():
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(
+ alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0)
+ assert (sorted(protobuf_helpers.field_mask(original, modified).paths) ==
+ ['alpha', 'red'])