Fix tf.image.ssim_multiscale execution issue

This fix tries to address the issue raised in 28241 where
tf.image.ssim_multiscale will throw out error if the input
is numpy array instead of tensor.

This fix adds the conversion (also removed duplicated shape checking).

This fix fixes 28241.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2019-04-30 01:08:19 +00:00
parent b6e4bf24eb
commit 6c3cbe4434

View File

@ -3060,12 +3060,11 @@ def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS):
are in range [0, 1]. Returns a tensor with shape:
broadcast(img1.shape[:-3], img2.shape[:-3]).
"""
# Shape checking.
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].merge_with(shape2[-3:])
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
# Convert to tensor if needed.
img1 = ops.convert_to_tensor(img1, name='img1')
img2 = ops.convert_to_tensor(img2, name='img2')
# Shape checking.
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
with ops.control_dependencies(checks):
img1 = array_ops.identity(img1)