This CL renames the pre-existing `_element_structure` property of tf.data datasets and iterators to `element_spec`, thus exposing it in the public API. PiperOrigin-RevId: 256201202
331 lines
12 KiB
Python
331 lines
12 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for input pipeline modifications for distribution strategies."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import readers
|
|
from tensorflow.python.data.util import structure
|
|
from tensorflow.python.distribute import input_ops
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.lib.io import python_io
|
|
from tensorflow.python.ops import gen_dataset_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
class AutoShardDatasetTest(test.TestCase):
|
|
|
|
def setUp(self):
|
|
super(AutoShardDatasetTest, self).setUp()
|
|
self._num_files = 10
|
|
self._num_records = 4
|
|
self._num_shards = 2
|
|
self._shard_index = 0
|
|
self._record_bytes = 10
|
|
|
|
def _getNext(self, dataset):
|
|
if context.executing_eagerly():
|
|
iterator = iter(dataset)
|
|
return iterator._next_internal # pylint: disable=protected-access
|
|
else:
|
|
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
get_next = iterator.get_next()
|
|
return lambda: get_next
|
|
|
|
def _record(self, r, f):
|
|
return compat.as_bytes("Record %d of file %d" % (r, f))
|
|
|
|
def _text_line(self, r, f):
|
|
return compat.as_bytes("Text line %d of file %d" % (r, f))
|
|
|
|
def _fixed_length_record(self, r, f):
|
|
return compat.as_bytes(str((r * f) % 10) * self._record_bytes)
|
|
|
|
def _createTFRecordFiles(self):
|
|
filenames = []
|
|
for i in range(self._num_files):
|
|
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
|
|
filenames.append(fn)
|
|
writer = python_io.TFRecordWriter(fn)
|
|
for j in range(self._num_records):
|
|
record = self._record(j, i)
|
|
writer.write(record)
|
|
writer.close()
|
|
return filenames
|
|
|
|
def _createTextFiles(self):
|
|
filenames = []
|
|
for i in range(self._num_files):
|
|
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
|
|
filenames.append(fn)
|
|
contents = []
|
|
for j in range(self._num_records):
|
|
contents.append(self._text_line(j, i))
|
|
if j + 1 != self._num_records or i == 0:
|
|
contents.append(b"\r\n")
|
|
contents = b"".join(contents)
|
|
|
|
with open(fn, "wb") as f:
|
|
f.write(contents)
|
|
return filenames
|
|
|
|
def _createFixedLengthRecordFiles(self):
|
|
filenames = []
|
|
for i in range(self._num_files):
|
|
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
|
|
filenames.append(fn)
|
|
with open(fn, "wb") as f:
|
|
for j in range(self._num_records):
|
|
f.write(self._fixed_length_record(j, i))
|
|
return filenames
|
|
|
|
def _verifySimpleShardingOutput(self, dataset, record_fn):
|
|
next_element_fn = self._getNext(dataset)
|
|
with self.cached_session():
|
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
|
for r in range(self._num_records):
|
|
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn()))
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
self.evaluate(next_element_fn())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTFRecordDataset(self):
|
|
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
self._verifySimpleShardingOutput(dataset, self._record)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testFlatMap(self):
|
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
|
self._createTFRecordFiles())
|
|
dataset = dataset.flat_map(readers.TFRecordDataset)
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
self._verifySimpleShardingOutput(dataset, self._record)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testInterleave(self):
|
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
|
self._createTFRecordFiles())
|
|
dataset = dataset.interleave(
|
|
readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
# Since block_length == num records in each file, the output will still
|
|
# contain records in order of files.
|
|
self._verifySimpleShardingOutput(dataset, self._record)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testListfiles(self):
|
|
filenames = self._createTFRecordFiles()
|
|
file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
|
|
dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
|
|
dataset = dataset.flat_map(readers.TFRecordDataset)
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
next_element_fn = self._getNext(dataset)
|
|
actual, expected = [], []
|
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
|
for r in range(self._num_records):
|
|
actual.append(self.evaluate(next_element_fn()))
|
|
expected.append(self._record(r, f))
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
self.evaluate(next_element_fn())
|
|
self.assertAllEqual(expected, actual)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testComplexPipeline(self):
|
|
# Setup a complex input pipeline.
|
|
batch_size = 2
|
|
num_epochs = 5
|
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
|
self._createTFRecordFiles())
|
|
dataset = dataset.shuffle(buffer_size=self._num_files)
|
|
dataset = dataset.flat_map(readers.TFRecordDataset)
|
|
dataset = dataset.prefetch(buffer_size=batch_size)
|
|
dataset = dataset.shuffle(2 * self._num_files * self._num_records)
|
|
dataset = dataset.repeat(num_epochs)
|
|
dataset = dataset.map(lambda x: x)
|
|
dataset = dataset.batch(batch_size)
|
|
dataset = dataset.prefetch(buffer_size=None)
|
|
|
|
# Auto shard.
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
# Verify output.
|
|
next_element_fn = self._getNext(dataset)
|
|
actual = []
|
|
num_iterations = (self._num_files * self._num_records * num_epochs) // (
|
|
self._num_shards * batch_size)
|
|
for _ in range(num_iterations):
|
|
actual.extend(self.evaluate(next_element_fn()))
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
self.evaluate(next_element_fn())
|
|
|
|
expected = []
|
|
for f in range(0, self._num_files, self._num_shards):
|
|
for r in range(self._num_records):
|
|
expected.append(self._record(r, f))
|
|
expected *= num_epochs
|
|
|
|
self.assertAllEqual(sorted(expected), sorted(actual))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testZip(self):
|
|
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
|
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
|
|
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
|
|
self._verifySimpleShardingOutput(dataset, record_fn)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConcat(self):
|
|
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
|
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
|
|
|
dataset = dataset1.concatenate(dataset2)
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
next_element_fn = self._getNext(dataset)
|
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
|
for r in range(self._num_records):
|
|
self.assertAllEqual(
|
|
self._record(r, f), self.evaluate(next_element_fn()))
|
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
|
for r in range(self._num_records):
|
|
self.assertAllEqual(
|
|
self._text_line(r, f), self.evaluate(next_element_fn()))
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
self.evaluate(next_element_fn())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTextLineReader(self):
|
|
dataset = readers.TextLineDataset(self._createTextFiles())
|
|
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
self._verifySimpleShardingOutput(dataset, self._text_line)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTextLineReaderWithFlatMap(self):
|
|
dataset = readers.TextLineDataset(self._createTextFiles())
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
self._verifySimpleShardingOutput(dataset, self._text_line)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testFixedLengthReaderWithFlatMap(self):
|
|
dataset = readers.FixedLengthRecordDataset(
|
|
self._createFixedLengthRecordFiles(), self._record_bytes)
|
|
dataset = input_ops.auto_shard_dataset(
|
|
dataset, self._num_shards, self._shard_index)
|
|
|
|
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
|
|
|
|
|
|
# A dataset that creates two variant tensors.
|
|
class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|
|
|
def __init__(self, input_dataset):
|
|
self._input_dataset = input_dataset
|
|
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
|
|
input_dataset._variant_tensor,
|
|
buffer_size=1,
|
|
**self._flat_structure)
|
|
variant_tensor = gen_dataset_ops.model_dataset(
|
|
temp_variant_tensor, **self._flat_structure)
|
|
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
|
|
|
|
|
|
class CloneDatasetTest(test.TestCase):
|
|
|
|
def _assert_datasets_equal(self, ds1, ds2):
|
|
# First lets assert the structure is the same.
|
|
self.assertTrue(
|
|
structure.are_compatible(ds1.element_spec, ds2.element_spec))
|
|
|
|
# Now create iterators on both and assert they produce the same values.
|
|
it1 = dataset_ops.make_initializable_iterator(ds1)
|
|
it2 = dataset_ops.make_initializable_iterator(ds2)
|
|
|
|
get_next1 = it1.get_next()
|
|
get_next2 = it2.get_next()
|
|
|
|
with self.cached_session():
|
|
self.evaluate([it1.initializer, it2.initializer])
|
|
val1, val2 = self.evaluate([get_next1, get_next2])
|
|
self.assertEqual(val1, val2)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testOnlySource(self):
|
|
ds = dataset_ops.Dataset.range(10)
|
|
cloned_ds = input_ops._clone_dataset(ds)
|
|
self._assert_datasets_equal(ds, cloned_ds)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSimplePipeline(self):
|
|
ds = dataset_ops.Dataset.range(10).map(math_ops.square)
|
|
cloned_ds = input_ops._clone_dataset(ds)
|
|
self._assert_datasets_equal(ds, cloned_ds)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testConcat(self):
|
|
ds1 = dataset_ops.Dataset.range(10)
|
|
ds2 = dataset_ops.Dataset.range(10)
|
|
ds = ds1.concatenate(ds2)
|
|
cloned_ds = input_ops._clone_dataset(ds)
|
|
self._assert_datasets_equal(ds, cloned_ds)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testZip(self):
|
|
ds1 = dataset_ops.Dataset.range(10)
|
|
ds2 = dataset_ops.Dataset.range(10)
|
|
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
|
cloned_ds = input_ops._clone_dataset(ds)
|
|
self._assert_datasets_equal(ds, cloned_ds)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultipleVariantTensors(self):
|
|
ds = dataset_ops.Dataset.range(10)
|
|
ds = _TestDataset(ds)
|
|
cloned_ds = input_ops._clone_dataset(ds)
|
|
self._assert_datasets_equal(ds, cloned_ds)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|