Merge pull request #25467 from a6802739:crop_and_resize_gradient
PiperOrigin-RevId: 236007758
This commit is contained in:
commit
746397a4ed
@ -99,6 +99,25 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
|
||||
|
||||
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
DataType input_type;
|
||||
string method;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type));
|
||||
auto image_shape = Shape(scope, op.input(0));
|
||||
grad_outputs->push_back(CropAndResizeGradImage(
|
||||
scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type,
|
||||
CropAndResizeGradImage::Method(method)));
|
||||
grad_outputs->push_back(CropAndResizeGradBoxes(
|
||||
scope, grad_inputs[0], op.input(0), op.input(1), op.input(2)));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
|
||||
REGISTER_GRADIENT_OP("CropAndResize", CropAndResizeGradHelper);
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ops::Const;
|
||||
using ops::CropAndResize;
|
||||
using ops::ResizeBicubic;
|
||||
using ops::ResizeBilinear;
|
||||
using ops::ResizeNearestNeighbor;
|
||||
@ -219,5 +220,50 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
|
||||
|
||||
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
|
||||
|
||||
class CropAndResizeGradTest : public ::testing::Test {
|
||||
protected:
|
||||
CropAndResizeGradTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
template <typename T>
|
||||
Tensor MakeData(const TensorShape& data_shape) {
|
||||
DataType data_type = DataTypeToEnum<T>::v();
|
||||
Tensor data(data_type, data_shape);
|
||||
auto data_flat = data.flat<T>();
|
||||
for (int i = 0; i < data_flat.size(); ++i) {
|
||||
data_flat(i) = T(i);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
|
||||
const Input& crop_szie, Output* x, Output* y) {
|
||||
*x = Const<T>(scope_, x_data);
|
||||
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
|
||||
CropAndResize::Method("bilinear"));
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
}
|
||||
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
void TestCropAndResize() {
|
||||
TensorShape x_shape({1, 4, 2, 1});
|
||||
Tensor x_data = MakeData<X_T>(x_shape);
|
||||
TensorShape box_shape({1, 4});
|
||||
Tensor boxes = MakeData<X_T>(box_shape);
|
||||
Output x, y;
|
||||
MakeOp<X_T>(x_data, boxes, {0}, {1, 1}, &x, &y);
|
||||
JAC_T max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
|
||||
scope_, x, x_data, y, {1, 1, 1, 1}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(CropAndResizeGradTest, TestCrop) {
|
||||
TestCropAndResize<float, float, float>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user