Explicitly setting shape in grayscale_to_rgb and rgb_to_grayscale.

Fixes shape inference following a previous change that generalized these two members of image_ops. The more general implementation, however, does not allow shape inference. This change explicitly fixes the shape with set_shape, following array_grad._TileGrad.

Adding tests for expected shape inference behavior.
Change: 111459347
This commit is contained in:
A. Unique TensorFlower 2016-01-06 12:51:57 -08:00 committed by Vijay Vasudevan
parent 96689166e9
commit be39348393
2 changed files with 29 additions and 2 deletions

View File

@ -966,7 +966,7 @@ def rgb_to_grayscale(images):
gray_float = math_ops.reduce_sum(flt_image * rgb_weights,
rank_1,
keep_dims=True)
gray_float.set_shape(images.get_shape()[:-1].concatenate([1]))
return convert_image_dtype(gray_float, orig_dtype)
@ -988,7 +988,9 @@ def grayscale_to_rgb(images):
[array_ops.ones(rank_1,
dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
multiples = array_ops.concat(0, shape_list)
return array_ops.tile(images, multiples)
rgb = array_ops.tile(images, multiples)
rgb.set_shape(images.get_shape()[:-1].concatenate([3]))
return rgb
# pylint: disable=invalid-name

View File

@ -133,6 +133,31 @@ class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
def testShapeInference(self):
# Shape inference works and produces expected output where possible
rgb_shape = [7, None, 19, 3]
gray_shape = rgb_shape[:-1] + [1]
with self.test_session():
rgb_tf = array_ops.placeholder(dtypes.uint8, shape=rgb_shape)
gray = image_ops.rgb_to_grayscale(rgb_tf)
self.assertEqual(gray_shape, gray.get_shape().as_list())
with self.test_session():
gray_tf = array_ops.placeholder(dtypes.uint8, shape=gray_shape)
rgb = image_ops.grayscale_to_rgb(gray_tf)
self.assertEqual(rgb_shape, rgb.get_shape().as_list())
# Shape inference does not break for unknown shapes
with self.test_session():
rgb_tf_unknown = array_ops.placeholder(dtypes.uint8)
gray_unknown = image_ops.rgb_to_grayscale(rgb_tf_unknown)
self.assertFalse(gray_unknown.get_shape())
with self.test_session():
gray_tf_unknown = array_ops.placeholder(dtypes.uint8)
rgb_unknown = image_ops.grayscale_to_rgb(gray_tf_unknown)
self.assertFalse(rgb_unknown.get_shape())
class AdjustHueTest(test_util.TensorFlowTestCase):