Dataset utilities UX improvements
- Display number of files used for training/validation when validation_split is used - Refuse to perform validation split if the data is shuffled and not seeded PiperOrigin-RevId: 308750122 Change-Id: I07f9090e714d1290532c7b7b7f51417f7193c797
This commit is contained in:
parent
2a72ad4071
commit
bcdbfb8a9c
@ -170,19 +170,16 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple (samples, labels), potentially restricted to the specified subset.
|
tuple (samples, labels), potentially restricted to the specified subset.
|
||||||
"""
|
"""
|
||||||
if validation_split:
|
if not validation_split:
|
||||||
if not 0 < validation_split < 1:
|
|
||||||
raise ValueError(
|
|
||||||
'`validation_split` must be between 0 and 1, received: %s' %
|
|
||||||
(validation_split,))
|
|
||||||
if subset is None:
|
|
||||||
return samples, labels
|
return samples, labels
|
||||||
|
|
||||||
num_val_samples = int(validation_split * len(samples))
|
num_val_samples = int(validation_split * len(samples))
|
||||||
if subset == 'training':
|
if subset == 'training':
|
||||||
|
print('Using %d files for training.' % (len(samples) - num_val_samples,))
|
||||||
samples = samples[:-num_val_samples]
|
samples = samples[:-num_val_samples]
|
||||||
labels = labels[:-num_val_samples]
|
labels = labels[:-num_val_samples]
|
||||||
elif subset == 'validation':
|
elif subset == 'validation':
|
||||||
|
print('Using %d files for validation.' % (num_val_samples,))
|
||||||
samples = samples[-num_val_samples:]
|
samples = samples[-num_val_samples:]
|
||||||
labels = labels[-num_val_samples:]
|
labels = labels[-num_val_samples:]
|
||||||
else:
|
else:
|
||||||
@ -199,3 +196,22 @@ def labels_to_dataset(labels, label_mode, num_classes):
|
|||||||
elif label_mode == 'categorical':
|
elif label_mode == 'categorical':
|
||||||
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
|
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
|
||||||
return label_ds
|
return label_ds
|
||||||
|
|
||||||
|
|
||||||
|
def check_validation_split_arg(validation_split, subset, shuffle, seed):
|
||||||
|
"""Raise errors in case of invalid argument values."""
|
||||||
|
if validation_split and not 0 < validation_split < 1:
|
||||||
|
raise ValueError(
|
||||||
|
'`validation_split` must be between 0 and 1, received: %s' %
|
||||||
|
(validation_split,))
|
||||||
|
if (validation_split or subset) and not (validation_split and subset):
|
||||||
|
raise ValueError(
|
||||||
|
'If `subset` is set, `validation_split` must be set, and inversely.')
|
||||||
|
if subset not in ('training', 'validation', None):
|
||||||
|
raise ValueError('`subset` must be either "training" '
|
||||||
|
'or "validation", received: %s' % (subset,))
|
||||||
|
if validation_split and shuffle and seed is None:
|
||||||
|
raise ValueError(
|
||||||
|
'If using `validation_split` and shuffling the data, you must provide '
|
||||||
|
'a `seed` argument, to make sure that there is no overlap between the '
|
||||||
|
'training and validation subset.')
|
||||||
|
@ -167,6 +167,8 @@ def image_dataset_from_directory(directory,
|
|||||||
'`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
|
'`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
|
||||||
'Received: %s' % (color_mode,))
|
'Received: %s' % (color_mode,))
|
||||||
interpolation = image_preprocessing.get_interpolation(interpolation)
|
interpolation = image_preprocessing.get_interpolation(interpolation)
|
||||||
|
dataset_utils.check_validation_split_arg(
|
||||||
|
validation_split, subset, shuffle, seed)
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = np.random.randint(1e6)
|
seed = np.random.randint(1e6)
|
||||||
|
@ -199,13 +199,13 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
directory = self._prepare_directory(num_classes=2, count=10)
|
directory = self._prepare_directory(num_classes=2, count=10)
|
||||||
dataset = image_dataset.image_dataset_from_directory(
|
dataset = image_dataset.image_dataset_from_directory(
|
||||||
directory, batch_size=10, image_size=(18, 18),
|
directory, batch_size=10, image_size=(18, 18),
|
||||||
validation_split=0.2, subset='training')
|
validation_split=0.2, subset='training', seed=1337)
|
||||||
batch = next(iter(dataset))
|
batch = next(iter(dataset))
|
||||||
self.assertLen(batch, 2)
|
self.assertLen(batch, 2)
|
||||||
self.assertEqual(batch[0].shape, (8, 18, 18, 3))
|
self.assertEqual(batch[0].shape, (8, 18, 18, 3))
|
||||||
dataset = image_dataset.image_dataset_from_directory(
|
dataset = image_dataset.image_dataset_from_directory(
|
||||||
directory, batch_size=10, image_size=(18, 18),
|
directory, batch_size=10, image_size=(18, 18),
|
||||||
validation_split=0.2, subset='validation')
|
validation_split=0.2, subset='validation', seed=1337)
|
||||||
batch = next(iter(dataset))
|
batch = next(iter(dataset))
|
||||||
self.assertLen(batch, 2)
|
self.assertLen(batch, 2)
|
||||||
self.assertEqual(batch[0].shape, (2, 18, 18, 3))
|
self.assertEqual(batch[0].shape, (2, 18, 18, 3))
|
||||||
@ -285,6 +285,14 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
_ = image_dataset.image_dataset_from_directory(
|
_ = image_dataset.image_dataset_from_directory(
|
||||||
directory, validation_split=0.2, subset='other')
|
directory, validation_split=0.2, subset='other')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
|
||||||
|
_ = image_dataset.image_dataset_from_directory(
|
||||||
|
directory, validation_split=0, subset='training')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
|
||||||
|
_ = image_dataset.image_dataset_from_directory(
|
||||||
|
directory, validation_split=0.2, subset='training')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
|
@ -131,6 +131,8 @@ def text_dataset_from_directory(directory,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||||
'or None. Received: %s' % (label_mode,))
|
'or None. Received: %s' % (label_mode,))
|
||||||
|
dataset_utils.check_validation_split_arg(
|
||||||
|
validation_split, subset, shuffle, seed)
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = np.random.randint(1e6)
|
seed = np.random.randint(1e6)
|
||||||
|
@ -142,12 +142,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
def test_text_dataset_from_directory_validation_split(self):
|
def test_text_dataset_from_directory_validation_split(self):
|
||||||
directory = self._prepare_directory(num_classes=2, count=10)
|
directory = self._prepare_directory(num_classes=2, count=10)
|
||||||
dataset = text_dataset.text_dataset_from_directory(
|
dataset = text_dataset.text_dataset_from_directory(
|
||||||
directory, batch_size=10, validation_split=0.2, subset='training')
|
directory, batch_size=10, validation_split=0.2, subset='training',
|
||||||
|
seed=1337)
|
||||||
batch = next(iter(dataset))
|
batch = next(iter(dataset))
|
||||||
self.assertLen(batch, 2)
|
self.assertLen(batch, 2)
|
||||||
self.assertEqual(batch[0].shape, (8,))
|
self.assertEqual(batch[0].shape, (8,))
|
||||||
dataset = text_dataset.text_dataset_from_directory(
|
dataset = text_dataset.text_dataset_from_directory(
|
||||||
directory, batch_size=10, validation_split=0.2, subset='validation')
|
directory, batch_size=10, validation_split=0.2, subset='validation',
|
||||||
|
seed=1337)
|
||||||
batch = next(iter(dataset))
|
batch = next(iter(dataset))
|
||||||
self.assertLen(batch, 2)
|
self.assertLen(batch, 2)
|
||||||
self.assertEqual(batch[0].shape, (2,))
|
self.assertEqual(batch[0].shape, (2,))
|
||||||
@ -212,6 +214,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
_ = text_dataset.text_dataset_from_directory(
|
_ = text_dataset.text_dataset_from_directory(
|
||||||
directory, validation_split=0.2, subset='other')
|
directory, validation_split=0.2, subset='other')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
|
||||||
|
_ = text_dataset.text_dataset_from_directory(
|
||||||
|
directory, validation_split=0, subset='training')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
|
||||||
|
_ = text_dataset.text_dataset_from_directory(
|
||||||
|
directory, validation_split=0.2, subset='training')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
|
Loading…
Reference in New Issue
Block a user