Fix disappearing channel dimension in static shapes of images from image_dataset_from_directory
.
PiperOrigin-RevId: 315965108 Change-Id: Ic2a41855c9013862502a034565ee19a1c4ead354
This commit is contained in:
parent
c3067cb521
commit
47a391a364
@ -228,4 +228,6 @@ def path_to_image(path, image_size, num_channels, interpolation):
|
||||
img = io_ops.read_file(path)
|
||||
img = image_ops.decode_image(
|
||||
img, channels=num_channels, expand_animations=False)
|
||||
return image_ops.resize_images_v2(img, image_size, method=interpolation)
|
||||
img = image_ops.resize_images_v2(img, image_size, method=interpolation)
|
||||
img.set_shape((image_size[0], image_size[1], num_channels))
|
||||
return img
|
||||
|
@ -24,6 +24,7 @@ import shutil
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.preprocessing import image as image_preproc
|
||||
from tensorflow.python.keras.preprocessing import image_dataset
|
||||
@ -123,6 +124,22 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(batch[1].shape, (8, 2))
|
||||
self.assertEqual(batch[1].dtype.name, 'float32')
|
||||
|
||||
def test_static_shape_in_graph(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
|
||||
directory = self._prepare_directory(num_classes=2)
|
||||
dataset = image_dataset.image_dataset_from_directory(
|
||||
directory, batch_size=8, image_size=(18, 18), label_mode='int')
|
||||
test_case = self
|
||||
|
||||
@def_function.function
|
||||
def symbolic_fn(ds):
|
||||
for x, _ in ds.take(1):
|
||||
test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3])
|
||||
|
||||
symbolic_fn(dataset)
|
||||
|
||||
def test_sample_count(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
|
Loading…
x
Reference in New Issue
Block a user