[tf.data] Exposing dataset / iterator element type specification in the public API as element_spec.

This CL renames the pre-existing `_element_structure` property of tf.data datasets and iterators to `element_spec`, thus exposing it in the public API.

PiperOrigin-RevId: 256201202
This commit is contained in:
Jiri Simsa 2019-07-02 11:10:43 -07:00 committed by TensorFlower Gardener
parent ff67edeac1
commit 9db44da931
44 changed files with 270 additions and 229 deletions

View File

@ -591,7 +591,7 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
self._table = table self._table = table
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -652,7 +652,7 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
super(_BigtableLookupDataset, self).__init__(variant_tensor) super(_BigtableLookupDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return tuple([structure.TensorStructure(dtypes.string, [])] * return tuple([structure.TensorStructure(dtypes.string, [])] *
self._num_outputs) self._num_outputs)
@ -681,7 +681,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
super(_BigtableScanDataset, self).__init__(variant_tensor) super(_BigtableScanDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return tuple([structure.TensorStructure(dtypes.string, [])] * return tuple([structure.TensorStructure(dtypes.string, [])] *
self._num_outputs) self._num_outputs)
@ -703,6 +703,6 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor) super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return (structure.TensorStructure(dtypes.string, []), return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, [])) structure.TensorStructure(dtypes.string, []))

View File

@ -346,13 +346,13 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
output_classes = nest.pack_sequence_as( output_classes = nest.pack_sequence_as(
output_types, nest.flatten(input_classes)) output_types, nest.flatten(input_classes))
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor) super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec

View File

@ -395,6 +395,6 @@ class LMDBDataset(dataset_ops.DatasetSource):
super(LMDBDataset, self).__init__(variant_tensor) super(LMDBDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return (structure.TensorStructure(dtypes.string, []), return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, [])) structure.TensorStructure(dtypes.string, []))

View File

@ -39,7 +39,7 @@ class _SlideDataset(dataset_ops.UnaryDataset):
window_shift, dtype=dtypes.int64, name="window_shift") window_shift, dtype=dtypes.int64, name="window_shift")
input_structure = dataset_ops.get_structure(input_dataset) input_structure = dataset_ops.get_structure(input_dataset)
self._structure = nest.map_structure( self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_sliding_window_dataset( variant_tensor = ged_ops.experimental_sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
@ -50,8 +50,8 @@ class _SlideDataset(dataset_ops.UnaryDataset):
super(_SlideDataset, self).__init__(input_dataset, variant_tensor) super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
@deprecation.deprecated_args( @deprecation.deprecated_args(

View File

@ -58,11 +58,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
self._filenames = ops.convert_to_tensor( self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames") filenames, dtype=dtypes.string, name="filenames")
variant_tensor = gen_dataset_ops.sequence_file_dataset( variant_tensor = gen_dataset_ops.sequence_file_dataset(
self._filenames, self._filenames, self._flat_types)
structure.get_flat_tensor_types(self._element_structure))
super(SequenceFileDataset, self).__init__(variant_tensor) super(SequenceFileDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return (structure.TensorStructure(dtypes.string, []), return (structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, [])) structure.TensorStructure(dtypes.string, []))

View File

@ -754,7 +754,7 @@ class IgniteDataset(dataset_ops.DatasetSource):
self.cache_type.to_permutation(), self.cache_type.to_permutation(),
dtype=dtypes.int32, dtype=dtypes.int32,
name="permutation") name="permutation")
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_types(), self.cache_type.to_output_shapes(),
self.cache_type.to_output_classes()) self.cache_type.to_output_classes())
@ -766,5 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource):
self.schema, self.permutation) self.schema, self.permutation)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec

View File

@ -69,5 +69,5 @@ class KafkaDataset(dataset_ops.DatasetSource):
self._group, self._eof, self._timeout) self._group, self._eof, self._timeout)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])

View File

@ -86,5 +86,5 @@ class KinesisDataset(dataset_ops.DatasetSource):
self._stream, self._shard, self._read_indefinitely, self._interval) self._stream, self._shard, self._read_indefinitely, self._interval)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])

View File

@ -37,7 +37,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
variant_ds = dataset_ops._VariantDataset(unwrapped_variant, variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
ds._element_structure) ds.element_spec)
get_next = self.getNext(variant_ds, requires_initialization=True) get_next = self.getNext(variant_ds, requires_initialization=True)
for i in range(100): for i in range(100):
self.assertEqual(i, self.evaluate(get_next())) self.assertEqual(i, self.evaluate(get_next()))
@ -54,7 +54,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
gpu_wrapped_variant) gpu_wrapped_variant)
variant_ds = dataset_ops._VariantDataset(unwrapped_variant, variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
ds._element_structure) ds.element_spec)
iterator = dataset_ops.make_initializable_iterator(variant_ds) iterator = dataset_ops.make_initializable_iterator(variant_ds)
get_next = iterator.get_next() get_next = iterator.get_next()

View File

@ -242,7 +242,7 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._batch_size = batch_size self._batch_size = batch_size
self._row_shape = row_shape self._row_shape = row_shape
self._structure = structure.SparseTensorStructure( self._element_spec = structure.SparseTensorStructure(
dataset_ops.get_legacy_output_types(input_dataset), dataset_ops.get_legacy_output_types(input_dataset),
tensor_shape.vector(None).concatenate(self._row_shape)) tensor_shape.vector(None).concatenate(self._row_shape))
@ -255,8 +255,8 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
variant_tensor) variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
class _MapAndBatchDataset(dataset_ops.UnaryDataset): class _MapAndBatchDataset(dataset_ops.UnaryDataset):
@ -285,12 +285,12 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
# or `False` (explicitly retaining the remainder). # or `False` (explicitly retaining the remainder).
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
self._structure = nest.map_structure( self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch( lambda component_spec: component_spec._batch(
tensor_util.constant_value(self._batch_size_t)), tensor_util.constant_value(self._batch_size_t)),
self._map_func.output_structure) self._map_func.output_structure)
else: else:
self._structure = nest.map_structure( self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(None), lambda component_spec: component_spec._batch(None),
self._map_func.output_structure) self._map_func.output_structure)
# pylint: enable=protected-access # pylint: enable=protected-access
@ -309,5 +309,5 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec

