Merge pull request #28274 from yongtang:28241-ssim_multiscale

PiperOrigin-RevId: 245969013
This commit is contained in:
TensorFlower Gardener 2019-04-30 10:14:54 -07:00
commit 3375fdf82e
2 changed files with 11 additions and 5 deletions

View File

@ -3098,12 +3098,11 @@ def ssim_multiscale(img1,
are in range [0, 1]. Returns a tensor with shape:
broadcast(img1.shape[:-3], img2.shape[:-3]).
"""
# Shape checking.
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].merge_with(shape2[-3:])
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
# Convert to tensor if needed.
img1 = ops.convert_to_tensor(img1, name='img1')
img2 = ops.convert_to_tensor(img2, name='img2')
# Shape checking.
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
with ops.control_dependencies(checks):
img1 = array_ops.identity(img1)

View File

@ -4895,6 +4895,13 @@ class MultiscaleSSIMTest(test_util.TensorFlowTestCase):
self.assertAllClose(
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
def testNumpyInput(self):
"""Test case for GitHub issue 28241."""
image = np.random.random([512, 512, 1])
score_tensor = image_ops.ssim_multiscale(image, image, max_val=1.0)
with self.cached_session(use_gpu=True):
_ = self.evaluate(score_tensor)
class ImageGradientsTest(test_util.TensorFlowTestCase):