[tf.data] API changes.

This CL makes the following tf.data API-related changes:
1) `tf.data.Iterator` and `tf.data.IteratorSpec` are exposed in the v2 API
2) `tf.experimental.Optional` is exposed in the API (previously exposed as `tf.data.experimental.Optional`)
3) `tf.experimental.Optional.none_from_structure` and `tf.experimental.Optional.value_structure` is renamed to and `tf.experimental.Optional.empty` and `tf.experimental.Optional.element_spec` respectively
4) `tf.OptionalSpec.value_structure` is renamed to `tf.OptionalSpec.element_spec`
5) reflects these changes in documentation and code
6) adds testable docstring for newly exposed APIs

PiperOrigin-RevId: 316003328
Change-Id: I7b7e79942308b3d2f94b988c31729980fb69d961
This commit is contained in:
Jiri Simsa 2020-06-11 16:41:20 -07:00 committed by TensorFlower Gardener
parent b6d13bb0a8
commit cfe037e3fe
32 changed files with 482 additions and 240 deletions

View File

@ -21,7 +21,7 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@deprecation.deprecated(None, "Use `tf.data.Dataset.enumerate()") @deprecation.deprecated(None, "Use `tf.data.Dataset.enumerate()`.")
@tf_export("data.experimental.enumerate_dataset") @tf_export("data.experimental.enumerate_dataset")
def enumerate_dataset(start=0): def enumerate_dataset(start=0):
"""A transformation that enumerates the elements of a dataset. """A transformation that enumerates the elements of a dataset.

View File

