[tf.data] Exposing dataset / iterator element type specification in the public API as element_spec
.
This CL renames the pre-existing `_element_structure` property of tf.data datasets and iterators to `element_spec`, thus exposing it in the public API. PiperOrigin-RevId: 256201202
This commit is contained in:
parent
ff67edeac1
commit
9db44da931
@ -591,7 +591,7 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
|
|||||||
self._table = table
|
self._table = table
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -652,7 +652,7 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
|
|||||||
super(_BigtableLookupDataset, self).__init__(variant_tensor)
|
super(_BigtableLookupDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
||||||
self._num_outputs)
|
self._num_outputs)
|
||||||
|
|
||||||
@ -681,7 +681,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
|
|||||||
super(_BigtableScanDataset, self).__init__(variant_tensor)
|
super(_BigtableScanDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
||||||
self._num_outputs)
|
self._num_outputs)
|
||||||
|
|
||||||
@ -703,6 +703,6 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
|
|||||||
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
|
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return (structure.TensorStructure(dtypes.string, []),
|
return (structure.TensorStructure(dtypes.string, []),
|
||||||
structure.TensorStructure(dtypes.string, []))
|
structure.TensorStructure(dtypes.string, []))
|
||||||
|
@ -346,13 +346,13 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
|
|||||||
output_classes = nest.pack_sequence_as(
|
output_classes = nest.pack_sequence_as(
|
||||||
output_types, nest.flatten(input_classes))
|
output_types, nest.flatten(input_classes))
|
||||||
|
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
output_types, output_shapes, output_classes)
|
output_types, output_shapes, output_classes)
|
||||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
|
@ -395,6 +395,6 @@ class LMDBDataset(dataset_ops.DatasetSource):
|
|||||||
super(LMDBDataset, self).__init__(variant_tensor)
|
super(LMDBDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return (structure.TensorStructure(dtypes.string, []),
|
return (structure.TensorStructure(dtypes.string, []),
|
||||||
structure.TensorStructure(dtypes.string, []))
|
structure.TensorStructure(dtypes.string, []))
|
||||||
|
@ -39,7 +39,7 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
|||||||
window_shift, dtype=dtypes.int64, name="window_shift")
|
window_shift, dtype=dtypes.int64, name="window_shift")
|
||||||
|
|
||||||
input_structure = dataset_ops.get_structure(input_dataset)
|
input_structure = dataset_ops.get_structure(input_dataset)
|
||||||
self._structure = nest.map_structure(
|
self._element_spec = nest.map_structure(
|
||||||
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
|
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
|
||||||
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
||||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
@ -50,8 +50,8 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
|||||||
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
|
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated_args(
|
@deprecation.deprecated_args(
|
||||||
|
@ -58,11 +58,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
|
|||||||
self._filenames = ops.convert_to_tensor(
|
self._filenames = ops.convert_to_tensor(
|
||||||
filenames, dtype=dtypes.string, name="filenames")
|
filenames, dtype=dtypes.string, name="filenames")
|
||||||
variant_tensor = gen_dataset_ops.sequence_file_dataset(
|
variant_tensor = gen_dataset_ops.sequence_file_dataset(
|
||||||
self._filenames,
|
self._filenames, self._flat_types)
|
||||||
structure.get_flat_tensor_types(self._element_structure))
|
|
||||||
super(SequenceFileDataset, self).__init__(variant_tensor)
|
super(SequenceFileDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return (structure.TensorStructure(dtypes.string, []),
|
return (structure.TensorStructure(dtypes.string, []),
|
||||||
structure.TensorStructure(dtypes.string, []))
|
structure.TensorStructure(dtypes.string, []))
|
||||||
|
@ -754,7 +754,7 @@ class IgniteDataset(dataset_ops.DatasetSource):
|
|||||||
self.cache_type.to_permutation(),
|
self.cache_type.to_permutation(),
|
||||||
dtype=dtypes.int32,
|
dtype=dtypes.int32,
|
||||||
name="permutation")
|
name="permutation")
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
self.cache_type.to_output_types(), self.cache_type.to_output_shapes(),
|
self.cache_type.to_output_types(), self.cache_type.to_output_shapes(),
|
||||||
self.cache_type.to_output_classes())
|
self.cache_type.to_output_classes())
|
||||||
|
|
||||||
@ -766,5 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource):
|
|||||||
self.schema, self.permutation)
|
self.schema, self.permutation)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
@ -69,5 +69,5 @@ class KafkaDataset(dataset_ops.DatasetSource):
|
|||||||
self._group, self._eof, self._timeout)
|
self._group, self._eof, self._timeout)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
@ -86,5 +86,5 @@ class KinesisDataset(dataset_ops.DatasetSource):
|
|||||||
self._stream, self._shard, self._read_indefinitely, self._interval)
|
self._stream, self._shard, self._read_indefinitely, self._interval)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
@ -37,7 +37,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
|||||||
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
|
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
|
||||||
|
|
||||||
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
||||||
ds._element_structure)
|
ds.element_spec)
|
||||||
get_next = self.getNext(variant_ds, requires_initialization=True)
|
get_next = self.getNext(variant_ds, requires_initialization=True)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, self.evaluate(get_next()))
|
self.assertEqual(i, self.evaluate(get_next()))
|
||||||
@ -54,7 +54,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
|||||||
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
|
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
|
||||||
gpu_wrapped_variant)
|
gpu_wrapped_variant)
|
||||||
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
||||||
ds._element_structure)
|
ds.element_spec)
|
||||||
iterator = dataset_ops.make_initializable_iterator(variant_ds)
|
iterator = dataset_ops.make_initializable_iterator(variant_ds)
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._row_shape = row_shape
|
self._row_shape = row_shape
|
||||||
self._structure = structure.SparseTensorStructure(
|
self._element_spec = structure.SparseTensorStructure(
|
||||||
dataset_ops.get_legacy_output_types(input_dataset),
|
dataset_ops.get_legacy_output_types(input_dataset),
|
||||||
tensor_shape.vector(None).concatenate(self._row_shape))
|
tensor_shape.vector(None).concatenate(self._row_shape))
|
||||||
|
|
||||||
@ -255,8 +255,8 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
variant_tensor)
|
variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||||
@ -285,12 +285,12 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
|
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
|
||||||
# or `False` (explicitly retaining the remainder).
|
# or `False` (explicitly retaining the remainder).
|
||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
self._structure = nest.map_structure(
|
self._element_spec = nest.map_structure(
|
||||||
lambda component_spec: component_spec._batch(
|
lambda component_spec: component_spec._batch(
|
||||||
tensor_util.constant_value(self._batch_size_t)),
|
tensor_util.constant_value(self._batch_size_t)),
|
||||||
self._map_func.output_structure)
|
self._map_func.output_structure)
|
||||||
else:
|
else:
|
||||||
self._structure = nest.map_structure(
|
self._element_spec = nest.map_structure(
|
||||||
lambda component_spec: component_spec._batch(None),
|
lambda component_spec: component_spec._batch(None),
|
||||||
self._map_func.output_structure)
|
self._map_func.output_structure)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -309,5 +309,5 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
@ -46,7 +46,7 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
|||||||
def __init__(self, input_dataset, num_workers, index):
|
def __init__(self, input_dataset, num_workers, index):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
|
|
||||||
self._structure = input_dataset._element_structure # pylint: disable=protected-access
|
self._element_spec = input_dataset.element_spec
|
||||||
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
||||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
@ -55,8 +55,8 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
|||||||
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
|
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
def _AutoShardDatasetV1(input_dataset, num_workers, index):
|
def _AutoShardDatasetV1(input_dataset, num_workers, index):
|
||||||
@ -85,7 +85,7 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
|||||||
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
|
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
|
||||||
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
|
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
|
||||||
|
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
input_types, output_shapes, input_classes)
|
input_types, output_shapes, input_classes)
|
||||||
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
||||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
@ -94,8 +94,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
|||||||
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
|
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
||||||
|
@ -65,6 +65,6 @@ def get_single_element(dataset):
|
|||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return structure.from_compatible_tensor_list(
|
return structure.from_compatible_tensor_list(
|
||||||
dataset._element_structure,
|
dataset.element_spec,
|
||||||
gen_dataset_ops.dataset_to_single_element(
|
gen_dataset_ops.dataset_to_single_element(
|
||||||
dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access
|
dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access
|
||||||
|
@ -299,8 +299,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
|||||||
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
reduce_func,
|
reduce_func,
|
||||||
self._transformation_name(),
|
self._transformation_name(),
|
||||||
input_structure=(self._state_structure,
|
input_structure=(self._state_structure, input_dataset.element_spec),
|
||||||
input_dataset._element_structure), # pylint: disable=protected-access
|
|
||||||
add_to_graph=False)
|
add_to_graph=False)
|
||||||
|
|
||||||
# Extract and validate class information from the returned values.
|
# Extract and validate class information from the returned values.
|
||||||
@ -355,7 +354,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
|||||||
input_structure=self._state_structure)
|
input_structure=self._state_structure)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._finalize_func.output_structure
|
return self._finalize_func.output_structure
|
||||||
|
|
||||||
def _functions(self):
|
def _functions(self):
|
||||||
@ -416,7 +415,7 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
|||||||
def _make_reduce_func(self, reduce_func, input_dataset):
|
def _make_reduce_func(self, reduce_func, input_dataset):
|
||||||
"""Make wrapping defun for reduce_func."""
|
"""Make wrapping defun for reduce_func."""
|
||||||
nested_dataset = dataset_ops.DatasetStructure(
|
nested_dataset = dataset_ops.DatasetStructure(
|
||||||
input_dataset._element_structure) # pylint: disable=protected-access
|
input_dataset.element_spec)
|
||||||
input_structure = (structure.TensorStructure(dtypes.int64,
|
input_structure = (structure.TensorStructure(dtypes.int64,
|
||||||
[]), nested_dataset)
|
[]), nested_dataset)
|
||||||
self._reduce_func = dataset_ops.StructuredFunctionWrapper(
|
self._reduce_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
@ -426,12 +425,12 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
|||||||
self._reduce_func.output_structure, dataset_ops.DatasetStructure):
|
self._reduce_func.output_structure, dataset_ops.DatasetStructure):
|
||||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._structure = (
|
self._element_spec = (
|
||||||
self._reduce_func.output_structure._element_structure)
|
self._reduce_func.output_structure._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
def _functions(self):
|
def _functions(self):
|
||||||
return [self._key_func, self._reduce_func, self._window_size_func]
|
return [self._key_func, self._reduce_func, self._window_size_func]
|
||||||
|
@ -119,7 +119,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
|||||||
nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
|
nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
|
||||||
])
|
])
|
||||||
|
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
first_output_types, output_shapes, first_output_classes)
|
first_output_types, output_shapes, first_output_classes)
|
||||||
super(_DirectedInterleaveDataset, self).__init__()
|
super(_DirectedInterleaveDataset, self).__init__()
|
||||||
|
|
||||||
@ -136,8 +136,8 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
|||||||
return [self._selector_input] + self._data_inputs
|
return [self._selector_input] + self._data_inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.sample_from_datasets", v1=[])
|
@tf_export("data.experimental.sample_from_datasets", v1=[])
|
||||||
@ -267,8 +267,8 @@ def choose_from_datasets_v2(datasets, choice_dataset):
|
|||||||
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
|
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
|
||||||
type.
|
type.
|
||||||
"""
|
"""
|
||||||
if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
|
if not structure.are_compatible(choice_dataset.element_spec,
|
||||||
structure.TensorStructure(dtypes.int64, [])):
|
structure.TensorStructure(dtypes.int64, [])):
|
||||||
raise TypeError("`choice_dataset` must be a dataset of scalar "
|
raise TypeError("`choice_dataset` must be a dataset of scalar "
|
||||||
"`tf.int64` tensors.")
|
"`tf.int64` tensors.")
|
||||||
return _DirectedInterleaveDataset(choice_dataset, datasets)
|
return _DirectedInterleaveDataset(choice_dataset, datasets)
|
||||||
|
@ -35,5 +35,5 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
|
|||||||
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
@ -156,7 +156,7 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
|||||||
A `Dataset` that has the same elements the inputs.
|
A `Dataset` that has the same elements the inputs.
|
||||||
"""
|
"""
|
||||||
self._datasets = list(datasets)
|
self._datasets = list(datasets)
|
||||||
self._structure = self._datasets[0]._element_structure # pylint: disable=protected-access
|
self._element_spec = self._datasets[0].element_spec
|
||||||
variant_tensor = (
|
variant_tensor = (
|
||||||
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
||||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||||
@ -168,8 +168,8 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
|||||||
return self._datasets
|
return self._datasets
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._datasets[0]._element_structure # pylint: disable=protected-access
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||||
@ -242,14 +242,13 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Dataset` that has the same elements the inputs.
|
A `Dataset` that has the same elements the inputs.
|
||||||
"""
|
"""
|
||||||
input_structure = dataset_ops.DatasetStructure(
|
input_structure = dataset_ops.DatasetStructure(input_dataset.element_spec)
|
||||||
dataset_ops.get_structure(input_dataset))
|
|
||||||
self._funcs = [
|
self._funcs = [
|
||||||
dataset_ops.StructuredFunctionWrapper(
|
dataset_ops.StructuredFunctionWrapper(
|
||||||
f, "ChooseFastestV2", input_structure=input_structure)
|
f, "ChooseFastestV2", input_structure=input_structure)
|
||||||
for f in functions
|
for f in functions
|
||||||
]
|
]
|
||||||
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
|
self._element_spec = self._funcs[0].output_structure._element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
self._captured_arguments = []
|
self._captured_arguments = []
|
||||||
for f in self._funcs:
|
for f in self._funcs:
|
||||||
@ -279,5 +278,5 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
|||||||
variant_tensor)
|
variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
@ -32,7 +32,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
|||||||
|
|
||||||
def __init__(self, input_dataset, features, num_parallel_calls):
|
def __init__(self, input_dataset, features, num_parallel_calls):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
|
if not structure.are_compatible(
|
||||||
|
input_dataset.element_spec,
|
||||||
structure.TensorStructure(dtypes.string, [None])):
|
structure.TensorStructure(dtypes.string, [None])):
|
||||||
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
||||||
self._num_parallel_calls = num_parallel_calls
|
self._num_parallel_calls = num_parallel_calls
|
||||||
@ -75,7 +76,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
|||||||
[ops.Tensor for _ in range(len(self._dense_defaults))] +
|
[ops.Tensor for _ in range(len(self._dense_defaults))] +
|
||||||
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
|
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
|
||||||
]))
|
]))
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
output_types, output_shapes, output_classes)
|
output_types, output_shapes, output_classes)
|
||||||
|
|
||||||
variant_tensor = (
|
variant_tensor = (
|
||||||
@ -91,8 +92,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
|||||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/111553342): add arguments names and example names as well.
|
# TODO(b/111553342): add arguments names and example names as well.
|
||||||
|
@ -146,8 +146,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||||||
dataset_ops.get_legacy_output_types(self),
|
dataset_ops.get_legacy_output_types(self),
|
||||||
dataset_ops.get_legacy_output_shapes(self),
|
dataset_ops.get_legacy_output_shapes(self),
|
||||||
dataset_ops.get_legacy_output_classes(self))
|
dataset_ops.get_legacy_output_classes(self))
|
||||||
return structure.to_tensor_list(self._element_structure,
|
return structure.to_tensor_list(self.element_spec, iterator.get_next())
|
||||||
iterator.get_next())
|
|
||||||
|
|
||||||
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -251,7 +250,7 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._map_func.output_structure
|
return self._map_func.output_structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
|
@ -39,7 +39,7 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
|
|||||||
super(RandomDatasetV2, self).__init__(variant_tensor)
|
super(RandomDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.int64, [])
|
return structure.TensorStructure(dtypes.int64, [])
|
||||||
|
|
||||||
|
|
||||||
|
@ -662,7 +662,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
|||||||
argument_default=[],
|
argument_default=[],
|
||||||
argument_dtype=dtypes.int64,
|
argument_dtype=dtypes.int64,
|
||||||
)
|
)
|
||||||
self._structure = tuple(
|
self._element_spec = tuple(
|
||||||
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
|
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
|
||||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||||
filenames=self._filenames,
|
filenames=self._filenames,
|
||||||
@ -678,8 +678,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
|||||||
super(CsvDatasetV2, self).__init__(variant_tensor)
|
super(CsvDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["data.experimental.CsvDataset"])
|
@tf_export(v1=["data.experimental.CsvDataset"])
|
||||||
@ -955,7 +955,7 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
|||||||
data_source_name, dtype=dtypes.string, name="data_source_name")
|
data_source_name, dtype=dtypes.string, name="data_source_name")
|
||||||
self._query = ops.convert_to_tensor(
|
self._query = ops.convert_to_tensor(
|
||||||
query, dtype=dtypes.string, name="query")
|
query, dtype=dtypes.string, name="query")
|
||||||
self._structure = nest.map_structure(
|
self._element_spec = nest.map_structure(
|
||||||
lambda dtype: structure.TensorStructure(dtype, []), output_types)
|
lambda dtype: structure.TensorStructure(dtype, []), output_types)
|
||||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||||
self._driver_name, self._data_source_name, self._query,
|
self._driver_name, self._data_source_name, self._query,
|
||||||
@ -963,8 +963,8 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
|||||||
super(SqlDatasetV2, self).__init__(variant_tensor)
|
super(SqlDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["data.experimental.SqlDataset"])
|
@tf_export(v1=["data.experimental.SqlDataset"])
|
||||||
|
@ -49,7 +49,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
|||||||
scan_func,
|
scan_func,
|
||||||
self._transformation_name(),
|
self._transformation_name(),
|
||||||
input_structure=(self._state_structure,
|
input_structure=(self._state_structure,
|
||||||
input_dataset._element_structure), # pylint: disable=protected-access
|
input_dataset.element_spec),
|
||||||
add_to_graph=False)
|
add_to_graph=False)
|
||||||
if not (
|
if not (
|
||||||
isinstance(wrapped_func.output_types, collections.Sequence) and
|
isinstance(wrapped_func.output_types, collections.Sequence) and
|
||||||
@ -91,7 +91,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
|||||||
old_state_shapes = nest.map_structure(
|
old_state_shapes = nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||||
self._state_structure)
|
self._state_structure)
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
output_types, output_shapes, output_classes)
|
output_types, output_shapes, output_classes)
|
||||||
|
|
||||||
flat_state_shapes = nest.flatten(old_state_shapes)
|
flat_state_shapes = nest.flatten(old_state_shapes)
|
||||||
@ -135,8 +135,8 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
|||||||
return [self._scan_func]
|
return [self._scan_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
return "tf.data.experimental.scan()"
|
return "tf.data.experimental.scan()"
|
||||||
|
@ -43,21 +43,6 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
# For testing deserialization of Datasets represented as functions
|
|
||||||
class _RevivedDataset(dataset_ops.DatasetV2):
|
|
||||||
|
|
||||||
def __init__(self, variant, element_structure):
|
|
||||||
self._structure = element_structure
|
|
||||||
super(_RevivedDataset, self).__init__(variant)
|
|
||||||
|
|
||||||
def _inputs(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _element_structure(self):
|
|
||||||
return self._structure
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@ -75,8 +60,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
fn = original_dataset._trace_variant_creation()
|
fn = original_dataset._trace_variant_creation()
|
||||||
variant = fn()
|
variant = fn()
|
||||||
|
|
||||||
revived_dataset = _RevivedDataset(
|
revived_dataset = dataset_ops._VariantDataset(
|
||||||
variant, original_dataset._element_structure)
|
variant, original_dataset.element_spec)
|
||||||
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
|
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
|
||||||
|
|
||||||
def testAsFunctionWithMapInFlatMap(self):
|
def testAsFunctionWithMapInFlatMap(self):
|
||||||
@ -88,8 +73,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
fn = original_dataset._trace_variant_creation()
|
fn = original_dataset._trace_variant_creation()
|
||||||
variant = fn()
|
variant = fn()
|
||||||
|
|
||||||
revived_dataset = _RevivedDataset(
|
revived_dataset = dataset_ops._VariantDataset(
|
||||||
variant, original_dataset._element_structure)
|
variant, original_dataset.element_spec)
|
||||||
self.assertDatasetProduces(revived_dataset, list(original_dataset))
|
self.assertDatasetProduces(revived_dataset, list(original_dataset))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -315,14 +315,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
"or when eager execution is enabled.")
|
"or when eager execution is enabled.")
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
"""The structure of an element of this dataset.
|
"""The type specification of an element of this dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||||
element of this dataset and specifying the type of individual components.
|
element of this dataset and specifying the type of individual components.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Dataset._element_structure")
|
raise NotImplementedError("Dataset.element_spec")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
|
output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
|
||||||
@ -339,7 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
Returns:
|
Returns:
|
||||||
A list `tf.TensorShapes`s for the element tensor representation.
|
A list `tf.TensorShapes`s for the element tensor representation.
|
||||||
"""
|
"""
|
||||||
return structure.get_flat_tensor_shapes(self._element_structure)
|
return structure.get_flat_tensor_shapes(self.element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _flat_types(self):
|
def _flat_types(self):
|
||||||
@ -348,7 +348,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
Returns:
|
Returns:
|
||||||
A list `tf.DType`s for the element tensor representation.
|
A list `tf.DType`s for the element tensor representation.
|
||||||
"""
|
"""
|
||||||
return structure.get_flat_tensor_types(self._element_structure)
|
return structure.get_flat_tensor_types(self.element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _flat_structure(self):
|
def _flat_structure(self):
|
||||||
@ -371,7 +371,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
return DatasetStructure(self._element_structure)
|
return DatasetStructure(self.element_spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_tensors(tensors):
|
def from_tensors(tensors):
|
||||||
@ -1442,7 +1442,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
wrapped_func = StructuredFunctionWrapper(
|
wrapped_func = StructuredFunctionWrapper(
|
||||||
reduce_func,
|
reduce_func,
|
||||||
"reduce()",
|
"reduce()",
|
||||||
input_structure=(state_structure, self._element_structure),
|
input_structure=(state_structure, self.element_spec),
|
||||||
add_to_graph=False)
|
add_to_graph=False)
|
||||||
|
|
||||||
# Extract and validate class information from the returned values.
|
# Extract and validate class information from the returned values.
|
||||||
@ -1546,17 +1546,17 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
def normalize(arg, *rest):
|
def normalize(arg, *rest):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if rest:
|
if rest:
|
||||||
return structure.to_batched_tensor_list(self._element_structure,
|
return structure.to_batched_tensor_list(self.element_spec,
|
||||||
(arg,) + rest)
|
(arg,) + rest)
|
||||||
else:
|
else:
|
||||||
return structure.to_batched_tensor_list(self._element_structure, arg)
|
return structure.to_batched_tensor_list(self.element_spec, arg)
|
||||||
|
|
||||||
normalized_dataset = self.map(normalize)
|
normalized_dataset = self.map(normalize)
|
||||||
|
|
||||||
# NOTE(mrry): Our `map()` has lost information about the structure of
|
# NOTE(mrry): Our `map()` has lost information about the structure of
|
||||||
# non-tensor components, so re-apply the structure of the original dataset.
|
# non-tensor components, so re-apply the structure of the original dataset.
|
||||||
restructured_dataset = _RestructuredDataset(normalized_dataset,
|
restructured_dataset = _RestructuredDataset(normalized_dataset,
|
||||||
self._element_structure)
|
self.element_spec)
|
||||||
return _UnbatchDataset(restructured_dataset)
|
return _UnbatchDataset(restructured_dataset)
|
||||||
|
|
||||||
def with_options(self, options):
|
def with_options(self, options):
|
||||||
@ -1750,7 +1750,7 @@ class DatasetV1(DatasetV2):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||||
self._element_structure)
|
self.element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -1764,7 +1764,7 @@ class DatasetV1(DatasetV2):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||||
self._element_structure)
|
self.element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -1778,10 +1778,10 @@ class DatasetV1(DatasetV2):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||||
self._element_structure)
|
self.element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
# TODO(b/110122868): Remove this override once all `Dataset` instances
|
# TODO(b/110122868): Remove this override once all `Dataset` instances
|
||||||
# implement `element_structure`.
|
# implement `element_structure`.
|
||||||
return structure.convert_legacy_structure(
|
return structure.convert_legacy_structure(
|
||||||
@ -2003,8 +2003,8 @@ class DatasetV1Adapter(DatasetV1):
|
|||||||
return self._dataset.options()
|
return self._dataset.options()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._dataset._element_structure # pylint: disable=protected-access
|
return self._dataset.element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._dataset)
|
return iter(self._dataset)
|
||||||
@ -2107,7 +2107,7 @@ def get_structure(dataset_or_iterator):
|
|||||||
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
|
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return dataset_or_iterator._element_structure # pylint: disable=protected-access
|
return dataset_or_iterator.element_spec # pylint: disable=protected-access
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
|
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
|
||||||
"object, but got %s." % type(dataset_or_iterator))
|
"object, but got %s." % type(dataset_or_iterator))
|
||||||
@ -2306,8 +2306,8 @@ class UnaryUnchangedStructureDataset(UnaryDataset):
|
|||||||
input_dataset, variant_tensor)
|
input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._input_dataset._element_structure # pylint: disable=protected-access
|
return self._input_dataset.element_spec
|
||||||
|
|
||||||
|
|
||||||
class TensorDataset(DatasetSource):
|
class TensorDataset(DatasetSource):
|
||||||
@ -2325,7 +2325,7 @@ class TensorDataset(DatasetSource):
|
|||||||
super(TensorDataset, self).__init__(variant_tensor)
|
super(TensorDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -2352,7 +2352,7 @@ class TensorSliceDataset(DatasetSource):
|
|||||||
super(TensorSliceDataset, self).__init__(variant_tensor)
|
super(TensorSliceDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -2381,7 +2381,7 @@ class SparseTensorSliceDataset(DatasetSource):
|
|||||||
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
|
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -2396,20 +2396,20 @@ class _VariantDataset(DatasetV2):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
class _NestedVariant(composite_tensor.CompositeTensor):
|
class _NestedVariant(composite_tensor.CompositeTensor):
|
||||||
|
|
||||||
def __init__(self, variant_tensor, element_structure, dataset_shape):
|
def __init__(self, variant_tensor, element_spec, dataset_shape):
|
||||||
self._variant_tensor = variant_tensor
|
self._variant_tensor = variant_tensor
|
||||||
self._element_structure = element_structure
|
self._element_spec = element_spec
|
||||||
self._dataset_shape = dataset_shape
|
self._dataset_shape = dataset_shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
return DatasetStructure(self._element_structure, self._dataset_shape)
|
return DatasetStructure(self._element_spec, self._dataset_shape)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.from_variant")
|
@tf_export("data.experimental.from_variant")
|
||||||
@ -2445,10 +2445,10 @@ def to_variant(dataset):
|
|||||||
class DatasetStructure(type_spec.BatchableTypeSpec):
|
class DatasetStructure(type_spec.BatchableTypeSpec):
|
||||||
"""Type specification for `tf.data.Dataset`."""
|
"""Type specification for `tf.data.Dataset`."""
|
||||||
|
|
||||||
__slots__ = ["_element_structure", "_dataset_shape"]
|
__slots__ = ["_element_spec", "_dataset_shape"]
|
||||||
|
|
||||||
def __init__(self, element_spec, dataset_shape=None):
|
def __init__(self, element_spec, dataset_shape=None):
|
||||||
self._element_structure = element_spec
|
self._element_spec = element_spec
|
||||||
if dataset_shape:
|
if dataset_shape:
|
||||||
self._dataset_shape = dataset_shape
|
self._dataset_shape = dataset_shape
|
||||||
else:
|
else:
|
||||||
@ -2459,7 +2459,7 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
|||||||
return _VariantDataset
|
return _VariantDataset
|
||||||
|
|
||||||
def _serialize(self):
|
def _serialize(self):
|
||||||
return (self._element_structure, self._dataset_shape)
|
return (self._element_spec, self._dataset_shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
@ -2471,10 +2471,9 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
|||||||
def _from_components(self, components):
|
def _from_components(self, components):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if self._dataset_shape.ndims == 0:
|
if self._dataset_shape.ndims == 0:
|
||||||
return _VariantDataset(components, self._element_structure)
|
return _VariantDataset(components, self._element_spec)
|
||||||
else:
|
else:
|
||||||
return _NestedVariant(components, self._element_structure,
|
return _NestedVariant(components, self._element_spec, self._dataset_shape)
|
||||||
self._dataset_shape)
|
|
||||||
|
|
||||||
def _to_tensor_list(self, value):
|
def _to_tensor_list(self, value):
|
||||||
return [
|
return [
|
||||||
@ -2484,17 +2483,17 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_value(value):
|
def from_value(value):
|
||||||
return DatasetStructure(value._element_structure) # pylint: disable=protected-access
|
return DatasetStructure(value.element_spec) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _batch(self, batch_size):
|
def _batch(self, batch_size):
|
||||||
return DatasetStructure(
|
return DatasetStructure(
|
||||||
self._element_structure,
|
self._element_spec,
|
||||||
tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
|
tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
|
||||||
|
|
||||||
def _unbatch(self):
|
def _unbatch(self):
|
||||||
if self._dataset_shape.ndims == 0:
|
if self._dataset_shape.ndims == 0:
|
||||||
raise ValueError("Unbatching a dataset is only supported for rank >= 1")
|
raise ValueError("Unbatching a dataset is only supported for rank >= 1")
|
||||||
return DatasetStructure(self._element_structure, self._dataset_shape[1:])
|
return DatasetStructure(self._element_spec, self._dataset_shape[1:])
|
||||||
|
|
||||||
def _to_batched_tensor_list(self, value):
|
def _to_batched_tensor_list(self, value):
|
||||||
if self._dataset_shape.ndims == 0:
|
if self._dataset_shape.ndims == 0:
|
||||||
@ -2571,7 +2570,7 @@ class StructuredFunctionWrapper(object):
|
|||||||
raise ValueError("Either `dataset`, `input_structure` or all of "
|
raise ValueError("Either `dataset`, `input_structure` or all of "
|
||||||
"`input_classes`, `input_shapes`, and `input_types` "
|
"`input_classes`, `input_shapes`, and `input_types` "
|
||||||
"must be specified.")
|
"must be specified.")
|
||||||
self._input_structure = dataset._element_structure
|
self._input_structure = dataset.element_spec
|
||||||
else:
|
else:
|
||||||
if not (dataset is None and input_classes is None and input_shapes is None
|
if not (dataset is None and input_classes is None and input_shapes is None
|
||||||
and input_types is None):
|
and input_types is None):
|
||||||
@ -2767,7 +2766,7 @@ class _GeneratorDataset(DatasetSource):
|
|||||||
super(_GeneratorDataset, self).__init__(variant_tensor)
|
super(_GeneratorDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._next_func.output_structure
|
return self._next_func.output_structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -2792,7 +2791,7 @@ class ZipDataset(DatasetV2):
|
|||||||
self._datasets = datasets
|
self._datasets = datasets
|
||||||
self._structure = nest.pack_sequence_as(
|
self._structure = nest.pack_sequence_as(
|
||||||
self._datasets,
|
self._datasets,
|
||||||
[ds._element_structure for ds in nest.flatten(self._datasets)]) # pylint: disable=protected-access
|
[ds.element_spec for ds in nest.flatten(self._datasets)])
|
||||||
variant_tensor = gen_dataset_ops.zip_dataset(
|
variant_tensor = gen_dataset_ops.zip_dataset(
|
||||||
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
||||||
**self._flat_structure)
|
**self._flat_structure)
|
||||||
@ -2802,7 +2801,7 @@ class ZipDataset(DatasetV2):
|
|||||||
return nest.flatten(self._datasets)
|
return nest.flatten(self._datasets)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -2850,7 +2849,7 @@ class ConcatenateDataset(DatasetV2):
|
|||||||
return self._input_datasets
|
return self._input_datasets
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -2907,7 +2906,7 @@ class RangeDataset(DatasetSource):
|
|||||||
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
|
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -3037,11 +3036,11 @@ class BatchDataset(UnaryDataset):
|
|||||||
self._structure = nest.map_structure(
|
self._structure = nest.map_structure(
|
||||||
lambda component_spec: component_spec._batch(
|
lambda component_spec: component_spec._batch(
|
||||||
tensor_util.constant_value(self._batch_size)),
|
tensor_util.constant_value(self._batch_size)),
|
||||||
input_dataset._element_structure)
|
input_dataset.element_spec)
|
||||||
else:
|
else:
|
||||||
self._structure = nest.map_structure(
|
self._structure = nest.map_structure(
|
||||||
lambda component_spec: component_spec._batch(None),
|
lambda component_spec: component_spec._batch(None),
|
||||||
input_dataset._element_structure)
|
input_dataset.element_spec)
|
||||||
variant_tensor = gen_dataset_ops.batch_dataset_v2(
|
variant_tensor = gen_dataset_ops.batch_dataset_v2(
|
||||||
input_dataset._variant_tensor,
|
input_dataset._variant_tensor,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
@ -3050,7 +3049,7 @@ class BatchDataset(UnaryDataset):
|
|||||||
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
|
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -3268,7 +3267,7 @@ class PaddedBatchDataset(UnaryDataset):
|
|||||||
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
|
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -3308,7 +3307,7 @@ class MapDataset(UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._map_func.output_structure
|
return self._map_func.output_structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -3350,7 +3349,7 @@ class ParallelMapDataset(UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._map_func.output_structure
|
return self._map_func.output_structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -3369,7 +3368,7 @@ class FlatMapDataset(UnaryDataset):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||||
type(self._map_func.output_structure)))
|
type(self._map_func.output_structure)))
|
||||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||||
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
self._map_func.function.captured_inputs,
|
self._map_func.function.captured_inputs,
|
||||||
@ -3381,7 +3380,7 @@ class FlatMapDataset(UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -3400,7 +3399,7 @@ class InterleaveDataset(UnaryDataset):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||||
type(self._map_func.output_structure)))
|
type(self._map_func.output_structure)))
|
||||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||||
self._cycle_length = ops.convert_to_tensor(
|
self._cycle_length = ops.convert_to_tensor(
|
||||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||||
self._block_length = ops.convert_to_tensor(
|
self._block_length = ops.convert_to_tensor(
|
||||||
@ -3419,7 +3418,7 @@ class InterleaveDataset(UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -3439,7 +3438,7 @@ class ParallelInterleaveDataset(UnaryDataset):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||||
type(self._map_func.output_structure)))
|
type(self._map_func.output_structure)))
|
||||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||||
self._cycle_length = ops.convert_to_tensor(
|
self._cycle_length = ops.convert_to_tensor(
|
||||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||||
self._block_length = ops.convert_to_tensor(
|
self._block_length = ops.convert_to_tensor(
|
||||||
@ -3461,7 +3460,7 @@ class ParallelInterleaveDataset(UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
@ -3561,7 +3560,7 @@ class WindowDataset(UnaryDataset):
|
|||||||
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
|
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -3684,7 +3683,7 @@ class _RestructuredDataset(UnaryDataset):
|
|||||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
@ -3713,5 +3712,5 @@ class _UnbatchDataset(UnaryDataset):
|
|||||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
@ -104,10 +104,12 @@ class Iterator(trackable.Trackable):
|
|||||||
raise ValueError("If `structure` is not specified, all of "
|
raise ValueError("If `structure` is not specified, all of "
|
||||||
"`output_types`, `output_shapes`, and `output_classes`"
|
"`output_types`, `output_shapes`, and `output_classes`"
|
||||||
" must be specified.")
|
" must be specified.")
|
||||||
self._structure = structure.convert_legacy_structure(
|
self._element_spec = structure.convert_legacy_structure(
|
||||||
output_types, output_shapes, output_classes)
|
output_types, output_shapes, output_classes)
|
||||||
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(self._structure)
|
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
|
||||||
self._flat_tensor_types = structure.get_flat_tensor_types(self._structure)
|
self._element_spec)
|
||||||
|
self._flat_tensor_types = structure.get_flat_tensor_types(
|
||||||
|
self._element_spec)
|
||||||
|
|
||||||
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
|
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
|
||||||
self._iterator_resource)
|
self._iterator_resource)
|
||||||
@ -347,13 +349,13 @@ class Iterator(trackable.Trackable):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
dataset_output_types = nest.map_structure(
|
dataset_output_types = nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_types(),
|
lambda component_spec: component_spec._to_legacy_output_types(),
|
||||||
dataset._element_structure)
|
dataset.element_spec)
|
||||||
dataset_output_shapes = nest.map_structure(
|
dataset_output_shapes = nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_shapes(),
|
lambda component_spec: component_spec._to_legacy_output_shapes(),
|
||||||
dataset._element_structure)
|
dataset.element_spec)
|
||||||
dataset_output_classes = nest.map_structure(
|
dataset_output_classes = nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_classes(),
|
lambda component_spec: component_spec._to_legacy_output_classes(),
|
||||||
dataset._element_structure)
|
dataset.element_spec)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
nest.assert_same_structure(self.output_types, dataset_output_types)
|
nest.assert_same_structure(self.output_types, dataset_output_types)
|
||||||
@ -436,7 +438,7 @@ class Iterator(trackable.Trackable):
|
|||||||
output_types=self._flat_tensor_types,
|
output_types=self._flat_tensor_types,
|
||||||
output_shapes=self._flat_tensor_shapes,
|
output_shapes=self._flat_tensor_shapes,
|
||||||
name=name)
|
name=name)
|
||||||
return structure.from_tensor_list(self._structure, flat_ret)
|
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||||
|
|
||||||
def string_handle(self, name=None):
|
def string_handle(self, name=None):
|
||||||
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
||||||
@ -467,7 +469,7 @@ class Iterator(trackable.Trackable):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -481,7 +483,7 @@ class Iterator(trackable.Trackable):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -495,17 +497,17 @@ class Iterator(trackable.Trackable):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
"""The structure of an element of this iterator.
|
"""The type specification of an element of this iterator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Structure` object representing the structure of the components of this
|
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||||
optional.
|
element of this iterator and specifying the type of individual components.
|
||||||
"""
|
"""
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
def _gather_saveables_for_checkpoint(self):
|
def _gather_saveables_for_checkpoint(self):
|
||||||
|
|
||||||
@ -556,7 +558,7 @@ class IteratorResourceDeleter(object):
|
|||||||
class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||||
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
|
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
|
||||||
|
|
||||||
def __init__(self, dataset=None, components=None, element_structure=None):
|
def __init__(self, dataset=None, components=None, element_spec=None):
|
||||||
"""Creates a new iterator from the given dataset.
|
"""Creates a new iterator from the given dataset.
|
||||||
|
|
||||||
If `dataset` is not specified, the iterator will be created from the given
|
If `dataset` is not specified, the iterator will be created from the given
|
||||||
@ -567,29 +569,29 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
Args:
|
Args:
|
||||||
dataset: A `tf.data.Dataset` object.
|
dataset: A `tf.data.Dataset` object.
|
||||||
components: Tensor components to construct the iterator from.
|
components: Tensor components to construct the iterator from.
|
||||||
element_structure: A nested structure of `TypeSpec` objects that
|
element_spec: A nested structure of `TypeSpec` objects that
|
||||||
represents the type specification elements of the iterator.
|
represents the type specification of elements of the iterator.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `dataset` is not provided and either `components` or
|
ValueError: If `dataset` is not provided and either `components` or
|
||||||
`element_structure` is not provided. Or `dataset` is provided and either
|
`element_spec` is not provided. Or `dataset` is provided and either
|
||||||
`components` and `element_structure` is provided.
|
`components` and `element_spec` is provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_message = "Either `dataset` or both `components` and "
|
error_message = "Either `dataset` or both `components` and "
|
||||||
"`element_structure` need to be provided."
|
"`element_spec` need to be provided."
|
||||||
|
|
||||||
self._device = context.context().device_name
|
self._device = context.context().device_name
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
if (components is None or element_structure is None):
|
if (components is None or element_spec is None):
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._structure = element_structure
|
self._element_spec = element_spec
|
||||||
self._flat_output_types = structure.get_flat_tensor_types(
|
self._flat_output_types = structure.get_flat_tensor_types(
|
||||||
self._structure)
|
self._element_spec)
|
||||||
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
||||||
self._structure)
|
self._element_spec)
|
||||||
self._iterator_resource, self._deleter = components
|
self._iterator_resource, self._deleter = components
|
||||||
# Delete the resource when this object is deleted
|
# Delete the resource when this object is deleted
|
||||||
self._resource_deleter = IteratorResourceDeleter(
|
self._resource_deleter = IteratorResourceDeleter(
|
||||||
@ -597,7 +599,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
device=self._device,
|
device=self._device,
|
||||||
deleter=self._deleter)
|
deleter=self._deleter)
|
||||||
else:
|
else:
|
||||||
if (components is not None or element_structure is not None):
|
if (components is not None or element_spec is not None):
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
if (_device_stack_is_empty() or
|
if (_device_stack_is_empty() or
|
||||||
context.context().device_spec.device_type != "CPU"):
|
context.context().device_spec.device_type != "CPU"):
|
||||||
@ -610,11 +612,11 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
dataset = dataset._apply_options()
|
dataset = dataset._apply_options()
|
||||||
ds_variant = dataset._variant_tensor
|
ds_variant = dataset._variant_tensor
|
||||||
self._structure = dataset._element_structure
|
self._element_spec = dataset.element_spec
|
||||||
self._flat_output_types = structure.get_flat_tensor_types(
|
self._flat_output_types = structure.get_flat_tensor_types(
|
||||||
self._structure)
|
self._element_spec)
|
||||||
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
||||||
self._structure)
|
self._element_spec)
|
||||||
with ops.colocate_with(ds_variant):
|
with ops.colocate_with(ds_variant):
|
||||||
self._iterator_resource, self._deleter = (
|
self._iterator_resource, self._deleter = (
|
||||||
gen_dataset_ops.anonymous_iterator_v2(
|
gen_dataset_ops.anonymous_iterator_v2(
|
||||||
@ -642,7 +644,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
self._iterator_resource,
|
self._iterator_resource,
|
||||||
output_types=self._flat_output_types,
|
output_types=self._flat_output_types,
|
||||||
output_shapes=self._flat_output_shapes)
|
output_shapes=self._flat_output_shapes)
|
||||||
return structure.from_compatible_tensor_list(self._structure, ret)
|
return structure.from_compatible_tensor_list(self._element_spec, ret)
|
||||||
|
|
||||||
# This runs in sync mode as iterators use an error status to communicate
|
# This runs in sync mode as iterators use an error status to communicate
|
||||||
# that there is no more data to iterate over.
|
# that there is no more data to iterate over.
|
||||||
@ -664,13 +666,13 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Fast path for the case `self._structure` is not a nested structure.
|
# Fast path for the case `self._structure` is not a nested structure.
|
||||||
return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access
|
return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return structure.from_compatible_tensor_list(self._structure, ret)
|
return structure.from_compatible_tensor_list(self._element_spec, ret)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
return IteratorSpec(self._element_structure)
|
return IteratorSpec(self.element_spec)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns a nested structure of `Tensor`s containing the next element."""
|
"""Returns a nested structure of `Tensor`s containing the next element."""
|
||||||
@ -693,7 +695,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -707,7 +709,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
@ -721,17 +723,17 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
|||||||
"""
|
"""
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||||
self._structure)
|
self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
"""The structure of an element of this iterator.
|
"""The type specification of an element of this iterator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Structure` object representing the structure of the components of this
|
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||||
optional.
|
element of this iterator and specifying the type of individual components.
|
||||||
"""
|
"""
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||||
@ -786,11 +788,11 @@ class IteratorSpec(type_spec.TypeSpec):
|
|||||||
return IteratorV2(
|
return IteratorV2(
|
||||||
dataset=None,
|
dataset=None,
|
||||||
components=components,
|
components=components,
|
||||||
element_structure=self._element_spec)
|
element_spec=self._element_spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_value(value):
|
def from_value(value):
|
||||||
return IteratorSpec(value._element_structure) # pylint: disable=protected-access
|
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/71645805): Expose trackable stateful objects from dataset
|
# TODO(b/71645805): Expose trackable stateful objects from dataset
|
||||||
@ -828,7 +830,6 @@ def get_next_as_optional(iterator):
|
|||||||
return optional_ops._OptionalImpl(
|
return optional_ops._OptionalImpl(
|
||||||
gen_dataset_ops.iterator_get_next_as_optional(
|
gen_dataset_ops.iterator_get_next_as_optional(
|
||||||
iterator._iterator_resource,
|
iterator._iterator_resource,
|
||||||
output_types=structure.get_flat_tensor_types(
|
output_types=structure.get_flat_tensor_types(iterator.element_spec),
|
||||||
iterator._element_structure),
|
|
||||||
output_shapes=structure.get_flat_tensor_shapes(
|
output_shapes=structure.get_flat_tensor_shapes(
|
||||||
iterator._element_structure)), iterator._element_structure)
|
iterator.element_spec)), iterator.element_spec)
|
||||||
|
@ -36,8 +36,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
"""A `dummy` generator dataset."""
|
"""A `dummy` generator dataset."""
|
||||||
|
|
||||||
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
|
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
|
||||||
source_device, element_structure):
|
source_device, element_spec):
|
||||||
self._structure = element_structure
|
self._element_spec = element_spec
|
||||||
|
|
||||||
multi_device_iterator_string_handle = (
|
multi_device_iterator_string_handle = (
|
||||||
gen_dataset_ops.multi_device_iterator_to_string_handle(
|
gen_dataset_ops.multi_device_iterator_to_string_handle(
|
||||||
@ -71,14 +71,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
multi_device_iterator = (
|
multi_device_iterator = (
|
||||||
gen_dataset_ops.multi_device_iterator_from_string_handle(
|
gen_dataset_ops.multi_device_iterator_from_string_handle(
|
||||||
string_handle=string_handle,
|
string_handle=string_handle,
|
||||||
output_types=structure.get_flat_tensor_types(self._structure),
|
output_types=structure.get_flat_tensor_types(self._element_spec),
|
||||||
output_shapes=structure.get_flat_tensor_shapes(self._structure)))
|
output_shapes=structure.get_flat_tensor_shapes(
|
||||||
|
self._element_spec)))
|
||||||
return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
|
return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
|
||||||
multi_device_iterator=multi_device_iterator,
|
multi_device_iterator=multi_device_iterator,
|
||||||
shard_num=shard_num,
|
shard_num=shard_num,
|
||||||
incarnation_id=incarnation_id,
|
incarnation_id=incarnation_id,
|
||||||
output_types=structure.get_flat_tensor_types(self._structure),
|
output_types=structure.get_flat_tensor_types(self._element_spec),
|
||||||
output_shapes=structure.get_flat_tensor_shapes(self._structure))
|
output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
|
||||||
|
|
||||||
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
return functional_ops.remote_call(
|
return functional_ops.remote_call(
|
||||||
target=source_device,
|
target=source_device,
|
||||||
args=[string_handle] + next_func_concrete.captured_inputs,
|
args=[string_handle] + next_func_concrete.captured_inputs,
|
||||||
Tout=structure.get_flat_tensor_types(self._structure),
|
Tout=structure.get_flat_tensor_types(self._element_spec),
|
||||||
f=next_func_concrete)
|
f=next_func_concrete)
|
||||||
|
|
||||||
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||||
@ -141,8 +142,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
||||||
@ -154,7 +155,7 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
|
|
||||||
def __init__(self, per_device_dataset, incarnation_id):
|
def __init__(self, per_device_dataset, incarnation_id):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._structure = per_device_dataset._element_structure
|
self._element_spec = per_device_dataset.element_spec
|
||||||
self._init_func = per_device_dataset._init_func
|
self._init_func = per_device_dataset._init_func
|
||||||
self._init_captured_args = self._init_func.captured_inputs
|
self._init_captured_args = self._init_func.captured_inputs
|
||||||
|
|
||||||
@ -183,8 +184,8 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
class MultiDeviceIterator(object):
|
class MultiDeviceIterator(object):
|
||||||
@ -253,7 +254,7 @@ class MultiDeviceIterator(object):
|
|||||||
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
|
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
|
||||||
self._incarnation_id,
|
self._incarnation_id,
|
||||||
self._source_device_tensor,
|
self._source_device_tensor,
|
||||||
self._dataset._element_structure) # pylint: disable=protected-access
|
self._dataset.element_spec)
|
||||||
self._prototype_device_datasets.append(ds)
|
self._prototype_device_datasets.append(ds)
|
||||||
|
|
||||||
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
|
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
|
||||||
@ -339,5 +340,5 @@ class MultiDeviceIterator(object):
|
|||||||
ds_variant, self._device_iterators[i]._iterator_resource)
|
ds_variant, self._device_iterators[i]._iterator_resource)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return dataset_ops.get_structure(self._dataset)
|
return self._dataset.element_spec
|
||||||
|
@ -115,7 +115,7 @@ class _TextLineDataset(dataset_ops.DatasetSource):
|
|||||||
super(_TextLineDataset, self).__init__(variant_tensor)
|
super(_TextLineDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
|
|||||||
super(TextLineDatasetV2, self).__init__(variant_tensor)
|
super(TextLineDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
|||||||
super(_TFRecordDataset, self).__init__(variant_tensor)
|
super(_TFRecordDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
|||||||
if not isinstance(self._map_func.output_structure,
|
if not isinstance(self._map_func.output_structure,
|
||||||
dataset_ops.DatasetStructure):
|
dataset_ops.DatasetStructure):
|
||||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||||
self._cycle_length = ops.convert_to_tensor(
|
self._cycle_length = ops.convert_to_tensor(
|
||||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||||
self._block_length = ops.convert_to_tensor(
|
self._block_length = ops.convert_to_tensor(
|
||||||
@ -257,8 +257,8 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
|||||||
return [self._map_func]
|
return [self._map_func]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return self._structure
|
return self._element_spec
|
||||||
|
|
||||||
def _transformation_name(self):
|
def _transformation_name(self):
|
||||||
return "tf.data.experimental.parallel_interleave()"
|
return "tf.data.experimental.parallel_interleave()"
|
||||||
@ -321,7 +321,7 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
|||||||
return self._impl._inputs() # pylint: disable=protected-access
|
return self._impl._inputs() # pylint: disable=protected-access
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -408,7 +408,7 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
|
|||||||
super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
|
super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
@ -466,7 +466,7 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
|
|||||||
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
|
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def element_spec(self):
|
||||||
return structure.TensorStructure(dtypes.string, [])
|
return structure.TensorStructure(dtypes.string, [])
|
||||||
|
|
||||||
|
|
||||||
|
@ -480,7 +480,7 @@ class DistributedDataset(_IterableInput):
|
|||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
# TODO(anjalisridhar): Identify if we need to set this property on the
|
# TODO(anjalisridhar): Identify if we need to set this property on the
|
||||||
# iterator.
|
# iterator.
|
||||||
self._element_structure = dataset._element_structure # pylint: disable=protected-access
|
self._element_spec = dataset.element_spec
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -490,7 +490,7 @@ class DistributedDataset(_IterableInput):
|
|||||||
self._input_workers)
|
self._input_workers)
|
||||||
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||||
self._strategy)
|
self._strategy)
|
||||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
iterator.element_spec = self._element_spec
|
||||||
return iterator
|
return iterator
|
||||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||||
"or when eager execution is enabled.")
|
"or when eager execution is enabled.")
|
||||||
@ -537,7 +537,7 @@ class DistributedDatasetV1(DistributedDataset):
|
|||||||
self._input_workers)
|
self._input_workers)
|
||||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||||
self._strategy)
|
self._strategy)
|
||||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
iterator.element_spec = self._element_spec
|
||||||
return iterator
|
return iterator
|
||||||
|
|
||||||
|
|
||||||
@ -670,9 +670,9 @@ class DatasetIterator(DistributedIteratorV1):
|
|||||||
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
|
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
|
||||||
super(DatasetIterator, self).__init__(
|
super(DatasetIterator, self).__init__(
|
||||||
input_workers,
|
input_workers,
|
||||||
worker_iterators, # pylint: disable=protected-access
|
worker_iterators,
|
||||||
strategy)
|
strategy)
|
||||||
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access
|
self._element_spec = dist_dataset._element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _dummy_tensor_fn(value_structure):
|
def _dummy_tensor_fn(value_structure):
|
||||||
|
@ -56,8 +56,7 @@ def _clone_dataset(dataset):
|
|||||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
|
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
|
||||||
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
|
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
|
||||||
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
|
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
|
||||||
return dataset_ops._VariantDataset(new_variant_tensor,
|
return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
|
||||||
dataset._element_structure)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_op_def(op):
|
def _get_op_def(op):
|
||||||
|
@ -276,8 +276,7 @@ class CloneDatasetTest(test.TestCase):
|
|||||||
def _assert_datasets_equal(self, ds1, ds2):
|
def _assert_datasets_equal(self, ds1, ds2):
|
||||||
# First lets assert the structure is the same.
|
# First lets assert the structure is the same.
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
structure.are_compatible(ds1._element_structure,
|
structure.are_compatible(ds1.element_spec, ds2.element_spec))
|
||||||
ds2._element_structure))
|
|
||||||
|
|
||||||
# Now create iterators on both and assert they produce the same values.
|
# Now create iterators on both and assert they produce the same values.
|
||||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||||
|
@ -5,6 +5,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -3,6 +3,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
|
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "initializer"
|
name: "initializer"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "output_classes"
|
name: "output_classes"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -4,6 +4,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<class \'abc.abstractproperty\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
@ -5,6 +5,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
|
argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user