[tf.data] Sort the results of tf.matching_files()
to enable Dataset.list_files()
to be determinstic.
PiperOrigin-RevId: 193126572
This commit is contained in:
parent
451070ab9e
commit
d0345d2d86
@ -60,6 +60,7 @@ class MatchingFilesOp : public OpKernel {
|
||||
output(index++) = all_fnames[i][j];
|
||||
}
|
||||
}
|
||||
std::sort(&output(0), &output(0) + num_files);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -69,6 +69,54 @@ class ListFilesDatasetOpTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(itr.get_next())
|
||||
|
||||
def testSimpleDirectoryNotShuffled(self):
|
||||
filenames = ['b', 'c', 'a']
|
||||
self._touchTempFiles(filenames)
|
||||
|
||||
dataset = dataset_ops.Dataset.list_files(
|
||||
path.join(self.tmp_dir, '*'), shuffle=False)
|
||||
with self.test_session() as sess:
|
||||
itr = dataset.make_one_shot_iterator()
|
||||
next_element = itr.get_next()
|
||||
|
||||
for filename in sorted(filenames):
|
||||
self.assertEqual(compat.as_bytes(path.join(self.tmp_dir, filename)),
|
||||
sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(itr.get_next())
|
||||
|
||||
def testFixedSeedResultsInRepeatableOrder(self):
|
||||
filenames = ['a', 'b', 'c']
|
||||
self._touchTempFiles(filenames)
|
||||
|
||||
dataset = dataset_ops.Dataset.list_files(
|
||||
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
|
||||
with self.test_session() as sess:
|
||||
itr = dataset.make_initializable_iterator()
|
||||
next_element = itr.get_next()
|
||||
|
||||
full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
|
||||
for filename in filenames]
|
||||
|
||||
all_produced_filenames = []
|
||||
for _ in range(3):
|
||||
produced_filenames = []
|
||||
sess.run(itr.initializer)
|
||||
try:
|
||||
while True:
|
||||
produced_filenames.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
all_produced_filenames.append(produced_filenames)
|
||||
|
||||
# Each run should produce the same set of filenames, which may be
|
||||
# different from the order of `full_filenames`.
|
||||
self.assertItemsEqual(full_filenames, all_produced_filenames[0])
|
||||
# However, the different runs should produce filenames in the same order
|
||||
# as each other.
|
||||
self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
|
||||
self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])
|
||||
|
||||
def testEmptyDirectoryInitializer(self):
|
||||
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
|
||||
|
@ -571,9 +571,13 @@ class Dataset(object):
|
||||
return PrefetchDataset(self, buffer_size)
|
||||
|
||||
@staticmethod
|
||||
def list_files(file_pattern, shuffle=None):
|
||||
def list_files(file_pattern, shuffle=None, seed=None):
|
||||
"""A dataset of all files matching a pattern.
|
||||
|
||||
NOTE: The default behavior of this method is to return filenames in
|
||||
a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
|
||||
to get results in a deterministic order.
|
||||
|
||||
Example:
|
||||
If we had the following files on our filesystem:
|
||||
- /path/to/dir/a.txt
|
||||
@ -584,20 +588,18 @@ class Dataset(object):
|
||||
- /path/to/dir/b.py
|
||||
- /path/to/dir/c.py
|
||||
|
||||
NOTE: The order of the file names returned can be non-deterministic even
|
||||
when `shuffle` is `False`.
|
||||
|
||||
Args:
|
||||
file_pattern: A string or scalar string `tf.Tensor`, representing
|
||||
the filename pattern that will be matched.
|
||||
shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
|
||||
Defaults to `True`.
|
||||
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
|
||||
random seed that will be used to create the distribution. See
|
||||
@{tf.set_random_seed} for behavior.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset` of strings corresponding to file names.
|
||||
"""
|
||||
# TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
|
||||
# behavior deterministic (e.g. by sorting the filenames).
|
||||
if shuffle is None:
|
||||
shuffle = True
|
||||
matching_files = gen_io_ops.matching_files(file_pattern)
|
||||
@ -607,7 +609,7 @@ class Dataset(object):
|
||||
# list of files might be empty.
|
||||
buffer_size = math_ops.maximum(
|
||||
array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
|
||||
dataset = dataset.shuffle(buffer_size)
|
||||
dataset = dataset.shuffle(buffer_size, seed=seed)
|
||||
return dataset
|
||||
|
||||
def repeat(self, count=None):
|
||||
|
@ -64,7 +64,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_initializable_iterator"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_initializable_iterator"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_initializable_iterator"
|
||||
|
@ -65,7 +65,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_initializable_iterator"
|
||||
|
Loading…
Reference in New Issue
Block a user