Enable tf.image.ssim to run on tensor-like inputs and add SSIM name space.
PiperOrigin-RevId: 317676190 Change-Id: I08408eaf397ace235b5f50513096cbd9ba46d5a8
This commit is contained in:
parent
251923169d
commit
834f2bd726
|
@ -3642,20 +3642,25 @@ def ssim(img1,
|
|||
values are in range (-1, 1], when pixel values are non-negative. Returns
|
||||
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
|
||||
"""
|
||||
_, _, checks = _verify_compatible_image_shapes(img1, img2)
|
||||
with ops.control_dependencies(checks):
|
||||
img1 = array_ops.identity(img1)
|
||||
with ops.name_scope(None, 'SSIM', [img1, img2]):
|
||||
# Convert to tensor if needed.
|
||||
img1 = ops.convert_to_tensor(img1, name='img1')
|
||||
img2 = ops.convert_to_tensor(img2, name='img2')
|
||||
# Shape checking.
|
||||
_, _, checks = _verify_compatible_image_shapes(img1, img2)
|
||||
with ops.control_dependencies(checks):
|
||||
img1 = array_ops.identity(img1)
|
||||
|
||||
# Need to convert the images to float32. Scale max_val accordingly so that
|
||||
# SSIM is computed correctly.
|
||||
max_val = math_ops.cast(max_val, img1.dtype)
|
||||
max_val = convert_image_dtype(max_val, dtypes.float32)
|
||||
img1 = convert_image_dtype(img1, dtypes.float32)
|
||||
img2 = convert_image_dtype(img2, dtypes.float32)
|
||||
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
|
||||
filter_sigma, k1, k2)
|
||||
# Compute average over color channels.
|
||||
return math_ops.reduce_mean(ssim_per_channel, [-1])
|
||||
# Need to convert the images to float32. Scale max_val accordingly so that
|
||||
# SSIM is computed correctly.
|
||||
max_val = math_ops.cast(max_val, img1.dtype)
|
||||
max_val = convert_image_dtype(max_val, dtypes.float32)
|
||||
img1 = convert_image_dtype(img1, dtypes.float32)
|
||||
img2 = convert_image_dtype(img2, dtypes.float32)
|
||||
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
|
||||
filter_sigma, k1, k2)
|
||||
# Compute average over color channels.
|
||||
return math_ops.reduce_mean(ssim_per_channel, [-1])
|
||||
|
||||
|
||||
# Default values obtained by Wang et al.
|
||||
|
|
|
@ -4865,6 +4865,29 @@ class SSIMTest(test_util.TensorFlowTestCase):
|
|||
with self.cached_session(use_gpu=True):
|
||||
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
|
||||
|
||||
def testBatchNumpyInputs(self):
|
||||
img = self._LoadTestImages()
|
||||
expected = self._ssim[np.triu_indices(3, k=1)]
|
||||
|
||||
img1, img2 = zip(*itertools.combinations(img, 2))
|
||||
img1 = np.concatenate(img1)
|
||||
img2 = np.concatenate(img2)
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
img1 = self.evaluate(constant_op.constant(img1))
|
||||
img2 = self.evaluate(constant_op.constant(img2))
|
||||
|
||||
ssim = image_ops.ssim(
|
||||
img1,
|
||||
img2,
|
||||
1.0,
|
||||
filter_size=11,
|
||||
filter_sigma=1.5,
|
||||
k1=0.01,
|
||||
k2=0.03)
|
||||
with self.cached_session(use_gpu=True):
|
||||
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
|
||||
|
||||
def testBroadcast(self):
|
||||
img = self._LoadTestImages()[:2]
|
||||
expected = self._ssim[:2, :2]
|
||||
|
|
Loading…
Reference in New Issue