Don't treat type
objects (with __array__) as ndarrays.
PiperOrigin-RevId: 307454154 Change-Id: I6669c41e4dd8256ffd7c4203a1e84ddc2b2f876b
This commit is contained in:
parent
931269353d
commit
6728f85d82
@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -107,23 +108,15 @@ def _make_input_signature_hashable(elem, variable_map=None):
|
|||||||
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
|
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
|
||||||
elem))
|
elem))
|
||||||
|
|
||||||
# If the element is not hashable, assume it is a weakref to a variable
|
|
||||||
# and return the dtype & shape. Else, simply return the element
|
|
||||||
try:
|
try:
|
||||||
hash(elem)
|
hash(elem)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
# TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
|
||||||
|
# all recognized types to be hashable.
|
||||||
assert isinstance(elem, weakref.ReferenceType)
|
assert isinstance(elem, weakref.ReferenceType)
|
||||||
v = elem()
|
v = elem()
|
||||||
|
|
||||||
# Check if v is a Variable. Note that we can't use isinstance to check if
|
if resource_variable_ops.is_resource_variable(v):
|
||||||
# it's a variable, since not all variable types are subclass of Variable.
|
|
||||||
# TODO(mdan) Update this to use a generic "Variable" superclass once we
|
|
||||||
# create one.
|
|
||||||
if not (hasattr(v, "shape") and hasattr(v, "dtype")):
|
|
||||||
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
|
|
||||||
"or hashable Python objects (or nested structures of "
|
|
||||||
"these types).\nGot type: %s" % type(v).__name__)
|
|
||||||
|
|
||||||
idx = variable_map.get(id(v))
|
idx = variable_map.get(id(v))
|
||||||
if idx is None:
|
if idx is None:
|
||||||
idx = len(variable_map)
|
idx = len(variable_map)
|
||||||
@ -134,6 +127,18 @@ def _make_input_signature_hashable(elem, variable_map=None):
|
|||||||
# us to return a different hash if variables have been aliased in a call.
|
# us to return a different hash if variables have been aliased in a call.
|
||||||
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
|
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
|
||||||
|
|
||||||
|
if _is_ndarray(v):
|
||||||
|
# Numpy arrays are not hashable, but when calling functions we treat them
|
||||||
|
# in the same way as tf.Tensors.
|
||||||
|
if not hasattr(v, "shape") or not hasattr(v, "dtype"):
|
||||||
|
# TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
|
||||||
|
v = _as_ndarray(v)
|
||||||
|
return tensor_spec.TensorSpec(v.shape, v.dtype)
|
||||||
|
|
||||||
|
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
|
||||||
|
"or hashable Python objects (or nested structures of "
|
||||||
|
"these types).\nGot type: %s" % type(v).__name__)
|
||||||
|
|
||||||
return elem
|
return elem
|
||||||
|
|
||||||
|
|
||||||
@ -2668,6 +2673,24 @@ class FunctionSpec(object):
|
|||||||
return inputs, {}
|
return inputs, {}
|
||||||
|
|
||||||
|
|
||||||
|
def _as_ndarray(value):
|
||||||
|
"""Converts value to an ndarray, assumes _is_ndarray(value)."""
|
||||||
|
# TODO(tomhennigan) Support __array_interface__ too.
|
||||||
|
return value.__array__()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ndarray(value):
|
||||||
|
"""Tests whether the given value is an ndarray (and not a TF tensor/var)."""
|
||||||
|
# TODO(tomhennigan) Support __array_interface__ too.
|
||||||
|
return hasattr(value, "__array__") and not (
|
||||||
|
resource_variable_ops.is_resource_variable(value)
|
||||||
|
or tensor_util.is_tensor(value)
|
||||||
|
# For legacy reasons we do not automatically promote Numpy strings.
|
||||||
|
or isinstance(value, np.str_)
|
||||||
|
# NumPy dtypes have __array__ as unbound methods.
|
||||||
|
or isinstance(value, type))
|
||||||
|
|
||||||
|
|
||||||
def _convert_numpy_inputs(inputs):
|
def _convert_numpy_inputs(inputs):
|
||||||
"""Convert numpy array inputs to tensors."""
|
"""Convert numpy array inputs to tensors."""
|
||||||
# We assume that any CompositeTensors have already converted their components
|
# We assume that any CompositeTensors have already converted their components
|
||||||
@ -2680,8 +2703,12 @@ def _convert_numpy_inputs(inputs):
|
|||||||
# possible since ndarrays are not hashable).
|
# possible since ndarrays are not hashable).
|
||||||
need_packing = False
|
need_packing = False
|
||||||
for index, value in enumerate(flat_inputs):
|
for index, value in enumerate(flat_inputs):
|
||||||
if type(value) == np.ndarray:
|
if _is_ndarray(value):
|
||||||
flat_inputs[index] = constant_op.constant(value)
|
a = _as_ndarray(value)
|
||||||
|
if not isinstance(a, np.ndarray):
|
||||||
|
raise TypeError("The output of __array__ must be an np.ndarray "
|
||||||
|
"(got {} from {}).".format(type(a), type(value)))
|
||||||
|
flat_inputs[index] = constant_op.constant(a)
|
||||||
need_packing = True
|
need_packing = True
|
||||||
if need_packing:
|
if need_packing:
|
||||||
return nest.pack_sequence_as(
|
return nest.pack_sequence_as(
|
||||||
|
@ -775,11 +775,44 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# shouldn't trigger another function definition.
|
# shouldn't trigger another function definition.
|
||||||
self.assertLen(total_function_cache(defined), 1)
|
self.assertLen(total_function_cache(defined), 1)
|
||||||
|
|
||||||
|
np_ones = numpy.ones([], numpy.float32)
|
||||||
|
np_zeros = numpy.zeros([], numpy.float32)
|
||||||
|
tf_ones = array_ops.ones([])
|
||||||
|
tf_zeros = array_ops.zeros([])
|
||||||
|
|
||||||
# Test that the numpy array is properly an argument to the graph function.
|
# Test that the numpy array is properly an argument to the graph function.
|
||||||
self.assertEqual(1., defined(numpy.ones([])).numpy())
|
self.assertEqual(1., defined(np_ones).numpy())
|
||||||
self.assertEqual(0., defined(numpy.zeros([])).numpy())
|
self.assertLen(total_function_cache(defined), 2)
|
||||||
self.assertEqual(1., defined(array_ops.ones([])).numpy())
|
self.assertEqual(0., defined(np_zeros).numpy())
|
||||||
self.assertEqual(0., defined(array_ops.zeros([])).numpy())
|
self.assertEqual(1., defined(tf_ones).numpy())
|
||||||
|
self.assertEqual(0., defined(tf_zeros).numpy())
|
||||||
|
self.assertLen(total_function_cache(defined), 2)
|
||||||
|
|
||||||
|
# Test that mutable inputs are supported.
|
||||||
|
mutable = numpy.ones([], numpy.float32)
|
||||||
|
self.assertEqual(1., defined(mutable).numpy())
|
||||||
|
mutable.fill(0)
|
||||||
|
self.assertEqual(0., defined(mutable).numpy())
|
||||||
|
|
||||||
|
class MyNdarray(numpy.ndarray):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Test that the subclasses of ndarray are converted too.
|
||||||
|
self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
|
||||||
|
self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
|
||||||
|
|
||||||
|
# We should not have triggered any re-tracing of the python function.
|
||||||
|
self.assertLen(total_function_cache(defined), 2)
|
||||||
|
|
||||||
|
def testNumpyDtypeInputSupported(self):
|
||||||
|
@function.defun
|
||||||
|
def f(x, dtype):
|
||||||
|
return constant_op.constant(dtype(x))
|
||||||
|
|
||||||
|
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
|
||||||
|
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
|
||||||
|
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
|
||||||
|
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
|
||||||
|
|
||||||
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
|
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user