[tf.data] Exposing dataset / iterator element type specification in the public API as element_spec.

This CL renames the pre-existing `_element_structure` property of tf.data datasets and iterators to `element_spec`, thus exposing it in the public API.

PiperOrigin-RevId: 256201202
This commit is contained in:
Jiri Simsa 2019-07-02 11:10:43 -07:00 committed by TensorFlower Gardener
parent ff67edeac1
commit 9db44da931
44 changed files with 270 additions and 229 deletions

View File

@ -591,7 +591,7 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
self._table = table
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -652,7 +652,7 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
super(_BigtableLookupDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return tuple([structure.TensorStructure(dtypes.string, [])] *
self._num_outputs)
@ -681,7 +681,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
super(_BigtableScanDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return tuple([structure.TensorStructure(dtypes.string, [])] *
self._num_outputs)
@ -703,6 +703,6 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, []))

View File

@ -346,13 +346,13 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
output_classes = nest.pack_sequence_as(
output_types, nest.flatten(input_classes))
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec

View File

@ -395,6 +395,6 @@ class LMDBDataset(dataset_ops.DatasetSource):
super(LMDBDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, []))

View File

@ -39,7 +39,7 @@ class _SlideDataset(dataset_ops.UnaryDataset):
window_shift, dtype=dtypes.int64, name="window_shift")
input_structure = dataset_ops.get_structure(input_dataset)
self._structure = nest.map_structure(
self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
@ -50,8 +50,8 @@ class _SlideDataset(dataset_ops.UnaryDataset):
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
@deprecation.deprecated_args(

View File

@ -58,11 +58,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
variant_tensor = gen_dataset_ops.sequence_file_dataset(
self._filenames,
structure.get_flat_tensor_types(self._element_structure))
self._filenames, self._flat_types)
super(SequenceFileDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, []))

View File

@ -754,7 +754,7 @@ class IgniteDataset(dataset_ops.DatasetSource):
self.cache_type.to_permutation(),
dtype=dtypes.int32,
name="permutation")
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
self.cache_type.to_output_types(), self.cache_type.to_output_shapes(),
self.cache_type.to_output_classes())
@ -766,5 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource):
self.schema, self.permutation)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec

View File

@ -69,5 +69,5 @@ class KafkaDataset(dataset_ops.DatasetSource):
self._group, self._eof, self._timeout)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])

View File

@ -86,5 +86,5 @@ class KinesisDataset(dataset_ops.DatasetSource):
self._stream, self._shard, self._read_indefinitely, self._interval)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])

View File

@ -37,7 +37,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
ds._element_structure)
ds.element_spec)
get_next = self.getNext(variant_ds, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(get_next()))
@ -54,7 +54,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
gpu_wrapped_variant)
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
ds._element_structure)
ds.element_spec)
iterator = dataset_ops.make_initializable_iterator(variant_ds)
get_next = iterator.get_next()

View File

@ -242,7 +242,7 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset
self._batch_size = batch_size
self._row_shape = row_shape
self._structure = structure.SparseTensorStructure(
self._element_spec = structure.SparseTensorStructure(
dataset_ops.get_legacy_output_types(input_dataset),
tensor_shape.vector(None).concatenate(self._row_shape))
@ -255,8 +255,8 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
class _MapAndBatchDataset(dataset_ops.UnaryDataset):
@ -285,12 +285,12 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
# or `False` (explicitly retaining the remainder).
# pylint: disable=g-long-lambda
self._structure = nest.map_structure(
self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(
tensor_util.constant_value(self._batch_size_t)),
self._map_func.output_structure)
else:
self._structure = nest.map_structure(
self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(None),
self._map_func.output_structure)
# pylint: enable=protected-access
@ -309,5 +309,5 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec

View File

@ -46,7 +46,7 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, num_workers, index):
self._input_dataset = input_dataset
self._structure = input_dataset._element_structure # pylint: disable=protected-access
self._element_spec = input_dataset.element_spec
variant_tensor = ged_ops.experimental_auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
@ -55,8 +55,8 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index):
@ -85,7 +85,7 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
@ -94,8 +94,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__