@ -112,7 +112,8 @@ def _get_next_as_optional_test_combinations():
def reduce_fn(x, y): def reduce_fn(x, y):
name, value, value_fn, gpu_compatible = y name, value, value_fn, gpu_compatible = y
return x + combinations.combine( return x + combinations.combine(
np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn), np_value=value,
tf_value_fn=combinations.NamedObject(name, value_fn),
gpu_compatible=gpu_compatible) gpu_compatible=gpu_compatible)
return functools.reduce(reduce_fn, cases, []) return functools.reduce(reduce_fn, cases, [])
@ -160,13 +161,13 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testFromNone(self): def testFromNone(self):
value_structure = tensor_spec.TensorSpec([], dtypes.float32) value_structure = tensor_spec.TensorSpec([], dtypes.float32)
opt = optional_ops.Optional.none_from_structure(value_structure) opt = optional_ops.Optional.empty(value_structure)
self.assertTrue(opt.value_structure.is_compatible_with(value_structure)) self.assertTrue(opt.element_spec.is_compatible_with(value_structure))
self.assertFalse( self.assertFalse(
opt.value_structure.is_compatible_with( opt.element_spec.is_compatible_with(
tensor_spec.TensorSpec([1], dtypes.float32))) tensor_spec.TensorSpec([1], dtypes.float32)))
self.assertFalse( self.assertFalse(
opt.value_structure.is_compatible_with( opt.element_spec.is_compatible_with(
tensor_spec.TensorSpec([], dtypes.int32))) tensor_spec.TensorSpec([], dtypes.int32)))
self.assertFalse(self.evaluate(opt.has_value())) self.assertFalse(self.evaluate(opt.has_value()))
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
@ -183,20 +184,17 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt1 = optional_ops.Optional.from_value((1.0, 2.0)) opt1 = optional_ops.Optional.from_value((1.0, 2.0))
opt2 = optional_ops.Optional.from_value((3.0, 4.0)) opt2 = optional_ops.Optional.from_value((3.0, 4.0))
add_tensor = math_ops.add_n([opt1._variant_tensor, add_tensor = math_ops.add_n(
opt2._variant_tensor]) [opt1._variant_tensor, opt2._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure) add_opt = optional_ops._OptionalImpl(add_tensor, opt1.element_spec)
self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0)) self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
# Without value # Without value
opt_none1 = optional_ops.Optional.none_from_structure( opt_none1 = optional_ops.Optional.empty(opt1.element_spec)
opt1.value_structure) opt_none2 = optional_ops.Optional.empty(opt2.element_spec)
opt_none2 = optional_ops.Optional.none_from_structure( add_tensor = math_ops.add_n(
opt2.value_structure) [opt_none1._variant_tensor, opt_none2._variant_tensor])
add_tensor = math_ops.add_n([opt_none1._variant_tensor, add_opt = optional_ops._OptionalImpl(add_tensor, opt_none1.element_spec)
opt_none2._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor,
opt_none1.value_structure)
self.assertFalse(self.evaluate(add_opt.has_value())) self.assertFalse(self.evaluate(add_opt.has_value()))
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
@ -211,13 +209,13 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor)) opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor)) opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
add_tensor = math_ops.add_n([opt3._variant_tensor, add_tensor = math_ops.add_n(
opt4._variant_tensor]) [opt3._variant_tensor, opt4._variant_tensor])
add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure) add_opt = optional_ops._OptionalImpl(add_tensor, opt3.element_spec)
self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0) self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1], inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
opt1.value_structure) opt1.element_spec)
self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0]) self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
@ -230,17 +228,14 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
# With value # With value
opt = optional_ops.Optional.from_value((1.0, 2.0)) opt = optional_ops.Optional.from_value((1.0, 2.0))
zeros_tensor = array_ops.zeros_like(opt._variant_tensor) zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor, zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt.element_spec)
opt.value_structure) self.assertAllEqual(self.evaluate(zeros_opt.get_value()), (0.0, 0.0))
self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
(0.0, 0.0))
# Without value # Without value
opt_none = optional_ops.Optional.none_from_structure( opt_none = optional_ops.Optional.empty(opt.element_spec)
opt.value_structure)
zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor) zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor, zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
opt_none.value_structure) opt_none.element_spec)
self.assertFalse(self.evaluate(zeros_opt.has_value())) self.assertFalse(self.evaluate(zeros_opt.has_value()))
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
@ -254,10 +249,9 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
opt2 = optional_ops.Optional.from_value(opt1._variant_tensor) opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
zeros_tensor = array_ops.zeros_like(opt2._variant_tensor) zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
zeros_opt = optional_ops._OptionalImpl(zeros_tensor, zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt2.element_spec)
opt2.value_structure)
inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(), inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
opt1.value_structure) opt1.element_spec)
self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0) self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
@ -269,16 +263,16 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
optional_with_value = optional_ops.Optional.from_value( optional_with_value = optional_ops.Optional.from_value(
(constant_op.constant(37.0), constant_op.constant("Foo"), (constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42))) constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure( optional_none = optional_ops.Optional.empty(
tensor_spec.TensorSpec([], dtypes.float32)) tensor_spec.TensorSpec([], dtypes.float32))
with ops.device("/gpu:0"): with ops.device("/gpu:0"):
gpu_optional_with_value = optional_ops._OptionalImpl( gpu_optional_with_value = optional_ops._OptionalImpl(
array_ops.identity(optional_with_value._variant_tensor), array_ops.identity(optional_with_value._variant_tensor),
optional_with_value.value_structure) optional_with_value.element_spec)
gpu_optional_none = optional_ops._OptionalImpl( gpu_optional_none = optional_ops._OptionalImpl(
array_ops.identity(optional_none._variant_tensor), array_ops.identity(optional_none._variant_tensor),
optional_none.value_structure) optional_none.element_spec)
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value() gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
gpu_optional_with_value_values = gpu_optional_with_value.get_value() gpu_optional_with_value_values = gpu_optional_with_value.get_value()
@ -299,7 +293,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
optional_with_value = optional_ops.Optional.from_value( optional_with_value = optional_ops.Optional.from_value(
(constant_op.constant(37.0), constant_op.constant("Foo"), (constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42))) constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure( optional_none = optional_ops.Optional.empty(
tensor_spec.TensorSpec([], dtypes.float32)) tensor_spec.TensorSpec([], dtypes.float32))
nested_optional = optional_ops.Optional.from_value( nested_optional = optional_ops.Optional.from_value(
(optional_with_value._variant_tensor, optional_none._variant_tensor, (optional_with_value._variant_tensor, optional_none._variant_tensor,
@ -308,7 +302,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
with ops.device("/gpu:0"): with ops.device("/gpu:0"):
gpu_nested_optional = optional_ops._OptionalImpl( gpu_nested_optional = optional_ops._OptionalImpl(
array_ops.identity(nested_optional._variant_tensor), array_ops.identity(nested_optional._variant_tensor),
nested_optional.value_structure) nested_optional.element_spec)
gpu_nested_optional_has_value = gpu_nested_optional.has_value() gpu_nested_optional_has_value = gpu_nested_optional.has_value()
gpu_nested_optional_values = gpu_nested_optional.get_value() gpu_nested_optional_values = gpu_nested_optional.get_value()
@ -316,10 +310,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(self.evaluate(gpu_nested_optional_has_value)) self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
inner_with_value = optional_ops._OptionalImpl( inner_with_value = optional_ops._OptionalImpl(
gpu_nested_optional_values[0], optional_with_value.value_structure) gpu_nested_optional_values[0], optional_with_value.element_spec)
inner_none = optional_ops._OptionalImpl( inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1],
gpu_nested_optional_values[1], optional_none.value_structure) optional_none.element_spec)
self.assertEqual((37.0, b"Foo", 42), self.assertEqual((37.0, b"Foo", 42),
self.evaluate(inner_with_value.get_value())) self.evaluate(inner_with_value.get_value()))
@ -327,21 +321,20 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2])) self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
@combinations.generate( @combinations.generate(
combinations.times( combinations.times(test_base.default_test_combinations(),
test_base.default_test_combinations(), _optional_spec_test_combinations()))
_optional_spec_test_combinations()))
def testOptionalSpec(self, tf_value_fn, expected_value_structure): def testOptionalSpec(self, tf_value_fn, expected_value_structure):
tf_value = tf_value_fn() tf_value = tf_value_fn()
opt = optional_ops.Optional.from_value(tf_value) opt = optional_ops.Optional.from_value(tf_value)
self.assertTrue( self.assertTrue(
structure.are_compatible(opt.value_structure, expected_value_structure)) structure.are_compatible(opt.element_spec, expected_value_structure))
opt_structure = structure.type_spec_from_value(opt) opt_structure = structure.type_spec_from_value(opt)
self.assertIsInstance(opt_structure, optional_ops.OptionalSpec) self.assertIsInstance(opt_structure, optional_ops.OptionalSpec)
self.assertTrue(structure.are_compatible(opt_structure, opt_structure)) self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
self.assertTrue( self.assertTrue(
structure.are_compatible(opt_structure._value_structure, structure.are_compatible(opt_structure._element_spec,
expected_value_structure)) expected_value_structure))
self.assertEqual([dtypes.variant], self.assertEqual([dtypes.variant],
structure.get_flat_tensor_types(opt_structure)) structure.get_flat_tensor_types(opt_structure))
@ -364,13 +357,11 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(round_trip_opt.get_value().get_value())) self.evaluate(round_trip_opt.get_value().get_value()))
else: else:
self.assertValuesEqual( self.assertValuesEqual(
self.evaluate(tf_value), self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
self.evaluate(round_trip_opt.get_value()))
@combinations.generate( @combinations.generate(
combinations.times( combinations.times(test_base.default_test_combinations(),
test_base.default_test_combinations(), _get_next_as_optional_test_combinations()))
_get_next_as_optional_test_combinations()))
def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
gpu_compatible): gpu_compatible):
if not gpu_compatible and test.is_gpu_available(): if not gpu_compatible and test.is_gpu_available():
@ -384,9 +375,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
for _ in range(3): for _ in range(3):
next_elem = iterator_ops.get_next_as_optional(iterator) next_elem = iterator_ops.get_next_as_optional(iterator)
self.assertIsInstance(next_elem, optional_ops.Optional) self.assertIsInstance(next_elem, optional_ops.Optional)
self.assertTrue(structure.are_compatible( self.assertTrue(
next_elem.value_structure, structure.are_compatible(
structure.type_spec_from_value(tf_value_fn()))) next_elem.element_spec,
structure.type_spec_from_value(tf_value_fn())))
self.assertTrue(next_elem.has_value()) self.assertTrue(next_elem.has_value())
self.assertValuesEqual(np_value, next_elem.get_value()) self.assertValuesEqual(np_value, next_elem.get_value())
# After exhausting the iterator, `next_elem.has_value()` will evaluate to # After exhausting the iterator, `next_elem.has_value()` will evaluate to
@ -400,9 +392,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
iterator = dataset_ops.make_initializable_iterator(ds) iterator = dataset_ops.make_initializable_iterator(ds)
next_elem = iterator_ops.get_next_as_optional(iterator) next_elem = iterator_ops.get_next_as_optional(iterator)
self.assertIsInstance(next_elem, optional_ops.Optional) self.assertIsInstance(next_elem, optional_ops.Optional)
self.assertTrue(structure.are_compatible( self.assertTrue(
next_elem.value_structure, structure.are_compatible(
structure.type_spec_from_value(tf_value_fn()))) next_elem.element_spec,
structure.type_spec_from_value(tf_value_fn())))
# Before initializing the iterator, evaluating the optional fails with # Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError. This is only relevant in graph mode. # a FailedPreconditionError. This is only relevant in graph mode.
elem_has_value_t = next_elem.has_value() elem_has_value_t = next_elem.has_value()
@ -430,6 +423,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testFunctionBoundaries(self): def testFunctionBoundaries(self):
@def_function.function @def_function.function
def get_optional(): def get_optional():
x = constant_op.constant(1.0) x = constant_op.constant(1.0)

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc import abc
import collections
import functools import functools
import sys import sys
import threading import threading
@ -102,7 +103,8 @@ tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
@tf_export("data.Dataset", v1=[]) @tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): class DatasetV2(collections.Iterable, tracking_base.Trackable,
composite_tensor.CompositeTensor):
"""Represents a potentially large set of elements. """Represents a potentially large set of elements.
The `tf.data.Dataset` API supports writing descriptive and efficient input The `tf.data.Dataset` API supports writing descriptive and efficient input
@ -399,12 +401,12 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
return dataset return dataset
def __iter__(self): def __iter__(self):
"""Creates an `Iterator` for enumerating the elements of this dataset. """Creates an iterator for elements of this dataset.
The returned iterator implements the Python iterator protocol. The returned iterator implements the Python Iterator protocol.
Returns: Returns:
An `Iterator` over the elements of this dataset. An `tf.data.Iterator` for the elements of this dataset.
Raises: Raises:
RuntimeError: If not inside of tf.function and not executing eagerly. RuntimeError: If not inside of tf.function and not executing eagerly.
@ -740,7 +742,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Note: The current implementation of `Dataset.from_generator()` uses Note: The current implementation of `Dataset.from_generator()` uses
`tf.numpy_function` and inherits the same constraints. In particular, it `tf.numpy_function` and inherits the same constraints. In particular, it
requires the `Dataset`- and `Iterator`-related operations to be placed requires the dataset and iterator related operations to be placed
on a device in the same process as the Python program that called on a device in the same process as the Python program that called
`Dataset.from_generator()`. The body of `generator` will not be `Dataset.from_generator()`. The body of `generator` will not be
serialized in a `GraphDef`, and you should not use this method if you serialized in a `GraphDef`, and you should not use this method if you
@ -2208,7 +2210,7 @@ class DatasetV1(DatasetV2):
"code base as there are in general no guarantees about the " "code base as there are in general no guarantees about the "
"interoperability of TF 1 and TF 2 code.") "interoperability of TF 1 and TF 2 code.")
def make_one_shot_iterator(self): def make_one_shot_iterator(self):
"""Creates an `Iterator` for enumerating the elements of this dataset. """Creates an iterator for elements of this dataset.
Note: The returned iterator will be initialized automatically. Note: The returned iterator will be initialized automatically.
A "one-shot" iterator does not currently support re-initialization. For A "one-shot" iterator does not currently support re-initialization. For
@ -2231,7 +2233,7 @@ class DatasetV1(DatasetV2):
``` ```
Returns: Returns:
An `Iterator` over the elements of this dataset. An `tf.data.Iterator` for elements of this dataset.
""" """
return self._make_one_shot_iterator() return self._make_one_shot_iterator()
@ -2301,7 +2303,7 @@ class DatasetV1(DatasetV2):
"are in general no guarantees about the interoperability of TF 1 and TF " "are in general no guarantees about the interoperability of TF 1 and TF "
"2 code.") "2 code.")
def make_initializable_iterator(self, shared_name=None): def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset. """Creates an iterator for elements of this dataset.
Note: The returned iterator will be in an uninitialized state, Note: The returned iterator will be in an uninitialized state,
and you must run the `iterator.initializer` operation before using it: and you must run the `iterator.initializer` operation before using it:
@ -2328,7 +2330,7 @@ class DatasetV1(DatasetV2):
devices (e.g. when using a remote server). devices (e.g. when using a remote server).
Returns: Returns:
An `Iterator` over the elements of this dataset. A `tf.data.Iterator` for elements of this dataset.
Raises: Raises:
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
@ -2676,7 +2678,7 @@ def _ensure_same_dataset_graph(dataset):
@tf_export(v1=["data.make_one_shot_iterator"]) @tf_export(v1=["data.make_one_shot_iterator"])
def make_one_shot_iterator(dataset): def make_one_shot_iterator(dataset):
"""Creates a `tf.compat.v1.data.Iterator` for enumerating dataset elements. """Creates an iterator for elements of `dataset`.
Note: The returned iterator will be initialized automatically. Note: The returned iterator will be initialized automatically.
A "one-shot" iterator does not support re-initialization. A "one-shot" iterator does not support re-initialization.
@ -2685,7 +2687,7 @@ def make_one_shot_iterator(dataset):
dataset: A `tf.data.Dataset`. dataset: A `tf.data.Dataset`.
Returns: Returns:
A `tf.compat.v1.data.Iterator` over the elements of this dataset. A `tf.data.Iterator` for elements of `dataset`.
""" """
try: try:
# Call the defined `_make_one_shot_iterator()` if there is one, because some # Call the defined `_make_one_shot_iterator()` if there is one, because some
@ -2697,7 +2699,7 @@ def make_one_shot_iterator(dataset):
@tf_export(v1=["data.make_initializable_iterator"]) @tf_export(v1=["data.make_initializable_iterator"])
def make_initializable_iterator(dataset, shared_name=None): def make_initializable_iterator(dataset, shared_name=None):
"""Creates a `tf.compat.v1.data.Iterator` for enumerating the elements of a dataset. """Creates an iterator for elements of `dataset`.
Note: The returned iterator will be in an uninitialized state, Note: The returned iterator will be in an uninitialized state,
and you must run the `iterator.initializer` operation before using it: and you must run the `iterator.initializer` operation before using it:
@ -2716,7 +2718,7 @@ def make_initializable_iterator(dataset, shared_name=None):
(e.g. when using a remote server). (e.g. when using a remote server).
Returns: Returns:
A `tf.compat.v1.data.Iterator` over the elements of `dataset`. A `tf.data.Iterator` for elements of `dataset`.
Raises: Raises:
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
@ -2731,10 +2733,10 @@ def make_initializable_iterator(dataset, shared_name=None):
@tf_export("data.experimental.get_structure") @tf_export("data.experimental.get_structure")
def get_structure(dataset_or_iterator): def get_structure(dataset_or_iterator):
"""Returns the type specification of an element of a `Dataset` or `Iterator`. """Returns the type signature for elements of the input dataset / iterator.
Args: Args:
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
Returns: Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an A nested structure of `tf.TypeSpec` objects matching the structure of an
@ -2742,21 +2744,20 @@ def get_structure(dataset_or_iterator):
components. components.
Raises: Raises:
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object. TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
object.
""" """
try: try:
return dataset_or_iterator.element_spec # pylint: disable=protected-access return dataset_or_iterator.element_spec # pylint: disable=protected-access
except AttributeError: except AttributeError:
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator " raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or "
"object, but got %s." % type(dataset_or_iterator)) "tf.data.Iterator object, but got %s." %
type(dataset_or_iterator))
@tf_export(v1=["data.get_output_classes"]) @tf_export(v1=["data.get_output_classes"])
def get_legacy_output_classes(dataset_or_iterator): def get_legacy_output_classes(dataset_or_iterator):
"""Returns the output classes of a `Dataset` or `Iterator` elements. """Returns the output classes for elements of the input dataset / iterator.
This utility method replaces the deprecated-in-V2
`tf.compat.v1.Dataset.output_classes` property.
Args: Args:
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
@ -2773,10 +2774,7 @@ def get_legacy_output_classes(dataset_or_iterator):
@tf_export(v1=["data.get_output_shapes"]) @tf_export(v1=["data.get_output_shapes"])
def get_legacy_output_shapes(dataset_or_iterator): def get_legacy_output_shapes(dataset_or_iterator):
"""Returns the output shapes of a `Dataset` or `Iterator` elements. """Returns the output shapes for elements of the input dataset / iterator.
This utility method replaces the deprecated-in-V2
`tf.compat.v1.Dataset.output_shapes` property.
Args: Args:
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
@ -2793,10 +2791,7 @@ def get_legacy_output_shapes(dataset_or_iterator):
@tf_export(v1=["data.get_output_types"]) @tf_export(v1=["data.get_output_types"])
def get_legacy_output_types(dataset_or_iterator): def get_legacy_output_types(dataset_or_iterator):
"""Returns the output shapes of a `Dataset` or `Iterator` elements. """Returns the output shapes for elements of the input dataset / iterator.
This utility method replaces the deprecated-in-V2
`tf.compat.v1.Dataset.output_types` property.
Args: Args:
dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.

View File

@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc
import collections
import threading import threading
import warnings import warnings
import six
from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
@ -489,12 +493,6 @@ class Iterator(trackable.Trackable):
@property @property
def element_spec(self): def element_spec(self):
"""The type specification of an element of this iterator.
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this iterator and specifying the type of individual components.
"""
return self._element_spec return self._element_spec
def _gather_saveables_for_checkpoint(self): def _gather_saveables_for_checkpoint(self):
@ -543,7 +541,102 @@ class IteratorResourceDeleter(object):
handle=self._handle, deleter=self._deleter) handle=self._handle, deleter=self._deleter)
class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor): @tf_export("data.Iterator", v1=[])
@six.add_metaclass(abc.ABCMeta)
class IteratorBase(collections.Iterator, trackable.Trackable,
composite_tensor.CompositeTensor):
"""Represents an iterator of a `tf.data.Dataset`.
`tf.data.Iterator` is the primary mechanism for enumerating elements of a
`tf.data.Dataset`. It supports the Python Iterator protocol, which means
it can be iterated over using a for-loop:
>>> dataset = tf.data.Dataset.range(2)
>>> for element in dataset:
... print(element)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
or by fetching individual elements explicitly via `get_next()`:
>>> dataset = tf.data.Dataset.range(2)
>>> iterator = iter(dataset)
>>> print(iterator.get_next())
tf.Tensor(0, shape=(), dtype=int64)
>>> print(iterator.get_next())
tf.Tensor(1, shape=(), dtype=int64)
In addition, non-raising iteration is supported via `get_next_as_optional()`,
which returns the next element (if available) wrapped in a
`tf.experimental.Optional`.
>>> dataset = tf.data.Dataset.from_tensors(42)
>>> iterator = iter(dataset)
>>> optional = iterator.get_next_as_optional()
>>> print(optional.has_value())
tf.Tensor(True, shape=(), dtype=bool)
>>> optional = iterator.get_next_as_optional()
>>> print(optional.has_value())
tf.Tensor(False, shape=(), dtype=bool)
"""
@abc.abstractproperty
def element_spec(self):
"""The type specification of an element of this iterator.
>>> dataset = tf.data.Dataset.from_tensors(42)
>>> iterator = iter(dataset)
>>> iterator.element_spec
tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this iterator, specifying the type of individual components.
"""
raise NotImplementedError("Iterator.element_spec")
@abc.abstractmethod
def get_next(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
>>> dataset = tf.data.Dataset.from_tensors(42)
>>> iterator = iter(dataset)
>>> print(iterator.get_next())
tf.Tensor(42, shape=(), dtype=int32)
Returns:
A nested structure of `tf.Tensor` objects.
Raises:
`tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
"""
raise NotImplementedError("Iterator.get_next()")
@abc.abstractmethod
def get_next_as_optional(self):
"""Returns a `tf.experimental.Optional` which contains the next element.
If the iterator has reached the end of the sequence, the returned
`tf.experimental.Optional` will have no value.
>>> dataset = tf.data.Dataset.from_tensors(42)
>>> iterator = iter(dataset)
>>> optional = iterator.get_next_as_optional()
>>> print(optional.has_value())
tf.Tensor(True, shape=(), dtype=bool)
>>> print(optional.get_value())
tf.Tensor(42, shape=(), dtype=int32)
>>> optional = iterator.get_next_as_optional()
>>> print(optional.has_value())
tf.Tensor(False, shape=(), dtype=bool)
Returns:
A `tf.experimental.Optional` object representing the next element.
"""
raise NotImplementedError("Iterator.get_next_as_optional()")
class OwnedIterator(IteratorBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset. """An iterator producing tf.Tensor objects from a tf.data.Dataset.
The iterator resource created through `OwnedIterator` is owned by the Python The iterator resource created through `OwnedIterator` is owned by the Python
@ -578,7 +671,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
`element_spec` is not provided. Or `dataset` is provided and either `element_spec` is not provided. Or `dataset` is provided and either
`components` and `element_spec` is provided. `components` and `element_spec` is provided.
""" """
error_message = ("Either `dataset` or both `components` and " error_message = ("Either `dataset` or both `components` and "
"`element_spec` need to be provided.") "`element_spec` need to be provided.")
@ -644,8 +736,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
return self.next() return self.next()
def _next_internal(self): def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
"""
if not context.executing_eagerly(): if not context.executing_eagerly():
with ops.device(self._device): with ops.device(self._device):
ret = gen_dataset_ops.iterator_get_next( ret = gen_dataset_ops.iterator_get_next(
@ -659,7 +749,7 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
# TODO(b/77291417): Fix # TODO(b/77291417): Fix
with context.execution_mode(context.SYNC): with context.execution_mode(context.SYNC):
with ops.device(self._device): with ops.device(self._device):
# TODO(ashankar): Consider removing this ops.device() contextmanager # TODO(ashankar): Consider removing this ops.device() context manager
# and instead mimic ops placement in graphs: Operations on resource # and instead mimic ops placement in graphs: Operations on resource
# handles execute on the same device as where the resource is placed. # handles execute on the same device as where the resource is placed.
ret = gen_dataset_ops.iterator_get_next( ret = gen_dataset_ops.iterator_get_next(
@ -678,7 +768,6 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
return IteratorSpec(self.element_spec) return IteratorSpec(self.element_spec)
def next(self): def next(self):
"""Returns a nested structure of `Tensor`s containing the next element."""
try: try:
return self._next_internal() return self._next_internal()
except errors.OutOfRangeError: except errors.OutOfRangeError:
@ -730,29 +819,20 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
@property @property
def element_spec(self): def element_spec(self):
"""The type specification of an element of this iterator.
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this iterator and specifying the type of individual components.
"""
return self._element_spec return self._element_spec
def get_next(self, name=None): def get_next(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
Args:
name: (Optional.) A name for the created operation. Currently unused.
Returns:
A nested structure of `tf.Tensor` objects.
Raises:
`tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
"""
del name
return self._next_internal() return self._next_internal()
def get_next_as_optional(self):
# pylint: disable=protected-access
return optional_ops._OptionalImpl(
gen_dataset_ops.iterator_get_next_as_optional(
self._iterator_resource,
output_types=structure.get_flat_tensor_types(self.element_spec),
output_shapes=structure.get_flat_tensor_shapes(
self.element_spec)), self.element_spec)
def _gather_saveables_for_checkpoint(self): def _gather_saveables_for_checkpoint(self):
def _saveable_factory(name): def _saveable_factory(name):
@ -771,9 +851,27 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
return {"ITERATOR": _saveable_factory} return {"ITERATOR": _saveable_factory}
# TODO(jsimsa): Export this as "tf.data.IteratorSpec". @tf_export("data.IteratorSpec", v1=[])
class IteratorSpec(type_spec.TypeSpec): class IteratorSpec(type_spec.TypeSpec):
"""Type specification for `OwnedIterator`.""" """Type specification for `tf.data.Iterator`.
For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
takes `tf.data.Iterator` as an input argument:
>>> @tf.function(input_signature=[tf.data.IteratorSpec(
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
... def square(iterator):
... x = iterator.get_next()
... return x * x
>>> dataset = tf.data.Dataset.from_tensors(5)
>>> iterator = iter(dataset)
>>> print(square(iterator))
tf.Tensor(25, shape=(), dtype=int32)
Attributes:
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of the iterator elements.
"""
__slots__ = ["_element_spec"] __slots__ = ["_element_spec"]
@ -833,19 +931,21 @@ class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
@deprecation.deprecated(
None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
@tf_export("data.experimental.get_next_as_optional") @tf_export("data.experimental.get_next_as_optional")
def get_next_as_optional(iterator): def get_next_as_optional(iterator):
"""Returns an `Optional` that contains the next value from the iterator. """Returns a `tf.experimental.Optional` with the next element of the iterator.
If `iterator` has reached the end of the sequence, the returned `Optional` If the iterator has reached the end of the sequence, the returned
will have no value. `tf.experimental.Optional` will have no value.
Args: Args:
iterator: An iterator for an instance of `tf.data.Dataset`. iterator: A `tf.data.Iterator`.
Returns: Returns:
An `Optional` object representing the next value from the iterator (if it A `tf.experimental.Optional` object which either contains the next element
has one) or no value. of the iterator (if it exists) or no value.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
return optional_ops._OptionalImpl( return optional_ops._OptionalImpl(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""An Optional type for representing potentially missing values.""" """A type for representing values that may or may not exist."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -28,28 +28,51 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.Optional") @tf_export("experimental.Optional", "data.experimental.Optional")
@deprecation.deprecated_endpoints("data.experimental.Optional")
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Optional(composite_tensor.CompositeTensor): class Optional(composite_tensor.CompositeTensor):
"""Wraps a value that may/may not be present at runtime. """Represents a value that may or may not be present.
An `Optional` can represent the result of an operation that may fail as a A `tf.experimental.Optional` can represent the result of an operation that may
value, rather than raising an exception and halting execution. For example, fail as a value, rather than raising an exception and halting execution. For
`tf.data.experimental.get_next_as_optional` returns an `Optional` that either example, `tf.data.Iterator.get_next_as_optional()` returns a
contains the next value of an iterator if one exists, or a "none" value that `tf.experimental.Optional` that either contains the next element of an
indicates the end of the sequence has been reached. iterator if one exists, or an "empty" value that indicates the end of the
sequence has been reached.
`Optional` can only be used by values that are convertible to `Tensor` or `tf.experimental.Optional` can only be used with values that are convertible
`CompositeTensor`. to `tf.Tensor` or `tf.CompositeTensor`.
One can create a `tf.experimental.Optional` from a value using the
`from_value()` method:
>>> optional = tf.experimental.Optional.from_value(42)
>>> print(optional.has_value())
tf.Tensor(True, shape=(), dtype=bool)
>>> print(optional.get_value())
tf.Tensor(42, shape=(), dtype=int32)
or without a value using the `empty()` method:
>>> optional = tf.experimental.Optional.empty(
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
>>> print(optional.has_value())
tf.Tensor(False, shape=(), dtype=bool)
""" """
@abc.abstractmethod @abc.abstractmethod
def has_value(self, name=None): def has_value(self, name=None):
"""Returns a tensor that evaluates to `True` if this optional has a value. """Returns a tensor that evaluates to `True` if this optional has a value.
>>> optional = tf.experimental.Optional.from_value(42)
>>> print(optional.has_value())
tf.Tensor(True, shape=(), dtype=bool)
Args: Args:
name: (Optional.) A name for the created operation. name: (Optional.) A name for the created operation.
@ -62,9 +85,13 @@ class Optional(composite_tensor.CompositeTensor):
def get_value(self, name=None): def get_value(self, name=None):
"""Returns the value wrapped by this optional. """Returns the value wrapped by this optional.
If this optional does not have a value (i.e. `self.has_value()` evaluates If this optional does not have a value (i.e. `self.has_value()` evaluates to
to `False`), this operation will raise `tf.errors.InvalidArgumentError` `False`), this operation will raise `tf.errors.InvalidArgumentError` at
at runtime. runtime.
>>> optional = tf.experimental.Optional.from_value(42)
>>> print(optional.get_value())
tf.Tensor(42, shape=(), dtype=int32)
Args: Args:
name: (Optional.) A name for the created operation. name: (Optional.) A name for the created operation.
@ -75,62 +102,77 @@ class Optional(composite_tensor.CompositeTensor):
raise NotImplementedError("Optional.get_value()") raise NotImplementedError("Optional.get_value()")
@abc.abstractproperty @abc.abstractproperty
def value_structure(self): def element_spec(self):
"""The structure of the components of this optional. """The type specification of an element of this optional.
>>> optional = tf.experimental.Optional.from_value(42)
>>> print(optional.element_spec)
tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
Returns: Returns:
A `Structure` object representing the structure of the components of this A nested structure of `tf.TypeSpec` objects matching the structure of an
optional. element of this optional, specifying the type of individual components.
""" """
raise NotImplementedError("Optional.value_structure") raise NotImplementedError("Optional.element_spec")
@staticmethod @staticmethod
def from_value(value): def empty(element_spec):
"""Returns an `Optional` that wraps the given value.
Args:
value: A value to wrap. The value must be convertible to `Tensor` or
`CompositeTensor`.
Returns:
An `Optional` that wraps `value`.
"""
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
value_structure = structure.type_spec_from_value(value)
encoded_value = structure.to_tensor_list(value_structure, value)
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
value_structure)
@staticmethod
def none_from_structure(value_structure):
"""Returns an `Optional` that has no value. """Returns an `Optional` that has no value.
NOTE: This method takes an argument that defines the structure of the value NOTE: This method takes an argument that defines the structure of the value
that would be contained in the returned `Optional` if it had a value. that would be contained in the returned `Optional` if it had a value.
>>> optional = tf.experimental.Optional.empty(
... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
>>> print(optional.has_value())
tf.Tensor(False, shape=(), dtype=bool)
Args: Args:
value_structure: A `Structure` object representing the structure of the element_spec: A nested structure of `tf.TypeSpec` objects matching the
components of this optional. structure of an element of this optional.
Returns: Returns:
An `Optional` that has no value. A `tf.experimental.Optional` with no value.
""" """
return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure) return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec)
@staticmethod
def from_value(value):
"""Returns a `tf.experimental.Optional` that wraps the given value.
>>> optional = tf.experimental.Optional.from_value(42)
>>> print(optional.has_value())
tf.Tensor(True, shape=(), dtype=bool)
>>> print(optional.get_value())
tf.Tensor(42, shape=(), dtype=int32)
Args:
value: A value to wrap. The value must be convertible to `tf.Tensor` or
`tf.CompositeTensor`.
Returns:
A `tf.experimental.Optional` that wraps `value`.
"""
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
element_spec = structure.type_spec_from_value(value)
encoded_value = structure.to_tensor_list(element_spec, value)
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
element_spec)
class _OptionalImpl(Optional): class _OptionalImpl(Optional):
"""Concrete implementation of `tf.data.experimental.Optional`. """Concrete implementation of `tf.experimental.Optional`.
NOTE(mrry): This implementation is kept private, to avoid defining NOTE(mrry): This implementation is kept private, to avoid defining
`Optional.__init__()` in the public API. `Optional.__init__()` in the public API.
""" """
def __init__(self, variant_tensor, value_structure): def __init__(self, variant_tensor, element_spec):
self._variant_tensor = variant_tensor self._variant_tensor = variant_tensor
self._value_structure = value_structure self._element_spec = element_spec
def has_value(self, name=None): def has_value(self, name=None):
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
@ -141,18 +183,18 @@ class _OptionalImpl(Optional):
with ops.name_scope(name, "OptionalGetValue", with ops.name_scope(name, "OptionalGetValue",
[self._variant_tensor]) as scope: [self._variant_tensor]) as scope:
return structure.from_tensor_list( return structure.from_tensor_list(
self._value_structure, self._element_spec,
gen_dataset_ops.optional_get_value( gen_dataset_ops.optional_get_value(
self._variant_tensor, self._variant_tensor,
name=scope, name=scope,
output_types=structure.get_flat_tensor_types( output_types=structure.get_flat_tensor_types(
self._value_structure), self._element_spec),
output_shapes=structure.get_flat_tensor_shapes( output_shapes=structure.get_flat_tensor_shapes(
self._value_structure))) self._element_spec)))
@property @property
def value_structure(self): def element_spec(self):
return self._value_structure return self._element_spec
@property @property
def _type_spec(self): def _type_spec(self):
@ -162,19 +204,38 @@ class _OptionalImpl(Optional):
@tf_export( @tf_export(
"OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"]) "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"])
class OptionalSpec(type_spec.TypeSpec): class OptionalSpec(type_spec.TypeSpec):
"""Represents an optional potentially containing a structured value.""" """Type specification for `tf.experimental.Optional`.
__slots__ = ["_value_structure"] For instance, `tf.OptionalSpec` can be used to define a tf.function that takes
`tf.experimental.Optional` as an input argument:
def __init__(self, value_structure): >>> @tf.function(input_signature=[tf.OptionalSpec(
self._value_structure = value_structure ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
... def maybe_square(optional):
... if optional.has_value():
... x = optional.get_value()
... return x * x
... return -1
>>> optional = tf.experimental.Optional.from_value(5)
>>> print(maybe_square(optional))
tf.Tensor(25, shape=(), dtype=int32)
Attributes:
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of the optional element.
"""
__slots__ = ["_element_spec"]
def __init__(self, element_spec):
self._element_spec = element_spec
@property @property
def value_type(self): def value_type(self):
return _OptionalImpl return _OptionalImpl
def _serialize(self): def _serialize(self):
return (self._value_structure,) return (self._element_spec,)
@property @property
def _component_specs(self): def _component_specs(self):
@ -185,11 +246,11 @@ class OptionalSpec(type_spec.TypeSpec):
def _from_components(self, flat_value): def _from_components(self, flat_value):
# pylint: disable=protected-access # pylint: disable=protected-access
return _OptionalImpl(flat_value[0], self._value_structure) return _OptionalImpl(flat_value[0], self._element_spec)
@staticmethod @staticmethod
def from_value(value): def from_value(value):
return OptionalSpec(value.value_structure) return OptionalSpec(value.element_spec)
def _to_legacy_output_types(self): def _to_legacy_output_types(self):
return self return self

View File

@ -1130,16 +1130,16 @@ class _SingleWorkerDatasetIteratorBase(object):
real_data = control_flow_ops.cond( real_data = control_flow_ops.cond(
data.has_value(), data.has_value(),
lambda: data.get_value(), lambda: data.get_value(),
lambda: _dummy_tensor_fn(data.value_structure), lambda: _dummy_tensor_fn(data.element_spec),
strict=True, strict=True,
) )
# Some dimensions in `replicas` will become unknown after we # Some dimensions in `replicas` will become unknown after we
# conditionally return the real tensors or the dummy tensors. Recover # conditionally return the real tensors or the dummy tensors. Recover
# the shapes from `data.value_structure`. We only need to do this in # the shapes from `data.element_spec`. We only need to do this in
# non eager mode because we always know the runtime shape of the # non eager mode because we always know the runtime shape of the
# tensors in eager mode. # tensors in eager mode.
if not context.executing_eagerly(): if not context.executing_eagerly():
real_data = _recover_shape_fn(real_data, data.value_structure) real_data = _recover_shape_fn(real_data, data.element_spec)
result.append(real_data) result.append(real_data)
# pylint: enable=cell-var-from-loop # pylint: enable=cell-var-from-loop
# pylint: enable=unnecessary-lambda # pylint: enable=unnecessary-lambda

View File

@ -9,7 +9,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "from_value" name: "from_value"

View File

@ -2,9 +2,7 @@ path: "tensorflow.data.Dataset"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -9,7 +9,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "from_value" name: "from_value"

View File

@ -4,12 +4,16 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "value_structure" name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>" mtype: "<class \'abc.abstractproperty\'>"
} }
member_method { member_method {
name: "__init__" name: "__init__"
} }
member_method {
name: "empty"
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "from_value" name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
@ -22,8 +26,4 @@ tf_class {
name: "has_value" name: "has_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "none_from_structure"
argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
}
} }

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,9 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV1\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -0,0 +1,29 @@
path: "tensorflow.experimental.Optional"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "empty"
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "has_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental" path: "tensorflow.experimental"
tf_module { tf_module {
member {
name: "Optional"
mtype: "<type \'type\'>"
}
member_method { member_method {
name: "async_clear_error" name: "async_clear_error"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -9,7 +9,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "from_value" name: "from_value"

View File

@ -1,9 +1,7 @@
path: "tensorflow.data.Dataset" path: "tensorflow.data.Dataset"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>" mtype: "<class \'abc.abstractproperty\'>"

View File

@ -3,9 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.FixedLengthRecordDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.readers.FixedLengthRecordDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -0,0 +1,26 @@
path: "tensorflow.data.IteratorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,20 @@
path: "tensorflow.data.Iterator"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorBase\'>"
is_instance: "<class \'collections.abc.Iterator\'>"
member {
name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "get_next"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -2,9 +2,7 @@ path: "tensorflow.data.TFRecordDataset"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -3,9 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.TextLineDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.readers.TextLineDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -3,9 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,12 +4,16 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "value_structure" name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>" mtype: "<class \'abc.abstractproperty\'>"
} }
member_method { member_method {
name: "__init__" name: "__init__"
} }
member_method {
name: "empty"
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "from_value" name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
@ -22,8 +26,4 @@ tf_class {
name: "has_value" name: "has_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "none_from_structure"
argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
}
} }

View File

@ -3,9 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -3,9 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDatasetV2\'>" is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDatasetV2\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>" is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'collections.abc.Iterable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member { member {
name: "element_spec" name: "element_spec"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -16,6 +16,14 @@ tf_module {
name: "INFINITE_CARDINALITY" name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>" mtype: "<type \'int\'>"
} }
member {
name: "Iterator"
mtype: "<type \'type\'>"
}
member {
name: "IteratorSpec"
mtype: "<type \'type\'>"
}
member { member {
name: "Options" name: "Options"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,29 @@
path: "tensorflow.experimental.Optional"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "empty"
argspec: "args=[\'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "has_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental" path: "tensorflow.experimental"
tf_module { tf_module {
member {
name: "Optional"
mtype: "<type \'type\'>"
}
member { member {
name: "dlpack" name: "dlpack"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"