Eager execution coverage for image_grad_test.py. Removed run_deprecated_v1 decorators. (Part 2)
PiperOrigin-RevId: 349505411 Change-Id: Ifeaf7b77d0150339d88eda47984752bb7e3efa98
This commit is contained in:
parent
8ef5acc36d
commit
d3bba8c715
@ -27,9 +27,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -137,47 +135,45 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
dtype=np.float32,
|
||||
use_gpu=False,
|
||||
force_gpu=False):
|
||||
with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu) as sess:
|
||||
with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu):
|
||||
# Input values should not influence gradients
|
||||
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(dtype)
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
resized_tensor = image_ops.resize_bilinear(
|
||||
input_tensor,
|
||||
out_shape[1:3],
|
||||
align_corners=align_corners,
|
||||
half_pixel_centers=half_pixel_centers)
|
||||
# compute_gradient will use a random tensor as the init value
|
||||
return gradient_checker.compute_gradient(input_tensor, in_shape,
|
||||
resized_tensor, out_shape)
|
||||
|
||||
@parameterized.parameters({
|
||||
'batch_size': 1,
|
||||
'channel_count': 1
|
||||
}, {
|
||||
'batch_size': 2,
|
||||
'channel_count': 3
|
||||
}, {
|
||||
'batch_size': 5,
|
||||
'channel_count': 4
|
||||
})
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapes(self, batch_size, channel_count):
|
||||
smaller_shape = [batch_size, 2, 3, channel_count]
|
||||
larger_shape = [batch_size, 4, 6, channel_count]
|
||||
for in_shape, out_shape, align_corners, half_pixel_centers in \
|
||||
self._itGen(smaller_shape, larger_shape):
|
||||
# Input values should not influence shapes
|
||||
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32)
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
resized_tensor = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
||||
self.assertEqual(out_shape, list(resized_tensor.get_shape()))
|
||||
grad_tensor = gradients_impl.gradients(resized_tensor, input_tensor)[0]
|
||||
self.assertEqual(in_shape, list(grad_tensor.get_shape()))
|
||||
with self.cached_session():
|
||||
resized_values = self.evaluate(resized_tensor)
|
||||
self.assertEqual(out_shape, list(resized_values.shape))
|
||||
grad_values = self.evaluate(grad_tensor)
|
||||
self.assertEqual(in_shape, list(grad_values.shape))
|
||||
def func(in_tensor):
|
||||
return image_ops.resize_bilinear(
|
||||
in_tensor,
|
||||
out_shape[1:3],
|
||||
align_corners=align_corners,
|
||||
half_pixel_centers=half_pixel_centers)
|
||||
|
||||
return gradient_checker_v2.compute_gradient(func, [input_tensor])
|
||||
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def _testShapesParameterized(self, use_tape):
|
||||
|
||||
TEST_CASES = [[1, 1], [2, 3], [5, 4]] # pylint: disable=invalid-name
|
||||
|
||||
for batch_size, channel_count in TEST_CASES:
|
||||
smaller_shape = [batch_size, 2, 3, channel_count]
|
||||
larger_shape = [batch_size, 4, 6, channel_count]
|
||||
for in_shape, out_shape, _, _ in self._itGen(smaller_shape, larger_shape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
# Input values should not influence shapes
|
||||
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32)
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
tape.watch(input_tensor)
|
||||
resized_tensor = image_ops.resize_bilinear(input_tensor,
|
||||
out_shape[1:3])
|
||||
self.assertEqual(out_shape, list(resized_tensor.get_shape()))
|
||||
|
||||
grad_tensor = tape.gradient(resized_tensor, input_tensor)
|
||||
self.assertEqual(in_shape, list(grad_tensor.get_shape()))
|
||||
with self.cached_session():
|
||||
resized_values = self.evaluate(resized_tensor)
|
||||
self.assertEqual(out_shape, list(resized_values.shape))
|
||||
grad_values = self.evaluate(grad_tensor)
|
||||
self.assertEqual(in_shape, list(grad_values.shape))
|
||||
|
||||
@parameterized.parameters({
|
||||
'batch_size': 1,
|
||||
@ -189,7 +185,6 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
'batch_size': 3,
|
||||
'channel_count': 2
|
||||
})
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradients(self, batch_size, channel_count):
|
||||
smaller_shape = [batch_size, 2, 3, channel_count]
|
||||
larger_shape = [batch_size, 5, 6, channel_count]
|
||||
@ -197,10 +192,9 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
self._itGen(smaller_shape, larger_shape):
|
||||
jacob_a, jacob_n = self._getJacobians(in_shape, out_shape, align_corners,
|
||||
half_pixel_centers)
|
||||
threshold = 1e-4
|
||||
threshold = 5e-3
|
||||
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTypes(self):
|
||||
in_shape = [1, 4, 6, 1]
|
||||
out_shape = [1, 2, 3, 1]
|
||||
@ -260,14 +254,12 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
'batch_size': 5,
|
||||
'channel_count': 4
|
||||
})
|
||||
@test_util.run_deprecated_v1
|
||||
def testCompareGpuVsCpu(self, batch_size, channel_count):
|
||||
smaller_shape = [batch_size, 4, 6, channel_count]
|
||||
larger_shape = [batch_size, 8, 16, channel_count]
|
||||
for params in self._itGen(smaller_shape, larger_shape):
|
||||
self._gpuVsCpuCase(*params, dtype=np.float32)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCompareGpuVsCpuFloat64(self):
|
||||
in_shape = [1, 5, 7, 1]
|
||||
out_shape = [1, 9, 11, 1]
|
||||
|
Loading…
Reference in New Issue
Block a user