From 3331c574bcfd85787d7a4f3d1b1b139239a6595b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Sep 2017 09:08:03 -0700 Subject: [PATCH] Implementing gradients for tf.image.resize_bicubic. PiperOrigin-RevId: 168547412 --- tensorflow/cc/ops/op_gen_overrides.pbtxt | 1 + tensorflow/core/kernels/resize_bicubic_op.cc | 122 +++++++++++++++++++ tensorflow/core/ops/image_ops.cc | 25 ++++ tensorflow/core/ops/ops.pbtxt | 37 ++++++ tensorflow/python/ops/hidden_ops.txt | 1 + tensorflow/python/ops/image_grad.py | 21 ++++ tensorflow/python/ops/image_grad_test.py | 61 ++++++++++ 7 files changed, 268 insertions(+) diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index 777e54d342f..0184c82c5af 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -134,6 +134,7 @@ op { name: "WholeFileReaderV2" rename_to: "WholeFileReader" } # image_ops op { name: "AdjustContrastv2" rename_to: "AdjustContrast" } op { name: "ResizeBilinearGrad" hide: true } +op { name: "ResizeBicubicGrad" hide: true } op { name: "ResizeNearestNeighborGrad" hide: true } # io_ops diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index 5131bce448e..1c43e77e7c2 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -180,6 +180,21 @@ static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state, } } +static void ComputeGradientXWeightsAndIndices( + const ImageResizerGradientState& resizer_state, + std::vector* x_wais) { + CachedInterpolationCalculator calc; + for (int64 x = 0; x < resizer_state.resized_width; ++x) { + GetWeightsAndIndices(resizer_state.width_scale, x, + resizer_state.original_width, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2, + x_wai.index_3); + } + // Do not scale, as we will be using these directly as tensor indices on the + // gradient pass. +} + template static EIGEN_ALWAYS_INLINE float ComputeYInterpolation( int which, int channel_num, const WeightsAndIndices& y_wai, @@ -365,6 +380,73 @@ inline void interpolate_with_caching( } } +template +inline void ResizeBicubicGrad(typename TTypes::ConstTensor input_grad, + const ImageResizerGradientState& resizer_state, + typename TTypes::Tensor output_grad) { + // This function computes gradients for the ResizeBicubic op by iterating over + // the input_grad Tensor and using WeightsAndIndices to appropriately update + // the output gradient. + const float height_scale = resizer_state.height_scale; + const int64 original_height = resizer_state.original_height; + const int channels = resizer_state.channels; + const int64 resized_width = resizer_state.resized_width; + const int64 resized_height = resizer_state.resized_height; + + output_grad.setZero(); + + std::vector x_wais(resizer_state.resized_width); + ComputeGradientXWeightsAndIndices(resizer_state, &x_wais); + for (int64 b = 0; b < resizer_state.batch_size; ++b) { + for (int64 y = 0; y < resized_height; ++y) { + WeightsAndIndices y_wai; + GetWeightsAndIndices(height_scale, y, original_height, &y_wai); + for (int64 x = 0; x < resized_width; ++x) { + const WeightsAndIndices& x_wai = x_wais[x]; + for (int64 c = 0; c < channels; ++c) { + T curr_input_grad = input_grad(b, y, x, c); + // row 0 of 0, 1, 2, 3 + output_grad(b, y_wai.index_0, x_wai.index_0, c) += + T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0); + output_grad(b, y_wai.index_0, x_wai.index_1, c) += + T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1); + output_grad(b, y_wai.index_0, x_wai.index_2, c) += + T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2); + output_grad(b, y_wai.index_0, x_wai.index_3, c) += + T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3); + // row 1 of 0, 1, 2, 3 + output_grad(b, y_wai.index_1, x_wai.index_0, c) += + T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0); + output_grad(b, y_wai.index_1, x_wai.index_1, c) += + T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1); + output_grad(b, y_wai.index_1, x_wai.index_2, c) += + T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2); + output_grad(b, y_wai.index_1, x_wai.index_3, c) += + T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3); + // row 2 of 0, 1, 2, 3 + output_grad(b, y_wai.index_2, x_wai.index_0, c) += + T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0); + output_grad(b, y_wai.index_2, x_wai.index_1, c) += + T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1); + output_grad(b, y_wai.index_2, x_wai.index_2, c) += + T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2); + output_grad(b, y_wai.index_2, x_wai.index_3, c) += + T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3); + // row 3 of 0, 1, 2, 3 + output_grad(b, y_wai.index_3, x_wai.index_0, c) += + T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0); + output_grad(b, y_wai.index_3, x_wai.index_1, c) += + T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1); + output_grad(b, y_wai.index_3, x_wai.index_2, c) += + T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2); + output_grad(b, y_wai.index_3, x_wai.index_3, c) += + T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3); + } + } + } + } +} + } // namespace typedef Eigen::ThreadPoolDevice CPUDevice; @@ -394,6 +476,36 @@ class ResizeBicubicOp : public OpKernel { bool align_corners_; }; +template +class ResizeBicubicOpGrad : public OpKernel { + public: + explicit ResizeBicubicOpGrad(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_)); + } + + void Compute(OpKernelContext* context) override { + // Validate input. + // First argument is gradient with respect to resized image. + const Tensor& input = context->input(0); + const Tensor& original_image = context->input(1); + + ImageResizerGradientState st(align_corners_); + st.ValidateAndCreateOutput(context, input, original_image); + + if (!context->status().ok()) return; + + typename TTypes::ConstTensor input_grad = + input.tensor(); + typename TTypes::Tensor output_grad = st.output->tensor(); + + ResizeBicubicGrad(input_grad, st, output_grad); + } + + private: + bool align_corners_; +}; + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \ .Device(DEVICE_CPU) \ @@ -405,4 +517,14 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL +#define REGISTER_GRAD_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ResizeBicubicOpGrad); + +TF_CALL_float(REGISTER_GRAD_KERNEL); +TF_CALL_double(REGISTER_GRAD_KERNEL); + +#undef REGISTER_GRAD_KERNEL + } // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 30a6e92b874..1453943d78d 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -197,6 +197,31 @@ resized_images: 4-D with shape `[batch, new_height, new_width, channels]`. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeBicubicGrad") + .Input("grads: float") + .Input("original_image: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("align_corners: bool = false") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }) + .Doc(R"doc( +Computes the gradient of bicubic interpolation. + +grads: 4-D with shape `[batch, height, width, channels]`. +original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, + The image tensor that was resized. +align_corners: If true, rescale grads by (orig_height - 1) / (height - 1), which + exactly aligns the 4 corners of grads and original_image. If false, rescale by + orig_height / height. Treat similarly the width dimension. +output: 4-D with shape `[batch, orig_height, orig_width, channels]`. + Gradients with respect to the input image. Input image must have been + float or double. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("ResizeBilinear") .Input("images: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 082809088f2..3f75472bca9 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -20176,6 +20176,43 @@ op { summary: "Resize `images` to `size` using bicubic interpolation." description: "Input images can be of different types but output images are always float." } +op { + name: "ResizeBicubicGrad" + input_arg { + name: "grads" + description: "4-D with shape `[batch, height, width, channels]`." + type: DT_FLOAT + } + input_arg { + name: "original_image" + description: "4-D with shape `[batch, orig_height, orig_width, channels]`,\nThe image tensor that was resized." + type_attr: "T" + } + output_arg { + name: "output" + description: "4-D with shape `[batch, orig_height, orig_width, channels]`.\nGradients with respect to the input image. Input image must have been\nfloat or double." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "align_corners" + type: "bool" + default_value { + b: false + } + description: "If true, rescale grads by (orig_height - 1) / (height - 1), which\nexactly aligns the 4 corners of grads and original_image. If false, rescale by\norig_height / height. Treat similarly the width dimension." + } + summary: "Computes the gradient of bicubic interpolation." +} op { name: "ResizeBilinear" input_arg { diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index e02d8a7099a..1678282ced6 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -167,6 +167,7 @@ NonMaxSuppression NonMaxSuppressionV2 RandomCrop ResizeBilinearGrad +ResizeBicubicGrad ResizeNearestNeighborGrad SampleDistortedBoundingBox SampleDistortedBoundingBoxV2 diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index b6b61ab92ce..d17f1a87d97 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -73,6 +73,27 @@ def _ResizeBilinearGrad(op, grad): return [grad0, None] +@ops.RegisterGradient("ResizeBicubic") +def _ResizeBicubicGrad(op, grad): + """The derivatives for bicubic resizing. + + Args: + op: The ResizeBicubic op. + grad: The tensor representing the gradient w.r.t. the output. + + Returns: + The gradients w.r.t. the input. + """ + allowed_types = [dtypes.float32, dtypes.float64] + grad0 = None + if op.inputs[0].dtype in allowed_types: + # pylint: disable=protected-access + grad0 = gen_image_ops._resize_bicubic_grad( + grad, op.inputs[0], align_corners=op.get_attr("align_corners")) + # pylint: enable=protected-access + return [grad0, None] + + @ops.RegisterGradient("CropAndResize") def _CropAndResizeGrad(op, grad): """The derivatives for crop_and_resize. diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index ea65ea10c4d..05e8fa1d728 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -173,6 +173,67 @@ class ResizeBilinearOpTest(test.TestCase): self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4) +class ResizeBicubicOpTest(test.TestCase): + + def testShapeIsCorrectAfterOp(self): + in_shape = [1, 2, 2, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 4).reshape(in_shape).astype(np.float32) + + for align_corners in [True, False]: + with self.test_session() as sess: + input_tensor = constant_op.constant(x, shape=in_shape) + resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3], + align_corners=align_corners) + self.assertEqual(out_shape, list(resize_out.get_shape())) + + resize_out = sess.run(resize_out) + self.assertEqual(out_shape, list(resize_out.shape)) + + def testGradFromResizeToLargerInBothDims(self): + in_shape = [1, 2, 3, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 6).reshape(in_shape).astype(np.float32) + + for align_corners in [True, False]: + with self.test_session(): + input_tensor = constant_op.constant(x, shape=in_shape) + resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3], + align_corners=align_corners) + err = gradient_checker.compute_gradient_error( + input_tensor, in_shape, resize_out, out_shape, x_init_value=x) + self.assertLess(err, 1e-3) + + def testGradFromResizeToSmallerInBothDims(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + + x = np.arange(0, 24).reshape(in_shape).astype(np.float32) + + for align_corners in [True, False]: + with self.test_session(): + input_tensor = constant_op.constant(x, shape=in_shape) + resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3], + align_corners=align_corners) + err = gradient_checker.compute_gradient_error( + input_tensor, in_shape, resize_out, out_shape, x_init_value=x) + self.assertLess(err, 1e-3) + + def testGradOnUnsupportedType(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + + x = np.arange(0, 24).reshape(in_shape).astype(np.uint8) + + with self.test_session(): + input_tensor = constant_op.constant(x, shape=in_shape) + resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3]) + grad = gradients_impl.gradients(input_tensor, [resize_out]) + self.assertEqual([None], grad) + + class CropAndResizeOpTest(test.TestCase): def testShapeIsCorrectAfterOp(self):