Enable tf.image.ssim to run on tensor-like inputs and add SSIM name space.
PiperOrigin-RevId: 317676190 Change-Id: I08408eaf397ace235b5f50513096cbd9ba46d5a8
This commit is contained in:
parent
251923169d
commit
834f2bd726
|
@ -3642,6 +3642,11 @@ def ssim(img1,
|
||||||
values are in range (-1, 1], when pixel values are non-negative. Returns
|
values are in range (-1, 1], when pixel values are non-negative. Returns
|
||||||
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
|
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'SSIM', [img1, img2]):
|
||||||
|
# Convert to tensor if needed.
|
||||||
|
img1 = ops.convert_to_tensor(img1, name='img1')
|
||||||
|
img2 = ops.convert_to_tensor(img2, name='img2')
|
||||||
|
# Shape checking.
|
||||||
_, _, checks = _verify_compatible_image_shapes(img1, img2)
|
_, _, checks = _verify_compatible_image_shapes(img1, img2)
|
||||||
with ops.control_dependencies(checks):
|
with ops.control_dependencies(checks):
|
||||||
img1 = array_ops.identity(img1)
|
img1 = array_ops.identity(img1)
|
||||||
|
|
|
@ -4865,6 +4865,29 @@ class SSIMTest(test_util.TensorFlowTestCase):
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
|
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
|
||||||
|
|
||||||
|
def testBatchNumpyInputs(self):
|
||||||
|
img = self._LoadTestImages()
|
||||||
|
expected = self._ssim[np.triu_indices(3, k=1)]
|
||||||
|
|
||||||
|
img1, img2 = zip(*itertools.combinations(img, 2))
|
||||||
|
img1 = np.concatenate(img1)
|
||||||
|
img2 = np.concatenate(img2)
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
img1 = self.evaluate(constant_op.constant(img1))
|
||||||
|
img2 = self.evaluate(constant_op.constant(img2))
|
||||||
|
|
||||||
|
ssim = image_ops.ssim(
|
||||||
|
img1,
|
||||||
|
img2,
|
||||||
|
1.0,
|
||||||
|
filter_size=11,
|
||||||
|
filter_sigma=1.5,
|
||||||
|
k1=0.01,
|
||||||
|
k2=0.03)
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
|
||||||
|
|
||||||
def testBroadcast(self):
|
def testBroadcast(self):
|
||||||
img = self._LoadTestImages()[:2]
|
img = self._LoadTestImages()[:2]
|
||||||
expected = self._ssim[:2, :2]
|
expected = self._ssim[:2, :2]
|
||||||
|
|
Loading…
Reference in New Issue