Enable tf.image.ssim to run on tensor-like inputs and add SSIM name space.

PiperOrigin-RevId: 317676190
Change-Id: I08408eaf397ace235b5f50513096cbd9ba46d5a8
This commit is contained in:
A. Unique TensorFlower 2020-06-22 10:01:26 -07:00 committed by TensorFlower Gardener
parent 251923169d
commit 834f2bd726
2 changed files with 41 additions and 13 deletions

View File

@ -3642,6 +3642,11 @@ def ssim(img1,
values are in range (-1, 1], when pixel values are non-negative. Returns values are in range (-1, 1], when pixel values are non-negative. Returns
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
""" """
with ops.name_scope(None, 'SSIM', [img1, img2]):
# Convert to tensor if needed.
img1 = ops.convert_to_tensor(img1, name='img1')
img2 = ops.convert_to_tensor(img2, name='img2')
# Shape checking.
_, _, checks = _verify_compatible_image_shapes(img1, img2) _, _, checks = _verify_compatible_image_shapes(img1, img2)
with ops.control_dependencies(checks): with ops.control_dependencies(checks):
img1 = array_ops.identity(img1) img1 = array_ops.identity(img1)

View File

@ -4865,6 +4865,29 @@ class SSIMTest(test_util.TensorFlowTestCase):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4) self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBatchNumpyInputs(self):
img = self._LoadTestImages()
expected = self._ssim[np.triu_indices(3, k=1)]
img1, img2 = zip(*itertools.combinations(img, 2))
img1 = np.concatenate(img1)
img2 = np.concatenate(img2)
with self.cached_session(use_gpu=True):
img1 = self.evaluate(constant_op.constant(img1))
img2 = self.evaluate(constant_op.constant(img2))
ssim = image_ops.ssim(
img1,
img2,
1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBroadcast(self): def testBroadcast(self):
img = self._LoadTestImages()[:2] img = self._LoadTestImages()[:2]
expected = self._ssim[:2, :2] expected = self._ssim[:2, :2]