From 38213428929862120bdb32b963a600792ec6feb2 Mon Sep 17 00:00:00 2001
From: Hye Soo Yang <hyey@google.com>
Date: Tue, 29 Dec 2020 20:44:36 -0800
Subject: [PATCH] Eager execution coverage for
 extract_image_patches_grad_test.py. Removed  run_deprecated_v1 decorators.

PiperOrigin-RevId: 349505498
Change-Id: I4a43b1be685d1fef4a9e7ed91e265617c7187feb
---
 tensorflow/python/kernel_tests/BUILD          |   2 +-
 .../extract_image_patches_grad_test.py        | 124 ++++++++++--------
 2 files changed, 71 insertions(+), 55 deletions(-)

diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index dc00408b9d8..e934d5c4134 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3229,7 +3229,7 @@ cuda_py_test(
     name = "extract_image_patches_grad_test",
     size = "medium",
     srcs = ["extract_image_patches_grad_test.py"],
-    shard_count = 3,
+    shard_count = 9,
     tags = ["notap"],  # http://b/31080670
     deps = [
         "//tensorflow/python:array_ops",
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index 694bd056037..88bcebde4a7 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -18,20 +18,23 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed as random_seed_lib
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import gradient_checker_v2
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
-class ExtractImagePatchesGradTest(test.TestCase):
+class ExtractImagePatchesGradTest(test.TestCase, parameterized.TestCase):
   """Gradient-checking for ExtractImagePatches op."""
 
   _TEST_CASES = [
@@ -79,7 +82,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
       },
   ]
 
-  @test_util.run_deprecated_v1
   def testGradient(self):
     # Set graph seed for determinism.
     random_seed = 42
@@ -91,80 +93,94 @@ class ExtractImagePatchesGradTest(test.TestCase):
         in_shape = test_case['in_shape']
         in_val = constant_op.constant(
             np.random.random(in_shape), dtype=dtypes.float32)
+        # Avoid `dangerous-default-value` pylint error by creating default
+        # args to `extract` as tuples.
+        ksizes = tuple(test_case['ksizes'])
+        strides = tuple(test_case['strides'])
+        rates = tuple(test_case['rates'])
 
         for padding in ['VALID', 'SAME']:
-          out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
-                                                    test_case['strides'],
-                                                    test_case['rates'], padding)
-          out_shape = out_val.get_shape().as_list()
 
-          err = gradient_checker.compute_gradient_error(in_val, in_shape,
-                                                        out_val, out_shape)
+          def extract(in_val,
+                      ksizes=ksizes,
+                      strides=strides,
+                      rates=rates,
+                      padding=padding):
+            return array_ops.extract_image_patches(in_val, ksizes, strides,
+                                                   rates, padding)
+
+          err = gradient_checker_v2.max_error(
+              *gradient_checker_v2.compute_gradient(extract, [in_val]))
           self.assertLess(err, 1e-4)
 
-  @test_util.run_deprecated_v1
-  def testConstructGradientWithLargeImages(self):
-    batch_size = 4
-    height = 1024
-    width = 1024
-    ksize = 5
-    images = variable_scope.get_variable('inputs',
-                                         (batch_size, height, width, 1))
-    patches = array_ops.extract_image_patches(images,
-                                              ksizes=[1, ksize, ksize, 1],
-                                              strides=[1, 1, 1, 1],
-                                              rates=[1, 1, 1, 1],
-                                              padding='SAME')
-    # Github issue: #20146
-    # tf.image.extract_image_patches() gradient very slow at graph construction
-    # time
-    gradients = gradients_impl.gradients(patches, images)
-    # Won't time out.
-    self.assertIsNotNone(gradients)
+  @parameterized.parameters(set((True, context.executing_eagerly())))
+  def testConstructGradientWithLargeImages(self, use_tape):
+    with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
+      batch_size = 4
+      # Prevent OOM by setting reasonably large image size (b/171808681).
+      height = 512
+      width = 512
+      ksize = 5
+      shape = (batch_size, height, width, 1)
+      images = variables.Variable(
+          np.random.uniform(size=np.prod(shape)).reshape(shape), name='inputs')
+      tape.watch(images)
+      patches = array_ops.extract_image_patches(images,
+                                                ksizes=[1, ksize, ksize, 1],
+                                                strides=[1, 1, 1, 1],
+                                                rates=[1, 1, 1, 1],
+                                                padding='SAME')
+      # Github issue: #20146
+      # tf.image.extract_image_patches() gradient very slow at graph
+      # construction time.
+      gradients = tape.gradient(patches, images)
+      # Won't time out.
+      self.assertIsNotNone(gradients)
 
   def _VariableShapeGradient(self, test_shape_pattern):
     """Use test_shape_pattern to infer which dimensions are of
 
     variable size.
     """
-    # Set graph seed for determinism.
-    random_seed = 42
-    random_seed_lib.set_random_seed(random_seed)
+    # Testing shape gradient requires graph mode.
+    with ops.Graph().as_default():
+      # Set graph seed for determinism.
+      random_seed = 42
+      random_seed_lib.set_random_seed(random_seed)
 
-    with self.test_session():
-      for test_case in self._TEST_CASES:
-        np.random.seed(random_seed)
-        in_shape = test_case['in_shape']
-        test_shape = [
-            x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
-        ]
-        in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
+      with self.test_session():
+        for test_case in self._TEST_CASES:
+          np.random.seed(random_seed)
+          in_shape = test_case['in_shape']
+          test_shape = [
+              x if x is None else y
+              for x, y in zip(test_shape_pattern, in_shape)
+          ]
+          in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
 
-        feed_dict = {in_val: np.random.random(in_shape)}
-        for padding in ['VALID', 'SAME']:
-          out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
-                                                    test_case['strides'],
-                                                    test_case['rates'], padding)
-          out_val_tmp = out_val.eval(feed_dict=feed_dict)
-          out_shape = out_val_tmp.shape
+          feed_dict = {in_val: np.random.random(in_shape)}
+          for padding in ['VALID', 'SAME']:
+            out_val = array_ops.extract_image_patches(in_val,
+                                                      test_case['ksizes'],
+                                                      test_case['strides'],
+                                                      test_case['rates'],
+                                                      padding)
+            out_val_tmp = out_val.eval(feed_dict=feed_dict)
+            out_shape = out_val_tmp.shape
 
-          err = gradient_checker.compute_gradient_error(in_val, in_shape,
-                                                        out_val, out_shape)
-          self.assertLess(err, 1e-4)
+            err = gradient_checker.compute_gradient_error(
+                in_val, in_shape, out_val, out_shape)
+            self.assertLess(err, 1e-4)
 
-  @test_util.run_deprecated_v1
   def test_BxxC_Gradient(self):
     self._VariableShapeGradient([-1, None, None, -1])
 
-  @test_util.run_deprecated_v1
   def test_xHWx_Gradient(self):
     self._VariableShapeGradient([None, -1, -1, None])
 
-  @test_util.run_deprecated_v1
   def test_BHWC_Gradient(self):
     self._VariableShapeGradient([-1, -1, -1, -1])
 
-  @test_util.run_deprecated_v1
   def test_AllNone_Gradient(self):
     self._VariableShapeGradient([None, None, None, None])