diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 7d0f63efbcc..056eea7eb5a 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, string kernel_type; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + bool antialias; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias)); grad_outputs->push_back(internal::ScaleAndTranslateGrad( scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), - internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + internal::ScaleAndTranslateGrad::KernelType(kernel_type) + .Antialias(antialias))); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); return scope.status(); } + REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 3bd52c80bd9..d50f4f5750a 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test { } template - void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, - Output* y) { + void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale, + Input translation, const string& kernel_type, bool antialias, + Output* x, Output* y) { *x = Const(scope_, x_data); - *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + *y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation, + ScaleAndTranslate::KernelType(kernel_type) + .Antialias(antialias) + .Antialias(antialias)); TF_ASSERT_OK(scope_.status()); } template - void TestResize() { - TensorShape x_shape({1, 2, 3, 1}); + void TestScaleAndTranslate(const TensorShape x_shape, const int out_height, + const int out_width, Input scale, + Input translation, const string& kernel_type, + bool antialias) { Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(x_data, {4, 6}, &x, &y); + MakeOp(x_data, {out_height, out_width}, scale, translation, + kernel_type, antialias, &x, &y); JAC_T max_error; TF_ASSERT_OK((ComputeGradientError( - scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); - EXPECT_LT(max_error, 1e-3); + scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error))); + EXPECT_LT(max_error, 2e-3); } + const std::vector kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f}, + Input{2.1f, 2.1f}}; + const std::vector kTranslations = { + Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f}, + Input{100.0f, 200.0f}}; Scope scope_; }; -TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } +TEST_F(ScaleAndTranslateGradTest, TestGrads) { + const std::vector kKernelTypes = {"lanczos1", "lanczos3", + "lanczos5", "gaussian"}; + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) { + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + TestScaleAndTranslate(kXShape, kOutHeight, kOutWidth, + scale, translation, "lanczos3", + false); + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 4, 6, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} class CropAndResizeGradTest : public ::testing::Test { protected: @@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test { template void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind, - const Input& crop_szie, Output* x, Output* y) { + const Input& crop_size, Output* x, Output* y) { *x = Const(scope_, x_data); - *y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie, + *y = CropAndResize(scope_, *x, boxes, box_ind, crop_size, CropAndResize::Method("bilinear")); TF_ASSERT_OK(scope_.status()); } diff --git a/tensorflow/core/kernels/scale_and_translate_op.cc b/tensorflow/core/kernels/scale_and_translate_op.cc index 92b458f2e75..fff457e55c7 100644 --- a/tensorflow/core/kernels/scale_and_translate_op.cc +++ b/tensorflow/core/kernels/scale_and_translate_op.cc @@ -82,10 +82,8 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, const float col_f = x + 0.5f; const float sample_f = col_f * inv_scale + inv_translate; - // Don't sample when the sampling *kernel* is completely outside the - // source image. - if (sample_f < 0 - kernel.Radius() * kernel_scale || - sample_f > input_size + kernel.Radius() * kernel_scale) { + // Don't sample when the sampling location is outside the source image. + if (sample_f < 0 || sample_f > input_size) { // Add an empty span. starts_vec(x) = 0; continue; @@ -169,11 +167,15 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, auto grad_weights_vec = grad_spans->weights.vec(); grad_weights_vec.setZero(); for (int input_index = 0; input_index < forward_input_size; ++input_index) { - const int start_span = grad_components[input_index].front().index; - grad_starts_vec(input_index) = start_span; - for (const GradComponent& gc : grad_components[input_index]) { - grad_weights_vec(input_index * grad_spans->span_size + gc.index - - start_span) += gc.weight; + if (!grad_components[input_index].empty()) { + const int start_span = grad_components[input_index].front().index; + grad_starts_vec(input_index) = start_span; + for (const GradComponent& gc : grad_components[input_index]) { + grad_weights_vec(input_index * grad_spans->span_size + gc.index - + start_span) += gc.weight; + } + } else { + grad_starts_vec(input_index) = 0; } } return Status::OK(); diff --git a/tensorflow/core/kernels/scale_and_translate_op_test.cc b/tensorflow/core/kernels/scale_and_translate_op_test.cc index 127f1641554..a17e3d83963 100644 --- a/tensorflow/core/kernels/scale_and_translate_op_test.cc +++ b/tensorflow/core/kernels/scale_and_translate_op_test.cc @@ -120,7 +120,8 @@ void Sample(const DynamicKernel& kernel, const bool antialias, 1; std::fill(dest, dest + channels, 0.0f); - if (y_span_end <= y_span_start || x_span_end <= x_span_start) { + if (sample_f.x() < 0.0f || sample_f.y() < 0.0f || sample_f.x() > in_width || + sample_f.y() > in_height) { return; } const Vector2f one_over_kernel_scale(1.0f / kernel_scale.x(), @@ -170,6 +171,8 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel, const int64 out_height = output.dimension(1); const int64 out_width = output.dimension(2); + const int64 in_height = images.dimension(1); + const int64 in_width = images.dimension(2); for (int b = 0; b < batch; ++b) { for (int64 y = 0; y < out_height; ++y) { @@ -178,8 +181,13 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel, for (int64 x = 0; x < out_width; ++x) { const float out_x_f = static_cast(x) + 0.5; const float in_x_f = out_x_f * scale.x() + translate.x(); - Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f), - &output(b, y, x, 0)); + if (in_x_f < 0.0f || in_y_f < 0.0f || in_x_f > in_width || + in_y_f > in_height) { + std::fill(&output(b, y, x, 0), &output(b, y, x + 1, 0), 0.0f); + } else { + Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f), + &output(b, y, x, 0)); + } } } } diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index 4925209b7d1..7d240dc6b63 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -68,6 +68,28 @@ def _ResizeBilinearGrad(op, grad): return [grad0, None] +@ops.RegisterGradient("ScaleAndTranslate") +def _ScaleAndTranslateGrad(op, grad): + """The derivatives for ScaleAndTranslate transformation op. + + Args: + op: The ScaleAndTranslate op. + grad: The tensor representing the gradient w.r.t. the output. + + Returns: + The gradients w.r.t. the input. + """ + + grad0 = gen_image_ops.scale_and_translate_grad( + grad, + op.inputs[0], + op.inputs[2], + op.inputs[3], + kernel_type=op.get_attr("kernel_type"), + antialias=op.get_attr("antialias")) + return [grad0, None, None, None] + + @ops.RegisterGradient("ResizeBicubic") def _ResizeBicubicGrad(op, grad): """The derivatives for bicubic resizing. diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index e4bec8e2551..ea41ea39f98 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker @@ -40,7 +41,7 @@ class ResizeNearestNeighborOpTest(test.TestCase): for nptype in self.TYPES: x = np.arange(0, 4).reshape(in_shape).astype(nptype) - with self.cached_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True): input_tensor = constant_op.constant(x, shape=in_shape) resize_out = image_ops.resize_nearest_neighbor(input_tensor, out_shape[1:3]) @@ -113,7 +114,7 @@ class ResizeBilinearOpTest(test.TestCase): x = np.arange(0, 4).reshape(in_shape).astype(np.float32) - with self.cached_session() as sess: + with self.cached_session(): input_tensor = constant_op.constant(x, shape=in_shape) resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) self.assertEqual(out_shape, list(resize_out.get_shape())) @@ -204,7 +205,7 @@ class ResizeBicubicOpTest(test.TestCase): x = np.arange(0, 4).reshape(in_shape).astype(np.float32) for align_corners in [True, False]: - with self.cached_session() as sess: + with self.cached_session(): input_tensor = constant_op.constant(x, shape=in_shape) resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3], align_corners=align_corners) @@ -259,6 +260,70 @@ class ResizeBicubicOpTest(test.TestCase): self.assertEqual([None], grad) +class ScaleAndTranslateOpTest(test.TestCase): + + @test_util.run_deprecated_v1 + def testGrads(self): + in_shape = [1, 2, 3, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 6).reshape(in_shape).astype(np.float32) + + kernel_types = [ + 'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle', + 'keyscubic', 'mitchellcubic' + ] + scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)] + translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)] + for scale in scales: + for translation in translations: + for kernel_type in kernel_types: + for antialias in [True, False]: + with self.cached_session(): + input_tensor = constant_op.constant(x, shape=in_shape) + scale_and_translate_out = image_ops.scale_and_translate( + input_tensor, + out_shape[1:3], + scale=constant_op.constant(scale), + translation=constant_op.constant(translation), + kernel_type=kernel_type, + antialias=antialias) + err = gradient_checker.compute_gradient_error( + input_tensor, + in_shape, + scale_and_translate_out, + out_shape, + x_init_value=x) + self.assertLess(err, 1e-3) + + def testIdentityGrads(self): + """Tests that Gradients for 1.0 scale should be ones for some kernels.""" + in_shape = [1, 2, 3, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 6).reshape(in_shape).astype(np.float32) + + kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic'] + scale = (1.0, 1.0) + translation = (0.0, 0.0) + antialias = True + for kernel_type in kernel_types: + with self.cached_session(): + input_tensor = constant_op.constant(x, shape=in_shape) + with backprop.GradientTape() as tape: + tape.watch(input_tensor) + scale_and_translate_out = image_ops.scale_and_translate( + input_tensor, + out_shape[1:3], + scale=constant_op.constant(scale), + translation=constant_op.constant(translation), + kernel_type=kernel_type, + antialias=antialias) + grad = tape.gradient(scale_and_translate_out, input_tensor)[0] + grad_v = self.evaluate(grad) + self.assertAllClose(np.ones_like(grad_v), grad_v) + + class CropAndResizeOpTest(test.TestCase): def testShapeIsCorrectAfterOp(self):