Enable tf.image.ssim to run on tensor-like inputs and add SSIM name space.

PiperOrigin-RevId: 317676190
Change-Id: I08408eaf397ace235b5f50513096cbd9ba46d5a8
This commit is contained in:
A. Unique TensorFlower 2020-06-22 10:01:26 -07:00 committed by TensorFlower Gardener
parent 251923169d
commit 834f2bd726
2 changed files with 41 additions and 13 deletions

View File

@ -3642,20 +3642,25 @@ def ssim(img1,
values are in range (-1, 1], when pixel values are non-negative. Returns
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
"""
_, _, checks = _verify_compatible_image_shapes(img1, img2)
with ops.control_dependencies(checks):
img1 = array_ops.identity(img1)
with ops.name_scope(None, 'SSIM', [img1, img2]):
# Convert to tensor if needed.
img1 = ops.convert_to_tensor(img1, name='img1')
img2 = ops.convert_to_tensor(img2, name='img2')
# Shape checking.
_, _, checks = _verify_compatible_image_shapes(img1, img2)
with ops.control_dependencies(checks):
img1 = array_ops.identity(img1)
# Need to convert the images to float32. Scale max_val accordingly so that
# SSIM is computed correctly.
max_val = math_ops.cast(max_val, img1.dtype)
max_val = convert_image_dtype(max_val, dtypes.float32)
img1 = convert_image_dtype(img1, dtypes.float32)
img2 = convert_image_dtype(img2, dtypes.float32)
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2)
# Compute average over color channels.
return math_ops.reduce_mean(ssim_per_channel, [-1])
# Need to convert the images to float32. Scale max_val accordingly so that
# SSIM is computed correctly.
max_val = math_ops.cast(max_val, img1.dtype)
max_val = convert_image_dtype(max_val, dtypes.float32)
img1 = convert_image_dtype(img1, dtypes.float32)
img2 = convert_image_dtype(img2, dtypes.float32)
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2)
# Compute average over color channels.
return math_ops.reduce_mean(ssim_per_channel, [-1])
# Default values obtained by Wang et al.

View File

@ -4865,6 +4865,29 @@ class SSIMTest(test_util.TensorFlowTestCase):
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBatchNumpyInputs(self):
img = self._LoadTestImages()
expected = self._ssim[np.triu_indices(3, k=1)]
img1, img2 = zip(*itertools.combinations(img, 2))
img1 = np.concatenate(img1)
img2 = np.concatenate(img2)
with self.cached_session(use_gpu=True):
img1 = self.evaluate(constant_op.constant(img1))
img2 = self.evaluate(constant_op.constant(img2))
ssim = image_ops.ssim(
img1,
img2,
1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBroadcast(self):
img = self._LoadTestImages()[:2]
expected = self._ssim[:2, :2]