View File

@ -65,6 +65,6 @@ def get_single_element(dataset):
# pylint: disable=protected-access
return structure.from_compatible_tensor_list(
dataset._element_structure,
dataset.element_spec,
gen_dataset_ops.dataset_to_single_element(
dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access

View File

@ -299,8 +299,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
wrapped_func = dataset_ops.StructuredFunctionWrapper(
reduce_func,
self._transformation_name(),
input_structure=(self._state_structure,
input_dataset._element_structure), # pylint: disable=protected-access
input_structure=(self._state_structure, input_dataset.element_spec),
add_to_graph=False)
# Extract and validate class information from the returned values.
@ -355,7 +354,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
input_structure=self._state_structure)
@property
def _element_structure(self):
def element_spec(self):
return self._finalize_func.output_structure
def _functions(self):
@ -416,7 +415,7 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping defun for reduce_func."""
nested_dataset = dataset_ops.DatasetStructure(
input_dataset._element_structure) # pylint: disable=protected-access
input_dataset.element_spec)
input_structure = (structure.TensorStructure(dtypes.int64,
[]), nested_dataset)
self._reduce_func = dataset_ops.StructuredFunctionWrapper(
@ -426,12 +425,12 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
self._reduce_func.output_structure, dataset_ops.DatasetStructure):
raise TypeError("`reduce_func` must return a `Dataset` object.")
# pylint: disable=protected-access
self._structure = (
self._reduce_func.output_structure._element_structure)
self._element_spec = (
self._reduce_func.output_structure._element_spec)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
def _functions(self):
return [self._key_func, self._reduce_func, self._window_size_func]

View File

@ -119,7 +119,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
])
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
first_output_types, output_shapes, first_output_classes)
super(_DirectedInterleaveDataset, self).__init__()
@ -136,8 +136,8 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
return [self._selector_input] + self._data_inputs
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
@tf_export("data.experimental.sample_from_datasets", v1=[])
@ -267,8 +267,8 @@ def choose_from_datasets_v2(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type.
"""
if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
structure.TensorStructure(dtypes.int64, [])):
if not structure.are_compatible(choice_dataset.element_spec,
structure.TensorStructure(dtypes.int64, [])):
raise TypeError("`choice_dataset` must be a dataset of scalar "
"`tf.int64` tensors.")
return _DirectedInterleaveDataset(choice_dataset, datasets)

View File

@ -35,5 +35,5 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
super(MatchingFilesDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])

View File

@ -156,7 +156,7 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
A `Dataset` that has the same elements the inputs.
"""
self._datasets = list(datasets)
self._structure = self._datasets[0]._element_structure # pylint: disable=protected-access
self._element_spec = self._datasets[0].element_spec
variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
@ -168,8 +168,8 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
return self._datasets
@property
def _element_structure(self):
return self._datasets[0]._element_structure # pylint: disable=protected-access
def element_spec(self):
return self._element_spec
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
@ -242,14 +242,13 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
Returns:
A `Dataset` that has the same elements the inputs.
"""
input_structure = dataset_ops.DatasetStructure(
dataset_ops.get_structure(input_dataset))
input_structure = dataset_ops.DatasetStructure(input_dataset.element_spec)
self._funcs = [
dataset_ops.StructuredFunctionWrapper(
f, "ChooseFastestV2", input_structure=input_structure)
for f in functions
]
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
self._element_spec = self._funcs[0].output_structure._element_spec # pylint: disable=protected-access
self._captured_arguments = []
for f in self._funcs:
@ -279,5 +278,5 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec

View File

