diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 113446680bc..57a915a17c9 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -107,32 +108,36 @@ def _make_input_signature_hashable(elem, variable_map=None): return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map), elem)) - # If the element is not hashable, assume it is a weakref to a variable - # and return the dtype & shape. Else, simply return the element try: hash(elem) except TypeError: + # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect + # all recognized types to be hashable. assert isinstance(elem, weakref.ReferenceType) v = elem() - # Check if v is a Variable. Note that we can't use isinstance to check if - # it's a variable, since not all variable types are subclass of Variable. - # TODO(mdan) Update this to use a generic "Variable" superclass once we - # create one. - if not (hasattr(v, "shape") and hasattr(v, "dtype")): - raise ValueError("Arguments to a tf.function must be Tensors, Variables, " - "or hashable Python objects (or nested structures of " - "these types).\nGot type: %s" % type(v).__name__) + if resource_variable_ops.is_resource_variable(v): + idx = variable_map.get(id(v)) + if idx is None: + idx = len(variable_map) + variable_map[id(v)] = idx - idx = variable_map.get(id(v)) - if idx is None: - idx = len(variable_map) - variable_map[id(v)] = idx + # We include the class name to avoid having different types of variables + # having the same hash. We Also include the variable index which allows + # us to return a different hash if variables have been aliased in a call. + return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx - # We include the class name to avoid having different types of variables - # having the same hash. We Also include the variable index which allows - # us to return a different hash if variables have been aliased in a call. - return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx + if _is_ndarray(v): + # Numpy arrays are not hashable, but when calling functions we treat them + # in the same way as tf.Tensors. + if not hasattr(v, "shape") or not hasattr(v, "dtype"): + # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs. + v = _as_ndarray(v) + return tensor_spec.TensorSpec(v.shape, v.dtype) + + raise ValueError("Arguments to a tf.function must be Tensors, Variables, " + "or hashable Python objects (or nested structures of " + "these types).\nGot type: %s" % type(v).__name__) return elem @@ -2668,6 +2673,24 @@ class FunctionSpec(object): return inputs, {} +def _as_ndarray(value): + """Converts value to an ndarray, assumes _is_ndarray(value).""" + # TODO(tomhennigan) Support __array_interface__ too. + return value.__array__() + + +def _is_ndarray(value): + """Tests whether the given value is an ndarray (and not a TF tensor/var).""" + # TODO(tomhennigan) Support __array_interface__ too. + return hasattr(value, "__array__") and not ( + resource_variable_ops.is_resource_variable(value) + or tensor_util.is_tensor(value) + # For legacy reasons we do not automatically promote Numpy strings. + or isinstance(value, np.str_) + # NumPy dtypes have __array__ as unbound methods. + or isinstance(value, type)) + + def _convert_numpy_inputs(inputs): """Convert numpy array inputs to tensors.""" # We assume that any CompositeTensors have already converted their components @@ -2680,8 +2703,12 @@ def _convert_numpy_inputs(inputs): # possible since ndarrays are not hashable). need_packing = False for index, value in enumerate(flat_inputs): - if type(value) == np.ndarray: - flat_inputs[index] = constant_op.constant(value) + if _is_ndarray(value): + a = _as_ndarray(value) + if not isinstance(a, np.ndarray): + raise TypeError("The output of __array__ must be an np.ndarray " + "(got {} from {}).".format(type(a), type(value))) + flat_inputs[index] = constant_op.constant(a) need_packing = True if need_packing: return nest.pack_sequence_as( diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index c13a1fd4794..fd668716236 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -775,11 +775,44 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # shouldn't trigger another function definition. self.assertLen(total_function_cache(defined), 1) + np_ones = numpy.ones([], numpy.float32) + np_zeros = numpy.zeros([], numpy.float32) + tf_ones = array_ops.ones([]) + tf_zeros = array_ops.zeros([]) + # Test that the numpy array is properly an argument to the graph function. - self.assertEqual(1., defined(numpy.ones([])).numpy()) - self.assertEqual(0., defined(numpy.zeros([])).numpy()) - self.assertEqual(1., defined(array_ops.ones([])).numpy()) - self.assertEqual(0., defined(array_ops.zeros([])).numpy()) + self.assertEqual(1., defined(np_ones).numpy()) + self.assertLen(total_function_cache(defined), 2) + self.assertEqual(0., defined(np_zeros).numpy()) + self.assertEqual(1., defined(tf_ones).numpy()) + self.assertEqual(0., defined(tf_zeros).numpy()) + self.assertLen(total_function_cache(defined), 2) + + # Test that mutable inputs are supported. + mutable = numpy.ones([], numpy.float32) + self.assertEqual(1., defined(mutable).numpy()) + mutable.fill(0) + self.assertEqual(0., defined(mutable).numpy()) + + class MyNdarray(numpy.ndarray): + pass + + # Test that the subclasses of ndarray are converted too. + self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy()) + self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy()) + + # We should not have triggered any re-tracing of the python function. + self.assertLen(total_function_cache(defined), 2) + + def testNumpyDtypeInputSupported(self): + @function.defun + def f(x, dtype): + return constant_op.constant(dtype(x)) + + self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1)) + self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2)) + self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1)) + self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2)) def testDefunNumpyArraysConvertedToTensorsInKwargs(self):