diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index a438c429c40..d287c4ef372 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -228,4 +228,6 @@ def path_to_image(path, image_size, num_channels, interpolation): img = io_ops.read_file(path) img = image_ops.decode_image( img, channels=num_channels, expand_animations=False) - return image_ops.resize_images_v2(img, image_size, method=interpolation) + img = image_ops.resize_images_v2(img, image_size, method=interpolation) + img.set_shape((image_size[0], image_size[1], num_channels)) + return img diff --git a/tensorflow/python/keras/preprocessing/image_dataset_test.py b/tensorflow/python/keras/preprocessing/image_dataset_test.py index b196d8249ff..efc3faed0cc 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset_test.py +++ b/tensorflow/python/keras/preprocessing/image_dataset_test.py @@ -24,6 +24,7 @@ import shutil import numpy as np from tensorflow.python.compat import v2_compat +from tensorflow.python.eager import def_function from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.preprocessing import image as image_preproc from tensorflow.python.keras.preprocessing import image_dataset @@ -123,6 +124,22 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): self.assertEqual(batch[1].shape, (8, 2)) self.assertEqual(batch[1].dtype.name, 'float32') + def test_static_shape_in_graph(self): + if PIL is None: + return # Skip test if PIL is not available. + + directory = self._prepare_directory(num_classes=2) + dataset = image_dataset.image_dataset_from_directory( + directory, batch_size=8, image_size=(18, 18), label_mode='int') + test_case = self + + @def_function.function + def symbolic_fn(ds): + for x, _ in ds.take(1): + test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3]) + + symbolic_fn(dataset) + def test_sample_count(self): if PIL is None: return # Skip test if PIL is not available.