diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 6aeb6f40f1c..c2ebb870f2c 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3060,12 +3060,11 @@ def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS): are in range [0, 1]. Returns a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). """ - # Shape checking. - shape1 = img1.get_shape().with_rank_at_least(3) - shape2 = img2.get_shape().with_rank_at_least(3) - shape1[-3:].merge_with(shape2[-3:]) - with ops.name_scope(None, 'MS-SSIM', [img1, img2]): + # Convert to tensor if needed. + img1 = ops.convert_to_tensor(img1, name='img1') + img2 = ops.convert_to_tensor(img2, name='img2') + # Shape checking. shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) with ops.control_dependencies(checks): img1 = array_ops.identity(img1)