Add smart_resize option to image_dataset_from_directory.

PiperOrigin-RevId: 351188455
Change-Id: Ifb24fabc529c266f5c99d9486f191b40ebfd76cf
This commit is contained in:
Francois Chollet 2021-01-11 10:48:22 -08:00 committed by TensorFlower Gardener
parent 4544864a34
commit 9349a096eb
3 changed files with 34 additions and 7 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.preprocessing import dataset_utils
from tensorflow.python.keras.preprocessing import image as keras_image_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.util.tf_export import keras_export
@ -44,7 +45,8 @@ def image_dataset_from_directory(directory,
validation_split=None,
subset=None,
interpolation='bilinear',
follow_links=False):
follow_links=False,
smart_resize=False):
"""Generates a `tf.data.Dataset` from image files in a directory.
If your directory structure is:
@ -113,6 +115,11 @@ def image_dataset_from_directory(directory,
`area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`.
follow_links: Whether to visits subdirectories pointed to by symlinks.
Defaults to False.
smart_resize: If True, the resizing function used will be
`tf.keras.preprocessing.image.smart_resize`, which preserves the aspect
ratio of the original image by using a mixture of resizing and cropping.
If False (default), the resizing function is `tf.image.resize`, which
does not preserve aspect ratio.
Returns:
A `tf.data.Dataset` object.
@ -202,7 +209,8 @@ def image_dataset_from_directory(directory,
labels=labels,
label_mode=label_mode,
num_classes=len(class_names),
interpolation=interpolation)
interpolation=interpolation,
smart_resize=smart_resize)
if shuffle:
# Shuffle locally at each iteration
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
@ -220,22 +228,30 @@ def paths_and_labels_to_dataset(image_paths,
labels,
label_mode,
num_classes,
interpolation):
interpolation,
smart_resize=False):
"""Constructs a dataset of images and labels."""
# TODO(fchollet): consider making num_parallel_calls settable
path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths)
args = (image_size, num_channels, interpolation, smart_resize)
img_ds = path_ds.map(
lambda x: path_to_image(x, image_size, num_channels, interpolation))
lambda x: load_image(x, *args))
if label_mode:
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
return img_ds
def path_to_image(path, image_size, num_channels, interpolation):
def load_image(path, image_size, num_channels, interpolation,
smart_resize=False):
"""Load an image from a path and resize it."""
img = io_ops.read_file(path)
img = image_ops.decode_image(
img, channels=num_channels, expand_animations=False)
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
if smart_resize:
img = keras_image_ops.smart_resize(img, image_size,
interpolation=interpolation)
else:
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
img.set_shape((image_size[0], image_size[1], num_channels))
return img

View File

@ -284,6 +284,17 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'No images found.'):
_ = image_dataset.image_dataset_from_directory(directory)
def test_image_dataset_from_directory_smart_resize(self):
if PIL is None:
return # Skip test if PIL is not available.
directory = self._prepare_directory(num_classes=2, count=5)
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=5, image_size=(18, 18), smart_resize=True)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (5, 18, 18, 3))
def test_image_dataset_from_directory_errors(self):
if PIL is None:
return # Skip test if PIL is not available.

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "image_dataset_from_directory"
argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\'], "
argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\', \'smart_resize\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\', \'False\'], "
}
member_method {
name: "text_dataset_from_directory"