diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index b6cdc7aab03..fa64055dfd6 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -489,7 +489,7 @@ class BigtableTable(object): "len(dataset.output_types))") return gen_bigtable_ops.dataset_to_bigtable( self._resource, - dataset._as_variant_tensor(), # pylint: disable=protected-access + dataset._variant_tensor, # pylint: disable=protected-access column_families, columns, timestamp) @@ -582,13 +582,14 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): """_BigtableKeyDataset is an abstract class representing the keys of a table. """ - def __init__(self, table): + def __init__(self, table, variant_tensor): """Constructs a _BigtableKeyDataset. Args: table: a Bigtable class. + variant_tensor: DT_VARIANT representation of the dataset. """ - super(_BigtableKeyDataset, self).__init__() + super(_BigtableKeyDataset, self).__init__(variant_tensor) self._table = table @property @@ -601,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, prefix): - super(_BigtablePrefixKeyDataset, self).__init__(table) self._prefix = prefix - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_prefix_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset( + table=table._resource, # pylint: disable=protected-access prefix=self._prefix) + super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor) class _BigtableRangeKeyDataset(_BigtableKeyDataset): @@ -615,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, start, end): - super(_BigtableRangeKeyDataset, self).__init__(table) self._start = start self._end = end - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_range_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset( + table=table._resource, # pylint: disable=protected-access start_key=self._start, end_key=self._end) + super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor) class _BigtableSampleKeysDataset(_BigtableKeyDataset): @@ -633,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset): # TODO(saeta): Expose the data size offsets into the keys. def __init__(self, table): - super(_BigtableSampleKeysDataset, self).__init__(table) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_sample_keys_dataset( - table=self._table._resource) # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset( + table=table._resource) # pylint: disable=protected-access + super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor) class _BigtableLookupDataset(dataset_ops.DatasetSource): @@ -651,20 +646,18 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._normalized = normalized self._column_families = [i[0] for i in normalized] self._columns = [i[1] for i in normalized] + variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access + table=self._table._resource, # pylint: disable=protected-access + column_families=self._column_families, + columns=self._columns) + super(_BigtableLookupDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure(tuple( [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_lookup_dataset( - keys_dataset=self._dataset._as_variant_tensor(), - table=self._table._resource, - column_families=self._column_families, - columns=self._columns) - class _BigtableScanDataset(dataset_ops.DatasetSource): """_BigtableScanDataset represents a dataset that retrieves keys and values. @@ -679,14 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] self._probability = probability self._num_outputs = len(normalized) + 1 # 1 for row key - - @property - def _element_structure(self): - return structure.NestedStructure(tuple( - [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_scan_dataset( + variant_tensor = gen_bigtable_ops.bigtable_scan_dataset( table=self._table._resource, # pylint: disable=protected-access prefix=self._prefix, start_key=self._start, @@ -694,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): column_families=self._column_families, columns=self._columns, probability=self._probability) + super(_BigtableScanDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure( + tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): @@ -705,17 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._prefix = prefix self._start = start self._end = end + variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end) + super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure( (structure.TensorStructure(dtypes.string, []), structure.TensorStructure(dtypes.string, []))) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( - table=self._table._resource, - prefix=self._prefix, - start_key=self._start, - end_key=self._end) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index c0152156a1b..c6bf5215c94 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -389,13 +389,11 @@ class LMDBDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(LMDBDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_lmdb_dataset( + variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset( self._filenames, **dataset_ops.flat_structure(self)) + super(LMDBDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 5c6ee6bfdc7..6708e01d081 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -30,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") @@ -43,14 +42,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) self._structure = input_structure._batch(None) # pylint: disable=protected-access - - def _as_variant_tensor(self): - return ged_ops.experimental_sliding_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_sliding_window_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) + super(_SlideDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index f6cb3d6313e..0e8e86f6b96 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -472,11 +472,11 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): for r in range(len(devices))]) def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values, auto_shard=True): + expected_values): device_map = values.ReplicaDeviceMap(devices) input_workers = values.InputWorkers(device_map, worker_devices) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers, auto_shard=auto_shard) + dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.cached_session() as sess: sess.run(multi_worker_iterator.initializer) @@ -518,16 +518,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testDataDistributionNoAutoShard(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], - auto_shard=False) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) def testDataDistributionTwoDevicePerWorker(self): if context.num_gpus() < 1: @@ -535,8 +528,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 2, 1, 3], [4, 6, 5, 7]]) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) def testTupleDataset(self): worker_devices, devices = self._cpu_devices() @@ -548,9 +542,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) - expected_values = [ - [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) - ] + expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] self._test_dataset(dataset_fn, worker_devices, devices, expected_values) @@ -561,17 +553,19 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): device_map = values.ReplicaDeviceMap(devices) input_workers = values.InputWorkers(device_map, worker_devices) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers, auto_shard=True) + dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) # After re-initializing the iterator, should be able to iterate again. sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) def testValueErrorForIterator(self): # Incompatiable arguments. diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 77813519c13..71eac729a8a 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -55,13 +55,11 @@ class SequenceFileDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(SequenceFileDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_dataset_ops.sequence_file_dataset( + variant_tensor = gen_dataset_ops.sequence_file_dataset( self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access + super(SequenceFileDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 52d87b80040..8a94f527bb6 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -27,6 +27,7 @@ from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test @@ -55,6 +56,7 @@ class DatasetsTest(test.TestCase): session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) self._sess = session.Session(self._worker.target, config=session_config) + self._worker_device = '/job:' + worker_job.name def testTextLineDataset(self): all_contents = [] @@ -70,7 +72,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +97,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +125,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +159,8 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +183,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index 1c298cfdd6a..5f0c01be4bc 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -119,25 +119,25 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { std::vector key_func_other_arguments_node; DataTypeVector key_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_key_func_, &key_func_other_arguments_node, + ctx, b, captured_key_func_, &key_func_other_arguments_node, &key_func_other_arguments_types)); std::vector init_func_other_arguments_node; DataTypeVector init_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_init_func_, &init_func_other_arguments_node, + ctx, b, captured_init_func_, &init_func_other_arguments_node, &init_func_other_arguments_types)); std::vector reduce_func_other_arguments_node; DataTypeVector reduce_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_reduce_func_, &reduce_func_other_arguments_node, + ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node, &reduce_func_other_arguments_types)); std::vector finalize_func_other_arguments_node; DataTypeVector finalize_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_finalize_func_, &finalize_func_other_arguments_node, + ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node, &finalize_func_other_arguments_types)); AttrValue key_func; @@ -406,7 +406,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } Status OtherArgumentsNodeAndType( - DatasetGraphDefBuilder* b, + SerializationContext* ctx, DatasetGraphDefBuilder* b, const std::unique_ptr& captured_func, std::vector* other_arguments_node, DataTypeVector* other_arguments_types) const { @@ -414,7 +414,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { other_arguments_types->reserve(captured_func->captured_inputs().size()); for (const Tensor& t : captured_func->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments_node->emplace_back(node); other_arguments_types->emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 98603d5a732..11491e00db8 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -117,20 +117,21 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { std::vector key_func_other_arguments_node; DataTypeVector key_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_key_func_, &key_func_other_arguments_node, + ctx, b, captured_key_func_, &key_func_other_arguments_node, &key_func_other_arguments_types)); std::vector reduce_func_other_arguments_node; DataTypeVector reduce_func_other_arguments_types; TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_reduce_func_, &reduce_func_other_arguments_node, + ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node, &reduce_func_other_arguments_types)); std::vector window_size_func_other_arguments_node; DataTypeVector window_size_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - b, captured_window_size_func_, &window_size_func_other_arguments_node, - &window_size_func_other_arguments_types)); + TF_RETURN_IF_ERROR( + OtherArgumentsNodeAndType(ctx, b, captured_window_size_func_, + &window_size_func_other_arguments_node, + &window_size_func_other_arguments_types)); AttrValue key_func; b->BuildAttrValue(key_func_, &key_func); @@ -490,7 +491,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { }; Status OtherArgumentsNodeAndType( - DatasetGraphDefBuilder* b, + SerializationContext* ctx, DatasetGraphDefBuilder* b, const std::unique_ptr& captured_func, std::vector* other_arguments_node, DataTypeVector* other_arguments_types) const { @@ -498,7 +499,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { other_arguments_types->reserve(captured_func->captured_inputs().size()); for (const Tensor& t : captured_func->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments_node->emplace_back(node); other_arguments_types->emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 3ff31355936..ef75c844565 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -210,7 +210,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc index 921f8ad5840..2b1aec358cc 100644 --- a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc @@ -169,7 +169,13 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 0230f90aba1..1c19119d88b 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -154,7 +154,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index 0d9a629a27f..76ab33fe988 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -119,7 +119,13 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index b8b657d3433..30b2fc5db80 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -137,7 +137,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 3846334622b..efa76ab34bc 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -95,7 +95,13 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 54e3645612c..1a5e6edb5b7 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -121,7 +121,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index fc6e93a81cb..02c0199a0c5 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -149,7 +149,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f844a005768..fda7ae0cbba 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -160,7 +160,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 5c09b2d5dc8..c0002c86d87 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -141,7 +141,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { other_arguments.reserve(captured_func_->captured_inputs().size()); for (const Tensor& t : captured_func_->captured_inputs()) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } other_arguments.emplace_back(node); other_arguments_types.emplace_back(t.dtype()); } diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py index b2f1b43ecf6..e523f36639d 100644 --- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py @@ -89,14 +89,12 @@ class CsvDatasetTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(nxt()) else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - nxt = self.getNext(dataset) - while True: - try: - self.evaluate(nxt()) - except errors.OutOfRangeError: - break + nxt = self.getNext(dataset) + while True: + try: + self.evaluate(nxt()) + except errors.OutOfRangeError: + break def _test_dataset( self, @@ -110,8 +108,14 @@ class CsvDatasetTest(test_base.DatasetTestBase): # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type - dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(dataset, expected_output, expected_err_re) + if expected_err_re is not None: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) + else: + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index ceadebc5411..c90c5ed306f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -120,8 +120,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[]) # Empty batch should be an initialization time error. - self.assertDatasetProduces( - dataset_fn(0, 14), expected_error=(errors.InvalidArgumentError, "")) + with self.assertRaises(errors.InvalidArgumentError): + self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[]) @parameterized.named_parameters( ("Even", False, False), diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py index dd432b8c15d..c111567c1c5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py @@ -211,16 +211,15 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "v", initializer=0, use_resource=False) assign_op = variable.assign_add(1) - unoptimized_dataset = dataset_fn(variable) - - options = dataset_ops.Options() - options.experimental_optimization.noop_elimination = True - options.experimental_optimization.map_and_batch_fusion = True - optimized_dataset = unoptimized_dataset.with_options(options) - # Check that warning is logged. warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: + unoptimized_dataset = dataset_fn(variable) + + options = dataset_ops.Options() + options.experimental_optimization.noop_elimination = True + options.experimental_optimization.map_and_batch_fusion = True + optimized_dataset = unoptimized_dataset.with_options(options) optimized_it = optimized_dataset.make_initializable_iterator() self.assertGreaterEqual(len(w), 1) diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py index fd96c0b5213..e97c80627cf 100644 --- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py @@ -110,13 +110,13 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase): # Test that an error is raised when `driver_name` is invalid. def testReadResultSetWithInvalidDriverName(self): - dataset = self._createSqlDataset( - driver_name="sqlfake", - query="SELECT first_name, last_name, motto FROM students " - "ORDER BY first_name DESC", - output_types=(dtypes.string, dtypes.string, dtypes.string)) - self.assertDatasetProduces( - dataset, expected_error=(errors.InvalidArgumentError, "")) + with self.assertRaises(errors.InvalidArgumentError): + dataset = self._createSqlDataset( + driver_name="sqlfake", + query="SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC", + output_types=(dtypes.string, dtypes.string, dtypes.string)) + self.assertDatasetProduces(dataset, expected_output=[]) # Test that an error is raised when a column name in `query` is nonexistent def testReadResultSetWithInvalidColumnName(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py index 59d0ebdb37e..8b330559f5f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -197,10 +197,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation): def dataset_fn(): - dataset = dataset_ops.Dataset.range(10).map( - lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))) + + def interleave_fn(_): + return dataset_ops.Dataset.range( + 10).map(lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))) + dataset = dataset_ops.Dataset.range(1).interleave( - lambda _: dataset, + interleave_fn, cycle_length=1, num_parallel_calls=optimization.AUTOTUNE) options = dataset_ops.Options() diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index 29df98f4ea4..f0cf7f0a995 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -352,7 +352,6 @@ class _UnbatchDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(_UnbatchDataset, self).__init__(input_dataset) flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -370,10 +369,10 @@ class _UnbatchDataset(dataset_ops.UnaryDataset): nest.map_structure(lambda s: s[1:], input_dataset.output_shapes), input_dataset.output_classes) - def _as_variant_tensor(self): - return ged_ops.experimental_unbatch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_unbatch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) + super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): @@ -440,7 +439,6 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(_DenseToSparseBatchDataset, self).__init__(input_dataset) if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % @@ -452,12 +450,13 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): input_dataset.output_types, tensor_shape.vector(None).concatenate(self._row_shape)) - def _as_variant_tensor(self): - return ged_ops.experimental_dense_to_sparse_batch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **dataset_ops.flat_structure(self)) + super(_DenseToSparseBatchDataset, self).__init__(input_dataset, + variant_tensor) @property def _element_structure(self): @@ -499,7 +498,6 @@ class _RestructuredDataset(dataset_ops.UnaryDataset): ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ - super(_RestructuredDataset, self).__init__(dataset) self._input_dataset = dataset if not allow_unsafe_cast: @@ -539,9 +537,8 @@ class _RestructuredDataset(dataset_ops.UnaryDataset): self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) - - def _as_variant_tensor(self): - return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access + super(_RestructuredDataset, self).__init__(dataset, variant_tensor) @property def _element_structure(self): @@ -554,8 +551,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" - super(_MapAndBatchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset + self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset) self._batch_size_t = ops.convert_to_tensor( @@ -573,14 +570,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset): tensor_util.constant_value(self._batch_size_t)) else: self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access - - def _functions(self): - return [self._map_func] - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return ged_ops.experimental_map_and_batch_dataset( - self._input_dataset._as_variant_tensor(), + variant_tensor = ged_ops.experimental_map_and_batch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, @@ -588,6 +579,10 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset): drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **dataset_ops.flat_structure(self)) + super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._map_func] @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/cardinality.py b/tensorflow/python/data/experimental/ops/cardinality.py index 9cf0a8801e8..0d596f68dd5 100644 --- a/tensorflow/python/data/experimental/ops/cardinality.py +++ b/tensorflow/python/data/experimental/ops/cardinality.py @@ -47,4 +47,4 @@ def cardinality(dataset): the cardinality is infinite or unknown, the operation returns the named constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively. """ - return ged_ops.experimental_dataset_cardinality(dataset._as_variant_tensor()) # pylint: disable=protected-access + return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py index 879b13ce092..eab29c7d88f 100644 --- a/tensorflow/python/data/experimental/ops/error_ops.py +++ b/tensorflow/python/data/experimental/ops/error_ops.py @@ -57,10 +57,9 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset): def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(_IgnoreErrorsDataset, self).__init__(input_dataset) self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_ignore_errors_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) + variant_tensor = ( + gen_experimental_dataset_ops.experimental_ignore_errors_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + **dataset_ops.flat_structure(self))) + super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py index d649a070127..46c215d6850 100644 --- a/tensorflow/python/data/experimental/ops/get_single_element.py +++ b/tensorflow/python/data/experimental/ops/get_single_element.py @@ -64,5 +64,4 @@ def get_single_element(dataset): # pylint: disable=protected-access return dataset._element_structure._from_compatible_tensor_list( gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), - **dataset_ops.flat_structure(dataset))) + dataset._variant_tensor, **dataset_ops.flat_structure(dataset))) diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py index ef6b232429b..2435f0cfdb7 100644 --- a/tensorflow/python/data/experimental/ops/grouping.py +++ b/tensorflow/python/data/experimental/ops/grouping.py @@ -242,14 +242,23 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(_GroupByReducerDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._make_key_func(key_func, input_dataset) self._make_init_func(reducer.init_func) self._make_reduce_func(reducer.reduce_func, input_dataset) self._make_finalize_func(reducer.finalize_func) + variant_tensor = ged_ops.experimental_group_by_reducer_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + self._key_func.function.captured_inputs, + self._init_func.function.captured_inputs, + self._reduce_func.function.captured_inputs, + self._finalize_func.function.captured_inputs, + key_func=self._key_func.function, + init_func=self._init_func.function, + reduce_func=self._reduce_func.function, + finalize_func=self._finalize_func.function, + **dataset_ops.flat_structure(self)) + super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor) def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" @@ -347,19 +356,6 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset): self._key_func, self._init_func, self._reduce_func, self._finalize_func ] - def _as_variant_tensor(self): - return ged_ops.experimental_group_by_reducer_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.function.captured_inputs, - self._init_func.function.captured_inputs, - self._reduce_func.function.captured_inputs, - self._finalize_func.function.captured_inputs, - key_func=self._key_func.function, - init_func=self._init_func.function, - reduce_func=self._reduce_func.function, - finalize_func=self._finalize_func.function, - **dataset_ops.flat_structure(self)) - def _transformation_name(self): return "tf.data.experimental.group_by_reducer()" @@ -369,13 +365,20 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(_GroupByWindowDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._make_key_func(key_func, input_dataset) self._make_reduce_func(reduce_func, input_dataset) self._make_window_size_func(window_size_func) + variant_tensor = ged_ops.experimental_group_by_window_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + self._key_func.function.captured_inputs, + self._reduce_func.function.captured_inputs, + self._window_size_func.function.captured_inputs, + key_func=self._key_func.function, + reduce_func=self._reduce_func.function, + window_size_func=self._window_size_func.function, + **dataset_ops.flat_structure(self)) + super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) def _make_window_size_func(self, window_size_func): """Make wrapping defun for window_size_func.""" @@ -426,17 +429,6 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset): def _functions(self): return [self._key_func, self._reduce_func, self._window_size_func] - def _as_variant_tensor(self): - return ged_ops.experimental_group_by_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.function.captured_inputs, - self._reduce_func.function.captured_inputs, - self._window_size_func.function.captured_inputs, - key_func=self._key_func.function, - reduce_func=self._reduce_func.function, - window_size_func=self._window_size_func.function, - **dataset_ops.flat_structure(self)) - def _transformation_name(self): return "tf.data.experimental.group_by_window()" diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py index 5a719f8ed8f..f4b7123df11 100644 --- a/tensorflow/python/data/experimental/ops/interleave_ops.py +++ b/tensorflow/python/data/experimental/ops/interleave_ops.py @@ -113,15 +113,15 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset): self._structure = structure.convert_legacy_structure( data_inputs[0].output_types, output_shapes, data_inputs[0].output_classes) + super(_DirectedInterleaveDataset, self).__init__() def _as_variant_tensor(self): # pylint: disable=protected-access return ( gen_experimental_dataset_ops.experimental_directed_interleave_dataset( - self._selector_input._as_variant_tensor(), [ - data_input._as_variant_tensor() - for data_input in self._data_inputs - ], **dataset_ops.flat_structure(self))) + self._selector_input._variant_tensor, + [data_input._variant_tensor for data_input in self._data_inputs], + **dataset_ops.flat_structure(self))) # pylint: enable=protected-access def _inputs(self): diff --git a/tensorflow/python/data/experimental/ops/matching_files.py b/tensorflow/python/data/experimental/ops/matching_files.py index 63b99cb1e45..29beda9fc3a 100644 --- a/tensorflow/python/data/experimental/ops/matching_files.py +++ b/tensorflow/python/data/experimental/ops/matching_files.py @@ -29,12 +29,10 @@ class MatchingFilesDataset(dataset_ops.DatasetSource): """A `Dataset` that list the files according to the input patterns.""" def __init__(self, patterns): - super(MatchingFilesDataset, self).__init__() self._patterns = ops.convert_to_tensor( patterns, dtype=dtypes.string, name="patterns") - - def _as_variant_tensor(self): - return ged_ops.experimental_matching_files_dataset(self._patterns) + variant_tensor = ged_ops.experimental_matching_files_dataset(self._patterns) + super(MatchingFilesDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py index c6c7de9265c..22a36646ea4 100644 --- a/tensorflow/python/data/experimental/ops/optimization.py +++ b/tensorflow/python/data/experimental/ops/optimization.py @@ -105,18 +105,17 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset): def __init__(self, input_dataset, transformations): """See `assert_next()` for details.""" - super(_AssertNextDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if transformations is None: raise ValueError("At least one transformation should be specified") self._transformations = ops.convert_to_tensor( transformations, dtype=dtypes.string, name="transformations") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_assert_next_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._transformations, - **dataset_ops.flat_structure(self)) + variant_tensor = ( + gen_experimental_dataset_ops.experimental_assert_next_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self))) + super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor) class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset): @@ -124,10 +123,9 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset): def __init__(self, input_dataset): """See `non_serializable()` for details.""" - super(_NonSerializableDataset, self).__init__(input_dataset) self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_non_serializable_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) + variant_tensor = ( + gen_experimental_dataset_ops.experimental_non_serializable_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + **dataset_ops.flat_structure(self))) + super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py index deb20d61888..a5ca96e89b5 100644 --- a/tensorflow/python/data/experimental/ops/parsing_ops.py +++ b/tensorflow/python/data/experimental/ops/parsing_ops.py @@ -31,7 +31,6 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset): """A `Dataset` that parses `example` dataset into a `dict` dataset.""" def __init__(self, input_dataset, features, num_parallel_calls): - super(_ParseExampleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access structure.TensorStructure(dtypes.string, [None])): @@ -81,16 +80,17 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset): self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_parse_example_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._num_parallel_calls, - self._dense_defaults, - self._sparse_keys, - self._dense_keys, - self._sparse_types, - self._dense_shapes, - **dataset_ops.flat_structure(self)) + variant_tensor = ( + gen_experimental_dataset_ops.experimental_parse_example_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + self._num_parallel_calls, + self._dense_defaults, + self._sparse_keys, + self._dense_keys, + self._sparse_types, + self._dense_shapes, + **dataset_ops.flat_structure(self))) + super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index e3a86223933..ef9db2f2d06 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -93,7 +93,6 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ - super(_CopyToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) @@ -101,6 +100,9 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) + wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( + self._input_dataset._variant_tensor) # pylint: disable=protected-access + @function.defun() def _init_func(): """Creates an iterator for the input dataset. @@ -108,8 +110,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): Returns: A `string` tensor that encapsulates the iterator created. """ - # pylint: disable=protected-access - ds_variant = self._input_dataset._as_variant_tensor() + ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies( @@ -195,6 +196,17 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): self._finalize_func.add_to_graph(g) # pylint: enable=protected-scope + with ops.device(self._target_device): + variant_tensor = gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + **dataset_ops.flat_structure(self._input_dataset)) + super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) + # The one_shot_iterator implementation needs a 0 arg _make_dataset function # that thereby captures all the inputs required to create the dataset. Since # there are strings that are inputs to the GeneratorDataset which can't be @@ -208,24 +220,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): else: return super(_CopyToDeviceDataset, self).make_one_shot_iterator() - def _as_variant_tensor(self): - with ops.device(self._target_device): - return gen_dataset_ops.generator_dataset( - self._init_captured_args, - self._next_captured_args, - self._finalize_captured_args, - init_func=self._init_func, - next_func=self._next_func, - finalize_func=self._finalize_func, - **dataset_ops.flat_structure(self._input_dataset)) - class _MapOnGpuDataset(dataset_ops.UnaryDataset): """A `Dataset` that maps a function over elements in its using a GPU.""" def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(_MapOnGpuDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism @@ -234,18 +234,16 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset): self._transformation_name(), dataset=input_dataset, defun_kwargs={"experimental_ints_on_device": True}) - - def _functions(self): - return [self._map_func] - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return ged_ops.experimental_map_dataset( - input_t, + variant_tensor = ged_ops.experimental_map_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, use_inter_op_parallelism=self._use_inter_op_parallelism, **dataset_ops.flat_structure(self)) + super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._map_func] @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py index cbdf367db6b..f96e4a84b4a 100644 --- a/tensorflow/python/data/experimental/ops/random_ops.py +++ b/tensorflow/python/data/experimental/ops/random_ops.py @@ -33,14 +33,10 @@ class RandomDatasetV2(dataset_ops.DatasetSource): def __init__(self, seed=None): """A `Dataset` of pseudorandom values.""" - super(RandomDatasetV2, self).__init__() self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_random_dataset( - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) + variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset( + seed=self._seed, seed2=self._seed2, **dataset_ops.flat_structure(self)) + super(RandomDatasetV2, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py index c2d82aeb591..177886e64be 100644 --- a/tensorflow/python/data/experimental/ops/readers.py +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -622,7 +622,6 @@ class CsvDatasetV2(dataset_ops.DatasetSource): the input data. If specified, only this subset of columns will be parsed. Defaults to parsing all columns. """ - super(CsvDatasetV2, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") self._compression_type = convert.optional_param_to_tensor( @@ -655,10 +654,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource): self._structure = structure.NestedStructure( tuple(structure.TensorStructure(d.dtype, []) for d in self._record_defaults)) - - def _as_variant_tensor(self): - # Constructs graph node for the dataset op. - return gen_experimental_dataset_ops.experimental_csv_dataset( + variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset( filenames=self._filenames, record_defaults=self._record_defaults, buffer_size=self._buffer_size, @@ -668,8 +664,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource): use_quote_delim=self._use_quote_delim, na_value=self._na_value, select_cols=self._select_cols, - compression_type=self._compression_type, - ) + compression_type=self._compression_type) + super(CsvDatasetV2, self).__init__(variant_tensor) @property def _element_structure(self): @@ -944,7 +940,6 @@ class SqlDatasetV2(dataset_ops.DatasetSource): output_types: A tuple of `tf.DType` objects representing the types of the columns returned by `query`. """ - super(SqlDatasetV2, self).__init__() self._driver_name = ops.convert_to_tensor( driver_name, dtype=dtypes.string, name="driver_name") self._data_source_name = ops.convert_to_tensor( @@ -954,11 +949,10 @@ class SqlDatasetV2(dataset_ops.DatasetSource): self._structure = structure.NestedStructure( nest.map_structure( lambda dtype: structure.TensorStructure(dtype, []), output_types)) - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_sql_dataset( + variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset( self._driver_name, self._data_source_name, self._query, nest.flatten(self.output_types), nest.flatten(self.output_shapes)) + super(SqlDatasetV2, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py index 5c77ad73434..7662626c3a0 100644 --- a/tensorflow/python/data/experimental/ops/scan_ops.py +++ b/tensorflow/python/data/experimental/ops/scan_ops.py @@ -33,7 +33,6 @@ class _ScanDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" - super(_ScanDataset, self).__init__(input_dataset) self._input_dataset = input_dataset with ops.name_scope("initial_state"): @@ -126,20 +125,18 @@ class _ScanDataset(dataset_ops.UnaryDataset): self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) - - def _functions(self): - return [self._scan_func] - - def _as_variant_tensor(self): # pylint: disable=protected-access - input_t = self._input_dataset._as_variant_tensor() - return gen_experimental_dataset_ops.experimental_scan_dataset( - input_t, + variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset( + self._input_dataset._variant_tensor, self._state_structure._to_tensor_list(self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **dataset_ops.flat_structure(self)) + super(_ScanDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._scan_func] @property def _element_structure(self): diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py index d12328a7145..86a615d5240 100644 --- a/tensorflow/python/data/experimental/ops/shuffle_ops.py +++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py @@ -30,7 +30,6 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset): """A `Dataset` that fuses `shuffle` and `repeat`.""" def __init__(self, input_dataset, buffer_size, count=None, seed=None): - super(_ShuffleAndRepeatDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") @@ -40,18 +39,15 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset): self._count = ops.convert_to_tensor( count, dtype=dtypes.int64, name="count") self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.shuffle_and_repeat_dataset( - input_resource, + variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, count=self._count, seed=self._seed, seed2=self._seed2, **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access + super(_ShuffleAndRepeatDataset, self).__init__(input_dataset, + variant_tensor) @tf_export("data.experimental.shuffle_and_repeat") diff --git a/tensorflow/python/data/experimental/ops/sleep.py b/tensorflow/python/data/experimental/ops/sleep.py index 2da832395b2..b66edc7a194 100644 --- a/tensorflow/python/data/experimental/ops/sleep.py +++ b/tensorflow/python/data/experimental/ops/sleep.py @@ -25,15 +25,13 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset): """A `Dataset` that sleeps before producing each upstream element.""" def __init__(self, input_dataset, sleep_microseconds): - super(_SleepDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._sleep_microseconds = sleep_microseconds - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_sleep_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._sleep_microseconds, **dataset_ops.flat_structure(self)) + super(_SleepDataset, self).__init__(input_dataset, variant_tensor) def sleep(sleep_microseconds): diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py index 15a9d24546e..13dcb92fa06 100644 --- a/tensorflow/python/data/experimental/ops/stats_ops.py +++ b/tensorflow/python/data/experimental/ops/stats_ops.py @@ -102,13 +102,11 @@ class _StatsDataset(dataset_ops.UnaryUnchangedStructureDataset): """A `Dataset` that acts as an identity, and also records statistics.""" def __init__(self, input_dataset, op_function, tag): - super(_StatsDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._op_function = op_function self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) - - def _as_variant_tensor(self): - return self._op_function( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = self._op_function( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._tag, **dataset_ops.flat_structure(self)) + super(_StatsDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py index 69e8829d687..bc2c726822a 100644 --- a/tensorflow/python/data/experimental/ops/threadpool.py +++ b/tensorflow/python/data/experimental/ops/threadpool.py @@ -64,15 +64,13 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset): """A `Dataset` that acts as an identity, and sets a custom threadpool.""" def __init__(self, input_dataset, thread_pool): - super(_ThreadPoolDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._thread_pool = thread_pool - - def _as_variant_tensor(self): - return ged_ops.experimental_thread_pool_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_thread_pool_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) + super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor) # TODO(b/73383364): Properly export in the `tf.data.experimental` API when diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py index 55ed98d8542..dd26cfa4ee9 100644 --- a/tensorflow/python/data/experimental/ops/unique.py +++ b/tensorflow/python/data/experimental/ops/unique.py @@ -53,15 +53,13 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset): def __init__(self, input_dataset): """See `unique()` for details.""" - super(_UniqueDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): raise TypeError( "`tf.data.experimental.unique()` only supports inputs with a single " "`tf.int32`, `tf.int64`, or `tf.string` component.") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_unique_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) + super(_UniqueDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py index aef6da51409..49eae146523 100644 --- a/tensorflow/python/data/experimental/ops/writers.py +++ b/tensorflow/python/data/experimental/ops/writers.py @@ -57,4 +57,4 @@ class TFRecordWriter(object): "produces shape {0} and types {1}".format(dataset.output_shapes, dataset.output_types)) return gen_experimental_dataset_ops.experimental_dataset_to_tf_record( - dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access + dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access diff --git a/tensorflow/python/data/kernel_tests/batch_test.py b/tensorflow/python/data/kernel_tests/batch_test.py index 5b035e59173..25512503467 100644 --- a/tensorflow/python/data/kernel_tests/batch_test.py +++ b/tensorflow/python/data/kernel_tests/batch_test.py @@ -91,9 +91,9 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase): result = self.evaluate(get_next()) def testBatchDatasetInvalidBatchSize(self): - dataset = (dataset_ops.Dataset.range(10).batch(0)) - self.assertDatasetProduces( - dataset, expected_error=(errors.InvalidArgumentError, '')) + with self.assertRaises(errors.InvalidArgumentError): + dataset = (dataset_ops.Dataset.range(10).batch(0)) + self.evaluate(dataset._variant_tensor) def testBatchSparse(self): diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py index b561cd58baf..4806101d8c7 100644 --- a/tensorflow/python/data/kernel_tests/cache_test.py +++ b/tensorflow/python/data/kernel_tests/cache_test.py @@ -139,8 +139,8 @@ class FileCacheTest(test_base.DatasetTestBase): self.evaluate(get_next1()) # Re-initialize - get_next1 = self.getNext(cache_dataset1) - get_next2 = self.getNext(cache_dataset2) + get_next1 = self.getNext(cache_dataset1, requires_initialization=True) + get_next2 = self.getNext(cache_dataset2, requires_initialization=True) # Reading concurrently should succeed. elements_itr1 = [] diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py index 3926be95505..8193dffc7d2 100644 --- a/tensorflow/python/data/kernel_tests/dataset_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_test.py @@ -272,12 +272,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): - dataset = dataset.batch(2) - with test.mock.patch.object(logging, "warning") as mock_log: - _ = dataset.make_one_shot_iterator() - self.assertRegexpMatches( - str(mock_log.call_args), "Please ensure that all datasets in the " - "pipeline are created in the same graph as the iterator.") + with self.assertRaisesRegexp(ValueError, "must be from the same graph"): + dataset = dataset.batch(2) @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShotSimple(self): @@ -293,9 +289,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): - dataset = dataset.batch(2) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): - _ = dataset.make_initializable_iterator() + dataset = dataset.batch(2) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/prefetch_test.py b/tensorflow/python/data/kernel_tests/prefetch_test.py index a143ba0ac63..8d076f6e685 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_test.py @@ -36,9 +36,10 @@ class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.parameters((-2), (-42)) def testInvalidBufferSize(self, buffer_size): - dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size) - self.assertDatasetProduces( - dataset, expected_error=(errors.InvalidArgumentError, "buffer_size")) + with self.assertRaises(errors.InvalidArgumentError): + dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size) + self.evaluate(dataset._variant_tensor) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/range_test.py b/tensorflow/python/data/kernel_tests/range_test.py index 3f5d25e7f39..b7ac60c3fff 100644 --- a/tensorflow/python/data/kernel_tests/range_test.py +++ b/tensorflow/python/data/kernel_tests/range_test.py @@ -43,9 +43,9 @@ class RangeTest(test_base.DatasetTestBase): def testZeroStep(self): start, stop, step = 2, 10, 0 - dataset = dataset_ops.Dataset.range(start, stop, step) - self.assertDatasetProduces( - dataset, expected_error=(errors.InvalidArgumentError, "")) + with self.assertRaises(errors.InvalidArgumentError): + dataset = dataset_ops.Dataset.range(start, stop, step) + self.evaluate(dataset._variant_tensor) def testNegativeStep(self): start, stop, step = 2, 10, -1 diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py index d083142ab6a..a7b4d86fcf9 100644 --- a/tensorflow/python/data/kernel_tests/window_test.py +++ b/tensorflow/python/data/kernel_tests/window_test.py @@ -116,12 +116,11 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): ("3", 14, 3, 3, 0), ) def testWindowDatasetInvalid(self, count, size, shift, stride): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x).repeat( - count).window( - size=size, shift=shift, - stride=stride).flat_map(lambda x: x.batch(batch_size=size)) - self.assertDatasetProduces( - dataset, expected_error=(errors.InvalidArgumentError, "")) + with self.assertRaises(errors.InvalidArgumentError): + ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window( + size=size, shift=shift, + stride=stride).flat_map(lambda x: x.batch(batch_size=size)) + self.evaluate(ds._variant_tensor) def testWindowSparse(self): diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index fbff7df9c37..112aa926ae5 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -35,6 +35,7 @@ py_library( "//tensorflow/python/data/util:random_seed", "//tensorflow/python/data/util:sparse", "//tensorflow/python/data/util:structure", + "//tensorflow/python/data/util:traverse", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2c1f69de608..7fa9ea59e88 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.data.util import options as options_lib from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.data.util import structure as structure_lib +from tensorflow.python.data.util import traverse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -75,9 +76,27 @@ class DatasetV2(object): plan" of transformations that act on those elements. """ - def __init__(self): + def __init__(self, variant_tensor): + """Creates a DatasetV2 object. + + This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not + take anything in its constructor whereas in the DatasetV2, we expect + subclasses to create a variant_tensor and pass it in to the super() call. + + Args: + variant_tensor: A DT_VARIANT tensor that represents the dataset. + """ + self._dataset_variant_tensor = variant_tensor self._graph_attr = ops.get_default_graph() + @property + def _variant_tensor(self): + return self._dataset_variant_tensor + + @_variant_tensor.setter + def _variant_tensor(self, _): + raise ValueError("The _variant_tensor property is read-only") + def _as_serialized_graph(self): """Produces serialized graph representation of the dataset. @@ -85,16 +104,7 @@ class DatasetV2(object): A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a serialized graph. """ - return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor()) - - @abc.abstractmethod - def _as_variant_tensor(self): - """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. - - Returns: - A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. - """ - raise NotImplementedError("Dataset._as_variant_tensor") + return gen_dataset_ops.dataset_to_graph(self._variant_tensor) @abc.abstractmethod def _inputs(self): @@ -1279,7 +1289,7 @@ class DatasetV2(object): # pylint: disable=protected-access return state_structure._from_compatible_tensor_list( gen_dataset_ops.reduce_dataset( - self._as_variant_tensor(), + self._variant_tensor, state_structure._to_tensor_list(initial_state), reduce_func.captured_inputs, f=reduce_func, @@ -1314,8 +1324,31 @@ class DatasetV1(DatasetV2): plan" of transformations that act on those elements. """ - def __init__(self): # pylint: disable=useless-super-delegation - super(DatasetV1, self).__init__() + def __init__(self): + try: + variant_tensor = self._as_variant_tensor() + except AttributeError as e: + if "_as_variant_tensor" in str(e): + raise AttributeError("Please use _variant_tensor instead of " + "_as_variant_tensor() to obtain the variant " + "associated with a dataset") + raise AttributeError("A likely cause of this error is that the super " + "call for this dataset is not the last line of the " + "__init__ method. The base class causes the " + "_as_variant_tensor call in its constructor and " + "if that uses attributes defined in the __init__ " + "method, those attrs need to be defined before the " + "super call.") + super(DatasetV1, self).__init__(variant_tensor) + + @abc.abstractmethod + def _as_variant_tensor(self): + """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. + + Returns: + A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. + """ + raise NotImplementedError("Dataset._as_variant_tensor") @deprecation.deprecated( None, "Use `for ... in dataset:` to iterate over a dataset. If using " @@ -1335,11 +1368,19 @@ class DatasetV1(DatasetV2): return iterator_ops.EagerIterator(self) _ensure_same_dataset_graph(self) + # Now that we create datasets at python object creation time, the capture + # by value _make_dataset() function would try to capture these variant + # tensor dataset inputs, which are marked as stateful ops and would throw + # an error if we try and capture them. We therefore traverse the graph + # to find all these ops and whitelist them so that the capturing + # logic instead of throwing an error recreates these ops which is what was + # happening before. + all_ds_ops = traverse.obtain_all_variant_tensor_ops(self) graph_level_seed, op_level_seed = core_random_seed.get_seed(None) # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. - @function.Defun(capture_by_value=True) + @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops) def _make_dataset(): """Factory function for a dataset.""" # NOTE(mrry): `Defun` does not capture the graph-level seed from the @@ -1351,7 +1392,7 @@ class DatasetV1(DatasetV2): (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) dataset = self._apply_options() - return dataset._as_variant_tensor() # pylint: disable=protected-access + return dataset._variant_tensor # pylint: disable=protected-access try: _make_dataset.add_to_graph(ops.get_default_graph()) @@ -1416,7 +1457,7 @@ class DatasetV1(DatasetV2): container="", shared_name=shared_name, **flat_structure(self)) with ops.colocate_with(iterator_resource): initializer = gen_dataset_ops.make_iterator( - dataset._as_variant_tensor(), # pylint: disable=protected-access + dataset._variant_tensor, # pylint: disable=protected-access iterator_resource) return iterator_ops.Iterator(iterator_resource, initializer, dataset.output_types, dataset.output_shapes, @@ -1621,11 +1662,11 @@ class DatasetV1Adapter(DatasetV1): """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" def __init__(self, dataset): - super(DatasetV1Adapter, self).__init__() self._dataset = dataset + super(DatasetV1Adapter, self).__init__() def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access + return self._dataset._variant_tensor # pylint: disable=protected-access def _has_captured_ref(self): return self._dataset._has_captured_ref() # pylint: disable=protected-access @@ -1657,14 +1698,14 @@ def _ensure_same_dataset_graph(dataset): if current_graph != ds_graph: logging.warning("The graph (" + str(current_graph) + ") of the iterator " "is different from the graph (" + str(ds_graph) + ") " - "the dataset: " + str(ds) + " was created in. " - "If you are using the Estimator API, make sure that no " - "part of the dataset returned by the `input_fn` function " - "is defined outside the `input_fn` function." - "Please ensure that all datasets in the pipeline are " - "created in the same graph as the iterator. NOTE: This " - "warning will become an error in future versions of " - "TensorFlow.") + "the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access + "created in. If you are using the Estimator API, " + "make sure that no part of the dataset returned by the " + "`input_fn` function is defined outside the `input_fn` " + "function. Please ensure that all datasets in the " + "pipeline are created in the same graph as the iterator. " + "NOTE: This warning will become an error in future " + "versions of TensorFlow.") for input_ds in ds._inputs(): # pylint: disable=protected-access if input_ds not in visited: bfs_q.put(input_ds) @@ -1820,9 +1861,9 @@ class DatasetSource(DatasetV2): class UnaryDataset(DatasetV2): """Abstract class representing a dataset with one input.""" - def __init__(self, input_dataset): - super(UnaryDataset, self).__init__() + def __init__(self, input_dataset, variant_tensor): self._input_dataset = input_dataset + super(UnaryDataset, self).__init__(variant_tensor) def _inputs(self): return [self._input_dataset] @@ -1831,6 +1872,11 @@ class UnaryDataset(DatasetV2): class UnaryUnchangedStructureDataset(UnaryDataset): """Represents a unary dataset with the same input and output structure.""" + def __init__(self, input_dataset, variant_tensor): + self._input_dataset = input_dataset + super(UnaryUnchangedStructureDataset, self).__init__( + input_dataset, variant_tensor) + @property def _element_structure(self): return self._input_dataset._element_structure # pylint: disable=protected-access @@ -1841,7 +1887,6 @@ class TensorDataset(DatasetSource): def __init__(self, tensors): """See `Dataset.from_tensors()` for details.""" - super(TensorDataset, self).__init__() with ops.name_scope("tensors"): tensors = nest.pack_sequence_as(tensors, [ sparse_tensor_lib.SparseTensor.from_value(t) @@ -1852,9 +1897,9 @@ class TensorDataset(DatasetSource): self._structure = structure_lib.Structure.from_value(tensors) self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access - def _as_variant_tensor(self): - return gen_dataset_ops.tensor_dataset( + variant_tensor = gen_dataset_ops.tensor_dataset( self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access + super(TensorDataset, self).__init__(variant_tensor) @property def _element_structure(self): @@ -1866,7 +1911,6 @@ class TensorSliceDataset(DatasetSource): def __init__(self, tensors): """See `Dataset.from_tensor_slices()` for details.""" - super(TensorSliceDataset, self).__init__() with ops.name_scope("tensors"): tensors = nest.pack_sequence_as(tensors, [ sparse_tensor_lib.SparseTensor.from_value(t) @@ -1887,9 +1931,9 @@ class TensorSliceDataset(DatasetSource): batch_dim.assert_is_compatible_with(tensor_shape.Dimension( tensor_shape.dimension_value(t.get_shape()[0]))) - def _as_variant_tensor(self): - return gen_dataset_ops.tensor_slice_dataset( + variant_tensor = gen_dataset_ops.tensor_slice_dataset( self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access + super(TensorSliceDataset, self).__init__(variant_tensor) @property def _element_structure(self): @@ -1901,7 +1945,6 @@ class SparseTensorSliceDataset(DatasetSource): def __init__(self, sparse_tensor): """See `Dataset.from_sparse_tensor_slices()` for details.""" - super(SparseTensorSliceDataset, self).__init__() if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.") self._sparse_tensor = sparse_tensor @@ -1914,10 +1957,10 @@ class SparseTensorSliceDataset(DatasetSource): structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]), structure_lib.TensorStructure(dtypes.int64, [rank]))) - def _as_variant_tensor(self): - return gen_dataset_ops.sparse_tensor_slice_dataset( + variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset( self._sparse_tensor.indices, self._sparse_tensor.values, self._sparse_tensor.dense_shape) + super(SparseTensorSliceDataset, self).__init__(variant_tensor) @property def _element_structure(self): @@ -1928,12 +1971,8 @@ class _VariantDataset(DatasetV2): """A Dataset wrapper around a `tf.variant`-typed function argument.""" def __init__(self, dataset_variant, structure): - super(_VariantDataset, self).__init__() - self._dataset_variant = dataset_variant self._structure = structure - - def _as_variant_tensor(self): - return self._dataset_variant + super(_VariantDataset, self).__init__(dataset_variant) def _inputs(self): return [] @@ -1965,7 +2004,7 @@ class DatasetStructure(structure_lib.Structure): other._element_structure)) def _to_tensor_list(self, value): - return [value._as_variant_tensor()] # pylint: disable=protected-access + return [value._variant_tensor] # pylint: disable=protected-access def _to_batched_tensor_list(self, value): raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.") @@ -2153,7 +2192,7 @@ def flat_structure(dataset): Most Dataset op constructors expect `output_shapes` and `output_types` arguments that represent the flattened structure of an element. This helper function generates these attrs as a keyword argument dictionary, allowing - `Dataset._as_variant_tensor()` implementations to pass + `Dataset._variant_tensor` implementations to pass `**flat_structure(self)` to the op constructor. Args: @@ -2189,7 +2228,6 @@ class _GeneratorDataset(DatasetSource): `init_func` immediately before a C++ iterator over this dataset is destroyed. The return value is ignored. """ - super(_GeneratorDataset, self).__init__() self._init_args = init_args self._init_structure = structure_lib.Structure.from_value(init_args) @@ -2208,9 +2246,7 @@ class _GeneratorDataset(DatasetSource): finalize_func, self._transformation_name(), input_structure=self._init_func.output_structure) - - def _as_variant_tensor(self): - return gen_dataset_ops.generator_dataset( + variant_tensor = gen_dataset_ops.generator_dataset( self._init_structure._to_tensor_list(self._init_args) # pylint: disable=protected-access + self._init_func.function.captured_inputs, self._next_func.function.captured_inputs, @@ -2219,6 +2255,7 @@ class _GeneratorDataset(DatasetSource): next_func=self._next_func.function, finalize_func=self._finalize_func.function, **flat_structure(self)) + super(_GeneratorDataset, self).__init__(variant_tensor) @property def _element_structure(self): @@ -2233,7 +2270,6 @@ class ZipDataset(DatasetV2): def __init__(self, datasets): """See `Dataset.zip()` for details.""" - super(ZipDataset, self).__init__() for ds in nest.flatten(datasets): if not isinstance(ds, DatasetV2): if isinstance(ds, list): @@ -2250,12 +2286,12 @@ class ZipDataset(DatasetV2): self._datasets, [ds._element_structure for ds in nest.flatten(self._datasets)])) # pylint: disable=protected-access - def _as_variant_tensor(self): # pylint: disable=protected-access - return gen_dataset_ops.zip_dataset( - [ds._as_variant_tensor() for ds in nest.flatten(self._datasets)], + variant_tensor = gen_dataset_ops.zip_dataset( + [ds._variant_tensor for ds in nest.flatten(self._datasets)], **flat_structure(self)) # pylint: enable=protected-access + super(ZipDataset, self).__init__(variant_tensor) def _inputs(self): return nest.flatten(self._datasets) @@ -2270,7 +2306,6 @@ class ConcatenateDataset(DatasetV2): def __init__(self, input_dataset, dataset_to_concatenate): """See `Dataset.concatenate()` for details.""" - super(ConcatenateDataset, self).__init__() self._input_dataset = input_dataset self._dataset_to_concatenate = dataset_to_concatenate @@ -2298,17 +2333,15 @@ class ConcatenateDataset(DatasetV2): output_types, output_shapes, output_classes) self._input_datasets = [input_dataset, dataset_to_concatenate] - - def _as_variant_tensor(self): # pylint: disable=protected-access - return gen_dataset_ops.concatenate_dataset( - self._input_dataset._as_variant_tensor(), - self._dataset_to_concatenate._as_variant_tensor(), + variant_tensor = gen_dataset_ops.concatenate_dataset( + input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor, **flat_structure(self)) # pylint: enable=protected-access + super(ConcatenateDataset, self).__init__(variant_tensor) def _inputs(self): - return [self._input_dataset, self._dataset_to_concatenate] + return self._input_datasets @property def _element_structure(self): @@ -2320,19 +2353,17 @@ class RepeatDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, count): """See `Dataset.repeat()` for details.""" - super(RepeatDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if count is None: self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") else: self._count = ops.convert_to_tensor( count, dtype=dtypes.int64, name="count") - - def _as_variant_tensor(self): - return gen_dataset_ops.repeat_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.repeat_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access count=self._count, **flat_structure(self)) + super(RepeatDataset, self).__init__(input_dataset, variant_tensor) class RangeDataset(DatasetSource): @@ -2340,8 +2371,13 @@ class RangeDataset(DatasetSource): def __init__(self, *args): """See `Dataset.range()` for details.""" - super(RangeDataset, self).__init__() self._parse_args(*args) + variant_tensor = gen_dataset_ops.range_dataset( + start=self._start, + stop=self._stop, + step=self._step, + **flat_structure(self)) + super(RangeDataset, self).__init__(variant_tensor) def _parse_args(self, *args): """Parse arguments according to the same rules as the `range()` builtin.""" @@ -2363,13 +2399,6 @@ class RangeDataset(DatasetSource): def _build_tensor(self, int64_value, name): return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) - def _as_variant_tensor(self): - return gen_dataset_ops.range_dataset( - start=self._start, - stop=self._stop, - step=self._step, - **flat_structure(self)) - @property def _element_structure(self): return structure_lib.TensorStructure(dtypes.int64, []) @@ -2380,16 +2409,14 @@ class CacheDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, filename): """See `Dataset.cache()` for details.""" - super(CacheDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._filename = ops.convert_to_tensor( filename, dtype=dtypes.string, name="filename") - - def _as_variant_tensor(self): - return gen_dataset_ops.cache_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.cache_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access filename=self._filename, **flat_structure(self)) + super(CacheDataset, self).__init__(input_dataset, variant_tensor) class ShuffleDataset(UnaryUnchangedStructureDataset): @@ -2420,7 +2447,6 @@ class ShuffleDataset(UnaryUnchangedStructureDataset): Raises: ValueError: if invalid arguments are provided. """ - super(ShuffleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") @@ -2430,15 +2456,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset): self._reshuffle_each_iteration = True else: self._reshuffle_each_iteration = reshuffle_each_iteration - - def _as_variant_tensor(self): - return gen_dataset_ops.shuffle_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.shuffle_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, seed=self._seed, seed2=self._seed2, reshuffle_each_iteration=self._reshuffle_each_iteration, **flat_structure(self)) + super(ShuffleDataset, self).__init__(input_dataset, variant_tensor) class TakeDataset(UnaryUnchangedStructureDataset): @@ -2446,15 +2471,13 @@ class TakeDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, count): """See `Dataset.take()` for details.""" - super(TakeDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") - - def _as_variant_tensor(self): - return gen_dataset_ops.take_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.take_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access count=self._count, **flat_structure(self)) + super(TakeDataset, self).__init__(input_dataset, variant_tensor) class SkipDataset(UnaryUnchangedStructureDataset): @@ -2462,15 +2485,13 @@ class SkipDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, count): """See `Dataset.skip()` for details.""" - super(SkipDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") - - def _as_variant_tensor(self): - return gen_dataset_ops.skip_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.skip_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access count=self._count, **flat_structure(self)) + super(SkipDataset, self).__init__(input_dataset, variant_tensor) class BatchDataset(UnaryDataset): @@ -2478,7 +2499,6 @@ class BatchDataset(UnaryDataset): def __init__(self, input_dataset, batch_size, drop_remainder): """See `Dataset.batch()` for details.""" - super(BatchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") @@ -2494,13 +2514,12 @@ class BatchDataset(UnaryDataset): tensor_util.constant_value(self._batch_size)) else: self._structure = input_dataset._element_structure._batch(None) - - def _as_variant_tensor(self): - return gen_dataset_ops.batch_dataset_v2( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.batch_dataset_v2( + input_dataset._variant_tensor, # pylint: disable=protected-access batch_size=self._batch_size, drop_remainder=self._drop_remainder, **flat_structure(self)) + super(BatchDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): @@ -2622,7 +2641,7 @@ class PaddedBatchDataset(UnaryDataset): def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, drop_remainder): """See `Dataset.batch()` for details.""" - super(PaddedBatchDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset if sparse.any_sparse(input_dataset.output_classes): # TODO(b/63669786): support batching of sparse tensors raise TypeError( @@ -2665,12 +2684,11 @@ class PaddedBatchDataset(UnaryDataset): self._input_dataset.output_types, output_shapes, self._input_dataset.output_classes) - def _as_variant_tensor(self): # pylint: disable=protected-access # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018. if smart_cond.smart_constant_value(self._drop_remainder) is False: - return gen_dataset_ops.padded_batch_dataset( - self._input_dataset._as_variant_tensor(), + variant_tensor = gen_dataset_ops.padded_batch_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access batch_size=self._batch_size, padded_shapes=[ ops.convert_to_tensor(s, dtype=dtypes.int64) @@ -2679,8 +2697,8 @@ class PaddedBatchDataset(UnaryDataset): padding_values=nest.flatten(self._padding_values), output_shapes=self._structure._flat_shapes) else: - return gen_dataset_ops.padded_batch_dataset_v2( - self._input_dataset._as_variant_tensor(), + variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( + input_dataset._variant_tensor, # pylint: disable=protected-access batch_size=self._batch_size, padded_shapes=[ ops.convert_to_tensor(s, dtype=dtypes.int64) @@ -2689,6 +2707,7 @@ class PaddedBatchDataset(UnaryDataset): padding_values=nest.flatten(self._padding_values), drop_remainder=self._drop_remainder, output_shapes=self._structure._flat_shapes) + super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): @@ -2727,22 +2746,19 @@ class MapDataset(UnaryDataset): use_inter_op_parallelism=True, preserve_cardinality=False): """See `Dataset.map()` for details.""" - super(MapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism self._preserve_cardinality = preserve_cardinality self._map_func = StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return gen_dataset_ops.map_dataset( - input_t, + variant_tensor = gen_dataset_ops.map_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, use_inter_op_parallelism=self._use_inter_op_parallelism, preserve_cardinality=self._preserve_cardinality, **flat_structure(self)) + super(MapDataset, self).__init__(input_dataset, variant_tensor) def _functions(self): return [self._map_func] @@ -2755,7 +2771,7 @@ class MapDataset(UnaryDataset): return "Dataset.map()" -class ParallelMapDataset(MapDataset): +class ParallelMapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input in parallel.""" def __init__(self, @@ -2765,23 +2781,32 @@ class ParallelMapDataset(MapDataset): use_inter_op_parallelism=True, preserve_cardinality=False): """See `Dataset.map()` for details.""" - super(ParallelMapDataset, self).__init__( - input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality) - + self._input_dataset = input_dataset + self._use_inter_op_parallelism = use_inter_op_parallelism + self._map_func = StructuredFunctionWrapper( + map_func, self._transformation_name(), dataset=input_dataset) self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_t = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.parallel_map_dataset( - input_t, + self._preserve_cardinality = preserve_cardinality + variant_tensor = gen_dataset_ops.parallel_map_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, num_parallel_calls=self._num_parallel_calls, use_inter_op_parallelism=self._use_inter_op_parallelism, preserve_cardinality=self._preserve_cardinality, **flat_structure(self)) + super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._map_func] + + @property + def _element_structure(self): + return self._map_func.output_structure + + def _transformation_name(self): + return "Dataset.map()" class FlatMapDataset(UnaryDataset): @@ -2789,24 +2814,21 @@ class FlatMapDataset(UnaryDataset): def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" - super(FlatMapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset - self._map_func = StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) if not isinstance(self._map_func.output_structure, DatasetStructure): raise TypeError("`map_func` must return a `Dataset` object.") self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access - - def _functions(self): - return [self._map_func] - - def _as_variant_tensor(self): - return gen_dataset_ops.flat_map_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.flat_map_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, **flat_structure(self)) + super(FlatMapDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._map_func] @property def _element_structure(self): @@ -2816,58 +2838,79 @@ class FlatMapDataset(UnaryDataset): return "Dataset.flat_map()" -class InterleaveDataset(FlatMapDataset): +class InterleaveDataset(UnaryDataset): """A `Dataset` that maps a function over its input and interleaves the result. """ def __init__(self, input_dataset, map_func, cycle_length, block_length): """See `Dataset.interleave()` for details.""" - super(InterleaveDataset, self).__init__(input_dataset, map_func) + self._input_dataset = input_dataset + self._map_func = StructuredFunctionWrapper( + map_func, self._transformation_name(), dataset=input_dataset) + if not isinstance(self._map_func.output_structure, DatasetStructure): + raise TypeError("`map_func` must return a `Dataset` object.") + self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._cycle_length = ops.convert_to_tensor( cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( block_length, dtype=dtypes.int64, name="block_length") - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.interleave_dataset( - self._input_dataset._as_variant_tensor(), - self._map_func.function.captured_inputs, + variant_tensor = gen_dataset_ops.interleave_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access + self._map_func.function.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, f=self._map_func.function, **flat_structure(self)) + super(InterleaveDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._map_func] + + @property + def _element_structure(self): + return self._structure def _transformation_name(self): return "Dataset.interleave()" -class ParallelInterleaveDataset(FlatMapDataset): +class ParallelInterleaveDataset(UnaryDataset): """A `Dataset` that maps a function over its input and interleaves the result. - """ def __init__(self, input_dataset, map_func, cycle_length, block_length, num_parallel_calls): """See `Dataset.interleave()` for details.""" - super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func) + self._input_dataset = input_dataset + self._map_func = StructuredFunctionWrapper( + map_func, self._transformation_name(), dataset=input_dataset) + if not isinstance(self._map_func.output_structure, DatasetStructure): + raise TypeError("`map_func` must return a `Dataset` object.") + self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access self._cycle_length = ops.convert_to_tensor( cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( block_length, dtype=dtypes.int64, name="block_length") self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.parallel_interleave_dataset_v2( - self._input_dataset._as_variant_tensor(), - self._map_func.function.captured_inputs, + variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2( + input_dataset._variant_tensor, # pylint: disable=protected-access + self._map_func.function.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, self._num_parallel_calls, f=self._map_func.function, **flat_structure(self)) + super(ParallelInterleaveDataset, self).__init__(input_dataset, + variant_tensor) + + def _functions(self): + return [self._map_func] + + @property + def _element_structure(self): + return self._structure def _transformation_name(self): return "Dataset.interleave()" @@ -2878,7 +2921,6 @@ class FilterDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, predicate): """See `Dataset.filter()` for details.""" - super(FilterDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = StructuredFunctionWrapper( predicate, self._transformation_name(), dataset=input_dataset) @@ -2886,16 +2928,15 @@ class FilterDataset(UnaryUnchangedStructureDataset): structure_lib.TensorStructure(dtypes.bool, [])): raise ValueError("`predicate` must return a scalar boolean tensor.") self._predicate = wrapped_func - - def _functions(self): - return [self._predicate] - - def _as_variant_tensor(self): - return gen_dataset_ops.filter_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.filter_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **flat_structure(self)) + super(FilterDataset, self).__init__(input_dataset, variant_tensor) + + def _functions(self): + return [self._predicate] def _transformation_name(self): return "Dataset.filter()" @@ -2906,18 +2947,16 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, buffer_size): """See `Dataset.prefetch()` for details.""" - super(PrefetchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if buffer_size is None: buffer_size = -1 # This is the sentinel for auto-tuning. self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - - def _as_variant_tensor(self): - return gen_dataset_ops.prefetch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.prefetch_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, **flat_structure(self)) + super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) class WindowDataset(UnaryDataset): @@ -2925,7 +2964,6 @@ class WindowDataset(UnaryDataset): def __init__(self, input_dataset, size, shift, stride, drop_remainder): """See `window_dataset()` for more details.""" - super(WindowDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") @@ -2944,15 +2982,14 @@ class WindowDataset(UnaryDataset): nest.flatten(input_dataset.output_types)) ]) self._structure = structure_lib.NestedStructure(nest_of_structures) - - def _as_variant_tensor(self): - return gen_dataset_ops.window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.window_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._size, self._shift, self._stride, self._drop_remainder, **flat_structure(self)) + super(WindowDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): @@ -2963,16 +3000,14 @@ class _OptionsDataset(UnaryUnchangedStructureDataset): """An identity `Dataset` that stores options.""" def __init__(self, input_dataset, options): - super(_OptionsDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._options = input_dataset.options() if self._options: self._options = self._options.merge(options) else: self._options = options - - def _as_variant_tensor(self): - return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access + super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) def options(self): return self._options @@ -2983,13 +3018,11 @@ class _ModelDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset): """See `optimize()` for details.""" - super(_ModelDataset, self).__init__(input_dataset) self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_dataset_ops.model_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.model_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access **flat_structure(self)) + super(_ModelDataset, self).__init__(input_dataset, variant_tensor) class _OptimizeDataset(UnaryUnchangedStructureDataset): @@ -2997,68 +3030,63 @@ class _OptimizeDataset(UnaryUnchangedStructureDataset): def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(_OptimizeDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if optimizations is None: optimizations = [] self._optimizations = ops.convert_to_tensor( optimizations, dtype=dtypes.string, name="optimizations") - - def _as_variant_tensor(self): - return gen_dataset_ops.optimize_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = gen_dataset_ops.optimize_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._optimizations, **flat_structure(self)) + super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor) class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset): """A `Dataset` that acts as an identity, and sets a stats aggregator.""" def __init__(self, input_dataset, aggregator, prefix, counter_prefix): - super(_SetStatsAggregatorDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._stats_aggregator = aggregator self._prefix = prefix self._counter_prefix = counter_prefix - - def _as_variant_tensor(self): - return ged_ops.experimental_set_stats_aggregator_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access self._prefix, self._counter_prefix, **flat_structure(self)) + super(_SetStatsAggregatorDataset, self).__init__(input_dataset, + variant_tensor) class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset): """A `Dataset` that acts as an identity, overriding intra-op parallelism.""" def __init__(self, input_dataset, max_intra_op_parallelism): - super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._max_intra_op_parallelism = ops.convert_to_tensor( max_intra_op_parallelism, dtype=dtypes.int64, name="max_intra_op_parallelism") - - def _as_variant_tensor(self): - return ged_ops.experimental_max_intra_op_parallelism_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._max_intra_op_parallelism, **flat_structure(self)) + super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset, + variant_tensor) class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset): """A `Dataset` that acts as an identity, setting a private threadpool.""" def __init__(self, input_dataset, num_threads): - super(_PrivateThreadPoolDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._num_threads = ops.convert_to_tensor( num_threads, dtype=dtypes.int64, name="num_threads") - - def _as_variant_tensor(self): - return ged_ops.experimental_private_thread_pool_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_private_thread_pool_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access self._num_threads, **flat_structure(self)) + super(_PrivateThreadPoolDataset, self).__init__(input_dataset, + variant_tensor) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index d0e91b01f91..bfa256f8d77 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -357,7 +357,7 @@ class Iterator(checkpointable.CheckpointableBase): (self.output_shapes, dataset.output_shapes)) with ops.colocate_with(self._iterator_resource): return gen_dataset_ops.make_iterator( - dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access + dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s representing the next element. @@ -524,7 +524,7 @@ class EagerIterator(checkpointable.CheckpointableBase): with ops.device("/cpu:0"): # pylint: disable=protected-access dataset = dataset._apply_options() - ds_variant = dataset._as_variant_tensor() + ds_variant = dataset._variant_tensor self._structure = structure_lib.convert_legacy_structure( dataset.output_types, dataset.output_shapes, dataset.output_classes) self._flat_output_types = self._structure._flat_types diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index 876b77b8537..8192d538917 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -30,12 +30,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops -class _PerDeviceGenerator(dataset_ops.Dataset): +class _PerDeviceGenerator(dataset_ops.DatasetV2): """A `dummy` generator dataset.""" def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, target_device, element_structure): - super(_PerDeviceGenerator, self).__init__() self._target_device = target_device self._structure = element_structure @@ -108,9 +107,8 @@ class _PerDeviceGenerator(dataset_ops.Dataset): ) self._finalize_captured_args = self._finalize_func.captured_inputs - def _as_variant_tensor(self): with ops.device(self._target_device): - return gen_dataset_ops.generator_dataset( + variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, @@ -118,6 +116,7 @@ class _PerDeviceGenerator(dataset_ops.Dataset): next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) + super(_PerDeviceGenerator, self).__init__(variant_tensor) def _inputs(self): # TODO(b/116506223): Determine which datasets should be used as inputs here. @@ -177,7 +176,7 @@ class MultiDeviceIterator(object): # The incarnation ID is used to ensure consistency between the per-device # iterators and the multi-device iterator. self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( - self._dataset._as_variant_tensor(), # pylint: disable=protected-access + self._dataset._variant_tensor, # pylint: disable=protected-access self._multi_device_iterator_resource, max_buffer_size=max_buffer_size) @@ -200,7 +199,8 @@ class MultiDeviceIterator(object): options.experimental_optimization.apply_default_optimizations = False ds = ds.with_options(options) with ops.device(device): - self._device_iterators.append(ds.make_initializable_iterator()) + self._device_iterators.append( + dataset_ops.make_initializable_iterator(ds)) device_iterator_initializers = [ iterator.initializer for iterator in self._device_iterators diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 0d6023dea28..5e61bcf6be0 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -49,7 +49,6 @@ class TextLineDatasetV2(dataset_ops.DatasetSource): to buffer. A value of 0 results in the default buffering values chosen based on the compression type. """ - super(TextLineDatasetV2, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") self._compression_type = convert.optional_param_to_tensor( @@ -59,10 +58,9 @@ class TextLineDatasetV2(dataset_ops.DatasetSource): argument_dtype=dtypes.string) self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) - - def _as_variant_tensor(self): - return gen_dataset_ops.text_line_dataset( + variant_tensor = gen_dataset_ops.text_line_dataset( self._filenames, self._compression_type, self._buffer_size) + super(TextLineDatasetV2, self).__init__(variant_tensor) @property def _element_structure(self): @@ -100,7 +98,6 @@ class _TFRecordDataset(dataset_ops.DatasetSource): buffer_size: (Optional.) A `tf.int64` scalar representing the number of bytes in the read buffer. 0 means no buffering. """ - super(_TFRecordDataset, self).__init__() # Force the type to string even if filenames is an empty list. self._filenames = ops.convert_to_tensor( filenames, dtypes.string, name="filenames") @@ -113,24 +110,32 @@ class _TFRecordDataset(dataset_ops.DatasetSource): "buffer_size", buffer_size, argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) - - def _as_variant_tensor(self): - return gen_dataset_ops.tf_record_dataset( + variant_tensor = gen_dataset_ops.tf_record_dataset( self._filenames, self._compression_type, self._buffer_size) + super(_TFRecordDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.TensorStructure(dtypes.string, []) -class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): +class ParallelInterleaveDataset(dataset_ops.UnaryDataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): """See `tf.data.experimental.parallel_interleave()` for details.""" - super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func, - cycle_length, block_length) + self._input_dataset = input_dataset + self._map_func = dataset_ops.StructuredFunctionWrapper( + map_func, self._transformation_name(), dataset=input_dataset) + if not isinstance(self._map_func.output_structure, + dataset_ops.DatasetStructure): + raise TypeError("`map_func` must return a `Dataset` object.") + self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") self._sloppy = ops.convert_to_tensor( sloppy, dtype=dtypes.bool, name="sloppy") self._buffer_output_elements = convert.optional_param_to_tensor( @@ -141,11 +146,8 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): "prefetch_input_elements", prefetch_input_elements, argument_default=2 * cycle_length) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return ged_ops.experimental_parallel_interleave_dataset( - self._input_dataset._as_variant_tensor(), + variant_tensor = ged_ops.experimental_parallel_interleave_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, self._cycle_length, self._block_length, @@ -154,7 +156,15 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): self._prefetch_input_elements, f=self._map_func.function, **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access + super(ParallelInterleaveDataset, self).__init__(input_dataset, + variant_tensor) + + def _functions(self): + return [self._map_func] + + @property + def _element_structure(self): + return self._structure def _transformation_name(self): return "tf.data.experimental.parallel_interleave()" @@ -186,7 +196,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2): TypeError: If any argument does not have the expected type. ValueError: If any argument does not have the expected shape. """ - super(TFRecordDatasetV2, self).__init__() if isinstance(filenames, dataset_ops.DatasetV2): if filenames.output_types != dtypes.string: raise TypeError( @@ -215,6 +224,8 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2): filenames, read_one_file, cycle_length=num_parallel_reads, block_length=1, sloppy=False, buffer_output_elements=None, prefetch_input_elements=None) + variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access + super(TFRecordDatasetV2, self).__init__(variant_tensor) def _clone(self, filenames=None, @@ -226,9 +237,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2): buffer_size or self._buffer_size, num_parallel_reads or self._num_parallel_reads) - def _as_variant_tensor(self): - return self._impl._as_variant_tensor() # pylint: disable=protected-access - def _inputs(self): return self._impl._inputs() # pylint: disable=protected-access @@ -295,7 +303,6 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource): compression_type: (Optional.) A `tf.string` scalar evaluating to one of `""` (no compression), `"ZLIB"`, or `"GZIP"`. """ - super(FixedLengthRecordDatasetV2, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") self._record_bytes = ops.convert_to_tensor( @@ -312,17 +319,16 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource): compression_type, argument_default="", argument_dtype=dtypes.string) - - def _as_variant_tensor(self): if (self._compression_type is not None or compat.forward_compatible(2018, 11, 30)): - return gen_dataset_ops.fixed_length_record_dataset_v2( + variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2( self._filenames, self._header_bytes, self._record_bytes, self._footer_bytes, self._buffer_size, self._compression_type) else: - return gen_dataset_ops.fixed_length_record_dataset( + variant_tensor = gen_dataset_ops.fixed_length_record_dataset( self._filenames, self._header_bytes, self._record_bytes, self._footer_bytes, self._buffer_size) + super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 04e80299e0d..c98b1f17293 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -163,3 +163,24 @@ py_test( "//tensorflow/python:util", ], ) + +py_library( + name = "traverse", + srcs = ["traverse.py"], + srcs_version = "PY2AND3", + deps = [ + ], +) + +py_test( + name = "traverse_test", + size = "small", + srcs = ["traverse_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":traverse", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/data/ops:dataset_ops", + ], +) diff --git a/tensorflow/python/data/util/traverse.py b/tensorflow/python/data/util/traverse.py new file mode 100644 index 00000000000..12e576fb414 --- /dev/null +++ b/tensorflow/python/data/util/traverse.py @@ -0,0 +1,56 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to traverse the Dataset dependency structure.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import queue as Queue # pylint: disable=redefined-builtin + +from tensorflow.python.framework import dtypes + + +def obtain_all_variant_tensor_ops(dataset): + """Given an input dataset, finds all dataset ops used for construction. + + A series of transformations would have created this dataset with each + transformation including zero or more Dataset ops, each producing a dataset + variant tensor. This method outputs all of them. + + Args: + dataset: Dataset to find variant tensors for. + + Returns: + A list of variant_tensor producing dataset ops used to construct this + dataset. + """ + all_variant_tensor_ops = [] + bfs_q = Queue.Queue() + bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access + visited = [] + while not bfs_q.empty(): + op = bfs_q.get() + visited.append(op) + # We look for all ops that produce variant tensors as output. This is a bit + # of overkill but the other dataset _inputs() traversal strategies can't + # cover the case of function inputs that capture dataset variants. + # TODO(b/120873778): Make this more efficient. + if op.outputs[0].dtype == dtypes.variant: + all_variant_tensor_ops.append(op) + for i in op.inputs: + input_op = i.op + if input_op not in visited: + bfs_q.put(input_op) + return all_variant_tensor_ops diff --git a/tensorflow/python/data/util/traverse_test.py b/tensorflow/python/data/util/traverse_test.py new file mode 100644 index 00000000000..53de1be897a --- /dev/null +++ b/tensorflow/python/data/util/traverse_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities for traversing the dataset construction graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import traverse +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): + + def __init__(self, input_dataset): + self._input_dataset = input_dataset + temp_variant_tensor = gen_dataset_ops.prefetch_dataset( + input_dataset._variant_tensor, + buffer_size=1, + **dataset_ops.flat_structure(self)) + variant_tensor = gen_dataset_ops.model_dataset( + temp_variant_tensor, **dataset_ops.flat_structure(self)) + super(_TestDataset, self).__init__(input_dataset, variant_tensor) + + +class TraverseTest(test.TestCase): + + @test_util.run_deprecated_v1 + def testOnlySource(self): + ds = dataset_ops.Dataset.range(10) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) + self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops]) + + @test_util.run_deprecated_v1 + def testSimplePipeline(self): + ds = dataset_ops.Dataset.range(10).map(math_ops.square) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) + self.assertSetEqual( + set(["MapDataset", "RangeDataset"]), + set([x.name for x in variant_tensor_ops])) + + @test_util.run_deprecated_v1 + def testConcat(self): + ds1 = dataset_ops.Dataset.range(10) + ds2 = dataset_ops.Dataset.range(10) + ds = ds1.concatenate(ds2) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) + self.assertSetEqual( + set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]), + set([x.name for x in variant_tensor_ops])) + + @test_util.run_deprecated_v1 + def testZip(self): + ds1 = dataset_ops.Dataset.range(10) + ds2 = dataset_ops.Dataset.range(10) + ds = dataset_ops.Dataset.zip((ds1, ds2)) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) + self.assertSetEqual( + set(["ZipDataset", "RangeDataset", "RangeDataset_1"]), + set([x.name for x in variant_tensor_ops])) + + @test_util.run_deprecated_v1 + def testMultipleVariantTensors(self): + ds = dataset_ops.Dataset.range(10) + ds = _TestDataset(ds) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) + self.assertSetEqual( + set(["RangeDataset", "ModelDataset", "PrefetchDataset"]), + set([x.name for x in variant_tensor_ops])) + + @test_util.run_deprecated_v1 + def testFlatMap(self): + ds1 = dataset_ops.Dataset.range(10).repeat(10) + + def map_fn(ds): + + def _map(x): + return ds.batch(x) + + return _map + + ds2 = dataset_ops.Dataset.range(20).prefetch(1) + ds2 = ds2.flat_map(map_fn(ds1)) + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2) + self.assertSetEqual( + set([ + "FlatMapDataset", "PrefetchDataset", "RepeatDataset", + "RangeDataset", "RangeDataset_1" + ]), set([x.name for x in variant_tensor_ops])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 887c61cb8fd..02957b2fefb 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -270,6 +270,7 @@ cuda_py_test( ":input_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:traverse", "//tensorflow/python:errors", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/distribute/input_ops.py b/tensorflow/python/distribute/input_ops.py index 2ded209701e..d9e833b6bc6 100644 --- a/tensorflow/python/distribute/input_ops.py +++ b/tensorflow/python/distribute/input_ops.py @@ -18,15 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import filter_for_shard_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import traverse +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging + # TODO(priyag): Any other reader datasets to consider here? _READER_DATASET_OPS = [ "TextLineDataset", "TFRecordDataset", "FixedLengthRecordDataset", @@ -53,100 +51,57 @@ def auto_shard_dataset(dataset, num_shards, index): determine a good way to shard the input dataset. """ - # TODO(priyag): Clone datasets instead of updating in place, similar to the - # clone method for TFRecordDataset. - def _auto_shard_impl(dataset, found_reader_op): - """Recursive implementation of auto sharding.""" + # TODO(rohanj): b/120673685 to track re-enabling auto sharding. + tf_logging.warn("Autosharding is currently disabled. Please shard your input " + "manually.") + del num_shards, index + return dataset - if not found_reader_op: - # TODO(priyag): Make this check more robust by enforcing some common - # property on reader datasets. - if (isinstance(dataset, readers.TextLineDataset) or - isinstance(dataset, readers.FixedLengthRecordDataset)): - filenames_tensor = dataset._filenames - num_files = array_ops.size(filenames_tensor) - sharded_filenames_tensor = array_ops.gather( - filenames_tensor, math_ops.range(index, num_files, num_shards)) - dataset._filenames = sharded_filenames_tensor - return dataset - elif isinstance(dataset, readers.TFRecordDataset): - # `TFRecordDataset` needs to be handled separately than other readers - # because it converts filenames to a dataset first. Also, we clone it - # instead of updating in place because it has special logic in the - # constructor. Eventually we will change all cases to clone datasets - # instead of updating in-place. - return dataset._clone( - filenames=dataset._filenames.apply( - filter_for_shard_ops.filter_for_shard(num_shards, index))) - elif isinstance(dataset, dataset_ops.RangeDataset): - return dataset.apply( - filter_for_shard_ops.filter_for_shard(num_shards, index)) - elif hasattr(dataset, "_map_func"): - # TODO(priyag): Make this check more robust by enforcing some common - # property on all map/flatmap/interleave datasets. - map_func_def = dataset._map_func.function.definition - for node in map_func_def.node_def: - if node.op in _READER_DATASET_OPS: - found_reader_op = True - break - elif node.op == "FlatMapDataset": - # TODO(priyag): Should this check for other map datasets? Should it - # be recursive? It is too specific to implementation of - # TFRecordDataset right now. - nested_func_name = node.attr["f"].func.name - nested_func = ops.get_default_graph()._functions[nested_func_name] - for nested_node in nested_func.definition.node_def: - if nested_node.op in _READER_DATASET_OPS: - found_reader_op = True - break - if found_reader_op: - break - if found_reader_op: - dataset._input_dataset = _auto_shard_impl( - dataset._input_dataset, found_reader_op) - return dataset - if isinstance(dataset, dataset_ops.DatasetV1Adapter): - dataset._dataset = _auto_shard_impl( - dataset._dataset, found_reader_op) - return dataset +def _clone_dataset(dataset): + """Returns a cloned version of `dataset`.""" + variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) + remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) + new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] + return dataset_ops._VariantDataset(new_variant_tensor, + dataset._element_structure) - # TODO(priyag): Make _input_dataset(s) a common property of all datasets to - # make this check more robust. - if hasattr(dataset, "_input_dataset"): - dataset._input_dataset = _auto_shard_impl( - dataset._input_dataset, found_reader_op) - if hasattr(dataset, "_dataset_to_concatenate"): - # Special case for `ConcatentateDataset`. We want to shard all input - # datasets. - dataset._dataset_to_concatenate = _auto_shard_impl( - dataset._dataset_to_concatenate, found_reader_op) - return dataset - if hasattr(dataset, "_datasets"): - # Special case for `ZipDataset`. - dataset._datasets = nest.pack_sequence_as(dataset._datasets, [ - _auto_shard_impl(ds, found_reader_op) - for ds in nest.flatten(dataset._datasets) - ]) - return dataset +def _get_op_def(op): + return op.op_def or op_def_registry.get_registered_ops()[op.type] - if not found_reader_op: - tf_logging.warn( - "Could not find a standard reader in the input pipeline" - "(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)." - "So auto-sharding is not done. Please verify correctness of " - "auto-sharding for your input.") - # TODO(yuefengz): maybe still shard it? - return dataset - # TODO(priyag): What do we want to do if the number of filenames is - # uneven in the number of shards? By default, this will just return as - # many items it can before throwing OutOfRangeError. - # TODO(priyag): This will shard the filenames before any shuffling of the - # filename dataset. It might be desirable to shard after shuffling - # filenames? If so, how do we achieve that? - return dataset.apply( - filter_for_shard_ops.filter_for_shard(num_shards, index)) +def _clone_helper(op_to_clone, variant_tensor_ops): + """Helper method that recursively clones `op_to_clone`. - return _auto_shard_impl(dataset=dataset, found_reader_op=False) + Args: + op_to_clone: The op we want to clone. + variant_tensor_ops: A list of ops that we have to clone along the way. + + Returns: + A dictionary mapping old_ops to new_ops created. Includes op_to_clone + as a key. + """ + remap_dict = {} + for input_tensor in op_to_clone.inputs: + input_tensor_op = input_tensor.op + if input_tensor_op in variant_tensor_ops: + recursive_map = _clone_helper(input_tensor_op, variant_tensor_ops) + remap_dict.update(recursive_map) + inputs_list = [] + for input_tensor in op_to_clone.inputs: + input_tensor_op = input_tensor.op + if input_tensor_op in remap_dict: + remapped_input = remap_dict[input_tensor_op].outputs[0] + inputs_list.append(remapped_input) + else: + inputs_list.append(input_tensor_op.outputs[input_tensor.value_index]) + g = ops.get_default_graph() + new_op = g.create_op( + op_to_clone.type, + inputs_list, [o.dtype for o in op_to_clone.outputs], + name=op_to_clone.name, + attrs=op_to_clone.node_def.attr, + op_def=_get_op_def(op_to_clone)) + remap_dict[op_to_clone] = new_op + return remap_dict diff --git a/tensorflow/python/distribute/input_ops_test.py b/tensorflow/python/distribute/input_ops_test.py index dcf946ba477..7db75163ed3 100644 --- a/tensorflow/python/distribute/input_ops_test.py +++ b/tensorflow/python/distribute/input_ops_test.py @@ -26,6 +26,8 @@ from tensorflow.python.distribute import input_ops from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -90,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase): def _verifySimpleShardingOutput(self, dataset, record_fn): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.cached_session() as sess: + with self.cached_session(): for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(record_fn(r, f), self.evaluate(next_element)) @@ -98,7 +100,7 @@ class AutoShardDatasetTest(test.TestCase): self.evaluate(next_element) @test_util.run_deprecated_v1 - def testTFRecordDataset(self): + def DISABLED_testTFRecordDataset(self): dataset = readers.TFRecordDataset(self._createTFRecordFiles()) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) @@ -106,7 +108,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._record) @test_util.run_deprecated_v1 - def testFlatMap(self): + def DISABLED_testFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.flat_map(readers.TFRecordDataset) @@ -116,7 +118,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._record) @test_util.run_deprecated_v1 - def testInterleave(self): + def DISABLED_testInterleave(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.interleave( @@ -129,7 +131,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._record) @test_util.run_deprecated_v1 - def testListfiles(self): + def DISABLED_testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) @@ -139,7 +141,7 @@ class AutoShardDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.cached_session() as sess: + with self.cached_session(): actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): @@ -150,7 +152,7 @@ class AutoShardDatasetTest(test.TestCase): self.assertAllEqual(expected, actual) @test_util.run_deprecated_v1 - def testComplexPipeline(self): + def DISABLED_testComplexPipeline(self): # Setup a complex input pipeline. batch_size = 2 num_epochs = 5 @@ -172,7 +174,7 @@ class AutoShardDatasetTest(test.TestCase): # Verify output. iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.cached_session() as sess: + with self.cached_session(): actual = [] num_iterations = (self._num_files * self._num_records * num_epochs) // ( self._num_shards * batch_size) @@ -190,7 +192,7 @@ class AutoShardDatasetTest(test.TestCase): self.assertAllEqual(sorted(expected), sorted(actual)) @test_util.run_deprecated_v1 - def testZip(self): + def DISABLED_testZip(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -201,7 +203,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, record_fn) @test_util.run_deprecated_v1 - def testConcat(self): + def DISABLED_testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) @@ -222,7 +224,7 @@ class AutoShardDatasetTest(test.TestCase): self.evaluate(next_element) @test_util.run_deprecated_v1 - def testTextLineReader(self): + def DISABLED_testTextLineReader(self): dataset = readers.TextLineDataset(self._createTextFiles()) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) @@ -230,7 +232,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._text_line) @test_util.run_deprecated_v1 - def testTextLineReaderWithFlatMap(self): + def DISABLED_testTextLineReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) dataset = dataset.flat_map(readers.TextLineDataset) dataset = input_ops.auto_shard_dataset( @@ -239,7 +241,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._text_line) @test_util.run_deprecated_v1 - def testFixedLengthReader(self): + def DISABLED_testFixedLengthReader(self): dataset = readers.FixedLengthRecordDataset( self._createFixedLengthRecordFiles(), self._record_bytes) dataset = input_ops.auto_shard_dataset( @@ -248,7 +250,7 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._fixed_length_record) @test_util.run_deprecated_v1 - def testFixedLengthReaderWithFlatMap(self): + def DISABLED_testFixedLengthReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createFixedLengthRecordFiles()) dataset = dataset.flat_map( @@ -258,5 +260,77 @@ class AutoShardDatasetTest(test.TestCase): self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + +# A dataset that creates two variant tensors. +class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): + + def __init__(self, input_dataset): + self._input_dataset = input_dataset + temp_variant_tensor = gen_dataset_ops.prefetch_dataset( + input_dataset._variant_tensor, + buffer_size=1, + **dataset_ops.flat_structure(self)) + variant_tensor = gen_dataset_ops.model_dataset( + temp_variant_tensor, **dataset_ops.flat_structure(self)) + super(_TestDataset, self).__init__(input_dataset, variant_tensor) + + +class CloneDatasetTest(test.TestCase): + + def _assert_datasets_equal(self, ds1, ds2): + # First lets assert the structure is the same. + self.assertTrue( + ds1._element_structure.is_compatible_with(ds2._element_structure)) + self.assertTrue( + ds2._element_structure.is_compatible_with(ds1._element_structure)) + + # Now create iterators on both and assert they produce the same values. + it1 = dataset_ops.make_initializable_iterator(ds1) + it2 = dataset_ops.make_initializable_iterator(ds2) + + get_next1 = it1.get_next() + get_next2 = it2.get_next() + + with self.cached_session(): + self.evaluate([it1.initializer, it2.initializer]) + val1, val2 = self.evaluate([get_next1, get_next2]) + self.assertEqual(val1, val2) + + @test_util.run_deprecated_v1 + def testOnlySource(self): + ds = dataset_ops.Dataset.range(10) + cloned_ds = input_ops._clone_dataset(ds) + self._assert_datasets_equal(ds, cloned_ds) + + @test_util.run_deprecated_v1 + def testSimplePipeline(self): + ds = dataset_ops.Dataset.range(10).map(math_ops.square) + cloned_ds = input_ops._clone_dataset(ds) + self._assert_datasets_equal(ds, cloned_ds) + + @test_util.run_deprecated_v1 + def testConcat(self): + ds1 = dataset_ops.Dataset.range(10) + ds2 = dataset_ops.Dataset.range(10) + ds = ds1.concatenate(ds2) + cloned_ds = input_ops._clone_dataset(ds) + self._assert_datasets_equal(ds, cloned_ds) + + @test_util.run_deprecated_v1 + def testZip(self): + ds1 = dataset_ops.Dataset.range(10) + ds2 = dataset_ops.Dataset.range(10) + ds = dataset_ops.Dataset.zip((ds1, ds2)) + cloned_ds = input_ops._clone_dataset(ds) + self._assert_datasets_equal(ds, cloned_ds) + + @test_util.run_deprecated_v1 + def testMultipleVariantTensors(self): + ds = dataset_ops.Dataset.range(10) + ds = _TestDataset(ds) + cloned_ds = input_ops._clone_dataset(ds) + self._assert_datasets_equal(ds, cloned_ds) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 55fc9c9e1f0..e0c575b01cf 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1600,9 +1600,9 @@ class MultiWorkerDataset(object): if len(dataset_fn) != input_workers.num_workers: raise ValueError("If `dataset_fn` is a list, it must have one entry " "per worker") - if auto_shard: - raise ValueError( - "If `dataset_fn` is a list, `auto_shard` is not supported.") + # TODO(rohanj): b/120673685 to track re-enabling auto sharding. + if auto_shard: + raise ValueError("Currently autosharding is not supported.") self._input_workers = input_workers self._datasets = [] # TODO(yuefengz, priyag): support different set of jobs for input @@ -1613,9 +1613,6 @@ class MultiWorkerDataset(object): worker_input = dataset_fn[i]() else: worker_input = dataset_fn() - if auto_shard: - worker_input = input_ops.auto_shard_dataset( - worker_input, input_workers.num_workers, i) dataset = PerReplicaDataset(worker_input, input_workers, i, prefetch_on_device=prefetch_on_device) self._datasets.append((worker, dataset)) @@ -1805,7 +1802,11 @@ class DatasetIterator(InputIteratorImpl): for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) - iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices) + cloned_dataset = dataset + if not context.executing_eagerly(): + cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access + iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, + worker_devices) iterators.append(iterator) super(DatasetIterator, self).__init__(input_workers, iterators) diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 3ecac329aa2..951b2df05ac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -16,7 +16,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None" } member_method { name: "apply"