From 5d178b62b68d1508336697bd2f59de8caf847f69 Mon Sep 17 00:00:00 2001 From: a6802739 Date: Sun, 3 Feb 2019 15:01:54 +0800 Subject: [PATCH] Add CropAndResize gradients for C++ image gradient operators --- tensorflow/cc/gradients/image_grad.cc | 19 +++++++++ tensorflow/cc/gradients/image_grad_test.cc | 46 ++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 05c287bdc62..12faee66b85 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -86,6 +86,25 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); +Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + DataType input_type; + string method; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type)); + auto image_shape = Shape(scope, op.input(0)); + grad_outputs->push_back(CropAndResizeGradImage( + scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type, + CropAndResizeGradImage::Method(method))); + grad_outputs->push_back(CropAndResizeGradBoxes( + scope, grad_inputs[0], op.input(0), op.input(1), op.input(2))); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} + +REGISTER_GRADIENT_OP("CropAndResize", CropAndResizeGradHelper); } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 1d150226538..1288d7e1e2c 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -27,6 +27,7 @@ namespace tensorflow { namespace { using ops::Const; +using ops::CropAndResize; using ops::ResizeBicubic; using ops::ResizeBilinear; using ops::ResizeNearestNeighbor; @@ -194,5 +195,50 @@ class ScaleAndTranslateGradTest : public ::testing::Test { TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } +class CropAndResizeGradTest : public ::testing::Test { + protected: + CropAndResizeGradTest() : scope_(Scope::NewRootScope()) {} + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind, + const Input& crop_szie, Output* x, Output* y) { + *x = Const(scope_, x_data); + *y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie, + CropAndResize::Method("bilinear")); + TF_ASSERT_OK(scope_.status()); + } + + template + void TestCropAndResize() { + TensorShape x_shape({1, 4, 2, 1}); + Tensor x_data = MakeData(x_shape); + TensorShape box_shape({1, 4}); + Tensor boxes = MakeData(box_shape); + Output x, y; + MakeOp(x_data, boxes, {0}, {1, 1}, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 1, 1, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + Scope scope_; +}; + +TEST_F(CropAndResizeGradTest, TestCrop) { + TestCropAndResize(); +} + } // namespace } // namespace tensorflow