Add smart_resize option to image_dataset_from_directory.
PiperOrigin-RevId: 351188455 Change-Id: Ifb24fabc529c266f5c99d9486f191b40ebfd76cf
This commit is contained in:
parent
4544864a34
commit
9349a096eb
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.preprocessing import dataset_utils
|
||||
from tensorflow.python.keras.preprocessing import image as keras_image_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -44,7 +45,8 @@ def image_dataset_from_directory(directory,
|
||||
validation_split=None,
|
||||
subset=None,
|
||||
interpolation='bilinear',
|
||||
follow_links=False):
|
||||
follow_links=False,
|
||||
smart_resize=False):
|
||||
"""Generates a `tf.data.Dataset` from image files in a directory.
|
||||
|
||||
If your directory structure is:
|
||||
@ -113,6 +115,11 @@ def image_dataset_from_directory(directory,
|
||||
`area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`.
|
||||
follow_links: Whether to visits subdirectories pointed to by symlinks.
|
||||
Defaults to False.
|
||||
smart_resize: If True, the resizing function used will be
|
||||
`tf.keras.preprocessing.image.smart_resize`, which preserves the aspect
|
||||
ratio of the original image by using a mixture of resizing and cropping.
|
||||
If False (default), the resizing function is `tf.image.resize`, which
|
||||
does not preserve aspect ratio.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` object.
|
||||
@ -202,7 +209,8 @@ def image_dataset_from_directory(directory,
|
||||
labels=labels,
|
||||
label_mode=label_mode,
|
||||
num_classes=len(class_names),
|
||||
interpolation=interpolation)
|
||||
interpolation=interpolation,
|
||||
smart_resize=smart_resize)
|
||||
if shuffle:
|
||||
# Shuffle locally at each iteration
|
||||
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
|
||||
@ -220,22 +228,30 @@ def paths_and_labels_to_dataset(image_paths,
|
||||
labels,
|
||||
label_mode,
|
||||
num_classes,
|
||||
interpolation):
|
||||
interpolation,
|
||||
smart_resize=False):
|
||||
"""Constructs a dataset of images and labels."""
|
||||
# TODO(fchollet): consider making num_parallel_calls settable
|
||||
path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths)
|
||||
args = (image_size, num_channels, interpolation, smart_resize)
|
||||
img_ds = path_ds.map(
|
||||
lambda x: path_to_image(x, image_size, num_channels, interpolation))
|
||||
lambda x: load_image(x, *args))
|
||||
if label_mode:
|
||||
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
|
||||
img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
|
||||
return img_ds
|
||||
|
||||
|
||||
def path_to_image(path, image_size, num_channels, interpolation):
|
||||
def load_image(path, image_size, num_channels, interpolation,
|
||||
smart_resize=False):
|
||||
"""Load an image from a path and resize it."""
|
||||
img = io_ops.read_file(path)
|
||||
img = image_ops.decode_image(
|
||||
img, channels=num_channels, expand_animations=False)
|
||||
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
|
||||
if smart_resize:
|
||||
img = keras_image_ops.smart_resize(img, image_size,
|
||||
interpolation=interpolation)
|
||||
else:
|
||||
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
|
||||
img.set_shape((image_size[0], image_size[1], num_channels))
|
||||
return img
|
||||
|
@ -284,6 +284,17 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'No images found.'):
|
||||
_ = image_dataset.image_dataset_from_directory(directory)
|
||||
|
||||
def test_image_dataset_from_directory_smart_resize(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
|
||||
directory = self._prepare_directory(num_classes=2, count=5)
|
||||
dataset = image_dataset.image_dataset_from_directory(
|
||||
directory, batch_size=5, image_size=(18, 18), smart_resize=True)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (5, 18, 18, 3))
|
||||
|
||||
def test_image_dataset_from_directory_errors(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "image_dataset_from_directory"
|
||||
argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\'], "
|
||||
argspec: "args=[\'directory\', \'labels\', \'label_mode\', \'class_names\', \'color_mode\', \'batch_size\', \'image_size\', \'shuffle\', \'seed\', \'validation_split\', \'subset\', \'interpolation\', \'follow_links\', \'smart_resize\'], varargs=None, keywords=None, defaults=[\'inferred\', \'int\', \'None\', \'rgb\', \'32\', \'(256, 256)\', \'True\', \'None\', \'None\', \'None\', \'bilinear\', \'False\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "text_dataset_from_directory"
|
||||
|
Loading…
x
Reference in New Issue
Block a user