diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index a934639d524..683681b5c98 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3642,20 +3642,25 @@ def ssim(img1, values are in range (-1, 1], when pixel values are non-negative. Returns a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). """ - _, _, checks = _verify_compatible_image_shapes(img1, img2) - with ops.control_dependencies(checks): - img1 = array_ops.identity(img1) + with ops.name_scope(None, 'SSIM', [img1, img2]): + # Convert to tensor if needed. + img1 = ops.convert_to_tensor(img1, name='img1') + img2 = ops.convert_to_tensor(img2, name='img2') + # Shape checking. + _, _, checks = _verify_compatible_image_shapes(img1, img2) + with ops.control_dependencies(checks): + img1 = array_ops.identity(img1) - # Need to convert the images to float32. Scale max_val accordingly so that - # SSIM is computed correctly. - max_val = math_ops.cast(max_val, img1.dtype) - max_val = convert_image_dtype(max_val, dtypes.float32) - img1 = convert_image_dtype(img1, dtypes.float32) - img2 = convert_image_dtype(img2, dtypes.float32) - ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size, - filter_sigma, k1, k2) - # Compute average over color channels. - return math_ops.reduce_mean(ssim_per_channel, [-1]) + # Need to convert the images to float32. Scale max_val accordingly so that + # SSIM is computed correctly. + max_val = math_ops.cast(max_val, img1.dtype) + max_val = convert_image_dtype(max_val, dtypes.float32) + img1 = convert_image_dtype(img1, dtypes.float32) + img2 = convert_image_dtype(img2, dtypes.float32) + ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size, + filter_sigma, k1, k2) + # Compute average over color channels. + return math_ops.reduce_mean(ssim_per_channel, [-1]) # Default values obtained by Wang et al. diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 3530885fe07..0206ccf9b33 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -4865,6 +4865,29 @@ class SSIMTest(test_util.TensorFlowTestCase): with self.cached_session(use_gpu=True): self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4) + def testBatchNumpyInputs(self): + img = self._LoadTestImages() + expected = self._ssim[np.triu_indices(3, k=1)] + + img1, img2 = zip(*itertools.combinations(img, 2)) + img1 = np.concatenate(img1) + img2 = np.concatenate(img2) + + with self.cached_session(use_gpu=True): + img1 = self.evaluate(constant_op.constant(img1)) + img2 = self.evaluate(constant_op.constant(img2)) + + ssim = image_ops.ssim( + img1, + img2, + 1.0, + filter_size=11, + filter_sigma=1.5, + k1=0.01, + k2=0.03) + with self.cached_session(use_gpu=True): + self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4) + def testBroadcast(self): img = self._LoadTestImages()[:2] expected = self._ssim[:2, :2]