Dataset utilities UX improvements
- Display number of files used for training/validation when validation_split is used - Refuse to perform validation split if the data is shuffled and not seeded PiperOrigin-RevId: 308750122 Change-Id: I07f9090e714d1290532c7b7b7f51417f7193c797
This commit is contained in:
parent
2a72ad4071
commit
bcdbfb8a9c
@ -170,19 +170,16 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
|
||||
Returns:
|
||||
tuple (samples, labels), potentially restricted to the specified subset.
|
||||
"""
|
||||
if validation_split:
|
||||
if not 0 < validation_split < 1:
|
||||
raise ValueError(
|
||||
'`validation_split` must be between 0 and 1, received: %s' %
|
||||
(validation_split,))
|
||||
if subset is None:
|
||||
if not validation_split:
|
||||
return samples, labels
|
||||
|
||||
num_val_samples = int(validation_split * len(samples))
|
||||
if subset == 'training':
|
||||
print('Using %d files for training.' % (len(samples) - num_val_samples,))
|
||||
samples = samples[:-num_val_samples]
|
||||
labels = labels[:-num_val_samples]
|
||||
elif subset == 'validation':
|
||||
print('Using %d files for validation.' % (num_val_samples,))
|
||||
samples = samples[-num_val_samples:]
|
||||
labels = labels[-num_val_samples:]
|
||||
else:
|
||||
@ -199,3 +196,22 @@ def labels_to_dataset(labels, label_mode, num_classes):
|
||||
elif label_mode == 'categorical':
|
||||
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
|
||||
return label_ds
|
||||
|
||||
|
||||
def check_validation_split_arg(validation_split, subset, shuffle, seed):
|
||||
"""Raise errors in case of invalid argument values."""
|
||||
if validation_split and not 0 < validation_split < 1:
|
||||
raise ValueError(
|
||||
'`validation_split` must be between 0 and 1, received: %s' %
|
||||
(validation_split,))
|
||||
if (validation_split or subset) and not (validation_split and subset):
|
||||
raise ValueError(
|
||||
'If `subset` is set, `validation_split` must be set, and inversely.')
|
||||
if subset not in ('training', 'validation', None):
|
||||
raise ValueError('`subset` must be either "training" '
|
||||
'or "validation", received: %s' % (subset,))
|
||||
if validation_split and shuffle and seed is None:
|
||||
raise ValueError(
|
||||
'If using `validation_split` and shuffling the data, you must provide '
|
||||
'a `seed` argument, to make sure that there is no overlap between the '
|
||||
'training and validation subset.')
|
||||
|
@ -167,6 +167,8 @@ def image_dataset_from_directory(directory,
|
||||
'`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
|
||||
'Received: %s' % (color_mode,))
|
||||
interpolation = image_preprocessing.get_interpolation(interpolation)
|
||||
dataset_utils.check_validation_split_arg(
|
||||
validation_split, subset, shuffle, seed)
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
|
@ -199,13 +199,13 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
directory = self._prepare_directory(num_classes=2, count=10)
|
||||
dataset = image_dataset.image_dataset_from_directory(
|
||||
directory, batch_size=10, image_size=(18, 18),
|
||||
validation_split=0.2, subset='training')
|
||||
validation_split=0.2, subset='training', seed=1337)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8, 18, 18, 3))
|
||||
dataset = image_dataset.image_dataset_from_directory(
|
||||
directory, batch_size=10, image_size=(18, 18),
|
||||
validation_split=0.2, subset='validation')
|
||||
validation_split=0.2, subset='validation', seed=1337)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (2, 18, 18, 3))
|
||||
@ -285,6 +285,14 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
directory, validation_split=0.2, subset='other')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
directory, validation_split=0, subset='training')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
directory, validation_split=0.2, subset='training')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
@ -131,6 +131,8 @@ def text_dataset_from_directory(directory,
|
||||
raise ValueError(
|
||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||
'or None. Received: %s' % (label_mode,))
|
||||
dataset_utils.check_validation_split_arg(
|
||||
validation_split, subset, shuffle, seed)
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
|
@ -142,12 +142,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
def test_text_dataset_from_directory_validation_split(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=10)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=10, validation_split=0.2, subset='training')
|
||||
directory, batch_size=10, validation_split=0.2, subset='training',
|
||||
seed=1337)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=10, validation_split=0.2, subset='validation')
|
||||
directory, batch_size=10, validation_split=0.2, subset='validation',
|
||||
seed=1337)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (2,))
|
||||
@ -212,6 +214,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, validation_split=0.2, subset='other')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, validation_split=0, subset='training')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, validation_split=0.2, subset='training')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
Loading…
Reference in New Issue
Block a user