Implementing gradients for tf.image.resize_bicubic.

PiperOrigin-RevId: 168547412
This commit is contained in:
A. Unique TensorFlower 2017-09-13 09:08:03 -07:00 committed by TensorFlower Gardener
parent 4982ef0fa4
commit 3331c574bc
7 changed files with 268 additions and 0 deletions

View File

@ -134,6 +134,7 @@ op { name: "WholeFileReaderV2" rename_to: "WholeFileReader" }
# image_ops
op { name: "AdjustContrastv2" rename_to: "AdjustContrast" }
op { name: "ResizeBilinearGrad" hide: true }
op { name: "ResizeBicubicGrad" hide: true }
op { name: "ResizeNearestNeighborGrad" hide: true }
# io_ops

View File

@ -180,6 +180,21 @@ static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
}
}
static void ComputeGradientXWeightsAndIndices(
const ImageResizerGradientState& resizer_state,
std::vector<WeightsAndIndices>* x_wais) {
CachedInterpolationCalculator calc;
for (int64 x = 0; x < resizer_state.resized_width; ++x) {
GetWeightsAndIndices(resizer_state.width_scale, x,
resizer_state.original_width, &(*x_wais)[x]);
auto& x_wai = (*x_wais)[x];
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
x_wai.index_3);
}
// Do not scale, as we will be using these directly as tensor indices on the
// gradient pass.
}
template <typename T>
static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
int which, int channel_num, const WeightsAndIndices& y_wai,
@ -365,6 +380,73 @@ inline void interpolate_with_caching(
}
}
template <typename T>
inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
const ImageResizerGradientState& resizer_state,
typename TTypes<T, 4>::Tensor output_grad) {
// This function computes gradients for the ResizeBicubic op by iterating over
// the input_grad Tensor and using WeightsAndIndices to appropriately update
// the output gradient.
const float height_scale = resizer_state.height_scale;
const int64 original_height = resizer_state.original_height;
const int channels = resizer_state.channels;
const int64 resized_width = resizer_state.resized_width;
const int64 resized_height = resizer_state.resized_height;
output_grad.setZero();
std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
ComputeGradientXWeightsAndIndices(resizer_state, &x_wais);
for (int64 b = 0; b < resizer_state.batch_size; ++b) {
for (int64 y = 0; y < resized_height; ++y) {
WeightsAndIndices y_wai;
GetWeightsAndIndices(height_scale, y, original_height, &y_wai);
for (int64 x = 0; x < resized_width; ++x) {
const WeightsAndIndices& x_wai = x_wais[x];
for (int64 c = 0; c < channels; ++c) {
T curr_input_grad = input_grad(b, y, x, c);
// row 0 of 0, 1, 2, 3
output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
// row 1 of 0, 1, 2, 3
output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
// row 2 of 0, 1, 2, 3
output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
// row 3 of 0, 1, 2, 3
output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
}
}
}
}
}
} // namespace
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -394,6 +476,36 @@ class ResizeBicubicOp : public OpKernel {
bool align_corners_;
};
template <typename Device, typename T>
class ResizeBicubicOpGrad : public OpKernel {
public:
explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
}
void Compute(OpKernelContext* context) override {
// Validate input.
// First argument is gradient with respect to resized image.
const Tensor& input = context->input(0);
const Tensor& original_image = context->input(1);
ImageResizerGradientState st(align_corners_);
st.ValidateAndCreateOutput(context, input, original_image);
if (!context->status().ok()) return;
typename TTypes<float, 4>::ConstTensor input_grad =
input.tensor<float, 4>();
typename TTypes<T, 4>::Tensor output_grad = st.output->tensor<T, 4>();
ResizeBicubicGrad<T>(input_grad, st, output_grad);
}
private:
bool align_corners_;
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \
.Device(DEVICE_CPU) \
@ -405,4 +517,14 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_GRAD_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ResizeBicubicOpGrad<CPUDevice, T>);
TF_CALL_float(REGISTER_GRAD_KERNEL);
TF_CALL_double(REGISTER_GRAD_KERNEL);
#undef REGISTER_GRAD_KERNEL
} // namespace tensorflow

View File

@ -197,6 +197,31 @@ resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("ResizeBicubicGrad")
.Input("grads: float")
.Input("original_image: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("align_corners: bool = false")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(1));
return Status::OK();
})
.Doc(R"doc(
Computes the gradient of bicubic interpolation.
grads: 4-D with shape `[batch, height, width, channels]`.
original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`,
The image tensor that was resized.
align_corners: If true, rescale grads by (orig_height - 1) / (height - 1), which
exactly aligns the 4 corners of grads and original_image. If false, rescale by
orig_height / height. Treat similarly the width dimension.
output: 4-D with shape `[batch, orig_height, orig_width, channels]`.
Gradients with respect to the input image. Input image must have been
float or double.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("ResizeBilinear")
.Input("images: T")

View File

@ -20176,6 +20176,43 @@ op {
summary: "Resize `images` to `size` using bicubic interpolation."
description: "Input images can be of different types but output images are always float."
}
op {
name: "ResizeBicubicGrad"
input_arg {
name: "grads"
description: "4-D with shape `[batch, height, width, channels]`."
type: DT_FLOAT
}
input_arg {
name: "original_image"
description: "4-D with shape `[batch, orig_height, orig_width, channels]`,\nThe image tensor that was resized."
type_attr: "T"
}
output_arg {
name: "output"
description: "4-D with shape `[batch, orig_height, orig_width, channels]`.\nGradients with respect to the input image. Input image must have been\nfloat or double."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
description: "If true, rescale grads by (orig_height - 1) / (height - 1), which\nexactly aligns the 4 corners of grads and original_image. If false, rescale by\norig_height / height. Treat similarly the width dimension."
}
summary: "Computes the gradient of bicubic interpolation."
}
op {
name: "ResizeBilinear"
input_arg {

View File

@ -167,6 +167,7 @@ NonMaxSuppression
NonMaxSuppressionV2
RandomCrop
ResizeBilinearGrad
ResizeBicubicGrad
ResizeNearestNeighborGrad
SampleDistortedBoundingBox
SampleDistortedBoundingBoxV2

View File

@ -73,6 +73,27 @@ def _ResizeBilinearGrad(op, grad):
return [grad0, None]
@ops.RegisterGradient("ResizeBicubic")
def _ResizeBicubicGrad(op, grad):
"""The derivatives for bicubic resizing.
Args:
op: The ResizeBicubic op.
grad: The tensor representing the gradient w.r.t. the output.
Returns:
The gradients w.r.t. the input.
"""
allowed_types = [dtypes.float32, dtypes.float64]
grad0 = None
if op.inputs[0].dtype in allowed_types:
# pylint: disable=protected-access
grad0 = gen_image_ops._resize_bicubic_grad(
grad, op.inputs[0], align_corners=op.get_attr("align_corners"))
# pylint: enable=protected-access
return [grad0, None]
@ops.RegisterGradient("CropAndResize")
def _CropAndResizeGrad(op, grad):
"""The derivatives for crop_and_resize.

View File

@ -173,6 +173,67 @@ class ResizeBilinearOpTest(test.TestCase):
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
class ResizeBicubicOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.test_session() as sess:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = sess.run(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.test_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.test_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
with self.test_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
grad = gradients_impl.gradients(input_tensor, [resize_out])
self.assertEqual([None], grad)
class CropAndResizeOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):