Add protobuf_helpers.field_mask to calculate a field mask from two messages (#5320)

diff --git a/google/api_core/protobuf_helpers.py b/google/api_core/protobuf_helpers.py
index a3f3227..ab6d5a2 100644
--- a/google/api_core/protobuf_helpers.py
+++ b/google/api_core/protobuf_helpers.py
@@ -15,11 +15,25 @@
 """Helpers for :mod:`protobuf`."""
 
 import collections
+import copy
 import inspect
 
-from google.protobuf.message import Message
+from google.protobuf import field_mask_pb2
+from google.protobuf import message
+from google.protobuf import wrappers_pb2
 
 _SENTINEL = object()
+_WRAPPER_TYPES = (
+    wrappers_pb2.BoolValue,
+    wrappers_pb2.BytesValue,
+    wrappers_pb2.DoubleValue,
+    wrappers_pb2.FloatValue,
+    wrappers_pb2.Int32Value,
+    wrappers_pb2.Int64Value,
+    wrappers_pb2.StringValue,
+    wrappers_pb2.UInt32Value,
+    wrappers_pb2.UInt64Value,
+)
 
 
 def from_any_pb(pb_type, any_pb):
@@ -73,13 +87,15 @@
             module to find Message subclasses.
 
     Returns:
-        dict[str, Message]: A dictionary with the Message class names as
-            keys, and the Message subclasses themselves as values.
+        dict[str, google.protobuf.message.Message]: A dictionary with the
+            Message class names as keys, and the Message subclasses themselves
+            as values.
     """
     answer = collections.OrderedDict()
     for name in dir(module):
         candidate = getattr(module, name)
-        if inspect.isclass(candidate) and issubclass(candidate, Message):
+        if (inspect.isclass(candidate) and
+                issubclass(candidate, message.Message)):
             answer[name] = candidate
     return answer
 
@@ -143,7 +159,7 @@
 
     # Attempt to get the value from the two types of objects we know about.
     # If we get something else, complain.
-    if isinstance(msg_or_dict, Message):
+    if isinstance(msg_or_dict, message.Message):
         answer = getattr(msg_or_dict, key, default)
     elif isinstance(msg_or_dict, collections.Mapping):
         answer = msg_or_dict.get(key, default)
@@ -186,7 +202,7 @@
         # Assign the dictionary values to the protobuf message.
         for item_key, item_value in value.items():
             set(getattr(msg, key), item_key, item_value)
-    elif isinstance(value, Message):
+    elif isinstance(value, message.Message):
         getattr(msg, key).CopyFrom(value)
     else:
         setattr(msg, key, value)
@@ -205,7 +221,8 @@
         TypeError: If ``msg_or_dict`` is not a Message or dictionary.
     """
     # Sanity check: Is our target object valid?
-    if not isinstance(msg_or_dict, (collections.MutableMapping, Message)):
+    if (not isinstance(msg_or_dict,
+                       (collections.MutableMapping, message.Message))):
         raise TypeError(
             'set() expected a dict or protobuf message, got {!r}.'.format(
                 type(msg_or_dict)))
@@ -247,3 +264,84 @@
     """
     if not get(msg_or_dict, key, default=None):
         set(msg_or_dict, key, value)
+
+
+def field_mask(original, modified):
+    """Create a field mask by comparing two messages.
+
+    Args:
+        original (~google.protobuf.message.Message): the original message.
+            If set to None, this field will be interpretted as an empty
+            message.
+        modified (~google.protobuf.message.Message): the modified message.
+            If set to None, this field will be interpretted as an empty
+            message.
+
+    Returns:
+        google.protobuf.field_mask_pb2.FieldMask: field mask that contains
+        the list of field names that have different values between the two
+        messages. If the messages are equivalent, then the field mask is empty.
+
+    Raises:
+        ValueError: If the ``original`` or ``modified`` are not the same type.
+    """
+    if original is None and modified is None:
+        return field_mask_pb2.FieldMask()
+
+    if original is None and modified is not None:
+        original = copy.deepcopy(modified)
+        original.Clear()
+
+    if modified is None and original is not None:
+        modified = copy.deepcopy(original)
+        modified.Clear()
+
+    if type(original) != type(modified):
+        raise ValueError(
+                'expected that both original and modified should be of the '
+                'same type, received "{!r}" and "{!r}".'.
+                format(type(original), type(modified)))
+
+    return field_mask_pb2.FieldMask(
+        paths=_field_mask_helper(original, modified))
+
+
+def _field_mask_helper(original, modified, current=''):
+    answer = []
+
+    for name in original.DESCRIPTOR.fields_by_name:
+        field_path = _get_path(current, name)
+
+        original_val = getattr(original, name)
+        modified_val = getattr(modified, name)
+
+        if _is_message(original_val) or _is_message(modified_val):
+            if original_val != modified_val:
+                # Wrapper types do not need to include the .value part of the
+                # path.
+                if _is_wrapper(original_val) or _is_wrapper(modified_val):
+                    answer.append(field_path)
+                elif not modified_val.ListFields():
+                    answer.append(field_path)
+                else:
+                    answer.extend(_field_mask_helper(original_val,
+                                                     modified_val, field_path))
+        else:
+            if original_val != modified_val:
+                answer.append(field_path)
+
+    return answer
+
+
+def _get_path(current, name):
+    if not current:
+        return name
+    return '%s.%s' % (current, name)
+
+
+def _is_message(value):
+    return isinstance(value, message.Message)
+
+
+def _is_wrapper(value):
+    return type(value) in _WRAPPER_TYPES
diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py
index 5ccc4b8..83e078b 100644
--- a/tests/unit/test_protobuf_helpers.py
+++ b/tests/unit/test_protobuf_helpers.py
@@ -18,8 +18,13 @@
 from google.api_core import protobuf_helpers
 from google.longrunning import operations_pb2
 from google.protobuf import any_pb2
+from google.protobuf import message
+from google.protobuf import source_context_pb2
+from google.protobuf import struct_pb2
 from google.protobuf import timestamp_pb2
-from google.protobuf.message import Message
+from google.protobuf import type_pb2
+from google.protobuf import wrappers_pb2
+from google.type import color_pb2
 from google.type import date_pb2
 from google.type import timeofday_pb2
 
@@ -67,7 +72,7 @@
 
     # Ensure that no non-Message objects were exported.
     for value in answer.values():
-        assert issubclass(value, Message)
+        assert issubclass(value, message.Message)
 
 
 def test_get_dict_absent():
@@ -230,3 +235,225 @@
     operation = operations_pb2.Operation(name='bar')
     protobuf_helpers.setdefault(operation, 'name', 'foo')
     assert operation.name == 'bar'
+
+
+def test_field_mask_invalid_args():
+    with pytest.raises(ValueError):
+        protobuf_helpers.field_mask('foo', any_pb2.Any())
+    with pytest.raises(ValueError):
+        protobuf_helpers.field_mask(any_pb2.Any(), 'bar')
+    with pytest.raises(ValueError):
+        protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation())
+
+
+def test_field_mask_equal_values():
+    assert protobuf_helpers.field_mask(None, None).paths == []
+
+    original = struct_pb2.Value(number_value=1.0)
+    modified = struct_pb2.Value(number_value=1.0)
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = struct_pb2.ListValue(
+            values=[struct_pb2.Value(number_value=1.0)])
+    modified = struct_pb2.ListValue(
+            values=[struct_pb2.Value(number_value=1.0)])
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = struct_pb2.Struct(
+            fields={'bar': struct_pb2.Value(number_value=1.0)})
+    modified = struct_pb2.Struct(
+            fields={'bar': struct_pb2.Value(number_value=1.0)})
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_zero_values():
+    # Singular Values
+    original = color_pb2.Color(red=0.0)
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = None
+    modified = color_pb2.Color(red=0.0)
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    # Repeated Values
+    original = struct_pb2.ListValue(values=[])
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = None
+    modified = struct_pb2.ListValue(values=[])
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    # Maps
+    original = struct_pb2.Struct(fields={})
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = None
+    modified = struct_pb2.Struct(fields={})
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    # Oneofs
+    original = struct_pb2.Value(number_value=0.0)
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+    original = None
+    modified = struct_pb2.Value(number_value=0.0)
+    assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_singular_field_diffs():
+    original = type_pb2.Type(name='name')
+    modified = type_pb2.Type()
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['name'])
+
+    original = type_pb2.Type(name='name')
+    modified = type_pb2.Type()
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['name'])
+
+    original = None
+    modified = type_pb2.Type(name='name')
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['name'])
+
+    original = type_pb2.Type(name='name')
+    modified = None
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['name'])
+
+
+def test_field_mask_message_diffs():
+    original = type_pb2.Type()
+    modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                            file_name='name'))
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['source_context.file_name'])
+
+    original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                             file_name='name'))
+    modified = type_pb2.Type()
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['source_context'])
+
+    original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                             file_name='name'))
+    modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                             file_name='other_name'))
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['source_context.file_name'])
+
+    original = None
+    modified = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                             file_name='name'))
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['source_context.file_name'])
+
+    original = type_pb2.Type(source_context=source_context_pb2.SourceContext(
+                             file_name='name'))
+    modified = None
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['source_context'])
+
+
+def test_field_mask_wrapper_type_diffs():
+    original = color_pb2.Color()
+    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    assert protobuf_helpers.field_mask(original, modified).paths == ['alpha']
+
+    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    modified = color_pb2.Color()
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['alpha'])
+
+    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['alpha'])
+
+    original = None
+    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['alpha'])
+
+    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    modified = None
+    assert (protobuf_helpers.field_mask(original, modified).paths ==
+            ['alpha'])
+
+
+def test_field_mask_repeated_diffs():
+    original = struct_pb2.ListValue()
+    modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+                                    struct_pb2.Value(number_value=2.0)])
+    assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+    original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+                                    struct_pb2.Value(number_value=2.0)])
+    modified = struct_pb2.ListValue()
+    assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+    original = None
+    modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+                                    struct_pb2.Value(number_value=2.0)])
+    assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+    original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+                                    struct_pb2.Value(number_value=2.0)])
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+    original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0),
+                                    struct_pb2.Value(number_value=2.0)])
+    modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=2.0),
+                                    struct_pb2.Value(number_value=1.0)])
+    assert protobuf_helpers.field_mask(original, modified).paths == ['values']
+
+
+def test_field_mask_map_diffs():
+    original = struct_pb2.Struct()
+    modified = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+    original = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    modified = struct_pb2.Struct()
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+    original = None
+    modified = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+    original = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    modified = None
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+    original = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    modified = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=2.0)})
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+    original = struct_pb2.Struct(
+            fields={'foo': struct_pb2.Value(number_value=1.0)})
+    modified = struct_pb2.Struct(
+            fields={'bar': struct_pb2.Value(number_value=1.0)})
+    assert protobuf_helpers.field_mask(original, modified).paths == ['fields']
+
+
+def test_field_mask_different_level_diffs():
+    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+    modified = color_pb2.Color(
+            alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0)
+    assert (sorted(protobuf_helpers.field_mask(original, modified).paths) ==
+            ['alpha', 'red'])