View File

@ -46,7 +46,7 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, num_workers, index): def __init__(self, input_dataset, num_workers, index):
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._structure = input_dataset._element_structure # pylint: disable=protected-access self._element_spec = input_dataset.element_spec
variant_tensor = ged_ops.experimental_auto_shard_dataset( variant_tensor = ged_ops.experimental_auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers, num_workers=num_workers,
@ -55,8 +55,8 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor) super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index): def _AutoShardDatasetV1(input_dataset, num_workers, index):
@ -85,7 +85,7 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
input_types, output_shapes, input_classes) input_types, output_shapes, input_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset( variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access self._input_dataset._variant_tensor, # pylint: disable=protected-access
@ -94,8 +94,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__ _AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__

View File

@ -65,6 +65,6 @@ def get_single_element(dataset):
# pylint: disable=protected-access # pylint: disable=protected-access
return structure.from_compatible_tensor_list( return structure.from_compatible_tensor_list(
dataset._element_structure, dataset.element_spec,
gen_dataset_ops.dataset_to_single_element( gen_dataset_ops.dataset_to_single_element(
dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access

View File

@ -299,8 +299,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
wrapped_func = dataset_ops.StructuredFunctionWrapper( wrapped_func = dataset_ops.StructuredFunctionWrapper(
reduce_func, reduce_func,
self._transformation_name(), self._transformation_name(),
input_structure=(self._state_structure, input_structure=(self._state_structure, input_dataset.element_spec),
input_dataset._element_structure), # pylint: disable=protected-access
add_to_graph=False) add_to_graph=False)
# Extract and validate class information from the returned values. # Extract and validate class information from the returned values.
@ -355,7 +354,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
input_structure=self._state_structure) input_structure=self._state_structure)
@property @property
def _element_structure(self): def element_spec(self):
return self._finalize_func.output_structure return self._finalize_func.output_structure
def _functions(self): def _functions(self):
@ -416,7 +415,7 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
def _make_reduce_func(self, reduce_func, input_dataset): def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping defun for reduce_func.""" """Make wrapping defun for reduce_func."""
nested_dataset = dataset_ops.DatasetStructure( nested_dataset = dataset_ops.DatasetStructure(
input_dataset._element_structure) # pylint: disable=protected-access input_dataset.element_spec)
input_structure = (structure.TensorStructure(dtypes.int64, input_structure = (structure.TensorStructure(dtypes.int64,
[]), nested_dataset) []), nested_dataset)
self._reduce_func = dataset_ops.StructuredFunctionWrapper( self._reduce_func = dataset_ops.StructuredFunctionWrapper(
@ -426,12 +425,12 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
self._reduce_func.output_structure, dataset_ops.DatasetStructure): self._reduce_func.output_structure, dataset_ops.DatasetStructure):
raise TypeError("`reduce_func` must return a `Dataset` object.") raise TypeError("`reduce_func` must return a `Dataset` object.")
# pylint: disable=protected-access # pylint: disable=protected-access
self._structure = ( self._element_spec = (
self._reduce_func.output_structure._element_structure) self._reduce_func.output_structure._element_spec)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
def _functions(self): def _functions(self):
return [self._key_func, self._reduce_func, self._window_size_func] return [self._key_func, self._reduce_func, self._window_size_func]

View File

@ -119,7 +119,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
nest.flatten(dataset_ops.get_legacy_output_shapes(data_input))) nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
]) ])
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
first_output_types, output_shapes, first_output_classes) first_output_types, output_shapes, first_output_classes)
super(_DirectedInterleaveDataset, self).__init__() super(_DirectedInterleaveDataset, self).__init__()
@ -136,8 +136,8 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
return [self._selector_input] + self._data_inputs return [self._selector_input] + self._data_inputs
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
@tf_export("data.experimental.sample_from_datasets", v1=[]) @tf_export("data.experimental.sample_from_datasets", v1=[])
@ -267,8 +267,8 @@ def choose_from_datasets_v2(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type. type.
""" """
if not dataset_ops.get_structure(choice_dataset).is_compatible_with( if not structure.are_compatible(choice_dataset.element_spec,
structure.TensorStructure(dtypes.int64, [])): structure.TensorStructure(dtypes.int64, [])):
raise TypeError("`choice_dataset` must be a dataset of scalar " raise TypeError("`choice_dataset` must be a dataset of scalar "
"`tf.int64` tensors.") "`tf.int64` tensors.")
return _DirectedInterleaveDataset(choice_dataset, datasets) return _DirectedInterleaveDataset(choice_dataset, datasets)

View File

@ -35,5 +35,5 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
super(MatchingFilesDataset, self).__init__(variant_tensor) super(MatchingFilesDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])

View File

@ -156,7 +156,7 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
A `Dataset` that has the same elements the inputs. A `Dataset` that has the same elements the inputs.
""" """
self._datasets = list(datasets) self._datasets = list(datasets)
self._structure = self._datasets[0]._element_structure # pylint: disable=protected-access self._element_spec = self._datasets[0].element_spec
variant_tensor = ( variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset( gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access [dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
@ -168,8 +168,8 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
return self._datasets return self._datasets
@property @property
def _element_structure(self): def element_spec(self):
return self._datasets[0]._element_structure # pylint: disable=protected-access return self._element_spec
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset): class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
@ -242,14 +242,13 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
Returns: Returns:
A `Dataset` that has the same elements the inputs. A `Dataset` that has the same elements the inputs.
""" """
input_structure = dataset_ops.DatasetStructure( input_structure = dataset_ops.DatasetStructure(input_dataset.element_spec)
dataset_ops.get_structure(input_dataset))
self._funcs = [ self._funcs = [
dataset_ops.StructuredFunctionWrapper( dataset_ops.StructuredFunctionWrapper(
f, "ChooseFastestV2", input_structure=input_structure) f, "ChooseFastestV2", input_structure=input_structure)
for f in functions for f in functions
] ]
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access self._element_spec = self._funcs[0].output_structure._element_spec # pylint: disable=protected-access
self._captured_arguments = [] self._captured_arguments = []
for f in self._funcs: for f in self._funcs:
@ -279,5 +278,5 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
variant_tensor) variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec

View File

@ -32,7 +32,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, features, num_parallel_calls): def __init__(self, input_dataset, features, num_parallel_calls):
self._input_dataset = input_dataset self._input_dataset = input_dataset
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access if not structure.are_compatible(
input_dataset.element_spec,
structure.TensorStructure(dtypes.string, [None])): structure.TensorStructure(dtypes.string, [None])):
raise TypeError("Input dataset should be a dataset of vectors of strings") raise TypeError("Input dataset should be a dataset of vectors of strings")
self._num_parallel_calls = num_parallel_calls self._num_parallel_calls = num_parallel_calls
@ -75,7 +76,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
[ops.Tensor for _ in range(len(self._dense_defaults))] + [ops.Tensor for _ in range(len(self._dense_defaults))] +
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
])) ]))
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
variant_tensor = ( variant_tensor = (
@ -91,8 +92,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
# TODO(b/111553342): add arguments names and example names as well. # TODO(b/111553342): add arguments names and example names as well.

View File

@ -146,8 +146,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
dataset_ops.get_legacy_output_types(self), dataset_ops.get_legacy_output_types(self),
dataset_ops.get_legacy_output_shapes(self), dataset_ops.get_legacy_output_shapes(self),
dataset_ops.get_legacy_output_classes(self)) dataset_ops.get_legacy_output_classes(self))
return structure.to_tensor_list(self._element_structure, return structure.to_tensor_list(self.element_spec, iterator.get_next())
iterator.get_next())
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -251,7 +250,7 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._map_func.output_structure return self._map_func.output_structure
def _transformation_name(self): def _transformation_name(self):

View File

@ -39,7 +39,7 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
super(RandomDatasetV2, self).__init__(variant_tensor) super(RandomDatasetV2, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.int64, []) return structure.TensorStructure(dtypes.int64, [])

View File

@ -662,7 +662,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
argument_default=[], argument_default=[],
argument_dtype=dtypes.int64, argument_dtype=dtypes.int64,
) )
self._structure = tuple( self._element_spec = tuple(
structure.TensorStructure(d.dtype, []) for d in self._record_defaults) structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset( variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames, filenames=self._filenames,
@ -678,8 +678,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
super(CsvDatasetV2, self).__init__(variant_tensor) super(CsvDatasetV2, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
@tf_export(v1=["data.experimental.CsvDataset"]) @tf_export(v1=["data.experimental.CsvDataset"])
@ -955,7 +955,7 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
data_source_name, dtype=dtypes.string, name="data_source_name") data_source_name, dtype=dtypes.string, name="data_source_name")
self._query = ops.convert_to_tensor( self._query = ops.convert_to_tensor(
query, dtype=dtypes.string, name="query") query, dtype=dtypes.string, name="query")
self._structure = nest.map_structure( self._element_spec = nest.map_structure(
lambda dtype: structure.TensorStructure(dtype, []), output_types) lambda dtype: structure.TensorStructure(dtype, []), output_types)
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset( variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query, self._driver_name, self._data_source_name, self._query,
@ -963,8 +963,8 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
super(SqlDatasetV2, self).__init__(variant_tensor) super(SqlDatasetV2, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
@tf_export(v1=["data.experimental.SqlDataset"]) @tf_export(v1=["data.experimental.SqlDataset"])

View File

@ -49,7 +49,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
scan_func, scan_func,
self._transformation_name(), self._transformation_name(),
input_structure=(self._state_structure, input_structure=(self._state_structure,
input_dataset._element_structure), # pylint: disable=protected-access input_dataset.element_spec),
add_to_graph=False) add_to_graph=False)
if not ( if not (
isinstance(wrapped_func.output_types, collections.Sequence) and isinstance(wrapped_func.output_types, collections.Sequence) and
@ -91,7 +91,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
old_state_shapes = nest.map_structure( old_state_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._state_structure) self._state_structure)
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
flat_state_shapes = nest.flatten(old_state_shapes) flat_state_shapes = nest.flatten(old_state_shapes)
@ -135,8 +135,8 @@ class _ScanDataset(dataset_ops.UnaryDataset):
return [self._scan_func] return [self._scan_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
def _transformation_name(self): def _transformation_name(self):
return "tf.data.experimental.scan()" return "tf.data.experimental.scan()"

View File

@ -43,21 +43,6 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
# For testing deserialization of Datasets represented as functions
class _RevivedDataset(dataset_ops.DatasetV2):
def __init__(self, variant, element_structure):
self._structure = element_structure
super(_RevivedDataset, self).__init__(variant)
def _inputs(self):
return []
@property
def _element_structure(self):
return self._structure
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@ -75,8 +60,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
fn = original_dataset._trace_variant_creation() fn = original_dataset._trace_variant_creation()
variant = fn() variant = fn()
revived_dataset = _RevivedDataset( revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset._element_structure) variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
def testAsFunctionWithMapInFlatMap(self): def testAsFunctionWithMapInFlatMap(self):
@ -88,8 +73,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
fn = original_dataset._trace_variant_creation() fn = original_dataset._trace_variant_creation()
variant = fn() variant = fn()
revived_dataset = _RevivedDataset( revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset._element_structure) variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, list(original_dataset)) self.assertDatasetProduces(revived_dataset, list(original_dataset))
@staticmethod @staticmethod

View File

@ -315,14 +315,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
"or when eager execution is enabled.") "or when eager execution is enabled.")
@abc.abstractproperty @abc.abstractproperty
def _element_structure(self): def element_spec(self):
"""The structure of an element of this dataset. """The type specification of an element of this dataset.
Returns: Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this dataset and specifying the type of individual components. element of this dataset and specifying the type of individual components.
""" """
raise NotImplementedError("Dataset._element_structure") raise NotImplementedError("Dataset.element_spec")
def __repr__(self): def __repr__(self):
output_shapes = nest.map_structure(str, get_legacy_output_shapes(self)) output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
@ -339,7 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Returns: Returns:
A list `tf.TensorShapes`s for the element tensor representation. A list `tf.TensorShapes`s for the element tensor representation.
""" """
return structure.get_flat_tensor_shapes(self._element_structure) return structure.get_flat_tensor_shapes(self.element_spec)
@property @property
def _flat_types(self): def _flat_types(self):
@ -348,7 +348,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Returns: Returns:
A list `tf.DType`s for the element tensor representation. A list `tf.DType`s for the element tensor representation.
""" """
return structure.get_flat_tensor_types(self._element_structure) return structure.get_flat_tensor_types(self.element_spec)
@property @property
def _flat_structure(self): def _flat_structure(self):
@ -371,7 +371,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
@property @property
def _type_spec(self): def _type_spec(self):
return DatasetStructure(self._element_structure) return DatasetStructure(self.element_spec)
@staticmethod @staticmethod
def from_tensors(tensors): def from_tensors(tensors):
@ -1442,7 +1442,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
wrapped_func = StructuredFunctionWrapper( wrapped_func = StructuredFunctionWrapper(
reduce_func, reduce_func,
"reduce()", "reduce()",
input_structure=(state_structure, self._element_structure), input_structure=(state_structure, self.element_spec),
add_to_graph=False) add_to_graph=False)
# Extract and validate class information from the returned values. # Extract and validate class information from the returned values.
@ -1546,17 +1546,17 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
def normalize(arg, *rest): def normalize(arg, *rest):
# pylint: disable=protected-access # pylint: disable=protected-access
if rest: if rest:
return structure.to_batched_tensor_list(self._element_structure, return structure.to_batched_tensor_list(self.element_spec,
(arg,) + rest) (arg,) + rest)
else: else:
return structure.to_batched_tensor_list(self._element_structure, arg) return structure.to_batched_tensor_list(self.element_spec, arg)
normalized_dataset = self.map(normalize) normalized_dataset = self.map(normalize)
# NOTE(mrry): Our `map()` has lost information about the structure of # NOTE(mrry): Our `map()` has lost information about the structure of
# non-tensor components, so re-apply the structure of the original dataset. # non-tensor components, so re-apply the structure of the original dataset.
restructured_dataset = _RestructuredDataset(normalized_dataset, restructured_dataset = _RestructuredDataset(normalized_dataset,
self._element_structure) self.element_spec)
return _UnbatchDataset(restructured_dataset) return _UnbatchDataset(restructured_dataset)
def with_options(self, options): def with_options(self, options):
@ -1750,7 +1750,7 @@ class DatasetV1(DatasetV2):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._element_structure) self.element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -1764,7 +1764,7 @@ class DatasetV1(DatasetV2):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._element_structure) self.element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -1778,10 +1778,10 @@ class DatasetV1(DatasetV2):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._element_structure) self.element_spec)
@property @property
def _element_structure(self): def element_spec(self):
# TODO(b/110122868): Remove this override once all `Dataset` instances # TODO(b/110122868): Remove this override once all `Dataset` instances
# implement `element_structure`. # implement `element_structure`.
return structure.convert_legacy_structure( return structure.convert_legacy_structure(
@ -2003,8 +2003,8 @@ class DatasetV1Adapter(DatasetV1):
return self._dataset.options() return self._dataset.options()
@property @property
def _element_structure(self): def element_spec(self):
return self._dataset._element_structure # pylint: disable=protected-access return self._dataset.element_spec # pylint: disable=protected-access
def __iter__(self): def __iter__(self):
return iter(self._dataset) return iter(self._dataset)
@ -2107,7 +2107,7 @@ def get_structure(dataset_or_iterator):
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object. TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
""" """
try: try:
return dataset_or_iterator._element_structure # pylint: disable=protected-access return dataset_or_iterator.element_spec # pylint: disable=protected-access
except AttributeError: except AttributeError:
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator " raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
"object, but got %s." % type(dataset_or_iterator)) "object, but got %s." % type(dataset_or_iterator))
@ -2306,8 +2306,8 @@ class UnaryUnchangedStructureDataset(UnaryDataset):
input_dataset, variant_tensor) input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._input_dataset._element_structure # pylint: disable=protected-access return self._input_dataset.element_spec
class TensorDataset(DatasetSource): class TensorDataset(DatasetSource):
@ -2325,7 +2325,7 @@ class TensorDataset(DatasetSource):
super(TensorDataset, self).__init__(variant_tensor) super(TensorDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -2352,7 +2352,7 @@ class TensorSliceDataset(DatasetSource):
super(TensorSliceDataset, self).__init__(variant_tensor) super(TensorSliceDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -2381,7 +2381,7 @@ class SparseTensorSliceDataset(DatasetSource):
super(SparseTensorSliceDataset, self).__init__(variant_tensor) super(SparseTensorSliceDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -2396,20 +2396,20 @@ class _VariantDataset(DatasetV2):
return [] return []
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
class _NestedVariant(composite_tensor.CompositeTensor): class _NestedVariant(composite_tensor.CompositeTensor):
def __init__(self, variant_tensor, element_structure, dataset_shape): def __init__(self, variant_tensor, element_spec, dataset_shape):
self._variant_tensor = variant_tensor self._variant_tensor = variant_tensor
self._element_structure = element_structure self._element_spec = element_spec
self._dataset_shape = dataset_shape self._dataset_shape = dataset_shape
@property @property
def _type_spec(self): def _type_spec(self):
return DatasetStructure(self._element_structure, self._dataset_shape) return DatasetStructure(self._element_spec, self._dataset_shape)
@tf_export("data.experimental.from_variant") @tf_export("data.experimental.from_variant")
@ -2445,10 +2445,10 @@ def to_variant(dataset):
class DatasetStructure(type_spec.BatchableTypeSpec): class DatasetStructure(type_spec.BatchableTypeSpec):
"""Type specification for `tf.data.Dataset`.""" """Type specification for `tf.data.Dataset`."""
__slots__ = ["_element_structure", "_dataset_shape"] __slots__ = ["_element_spec", "_dataset_shape"]
def __init__(self, element_spec, dataset_shape=None): def __init__(self, element_spec, dataset_shape=None):
self._element_structure = element_spec self._element_spec = element_spec
if dataset_shape: if dataset_shape:
self._dataset_shape = dataset_shape self._dataset_shape = dataset_shape
else: else:
@ -2459,7 +2459,7 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
return _VariantDataset return _VariantDataset
def _serialize(self): def _serialize(self):
return (self._element_structure, self._dataset_shape) return (self._element_spec, self._dataset_shape)
@property @property
def _component_specs(self): def _component_specs(self):
@ -2471,10 +2471,9 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
def _from_components(self, components): def _from_components(self, components):
# pylint: disable=protected-access # pylint: disable=protected-access
if self._dataset_shape.ndims == 0: if self._dataset_shape.ndims == 0:
return _VariantDataset(components, self._element_structure) return _VariantDataset(components, self._element_spec)
else: else:
return _NestedVariant(components, self._element_structure, return _NestedVariant(components, self._element_spec, self._dataset_shape)
self._dataset_shape)
def _to_tensor_list(self, value): def _to_tensor_list(self, value):
return [ return [
@ -2484,17 +2483,17 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
@staticmethod @staticmethod
def from_value(value): def from_value(value):
return DatasetStructure(value._element_structure) # pylint: disable=protected-access return DatasetStructure(value.element_spec) # pylint: disable=protected-access
def _batch(self, batch_size): def _batch(self, batch_size):
return DatasetStructure( return DatasetStructure(
self._element_structure, self._element_spec,
tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
def _unbatch(self): def _unbatch(self):
if self._dataset_shape.ndims == 0: if self._dataset_shape.ndims == 0:
raise ValueError("Unbatching a dataset is only supported for rank >= 1") raise ValueError("Unbatching a dataset is only supported for rank >= 1")
return DatasetStructure(self._element_structure, self._dataset_shape[1:]) return DatasetStructure(self._element_spec, self._dataset_shape[1:])
def _to_batched_tensor_list(self, value): def _to_batched_tensor_list(self, value):
if self._dataset_shape.ndims == 0: if self._dataset_shape.ndims == 0:
@ -2571,7 +2570,7 @@ class StructuredFunctionWrapper(object):
raise ValueError("Either `dataset`, `input_structure` or all of " raise ValueError("Either `dataset`, `input_structure` or all of "
"`input_classes`, `input_shapes`, and `input_types` " "`input_classes`, `input_shapes`, and `input_types` "
"must be specified.") "must be specified.")
self._input_structure = dataset._element_structure self._input_structure = dataset.element_spec
else: else:
if not (dataset is None and input_classes is None and input_shapes is None if not (dataset is None and input_classes is None and input_shapes is None
and input_types is None): and input_types is None):
@ -2767,7 +2766,7 @@ class _GeneratorDataset(DatasetSource):
super(_GeneratorDataset, self).__init__(variant_tensor) super(_GeneratorDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._next_func.output_structure return self._next_func.output_structure
def _transformation_name(self): def _transformation_name(self):
@ -2792,7 +2791,7 @@ class ZipDataset(DatasetV2):
self._datasets = datasets self._datasets = datasets
self._structure = nest.pack_sequence_as( self._structure = nest.pack_sequence_as(
self._datasets, self._datasets,
[ds._element_structure for ds in nest.flatten(self._datasets)]) # pylint: disable=protected-access [ds.element_spec for ds in nest.flatten(self._datasets)])
variant_tensor = gen_dataset_ops.zip_dataset( variant_tensor = gen_dataset_ops.zip_dataset(
[ds._variant_tensor for ds in nest.flatten(self._datasets)], [ds._variant_tensor for ds in nest.flatten(self._datasets)],
**self._flat_structure) **self._flat_structure)
@ -2802,7 +2801,7 @@ class ZipDataset(DatasetV2):
return nest.flatten(self._datasets) return nest.flatten(self._datasets)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -2850,7 +2849,7 @@ class ConcatenateDataset(DatasetV2):
return self._input_datasets return self._input_datasets
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -2907,7 +2906,7 @@ class RangeDataset(DatasetSource):
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -3037,11 +3036,11 @@ class BatchDataset(UnaryDataset):
self._structure = nest.map_structure( self._structure = nest.map_structure(
lambda component_spec: component_spec._batch( lambda component_spec: component_spec._batch(
tensor_util.constant_value(self._batch_size)), tensor_util.constant_value(self._batch_size)),
input_dataset._element_structure) input_dataset.element_spec)
else: else:
self._structure = nest.map_structure( self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(None), lambda component_spec: component_spec._batch(None),
input_dataset._element_structure) input_dataset.element_spec)
variant_tensor = gen_dataset_ops.batch_dataset_v2( variant_tensor = gen_dataset_ops.batch_dataset_v2(
input_dataset._variant_tensor, input_dataset._variant_tensor,
batch_size=self._batch_size, batch_size=self._batch_size,
@ -3050,7 +3049,7 @@ class BatchDataset(UnaryDataset):
super(BatchDataset, self).__init__(input_dataset, variant_tensor) super(BatchDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -3268,7 +3267,7 @@ class PaddedBatchDataset(UnaryDataset):
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -3308,7 +3307,7 @@ class MapDataset(UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._map_func.output_structure return self._map_func.output_structure
def _transformation_name(self): def _transformation_name(self):
@ -3350,7 +3349,7 @@ class ParallelMapDataset(UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._map_func.output_structure return self._map_func.output_structure
def _transformation_name(self): def _transformation_name(self):
@ -3369,7 +3368,7 @@ class FlatMapDataset(UnaryDataset):
raise TypeError( raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format( "`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure))) type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.flat_map_dataset( variant_tensor = gen_dataset_ops.flat_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, self._map_func.function.captured_inputs,
@ -3381,7 +3380,7 @@ class FlatMapDataset(UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
def _transformation_name(self): def _transformation_name(self):
@ -3400,7 +3399,7 @@ class InterleaveDataset(UnaryDataset):
raise TypeError( raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format( "`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure))) type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor( self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
@ -3419,7 +3418,7 @@ class InterleaveDataset(UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
def _transformation_name(self): def _transformation_name(self):
@ -3439,7 +3438,7 @@ class ParallelInterleaveDataset(UnaryDataset):
raise TypeError( raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format( "`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure))) type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor( self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
@ -3461,7 +3460,7 @@ class ParallelInterleaveDataset(UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
def _transformation_name(self): def _transformation_name(self):
@ -3561,7 +3560,7 @@ class WindowDataset(UnaryDataset):
super(WindowDataset, self).__init__(input_dataset, variant_tensor) super(WindowDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -3684,7 +3683,7 @@ class _RestructuredDataset(UnaryDataset):
super(_RestructuredDataset, self).__init__(dataset, variant_tensor) super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure
@ -3713,5 +3712,5 @@ class _UnbatchDataset(UnaryDataset):
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._structure

View File

@ -104,10 +104,12 @@ class Iterator(trackable.Trackable):
raise ValueError("If `structure` is not specified, all of " raise ValueError("If `structure` is not specified, all of "
"`output_types`, `output_shapes`, and `output_classes`" "`output_types`, `output_shapes`, and `output_classes`"
" must be specified.") " must be specified.")
self._structure = structure.convert_legacy_structure( self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes) output_types, output_shapes, output_classes)
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(self._structure) self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
self._flat_tensor_types = structure.get_flat_tensor_types(self._structure) self._element_spec)
self._flat_tensor_types = structure.get_flat_tensor_types(
self._element_spec)
self._string_handle = gen_dataset_ops.iterator_to_string_handle( self._string_handle = gen_dataset_ops.iterator_to_string_handle(
self._iterator_resource) self._iterator_resource)
@ -347,13 +349,13 @@ class Iterator(trackable.Trackable):
# pylint: disable=protected-access # pylint: disable=protected-access
dataset_output_types = nest.map_structure( dataset_output_types = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), lambda component_spec: component_spec._to_legacy_output_types(),
dataset._element_structure) dataset.element_spec)
dataset_output_shapes = nest.map_structure( dataset_output_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), lambda component_spec: component_spec._to_legacy_output_shapes(),
dataset._element_structure) dataset.element_spec)
dataset_output_classes = nest.map_structure( dataset_output_classes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), lambda component_spec: component_spec._to_legacy_output_classes(),
dataset._element_structure) dataset.element_spec)
# pylint: enable=protected-access # pylint: enable=protected-access
nest.assert_same_structure(self.output_types, dataset_output_types) nest.assert_same_structure(self.output_types, dataset_output_types)
@ -436,7 +438,7 @@ class Iterator(trackable.Trackable):
output_types=self._flat_tensor_types, output_types=self._flat_tensor_types,
output_shapes=self._flat_tensor_shapes, output_shapes=self._flat_tensor_shapes,
name=name) name=name)
return structure.from_tensor_list(self._structure, flat_ret) return structure.from_tensor_list(self._element_spec, flat_ret)
def string_handle(self, name=None): def string_handle(self, name=None):
"""Returns a string-valued `tf.Tensor` that represents this iterator. """Returns a string-valued `tf.Tensor` that represents this iterator.
@ -467,7 +469,7 @@ class Iterator(trackable.Trackable):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -481,7 +483,7 @@ class Iterator(trackable.Trackable):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -495,17 +497,17 @@ class Iterator(trackable.Trackable):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
def _element_structure(self): def element_spec(self):
"""The structure of an element of this iterator. """The type specification of an element of this iterator.
Returns: Returns:
A `Structure` object representing the structure of the components of this A nested structure of `tf.TypeSpec` objects matching the structure of an
optional. element of this iterator and specifying the type of individual components.
""" """
return self._structure return self._element_spec
def _gather_saveables_for_checkpoint(self): def _gather_saveables_for_checkpoint(self):
@ -556,7 +558,7 @@ class IteratorResourceDeleter(object):
class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor): class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.""" """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset=None, components=None, element_structure=None): def __init__(self, dataset=None, components=None, element_spec=None):
"""Creates a new iterator from the given dataset. """Creates a new iterator from the given dataset.
If `dataset` is not specified, the iterator will be created from the given If `dataset` is not specified, the iterator will be created from the given
@ -567,29 +569,29 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
Args: Args:
dataset: A `tf.data.Dataset` object. dataset: A `tf.data.Dataset` object.
components: Tensor components to construct the iterator from. components: Tensor components to construct the iterator from.
element_structure: A nested structure of `TypeSpec` objects that element_spec: A nested structure of `TypeSpec` objects that
represents the type specification elements of the iterator. represents the type specification of elements of the iterator.
Raises: Raises:
ValueError: If `dataset` is not provided and either `components` or ValueError: If `dataset` is not provided and either `components` or
`element_structure` is not provided. Or `dataset` is provided and either `element_spec` is not provided. Or `dataset` is provided and either
`components` and `element_structure` is provided. `components` and `element_spec` is provided.
""" """
error_message = "Either `dataset` or both `components` and " error_message = "Either `dataset` or both `components` and "
"`element_structure` need to be provided." "`element_spec` need to be provided."
self._device = context.context().device_name self._device = context.context().device_name
if dataset is None: if dataset is None:
if (components is None or element_structure is None): if (components is None or element_spec is None):
raise ValueError(error_message) raise ValueError(error_message)
# pylint: disable=protected-access # pylint: disable=protected-access
self._structure = element_structure self._element_spec = element_spec
self._flat_output_types = structure.get_flat_tensor_types( self._flat_output_types = structure.get_flat_tensor_types(
self._structure) self._element_spec)
self._flat_output_shapes = structure.get_flat_tensor_shapes( self._flat_output_shapes = structure.get_flat_tensor_shapes(
self._structure) self._element_spec)
self._iterator_resource, self._deleter = components self._iterator_resource, self._deleter = components
# Delete the resource when this object is deleted # Delete the resource when this object is deleted
self._resource_deleter = IteratorResourceDeleter( self._resource_deleter = IteratorResourceDeleter(
@ -597,7 +599,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
device=self._device, device=self._device,
deleter=self._deleter) deleter=self._deleter)
else: else:
if (components is not None or element_structure is not None): if (components is not None or element_spec is not None):
raise ValueError(error_message) raise ValueError(error_message)
if (_device_stack_is_empty() or if (_device_stack_is_empty() or
context.context().device_spec.device_type != "CPU"): context.context().device_spec.device_type != "CPU"):
@ -610,11 +612,11 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
# pylint: disable=protected-access # pylint: disable=protected-access
dataset = dataset._apply_options() dataset = dataset._apply_options()
ds_variant = dataset._variant_tensor ds_variant = dataset._variant_tensor
self._structure = dataset._element_structure self._element_spec = dataset.element_spec
self._flat_output_types = structure.get_flat_tensor_types( self._flat_output_types = structure.get_flat_tensor_types(
self._structure) self._element_spec)
self._flat_output_shapes = structure.get_flat_tensor_shapes( self._flat_output_shapes = structure.get_flat_tensor_shapes(
self._structure) self._element_spec)
with ops.colocate_with(ds_variant): with ops.colocate_with(ds_variant):
self._iterator_resource, self._deleter = ( self._iterator_resource, self._deleter = (
gen_dataset_ops.anonymous_iterator_v2( gen_dataset_ops.anonymous_iterator_v2(
@ -642,7 +644,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
self._iterator_resource, self._iterator_resource,
output_types=self._flat_output_types, output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes) output_shapes=self._flat_output_shapes)
return structure.from_compatible_tensor_list(self._structure, ret) return structure.from_compatible_tensor_list(self._element_spec, ret)
# This runs in sync mode as iterators use an error status to communicate # This runs in sync mode as iterators use an error status to communicate
# that there is no more data to iterate over. # that there is no more data to iterate over.
@ -664,13 +666,13 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
try: try:
# Fast path for the case `self._structure` is not a nested structure. # Fast path for the case `self._structure` is not a nested structure.
return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access
except AttributeError: except AttributeError:
return structure.from_compatible_tensor_list(self._structure, ret) return structure.from_compatible_tensor_list(self._element_spec, ret)
@property @property
def _type_spec(self): def _type_spec(self):
return IteratorSpec(self._element_structure) return IteratorSpec(self.element_spec)
def next(self): def next(self):
"""Returns a nested structure of `Tensor`s containing the next element.""" """Returns a nested structure of `Tensor`s containing the next element."""
@ -693,7 +695,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -707,7 +709,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -721,17 +723,17 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
""" """
return nest.map_structure( return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._structure) self._element_spec)
@property @property
def _element_structure(self): def element_spec(self):
"""The structure of an element of this iterator. """The type specification of an element of this iterator.
Returns: Returns:
A `Structure` object representing the structure of the components of this A nested structure of `tf.TypeSpec` objects matching the structure of an
optional. element of this iterator and specifying the type of individual components.
""" """
return self._structure return self._element_spec
def get_next(self, name=None): def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s containing the next element. """Returns a nested structure of `tf.Tensor`s containing the next element.
@ -786,11 +788,11 @@ class IteratorSpec(type_spec.TypeSpec):
return IteratorV2( return IteratorV2(
dataset=None, dataset=None,
components=components, components=components,
element_structure=self._element_spec) element_spec=self._element_spec)
@staticmethod @staticmethod
def from_value(value): def from_value(value):
return IteratorSpec(value._element_structure) # pylint: disable=protected-access return IteratorSpec(value.element_spec) # pylint: disable=protected-access
# TODO(b/71645805): Expose trackable stateful objects from dataset # TODO(b/71645805): Expose trackable stateful objects from dataset
@ -828,7 +830,6 @@ def get_next_as_optional(iterator):
return optional_ops._OptionalImpl( return optional_ops._OptionalImpl(
gen_dataset_ops.iterator_get_next_as_optional( gen_dataset_ops.iterator_get_next_as_optional(
iterator._iterator_resource, iterator._iterator_resource,
output_types=structure.get_flat_tensor_types( output_types=structure.get_flat_tensor_types(iterator.element_spec),
iterator._element_structure),
output_shapes=structure.get_flat_tensor_shapes( output_shapes=structure.get_flat_tensor_shapes(
iterator._element_structure)), iterator._element_structure) iterator.element_spec)), iterator.element_spec)

View File

@ -36,8 +36,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
"""A `dummy` generator dataset.""" """A `dummy` generator dataset."""
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
source_device, element_structure): source_device, element_spec):
self._structure = element_structure self._element_spec = element_spec
multi_device_iterator_string_handle = ( multi_device_iterator_string_handle = (
gen_dataset_ops.multi_device_iterator_to_string_handle( gen_dataset_ops.multi_device_iterator_to_string_handle(
@ -71,14 +71,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
multi_device_iterator = ( multi_device_iterator = (
gen_dataset_ops.multi_device_iterator_from_string_handle( gen_dataset_ops.multi_device_iterator_from_string_handle(
string_handle=string_handle, string_handle=string_handle,
output_types=structure.get_flat_tensor_types(self._structure), output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(self._structure))) output_shapes=structure.get_flat_tensor_shapes(
self._element_spec)))
return gen_dataset_ops.multi_device_iterator_get_next_from_shard( return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
multi_device_iterator=multi_device_iterator, multi_device_iterator=multi_device_iterator,
shard_num=shard_num, shard_num=shard_num,
incarnation_id=incarnation_id, incarnation_id=incarnation_id,
output_types=structure.get_flat_tensor_types(self._structure), output_types=structure.get_flat_tensor_types(self._element_spec),
output_shapes=structure.get_flat_tensor_shapes(self._structure)) output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -91,7 +92,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
return functional_ops.remote_call( return functional_ops.remote_call(
target=source_device, target=source_device,
args=[string_handle] + next_func_concrete.captured_inputs, args=[string_handle] + next_func_concrete.captured_inputs,
Tout=structure.get_flat_tensor_types(self._structure), Tout=structure.get_flat_tensor_types(self._element_spec),
f=next_func_concrete) f=next_func_concrete)
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
@ -141,8 +142,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
return [] return []
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2): class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
@ -154,7 +155,7 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
def __init__(self, per_device_dataset, incarnation_id): def __init__(self, per_device_dataset, incarnation_id):
# pylint: disable=protected-access # pylint: disable=protected-access
self._structure = per_device_dataset._element_structure self._element_spec = per_device_dataset.element_spec
self._init_func = per_device_dataset._init_func self._init_func = per_device_dataset._init_func
self._init_captured_args = self._init_func.captured_inputs self._init_captured_args = self._init_func.captured_inputs
@ -183,8 +184,8 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
return [] return []
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
class MultiDeviceIterator(object): class MultiDeviceIterator(object):
@ -253,7 +254,7 @@ class MultiDeviceIterator(object):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
self._incarnation_id, self._incarnation_id,
self._source_device_tensor, self._source_device_tensor,
self._dataset._element_structure) # pylint: disable=protected-access self._dataset.element_spec)
self._prototype_device_datasets.append(ds) self._prototype_device_datasets.append(ds)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
@ -339,5 +340,5 @@ class MultiDeviceIterator(object):
ds_variant, self._device_iterators[i]._iterator_resource) ds_variant, self._device_iterators[i]._iterator_resource)
@property @property
def _element_structure(self): def element_spec(self):
return dataset_ops.get_structure(self._dataset) return self._dataset.element_spec

View File

@ -115,7 +115,7 @@ class _TextLineDataset(dataset_ops.DatasetSource):
super(_TextLineDataset, self).__init__(variant_tensor) super(_TextLineDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -157,7 +157,7 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
super(TextLineDatasetV2, self).__init__(variant_tensor) super(TextLineDatasetV2, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -209,7 +209,7 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
super(_TFRecordDataset, self).__init__(variant_tensor) super(_TFRecordDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -225,7 +225,7 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
if not isinstance(self._map_func.output_structure, if not isinstance(self._map_func.output_structure,
dataset_ops.DatasetStructure): dataset_ops.DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.") raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor( self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
@ -257,8 +257,8 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
return [self._map_func] return [self._map_func]
@property @property
def _element_structure(self): def element_spec(self):
return self._structure return self._element_spec
def _transformation_name(self): def _transformation_name(self):
return "tf.data.experimental.parallel_interleave()" return "tf.data.experimental.parallel_interleave()"
@ -321,7 +321,7 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
return self._impl._inputs() # pylint: disable=protected-access return self._impl._inputs() # pylint: disable=protected-access
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -408,7 +408,7 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
super(_FixedLengthRecordDataset, self).__init__(variant_tensor) super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])
@ -466,7 +466,7 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor) super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
@property @property
def _element_structure(self): def element_spec(self):
return structure.TensorStructure(dtypes.string, []) return structure.TensorStructure(dtypes.string, [])

View File

@ -480,7 +480,7 @@ class DistributedDataset(_IterableInput):
self._input_workers = input_workers self._input_workers = input_workers
# TODO(anjalisridhar): Identify if we need to set this property on the # TODO(anjalisridhar): Identify if we need to set this property on the
# iterator. # iterator.
self._element_structure = dataset._element_structure # pylint: disable=protected-access self._element_spec = dataset.element_spec
self._strategy = strategy self._strategy = strategy
def __iter__(self): def __iter__(self):
@ -490,7 +490,7 @@ class DistributedDataset(_IterableInput):
self._input_workers) self._input_workers)
iterator = DistributedIterator(self._input_workers, worker_iterators, iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy) self._strategy)
iterator._element_structure = self._element_structure # pylint: disable=protected-access iterator.element_spec = self._element_spec
return iterator return iterator
raise RuntimeError("__iter__() is only supported inside of tf.function " raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.") "or when eager execution is enabled.")
@ -537,7 +537,7 @@ class DistributedDatasetV1(DistributedDataset):
self._input_workers) self._input_workers)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators, iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy) self._strategy)
iterator._element_structure = self._element_structure # pylint: disable=protected-access iterator.element_spec = self._element_spec
return iterator return iterator
@ -670,9 +670,9 @@ class DatasetIterator(DistributedIteratorV1):
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
super(DatasetIterator, self).__init__( super(DatasetIterator, self).__init__(
input_workers, input_workers,
worker_iterators, # pylint: disable=protected-access worker_iterators,
strategy) strategy)
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access self._element_spec = dist_dataset._element_spec # pylint: disable=protected-access
def _dummy_tensor_fn(value_structure): def _dummy_tensor_fn(value_structure):

View File

@ -56,8 +56,7 @@ def _clone_dataset(dataset):
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
return dataset_ops._VariantDataset(new_variant_tensor, return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
dataset._element_structure)
def _get_op_def(op): def _get_op_def(op):

View File

@ -276,8 +276,7 @@ class CloneDatasetTest(test.TestCase):
def _assert_datasets_equal(self, ds1, ds2): def _assert_datasets_equal(self, ds1, ds2):
# First lets assert the structure is the same. # First lets assert the structure is the same.
self.assertTrue( self.assertTrue(
structure.are_compatible(ds1._element_structure, structure.are_compatible(ds1.element_spec, ds2.element_spec))
ds2._element_structure))
# Now create iterators on both and assert they produce the same values. # Now create iterators on both and assert they produce the same values.
it1 = dataset_ops.make_initializable_iterator(ds1) it1 = dataset_ops.make_initializable_iterator(ds1)

View File

@ -5,6 +5,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>" is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "initializer" name: "initializer"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "output_classes" name: "output_classes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "

View File

@ -5,6 +5,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], " argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -6,6 +6,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"