[tf.data] Sort the results of tf.matching_files() to enable Dataset.list_files() to be determinstic.

PiperOrigin-RevId: 193126572
This commit is contained in:
Derek Murray 2018-04-16 17:25:12 -07:00 committed by TensorFlower Gardener
parent 451070ab9e
commit d0345d2d86
7 changed files with 62 additions and 11 deletions

View File

@ -60,6 +60,7 @@ class MatchingFilesOp : public OpKernel {
output(index++) = all_fnames[i][j];
}
}
std::sort(&output(0), &output(0) + num_files);
}
};

View File

@ -69,6 +69,54 @@ class ListFilesDatasetOpTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
def testSimpleDirectoryNotShuffled(self):
filenames = ['b', 'c', 'a']
self._touchTempFiles(filenames)
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False)
with self.test_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
for filename in sorted(filenames):
self.assertEqual(compat.as_bytes(path.join(self.tmp_dir, filename)),
sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
def testFixedSeedResultsInRepeatableOrder(self):
filenames = ['a', 'b', 'c']
self._touchTempFiles(filenames)
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
for filename in filenames]
all_produced_filenames = []
for _ in range(3):
produced_filenames = []
sess.run(itr.initializer)
try:
while True:
produced_filenames.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
all_produced_filenames.append(produced_filenames)
# Each run should produce the same set of filenames, which may be
# different from the order of `full_filenames`.
self.assertItemsEqual(full_filenames, all_produced_filenames[0])
# However, the different runs should produce filenames in the same order
# as each other.
self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])
def testEmptyDirectoryInitializer(self):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)

View File

@ -571,9 +571,13 @@ class Dataset(object):
return PrefetchDataset(self, buffer_size)
@staticmethod
def list_files(file_pattern, shuffle=None):
def list_files(file_pattern, shuffle=None, seed=None):
"""A dataset of all files matching a pattern.
NOTE: The default behavior of this method is to return filenames in
a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
to get results in a deterministic order.
Example:
If we had the following files on our filesystem:
- /path/to/dir/a.txt
@ -584,20 +588,18 @@ class Dataset(object):
- /path/to/dir/b.py
- /path/to/dir/c.py
NOTE: The order of the file names returned can be non-deterministic even
when `shuffle` is `False`.
Args:
file_pattern: A string or scalar string `tf.Tensor`, representing
the filename pattern that will be matched.
shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
Defaults to `True`.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
random seed that will be used to create the distribution. See
@{tf.set_random_seed} for behavior.
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
# TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
# behavior deterministic (e.g. by sorting the filenames).
if shuffle is None:
shuffle = True
matching_files = gen_io_ops.matching_files(file_pattern)
@ -607,7 +609,7 @@ class Dataset(object):
# list of files might be empty.
buffer_size = math_ops.maximum(
array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.shuffle(buffer_size, seed=seed)
return dataset
def repeat(self, count=None):

View File

@ -64,7 +64,7 @@ tf_class {
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"

View File

@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"

View File

@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"

View File

@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_initializable_iterator"