Adds python gradients for ScaleAndTranslate ops, which are now used internal by tf.resize_images_v2 op.

Also changes behavior of ScaleAndTranslateOp when sampling location is near borders, before if any part of the scaled sampling kernel was inside the input image a result would be returned, now a result will be returned only if the sampling location is inside the input image. This avoids complicated edge cases where only sampling contributions with weight zero are inside the image, which are then normalized causing large gradients.
Fixes bug in ScaleAndTranslateOpGrad when output pixel has no contribution from input.
This change won't affect any current usage in resize.
Correctly pass antialias attribute to ScaleAndTranslateOpGrad.

PiperOrigin-RevId: 238576377
This commit is contained in:
A. Unique TensorFlower 2019-03-14 20:57:03 -07:00 committed by TensorFlower Gardener
parent ce6a60a058
commit 7a056bf825
6 changed files with 205 additions and 27 deletions

View File

@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
string kernel_type;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type));
bool antialias;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias));
grad_outputs->push_back(internal::ScaleAndTranslateGrad(
scope, grad_inputs[0], op.input(0), op.input(2), op.input(3),
internal::ScaleAndTranslateGrad::KernelType(kernel_type)));
internal::ScaleAndTranslateGrad::KernelType(kernel_type)
.Antialias(antialias)));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,

View File

