Eager execution coverage for image_grad_test.py. Removed run_deprecated_v1 decorators. (Part 1)
PiperOrigin-RevId: 339358538 Change-Id: I70d776bfa1d78ada4f7e046b189fcb7434a33e8d
This commit is contained in:
parent
64300f9afe
commit
8f15dd3430
@ -18,21 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.for_all_test_methods(test_util.disable_xla,
|
||||
@ -221,17 +221,18 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
threshold = 1e-5
|
||||
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradOnUnsupportedType(self):
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testGradOnUnsupportedType(self, use_tape):
|
||||
in_shape = [1, 4, 6, 1]
|
||||
out_shape = [1, 2, 3, 1]
|
||||
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
||||
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
tape.watch(input_tensor)
|
||||
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
||||
with self.cached_session():
|
||||
grad = gradients_impl.gradients(resize_out, [input_tensor])
|
||||
grad = tape.gradient(resize_out, [input_tensor])
|
||||
self.assertEqual([None], grad)
|
||||
|
||||
def _gpuVsCpuCase(self, in_shape, out_shape, align_corners,
|
||||
@ -279,7 +280,8 @@ class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
dtype=np.float64)
|
||||
|
||||
|
||||
class ResizeBicubicOpTestBase(test.TestCase):
|
||||
class ResizeBicubicOpTestBase(test.TestCase, parameterized.TestCase):
|
||||
"""Tests resize bicubic ops."""
|
||||
|
||||
def testShapeIsCorrectAfterOp(self):
|
||||
in_shape = [1, 2, 2, 1]
|
||||
@ -296,55 +298,63 @@ class ResizeBicubicOpTestBase(test.TestCase):
|
||||
resize_out = self.evaluate(resize_out)
|
||||
self.assertEqual(out_shape, list(resize_out.shape))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradFromResizeToLargerInBothDims(self):
|
||||
in_shape = [1, 2, 3, 1]
|
||||
out_shape = [1, 4, 6, 1]
|
||||
|
||||
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
|
||||
for align_corners in [True, False]:
|
||||
with self.cached_session():
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
resize_out = image_ops.resize_bicubic(
|
||||
|
||||
def func(input_tensor, align_corners=align_corners):
|
||||
return image_ops.resize_bicubic(
|
||||
input_tensor, out_shape[1:3], align_corners=align_corners)
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||
|
||||
with self.cached_session():
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(func, [input_tensor]))
|
||||
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradFromResizeToSmallerInBothDims(self):
|
||||
in_shape = [1, 4, 6, 1]
|
||||
out_shape = [1, 2, 3, 1]
|
||||
|
||||
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
|
||||
for align_corners in [True, False]:
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
resize_out = image_ops.resize_bicubic(
|
||||
|
||||
def func(input_tensor, align_corners=align_corners):
|
||||
return image_ops.resize_bicubic(
|
||||
input_tensor, out_shape[1:3], align_corners=align_corners)
|
||||
|
||||
with self.cached_session():
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(func, [input_tensor]))
|
||||
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradOnUnsupportedType(self):
|
||||
@parameterized.parameters(set((True, context.executing_eagerly())))
|
||||
def testGradOnUnsupportedType(self, use_tape):
|
||||
with test_util.AbstractGradientTape(use_tape=use_tape) as tape:
|
||||
in_shape = [1, 4, 6, 1]
|
||||
out_shape = [1, 2, 3, 1]
|
||||
|
||||
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
||||
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
tape.watch(input_tensor)
|
||||
|
||||
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
|
||||
with self.cached_session():
|
||||
grad = gradients_impl.gradients(resize_out, [input_tensor])
|
||||
grad = tape.gradient(resize_out, [input_tensor])
|
||||
self.assertEqual([None], grad)
|
||||
|
||||
|
||||
class ScaleAndTranslateOpTestBase(test.TestCase):
|
||||
"""Tests scale and translate op."""
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGrads(self):
|
||||
in_shape = [1, 2, 3, 1]
|
||||
out_shape = [1, 4, 6, 1]
|
||||
@ -363,19 +373,25 @@ class ScaleAndTranslateOpTestBase(test.TestCase):
|
||||
for antialias in [True, False]:
|
||||
with self.cached_session():
|
||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||
scale_and_translate_out = image_ops.scale_and_translate(
|
||||
|
||||
def scale_trans(input_tensor,
|
||||
scale=scale,
|
||||
translation=translation,
|
||||
kernel_type=kernel_type,
|
||||
antialias=antialias):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
return image_ops.scale_and_translate(
|
||||
input_tensor,
|
||||
out_shape[1:3],
|
||||
scale=constant_op.constant(scale),
|
||||
translation=constant_op.constant(translation),
|
||||
kernel_type=kernel_type,
|
||||
antialias=antialias)
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
input_tensor,
|
||||
in_shape,
|
||||
scale_and_translate_out,
|
||||
out_shape,
|
||||
x_init_value=x)
|
||||
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(scale_trans,
|
||||
[input_tensor]))
|
||||
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
def testIdentityGrads(self):
|
||||
@ -466,7 +482,6 @@ class CropAndResizeOpTestBase(test.TestCase):
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradRandomBoxes(self):
|
||||
"""Test that the gradient is correct for randomly generated boxes.
|
||||
|
||||
@ -494,8 +509,6 @@ class CropAndResizeOpTestBase(test.TestCase):
|
||||
batch = num_boxes
|
||||
image_shape = [batch, image_height, image_width, depth]
|
||||
crop_size = [crop_height, crop_width]
|
||||
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
||||
boxes_shape = [num_boxes, 4]
|
||||
|
||||
image = np.arange(0, batch * image_height * image_width *
|
||||
depth).reshape(image_shape).astype(np.float32)
|
||||
@ -512,21 +525,28 @@ class CropAndResizeOpTestBase(test.TestCase):
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
box_ind = np.arange(batch, dtype=np.int32)
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
image_tensor = constant_op.constant(image, shape=image_shape)
|
||||
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
|
||||
box_ind_tensor = constant_op.constant(
|
||||
box_ind, shape=[num_boxes])
|
||||
crops = image_ops.crop_and_resize(
|
||||
box_ind_tensor = constant_op.constant(box_ind, shape=[num_boxes])
|
||||
|
||||
def crop_resize(image_tensor, boxes_tensor):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
return image_ops.crop_and_resize(
|
||||
image_tensor, boxes_tensor, box_ind_tensor,
|
||||
constant_op.constant(crop_size, shape=[2]))
|
||||
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
|
||||
crops,
|
||||
crops_shape,
|
||||
delta=delta,
|
||||
x_init_value=[image, boxes])
|
||||
with test_util.device(use_gpu=True):
|
||||
with self.cached_session():
|
||||
# pylint: disable=cell-var-from-loop
|
||||
err1 = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(
|
||||
lambda x: crop_resize(x, boxes_tensor),
|
||||
[image_tensor]))
|
||||
err2 = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(
|
||||
lambda x: crop_resize(image_tensor, x),
|
||||
[boxes_tensor]))
|
||||
err = max(err1, err2)
|
||||
|
||||
self.assertLess(err, 2e-3)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user