Eager execution coverage for extract_image_patches_grad_test.py. Removed run_deprecated_v1 decorators.
PiperOrigin-RevId: 349505498 Change-Id: I4a43b1be685d1fef4a9e7ed91e265617c7187feb
This commit is contained in:
parent
d3bba8c715
commit
3821342892
@ -3229,7 +3229,7 @@ cuda_py_test(
|
|||||||
name = "extract_image_patches_grad_test",
|
name = "extract_image_patches_grad_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["extract_image_patches_grad_test.py"],
|
srcs = ["extract_image_patches_grad_test.py"],
|
||||||
shard_count = 3,
|
shard_count = 9,
|
||||||
tags = ["notap"], # http://b/31080670
|
tags = ["notap"], # http://b/31080670
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
@ -18,20 +18,23 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed as random_seed_lib
|
from tensorflow.python.framework import random_seed as random_seed_lib
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ExtractImagePatchesGradTest(test.TestCase):
|
class ExtractImagePatchesGradTest(test.TestCase, parameterized.TestCase):
|
||||||
"""Gradient-checking for ExtractImagePatches op."""
|
"""Gradient-checking for ExtractImagePatches op."""
|
||||||
|
|
||||||
_TEST_CASES = [
|
_TEST_CASES = [
|
||||||
@ -79,7 +82,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
# Set graph seed for determinism.
|
# Set graph seed for determinism.
|
||||||
random_seed = 42
|
random_seed = 42
|
||||||
@ -91,80 +93,94 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
|||||||
in_shape = test_case['in_shape']
|
in_shape = test_case['in_shape']
|
||||||
in_val = constant_op.constant(
|
in_val = constant_op.constant(
|
||||||
np.random.random(in_shape), dtype=dtypes.float32)
|
np.random.random(in_shape), dtype=dtypes.float32)
|
||||||
|
# Avoid `dangerous-default-value` pylint error by creating default
|
||||||
|
# args to `extract` as tuples.
|
||||||
|
ksizes = tuple(test_case['ksizes'])
|
||||||
|
strides = tuple(test_case['strides'])
|
||||||
|
rates = tuple(test_case['rates'])
|
||||||
|
|
||||||
for padding in ['VALID', 'SAME']:
|
for padding in ['VALID', 'SAME']:
|
||||||
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
|
||||||
test_case['strides'],
|
|
||||||
test_case['rates'], padding)
|
|
||||||
out_shape = out_val.get_shape().as_list()
|
|
||||||
|
|
||||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
def extract(in_val,
|
||||||
out_val, out_shape)
|
ksizes=ksizes,
|
||||||
|
strides=strides,
|
||||||
|
rates=rates,
|
||||||
|
padding=padding):
|
||||||
|
return array_ops.extract_image_patches(in_val, ksizes, strides,
|
||||||
|
rates, padding)
|
||||||
|
|
||||||
|
err = gradient_checker_v2.max_error(
|
||||||
|
*gradient_checker_v2.compute_gradient(extract, [in_val]))
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||||
def testConstructGradientWithLargeImages(self):
|
def testConstructGradientWithLargeImages(self, use_tape):
|
||||||
batch_size = 4
|
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||||
height = 1024
|
batch_size = 4
|
||||||
width = 1024
|
# Prevent OOM by setting reasonably large image size (b/171808681).
|
||||||
ksize = 5
|
height = 512
|
||||||
images = variable_scope.get_variable('inputs',
|
width = 512
|
||||||
(batch_size, height, width, 1))
|
ksize = 5
|
||||||
patches = array_ops.extract_image_patches(images,
|
shape = (batch_size, height, width, 1)
|
||||||
ksizes=[1, ksize, ksize, 1],
|
images = variables.Variable(
|
||||||
strides=[1, 1, 1, 1],
|
np.random.uniform(size=np.prod(shape)).reshape(shape), name='inputs')
|
||||||
rates=[1, 1, 1, 1],
|
tape.watch(images)
|
||||||
padding='SAME')
|
patches = array_ops.extract_image_patches(images,
|
||||||
# Github issue: #20146
|
ksizes=[1, ksize, ksize, 1],
|
||||||
# tf.image.extract_image_patches() gradient very slow at graph construction
|
strides=[1, 1, 1, 1],
|
||||||
# time
|
rates=[1, 1, 1, 1],
|
||||||
gradients = gradients_impl.gradients(patches, images)
|
padding='SAME')
|
||||||
# Won't time out.
|
# Github issue: #20146
|
||||||
self.assertIsNotNone(gradients)
|
# tf.image.extract_image_patches() gradient very slow at graph
|
||||||
|
# construction time.
|
||||||
|
gradients = tape.gradient(patches, images)
|
||||||
|
# Won't time out.
|
||||||
|
self.assertIsNotNone(gradients)
|
||||||
|
|
||||||
def _VariableShapeGradient(self, test_shape_pattern):
|
def _VariableShapeGradient(self, test_shape_pattern):
|
||||||
"""Use test_shape_pattern to infer which dimensions are of
|
"""Use test_shape_pattern to infer which dimensions are of
|
||||||
|
|
||||||
variable size.
|
variable size.
|
||||||
"""
|
"""
|
||||||
# Set graph seed for determinism.
|
# Testing shape gradient requires graph mode.
|
||||||
random_seed = 42
|
with ops.Graph().as_default():
|
||||||
random_seed_lib.set_random_seed(random_seed)
|
# Set graph seed for determinism.
|
||||||
|
random_seed = 42
|
||||||
|
random_seed_lib.set_random_seed(random_seed)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
for test_case in self._TEST_CASES:
|
for test_case in self._TEST_CASES:
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
in_shape = test_case['in_shape']
|
in_shape = test_case['in_shape']
|
||||||
test_shape = [
|
test_shape = [
|
||||||
x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
|
x if x is None else y
|
||||||
]
|
for x, y in zip(test_shape_pattern, in_shape)
|
||||||
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
]
|
||||||
|
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
||||||
|
|
||||||
feed_dict = {in_val: np.random.random(in_shape)}
|
feed_dict = {in_val: np.random.random(in_shape)}
|
||||||
for padding in ['VALID', 'SAME']:
|
for padding in ['VALID', 'SAME']:
|
||||||
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
out_val = array_ops.extract_image_patches(in_val,
|
||||||
test_case['strides'],
|
test_case['ksizes'],
|
||||||
test_case['rates'], padding)
|
test_case['strides'],
|
||||||
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
test_case['rates'],
|
||||||
out_shape = out_val_tmp.shape
|
padding)
|
||||||
|
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
||||||
|
out_shape = out_val_tmp.shape
|
||||||
|
|
||||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
err = gradient_checker.compute_gradient_error(
|
||||||
out_val, out_shape)
|
in_val, in_shape, out_val, out_shape)
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_BxxC_Gradient(self):
|
def test_BxxC_Gradient(self):
|
||||||
self._VariableShapeGradient([-1, None, None, -1])
|
self._VariableShapeGradient([-1, None, None, -1])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_xHWx_Gradient(self):
|
def test_xHWx_Gradient(self):
|
||||||
self._VariableShapeGradient([None, -1, -1, None])
|
self._VariableShapeGradient([None, -1, -1, None])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_BHWC_Gradient(self):
|
def test_BHWC_Gradient(self):
|
||||||
self._VariableShapeGradient([-1, -1, -1, -1])
|
self._VariableShapeGradient([-1, -1, -1, -1])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_AllNone_Gradient(self):
|
def test_AllNone_Gradient(self):
|
||||||
self._VariableShapeGradient([None, None, None, None])
|
self._VariableShapeGradient([None, None, None, None])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user