Don't treat type objects (with __array__) as ndarrays.

PiperOrigin-RevId: 307454154
Change-Id: I6669c41e4dd8256ffd7c4203a1e84ddc2b2f876b
This commit is contained in:
Tom Hennigan 2020-04-20 12:33:49 -07:00 committed by TensorFlower Gardener
parent 931269353d
commit 6728f85d82
2 changed files with 84 additions and 24 deletions

View File

@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -107,32 +108,36 @@ def _make_input_signature_hashable(elem, variable_map=None):
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
elem))
# If the element is not hashable, assume it is a weakref to a variable
# and return the dtype & shape. Else, simply return the element
try:
hash(elem)
except TypeError:
# TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
# all recognized types to be hashable.
assert isinstance(elem, weakref.ReferenceType)
v = elem()
# Check if v is a Variable. Note that we can't use isinstance to check if
# it's a variable, since not all variable types are subclass of Variable.
# TODO(mdan) Update this to use a generic "Variable" superclass once we
# create one.
if not (hasattr(v, "shape") and hasattr(v, "dtype")):
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
"or hashable Python objects (or nested structures of "
"these types).\nGot type: %s" % type(v).__name__)
if resource_variable_ops.is_resource_variable(v):
idx = variable_map.get(id(v))
if idx is None:
idx = len(variable_map)
variable_map[id(v)] = idx
idx = variable_map.get(id(v))
if idx is None:
idx = len(variable_map)
variable_map[id(v)] = idx
# We include the class name to avoid having different types of variables
# having the same hash. We Also include the variable index which allows
# us to return a different hash if variables have been aliased in a call.
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
# We include the class name to avoid having different types of variables
# having the same hash. We Also include the variable index which allows
# us to return a different hash if variables have been aliased in a call.
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
if _is_ndarray(v):
# Numpy arrays are not hashable, but when calling functions we treat them
# in the same way as tf.Tensors.
if not hasattr(v, "shape") or not hasattr(v, "dtype"):
# TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
v = _as_ndarray(v)
return tensor_spec.TensorSpec(v.shape, v.dtype)
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
"or hashable Python objects (or nested structures of "
"these types).\nGot type: %s" % type(v).__name__)
return elem
@ -2668,6 +2673,24 @@ class FunctionSpec(object):
return inputs, {}
def _as_ndarray(value):
"""Converts value to an ndarray, assumes _is_ndarray(value)."""
# TODO(tomhennigan) Support __array_interface__ too.
return value.__array__()
def _is_ndarray(value):
"""Tests whether the given value is an ndarray (and not a TF tensor/var)."""
# TODO(tomhennigan) Support __array_interface__ too.
return hasattr(value, "__array__") and not (
resource_variable_ops.is_resource_variable(value)
or tensor_util.is_tensor(value)
# For legacy reasons we do not automatically promote Numpy strings.
or isinstance(value, np.str_)
# NumPy dtypes have __array__ as unbound methods.
or isinstance(value, type))
def _convert_numpy_inputs(inputs):
"""Convert numpy array inputs to tensors."""
# We assume that any CompositeTensors have already converted their components
@ -2680,8 +2703,12 @@ def _convert_numpy_inputs(inputs):
# possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
if _is_ndarray(value):
a = _as_ndarray(value)
if not isinstance(a, np.ndarray):
raise TypeError("The output of __array__ must be an np.ndarray "
"(got {} from {}).".format(type(a), type(value)))
flat_inputs[index] = constant_op.constant(a)
need_packing = True
if need_packing:
return nest.pack_sequence_as(

View File

@ -775,11 +775,44 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
np_ones = numpy.ones([], numpy.float32)
np_zeros = numpy.zeros([], numpy.float32)
tf_ones = array_ops.ones([])
tf_zeros = array_ops.zeros([])
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(numpy.ones([])).numpy())
self.assertEqual(0., defined(numpy.zeros([])).numpy())
self.assertEqual(1., defined(array_ops.ones([])).numpy())
self.assertEqual(0., defined(array_ops.zeros([])).numpy())
self.assertEqual(1., defined(np_ones).numpy())
self.assertLen(total_function_cache(defined), 2)
self.assertEqual(0., defined(np_zeros).numpy())
self.assertEqual(1., defined(tf_ones).numpy())
self.assertEqual(0., defined(tf_zeros).numpy())
self.assertLen(total_function_cache(defined), 2)
# Test that mutable inputs are supported.
mutable = numpy.ones([], numpy.float32)
self.assertEqual(1., defined(mutable).numpy())
mutable.fill(0)
self.assertEqual(0., defined(mutable).numpy())
class MyNdarray(numpy.ndarray):
pass
# Test that the subclasses of ndarray are converted too.
self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
# We should not have triggered any re-tracing of the python function.
self.assertLen(total_function_cache(defined), 2)
def testNumpyDtypeInputSupported(self):
@function.defun
def f(x, dtype):
return constant_op.constant(dtype(x))
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):