Fix disappearing channel dimension in static shapes of images from image_dataset_from_directory.

PiperOrigin-RevId: 315965108
Change-Id: Ic2a41855c9013862502a034565ee19a1c4ead354
This commit is contained in:
Francois Chollet 2020-06-11 13:23:37 -07:00 committed by TensorFlower Gardener
parent c3067cb521
commit 47a391a364
2 changed files with 20 additions and 1 deletions

View File

@ -228,4 +228,6 @@ def path_to_image(path, image_size, num_channels, interpolation):
img = io_ops.read_file(path)
img = image_ops.decode_image(
img, channels=num_channels, expand_animations=False)
return image_ops.resize_images_v2(img, image_size, method=interpolation)
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
img.set_shape((image_size[0], image_size[1], num_channels))
return img

View File

@ -24,6 +24,7 @@ import shutil
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.preprocessing import image as image_preproc
from tensorflow.python.keras.preprocessing import image_dataset
@ -123,6 +124,22 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
self.assertEqual(batch[1].shape, (8, 2))
self.assertEqual(batch[1].dtype.name, 'float32')
def test_static_shape_in_graph(self):
if PIL is None:
return # Skip test if PIL is not available.
directory = self._prepare_directory(num_classes=2)
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=8, image_size=(18, 18), label_mode='int')
test_case = self
@def_function.function
def symbolic_fn(ds):
for x, _ in ds.take(1):
test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3])
symbolic_fn(dataset)
def test_sample_count(self):
if PIL is None:
return # Skip test if PIL is not available.