[tf.data] Exposing dataset / iterator element type specification in the public API as element_spec
.
This CL renames the pre-existing `_element_structure` property of tf.data datasets and iterators to `element_spec`, thus exposing it in the public API. PiperOrigin-RevId: 256201202
This commit is contained in:
parent
ff67edeac1
commit
9db44da931
tensorflow
contrib
bigtable/python/ops
data/python/ops
hadoop/python/ops
ignite/python/ops
kafka/python/ops
kinesis/python/ops
python
data
experimental
kernel_tests
ops
kernel_tests
ops
distribute
tools/api/golden
v1
tensorflow.data.-dataset.pbtxttensorflow.data.-fixed-length-record-dataset.pbtxttensorflow.data.-iterator.pbtxttensorflow.data.-t-f-record-dataset.pbtxttensorflow.data.-text-line-dataset.pbtxttensorflow.data.experimental.-csv-dataset.pbtxttensorflow.data.experimental.-random-dataset.pbtxttensorflow.data.experimental.-sql-dataset.pbtxt
v2
tensorflow.data.-dataset.pbtxttensorflow.data.-fixed-length-record-dataset.pbtxttensorflow.data.-t-f-record-dataset.pbtxttensorflow.data.-text-line-dataset.pbtxttensorflow.data.experimental.-csv-dataset.pbtxttensorflow.data.experimental.-random-dataset.pbtxttensorflow.data.experimental.-sql-dataset.pbtxt
@ -591,7 +591,7 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
|
||||
self._table = table
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -652,7 +652,7 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
|
||||
super(_BigtableLookupDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
||||
self._num_outputs)
|
||||
|
||||
@ -681,7 +681,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
|
||||
super(_BigtableScanDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return tuple([structure.TensorStructure(dtypes.string, [])] *
|
||||
self._num_outputs)
|
||||
|
||||
@ -703,6 +703,6 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
|
||||
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return (structure.TensorStructure(dtypes.string, []),
|
||||
structure.TensorStructure(dtypes.string, []))
|
||||
|
@ -346,13 +346,13 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
|
||||
output_classes = nest.pack_sequence_as(
|
||||
output_types, nest.flatten(input_classes))
|
||||
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
|
@ -395,6 +395,6 @@ class LMDBDataset(dataset_ops.DatasetSource):
|
||||
super(LMDBDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return (structure.TensorStructure(dtypes.string, []),
|
||||
structure.TensorStructure(dtypes.string, []))
|
||||
|
@ -39,7 +39,7 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
||||
window_shift, dtype=dtypes.int64, name="window_shift")
|
||||
|
||||
input_structure = dataset_ops.get_structure(input_dataset)
|
||||
self._structure = nest.map_structure(
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
@ -50,8 +50,8 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
||||
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@deprecation.deprecated_args(
|
||||
|
@ -58,11 +58,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
variant_tensor = gen_dataset_ops.sequence_file_dataset(
|
||||
self._filenames,
|
||||
structure.get_flat_tensor_types(self._element_structure))
|
||||
self._filenames, self._flat_types)
|
||||
super(SequenceFileDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return (structure.TensorStructure(dtypes.string, []),
|
||||
structure.TensorStructure(dtypes.string, []))
|
||||
|
@ -754,7 +754,7 @@ class IgniteDataset(dataset_ops.DatasetSource):
|
||||
self.cache_type.to_permutation(),
|
||||
dtype=dtypes.int32,
|
||||
name="permutation")
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
self.cache_type.to_output_types(), self.cache_type.to_output_shapes(),
|
||||
self.cache_type.to_output_classes())
|
||||
|
||||
@ -766,5 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource):
|
||||
self.schema, self.permutation)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
@ -69,5 +69,5 @@ class KafkaDataset(dataset_ops.DatasetSource):
|
||||
self._group, self._eof, self._timeout)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
@ -86,5 +86,5 @@ class KinesisDataset(dataset_ops.DatasetSource):
|
||||
self._stream, self._shard, self._read_indefinitely, self._interval)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
@ -37,7 +37,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
||||
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
|
||||
|
||||
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
||||
ds._element_structure)
|
||||
ds.element_spec)
|
||||
get_next = self.getNext(variant_ds, requires_initialization=True)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(get_next()))
|
||||
@ -54,7 +54,7 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
|
||||
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
|
||||
gpu_wrapped_variant)
|
||||
variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
|
||||
ds._element_structure)
|
||||
ds.element_spec)
|
||||
iterator = dataset_ops.make_initializable_iterator(variant_ds)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
|
@ -242,7 +242,7 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
self._input_dataset = input_dataset
|
||||
self._batch_size = batch_size
|
||||
self._row_shape = row_shape
|
||||
self._structure = structure.SparseTensorStructure(
|
||||
self._element_spec = structure.SparseTensorStructure(
|
||||
dataset_ops.get_legacy_output_types(input_dataset),
|
||||
tensor_shape.vector(None).concatenate(self._row_shape))
|
||||
|
||||
@ -255,8 +255,8 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
@ -285,12 +285,12 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
|
||||
# or `False` (explicitly retaining the remainder).
|
||||
# pylint: disable=g-long-lambda
|
||||
self._structure = nest.map_structure(
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(
|
||||
tensor_util.constant_value(self._batch_size_t)),
|
||||
self._map_func.output_structure)
|
||||
else:
|
||||
self._structure = nest.map_structure(
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(None),
|
||||
self._map_func.output_structure)
|
||||
# pylint: enable=protected-access
|
||||
@ -309,5 +309,5 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
@ -46,7 +46,7 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
||||
def __init__(self, input_dataset, num_workers, index):
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._structure = input_dataset._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = input_dataset.element_spec
|
||||
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
@ -55,8 +55,8 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
||||
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
def _AutoShardDatasetV1(input_dataset, num_workers, index):
|
||||
@ -85,7 +85,7 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
|
||||
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
|
||||
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
input_types, output_shapes, input_classes)
|
||||
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
@ -94,8 +94,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
||||
|
@ -65,6 +65,6 @@ def get_single_element(dataset):
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return structure.from_compatible_tensor_list(
|
||||
dataset._element_structure,
|
||||
dataset.element_spec,
|
||||
gen_dataset_ops.dataset_to_single_element(
|
||||
dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access
|
||||
|
@ -299,8 +299,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||
reduce_func,
|
||||
self._transformation_name(),
|
||||
input_structure=(self._state_structure,
|
||||
input_dataset._element_structure), # pylint: disable=protected-access
|
||||
input_structure=(self._state_structure, input_dataset.element_spec),
|
||||
add_to_graph=False)
|
||||
|
||||
# Extract and validate class information from the returned values.
|
||||
@ -355,7 +354,7 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
input_structure=self._state_structure)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._finalize_func.output_structure
|
||||
|
||||
def _functions(self):
|
||||
@ -416,7 +415,7 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
def _make_reduce_func(self, reduce_func, input_dataset):
|
||||
"""Make wrapping defun for reduce_func."""
|
||||
nested_dataset = dataset_ops.DatasetStructure(
|
||||
input_dataset._element_structure) # pylint: disable=protected-access
|
||||
input_dataset.element_spec)
|
||||
input_structure = (structure.TensorStructure(dtypes.int64,
|
||||
[]), nested_dataset)
|
||||
self._reduce_func = dataset_ops.StructuredFunctionWrapper(
|
||||
@ -426,12 +425,12 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
self._reduce_func.output_structure, dataset_ops.DatasetStructure):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
# pylint: disable=protected-access
|
||||
self._structure = (
|
||||
self._reduce_func.output_structure._element_structure)
|
||||
self._element_spec = (
|
||||
self._reduce_func.output_structure._element_spec)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
def _functions(self):
|
||||
return [self._key_func, self._reduce_func, self._window_size_func]
|
||||
|
@ -119,7 +119,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
|
||||
])
|
||||
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
first_output_types, output_shapes, first_output_classes)
|
||||
super(_DirectedInterleaveDataset, self).__init__()
|
||||
|
||||
@ -136,8 +136,8 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
return [self._selector_input] + self._data_inputs
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@tf_export("data.experimental.sample_from_datasets", v1=[])
|
||||
@ -267,8 +267,8 @@ def choose_from_datasets_v2(datasets, choice_dataset):
|
||||
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
|
||||
type.
|
||||
"""
|
||||
if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
|
||||
structure.TensorStructure(dtypes.int64, [])):
|
||||
if not structure.are_compatible(choice_dataset.element_spec,
|
||||
structure.TensorStructure(dtypes.int64, [])):
|
||||
raise TypeError("`choice_dataset` must be a dataset of scalar "
|
||||
"`tf.int64` tensors.")
|
||||
return _DirectedInterleaveDataset(choice_dataset, datasets)
|
||||
|
@ -35,5 +35,5 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
|
||||
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
@ -156,7 +156,7 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
||||
A `Dataset` that has the same elements the inputs.
|
||||
"""
|
||||
self._datasets = list(datasets)
|
||||
self._structure = self._datasets[0]._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = self._datasets[0].element_spec
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
@ -168,8 +168,8 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
||||
return self._datasets
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._datasets[0]._element_structure # pylint: disable=protected-access
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||
@ -242,14 +242,13 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||
Returns:
|
||||
A `Dataset` that has the same elements the inputs.
|
||||
"""
|
||||
input_structure = dataset_ops.DatasetStructure(
|
||||
dataset_ops.get_structure(input_dataset))
|
||||
input_structure = dataset_ops.DatasetStructure(input_dataset.element_spec)
|
||||
self._funcs = [
|
||||
dataset_ops.StructuredFunctionWrapper(
|
||||
f, "ChooseFastestV2", input_structure=input_structure)
|
||||
for f in functions
|
||||
]
|
||||
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = self._funcs[0].output_structure._element_spec # pylint: disable=protected-access
|
||||
|
||||
self._captured_arguments = []
|
||||
for f in self._funcs:
|
||||
@ -279,5 +278,5 @@ class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||
variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
@ -32,7 +32,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, features, num_parallel_calls):
|
||||
self._input_dataset = input_dataset
|
||||
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
|
||||
if not structure.are_compatible(
|
||||
input_dataset.element_spec,
|
||||
structure.TensorStructure(dtypes.string, [None])):
|
||||
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
||||
self._num_parallel_calls = num_parallel_calls
|
||||
@ -75,7 +76,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
[ops.Tensor for _ in range(len(self._dense_defaults))] +
|
||||
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
|
||||
]))
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
variant_tensor = (
|
||||
@ -91,8 +92,8 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
# TODO(b/111553342): add arguments names and example names as well.
|
||||
|
@ -146,8 +146,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
dataset_ops.get_legacy_output_types(self),
|
||||
dataset_ops.get_legacy_output_shapes(self),
|
||||
dataset_ops.get_legacy_output_classes(self))
|
||||
return structure.to_tensor_list(self._element_structure,
|
||||
iterator.get_next())
|
||||
return structure.to_tensor_list(self.element_spec, iterator.get_next())
|
||||
|
||||
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||
|
||||
@ -251,7 +250,7 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._map_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
|
@ -39,7 +39,7 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
|
||||
super(RandomDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.int64, [])
|
||||
|
||||
|
||||
|
@ -662,7 +662,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
argument_default=[],
|
||||
argument_dtype=dtypes.int64,
|
||||
)
|
||||
self._structure = tuple(
|
||||
self._element_spec = tuple(
|
||||
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
filenames=self._filenames,
|
||||
@ -678,8 +678,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
super(CsvDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@tf_export(v1=["data.experimental.CsvDataset"])
|
||||
@ -955,7 +955,7 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
data_source_name, dtype=dtypes.string, name="data_source_name")
|
||||
self._query = ops.convert_to_tensor(
|
||||
query, dtype=dtypes.string, name="query")
|
||||
self._structure = nest.map_structure(
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda dtype: structure.TensorStructure(dtype, []), output_types)
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
@ -963,8 +963,8 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
super(SqlDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@tf_export(v1=["data.experimental.SqlDataset"])
|
||||
|
@ -49,7 +49,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
scan_func,
|
||||
self._transformation_name(),
|
||||
input_structure=(self._state_structure,
|
||||
input_dataset._element_structure), # pylint: disable=protected-access
|
||||
input_dataset.element_spec),
|
||||
add_to_graph=False)
|
||||
if not (
|
||||
isinstance(wrapped_func.output_types, collections.Sequence) and
|
||||
@ -91,7 +91,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
old_state_shapes = nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||
self._state_structure)
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
flat_state_shapes = nest.flatten(old_state_shapes)
|
||||
@ -135,8 +135,8 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
return [self._scan_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.scan()"
|
||||
|
@ -43,21 +43,6 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
# For testing deserialization of Datasets represented as functions
|
||||
class _RevivedDataset(dataset_ops.DatasetV2):
|
||||
|
||||
def __init__(self, variant, element_structure):
|
||||
self._structure = element_structure
|
||||
super(_RevivedDataset, self).__init__(variant)
|
||||
|
||||
def _inputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@ -75,8 +60,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
fn = original_dataset._trace_variant_creation()
|
||||
variant = fn()
|
||||
|
||||
revived_dataset = _RevivedDataset(
|
||||
variant, original_dataset._element_structure)
|
||||
revived_dataset = dataset_ops._VariantDataset(
|
||||
variant, original_dataset.element_spec)
|
||||
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
|
||||
|
||||
def testAsFunctionWithMapInFlatMap(self):
|
||||
@ -88,8 +73,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
fn = original_dataset._trace_variant_creation()
|
||||
variant = fn()
|
||||
|
||||
revived_dataset = _RevivedDataset(
|
||||
variant, original_dataset._element_structure)
|
||||
revived_dataset = dataset_ops._VariantDataset(
|
||||
variant, original_dataset.element_spec)
|
||||
self.assertDatasetProduces(revived_dataset, list(original_dataset))
|
||||
|
||||
@staticmethod
|
||||
|
@ -315,14 +315,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
"or when eager execution is enabled.")
|
||||
|
||||
@abc.abstractproperty
|
||||
def _element_structure(self):
|
||||
"""The structure of an element of this dataset.
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this dataset.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this dataset and specifying the type of individual components.
|
||||
"""
|
||||
raise NotImplementedError("Dataset._element_structure")
|
||||
raise NotImplementedError("Dataset.element_spec")
|
||||
|
||||
def __repr__(self):
|
||||
output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
|
||||
@ -339,7 +339,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
Returns:
|
||||
A list `tf.TensorShapes`s for the element tensor representation.
|
||||
"""
|
||||
return structure.get_flat_tensor_shapes(self._element_structure)
|
||||
return structure.get_flat_tensor_shapes(self.element_spec)
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
@ -348,7 +348,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
Returns:
|
||||
A list `tf.DType`s for the element tensor representation.
|
||||
"""
|
||||
return structure.get_flat_tensor_types(self._element_structure)
|
||||
return structure.get_flat_tensor_types(self.element_spec)
|
||||
|
||||
@property
|
||||
def _flat_structure(self):
|
||||
@ -371,7 +371,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return DatasetStructure(self._element_structure)
|
||||
return DatasetStructure(self.element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_tensors(tensors):
|
||||
@ -1442,7 +1442,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
wrapped_func = StructuredFunctionWrapper(
|
||||
reduce_func,
|
||||
"reduce()",
|
||||
input_structure=(state_structure, self._element_structure),
|
||||
input_structure=(state_structure, self.element_spec),
|
||||
add_to_graph=False)
|
||||
|
||||
# Extract and validate class information from the returned values.
|
||||
@ -1546,17 +1546,17 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
def normalize(arg, *rest):
|
||||
# pylint: disable=protected-access
|
||||
if rest:
|
||||
return structure.to_batched_tensor_list(self._element_structure,
|
||||
return structure.to_batched_tensor_list(self.element_spec,
|
||||
(arg,) + rest)
|
||||
else:
|
||||
return structure.to_batched_tensor_list(self._element_structure, arg)
|
||||
return structure.to_batched_tensor_list(self.element_spec, arg)
|
||||
|
||||
normalized_dataset = self.map(normalize)
|
||||
|
||||
# NOTE(mrry): Our `map()` has lost information about the structure of
|
||||
# non-tensor components, so re-apply the structure of the original dataset.
|
||||
restructured_dataset = _RestructuredDataset(normalized_dataset,
|
||||
self._element_structure)
|
||||
self.element_spec)
|
||||
return _UnbatchDataset(restructured_dataset)
|
||||
|
||||
def with_options(self, options):
|
||||
@ -1750,7 +1750,7 @@ class DatasetV1(DatasetV2):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||
self._element_structure)
|
||||
self.element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -1764,7 +1764,7 @@ class DatasetV1(DatasetV2):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||
self._element_structure)
|
||||
self.element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -1778,10 +1778,10 @@ class DatasetV1(DatasetV2):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||
self._element_structure)
|
||||
self.element_spec)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
# TODO(b/110122868): Remove this override once all `Dataset` instances
|
||||
# implement `element_structure`.
|
||||
return structure.convert_legacy_structure(
|
||||
@ -2003,8 +2003,8 @@ class DatasetV1Adapter(DatasetV1):
|
||||
return self._dataset.options()
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._dataset._element_structure # pylint: disable=protected-access
|
||||
def element_spec(self):
|
||||
return self._dataset.element_spec # pylint: disable=protected-access
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._dataset)
|
||||
@ -2107,7 +2107,7 @@ def get_structure(dataset_or_iterator):
|
||||
TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
|
||||
"""
|
||||
try:
|
||||
return dataset_or_iterator._element_structure # pylint: disable=protected-access
|
||||
return dataset_or_iterator.element_spec # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
|
||||
"object, but got %s." % type(dataset_or_iterator))
|
||||
@ -2306,8 +2306,8 @@ class UnaryUnchangedStructureDataset(UnaryDataset):
|
||||
input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._input_dataset._element_structure # pylint: disable=protected-access
|
||||
def element_spec(self):
|
||||
return self._input_dataset.element_spec
|
||||
|
||||
|
||||
class TensorDataset(DatasetSource):
|
||||
@ -2325,7 +2325,7 @@ class TensorDataset(DatasetSource):
|
||||
super(TensorDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -2352,7 +2352,7 @@ class TensorSliceDataset(DatasetSource):
|
||||
super(TensorSliceDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -2381,7 +2381,7 @@ class SparseTensorSliceDataset(DatasetSource):
|
||||
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -2396,20 +2396,20 @@ class _VariantDataset(DatasetV2):
|
||||
return []
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
class _NestedVariant(composite_tensor.CompositeTensor):
|
||||
|
||||
def __init__(self, variant_tensor, element_structure, dataset_shape):
|
||||
def __init__(self, variant_tensor, element_spec, dataset_shape):
|
||||
self._variant_tensor = variant_tensor
|
||||
self._element_structure = element_structure
|
||||
self._element_spec = element_spec
|
||||
self._dataset_shape = dataset_shape
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return DatasetStructure(self._element_structure, self._dataset_shape)
|
||||
return DatasetStructure(self._element_spec, self._dataset_shape)
|
||||
|
||||
|
||||
@tf_export("data.experimental.from_variant")
|
||||
@ -2445,10 +2445,10 @@ def to_variant(dataset):
|
||||
class DatasetStructure(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for `tf.data.Dataset`."""
|
||||
|
||||
__slots__ = ["_element_structure", "_dataset_shape"]
|
||||
__slots__ = ["_element_spec", "_dataset_shape"]
|
||||
|
||||
def __init__(self, element_spec, dataset_shape=None):
|
||||
self._element_structure = element_spec
|
||||
self._element_spec = element_spec
|
||||
if dataset_shape:
|
||||
self._dataset_shape = dataset_shape
|
||||
else:
|
||||
@ -2459,7 +2459,7 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
||||
return _VariantDataset
|
||||
|
||||
def _serialize(self):
|
||||
return (self._element_structure, self._dataset_shape)
|
||||
return (self._element_spec, self._dataset_shape)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -2471,10 +2471,9 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
||||
def _from_components(self, components):
|
||||
# pylint: disable=protected-access
|
||||
if self._dataset_shape.ndims == 0:
|
||||
return _VariantDataset(components, self._element_structure)
|
||||
return _VariantDataset(components, self._element_spec)
|
||||
else:
|
||||
return _NestedVariant(components, self._element_structure,
|
||||
self._dataset_shape)
|
||||
return _NestedVariant(components, self._element_spec, self._dataset_shape)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [
|
||||
@ -2484,17 +2483,17 @@ class DatasetStructure(type_spec.BatchableTypeSpec):
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return DatasetStructure(value._element_structure) # pylint: disable=protected-access
|
||||
return DatasetStructure(value.element_spec) # pylint: disable=protected-access
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return DatasetStructure(
|
||||
self._element_structure,
|
||||
self._element_spec,
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
|
||||
|
||||
def _unbatch(self):
|
||||
if self._dataset_shape.ndims == 0:
|
||||
raise ValueError("Unbatching a dataset is only supported for rank >= 1")
|
||||
return DatasetStructure(self._element_structure, self._dataset_shape[1:])
|
||||
return DatasetStructure(self._element_spec, self._dataset_shape[1:])
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
if self._dataset_shape.ndims == 0:
|
||||
@ -2571,7 +2570,7 @@ class StructuredFunctionWrapper(object):
|
||||
raise ValueError("Either `dataset`, `input_structure` or all of "
|
||||
"`input_classes`, `input_shapes`, and `input_types` "
|
||||
"must be specified.")
|
||||
self._input_structure = dataset._element_structure
|
||||
self._input_structure = dataset.element_spec
|
||||
else:
|
||||
if not (dataset is None and input_classes is None and input_shapes is None
|
||||
and input_types is None):
|
||||
@ -2767,7 +2766,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
super(_GeneratorDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._next_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -2792,7 +2791,7 @@ class ZipDataset(DatasetV2):
|
||||
self._datasets = datasets
|
||||
self._structure = nest.pack_sequence_as(
|
||||
self._datasets,
|
||||
[ds._element_structure for ds in nest.flatten(self._datasets)]) # pylint: disable=protected-access
|
||||
[ds.element_spec for ds in nest.flatten(self._datasets)])
|
||||
variant_tensor = gen_dataset_ops.zip_dataset(
|
||||
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
||||
**self._flat_structure)
|
||||
@ -2802,7 +2801,7 @@ class ZipDataset(DatasetV2):
|
||||
return nest.flatten(self._datasets)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -2850,7 +2849,7 @@ class ConcatenateDataset(DatasetV2):
|
||||
return self._input_datasets
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -2907,7 +2906,7 @@ class RangeDataset(DatasetSource):
|
||||
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -3037,11 +3036,11 @@ class BatchDataset(UnaryDataset):
|
||||
self._structure = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(
|
||||
tensor_util.constant_value(self._batch_size)),
|
||||
input_dataset._element_structure)
|
||||
input_dataset.element_spec)
|
||||
else:
|
||||
self._structure = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(None),
|
||||
input_dataset._element_structure)
|
||||
input_dataset.element_spec)
|
||||
variant_tensor = gen_dataset_ops.batch_dataset_v2(
|
||||
input_dataset._variant_tensor,
|
||||
batch_size=self._batch_size,
|
||||
@ -3050,7 +3049,7 @@ class BatchDataset(UnaryDataset):
|
||||
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -3268,7 +3267,7 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -3308,7 +3307,7 @@ class MapDataset(UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._map_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -3350,7 +3349,7 @@ class ParallelMapDataset(UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._map_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -3369,7 +3368,7 @@ class FlatMapDataset(UnaryDataset):
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
@ -3381,7 +3380,7 @@ class FlatMapDataset(UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -3400,7 +3399,7 @@ class InterleaveDataset(UnaryDataset):
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
@ -3419,7 +3418,7 @@ class InterleaveDataset(UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -3439,7 +3438,7 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
@ -3461,7 +3460,7 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
@ -3561,7 +3560,7 @@ class WindowDataset(UnaryDataset):
|
||||
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -3684,7 +3683,7 @@ class _RestructuredDataset(UnaryDataset):
|
||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
|
||||
@ -3713,5 +3712,5 @@ class _UnbatchDataset(UnaryDataset):
|
||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
@ -104,10 +104,12 @@ class Iterator(trackable.Trackable):
|
||||
raise ValueError("If `structure` is not specified, all of "
|
||||
"`output_types`, `output_shapes`, and `output_classes`"
|
||||
" must be specified.")
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(self._structure)
|
||||
self._flat_tensor_types = structure.get_flat_tensor_types(self._structure)
|
||||
self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
|
||||
self._element_spec)
|
||||
self._flat_tensor_types = structure.get_flat_tensor_types(
|
||||
self._element_spec)
|
||||
|
||||
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
|
||||
self._iterator_resource)
|
||||
@ -347,13 +349,13 @@ class Iterator(trackable.Trackable):
|
||||
# pylint: disable=protected-access
|
||||
dataset_output_types = nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_types(),
|
||||
dataset._element_structure)
|
||||
dataset.element_spec)
|
||||
dataset_output_shapes = nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(),
|
||||
dataset._element_structure)
|
||||
dataset.element_spec)
|
||||
dataset_output_classes = nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_classes(),
|
||||
dataset._element_structure)
|
||||
dataset.element_spec)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
nest.assert_same_structure(self.output_types, dataset_output_types)
|
||||
@ -436,7 +438,7 @@ class Iterator(trackable.Trackable):
|
||||
output_types=self._flat_tensor_types,
|
||||
output_shapes=self._flat_tensor_shapes,
|
||||
name=name)
|
||||
return structure.from_tensor_list(self._structure, flat_ret)
|
||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||
|
||||
def string_handle(self, name=None):
|
||||
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
||||
@ -467,7 +469,7 @@ class Iterator(trackable.Trackable):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -481,7 +483,7 @@ class Iterator(trackable.Trackable):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -495,17 +497,17 @@ class Iterator(trackable.Trackable):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
"""The structure of an element of this iterator.
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A `Structure` object representing the structure of the components of this
|
||||
optional.
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this iterator and specifying the type of individual components.
|
||||
"""
|
||||
return self._structure
|
||||
return self._element_spec
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
|
||||
@ -556,7 +558,7 @@ class IteratorResourceDeleter(object):
|
||||
class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
|
||||
|
||||
def __init__(self, dataset=None, components=None, element_structure=None):
|
||||
def __init__(self, dataset=None, components=None, element_spec=None):
|
||||
"""Creates a new iterator from the given dataset.
|
||||
|
||||
If `dataset` is not specified, the iterator will be created from the given
|
||||
@ -567,29 +569,29 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` object.
|
||||
components: Tensor components to construct the iterator from.
|
||||
element_structure: A nested structure of `TypeSpec` objects that
|
||||
represents the type specification elements of the iterator.
|
||||
element_spec: A nested structure of `TypeSpec` objects that
|
||||
represents the type specification of elements of the iterator.
|
||||
|
||||
Raises:
|
||||
ValueError: If `dataset` is not provided and either `components` or
|
||||
`element_structure` is not provided. Or `dataset` is provided and either
|
||||
`components` and `element_structure` is provided.
|
||||
`element_spec` is not provided. Or `dataset` is provided and either
|
||||
`components` and `element_spec` is provided.
|
||||
"""
|
||||
|
||||
error_message = "Either `dataset` or both `components` and "
|
||||
"`element_structure` need to be provided."
|
||||
"`element_spec` need to be provided."
|
||||
|
||||
self._device = context.context().device_name
|
||||
|
||||
if dataset is None:
|
||||
if (components is None or element_structure is None):
|
||||
if (components is None or element_spec is None):
|
||||
raise ValueError(error_message)
|
||||
# pylint: disable=protected-access
|
||||
self._structure = element_structure
|
||||
self._element_spec = element_spec
|
||||
self._flat_output_types = structure.get_flat_tensor_types(
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
self._iterator_resource, self._deleter = components
|
||||
# Delete the resource when this object is deleted
|
||||
self._resource_deleter = IteratorResourceDeleter(
|
||||
@ -597,7 +599,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
device=self._device,
|
||||
deleter=self._deleter)
|
||||
else:
|
||||
if (components is not None or element_structure is not None):
|
||||
if (components is not None or element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
if (_device_stack_is_empty() or
|
||||
context.context().device_spec.device_type != "CPU"):
|
||||
@ -610,11 +612,11 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
# pylint: disable=protected-access
|
||||
dataset = dataset._apply_options()
|
||||
ds_variant = dataset._variant_tensor
|
||||
self._structure = dataset._element_structure
|
||||
self._element_spec = dataset.element_spec
|
||||
self._flat_output_types = structure.get_flat_tensor_types(
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
self._flat_output_shapes = structure.get_flat_tensor_shapes(
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
with ops.colocate_with(ds_variant):
|
||||
self._iterator_resource, self._deleter = (
|
||||
gen_dataset_ops.anonymous_iterator_v2(
|
||||
@ -642,7 +644,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
self._iterator_resource,
|
||||
output_types=self._flat_output_types,
|
||||
output_shapes=self._flat_output_shapes)
|
||||
return structure.from_compatible_tensor_list(self._structure, ret)
|
||||
return structure.from_compatible_tensor_list(self._element_spec, ret)
|
||||
|
||||
# This runs in sync mode as iterators use an error status to communicate
|
||||
# that there is no more data to iterate over.
|
||||
@ -664,13 +666,13 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
try:
|
||||
# Fast path for the case `self._structure` is not a nested structure.
|
||||
return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access
|
||||
return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
return structure.from_compatible_tensor_list(self._structure, ret)
|
||||
return structure.from_compatible_tensor_list(self._element_spec, ret)
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return IteratorSpec(self._element_structure)
|
||||
return IteratorSpec(self.element_spec)
|
||||
|
||||
def next(self):
|
||||
"""Returns a nested structure of `Tensor`s containing the next element."""
|
||||
@ -693,7 +695,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -707,7 +709,7 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
@ -721,17 +723,17 @@ class IteratorV2(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||
self._structure)
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
"""The structure of an element of this iterator.
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A `Structure` object representing the structure of the components of this
|
||||
optional.
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this iterator and specifying the type of individual components.
|
||||
"""
|
||||
return self._structure
|
||||
return self._element_spec
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||
@ -786,11 +788,11 @@ class IteratorSpec(type_spec.TypeSpec):
|
||||
return IteratorV2(
|
||||
dataset=None,
|
||||
components=components,
|
||||
element_structure=self._element_spec)
|
||||
element_spec=self._element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return IteratorSpec(value._element_structure) # pylint: disable=protected-access
|
||||
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(b/71645805): Expose trackable stateful objects from dataset
|
||||
@ -828,7 +830,6 @@ def get_next_as_optional(iterator):
|
||||
return optional_ops._OptionalImpl(
|
||||
gen_dataset_ops.iterator_get_next_as_optional(
|
||||
iterator._iterator_resource,
|
||||
output_types=structure.get_flat_tensor_types(
|
||||
iterator._element_structure),
|
||||
output_types=structure.get_flat_tensor_types(iterator.element_spec),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
iterator._element_structure)), iterator._element_structure)
|
||||
iterator.element_spec)), iterator.element_spec)
|
||||
|
@ -36,8 +36,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
"""A `dummy` generator dataset."""
|
||||
|
||||
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
|
||||
source_device, element_structure):
|
||||
self._structure = element_structure
|
||||
source_device, element_spec):
|
||||
self._element_spec = element_spec
|
||||
|
||||
multi_device_iterator_string_handle = (
|
||||
gen_dataset_ops.multi_device_iterator_to_string_handle(
|
||||
@ -71,14 +71,15 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
multi_device_iterator = (
|
||||
gen_dataset_ops.multi_device_iterator_from_string_handle(
|
||||
string_handle=string_handle,
|
||||
output_types=structure.get_flat_tensor_types(self._structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(self._structure)))
|
||||
output_types=structure.get_flat_tensor_types(self._element_spec),
|
||||
output_shapes=structure.get_flat_tensor_shapes(
|
||||
self._element_spec)))
|
||||
return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
|
||||
multi_device_iterator=multi_device_iterator,
|
||||
shard_num=shard_num,
|
||||
incarnation_id=incarnation_id,
|
||||
output_types=structure.get_flat_tensor_types(self._structure),
|
||||
output_shapes=structure.get_flat_tensor_shapes(self._structure))
|
||||
output_types=structure.get_flat_tensor_types(self._element_spec),
|
||||
output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
|
||||
|
||||
next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||
|
||||
@ -91,7 +92,7 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
return functional_ops.remote_call(
|
||||
target=source_device,
|
||||
args=[string_handle] + next_func_concrete.captured_inputs,
|
||||
Tout=structure.get_flat_tensor_types(self._structure),
|
||||
Tout=structure.get_flat_tensor_types(self._element_spec),
|
||||
f=next_func_concrete)
|
||||
|
||||
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
||||
@ -141,8 +142,8 @@ class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
return []
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
@ -154,7 +155,7 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
|
||||
def __init__(self, per_device_dataset, incarnation_id):
|
||||
# pylint: disable=protected-access
|
||||
self._structure = per_device_dataset._element_structure
|
||||
self._element_spec = per_device_dataset.element_spec
|
||||
self._init_func = per_device_dataset._init_func
|
||||
self._init_captured_args = self._init_func.captured_inputs
|
||||
|
||||
@ -183,8 +184,8 @@ class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
return []
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class MultiDeviceIterator(object):
|
||||
@ -253,7 +254,7 @@ class MultiDeviceIterator(object):
|
||||
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
|
||||
self._incarnation_id,
|
||||
self._source_device_tensor,
|
||||
self._dataset._element_structure) # pylint: disable=protected-access
|
||||
self._dataset.element_spec)
|
||||
self._prototype_device_datasets.append(ds)
|
||||
|
||||
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
|
||||
@ -339,5 +340,5 @@ class MultiDeviceIterator(object):
|
||||
ds_variant, self._device_iterators[i]._iterator_resource)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return dataset_ops.get_structure(self._dataset)
|
||||
def element_spec(self):
|
||||
return self._dataset.element_spec
|
||||
|
@ -115,7 +115,7 @@ class _TextLineDataset(dataset_ops.DatasetSource):
|
||||
super(_TextLineDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -157,7 +157,7 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
|
||||
super(TextLineDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
||||
super(_TFRecordDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -225,7 +225,7 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
if not isinstance(self._map_func.output_structure,
|
||||
dataset_ops.DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
@ -257,8 +257,8 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
@ -321,7 +321,7 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
return self._impl._inputs() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -408,7 +408,7 @@ class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
|
||||
super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
@ -466,7 +466,7 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
|
||||
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
def element_spec(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
|
@ -480,7 +480,7 @@ class DistributedDataset(_IterableInput):
|
||||
self._input_workers = input_workers
|
||||
# TODO(anjalisridhar): Identify if we need to set this property on the
|
||||
# iterator.
|
||||
self._element_structure = dataset._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = dataset.element_spec
|
||||
self._strategy = strategy
|
||||
|
||||
def __iter__(self):
|
||||
@ -490,7 +490,7 @@ class DistributedDataset(_IterableInput):
|
||||
self._input_workers)
|
||||
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
||||
iterator.element_spec = self._element_spec
|
||||
return iterator
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
@ -537,7 +537,7 @@ class DistributedDatasetV1(DistributedDataset):
|
||||
self._input_workers)
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
||||
iterator.element_spec = self._element_spec
|
||||
return iterator
|
||||
|
||||
|
||||
@ -670,9 +670,9 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
|
||||
super(DatasetIterator, self).__init__(
|
||||
input_workers,
|
||||
worker_iterators, # pylint: disable=protected-access
|
||||
worker_iterators,
|
||||
strategy)
|
||||
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access
|
||||
self._element_spec = dist_dataset._element_spec # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _dummy_tensor_fn(value_structure):
|
||||
|
@ -56,8 +56,7 @@ def _clone_dataset(dataset):
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
|
||||
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
|
||||
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
|
||||
return dataset_ops._VariantDataset(new_variant_tensor,
|
||||
dataset._element_structure)
|
||||
return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
|
||||
|
||||
|
||||
def _get_op_def(op):
|
||||
|
@ -276,8 +276,7 @@ class CloneDatasetTest(test.TestCase):
|
||||
def _assert_datasets_equal(self, ds1, ds2):
|
||||
# First lets assert the structure is the same.
|
||||
self.assertTrue(
|
||||
structure.are_compatible(ds1._element_structure,
|
||||
ds2._element_structure))
|
||||
structure.are_compatible(ds1.element_spec, ds2.element_spec))
|
||||
|
||||
# Now create iterators on both and assert they produce the same values.
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
|
@ -5,6 +5,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -3,6 +3,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "initializer"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_classes"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,6 +4,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\', \'compression_type\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
|
@ -5,6 +5,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user