Eager execution coverage for image_grad_test.py. Removed run_deprecated_v1 decorators. (Part 1)

PiperOrigin-RevId: 339358538
Change-Id: I70d776bfa1d78ada4f7e046b189fcb7434a33e8d
This commit is contained in:
Hye Soo Yang 2020-10-27 16:47:35 -07:00 committed by TensorFlower Gardener
parent 64300f9afe
commit 8f15dd3430

View File

@ -18,21 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from absl.testing import parameterized
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.platform import test
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@test_util.for_all_test_methods(test_util.disable_xla,
@ -221,17 +221,18 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
threshold = 1e-5
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradOnUnsupportedType(self, use_tape):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
tape.watch(input_tensor)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(resize_out, [input_tensor])
grad = tape.gradient(resize_out, [input_tensor])
self.assertEqual([None], grad)
def _gpuVsCpuCase(self, in_shape, out_shape, align_corners,
@ -279,7 +280,8 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
dtype=np.float64)
class ResizeBicubicOpTestBase(test.TestCase):
class ResizeBicubicOpTestBase(test.TestCase, parameterized.TestCase):
"""Tests resize bicubic ops."""
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
@ -296,55 +298,63 @@ class ResizeBicubicOpTestBase(test.TestCase):
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@test_util.run_deprecated_v1
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
for align_corners in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
def func(input_tensor, align_corners=align_corners):
return image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
with self.cached_session():
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(func, [input_tensor]))
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
for align_corners in [True, False]:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
def func(input_tensor, align_corners=align_corners):
return image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
with self.cached_session():
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(func, [input_tensor]))
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
@parameterized.parameters(set((True, context.executing_eagerly())))
def testGradOnUnsupportedType(self, use_tape):
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
tape.watch(input_tensor)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(resize_out, [input_tensor])
grad = tape.gradient(resize_out, [input_tensor])
self.assertEqual([None], grad)
class ScaleAndTranslateOpTestBase(test.TestCase):
"""Tests scale and translate op."""
@test_util.run_deprecated_v1
def testGrads(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
@ -363,19 +373,25 @@ class ScaleAndTranslateOpTestBase(test.TestCase):
for antialias in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
scale_and_translate_out = image_ops.scale_and_translate(
def scale_trans(input_tensor,
scale=scale,
translation=translation,
kernel_type=kernel_type,
antialias=antialias):
# pylint: disable=cell-var-from-loop
return image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
err = gradient_checker.compute_gradient_error(
input_tensor,
in_shape,
scale_and_translate_out,
out_shape,
x_init_value=x)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(scale_trans,
[input_tensor]))
self.assertLess(err, 1e-3)
def testIdentityGrads(self):
@ -466,7 +482,6 @@ class CropAndResizeOpTestBase(test.TestCase):
samples.append(sample)
return samples
@test_util.run_deprecated_v1
def testGradRandomBoxes(self):
"""Test that the gradient is correct for randomly generated boxes.
@ -494,8 +509,6 @@ class CropAndResizeOpTestBase(test.TestCase):
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
@ -512,21 +525,28 @@ class CropAndResizeOpTestBase(test.TestCase):
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
with self.cached_session(use_gpu=True):
image_tensor = constant_op.constant(image, shape=image_shape)
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = constant_op.constant(
box_ind, shape=[num_boxes])
crops = image_ops.crop_and_resize(
box_ind_tensor = constant_op.constant(box_ind, shape=[num_boxes])
def crop_resize(image_tensor, boxes_tensor):
# pylint: disable=cell-var-from-loop
return image_ops.crop_and_resize(
image_tensor, boxes_tensor, box_ind_tensor,
constant_op.constant(crop_size, shape=[2]))
err = gradient_checker.compute_gradient_error(
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
with test_util.device(use_gpu=True):
with self.cached_session():
# pylint: disable=cell-var-from-loop
err1 = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(
lambda x: crop_resize(x, boxes_tensor),
[image_tensor]))
err2 = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(
lambda x: crop_resize(image_tensor, x),
[boxes_tensor]))
err = max(err1, err2)
self.assertLess(err, 2e-3)