Internal change
PiperOrigin-RevId: 301680486 Change-Id: If79ff677cfe72fa72ac02996b71047f3caa8b7d2
This commit is contained in:
parent
4f315c18bc
commit
8b30b82190
@ -48,7 +48,7 @@ def _format_record(array, sparse):
|
||||
return {
|
||||
"values": array,
|
||||
"indices": [[i] for i in range(len(array))],
|
||||
"dense_shape": [len(array),]
|
||||
"dense_shape": (len(array),)
|
||||
}
|
||||
return array
|
||||
|
||||
@ -402,16 +402,13 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
bucket_size = 10
|
||||
|
||||
def _build_dataset():
|
||||
input_data = [list(range(i + 1)) for i in range(min_len, max_len)]
|
||||
|
||||
input_data = [range(i+1) for i in range(min_len, max_len)]
|
||||
def generator_fn():
|
||||
for record in input_data:
|
||||
yield _format_record(record, sparse=True)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator=generator_fn,
|
||||
output_types=_get_record_type(sparse=True))
|
||||
|
||||
dataset = dataset.map(_to_sparse_tensor)
|
||||
return dataset
|
||||
|
||||
|
@ -28,12 +28,7 @@ from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -246,7 +241,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -266,7 +261,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -287,9 +282,11 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual((1, 2), self.evaluate(get_next()))
|
||||
self.assertEqual((3, 4), self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
self.evaluate(get_next())
|
||||
self.assertEqual((9, 10), self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -408,12 +405,8 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
stateful=True)
|
||||
|
||||
dummy = constant_op.constant(37)
|
||||
|
||||
dataset = dataset_ops._GeneratorDataset(
|
||||
dummy, lambda x: x, lambda x: x, finalize_fn,
|
||||
tensor_spec.TensorSpec((), dtypes.int32))
|
||||
|
||||
dataset = dataset.take(2)
|
||||
dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x,
|
||||
finalize_fn).take(2)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
self.assertAllEqual(37, self.evaluate(get_next()))
|
||||
@ -435,46 +428,6 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual([20], self.evaluate(get_next()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorRaggedTensor(self):
|
||||
|
||||
def generator():
|
||||
yield ragged_factory_ops.constant([[1, 2], [3]],
|
||||
dtype=dtypes.int64,
|
||||
ragged_rank=1)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator,
|
||||
output_signature=ragged_tensor.RaggedTensorSpec(
|
||||
shape=(2, None), dtype=dtypes.int64))
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
ret = get_next()
|
||||
|
||||
self.assertIsInstance(ret, ragged_tensor.RaggedTensor)
|
||||
self.assertAllEqual([1, 2, 3], ret.values)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorSparseTensor(self):
|
||||
|
||||
def generator():
|
||||
yield sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=constant_op.constant([1, 2], dtype=dtypes.int64),
|
||||
dense_shape=[3, 4])
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator,
|
||||
output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64))
|
||||
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
ret = get_next()
|
||||
|
||||
self.assertIsInstance(ret, sparse_tensor.SparseTensor)
|
||||
self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]],
|
||||
sparse_ops.sparse_tensor_to_dense(ret))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -946,9 +946,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
output_spec = tensor_spec.TensorSpec((), dtypes.int64)
|
||||
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
|
||||
output_spec)
|
||||
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn)
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
|
||||
|
@ -408,7 +408,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this dataset.
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
|
||||
>>> dataset.element_spec
|
||||
TensorSpec(shape=(), dtype=tf.int32, name=None)
|
||||
|
||||
Returns:
|
||||
@ -674,48 +675,27 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
del self._iterators[iterator_id]
|
||||
|
||||
@staticmethod
|
||||
@deprecation.deprecated_args(None, "Use output_signature instead",
|
||||
"output_types", "output_shapes")
|
||||
def from_generator(generator,
|
||||
output_types=None,
|
||||
output_shapes=None,
|
||||
args=None,
|
||||
output_signature=None):
|
||||
def from_generator(generator, output_types, output_shapes=None, args=None):
|
||||
"""Creates a `Dataset` whose elements are generated by `generator`.
|
||||
|
||||
The `generator` argument must be a callable object that returns
|
||||
an object that supports the `iter()` protocol (e.g. a generator function).
|
||||
The elements generated by `generator` must be compatible with the given
|
||||
`output_types` and (optional) `output_shapes` arguments.
|
||||
|
||||
The elements generated by `generator` must be compatible with either the
|
||||
given `output_signature` argument or with the given `output_types` and
|
||||
(optionally) `output_shapes` arguments whichiver was specified.
|
||||
|
||||
The recommended way to call `from_generator` is to use the
|
||||
`output_signature` argument. In this case the output will be assumed to
|
||||
consist of objects with the classes, shapes and types defined by
|
||||
`tf.TypeSpec` objects from `output_signature` argument:
|
||||
|
||||
>>> import itertools
|
||||
>>>
|
||||
>>> def gen():
|
||||
... ragged_tensor = tf.ragged.constant([[1, 2], [3]],
|
||||
... ragged_rank=1,
|
||||
... dtype=tf.int64)
|
||||
... yield 42, ragged_tensor
|
||||
... for i in itertools.count(1):
|
||||
... yield (i, [1] * i)
|
||||
>>>
|
||||
>>> dataset = tf.data.Dataset.from_generator(
|
||||
... gen,
|
||||
... output_signature=(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int64),
|
||||
... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int64)))
|
||||
... (tf.int64, tf.int64),
|
||||
... (tf.TensorShape([]), tf.TensorShape([None])))
|
||||
>>>
|
||||
>>> list(dataset.take(1))
|
||||
[(<tf.Tensor: shape=(), dtype=int64, numpy=42>,
|
||||
<tf.RaggedTensor [[1, 2], [3]]>)]
|
||||
|
||||
There is also a deprecated way to call `from_generator` by either with
|
||||
`output_types` argument alone or together with `output_shapes` argument.
|
||||
In this case the output of the function will be assumed to consist of
|
||||
`tf.Tensor` objects with with the types defined by `output_types` and with
|
||||
the shapes which are either unknown or defined by `output_shapes`.
|
||||
>>> list(dataset.take(3).as_numpy_iterator())
|
||||
[(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
|
||||
|
||||
Note: The current implementation of `Dataset.from_generator()` uses
|
||||
`tf.numpy_function` and inherits the same constraints. In particular, it
|
||||
@ -739,56 +719,31 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
`iter()` protocol. If `args` is not specified, `generator` must take no
|
||||
arguments; otherwise it must take as many arguments as there are values
|
||||
in `args`.
|
||||
output_types: (Optional.) A nested structure of `tf.DType` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
output_types: A nested structure of `tf.DType` objects corresponding to
|
||||
each component of an element yielded by `generator`.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
|
||||
and passed to `generator` as NumPy-array arguments.
|
||||
output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects
|
||||
corresponding to each component of an element yielded by `generator`.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
"""
|
||||
if not callable(generator):
|
||||
raise TypeError("`generator` must be callable.")
|
||||
|
||||
if output_signature is not None:
|
||||
if output_types is not None:
|
||||
raise TypeError("`output_types` can not be used together with "
|
||||
"`output_signature`")
|
||||
if output_shapes is not None:
|
||||
raise TypeError("`output_shapes` can not be used together with "
|
||||
"`output_signature`")
|
||||
if not all(
|
||||
isinstance(_, type_spec.TypeSpec)
|
||||
for _ in nest.flatten(output_signature)):
|
||||
raise TypeError("All the elements of `output_siganture` must be "
|
||||
"a `tf.TypeSpec` objects.")
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
else:
|
||||
if output_types is None and output_shapes is not None:
|
||||
raise TypeError("`output_shapes` can not be used alone without "
|
||||
"`output_types`")
|
||||
|
||||
if output_signature is None:
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
else:
|
||||
output_shapes = nest.map_structure_up_to(output_types,
|
||||
tensor_shape.as_shape,
|
||||
output_shapes)
|
||||
output_signature = nest.map_structure_up_to(output_types,
|
||||
tensor_spec.TensorSpec,
|
||||
output_shapes, output_types)
|
||||
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
if args is None:
|
||||
args = ()
|
||||
else:
|
||||
args = tuple(ops.convert_n_to_tensor(args, name="args"))
|
||||
|
||||
flat_output_types = structure.get_flat_tensor_types(output_signature)
|
||||
flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
|
||||
flattened_shapes = nest.flatten(output_shapes)
|
||||
|
||||
generator_state = DatasetV2._GeneratorState(generator)
|
||||
|
||||
@ -826,41 +781,56 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
"""A `py_func` that will be called to invoke the iterator."""
|
||||
# `next()` raises `StopIteration` when there are no more
|
||||
# elements remaining to be generated.
|
||||
values = next(generator_state.get_iterator(iterator_id.numpy()))
|
||||
|
||||
def serialize_structure(s):
|
||||
return nest.map_structure(lambda ts: ts._serialize(), s) # pylint: disable=protected-access
|
||||
values = next(generator_state.get_iterator(iterator_id))
|
||||
|
||||
# Use the same _convert function from the py_func() implementation to
|
||||
# convert the returned values to arrays early, so that we can inspect
|
||||
# their values.
|
||||
try:
|
||||
output_dtypes = nest.map_structure(lambda t: t.dtype,
|
||||
output_signature)
|
||||
values = structure.normalize_element(values, dtypes=output_dtypes)
|
||||
flattened_values = nest.flatten_up_to(output_types, values)
|
||||
except (TypeError, ValueError):
|
||||
six.reraise(
|
||||
TypeError,
|
||||
TypeError(
|
||||
"`generator` yielded an element that did not match the "
|
||||
"expected structure. The expected structure was %s, but the "
|
||||
"yielded element was %s." %
|
||||
(serialize_structure(output_signature), values)),
|
||||
sys.exc_info()[2])
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that did not match the expected "
|
||||
"structure. The expected structure was %s, but the yielded "
|
||||
"element was %s." % (output_types, values)), sys.exc_info()[2])
|
||||
ret_arrays = []
|
||||
for ret, dtype in zip(flattened_values, flattened_types):
|
||||
try:
|
||||
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
||||
ret, dtype=dtype.as_numpy_dtype))
|
||||
except (TypeError, ValueError):
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that could not be converted to "
|
||||
"the expected type. The expected type was %s, but the yielded "
|
||||
"element was %s." % (dtype.name, ret)), sys.exc_info()[2])
|
||||
|
||||
values_spec = structure.type_spec_from_value(values)
|
||||
# Additional type and shape checking to ensure that the components
|
||||
# of the generated element match the `output_types` and `output_shapes`
|
||||
# arguments.
|
||||
for (ret_array, expected_dtype, expected_shape) in zip(
|
||||
ret_arrays, flattened_types, flattened_shapes):
|
||||
if ret_array.dtype != expected_dtype.as_numpy_dtype:
|
||||
raise TypeError(
|
||||
"`generator` yielded an element of type %s where an element "
|
||||
"of type %s was expected." % (ret_array.dtype,
|
||||
expected_dtype.as_numpy_dtype))
|
||||
if not expected_shape.is_compatible_with(ret_array.shape):
|
||||
raise ValueError(
|
||||
"`generator` yielded an element of shape %s where an element "
|
||||
"of shape %s was expected." % (ret_array.shape, expected_shape))
|
||||
|
||||
if not structure.are_compatible(values_spec, output_signature):
|
||||
raise TypeError(
|
||||
"`generator` yielded an element of TypeSpec%s where an element "
|
||||
"of TypeSpec%s was expected." %
|
||||
(serialize_structure(values_spec),
|
||||
serialize_structure(output_signature)))
|
||||
return ret_arrays
|
||||
|
||||
return structure.to_tensor_list(output_signature, values)
|
||||
flat_values = script_ops.numpy_function(generator_py_func,
|
||||
[iterator_id_t], flattened_types)
|
||||
|
||||
return script_ops._eager_py_func( # pylint: disable=protected-access
|
||||
generator_py_func,
|
||||
inp=[iterator_id_t],
|
||||
Tout=flat_output_types,
|
||||
use_tape_cache=False)
|
||||
# The `py_func()` op drops the inferred shapes, so we add them back in
|
||||
# here.
|
||||
if output_shapes is not None:
|
||||
for ret_t, shape in zip(flat_values, flattened_shapes):
|
||||
ret_t.set_shape(shape)
|
||||
|
||||
return nest.pack_sequence_as(output_types, flat_values)
|
||||
|
||||
def finalize_fn(iterator_id_t):
|
||||
"""Releases host-side state for the iterator with ID `iterator_id_t`."""
|
||||
@ -886,7 +856,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
# given ID, and raises StopIteration when that iterator contains no
|
||||
# more elements.
|
||||
return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
|
||||
finalize_fn, output_signature)
|
||||
finalize_fn)
|
||||
|
||||
# A single-element dataset that, each time it is evaluated, contains a
|
||||
# freshly-generated and unique (for the returned dataset) int64
|
||||
@ -2308,14 +2278,9 @@ class DatasetV1(DatasetV2):
|
||||
|
||||
@staticmethod
|
||||
@functools.wraps(DatasetV2.from_generator)
|
||||
def from_generator(generator,
|
||||
output_types=None,
|
||||
output_shapes=None,
|
||||
args=None,
|
||||
output_signature=None):
|
||||
return DatasetV1Adapter(
|
||||
DatasetV2.from_generator(generator, output_types, output_shapes, args,
|
||||
output_signature))
|
||||
def from_generator(generator, output_types, output_shapes=None, args=None):
|
||||
return DatasetV1Adapter(DatasetV2.from_generator(
|
||||
generator, output_types, output_shapes, args))
|
||||
|
||||
@staticmethod
|
||||
@functools.wraps(DatasetV2.range)
|
||||
@ -3296,8 +3261,7 @@ class StructuredFunctionWrapper(object):
|
||||
class _GeneratorDataset(DatasetSource):
|
||||
"""A `Dataset` that generates elements by invoking a function."""
|
||||
|
||||
def __init__(self, init_args, init_func, next_func, finalize_func,
|
||||
output_signature):
|
||||
def __init__(self, init_args, init_func, next_func, finalize_func):
|
||||
"""Constructs a `_GeneratorDataset`.
|
||||
|
||||
Args:
|
||||
@ -3311,8 +3275,6 @@ class _GeneratorDataset(DatasetSource):
|
||||
finalize_func: A TensorFlow function that will be called on the result of
|
||||
`init_func` immediately before a C++ iterator over this dataset is
|
||||
destroyed. The return value is ignored.
|
||||
output_signature: A nested structure of `tf.TypeSpec` objects describing
|
||||
the output of `next_func`.
|
||||
"""
|
||||
self._init_args = init_args
|
||||
|
||||
@ -3332,9 +3294,6 @@ class _GeneratorDataset(DatasetSource):
|
||||
finalize_func,
|
||||
self._transformation_name(),
|
||||
input_structure=self._init_func.output_structure)
|
||||
|
||||
self._output_signature = output_signature
|
||||
|
||||
variant_tensor = gen_dataset_ops.generator_dataset(
|
||||
structure.to_tensor_list(self._init_structure, self._init_args) +
|
||||
self._init_func.function.captured_inputs,
|
||||
@ -3348,7 +3307,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._output_signature
|
||||
return self._next_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.from_generator()"
|
||||
|
@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank):
|
||||
|
||||
# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
|
||||
# it is a subclass of `CompositeTensor`.
|
||||
def normalize_element(element, dtypes=None):
|
||||
def normalize_element(element):
|
||||
"""Normalizes a nested structure of element components.
|
||||
|
||||
* Components matching `SparseTensorSpec` are converted to `SparseTensor`.
|
||||
@ -78,10 +78,6 @@ def normalize_element(element, dtypes=None):
|
||||
|
||||
Args:
|
||||
element: A nested structure of individual components.
|
||||
dtypes: (Optional.) A nested structure of `tf.DType` objects corresponding
|
||||
to each component of `element`. If specified, it will be used to set the
|
||||
exact type of output tensor when converting input components which
|
||||
are not tensors themselves (e.g. numpy arrays, native python types, etc.)
|
||||
|
||||
Returns:
|
||||
A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`,
|
||||
@ -89,21 +85,17 @@ def normalize_element(element, dtypes=None):
|
||||
"""
|
||||
components = nest.flatten(element)
|
||||
normalized_components = []
|
||||
if dtypes is None:
|
||||
flattened_dtypes = [None] * len(components)
|
||||
else:
|
||||
flattened_dtypes = nest.flatten(dtypes)
|
||||
with ops.name_scope("normalize_element"):
|
||||
# Imported here to avoid circular dependency.
|
||||
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
|
||||
for i, (t, dtype) in enumerate(zip(components, flattened_dtypes)):
|
||||
for i, t in enumerate(components):
|
||||
try:
|
||||
spec = type_spec_from_value(t, use_fallback=False)
|
||||
except TypeError:
|
||||
# TypeError indicates it was not possible to compute a `TypeSpec` for
|
||||
# the value. As a fallback try converting the value to a tensor.
|
||||
normalized_components.append(
|
||||
ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
|
||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
else:
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
|
||||
@ -120,7 +112,7 @@ def normalize_element(element, dtypes=None):
|
||||
normalized_components.append(t)
|
||||
else:
|
||||
normalized_components.append(
|
||||
ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
|
||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
return nest.pack_sequence_as(element, normalized_components)
|
||||
|
||||
|
||||
|
@ -2085,11 +2085,6 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
else:
|
||||
return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The `tf.dtypes.DType` specified by this type for the RaggedTensor."""
|
||||
return self._dtype
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype)
|
||||
|
||||
|
@ -70,7 +70,7 @@ def _maybe_copy_to_context_device(tensor, device_name):
|
||||
class EagerFunc(object):
|
||||
"""A wrapper for a function owned by an EagerPyFunc."""
|
||||
|
||||
def __init__(self, func, Tout, is_grad_func, use_tape_cache=True):
|
||||
def __init__(self, func, Tout, is_grad_func):
|
||||
"""Constructs an EagerFunc.
|
||||
|
||||
Args:
|
||||
@ -79,12 +79,10 @@ class EagerFunc(object):
|
||||
None.
|
||||
is_grad_func: Whether this EagerFunc is the gradient of another
|
||||
EagerPyFunc.
|
||||
use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
|
||||
"""
|
||||
self._func = func
|
||||
self._out_dtypes = Tout
|
||||
self._is_grad_func = is_grad_func
|
||||
self._use_tape_cache = use_tape_cache
|
||||
|
||||
def _convert(self, value, dtype):
|
||||
"""Converts `value` to a tensor of type `dtype`, with error checking.
|
||||
@ -148,8 +146,7 @@ class EagerFunc(object):
|
||||
else:
|
||||
outputs = _maybe_copy_to_context_device(
|
||||
self._convert(ret, dtype=self._out_dtypes[0]), device_name)
|
||||
if self._use_tape_cache:
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -279,8 +276,7 @@ def _internal_py_func(func,
|
||||
stateful=None,
|
||||
eager=False,
|
||||
is_grad_func=False,
|
||||
name=None,
|
||||
use_tape_cache=True):
|
||||
name=None):
|
||||
"""See documentation for py_func and eager_py_func."""
|
||||
if not callable(func):
|
||||
raise ValueError("Expected func to be callable, got func of type {}".format(
|
||||
@ -296,7 +292,7 @@ def _internal_py_func(func,
|
||||
Tout = [Tout]
|
||||
|
||||
if eager:
|
||||
func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache)
|
||||
func = EagerFunc(func, Tout, is_grad_func)
|
||||
|
||||
# Tying the registered function's lifetime with the current default graph is
|
||||
# not reliable. For example, Estimator-based binaries may switch graphs in
|
||||
@ -373,35 +369,6 @@ def _EagerPyFuncGrad(op, *dy):
|
||||
is_grad_func=True)
|
||||
|
||||
|
||||
# NOTE(lithuak): this function as a layer of indirection was added with one
|
||||
# specific purpose: as a workaround for github issue #35084.
|
||||
# It does all the same as `eager_py_func` used to do with one difference:
|
||||
# it can be used to instruct underlying EagerFunc not to use `tape_cache`
|
||||
# to avoid memory leak. When the issue #35084 is fixed - this function should
|
||||
# be removed, its body should be moved back to become the body of
|
||||
# `eager_py_func` and all the call sites should be reverted to
|
||||
# using `eager_py_func` without `use_tape_cache` argument of any value.
|
||||
def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True):
|
||||
"""Wraps a python function into a TensorFlow op that executes it eagerly."""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func,
|
||||
inp=inp,
|
||||
Tout=Tout,
|
||||
eager=True,
|
||||
name=name,
|
||||
use_tape_cache=use_tape_cache)
|
||||
|
||||
return _internal_py_func(
|
||||
func=func,
|
||||
inp=inp,
|
||||
Tout=Tout,
|
||||
eager=True,
|
||||
name=name,
|
||||
use_tape_cache=use_tape_cache)
|
||||
|
||||
|
||||
@tf_export("py_function")
|
||||
def eager_py_func(func, inp, Tout, name=None):
|
||||
"""Wraps a python function into a TensorFlow op that executes it eagerly.
|
||||
@ -482,8 +449,12 @@ def eager_py_func(func, inp, Tout, name=None):
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||
if `func` returns None.
|
||||
"""
|
||||
return _eager_py_func(
|
||||
func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True)
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
|
||||
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
|
||||
|
||||
def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||
|
@ -4,10 +4,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -63,7 +63,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_sparse_tensor_slices"
|
||||
|
@ -4,10 +4,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -47,7 +47,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
@ -48,7 +48,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_generator"
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_tensor_slices"
|
||||
|
Loading…
Reference in New Issue
Block a user