@ -32,7 +32,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, features, num_parallel_calls):
self._input_dataset = input_dataset
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
if not structure.are_compatible(
input_dataset.element_spec,
structure.TensorStructure(dtypes.string, [None])):
raise TypeError("Input dataset should be a dataset of vectors of strings")
self._num_parallel_calls = num_parallel_calls
@ -75,7 +76,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
[ops.Tensor for _ in range(len(self._dense_defaults))] +
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
]))
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = (
@ -91,8 +92,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
# TODO(b/111553342): add arguments names and example names as well.

View File

@ -146,8 +146,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
dataset_ops.get_legacy_output_types(self),
dataset_ops.get_legacy_output_shapes(self),
dataset_ops.get_legacy_output_classes(self))
return structure.to_tensor_list(self._element_structure,
iterator.get_next())
return structure.to_tensor_list(self.element_spec, iterator.get_next())
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -251,7 +250,7 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._map_func.output_structure
def _transformation_name(self):

View File

@ -39,7 +39,7 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
super(RandomDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.int64, [])

View File

@ -662,7 +662,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
argument_default=[],
argument_dtype=dtypes.int64,
)
self._structure = tuple(
self._element_spec = tuple(
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
@ -678,8 +678,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
super(CsvDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
@tf_export(v1=["data.experimental.CsvDataset"])
@ -955,7 +955,7 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
data_source_name, dtype=dtypes.string, name="data_source_name")
self._query = ops.convert_to_tensor(
query, dtype=dtypes.string, name="query")
self._structure = nest.map_structure(
self._element_spec = nest.map_structure(
lambda dtype: structure.TensorStructure(dtype, []), output_types)
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
@ -963,8 +963,8 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
super(SqlDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
@tf_export(v1=["data.experimental.SqlDataset"])

View File

@ -49,7 +49,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
scan_func,
self._transformation_name(),
input_structure=(self._state_structure,
input_dataset._element_structure), # pylint: disable=protected-access
input_dataset.element_spec),
add_to_graph=False)
if not (
isinstance(wrapped_func.output_types, collections.Sequence) and
@ -91,7 +91,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
old_state_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._state_structure)
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
flat_state_shapes = nest.flatten(old_state_shapes)
@ -135,8 +135,8 @@ class _ScanDataset(dataset_ops.UnaryDataset):
return [self._scan_func]
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
def _transformation_name(self):
return "tf.data.experimental.scan()"

View File

@ -43,21 +43,6 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
# For testing deserialization of Datasets represented as functions
class _RevivedDataset(dataset_ops.DatasetV2):
def __init__(self, variant, element_structure):
self._structure = element_structure
super(_RevivedDataset, self).__init__(variant)
def _inputs(self):
return []
@property
def _element_structure(self):
return self._structure
@test_util.run_all_in_graph_and_eager_modes
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@ -75,8 +60,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = _RevivedDataset(
variant, original_dataset._element_structure)
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
def testAsFunctionWithMapInFlatMap(self):
@ -88,8 +73,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = _RevivedDataset(
variant, original_dataset._element_structure)
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, list(original_dataset))
@staticmethod

View File

@ -315,14 +315,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
"or when eager execution is enabled.")
@abc.abstractproperty
def _element_structure(self):
"""The structure of an element of this dataset.
def element_spec(self):
"""The type specification of an element of this dataset.
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this dataset and specifying the type of individual components.
"""
raise NotImplementedError("Dataset._element_structure")
raise NotImplementedError("Dataset.element_spec")
def __repr__(self):
output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
@ -339,7 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Returns:
A list `tf.TensorShapes`s for the element tensor representation.
"""
return structure.get_flat_tensor_shapes(self._element_structure)
return structure.get_flat_tensor_shapes(self.element_spec)
@property
def _flat_types(self):
@ -348,7 +348,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Returns:
A list `tf.DType`s for the element tensor representation.
"""
return structure.get_flat_tensor_types(self._element_structure)
return structure.get_flat_tensor_types(self.element_spec)
@property
def _flat_structure(self):
@ -371,7 +371,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
@property
def _type_spec(self):
return DatasetStructure(self._element_structure)
return DatasetStructure(self.element_spec)
@staticmethod
def from_tensors(tensors):
@ -1442,7 +1442,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
wrapped_func = StructuredFunctionWrapper(
reduce_func,
"reduce()",
input_structure=(state_structure, self._element_structure),
input_structure=(state_structure, self.element_spec),
add_to_graph=False)
# Extract and validate class information from the returned values.
@ -1546,17 +1546,17 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
def normalize(arg, *rest):
# pylint: disable=protected-access
if rest:
return structure.to_batched_tensor_list(self._element_structure,
return structure.to_batched_tensor_list(self.element_spec,
(arg,) + rest)
else:
return structure.to_batched_tensor_list(self._element_structure, arg)
return structure.to_batched_tensor_list(self.element_spec, arg)
normalized_dataset = self.map(normalize)
# NOTE(mrry): Our `map()` has lost information about the structure of
# non-tensor components, so re-apply the structure of the original dataset.
restructured_dataset = _RestructuredDataset(normalized_dataset,
self._element_structure)
self.element_spec)
return _UnbatchDataset(restructured_dataset)
def with_options(self, options):
@ -1750,7 +1750,7 @@ class DatasetV1(DatasetV2):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._element_structure)
self.element_spec)
@property
@deprecation.deprecated(
@ -1764,7 +1764,7 @@ class DatasetV1(DatasetV2):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._element_structure)
self.element_spec)
@property
@deprecation.deprecated(
@ -1778,10 +1778,10 @@ class DatasetV1(DatasetV2):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._element_structure)
self.element_spec)
@property
def _element_structure(self):
def element_spec(self):
# TODO(b/110122868): Remove this override once all `Dataset` instances
# implement `element_structure`.
return structure.convert_legacy_structure(
@ -2003,8 +2003,8 @@ class DatasetV1Adapter(DatasetV1):
return self._dataset.options()
@property
def _element_structure(self):
return self._dataset._element_structure # pylint: disable=protected-access
def element_spec(self):
return self._dataset.element_spec # pylint: disable=protected-access
def __iter__(self):
return iter(self._dataset)
@ -2107,7 +2107,7 @@ def get_structure(dataset_or_iterator):
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
"""
try:
return dataset_or_iterator._element_structure # pylint: disable=protected-access
return dataset_or_iterator.element_spec # pylint: disable=protected-access
except AttributeError:
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
"object, but got %s." % type(dataset_or_iterator))
@ -2306,8 +2306,8 @@ class UnaryUnchangedStructureDataset(UnaryDataset):
input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._input_dataset._element_structure # pylint: disable=protected-access
def element_spec(self):
return self._input_dataset.element_spec
class TensorDataset(DatasetSource):
@ -2325,7 +2325,7 @@ class TensorDataset(DatasetSource):
super(TensorDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -2352,7 +2352,7 @@ class TensorSliceDataset(DatasetSource):
super(TensorSliceDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -2381,7 +2381,7 @@ class SparseTensorSliceDataset(DatasetSource):
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -2396,20 +2396,20 @@ class _VariantDataset(DatasetV2):
return []
@property
def _element_structure(self):
def element_spec(self):
return self._structure
class _NestedVariant(composite_tensor.CompositeTensor):
def __init__(self, variant_tensor, element_structure, dataset_shape):
def __init__(self, variant_tensor, element_spec, dataset_shape):
self._variant_tensor = variant_tensor
self._element_structure = element_structure
self._element_spec = element_spec
self._dataset_shape = dataset_shape
@property
def _type_spec(self):
return DatasetStructure(self._element_structure, self._dataset_shape)
return DatasetStructure(self._element_spec, self._dataset_shape)
@tf_export("data.experimental.from_variant")
@ -2445,10 +2445,10 @@ def to_variant(dataset):
class DatasetStructure(type_spec.BatchableTypeSpec):
"""Type specification for `tf.data.Dataset`."""
__slots__ = ["_element_structure", "_dataset_shape"]
__slots__ = ["_element_spec", "_dataset_shape"]
def __init__(self, element_spec, dataset_shape=None):
self._element_structure = element_spec
self._element_spec = element_spec
if dataset_shape:
self._dataset_shape = dataset_shape
else:
@ -2459,7 +2459,7 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
return _VariantDataset
def _serialize(self):
return (self._element_structure, self._dataset_shape)
return (self._element_spec, self._dataset_shape)
@property
def _component_specs(self):
@ -2471,10 +2471,9 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
def _from_components(self, components):
# pylint: disable=protected-access
if self._dataset_shape.ndims == 0:
return _VariantDataset(components, self._element_structure)
return _VariantDataset(components, self._element_spec)
else:
return _NestedVariant(components, self._element_structure,
self._dataset_shape)
return _NestedVariant(components, self._element_spec, self._dataset_shape)
def _to_tensor_list(self, value):
return [
@ -2484,17 +2483,17 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
@staticmethod
def from_value(value):
return DatasetStructure(value._element_structure) # pylint: disable=protected-access
return DatasetStructure(value.element_spec) # pylint: disable=protected-access
def _batch(self, batch_size):
return DatasetStructure(
self._element_structure,
self._element_spec,
tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
def _unbatch(self):
if self._dataset_shape.ndims == 0:
raise ValueError("Unbatching a dataset is only supported for rank >= 1")
return DatasetStructure(self._element_structure, self._dataset_shape[1:])
return DatasetStructure(self._element_spec, self._dataset_shape[1:])
def _to_batched_tensor_list(self, value):
if self._dataset_shape.ndims == 0:
@ -2571,7 +2570,7 @@ class StructuredFunctionWrapper(object):
raise ValueError("Either `dataset`, `input_structure` or all of "
"`input_classes`, `input_shapes`, and `input_types` "
"must be specified.")
self._input_structure = dataset._element_structure
self._input_structure = dataset.element_spec
else:
if not (dataset is None and input_classes is None and input_shapes is None
and input_types is None):
@ -2767,7 +2766,7 @@ class _GeneratorDataset(DatasetSource):
super(_GeneratorDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._next_func.output_structure
def _transformation_name(self):
@ -2792,7 +2791,7 @@ class ZipDataset(DatasetV2):
self._datasets = datasets
self._structure = nest.pack_sequence_as(
self._datasets,
[ds._element_structure for ds in nest.flatten(self._datasets)]) # pylint: disable=protected-access
[ds.element_spec for ds in nest.flatten(self._datasets)])
variant_tensor = gen_dataset_ops.zip_dataset(
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
**self._flat_structure)
@ -2802,7 +2801,7 @@ class ZipDataset(DatasetV2):
return nest.flatten(self._datasets)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -2850,7 +2849,7 @@ class ConcatenateDataset(DatasetV2):
return self._input_datasets
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -2907,7 +2906,7 @@ class RangeDataset(DatasetSource):
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -3037,11 +3036,11 @@ class BatchDataset(UnaryDataset):
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(
tensor_util.constant_value(self._batch_size)),
input_dataset._element_structure)
input_dataset.element_spec)
else:
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(None),
input_dataset._element_structure)
input_dataset.element_spec)
variant_tensor = gen_dataset_ops.batch_dataset_v2(
input_dataset._variant_tensor,
batch_size=self._batch_size,
@ -3050,7 +3049,7 @@ class BatchDataset(UnaryDataset):
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -3268,7 +3267,7 @@ class PaddedBatchDataset(UnaryDataset):
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -3308,7 +3307,7 @@ class MapDataset(UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._map_func.output_structure
def _transformation_name(self):
@ -3350,7 +3349,7 @@ class ParallelMapDataset(UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._map_func.output_structure
def _transformation_name(self):
@ -3369,7 +3368,7 @@ class FlatMapDataset(UnaryDataset):
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.flat_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
@ -3381,7 +3380,7 @@ class FlatMapDataset(UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._structure
def _transformation_name(self):
@ -3400,7 +3399,7 @@ class InterleaveDataset(UnaryDataset):
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
@ -3419,7 +3418,7 @@ class InterleaveDataset(UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._structure
def _transformation_name(self):
@ -3439,7 +3438,7 @@ class ParallelInterleaveDataset(UnaryDataset):
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
@ -3461,7 +3460,7 @@ class ParallelInterleaveDataset(UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
def element_spec(self):
return self._structure
def _transformation_name(self):
@ -3561,7 +3560,7 @@ class WindowDataset(UnaryDataset):
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -3684,7 +3683,7 @@ class _RestructuredDataset(UnaryDataset):
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure
@ -3713,5 +3712,5 @@ class _UnbatchDataset(UnaryDataset):
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return self._structure

View File

@ -104,10 +104,12 @@ class Iterator(trackable.Trackable):
raise ValueError("If `structure` is not specified, all of "
"`output_types`, `output_shapes`, and `output_classes`"
" must be specified.")
self._structure = structure.convert_legacy_structure(
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(self._structure)
self._flat_tensor_types = structure.get_flat_tensor_types(self._structure)
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
self._element_spec)
self._flat_tensor_types = structure.get_flat_tensor_types(
self._element_spec)
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
self._iterator_resource)
@ -347,13 +349,13 @@ class Iterator(trackable.Trackable):
# pylint: disable=protected-access
dataset_output_types = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(),
dataset._element_structure)
dataset.element_spec)
dataset_output_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(),
dataset._element_structure)
dataset.element_spec)
dataset_output_classes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(),
dataset._element_structure)
dataset.element_spec)
# pylint: enable=protected-access
nest.assert_same_structure(self.output_types, dataset_output_types)
@ -436,7 +438,7 @@ class Iterator(trackable.Trackable):
output_types=self._flat_tensor_types,
output_shapes=self._flat_tensor_shapes,
name=name)
return structure.from_tensor_list(self._structure, flat_ret)
return structure.from_tensor_list(self._element_spec, flat_ret)
def string_handle(self, name=None):
"""Returns a string-valued `tf.Tensor` that represents this iterator.
@ -467,7 +469,7 @@ class Iterator(trackable.Trackable):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
@deprecation.deprecated(
@ -481,7 +483,7 @@ class Iterator(trackable.Trackable):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
@deprecation.deprecated(
@ -495,17 +497,17 @@ class Iterator(trackable.Trackable):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
def _element_structure(self):
"""The structure of an element of this iterator.
def element_spec(self):
"""The type specification of an element of this iterator.
Returns:
A `Structure` object representing the structure of the components of this
optional.
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this iterator and specifying the type of individual components.
"""
return self._structure
return self._element_spec
def _gather_saveables_for_checkpoint(self):
@ -556,7 +558,7 @@ class IteratorResourceDeleter(object):
class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset=None, components=None, element_structure=None):
def __init__(self, dataset=None, components=None, element_spec=None):
"""Creates a new iterator from the given dataset.
If `dataset` is not specified, the iterator will be created from the given
@ -567,29 +569,29 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
Args:
dataset: A `tf.data.Dataset` object.
components: Tensor components to construct the iterator from.
element_structure: A nested structure of `TypeSpec` objects that
represents the type specification elements of the iterator.
element_spec: A nested structure of `TypeSpec` objects that
represents the type specification of elements of the iterator.
Raises:
ValueError: If `dataset` is not provided and either `components` or
`element_structure` is not provided. Or `dataset` is provided and either
`components` and `element_structure` is provided.
`element_spec` is not provided. Or `dataset` is provided and either
`components` and `element_spec` is provided.
"""
error_message = "Either `dataset` or both `components` and "
"`element_structure` need to be provided."
"`element_spec` need to be provided."
self._device = context.context().device_name
if dataset is None:
if (components is None or element_structure is None):
if (components is None or element_spec is None):
raise ValueError(error_message)
# pylint: disable=protected-access
self._structure = element_structure
self._element_spec = element_spec
self._flat_output_types = structure.get_flat_tensor_types(
self._structure)
self._element_spec)
self._flat_output_shapes = structure.get_flat_tensor_shapes(
self._structure)
self._element_spec)
self._iterator_resource, self._deleter = components
# Delete the resource when this object is deleted
self._resource_deleter = IteratorResourceDeleter(
@ -597,7 +599,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
device=self._device,
deleter=self._deleter)
else:
if (components is not None or element_structure is not None):
if (components is not None or element_spec is not None):
raise ValueError(error_message)
if (_device_stack_is_empty() or
context.context().device_spec.device_type != "CPU"):
@ -610,11 +612,11 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
# pylint: disable=protected-access
dataset = dataset._apply_options()
ds_variant = dataset._variant_tensor
self._structure = dataset._element_structure
self._element_spec = dataset.element_spec
self._flat_output_types = structure.get_flat_tensor_types(
self._structure)
self._element_spec)
self._flat_output_shapes = structure.get_flat_tensor_shapes(
self._structure)
self._element_spec)
with ops.colocate_with(ds_variant):
self._iterator_resource, self._deleter = (
gen_dataset_ops.anonymous_iterator_v2(
@ -642,7 +644,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
self._iterator_resource,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
return structure.from_compatible_tensor_list(self._structure, ret)
return structure.from_compatible_tensor_list(self._element_spec, ret)
# This runs in sync mode as iterators use an error status to communicate
# that there is no more data to iterate over.
@ -664,13 +666,13 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
try:
# Fast path for the case `self._structure` is not a nested structure.
return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access
return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access
except AttributeError:
return structure.from_compatible_tensor_list(self._structure, ret)
return structure.from_compatible_tensor_list(self._element_spec, ret)
@property
def _type_spec(self):
return IteratorSpec(self._element_structure)
return IteratorSpec(self.element_spec)
def next(self):
"""Returns a nested structure of `Tensor`s containing the next element."""
@ -693,7 +695,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
@deprecation.deprecated(
@ -707,7 +709,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
@deprecation.deprecated(
@ -721,17 +723,17 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._structure)
self._element_spec)
@property
def _element_structure(self):
"""The structure of an element of this iterator.
def element_spec(self):
"""The type specification of an element of this iterator.
Returns:
A `Structure` object representing the structure of the components of this
optional.
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this iterator and specifying the type of individual components.
"""
return self._structure
return self._element_spec
def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
@ -786,11 +788,11 @@ class IteratorSpec(type_spec.TypeSpec):
return IteratorV2(
dataset=None,
components=components,
element_structure=self._element_spec)
element_spec=self._element_spec)
@staticmethod
def from_value(value):
return IteratorSpec(value._element_structure) # pylint: disable=protected-access
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
# TODO(b/71645805): Expose trackable stateful objects from dataset
@ -828,7 +830,6 @@ def get_next_as_optional(iterator):
return optional_ops._OptionalImpl(
gen_dataset_ops.iterator_get_next_as_optional(
iterator._iterator_resource,
output_types=structure.get_flat_tensor_types(
iterator._element_structure),
output_types=structure.get_flat_tensor_types(iterator.element_spec),
output_shapes=structure.get_flat_tensor_shapes(
iterator._element_structure)), iterator._element_structure)
iterator.element_spec)), iterator.element_spec)

View File

@ -36,8 +36,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
"""A `dummy` generator dataset."""
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
source_device, element_structure):
self._structure = element_structure
source_device, element_spec):
self._element_spec = element_spec
multi_device_iterator_string_handle = (
gen_dataset_ops.multi_device_iterator_to_string_handle(
@ -71,14 +71,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
multi_device_iterator = (
gen_dataset_ops.multi_device_iterator_from_string_handle(
string_handle=string_handle,
output_types=structure.get_flat_tensor_types(self._structure),
output_shapes=structure.get_flat_tensor_shapes(self._structure)))
output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(
self._element_spec)))
return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
multi_device_iterator=multi_device_iterator,
shard_num=shard_num,
incarnation_id=incarnation_id,
output_types=structure.get_flat_tensor_types(self._structure),
output_shapes=structure.get_flat_tensor_shapes(self._structure))
output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -91,7 +92,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
return functional_ops.remote_call(
target=source_device,
args=[string_handle] + next_func_concrete.captured_inputs,
Tout=structure.get_flat_tensor_types(self._structure),
Tout=structure.get_flat_tensor_types(self._element_spec),
f=next_func_concrete)
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -141,8 +142,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
return []
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
@ -154,7 +155,7 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
def __init__(self, per_device_dataset, incarnation_id):
# pylint: disable=protected-access
self._structure = per_device_dataset._element_structure
self._element_spec = per_device_dataset.element_spec
self._init_func = per_device_dataset._init_func
self._init_captured_args = self._init_func.captured_inputs
@ -183,8 +184,8 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
return []
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
class MultiDeviceIterator(object):
@ -253,7 +254,7 @@ class MultiDeviceIterator(object):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
self._incarnation_id,
self._source_device_tensor,
self._dataset._element_structure) # pylint: disable=protected-access
self._dataset.element_spec)
self._prototype_device_datasets.append(ds)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
@ -339,5 +340,5 @@ class MultiDeviceIterator(object):
ds_variant, self._device_iterators[i]._iterator_resource)
@property
def _element_structure(self):
return dataset_ops.get_structure(self._dataset)
def element_spec(self):
return self._dataset.element_spec

View File

@ -115,7 +115,7 @@ class _TextLineDataset(dataset_ops.DatasetSource):
super(_TextLineDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -157,7 +157,7 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
super(TextLineDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -209,7 +209,7 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
super(_TFRecordDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -225,7 +225,7 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
if not isinstance(self._map_func.output_structure,
dataset_ops.DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
@ -257,8 +257,8 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def element_spec(self):
return self._element_spec
def _transformation_name(self):
return "tf.data.experimental.parallel_interleave()"
@ -321,7 +321,7 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
return self._impl._inputs() # pylint: disable=protected-access
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -408,7 +408,7 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])
@ -466,7 +466,7 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
def element_spec(self):
return structure.TensorStructure(dtypes.string, [])

View File

@ -480,7 +480,7 @@ class DistributedDataset(_IterableInput):
self._input_workers = input_workers
# TODO(anjalisridhar): Identify if we need to set this property on the
# iterator.
self._element_structure = dataset._element_structure # pylint: disable=protected-access
self._element_spec = dataset.element_spec
self._strategy = strategy
def __iter__(self):
@ -490,7 +490,7 @@ class DistributedDataset(_IterableInput):
self._input_workers)
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
iterator._element_structure = self._element_structure # pylint: disable=protected-access
iterator.element_spec = self._element_spec
return iterator
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
@ -537,7 +537,7 @@ class DistributedDatasetV1(DistributedDataset):
self._input_workers)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
iterator._element_structure = self._element_structure # pylint: disable=protected-access
iterator.element_spec = self._element_spec
return iterator
@ -670,9 +670,9 @@ class DatasetIterator(DistributedIteratorV1):
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(
input_workers,
worker_iterators, # pylint: disable=protected-access
worker_iterators,
strategy)
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access
self._element_spec = dist_dataset._element_spec # pylint: disable=protected-access
def _dummy_tensor_fn(value_structure):

View File

@ -56,8 +56,7 @@ def _clone_dataset(dataset):
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
return dataset_ops._VariantDataset(new_variant_tensor,
dataset._element_structure)
return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
def _get_op_def(op):

View File

@ -276,8 +276,7 @@ class CloneDatasetTest(test.TestCase):
def _assert_datasets_equal(self, ds1, ds2):
# First lets assert the structure is the same.
self.assertTrue(
structure.are_compatible(ds1._element_structure,
ds2._element_structure))
structure.are_compatible(ds1.element_spec, ds2.element_spec))
# Now create iterators on both and assert they produce the same values.
it1 = dataset_ops.make_initializable_iterator(ds1)

View File

@ -5,6 +5,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "initializer"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "output_classes"
mtype: "<type \'property\'>"

View File

@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "

View File

@ -5,6 +5,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"