diff --git a/tensorflow/python/keras/preprocessing/dataset_utils.py b/tensorflow/python/keras/preprocessing/dataset_utils.py index 70d9566889f..bc65c7b9b99 100644 --- a/tensorflow/python/keras/preprocessing/dataset_utils.py +++ b/tensorflow/python/keras/preprocessing/dataset_utils.py @@ -170,19 +170,16 @@ def get_training_or_validation_split(samples, labels, validation_split, subset): Returns: tuple (samples, labels), potentially restricted to the specified subset. """ - if validation_split: - if not 0 < validation_split < 1: - raise ValueError( - '`validation_split` must be between 0 and 1, received: %s' % - (validation_split,)) - if subset is None: + if not validation_split: return samples, labels num_val_samples = int(validation_split * len(samples)) if subset == 'training': + print('Using %d files for training.' % (len(samples) - num_val_samples,)) samples = samples[:-num_val_samples] labels = labels[:-num_val_samples] elif subset == 'validation': + print('Using %d files for validation.' % (num_val_samples,)) samples = samples[-num_val_samples:] labels = labels[-num_val_samples:] else: @@ -199,3 +196,22 @@ def labels_to_dataset(labels, label_mode, num_classes): elif label_mode == 'categorical': label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes)) return label_ds + + +def check_validation_split_arg(validation_split, subset, shuffle, seed): + """Raise errors in case of invalid argument values.""" + if validation_split and not 0 < validation_split < 1: + raise ValueError( + '`validation_split` must be between 0 and 1, received: %s' % + (validation_split,)) + if (validation_split or subset) and not (validation_split and subset): + raise ValueError( + 'If `subset` is set, `validation_split` must be set, and inversely.') + if subset not in ('training', 'validation', None): + raise ValueError('`subset` must be either "training" ' + 'or "validation", received: %s' % (subset,)) + if validation_split and shuffle and seed is None: + raise ValueError( + 'If using `validation_split` and shuffling the data, you must provide ' + 'a `seed` argument, to make sure that there is no overlap between the ' + 'training and validation subset.') diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index 2e24ef887ae..a438c429c40 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -167,6 +167,8 @@ def image_dataset_from_directory(directory, '`color_mode` must be one of {"rbg", "rgba", "grayscale"}. ' 'Received: %s' % (color_mode,)) interpolation = image_preprocessing.get_interpolation(interpolation) + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed) if seed is None: seed = np.random.randint(1e6) diff --git a/tensorflow/python/keras/preprocessing/image_dataset_test.py b/tensorflow/python/keras/preprocessing/image_dataset_test.py index aa10c1c7bac..b196d8249ff 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset_test.py +++ b/tensorflow/python/keras/preprocessing/image_dataset_test.py @@ -199,13 +199,13 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): directory = self._prepare_directory(num_classes=2, count=10) dataset = image_dataset.image_dataset_from_directory( directory, batch_size=10, image_size=(18, 18), - validation_split=0.2, subset='training') + validation_split=0.2, subset='training', seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8, 18, 18, 3)) dataset = image_dataset.image_dataset_from_directory( directory, batch_size=10, image_size=(18, 18), - validation_split=0.2, subset='validation') + validation_split=0.2, subset='validation', seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (2, 18, 18, 3)) @@ -285,6 +285,14 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): _ = image_dataset.image_dataset_from_directory( directory, validation_split=0.2, subset='other') + with self.assertRaisesRegex(ValueError, '`validation_split` must be set'): + _ = image_dataset.image_dataset_from_directory( + directory, validation_split=0, subset='training') + + with self.assertRaisesRegex(ValueError, 'must provide a `seed`'): + _ = image_dataset.image_dataset_from_directory( + directory, validation_split=0.2, subset='training') + if __name__ == '__main__': v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/keras/preprocessing/text_dataset.py b/tensorflow/python/keras/preprocessing/text_dataset.py index 6a57e993ce0..c634df86edd 100644 --- a/tensorflow/python/keras/preprocessing/text_dataset.py +++ b/tensorflow/python/keras/preprocessing/text_dataset.py @@ -131,6 +131,8 @@ def text_dataset_from_directory(directory, raise ValueError( '`label_mode` argument must be one of "int", "categorical", "binary", ' 'or None. Received: %s' % (label_mode,)) + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed) if seed is None: seed = np.random.randint(1e6) diff --git a/tensorflow/python/keras/preprocessing/text_dataset_test.py b/tensorflow/python/keras/preprocessing/text_dataset_test.py index c0e231e69a9..f36bd9d89ad 100644 --- a/tensorflow/python/keras/preprocessing/text_dataset_test.py +++ b/tensorflow/python/keras/preprocessing/text_dataset_test.py @@ -142,12 +142,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase): def test_text_dataset_from_directory_validation_split(self): directory = self._prepare_directory(num_classes=2, count=10) dataset = text_dataset.text_dataset_from_directory( - directory, batch_size=10, validation_split=0.2, subset='training') + directory, batch_size=10, validation_split=0.2, subset='training', + seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) dataset = text_dataset.text_dataset_from_directory( - directory, batch_size=10, validation_split=0.2, subset='validation') + directory, batch_size=10, validation_split=0.2, subset='validation', + seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (2,)) @@ -212,6 +214,14 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase): _ = text_dataset.text_dataset_from_directory( directory, validation_split=0.2, subset='other') + with self.assertRaisesRegex(ValueError, '`validation_split` must be set'): + _ = text_dataset.text_dataset_from_directory( + directory, validation_split=0, subset='training') + + with self.assertRaisesRegex(ValueError, 'must provide a `seed`'): + _ = text_dataset.text_dataset_from_directory( + directory, validation_split=0.2, subset='training') + if __name__ == '__main__': v2_compat.enable_v2_behavior()