Merge pull request #28274 from yongtang:28241-ssim_multiscale
PiperOrigin-RevId: 245969013
This commit is contained in:
commit
3375fdf82e
@ -3098,12 +3098,11 @@ def ssim_multiscale(img1,
|
|||||||
are in range [0, 1]. Returns a tensor with shape:
|
are in range [0, 1]. Returns a tensor with shape:
|
||||||
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
||||||
"""
|
"""
|
||||||
# Shape checking.
|
|
||||||
shape1 = img1.get_shape().with_rank_at_least(3)
|
|
||||||
shape2 = img2.get_shape().with_rank_at_least(3)
|
|
||||||
shape1[-3:].merge_with(shape2[-3:])
|
|
||||||
|
|
||||||
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
|
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
|
||||||
|
# Convert to tensor if needed.
|
||||||
|
img1 = ops.convert_to_tensor(img1, name='img1')
|
||||||
|
img2 = ops.convert_to_tensor(img2, name='img2')
|
||||||
|
# Shape checking.
|
||||||
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
||||||
with ops.control_dependencies(checks):
|
with ops.control_dependencies(checks):
|
||||||
img1 = array_ops.identity(img1)
|
img1 = array_ops.identity(img1)
|
||||||
|
@ -4895,6 +4895,13 @@ class MultiscaleSSIMTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
|
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
|
||||||
|
|
||||||
|
def testNumpyInput(self):
|
||||||
|
"""Test case for GitHub issue 28241."""
|
||||||
|
image = np.random.random([512, 512, 1])
|
||||||
|
score_tensor = image_ops.ssim_multiscale(image, image, max_val=1.0)
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
_ = self.evaluate(score_tensor)
|
||||||
|
|
||||||
|
|
||||||
class ImageGradientsTest(test_util.TensorFlowTestCase):
|
class ImageGradientsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user