Explicitly setting shape in grayscale_to_rgb and rgb_to_grayscale.
Fixes shape inference following a previous change that generalized these two members of image_ops. The more general implementation, however, does not allow shape inference. This change explicitly fixes the shape with set_shape, following array_grad._TileGrad. Adding tests for expected shape inference behavior. Change: 111459347
This commit is contained in:
parent
96689166e9
commit
be39348393
@ -966,7 +966,7 @@ def rgb_to_grayscale(images):
|
||||
gray_float = math_ops.reduce_sum(flt_image * rgb_weights,
|
||||
rank_1,
|
||||
keep_dims=True)
|
||||
|
||||
gray_float.set_shape(images.get_shape()[:-1].concatenate([1]))
|
||||
return convert_image_dtype(gray_float, orig_dtype)
|
||||
|
||||
|
||||
@ -988,7 +988,9 @@ def grayscale_to_rgb(images):
|
||||
[array_ops.ones(rank_1,
|
||||
dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
|
||||
multiples = array_ops.concat(0, shape_list)
|
||||
return array_ops.tile(images, multiples)
|
||||
rgb = array_ops.tile(images, multiples)
|
||||
rgb.set_shape(images.get_shape()[:-1].concatenate([3]))
|
||||
return rgb
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
@ -133,6 +133,31 @@ class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
|
||||
y_tf = y.eval()
|
||||
self.assertAllEqual(y_tf, y_np)
|
||||
|
||||
def testShapeInference(self):
|
||||
# Shape inference works and produces expected output where possible
|
||||
rgb_shape = [7, None, 19, 3]
|
||||
gray_shape = rgb_shape[:-1] + [1]
|
||||
with self.test_session():
|
||||
rgb_tf = array_ops.placeholder(dtypes.uint8, shape=rgb_shape)
|
||||
gray = image_ops.rgb_to_grayscale(rgb_tf)
|
||||
self.assertEqual(gray_shape, gray.get_shape().as_list())
|
||||
|
||||
with self.test_session():
|
||||
gray_tf = array_ops.placeholder(dtypes.uint8, shape=gray_shape)
|
||||
rgb = image_ops.grayscale_to_rgb(gray_tf)
|
||||
self.assertEqual(rgb_shape, rgb.get_shape().as_list())
|
||||
|
||||
# Shape inference does not break for unknown shapes
|
||||
with self.test_session():
|
||||
rgb_tf_unknown = array_ops.placeholder(dtypes.uint8)
|
||||
gray_unknown = image_ops.rgb_to_grayscale(rgb_tf_unknown)
|
||||
self.assertFalse(gray_unknown.get_shape())
|
||||
|
||||
with self.test_session():
|
||||
gray_tf_unknown = array_ops.placeholder(dtypes.uint8)
|
||||
rgb_unknown = image_ops.grayscale_to_rgb(gray_tf_unknown)
|
||||
self.assertFalse(rgb_unknown.get_shape())
|
||||
|
||||
|
||||
class AdjustHueTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user