Eager execution coverage for extract_image_patches_grad_test.py. Removed run_deprecated_v1 decorators.
PiperOrigin-RevId: 349505498 Change-Id: I4a43b1be685d1fef4a9e7ed91e265617c7187feb
This commit is contained in:
parent
d3bba8c715
commit
3821342892
tensorflow/python/kernel_tests
@ -3229,7 +3229,7 @@ cuda_py_test(
|
||||
name = "extract_image_patches_grad_test",
|
||||
size = "medium",
|
||||
srcs = ["extract_image_patches_grad_test.py"],
|
||||
shard_count = 3,
|
||||
shard_count = 9,
|
||||
tags = ["notap"], # http://b/31080670
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -18,20 +18,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed as random_seed_lib
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExtractImagePatchesGradTest(test.TestCase):
|
||||
class ExtractImagePatchesGradTest(test.TestCase, parameterized.TestCase):
|
||||
"""Gradient-checking for ExtractImagePatches op."""
|
||||
|
||||
_TEST_CASES = [
|
||||
@ -79,7 +82,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradient(self):
|
||||
# Set graph seed for determinism.
|
||||
random_seed = 42
|
||||
@ -91,80 +93,94 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
||||
in_shape = test_case['in_shape']
|
||||
in_val = constant_op.constant(
|
||||
np.random.random(in_shape), dtype=dtypes.float32)
|
||||
# Avoid `dangerous-default-value` pylint error by creating default
|
||||
# args to `extract` as tuples.
|
||||
ksizes = tuple(test_case['ksizes'])
|
||||
strides = tuple(test_case['strides'])
|
||||
rates = tuple(test_case['rates'])
|
||||
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
||||
test_case['strides'],
|
||||
test_case['rates'], padding)
|
||||
out_shape = out_val.get_shape().as_list()
|
||||
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
def extract(in_val,
|
||||
ksizes=ksizes,
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding):
|
||||
return array_ops.extract_image_patches(in_val, ksizes, strides,
|
||||
rates, padding)
|
||||
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(extract, [in_val]))
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConstructGradientWithLargeImages(self):
|
||||
batch_size = 4
|
||||
height = 1024
|
||||
width = 1024
|
||||
ksize = 5
|
||||
images = variable_scope.get_variable('inputs',
|
||||
(batch_size, height, width, 1))
|
||||
patches = array_ops.extract_image_patches(images,
|
||||
ksizes=[1, ksize, ksize, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
# Github issue: #20146
|
||||
# tf.image.extract_image_patches() gradient very slow at graph construction
|
||||
# time
|
||||
gradients = gradients_impl.gradients(patches, images)
|
||||
# Won't time out.
|
||||
self.assertIsNotNone(gradients)
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testConstructGradientWithLargeImages(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
batch_size = 4
|
||||
# Prevent OOM by setting reasonably large image size (b/171808681).
|
||||
height = 512
|
||||
width = 512
|
||||
ksize = 5
|
||||
shape = (batch_size, height, width, 1)
|
||||
images = variables.Variable(
|
||||
np.random.uniform(size=np.prod(shape)).reshape(shape), name='inputs')
|
||||
tape.watch(images)
|
||||
patches = array_ops.extract_image_patches(images,
|
||||
ksizes=[1, ksize, ksize, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
# Github issue: #20146
|
||||
# tf.image.extract_image_patches() gradient very slow at graph
|
||||
# construction time.
|
||||
gradients = tape.gradient(patches, images)
|
||||
# Won't time out.
|
||||
self.assertIsNotNone(gradients)
|
||||
|
||||
def _VariableShapeGradient(self, test_shape_pattern):
|
||||
"""Use test_shape_pattern to infer which dimensions are of
|
||||
|
||||
variable size.
|
||||
"""
|
||||
# Set graph seed for determinism.
|
||||
random_seed = 42
|
||||
random_seed_lib.set_random_seed(random_seed)
|
||||
# Testing shape gradient requires graph mode.
|
||||
with ops.Graph().as_default():
|
||||
# Set graph seed for determinism.
|
||||
random_seed = 42
|
||||
random_seed_lib.set_random_seed(random_seed)
|
||||
|
||||
with self.test_session():
|
||||
for test_case in self._TEST_CASES:
|
||||
np.random.seed(random_seed)
|
||||
in_shape = test_case['in_shape']
|
||||
test_shape = [
|
||||
x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
|
||||
]
|
||||
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
||||
with self.test_session():
|
||||
for test_case in self._TEST_CASES:
|
||||
np.random.seed(random_seed)
|
||||
in_shape = test_case['in_shape']
|
||||
test_shape = [
|
||||
x if x is None else y
|
||||
for x, y in zip(test_shape_pattern, in_shape)
|
||||
]
|
||||
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
||||
|
||||
feed_dict = {in_val: np.random.random(in_shape)}
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
||||
test_case['strides'],
|
||||
test_case['rates'], padding)
|
||||
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
||||
out_shape = out_val_tmp.shape
|
||||
feed_dict = {in_val: np.random.random(in_shape)}
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_image_patches(in_val,
|
||||
test_case['ksizes'],
|
||||
test_case['strides'],
|
||||
test_case['rates'],
|
||||
padding)
|
||||
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
||||
out_shape = out_val_tmp.shape
|
||||
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
self.assertLess(err, 1e-4)
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
in_val, in_shape, out_val, out_shape)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_BxxC_Gradient(self):
|
||||
self._VariableShapeGradient([-1, None, None, -1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_xHWx_Gradient(self):
|
||||
self._VariableShapeGradient([None, -1, -1, None])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_BHWC_Gradient(self):
|
||||
self._VariableShapeGradient([-1, -1, -1, -1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_AllNone_Gradient(self):
|
||||
self._VariableShapeGradient([None, None, None, None])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user