From 9349a096eb2c629dc622d75c6c738e80e0d408e7 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 11 Jan 2021 10:48:22 -0800 Subject: [PATCH] Add smart_resize option to image_dataset_from_directory. PiperOrigin-RevId: 351188455 Change-Id: Ifb24fabc529c266f5c99d9486f191b40ebfd76cf --- .../keras/preprocessing/image_dataset.py | 28 +++++++++++++++---- .../keras/preprocessing/image_dataset_test.py | 11 ++++++++ .../v2/tensorflow.keras.preprocessing.pbtxt | 2 +- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index 73c5d99b01e..a25735e599a 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.preprocessing import dataset_utils +from tensorflow.python.keras.preprocessing import image as keras_image_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import io_ops from tensorflow.python.util.tf_export import keras_export @@ -44,7 +45,8 @@ def image_dataset_from_directory(directory, validation_split=None, subset=None, interpolation='bilinear', - follow_links=False): + follow_links=False, + smart_resize=False): """Generates a `tf.data.Dataset` from image files in a directory. If your directory structure is: @@ -113,6 +115,11 @@ def image_dataset_from_directory(directory, `area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`. follow_links: Whether to visits subdirectories pointed to by symlinks. Defaults to False. + smart_resize: If True, the resizing function used will be + `tf.keras.preprocessing.image.smart_resize`, which preserves the aspect + ratio of the original image by using a mixture of resizing and cropping. + If False (default), the resizing function is `tf.image.resize`, which + does not preserve aspect ratio. Returns: A `tf.data.Dataset` object. @@ -202,7 +209,8 @@ def image_dataset_from_directory(directory, labels=labels, label_mode=label_mode, num_classes=len(class_names), - interpolation=interpolation) + interpolation=interpolation, + smart_resize=smart_resize) if shuffle: # Shuffle locally at each iteration dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) @@ -220,22 +228,30 @@ def paths_and_labels_to_dataset(image_paths, labels, label_mode, num_classes, - interpolation): + interpolation, + smart_resize=False): """Constructs a dataset of images and labels.""" # TODO(fchollet): consider making num_parallel_calls settable path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths) + args = (image_size, num_channels, interpolation, smart_resize) img_ds = path_ds.map( - lambda x: path_to_image(x, image_size, num_channels, interpolation)) + lambda x: load_image(x, *args)) if label_mode: label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes) img_ds = dataset_ops.Dataset.zip((img_ds, label_ds)) return img_ds -def path_to_image(path, image_size, num_channels, interpolation): +def load_image(path, image_size, num_channels, interpolation, + smart_resize=False): + """Load an image from a path and resize it.""" img = io_ops.read_file(path) img = image_ops.decode_image( img, channels=num_channels, expand_animations=False) - img = image_ops.resize_images_v2(img, image_size, method=interpolation) + if smart_resize: + img = keras_image_ops.smart_resize(img, image_size, + interpolation=interpolation) + else: + img = image_ops.resize_images_v2(img, image_size, method=interpolation) img.set_shape((image_size[0], image_size[1], num_channels)) return img diff --git a/tensorflow/python/keras/preprocessing/image_dataset_test.py b/tensorflow/python/keras/preprocessing/image_dataset_test.py index cdb79df2a68..4892bc5da9d 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset_test.py +++ b/tensorflow/python/keras/preprocessing/image_dataset_test.py @@ -284,6 +284,17 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'No images found.'): _ = image_dataset.image_dataset_from_directory(directory) + def test_image_dataset_from_directory_smart_resize(self): + if PIL is None: + return # Skip test if PIL is not available. + + directory = self._prepare_directory(num_classes=2, count=5) + dataset = image_dataset.image_dataset_from_directory( + directory, batch_size=5, image_size=(18, 18), smart_resize=True) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (5, 18, 18, 3)) + def test_image_dataset_from_directory_errors(self): if PIL is None: return # Skip test if PIL is not available. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt index 3189c502774..61c1a776b23 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "image_dataset_from_directory" - argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\'], " + argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\', \'smart_resize\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\', \'False\'], " } member_method { name: "text_dataset_from_directory"