Extend special casing of NumPy arrays to anything supporting __array__.

If a non-tf class (np.ndarray subclass or not) has the __array__ method it is
used in combination with tf.constant to convert the input array to a tf.Tensor
before the users function or a cached graph function is called.

Note to match current tested behavior I have special cased instances of np.str_
to *not* be automatically converted to tensor.

PiperOrigin-RevId: 307375562
Change-Id: I9a84efc1dfdbb916c9894f138739e9b7e16f757e
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e6802e8..6a3f99b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -53,6 +53,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import custom_gradient
@@ -106,32 +107,36 @@
     return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
                      elem))
 
-  # If the element is not hashable, assume it is a weakref to a variable
-  # and return the dtype & shape. Else, simply return the element
   try:
     hash(elem)
   except TypeError:
+    # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
+    # all recognized types to be hashable.
     assert isinstance(elem, weakref.ReferenceType)
     v = elem()
 
-    # Check if v is a Variable.  Note that we can't use isinstance to check if
-    # it's a variable, since not all variable types are subclass of Variable.
-    # TODO(mdan) Update this to use a generic "Variable" superclass once we
-    # create one.
-    if not (hasattr(v, "shape") and hasattr(v, "dtype")):
-      raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
-                       "or hashable Python objects (or nested structures of "
-                       "these types).\nGot type: %s" % type(v).__name__)
+    if resource_variable_ops.is_resource_variable(v):
+      idx = variable_map.get(id(v))
+      if idx is None:
+        idx = len(variable_map)
+        variable_map[id(v)] = idx
 
-    idx = variable_map.get(id(v))
-    if idx is None:
-      idx = len(variable_map)
-      variable_map[id(v)] = idx
+      # We include the class name to avoid having different types of variables
+      # having the same hash. We Also include the variable index which allows
+      # us to return a different hash if variables have been aliased in a call.
+      return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
 
-    # We include the class name to avoid having different types of variables
-    # having the same hash. We Also include the variable index which allows
-    # us to return a different hash if variables have been aliased in a call.
-    return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
+    if _is_ndarray(v):
+      # Numpy arrays are not hashable, but when calling functions we treat them
+      # in the same way as tf.Tensors.
+      if not hasattr(v, "shape") or not hasattr(v, "dtype"):
+        # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
+        v = _as_ndarray(v)
+      return tensor_spec.TensorSpec(v.shape, v.dtype)
+
+    raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
+                     "or hashable Python objects (or nested structures of "
+                     "these types).\nGot type: %s" % type(v).__name__)
 
   return elem
 
@@ -2240,6 +2245,22 @@
       return inputs, {}
 
 
+def _as_ndarray(value):
+  """Converts value to an ndarray, assumes _is_ndarray(value)."""
+  # TODO(tomhennigan) Support __array_interface__ too.
+  return value.__array__()
+
+
+def _is_ndarray(value):
+  """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
+  # TODO(tomhennigan) Support __array_interface__ too.
+  return hasattr(value, "__array__") and not (
+      resource_variable_ops.is_resource_variable(value)
+      or tensor_util.is_tensor(value)
+      # For legacy reasons we do not automatically promote Numpy strings.
+      or isinstance(value, np.str_))
+
+
 def _convert_numpy_inputs(inputs):
   """Convert numpy array inputs to tensors."""
   # We assume that any CompositeTensors have already converted their components
@@ -2252,8 +2273,12 @@
   # possible since ndarrays are not hashable).
   need_packing = False
   for index, value in enumerate(flat_inputs):
-    if type(value) == np.ndarray:
-      flat_inputs[index] = constant_op.constant(value)
+    if _is_ndarray(value):
+      a = _as_ndarray(value)
+      if not isinstance(a, np.ndarray):
+        raise TypeError("The output of __array__ must be an np.ndarray "
+                        "(got {} from {}).".format(type(a), type(value)))
+      flat_inputs[index] = constant_op.constant(a)
       need_packing = True
   if need_packing:
     return nest.pack_sequence_as(
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 96166f2..be6524c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -762,11 +762,34 @@
     # shouldn't trigger another function definition.
     self.assertLen(total_function_cache(defined), 1)
 
+    np_ones = numpy.ones([], numpy.float32)
+    np_zeros = numpy.zeros([], numpy.float32)
+    tf_ones = array_ops.ones([])
+    tf_zeros = array_ops.zeros([])
+
     # Test that the numpy array is properly an argument to the graph function.
-    self.assertEqual(1., defined(numpy.ones([])).numpy())
-    self.assertEqual(0., defined(numpy.zeros([])).numpy())
-    self.assertEqual(1., defined(array_ops.ones([])).numpy())
-    self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+    self.assertEqual(1., defined(np_ones).numpy())
+    self.assertLen(total_function_cache(defined), 2)
+    self.assertEqual(0., defined(np_zeros).numpy())
+    self.assertEqual(1., defined(tf_ones).numpy())
+    self.assertEqual(0., defined(tf_zeros).numpy())
+    self.assertLen(total_function_cache(defined), 2)
+
+    # Test that mutable inputs are supported.
+    mutable = numpy.ones([], numpy.float32)
+    self.assertEqual(1., defined(mutable).numpy())
+    mutable.fill(0)
+    self.assertEqual(0., defined(mutable).numpy())
+
+    class MyNdarray(numpy.ndarray):
+      pass
+
+     # Test that the subclasses of ndarray are converted too.
+    self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
+    self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
+
+    # We should not have triggered any re-tracing of the python function.
+    self.assertLen(total_function_cache(defined), 2)
 
   def testDefunNumpyArraysConvertedToTensorsInKwargs(self):