Extend special casing of NumPy arrays to anything supporting __array__.
If a non-tf class (np.ndarray subclass or not) has the __array__ method it is
used in combination with tf.constant to convert the input array to a tf.Tensor
before the users function or a cached graph function is called.
Note to match current tested behavior I have special cased instances of np.str_
to *not* be automatically converted to tensor.
PiperOrigin-RevId: 307375562
Change-Id: I9a84efc1dfdbb916c9894f138739e9b7e16f757e
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e6802e8..6a3f99b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -53,6 +53,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
@@ -106,32 +107,36 @@
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
elem))
- # If the element is not hashable, assume it is a weakref to a variable
- # and return the dtype & shape. Else, simply return the element
try:
hash(elem)
except TypeError:
+ # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
+ # all recognized types to be hashable.
assert isinstance(elem, weakref.ReferenceType)
v = elem()
- # Check if v is a Variable. Note that we can't use isinstance to check if
- # it's a variable, since not all variable types are subclass of Variable.
- # TODO(mdan) Update this to use a generic "Variable" superclass once we
- # create one.
- if not (hasattr(v, "shape") and hasattr(v, "dtype")):
- raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
- "or hashable Python objects (or nested structures of "
- "these types).\nGot type: %s" % type(v).__name__)
+ if resource_variable_ops.is_resource_variable(v):
+ idx = variable_map.get(id(v))
+ if idx is None:
+ idx = len(variable_map)
+ variable_map[id(v)] = idx
- idx = variable_map.get(id(v))
- if idx is None:
- idx = len(variable_map)
- variable_map[id(v)] = idx
+ # We include the class name to avoid having different types of variables
+ # having the same hash. We Also include the variable index which allows
+ # us to return a different hash if variables have been aliased in a call.
+ return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
- # We include the class name to avoid having different types of variables
- # having the same hash. We Also include the variable index which allows
- # us to return a different hash if variables have been aliased in a call.
- return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
+ if _is_ndarray(v):
+ # Numpy arrays are not hashable, but when calling functions we treat them
+ # in the same way as tf.Tensors.
+ if not hasattr(v, "shape") or not hasattr(v, "dtype"):
+ # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
+ v = _as_ndarray(v)
+ return tensor_spec.TensorSpec(v.shape, v.dtype)
+
+ raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
+ "or hashable Python objects (or nested structures of "
+ "these types).\nGot type: %s" % type(v).__name__)
return elem
@@ -2240,6 +2245,22 @@
return inputs, {}
+def _as_ndarray(value):
+ """Converts value to an ndarray, assumes _is_ndarray(value)."""
+ # TODO(tomhennigan) Support __array_interface__ too.
+ return value.__array__()
+
+
+def _is_ndarray(value):
+ """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
+ # TODO(tomhennigan) Support __array_interface__ too.
+ return hasattr(value, "__array__") and not (
+ resource_variable_ops.is_resource_variable(value)
+ or tensor_util.is_tensor(value)
+ # For legacy reasons we do not automatically promote Numpy strings.
+ or isinstance(value, np.str_))
+
+
def _convert_numpy_inputs(inputs):
"""Convert numpy array inputs to tensors."""
# We assume that any CompositeTensors have already converted their components
@@ -2252,8 +2273,12 @@
# possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
- if type(value) == np.ndarray:
- flat_inputs[index] = constant_op.constant(value)
+ if _is_ndarray(value):
+ a = _as_ndarray(value)
+ if not isinstance(a, np.ndarray):
+ raise TypeError("The output of __array__ must be an np.ndarray "
+ "(got {} from {}).".format(type(a), type(value)))
+ flat_inputs[index] = constant_op.constant(a)
need_packing = True
if need_packing:
return nest.pack_sequence_as(
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 96166f2..be6524c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -762,11 +762,34 @@
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
+ np_ones = numpy.ones([], numpy.float32)
+ np_zeros = numpy.zeros([], numpy.float32)
+ tf_ones = array_ops.ones([])
+ tf_zeros = array_ops.zeros([])
+
# Test that the numpy array is properly an argument to the graph function.
- self.assertEqual(1., defined(numpy.ones([])).numpy())
- self.assertEqual(0., defined(numpy.zeros([])).numpy())
- self.assertEqual(1., defined(array_ops.ones([])).numpy())
- self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+ self.assertEqual(1., defined(np_ones).numpy())
+ self.assertLen(total_function_cache(defined), 2)
+ self.assertEqual(0., defined(np_zeros).numpy())
+ self.assertEqual(1., defined(tf_ones).numpy())
+ self.assertEqual(0., defined(tf_zeros).numpy())
+ self.assertLen(total_function_cache(defined), 2)
+
+ # Test that mutable inputs are supported.
+ mutable = numpy.ones([], numpy.float32)
+ self.assertEqual(1., defined(mutable).numpy())
+ mutable.fill(0)
+ self.assertEqual(0., defined(mutable).numpy())
+
+ class MyNdarray(numpy.ndarray):
+ pass
+
+ # Test that the subclasses of ndarray are converted too.
+ self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
+ self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
+
+ # We should not have triggered any re-tracing of the python function.
+ self.assertLen(total_function_cache(defined), 2)
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):