[tf.data] API changes.
This CL makes the following tf.data API-related changes: 1) `tf.data.Iterator` and `tf.data.IteratorSpec` are exposed in the v2 API 2) `tf.experimental.Optional` is exposed in the API (previously exposed as `tf.data.experimental.Optional`) 3) `tf.experimental.Optional.none_from_structure` and `tf.experimental.Optional.value_structure` is renamed to and `tf.experimental.Optional.empty` and `tf.experimental.Optional.element_spec` respectively 4) `tf.OptionalSpec.value_structure` is renamed to `tf.OptionalSpec.element_spec` 5) reflects these changes in documentation and code 6) adds testable docstring for newly exposed APIs PiperOrigin-RevId: 316003328 Change-Id: I7b7e79942308b3d2f94b988c31729980fb69d961
This commit is contained in:
parent
b6d13bb0a8
commit
cfe037e3fe
tensorflow
python
data
experimental/ops
kernel_tests
ops
distribute
tools/api/golden
v1
tensorflow.-optional-spec.pbtxttensorflow.data.-dataset.pbtxttensorflow.data.-fixed-length-record-dataset.pbtxttensorflow.data.-t-f-record-dataset.pbtxttensorflow.data.-text-line-dataset.pbtxttensorflow.data.experimental.-csv-dataset.pbtxttensorflow.data.experimental.-optional-structure.pbtxttensorflow.data.experimental.-optional.pbtxttensorflow.data.experimental.-random-dataset.pbtxttensorflow.data.experimental.-sql-dataset.pbtxttensorflow.experimental.-optional.pbtxttensorflow.experimental.pbtxt
v2
tensorflow.-optional-spec.pbtxttensorflow.data.-dataset.pbtxttensorflow.data.-fixed-length-record-dataset.pbtxttensorflow.data.-iterator-spec.pbtxttensorflow.data.-iterator.pbtxttensorflow.data.-t-f-record-dataset.pbtxttensorflow.data.-text-line-dataset.pbtxttensorflow.data.experimental.-csv-dataset.pbtxttensorflow.data.experimental.-optional.pbtxttensorflow.data.experimental.-random-dataset.pbtxttensorflow.data.experimental.-sql-dataset.pbtxttensorflow.data.pbtxttensorflow.experimental.-optional.pbtxttensorflow.experimental.pbtxt
@ -21,7 +21,7 @@ from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.enumerate()")
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.enumerate()`.")
|
||||
@tf_export("data.experimental.enumerate_dataset")
|
||||
def enumerate_dataset(start=0):
|
||||
"""A transformation that enumerates the elements of a dataset.
|
||||
|
@ -112,7 +112,8 @@ def _get_next_as_optional_test_combinations():
|
||||
def reduce_fn(x, y):
|
||||
name, value, value_fn, gpu_compatible = y
|
||||
return x + combinations.combine(
|
||||
np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn),
|
||||
np_value=value,
|
||||
tf_value_fn=combinations.NamedObject(name, value_fn),
|
||||
gpu_compatible=gpu_compatible)
|
||||
|
||||
return functools.reduce(reduce_fn, cases, [])
|
||||
@ -160,13 +161,13 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromNone(self):
|
||||
value_structure = tensor_spec.TensorSpec([], dtypes.float32)
|
||||
opt = optional_ops.Optional.none_from_structure(value_structure)
|
||||
self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
|
||||
opt = optional_ops.Optional.empty(value_structure)
|
||||
self.assertTrue(opt.element_spec.is_compatible_with(value_structure))
|
||||
self.assertFalse(
|
||||
opt.value_structure.is_compatible_with(
|
||||
opt.element_spec.is_compatible_with(
|
||||
tensor_spec.TensorSpec([1], dtypes.float32)))
|
||||
self.assertFalse(
|
||||
opt.value_structure.is_compatible_with(
|
||||
opt.element_spec.is_compatible_with(
|
||||
tensor_spec.TensorSpec([], dtypes.int32)))
|
||||
self.assertFalse(self.evaluate(opt.has_value()))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -183,20 +184,17 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt1 = optional_ops.Optional.from_value((1.0, 2.0))
|
||||
opt2 = optional_ops.Optional.from_value((3.0, 4.0))
|
||||
|
||||
add_tensor = math_ops.add_n([opt1._variant_tensor,
|
||||
opt2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
|
||||
add_tensor = math_ops.add_n(
|
||||
[opt1._variant_tensor, opt2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.element_spec)
|
||||
self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
|
||||
|
||||
# Without value
|
||||
opt_none1 = optional_ops.Optional.none_from_structure(
|
||||
opt1.value_structure)
|
||||
opt_none2 = optional_ops.Optional.none_from_structure(
|
||||
opt2.value_structure)
|
||||
add_tensor = math_ops.add_n([opt_none1._variant_tensor,
|
||||
opt_none2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor,
|
||||
opt_none1.value_structure)
|
||||
opt_none1 = optional_ops.Optional.empty(opt1.element_spec)
|
||||
opt_none2 = optional_ops.Optional.empty(opt2.element_spec)
|
||||
add_tensor = math_ops.add_n(
|
||||
[opt_none1._variant_tensor, opt_none2._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt_none1.element_spec)
|
||||
self.assertFalse(self.evaluate(add_opt.has_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -211,13 +209,13 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
|
||||
opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
|
||||
|
||||
add_tensor = math_ops.add_n([opt3._variant_tensor,
|
||||
opt4._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
|
||||
add_tensor = math_ops.add_n(
|
||||
[opt3._variant_tensor, opt4._variant_tensor])
|
||||
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.element_spec)
|
||||
self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
|
||||
|
||||
inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
|
||||
opt1.value_structure)
|
||||
opt1.element_spec)
|
||||
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -230,17 +228,14 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# With value
|
||||
opt = optional_ops.Optional.from_value((1.0, 2.0))
|
||||
zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt.value_structure)
|
||||
self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
|
||||
(0.0, 0.0))
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt.element_spec)
|
||||
self.assertAllEqual(self.evaluate(zeros_opt.get_value()), (0.0, 0.0))
|
||||
|
||||
# Without value
|
||||
opt_none = optional_ops.Optional.none_from_structure(
|
||||
opt.value_structure)
|
||||
opt_none = optional_ops.Optional.empty(opt.element_spec)
|
||||
zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt_none.value_structure)
|
||||
opt_none.element_spec)
|
||||
self.assertFalse(self.evaluate(zeros_opt.has_value()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -254,10 +249,9 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
|
||||
|
||||
zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
|
||||
opt2.value_structure)
|
||||
zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt2.element_spec)
|
||||
inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
|
||||
opt1.value_structure)
|
||||
opt1.element_spec)
|
||||
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
@ -269,16 +263,16 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optional_with_value = optional_ops.Optional.from_value(
|
||||
(constant_op.constant(37.0), constant_op.constant("Foo"),
|
||||
constant_op.constant(42)))
|
||||
optional_none = optional_ops.Optional.none_from_structure(
|
||||
optional_none = optional_ops.Optional.empty(
|
||||
tensor_spec.TensorSpec([], dtypes.float32))
|
||||
|
||||
with ops.device("/gpu:0"):
|
||||
gpu_optional_with_value = optional_ops._OptionalImpl(
|
||||
array_ops.identity(optional_with_value._variant_tensor),
|
||||
optional_with_value.value_structure)
|
||||
optional_with_value.element_spec)
|
||||
gpu_optional_none = optional_ops._OptionalImpl(
|
||||
array_ops.identity(optional_none._variant_tensor),
|
||||
optional_none.value_structure)
|
||||
optional_none.element_spec)
|
||||
|
||||
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
|
||||
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
|
||||
@ -299,7 +293,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optional_with_value = optional_ops.Optional.from_value(
|
||||
(constant_op.constant(37.0), constant_op.constant("Foo"),
|
||||
constant_op.constant(42)))
|
||||
optional_none = optional_ops.Optional.none_from_structure(
|
||||
optional_none = optional_ops.Optional.empty(
|
||||
tensor_spec.TensorSpec([], dtypes.float32))
|
||||
nested_optional = optional_ops.Optional.from_value(
|
||||
(optional_with_value._variant_tensor, optional_none._variant_tensor,
|
||||
@ -308,7 +302,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with ops.device("/gpu:0"):
|
||||
gpu_nested_optional = optional_ops._OptionalImpl(
|
||||
array_ops.identity(nested_optional._variant_tensor),
|
||||
nested_optional.value_structure)
|
||||
nested_optional.element_spec)
|
||||
|
||||
gpu_nested_optional_has_value = gpu_nested_optional.has_value()
|
||||
gpu_nested_optional_values = gpu_nested_optional.get_value()
|
||||
@ -316,10 +310,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
|
||||
|
||||
inner_with_value = optional_ops._OptionalImpl(
|
||||
gpu_nested_optional_values[0], optional_with_value.value_structure)
|
||||
gpu_nested_optional_values[0], optional_with_value.element_spec)
|
||||
|
||||
inner_none = optional_ops._OptionalImpl(
|
||||
gpu_nested_optional_values[1], optional_none.value_structure)
|
||||
inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1],
|
||||
optional_none.element_spec)
|
||||
|
||||
self.assertEqual((37.0, b"Foo", 42),
|
||||
self.evaluate(inner_with_value.get_value()))
|
||||
@ -327,21 +321,20 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
_optional_spec_test_combinations()))
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_optional_spec_test_combinations()))
|
||||
def testOptionalSpec(self, tf_value_fn, expected_value_structure):
|
||||
tf_value = tf_value_fn()
|
||||
opt = optional_ops.Optional.from_value(tf_value)
|
||||
|
||||
self.assertTrue(
|
||||
structure.are_compatible(opt.value_structure, expected_value_structure))
|
||||
structure.are_compatible(opt.element_spec, expected_value_structure))
|
||||
|
||||
opt_structure = structure.type_spec_from_value(opt)
|
||||
self.assertIsInstance(opt_structure, optional_ops.OptionalSpec)
|
||||
self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
|
||||
self.assertTrue(
|
||||
structure.are_compatible(opt_structure._value_structure,
|
||||
structure.are_compatible(opt_structure._element_spec,
|
||||
expected_value_structure))
|
||||
self.assertEqual([dtypes.variant],
|
||||
structure.get_flat_tensor_types(opt_structure))
|
||||
@ -364,13 +357,11 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(round_trip_opt.get_value().get_value()))
|
||||
else:
|
||||
self.assertValuesEqual(
|
||||
self.evaluate(tf_value),
|
||||
self.evaluate(round_trip_opt.get_value()))
|
||||
self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
_get_next_as_optional_test_combinations()))
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_get_next_as_optional_test_combinations()))
|
||||
def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
|
||||
gpu_compatible):
|
||||
if not gpu_compatible and test.is_gpu_available():
|
||||
@ -384,9 +375,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for _ in range(3):
|
||||
next_elem = iterator_ops.get_next_as_optional(iterator)
|
||||
self.assertIsInstance(next_elem, optional_ops.Optional)
|
||||
self.assertTrue(structure.are_compatible(
|
||||
next_elem.value_structure,
|
||||
structure.type_spec_from_value(tf_value_fn())))
|
||||
self.assertTrue(
|
||||
structure.are_compatible(
|
||||
next_elem.element_spec,
|
||||
structure.type_spec_from_value(tf_value_fn())))
|
||||
self.assertTrue(next_elem.has_value())
|
||||
self.assertValuesEqual(np_value, next_elem.get_value())
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
@ -400,9 +392,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
iterator = dataset_ops.make_initializable_iterator(ds)
|
||||
next_elem = iterator_ops.get_next_as_optional(iterator)
|
||||
self.assertIsInstance(next_elem, optional_ops.Optional)
|
||||
self.assertTrue(structure.are_compatible(
|
||||
next_elem.value_structure,
|
||||
structure.type_spec_from_value(tf_value_fn())))
|
||||
self.assertTrue(
|
||||
structure.are_compatible(
|
||||
next_elem.element_spec,
|
||||
structure.type_spec_from_value(tf_value_fn())))
|
||||
# Before initializing the iterator, evaluating the optional fails with
|
||||
# a FailedPreconditionError. This is only relevant in graph mode.
|
||||
elem_has_value_t = next_elem.has_value()
|
||||
@ -430,6 +423,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFunctionBoundaries(self):
|
||||
|
||||
@def_function.function
|
||||
def get_optional():
|
||||
x = constant_op.constant(1.0)
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
import sys
|
||||
import threading
|
||||
@ -102,7 +103,8 @@ tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
|
||||
|
||||
@tf_export("data.Dataset", v1=[])
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
class DatasetV2(collections.Iterable, tracking_base.Trackable,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Represents a potentially large set of elements.
|
||||
|
||||
The `tf.data.Dataset` API supports writing descriptive and efficient input
|
||||
@ -399,12 +401,12 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
return dataset
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates an `Iterator` for enumerating the elements of this dataset.
|
||||
"""Creates an iterator for elements of this dataset.
|
||||
|
||||
The returned iterator implements the Python iterator protocol.
|
||||
The returned iterator implements the Python Iterator protocol.
|
||||
|
||||
Returns:
|
||||
An `Iterator` over the elements of this dataset.
|
||||
An `tf.data.Iterator` for the elements of this dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||
@ -740,7 +742,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
Note: The current implementation of `Dataset.from_generator()` uses
|
||||
`tf.numpy_function` and inherits the same constraints. In particular, it
|
||||
requires the `Dataset`- and `Iterator`-related operations to be placed
|
||||
requires the dataset and iterator related operations to be placed
|
||||
on a device in the same process as the Python program that called
|
||||
`Dataset.from_generator()`. The body of `generator` will not be
|
||||
serialized in a `GraphDef`, and you should not use this method if you
|
||||
@ -2208,7 +2210,7 @@ class DatasetV1(DatasetV2):
|
||||
"code base as there are in general no guarantees about the "
|
||||
"interoperability of TF 1 and TF 2 code.")
|
||||
def make_one_shot_iterator(self):
|
||||
"""Creates an `Iterator` for enumerating the elements of this dataset.
|
||||
"""Creates an iterator for elements of this dataset.
|
||||
|
||||
Note: The returned iterator will be initialized automatically.
|
||||
A "one-shot" iterator does not currently support re-initialization. For
|
||||
@ -2231,7 +2233,7 @@ class DatasetV1(DatasetV2):
|
||||
```
|
||||
|
||||
Returns:
|
||||
An `Iterator` over the elements of this dataset.
|
||||
An `tf.data.Iterator` for elements of this dataset.
|
||||
"""
|
||||
return self._make_one_shot_iterator()
|
||||
|
||||
@ -2301,7 +2303,7 @@ class DatasetV1(DatasetV2):
|
||||
"are in general no guarantees about the interoperability of TF 1 and TF "
|
||||
"2 code.")
|
||||
def make_initializable_iterator(self, shared_name=None):
|
||||
"""Creates an `Iterator` for enumerating the elements of this dataset.
|
||||
"""Creates an iterator for elements of this dataset.
|
||||
|
||||
Note: The returned iterator will be in an uninitialized state,
|
||||
and you must run the `iterator.initializer` operation before using it:
|
||||
@ -2328,7 +2330,7 @@ class DatasetV1(DatasetV2):
|
||||
devices (e.g. when using a remote server).
|
||||
|
||||
Returns:
|
||||
An `Iterator` over the elements of this dataset.
|
||||
A `tf.data.Iterator` for elements of this dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
@ -2676,7 +2678,7 @@ def _ensure_same_dataset_graph(dataset):
|
||||
|
||||
@tf_export(v1=["data.make_one_shot_iterator"])
|
||||
def make_one_shot_iterator(dataset):
|
||||
"""Creates a `tf.compat.v1.data.Iterator` for enumerating dataset elements.
|
||||
"""Creates an iterator for elements of `dataset`.
|
||||
|
||||
Note: The returned iterator will be initialized automatically.
|
||||
A "one-shot" iterator does not support re-initialization.
|
||||
@ -2685,7 +2687,7 @@ def make_one_shot_iterator(dataset):
|
||||
dataset: A `tf.data.Dataset`.
|
||||
|
||||
Returns:
|
||||
A `tf.compat.v1.data.Iterator` over the elements of this dataset.
|
||||
A `tf.data.Iterator` for elements of `dataset`.
|
||||
"""
|
||||
try:
|
||||
# Call the defined `_make_one_shot_iterator()` if there is one, because some
|
||||
@ -2697,7 +2699,7 @@ def make_one_shot_iterator(dataset):
|
||||
|
||||
@tf_export(v1=["data.make_initializable_iterator"])
|
||||
def make_initializable_iterator(dataset, shared_name=None):
|
||||
"""Creates a `tf.compat.v1.data.Iterator` for enumerating the elements of a dataset.
|
||||
"""Creates an iterator for elements of `dataset`.
|
||||
|
||||
Note: The returned iterator will be in an uninitialized state,
|
||||
and you must run the `iterator.initializer` operation before using it:
|
||||
@ -2716,7 +2718,7 @@ def make_initializable_iterator(dataset, shared_name=None):
|
||||
(e.g. when using a remote server).
|
||||
|
||||
Returns:
|
||||
A `tf.compat.v1.data.Iterator` over the elements of `dataset`.
|
||||
A `tf.data.Iterator` for elements of `dataset`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
@ -2731,10 +2733,10 @@ def make_initializable_iterator(dataset, shared_name=None):
|
||||
|
||||
@tf_export("data.experimental.get_structure")
|
||||
def get_structure(dataset_or_iterator):
|
||||
"""Returns the type specification of an element of a `Dataset` or `Iterator`.
|
||||
"""Returns the type signature for elements of the input dataset / iterator.
|
||||
|
||||
Args:
|
||||
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
|
||||
dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
@ -2742,21 +2744,20 @@ def get_structure(dataset_or_iterator):
|
||||
components.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
|
||||
TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
|
||||
object.
|
||||
"""
|
||||
try:
|
||||
return dataset_or_iterator.element_spec # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
|
||||
"object, but got %s." % type(dataset_or_iterator))
|
||||
raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or "
|
||||
"tf.data.Iterator object, but got %s." %
|
||||
type(dataset_or_iterator))
|
||||
|
||||
|
||||
@tf_export(v1=["data.get_output_classes"])
|
||||
def get_legacy_output_classes(dataset_or_iterator):
|
||||
"""Returns the output classes of a `Dataset` or `Iterator` elements.
|
||||
|
||||
This utility method replaces the deprecated-in-V2
|
||||
`tf.compat.v1.Dataset.output_classes` property.
|
||||
"""Returns the output classes for elements of the input dataset / iterator.
|
||||
|
||||
Args:
|
||||
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
|
||||
@ -2773,10 +2774,7 @@ def get_legacy_output_classes(dataset_or_iterator):
|
||||
|
||||
@tf_export(v1=["data.get_output_shapes"])
|
||||
def get_legacy_output_shapes(dataset_or_iterator):
|
||||
"""Returns the output shapes of a `Dataset` or `Iterator` elements.
|
||||
|
||||
This utility method replaces the deprecated-in-V2
|
||||
`tf.compat.v1.Dataset.output_shapes` property.
|
||||
"""Returns the output shapes for elements of the input dataset / iterator.
|
||||
|
||||
Args:
|
||||
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
|
||||
@ -2793,10 +2791,7 @@ def get_legacy_output_shapes(dataset_or_iterator):
|
||||
|
||||
@tf_export(v1=["data.get_output_types"])
|
||||
def get_legacy_output_types(dataset_or_iterator):
|
||||
"""Returns the output shapes of a `Dataset` or `Iterator` elements.
|
||||
|
||||
This utility method replaces the deprecated-in-V2
|
||||
`tf.compat.v1.Dataset.output_types` property.
|
||||
"""Returns the output shapes for elements of the input dataset / iterator.
|
||||
|
||||
Args:
|
||||
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
|
||||
|
@ -17,9 +17,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
@ -489,12 +493,6 @@ class Iterator(trackable.Trackable):
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this iterator and specifying the type of individual components.
|
||||
"""
|
||||
return self._element_spec
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
@ -543,7 +541,102 @@ class IteratorResourceDeleter(object):
|
||||
handle=self._handle, deleter=self._deleter)
|
||||
|
||||
|
||||
class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
@tf_export("data.Iterator", v1=[])
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class IteratorBase(collections.Iterator, trackable.Trackable,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Represents an iterator of a `tf.data.Dataset`.
|
||||
|
||||
`tf.data.Iterator` is the primary mechanism for enumerating elements of a
|
||||
`tf.data.Dataset`. It supports the Python Iterator protocol, which means
|
||||
it can be iterated over using a for-loop:
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(2)
|
||||
>>> for element in dataset:
|
||||
... print(element)
|
||||
tf.Tensor(0, shape=(), dtype=int64)
|
||||
tf.Tensor(1, shape=(), dtype=int64)
|
||||
|
||||
or by fetching individual elements explicitly via `get_next()`:
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(2)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> print(iterator.get_next())
|
||||
tf.Tensor(0, shape=(), dtype=int64)
|
||||
>>> print(iterator.get_next())
|
||||
tf.Tensor(1, shape=(), dtype=int64)
|
||||
|
||||
In addition, non-raising iteration is supported via `get_next_as_optional()`,
|
||||
which returns the next element (if available) wrapped in a
|
||||
`tf.experimental.Optional`.
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensors(42)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> optional = iterator.get_next_as_optional()
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(True, shape=(), dtype=bool)
|
||||
>>> optional = iterator.get_next_as_optional()
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(False, shape=(), dtype=bool)
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator.
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensors(42)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> iterator.element_spec
|
||||
tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this iterator, specifying the type of individual components.
|
||||
"""
|
||||
raise NotImplementedError("Iterator.element_spec")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next(self):
|
||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensors(42)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> print(iterator.get_next())
|
||||
tf.Tensor(42, shape=(), dtype=int32)
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.Tensor` objects.
|
||||
|
||||
Raises:
|
||||
`tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
|
||||
"""
|
||||
raise NotImplementedError("Iterator.get_next()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_as_optional(self):
|
||||
"""Returns a `tf.experimental.Optional` which contains the next element.
|
||||
|
||||
If the iterator has reached the end of the sequence, the returned
|
||||
`tf.experimental.Optional` will have no value.
|
||||
|
||||
>>> dataset = tf.data.Dataset.from_tensors(42)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> optional = iterator.get_next_as_optional()
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(True, shape=(), dtype=bool)
|
||||
>>> print(optional.get_value())
|
||||
tf.Tensor(42, shape=(), dtype=int32)
|
||||
>>> optional = iterator.get_next_as_optional()
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(False, shape=(), dtype=bool)
|
||||
|
||||
Returns:
|
||||
A `tf.experimental.Optional` object representing the next element.
|
||||
"""
|
||||
raise NotImplementedError("Iterator.get_next_as_optional()")
|
||||
|
||||
|
||||
class OwnedIterator(IteratorBase):
|
||||
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
|
||||
|
||||
The iterator resource created through `OwnedIterator` is owned by the Python
|
||||
@ -578,7 +671,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
`element_spec` is not provided. Or `dataset` is provided and either
|
||||
`components` and `element_spec` is provided.
|
||||
"""
|
||||
|
||||
error_message = ("Either `dataset` or both `components` and "
|
||||
"`element_spec` need to be provided.")
|
||||
|
||||
@ -644,8 +736,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
return self.next()
|
||||
|
||||
def _next_internal(self):
|
||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||
"""
|
||||
if not context.executing_eagerly():
|
||||
with ops.device(self._device):
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
@ -659,7 +749,7 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
# TODO(b/77291417): Fix
|
||||
with context.execution_mode(context.SYNC):
|
||||
with ops.device(self._device):
|
||||
# TODO(ashankar): Consider removing this ops.device() contextmanager
|
||||
# TODO(ashankar): Consider removing this ops.device() context manager
|
||||
# and instead mimic ops placement in graphs: Operations on resource
|
||||
# handles execute on the same device as where the resource is placed.
|
||||
ret = gen_dataset_ops.iterator_get_next(
|
||||
@ -678,7 +768,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
return IteratorSpec(self.element_spec)
|
||||
|
||||
def next(self):
|
||||
"""Returns a nested structure of `Tensor`s containing the next element."""
|
||||
try:
|
||||
return self._next_internal()
|
||||
except errors.OutOfRangeError:
|
||||
@ -730,29 +819,20 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this iterator and specifying the type of individual components.
|
||||
"""
|
||||
return self._element_spec
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||
|
||||
Args:
|
||||
name: (Optional.) A name for the created operation. Currently unused.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.Tensor` objects.
|
||||
|
||||
Raises:
|
||||
`tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
|
||||
"""
|
||||
del name
|
||||
def get_next(self):
|
||||
return self._next_internal()
|
||||
|
||||
def get_next_as_optional(self):
|
||||
# pylint: disable=protected-access
|
||||
return optional_ops._OptionalImpl(
|
||||
gen_dataset_ops.iterator_get_next_as_optional(
|
||||
self._iterator_resource,
|
||||
output_types=structure.get_flat_tensor_types(self.element_spec),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
self.element_spec)), self.element_spec)
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
|
||||
def _saveable_factory(name):
|
||||
@ -771,9 +851,27 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
return {"ITERATOR": _saveable_factory}
|
||||
|
||||
|
||||
# TODO(jsimsa): Export this as "tf.data.IteratorSpec".
|
||||
@tf_export("data.IteratorSpec", v1=[])
|
||||
class IteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `OwnedIterator`."""
|
||||
"""Type specification for `tf.data.Iterator`.
|
||||
|
||||
For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
|
||||
takes `tf.data.Iterator` as an input argument:
|
||||
|
||||
>>> @tf.function(input_signature=[tf.data.IteratorSpec(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
|
||||
... def square(iterator):
|
||||
... x = iterator.get_next()
|
||||
... return x * x
|
||||
>>> dataset = tf.data.Dataset.from_tensors(5)
|
||||
>>> iterator = iter(dataset)
|
||||
>>> print(square(iterator))
|
||||
tf.Tensor(25, shape=(), dtype=int32)
|
||||
|
||||
Attributes:
|
||||
element_spec: A nested structure of `TypeSpec` objects that represents the
|
||||
type specification of the iterator elements.
|
||||
"""
|
||||
|
||||
__slots__ = ["_element_spec"]
|
||||
|
||||
@ -833,19 +931,21 @@ class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
|
||||
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
|
||||
@tf_export("data.experimental.get_next_as_optional")
|
||||
def get_next_as_optional(iterator):
|
||||
"""Returns an `Optional` that contains the next value from the iterator.
|
||||
"""Returns a `tf.experimental.Optional` with the next element of the iterator.
|
||||
|
||||
If `iterator` has reached the end of the sequence, the returned `Optional`
|
||||
will have no value.
|
||||
If the iterator has reached the end of the sequence, the returned
|
||||
`tf.experimental.Optional` will have no value.
|
||||
|
||||
Args:
|
||||
iterator: An iterator for an instance of `tf.data.Dataset`.
|
||||
iterator: A `tf.data.Iterator`.
|
||||
|
||||
Returns:
|
||||
An `Optional` object representing the next value from the iterator (if it
|
||||
has one) or no value.
|
||||
A `tf.experimental.Optional` object which either contains the next element
|
||||
of the iterator (if it exists) or no value.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
return optional_ops._OptionalImpl(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An Optional type for representing potentially missing values."""
|
||||
"""A type for representing values that may or may not exist."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -28,28 +28,51 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.Optional")
|
||||
@tf_export("experimental.Optional", "data.experimental.Optional")
|
||||
@deprecation.deprecated_endpoints("data.experimental.Optional")
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Optional(composite_tensor.CompositeTensor):
|
||||
"""Wraps a value that may/may not be present at runtime.
|
||||
"""Represents a value that may or may not be present.
|
||||
|
||||
An `Optional` can represent the result of an operation that may fail as a
|
||||
value, rather than raising an exception and halting execution. For example,
|
||||
`tf.data.experimental.get_next_as_optional` returns an `Optional` that either
|
||||
contains the next value of an iterator if one exists, or a "none" value that
|
||||
indicates the end of the sequence has been reached.
|
||||
A `tf.experimental.Optional` can represent the result of an operation that may
|
||||
fail as a value, rather than raising an exception and halting execution. For
|
||||
example, `tf.data.Iterator.get_next_as_optional()` returns a
|
||||
`tf.experimental.Optional` that either contains the next element of an
|
||||
iterator if one exists, or an "empty" value that indicates the end of the
|
||||
sequence has been reached.
|
||||
|
||||
`Optional` can only be used by values that are convertible to `Tensor` or
|
||||
`CompositeTensor`.
|
||||
`tf.experimental.Optional` can only be used with values that are convertible
|
||||
to `tf.Tensor` or `tf.CompositeTensor`.
|
||||
|
||||
One can create a `tf.experimental.Optional` from a value using the
|
||||
`from_value()` method:
|
||||
|
||||
>>> optional = tf.experimental.Optional.from_value(42)
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(True, shape=(), dtype=bool)
|
||||
>>> print(optional.get_value())
|
||||
tf.Tensor(42, shape=(), dtype=int32)
|
||||
|
||||
or without a value using the `empty()` method:
|
||||
|
||||
>>> optional = tf.experimental.Optional.empty(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(False, shape=(), dtype=bool)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def has_value(self, name=None):
|
||||
"""Returns a tensor that evaluates to `True` if this optional has a value.
|
||||
|
||||
>>> optional = tf.experimental.Optional.from_value(42)
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(True, shape=(), dtype=bool)
|
||||
|
||||
Args:
|
||||
name: (Optional.) A name for the created operation.
|
||||
|
||||
@ -62,9 +85,13 @@ class Optional(composite_tensor.CompositeTensor):
|
||||
def get_value(self, name=None):
|
||||
"""Returns the value wrapped by this optional.
|
||||
|
||||
If this optional does not have a value (i.e. `self.has_value()` evaluates
|
||||
to `False`), this operation will raise `tf.errors.InvalidArgumentError`
|
||||
at runtime.
|
||||
If this optional does not have a value (i.e. `self.has_value()` evaluates to
|
||||
`False`), this operation will raise `tf.errors.InvalidArgumentError` at
|
||||
runtime.
|
||||
|
||||
>>> optional = tf.experimental.Optional.from_value(42)
|
||||
>>> print(optional.get_value())
|
||||
tf.Tensor(42, shape=(), dtype=int32)
|
||||
|
||||
Args:
|
||||
name: (Optional.) A name for the created operation.
|
||||
@ -75,62 +102,77 @@ class Optional(composite_tensor.CompositeTensor):
|
||||
raise NotImplementedError("Optional.get_value()")
|
||||
|
||||
@abc.abstractproperty
|
||||
def value_structure(self):
|
||||
"""The structure of the components of this optional.
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this optional.
|
||||
|
||||
>>> optional = tf.experimental.Optional.from_value(42)
|
||||
>>> print(optional.element_spec)
|
||||
tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
|
||||
|
||||
Returns:
|
||||
A `Structure` object representing the structure of the components of this
|
||||
optional.
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this optional, specifying the type of individual components.
|
||||
"""
|
||||
raise NotImplementedError("Optional.value_structure")
|
||||
raise NotImplementedError("Optional.element_spec")
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
"""Returns an `Optional` that wraps the given value.
|
||||
|
||||
Args:
|
||||
value: A value to wrap. The value must be convertible to `Tensor` or
|
||||
`CompositeTensor`.
|
||||
|
||||
Returns:
|
||||
An `Optional` that wraps `value`.
|
||||
"""
|
||||
with ops.name_scope("optional") as scope:
|
||||
with ops.name_scope("value"):
|
||||
value_structure = structure.type_spec_from_value(value)
|
||||
encoded_value = structure.to_tensor_list(value_structure, value)
|
||||
|
||||
return _OptionalImpl(
|
||||
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
|
||||
value_structure)
|
||||
|
||||
@staticmethod
|
||||
def none_from_structure(value_structure):
|
||||
def empty(element_spec):
|
||||
"""Returns an `Optional` that has no value.
|
||||
|
||||
NOTE: This method takes an argument that defines the structure of the value
|
||||
that would be contained in the returned `Optional` if it had a value.
|
||||
|
||||
>>> optional = tf.experimental.Optional.empty(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(False, shape=(), dtype=bool)
|
||||
|
||||
Args:
|
||||
value_structure: A `Structure` object representing the structure of the
|
||||
components of this optional.
|
||||
element_spec: A nested structure of `tf.TypeSpec` objects matching the
|
||||
structure of an element of this optional.
|
||||
|
||||
Returns:
|
||||
An `Optional` that has no value.
|
||||
A `tf.experimental.Optional` with no value.
|
||||
"""
|
||||
return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
|
||||
return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
"""Returns a `tf.experimental.Optional` that wraps the given value.
|
||||
|
||||
>>> optional = tf.experimental.Optional.from_value(42)
|
||||
>>> print(optional.has_value())
|
||||
tf.Tensor(True, shape=(), dtype=bool)
|
||||
>>> print(optional.get_value())
|
||||
tf.Tensor(42, shape=(), dtype=int32)
|
||||
|
||||
Args:
|
||||
value: A value to wrap. The value must be convertible to `tf.Tensor` or
|
||||
`tf.CompositeTensor`.
|
||||
|
||||
Returns:
|
||||
A `tf.experimental.Optional` that wraps `value`.
|
||||
"""
|
||||
with ops.name_scope("optional") as scope:
|
||||
with ops.name_scope("value"):
|
||||
element_spec = structure.type_spec_from_value(value)
|
||||
encoded_value = structure.to_tensor_list(element_spec, value)
|
||||
|
||||
return _OptionalImpl(
|
||||
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
|
||||
element_spec)
|
||||
|
||||
|
||||
class _OptionalImpl(Optional):
|
||||
"""Concrete implementation of `tf.data.experimental.Optional`.
|
||||
"""Concrete implementation of `tf.experimental.Optional`.
|
||||
|
||||
NOTE(mrry): This implementation is kept private, to avoid defining
|
||||
`Optional.__init__()` in the public API.
|
||||
"""
|
||||
|
||||
def __init__(self, variant_tensor, value_structure):
|
||||
def __init__(self, variant_tensor, element_spec):
|
||||
self._variant_tensor = variant_tensor
|
||||
self._value_structure = value_structure
|
||||
self._element_spec = element_spec
|
||||
|
||||
def has_value(self, name=None):
|
||||
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
|
||||
@ -141,18 +183,18 @@ class _OptionalImpl(Optional):
|
||||
with ops.name_scope(name, "OptionalGetValue",
|
||||
[self._variant_tensor]) as scope:
|
||||
return structure.from_tensor_list(
|
||||
self._value_structure,
|
||||
self._element_spec,
|
||||
gen_dataset_ops.optional_get_value(
|
||||
self._variant_tensor,
|
||||
name=scope,
|
||||
output_types=structure.get_flat_tensor_types(
|
||||
self._value_structure),
|
||||
self._element_spec),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
self._value_structure)))
|
||||
self._element_spec)))
|
||||
|
||||
@property
|
||||
def value_structure(self):
|
||||
return self._value_structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
@ -162,19 +204,38 @@ class _OptionalImpl(Optional):
|
||||
@tf_export(
|
||||
"OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"])
|
||||
class OptionalSpec(type_spec.TypeSpec):
|
||||
"""Represents an optional potentially containing a structured value."""
|
||||
"""Type specification for `tf.experimental.Optional`.
|
||||
|
||||
__slots__ = ["_value_structure"]
|
||||
For instance, `tf.OptionalSpec` can be used to define a tf.function that takes
|
||||
`tf.experimental.Optional` as an input argument:
|
||||
|
||||
def __init__(self, value_structure):
|
||||
self._value_structure = value_structure
|
||||
>>> @tf.function(input_signature=[tf.OptionalSpec(
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
|
||||
... def maybe_square(optional):
|
||||
... if optional.has_value():
|
||||
... x = optional.get_value()
|
||||
... return x * x
|
||||
... return -1
|
||||
>>> optional = tf.experimental.Optional.from_value(5)
|
||||
>>> print(maybe_square(optional))
|
||||
tf.Tensor(25, shape=(), dtype=int32)
|
||||
|
||||
Attributes:
|
||||
element_spec: A nested structure of `TypeSpec` objects that represents the
|
||||
type specification of the optional element.
|
||||
"""
|
||||
|
||||
__slots__ = ["_element_spec"]
|
||||
|
||||
def __init__(self, element_spec):
|
||||
self._element_spec = element_spec
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return _OptionalImpl
|
||||
|
||||
def _serialize(self):
|
||||
return (self._value_structure,)
|
||||
return (self._element_spec,)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -185,11 +246,11 @@ class OptionalSpec(type_spec.TypeSpec):
|
||||
|
||||
def _from_components(self, flat_value):
|
||||
# pylint: disable=protected-access
|
||||
return _OptionalImpl(flat_value[0], self._value_structure)
|
||||
return _OptionalImpl(flat_value[0], self._element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return OptionalSpec(value.value_structure)
|
||||
return OptionalSpec(value.element_spec)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self
|
||||
|
@ -1130,16 +1130,16 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
real_data = control_flow_ops.cond(
|
||||
data.has_value(),
|
||||
lambda: data.get_value(),
|
||||
lambda: _dummy_tensor_fn(data.value_structure),
|
||||
lambda: _dummy_tensor_fn(data.element_spec),
|
||||
strict=True,
|
||||
)
|
||||
# Some dimensions in `replicas` will become unknown after we
|
||||
# conditionally return the real tensors or the dummy tensors. Recover
|
||||
# the shapes from `data.value_structure`. We only need to do this in
|
||||
# the shapes from `data.element_spec`. We only need to do this in
|
||||
# non eager mode because we always know the runtime shape of the
|
||||
# tensors in eager mode.
|
||||
if not context.executing_eagerly():
|
||||
real_data = _recover_shape_fn(real_data, data.value_structure)
|
||||
real_data = _recover_shape_fn(real_data, data.element_spec)
|
||||
result.append(real_data)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
@ -9,7 +9,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -2,9 +2,7 @@ path: "tensorflow.data.Dataset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -9,7 +9,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -4,12 +4,16 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_structure"
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "empty"
|
||||
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -22,8 +26,4 @@ tf_class {
|
||||
name: "has_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "none_from_structure"
|
||||
argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,9 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -0,0 +1,29 @@
|
||||
path: "tensorflow.experimental.Optional"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "empty"
|
||||
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "has_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Optional"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "async_clear_error"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -9,7 +9,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -1,9 +1,7 @@
|
||||
path: "tensorflow.data.Dataset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
|
@ -3,9 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.readers.FixedLengthRecordDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -0,0 +1,26 @@
|
||||
path: "tensorflow.data.IteratorSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
path: "tensorflow.data.Iterator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorBase\'>"
|
||||
is_instance: "<class \'collections.abc.Iterator\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next_as_optional"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -2,9 +2,7 @@ path: "tensorflow.data.TFRecordDataset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -3,9 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.readers.TextLineDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -3,9 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,12 +4,16 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_structure"
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "empty"
|
||||
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -22,8 +26,4 @@ tf_class {
|
||||
name: "has_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "none_from_structure"
|
||||
argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -3,9 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -3,9 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -16,6 +16,14 @@ tf_module {
|
||||
name: "INFINITE_CARDINALITY"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member {
|
||||
name: "Iterator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "IteratorSpec"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Options"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,29 @@
|
||||
path: "tensorflow.experimental.Optional"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "empty"
|
||||
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "has_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Optional"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "dlpack"
|
||||
mtype: "<type \'module\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user