Dataset utilities UX improvements

- Display number of files used for training/validation when validation_split is used
- Refuse to perform validation split if the data is shuffled and not seeded

PiperOrigin-RevId: 308750122
Change-Id: I07f9090e714d1290532c7b7b7f51417f7193c797
This commit is contained in:
Francois Chollet 2020-04-27 20:12:22 -07:00 committed by TensorFlower Gardener
parent 2a72ad4071
commit bcdbfb8a9c
5 changed files with 48 additions and 10 deletions

View File

@ -170,19 +170,16 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
Returns:
tuple (samples, labels), potentially restricted to the specified subset.
"""
if validation_split:
if not 0 < validation_split < 1:
raise ValueError(
'`validation_split` must be between 0 and 1, received: %s' %
(validation_split,))
if subset is None:
if not validation_split:
return samples, labels
num_val_samples = int(validation_split * len(samples))
if subset == 'training':
print('Using %d files for training.' % (len(samples) - num_val_samples,))
samples = samples[:-num_val_samples]
labels = labels[:-num_val_samples]
elif subset == 'validation':
print('Using %d files for validation.' % (num_val_samples,))
samples = samples[-num_val_samples:]
labels = labels[-num_val_samples:]
else:
@ -199,3 +196,22 @@ def labels_to_dataset(labels, label_mode, num_classes):
elif label_mode == 'categorical':
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
return label_ds
def check_validation_split_arg(validation_split, subset, shuffle, seed):
"""Raise errors in case of invalid argument values."""
if validation_split and not 0 < validation_split < 1:
raise ValueError(
'`validation_split` must be between 0 and 1, received: %s' %
(validation_split,))
if (validation_split or subset) and not (validation_split and subset):
raise ValueError(
'If `subset` is set, `validation_split` must be set, and inversely.')
if subset not in ('training', 'validation', None):
raise ValueError('`subset` must be either "training" '
'or "validation", received: %s' % (subset,))
if validation_split and shuffle and seed is None:
raise ValueError(
'If using `validation_split` and shuffling the data, you must provide '
'a `seed` argument, to make sure that there is no overlap between the '
'training and validation subset.')

View File

@ -167,6 +167,8 @@ def image_dataset_from_directory(directory,
'`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
'Received: %s' % (color_mode,))
interpolation = image_preprocessing.get_interpolation(interpolation)
dataset_utils.check_validation_split_arg(
validation_split, subset, shuffle, seed)
if seed is None:
seed = np.random.randint(1e6)

View File

@ -199,13 +199,13 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
directory = self._prepare_directory(num_classes=2, count=10)
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=10, image_size=(18, 18),
validation_split=0.2, subset='training')
validation_split=0.2, subset='training', seed=1337)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8, 18, 18, 3))
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=10, image_size=(18, 18),
validation_split=0.2, subset='validation')
validation_split=0.2, subset='validation', seed=1337)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (2, 18, 18, 3))
@ -285,6 +285,14 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
_ = image_dataset.image_dataset_from_directory(
directory, validation_split=0.2, subset='other')
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
_ = image_dataset.image_dataset_from_directory(
directory, validation_split=0, subset='training')
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
_ = image_dataset.image_dataset_from_directory(
directory, validation_split=0.2, subset='training')
if __name__ == '__main__':
v2_compat.enable_v2_behavior()

View File

@ -131,6 +131,8 @@ def text_dataset_from_directory(directory,
raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,))
dataset_utils.check_validation_split_arg(
validation_split, subset, shuffle, seed)
if seed is None:
seed = np.random.randint(1e6)

View File

@ -142,12 +142,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
def test_text_dataset_from_directory_validation_split(self):
directory = self._prepare_directory(num_classes=2, count=10)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=10, validation_split=0.2, subset='training')
directory, batch_size=10, validation_split=0.2, subset='training',
seed=1337)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=10, validation_split=0.2, subset='validation')
directory, batch_size=10, validation_split=0.2, subset='validation',
seed=1337)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (2,))
@ -212,6 +214,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
_ = text_dataset.text_dataset_from_directory(
directory, validation_split=0.2, subset='other')
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
_ = text_dataset.text_dataset_from_directory(
directory, validation_split=0, subset='training')
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
_ = text_dataset.text_dataset_from_directory(
directory, validation_split=0.2, subset='training')
if __name__ == '__main__':
v2_compat.enable_v2_behavior()