Merge pull request #25467 from a6802739:crop_and_resize_gradient

PiperOrigin-RevId: 236007758
This commit is contained in:
TensorFlower Gardener 2019-02-27 16:47:02 -08:00
commit 746397a4ed
2 changed files with 65 additions and 0 deletions

View File

@ -99,6 +99,25 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
DataType input_type;
string method;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type));
auto image_shape = Shape(scope, op.input(0));
grad_outputs->push_back(CropAndResizeGradImage(
scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type,
CropAndResizeGradImage::Method(method)));
grad_outputs->push_back(CropAndResizeGradBoxes(
scope, grad_inputs[0], op.input(0), op.input(1), op.input(2)));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("CropAndResize", CropAndResizeGradHelper);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -27,6 +27,7 @@ namespace tensorflow {
namespace {
using ops::Const;
using ops::CropAndResize;
using ops::ResizeBicubic;
using ops::ResizeBilinear;
using ops::ResizeNearestNeighbor;
@ -219,5 +220,50 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
class CropAndResizeGradTest : public ::testing::Test {
protected:
CropAndResizeGradTest() : scope_(Scope::NewRootScope()) {}
template <typename T>
Tensor MakeData(const TensorShape& data_shape) {
DataType data_type = DataTypeToEnum<T>::v();
Tensor data(data_type, data_shape);
auto data_flat = data.flat<T>();
for (int i = 0; i < data_flat.size(); ++i) {
data_flat(i) = T(i);
}
return data;
}
template <typename T>
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
const Input& crop_szie, Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
CropAndResize::Method("bilinear"));
TF_ASSERT_OK(scope_.status());
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestCropAndResize() {
TensorShape x_shape({1, 4, 2, 1});
Tensor x_data = MakeData<X_T>(x_shape);
TensorShape box_shape({1, 4});
Tensor boxes = MakeData<X_T>(box_shape);
Output x, y;
MakeOp<X_T>(x_data, boxes, {0}, {1, 1}, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 1, 1, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
Scope scope_;
};
TEST_F(CropAndResizeGradTest, TestCrop) {
TestCropAndResize<float, float, float>();
}
} // namespace
} // namespace tensorflow