Merge pull request #28274 from yongtang:28241-ssim_multiscale
PiperOrigin-RevId: 245969013
This commit is contained in:
commit
3375fdf82e
@ -3098,12 +3098,11 @@ def ssim_multiscale(img1,
|
||||
are in range [0, 1]. Returns a tensor with shape:
|
||||
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
||||
"""
|
||||
# Shape checking.
|
||||
shape1 = img1.get_shape().with_rank_at_least(3)
|
||||
shape2 = img2.get_shape().with_rank_at_least(3)
|
||||
shape1[-3:].merge_with(shape2[-3:])
|
||||
|
||||
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
|
||||
# Convert to tensor if needed.
|
||||
img1 = ops.convert_to_tensor(img1, name='img1')
|
||||
img2 = ops.convert_to_tensor(img2, name='img2')
|
||||
# Shape checking.
|
||||
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
||||
with ops.control_dependencies(checks):
|
||||
img1 = array_ops.identity(img1)
|
||||
|
@ -4895,6 +4895,13 @@ class MultiscaleSSIMTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(
|
||||
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
|
||||
|
||||
def testNumpyInput(self):
|
||||
"""Test case for GitHub issue 28241."""
|
||||
image = np.random.random([512, 512, 1])
|
||||
score_tensor = image_ops.ssim_multiscale(image, image, max_val=1.0)
|
||||
with self.cached_session(use_gpu=True):
|
||||
_ = self.evaluate(score_tensor)
|
||||
|
||||
|
||||
class ImageGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user