Don't treat type
objects (with __array__) as ndarrays.
PiperOrigin-RevId: 307454154 Change-Id: I6669c41e4dd8256ffd7c4203a1e84ddc2b2f876b
This commit is contained in:
parent
931269353d
commit
6728f85d82
@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -107,32 +108,36 @@ def _make_input_signature_hashable(elem, variable_map=None):
|
||||
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
|
||||
elem))
|
||||
|
||||
# If the element is not hashable, assume it is a weakref to a variable
|
||||
# and return the dtype & shape. Else, simply return the element
|
||||
try:
|
||||
hash(elem)
|
||||
except TypeError:
|
||||
# TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
|
||||
# all recognized types to be hashable.
|
||||
assert isinstance(elem, weakref.ReferenceType)
|
||||
v = elem()
|
||||
|
||||
# Check if v is a Variable. Note that we can't use isinstance to check if
|
||||
# it's a variable, since not all variable types are subclass of Variable.
|
||||
# TODO(mdan) Update this to use a generic "Variable" superclass once we
|
||||
# create one.
|
||||
if not (hasattr(v, "shape") and hasattr(v, "dtype")):
|
||||
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
|
||||
"or hashable Python objects (or nested structures of "
|
||||
"these types).\nGot type: %s" % type(v).__name__)
|
||||
if resource_variable_ops.is_resource_variable(v):
|
||||
idx = variable_map.get(id(v))
|
||||
if idx is None:
|
||||
idx = len(variable_map)
|
||||
variable_map[id(v)] = idx
|
||||
|
||||
idx = variable_map.get(id(v))
|
||||
if idx is None:
|
||||
idx = len(variable_map)
|
||||
variable_map[id(v)] = idx
|
||||
# We include the class name to avoid having different types of variables
|
||||
# having the same hash. We Also include the variable index which allows
|
||||
# us to return a different hash if variables have been aliased in a call.
|
||||
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
|
||||
|
||||
# We include the class name to avoid having different types of variables
|
||||
# having the same hash. We Also include the variable index which allows
|
||||
# us to return a different hash if variables have been aliased in a call.
|
||||
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
|
||||
if _is_ndarray(v):
|
||||
# Numpy arrays are not hashable, but when calling functions we treat them
|
||||
# in the same way as tf.Tensors.
|
||||
if not hasattr(v, "shape") or not hasattr(v, "dtype"):
|
||||
# TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
|
||||
v = _as_ndarray(v)
|
||||
return tensor_spec.TensorSpec(v.shape, v.dtype)
|
||||
|
||||
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
|
||||
"or hashable Python objects (or nested structures of "
|
||||
"these types).\nGot type: %s" % type(v).__name__)
|
||||
|
||||
return elem
|
||||
|
||||
@ -2668,6 +2673,24 @@ class FunctionSpec(object):
|
||||
return inputs, {}
|
||||
|
||||
|
||||
def _as_ndarray(value):
|
||||
"""Converts value to an ndarray, assumes _is_ndarray(value)."""
|
||||
# TODO(tomhennigan) Support __array_interface__ too.
|
||||
return value.__array__()
|
||||
|
||||
|
||||
def _is_ndarray(value):
|
||||
"""Tests whether the given value is an ndarray (and not a TF tensor/var)."""
|
||||
# TODO(tomhennigan) Support __array_interface__ too.
|
||||
return hasattr(value, "__array__") and not (
|
||||
resource_variable_ops.is_resource_variable(value)
|
||||
or tensor_util.is_tensor(value)
|
||||
# For legacy reasons we do not automatically promote Numpy strings.
|
||||
or isinstance(value, np.str_)
|
||||
# NumPy dtypes have __array__ as unbound methods.
|
||||
or isinstance(value, type))
|
||||
|
||||
|
||||
def _convert_numpy_inputs(inputs):
|
||||
"""Convert numpy array inputs to tensors."""
|
||||
# We assume that any CompositeTensors have already converted their components
|
||||
@ -2680,8 +2703,12 @@ def _convert_numpy_inputs(inputs):
|
||||
# possible since ndarrays are not hashable).
|
||||
need_packing = False
|
||||
for index, value in enumerate(flat_inputs):
|
||||
if type(value) == np.ndarray:
|
||||
flat_inputs[index] = constant_op.constant(value)
|
||||
if _is_ndarray(value):
|
||||
a = _as_ndarray(value)
|
||||
if not isinstance(a, np.ndarray):
|
||||
raise TypeError("The output of __array__ must be an np.ndarray "
|
||||
"(got {} from {}).".format(type(a), type(value)))
|
||||
flat_inputs[index] = constant_op.constant(a)
|
||||
need_packing = True
|
||||
if need_packing:
|
||||
return nest.pack_sequence_as(
|
||||
|
@ -775,11 +775,44 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# shouldn't trigger another function definition.
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
np_ones = numpy.ones([], numpy.float32)
|
||||
np_zeros = numpy.zeros([], numpy.float32)
|
||||
tf_ones = array_ops.ones([])
|
||||
tf_zeros = array_ops.zeros([])
|
||||
|
||||
# Test that the numpy array is properly an argument to the graph function.
|
||||
self.assertEqual(1., defined(numpy.ones([])).numpy())
|
||||
self.assertEqual(0., defined(numpy.zeros([])).numpy())
|
||||
self.assertEqual(1., defined(array_ops.ones([])).numpy())
|
||||
self.assertEqual(0., defined(array_ops.zeros([])).numpy())
|
||||
self.assertEqual(1., defined(np_ones).numpy())
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
self.assertEqual(0., defined(np_zeros).numpy())
|
||||
self.assertEqual(1., defined(tf_ones).numpy())
|
||||
self.assertEqual(0., defined(tf_zeros).numpy())
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
# Test that mutable inputs are supported.
|
||||
mutable = numpy.ones([], numpy.float32)
|
||||
self.assertEqual(1., defined(mutable).numpy())
|
||||
mutable.fill(0)
|
||||
self.assertEqual(0., defined(mutable).numpy())
|
||||
|
||||
class MyNdarray(numpy.ndarray):
|
||||
pass
|
||||
|
||||
# Test that the subclasses of ndarray are converted too.
|
||||
self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
|
||||
self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
|
||||
|
||||
# We should not have triggered any re-tracing of the python function.
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testNumpyDtypeInputSupported(self):
|
||||
@function.defun
|
||||
def f(x, dtype):
|
||||
return constant_op.constant(dtype(x))
|
||||
|
||||
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
|
||||
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
|
||||
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
|
||||
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
|
||||
|
||||
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user