@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
}
template <typename T>
void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x,
Output* y) {
void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale,
Input translation, const string& kernel_type, bool antialias,
Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f});
*y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation,
ScaleAndTranslate::KernelType(kernel_type)
.Antialias(antialias)
.Antialias(antialias));
TF_ASSERT_OK(scope_.status());
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResize() {
TensorShape x_shape({1, 2, 3, 1});
void TestScaleAndTranslate(const TensorShape x_shape, const int out_height,
const int out_width, Input scale,
Input translation, const string& kernel_type,
bool antialias) {
Tensor x_data = MakeData<X_T>(x_shape);
Output x, y;
MakeOp<X_T>(x_data, {4, 6}, &x, &y);
MakeOp<X_T>(x_data, {out_height, out_width}, scale, translation,
kernel_type, antialias, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error)));
EXPECT_LT(max_error, 2e-3);
}
const std::vector<Input> kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f},
Input{2.1f, 2.1f}};
const std::vector<Input> kTranslations = {
Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f},
Input{100.0f, 200.0f}};
Scope scope_;
};
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
TEST_F(ScaleAndTranslateGradTest, TestGrads) {
const std::vector<std::string> kKernelTypes = {"lanczos1", "lanczos3",
"lanczos5", "gaussian"};
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) {
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
TestScaleAndTranslate<float, float, float>(kXShape, kOutHeight, kOutWidth,
scale, translation, "lanczos3",
false);
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 4, 6, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
class CropAndResizeGradTest : public ::testing::Test {
protected:
@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test {
template <typename T>
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
const Input& crop_szie, Output* x, Output* y) {
const Input& crop_size, Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_size,
CropAndResize::Method("bilinear"));
TF_ASSERT_OK(scope_.status());
}

View File

@ -82,10 +82,8 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel,
const float col_f = x + 0.5f;
const float sample_f = col_f * inv_scale + inv_translate;
// Don't sample when the sampling *kernel* is completely outside the
// source image.
if (sample_f < 0 - kernel.Radius() * kernel_scale ||
sample_f > input_size + kernel.Radius() * kernel_scale) {
// Don't sample when the sampling location is outside the source image.
if (sample_f < 0 || sample_f > input_size) {
// Add an empty span.
starts_vec(x) = 0;
continue;
@ -169,11 +167,15 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans,
auto grad_weights_vec = grad_spans->weights.vec<float>();
grad_weights_vec.setZero();
for (int input_index = 0; input_index < forward_input_size; ++input_index) {
const int start_span = grad_components[input_index].front().index;
grad_starts_vec(input_index) = start_span;
for (const GradComponent& gc : grad_components[input_index]) {
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
start_span) += gc.weight;
if (!grad_components[input_index].empty()) {
const int start_span = grad_components[input_index].front().index;
grad_starts_vec(input_index) = start_span;
for (const GradComponent& gc : grad_components[input_index]) {
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
start_span) += gc.weight;
}
} else {
grad_starts_vec(input_index) = 0;
}
}
return Status::OK();

View File

@ -120,7 +120,8 @@ void Sample(const DynamicKernel& kernel, const bool antialias,
1;
std::fill(dest, dest + channels, 0.0f);
if (y_span_end <= y_span_start || x_span_end <= x_span_start) {
if (sample_f.x() < 0.0f || sample_f.y() < 0.0f || sample_f.x() > in_width ||
sample_f.y() > in_height) {
return;
}
const Vector2f one_over_kernel_scale(1.0f / kernel_scale.x(),
@ -170,6 +171,8 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
const int64 out_height = output.dimension(1);
const int64 out_width = output.dimension(2);
const int64 in_height = images.dimension(1);
const int64 in_width = images.dimension(2);
for (int b = 0; b < batch; ++b) {
for (int64 y = 0; y < out_height; ++y) {
@ -178,8 +181,13 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
for (int64 x = 0; x < out_width; ++x) {
const float out_x_f = static_cast<float>(x) + 0.5;
const float in_x_f = out_x_f * scale.x() + translate.x();
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
&output(b, y, x, 0));
if (in_x_f < 0.0f || in_y_f < 0.0f || in_x_f > in_width ||
in_y_f > in_height) {
std::fill(&output(b, y, x, 0), &output(b, y, x + 1, 0), 0.0f);
} else {
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
&output(b, y, x, 0));
}
}
}
}

View File

@ -68,6 +68,28 @@ def _ResizeBilinearGrad(op, grad):
return [grad0, None]
@ops.RegisterGradient("ScaleAndTranslate")
def _ScaleAndTranslateGrad(op, grad):
"""The derivatives for ScaleAndTranslate transformation op.
Args:
op: The ScaleAndTranslate op.
grad: The tensor representing the gradient w.r.t. the output.
Returns:
The gradients w.r.t. the input.
"""
grad0 = gen_image_ops.scale_and_translate_grad(
grad,
op.inputs[0],
op.inputs[2],
op.inputs[3],
kernel_type=op.get_attr("kernel_type"),
antialias=op.get_attr("antialias"))
return [grad0, None, None, None]
@ops.RegisterGradient("ResizeBicubic")
def _ResizeBicubicGrad(op, grad):
"""The derivatives for bicubic resizing.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
@ -40,7 +41,7 @@ class ResizeNearestNeighborOpTest(test.TestCase):
for nptype in self.TYPES:
x = np.arange(0, 4).reshape(in_shape).astype(nptype)
with self.cached_session(use_gpu=True) as sess:
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
out_shape[1:3])
@ -113,7 +114,7 @@ class ResizeBilinearOpTest(test.TestCase):
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
with self.cached_session() as sess:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
self.assertEqual(out_shape, list(resize_out.get_shape()))
@ -204,7 +205,7 @@ class ResizeBicubicOpTest(test.TestCase):
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.cached_session() as sess:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
@ -259,6 +260,70 @@ class ResizeBicubicOpTest(test.TestCase):
self.assertEqual([None], grad)
class ScaleAndTranslateOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testGrads(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = [
'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle',
'keyscubic', 'mitchellcubic'
]
scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)]
translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)]
for scale in scales:
for translation in translations:
for kernel_type in kernel_types:
for antialias in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
err = gradient_checker.compute_gradient_error(
input_tensor,
in_shape,
scale_and_translate_out,
out_shape,
x_init_value=x)
self.assertLess(err, 1e-3)
def testIdentityGrads(self):
"""Tests that Gradients for 1.0 scale should be ones for some kernels."""
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic']
scale = (1.0, 1.0)
translation = (0.0, 0.0)
antialias = True
for kernel_type in kernel_types:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
with backprop.GradientTape() as tape:
tape.watch(input_tensor)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
grad = tape.gradient(scale_and_translate_out, input_tensor)[0]
grad_v = self.evaluate(grad)
self.assertAllClose(np.ones_like(grad_v), grad_v)
class CropAndResizeOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):