Merge pull request #41981 from lithuak:iss35342-take-two
PiperOrigin-RevId: 329749527 Change-Id: I4e29264e4f7eb97f47856e8eb527a66336a78c1f
This commit is contained in:
commit
a2c542a0d8
@ -98,6 +98,9 @@
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||
with a new `output_signature` argument, which allows `from_generator` to
|
||||
produce any type describable by a `tf.TypeSpec`.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
@ -183,13 +186,11 @@
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
* `tf.nn`:
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
<ADD RELEASE NOTES HERE>
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -61,6 +61,7 @@ class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
([], []),
|
||||
([1], [1]),
|
||||
([1, 1, 1, 1, 1, 1, 1], [1]),
|
||||
([1, 1, 1, 1, 0], [1, 0]),
|
||||
([1, 2, 3, 4], [1, 2, 3, 4]),
|
||||
([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
|
||||
([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
|
||||
|
@ -28,7 +28,12 @@ from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -259,7 +264,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -279,7 +284,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -300,11 +305,9 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual((1, 2), self.evaluate(get_next()))
|
||||
self.assertEqual((3, 4), self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
self.assertEqual((9, 10), self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -423,8 +426,12 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
stateful=True)
|
||||
|
||||
dummy = constant_op.constant(37)
|
||||
dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x,
|
||||
finalize_fn).take(2)
|
||||
|
||||
dataset = dataset_ops._GeneratorDataset(
|
||||
dummy, lambda x: x, lambda x: x, finalize_fn,
|
||||
tensor_spec.TensorSpec((), dtypes.int32))
|
||||
|
||||
dataset = dataset.take(2)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
self.assertAllEqual(37, self.evaluate(get_next()))
|
||||
@ -446,6 +453,44 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([20], self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorRaggedTensor(self):
|
||||
|
||||
def generator():
|
||||
yield ragged_factory_ops.constant([[1, 2], [3]])
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator,
|
||||
output_signature=ragged_tensor.RaggedTensorSpec(
|
||||
shape=(2, None), dtype=dtypes.int32))
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
ret = get_next()
|
||||
|
||||
self.assertIsInstance(ret, ragged_tensor.RaggedTensor)
|
||||
self.assertAllEqual([[1, 2], [3]], ret)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorSparseTensor(self):
|
||||
|
||||
def generator():
|
||||
yield sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=constant_op.constant([1, 2], dtype=dtypes.int64),
|
||||
dense_shape=[3, 4])
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator,
|
||||
output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64))
|
||||
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
ret = get_next()
|
||||
|
||||
self.assertIsInstance(ret, sparse_tensor.SparseTensor)
|
||||
self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]],
|
||||
sparse_ops.sparse_tensor_to_dense(ret))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testTypeIsListError(self):
|
||||
|
||||
|
@ -946,7 +946,9 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn)
|
||||
output_signature = tensor_spec.TensorSpec((), dtypes.int64)
|
||||
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
|
||||
output_signature)
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
|
||||
|
@ -725,27 +725,46 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
del self._iterators[iterator_id]
|
||||
|
||||
@staticmethod
|
||||
def from_generator(generator, output_types, output_shapes=None, args=None):
|
||||
@deprecation.deprecated_args(None, "Use output_signature instead",
|
||||
"output_types", "output_shapes")
|
||||
def from_generator(generator,
|
||||
output_types=None,
|
||||
output_shapes=None,
|
||||
args=None,
|
||||
output_signature=None):
|
||||
"""Creates a `Dataset` whose elements are generated by `generator`.
|
||||
|
||||
The `generator` argument must be a callable object that returns
|
||||
an object that supports the `iter()` protocol (e.g. a generator function).
|
||||
The elements generated by `generator` must be compatible with the given
|
||||
`output_types` and (optional) `output_shapes` arguments.
|
||||
|
||||
>>> import itertools
|
||||
>>>
|
||||
The elements generated by `generator` must be compatible with either the
|
||||
given `output_signature` argument or with the given `output_types` and
|
||||
(optionally) `output_shapes` arguments, whichiver was specified.
|
||||
|
||||
The recommended way to call `from_generator` is to use the
|
||||
`output_signature` argument. In this case the output will be assumed to
|
||||
consist of objects with the classes, shapes and types defined by
|
||||
`tf.TypeSpec` objects from `output_signature` argument:
|
||||
|
||||
>>> def gen():
|
||||
... for i in itertools.count(1):
|
||||
... yield (i, [1] * i)
|
||||
... ragged_tensor = tf.ragged.constant([[1, 2], [3]])
|
||||
... yield 42, ragged_tensor
|
||||
>>>
|
||||
>>> dataset = tf.data.Dataset.from_generator(
|
||||
... gen,
|
||||
... (tf.int64, tf.int64),
|
||||
... (tf.TensorShape([]), tf.TensorShape([None])))
|
||||
... output_signature=(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
|
||||
>>>
|
||||
>>> list(dataset.take(3).as_numpy_iterator())
|
||||
[(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
|
||||
>>> list(dataset.take(1))
|
||||
[(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
|
||||
<tf.RaggedTensor [[1, 2], [3]]>)]
|
||||
|
||||
There is also a deprecated way to call `from_generator` by either with
|
||||
`output_types` argument alone or together with `output_shapes` argument.
|
||||
In this case the output of the function will be assumed to consist of
|
||||
`tf.Tensor` objects with with the types defined by `output_types` and with
|
||||
the shapes which are either unknown or defined by `output_shapes`.
|
||||
|
||||
Note: The current implementation of `Dataset.from_generator()` uses
|
||||
`tf.numpy_function` and inherits the same constraints. In particular, it
|
||||
@ -769,31 +788,56 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
`iter()` protocol. If `args` is not specified, `generator` must take no
|
||||
arguments; otherwise it must take as many arguments as there are values
|
||||
in `args`.
|
||||
output_types: A nested structure of `tf.DType` objects corresponding to
|
||||
each component of an element yielded by `generator`.
|
||||
output_types: (Optional.) A nested structure of `tf.DType` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
|
||||
and passed to `generator` as NumPy-array arguments.
|
||||
output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
"""
|
||||
if not callable(generator):
|
||||
raise TypeError("`generator` must be callable.")
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
|
||||
if output_signature is not None:
|
||||
if output_types is not None:
|
||||
raise TypeError("`output_types` can not be used together with "
|
||||
"`output_signature`")
|
||||
if output_shapes is not None:
|
||||
raise TypeError("`output_shapes` can not be used together with "
|
||||
"`output_signature`")
|
||||
if not all(
|
||||
isinstance(_, type_spec.TypeSpec)
|
||||
for _ in nest.flatten(output_signature)):
|
||||
raise TypeError("All the elements of `output_signature` must be "
|
||||
"`tf.TypeSpec` objects.")
|
||||
else:
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
if output_types is None:
|
||||
raise TypeError("Either `output_signature` or `output_types` must "
|
||||
"be specified")
|
||||
|
||||
if output_signature is None:
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
else:
|
||||
output_shapes = nest.map_structure_up_to(output_types,
|
||||
tensor_shape.as_shape,
|
||||
output_shapes)
|
||||
output_signature = nest.map_structure_up_to(output_types,
|
||||
tensor_spec.TensorSpec,
|
||||
output_shapes, output_types)
|
||||
|
||||
if args is None:
|
||||
args = ()
|
||||
else:
|
||||
args = tuple(ops.convert_n_to_tensor(args, name="args"))
|
||||
|
||||
flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
|
||||
flattened_shapes = nest.flatten(output_shapes)
|
||||
flat_output_types = structure.get_flat_tensor_types(output_signature)
|
||||
|
||||
generator_state = DatasetV2._GeneratorState(generator)
|
||||
|
||||
@ -831,56 +875,33 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
"""A `py_func` that will be called to invoke the iterator."""
|
||||
# `next()` raises `StopIteration` when there are no more
|
||||
# elements remaining to be generated.
|
||||
values = next(generator_state.get_iterator(iterator_id))
|
||||
values = next(generator_state.get_iterator(iterator_id.numpy()))
|
||||
|
||||
# Use the same _convert function from the py_func() implementation to
|
||||
# convert the returned values to arrays early, so that we can inspect
|
||||
# their values.
|
||||
try:
|
||||
flattened_values = nest.flatten_up_to(output_types, values)
|
||||
values = structure.normalize_element(values, output_signature)
|
||||
except (TypeError, ValueError):
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that did not match the expected "
|
||||
"structure. The expected structure was %s, but the yielded "
|
||||
"element was %s." % (output_types, values)), sys.exc_info()[2])
|
||||
ret_arrays = []
|
||||
for ret, dtype in zip(flattened_values, flattened_types):
|
||||
try:
|
||||
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
||||
ret, dtype=dtype.as_numpy_dtype))
|
||||
except (TypeError, ValueError):
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that could not be converted to "
|
||||
"the expected type. The expected type was %s, but the yielded "
|
||||
"element was %s." % (dtype.name, ret)), sys.exc_info()[2])
|
||||
six.reraise(
|
||||
TypeError,
|
||||
TypeError(
|
||||
"`generator` yielded an element that did not match the "
|
||||
"expected structure. The expected structure was %s, but the "
|
||||
"yielded element was %s." % (output_signature, values)),
|
||||
sys.exc_info()[2])
|
||||
|
||||
# Additional type and shape checking to ensure that the components
|
||||
# of the generated element match the `output_types` and `output_shapes`
|
||||
# arguments.
|
||||
for (ret_array, expected_dtype, expected_shape) in zip(
|
||||
ret_arrays, flattened_types, flattened_shapes):
|
||||
if ret_array.dtype != expected_dtype.as_numpy_dtype:
|
||||
raise TypeError(
|
||||
"`generator` yielded an element of type %s where an element "
|
||||
"of type %s was expected." % (ret_array.dtype,
|
||||
expected_dtype.as_numpy_dtype))
|
||||
if not expected_shape.is_compatible_with(ret_array.shape):
|
||||
raise ValueError(
|
||||
"`generator` yielded an element of shape %s where an element "
|
||||
"of shape %s was expected." % (ret_array.shape, expected_shape))
|
||||
values_spec = structure.type_spec_from_value(values)
|
||||
|
||||
return ret_arrays
|
||||
if not structure.are_compatible(values_spec, output_signature):
|
||||
raise TypeError(
|
||||
"`generator` yielded an element of %s where an element "
|
||||
"of %s was expected." % (values_spec, output_signature))
|
||||
|
||||
flat_values = script_ops.numpy_function(generator_py_func,
|
||||
[iterator_id_t], flattened_types)
|
||||
return structure.to_tensor_list(output_signature, values)
|
||||
|
||||
# The `py_func()` op drops the inferred shapes, so we add them back in
|
||||
# here.
|
||||
if output_shapes is not None:
|
||||
for ret_t, shape in zip(flat_values, flattened_shapes):
|
||||
ret_t.set_shape(shape)
|
||||
|
||||
return nest.pack_sequence_as(output_types, flat_values)
|
||||
return script_ops._eager_py_func( # pylint: disable=protected-access
|
||||
generator_py_func,
|
||||
inp=[iterator_id_t],
|
||||
Tout=flat_output_types,
|
||||
use_tape_cache=False)
|
||||
|
||||
def finalize_fn(iterator_id_t):
|
||||
"""Releases host-side state for the iterator with ID `iterator_id_t`."""
|
||||
@ -906,7 +927,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
# given ID, and raises StopIteration when that iterator contains no
|
||||
# more elements.
|
||||
return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
|
||||
finalize_fn)
|
||||
finalize_fn, output_signature)
|
||||
|
||||
# A single-element dataset that, each time it is evaluated, contains a
|
||||
# freshly-generated and unique (for the returned dataset) int64
|
||||
@ -2148,21 +2169,8 @@ name=None))
|
||||
|
||||
`cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
|
||||
contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
|
||||
the analysis fails to determine the number of elements in the dataset.
|
||||
|
||||
`cardinality` only reports known cardinality (finite or infinite), if it can
|
||||
be inferred statically. In particular, the implementation does not iterate
|
||||
through the dataset or evaluate user-defined functions. As a consequence,
|
||||
the statically inferred cardinality may often be unknown. For example, if
|
||||
the dataset reads from file(s), the cardinality will be unknown. The
|
||||
cardinality will also be unknown if the dataset contains user-defined
|
||||
functions which could affect the cardinality (such as the functions in
|
||||
`filter`, `flat_map`, `interleave`, or `from_generator`).
|
||||
|
||||
When constructing a dataset, you can apply the
|
||||
`tf.data.experimental.assert_cardinality` transformation to inform the
|
||||
dataset of its expected cardinality, so that `cardinality` can produce a
|
||||
known cardinality.
|
||||
the analysis fails to determine the number of elements in the dataset
|
||||
(e.g. when the dataset source is a file).
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> print(dataset.cardinality().numpy())
|
||||
@ -2171,13 +2179,10 @@ name=None))
|
||||
>>> cardinality = dataset.cardinality()
|
||||
>>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
|
||||
True
|
||||
>>> dataset = dataset.filter(lambda x: False)
|
||||
>>> dataset = dataset.filter(lambda x: True)
|
||||
>>> cardinality = dataset.cardinality()
|
||||
>>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
|
||||
True
|
||||
>>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(0))
|
||||
>>> print(dataset.cardinality().numpy())
|
||||
0
|
||||
|
||||
Returns:
|
||||
A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
|
||||
@ -2458,9 +2463,14 @@ class DatasetV1(DatasetV2):
|
||||
|
||||
@staticmethod
|
||||
@functools.wraps(DatasetV2.from_generator)
|
||||
def from_generator(generator, output_types, output_shapes=None, args=None):
|
||||
return DatasetV1Adapter(DatasetV2.from_generator(
|
||||
generator, output_types, output_shapes, args))
|
||||
def from_generator(generator,
|
||||
output_types=None,
|
||||
output_shapes=None,
|
||||
args=None,
|
||||
output_signature=None):
|
||||
return DatasetV1Adapter(
|
||||
DatasetV2.from_generator(generator, output_types, output_shapes, args,
|
||||
output_signature))
|
||||
|
||||
@staticmethod
|
||||
@functools.wraps(DatasetV2.range)
|
||||
@ -3469,7 +3479,8 @@ class StructuredFunctionWrapper(object):
|
||||
class _GeneratorDataset(DatasetSource):
|
||||
"""A `Dataset` that generates elements by invoking a function."""
|
||||
|
||||
def __init__(self, init_args, init_func, next_func, finalize_func):
|
||||
def __init__(self, init_args, init_func, next_func, finalize_func,
|
||||
output_signature):
|
||||
"""Constructs a `_GeneratorDataset`.
|
||||
|
||||
Args:
|
||||
@ -3483,6 +3494,8 @@ class _GeneratorDataset(DatasetSource):
|
||||
finalize_func: A TensorFlow function that will be called on the result of
|
||||
`init_func` immediately before a C++ iterator over this dataset is
|
||||
destroyed. The return value is ignored.
|
||||
output_signature: A nested structure of `tf.TypeSpec` objects describing
|
||||
the output of `next_func`.
|
||||
"""
|
||||
self._init_args = init_args
|
||||
|
||||
@ -3502,6 +3515,9 @@ class _GeneratorDataset(DatasetSource):
|
||||
finalize_func,
|
||||
self._transformation_name(),
|
||||
input_structure=self._init_func.output_structure)
|
||||
|
||||
self._output_signature = output_signature
|
||||
|
||||
variant_tensor = gen_dataset_ops.generator_dataset(
|
||||
structure.to_tensor_list(self._init_structure, self._init_args) +
|
||||
self._init_func.function.captured_inputs,
|
||||
@ -3515,7 +3531,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._next_func.output_structure
|
||||
return self._output_signature
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.from_generator()"
|
||||
|
@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank):
|
||||
|
||||
# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
|
||||
# it is a subclass of `CompositeTensor`.
|
||||
def normalize_element(element):
|
||||
def normalize_element(element, element_signature=None):
|
||||
"""Normalizes a nested structure of element components.
|
||||
|
||||
* Components matching `SparseTensorSpec` are converted to `SparseTensor`.
|
||||
@ -78,19 +78,32 @@ def normalize_element(element):
|
||||
|
||||
Args:
|
||||
element: A nested structure of individual components.
|
||||
element_signature: (Optional.) A nested structure of `tf.DType` objects
|
||||
corresponding to each component of `element`. If specified, it will be
|
||||
used to set the exact type of output tensor when converting input
|
||||
components which are not tensors themselves (e.g. numpy arrays, native
|
||||
python types, etc.)
|
||||
|
||||
Returns:
|
||||
A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`,
|
||||
or `TensorArray` objects.
|
||||
"""
|
||||
components = nest.flatten(element)
|
||||
normalized_components = []
|
||||
if element_signature is None:
|
||||
components = nest.flatten(element)
|
||||
flattened_signature = [None] * len(components)
|
||||
pack_as = element
|
||||
else:
|
||||
flattened_signature = nest.flatten(element_signature)
|
||||
components = nest.flatten_up_to(element_signature, element)
|
||||
pack_as = element_signature
|
||||
with ops.name_scope("normalize_element"):
|
||||
# Imported here to avoid circular dependency.
|
||||
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
|
||||
for i, t in enumerate(components):
|
||||
for i, (t, spec) in enumerate(zip(components, flattened_signature)):
|
||||
try:
|
||||
spec = type_spec_from_value(t, use_fallback=False)
|
||||
if spec is None:
|
||||
spec = type_spec_from_value(t, use_fallback=False)
|
||||
except TypeError:
|
||||
# TypeError indicates it was not possible to compute a `TypeSpec` for
|
||||
# the value. As a fallback try converting the value to a tensor.
|
||||
@ -111,9 +124,10 @@ def normalize_element(element):
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
normalized_components.append(t)
|
||||
else:
|
||||
dtype = getattr(spec, "dtype", None)
|
||||
normalized_components.append(
|
||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
return nest.pack_sequence_as(element, normalized_components)
|
||||
ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
|
||||
return nest.pack_sequence_as(pack_as, normalized_components)
|
||||
|
||||
|
||||
def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
|
@ -71,7 +71,7 @@ def _maybe_copy_to_context_device(tensor, device_name):
|
||||
class EagerFunc(object):
|
||||
"""A wrapper for a function owned by an EagerPyFunc."""
|
||||
|
||||
def __init__(self, func, Tout, is_grad_func):
|
||||
def __init__(self, func, Tout, is_grad_func, use_tape_cache=True):
|
||||
"""Constructs an EagerFunc.
|
||||
|
||||
Args:
|
||||
@ -80,10 +80,14 @@ class EagerFunc(object):
|
||||
None.
|
||||
is_grad_func: Whether this EagerFunc is the gradient of another
|
||||
EagerPyFunc.
|
||||
use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
|
||||
For additional information, see description of `_eager_py_func`.
|
||||
This parameter should be removed once the #35084 issue is fixed.
|
||||
"""
|
||||
self._func = func
|
||||
self._out_dtypes = Tout
|
||||
self._is_grad_func = is_grad_func
|
||||
self._use_tape_cache = use_tape_cache
|
||||
|
||||
def _convert(self, value, dtype):
|
||||
"""Converts `value` to a tensor of type `dtype`, with error checking.
|
||||
@ -147,7 +151,8 @@ class EagerFunc(object):
|
||||
else:
|
||||
outputs = _maybe_copy_to_context_device(
|
||||
self._convert(ret, dtype=self._out_dtypes[0]), device_name)
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
if self._use_tape_cache:
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -277,7 +282,8 @@ def _internal_py_func(func,
|
||||
stateful=None,
|
||||
eager=False,
|
||||
is_grad_func=False,
|
||||
name=None):
|
||||
name=None,
|
||||
use_tape_cache=True):
|
||||
"""See documentation for py_func and eager_py_func."""
|
||||
if not callable(func):
|
||||
raise ValueError("Expected func to be callable, got func of type {}".format(
|
||||
@ -293,7 +299,7 @@ def _internal_py_func(func,
|
||||
Tout = [Tout]
|
||||
|
||||
if eager:
|
||||
func = EagerFunc(func, Tout, is_grad_func)
|
||||
func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache)
|
||||
|
||||
# Tying the registered function's lifetime with the current default graph is
|
||||
# not reliable. For example, Estimator-based binaries may switch graphs in
|
||||
@ -370,6 +376,58 @@ def _EagerPyFuncGrad(op, *dy):
|
||||
is_grad_func=True)
|
||||
|
||||
|
||||
def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True):
|
||||
"""Wraps a python function into a TensorFlow op that executes it eagerly.
|
||||
|
||||
This function is the internal implementation for `eager_py_func`, see the
|
||||
`eager_py_func` docstring for the full description.
|
||||
|
||||
Note: this function as a layer of indirection was added with one
|
||||
specific purpose: as a workaround for github issue #35084.
|
||||
It does all the same as `eager_py_func` used to do with one difference:
|
||||
it can be used to instruct underlying EagerFunc not to use `tape_cache`
|
||||
to avoid memory leak. When the issue #35084 is fixed - this function should
|
||||
be removed, its body should be moved back to become the body of
|
||||
`eager_py_func` and all the call sites should be reverted to
|
||||
using `eager_py_func` without `use_tape_cache` argument of any value.
|
||||
|
||||
Args:
|
||||
func: A Python function which accepts a list of `Tensor` objects having
|
||||
element types that match the corresponding `tf.Tensor` objects in `inp`
|
||||
and returns a list of `Tensor` objects (or a single `Tensor`, or `None`)
|
||||
having element types that match the corresponding values in `Tout`.
|
||||
inp: A list of `Tensor` objects.
|
||||
Tout: A list or tuple of tensorflow data types or a single tensorflow data
|
||||
type if there is only one, indicating what `func` returns; an empty list
|
||||
if no value is returned (i.e., if the return value is `None`).
|
||||
name: A name for the operation (optional).
|
||||
use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
|
||||
For additional information, see description of `_eager_py_func`.
|
||||
This parameter should be removed once the #35084 issue is fixed.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||
if `func` returns None.
|
||||
"""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func,
|
||||
inp=inp,
|
||||
Tout=Tout,
|
||||
eager=True,
|
||||
name=name,
|
||||
use_tape_cache=use_tape_cache)
|
||||
|
||||
return _internal_py_func(
|
||||
func=func,
|
||||
inp=inp,
|
||||
Tout=Tout,
|
||||
eager=True,
|
||||
name=name,
|
||||
use_tape_cache=use_tape_cache)
|
||||
|
||||
|
||||
@tf_export("py_function")
|
||||
@dispatch.add_dispatch_support
|
||||
def eager_py_func(func, inp, Tout, name=None):
|
||||
@ -451,12 +509,8 @@ def eager_py_func(func, inp, Tout, name=None):
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||
if `func` returns None.
|
||||
"""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
|
||||
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
return _eager_py_func(
|
||||
func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True)
|
||||
|
||||
|
||||
def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -49,7 +49,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
Loading…
Reference in New Issue
Block a user